Skip to content

Commit

Permalink
Add support for BallTree, BinaryTree (skops-dev#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
BenjaminBossan authored Sep 15, 2022
1 parent 45e8e7b commit 1bfffc9
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 42 deletions.
19 changes: 6 additions & 13 deletions skops/io/_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def dict_get_state(obj, dst):
key_types = _get_state([type(key) for key in obj.keys()], dst)
content = {}
for key, value in obj.items():
if isinstance(value, property):
continue
if np.isscalar(key) and hasattr(key, "item"):
# convert numpy value to python object
key = key.item()
Expand Down Expand Up @@ -194,16 +196,10 @@ def object_get_state(obj, dst):
else:
return res

content = {}
for key, value in attrs.items():
if isinstance(getattr(type(obj), key, None), property):
continue
if key == "C":
pass
content[key] = _get_state(value, dst)

content = _get_state(attrs, dst)
# it's sufficient to store the "content" because we know that this dict can
# only have str type keys
res["content"] = content

return res


Expand All @@ -224,10 +220,7 @@ def object_get_instance(state, src):
if not len(content):
return instance

attrs = {}
for key, value in content.items():
attrs[key] = _get_instance(value, src)

attrs = _get_instance(content, src)
if hasattr(instance, "__setstate__"):
instance.__setstate__(attrs)
else:
Expand Down
43 changes: 15 additions & 28 deletions skops/io/_sklearn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import inspect

from sklearn.covariance._graph_lasso import _DictWithDeprecatedKeys
from sklearn.linear_model._sgd_fast import (
EpsilonInsensitive,
Expand All @@ -16,14 +14,7 @@
from sklearn.utils import Bunch

from ._general import dict_get_instance, dict_get_state
from ._utils import (
_get_instance,
_get_state,
get_instance,
get_module,
get_state,
gettype,
)
from ._utils import _get_instance, _get_state, get_module, gettype

ALLOWED_SGD_LOSSES = {
ModifiedHuber,
Expand All @@ -42,7 +33,7 @@ def reduce_get_state(obj, dst):
# method to get the state.
res = {
"__class__": obj.__class__.__name__,
"__module__": inspect.getmodule(type(obj)).__name__,
"__module__": get_module(type(obj)),
}

# We get the output of __reduce__ and use it to reconstruct the object.
Expand All @@ -61,7 +52,7 @@ def reduce_get_state(obj, dst):
# As a good example, this makes Tree object to be serializable.
reduce = obj.__reduce__()
res["__reduce__"] = {}
res["__reduce__"]["args"] = get_state(reduce[1], dst)
res["__reduce__"]["args"] = _get_state(reduce[1], dst)

if len(reduce) == 3:
# reduce includes what's needed for __getstate__ and we don't need to
Expand All @@ -72,31 +63,27 @@ def reduce_get_state(obj, dst):
elif hasattr(obj, "__dict__"):
attrs = obj.__dict__
else:
return res
attrs = {}

content = {}
for key, value in attrs.items():
if isinstance(getattr(type(obj), key, None), property):
continue
content[key] = _get_state(value, dst)

res["content"] = content
if not isinstance(attrs, dict):
raise TypeError(f"Objects of type {res['__class__']} not supported yet")

res["content"] = {"attrs": _get_state(attrs, dst)}
return res


def reduce_get_instance(state, src, constructor):
reduce = state["__reduce__"]
args = get_instance(reduce["args"], src)
args = _get_instance(reduce["args"], src)
instance = constructor(*args)

if "content" not in state:
attrs = _get_instance(state["content"]["attrs"], src)
if not attrs:
# nothing more to do
return instance

content = state["content"]
attrs = {}
for key, value in content.items():
attrs[key] = _get_instance(value, src)
if isinstance(args, tuple) and not hasattr(instance, "__setstate__"):
raise TypeError(f"Objects of type {constructor} are not supported yet")

if hasattr(instance, "__setstate__"):
instance.__setstate__(attrs)
Expand All @@ -107,14 +94,14 @@ def reduce_get_instance(state, src, constructor):


def Tree_get_instance(state, src):
return reduce_get_instance(state, src, Tree)
return reduce_get_instance(state, src, constructor=Tree)


def sgd_loss_get_instance(state, src):
cls = gettype(state)
if cls not in ALLOWED_SGD_LOSSES:
raise TypeError(f"Expected LossFunction, got {cls}")
return reduce_get_instance(state, src, cls)
return reduce_get_instance(state, src, constructor=cls)


def bunch_get_instance(state, src):
Expand Down
9 changes: 8 additions & 1 deletion skops/io/tests/test_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
StratifiedGroupKFold,
check_cv,
)
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
from sklearn.pipeline import FeatureUnion, Pipeline
from sklearn.preprocessing import (
FunctionTransformer,
Expand Down Expand Up @@ -136,6 +137,9 @@ def _tested_estimators(type_filter=None):
inverse_func=partial(np.add, -10),
)

yield KNeighborsClassifier(algorithm="kd_tree")
yield KNeighborsRegressor(algorithm="ball_tree")

yield ColumnTransformer(
[
("norm1", Normalizer(norm="l1"), [0]),
Expand Down Expand Up @@ -545,7 +549,10 @@ def fit(self, X, y=None, **fit_params):
},
}
# check both the top level state and the nested state
states = schema["content"], schema["content"]["nested_"]["content"]
states = (
schema["content"]["content"],
schema["content"]["content"]["nested_"]["content"],
)
for key, val_expected in expected.items():
for state in states:
val_state = state[key]
Expand Down

0 comments on commit 1bfffc9

Please sign in to comment.