Skip to content

Commit

Permalink
track tasklet runtime (#405)
Browse files Browse the repository at this point in the history
Using a decorator created in flame/monitor/util.py, tasklets now can send runtime information to a MetricCollector in a composer.
The composer is initiliazed only once per process in the Role abstract class.
This allows for all runtime metrics to be stored in the same MetricCollector instance within a process.

Additionally, some imports in flame/mode have been adjusted, and aliases for feddyn were added as well.
  • Loading branch information
GustavBaumgart committed May 11, 2023
1 parent f89e054 commit 0b030d1
Show file tree
Hide file tree
Showing 14 changed files with 139 additions and 32 deletions.
3 changes: 3 additions & 0 deletions lib/python/flame/mode/composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from types import TracebackType
from typing import Optional, Type
from flame.mode.enums import LoopIndicator
from flame.mode.role import Role

logger = logging.getLogger(__name__)

Expand All @@ -34,6 +35,8 @@ def __init__(self) -> None:
self.reverse_chain = dict()

self.unlinked_tasklets = dict()

self.mc = Role.mc

def __enter__(self):
"""Enter custom context."""
Expand Down
1 change: 1 addition & 0 deletions lib/python/flame/mode/distributed/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ def put(self, tag: str) -> None:

def save_metrics(self):
"""Save metrics in a model registry."""
self.metrics = self.metrics | self.mc.get()
logger.debug(f"saving metrics: {self.metrics}")
if self.metrics:
self.registry_client.save_metrics(self._round, self.metrics)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
TopAggregator as BaseTopAggregator,
)
from flame.mode.message import MessageType
from flame.mode.tasklet import Loop, Tasklet
from flame.mode.tasklet import Tasklet

logger = logging.getLogger(__name__)

Expand Down
30 changes: 16 additions & 14 deletions lib/python/flame/mode/horizontal/feddyn/top_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,33 +176,35 @@ def compose(self) -> None:
with Composer() as composer:
self.composer = composer

task_internal_init = Tasklet("", self.internal_init)
task_internal_init = Tasklet("internal_init", self.internal_init)

task_init = Tasklet("", self.initialize)
task_init = Tasklet("initialize", self.initialize)

task_load_data = Tasklet("", self.load_data)
task_load_data = Tasklet("load_data", self.load_data)

task_get_dataset = Tasklet("", self.get, TAG_GET_DATATSET_SIZE)
task_get_dataset = Tasklet("get_dataset_size", self.get, TAG_GET_DATATSET_SIZE)

task_put = Tasklet("", self.put, TAG_DISTRIBUTE)
task_put = Tasklet("distribute", self.put, TAG_DISTRIBUTE)

task_get = Tasklet("", self.get, TAG_AGGREGATE)
task_get = Tasklet("aggregate", self.get, TAG_AGGREGATE)

task_train = Tasklet("", self.train)
task_train = Tasklet("train", self.train)

task_eval = Tasklet("", self.evaluate)
task_eval = Tasklet("evaluate", self.evaluate)

task_analysis = Tasklet("", self.run_analysis)
task_analysis = Tasklet("analysis", self.run_analysis)

task_save_metrics = Tasklet("", self.save_metrics)
task_save_metrics = Tasklet("save_metrics", self.save_metrics)

task_increment_round = Tasklet("", self.increment_round)
task_increment_round = Tasklet("inc_round", self.increment_round)

task_end_of_training = Tasklet("", self.inform_end_of_training)
task_end_of_training = Tasklet(
"inform_end_of_training", self.inform_end_of_training
)

task_save_params = Tasklet("", self.save_params)
task_save_params = Tasklet("save_params", self.save_params)

task_save_model = Tasklet("", self.save_model)
task_save_model = Tasklet("save_model", self.save_model)

# create a loop object with loop exit condition function
loop = Loop(loop_check_fn=lambda: self._work_done)
Expand Down
18 changes: 9 additions & 9 deletions lib/python/flame/mode/horizontal/feddyn/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,23 +153,23 @@ def compose(self) -> None:
with Composer() as composer:
self.composer = composer

task_internal_init = Tasklet("", self.internal_init)
task_internal_init = Tasklet("internal_init", self.internal_init)

task_load_data = Tasklet("", self.load_data)
task_load_data = Tasklet("load_data", self.load_data)

task_init = Tasklet("", self.initialize)
task_init = Tasklet("init", self.initialize)

task_put_dataset_size = Tasklet("", self.put, TAG_UPLOAD_DATASET_SIZE)
task_put_dataset_size = Tasklet("upload_dataset_size", self.put, TAG_UPLOAD_DATASET_SIZE)

task_get = Tasklet("", self.get, TAG_FETCH)
task_get = Tasklet("fetch", self.get, TAG_FETCH)

task_train = Tasklet("", self.train)
task_train = Tasklet("train", self.train)

task_eval = Tasklet("", self.evaluate)
task_eval = Tasklet("evaluate", self.evaluate)

task_put = Tasklet("", self.put, TAG_UPLOAD)
task_put = Tasklet("upload", self.put, TAG_UPLOAD)

task_save_metrics = Tasklet("", self.save_metrics)
task_save_metrics = Tasklet("save_metrics", self.save_metrics)

# create a loop object with loop exit condition function
loop = Loop(loop_check_fn=lambda: self._work_done)
Expand Down
10 changes: 5 additions & 5 deletions lib/python/flame/mode/horizontal/oort/top_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
from copy import deepcopy
from typing import Any, Tuple

from ....common.util import weights_to_device, weights_to_model_device
from ....common.constants import DeviceType
from ....optimizer.train_result import TrainResult
from ...message import MessageType
from ....selector.oort import (
from flame.common.util import weights_to_device, weights_to_model_device
from flame.common.constants import DeviceType
from flame.optimizer.train_result import TrainResult
from flame.mode.message import MessageType
from flame.selector.oort import (
PROP_ROUND_DURATION,
PROP_ROUND_START_TIME,
PROP_STAT_UTILITY,
Expand Down
6 changes: 4 additions & 2 deletions lib/python/flame/mode/horizontal/syncfl/top_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@

class TopAggregator(Role, metaclass=ABCMeta):
"""Top level Aggregator implements an ML aggregation role."""

@abstract_attribute
def config(self) -> Config:
"""Abstract attribute for config object."""
Expand Down Expand Up @@ -101,7 +101,7 @@ def internal_init(self) -> None:
self.datasampler = datasampler_provider.get(
self.config.datasampler.sort, **self.config.datasampler.kwargs
).aggregator_data_sampler

self._round = 1
self._rounds = 1
self._rounds = self.config.hyperparameters.rounds
Expand Down Expand Up @@ -240,6 +240,8 @@ def run_analysis(self):

def save_metrics(self):
"""Save metrics in a model registry."""
# update metrics with metrics from metric collector
self.metrics = self.metrics | self.mc.get()
logger.debug(f"saving metrics: {self.metrics}")
if self.metrics:
self.registry_client.save_metrics(self._round - 1, self.metrics)
Expand Down
2 changes: 2 additions & 0 deletions lib/python/flame/mode/horizontal/syncfl/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@ def _send_weights(self, tag: str) -> None:

def save_metrics(self):
"""Save metrics in a model registry."""
# update self.metrics with metrics from MetricCollector instance
self.metrics = self.metrics | self.mc.get()
logger.debug(f"saving metrics: {self.metrics}")
if self.metrics:
self.registry_client.save_metrics(self._round - 1, self.metrics)
Expand Down
1 change: 0 additions & 1 deletion lib/python/flame/mode/hybrid/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
"""hybrid FL trainer."""

import logging
import time

from flame.channel_manager import ChannelManager
from flame.common.constants import DeviceType
Expand Down
2 changes: 2 additions & 0 deletions lib/python/flame/mode/role.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
"""role abstract class."""

from abc import ABC, abstractmethod
from flame.monitor.metric_collector import MetricCollector


class Role(ABC):
"""Abstract base class for role implementation."""
mc = MetricCollector()

###########################################################################
# The following functions need to be implemented the child class
Expand Down
2 changes: 2 additions & 0 deletions lib/python/flame/mode/tasklet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from flame.mode.composer import ComposerContext
from flame.mode.enums import LoopIndicator
from flame.monitor.runtime import time_tasklet

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -151,6 +152,7 @@ def get_ender(self) -> Tasklet:

return self.loop_ender

@time_tasklet
def do(self) -> None:
"""Execute tasklet."""
self.func(*self.args, **self.kwargs)
Expand Down
17 changes: 17 additions & 0 deletions lib/python/flame/monitor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2023 Cisco Systems, Inc. and its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0


38 changes: 38 additions & 0 deletions lib/python/flame/monitor/metric_collector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2023 Cisco Systems, Inc. and its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
"""Metric Collector."""

import logging

logger = logging.getLogger(__name__)

class MetricCollector:
def __init__(self):
"""Initialize Metric Collector."""
self.state_dict = dict()

def save(self, mtype, alias, value):
"""Saves key-value pair for metric."""
key = f'{mtype}-{alias}'
self.state_dict[key] = value
logger.debug(f"Saving state_dict[{key}] = {value}")

def get(self):
"""Returns the current metrics that were collected and clears the saved dictionary."""
temp = self.state_dict
self.state_dict = dict()
return temp

39 changes: 39 additions & 0 deletions lib/python/flame/monitor/runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2023 Cisco Systems, Inc. and its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
"""Runtime for Metric Collector."""

import logging
import time

logger = logging.getLogger(__name__)

def time_tasklet(func):
"""Decorator to time Tasklet.do() function"""
def wrapper(*args, **kwargs):
s = args[0]
if s.composer.mc:
start = time.time()
result = func(*args, **kwargs)
end = time.time()

s.composer.mc.save("runtime", s.alias, end-start)
logger.debug(f"Runtime of {s.alias} is {end-start}")
return result
else:
logger.debug("No MetricCollector; won't record runtime")
return func(*args, **kwargs)

return wrapper

0 comments on commit 0b030d1

Please sign in to comment.