Skip to content

Commit

Permalink
Merge branch 'persist' of github.com:adrinjalali/skops into persist
Browse files Browse the repository at this point in the history
  • Loading branch information
adrinjalali committed Sep 15, 2022
2 parents 7266673 + ff109a5 commit 45e8e7b
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 69 deletions.
69 changes: 69 additions & 0 deletions skops/io/_general.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from functools import partial
from types import FunctionType

Expand Down Expand Up @@ -169,6 +170,72 @@ def slice_get_instance(obj, src):
return slice(start, stop, step)


def object_get_state(obj, dst):
# This method is for objects which can either be persisted with json, or
# the ones for which we can get/set attributes through
# __getstate__/__setstate__ or reading/writing to __dict__.
try:
# if we can simply use json, then we're done.
return json.dumps(obj)
except Exception:
pass

res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
}

# __getstate__ takes priority over __dict__, and if non exist, we only save
# the type of the object, and loading would mean instantiating the object.
if hasattr(obj, "__getstate__"):
attrs = obj.__getstate__()
elif hasattr(obj, "__dict__"):
attrs = obj.__dict__
else:
return res

content = {}
for key, value in attrs.items():
if isinstance(getattr(type(obj), key, None), property):
continue
if key == "C":
pass
content[key] = _get_state(value, dst)

res["content"] = content

return res


def object_get_instance(state, src):
try:
return json.loads(state)
except Exception:
pass

cls = gettype(state)

# Instead of simply constructing the instance, we use __new__, which
# bypasses the __init__, and then we set the attributes. This solves
# the issue of required init arguments.
instance = cls.__new__(cls)

content = state.get("content", {})
if not len(content):
return instance

attrs = {}
for key, value in content.items():
attrs[key] = _get_instance(value, src)

if hasattr(instance, "__setstate__"):
instance.__setstate__(attrs)
else:
instance.__dict__.update(attrs)

return instance


# tuples of type and function that gets the state of that type
GET_STATE_DISPATCH_FUNCTIONS = [
(dict, dict_get_state),
Expand All @@ -178,6 +245,7 @@ def slice_get_instance(obj, src):
(FunctionType, function_get_state),
(partial, partial_get_state),
(type, type_get_state),
(object, object_get_state),
]
# tuples of type and function that creates the instance of that type
GET_INSTANCE_DISPATCH_FUNCTIONS = [
Expand All @@ -188,4 +256,5 @@ def slice_get_instance(obj, src):
(FunctionType, function_get_instance),
(partial, partial_get_instance),
(type, type_get_instance),
(object, object_get_instance),
]
69 changes: 0 additions & 69 deletions skops/io/_sklearn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import inspect
import json

from sklearn.covariance._graph_lasso import _DictWithDeprecatedKeys
from sklearn.linear_model._sgd_fast import (
Expand Down Expand Up @@ -38,72 +37,6 @@
}


def generic_get_state(obj, dst):
# This method is for objects which can either be persisted with json, or
# the ones for which we can get/set attributes through
# __getstate__/__setstate__ or reading/writing to __dict__.
try:
# if we can simply use json, then we're done.
return json.dumps(obj)
except Exception:
pass

res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
}

# __getstate__ takes priority over __dict__, and if non exist, we only save
# the type of the object, and loading would mean instantiating the object.
if hasattr(obj, "__getstate__"):
attrs = obj.__getstate__()
elif hasattr(obj, "__dict__"):
attrs = obj.__dict__
else:
return res

content = {}
for key, value in attrs.items():
if isinstance(getattr(type(obj), key, None), property):
continue
if key == "C":
pass
content[key] = _get_state(value, dst)

res["content"] = content

return res


def generic_get_instance(state, src):
try:
return json.loads(state)
except Exception:
pass

cls = gettype(state)

# Instead of simply constructing the instance, we use __new__, which
# bypasses the __init__, and then we set the attributes. This solves
# the issue of required init arguments.
instance = cls.__new__(cls)

content = state.get("content", {})
if not len(content):
return instance

attrs = {}
for key, value in content.items():
attrs[key] = _get_instance(value, src)

if hasattr(instance, "__setstate__"):
instance.__setstate__(attrs)
else:
instance.__dict__.update(attrs)

return instance


def reduce_get_state(obj, dst):
# This method is for objects for which we have to use the __reduce__
# method to get the state.
Expand Down Expand Up @@ -220,13 +153,11 @@ def _DictWithDeprecatedKeys_get_instance(state, src):
(LossFunction, reduce_get_state),
(Tree, reduce_get_state),
(_DictWithDeprecatedKeys, _DictWithDeprecatedKeys_get_state),
(object, generic_get_state),
]
# tuples of type and function that creates the instance of that type
GET_INSTANCE_DISPATCH_FUNCTIONS = [
(LossFunction, sgd_loss_get_instance),
(Tree, Tree_get_instance),
(Bunch, bunch_get_instance),
(_DictWithDeprecatedKeys, _DictWithDeprecatedKeys_get_instance),
(object, generic_get_instance),
]

0 comments on commit 45e8e7b

Please sign in to comment.