Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add hpo search using wandb #82

Merged
merged 8 commits into from
Apr 6, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion anomalib/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
from typing import Any, List, Optional, Union

import numpy as np
import wandb
from matplotlib.figure import Figure
from pytorch_lightning.loggers.wandb import WandbLogger
from pytorch_lightning.utilities import rank_zero_only

import wandb

from .base import ImageLoggerBase


Expand Down
13 changes: 13 additions & 0 deletions tools/hpo/sweep.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
observation_budget: 10
method: bayes
metric:
name: pixel_AUROC
goal: minimize
parameters:
dataset:
category: capsule
image_size:
values: [128, 256]
model:
backbone:
values: [resnet18, wide_resnet50_2]
15 changes: 15 additions & 0 deletions tools/hpo/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Utils to help in HPO search."""

# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
111 changes: 111 additions & 0 deletions tools/hpo/utils/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""Utils to update configuration files."""

# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

import operator
from functools import reduce
from typing import Any, List, Union

from omegaconf import DictConfig, ListConfig


def flatten_sweep_params(params_dict: DictConfig) -> DictConfig:
"""Flatten the nested parameters section of the config object.

Args:
params_dict: DictConfig: The dictionary containing the hpo parameters in the original, nested, structure.

Returns:
flattened version of the parameter dictionary.
"""

def process_params(nested_params: DictConfig, keys: List[str], flattened_params: DictConfig):
"""Flatten nested dictionary.

Recursive helper function that traverses the nested config object and stores the leaf nodes in a flattened
dictionary.

Args:
nested_params: DictConfig: config object containing the original parameters.
keys: List[str]: list of keys leading to the current location in the config.
flattened_params: DictConfig: Dictionary in which the flattened parameters are stored.
"""
for name, cfg in nested_params.items():
if isinstance(cfg, DictConfig):
process_params(cfg, keys + [str(name)], flattened_params)
else:
key = ".".join(keys + [str(name)])
flattened_params[key] = cfg

flattened_params_dict = DictConfig({})
process_params(params_dict, [], flattened_params_dict)

return flattened_params_dict


def flatten_hpo_params(params_dict: DictConfig) -> DictConfig:
"""Flatten the nested hpo parameter section of the config object.

Args:
params_dict: DictConfig: The dictionary containing the hpo parameters in the original, nested, structure.

Returns:
flattened version of the parameter dictionary.
"""

def process_params(nested_params: DictConfig, keys: List[str], flattened_params: DictConfig):
"""Flatten nexted dictionary till the time it reaches the hpo params.

Recursive helper function that traverses the nested config object and stores the leaf nodes in a flattened
dictionary.

Args:
nested_params: DictConfig: config object containing the original parameters.
keys: List[str]: list of keys leading to the current location in the config.
flattened_params: DictConfig: Dictionary in which the flattened parameters are stored.
"""
if len({"values", "min", "max"}.intersection(nested_params.keys())) > 0:
key = ".".join(keys)
flattened_params[key] = nested_params
else:
for name, cfg in nested_params.items():
if isinstance(cfg, DictConfig):
process_params(cfg, keys + [str(name)], flattened_params)

flattened_params_dict = DictConfig({})
process_params(params_dict, [], flattened_params_dict)

return flattened_params_dict


def get_from_nested_config(config: Union[DictConfig, ListConfig], keymap: List) -> Any:
"""Retrieves an item from a nested config object using a list of keys.

Args:
config: DictConfig: nested DictConfig object
keymap: List[str]: list of keys corresponding to item that should be retrieved.
"""
return reduce(operator.getitem, keymap, config)


def set_in_nested_config(config: Union[DictConfig, ListConfig], keymap: List, value: Any):
"""Set an item in a nested config object using a list of keys.

Args:
config: DictConfig: nested DictConfig object
keymap: List[str]: list of keys corresponding to item that should be set.
value: Any: Value that should be assigned to the dictionary item at the specified location.
"""
get_from_nested_config(config, keymap[:-1])[keymap[-1]] = value
95 changes: 95 additions & 0 deletions tools/hpo/wandb_sweep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""Run wandb sweep."""

# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

from argparse import ArgumentParser
from pathlib import Path

import pytorch_lightning as pl
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import seed_everything
from pytorch_lightning.loggers import WandbLogger
from utils.config import flatten_hpo_params, flatten_sweep_params, set_in_nested_config

import wandb
from anomalib.config import get_configurable_parameters, update_input_size_config
from anomalib.data import get_datamodule
from anomalib.models import get_model


class WandbSweep:
"""wandb sweep.

Args:
config (DictConfig): Original model configuration.
sweep_config (DictConfig): Sweep configuration.
"""

def __init__(self, config: DictConfig, sweep_config: DictConfig) -> None:
self.config = config
self.sweep_config = sweep_config
self.observation_budget = sweep_config.observation_budget
if "observation_budget" in self.sweep_config.keys():
self.sweep_config.pop("observation_budget")

def run(self):
"""Run the sweep."""
flattened_hpo_params = flatten_hpo_params(self.sweep_config.parameters)
self.sweep_config.parameters = flattened_hpo_params
sweep_id = wandb.sweep(
OmegaConf.to_object(self.sweep_config),
project=f"{self.config.model.name}_{self.config.dataset.name}",
)
wandb.agent(sweep_id, function=self.sweep, count=self.observation_budget)

def sweep(self):
"""Method to load the model, update config and call fit. The metrics are logged to ```wandb``` dashboard."""
wandb_logger = WandbLogger(config=flatten_sweep_params(self.sweep_config), log_model=False)
sweep_config = wandb_logger.experiment.config

for param in sweep_config.keys():
set_in_nested_config(self.config, param.split("."), sweep_config[param])
config = update_input_size_config(self.config)

model = get_model(config)
datamodule = get_datamodule(config)

# Disable saving checkpoints as all checkpoints from the sweep will get uploaded
config.trainer.checkpoint_callback = False

trainer = pl.Trainer(**config.trainer, logger=wandb_logger)
trainer.fit(model, datamodule=datamodule)


def get_args():
"""Gets parameters from commandline."""
parser = ArgumentParser()
parser.add_argument("--model", type=str, default="stfpm", help="Name of the algorithm to train/test")
ashwinvaidya17 marked this conversation as resolved.
Show resolved Hide resolved
parser.add_argument("--model_config_path", type=Path, required=False, help="Path to a model config file")
parser.add_argument("--sweep_config_path", type=Path, required=True, help="Path to sweep configuration")

return parser.parse_args()


if __name__ == "__main__":
args = get_args()
model_config = get_configurable_parameters(model_name=args.model, model_config_path=args.model_config_path)
hpo_config = OmegaConf.load(args.sweep_config_path)

if model_config.project.seed != 0:
seed_everything(model_config.project.seed)

sweep = WandbSweep(model_config, hpo_config)
sweep.run()