Skip to content

Commit

Permalink
Rework some of the result classes (#288)
Browse files Browse the repository at this point in the history
* introduce better hierarchy in result base classes

Signed-off-by: Niels Nuyttens <niels@nannyml.com>

* Fix weird breaking test

Signed-off-by: Niels Nuyttens <niels@nannyml.com>

* Same treatment for 2D tests

Signed-off-by: Niels Nuyttens <niels@nannyml.com>

---------

Signed-off-by: Niels Nuyttens <niels@nannyml.com>
  • Loading branch information
nnansters authored May 4, 2023
1 parent 18fe766 commit 0a2bc14
Show file tree
Hide file tree
Showing 11 changed files with 77 additions and 82 deletions.
98 changes: 54 additions & 44 deletions nannyml/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,9 @@ def _get_property_for_key(self, key: Key, property_name: str) -> Optional[pd.Ser
return self.data.get(key.properties + (property_name,), default=None)


class Abstract1DResult(AbstractResult, ABC, Generic[MetricLike]):
def __init__(self, results_data: pd.DataFrame, metrics: list[MetricLike] = [], *args, **kwargs):
class Abstract1DResult(AbstractResult, ABC):
def __init__(self, results_data: pd.DataFrame, *args, **kwargs):
super().__init__(results_data)
self.metrics = metrics

@property
def chunk_keys(self) -> pd.Series:
Expand All @@ -149,23 +148,38 @@ def chunk_indices(self) -> pd.Series:
def chunk_periods(self) -> pd.Series:
return self.data[('chunk', 'period')]

def _filter(self, period: str, *args, **kwargs) -> Self:
data = self.data
if period != 'all':
data = self.data.loc[self.data.loc[:, ('chunk', 'period')] == period, :]
data = data.reset_index(drop=True)

res = copy.deepcopy(self)
res.data = data
return res


class PerMetricResult(Abstract1DResult, ABC, Generic[MetricLike]):
def __init__(self, results_data: pd.DataFrame, metrics: list[MetricLike] = [], *args, **kwargs):
super().__init__(results_data)
self.metrics = metrics

def _filter(self, period: str, metrics: Optional[List[str]] = None, *args, **kwargs) -> Self:
if metrics is None:
metrics = [metric.column_name for metric in self.metrics]

data = pd.concat([self.data.loc[:, (['chunk'])], self.data.loc[:, (metrics,)]], axis=1)
if period != 'all':
data = data.loc[self.data.loc[:, ('chunk', 'period')] == period, :]
res = super()._filter(period, args, kwargs)

data = pd.concat([res.data.loc[:, (['chunk'])], res.data.loc[:, (metrics,)]], axis=1)
data = data.reset_index(drop=True)

res = copy.deepcopy(self)
res.data = data
res.metrics = [metric for metric in self.metrics if metric.column_name in metrics]

return res


class Abstract1DColumnsResult(AbstractResult, ABC, Generic[MetricLike]):
class PerColumnResult(Abstract1DResult, ABC):
def __init__(self, results_data: pd.DataFrame, column_names: Union[str, List[str]] = [], *args, **kwargs):
super().__init__(results_data)
if isinstance(column_names, str):
Expand All @@ -175,26 +189,6 @@ def __init__(self, results_data: pd.DataFrame, column_names: Union[str, List[str
else:
raise TypeError("column_names should be either a column name string or a list of strings.")

@property
def chunk_keys(self) -> pd.Series:
return self.data[('chunk', 'key')]

@property
def chunk_start_dates(self) -> pd.Series:
return self.data[('chunk', 'start_date')]

@property
def chunk_end_dates(self) -> pd.Series:
return self.data[('chunk', 'end_date')]

@property
def chunk_indices(self) -> pd.Series:
return self.data[('chunk', 'chunk_index')]

@property
def chunk_periods(self) -> pd.Series:
return self.data[('chunk', 'period')]

def _filter(
self,
period: str,
Expand All @@ -212,27 +206,19 @@ def _filter(
else:
raise TypeError("column_names should be either a column name string or a list of strings.")

# is column names loc argument correct? likely
# data = pd.concat([self.data.loc[:, (['chunk'])], self.data.loc[:, (metrics,)]], axis=1)
data = pd.concat([self.data.loc[:, (['chunk'])], self.data.loc[:, (column_names,)]], axis=1)
if period != 'all':
data = data.loc[self.data.loc[:, ('chunk', 'period')] == period, :]
res = super()._filter(period, args, kwargs)

data = pd.concat([res.data.loc[:, (['chunk'])], res.data.loc[:, (column_names,)]], axis=1)
data = data.reset_index(drop=True)

res = copy.deepcopy(self)
res.data = data
res.column_names = [c for c in self.column_names if c in column_names]
return res


class Abstract2DResult(AbstractResult, ABC, Generic[MetricLike]):
def __init__(
self, results_data: pd.DataFrame, metrics: list[MetricLike] = [], column_names: List[str] = [], *args, **kwargs
):
class Abstract2DResult(AbstractResult, ABC):
def __init__(self, results_data: pd.DataFrame, *args, **kwargs):
super().__init__(results_data)
self.metrics = metrics
self.column_names = column_names

@property
def chunk_keys(self) -> pd.Series:
Expand All @@ -254,6 +240,31 @@ def chunk_indices(self) -> pd.Series:
def chunk_periods(self) -> pd.Series:
return self.data[('chunk', 'chunk', 'period')]

def _filter(
self,
period: str,
*args,
**kwargs,
) -> Self:
data = self.data
if period != 'all':
data = data.loc[self.data.loc[:, ('chunk', 'chunk', 'period')] == period, :]
data = data.reset_index(drop=True)

res = copy.deepcopy(self)
res.data = data

return res


class PerMetricPerColumnResult(Abstract2DResult, ABC, Generic[MetricLike]):
def __init__(
self, results_data: pd.DataFrame, metrics: list[MetricLike] = [], column_names: List[str] = [], *args, **kwargs
):
super().__init__(results_data)
self.metrics = metrics
self.column_names = column_names

def _filter(
self,
period: str,
Expand All @@ -267,16 +278,15 @@ def _filter(
if column_names is None:
column_names = self.column_names

data = pd.concat([self.data.loc[:, (['chunk'])], self.data.loc[:, (column_names, metrics)]], axis=1)
if period != 'all':
data = data.loc[self.data.loc[:, ('chunk', 'chunk', 'period')] == period, :]
res = super()._filter(period, args, kwargs)

data = pd.concat([res.data.loc[:, (['chunk'])], res.data.loc[:, (column_names, metrics)]], axis=1)
data = data.reset_index(drop=True)

res = copy.deepcopy(self)
res.data = data
res.metrics = [metric for metric in self.metrics if metric.column_name in metrics]
res.column_names = [c for c in self.column_names if c in column_names]

return res


Expand Down
4 changes: 2 additions & 2 deletions nannyml/data_quality/missing/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
import plotly.graph_objects as go

from nannyml._typing import Key
from nannyml.base import Abstract1DColumnsResult
from nannyml.base import PerColumnResult
from nannyml.chunk import Chunker
from nannyml.plots.blueprints.comparisons import ResultCompareMixin
from nannyml.plots.blueprints.metrics import plot_metrics
from nannyml.usage_logging import UsageEvent, log_usage


class Result(Abstract1DColumnsResult, ResultCompareMixin):
class Result(PerColumnResult, ResultCompareMixin):
"""Contains the results of the univariate statistical drift calculation and provides plotting functionality."""

def __init__(
Expand Down
4 changes: 2 additions & 2 deletions nannyml/data_quality/unseen/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import plotly.graph_objects as go

from nannyml._typing import Key
from nannyml.base import Abstract1DColumnsResult
from nannyml.base import PerColumnResult
from nannyml.chunk import Chunker

# from nannyml.exceptions import InvalidArgumentsException
Expand All @@ -26,7 +26,7 @@
from nannyml.usage_logging import UsageEvent, log_usage


class Result(Abstract1DColumnsResult, ResultCompareMixin):
class Result(PerColumnResult, ResultCompareMixin):
"""Contains the results of the univariate statistical drift calculation and provides plotting functionality."""

def __init__(
Expand Down
4 changes: 2 additions & 2 deletions nannyml/drift/multivariate/data_reconstruction/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import plotly.graph_objects as go

from nannyml._typing import Key
from nannyml.base import Abstract1DResult
from nannyml.base import PerMetricResult
from nannyml.exceptions import InvalidArgumentsException
from nannyml.plots.blueprints.comparisons import ResultCompareMixin
from nannyml.plots.blueprints.metrics import plot_metric
Expand All @@ -21,7 +21,7 @@
Metric = namedtuple("Metric", "display_name column_name")


class Result(Abstract1DResult[Metric], ResultCompareMixin):
class Result(PerMetricResult[Metric], ResultCompareMixin):
"""Class wrapping the results of the data reconstruction drift calculator and providing plotting functionality."""

def __init__(
Expand Down
4 changes: 2 additions & 2 deletions nannyml/drift/univariate/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import plotly.graph_objects as go

from nannyml._typing import Key
from nannyml.base import Abstract2DResult
from nannyml.base import PerMetricPerColumnResult
from nannyml.chunk import Chunker, DefaultChunker
from nannyml.drift.univariate.methods import FeatureType, Method, MethodFactory
from nannyml.exceptions import InvalidArgumentsException
Expand All @@ -28,7 +28,7 @@
from nannyml.usage_logging import UsageEvent, log_usage


class Result(Abstract2DResult[Method], ResultCompareMixin):
class Result(PerMetricPerColumnResult[Method], ResultCompareMixin):
"""Class wrapping the results of the univariate drift calculator and providing plotting functionality."""

def __init__(
Expand Down
16 changes: 5 additions & 11 deletions nannyml/performance_calculation/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import pandas as pd
import plotly.graph_objects as go

from nannyml._typing import Key, ProblemType
from nannyml.base import Abstract1DResult
from nannyml._typing import Key, ProblemType, Self
from nannyml.base import PerMetricResult
from nannyml.exceptions import InvalidArgumentsException
from nannyml.performance_calculation import SUPPORTED_METRIC_FILTER_VALUES
from nannyml.performance_calculation.metrics.base import Metric
Expand All @@ -21,7 +21,7 @@
from nannyml.usage_logging import UsageEvent, log_usage


class Result(Abstract1DResult[Metric], ResultCompareMixin):
class Result(PerMetricResult[Metric], ResultCompareMixin):
"""Wraps performance calculation results and provides filtering and plotting functionality."""

metrics: List[Metric]
Expand Down Expand Up @@ -164,7 +164,7 @@ def plot(
else:
raise InvalidArgumentsException(f"unknown plot kind '{kind}'. " f"Please provide on of: ['performance'].")

def _filter(self, period: str, metrics: Optional[List[str]] = None, *args, **kwargs) -> Result:
def _filter(self, period: str, metrics: Optional[List[str]] = None, *args, **kwargs) -> Self:
"""Filter the results based on the specified period and metrics."""
if metrics is None:
filtered_metrics = self.metrics
Expand All @@ -183,13 +183,7 @@ def _filter(self, period: str, metrics: Optional[List[str]] = None, *args, **kwa

metric_column_names = [name for metric in filtered_metrics for name in metric.column_names]

data = pd.concat([self.data.loc[:, (['chunk'])], self.data.loc[:, (metric_column_names,)]], axis=1)
if period != 'all':
data = data.loc[data.loc[:, ('chunk', 'period')] == period, :]

data = data.reset_index(drop=True)
res = copy.deepcopy(self)
res.data = data
res = super()._filter(period, metric_column_names, args, kwargs)
res.metrics = filtered_metrics

return res
Expand Down
2 changes: 1 addition & 1 deletion nannyml/performance_estimation/confidence_based/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def display_name(self) -> str:

@property
def column_name(self) -> str:
return self.components[0][0]
return self.components[0][1]

@property
def display_names(self):
Expand Down
17 changes: 5 additions & 12 deletions nannyml/performance_estimation/confidence_based/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import pandas as pd
from plotly import graph_objects as go

from nannyml._typing import Key, ModelOutputsType, ProblemType
from nannyml.base import Abstract1DResult
from nannyml._typing import Key, ModelOutputsType, ProblemType, Self
from nannyml.base import PerMetricResult
from nannyml.chunk import Chunker
from nannyml.exceptions import InvalidArgumentsException
from nannyml.performance_estimation.confidence_based import SUPPORTED_METRIC_FILTER_VALUES
Expand All @@ -23,7 +23,7 @@
from nannyml.usage_logging import UsageEvent, log_usage


class Result(Abstract1DResult[Metric], ResultCompareMixin):
class Result(PerMetricResult[Metric], ResultCompareMixin):
"""Contains results for CBPE estimation and adds filtering and plotting functionality."""

def __init__(
Expand Down Expand Up @@ -72,7 +72,7 @@ def __init__(
self.problem_type = problem_type
self.chunker = chunker

def _filter(self, period: str, metrics: Optional[List[str]] = None, *args, **kwargs) -> Result:
def _filter(self, period: str, metrics: Optional[List[str]] = None, *args, **kwargs) -> Self:
"""Filter the results based on the specified period and metrics.
This function begins by expanding the metrics to all the metrics that were specified
Expand Down Expand Up @@ -102,14 +102,7 @@ def _filter(self, period: str, metrics: Optional[List[str]] = None, *args, **kwa

metric_column_names = [name for metric in filtered_metrics for name in metric.column_names]

data = pd.concat([self.data.loc[:, (['chunk'])], self.data.loc[:, (metric_column_names,)]], axis=1)
if period != 'all':
data = data.loc[data.loc[:, ('chunk', 'period')] == period, :]

data = data.reset_index(drop=True)
res = copy.deepcopy(self)
res.data = data
res.metrics = filtered_metrics
res = super()._filter(period, metric_column_names, args, kwargs)

return res

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from plotly.graph_objects import Figure

from nannyml._typing import Key
from nannyml.base import Abstract1DResult
from nannyml.base import PerMetricResult
from nannyml.chunk import Chunker
from nannyml.exceptions import InvalidArgumentsException
from nannyml.performance_estimation.direct_loss_estimation.metrics import Metric
Expand All @@ -16,7 +16,7 @@
from nannyml.usage_logging import UsageEvent, log_usage


class Result(Abstract1DResult[Metric], ResultCompareMixin):
class Result(PerMetricResult[Metric], ResultCompareMixin):
"""Contains results for CBPE estimation and adds filtering and plotting functionality."""

def __init__(
Expand Down
5 changes: 2 additions & 3 deletions tests/drift/test_univariate_drift_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,9 @@ def test_method_logs_warning_when_lower_threshold_is_overridden_by_metric_limits
reference = pd.Series(np.random.normal(0, 1, 1000), name='A')
method.fit(reference)

assert len(caplog.messages) == 1
assert (
caplog.messages[0] == f'{method.display_name} lower threshold value -1 overridden by '
f'lower threshold value limit {method.lower_threshold_value_limit}'
f'{method.display_name} lower threshold value -1 overridden by '
f'lower threshold value limit {method.lower_threshold_value_limit}' in caplog.messages
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,6 @@ def test_metric_logs_warning_when_lower_threshold_is_overridden_by_metric_limits
)
metric.fit(reference, chunker=DefaultChunker())

assert len(caplog.messages) == 1
assert (
f'{metric.display_name} lower threshold value -1 overridden by '
f'lower threshold value limit {metric.lower_threshold_value_limit}' in caplog.messages
Expand Down

0 comments on commit 0a2bc14

Please sign in to comment.