Skip to content

Commit

Permalink
✨ Add hpo search using wandb (#82)
Browse files Browse the repository at this point in the history
  • Loading branch information
ashwinvaidya17 committed Apr 6, 2022
1 parent fd692ea commit aa08a99
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 1 deletion.
3 changes: 2 additions & 1 deletion anomalib/utils/sweep/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# See the License for the specific language governing permissions
# and limitations under the License.

from .config import get_run_config, set_in_nested_config
from .config import flatten_sweep_params, get_run_config, set_in_nested_config
from .helpers import (
get_meta_data,
get_openvino_throughput,
Expand All @@ -29,4 +29,5 @@
"get_meta_data",
"get_openvino_throughput",
"get_torch_throughput",
"flatten_sweep_params",
]
13 changes: 13 additions & 0 deletions tools/hpo/sweep.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
observation_budget: 10
method: bayes
metric:
name: pixel_AUROC
goal: minimize
parameters:
dataset:
category: capsule
image_size:
values: [128, 256]
model:
backbone:
values: [resnet18, wide_resnet50_2]
19 changes: 19 additions & 0 deletions tools/hpo/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Utils to help in HPO search."""

# Copyright (C) 2021 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

from .config import flatten_hpo_params

__all__ = ["flatten_hpo_params"]
54 changes: 54 additions & 0 deletions tools/hpo/utils/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""Utils to update configuration files."""

# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

from typing import List

from omegaconf import DictConfig


def flatten_hpo_params(params_dict: DictConfig) -> DictConfig:
"""Flatten the nested hpo parameter section of the config object.
Args:
params_dict: DictConfig: The dictionary containing the hpo parameters in the original, nested, structure.
Returns:
flattened version of the parameter dictionary.
"""

def process_params(nested_params: DictConfig, keys: List[str], flattened_params: DictConfig):
"""Flatten nested dictionary till the time it reaches the hpo params.
Recursive helper function that traverses the nested config object and stores the leaf nodes in a flattened
dictionary.
Args:
nested_params: DictConfig: config object containing the original parameters.
keys: List[str]: list of keys leading to the current location in the config.
flattened_params: DictConfig: Dictionary in which the flattened parameters are stored.
"""
if len({"values", "min", "max"}.intersection(nested_params.keys())) > 0:
key = ".".join(keys)
flattened_params[key] = nested_params
else:
for name, cfg in nested_params.items():
if isinstance(cfg, DictConfig):
process_params(cfg, keys + [str(name)], flattened_params)

flattened_params_dict = DictConfig({})
process_params(params_dict, [], flattened_params_dict)

return flattened_params_dict
96 changes: 96 additions & 0 deletions tools/hpo/wandb_sweep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""Run wandb sweep."""

# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

from argparse import ArgumentParser
from pathlib import Path

import pytorch_lightning as pl
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import seed_everything
from pytorch_lightning.loggers import WandbLogger
from utils import flatten_hpo_params

import wandb
from anomalib.config import get_configurable_parameters, update_input_size_config
from anomalib.data import get_datamodule
from anomalib.models import get_model
from anomalib.utils.sweep import flatten_sweep_params, set_in_nested_config


class WandbSweep:
"""wandb sweep.
Args:
config (DictConfig): Original model configuration.
sweep_config (DictConfig): Sweep configuration.
"""

def __init__(self, config: DictConfig, sweep_config: DictConfig) -> None:
self.config = config
self.sweep_config = sweep_config
self.observation_budget = sweep_config.observation_budget
if "observation_budget" in self.sweep_config.keys():
self.sweep_config.pop("observation_budget")

def run(self):
"""Run the sweep."""
flattened_hpo_params = flatten_hpo_params(self.sweep_config.parameters)
self.sweep_config.parameters = flattened_hpo_params
sweep_id = wandb.sweep(
OmegaConf.to_object(self.sweep_config),
project=f"{self.config.model.name}_{self.config.dataset.name}",
)
wandb.agent(sweep_id, function=self.sweep, count=self.observation_budget)

def sweep(self):
"""Method to load the model, update config and call fit. The metrics are logged to ```wandb``` dashboard."""
wandb_logger = WandbLogger(config=flatten_sweep_params(self.sweep_config), log_model=False)
sweep_config = wandb_logger.experiment.config

for param in sweep_config.keys():
set_in_nested_config(self.config, param.split("."), sweep_config[param])
config = update_input_size_config(self.config)

model = get_model(config)
datamodule = get_datamodule(config)

# Disable saving checkpoints as all checkpoints from the sweep will get uploaded
config.trainer.checkpoint_callback = False

trainer = pl.Trainer(**config.trainer, logger=wandb_logger)
trainer.fit(model, datamodule=datamodule)


def get_args():
"""Gets parameters from commandline."""
parser = ArgumentParser()
parser.add_argument("--model", type=str, default="padim", help="Name of the algorithm to train/test")
parser.add_argument("--model_config", type=Path, required=False, help="Path to a model config file")
parser.add_argument("--sweep_config", type=Path, required=True, help="Path to sweep configuration")

return parser.parse_args()


if __name__ == "__main__":
args = get_args()
model_config = get_configurable_parameters(model_name=args.model, model_config_path=args.model_config)
hpo_config = OmegaConf.load(args.sweep_config)

if model_config.project.seed != 0:
seed_everything(model_config.project.seed)

sweep = WandbSweep(model_config, hpo_config)
sweep.run()

0 comments on commit aa08a99

Please sign in to comment.