diff --git a/lib/python/flame/mode/composer.py b/lib/python/flame/mode/composer.py index c62a854f6..97838bd3d 100644 --- a/lib/python/flame/mode/composer.py +++ b/lib/python/flame/mode/composer.py @@ -20,6 +20,7 @@ from types import TracebackType from typing import Optional, Type from flame.mode.enums import LoopIndicator +from flame.mode.role import Role logger = logging.getLogger(__name__) @@ -34,6 +35,8 @@ def __init__(self) -> None: self.reverse_chain = dict() self.unlinked_tasklets = dict() + + self.mc = Role.mc def __enter__(self): """Enter custom context.""" diff --git a/lib/python/flame/mode/distributed/trainer.py b/lib/python/flame/mode/distributed/trainer.py index eb45fd233..e37291220 100644 --- a/lib/python/flame/mode/distributed/trainer.py +++ b/lib/python/flame/mode/distributed/trainer.py @@ -428,6 +428,7 @@ def put(self, tag: str) -> None: def save_metrics(self): """Save metrics in a model registry.""" + self.metrics = self.metrics | self.mc.get() logger.debug(f"saving metrics: {self.metrics}") if self.metrics: self.registry_client.save_metrics(self._round, self.metrics) diff --git a/lib/python/flame/mode/horizontal/coord_syncfl/top_aggregator.py b/lib/python/flame/mode/horizontal/coord_syncfl/top_aggregator.py index 10fb83097..88ce4220d 100644 --- a/lib/python/flame/mode/horizontal/coord_syncfl/top_aggregator.py +++ b/lib/python/flame/mode/horizontal/coord_syncfl/top_aggregator.py @@ -22,7 +22,7 @@ TopAggregator as BaseTopAggregator, ) from flame.mode.message import MessageType -from flame.mode.tasklet import Loop, Tasklet +from flame.mode.tasklet import Tasklet logger = logging.getLogger(__name__) diff --git a/lib/python/flame/mode/horizontal/feddyn/top_aggregator.py b/lib/python/flame/mode/horizontal/feddyn/top_aggregator.py index 782d8a683..5654a1f69 100644 --- a/lib/python/flame/mode/horizontal/feddyn/top_aggregator.py +++ b/lib/python/flame/mode/horizontal/feddyn/top_aggregator.py @@ -176,33 +176,35 @@ def compose(self) -> None: with Composer() as composer: self.composer = composer - task_internal_init = Tasklet("", self.internal_init) + task_internal_init = Tasklet("internal_init", self.internal_init) - task_init = Tasklet("", self.initialize) + task_init = Tasklet("initialize", self.initialize) - task_load_data = Tasklet("", self.load_data) + task_load_data = Tasklet("load_data", self.load_data) - task_get_dataset = Tasklet("", self.get, TAG_GET_DATATSET_SIZE) + task_get_dataset = Tasklet("get_dataset_size", self.get, TAG_GET_DATATSET_SIZE) - task_put = Tasklet("", self.put, TAG_DISTRIBUTE) + task_put = Tasklet("distribute", self.put, TAG_DISTRIBUTE) - task_get = Tasklet("", self.get, TAG_AGGREGATE) + task_get = Tasklet("aggregate", self.get, TAG_AGGREGATE) - task_train = Tasklet("", self.train) + task_train = Tasklet("train", self.train) - task_eval = Tasklet("", self.evaluate) + task_eval = Tasklet("evaluate", self.evaluate) - task_analysis = Tasklet("", self.run_analysis) + task_analysis = Tasklet("analysis", self.run_analysis) - task_save_metrics = Tasklet("", self.save_metrics) + task_save_metrics = Tasklet("save_metrics", self.save_metrics) - task_increment_round = Tasklet("", self.increment_round) + task_increment_round = Tasklet("inc_round", self.increment_round) - task_end_of_training = Tasklet("", self.inform_end_of_training) + task_end_of_training = Tasklet( + "inform_end_of_training", self.inform_end_of_training + ) - task_save_params = Tasklet("", self.save_params) + task_save_params = Tasklet("save_params", self.save_params) - task_save_model = Tasklet("", self.save_model) + task_save_model = Tasklet("save_model", self.save_model) # create a loop object with loop exit condition function loop = Loop(loop_check_fn=lambda: self._work_done) diff --git a/lib/python/flame/mode/horizontal/feddyn/trainer.py b/lib/python/flame/mode/horizontal/feddyn/trainer.py index 1f2f43704..8ca7d38b4 100644 --- a/lib/python/flame/mode/horizontal/feddyn/trainer.py +++ b/lib/python/flame/mode/horizontal/feddyn/trainer.py @@ -153,23 +153,23 @@ def compose(self) -> None: with Composer() as composer: self.composer = composer - task_internal_init = Tasklet("", self.internal_init) + task_internal_init = Tasklet("internal_init", self.internal_init) - task_load_data = Tasklet("", self.load_data) + task_load_data = Tasklet("load_data", self.load_data) - task_init = Tasklet("", self.initialize) + task_init = Tasklet("init", self.initialize) - task_put_dataset_size = Tasklet("", self.put, TAG_UPLOAD_DATASET_SIZE) + task_put_dataset_size = Tasklet("upload_dataset_size", self.put, TAG_UPLOAD_DATASET_SIZE) - task_get = Tasklet("", self.get, TAG_FETCH) + task_get = Tasklet("fetch", self.get, TAG_FETCH) - task_train = Tasklet("", self.train) + task_train = Tasklet("train", self.train) - task_eval = Tasklet("", self.evaluate) + task_eval = Tasklet("evaluate", self.evaluate) - task_put = Tasklet("", self.put, TAG_UPLOAD) + task_put = Tasklet("upload", self.put, TAG_UPLOAD) - task_save_metrics = Tasklet("", self.save_metrics) + task_save_metrics = Tasklet("save_metrics", self.save_metrics) # create a loop object with loop exit condition function loop = Loop(loop_check_fn=lambda: self._work_done) diff --git a/lib/python/flame/mode/horizontal/oort/top_aggregator.py b/lib/python/flame/mode/horizontal/oort/top_aggregator.py index a3f46b440..3ed6aa7b3 100644 --- a/lib/python/flame/mode/horizontal/oort/top_aggregator.py +++ b/lib/python/flame/mode/horizontal/oort/top_aggregator.py @@ -21,11 +21,11 @@ from copy import deepcopy from typing import Any, Tuple -from ....common.util import weights_to_device, weights_to_model_device -from ....common.constants import DeviceType -from ....optimizer.train_result import TrainResult -from ...message import MessageType -from ....selector.oort import ( +from flame.common.util import weights_to_device, weights_to_model_device +from flame.common.constants import DeviceType +from flame.optimizer.train_result import TrainResult +from flame.mode.message import MessageType +from flame.selector.oort import ( PROP_ROUND_DURATION, PROP_ROUND_START_TIME, PROP_STAT_UTILITY, diff --git a/lib/python/flame/mode/horizontal/syncfl/top_aggregator.py b/lib/python/flame/mode/horizontal/syncfl/top_aggregator.py index 98e3dfd28..3d017cdf3 100644 --- a/lib/python/flame/mode/horizontal/syncfl/top_aggregator.py +++ b/lib/python/flame/mode/horizontal/syncfl/top_aggregator.py @@ -53,7 +53,7 @@ class TopAggregator(Role, metaclass=ABCMeta): """Top level Aggregator implements an ML aggregation role.""" - + @abstract_attribute def config(self) -> Config: """Abstract attribute for config object.""" @@ -101,7 +101,7 @@ def internal_init(self) -> None: self.datasampler = datasampler_provider.get( self.config.datasampler.sort, **self.config.datasampler.kwargs ).aggregator_data_sampler - + self._round = 1 self._rounds = 1 self._rounds = self.config.hyperparameters.rounds @@ -240,6 +240,8 @@ def run_analysis(self): def save_metrics(self): """Save metrics in a model registry.""" + # update metrics with metrics from metric collector + self.metrics = self.metrics | self.mc.get() logger.debug(f"saving metrics: {self.metrics}") if self.metrics: self.registry_client.save_metrics(self._round - 1, self.metrics) diff --git a/lib/python/flame/mode/horizontal/syncfl/trainer.py b/lib/python/flame/mode/horizontal/syncfl/trainer.py index d9c9aa57c..d9a79a280 100644 --- a/lib/python/flame/mode/horizontal/syncfl/trainer.py +++ b/lib/python/flame/mode/horizontal/syncfl/trainer.py @@ -196,6 +196,8 @@ def _send_weights(self, tag: str) -> None: def save_metrics(self): """Save metrics in a model registry.""" + # update self.metrics with metrics from MetricCollector instance + self.metrics = self.metrics | self.mc.get() logger.debug(f"saving metrics: {self.metrics}") if self.metrics: self.registry_client.save_metrics(self._round - 1, self.metrics) diff --git a/lib/python/flame/mode/hybrid/trainer.py b/lib/python/flame/mode/hybrid/trainer.py index c7b383423..cca53137f 100644 --- a/lib/python/flame/mode/hybrid/trainer.py +++ b/lib/python/flame/mode/hybrid/trainer.py @@ -16,7 +16,6 @@ """hybrid FL trainer.""" import logging -import time from flame.channel_manager import ChannelManager from flame.common.constants import DeviceType diff --git a/lib/python/flame/mode/role.py b/lib/python/flame/mode/role.py index 334c724a2..dc8440190 100644 --- a/lib/python/flame/mode/role.py +++ b/lib/python/flame/mode/role.py @@ -17,10 +17,12 @@ """role abstract class.""" from abc import ABC, abstractmethod +from flame.monitor.metric_collector import MetricCollector class Role(ABC): """Abstract base class for role implementation.""" + mc = MetricCollector() ########################################################################### # The following functions need to be implemented the child class diff --git a/lib/python/flame/mode/tasklet.py b/lib/python/flame/mode/tasklet.py index 29de8703c..1f06c50a5 100644 --- a/lib/python/flame/mode/tasklet.py +++ b/lib/python/flame/mode/tasklet.py @@ -23,6 +23,7 @@ from flame.mode.composer import ComposerContext from flame.mode.enums import LoopIndicator +from flame.monitor.runtime import time_tasklet logger = logging.getLogger(__name__) @@ -151,6 +152,7 @@ def get_ender(self) -> Tasklet: return self.loop_ender + @time_tasklet def do(self) -> None: """Execute tasklet.""" self.func(*self.args, **self.kwargs) diff --git a/lib/python/flame/monitor/__init__.py b/lib/python/flame/monitor/__init__.py new file mode 100644 index 000000000..506f034ea --- /dev/null +++ b/lib/python/flame/monitor/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2023 Cisco Systems, Inc. and its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + + diff --git a/lib/python/flame/monitor/metric_collector.py b/lib/python/flame/monitor/metric_collector.py new file mode 100644 index 000000000..17abbdf93 --- /dev/null +++ b/lib/python/flame/monitor/metric_collector.py @@ -0,0 +1,38 @@ +# Copyright 2023 Cisco Systems, Inc. and its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +"""Metric Collector.""" + +import logging + +logger = logging.getLogger(__name__) + +class MetricCollector: + def __init__(self): + """Initialize Metric Collector.""" + self.state_dict = dict() + + def save(self, mtype, alias, value): + """Saves key-value pair for metric.""" + key = f'{mtype}-{alias}' + self.state_dict[key] = value + logger.debug(f"Saving state_dict[{key}] = {value}") + + def get(self): + """Returns the current metrics that were collected and clears the saved dictionary.""" + temp = self.state_dict + self.state_dict = dict() + return temp + diff --git a/lib/python/flame/monitor/runtime.py b/lib/python/flame/monitor/runtime.py new file mode 100644 index 000000000..8d066b533 --- /dev/null +++ b/lib/python/flame/monitor/runtime.py @@ -0,0 +1,39 @@ +# Copyright 2023 Cisco Systems, Inc. and its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +"""Runtime for Metric Collector.""" + +import logging +import time + +logger = logging.getLogger(__name__) + +def time_tasklet(func): + """Decorator to time Tasklet.do() function""" + def wrapper(*args, **kwargs): + s = args[0] + if s.composer.mc: + start = time.time() + result = func(*args, **kwargs) + end = time.time() + + s.composer.mc.save("runtime", s.alias, end-start) + logger.debug(f"Runtime of {s.alias} is {end-start}") + return result + else: + logger.debug("No MetricCollector; won't record runtime") + return func(*args, **kwargs) + + return wrapper