Skip to content

Commit

Permalink
MetricsHolder clean-up + typing (#6645)
Browse files Browse the repository at this point in the history
* Metrics holder cleanup and better error message

* Update pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

* _VALUE -> _METRIC_TYPE
  • Loading branch information
carmocca authored Mar 24, 2021
1 parent d471fa3 commit 2dd6f9e
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,13 @@ def cached_results(self) -> Union[EpochResultStore, None]:
return self._cached_results.get(self.trainer._running_stage) # type: ignore

def get_metrics(self, key: str) -> Dict:
metrics_holder = getattr(self, f"_{key}", None)
model_ref = self.trainer.lightning_module
metrics_holder.convert(
self.trainer._device_type == DeviceType.TPU,
model_ref.device if model_ref is not None else model_ref,
)
metrics_holder: MetricsHolder = getattr(self, f"_{key}")
model = self.trainer.lightning_module
metrics_holder.convert(model.device if model is not None else None)
return metrics_holder.metrics

def set_metrics(self, key: str, val: Dict) -> None:
metrics_holder = getattr(self, f"_{key}", None)
metrics_holder: MetricsHolder = getattr(self, f"_{key}")
metrics_holder.reset(val)

def reset(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,43 +12,52 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numbers
from typing import Any
from typing import Any, Dict, Optional, Union

import torch
from torchmetrics import Metric

from pytorch_lightning.utilities.exceptions import MisconfigurationException

_METRIC_TYPE = Union[Metric, torch.Tensor, int, float, Any]


class MetricsHolder:
"""
This class acts as a dictonary holder.
This class acts as a dictionary holder.
It holds metrics and implements conversion functions.
Those functions will be triggered within LoggerConnector
when the property is being requested from the user.
"""

def __init__(self, to_float: bool = False):
self.metrics = {}
def __init__(self, to_float: bool = False) -> None:
self.metrics: Dict[str, _METRIC_TYPE] = {}
self._to_float = to_float

def update(self, metrics):
def update(self, metrics: dict) -> None:
self.metrics.update(metrics)

def pop(self, key, default):
def pop(self, key: str, default: _METRIC_TYPE) -> _METRIC_TYPE:
return self.metrics.pop(key, default)

def reset(self, metrics):
def reset(self, metrics: Dict[str, _METRIC_TYPE]) -> None:
self.metrics = metrics

def convert(self, use_tpu: bool, device: torch.device):
def convert(self, device: Optional[torch.device]) -> None:
for key, value in self.metrics.items():
self.metrics[key] = self._convert(value, use_tpu, device)

def _convert(self, current: Any, use_tpu: bool, device: torch.device):
if self._to_float:
return self._convert_to_float(current, use_tpu, device)
return self._convert_to_tensor(current, use_tpu, device)

def _convert_to_float(self, current, use_tpu: bool, device: torch.device):
if self._to_float:
if isinstance(value, torch.Tensor) and value.numel() != 1:
raise MisconfigurationException(
f"The metric `{key}` does not contain a single element"
f" thus it cannot be converted to float. Found `{value}`"
)
converted = self._convert_to_float(value)
else:
converted = self._convert_to_tensor(value, device)
self.metrics[key] = converted

@staticmethod
def _convert_to_float(current: _METRIC_TYPE) -> float:
if isinstance(current, Metric):
current = current.compute().detach()

Expand All @@ -60,16 +69,13 @@ def _convert_to_float(self, current, use_tpu: bool, device: torch.device):

return current

def _convert_to_tensor(self, current: Any, use_tpu: bool, device: torch.device):
if current is not None:
if isinstance(current, Metric):
current = current.compute().detach()
@staticmethod
def _convert_to_tensor(current: _METRIC_TYPE, device: Optional[torch.device]) -> torch.Tensor:
if isinstance(current, Metric):
current = current.compute().detach()

elif isinstance(current, numbers.Number):
if device is None:
current = torch.tensor(current, dtype=torch.float)
else:
current = torch.tensor(current, device=device, dtype=torch.float)
elif isinstance(current, numbers.Number):
current = torch.tensor(current, device=device, dtype=torch.float)

if isinstance(current, torch.Tensor) and current.device.type == "xla":
current = current.cpu()
Expand Down
32 changes: 27 additions & 5 deletions tests/trainer/logging_/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,13 +447,38 @@ def is_float(value: Any) -> bool:
"y": torch.tensor(2),
"z": acc(preds, targets),
})
metric_holder.convert(False, device)
metric_holder.convert(device)
metrics = metric_holder.metrics
assert excepted_function(metrics["x"])
assert excepted_function(metrics["y"])
assert excepted_function(metrics["z"])


def test_metric_holder_raises(tmpdir):
"""Check that an error is raised when trying to convert non-scalar tensors"""

class TestModel(BoringModel):

def validation_step(self, batch, *args, **kwargs):
output = self(batch)
return {"test": output}

def test_step(self, *args, **kwargs):
return self.validation_step(*args, **kwargs)

model = TestModel()
model.validation_epoch_end = None
model.test_epoch_end = None

trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)

match = "The metric `test` does not contain a single element"
with pytest.raises(MisconfigurationException, match=match):
trainer.validate(model)
with pytest.raises(MisconfigurationException, match=match):
trainer.test(model)


def test_logging_to_progress_bar_with_reserved_key(tmpdir):
""" Test that logging a metric with a reserved name to the progress bar raises a warning. """

Expand All @@ -465,10 +490,7 @@ def training_step(self, *args, **kwargs):
return output

model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=2,
)
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
with pytest.warns(UserWarning, match="The progress bar already tracks a metric with the .* 'loss'"):
trainer.fit(model)

Expand Down

0 comments on commit 2dd6f9e

Please sign in to comment.