Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve saving and loading of models #69

Merged
merged 38 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
72098be
first attempt of saving/loading models the right way
gonlairo Jun 11, 2023
ed8d477
some progress
gonlairo Jun 14, 2023
d33a790
first working proposal
gonlairo Jun 15, 2023
49c5518
small progress
gonlairo Jun 19, 2023
9805499
improve the API. Found self.get_params()
gonlairo Jun 26, 2023
fdf6fb7
simplify code + first test pass
gonlairo Jun 28, 2023
f4f9b47
fix spelling
gonlairo Jun 28, 2023
3e83688
raise Valuerror + add multisession support
gonlairo Jun 30, 2023
0331c75
fix name
gonlairo Jun 30, 2023
0401c7f
fix typo
gonlairo Jun 30, 2023
945eb94
improve test + code
gonlairo Jul 4, 2023
bec32e1
organize code better + add more checks
gonlairo Jul 8, 2023
4f4e970
improve tests
gonlairo Jul 8, 2023
02bb3c0
fix import in test
gonlairo Jul 13, 2023
4efc27e
improve docs
gonlairo Jul 13, 2023
7126283
add suggestions
gonlairo Jul 13, 2023
622277a
remove comments in test
gonlairo Jul 13, 2023
8482987
Merge remote-tracking branch 'origin/public' into rodrigo/save-models
gonlairo Jul 14, 2023
efc2aa2
fix typo test
gonlairo Jul 14, 2023
d302e21
fix docs + delete unnecesary check
gonlairo Jul 17, 2023
6ce7213
remove args from state
gonlairo Jul 17, 2023
3380b9c
improvements
gonlairo Jul 20, 2023
d96b935
fix typo
gonlairo Jul 23, 2023
bb75e82
fix typo
gonlairo Jul 23, 2023
3d8f8ad
fix typo
gonlairo Jul 23, 2023
8aab47e
change _label_types to label_types_
gonlairo Jul 23, 2023
ac6fbf9
fix docs
gonlairo Jul 23, 2023
1eae71a
Merge branch 'public' into rodrigo/save-models
gonlairo Jul 24, 2023
0a3ee9b
Merge branch 'public' into save-models
stes Sep 13, 2023
a4177bf
Update cebra/integrations/sklearn/cebra.py
gonlairo Sep 13, 2023
baa4d72
Update cebra/integrations/sklearn/cebra.py
gonlairo Sep 13, 2023
067fac7
improve typing
gonlairo Sep 13, 2023
8ca848c
improve typing
gonlairo Sep 13, 2023
f0d33e6
fix typo + pre-commit
gonlairo Sep 13, 2023
12e8833
better docstrings
gonlairo Sep 18, 2023
f289e0b
fix unindent error in docs
gonlairo Sep 25, 2023
38c90a8
Merge branch 'main' into save-models
stes Sep 25, 2023
fb2bde5
Update cebra.py
MMathisLab Oct 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 188 additions & 22 deletions cebra/integrations/sklearn/cebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
"""Define the CEBRA model."""

import copy
import itertools
import warnings
from typing import Callable, Iterable, List, Literal, Optional, Tuple, Union
gonlairo marked this conversation as resolved.
Show resolved Hide resolved

import numpy as np
import numpy.typing as npt
import pkg_resources
import sklearn.utils.validation as sklearn_utils_validation
import torch
from sklearn.base import BaseEstimator
Expand Down Expand Up @@ -56,15 +59,15 @@ def _init_loader(
algorithm, which might (depending on the arguments for this function)
be passed to the data loader.

Raises
Raises:
ValueError: If an argument is missing in ``extra_kwargs`` or ``shared_kwargs``
needed to run the requested configuration.
NotImplementedError: If the requested combinations of arguments is not yet
implemented. If this error occurs, check if the desired functionality
is implemented in :py:mod:`cebra.data`, and consider using the CEBRA
PyTorch API directly.

Returns
Returns:
the data loader and name of a suitable solver

Note:
Expand Down Expand Up @@ -260,6 +263,92 @@ def _require_arg(key):
f"information to your bug report: \n" + error_message)


def _check_type_checkpoint(checkpoint):
if not isinstance(checkpoint, cebra.CEBRA):
raise RuntimeError("Model loaded from file is not compatible with "
"the current CEBRA version.")
if not sklearn_utils.check_fitted(checkpoint):
raise ValueError(
"CEBRA model is not fitted. Loading it is not supported.")

return checkpoint


def _load_cebra_with_sklearn_backend(cebra_info: dict) -> "CEBRA":
gonlairo marked this conversation as resolved.
Show resolved Hide resolved
"""Loads a CEBRA model with a Sklearn backend.

Args:
cebra_info: A dictionary containing information about the CEBRA model.
gonlairo marked this conversation as resolved.
Show resolved Hide resolved

Returns:
The loaded CEBRA object.

Raises:
ValueError: If the loaded CEBRA model is not fitted, indicating that loading it is not supported.
gonlairo marked this conversation as resolved.
Show resolved Hide resolved
"""
required_keys = ['args', 'state', 'state_dict']
missing_keys = [key for key in required_keys if key not in cebra_info]
if missing_keys:
raise ValueError(
f"Missing keys in data dictionary: {', '.join(missing_keys)}. "
f"You can try loading the CEBRA model with a different backend.")
gonlairo marked this conversation as resolved.
Show resolved Hide resolved

args, state, state_dict = cebra_info['args'], cebra_info[
'state'], cebra_info['state_dict']
cebra_ = cebra.CEBRA(**args)

for key, value in state.items():
setattr(cebra_, key, value)

state_and_args = {**args, **state}

if not sklearn_utils.check_fitted(cebra_):
raise ValueError(
"CEBRA model is not fitted. Loading it is not supported.")
gonlairo marked this conversation as resolved.
Show resolved Hide resolved

if cebra_.num_sessions_ is None:
model = cebra.models.init(
args["model_architecture"],
num_neurons=state["n_features_in_"],
num_units=args["num_hidden_units"],
num_output=args["output_dimension"],
).to(state['device_'])

elif isinstance(cebra_.num_sessions_, int):
model = nn.ModuleList([
cebra.models.init(
args["model_architecture"],
num_neurons=n_features,
num_units=args["num_hidden_units"],
num_output=args["output_dimension"],
) for n_features in state["n_features_in_"]
]).to(state['device_'])

criterion = cebra_._prepare_criterion()
criterion.to(state['device_'])

optimizer = torch.optim.Adam(
itertools.chain(model.parameters(), criterion.parameters()),
lr=args['learning_rate'],
**dict(args['optimizer_kwargs']),
)

solver = cebra.solver.init(
state['solver_name_'],
model=model,
criterion=criterion,
optimizer=optimizer,
tqdm_on=args['verbose'],
)
solver.load_state_dict(state_dict)
solver.to(state['device_'])

cebra_.model_ = model
cebra_.solver_ = solver

return cebra_


class CEBRA(BaseEstimator, TransformerMixin):
"""CEBRA model defined as part of a ``scikit-learn``-like API.

Expand Down Expand Up @@ -735,16 +824,16 @@ def _check_labels_types(self, y: tuple, session_id: Optional[int] = None):
"""
n_idx = len(y)
# Check that same number of index
if len(self._label_types) != n_idx:
if len(self.label_types_) != n_idx:
raise ValueError(
f"Number of index invalid: labels must have the same number of index as for fitting,"
f"expects {len(self._label_types)}, got {n_idx} idx.")
f"expects {len(self.label_types_)}, got {n_idx} idx.")

for i in range(len(self._label_types)): # for each index
for i in range(len(self.label_types_)): # for each index
if self.num_sessions is None:
label_types_idx = self._label_types[i]
label_types_idx = self.label_types_[i]
else:
label_types_idx = self._label_types[i][session_id]
label_types_idx = self.label_types_[i][session_id]

if (len(label_types_idx[1]) > 1 and len(y[i].shape)
> 1): # is there more than one feature in the index
Expand Down Expand Up @@ -794,7 +883,7 @@ def _prepare_fit(
criterion = self._prepare_criterion()
criterion.to(self.device_)
optimizer = torch.optim.Adam(
list(model.parameters()) + list(criterion.parameters()),
itertools.chain(model.parameters(), criterion.parameters()),
lr=self.learning_rate,
**dict(self.optimizer_kwargs),
)
Expand All @@ -807,8 +896,9 @@ def _prepare_fit(
tqdm_on=self.verbose,
)
solver.to(self.device_)
self.solver_name_ = solver_name

self._label_types = ([[(y_session.dtype, y_session.shape)
self.label_types_ = ([[(y_session.dtype, y_session.shape)
for y_session in y_index]
for y_index in y] if is_multisession else
[(y_.dtype, y_.shape) for y_ in y])
Expand Down Expand Up @@ -1191,12 +1281,25 @@ def _more_tags(self):
# current version of CEBRA.
return {"non_deterministic": True}

def save(self, filename: str, backend: Literal["torch"] = "torch"):
def _get_state(self):
cebra_dict = self.__dict__
state = {
'label_types_': cebra_dict['label_types_'],
'device_': cebra_dict['device_'],
'n_features_': cebra_dict['n_features_'],
'n_features_in_': cebra_dict['n_features_in_'],
'num_sessions_': cebra_dict['num_sessions_'],
'offset_': cebra_dict['offset_'],
'solver_name_': cebra_dict['solver_name_'],
}
return state

def save(self, filename: str, backend: str = "sklearn"):
gonlairo marked this conversation as resolved.
Show resolved Hide resolved
"""Save the model to disk.

Args:
filename: The path to the file in which to save the trained model.
backend: A string identifying the used backend.
backend: A string identifying the used backend. Default is "sklearn".

Returns:
The saved model checkpoint.
Expand All @@ -1205,6 +1308,22 @@ def save(self, filename: str, backend: Literal["torch"] = "torch"):
Experimental functionality. Do not expect the save/load functionalities to be
MMathisLab marked this conversation as resolved.
Show resolved Hide resolved
backward compatible yet between CEBRA versions!

File Format:
The saved model checkpoint file format depends on the specified backend.

"sklearn" backend (default):
The model is saved in a PyTorch-compatible format using `torch.save`. The saved checkpoint
is a dictionary containing the following elements:
- 'args': A dictionary of parameters used to initialize the CEBRA model.
- 'state': The state of the CEBRA model, which includes various internal attributes.
- 'state_dict': The state dictionary of the underlying solver used by CEBRA.
- 'metadata': Additional metadata about the saved model, including the backend used and the version of CEBRA PyTorch, NumPy and scikit-learn.

"torch" backend:
The model is directly saved using `torch.save` with no additional information. The saved
file contains the entire CEBRA model state.


Example:

>>> import cebra
Expand All @@ -1216,15 +1335,41 @@ def save(self, filename: str, backend: Literal["torch"] = "torch"):
>>> cebra_model.save('/tmp/foo.pt')
stes marked this conversation as resolved.
Show resolved Hide resolved

"""
if backend != "torch":
raise NotImplementedError(f"Unsupported backend: {backend}")
checkpoint = torch.save(self, filename)
if sklearn_utils.check_fitted(self):
if backend == "torch":
checkpoint = torch.save(self, filename)

elif backend == "sklearn":
checkpoint = torch.save(
{
'args': self.get_params(),
'state': self._get_state(),
'state_dict': self.solver_.state_dict(),
'metadata': {
'backend':
backend,
'cebra_version':
cebra.__version__,
'torch_version':
torch.__version__,
'numpy_version':
np.__version__,
'sklearn_version':
pkg_resources.get_distribution("scikit-learn"
).version
}
}, filename)
else:
raise NotImplementedError(f"Unsupported backend: {backend}")
else:
raise ValueError("CEBRA object is not fitted. "
"Saving a non-fitted model is not supported.")
return checkpoint

@classmethod
def load(cls,
filename: str,
backend: Literal["torch"] = "torch",
backend: Literal = "auto",
gonlairo marked this conversation as resolved.
Show resolved Hide resolved
**kwargs) -> "CEBRA":
"""Load a model from disk.

Expand All @@ -1240,6 +1385,8 @@ def load(cls,
Experimental functionality. Do not expect the save/load functionalities to be
backward compatible yet between CEBRA versions!

For information about the file format we refer to :py:meth:`CEBRA.save`.
gonlairo marked this conversation as resolved.
Show resolved Hide resolved

Example:

>>> import cebra
Expand All @@ -1249,13 +1396,32 @@ def load(cls,
>>> embedding = loaded_model.transform(dataset)

"""
if backend != "torch":
raise NotImplementedError(f"Unsupported backend: {backend}")
model = torch.load(filename, **kwargs)
if not isinstance(model, cls):
raise RuntimeError("Model loaded from file is not compatible with "
"the current CEBRA version.")
return model

supported_backends = ["auto", "sklearn", "torch"]
if backend not in supported_backends:
raise NotImplementedError(
f"Unsupported backend: '{backend}'. Supported backends are: {', '.join(supported_backends)}"
)

checkpoint = torch.load(filename, **kwargs)

if backend == "auto":
backend = "sklearn" if isinstance(checkpoint, dict) else "torch"

if isinstance(checkpoint, dict) and backend == "torch":
raise RuntimeError(
f"Cannot use 'torch' backend with a dictionary-based checkpoint. "
f"Please try a different backend.")
if not isinstance(checkpoint, dict) and backend == "sklearn":
raise RuntimeError(f"Cannot use 'sklearn' backend. "
gonlairo marked this conversation as resolved.
Show resolved Hide resolved
f"Please try a different backend.")

if backend == "sklearn":
cebra_ = _load_cebra_with_sklearn_backend(checkpoint)
else:
cebra_ = _check_type_checkpoint(checkpoint)

return cebra_

def to(self, device: Union[str, torch.device]):
"""Moves the cebra model to the specified device.
Expand Down
Loading
Loading