Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: get_instance method saved in state #200

Conversation

BenjaminBossan
Copy link
Collaborator

Resolves #197

Description

Currently, during the dispatch of the get_instance functions, the class stored in the state is being loaded to determine which function to dispatch to. This is bad because module loading can be dangerous. We will add auditing but it is intended to be on the level of get_instance itself, not for the dispatch mechanism.

In this PR, the state returned by get_state functions is augmented with the name of the get_instance method required to load the object. This way, we can look up the correct method based on the state and don't need to use the modified singledispatch mechanism, thus avoiding loading modules during dispatching.

Implementation

Whereas for get_state, we still rely in singledispatch, for get_instance we now use a simple dictionary that looks up the function based on its name (which in turn is stored in the state). The dictionary, going by the name of GET_INSTANCE_MAPPING, is populated similarly to how the get_instance functions were registered previously with singledispatch.

There was an issue with circular imports (e.g. get_instance > GET_INSTANCE_MAPPING > ndarray_get_instance > get_instance), hence the get_instance function was moved to its own module, _dispatch.py (better name suggestions are welcome).

For some types, we now need extra get_state functions because they have specific get_instance methods. So e.g. sgd_loss_get_state just wraps reduce_get_state and adds sgd_loss_get_instance as its loader.

Coincidental changes

Since we no longer have to inspect the contents of state to determine the function to dispatch to for get_instance, we can fall back to the Python implementation of singledispatch instead of rolling our own. This side effect is a big win.

The function Tree_get_instance was renamed to tree_get_instance for consistency.

In the debug_dispatch_functions, there was some code from a time when the state was allowed not to be a dict (json-serializable objects). Now we always have a dict, so this dead code was removed.

Also, this fixture was elevated to module-level scope, since it only needs to run once.

Resolves skops-dev#197

Description

Currently, during the dispatch of the get_instance functions, the class
stored in the state is being loaded to determine which function to
dispatch to. This is bad because module loading can be dangerous. We
will add auditing but it is intended to be on the level of
get_instance itself, not for the dispatch mechanism.

In this PR, the state returned by get_state functions is augmented with
the name of the get_instance method required to load the object. This
way, we can look up the correct method based on the state and don't need
to use the modified singledispatch mechanism, thus avoiding loading
modules during dispatching.

Implementation

Whereas for get_state, we still rely in singledispatch, for get_instance
we now use a simple dictionary that looks up the function based on its
name (which in turn is stored in the state). The dictionary, going by
the name of GET_INSTANCE_MAPPING, is populated similarly to how the
get_instance functions were registered previously with singledispatch.

There was an issue with circular imports (e.g. get_instance >
GET_INSTANCE_MAPPING > ndarray_get_instance > get_instance), hence the
get_instance function was moved to its own module, _dispatch.py (other
name suggestions welcome).

For some types, we now need extra get_state functions because they
have specific get_instance methods. So e.g. sgd_loss_get_state just
wraps reduce_get_state and adds sgd_loss_get_instance as its loader.

Coincidental changes

Since we no longer have to inspect the contents of state to determine
the function to dispatch to for get_instance, we can fall back to the
Python implementation of singledispatch instead of rolling our own. This
side effect is a big win.

The function Tree_get_instance was renamed to tree_get_instance for
consistency.

In the debug_dispatch_functions, there was some code from a time when
the state was allowed not to be a dict (json-serializable objects). Now
we always have a dict, so this dead code was removed.

Also, this fixture was elevated to module-level scope, since it only
needs to run once.
@BenjaminBossan
Copy link
Collaborator Author

@skops-dev/maintainers ready for review

Tests pass without it.
@BenjaminBossan
Copy link
Collaborator Author

Addendum: I noticed (thanks codecov) that I had accidentally not registered bunch_get_state. Therefore, bunch_get_instance also was not used. Still, the tests passed. At first, I thought we might not be testing correctly but Bunch attributes are indeed loaded as Bunch objects. Therefore, it seems we can remove the special treatment for Bunch.

I'm not exactly sure which change we did that resulted in this, since I'm pretty sure it used to be necessary, but it's not anymore. Therefore, I removed the corresponding code.

Copy link
Member

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks really nice!

from typing import Any, Callable
from zipfile import ZipFile

GET_INSTANCE_MAPPING: dict[str, Callable[[dict[str, Any], ZipFile], Any]] = {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we not have the type here? we're gonna change signature later and we'd have to keep these in sync.

get_instance_func = GET_INSTANCE_MAPPING[state["__loader__"]]
except KeyError:
raise TypeError(
f"Creating an instance of type {type(state)} is not supported yet"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type of state is a dict. We should say something like:

f" Can't find loader {state["__loader__"]} for type {state[__module__]}.{state["__class__"]}."

@@ -21,10 +21,10 @@
SquaredLoss,
)
from sklearn.tree._tree import Tree
from sklearn.utils import Bunch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is because we load and create the class of the object whatever it was now in dict_get_instance instead of creating a dict.

- Remove type annotation for GET_INSTANCE_MAPPING
- Better error message if loader not found

Also changed:

- Misleading var names in get_state
@BenjaminBossan
Copy link
Collaborator Author

@adrinjalali I addressed your comments, also added a test for the error message.

Copy link
Member

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm mostly happy about the code we're removing here lol

@adrinjalali adrinjalali merged commit cdf5d57 into skops-dev:main Oct 24, 2022
@BenjaminBossan BenjaminBossan deleted the persist-refactor-store-loader-in-state branch October 25, 2022 08:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Refactor get_instance functions to not use dispatch
2 participants