Skip to content

Commit

Permalink
rename
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Apr 17, 2020
1 parent e0ddda5 commit dee5261
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from pytorch_lightning.metrics.converters import tensor_metric, numpy_metric
from pytorch_lightning.utilities.apply_func import apply_to_collection

__all__ = ['Metric', 'TensorMetric', 'NumpyMetric']
__all__ = ['MetricBase', 'TensorMetricBase', 'NumpyMetricBase']


class Metric(torch.nn.Module, ABC):
class MetricBase(torch.nn.Module, ABC):
"""
Abstract Base Class for metric implementation.
Expand Down Expand Up @@ -95,7 +95,7 @@ def to(self, *args, **kwargs) -> torch.nn.Module:
Module: self
Example::
>>> class ExampleMetric(Metric):
>>> class ExampleMetric(MetricBase):
... def __init__(self, weight: torch.Tensor):
... super().__init__('example')
... self.register_buffer('weight', weight)
Expand Down Expand Up @@ -199,7 +199,7 @@ def half(self) -> torch.nn.Module:
return super().half()


class TensorMetric(Metric):
class TensorMetricBase(MetricBase):
"""
Base class for metric implementation operating directly on tensors.
All inputs and outputs will be casted to tensors if necessary.
Expand Down Expand Up @@ -229,7 +229,7 @@ def _to_device_dtype(x: torch.Tensor) -> torch.Tensor:
_to_device_dtype)


class NumpyMetric(Metric):
class NumpyMetricBase(MetricBase):
"""
Base class for metric implementation operating on numpy arrays.
All inputs will be casted to numpy if necessary and all outputs will
Expand Down
8 changes: 4 additions & 4 deletions tests/metrics/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import numpy as np
import torch

from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric
from pytorch_lightning.metrics.metric import MetricBase, TensorMetricBase, NumpyMetricBase


class DummyTensorMetric(TensorMetric):
class DummyTensorMetric(TensorMetricBase):
def __init__(self):
super().__init__('dummy')

Expand All @@ -14,7 +14,7 @@ def forward(self, input1, input2):
return 1.


class DummyNumpyMetric(NumpyMetric):
class DummyNumpyMetric(NumpyMetricBase):
def __init__(self):
super().__init__('dummy')

Expand All @@ -24,7 +24,7 @@ def forward(self, input1, input2):
return 1.


def _test_metric(metric: Metric):
def _test_metric(metric: MetricBase):
input1, input2 = torch.tensor([1.]), torch.tensor([2.])

def change_and_check_device_dtype(device, dtype):
Expand Down

0 comments on commit dee5261

Please sign in to comment.