Skip to content

Internal API

Peter Hausamann edited this page Jun 19, 2020 · 6 revisions

If you want to implement new functionality by implementing new estimators or wrapping existing sklearn estimators this page will guide you through the necessary steps.

Wrapping an existing estimator

Let's say you want to wrap an estimator called AwesomeEstimator from sklearn.awesome_module that implements the usual fit, predict and score methods.

All sklearn-xarray estimators should inherit from sklearn_xarray.common.wrappers.EstimatorWrapper which provides the fit method that all sklearn estimator must possess. All other member methods are optional and provided by mixins. The constructor of the wrapper should construct an instance of the wrapped estimator and pass it to the superclass constructor:

from sklearn_xarray.common.wrappers import (
    EstimatorWrapper, _ImplementsPredictMixin, _ImplementsScoreMixin)
from sklearn.awesome_module import AwesomeEstimator as _AwesomeEstimator

class AwesomeEstimator(EstimatorWrapper, _ImplementsPredictMixin,
                       _ImplementsScoreMixin):
    """ An example class demonstrating the internal API. """

    def __init__(self, param_1=None, param_2=None, reshapes=None,
                 sample_dim='sample', compat=False):
        estimator = _AwesomeEstimator(param_1=param_1, param_2=param_2)
        super(AwesomeEstimator, self).__init__(
            estimator, reshapes=reshapes, sample_dim=sample_dim, compat=compat)

Notice that the keyword arguments reshapes, sample_dim and compat are also passed. If the wrapped estimator changes the number of features when calling predict you can specify reshapes='feature', but it's not absolutely necessary. Check out the documentation for more details.

Wrapping estimator methods

Now let's say the AwesomeEstimator also has a do_stuff method that does not have a corresponding wrapper yet. You'll have to implement a mixin that wraps this method in sklearn_xarray.common.base. It should include a public do_stuff method as well as an internal _do_stuff method whose input is a DataArray and which calls the wrapped estimators do_stuff. _CommonEstimatorWrapper._call_fitted will then take care of the rest:

import numpy as np
from sklearn_xarray.common.base import _CommonEstimatorWrapper

class _ImplementsDoStuffMixin(_CommonEstimatorWrapper):

    def _do_stuff(self, estimator, X):
        """ Do stuff with ``self.estimator`` and update dims. """

        if self.sample_dim is not None:
            # transpose to sample dim first, do stuff and transpose back
            order = self._get_transpose_order(X)
            X_arr = np.transpose(X.data, order)
            X_d = estimator.do_stuff(X_arr)
            if X_d.ndim == X.ndim:
                X_d = np.transpose(X_d, np.argsort(order))
        else:
            X_d = estimator.do_stuff(X.data)

        # update dims
        dims_new = self._update_dims(X, X_d)

        return X_d, dims_new

    def do_stuff(self, X):
        """ A wrapper around the do_stuff function. """

        return self._call_fitted('do_stuff', X)

Implementing new estimators

If you want to implement new estimators, you can obviously just implement a standard sklearn-compatible estimator and wrap it with the steps described above. But you'll probably want to implement some functionality that makes use of xarray's coordinates or other features. In that case you'll have to implement the estimator from scratch.

First of all, your estimator should implement a fit method that determines the type of the training data:

from sklearn_xarray.utils import is_dataarray, is_dataset

def fit(self, X, y=None, **fit_params):
    """ Fit estimator to data. """

    if is_dataset(X):
        self.type_ = 'Dataset'
        # implement fitting procedure for Datasets
    elif is_dataarray(X):
        self.type_ = 'DataArray'
        # implement fitting procedure for DataArrays
    else:
        self.type_ = other
        # implement fitting procedure for numpy-like

    return self

All other methods that work on fitted estimators should call check_fitted first:

from sklearn.utils.validation import check_is_fitted

check_is_fitted(self, ['type_'])

You should probably implement the functionality for DataArrays in a seperate method. For Datasets you can then just call that method for each of the data_vars and join them together afterwards.

Take a look at the estimators in the preprocessing module for a couple of examples.

Clone this wiki locally