Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT Audit before loading a skops file #204

Merged
merged 36 commits into from
Nov 25, 2022
Merged

Conversation

adrinjalali
Copy link
Member

@adrinjalali adrinjalali commented Oct 26, 2022

This adds auditing before load for the skops file format.

It creates a tree of nodes by traversing the state json stored in the .skops file, and loads the information in memory w/o loading any modules or constructing any instances.

Then we can check this tree for existing types/functions and report things which are not trusted. The user then has to pass this list to a load/loads function to successfully load the .skops file:

    from sklearn.linear_model import LogisticRegression
    from skops.io import dump, load

    clf = LogisticRegression(random_state=0, solver="liblinear")
    clf.fit(X_train, y_train)
    from skops.io import dumps, loads
    serialized = dumps(clf)
	# either blindly trust the source
    loaded = loads(serialized, trusted=True)

	# or load by passing a trusted list
    from skops.io import get_untrusted_types
    untrusted_types = get_untrusted_types(file="my-logistic-regression.skops")
    print(untrusted_types)

    loaded = load("my-logistic-regression.skops", trusted=untrusted_types)

@adrinjalali
Copy link
Member Author

@skops-dev/maintainers working on some tests and doing a self review, which will add a bunch to the PR. But how much should we include in this PR? There is a ton of work which can be done later. This PR as is allows the tests to pass as long as the user trusts the input.

Otherwise there are issues which I'm fixing, and there is also work to actually fill a bunch of security holes. I'm not sure how much we want to go on, on the same PR.

@BenjaminBossan
Copy link
Collaborator

Otherwise there are issues which I'm fixing, and there is also work to actually fill a bunch of security holes. I'm not sure how much we want to go on, on the same PR.

I don't think we need to be too strict for the initial release, as long as it's clear to the user what they can and what they can't expect. Regarding this, could you please summarize what you want to have in this PR? E.g. the load docstring says:

    trusted: bool, or list of str, default=False
        If ``True``, the object will be loaded without any security checks. If
        ``False``, the object will be loaded only if there are only trusted
        objects in the dumped file. If a list of strings, the object will be
        loaded only if there are only trusted objects and objects of types
        listed in ``trusted`` are in the dumped file.

but at least in the tests, I don't see the list of string option being used. Does it work? How would I specify the types? And if False, how is it determined what the "trusted objects" are? Those would be some of the questions I would have as a user looking to adopt the feature.

@adrinjalali
Copy link
Member Author

@BenjaminBossan I completely agree with all you said, I was gonna add those to this PR anyway. So I think we're on the same page. I'll get this PR to a reviewable state and ping back.

@adrinjalali
Copy link
Member Author

@BenjaminBossan API question. Which one would you prefer?

load("file.skops", trusted=["mymodule.myclass"])
# or
load("file.skops", trusted=["mymodule.myclass", "builtins.int", "builtins.dict", ...])

As in, should we always trust the types we usually trust, or should we require user to even pass those if they want to customize trusted types?

I personally lean towards the first option.

@BenjaminBossan
Copy link
Collaborator

As in, should we always trust the types we usually trust, or should we require user to even pass those if they want to customize trusted types?

I personally lean towards the first option.

I think this type of question pops up quite often in different disguises, like each time you have some allow list and disallow list. In this case, the second option would be far too much work and would be error prone, so option 1 looks better. I wonder, however, if we can make it so that if needed, a user could still choose not to trust all the defaults.

Copy link
Member Author

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Some lines appear not to be covered, even though it looks like they should be. We should probably understand why that is before merging.

I've left comments on those lines now.

  1. I wonder if we can somehow ensure that during a node's __init__, no object is constructed or module imported. Yes, of course we could just look hard at the code and make sure that nothing ever slips through, but I would like some automatic checking. I have no concrete idea honestly, but maybe we can somehow use the LoadContext and have separate modes on it, audit mode and construct mode, and during construction, we check that construct mode is set or else fail. Then, we only set construct mode temporarily inside of construct, so if some code accidentally constructs an object during the auditing part, it would fail.

We certainly should work on that, but I've been thinking of a series of following "hardenning" PRs which work on improving the security of this implementation. We could think of a few ways, for instance, patching import_module to prevent imports during __init__ maybe.

  1. Is there something we can do to ensure that the children of a node are set correctly? It seems to me that if we do a mistake there, it would leave a backdoor open. IIUC, we only have 3 possibilities for child nodes, a list, a dict or something else. Can we maybe infer the right way to check just from looking at state? I think that would be a huge simplification. Then we don't need custom __init__ in the node subclasses, we just store self.state = state in Node, thus we don't need to think about what all the children are, which can be error prone. For this to work, we might have to ensure that our states always have the same structure, and during loading verify that this structure is found, but it would be a worthwhile trade-off IMO (if it works).

Another thing we can do is to pass allowed_types to get_tree to limit the type of node which can be created. I don't think we can/should leave state as is, we should parse the data during __init__ and load those values; e.g. we read the numpy data in memory from the zipfile in __init__, and construct the array later.

  1. API-wise, as a user, should I be able to say load(..., trusted=["sklearn.*"]) or something like that? I.e. have a way to blanket trust a module as a whole. Maybe as a future addition?

I don't think users should ever do trusted=["sklearn.*] because one can always find things like os through that path. But API wise, adding it later would be easy.

  1. I think the tests should be extended a bit to cover a handful of realistic use cases we could encounter in the wild. As an example, I would like to see aPipeline containing a FeatureUnion, consisting of a couple of estimators in total. The test should check that if I only trust a subset of those estimators, the whole pipeline is not trusted. Or if I have a FunctionTransformer with a custom function or numpy function, it is not trusted unless that function is allowed.

I agree, but I feel like doing that while working on adding sklearn safe types would be easier. Trying to limit this PR as much as we can. If you think we should add a specific test, we can do though. Also, I feel a bit exhausted working on this PR, if you add those tests here, I definitely wouldn't complain :P

docs/persistence.rst Show resolved Hide resolved
docs/persistence.rst Outdated Show resolved Hide resolved
type_name : str
The class name of the type.

trusted : list of str
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happier w/o the types :D

skops/io/_audit.py Outdated Show resolved Hide resolved
skops/io/_audit.py Outdated Show resolved Hide resolved

def get_unsafe_set(self):
if self.is_safe:
return set()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, once we pass trusted down or when we start trusting a few functions from numpy/scipy, this would be covered.

]
self.children = {"shape": Node, "content": list}
else:
raise ValueError(f"Unknown type {self.type}.")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this only should happen if the given file is malformed, which we're not testing for now.

self.type = state["type"]
self.trusted = self._get_trusted(trusted, ["scipy.sparse.spmatrix"])
if self.type != "scipy":
raise TypeError(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should also only happen when input is malformed.

return instance

if isinstance(args, tuple) and not hasattr(instance, "__setstate__"):
raise UnsupportedTypeException(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would be the case on some odd types which we're not testing.

if hasattr(instance, "__setstate__"):
instance.__setstate__(attrs)
else:
instance.__dict__.update(attrs)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here about odd types.

Copy link
Collaborator

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We certainly should work on that, but I've been thinking of a series of following "hardenning" PRs which work on improving the security of this implementation.

Okay, we can do that as a follow up.

I don't think we can/should leave state as is, we should parse the data during __init__ and load those values; e.g. we read the numpy data in memory from the zipfile in __init__, and construct the array later.

Could you please elaborate on that, I don't understand it. Probably it also relates to this comment elsewhere:

we need to do the traverse in __init__, and this traversal is node specific. We almost always call get_tree on children in __init__.

I checked each self.children = ... and none of them are calling get_tree.

I tried to modify the code as following to avoid having a bunch of different attributes for every node type and it worked (only tried for DictNode):

class DictNode(Node):
    def __init__(self, state, load_context: LoadContext, trusted=False):
        super().__init__(state, load_context, trusted)
        self.trusted = self._get_trusted(trusted, ["builtins.dict"])
        # ideally, we could automatically infer the children from the state...
        self.children = {
            "key_types": get_tree(state["key_types"], load_context),
            "content": {
                key: get_tree(value, load_context)
                for key, value in state["content"].items()
            },
        }
        # no other custom attributes stored

    def _construct(self):
        key_types = self.children["key_types"]
        content_ = self.children["content"]
        content = gettype(self.module_name, self.class_name)()
        key_types = key_types.construct()
        for k_type, (key, val) in zip(key_types, content_.items()):
            if key == "categories_":
                pass
            content[k_type(key)] = val.construct()
        return content

# inside of get_unsafe_set, the loop is replaced with:

    def get_unsafe_set(self):
        ...
        for child in self.children.values():
            if child is None:
                continue

            # Get the safety set based on the type of the child. In most cases
            # other than ListNode and DictNode, children are all of type Node.
            if isinstance(child, list):
                for value in child:
                    res.update(value.get_unsafe_set())
            elif isinstance(child, dict):
                for value in child.values():
                    res.update(value.get_unsafe_set())
            elif isinstance(child, Node):
                res.update(child.get_unsafe_set())
            else:
                raise ValueError(f"Unknown type {type(child)}.")
        ...

It's not exactly my initial proposal with saving state but still addresses my concern. WDYT?

I don't think users should ever do trusted=["sklearn.*] because one can always find things like os through that path.

Even if a module in sklearn imports os, the module name would still not be sklearn, so there is no match, right?

I agree, but I feel like doing that while working on adding sklearn safe types would be easier.

About that, what's the plan? Enumerating each and every sklearn class and function that's safe? And keeping it up-to-date with each release, across multiple versions? Sounds infeasible.

skops/io/_audit.py Outdated Show resolved Hide resolved
skops/io/_dispatch.py Outdated Show resolved Hide resolved
skops/io/_persist.py Outdated Show resolved Hide resolved
skops/io/tests/test_audit.py Show resolved Hide resolved
skops/io/_dispatch.py Show resolved Hide resolved
Copy link
Collaborator

@E-Aho E-Aho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had a few more thoughts :)

Comment on lines +159 to +166
# TODO: This should help with fixing recursive references.
# if id(value) in save_context.memo:
# return {
# "__module__": None,
# "__class__": None,
# "__id__": id(value),
# "__loader__": "CachedNode",
# }
Copy link
Collaborator

@E-Aho E-Aho Nov 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should work ok, but we would need to restructure how things get memoized.

Right now, things only get memoized once they've been fully constructued, so if an object has a recursive attribute, it never actually gets to be in the memo.

I thought about this originally, and started exploring if there's a nice way to hold a reference to an object that isn't initialized yet, so we could do something like:

    _obj = Object()  # pseudocode, not the right way to do this

    save_context.memoize(_obj, id)
    res = _get_state(value, save_context)

    _obj.assign(res)   

    return _obj

Still haven't fully thought this through, but I think we might need to do something like this to allow circular self-references to work the right way.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I've also thought exactly along those lines, I still don't know how exactly to do it though. Worst case scenario we end up doing the DAG work lol. But even that shouldn't require much refactoring.

The reason I commented out this code, is that small integers have the same id in python (and I don't know what else has that), and that somehow when saving, we don't save in the right order and the cached object gets loaded before the actual object does.

Comment on lines +60 to +62
def gettype(module_name, cls_or_func):
if module_name and cls_or_func:
return _import_obj(module_name, cls_or_func)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to check something here:

As it stands, is there any way loading a .skops file could redefine something in the global namespace?

My understanding (correct me if I'm wrong), is that this could only happen if:

  • Code is defined that does that before calling loads
  • Something does this in a user's imported module

In either of those cases, this isn't a vulnerability with skops itself, so I think it's ok, but I wanted to make sure I've not missed somewhere that global namespaces could be changed during load, as that could lead to a vulnerability.

In other words, there's not currently a way someone could structure a .skops file that redefines a type we deem "trusted", like np.random.Generator, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's exactly the idea, unless an import statement would do that, in which case the user is already compromised anyway.

I don't see a way with the current format for anybody to be able to modify globals the way one can do with pickle.

skops/io/_persist.py Outdated Show resolved Hide resolved
Copy link
Member Author

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if a module in sklearn imports os, the module name would still not be sklearn, so there is no match, right?

when we persist yes, but somebody could curate a .skops file where they do sklearn.datasets._base as module and os as class name and we'd import it.

About that, what's the plan? Enumerating each and every sklearn class and function that's safe? And keeping it up-to-date with each release, across multiple versions? Sounds infeasible.

It sounds feasible to me, but would require some automation to streamline it. By enumerating, I would think of getting all estimators through sklearn's API rather than adding manually though.

It's not exactly my initial proposal with saving state but still addresses my concern. WDYT?

This doesn't look neat to me, but while doing it I found one bug, so it's a good pattern I'd say. Also, had to modify the get_unsafe_set quit a bit as a result.

skops/io/_dispatch.py Show resolved Hide resolved
Comment on lines +60 to +62
def gettype(module_name, cls_or_func):
if module_name and cls_or_func:
return _import_obj(module_name, cls_or_func)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's exactly the idea, unless an import statement would do that, in which case the user is already compromised anyway.

I don't see a way with the current format for anybody to be able to modify globals the way one can do with pickle.

Comment on lines +159 to +166
# TODO: This should help with fixing recursive references.
# if id(value) in save_context.memo:
# return {
# "__module__": None,
# "__class__": None,
# "__id__": id(value),
# "__loader__": "CachedNode",
# }
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I've also thought exactly along those lines, I still don't know how exactly to do it though. Worst case scenario we end up doing the DAG work lol. But even that shouldn't require much refactoring.

The reason I commented out this code, is that small integers have the same id in python (and I don't know what else has that), and that somehow when saving, we don't save in the right order and the cached object gets loaded before the actual object does.

Copy link
Member Author

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is ready for another review @BenjaminBossan

# this means we're already computing this node's unsafe set, so we
# return an empty set and let the computation of the parent node
# continue. This is to avoid infinite recursion.
return set()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is isn't tested since we haven't figured the whole recursive pointers thing out.

# conditions about BytesIO, etc should be ignored.
if not check_type(get_module(child), child.__name__, self.trusted):
# if the child is a type, we check its safety
res.add(get_module(child) + "." + child.__name__)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for all cases where the child is a type, we have it as trusted, This is only used in reduce.

super().__init__(state, load_context, trusted, memoize=False)
self.trusted = True
self.cached = load_context.get_object(state.get("__id__"))
self.children = {} # type: ignore
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cache node is used for recursive pointers as well.

Copy link
Collaborator

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty good work, I'm happy with the result. There are some small things left to do, please check my comments, but overall exceptional work.

when we persist yes, but somebody could curate a .skops file where they do sklearn.datasets._base as module and os as class name and we'd import it.

Yes, good point. But I wonder how much of a problem that is. If sklearn imports foo somewhere, we cannot guarantee that foo will not be imported at some point, even if users pass fine grained types to allow. Therefore, we already need to assume that importing foo is safe.

Now, using foo could still be dangerous. However, when we create the instances, we could check if the path of foo corresponds to sklearn.* and if not, raise an error. That way, we could still prevent its usage, if I'm not missing something.

It sounds feasible to me, but would require some automation to streamline it. By enumerating, I would think of getting all estimators through sklearn's API rather than adding manually though.

Okay, let's see how it'll work out in practice. Do you plan to include that before next release?

This doesn't look neat to me, but while doing it I found one bug, so it's a good pattern I'd say. Also, had to modify the get_unsafe_set quit a bit as a result.

Not sure what part didn't look neat, but the way you refactored corresponds to my intent and is cleaner IMO, so I'm happy with the outcome.

skops/io/_audit.py Outdated Show resolved Hide resolved
skops/io/_audit.py Outdated Show resolved Hide resolved
skops/io/_dispatch.py Outdated Show resolved Hide resolved
skops/io/_dispatch.py Outdated Show resolved Hide resolved
skops/io/_dispatch.py Outdated Show resolved Hide resolved
skops/io/_numpy.py Show resolved Hide resolved
skops/io/_numpy.py Outdated Show resolved Hide resolved
return state


def tree_get_instance(state, load_context):
return reduce_get_instance(state, load_context, constructor=Tree)
class TreeNode(ReduceNode):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit of an unfortunate name now that "Tree" can have another meaning in skops.io.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could call it SklearnTreeTreeNode 😁

skops/io/_sklearn.py Outdated Show resolved Hide resolved
skops/io/tests/test_audit.py Outdated Show resolved Hide resolved
Copy link
Member Author

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good point. But I wonder how much of a problem that is. If sklearn imports foo somewhere, we cannot guarantee that foo will not be imported at some point, even if users pass fine grained types to allow. Therefore, we already need to assume that importing foo is safe.

Now, using foo could still be dangerous. However, when we create the instances, we could check if the path of foo corresponds to sklearn.* and if not, raise an error. That way, we could still prevent its usage, if I'm not missing something.

Yes, that'd be interesting, but we should do the check before creating the instance. We can do that in a followup PR, and add extra checks for it.

It sounds feasible to me, but would require some automation to streamline it. By enumerating, I would think of getting all estimators through sklearn's API rather than adding manually though.

Okay, let's see how it'll work out in practice. Do you plan to include that before next release?

No, I think we can release a first version w/o trusting much from sklearn.

return content_type([item.construct() for item in self.children["content"]])


def set_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good question, but I saw it being an issue during my tests.

return state


def tree_get_instance(state, load_context):
return reduce_get_instance(state, load_context, constructor=Tree)
class TreeNode(ReduceNode):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could call it SklearnTreeTreeNode 😁

@adrinjalali
Copy link
Member Author

I think I'm happy with this now. It could be merged and we could release early and get feedback, and in the meantime work on the lot left to improve.

cc @skops-dev/maintainers

Copy link
Collaborator

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work, I think this is now ready to be merged. Everything that's still open, we can work on later, e.g. I'd like to add some light typing (-:

@E-Aho Do you want to give this another pass too?

@E-Aho
Copy link
Collaborator

E-Aho commented Nov 25, 2022

Sure! I can give it a final look tonight and hopefully we can merge this in!

Copy link
Collaborator

@E-Aho E-Aho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I caught a tiny typo in a docstring, but overall LGTM!

skops/io/_dispatch.py Outdated Show resolved Hide resolved
@adrinjalali
Copy link
Member Author

Feel free to hit the "Squash and merge" button then @E-Aho :) (we almost always squash and merge)

@BenjaminBossan
Copy link
Collaborator

Feel free to hit the "Squash and merge" button then E-Aho :) (we almost always squash and merge)

We usually try to craft a nice commit message, not just using the GH suggestion.

@E-Aho E-Aho merged commit ca93021 into skops-dev:main Nov 25, 2022
@E-Aho
Copy link
Collaborator

E-Aho commented Nov 25, 2022

All set :)

Side note, is there a list anywhere of the commit prefixes? [FEAT, FIX, DOC, etc]

@adrinjalali adrinjalali deleted the audit-tree branch November 28, 2022 10:32
@adrinjalali
Copy link
Member Author

I don't have a list, but I use:

  • DOC: documentation changes
  • FEAT/FEA: new major features
  • ENH: enhancements to existing features
  • CI: continuous integration, sometimes overlaps with MNT
  • MNT/MAINT: maintenance, technical debt, etc
  • FIX: bug fixes
  • TST: new tests, refactoring tests
  • PERF: performance improvements

@BenjaminBossan
Copy link
Collaborator

I was also searching for such a list. How about adding it to the contribution guide?

@adrinjalali
Copy link
Member Author

We should add it to the maintainers guide instead of contributing guide (we don't have the separation now). Maintainers can/should fix commit messages/titles before merging, and I don't think we should burden first time contributors with such details.

@adrinjalali
Copy link
Member Author

Created #217

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
persistence Secure persistence feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants