Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RegistryMixin improved alias management #404

Merged
merged 11 commits into from
Jan 23, 2024
129 changes: 113 additions & 16 deletions src/sparsezoo/utils/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,32 @@
"register",
"get_from_registry",
"registered_names",
"registered_aliases",
"standardize_lookup_name",
]


_ALIAS_REGISTRY: Dict[Type, Dict[str, str]] = defaultdict(dict)
_REGISTRY: Dict[Type, Dict[str, Any]] = defaultdict(dict)


def standardize_lookup_name(name: str) -> str:
"""
Standardize the given name for lookup in the registry.
This will replace all underscores and spaces with hyphens and
convert the name to lowercase.
Satrat marked this conversation as resolved.
Show resolved Hide resolved

example:
```
standardize_lookup_name("Foo_bar baz") == "foo-bar-baz"
```

:param name: name to standardize
:return: standardized name
"""
return name.replace("_", "-").replace(" ", "-").lower()


class RegistryMixin:
"""
Universal registry to support registration and loading of child classes and plugins
Expand Down Expand Up @@ -64,10 +84,16 @@ class ImageNetDataset(Dataset):
class Cifar(Dataset):
pass

Note: the name will be standardized for lookup in the registry.
For example, if a class is registered as "cifar_dataset" or
"cifar dataset", it will be stored as "cifar-dataset". The user
will be able to load the class with any of the three name variants.

# register with multiple aliases
@Dataset.register(name=["cifar-10-dataset", "cifar-100-dataset"])
@Dataset.register(alias=["cifar-10-dataset", "cifar_100_dataset"])
Satrat marked this conversation as resolved.
Show resolved Hide resolved
class Cifar(Dataset):
pass
Note: aliases will NOT be standardized for lookup in the registry.

# load as "cifar-dataset"
cifar = Dataset.load_from_registry("cifar-dataset")
Expand All @@ -82,39 +108,45 @@ class Cifar(Dataset):
registry_requires_subclass: bool = False

@classmethod
def register(cls, name: Union[List[str], str, None] = None):
def register(
cls, name: Optional[str] = None, alias: Union[List[str], str, None] = None
):
"""
Decorator for registering a value (ie class or function) wrapped by this
decorator to the base class (class that .register is called from)

:param name: name or list of names to register the wrapped value as,
defaults to value.__name__
:param alias: alias or list of aliases to register the wrapped value as,
defaults to None
:return: register decorator
"""

def decorator(value: Any):
cls.register_value(value, name=name)
cls.register_value(value, name=name, alias=alias)
return value

return decorator

@classmethod
def register_value(cls, value: Any, name: Union[List[str], str, None] = None):
def register_value(
cls, value: Any, name: str, alias: Union[str, List[str], None] = None
):
"""
Registers the given value to the class `.register_value` is called from
:param value: value to register
:param name: name or list of names to register the wrapped value as,
:param name: name to register the wrapped value as,
defaults to value.__name__
:param alias: alias or list of aliases to register the wrapped value as,
defaults to None
"""
names = name if isinstance(name, list) else [name]

for name in names:
register(
parent_class=cls,
value=value,
name=name,
require_subclass=cls.registry_requires_subclass,
)
register(
parent_class=cls,
value=value,
name=name,
alias=alias,
require_subclass=cls.registry_requires_subclass,
)

@classmethod
def load_from_registry(cls, name: str, **constructor_kwargs) -> object:
Expand Down Expand Up @@ -149,24 +181,37 @@ def registered_names(cls) -> List[str]:
"""
return registered_names(cls)

@classmethod
def registered_aliases(cls) -> List[str]:
"""
:return: list of all aliases registered to this class
"""
return registered_aliases(cls)


def register(
parent_class: Type,
value: Any,
name: Optional[str] = None,
alias: Union[List[str], str, None] = None,
require_subclass: bool = False,
):
"""
:param parent_class: class to register the name under
:param value: the value to register
:param name: name to register the wrapped value as, defaults to value.__name__
:param alias: alias or list of aliases to register the wrapped value as,
defaults to None
:param require_subclass: require that value is a subclass of the class this
method is called from
"""
if name is None:
# default name
name = value.__name__

name = standardize_lookup_name(name)
register_alias(name=name, alias=alias, parent_class=parent_class)
Satrat marked this conversation as resolved.
Show resolved Hide resolved
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved

if require_subclass:
_validate_subclass(parent_class, value)

Expand Down Expand Up @@ -194,19 +239,24 @@ def get_from_registry(
:return: value from retrieved the registry for the given name, raises
error if not found
"""
name = standardize_lookup_name(name)

if ":" in name:
# user specifying specific module to load and value to import
module_path, value_name = name.split(":")
retrieved_value = _import_and_get_value_from_module(module_path, value_name)
else:
# look up name in alias registry
name = _ALIAS_REGISTRY[parent_class].get(name)
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
# look up name in registry
retrieved_value = _REGISTRY[parent_class].get(name)
if retrieved_value is None:
raise KeyError(
f"Unable to find {name} registered under type {parent_class}. "
f"Unable to find {name} registered under type {parent_class}.\n"
f"Registered values for {parent_class}: "
f"{registered_names(parent_class)}"
f"{registered_names(parent_class)}\n"
f"Registered aliases for {parent_class}: "
f"{registered_aliases(parent_class)}"
)

if require_subclass:
Expand All @@ -223,6 +273,53 @@ def registered_names(parent_class: Type) -> List[str]:
return list(_REGISTRY[parent_class].keys())


def registered_aliases(parent_class: Type) -> List[str]:
"""
:param parent_class: class to look up the registry of
:return: all aliases registered to the given class
"""
registered_aliases_plus_names = list(_ALIAS_REGISTRY[parent_class].keys())
registered_aliases = list(
set(registered_aliases_plus_names) - set(registered_names(parent_class))
)
return registered_aliases


def register_alias(
name: str, parent_class: Type, alias: Union[str, List[str], None] = None
):
"""
Updates the mapping from the alias(es) to the given name.
If the alias is None, the name is used as the alias.
```

:param name: name that the alias refers to
:param parent_class: class that the name is registered under
:param alias: single alias or list of aliases that
refer to the name, defaults to None
"""
if alias is not None:
alias = alias if isinstance(alias, list) else [alias]
else:
alias = []

if name in alias:
raise KeyError(
f"Attempting to register alias {name}, "
f"that is identical to the standardized name: {name}."
)
alias.append(name)

for alias_name in alias:
if alias_name in _ALIAS_REGISTRY[parent_class]:
raise KeyError(
f"Attempting to register alias {alias_name} as {name} "
f"however {alias_name} has already been registered as "
f"{_ALIAS_REGISTRY[alias_name]}"
)
_ALIAS_REGISTRY[parent_class][alias_name] = name


def _import_and_get_value_from_module(module_path: str, value_name: str) -> Any:
# import the given module path and try to get the value_name if it is included
# in the module
Expand Down
Loading
Loading