diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 82a0e6b0436a6..b804241aa1b7e 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -14,11 +14,11 @@ class Callback(abc.ABC): Abstract base class used to build new callbacks. """ - def setup(self, trainer, stage: str): + def setup(self, trainer, pl_module, stage: str): """Called when fit or test begins""" pass - def teardown(self, trainer, stage: str): + def teardown(self, trainer, pl_module, stage: str): """Called when fit or test ends""" pass @@ -30,11 +30,11 @@ def on_init_end(self, trainer): """Called when the trainer initialization ends, model has not yet been set.""" pass - def on_fit_start(self, trainer): + def on_fit_start(self, trainer, pl_module): """Called when fit begins""" pass - def on_fit_end(self, trainer): + def on_fit_end(self, trainer, pl_module): """Called when fit ends""" pass @@ -46,11 +46,11 @@ def on_sanity_check_end(self, trainer, pl_module): """Called when the validation sanity check ends.""" pass - def on_train_batch_start(self, trainer, pl_module): + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): """Called when the validation batch begins.""" pass - def on_train_batch_end(self, trainer, pl_module): + def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx): """Called when the validation batch ends.""" pass @@ -90,19 +90,19 @@ def on_batch_start(self, trainer, pl_module): """Called when the training batch begins.""" pass - def on_validation_batch_start(self, trainer, pl_module): + def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): """Called when the validation batch begins.""" pass - def on_validation_batch_end(self, trainer, pl_module): + def on_validation_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx): """Called when the validation batch ends.""" pass - def on_test_batch_start(self, trainer, pl_module): + def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): """Called when the test batch begins.""" pass - def on_test_batch_end(self, trainer, pl_module): + def on_test_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx): """Called when the test batch ends.""" pass @@ -118,6 +118,14 @@ def on_train_end(self, trainer, pl_module): """Called when the train ends.""" pass + def on_pretrain_routine_start(self, trainer, pl_module): + """Called when the pretrain routine begins.""" + pass + + def on_pretrain_routine_end(self, trainer, pl_module): + """Called when the pretrain routine ends.""" + pass + def on_validation_start(self, trainer, pl_module): """Called when the validation loop begins.""" pass diff --git a/pytorch_lightning/callbacks/lr_logger.py b/pytorch_lightning/callbacks/lr_logger.py index 7ec73b8c88811..27fbb81800241 100755 --- a/pytorch_lightning/callbacks/lr_logger.py +++ b/pytorch_lightning/callbacks/lr_logger.py @@ -64,7 +64,7 @@ def on_train_start(self, trainer, pl_module): # Initialize for storing values self.lrs = {name: [] for name in names} - def on_train_batch_start(self, trainer, pl_module): + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): latest_stat = self._extract_lr(trainer, 'step') if trainer.logger and latest_stat: trainer.logger.log_metrics(latest_stat, step=trainer.global_step) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 4ab990f74724e..c3cd9137c9ed7 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -138,19 +138,19 @@ def on_train_start(self, trainer, pl_module): def on_epoch_start(self, trainer, pl_module): self._train_batch_idx = 0 - def on_train_batch_end(self, trainer, pl_module): + def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx): self._train_batch_idx += 1 def on_validation_start(self, trainer, pl_module): self._val_batch_idx = 0 - def on_validation_batch_end(self, trainer, pl_module): + def on_validation_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx): self._val_batch_idx += 1 def on_test_start(self, trainer, pl_module): self._test_batch_idx = 0 - def on_test_batch_end(self, trainer, pl_module): + def on_test_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx): self._test_batch_idx += 1 @@ -318,8 +318,8 @@ def on_epoch_start(self, trainer, pl_module): self.main_progress_bar.reset(convert_inf(total_batches)) self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch + 1}') - def on_train_batch_end(self, trainer, pl_module): - super().on_train_batch_end(trainer, pl_module) + def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + super().on_train_batch_end(trainer, pl_module, batch, batch_idx, dataloader_idx) if self.is_enabled and self.train_batch_idx % self.refresh_rate == 0: self.main_progress_bar.update(self.refresh_rate) self.main_progress_bar.set_postfix(trainer.progress_bar_dict) @@ -329,8 +329,8 @@ def on_validation_start(self, trainer, pl_module): self.val_progress_bar = self.init_validation_tqdm() self.val_progress_bar.total = convert_inf(self.total_val_batches) - def on_validation_batch_end(self, trainer, pl_module): - super().on_validation_batch_end(trainer, pl_module) + def on_validation_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + super().on_validation_batch_end(trainer, pl_module, batch, batch_idx, dataloader_idx) if self.is_enabled and self.val_batch_idx % self.refresh_rate == 0: self.val_progress_bar.update(self.refresh_rate) self.main_progress_bar.update(self.refresh_rate) @@ -349,8 +349,8 @@ def on_test_start(self, trainer, pl_module): self.test_progress_bar = self.init_test_tqdm() self.test_progress_bar.total = convert_inf(self.total_test_batches) - def on_test_batch_end(self, trainer, pl_module): - super().on_test_batch_end(trainer, pl_module) + def on_test_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + super().on_test_batch_end(trainer, pl_module, batch, batch_idx, dataloader_idx) if self.is_enabled and self.test_batch_idx % self.refresh_rate == 0: self.test_progress_bar.update(self.refresh_rate) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 1218dcbe6760f..922f500ce8c25 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -77,7 +77,31 @@ def on_train_end(self) -> None: """ # do something at the end of training - def on_train_batch_start(self, batch: Any) -> None: + def on_pretrain_routine_start(self) -> None: + """ + Called at the beginning of the pretrain routine (between fit and train start). + + - fit + - pretrain_routine start + - pretrain_routine end + - training_start + + """ + # do something at the start of the pretrain routine + + def on_pretrain_routine_end(self) -> None: + """ + Called at the end of the pretrain routine (between fit and train start). + + - fit + - pretrain_routine start + - pretrain_routine end + - training_start + + """ + # do something at the end of the pretrain routine + + def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: """ Called in the training loop before anything happens for that batch. @@ -85,12 +109,19 @@ def on_train_batch_start(self, batch: Any) -> None: Args: batch: The batched data as it is returned by the training DataLoader. + batch_idx: the index of the batch + dataloader_idx: the index of the dataloader """ # do something when the batch starts - def on_train_batch_end(self) -> None: + def on_train_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: """ Called in the training loop after the batch. + + Args: + batch: The batched data as it is returned by the training DataLoader. + batch_idx: the index of the batch + dataloader_idx: the index of the dataloader """ # do something when the batch end diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 7c62743455317..b10c86af451b0 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -14,12 +14,12 @@ class TrainerCallbackHookMixin(ABC): def setup(self, stage: str): """Called in the beginning of fit and test""" for callback in self.callbacks: - callback.setup(self, stage) + callback.setup(self, self.get_model(), stage) def teardown(self, stage: str): """Called at the end of fit and test""" for callback in self.callbacks: - callback.teardown(self, stage) + callback.teardown(self, self.get_model(), stage) def on_init_start(self): """Called when the trainer initialization begins, model has not yet been set.""" @@ -31,15 +31,15 @@ def on_init_end(self): for callback in self.callbacks: callback.on_init_end(self) - def on_fit_start(self): + def on_fit_start(self, model): """Called when the trainer initialization begins, model has not yet been set.""" for callback in self.callbacks: - callback.on_fit_start(self) + callback.on_fit_start(self, model) def on_fit_end(self): """Called when the trainer initialization begins, model has not yet been set.""" for callback in self.callbacks: - callback.on_fit_end(self) + callback.on_fit_end(self, self.get_model()) def on_sanity_check_start(self): """Called when the validation sanity check starts.""" @@ -101,6 +101,16 @@ def on_train_end(self): for callback in self.callbacks: callback.on_train_end(self, self.get_model()) + def on_pretrain_routine_start(self, model): + """Called when the train begins.""" + for callback in self.callbacks: + callback.on_pretrain_routine_start(self, model) + + def on_pretrain_routine_end(self, model): + """Called when the train ends.""" + for callback in self.callbacks: + callback.on_pretrain_routine_end(self, model) + def on_batch_start(self): """Called when the training batch begins.""" for callback in self.callbacks: @@ -111,35 +121,35 @@ def on_batch_end(self): for callback in self.callbacks: callback.on_batch_end(self, self.get_model()) - def on_train_batch_start(self): + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): """Called when the training batch begins.""" for callback in self.callbacks: - callback.on_train_batch_start(self, self.get_model()) + callback.on_train_batch_start(self, self.get_model(), batch, batch_idx, dataloader_idx) - def on_train_batch_end(self): + def on_train_batch_end(self, batch, batch_idx, dataloader_idx): """Called when the training batch ends.""" for callback in self.callbacks: - callback.on_train_batch_end(self, self.get_model()) + callback.on_train_batch_end(self, self.get_model(), batch, batch_idx, dataloader_idx) - def on_validation_batch_start(self): + def on_validation_batch_start(self, batch, batch_idx, dataloader_idx): """Called when the validation batch begins.""" for callback in self.callbacks: - callback.on_validation_batch_start(self, self.get_model()) + callback.on_validation_batch_start(self, self.get_model(), batch, batch_idx, dataloader_idx) - def on_validation_batch_end(self): + def on_validation_batch_end(self, batch, batch_idx, dataloader_idx): """Called when the validation batch ends.""" for callback in self.callbacks: - callback.on_validation_batch_end(self, self.get_model()) + callback.on_validation_batch_end(self, self.get_model(), batch, batch_idx, dataloader_idx) - def on_test_batch_start(self): + def on_test_batch_start(self, batch, batch_idx, dataloader_idx): """Called when the test batch begins.""" for callback in self.callbacks: - callback.on_test_batch_start(self, self.get_model()) + callback.on_test_batch_start(self, self.get_model(), batch, batch_idx, dataloader_idx) - def on_test_batch_end(self): + def on_test_batch_end(self, batch, batch_idx, dataloader_idx): """Called when the test batch ends.""" for callback in self.callbacks: - callback.on_test_batch_end(self, self.get_model()) + callback.on_test_batch_end(self, self.get_model(), batch, batch_idx, dataloader_idx) def on_validation_start(self): """Called when the validation loop begins.""" diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index add9bb24c672a..75d09db2cec65 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -312,9 +312,9 @@ def _evaluate( # callbacks if test_mode: - self.on_test_batch_start() + self.on_test_batch_start(batch, batch_idx, dataloader_idx) else: - self.on_validation_batch_start() + self.on_validation_batch_start(batch, batch_idx, dataloader_idx) # ----------------- # RUN EVALUATION STEP @@ -336,13 +336,13 @@ def _evaluate( model_ref = self.get_model() with self.profiler.profile('test_step_end'): output = model_ref.test_step_end(output) - self.on_test_batch_end() + self.on_test_batch_end(batch, batch_idx, dataloader_idx) else: if self.is_overridden('validation_step_end'): model_ref = self.get_model() with self.profiler.profile('validation_step_end'): output = model_ref.validation_step_end(output) - self.on_validation_batch_end() + self.on_validation_batch_end(batch, batch_idx, dataloader_idx) # track outputs for collation if output is not None: diff --git a/pytorch_lightning/trainer/lr_finder.py b/pytorch_lightning/trainer/lr_finder.py index 23ad702956e84..e028bf60eda7e 100755 --- a/pytorch_lightning/trainer/lr_finder.py +++ b/pytorch_lightning/trainer/lr_finder.py @@ -382,7 +382,7 @@ def on_batch_start(self, trainer, pl_module): self.lrs.append(trainer.lr_schedulers[0]['scheduler'].lr[0]) - def on_train_batch_end(self, trainer, pl_module): + def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx): """ Called when the training batch ends, logs the calculated loss """ if (trainer.batch_idx + 1) % trainer.accumulate_grad_batches != 0: return diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4b342328df297..0d596d7dd611c 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -958,7 +958,7 @@ def fit( self.config_validator.verify_loop_configurations(model) # callbacks - self.on_fit_start() + self.on_fit_start(model) if self.is_function_implemented('on_fit_start', model): model.on_fit_start() @@ -1055,13 +1055,12 @@ def fit( self.accelerator_backend.setup(model) results = self.accelerator_backend.train(model) - # callbacks + # on fit end callback self.on_fit_end() - - # model hooks if self.is_function_implemented('on_fit_end'): model.on_fit_end() + # teardown callback self.teardown('fit') if self.is_function_implemented('teardown'): model.teardown('fit') @@ -1156,6 +1155,11 @@ def run_pretrain_routine(self, model: LightningModule): # register auto-resubmit when on SLURM self.register_slurm_signal_handlers() + # on pretrain routine start + self.on_pretrain_routine_start(ref_model) + if self.is_function_implemented('on_pretrain_routine_start'): + ref_model.on_pretrain_routine_start() + # print model summary if self.is_global_zero and self.weights_summary is not None and not self.testing: if self.weights_summary in ModelSummary.MODES: @@ -1198,6 +1202,11 @@ def run_pretrain_routine(self, model: LightningModule): with torch.cuda.device(f'cuda:{self.root_gpu}'): torch.cuda.empty_cache() + # on pretrain routine end + self.on_pretrain_routine_end(ref_model) + if self.is_function_implemented('on_pretrain_routine_end'): + ref_model.on_pretrain_routine_end() + # CORE TRAINING LOOP self.train() diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 993e8ccd53fd0..e1e7737247106 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -703,11 +703,12 @@ def run_training_batch(self, batch, batch_idx): return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic) with self.profiler.profile('on_train_batch_start'): - # callbacks - self.on_train_batch_start() + # forward support for multiple loaders + dataloader_idx = 0 + self.on_train_batch_start(batch, batch_idx, dataloader_idx) # hooks if self.is_function_implemented('on_train_batch_start'): - response = self.get_model().on_train_batch_start(batch) + response = self.get_model().on_train_batch_start(batch, batch_idx, dataloader_idx) if response == -1: return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic) @@ -798,11 +799,12 @@ def run_training_batch(self, batch, batch_idx): self.get_model().on_batch_end() with self.profiler.profile('on_train_batch_end'): - # callbacks - self.on_train_batch_end() + # forward support for multiple loaders + dataloader_idx = 0 + self.on_train_batch_end(batch, batch_idx, dataloader_idx) # model hooks if self.is_function_implemented('on_train_batch_end'): - self.get_model().on_train_batch_end() + self.get_model().on_train_batch_end(batch, batch_idx, dataloader_idx) # collapse all metrics into one dict batch_log_metrics = {k: v for d in batch_log_metrics for k, v in d.items()} diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 83de82c71de67..90a3d9117c376 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -36,16 +36,18 @@ def __init__(self): self.on_test_batch_end_called = False self.on_train_start_called = False self.on_train_end_called = False + self.on_pretrain_routine_start_called = False + self.on_pretrain_routine_end_called = False self.on_validation_start_called = False self.on_validation_end_called = False self.on_test_start_called = False self.on_test_end_called = False - def setup(self, trainer, stage: str): + def setup(self, trainer, pl_module, stage: str): assert isinstance(trainer, Trainer) self.setup_called = True - def teardown(self, trainer, step: str): + def teardown(self, trainer, pl_module, step: str): assert isinstance(trainer, Trainer) self.teardown_called = True @@ -57,11 +59,11 @@ def on_init_end(self, trainer): assert isinstance(trainer, Trainer) self.on_init_end_called = True - def on_fit_start(self, trainer): + def on_fit_start(self, trainer, pl_module): assert isinstance(trainer, Trainer) self.on_fit_start_called = True - def on_fit_end(self, trainer): + def on_fit_end(self, trainer, pl_module): assert isinstance(trainer, Trainer) self.on_fit_end_called = True @@ -89,27 +91,27 @@ def on_batch_end(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_batch_end_called = True - def on_train_batch_start(self, trainer, pl_module): + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): _check_args(trainer, pl_module) self.on_train_batch_start_called = True - def on_train_batch_end(self, trainer, pl_module): + def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx): _check_args(trainer, pl_module) self.on_train_batch_end_called = True - def on_validation_batch_start(self, trainer, pl_module): + def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): _check_args(trainer, pl_module) self.on_validation_batch_start_called = True - def on_validation_batch_end(self, trainer, pl_module): + def on_validation_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx): _check_args(trainer, pl_module) self.on_validation_batch_end_called = True - def on_test_batch_start(self, trainer, pl_module): + def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): _check_args(trainer, pl_module) self.on_test_batch_start_called = True - def on_test_batch_end(self, trainer, pl_module): + def on_test_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx): _check_args(trainer, pl_module) self.on_test_batch_end_called = True @@ -121,6 +123,14 @@ def on_train_end(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_train_end_called = True + def on_pretrain_routine_start(self, trainer, pl_module): + _check_args(trainer, pl_module) + self.on_pretrain_routine_start_called = True + + def on_pretrain_routine_end(self, trainer, pl_module): + _check_args(trainer, pl_module) + self.on_pretrain_routine_end_called = True + def on_validation_start(self, trainer, pl_module): _check_args(trainer, pl_module) self.on_validation_start_called = True @@ -168,6 +178,8 @@ def on_test_end(self, trainer, pl_module): assert not test_callback.on_test_batch_end_called assert not test_callback.on_train_start_called assert not test_callback.on_train_end_called + assert not test_callback.on_pretrain_routine_start_called + assert not test_callback.on_pretrain_routine_end_called assert not test_callback.on_validation_start_called assert not test_callback.on_validation_end_called assert not test_callback.on_test_start_called @@ -197,6 +209,8 @@ def on_test_end(self, trainer, pl_module): assert not test_callback.on_test_batch_end_called assert not test_callback.on_train_start_called assert not test_callback.on_train_end_called + assert not test_callback.on_pretrain_routine_start_called + assert not test_callback.on_pretrain_routine_end_called assert not test_callback.on_validation_start_called assert not test_callback.on_validation_end_called assert not test_callback.on_test_start_called @@ -222,6 +236,8 @@ def on_test_end(self, trainer, pl_module): assert test_callback.on_validation_batch_end_called assert test_callback.on_train_start_called assert test_callback.on_train_end_called + assert test_callback.on_pretrain_routine_start_called + assert test_callback.on_pretrain_routine_end_called assert test_callback.on_validation_start_called assert test_callback.on_validation_end_called assert not test_callback.on_test_batch_start_called diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 779077c437585..713bdf3c3c2c4 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -153,25 +153,25 @@ class CurrentProgressBar(ProgressBar): val_batches_seen = 0 test_batches_seen = 0 - def on_train_batch_start(self, trainer, pl_module): - super().on_train_batch_start(trainer, pl_module) + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + super().on_train_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx) assert self.train_batch_idx == trainer.batch_idx - def on_train_batch_end(self, trainer, pl_module): - super().on_train_batch_end(trainer, pl_module) + def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + super().on_train_batch_end(trainer, pl_module, batch, batch_idx, dataloader_idx) assert self.train_batch_idx == trainer.batch_idx + 1 if not self.is_disabled and self.train_batch_idx % self.refresh_rate == 0: assert self.main_progress_bar.n == self.train_batch_idx self.train_batches_seen += 1 - def on_validation_batch_end(self, trainer, pl_module): - super().on_validation_batch_end(trainer, pl_module) + def on_validation_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + super().on_validation_batch_end(trainer, pl_module, batch, batch_idx, dataloader_idx) if not self.is_disabled and self.val_batch_idx % self.refresh_rate == 0: assert self.val_progress_bar.n == self.val_batch_idx self.val_batches_seen += 1 - def on_test_batch_end(self, trainer, pl_module): - super().on_test_batch_end(trainer, pl_module) + def on_test_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + super().on_test_batch_end(trainer, pl_module, batch, batch_idx, dataloader_idx) if not self.is_disabled and self.test_batch_idx % self.refresh_rate == 0: assert self.test_progress_bar.n == self.test_batch_idx self.test_batches_seen += 1 diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py index 7978aa8e41ace..90d0f3a89ee9a 100644 --- a/tests/loggers/test_all.py +++ b/tests/loggers/test_all.py @@ -5,13 +5,11 @@ import platform from unittest import mock -import cloudpickle import pytest import tests.base.develop_utils as tutils from pytorch_lightning import Trainer, Callback from pytorch_lightning.loggers import ( - CSVLogger, TensorBoardLogger, MLFlowLogger, NeptuneLogger, @@ -36,7 +34,6 @@ def _get_logger_args(logger_class, save_dir): @pytest.mark.parametrize("logger_class", [ TensorBoardLogger, - CSVLogger, CometLogger, MLFlowLogger, NeptuneLogger, @@ -88,7 +85,6 @@ def log_metrics(self, metrics, step): @pytest.mark.parametrize("logger_class", [ - CSVLogger, TensorBoardLogger, CometLogger, MLFlowLogger, @@ -152,7 +148,6 @@ def name(self): @pytest.mark.parametrize("logger_class", [ TensorBoardLogger, - CSVLogger, CometLogger, MLFlowLogger, NeptuneLogger, @@ -175,7 +170,6 @@ def test_loggers_pickle(tmpdir, monkeypatch, logger_class): # test pickling loggers pickle.dumps(logger) - cloudpickle.dumps(logger) trainer = Trainer( max_epochs=1, @@ -220,7 +214,7 @@ class RankZeroLoggerCheck(Callback): # this class has to be defined outside the test function, otherwise we get pickle error # due to the way ddp process is launched - def on_train_batch_start(self, trainer, pl_module): + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): is_dummy = isinstance(trainer.logger.experiment, DummyExperiment) if trainer.is_global_zero: assert not is_dummy @@ -232,7 +226,6 @@ def on_train_batch_start(self, trainer, pl_module): @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.parametrize("logger_class", [ TensorBoardLogger, - # CSVLogger, # todo CometLogger, MLFlowLogger, NeptuneLogger, diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index d6641c2f7ab24..2444d9905f6a1 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -368,7 +368,7 @@ def _new_model(): def increment_epoch(self): self.num_epochs_seen += 1 - def increment_batch(self, _): + def increment_batch(self, batch, batch_idx, dataloader_idx): self.num_batches_seen += 1 def increment_on_load_checkpoint(self, _): @@ -691,7 +691,7 @@ class InterruptCallback(Callback): def __init__(self): super().__init__() - def on_train_batch_start(self, trainer, pl_module): + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): raise KeyboardInterrupt class HandleInterruptCallback(Callback): @@ -988,17 +988,3 @@ def setup(self, stage): trainer.test(ckpt_path=None) assert trainer.stage == 'test' assert trainer.get_model().stage == 'test' - - -def test_trainer_ddp_spawn_none_checkpoint(tmpdir): - model = EvalModelTemplate() - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - checkpoint_callback=None, - distributed_backend="ddp_spawn" - ) - assert trainer.checkpoint_callback is None - result = trainer.fit(model) - assert trainer.checkpoint_callback is None - assert result == 1 diff --git a/tests/utilities/test_dtype_device_mixin.py b/tests/utilities/test_dtype_device_mixin.py index 08f808bda9ceb..ac6e4923a72ff 100644 --- a/tests/utilities/test_dtype_device_mixin.py +++ b/tests/utilities/test_dtype_device_mixin.py @@ -27,7 +27,7 @@ def __init__(self, *args, **kwargs): class DeviceAssertCallback(Callback): - def on_train_batch_start(self, trainer, model): + def on_train_batch_start(self, trainer, model, batch, batch_idx, dataloader_idx): rank = trainer.local_rank assert isinstance(model, TopModule) # index = None also means first device