diff --git a/CHANGELOG.md b/CHANGELOG.md index 23ba6c4f26411..9f81e97b3efbb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -221,6 +221,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `is_overridden(model=...)` in favor of `is_overridden(instance=...)` ([#7918](https://github.com/PyTorchLightning/pytorch-lightning/pull/7918)) +- Deprecated automatically detaching returned extras with grads ([#7994](https://github.com/PyTorchLightning/pytorch-lightning/pull/7994)) + + - Deprecated default value of `monitor` argument in EarlyStopping callback to enforce `monitor` as a required argument ([#7907](https://github.com/PyTorchLightning/pytorch-lightning/pull/7907)) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index f751916804ad1..096741d4a7486 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -26,12 +26,15 @@ from pytorch_lightning.utilities.enums import LightningEnum from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.metrics import metrics_to_scalars +from pytorch_lightning.utilities.warnings import WarningCache # re-define the ones from pytorch_lightning.utilities.types without the `Number` type # TODO(@tchaton): Typing-pickle issue on python<3.7 (https://github.com/cloudpipe/cloudpickle/pull/318) _METRIC = Any # Union[Metric, torch.Tensor] _METRIC_COLLECTION = Union[_METRIC, Mapping[str, _METRIC]] +warning_cache = WarningCache() + class MetricSource(LightningEnum): CALLBACK = "callback" @@ -367,9 +370,15 @@ def extra(self, extra: Mapping[str, Any]) -> None: def check_fn(v): if v.grad_fn is not None: - raise MisconfigurationException(f'You returned a tensor with `grad_fn`. The extra values are {extra}') + warning_cache.warn( + f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically" + " but this behaviour will change in v1.6. Please detach it manually:" + " `return {'loss': ..., 'something': something.detach()}`", DeprecationWarning + ) + return v.detach() + return v - apply_to_collection(extra, torch.Tensor, check_fn) + extra = apply_to_collection(extra, torch.Tensor, check_fn) self['_extra'] = extra def log( diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index cb150cb013ec2..2fa54f3b253fb 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -212,3 +212,18 @@ def test_v1_6_0_early_stopping_monitor(tmpdir): " For backward compatibility, setting this to `early_stop_on`." ): EarlyStopping() + + +def test_v1_6_0_extras_with_gradients(tmpdir): + + class TestModel(BoringModel): + + def training_step(self, *args): + loss = super().training_step(*args)['loss'] + return {"loss": loss, 'foo': loss} + + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1) + model = TestModel() + match = r"\{'foo'\} has a `grad_fn`.*behaviour will change in v1\.6" + with pytest.deprecated_call(match=match): + trainer.fit(model) diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index bff558e81b29e..8ca51b2dee3ef 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -690,17 +690,6 @@ def training_step(self, batch, batch_idx): with pytest.raises(MisconfigurationException, match='`self.log` with the key `foo/dataloader_idx_0`'): trainer.fit(model) - class TestModel(BoringModel): - - def training_step(self, *args): - loss = super().training_step(*args)['loss'] - return {"loss": loss, 'foo': loss} - - trainer = Trainer(default_root_dir=tmpdir) - model = TestModel() - with pytest.raises(MisconfigurationException, match='You returned a tensor with `grad_fn`'): - trainer.fit(model) - class TestModel(BoringModel): def training_step(self, *args):