Skip to content

Commit

Permalink
Plugin/mlflow custom flavor - fixed small bug (#1068)
Browse files Browse the repository at this point in the history
Thanks for adding the custom flavor option to mlflow. I tested it and only the flavor specific kwargs somehow got mixed up and tried to be called somewhere else which lead to a fatal error.

Changes
renamed class attribute to avoid conflicting calls

How I tested this
added unit test that doubles as minimal example for reproducing the error

Notes
only tested the DataSaver (due to internal legacy structure, DataLoader is not utilised/hard to implement)

* working mlflow plugin need to add kwargs for custom flavor

* flavor kwargs used outside of mlflow saver/loader class

* updated class description

* flavor kwargs used outside of mlflow saver/loader class

* removed incidental files

---------

Co-authored-by: Jernej Frank <jernej.frank@oxehealth.com>
  • Loading branch information
2 people authored and skrawcz committed Jul 27, 2024
1 parent 92e80f1 commit f26ddfb
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 11 deletions.
1 change: 1 addition & 0 deletions examples/hamilton_ui/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def run(load_from_parquet: bool, username: str, project_id: int):
"template": "machine_learning",
"loading_data_from": "parquet" if load_from_parquet else "api",
"TODO": "add_more_tags_to_find_your_run_later",
"custom_tag": "adding custom tag",
},
)

Expand Down
18 changes: 9 additions & 9 deletions hamilton/plugins/mlflow_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@ class MLFlowModelSaver(DataSaver):
:param register_as: If not None, register the model under the specified name.
:param flavor: Library format to save the model (sklearn, xgboost, etc.). Automatically inferred if None.
:param run_id: Log model to a specific run. Leave to `None` if using the `MLFlowTracker`
:param kwargs: Arguments for `.log_model()`. Can be flavor-specific.
:param mlflow_kwargs: Arguments for `.log_model()`. Can be flavor-specific.
"""

path: Union[str, pathlib.Path] = "model"
register_as: Optional[str] = None
flavor: Optional[Union[str, ModuleType]] = None
run_id: Optional[str] = None
kwargs: Dict[str, Any] = None
mlflow_kwargs: Dict[str, Any] = None

def __post_init__(self):
self.kwargs = self.kwargs if self.kwargs else {}
self.mlflow_kwargs = self.mlflow_kwargs if self.mlflow_kwargs else {}

@classmethod
def name(cls) -> str:
Expand Down Expand Up @@ -69,11 +69,11 @@ def save_data(self, data) -> Dict[str, Any]:

# save to active run
if mlflow.active_run():
model_info = flavor_module.log_model(data, self.path, **self.kwargs)
model_info = flavor_module.log_model(data, self.path, **self.mlflow_kwargs)
# create a run with `run_id` and save to it
else:
with mlflow.start_run(run_id=self.run_id):
model_info = flavor_module.log_model(data, self.path, **self.kwargs)
model_info = flavor_module.log_model(data, self.path, **self.mlflow_kwargs)

# create metadata from ModelInfo object
metadata = {k.strip("_"): v for k, v in model_info.__dict__.items()}
Expand Down Expand Up @@ -105,7 +105,7 @@ class MLFlowModelLoader(DataLoader):
:param version: Version of the registered model. Can pass as string `v1` or integer `1`
:param version_alias: Version alias of the registered model. Specify either this or `version`
:param flavor: Library format to load the model (sklearn, xgboost, etc.). Automatically inferred if None.
:param kwargs: Arguments for `.load_model()`. Can be flavor-specific.
:param mlflow_kwargs: Arguments for `.load_model()`. Can be flavor-specific.
"""

model_uri: Optional[str] = None
Expand All @@ -116,14 +116,14 @@ class MLFlowModelLoader(DataLoader):
version: Optional[Union[str, int]] = None
version_alias: Optional[str] = None
flavor: Optional[Union[ModuleType, str]] = None
kwargs: Dict[str, Any] = None
mlflow_kwargs: Dict[str, Any] = None

# __post_init__ is required to set kwargs as empty dict because
# can't set: kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
# otherwise raises `InvalidDecoratorException` because materializer factory check
# for all params being set and `kwargs` would be unset until instantiation.
def __post_init__(self):
self.kwargs = self.kwargs if self.kwargs else {}
self.mlflow_kwargs = self.mlflow_kwargs if self.mlflow_kwargs else {}

if self.model_uri:
return
Expand Down Expand Up @@ -180,7 +180,7 @@ def load_data(self, type_: Type) -> Tuple[Any, Dict[str, Any]]:
except ImportError:
raise ImportError(f"Flavor {flavor} is unsupported by MLFlow")

model = flavor_module.load_model(model_uri=self.model_uri)
model = flavor_module.load_model(model_uri=self.model_uri, **self.mlflow_kwargs)
return model, metadata


Expand Down
25 changes: 23 additions & 2 deletions tests/plugins/test_mlflow_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from sklearn.base import BaseEstimator
from sklearn.linear_model import LinearRegression

from hamilton.io.materialization import to

# TODO move these tests to `plugin_tests` because the required read-writes can get
# complicated and tests are time consuming.

Expand Down Expand Up @@ -170,11 +172,30 @@ def test_mlflow_specify_flavor_using_module(fitted_sklearn_model: BaseEstimator,
def test_mlflow_handle_saver_kwargs():
path = "tmp/path"
flavor = "sklearn"
saver = MLFlowModelSaver(path=path, flavor=flavor, kwargs=dict(unknown_kwarg=True))
saver = MLFlowModelSaver(path=path, flavor=flavor, mlflow_kwargs=dict(unknown_kwarg=True))

assert saver.path == path
assert saver.flavor == flavor
assert saver.kwargs.get("unknown_kwarg") is True
assert saver.mlflow_kwargs.get("unknown_kwarg") is True


def test_io_to_mlflow_handle_saver_kwargs_():
path = "tmp/path"
flavor = "sklearn"
id = "saver_id"
dependencies = ["tmp_node"]
saver = to.mlflow(
path=path,
flavor=flavor,
id=id,
dependencies=dependencies,
mlflow_kwargs=dict(unknown_kwarg=True),
)
mlflow_saver = vars(saver)["data_saver_kwargs"]

assert mlflow_saver["path"].value == path
assert mlflow_saver["flavor"].value == flavor
assert mlflow_saver["mlflow_kwargs"].value.get("unknown_kwarg") is True


def test_mlflow_registered_model_metadata(fitted_sklearn_model: BaseEstimator, tmp_path: Path):
Expand Down

0 comments on commit f26ddfb

Please sign in to comment.