Skip to content

Commit

Permalink
LLM Foundry CLI (just registry) (#1043)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Mar 24, 2024
1 parent 813d596 commit 67dcab9
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 20 deletions.
8 changes: 7 additions & 1 deletion REGISTRY.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,10 @@ code_paths:


## Discovering registrable components
Coming soon
To help find and understand registrable components, you can use the `llmfoundry registry` cli command.

We provide two commands:
- `llmfoundry registry get [--group]`: List all registries, and their components, optionally specifying a specific registry. Example usage: `llmfoundry registry get --group loggers` or `llmfoundry registry get`
- `llmfoundry registry find <group> <name>`: Get information about a specific registered component. Example usage: `llmfoundry registry find loggers wandb`

Use `--help` on any of these commands for more information.
2 changes: 2 additions & 0 deletions llmfoundry/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
12 changes: 12 additions & 0 deletions llmfoundry/cli/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import typer

from llmfoundry.cli import registry_cli

app = typer.Typer(pretty_exceptions_show_locals=False)
app.add_typer(registry_cli.app, name='registry')

if __name__ == '__main__':
app()
72 changes: 72 additions & 0 deletions llmfoundry/cli/registry_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Optional

import typer
from rich.console import Console
from rich.table import Table

from llmfoundry import registry
from llmfoundry.utils.registry_utils import TypedRegistry

console = Console()
app = typer.Typer(pretty_exceptions_show_locals=False)


def _get_registries(group: Optional[str] = None) -> list[TypedRegistry]:
registry_attr_names = dir(registry)
registry_attrs = [getattr(registry, name) for name in registry_attr_names]
available_registries = [
r for r in registry_attrs if isinstance(r, TypedRegistry)
]

if group is not None and group not in registry_attr_names:
console.print(
f'Group {group} not found in registry. Run `llmfoundry registry get` to see available groups.'
)
return []

if group is not None:
available_registries = [getattr(registry, group)]

return available_registries


@app.command()
def get(group: Optional[str] = None):
"""Get the available registries.
Args:
group (Optional[str], optional): The group to get. If not provided, all groups will be shown. Defaults to None.
"""
available_registries = _get_registries(group)

table = Table('Registry', 'Description', 'Options', show_lines=True)
for r in available_registries:
table.add_row('.'.join(r.namespace), r.description,
', '.join(r.get_all()))

console.print(table)


@app.command()
def find(group: str, name: str):
"""Find a registry entry by name.
Args:
group (str): The group to search.
name (str): The name of the entry to search for.
"""
available_registries = _get_registries(group)
if not available_registries:
return

r = available_registries[0]
find_output = r.find(name)

table = Table('Module', 'File', 'Line number', 'Docstring')
table.add_row(find_output['module'], find_output['file'],
str(find_output['line_no']), find_output['docstring'])

console.print(table)
54 changes: 35 additions & 19 deletions llmfoundry/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,54 +11,70 @@
from llmfoundry.interfaces import CallbackWithConfig
from llmfoundry.utils.registry_utils import create_registry

_loggers_description = """The loggers registry is used to register classes that implement the LoggerDestination interface.
These classes are used to log data from the training loop, and will be passed to the loggers arg of the Trainer. The loggers
will be constructed by directly passing along the specified kwargs to the constructor."""
_loggers_description = (
'The loggers registry is used to register classes that implement the LoggerDestination interface. '
+
'These classes are used to log data from the training loop, and will be passed to the loggers arg of the Trainer. The loggers '
+
'will be constructed by directly passing along the specified kwargs to the constructor.'
)
loggers = create_registry('llmfoundry',
'loggers',
generic_type=Type[LoggerDestination],
entry_points=True,
description=_loggers_description)

_callbacks_description = """The callbacks registry is used to register classes that implement the Callback interface.
These classes are used to interact with the Composer event system, and will be passed to the callbacks arg of the Trainer.
The callbacks will be constructed by directly passing along the specified kwargs to the constructor."""
_callbacks_description = (
'The callbacks registry is used to register classes that implement the Callback interface. '
+
'These classes are used to interact with the Composer event system, and will be passed to the callbacks arg of the Trainer. '
+
'The callbacks will be constructed by directly passing along the specified kwargs to the constructor.'
)
callbacks = create_registry('llmfoundry',
'callbacks',
generic_type=Type[Callback],
entry_points=True,
description=_callbacks_description)

_callbacks_with_config_description = """The callbacks_with_config registry is used to register classes that implement the CallbackWithConfig interface.
These are the same as the callbacks registry, except that they additionally take the full training config as an argument to their constructor."""
_callbacks_with_config_description = (
'The callbacks_with_config registry is used to register classes that implement the CallbackWithConfig interface. '
+
'These are the same as the callbacks registry, except that they additionally take the full training config as an argument to their constructor.'
)
callbacks_with_config = create_registry(
'llm_foundry',
'callbacks_with_config',
'llm_foundry.callbacks_with_config',
generic_type=Type[CallbackWithConfig],
entry_points=True,
description=_callbacks_with_config_description)

_optimizers_description = """The optimizers registry is used to register classes that implement the Optimizer interface.
The optimizer will be passed to the optimizers arg of the Trainer. The optimizer will be constructed by directly passing along the
specified kwargs to the constructor, along with the model parameters."""
_optimizers_description = (
'The optimizers registry is used to register classes that implement the Optimizer interface. '
+
'The optimizer will be passed to the optimizers arg of the Trainer. The optimizer will be constructed by directly passing along the '
+ 'specified kwargs to the constructor, along with the model parameters.')
optimizers = create_registry('llmfoundry',
'optimizers',
generic_type=Type[Optimizer],
entry_points=True,
description=_optimizers_description)

_algorithms_description = """The algorithms registry is used to register classes that implement the Algorithm interface.
The algorithm will be passed to the algorithms arg of the Trainer. The algorithm will be constructed by directly passing along the
specified kwargs to the constructor."""
_algorithms_description = (
'The algorithms registry is used to register classes that implement the Algorithm interface. '
+
'The algorithm will be passed to the algorithms arg of the Trainer. The algorithm will be constructed by directly passing along the '
+ 'specified kwargs to the constructor.')
algorithms = create_registry('llmfoundry',
'algorithms',
generic_type=Type[Algorithm],
entry_points=True,
description=_algorithms_description)

_schedulers_description = """The schedulers registry is used to register classes that implement the ComposerScheduler interface.
The scheduler will be passed to the schedulers arg of the Trainer. The scheduler will be constructed by directly passing along the
specified kwargs to the constructor."""
_schedulers_description = (
'The schedulers registry is used to register classes that implement the ComposerScheduler interface. '
+
'The scheduler will be passed to the schedulers arg of the Trainer. The scheduler will be constructed by directly passing along the '
+ 'specified kwargs to the constructor.')
schedulers = create_registry('llmfoundry',
'schedulers',
generic_type=Type[ComposerScheduler],
Expand Down
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
'beautifulsoup4>=4.12.2,<5', # required for model download utils
'tenacity>=8.2.3,<9',
'catalogue>=2,<3',
'typer[all]<1',
]

extra_deps = {}
Expand Down Expand Up @@ -145,4 +146,7 @@
install_requires=install_requires,
extras_require=extra_deps,
python_requires='>=3.9',
entry_points={
'console_scripts': ['llmfoundry = llmfoundry.cli.cli:app'],
},
)
22 changes: 22 additions & 0 deletions tests/cli/test_registry_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from llmfoundry import registry
from llmfoundry.cli.registry_cli import _get_registries
from llmfoundry.utils.registry_utils import TypedRegistry


def test_get_registries():
available_registries = _get_registries()
expected_registries = [
getattr(registry, r)
for r in dir(registry)
if isinstance(getattr(registry, r), TypedRegistry)
]
assert available_registries == expected_registries


def test_get_registries_group():
available_registries = _get_registries('loggers')
assert len(available_registries) == 1
assert available_registries[0].namespace == ('llmfoundry', 'loggers')

0 comments on commit 67dcab9

Please sign in to comment.