Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add algorithm registry #433

Merged
merged 7 commits into from
Dec 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion bagua/torch_api/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,33 @@
#!/usr/bin/env python3

from .base import Algorithm, AlgorithmImpl # noqa: F401
from .base import Algorithm, AlgorithmImpl, GlobalAlgorithmRegistry # noqa: F401
from . import bytegrad, decentralized, gradient_allreduce # noqa: F401
from . import q_adam, async_model_average # noqa: F401


GlobalAlgorithmRegistry.register(
"gradient_allreduce",
gradient_allreduce.GradientAllReduceAlgorithm,
description="Gradient AllReduce Algorithm",
)
GlobalAlgorithmRegistry.register(
"bytegrad", bytegrad.ByteGradAlgorithm, description="ByteGrad Algorithm"
)
GlobalAlgorithmRegistry.register(
"decentralized",
decentralized.DecentralizedAlgorithm,
description="Decentralized SGD Algorithm",
)
GlobalAlgorithmRegistry.register(
"low_precision_decentralized",
decentralized.LowPrecisionDecentralizedAlgorithm,
description="Low Precision Decentralized SGD Algorithm",
)
GlobalAlgorithmRegistry.register(
"qadam", q_adam.QAdamAlgorithm, description="QAdam Algorithm"
)
GlobalAlgorithmRegistry.register(
"async",
async_model_average.AsyncModelAverageAlgorithm,
description="Asynchronous Model Average Algorithm",
)
90 changes: 87 additions & 3 deletions bagua/torch_api/algorithms/base.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,52 @@
from __future__ import annotations
from bagua.torch_api.data_parallel.bagua_distributed import BaguaDistributedDataParallel
from bagua.torch_api.bucket import BaguaBucket
from bagua.torch_api.tensor import BaguaTensor
from bagua.torch_api.communication import BaguaProcessGroup
from typing import List
from typing import Any, Callable, Dict, List, Optional
import torch


__all__ = ["Algorithm", "AlgorithmImpl"]


class Algorithm:
"""
This is the base class that all Bagua algorithms inherit.
"""

def reify(self, process_group: BaguaProcessGroup):
"""
Create an algorithm instance.
Create an algorithm implementation instance. See :class:`AlgorithmImpl`.

Args:
process_group: The process group to work on.

Returns:
An instance of Bagua algorithm implementation.
"""
pass

@classmethod
def init(cls, name, **kwargs) -> Algorithm:
"""Helper class to initialize a registered Bagua algorithm.

Args:
name: Name of the registered Bagua algorithm.
kwargs: Arguments to initialize the registered Bagua algorithm.

Returns:
An instance of a registered Bagua algorithm.

Example::
>>> from bagua.torch_api.algorithms import Algorithm
>>> algorithm = Algorithm.init("gradient_allreduce", hierarchical=True)

.. note::
Call ``str(bagua.torch_api.algorithms.GlobalAlgorithmRegistry)`` to see all registered Bagua algorithms.
"""
return GlobalAlgorithmRegistry.get(name)(**kwargs)


class AlgorithmImpl:
"""
Expand All @@ -43,7 +70,9 @@ def need_reset(self) -> bool:
"""
return False

def init_tensors(self, bagua_ddp: BaguaDistributedDataParallel) -> List[BaguaTensor]:
def init_tensors(
self, bagua_ddp: BaguaDistributedDataParallel
) -> List[BaguaTensor]:
"""
Given a :class:`~bagua.torch_api.data_parallel.BaguaDistributedDataParallel`, return Bagua tensors to be used in Bagua for later
operations.
Expand Down Expand Up @@ -177,3 +206,58 @@ def init_operations(
bagua_ddp: :class:`bagua.torch_api.data_parallel.BaguaDistributedDataParallel`.
bucket: A single bucket to register operations.
"""


class _AlgorithmRegistry(dict):
def register(
self,
name: str,
algorithm: Callable,
description: Optional[str] = None,
):
"""Registers an Bagua Algorithm mapped to a name and with required metadata.

Args:
name: The name that identifies a Bagua algorithm, e.g. "gradient_allreduce".
algorithm: Class of the Bagua algorithm.
description: Description of the Bagua algorithm.
"""
if not (name is None or isinstance(name, str)):
raise TypeError(f"`name` must be a str, found {name}")

if name in self:
raise ValueError(f"'{name}' is already present in the registry.")

data: Dict[str, Any] = {}
data["algorithm"] = algorithm
data["description"] = description if description is not None else ""

self[name] = data

def get(self, name: str) -> Callable:
"""Calls the registered Bagua algorithm with the name and returns the algorithm class.

Args:
name: The name that identifies a Bagua algorithm, e.g. "gradient_allreduce".

Returns:
The class of the Bagua algorithm.
"""

if name in self:
data = self[name]
return data["algorithm"]

err_msg = "'{}' not found in registry. Available names: {}"
available_names = ", ".join(sorted(self.keys())) or "none"
raise KeyError(err_msg.format(name, available_names))

def available_algorithms(self) -> List[str]:
"""Returns a list of registered Bagua algorithms."""
return list(self.keys())

def __str__(self) -> str:
return "Registered Algorithms: {}".format(", ".join(self.keys()))


GlobalAlgorithmRegistry = _AlgorithmRegistry()