Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Serialisation of (online) state for online detectors #604

Merged
merged 48 commits into from
Jan 24, 2023
Merged

Serialisation of (online) state for online detectors #604

merged 48 commits into from
Jan 24, 2023

Conversation

ascillitoe
Copy link
Contributor

@ascillitoe ascillitoe commented Sep 2, 2022

This PR implements the functionality to save and load state for online detectors. At a given time step, the save_state method can be called to create a "checkpoint". This can later be loaded via the load_state method. At any time, the reset_state reset method can be used to reset the state back to the t=0 timestep.

Scope

This PR deals with online state only. See #604 (comment) for a discussion on online versus offline state.

Example(s)

Saving and loading state

The state of online detectors can now be saved and loaded via save_state and load_state:

The detector's state may be saved with the save_state method:

cd = CVMDriftOnline(x_ref, ert, window_sizes)  # Instantiate detector at t=0
cd.predict(x_1)  # t=1
cd.save_state('checkpoint_t1')  # Save state at t=1
cd.predict(x_2)  # t=2

The previously saved state may then be loaded via the load_state method:

# Load state at t=1
cd.load_state('checkpoint_t1')

At any point, the state may be reset with the reset method. Also see colab notebook.

Saving and loading detector with state

Calling save_detector with save_state=True will save an online detectors state to state/ within the detector save directory. load_detector will simply attempt to load state if a state/ directory exists.

from alibi_detect.cd import LSDDDriftOnline
from alibi_detect.saving import save_detector, load_detector

# Init detector (self.t = 0)
dd = LSDDDriftOnline(x_ref, ert, window_size)

# Perform predict call to update state (e.g. self.t = self.t + 1)
dd.predict(x)

# Save detector with its state included
save_detector(dd, filepath, save_state=True)

# Load stateful detector (i.e. self.t = 1)
dd_new = load_detector(filepath)

TODO's:

  • Implement for remaining detectors and backends.
  • Integrate with the save_detector and load_detector functions, to allow state to be saved and loaded when the detector itself is serialized/unserialized.
  • Tests.
  • Docs.
  • Changelog
  • Public facing example. - examples added to Saving and loading page and detector methods pages.
  • Test all detectors and backends in test_saving.py state test.

## Outstanding considerations (specific to LSDD for now but maybe more widely applicible)

There might be an open question to resolve regarding what we define "state" to be. This PR currently considers it to be only the attributes that are updated in _update_state (self.t, self.test_window and self.k_xtc). In other words, "state" is defined as any attribute that is dependent on time (updated when a new instance x_t is given via score or predict).

However, there is already a notion of "state" introduced when we initialise a detector (or reinitialise it via the reset method). Here, in addition to the attributes already mentioned, we set self.ref_inds, self.c2s, and self.init_test_inds. This leads to considerations:

1. Will there be confusion between the reset and reset_state methods, and do we need to change the docstrings or names?
2. There is randomness involved in the initialisation of LSDDDrift (in _configure_ref_subset). It is likely that if the detector is instantiated later on, and load_state is used to restart from a checkpoint, predictions will still be different compared to those that were observed after save_state was called with the original detector. This would only be avoided if random seeds were set both times. With this in mind, do we want to change our definition of "state" to include self.ref_inds, self.c2s, and self.init_test_inds?

@ascillitoe ascillitoe added the WIP PR is a Work in Progress label Sep 2, 2022
@arnaudvl
Copy link
Contributor

arnaudvl commented Sep 2, 2022

  1. Stateful here relates to whatever state (attributes) changes between prediction calls, not what state is set in the init of the detector, which would make all detectors stateful of course. To avoid possible confusion 1 suggestion would be to name the methods reset_detector and reset_state?
  2. I believe that in general randomness should be handled outside of the detector/library, similar to e.g. PyTorch models. Isn't that randomness already eliminated though when just loading a saved detector? What am I missing here?

@ascillitoe
Copy link
Contributor Author

  • Stateful here relates to whatever state (attributes) changes between prediction calls, not what state is set in the init of the detector, which would make all detectors stateful of course. To avoid possible confusion 1 suggestion would be to name the methods reset_detector and reset_state?
  • I believe that in general randomness should be handled outside of the detector/library, similar to e.g. PyTorch models. Isn't that randomness already eliminated though when just loading a saved detector? What am I missing here?

Nice idea with the name change, and yes I agree, "state" does not refer to any attributes set in init, as that would be "config" (with our definitions).

My only concern is that users might expect a detector to give the same predictions as the original when loaded from a "checkpoint" via save/load_state. That is currently not the case (unless seeds are set manually when the original and re-loaded detectors are instantiated). save/load_detector do not help with this.

Maybe the answer is just to make it clear in the docstrings though... as in any case statistically the detectors behaviour should be the same after the checkpoint even if the exact predictions are not the same?

@ascillitoe ascillitoe added the Type: Serialization Serialization proposals and changes label Sep 29, 2022
@codecov-commenter
Copy link

codecov-commenter commented Nov 7, 2022

Codecov Report

Merging #604 (828e486) into master (f0b57b4) will increase coverage by 0.17%.
The diff coverage is 95.65%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #604      +/-   ##
==========================================
+ Coverage   80.15%   80.32%   +0.17%     
==========================================
  Files         133      137       +4     
  Lines        9177     9292     +115     
==========================================
+ Hits         7356     7464     +108     
- Misses       1821     1828       +7     
Flag Coverage Δ
macos-latest-3.10 76.87% <95.65%> (+0.21%) ⬆️
ubuntu-latest-3.10 80.21% <95.65%> (+0.17%) ⬆️
ubuntu-latest-3.7 80.11% <95.65%> (+0.17%) ⬆️
ubuntu-latest-3.8 80.16% <95.65%> (+0.17%) ⬆️
ubuntu-latest-3.9 80.16% <95.65%> (+0.17%) ⬆️
windows-latest-3.9 76.80% <95.65%> (+0.21%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
alibi_detect/cd/base_online.py 88.23% <82.60%> (-2.96%) ⬇️
alibi_detect/cd/lsdd_online.py 93.61% <85.71%> (+0.75%) ⬆️
alibi_detect/cd/mmd_online.py 94.33% <85.71%> (+0.58%) ⬆️
alibi_detect/utils/state/state.py 97.50% <97.50%> (ø)
alibi_detect/base.py 85.45% <100.00%> (+0.98%) ⬆️
alibi_detect/cd/cvm_online.py 75.63% <100.00%> (+1.52%) ⬆️
alibi_detect/cd/fet_online.py 88.97% <100.00%> (+0.17%) ⬆️
alibi_detect/cd/pytorch/lsdd_online.py 95.78% <100.00%> (+0.18%) ⬆️
alibi_detect/cd/pytorch/mmd_online.py 100.00% <100.00%> (ø)
alibi_detect/cd/tensorflow/lsdd_online.py 95.60% <100.00%> (+0.25%) ⬆️
... and 7 more

@ascillitoe
Copy link
Contributor Author

ascillitoe commented Nov 7, 2022

Will be conflicts until #618 is merged.

Edit: Resolved.

@ascillitoe ascillitoe added this to the v0.11.0 milestone Nov 22, 2022
@ascillitoe
Copy link
Contributor Author

Regarding the codecov report, the 4.32% decrease is not accurate. This is based on an old master commit where we were still counting tests in coverage.

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@ascillitoe
Copy link
Contributor Author

@arnaudvl @ojcobb (and @jklaise/@mauicv) could do with your thoughts on this.

In the latest implementation, I have removed the new reset_state method. State can still be reset with the existing reset method. reset calls _initialise internally, which involves a random operation (for the LSDD detector at least). Because of this, we have a lack of determinism even if random seeds are manually set (explanation below).

Question: Shall we keep a reset_state method separate to reset, or is this just confusing for users? @arnaudvl previously suggested we could rename reset to reset_detector to avoid confusion... but it might be worth discussing what we think the use case of each method actually is anyway. If we don't care about determinism in the case of Example 1, we can remove reset_state I think...

As Example 2 shows, determinism in the case of saving/loading of a detector is not affected by this decision anyway...

Difference between reset and reset_state

(all examples are for LSDDDriftOnline)

reset

reset is an existing method, which calls _initialise:

    def reset(self) -> None:
        "Resets the detector but does not reconfigure thresholds."
        self._initialise()

_initialise typically sets some attributes to zero, and calls _configure_ref_subset (this method contains random ops!):

    def _initialise(self) -> None:
        self.t = 0  # corresponds to a test set of ref data
        self.test_stats = np.array([])  # type: ignore[var-annotated]
        self.drift_preds = np.array([])  # type: ignore[var-annotated]
        self._configure_ref_subset()

reset_state

reset_state was/is a new method that specifically only resets the core "stateful" attributes (those updated by _update_state):

    def reset_state(self):
        """
        Reset the detectors state.
        """
        self.t = 0
        self.test_window = self.x_ref_eff[self.init_test_inds]
        self.k_xtc = self.kernel(self.test_window, self.kernel_centers)

This requires _initialise to have been run (so that self.test_window has been set), however it doesn't re-run _initialise, therefore doesn't involve random ops.

Examples

Example 1: Determinism when resetting

When resetting an instantiated detector and repeating predictions, test stats will be repeatable if reset_state used. Example from test_lsdd_online_pt.py:

    # Run for 50 time steps
    test_stats_1 = []
    for t, x_t in enumerate(x):
        preds = dd.predict(x_t)
        test_stats_1.append(preds['data']['test_stat'])
        if t == 20:
            dd.save_state(tmp_path)

    # Clear state and repeat, check that same test_stats both times
    dd.reset_state()
    test_stats_2 = []
    for t, x_t in enumerate(x):
        preds = dd.predict(x_t)
        test_stats_2.append(preds['data']['test_stat'])
    np.testing.assert_array_equal(test_stats_1, test_stats_2)  # passes!

This fails if reset is used, even if torch.manual_seed() is run before instantiating the detector and before reset. Setting seeds externally does not help here because the number of torch.random operations run prior to reaching _initialise is different in both cases (fresh instantiation also involves random operations in _configure_kernel_centers and _configure_thresholds).

Example 2: Determinism when saving/loading

When saving and loading a stateful detector via save_detector(..., save_state=True) and load_detector, predictions following save_detector and load_detector will consistent as long as seeds are set manually. Example from test_saving.py:

     with fixed_seed(seed):
         dd = detector(X_ref, ert=100, window_size=10, backend='pytorch')
    # Run for 50 time-steps
    test_stats = []
    for t, x_t in enumerate(X_h0[:50]):
        test_stats.append(dd.predict(x_t)['data']['test_stat'])
        if t == 20:
            # Save detector (with state)
            save_detector(dd, tmp_path, save_state=True)

    # Check state/ dir created
    state_path = dd.state_path if detector == CVMDriftOnline else dd._detector.state_path
    assert state_path == tmp_path.joinpath('state')
    assert state_path.is_dir()

    # Load
    with fixed_seed(seed):
        dd_new = load_detector(tmp_path)
    # Check attributes and compare predictions at t=21
    assert dd_new.t == 21
    np.testing.assert_array_equal(dd_new.predict(X_h0[21])['data']['test_stat'], test_stats[21])

This use case does not depend on design of reset/reset_state etc.

@ascillitoe
Copy link
Contributor Author

Additional side-note, this issue with setting random seeds not giving deterministic behaviour for a given operation (in this case the torch.randperm called by reset -> _initialise -> _configure_ref_subset) is something we've run into a few times before. The issue is that even if a given random state is set externally (i.e. torch.manual_seed), the random state will be different by the time we get to the torch.randperm if a different number of random operations are called before we get there, since each random op cycles the random state. This is made even more difficult for something like LSDDDriftOnline, since we have random ops in a while loop, so cannot do not know how many random ops will be called.

The only solution I can think of for this is a scikit-learn style approach, where we accept random_state as a kwarg, and then use self.random_state in random ops we want to be deterministic. We would have to be careful with ops like the torch.randperm in _configure_ref_subset, since we do need this to possess a degree of randomness over each iteration in the while loop...

@ascillitoe
Copy link
Contributor Author

ascillitoe commented Nov 28, 2022

@arnaudvl @ojcobb a possible alternative strategy to make reset deterministic is to rework _initialize methods to ensure they are deterministic. For example, for LSDDDriftOnline, this would involve avoiding the while loop to find a new self.init_test_inds in _configure_ref_subset if init_test_inds already exists:

    def _configure_ref_subset(self):
        """
        Configure reference subset. If already configured, the stateful attributes `test_window` and `k_xtc` are
        reset without re-configuring a new reference subset.
        """
        etw_size = 2 * self.window_size - 1  # etw = extended test window
        nkc_size = self.n - self.n_kernel_centers  # nkc = non-kernel-centers
        rw_size = nkc_size - etw_size  # rw = ref-window
        # Check if already configured, we will re-initialise stateful attributes w/o searching for new ref split if so
        configure_ref = self.init_test_inds is None
        if configure_ref:
            # Make split and ensure it doesn't cause an initial detection
            lsdd_init = None
            while lsdd_init is None or lsdd_init >= self.get_threshold(0):
                # Make split
                perm = torch.randperm(nkc_size)
                self.ref_inds, self.init_test_inds = perm[:rw_size], perm[-self.window_size:]
                self.test_window = self.x_ref_eff[self.init_test_inds]
                # Compute initial lsdd to check for initial detection
                self.c2s = self.k_xc[self.ref_inds].mean(0)  # (below Eqn 21)
                self.k_xtc = self.kernel(self.test_window, self.kernel_centers)
                h_init = self.c2s - self.k_xtc.mean(0)  # (Eqn 21)
                lsdd_init = h_init[None, :] @ self.H_lam_inv @ h_init[:, None]  # (Eqn 11)
        else:
            # Reset stateful attributes using existing split
            self.test_window = self.x_ref_eff[self.init_test_inds]
            self.k_xtc = self.kernel(self.test_window, self.kernel_centers)

This seems like a reasonable compromise to me? However, the additional duplication/complexity is unnecessary if we truly don't care about repeatable predictions post-reset?

Comment on lines 234 to 236
def save_state(self, filepath): ...

def load_state(self, filepath): ...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should parameters have type hints?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in 15d1dfa. Note: I also updated the pre-existing get_config and set_config methods here.

@ascillitoe
Copy link
Contributor Author

ascillitoe commented Jan 13, 2023

@ascillitoe Is this file move ok? I thought conftest.py should always be inside a folder named tests? alibi_detect/saving/tests/conftest.py → alibi_detect/conftest.py Doesn't this make conftest a public module of alibi-detect which we wouldn't want?

Mmn good point, thinking again it doesn't seem ideal to have it outside of tests/. I moved it so that we didn't have to duplicate the seed fixture in multiple places. I couldn't think of a better way to have a global conftest that is shared across all tests. Do you know of a way?

Re it becoming a public module I suspect you're mostly right. We do __all__ = ["ad", "cd", "models", "od", "utils", "saving"] in alibi_detect.__init__ so it won't be exposed by dir(alibi_detect) at least. However, seed could technically be imported by:

from alibi_detect.saving.conftest import seed

Weirdly though, with our alibi-detect v0.10.4, the following also works! (think I need to undo the file move, but maybe we also need to double check this separately)

from alibi_detect.saving.tests.conftest import seed
seed(0)

filepath
The directory to load state from.
"""
self._set_state_dir(filepath)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this necessary when loading?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just so that self.state_dir is set (and converted from str to pathlib.Path) when load_stateis called as well as whensave_state` is called.

I thought it might be helpful to have state_dir as a public attribute so that a user could see interrogate the detector to see where state was loaded from. Although thinking about it more, for the backend detectors one would have to do detector._detector.state_dir (access a private attribute) anyway. I guess we'd probably want to define a @property if we actually want to support this functionality properly...

Happy to just make it private if you think its better though...

dirpath
The directory to save state file inside.
"""
self.state_dir = Path(dirpath)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be a private attribute?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 390 to 447
def _set_state_dir(self, dirpath: Union[str, os.PathLike]):
"""
Set the directory path to store state in, and create an empty directory if it doesn't already exist.

Parameters
----------
dirpath
The directory to save state file inside.
"""
self.state_dir = Path(dirpath)
self.state_dir.mkdir(parents=True, exist_ok=True)

def save_state(self, filepath: Union[str, os.PathLike]):
"""
Save a detector's state to disk in order to generate a checkpoint.

Parameters
----------
filepath
The directory to save state to.
"""
self._set_state_dir(filepath)
self._save_state()
logger.info('Saved state for t={} to {}'.format(self.t, self.state_dir))

def load_state(self, filepath: Union[str, os.PathLike]):
"""
Load the detector's state from disk, in order to restart from a checkpoint previously generated with
`save_state`.

Parameters
----------
filepath
The directory to load state from.
"""
self._set_state_dir(filepath)
self._load_state()
logger.info('State loaded for t={} from {}'.format(self.t, self.state_dir))

def _save_state(self):
"""
Private method to save a detector's state to disk.

TODO - Method slightly verbose as designed to facilitate saving of "offline" state in follow-up PR.
"""
filename = 'state'
keys = self.online_state_keys
save_state_dict(self, keys, self.state_dir.joinpath(filename + '.npz'))

def _load_state(self, offline: bool = False):
"""
Private method to load a detector's state from disk.

TODO - Method slightly verbose as designed to facilitate loading of "offline" state in follow-up PR.
"""
filename = 'state'
load_state_dict(self, self.state_dir.joinpath(filename + '.npz'), raise_error=True)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like a lot of duplicated code that's exactly the same as for BaseMultiDriftOnline which suggests we may want to refactor using functions instead of methods or a mixin class? Or perhaps the class hierarchy needs to be updated.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this seems to apply to other methods too, so perhaps is a more widespread problem requiring a refactoring later...

Copy link
Contributor Author

@ascillitoe ascillitoe Jan 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point. Having a BaseDriftOnline class for generic methods, or a mix-in both seem much nicer than this current pattern. I'll have a rethink 👍🏻

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To reduce duplication, 5daf1b1 adds a BaseDriftOnline class. @jklaise @mauicv could I get your thoughts on the design of BaseDriftOnline please? I've gone with a parent class rather than mix-in since it seems strange to define a mix-in in alibi_detect/base.py when it is only to be used in two classes (BaseMultiDriftOnline and BaseUniDriftOnline). I also decided to put it in alibi_detect/cd/base_online.py rather than alibi_detect/base.py since at the moment the concept of "online" detectors is specific to drift (this may change if we decide stateful outlier detectors are in fact "online").

Copy link
Contributor

@jklaise jklaise Jan 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM however noting that there's quite a few abstract methods, some of which (not all?) are implemented in the Multi/Uni abstract child classes, which come with their own set of abstract methods... Worried that this may become a bit tricky to keep track of. As a minimum, would group all abstract methods to come after each other and add docstrings on expected implementation and also, where valid, which of the Multi/Uni classes implement these methods (+ type hints as always).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #604 (comment) wrt to type hints, not sure on best approach here.

Wrt to the abstract methods, if they are missing from the Multi/Uni child classes that will be because they are instead defined in the next subclass down i.e. LSDDDriftOnlineTorch._initialise_state or CVMDriftOnline._configure_thresholds...

We could move the abstract methods such as _configure_thresholds back to their respective Multi/Uni abstract class, at the cost of more duplication (but maybe less complexity?)

Copy link
Contributor Author

@ascillitoe ascillitoe Jan 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the new base class, and moved state methods to StateMixin. See #604 (comment).

Comment on lines +818 to +820
# Skip if backend not `tensorflow` or `pytorch`
if backend not in ('tensorflow', 'pytorch'):
pytest.skip("Detector doesn't have this backend")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this due to some keops behaviour? Basically asking why skip here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test_saving file cycles through all possible backends:

backends = ['tensorflow', 'pytorch', 'sklearn']
if has_keops:  # pykeops only installed in Linux CI
    backends.append('keops')
backend = param_fixture("backend", backends)

We have to skip tests if the associated detector doesn't have that backend. In this case, online detectors do not have a keops backend.

@pytest.mark.parametrize('batch_size', batch_size)
@pytest.mark.parametrize('n_feat', n_features)
def test_cvmdriftonline(window_sizes, batch_size, n_feat, seed):
with fixed_seed(seed):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noting that the previous version of tests didn't have a fixed seed, presumably it wasn't needed in this setting as test suite has been passing. Is there a need to introduce a fixed seed here as it seems detrimental to the testing for this particular set of tests?

Copy link
Contributor

@jklaise jklaise Jan 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P.S. same comment applies to tests below and in other modules.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mmn yeh I figured it would be good to add since there is randomness in the initialization of these detectors (in _configure_thresholds, and in the generation of x_ref/x_h0/x_h1). Although the tests do currently pass without fixing the seed, this doesn't actually mean they pass for any random seed. I seem to recall that when I looked into this before, np.random.seed's set in one test leaked into others. Presumably, this means we have been implicitly fixing the seed in these tests anyway.

Ideally (IMO), we'd get to a point where any random operations in tests are done inside with fixed_seed(seed)'s, then if a new bug is introduced, we can go back and reproduce it with the same random seed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that in case a bug happens then it's valuable to be able to reproduce with the same seed. But, on the other hand, "any random operations in tests done inside with fixed_seed(seed)" sounds like the opposite to what we want to do (unless for tests where we compare outputs of the same seed) - as stuff should pass most tests with any seed?

Copy link
Contributor Author

@ascillitoe ascillitoe Jan 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeh fair point, tests should generally pass with any seed, especially if they are unit type tests. The problem at the moment is we have lots of functional tests where we are testing a detector's predictions and checking things like Expected Runtime (ERT) for online detectors. We probably want more granular unit tests in lots of places...

Edit: by "any random operations in tests done inside with fixed_seed(seed)", I more meant any random operations that might for some reason affect the outcome of the test.

Copy link
Contributor

@jklaise jklaise left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, main question about default behaviour wrt state saving when save_detector is called on online detectors. Regardless of choice, I believe this should be prominent in saving and method docs.

Comment on lines 31 to 33
@abstractmethod
def _update_state(self, x_t):
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type hints of parameters and return types required.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unfortunately necessary, since univariate online detectors have def _update_state(self, x_t: np.ndarray): whilst multivariate have def _update_state(self, x_t: torch.Tensor): (or tf.Tensor). We violate Liskov's substitution principle slightly.

I sort of think this is OK to not add type hints in the abstract method since we only have it there to signal that sub-classes must have an _update_method which takes an instance and updates online state, but we don't specify the exact type. However, we could also do def _update_state(self, x_t: Union[np.ndarray, 'torch.Tensor', 'tf.Tensor'): and then add # type: ignore[override] in the sub-class? This is actually what we did previously...

P.s. its a similar story for the get_threshold method...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jklaise @mauicv I have got rid of the new BaseDriftOnline class. There are stil some existing issues like the ones above, but this PR now shouldn't introduce new ones at least!

The save_state/load_state methods are now added via StateMixin.

@ascillitoe
Copy link
Contributor Author

d190589 removes save_state from save_detector. The new logic is to always save state if self.t > 0. If this is not desired, .reset() can be called prior to using save_detector.

Copy link
Contributor

@jklaise jklaise left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Should we add some documentation somewhere that save will save the state by default, and if that's not desired one should call reset_state first? Or is it going to be too confusing for now?

@ascillitoe
Copy link
Contributor Author

LGTM! Should we add some documentation somewhere that save will save the state by default, and if that's not desired one should call reset_state first? Or is it going to be too confusing for now?

Thanks! Will add this documentation now, just realised I added it in #628 instead of here. Doh!

[Online drift detectors](../cd/methods.md#online) are stateful, with their state updated upon each `predict` call.
When saving an online detector, the `save_state` option controls whether to include the detector's state:
[Online drift detectors](../cd/methods.md#online) are stateful, with their state updated each timestep `t` (each time
`.predict()` or `.state()` is called). {func}`~alibi_detect.saving.save_detector` will save the state of online
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean score() instead of state()? On that note, do we even document the usage/use cases of score()? If not, perhaps should leave it out.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good spot thanks. Also fair point about not really documenting it. I'll remove.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed .state() in d888276

@ascillitoe ascillitoe removed the WIP PR is a Work in Progress label Jan 18, 2023
Copy link
Collaborator

@mauicv mauicv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@ascillitoe ascillitoe merged commit 1b3295f into SeldonIO:master Jan 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Type: Serialization Serialization proposals and changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants