Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐞 Fix comet HPO #597

Merged
merged 2 commits into from
Sep 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .config import flatten_hpo_params
from .runners import CometSweep, WandbSweep

__all__ = ["flatten_hpo_params"]
__all__ = ["CometSweep", "WandbSweep"]
File renamed without changes.
132 changes: 132 additions & 0 deletions anomalib/utils/hpo/runners.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""Sweep Backends."""

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from typing import Optional, Union

import pytorch_lightning as pl
from comet_ml import Optimizer
from omegaconf import DictConfig, ListConfig, OmegaConf
from pytorch_lightning.loggers import CometLogger, WandbLogger

import wandb
from anomalib.config import update_input_size_config
from anomalib.data import get_datamodule
from anomalib.models import get_model
from anomalib.utils.sweep import (
flatten_sweep_params,
get_sweep_callbacks,
set_in_nested_config,
)

from .config import flatten_hpo_params


class WandbSweep:
"""wandb sweep.

Args:
config (DictConfig): Original model configuration.
sweep_config (DictConfig): Sweep configuration.
entity (str, optional): Username or workspace to send the project to. Defaults to None.
"""

def __init__(
self,
config: Union[DictConfig, ListConfig],
sweep_config: Union[DictConfig, ListConfig],
entity: Optional[str] = None,
) -> None:
self.config = config
self.sweep_config = sweep_config
self.observation_budget = sweep_config.observation_budget
self.entity = entity
if "observation_budget" in self.sweep_config.keys():
# this instance check is to silence mypy.
if isinstance(self.sweep_config, DictConfig):
self.sweep_config.pop("observation_budget")

def run(self):
"""Run the sweep."""
flattened_hpo_params = flatten_hpo_params(self.sweep_config.parameters)
self.sweep_config.parameters = flattened_hpo_params
sweep_id = wandb.sweep(
OmegaConf.to_object(self.sweep_config),
project=f"{self.config.model.name}_{self.config.dataset.name}",
entity=self.entity,
)
wandb.agent(sweep_id, function=self.sweep, count=self.observation_budget)

def sweep(self):
"""Method to load the model, update config and call fit. The metrics are logged to ```wandb``` dashboard."""
wandb_logger = WandbLogger(config=flatten_sweep_params(self.sweep_config), log_model=False)
sweep_config = wandb_logger.experiment.config

for param in sweep_config.keys():
set_in_nested_config(self.config, param.split("."), sweep_config[param])
config = update_input_size_config(self.config)

model = get_model(config)
datamodule = get_datamodule(config)
callbacks = get_sweep_callbacks(config)

# Disable saving checkpoints as all checkpoints from the sweep will get uploaded
config.trainer.checkpoint_callback = False

trainer = pl.Trainer(**config.trainer, logger=wandb_logger, callbacks=callbacks)
trainer.fit(model, datamodule=datamodule)


class CometSweep:
"""comet sweep.

Args:
config (DictConfig): Original model configuration.
sweep_config (DictConfig): Sweep configuration.
entity (str, optional): Username or workspace to send the project to. Defaults to None.
"""

def __init__(
self,
config: Union[DictConfig, ListConfig],
sweep_config: Union[DictConfig, ListConfig],
entity: Optional[str] = None,
) -> None:
self.config = config
self.sweep_config = sweep_config
self.entity = entity

def run(self):
"""Run the sweep."""
flattened_hpo_params = flatten_hpo_params(self.sweep_config.parameters)
self.sweep_config.parameters = flattened_hpo_params

# comet's Optimizer takes dict as an input, not DictConfig
std_dict = OmegaConf.to_object(self.sweep_config)

opt = Optimizer(std_dict)

project_name = f"{self.config.model.name}_{self.config.dataset.name}"

for experiment in opt.get_experiments(project_name=project_name):
comet_logger = CometLogger(workspace=self.entity)

# allow pytorch-lightning to use the experiment from optimizer
comet_logger._experiment = experiment # pylint: disable=protected-access
run_params = experiment.params
for param in run_params.keys():
# this check is needed as comet also returns model and sweep_config as keys
if param in self.sweep_config.parameters.keys():
set_in_nested_config(self.config, param.split("."), run_params[param])
config = update_input_size_config(self.config)

model = get_model(config)
datamodule = get_datamodule(config)
callbacks = get_sweep_callbacks(config)

# Disable saving checkpoints as all checkpoints from the sweep will get uploaded
config.trainer.checkpoint_callback = False

trainer = pl.Trainer(**config.trainer, logger=comet_logger, callbacks=callbacks)
trainer.fit(model, datamodule=datamodule)
12 changes: 7 additions & 5 deletions anomalib/utils/sweep/helpers/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@ def get_sweep_callbacks(config: Union[ListConfig, DictConfig]) -> List[Callback]
config.metrics.threshold.pixel_default if "pixel_default" in config.metrics.threshold.keys() else None
)
metrics_callback = MetricsConfigurationCallback(
config.metrics.threshold.adaptive,
image_threshold,
pixel_threshold,
image_metric_names,
pixel_metric_names,
adaptive_threshold=config.metrics.threshold.adaptive,
task=config.dataset.task,
default_image_threshold=image_threshold,
default_pixel_threshold=pixel_threshold,
image_metric_names=image_metric_names,
pixel_metric_names=pixel_metric_names,
normalization_method=config.model.normalization_method,
)
callbacks.append(metrics_callback)

Expand Down
119 changes: 11 additions & 108 deletions tools/hpo/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,114 +7,11 @@
from pathlib import Path
from typing import Union

import pytorch_lightning as pl
from comet_ml import Optimizer
from omegaconf import DictConfig, ListConfig, OmegaConf
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from pytorch_lightning.loggers import CometLogger, WandbLogger
from utils import flatten_hpo_params

import wandb
from anomalib.config import get_configurable_parameters, update_input_size_config
from anomalib.data import get_datamodule
from anomalib.models import get_model
from anomalib.utils.sweep import (
flatten_sweep_params,
get_sweep_callbacks,
set_in_nested_config,
)


class WandbSweep:
"""wandb sweep.

Args:
config (DictConfig): Original model configuration.
sweep_config (DictConfig): Sweep configuration.
"""

def __init__(self, config: Union[DictConfig, ListConfig], sweep_config: Union[DictConfig, ListConfig]) -> None:
self.config = config
self.sweep_config = sweep_config
self.observation_budget = sweep_config.observation_budget
if "observation_budget" in self.sweep_config.keys():
# this instance check is to silence mypy.
if isinstance(self.sweep_config, DictConfig):
self.sweep_config.pop("observation_budget")

def run(self):
"""Run the sweep."""
flattened_hpo_params = flatten_hpo_params(self.sweep_config.parameters)
self.sweep_config.parameters = flattened_hpo_params
sweep_id = wandb.sweep(
OmegaConf.to_object(self.sweep_config),
project=f"{self.config.model.name}_{self.config.dataset.name}",
)
wandb.agent(sweep_id, function=self.sweep, count=self.observation_budget)

def sweep(self):
"""Method to load the model, update config and call fit. The metrics are logged to ```wandb``` dashboard."""
wandb_logger = WandbLogger(config=flatten_sweep_params(self.sweep_config), log_model=False)
sweep_config = wandb_logger.experiment.config

for param in sweep_config.keys():
set_in_nested_config(self.config, param.split("."), sweep_config[param])
config = update_input_size_config(self.config)

model = get_model(config)
datamodule = get_datamodule(config)
callbacks = get_sweep_callbacks(config)

# Disable saving checkpoints as all checkpoints from the sweep will get uploaded
config.trainer.checkpoint_callback = False

trainer = pl.Trainer(**config.trainer, logger=wandb_logger, callbacks=callbacks)
trainer.fit(model, datamodule=datamodule)


class CometSweep:
"""comet sweep.

Args:
config (DictConfig): Original model configuration.
sweep_config (DictConfig): Sweep configuration.
"""

def __init__(self, config: Union[DictConfig, ListConfig], sweep_config: Union[DictConfig, ListConfig]) -> None:
self.config = config
self.sweep_config = sweep_config

def run(self):
"""Run the sweep."""
flattened_hpo_params = flatten_hpo_params(self.sweep_config.parameters)
self.sweep_config.parameters = flattened_hpo_params

# comet's Optmizer cannot takes dict as an input, not DictConfig
std_dict = OmegaConf.to_object(self.sweep_config)

opt = Optimizer(std_dict)

project_name = f"{self.config.model.name}_{self.config.dataset.name}"

for exp in opt.get_experiments(project_name=project_name):
comet_logger = CometLogger()

# allow pytorch-lightning to use the experiment from optimizer
comet_logger._experiment = exp # pylint: disable=protected-access
run_params = exp.params
for param in run_params.keys():
set_in_nested_config(self.config, param.split("."), run_params[param])
config = update_input_size_config(self.config)

model = get_model(config)
datamodule = get_datamodule(config)
callbacks = get_sweep_callbacks(config)

# Disable saving checkpoints as all checkpoints from the sweep will get uploaded
config.trainer.checkpoint_callback = False

trainer = pl.Trainer(**config.trainer, logger=comet_logger, callbacks=callbacks)
trainer.fit(model, datamodule=datamodule)
from anomalib.config import get_configurable_parameters
from anomalib.utils.hpo import CometSweep, WandbSweep


def get_args():
Expand All @@ -123,6 +20,12 @@ def get_args():
parser.add_argument("--model", type=str, default="padim", help="Name of the algorithm to train/test")
parser.add_argument("--model_config", type=Path, required=False, help="Path to a model config file")
parser.add_argument("--sweep_config", type=Path, required=True, help="Path to sweep configuration")
parser.add_argument(
"--entity",
type=str,
required=False,
help="Username or workspace where you want to send your runs to. If not set, the default workspace is used.",
)

return parser.parse_args()

Expand All @@ -138,7 +41,7 @@ def get_args():
# check hpo config structure to see whether it adheres to comet or wandb format
sweep: Union[CometSweep, WandbSweep]
if "spec" in hpo_config.keys():
sweep = CometSweep(model_config, hpo_config)
sweep = CometSweep(model_config, hpo_config, entity=args.entity)
else:
sweep = WandbSweep(model_config, hpo_config)
sweep = WandbSweep(model_config, hpo_config, entity=args.entity)
sweep.run()