-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Utility function to better understand schema of a skops file #301
Comments
Yes, this is also in line with what I had as |
Yeah, good point. Unsafe items could be highlighted in some way |
Here is a quick and dirty implementation, it's not exhaustive but intended as a basis for discussion: import io
import json
from skops.io._audit import Node, get_tree
from skops.io._utils import LoadContext
from skops.io._numpy import NdArrayNode
from skops.io._scipy import SparseMatrixNode
from skops.io._general import FunctionNode
def _check_array_schema(node):
assert isinstance(node.children, dict)
assert len(node.children) == 1
assert 'content' in node.children
assert isinstance(node.children['content'], io.BytesIO)
def _check_function_schema(node):
assert isinstance(node.children, dict)
assert len(node.children) == 1
assert 'content' in node.children
children = node.children['content']
assert isinstance(children, dict)
assert len(children) == 2
assert "module_path" in children
assert "function" in children
assert isinstance(children["module_path"], str)
assert isinstance(children["function"], str)
def _print_node(node: Node, node_name: str, level: int, unsafe_set: set[str]):
name = f"{node.module_name}.{node.class_name}"
is_unsafe = name in unsafe_set
if isinstance(node, FunctionNode):
# if a FunctionNode, children are not visited, but safety should still be checked
child = node.children['content']
fn_name = f"{child['module_path']}.{child['function']}"
is_unsafe = fn_name in unsafe_set
name = name + "=>" + fn_name
prefix = ""
if level > 0:
prefix += "├-"
if level > 1:
prefix += "--" * (level - 1)
#text = f"{level * ' '}{node_name}: {name}{' [UNSAFE]' if is_unsafe else ''}"
text = f"{prefix}{node_name}: {name}{' [UNSAFE]' if is_unsafe else ''}"
print(text)
def print_tree(node, node_name="root", level=0, unsafe_set=None, skip=None, sink=_print_node):
# helper function to pretty-print the nodes
unsafe_set = unsafe_set if unsafe_set is not None else node.get_unsafe_set()
skip = skip if skip is not None else set()
# TODO: let's skip "key_types" but check it's schema first
if node_name in skip:
return
# COMPOSITE TYPES: CHECK ALL ITEMS
if isinstance(node, dict):
for key, val in node.items():
print_tree(val, node_name=key, level=level, unsafe_set=unsafe_set, skip=skip)
return
if isinstance(node, (list, tuple)):
for val in node:
print_tree(val, node_name=node_name, level=level, unsafe_set=unsafe_set, skip=skip)
return
# NO MATCH: RAISE ERROR
if not isinstance(node, Node):
raise TypeError(f"{type(node)}")
# TRIGGER SIDE-EFFECT
_print_node(node=node, node_name=node_name, level=level, unsafe_set=unsafe_set)
# TYPES WHOSE CHILDREN IT MAKES NO SENSE TO VISIT
if isinstance(node, (NdArrayNode, SparseMatrixNode)):
_check_array_schema(node)
return
if isinstance(node, FunctionNode):
_check_function_schema(node)
return
# RECURSE
print_tree(node.children, node_name=node_name, level=level+1, unsafe_set=unsafe_set, skip=skip) As can be seen, the core functionality is not terribly complex, basically just visit the node and then recursively visit the children. For some types, it makes no sense visiting the children (numpy array, function), those are special cased, resulting in some complexity. If we encounter them, we thus return early but also check the children's schema to ensure there are no shenanigans there. Probably some more special cases need to be added. Below is a working example: import numpy as np
import skops.io as sio
from zipfile import ZipFile
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import FeatureUnion, Pipeline
from sklearn.preprocessing import FunctionTransformer, MinMaxScaler, PolynomialFeatures, StandardScaler
estimator = Pipeline([
("features", FeatureUnion([
("scaler", StandardScaler()),
("scaled-poly", Pipeline([
("polys", FeatureUnion([
("poly1", PolynomialFeatures()),
("poly2", PolynomialFeatures(degree=3, include_bias=False))
])),
("square-root", FunctionTransformer(np.sqrt)),
("scale", MinMaxScaler()),
])),
])),
("clf", LogisticRegression(random_state=0, solver="liblinear")),
]).fit([[0, 1], [2, 3], [4, 5]], [0, 1, 2])
skops_file = sio.dumps(estimator)
with ZipFile(io.BytesIO(skops_file), "r") as zip_file:
schema = json.loads(zip_file.read("schema.json"))
tree = get_tree(schema, load_context=LoadContext(src=zip_file))
print_tree(tree, skip={"key_types"}) This prints:
|
That looks pretty okay, we could add a filter like |
To clarify:
Let's say I only want unknown, what would the result look like?
Maybe not so useful if there is already WDYT of the "schema checks"? Useful, overkill? Should we use a library for that? jsonschema, pydantic? |
I wasn't sure what As for how it'd look like, we would print all parents of unsafe types, for the user to know where they're happening. And that's the advantage over As for schema checks, I'd hold on for now:
|
I mainly added that because we have the
Yes, makes sense, but it'll be harder to implement.
If the schema changes, we could theoretically catch that and perform a different schema check. My concern here is that if we return early because we encounter, say, a numpy array, we should at least check that its schema is what we expect or else an attacker could hide something in there and it would not be visible when printing out the tree. If we treat this feature as more of a convenience function and make it clear it doesn't necessarily help auditing the file, we could skip the checks though. |
Resolves skops-dev#301 This is not finished, just a basis for discussion.
This is not high priority, but more of a nice to have.
The
schema.json
generated by the skops persistence is more or less human readable. This is good, because a curious user may want to inspect it before loading a model, so that they have an idea of what they're about to load. It does, however, contain a lot of fluff that is not necessary to understand what's going on. Also, it requires a few lines of code to even read theschema.json
. I wonder if we could provide some helper function to assist with that.Here is an example:
Maybe we could turn it into:
That is, instead of showing module and class separately, we show it in a single line. We also strip other fields like
__loader__
,__id__
,key_types
, etc. This results in a much more compact and readable view of the schema.On top of that, we could think about having a more graphical view of the resulting tree, like with the
tree command
(not sure how this would work with circular references though...).Of course, this would not be a "security feature", in the sense that looking at this simplified schema would be necessary or sufficient to determine of the object is secure.
The text was updated successfully, but these errors were encountered: