Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MetricsHolder clean-up + typing #6645

Merged
merged 3 commits into from
Mar 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,13 @@ def cached_results(self) -> Union[EpochResultStore, None]:
return self._cached_results.get(self.trainer._running_stage) # type: ignore

def get_metrics(self, key: str) -> Dict:
metrics_holder = getattr(self, f"_{key}", None)
model_ref = self.trainer.lightning_module
metrics_holder.convert(
self.trainer._device_type == DeviceType.TPU,
model_ref.device if model_ref is not None else model_ref,
)
metrics_holder: MetricsHolder = getattr(self, f"_{key}")
model = self.trainer.lightning_module
metrics_holder.convert(model.device if model is not None else None)
return metrics_holder.metrics

def set_metrics(self, key: str, val: Dict) -> None:
metrics_holder = getattr(self, f"_{key}", None)
metrics_holder: MetricsHolder = getattr(self, f"_{key}")
metrics_holder.reset(val)
carmocca marked this conversation as resolved.
Show resolved Hide resolved

def reset(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,43 +12,52 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numbers
from typing import Any
from typing import Any, Dict, Optional, Union

import torch
from torchmetrics import Metric

from pytorch_lightning.utilities.exceptions import MisconfigurationException

_METRIC_TYPE = Union[Metric, torch.Tensor, int, float, Any]


class MetricsHolder:
"""
This class acts as a dictonary holder.
This class acts as a dictionary holder.
It holds metrics and implements conversion functions.
Those functions will be triggered within LoggerConnector
when the property is being requested from the user.
"""

def __init__(self, to_float: bool = False):
self.metrics = {}
def __init__(self, to_float: bool = False) -> None:
self.metrics: Dict[str, _METRIC_TYPE] = {}
self._to_float = to_float

def update(self, metrics):
def update(self, metrics: dict) -> None:
self.metrics.update(metrics)

def pop(self, key, default):
def pop(self, key: str, default: _METRIC_TYPE) -> _METRIC_TYPE:
return self.metrics.pop(key, default)

def reset(self, metrics):
def reset(self, metrics: Dict[str, _METRIC_TYPE]) -> None:
self.metrics = metrics

def convert(self, use_tpu: bool, device: torch.device):
def convert(self, device: Optional[torch.device]) -> None:
for key, value in self.metrics.items():
self.metrics[key] = self._convert(value, use_tpu, device)

def _convert(self, current: Any, use_tpu: bool, device: torch.device):
if self._to_float:
return self._convert_to_float(current, use_tpu, device)
return self._convert_to_tensor(current, use_tpu, device)

def _convert_to_float(self, current, use_tpu: bool, device: torch.device):
if self._to_float:
if isinstance(value, torch.Tensor) and value.numel() != 1:
raise MisconfigurationException(
f"The metric `{key}` does not contain a single element"
f" thus it cannot be converted to float. Found `{value}`"
)
converted = self._convert_to_float(value)
else:
converted = self._convert_to_tensor(value, device)
self.metrics[key] = converted

@staticmethod
def _convert_to_float(current: _METRIC_TYPE) -> float:
if isinstance(current, Metric):
current = current.compute().detach()

Expand All @@ -60,16 +69,13 @@ def _convert_to_float(self, current, use_tpu: bool, device: torch.device):

return current

def _convert_to_tensor(self, current: Any, use_tpu: bool, device: torch.device):
if current is not None:
if isinstance(current, Metric):
current = current.compute().detach()
@staticmethod
def _convert_to_tensor(current: _METRIC_TYPE, device: Optional[torch.device]) -> torch.Tensor:
if isinstance(current, Metric):
current = current.compute().detach()

elif isinstance(current, numbers.Number):
if device is None:
current = torch.tensor(current, dtype=torch.float)
else:
current = torch.tensor(current, device=device, dtype=torch.float)
elif isinstance(current, numbers.Number):
current = torch.tensor(current, device=device, dtype=torch.float)

if isinstance(current, torch.Tensor) and current.device.type == "xla":
current = current.cpu()
Expand Down
32 changes: 27 additions & 5 deletions tests/trainer/logging_/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,13 +447,38 @@ def is_float(value: Any) -> bool:
"y": torch.tensor(2),
"z": acc(preds, targets),
})
metric_holder.convert(False, device)
metric_holder.convert(device)
metrics = metric_holder.metrics
assert excepted_function(metrics["x"])
assert excepted_function(metrics["y"])
assert excepted_function(metrics["z"])


def test_metric_holder_raises(tmpdir):
"""Check that an error is raised when trying to convert non-scalar tensors"""

class TestModel(BoringModel):

def validation_step(self, batch, *args, **kwargs):
output = self(batch)
return {"test": output}

def test_step(self, *args, **kwargs):
return self.validation_step(*args, **kwargs)

model = TestModel()
model.validation_epoch_end = None
model.test_epoch_end = None

trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)

match = "The metric `test` does not contain a single element"
with pytest.raises(MisconfigurationException, match=match):
trainer.validate(model)
with pytest.raises(MisconfigurationException, match=match):
trainer.test(model)


def test_logging_to_progress_bar_with_reserved_key(tmpdir):
""" Test that logging a metric with a reserved name to the progress bar raises a warning. """

Expand All @@ -465,10 +490,7 @@ def training_step(self, *args, **kwargs):
return output

model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=2,
)
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
with pytest.warns(UserWarning, match="The progress bar already tracks a metric with the .* 'loss'"):
trainer.fit(model)

Expand Down