diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index 56e3867..d84645a 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -12,10 +12,14 @@ """Define the CEBRA model.""" import copy -from typing import Callable, Iterable, List, Literal, Optional, Tuple, Union +import itertools +import warnings +from typing import (Callable, Dict, Iterable, List, Literal, Optional, Tuple, + Union) import numpy as np import numpy.typing as npt +import pkg_resources import sklearn.utils.validation as sklearn_utils_validation import torch from sklearn.base import BaseEstimator @@ -56,7 +60,7 @@ def _init_loader( algorithm, which might (depending on the arguments for this function) be passed to the data loader. - Raises + Raises: ValueError: If an argument is missing in ``extra_kwargs`` or ``shared_kwargs`` needed to run the requested configuration. NotImplementedError: If the requested combinations of arguments is not yet @@ -64,7 +68,7 @@ def _init_loader( is implemented in :py:mod:`cebra.data`, and consider using the CEBRA PyTorch API directly. - Returns + Returns: the data loader and name of a suitable solver Note: @@ -260,6 +264,94 @@ def _require_arg(key): f"information to your bug report: \n" + error_message) +def _check_type_checkpoint(checkpoint): + if not isinstance(checkpoint, cebra.CEBRA): + raise RuntimeError("Model loaded from file is not compatible with " + "the current CEBRA version.") + if not sklearn_utils.check_fitted(checkpoint): + raise ValueError( + "CEBRA model is not fitted. Loading it is not supported.") + + return checkpoint + + +def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA": + """Loads a CEBRA model with a Sklearn backend. + + Args: + cebra_info: A dictionary containing information about the CEBRA object, + including the arguments, the state of the object and the state + dictionary of the model. + + Returns: + The loaded CEBRA object. + + Raises: + ValueError: If the loaded CEBRA model was not already fit, indicating that loading it is not supported. + """ + required_keys = ['args', 'state', 'state_dict'] + missing_keys = [key for key in required_keys if key not in cebra_info] + if missing_keys: + raise ValueError( + f"Missing keys in data dictionary: {', '.join(missing_keys)}. " + f"You can try loading the CEBRA model with the torch backend.") + + args, state, state_dict = cebra_info['args'], cebra_info[ + 'state'], cebra_info['state_dict'] + cebra_ = cebra.CEBRA(**args) + + for key, value in state.items(): + setattr(cebra_, key, value) + + state_and_args = {**args, **state} + + if not sklearn_utils.check_fitted(cebra_): + raise ValueError( + "CEBRA model was not already fit. Loading it is not supported.") + + if cebra_.num_sessions_ is None: + model = cebra.models.init( + args["model_architecture"], + num_neurons=state["n_features_in_"], + num_units=args["num_hidden_units"], + num_output=args["output_dimension"], + ).to(state['device_']) + + elif isinstance(cebra_.num_sessions_, int): + model = nn.ModuleList([ + cebra.models.init( + args["model_architecture"], + num_neurons=n_features, + num_units=args["num_hidden_units"], + num_output=args["output_dimension"], + ) for n_features in state["n_features_in_"] + ]).to(state['device_']) + + criterion = cebra_._prepare_criterion() + criterion.to(state['device_']) + + optimizer = torch.optim.Adam( + itertools.chain(model.parameters(), criterion.parameters()), + lr=args['learning_rate'], + **dict(args['optimizer_kwargs']), + ) + + solver = cebra.solver.init( + state['solver_name_'], + model=model, + criterion=criterion, + optimizer=optimizer, + tqdm_on=args['verbose'], + ) + solver.load_state_dict(state_dict) + solver.to(state['device_']) + + cebra_.model_ = model + cebra_.solver_ = solver + + return cebra_ + + class CEBRA(BaseEstimator, TransformerMixin): """CEBRA model defined as part of a ``scikit-learn``-like API. @@ -735,16 +827,16 @@ def _check_labels_types(self, y: tuple, session_id: Optional[int] = None): """ n_idx = len(y) # Check that same number of index - if len(self._label_types) != n_idx: + if len(self.label_types_) != n_idx: raise ValueError( f"Number of index invalid: labels must have the same number of index as for fitting," - f"expects {len(self._label_types)}, got {n_idx} idx.") + f"expects {len(self.label_types_)}, got {n_idx} idx.") - for i in range(len(self._label_types)): # for each index + for i in range(len(self.label_types_)): # for each index if self.num_sessions is None: - label_types_idx = self._label_types[i] + label_types_idx = self.label_types_[i] else: - label_types_idx = self._label_types[i][session_id] + label_types_idx = self.label_types_[i][session_id] if (len(label_types_idx[1]) > 1 and len(y[i].shape) > 1): # is there more than one feature in the index @@ -794,7 +886,7 @@ def _prepare_fit( criterion = self._prepare_criterion() criterion.to(self.device_) optimizer = torch.optim.Adam( - list(model.parameters()) + list(criterion.parameters()), + itertools.chain(model.parameters(), criterion.parameters()), lr=self.learning_rate, **dict(self.optimizer_kwargs), ) @@ -807,8 +899,9 @@ def _prepare_fit( tqdm_on=self.verbose, ) solver.to(self.device_) + self.solver_name_ = solver_name - self._label_types = ([[(y_session.dtype, y_session.shape) + self.label_types_ = ([[(y_session.dtype, y_session.shape) for y_session in y_index] for y_index in y] if is_multisession else [(y_.dtype, y_.shape) for y_ in y]) @@ -1191,19 +1284,49 @@ def _more_tags(self): # current version of CEBRA. return {"non_deterministic": True} - def save(self, filename: str, backend: Literal["torch"] = "torch"): + def _get_state(self): + cebra_dict = self.__dict__ + state = { + 'label_types_': cebra_dict['label_types_'], + 'device_': cebra_dict['device_'], + 'n_features_': cebra_dict['n_features_'], + 'n_features_in_': cebra_dict['n_features_in_'], + 'num_sessions_': cebra_dict['num_sessions_'], + 'offset_': cebra_dict['offset_'], + 'solver_name_': cebra_dict['solver_name_'], + } + return state + + def save(self, + filename: str, + backend: Literal["torch", "sklearn"] = "sklearn"): """Save the model to disk. Args: filename: The path to the file in which to save the trained model. - backend: A string identifying the used backend. + backend: A string identifying the used backend. Default is "sklearn". Returns: The saved model checkpoint. Note: - Experimental functionality. Do not expect the save/load functionalities to be - backward compatible yet between CEBRA versions! + The save/load functionalities may change in a future version. + + File Format: + The saved model checkpoint file format depends on the specified backend. + + "sklearn" backend (default): + The model is saved in a PyTorch-compatible format using `torch.save`. The saved checkpoint + is a dictionary containing the following elements: + - 'args': A dictionary of parameters used to initialize the CEBRA model. + - 'state': The state of the CEBRA model, which includes various internal attributes. + - 'state_dict': The state dictionary of the underlying solver used by CEBRA. + - 'metadata': Additional metadata about the saved model, including the backend used and the version of CEBRA PyTorch, NumPy and scikit-learn. + + "torch" backend: + The model is directly saved using `torch.save` with no additional information. The saved + file contains the entire CEBRA model state. + Example: @@ -1216,15 +1339,41 @@ def save(self, filename: str, backend: Literal["torch"] = "torch"): >>> cebra_model.save('/tmp/foo.pt') """ - if backend != "torch": - raise NotImplementedError(f"Unsupported backend: {backend}") - checkpoint = torch.save(self, filename) + if sklearn_utils.check_fitted(self): + if backend == "torch": + checkpoint = torch.save(self, filename) + + elif backend == "sklearn": + checkpoint = torch.save( + { + 'args': self.get_params(), + 'state': self._get_state(), + 'state_dict': self.solver_.state_dict(), + 'metadata': { + 'backend': + backend, + 'cebra_version': + cebra.__version__, + 'torch_version': + torch.__version__, + 'numpy_version': + np.__version__, + 'sklearn_version': + pkg_resources.get_distribution("scikit-learn" + ).version + } + }, filename) + else: + raise NotImplementedError(f"Unsupported backend: {backend}") + else: + raise ValueError("CEBRA object is not fitted. " + "Saving a non-fitted model is not supported.") return checkpoint @classmethod def load(cls, filename: str, - backend: Literal["torch"] = "torch", + backend: Literal["auto", "sklearn", "torch"] = "auto", **kwargs) -> "CEBRA": """Load a model from disk. @@ -1240,6 +1389,8 @@ def load(cls, Experimental functionality. Do not expect the save/load functionalities to be backward compatible yet between CEBRA versions! + For information about the file format please refer to :py:meth:`cebra.CEBRA.save`. + Example: >>> import cebra @@ -1249,13 +1400,32 @@ def load(cls, >>> embedding = loaded_model.transform(dataset) """ - if backend != "torch": - raise NotImplementedError(f"Unsupported backend: {backend}") - model = torch.load(filename, **kwargs) - if not isinstance(model, cls): - raise RuntimeError("Model loaded from file is not compatible with " - "the current CEBRA version.") - return model + + supported_backends = ["auto", "sklearn", "torch"] + if backend not in supported_backends: + raise NotImplementedError( + f"Unsupported backend: '{backend}'. Supported backends are: {', '.join(supported_backends)}" + ) + + checkpoint = torch.load(filename, **kwargs) + + if backend == "auto": + backend = "sklearn" if isinstance(checkpoint, dict) else "torch" + + if isinstance(checkpoint, dict) and backend == "torch": + raise RuntimeError( + f"Cannot use 'torch' backend with a dictionary-based checkpoint. " + f"Please try a different backend.") + if not isinstance(checkpoint, dict) and backend == "sklearn": + raise RuntimeError(f"Cannot use 'sklearn' backend a non dictionary-based checkpoint. " + f"Please try a different backend.") + + if backend == "sklearn": + cebra_ = _load_cebra_with_sklearn_backend(checkpoint) + else: + cebra_ = _check_type_checkpoint(checkpoint) + + return cebra_ def to(self, device: Union[str, torch.device]): """Moves the cebra model to the specified device. @@ -1282,7 +1452,7 @@ def to(self, device: Union[str, torch.device]): raise TypeError( "The 'device' parameter must be a string or torch.device object." ) - + if isinstance(device, str): if (not device == 'cpu') and (not device.startswith('cuda')) and ( not device == 'mps'): @@ -1292,7 +1462,8 @@ def to(self, device: Union[str, torch.device]): elif isinstance(device, torch.device): if (not device.type == 'cpu') and ( - not device.type.startswith('cuda')) and (not device == 'mps'): + not device.type.startswith('cuda')) and (not device + == 'mps'): raise ValueError( "The 'device' parameter must be a valid device string or device object." ) diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index ff07979..f60a834 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -19,6 +19,7 @@ import pytest import sklearn.utils.estimator_checks import torch +import torch.nn as nn import cebra.data as cebra_data import cebra.helper @@ -26,6 +27,8 @@ import cebra.integrations.sklearn.dataset as cebra_sklearn_dataset import cebra.integrations.sklearn.utils as cebra_sklearn_utils import cebra.models +import cebra.models.model +from cebra.models import parametrize if torch.cuda.is_available(): _DEVICES = "cpu", "cuda" @@ -805,23 +808,23 @@ def _assert_same_state_dict(first, second): assert first[key] == second[key] -def _assert_equal(original_model, loaded_model): - assert original_model.get_params() == loaded_model.get_params() +def check_if_fit(model): + """Check if a model was already fit. - def check_fitted(model): - """Check if a model is fitted. + Args: + model: The model to check. - Args: - model: The model to assess. + Returns: + True if the model was already fit. + """ + return hasattr(model, "n_features_") - Returns: - True if fitted. - """ - return hasattr(model, "n_features_") - assert check_fitted(loaded_model) == check_fitted(original_model) +def _assert_equal(original_model, loaded_model): + assert original_model.get_params() == loaded_model.get_params() + assert check_if_fit(loaded_model) == check_if_fit(original_model) - if check_fitted(loaded_model): + if check_if_fit(loaded_model): _assert_same_state_dict(original_model.state_dict_, loaded_model.state_dict_) X = np.random.normal(0, 1, (100, 1)) @@ -834,17 +837,72 @@ def check_fitted(model): original_model.transform(X)) +@parametrize( + "parametrized-model-{output_dim}", + output_dim=(5, 10), +) +class ParametrizedModelExample(cebra.models.model._OffsetModel): + """CEBRA model with a single sample receptive field, without output normalization.""" + + def __init__( + self, + num_neurons, + num_units, + num_output, + output_dim, + normalize=False, + ): + super().__init__( + nn.Flatten(start_dim=1, end_dim=-1), + nn.Linear( + num_neurons, + output_dim, + ), + num_input=num_neurons, + num_output=num_output, + normalize=normalize, + ) + + def get_offset(self) -> cebra.data.datatypes.Offset: + """See :py:meth:`~.Model.get_offset`""" + return cebra.data.Offset(0, 1) + @pytest.mark.parametrize("action", _iterate_actions()) -def test_save_and_load(action): - model_architecture = "offset10-model" +@pytest.mark.parametrize("backend_save", ["torch", "sklearn"]) +@pytest.mark.parametrize("backend_load", ["auto", "torch", "sklearn"]) +@pytest.mark.parametrize("model_architecture", + ["offset1-model", "parametrized-model-5"]) +@pytest.mark.parametrize("device", ["cpu"] + + ["cuda"] if torch.cuda.is_available() else []) +def test_save_and_load(action, backend_save, backend_load, model_architecture, + device): original_model = cebra_sklearn_cebra.CEBRA( - model_architecture=model_architecture, max_iterations=5, batch_size=42) + model_architecture=model_architecture, + max_iterations=5, + batch_size=100, + device=device) + original_model = action(original_model) with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as savefile: - original_model.save(savefile.name) - loaded_model = cebra_sklearn_cebra.CEBRA.load(savefile.name) - _assert_equal(original_model, loaded_model) - + if not check_if_fit(original_model): + with pytest.raises(ValueError): + original_model.save(savefile.name, backend=backend_save) + else: + if "parametrized" in original_model.model_architecture and backend_save == "torch": + with pytest.raises(AttributeError): + original_model.save(savefile.name, backend=backend_save) + else: + original_model.save(savefile.name, backend=backend_save) + + if (backend_load != "auto") and (backend_save != backend_load): + with pytest.raises(RuntimeError): + cebra_sklearn_cebra.CEBRA.load(savefile.name, + backend_load) + else: + loaded_model = cebra_sklearn_cebra.CEBRA.load( + savefile.name, backend_load) + _assert_equal(original_model, loaded_model) + action(loaded_model) def get_ordered_cuda_devices(): available_devices = ['cuda']