Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hide members and make RecordSet picklable. #3209

Merged
merged 10 commits into from
Apr 18, 2024
95 changes: 67 additions & 28 deletions src/py/flwr/common/record/recordset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,64 +16,103 @@


from dataclasses import dataclass
from typing import Callable, Dict, Optional, Type, TypeVar
from typing import Dict, Optional, cast

from .configsrecord import ConfigsRecord
from .metricsrecord import MetricsRecord
from .parametersrecord import ParametersRecord
from .typeddict import TypedDict

T = TypeVar("T")

class RecordSetData:
"""Inner data container for the RecordSet class."""

@dataclass
class RecordSet:
"""RecordSet stores groups of parameters, metrics and configs."""

_parameters_records: TypedDict[str, ParametersRecord]
_metrics_records: TypedDict[str, MetricsRecord]
_configs_records: TypedDict[str, ConfigsRecord]
parameters_records: TypedDict[str, ParametersRecord]
metrics_records: TypedDict[str, MetricsRecord]
configs_records: TypedDict[str, ConfigsRecord]

def __init__(
self,
parameters_records: Optional[Dict[str, ParametersRecord]] = None,
metrics_records: Optional[Dict[str, MetricsRecord]] = None,
configs_records: Optional[Dict[str, ConfigsRecord]] = None,
) -> None:
def _get_check_fn(__t: Type[T]) -> Callable[[T], None]:
def _check_fn(__v: T) -> None:
if not isinstance(__v, __t):
raise TypeError(f"Expected `{__t}`, but `{type(__v)}` was passed.")

return _check_fn

self._parameters_records = TypedDict[str, ParametersRecord](
_get_check_fn(str), _get_check_fn(ParametersRecord)
self.parameters_records = TypedDict[str, ParametersRecord](
self._check_fn_str, self._check_fn_params
)
self._metrics_records = TypedDict[str, MetricsRecord](
_get_check_fn(str), _get_check_fn(MetricsRecord)
self.metrics_records = TypedDict[str, MetricsRecord](
self._check_fn_str, self._check_fn_metrics
)
self._configs_records = TypedDict[str, ConfigsRecord](
_get_check_fn(str), _get_check_fn(ConfigsRecord)
self.configs_records = TypedDict[str, ConfigsRecord](
self._check_fn_str, self._check_fn_configs
)
if parameters_records is not None:
self._parameters_records.update(parameters_records)
self.parameters_records.update(parameters_records)
if metrics_records is not None:
self._metrics_records.update(metrics_records)
self.metrics_records.update(metrics_records)
if configs_records is not None:
self._configs_records.update(configs_records)
self.configs_records.update(configs_records)

def _check_fn_str(self, key: str) -> None:
if not isinstance(key, str):
raise TypeError(
f"Expected `{str.__name__}`, but "
f"received `{type(key).__name__}` for the key."
)

def _check_fn_params(self, record: ParametersRecord) -> None:
if not isinstance(record, ParametersRecord):
raise TypeError(
f"Expected `{ParametersRecord.__name__}`, but "
f"received `{type(record).__name__}` for the value."
)

def _check_fn_metrics(self, record: MetricsRecord) -> None:
if not isinstance(record, MetricsRecord):
raise TypeError(
f"Expected `{MetricsRecord.__name__}`, but "
f"received `{type(record).__name__}` for the value."
)

def _check_fn_configs(self, record: ConfigsRecord) -> None:
if not isinstance(record, ConfigsRecord):
raise TypeError(
f"Expected `{ConfigsRecord.__name__}`, but "
f"received `{type(record).__name__}` for the value."
)


@dataclass
class RecordSet:
"""RecordSet stores groups of parameters, metrics and configs."""

def __init__(
self,
parameters_records: Optional[Dict[str, ParametersRecord]] = None,
metrics_records: Optional[Dict[str, MetricsRecord]] = None,
configs_records: Optional[Dict[str, ConfigsRecord]] = None,
) -> None:
data = RecordSetData(
parameters_records=parameters_records,
metrics_records=metrics_records,
configs_records=configs_records,
)
setattr(self, "_data", data) # noqa

@property
def parameters_records(self) -> TypedDict[str, ParametersRecord]:
"""Dictionary holding ParametersRecord instances."""
return self._parameters_records
data = cast(RecordSetData, getattr(self, "_data")) # noqa
return data.parameters_records

@property
def metrics_records(self) -> TypedDict[str, MetricsRecord]:
"""Dictionary holding MetricsRecord instances."""
return self._metrics_records
data = cast(RecordSetData, getattr(self, "_data")) # noqa
return data.metrics_records

@property
def configs_records(self) -> TypedDict[str, ConfigsRecord]:
"""Dictionary holding ConfigsRecord instances."""
return self._configs_records
data = cast(RecordSetData, getattr(self, "_data")) # noqa
return data.configs_records
18 changes: 17 additions & 1 deletion src/py/flwr/common/record/recordset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ==============================================================================
"""RecordSet tests."""

import pickle
from copy import deepcopy
from typing import Callable, Dict, List, OrderedDict, Type, Union

Expand All @@ -33,7 +34,7 @@
Parameters,
)

from . import Array, ConfigsRecord, MetricsRecord, ParametersRecord
from . import Array, ConfigsRecord, MetricsRecord, ParametersRecord, RecordSet


def get_ndarrays() -> NDArrays:
Expand Down Expand Up @@ -398,3 +399,18 @@ def test_count_bytes_configsrecord() -> None:

record_bytest_count = c_record.count_bytes()
assert bytes_in_dict == record_bytest_count


def test_record_is_picklable() -> None:
"""Test if RecordSet and *Record are picklable."""
# Prepare
p_record = ParametersRecord()
m_record = MetricsRecord({"aa": 123})
c_record = ConfigsRecord({"cc": bytes(9)})
rs = RecordSet()
rs.parameters_records["params"] = p_record
rs.metrics_records["metrics"] = m_record
rs.configs_records["configs"] = c_record

# Execute
pickle.dumps((p_record, m_record, c_record, rs))