-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PoC] Add KFold - External Loop. #8715
Changes from all commits
c7a1fe5
210bd51
91a9dff
2fb8496
6ac569f
626525e
480744e
7cca34d
40089ae
d6d30fd
dd56081
d209d53
615ab30
de5da36
d86e7af
4b42c07
8a03ea1
f853b60
8d667f1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
""" | ||
WARNING: Loop customization is in `pre-alpha release` and the API is likely to change quite a lot ! | ||
Please, open issues with your own particular requests, so the Lightning Team can progressively converge to a great API. | ||
""" | ||
|
||
from typing import Any, Dict, List, Optional, Type | ||
|
||
import numpy as np | ||
from sklearn.model_selection import KFold | ||
from torch.utils.data.dataloader import DataLoader | ||
from torch.utils.data.dataset import Dataset, Subset | ||
|
||
from pytorch_lightning import _logger as log | ||
from pytorch_lightning import LightningDataModule, seed_everything | ||
from pytorch_lightning.callbacks.base import Callback | ||
from pytorch_lightning.loops.external_loop import ExternalLoop | ||
from pytorch_lightning.utilities import rank_zero_only | ||
from pytorch_lightning.utilities.boring_model import BoringModel, RandomDataset | ||
|
||
seed_everything(42) | ||
|
||
|
||
class BaseDataModule(LightningDataModule): | ||
def __init__(self): | ||
super().__init__() | ||
self.non_picklable = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. curious, what was the idea here? seems left over :) |
||
self.checkpoint_state: Optional[str] = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same question here :) |
||
|
||
self._train_dataset: Optional[Dataset] = None | ||
self._val_dataset: Optional[Dataset] = None | ||
self._test_dataset: Optional[Dataset] = None | ||
self._predict_dataset: Optional[Dataset] = None | ||
|
||
self._processed_train_dataset: Optional[Dataset] = None | ||
self._processed_val_dataset: Optional[Dataset] = None | ||
self._processed_test_dataset: Optional[Dataset] = None | ||
self._processed_predict_dataset: Optional[Dataset] = None | ||
|
||
@property | ||
def train_dataset(self) -> Optional[Dataset]: | ||
return self._train_dataset | ||
|
||
@property | ||
def val_dataset(self) -> Optional[Dataset]: | ||
return self._val_dataset | ||
|
||
@property | ||
def test_dataset(self) -> Optional[Dataset]: | ||
return self._test_dataset | ||
|
||
@property | ||
def predict_dataset(self) -> Optional[Dataset]: | ||
return self._predict_dataset | ||
|
||
@property | ||
def processed_train_dataset(self) -> Optional[Dataset]: | ||
return self._processed_train_dataset or self.train_dataset | ||
|
||
@property | ||
def processed_val_dataset(self) -> Optional[Dataset]: | ||
return self._processed_val_dataset or self.val_dataset | ||
|
||
@property | ||
def processed_test_dataset(self) -> Optional[Dataset]: | ||
return self._processed_test_dataset or self.test_dataset | ||
|
||
@property | ||
def processed_predict_dataset(self) -> Optional[Dataset]: | ||
return self._processed_predict_dataset or self.predict_dataset | ||
|
||
@processed_train_dataset.setter | ||
def processed_train_dataset(self, processed_train_dataset) -> None: | ||
self._processed_train_dataset = processed_train_dataset | ||
|
||
@processed_val_dataset.setter | ||
def processed_val_dataset(self, processed_val_dataset) -> None: | ||
self._processed_val_dataset = processed_val_dataset | ||
|
||
@processed_val_dataset.setter | ||
def processed_val_dataset(self, processed_val_dataset) -> None: | ||
self._processed_val_dataset = processed_val_dataset | ||
|
||
@processed_test_dataset.setter | ||
def processed_test_dataset(self, processed_test_dataset) -> None: | ||
self._processed_test_dataset = processed_test_dataset | ||
|
||
def train_dataloader(self) -> DataLoader: | ||
return DataLoader(self.processed_train_dataset) | ||
|
||
def val_dataloader(self) -> DataLoader: | ||
return DataLoader(self.processed_val_dataset) | ||
|
||
def test_dataloader(self) -> DataLoader: | ||
return DataLoader(self.processed_test_dataset) | ||
|
||
def predict_dataloader(self) -> DataLoader: | ||
return DataLoader(self.processed_predict_dataset) | ||
|
||
|
||
class BoringDataModule(BaseDataModule): | ||
def prepare_data(self) -> None: | ||
self.random_full = RandomDataset(32, 64 * 4) | ||
|
||
def setup(self, stage: Optional[str] = None) -> None: | ||
if stage == "fit" or stage is None: | ||
self._train_dataset = Subset(self.random_full, indices=range(64)) | ||
self.dims = self._train_dataset[0].shape | ||
|
||
if stage in ("fit", "validate") or stage is None: | ||
self._val_dataset = Subset(self.random_full, indices=range(64, 64 * 2)) | ||
|
||
if stage == "test" or stage is None: | ||
self._test_dataset = Subset(self.random_full, indices=range(64 * 2, 64 * 3)) | ||
self.dims = getattr(self, "dims", self._test_dataset[0].shape) | ||
|
||
if stage == "predict" or stage is None: | ||
self._predict_dataset = Subset(self.random_full, indices=range(64 * 3, 64 * 4)) | ||
self.dims = getattr(self, "dims", self._predict_dataset[0].shape) | ||
|
||
|
||
class KFoldLoop(ExternalLoop): | ||
tchaton marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def __init__( | ||
self, | ||
num_folds: int, | ||
best_model_paths: List[str] = [], | ||
restarting: bool = False, | ||
): | ||
super().__init__() | ||
self.num_folds = num_folds | ||
self.best_model_paths = best_model_paths | ||
self.restarting = restarting | ||
|
||
@staticmethod | ||
def loop_base_callback() -> Type[Callback]: | ||
class BaseKFoldCallback(Callback): | ||
@rank_zero_only | ||
def on_fold_start(self, trainer, pl_module, counter): | ||
"""Override with your own logic""" | ||
|
||
return BaseKFoldCallback | ||
Comment on lines
+148
to
+154
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can't we define this outside this class but in the file namespace? |
||
|
||
@property | ||
def done(self) -> bool: | ||
return self.current_fold >= self.num_folds | ||
|
||
def reset(self) -> None: | ||
if not self.restarting: | ||
self.current_fold = 0 | ||
|
||
def on_run_start(self, *args: Any, **kwargs: Any) -> None: | ||
# temporary hack | ||
self.trainer.datamodule.setup("fit") | ||
|
||
def on_advance_start(self) -> None: | ||
# more reproducible as re-creating a different trainer. | ||
self.create_trainer(max_epochs=np.random.randint(10)) | ||
# reload dataset for the current fold | ||
dm = self.trainer.datamodule | ||
dm.processed_train_dataset = self.process_dataset("train", dm.train_dataset) | ||
dm.processed_val_dataset = self.process_dataset("val", dm.val_dataset) | ||
# call user hook | ||
self.trainer.call_hook("on_fold_start", self.current_fold) | ||
# reset model parameters | ||
self.trainer.lightning_module.reset_parameters() | ||
|
||
def advance(self) -> Any: | ||
# dataloaders will be automatically reloaded | ||
return self.trainer.fit(self.trainer.lightning_module, datamodule=self.trainer.datamodule) | ||
|
||
def on_advance_end(self) -> None: | ||
self.current_fold += 1 | ||
# stored best weight path for this fold | ||
self.best_model_paths.append(self.trainer.checkpoint_callback.best_model_path) | ||
|
||
# utilities for creating a hold | ||
def process_dataset(self, stage: str, dataset: Dataset) -> Subset: | ||
kfold = KFold(self.num_folds, random_state=42, shuffle=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is a dependency for sklearn worth it for just this? |
||
train_indices, validation_indices = list(kfold.split(range(len(dataset))))[self.current_fold] | ||
indices = train_indices if stage == "train" else validation_indices | ||
return Subset(dataset, indices.tolist()) | ||
|
||
def on_save_checkpoint(self) -> Dict: | ||
return {"current_fold": self.current_fold} | ||
|
||
def on_load_checkpoint(self, state_dict) -> None: | ||
self.current_fold = state_dict["current_fold"] | ||
|
||
|
||
class KFoldCallback(KFoldLoop.loop_base_callback()): | ||
|
||
"""This callback demonstrates how to implement your own callback API.""" | ||
|
||
@rank_zero_only | ||
def on_fold_start(self, trainer, pl_module, counter) -> None: | ||
log.info(f"Starting to train on fold {counter}") | ||
|
||
|
||
loop = KFoldLoop(5) | ||
model = BoringModel() | ||
datamodule = BoringDataModule() | ||
loop.connect_trainer(max_epochs=10, callbacks=KFoldCallback()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. alternatively these could be passed in through init via an argument trainer_kwargs. |
||
loop.run(model, datamodule=datamodule) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import functools | ||
from typing import Any, Callable, Dict, Optional | ||
|
||
import pytorch_lightning as pl | ||
from pytorch_lightning.loops.base import Loop | ||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
from pytorch_lightning.utilities.warnings import WarningCache | ||
|
||
warning_cache = WarningCache() | ||
|
||
|
||
class ExternalLoop(Loop): | ||
"""This Loop is meant wrap trainer calls""" | ||
|
||
def __init__(self): | ||
super().__init__() | ||
warning_cache.warn("The ExternalLoop API is a `pre-alpha release` and breaking API changes are expected.") | ||
self.create_trainer = self._wrap_trainer_wrapper(self.create_trainer) | ||
self._has_setup = False | ||
self._restore_external_loop = True | ||
|
||
def _wrap_trainer_wrapper(self, create_trainer: Callable) -> Callable: | ||
@functools.wraps(create_trainer) | ||
def wrapped_func(*args: Any, **kwargs: Any) -> Optional[Any]: | ||
trainer = create_trainer(*args, trainer_kwargs=self.trainer_kwargs, **kwargs) | ||
if not isinstance(trainer, pl.Trainer): | ||
raise MisconfigurationException("The `create_trainer` hook should return a Trainer") | ||
self.trainer = trainer | ||
self.trainer.external_loop = self | ||
|
||
self.trainer.accelerator.connect(self.__lightning_module) | ||
|
||
# links data to the trainer | ||
self.trainer.data_connector.attach_data( | ||
self.trainer.lightning_module, | ||
train_dataloaders=self.__train_dataloader, | ||
val_dataloaders=self.__val_dataloaders, | ||
test_dataloaders=self.__test_dataloaders, | ||
predict_dataloaders=self.__predict_dataloaders, | ||
datamodule=self.__datamodule, | ||
) | ||
|
||
# attach model to the training type plugin | ||
self.trainer.data_connector.prepare_data() | ||
|
||
self.trainer.checkpoint_connector.resume_start() | ||
self.trainer.checkpoint_connector.restore_loops(restore_external_loop=self._restore_external_loop) | ||
return trainer | ||
|
||
return wrapped_func | ||
|
||
def connect_trainer(self, **trainer_kwargs: Dict[str, Any]) -> None: | ||
self.trainer_kwargs = trainer_kwargs | ||
|
||
def create_trainer(self, *args, trainer_kwargs: Dict[str, Any] = {}, **kwargs) -> "pl.Trainer": | ||
trainer_kwargs.update(kwargs) | ||
return pl.Trainer(*args, **trainer_kwargs) | ||
|
||
def run( | ||
self, | ||
model: "pl.LightningModule", | ||
train_dataloader=None, | ||
val_dataloaders=None, | ||
test_dataloaders=None, | ||
predict_dataloaders=None, | ||
datamodule=None, | ||
): | ||
|
||
self.__lightning_module = model | ||
self.__train_dataloader = train_dataloader | ||
self.__val_dataloaders = val_dataloaders | ||
self.__test_dataloaders = test_dataloaders | ||
self.__predict_dataloaders = predict_dataloaders | ||
self.__datamodule = datamodule | ||
|
||
# if a datamodule comes in as the second arg, then fix it for the user | ||
if isinstance(train_dataloader, pl.LightningDataModule): | ||
datamodule = train_dataloader | ||
train_dataloader = None | ||
|
||
if train_dataloader is not None and datamodule: | ||
raise MisconfigurationException("You cannot pass both `loop.run(dataloaders=..., datamodule=...)`") | ||
|
||
if model is None: | ||
raise MisconfigurationException("`model` must be provided to `loop.run()`") | ||
|
||
if self._trainer is None: | ||
self.create_trainer() | ||
self._restore_external_loop = False | ||
|
||
return super().run() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -186,7 +186,7 @@ def restore_callbacks(self) -> None: | |
) | ||
self.trainer.on_load_checkpoint(self._loaded_checkpoint) | ||
|
||
def restore_loops(self) -> None: | ||
def restore_loops(self, restore_external_loop: bool = False) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's probably enough if this is controlled below by the existence of an external loop as checked below. |
||
""" | ||
Restores the loop progress from the pre-loaded checkpoint. | ||
Calls hooks on the loops to give it a chance to restore its state from the checkpoint. | ||
|
@@ -226,6 +226,11 @@ def restore_loops(self) -> None: | |
self.trainer.test_loop.load_state_dict(state_dict["test_loop"]) | ||
self.trainer.predict_loop.load_state_dict(state_dict["predict_loop"]) | ||
|
||
if restore_external_loop: | ||
external_loop = getattr(self.trainer, "external_loop", None) | ||
if external_loop: | ||
self.trainer.external_loop.load_state_dict(state_dict["external_loop"]) | ||
|
||
def restore_optimizers_and_schedulers(self) -> None: | ||
"""Restores the optimizers and learning rate scheduler states from the pre-loaded checkpoint.""" | ||
if ( | ||
|
@@ -471,9 +476,13 @@ def _get_lightning_module_state_dict(self) -> Dict[str, torch.Tensor]: | |
return state_dict | ||
|
||
def _get_loops_state_dict(self) -> Dict[str, Any]: | ||
return { | ||
state_dict = { | ||
"fit_loop": self.trainer.fit_loop.state_dict(), | ||
"validate_loop": self.trainer.validate_loop.state_dict(), | ||
"test_loop": self.trainer.test_loop.state_dict(), | ||
"predict_loop": self.trainer.predict_loop.state_dict(), | ||
} | ||
external_loop = getattr(self.trainer, "external_loop", None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. perhaps trainer can have a property like one we have for the other loops? |
||
if external_loop: | ||
state_dict.update({"external_loop": external_loop.state_dict()}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can there be more than one external loop. I mean, one external loop nested inside another? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It could and Loop will automatically gather their children states. |
||
return state_dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rather not seed anything globally