Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT Audit before loading a skops file #204

Merged
merged 36 commits into from
Nov 25, 2022
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
4e8141e
FEAT Audit before loading a skops file
adrinjalali Oct 26, 2022
cdedde5
WIP
adrinjalali Nov 1, 2022
aaec391
merge upstream/main
adrinjalali Nov 1, 2022
635d721
numpy loaders
adrinjalali Nov 1, 2022
4644745
scipy and src issue
adrinjalali Nov 1, 2022
262e334
sklearn
adrinjalali Nov 2, 2022
b3934c9
make tests pass
adrinjalali Nov 2, 2022
347c12a
remove pickle.py
adrinjalali Nov 2, 2022
23bccb4
fix a few issues
adrinjalali Nov 8, 2022
3699947
Merge remote-tracking branch 'upstream/main' into audit-tree
adrinjalali Nov 8, 2022
9505e6a
add missing files
adrinjalali Nov 9, 2022
b584eb5
add get_untrusted_types and docs
adrinjalali Nov 11, 2022
006229e
minor fix
adrinjalali Nov 11, 2022
c322c92
add more tests
adrinjalali Nov 11, 2022
f244059
add a smoke test, failing though
adrinjalali Nov 11, 2022
66e83bc
implement safety for functions
adrinjalali Nov 11, 2022
da1b0b5
add missing Tree children for audit
adrinjalali Nov 11, 2022
3d6a905
add missing SparseMatrixNode children for audit
adrinjalali Nov 11, 2022
3fe8f0a
tests pass
adrinjalali Nov 14, 2022
3b60b1e
remove safety tree code
adrinjalali Nov 14, 2022
2a9fbd5
minor test
adrinjalali Nov 14, 2022
a357192
fix ids in test
adrinjalali Nov 14, 2022
f67e63f
Merge remote-tracking branch 'upstream/main' into audit-tree
adrinjalali Nov 15, 2022
939ed29
move type ignore
adrinjalali Nov 15, 2022
97cfb7d
add more tests and some docs
adrinjalali Nov 15, 2022
f5574fc
more comments
adrinjalali Nov 16, 2022
14f4ebd
fix recursive dump and get_untrusted_set
adrinjalali Nov 17, 2022
786e742
Ben's comments
adrinjalali Nov 17, 2022
2ef858e
address comments: sentinel, contextmanager, sorted
adrinjalali Nov 21, 2022
829771d
move all children to the children attribute
adrinjalali Nov 22, 2022
d3488e9
Merge remote-tracking branch 'upstream/main' into audit-tree
adrinjalali Nov 22, 2022
375e862
add complex pipeline test
adrinjalali Nov 23, 2022
00e38bb
apply Ben's suggestions
adrinjalali Nov 24, 2022
f3c0dd0
Merge remote-tracking branch 'upstream/main' into audit-tree
adrinjalali Nov 24, 2022
8a07580
Card object should pass trusted to load
adrinjalali Nov 24, 2022
a309112
Update skops/io/_dispatch.py
E-Aho Nov 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 32 additions & 15 deletions docs/persistence.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,37 @@ The code snippet below illustrates how to use :func:`skops.io.dump` and
clf.fit(X_train, y_train)
dump(clf, "my-logistic-regression.skops")
# ...
loaded = load("my-logistic-regression.skops")
loaded = load("my-logistic-regression.skops", trusted=True)
loaded.predict(X_test)

# in memory
from skops.io import dumps, loads
serialized = dumps(clf)
loaded = loads(serialized)
loaded = loads(serialized, trusted=True)

At the moment, we support the vast majority of sklearn estimators. This includes
complex use cases such as :class:`sklearn.pipeline.Pipeline`,
Note that you should only load files with ``trusted=True`` if you trust the
source. Otherwise you can get a list of untrusted types present in the dump
using :func:`skops.io.get_untrusted_types`:

.. code:: python

from skops.io import get_untrusted_types
unknown_types = get_untrusted_types(file="my-logistic-regression.skops")
print(unknown_types)

Once you check the list and you validate that everything in the list is safe,
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
you can load the file with ``trusted=unknown_types``:

.. code:: python

loaded = load("my-logistic-regression.skops", trusted=unknown_types)

At the moment, we support the vast majority of sklearn estimators. This
includes complex use cases such as :class:`sklearn.pipeline.Pipeline`,
:class:`sklearn.model_selection.GridSearchCV`, classes using Cython code, such
as :class:`sklearn.tree.DecisionTreeClassifier`, and more. If you discover an sklearn
estimator that does not work, please open an issue on the skops `GitHub page
<https://github.com/skops-dev/skops/issues>`_ and let us know.
as :class:`sklearn.tree.DecisionTreeClassifier`, and more. If you discover an
sklearn estimator that does not work, please open an issue on the skops `GitHub
page <https://github.com/skops-dev/skops/issues>`_ and let us know.

In contrast to ``pickle``, skops cannot persist arbitrary Python code. This
means if you have custom functions (say, a custom function to be used with
Expand All @@ -74,16 +91,16 @@ Roadmap
-------

Currently, it is still possible to run insecure code when using skops
persistence. For example, it's possible to load a save file that evaluates arbitrary
code using :func:`eval`. However, we have concrete plans on how to mitigate
this, so please stay updated.
persistence. For example, it's possible to load a save file that evaluates
arbitrary code using :func:`eval`. However, we have concrete plans on how to
mitigate this, so please stay updated.

On top of trying to support persisting all relevant sklearn objects, we plan on
making persistence extensible for other libraries. As a user, this means that if
you trust a certain library, you will be able to tell skops to load code from
that library. As a library author, there will be a clear path of what needs to
be done to add secure persistence to your library, such that skops can save and
load code from your library.
making persistence extensible for other libraries. As a user, this means that
if you trust a certain library, you will be able to tell skops to load code
from that library. As a library author, there will be a clear path of what
needs to be done to add secure persistence to your library, such that skops can
save and load code from your library.

To follow what features are currently planned, filter for the `"persistence"
label <https://github.com/skops-dev/skops/labels/persistence>`_ in our GitHub
Expand Down
4 changes: 2 additions & 2 deletions skops/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from ._persist import dump, dumps, load, loads
from ._persist import dump, dumps, get_untrusted_types, load, loads

__all__ = ["dumps", "load", "loads", "dump"]
__all__ = ["dumps", "load", "loads", "dump", "get_untrusted_types"]
61 changes: 61 additions & 0 deletions skops/io/_audit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from skops.io.exceptions import UntrustedTypesFoundException


def check_type(module_name, type_name, trusted):
"""Check if a type is safe to load.

A type is safe to load only if it's present in the trusted list.

Parameters
----------
module_name : str
The module name of the type.

type_name : str
The class name of the type.

trusted : bool, or list of str
If ``True``, the tree is considered safe. Otherwise trusted has to be
a list of trusted types.

Returns
-------
is_safe : bool
True if the type is safe, False otherwise.
"""
if trusted is True:
return True
return module_name + "." + type_name in trusted


def audit_tree(tree, trusted):
"""Audit a tree of nodes.

A tree is safe only if it contains trusted types. Audit is skipped if
adrinjalali marked this conversation as resolved.
Show resolved Hide resolved
trusted is ``True``.

Parameters
----------
tree : Node
adrinjalali marked this conversation as resolved.
Show resolved Hide resolved
The tree to audit.

trusted : bool, or list of str
If ``True``, the tree is considered safe. Otherwise trusted has to be
a list of trusted types names.

An entry in the list is typically of the form
``skops.io._utils.get_module(obj) + "." + obj.__class__.__name__``.

Raises
------
TypeError
adrinjalali marked this conversation as resolved.
Show resolved Hide resolved
If the tree contains an untrusted type.
"""
if trusted is True:
return

unsafe = tree.get_unsafe_set()
if isinstance(trusted, (list, set)):
unsafe -= set(trusted)
if unsafe:
raise UntrustedTypesFoundException(unsafe)
226 changes: 200 additions & 26 deletions skops/io/_dispatch.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,209 @@
from __future__ import annotations

import json
from ._audit import check_type
from ._utils import LoadContext

from skops.io._utils import LoadContext
NODE_TYPE_MAPPING = {} # type: ignore

GET_INSTANCE_MAPPING = {} # type: ignore

class Node:
"""A node in the tree of objects.

def get_instance(state, load_context: LoadContext):
"""Create instance based on the state, using json if possible"""
This class is a parent class for all nodes in the tree of objects. Each
type of object (e.g. dict, list, etc.) has its own subclass of Node.

Each child class has to implement two methods: ``__init__`` and
``_construct``.

``__init__`` takes care of traversing the state tree and to create the
corresponding ``Node`` objects. It has access to the ``load_context`` which
in turn has access to the source zip file. The child class's ``__init__``
must also set the ``children`` attribute, which is a dictionary of
``{child_name: child_type}``. ``child_name`` is the name of the attribute
adrinjalali marked this conversation as resolved.
Show resolved Hide resolved
which can be checked for safety, and ``child_type`` is the type of the
attribute. ``child_type`` can be ``list``, ``dict``, or ``Node``. Note that
primitives are persisted as a ``JsonNode``.

``_construct`` takes care of constructing the object. It is only called
once and the result is cached in ``construct`` which is implemented in this
class. All required data to construct an instance should be loaded during
``__init__``.

The separation of ``__init__`` and ``_construct`` is necessary because
audit methods are called after ``__init__`` and before ``construct``.
Therefore ``__init__`` should avoid creating any instances or importing
any modules, to avoid running potentially untrusted code.

Parameters
----------
state : dict
A dict representing the state of the dumped object.

load_context : LoadContext
The context of the loading process.

trusted : bool or list of str, default=False
If ``True``, the object will be loaded without any security checks. If
``False``, the object will be loaded only if there are only trusted
objects in the dumped file. If a list of strings, the object will be
loaded only if all of its required types are listed in ``trusted``
or are trusted by default.
"""
adrinjalali marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, state, load_context: LoadContext, trusted=False, memoize=True):
self.class_name, self.module_name = state["__class__"], state["__module__"]
self.trusted = trusted
self._is_safe = None
self._constructed = None
saved_id = state.get("__id__")
if saved_id and memoize:
# hold reference to obj in case same instance encountered again in
# save state
load_context.memoize(self, saved_id)

def construct(self):
"""Construct the object.

We only construct the object once, and then cache the result.
"""
if self._constructed is not None:
adrinjalali marked this conversation as resolved.
Show resolved Hide resolved
return self._constructed
self._constructed = self._construct()
return self._constructed

@staticmethod
def _get_trusted(trusted, default):
"""Return a trusted list, or True.

If ``trusted`` is ``False``, we return the ``default``, otherwise the
``trusted`` value is used.

This is a convenience method called by child classes.
"""
if trusted is True:
# if trusted is True, we trust the node
return True

if trusted is False:
# if trusted is False, we only trust the defaults
return default

# otherwise we trust the given list
return trusted

def is_self_safe(self):
"""True only if the node's type is considered safe.

This property only checks the type of the node, not its children.
"""
return check_type(self.module_name, self.class_name, self.trusted)

def is_safe(self):
"""True only if the node and all its children are safe."""
# if trusted is set to True, we don't do any safety checks.
if self.trusted is True:
return True
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for now we're not passing trusted to these classes, when we do, this like would be covered.


return len(self.get_unsafe_set()) == 0

def get_unsafe_set(self):
"""Get the set of unsafe types.

This method returns all types which are not trusted, including this
node and all its children.

Returns
-------
unsafe_set : set
A set of unsafe types.
"""
if hasattr(self, "_computing_unsafe_set"):
adrinjalali marked this conversation as resolved.
Show resolved Hide resolved
# this means we're already computing this node's unsafe set, so we
# return an empty set and let the computation of the parent node
# continue. This is to avoid infinite recursion.
return set()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is isn't tested since we haven't figured the whole recursive pointers thing out.

self._computing_unsafe_set = True

res = set()
if not self.is_self_safe():
res.add(self.module_name + "." + self.class_name)

for child, ch_type in self.children.items():
adrinjalali marked this conversation as resolved.
Show resolved Hide resolved
if getattr(self, child) is None:
continue

# Get the safety set based on the type of the child. In most cases
# other than ListNode and DictNode, children are all of type Node.
if ch_type is list:
for value in getattr(self, child):
res.update(value.get_unsafe_set())
elif ch_type is dict:
for value in getattr(self, child).values():
res.update(value.get_unsafe_set())
elif issubclass(ch_type, Node):
res.update(getattr(self, child).get_unsafe_set())
else:
raise ValueError(f"Unknown type {ch_type}.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, how about a custom exception type?


del self._computing_unsafe_set
return res


class CachedNode(Node):
def __init__(self, state, load_context: LoadContext, trusted=False):
# we pass memoize as False because we don't want to memoize the cached
# node.
super().__init__(state, load_context, trusted, memoize=False)
self.trusted = True
self.cached = load_context.get_object(state.get("__id__"))
self.children = {} # type: ignore
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cache node is used for recursive pointers as well.


def _construct(self):
# TODO: FIXME This causes a recursion error when loading a cached
# object if we call the cached object's `construct``. Some refactoring
# is needed to fix this.
return self.cached.construct()


NODE_TYPE_MAPPING.update({"CachedNode": CachedNode})
adrinjalali marked this conversation as resolved.
Show resolved Hide resolved


def get_tree(state, load_context: LoadContext):
"""Get the tree of nodes.

This function returns the root node of the tree of nodes. The tree is
constructed recursively by traversing the state tree. No instances are
created during this process. One would need to call ``construct`` on the
root node to create the instances.

This function also handles memoization of the nodes. If a node has already
been created, it is returned instead of creating a new one.

Parameters
----------
state : dict
The state of the dumped object.

load_context : LoadContext
The context of the loading process.
"""
saved_id = state.get("__id__")
if saved_id in load_context.memo:
# an instance has already been loaded, just return the loaded instance
return load_context.get_instance(saved_id)

if state.get("is_json"):
loaded_obj = json.loads(state["content"])
else:
try:
get_instance_func = GET_INSTANCE_MAPPING[state["__loader__"]]
except KeyError:
type_name = f"{state['__module__']}.{state['__class__']}"
raise TypeError(
f" Can't find loader {state['__loader__']} for type {type_name}."
)

loaded_obj = get_instance_func(state, load_context)

# hold reference to obj in case same instance encountered again in save state
if saved_id:
load_context.memoize(loaded_obj, saved_id)

return loaded_obj
# The Node is already loaded. We return the node. Note that the node is
adrinjalali marked this conversation as resolved.
Show resolved Hide resolved
# not constructed at this point. It will be constructed when the parent
# node's ``construct`` method is called, and for this node it'll be
# called more than once. But that's not an issue since the node's
# ``construct`` method caches the instance.
return load_context.get_object(saved_id)

try:
node_cls = NODE_TYPE_MAPPING[state["__loader__"]]
except KeyError:
type_name = f"{state['__module__']}.{state['__class__']}"
raise TypeError(
f" Can't find loader {state['__loader__']} for type {type_name}."
)

loaded_tree = node_cls(state, load_context, trusted=False) # type: ignore

return loaded_tree
Loading