Skip to content

Commit

Permalink
Deprecate returning extras with grads (#7994)
Browse files Browse the repository at this point in the history
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
carmocca and awaelchli authored Jun 18, 2021
1 parent f447839 commit a23a699
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 13 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `is_overridden(model=...)` in favor of `is_overridden(instance=...)` ([#7918](https://github.com/PyTorchLightning/pytorch-lightning/pull/7918))


- Deprecated automatically detaching returned extras with grads ([#7994](https://github.com/PyTorchLightning/pytorch-lightning/pull/7994))


- Deprecated default value of `monitor` argument in EarlyStopping callback to enforce `monitor` as a required argument ([#7907](https://github.com/PyTorchLightning/pytorch-lightning/pull/7907))


Expand Down
13 changes: 11 additions & 2 deletions pytorch_lightning/trainer/connectors/logger_connector/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,15 @@
from pytorch_lightning.utilities.enums import LightningEnum
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.metrics import metrics_to_scalars
from pytorch_lightning.utilities.warnings import WarningCache

# re-define the ones from pytorch_lightning.utilities.types without the `Number` type
# TODO(@tchaton): Typing-pickle issue on python<3.7 (https://github.com/cloudpipe/cloudpickle/pull/318)
_METRIC = Any # Union[Metric, torch.Tensor]
_METRIC_COLLECTION = Union[_METRIC, Mapping[str, _METRIC]]

warning_cache = WarningCache()


class MetricSource(LightningEnum):
CALLBACK = "callback"
Expand Down Expand Up @@ -367,9 +370,15 @@ def extra(self, extra: Mapping[str, Any]) -> None:

def check_fn(v):
if v.grad_fn is not None:
raise MisconfigurationException(f'You returned a tensor with `grad_fn`. The extra values are {extra}')
warning_cache.warn(
f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically"
" but this behaviour will change in v1.6. Please detach it manually:"
" `return {'loss': ..., 'something': something.detach()}`", DeprecationWarning
)
return v.detach()
return v

apply_to_collection(extra, torch.Tensor, check_fn)
extra = apply_to_collection(extra, torch.Tensor, check_fn)
self['_extra'] = extra

def log(
Expand Down
15 changes: 15 additions & 0 deletions tests/deprecated_api/test_remove_1-6.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,18 @@ def test_v1_6_0_early_stopping_monitor(tmpdir):
" For backward compatibility, setting this to `early_stop_on`."
):
EarlyStopping()


def test_v1_6_0_extras_with_gradients(tmpdir):

class TestModel(BoringModel):

def training_step(self, *args):
loss = super().training_step(*args)['loss']
return {"loss": loss, 'foo': loss}

trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
model = TestModel()
match = r"\{'foo'\} has a `grad_fn`.*behaviour will change in v1\.6"
with pytest.deprecated_call(match=match):
trainer.fit(model)
11 changes: 0 additions & 11 deletions tests/trainer/logging_/test_train_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,17 +690,6 @@ def training_step(self, batch, batch_idx):
with pytest.raises(MisconfigurationException, match='`self.log` with the key `foo/dataloader_idx_0`'):
trainer.fit(model)

class TestModel(BoringModel):

def training_step(self, *args):
loss = super().training_step(*args)['loss']
return {"loss": loss, 'foo': loss}

trainer = Trainer(default_root_dir=tmpdir)
model = TestModel()
with pytest.raises(MisconfigurationException, match='You returned a tensor with `grad_fn`'):
trainer.fit(model)

class TestModel(BoringModel):

def training_step(self, *args):
Expand Down

0 comments on commit a23a699

Please sign in to comment.