Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Utility function to better understand schema of a skops file #301

Closed
BenjaminBossan opened this issue Feb 16, 2023 · 7 comments · Fixed by #317
Closed

Utility function to better understand schema of a skops file #301

BenjaminBossan opened this issue Feb 16, 2023 · 7 comments · Fixed by #317
Labels
persistence Secure persistence feature

Comments

@BenjaminBossan
Copy link
Collaborator

This is not high priority, but more of a nice to have.

The schema.json generated by the skops persistence is more or less human readable. This is good, because a curious user may want to inspect it before loading a model, so that they have an idea of what they're about to load. It does, however, contain a lot of fluff that is not necessary to understand what's going on. Also, it requires a few lines of code to even read the schema.json. I wonder if we could provide some helper function to assist with that.

Here is an example:

# contents of a schema.json I have lying around:
{
  "__class__": "StackingRegressor",
  "__module__": "sklearn.ensemble._stacking",
  "__loader__": "ObjectNode",
  "content": {
    "__class__": "dict",
    "__module__": "builtins",
    "__loader__": "DictNode",
    "content": {
      "estimators": {
        "__class__": "list",
        "__module__": "builtins",
        "__loader__": "ListNode",
        "content": [
          {
            "__class__": "tuple",
            "__module__": "builtins",
            "__loader__": "TupleNode",
            "content": [
              {
                "__class__": "str",
                "__module__": "builtins",
                "__loader__": "JsonNode",
                "content": "\"knn@5\"",
                "is_json": true,
                "__id__": 140165569561648
              },
              {
                "__class__": "Pipeline",
                "__module__": "sklearn.pipeline",
                "__loader__": "ObjectNode",
                "content": {
                ...
# 20000 more lines

Maybe we could turn it into:

>>> schema = sio.load_simple_schema(...)
>>> print(schema)
{
  "class": "sklearn.ensemble._stacking.StackingRegressor",
  "content": {
    "class": "builtins.dict",
    "content": {
      "estimators": {
        "class": "builtins.list",
        "content": [
          {
            "class": "builtins.tuple",
            "content": [
              {
                "class": "builtins.str",
                "content": "knn@5",
              },
              {
                "class": "sklearn.pipeline.Pipeline",
                "content": ...
...

That is, instead of showing module and class separately, we show it in a single line. We also strip other fields like __loader__, __id__, key_types, etc. This results in a much more compact and readable view of the schema.

On top of that, we could think about having a more graphical view of the resulting tree, like with the tree command (not sure how this would work with circular references though...).

>>> sio.visualize_schema(...)
sklearn.ensemble._stacking.StackingRegressor
├── estimators
│   ├── builtins.tuple
│   │   └── builtins.str "knn@5"
│   │   └── sklearn.pipeline.Pipeline
│   │   │   └── ...
...

Of course, this would not be a "security feature", in the sense that looking at this simplified schema would be necessary or sufficient to determine of the object is secure.

@BenjaminBossan BenjaminBossan added the persistence Secure persistence feature label Feb 16, 2023
@adrinjalali
Copy link
Member

Yes, this is also in line with what I had as get_unsafe_tree part of the code which we removed. I was also thinking we should start implementing something like this sooner than later. It would also help people to understand where an "unsafe" item is and if it makes sense for it to be there or not.

@BenjaminBossan
Copy link
Collaborator Author

Yeah, good point. Unsafe items could be highlighted in some way

@BenjaminBossan
Copy link
Collaborator Author

Here is a quick and dirty implementation, it's not exhaustive but intended as a basis for discussion:

import io
import json

from skops.io._audit import Node, get_tree
from skops.io._utils import LoadContext
from skops.io._numpy import NdArrayNode
from skops.io._scipy import SparseMatrixNode
from skops.io._general import FunctionNode

def _check_array_schema(node):
    assert isinstance(node.children, dict)
    assert len(node.children) == 1
    assert 'content' in node.children
    assert isinstance(node.children['content'], io.BytesIO)


def _check_function_schema(node):
    assert isinstance(node.children, dict)
    assert len(node.children) == 1
    assert 'content' in node.children

    children = node.children['content']
    assert isinstance(children, dict)
    assert len(children) == 2
    assert "module_path" in children
    assert "function" in children
    assert isinstance(children["module_path"], str)
    assert isinstance(children["function"], str)


def _print_node(node: Node, node_name: str, level: int, unsafe_set: set[str]):
    name = f"{node.module_name}.{node.class_name}"
    
    is_unsafe = name in unsafe_set

    if isinstance(node, FunctionNode):
        # if a FunctionNode, children are not visited, but safety should still be checked
        child = node.children['content']
        fn_name = f"{child['module_path']}.{child['function']}"
        is_unsafe = fn_name in unsafe_set
        name = name + "=>" + fn_name

    prefix = ""
    if level > 0:
        prefix += "├-"
    if level > 1:
        prefix += "--" * (level - 1)

    #text = f"{level * '  '}{node_name}: {name}{' [UNSAFE]' if is_unsafe else ''}"
    text = f"{prefix}{node_name}: {name}{' [UNSAFE]' if is_unsafe else ''}"
    print(text)


def print_tree(node, node_name="root", level=0, unsafe_set=None, skip=None, sink=_print_node):
    # helper function to pretty-print the nodes
    unsafe_set = unsafe_set if unsafe_set is not None else node.get_unsafe_set()
    skip = skip if skip is not None else set()

    # TODO: let's skip "key_types" but check it's schema first
    if node_name in skip:
        return

    # COMPOSITE TYPES: CHECK ALL ITEMS
    if isinstance(node, dict):
        for key, val in node.items():
            print_tree(val, node_name=key, level=level, unsafe_set=unsafe_set, skip=skip)
        return

    if isinstance(node, (list, tuple)):
        for val in node:
            print_tree(val, node_name=node_name, level=level, unsafe_set=unsafe_set, skip=skip)
        return

    # NO MATCH: RAISE ERROR
    if not isinstance(node, Node):
        raise TypeError(f"{type(node)}")

    # TRIGGER SIDE-EFFECT
    _print_node(node=node, node_name=node_name, level=level, unsafe_set=unsafe_set)

    # TYPES WHOSE CHILDREN IT MAKES NO SENSE TO VISIT
    if isinstance(node, (NdArrayNode, SparseMatrixNode)):
        _check_array_schema(node)
        return

    if isinstance(node, FunctionNode):
        _check_function_schema(node)
        return

    # RECURSE
    print_tree(node.children, node_name=node_name, level=level+1, unsafe_set=unsafe_set, skip=skip)

As can be seen, the core functionality is not terribly complex, basically just visit the node and then recursively visit the children. For some types, it makes no sense visiting the children (numpy array, function), those are special cased, resulting in some complexity. If we encounter them, we thus return early but also check the children's schema to ensure there are no shenanigans there. Probably some more special cases need to be added.

Below is a working example:

import numpy as np
import skops.io as sio
from zipfile import ZipFile
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import FeatureUnion, Pipeline
from sklearn.preprocessing import FunctionTransformer, MinMaxScaler, PolynomialFeatures, StandardScaler

estimator = Pipeline([
    ("features", FeatureUnion([
        ("scaler", StandardScaler()),
        ("scaled-poly", Pipeline([
            ("polys", FeatureUnion([
                ("poly1", PolynomialFeatures()),
                ("poly2", PolynomialFeatures(degree=3, include_bias=False))
            ])),
            ("square-root", FunctionTransformer(np.sqrt)),
            ("scale", MinMaxScaler()),
        ])),
    ])),
    ("clf", LogisticRegression(random_state=0, solver="liblinear")),
]).fit([[0, 1], [2, 3], [4, 5]], [0, 1, 2])

skops_file = sio.dumps(estimator)
with ZipFile(io.BytesIO(skops_file), "r") as zip_file:
    schema = json.loads(zip_file.read("schema.json"))
    tree = get_tree(schema, load_context=LoadContext(src=zip_file))

print_tree(tree, skip={"key_types"})

This prints:

root: sklearn.pipeline.Pipeline
├-attrs: builtins.dict
├---steps: builtins.list
├-----content: builtins.tuple
├-------content: builtins.str
├-------content: sklearn.pipeline.FeatureUnion
├---------attrs: builtins.dict
├-----------transformer_list: builtins.list
├-------------content: builtins.tuple
├---------------content: builtins.str
├---------------content: sklearn.preprocessing._data.StandardScaler
├-----------------attrs: builtins.dict
├-------------------with_mean: builtins.str
├-------------------with_std: builtins.str
├-------------------copy: builtins.str
├-------------------n_features_in_: builtins.str
├-------------------n_samples_seen_: numpy.int64 [UNSAFE]
├-------------------mean_: numpy.ndarray
├-------------------var_: numpy.ndarray
├-------------------scale_: numpy.ndarray
├-------------------_sklearn_version: builtins.str
├-------------content: builtins.tuple
├---------------content: builtins.str
├---------------content: sklearn.pipeline.Pipeline
├-----------------attrs: builtins.dict
├-------------------steps: builtins.list
├---------------------content: builtins.tuple
├-----------------------content: builtins.str
├-----------------------content: sklearn.pipeline.FeatureUnion
├-------------------------attrs: builtins.dict
├---------------------------transformer_list: builtins.list
├-----------------------------content: builtins.tuple
├-------------------------------content: builtins.str
├-------------------------------content: sklearn.preprocessing._polynomial.PolynomialFeatures
├---------------------------------attrs: builtins.dict
├-----------------------------------degree: builtins.str
├-----------------------------------interaction_only: builtins.str
├-----------------------------------include_bias: builtins.str
├-----------------------------------order: builtins.str
├-----------------------------------n_features_in_: builtins.str
├-----------------------------------_min_degree: builtins.str
├-----------------------------------_max_degree: builtins.str
├-----------------------------------n_output_features_: builtins.str
├-----------------------------------_n_out_full: builtins.str
├-----------------------------------_sklearn_version: builtins.str
├-----------------------------content: builtins.tuple
├-------------------------------content: builtins.str
├-------------------------------content: sklearn.preprocessing._polynomial.PolynomialFeatures
├---------------------------------attrs: builtins.dict
├-----------------------------------degree: builtins.str
├-----------------------------------interaction_only: builtins.str
├-----------------------------------include_bias: builtins.str
├-----------------------------------order: builtins.str
├-----------------------------------n_features_in_: builtins.str
├-----------------------------------_min_degree: builtins.str
├-----------------------------------_max_degree: builtins.str
├-----------------------------------n_output_features_: builtins.str
├-----------------------------------_n_out_full: builtins.str
├-----------------------------------_sklearn_version: builtins.str
├---------------------------n_jobs: builtins.str
├---------------------------transformer_weights: builtins.str
├---------------------------verbose: builtins.str
├---------------------------_sklearn_version: builtins.str
├---------------------content: builtins.tuple
├-----------------------content: builtins.str
├-----------------------content: sklearn.preprocessing._function_transformer.FunctionTransformer
├-------------------------attrs: builtins.dict
├---------------------------func: numpy.ufunc=>numpy.core._multiarray_umath.sqrt [UNSAFE]
├---------------------------inverse_func: builtins.str
├---------------------------validate: builtins.str
├---------------------------accept_sparse: builtins.str
├---------------------------check_inverse: builtins.str
├---------------------------feature_names_out: builtins.str
├---------------------------kw_args: builtins.str
├---------------------------inv_kw_args: builtins.str
├---------------------------n_features_in_: builtins.str
├---------------------------_sklearn_version: builtins.str
├---------------------content: builtins.tuple
├-----------------------content: builtins.str
├-----------------------content: sklearn.preprocessing._data.MinMaxScaler
├-------------------------attrs: builtins.dict
├---------------------------feature_range: builtins.tuple
├-----------------------------content: builtins.str
├-----------------------------content: builtins.str
├---------------------------copy: builtins.str
├---------------------------clip: builtins.str
├---------------------------n_features_in_: builtins.str
├---------------------------n_samples_seen_: builtins.str
├---------------------------scale_: numpy.ndarray
├---------------------------min_: numpy.ndarray
├---------------------------data_min_: numpy.ndarray
├---------------------------data_max_: numpy.ndarray
├---------------------------data_range_: numpy.ndarray
├---------------------------_sklearn_version: builtins.str
├-------------------memory: builtins.str
├-------------------verbose: builtins.str
├-------------------_sklearn_version: builtins.str
├-----------n_jobs: builtins.str
├-----------transformer_weights: builtins.str
├-----------verbose: builtins.str
├-----------_sklearn_version: builtins.str
├-----content: builtins.tuple
├-------content: builtins.str
├-------content: sklearn.linear_model._logistic.LogisticRegression
├---------attrs: builtins.dict
├-----------penalty: builtins.str
├-----------dual: builtins.str
├-----------tol: builtins.str
├-----------C: builtins.str
├-----------fit_intercept: builtins.str
├-----------intercept_scaling: builtins.str
├-----------class_weight: builtins.str
├-----------random_state: builtins.str
├-----------solver: builtins.str
├-----------max_iter: builtins.str
├-----------multi_class: builtins.str
├-----------verbose: builtins.str
├-----------warm_start: builtins.str
├-----------n_jobs: builtins.str
├-----------l1_ratio: builtins.str
├-----------n_features_in_: builtins.str
├-----------classes_: numpy.ndarray
├-----------coef_: numpy.ndarray
├-----------intercept_: numpy.ndarray
├-----------n_iter_: numpy.ndarray
├-----------_sklearn_version: builtins.str
├---memory: builtins.str
├---verbose: builtins.str
├---_sklearn_version: builtins.str

@adrinjalali
Copy link
Member

That looks pretty okay, we could add a filter like filter={"known"/"unknown"/"all"} or something

@BenjaminBossan
Copy link
Collaborator Author

That looks pretty okay, we could add a filter like filter={"known"/"unknown"/"all"} or something

To clarify:

  1. This is to show only known, i.e. considered safe / unsafe / all types? Would we ever want the first option?
  2. Is this in addition to the skip argument?

Let's say I only want unknown, what would the result look like?

root: sklearn.pipeline.Pipeline
├-------------------n_samples_seen_: numpy.int64 [UNSAFE]
├---------------------------func: numpy.ufunc=>numpy.core._multiarray_umath.sqrt [UNSAFE]

Maybe not so useful if there is already get_untrusted_types?

WDYT of the "schema checks"? Useful, overkill? Should we use a library for that? jsonschema, pydantic?

@adrinjalali
Copy link
Member

I wasn't sure what skip does, it's probably the same thing. And yeah, we probably don't want to ever print only the known types, not sure, we might though.

As for how it'd look like, we would print all parents of unsafe types, for the user to know where they're happening. And that's the advantage over get_untrusted_types. Right now, we tell the user this type is not okay, but for better context, the user should know where that type is used.

As for schema checks, I'd hold on for now:

  • we haven't started on letting third parties extend this format, so the format/schema might change
  • I don't think we should add a dependency for a schema check.

@BenjaminBossan
Copy link
Collaborator Author

I wasn't sure what skip does, it's probably the same thing

I mainly added that because we have the key_types field, which is really noisy and mostly exists for technical reasons that a user doesn't care about, but maybe we can just hard-code to always skip those?

As for how it'd look like, we would print all parents of unsafe types, for the user to know where they're happening.

Yes, makes sense, but it'll be harder to implement.

As for schema checks, I'd hold on for now

If the schema changes, we could theoretically catch that and perform a different schema check. My concern here is that if we return early because we encounter, say, a numpy array, we should at least check that its schema is what we expect or else an attacker could hide something in there and it would not be visible when printing out the tree. If we treat this feature as more of a convenience function and make it clear it doesn't necessarily help auditing the file, we could skip the checks though.

BenjaminBossan added a commit to BenjaminBossan/skops that referenced this issue Mar 10, 2023
Resolves skops-dev#301

This is not finished, just a basis for discussion.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
persistence Secure persistence feature
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants