Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: local registry #421

Merged
merged 1 commit into from
May 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 2 additions & 13 deletions lib/python/flame/common/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@

from pip._internal.cli.main import main as pipmain

from ..config import Config
from .typing import ModelWeights
from .constants import DeviceType
from flame.common.typing import ModelWeights
from flame.common.constants import DeviceType

PYTORCH = "torch"
TENSORFLOW = "tensorflow"
Expand Down Expand Up @@ -150,16 +149,6 @@ def install_package(package: str) -> bool:
return False


def mlflow_runname(config: Config) -> str:
groupby_value = ""
for v in config.channels.values():
for val in v.group_by.value:
if val in config.realm:
groupby_value = groupby_value + val + "-"

return config.role + "-" + groupby_value + config.task_id[:8]


def delta_weights_pytorch(
a: ModelWeights, b: ModelWeights
) -> Union[ModelWeights, None]:
Expand Down
1 change: 1 addition & 0 deletions lib/python/flame/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class RegistryType(str, Enum):
"""Define model registry types."""

DUMMY = "dummy"
LOCAL = "local"
myungjin marked this conversation as resolved.
Show resolved Hide resolved
MLFLOW = "mlflow"


Expand Down
5 changes: 2 additions & 3 deletions lib/python/flame/mode/distributed/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
delta_weights_pytorch,
delta_weights_tensorflow,
get_ml_framework_in_use,
mlflow_runname,
valid_frameworks,
weights_to_device,
weights_to_model_device,
Expand Down Expand Up @@ -70,7 +69,7 @@ def internal_init(self) -> None:
"""Initialize internal state for role."""
self.registry_client = registry_provider.get(self.config.registry.sort)
# initialize registry client
self.registry_client(self.config.registry.uri, self.config.job.job_id)
self.registry_client(self.config)

base_model = self.config.base_model
if base_model and base_model.name != "" and base_model.version > 0:
Expand All @@ -80,7 +79,7 @@ def internal_init(self) -> None:
self.ring_weights = None # latest model weights from ring all-reduce
self.weights = None

self.registry_client.setup_run(mlflow_runname(self.config))
self.registry_client.setup_run()
self.metrics = dict()

self._round = 1
Expand Down
9 changes: 4 additions & 5 deletions lib/python/flame/mode/horizontal/syncfl/top_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from flame.common.util import (
MLFramework,
get_ml_framework_in_use,
mlflow_runname,
valid_frameworks,
weights_to_device,
weights_to_model_device,
Expand Down Expand Up @@ -81,22 +80,22 @@ def internal_init(self) -> None:

self.registry_client = registry_provider.get(self.config.registry.sort)
# initialize registry client
self.registry_client(self.config.registry.uri, self.config.job.job_id)
self.registry_client(self.config)

base_model = self.config.base_model
if base_model and base_model.name != "" and base_model.version > 0:
self.model = self.registry_client.load_model(
base_model.name, base_model.version
)

self.registry_client.setup_run(mlflow_runname(self.config))
self.registry_client.setup_run()
self.metrics = dict()

# disk cache is used for saving memory in case model is large
# automatic eviction of disk cache is disabled with cull_limit 0
self.cache = Cache()
self.cache.reset('size_limit', 1e15)
self.cache.reset('cull_limit', 0)
self.cache.reset("size_limit", 1e15)
self.cache.reset("cull_limit", 0)

self.optimizer = optimizer_provider.get(
self.config.optimizer.sort, **self.config.optimizer.kwargs
Expand Down
5 changes: 2 additions & 3 deletions lib/python/flame/mode/horizontal/syncfl/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
delta_weights_pytorch,
delta_weights_tensorflow,
get_ml_framework_in_use,
mlflow_runname,
valid_frameworks,
weights_to_device,
weights_to_model_device,
Expand Down Expand Up @@ -71,9 +70,9 @@ def internal_init(self) -> None:

self.registry_client = registry_provider.get(self.config.registry.sort)
# initialize registry client
self.registry_client(self.config.registry.uri, self.config.job.job_id)
self.registry_client(self.config)

self.registry_client.setup_run(mlflow_runname(self.config))
self.registry_client.setup_run()
self.metrics = dict()

# needed for trainer-side optimization algorithms such as fedprox
Expand Down
15 changes: 9 additions & 6 deletions lib/python/flame/registries.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,24 @@

from typing import Union

from .config import RegistryType
from .object_factory import ObjectFactory
from .registry.dummy import DummyRegistryClient
from .registry.mlflow import MLflowRegistryClient
from flame.config import RegistryType
from flame.object_factory import ObjectFactory
from flame.registry.dummy import DummyRegistryClient
from flame.registry.mlflow import MLflowRegistryClient
from flame.registry.local import LocalRegistryClient


class RegistryProvider(ObjectFactory):
"""Model registry provider."""

def get(self, registry_name,
**kwargs) -> Union[DummyRegistryClient, MLflowRegistryClient]:
def get(
self, registry_name, **kwargs
) -> Union[DummyRegistryClient, MLflowRegistryClient]:
"""Return a registry client for a given registry name."""
return self.create(registry_name, **kwargs)


registry_provider = RegistryProvider()
registry_provider.register(RegistryType.DUMMY, DummyRegistryClient)
registry_provider.register(RegistryType.MLFLOW, MLflowRegistryClient)
registry_provider.register(RegistryType.LOCAL, LocalRegistryClient)
11 changes: 6 additions & 5 deletions lib/python/flame/registry/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,26 @@
from abc import ABC, abstractmethod
from typing import Any, Optional

from flame.config import Hyperparameters, Config


class AbstractRegistryClient(ABC):
"""Abstract registry client."""

@abstractmethod
def __call__(self, uri: str, job_id: str) -> None:
def __call__(self, config: Config) -> None:
"""Abstract method for initializing a registry client."""

@abstractmethod
def setup_run(self, name: str) -> None:
def setup_run(self) -> None:
"""Abstract method for setup a run."""

@abstractmethod
def save_metrics(self, epoch: int, metrics: Optional[dict[str,
float]]) -> None:
def save_metrics(self, epoch: int, metrics: Optional[dict[str, float]]) -> None:
"""Abstract method for saving metrics in a model registry."""

@abstractmethod
def save_params(self, hyperparameters: Optional[dict[str, float]]) -> None:
def save_params(self, hyperparameters: Optional[Hyperparameters]) -> None:
"""Abstract method for saving hyperparameters in a model registry."""

@abstractmethod
Expand Down
12 changes: 6 additions & 6 deletions lib/python/flame/registry/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,26 @@

from typing import Any, Optional

from .abstract import AbstractRegistryClient
from flame.config import Hyperparameters, Config
from flame.registry.abstract import AbstractRegistryClient


class DummyRegistryClient(AbstractRegistryClient):
"""Dummy registry client."""

def __call__(self, uri: str, job_id: str) -> None:
def __call__(self, config: Config) -> None:
"""Initialize the instance."""
pass

def setup_run(self, name: str) -> None:
def setup_run(self) -> None:
"""Set up a run."""
pass

def save_metrics(self, epoch: int, metrics: Optional[dict[str,
float]]) -> None:
def save_metrics(self, epoch: int, metrics: Optional[dict[str, float]]) -> None:
"""Save metrics in a model registry."""
pass

def save_params(self, hyperparameters: Optional[dict[str, float]]) -> None:
def save_params(self, hyperparameters: Optional[Hyperparameters]) -> None:
"""Save hyperparameters in a model registry."""
pass

Expand Down
133 changes: 133 additions & 0 deletions lib/python/flame/registry/local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright 2023 Cisco Systems, Inc. and its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
"""Local registry client."""

from typing import Any, Optional
import logging
import os
import csv
import pickle
from collections import defaultdict
from pathlib import Path
import shutil


from flame.common.util import get_ml_framework_in_use, MLFramework
from flame.config import Hyperparameters
from flame.config import Config
from flame.registry.abstract import AbstractRegistryClient

logger = logging.getLogger(__name__)

# specify folder names
METRICS = "metrics"
MODEL = "model"
PARAMS = "params"
FLAME_LOG = "flame-log"


class LocalRegistryClient(AbstractRegistryClient):
"""Local registry client."""

_instance = None

def __new__(cls):
"""Create a singleton instance."""
if cls._instance is None:
logger.info("Create a local registry client instance")
cls._instance = super().__new__(cls)
return cls._instance

def __call__(self, config: Config) -> None:
"""Initialize the instance."""
self.job_id = config.job.job_id
self.task_id = config.task_id

def setup_run(self) -> None:
"""Set up a run."""
# set up directories
home_dir = Path.home()
log_dir = os.path.join(home_dir, FLAME_LOG)
self.registry_path = os.path.join(log_dir, self.job_id, self.task_id)

if os.path.exists(self.registry_path):
shutil.rmtree(self.registry_path)

for directory in [METRICS, MODEL, PARAMS]:
os.makedirs(os.path.join(self.registry_path, directory))

# version tracking
self.param_version = 1
self.model_versions = defaultdict(lambda: 1)

def save_metrics(self, epoch: int, metrics: Optional[dict[str, float]]) -> None:
"""Save metrics in a model registry."""
for metric in metrics:
filename = os.path.join(self.registry_path, METRICS, metric)
exists = os.path.exists(filename)
with open(filename, "a+") as file:
csv_writer = csv.writer(file)
csv_writer.writerow(["round", metric]) if not exists else None
csv_writer.writerow([epoch, metrics[metric]])

def save_params(self, hyperparameters: Optional[Hyperparameters]) -> None:
"""Save hyperparameters in a model registry."""
with open(
os.path.join(self.registry_path, PARAMS, str(self.param_version)), "wb"
) as file:
pickle.dump(hyperparameters.dict(), file)
self.param_version += 1

def cleanup(self) -> None:
"""Clean up resources."""
pass

def save_model(self, name: str, model: Any) -> None:
"""Save a model in a model registry."""
model_folder = os.path.join(self.registry_path, MODEL, name)
os.makedirs(model_folder, exist_ok=True)

ml_framework = get_ml_framework_in_use()

if ml_framework == MLFramework.PYTORCH:
import torch

torch.save(
model, os.path.join(model_folder, str(self.model_versions[name]))
)
elif ml_framework == MLFramework.TENSORFLOW:
model.save(os.path.join(model_folder, str(self.model_versions[name])))
self.model_versions[name] += 1

def load_model(self, name: str, version: int) -> object:
"""
Load a model.

This method can be called without calling self.setup_run().
"""
ml_framework = get_ml_framework_in_use()
model_path = os.path.join(self.registry_path, MODEL, name, str(version))
if ml_framework == MLFramework.PYTORCH:
import torch

# the class definition for the model must be available for this
return torch.load(model_path)
elif ml_framework == MLFramework.TENSORFLOW:
import tensorflow

return tensorflow.keras.models.load_model(model_path)

raise ModuleNotFoundError("Module for loading model not found")
Loading