diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 15428c5d5c248..0d0c3781c7724 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -81,16 +81,13 @@ def cached_results(self) -> Union[EpochResultStore, None]: return self._cached_results.get(self.trainer._running_stage) # type: ignore def get_metrics(self, key: str) -> Dict: - metrics_holder = getattr(self, f"_{key}", None) - model_ref = self.trainer.lightning_module - metrics_holder.convert( - self.trainer._device_type == DeviceType.TPU, - model_ref.device if model_ref is not None else model_ref, - ) + metrics_holder: MetricsHolder = getattr(self, f"_{key}") + model = self.trainer.lightning_module + metrics_holder.convert(model.device if model is not None else None) return metrics_holder.metrics def set_metrics(self, key: str, val: Dict) -> None: - metrics_holder = getattr(self, f"_{key}", None) + metrics_holder: MetricsHolder = getattr(self, f"_{key}") metrics_holder.reset(val) def reset(self) -> None: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py b/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py index 554f1d3faf9ed..1efbcc638674f 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py @@ -12,43 +12,52 @@ # See the License for the specific language governing permissions and # limitations under the License. import numbers -from typing import Any +from typing import Any, Dict, Optional, Union import torch from torchmetrics import Metric +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +_METRIC_TYPE = Union[Metric, torch.Tensor, int, float, Any] + class MetricsHolder: """ - This class acts as a dictonary holder. + This class acts as a dictionary holder. It holds metrics and implements conversion functions. Those functions will be triggered within LoggerConnector when the property is being requested from the user. """ - def __init__(self, to_float: bool = False): - self.metrics = {} + def __init__(self, to_float: bool = False) -> None: + self.metrics: Dict[str, _METRIC_TYPE] = {} self._to_float = to_float - def update(self, metrics): + def update(self, metrics: dict) -> None: self.metrics.update(metrics) - def pop(self, key, default): + def pop(self, key: str, default: _METRIC_TYPE) -> _METRIC_TYPE: return self.metrics.pop(key, default) - def reset(self, metrics): + def reset(self, metrics: Dict[str, _METRIC_TYPE]) -> None: self.metrics = metrics - def convert(self, use_tpu: bool, device: torch.device): + def convert(self, device: Optional[torch.device]) -> None: for key, value in self.metrics.items(): - self.metrics[key] = self._convert(value, use_tpu, device) - - def _convert(self, current: Any, use_tpu: bool, device: torch.device): - if self._to_float: - return self._convert_to_float(current, use_tpu, device) - return self._convert_to_tensor(current, use_tpu, device) - - def _convert_to_float(self, current, use_tpu: bool, device: torch.device): + if self._to_float: + if isinstance(value, torch.Tensor) and value.numel() != 1: + raise MisconfigurationException( + f"The metric `{key}` does not contain a single element" + f" thus it cannot be converted to float. Found `{value}`" + ) + converted = self._convert_to_float(value) + else: + converted = self._convert_to_tensor(value, device) + self.metrics[key] = converted + + @staticmethod + def _convert_to_float(current: _METRIC_TYPE) -> float: if isinstance(current, Metric): current = current.compute().detach() @@ -60,16 +69,13 @@ def _convert_to_float(self, current, use_tpu: bool, device: torch.device): return current - def _convert_to_tensor(self, current: Any, use_tpu: bool, device: torch.device): - if current is not None: - if isinstance(current, Metric): - current = current.compute().detach() + @staticmethod + def _convert_to_tensor(current: _METRIC_TYPE, device: Optional[torch.device]) -> torch.Tensor: + if isinstance(current, Metric): + current = current.compute().detach() - elif isinstance(current, numbers.Number): - if device is None: - current = torch.tensor(current, dtype=torch.float) - else: - current = torch.tensor(current, device=device, dtype=torch.float) + elif isinstance(current, numbers.Number): + current = torch.tensor(current, device=device, dtype=torch.float) if isinstance(current, torch.Tensor) and current.device.type == "xla": current = current.cpu() diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 3db0a8eaa065b..d14ed71940328 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -447,13 +447,38 @@ def is_float(value: Any) -> bool: "y": torch.tensor(2), "z": acc(preds, targets), }) - metric_holder.convert(False, device) + metric_holder.convert(device) metrics = metric_holder.metrics assert excepted_function(metrics["x"]) assert excepted_function(metrics["y"]) assert excepted_function(metrics["z"]) +def test_metric_holder_raises(tmpdir): + """Check that an error is raised when trying to convert non-scalar tensors""" + + class TestModel(BoringModel): + + def validation_step(self, batch, *args, **kwargs): + output = self(batch) + return {"test": output} + + def test_step(self, *args, **kwargs): + return self.validation_step(*args, **kwargs) + + model = TestModel() + model.validation_epoch_end = None + model.test_epoch_end = None + + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + + match = "The metric `test` does not contain a single element" + with pytest.raises(MisconfigurationException, match=match): + trainer.validate(model) + with pytest.raises(MisconfigurationException, match=match): + trainer.test(model) + + def test_logging_to_progress_bar_with_reserved_key(tmpdir): """ Test that logging a metric with a reserved name to the progress bar raises a warning. """ @@ -465,10 +490,7 @@ def training_step(self, *args, **kwargs): return output model = TestModel() - trainer = Trainer( - default_root_dir=tmpdir, - max_steps=2, - ) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) with pytest.warns(UserWarning, match="The progress bar already tracks a metric with the .* 'loss'"): trainer.fit(model)