Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add hpo search using wandb #82

Merged
merged 8 commits into from
Apr 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion anomalib/utils/sweep/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# See the License for the specific language governing permissions
# and limitations under the License.

from .config import get_run_config, set_in_nested_config
from .config import flatten_sweep_params, get_run_config, set_in_nested_config
from .helpers import (
get_meta_data,
get_openvino_throughput,
Expand All @@ -29,4 +29,5 @@
"get_meta_data",
"get_openvino_throughput",
"get_torch_throughput",
"flatten_sweep_params",
]
13 changes: 13 additions & 0 deletions tools/hpo/sweep.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
observation_budget: 10
method: bayes
metric:
name: pixel_AUROC
goal: minimize
parameters:
dataset:
category: capsule
image_size:
values: [128, 256]
model:
backbone:
values: [resnet18, wide_resnet50_2]
19 changes: 19 additions & 0 deletions tools/hpo/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Utils to help in HPO search."""

# Copyright (C) 2021 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

from .config import flatten_hpo_params

__all__ = ["flatten_hpo_params"]
54 changes: 54 additions & 0 deletions tools/hpo/utils/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""Utils to update configuration files."""

# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

from typing import List

from omegaconf import DictConfig


def flatten_hpo_params(params_dict: DictConfig) -> DictConfig:
"""Flatten the nested hpo parameter section of the config object.

Args:
params_dict: DictConfig: The dictionary containing the hpo parameters in the original, nested, structure.

Returns:
flattened version of the parameter dictionary.
"""

def process_params(nested_params: DictConfig, keys: List[str], flattened_params: DictConfig):
"""Flatten nested dictionary till the time it reaches the hpo params.

Recursive helper function that traverses the nested config object and stores the leaf nodes in a flattened
dictionary.

Args:
nested_params: DictConfig: config object containing the original parameters.
keys: List[str]: list of keys leading to the current location in the config.
flattened_params: DictConfig: Dictionary in which the flattened parameters are stored.
"""
if len({"values", "min", "max"}.intersection(nested_params.keys())) > 0:
key = ".".join(keys)
flattened_params[key] = nested_params
else:
for name, cfg in nested_params.items():
if isinstance(cfg, DictConfig):
process_params(cfg, keys + [str(name)], flattened_params)

flattened_params_dict = DictConfig({})
process_params(params_dict, [], flattened_params_dict)

return flattened_params_dict
96 changes: 96 additions & 0 deletions tools/hpo/wandb_sweep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""Run wandb sweep."""

# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

from argparse import ArgumentParser
from pathlib import Path

import pytorch_lightning as pl
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import seed_everything
from pytorch_lightning.loggers import WandbLogger
from utils import flatten_hpo_params

import wandb
from anomalib.config import get_configurable_parameters, update_input_size_config
from anomalib.data import get_datamodule
from anomalib.models import get_model
from anomalib.utils.sweep import flatten_sweep_params, set_in_nested_config


class WandbSweep:
"""wandb sweep.

Args:
config (DictConfig): Original model configuration.
sweep_config (DictConfig): Sweep configuration.
"""

def __init__(self, config: DictConfig, sweep_config: DictConfig) -> None:
self.config = config
self.sweep_config = sweep_config
self.observation_budget = sweep_config.observation_budget
if "observation_budget" in self.sweep_config.keys():
self.sweep_config.pop("observation_budget")

def run(self):
"""Run the sweep."""
flattened_hpo_params = flatten_hpo_params(self.sweep_config.parameters)
self.sweep_config.parameters = flattened_hpo_params
sweep_id = wandb.sweep(
OmegaConf.to_object(self.sweep_config),
project=f"{self.config.model.name}_{self.config.dataset.name}",
)
wandb.agent(sweep_id, function=self.sweep, count=self.observation_budget)

def sweep(self):
"""Method to load the model, update config and call fit. The metrics are logged to ```wandb``` dashboard."""
wandb_logger = WandbLogger(config=flatten_sweep_params(self.sweep_config), log_model=False)
sweep_config = wandb_logger.experiment.config

for param in sweep_config.keys():
set_in_nested_config(self.config, param.split("."), sweep_config[param])
config = update_input_size_config(self.config)

model = get_model(config)
datamodule = get_datamodule(config)

# Disable saving checkpoints as all checkpoints from the sweep will get uploaded
config.trainer.checkpoint_callback = False

trainer = pl.Trainer(**config.trainer, logger=wandb_logger)
trainer.fit(model, datamodule=datamodule)


def get_args():
"""Gets parameters from commandline."""
parser = ArgumentParser()
parser.add_argument("--model", type=str, default="padim", help="Name of the algorithm to train/test")
parser.add_argument("--model_config", type=Path, required=False, help="Path to a model config file")
parser.add_argument("--sweep_config", type=Path, required=True, help="Path to sweep configuration")

return parser.parse_args()


if __name__ == "__main__":
args = get_args()
model_config = get_configurable_parameters(model_name=args.model, model_config_path=args.model_config)
hpo_config = OmegaConf.load(args.sweep_config)

if model_config.project.seed != 0:
seed_everything(model_config.project.seed)

sweep = WandbSweep(model_config, hpo_config)
sweep.run()