Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🏷 Rename --model_config_path to config #246

Merged
merged 5 commits into from
Apr 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,18 @@ file, [`config.yaml`](https://gitlab-icv.inn.intel.com/algo_rnd_team/anomaly/-/b
category, the config file is to be provided:

```bash
python tools/train.py --model_config_path <path/to/model/config.yaml>
python tools/train.py --config <path/to/model/config.yaml>
```

For example, to train [PADIM](anomalib/models/padim) you can use

```bash
python tools/train.py --model_config_path anomalib/models/padim/config.yaml
python tools/train.py --config anomalib/models/padim/config.yaml
```

Note that `--model_config_path` will be deprecated in `v0.2.8` and removed
in `v0.2.9`.

Alternatively, a model name could also be provided as an argument, where the scripts automatically finds the corresponding config file.

```bash
Expand Down Expand Up @@ -138,7 +141,7 @@ The following command can be used to run inference from the command line:

```bash
python tools/inference.py \
--model_config_path <path/to/model/config.yaml> \
--config <path/to/model/config.yaml> \
--weight_path <path/to/weight/file> \
--image_path <path/to/image>
```
Expand All @@ -147,7 +150,7 @@ As a quick example:

```bash
python tools/inference.py \
--model_config_path anomalib/models/padim/config.yaml \
--config anomalib/models/padim/config.yaml \
--weight_path results/padim/mvtec/bottle/weights/model.ckpt \
--image_path datasets/MVTec/bottle/test/broken_large/000.png
```
Expand All @@ -164,7 +167,7 @@ Example OpenVINO Inference:

```bash
python tools/inference.py \
--model_config_path \
--config \
anomalib/models/padim/config.yaml \
--weight_path \
results/padim/mvtec/bottle/compressed/compressed_model.xml \
Expand Down Expand Up @@ -200,7 +203,7 @@ Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License
| DFM | ResNet-18 | 0.894 | 0.864 | 0.558 | 0.945 | 0.984 | 0.946 | 0.994 | 0.913 | 0.871 | 0.979 | 0.941 | 0.838 | 0.761 | 0.95 | 0.911 | 0.949 |
| DFKDE | Wide ResNet-50 | 0.774 | 0.708 | 0.422 | 0.905 | 0.959 | 0.903 | 0.936 | 0.746 | 0.853 | 0.736 | 0.687 | 0.749 | 0.574 | 0.697 | 0.843 | 0.892 |
| DFKDE | ResNet-18 | 0.762 | 0.646 | 0.577 | 0.669 | 0.965 | 0.863 | 0.951 | 0.751 | 0.698 | 0.806 | 0.729 | 0.607 | 0.694 | 0.767 | 0.839 | 0.866 |
| GANomaly | | 0.421 | 0.203 | 0.404 | 0.413 | 0.408 | 0.744 | 0.251 | 0.457 | 0.682 | 0.537 | 0.270 | 0.472 | 0.231 | 0.372 | 0.440 | 0.434 |
| GANomaly | | 0.421 | 0.203 | 0.404 | 0.413 | 0.408 | 0.744 | 0.251 | 0.457 | 0.682 | 0.537 | 0.270 | 0.472 | 0.231 | 0.372 | 0.440 | 0.434 |

### Pixel-Level AUC

Expand Down Expand Up @@ -229,7 +232,7 @@ Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License
| DFM | ResNet-18 | 0.919 | 0.895 | 0.844 | 0.926 | 0.971 | 0.948 | 0.977 | 0.874 | 0.935 | 0.957 | 0.958 | 0.921 | 0.874 | 0.933 | 0.833 | 0.943 |
| DFKDE | Wide ResNet-50 | 0.875 | 0.907 | 0.844 | 0.905 | 0.945 | 0.914 | 0.946 | 0.790 | 0.914 | 0.817 | 0.894 | 0.922 | 0.855 | 0.845 | 0.722 | 0.910 |
| DFKDE | ResNet-18 | 0.872 | 0.864 | 0.844 | 0.854 | 0.960 | 0.898 | 0.942 | 0.793 | 0.908 | 0.827 | 0.894 | 0.916 | 0.859 | 0.853 | 0.756 | 0.916 |
| GANomaly | | 0.834 | 0.864 | 0.844 | 0.852 | 0.836 | 0.863 | 0.863 | 0.760 | 0.905 | 0.777 | 0.894 | 0.916 | 0.853 | 0.833 | 0.571 | 0.881 |
| GANomaly | | 0.834 | 0.864 | 0.844 | 0.852 | 0.836 | 0.863 | 0.863 | 0.760 | 0.905 | 0.777 | 0.894 | 0.916 | 0.853 | 0.833 | 0.571 | 0.881 |

## Reference
If you use this library and love it, use this to cite it 🤗
Expand Down
12 changes: 6 additions & 6 deletions anomalib/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def update_multi_gpu_training_config(config: Union[DictConfig, ListConfig]) -> U

def get_configurable_parameters(
model_name: Optional[str] = None,
model_config_path: Optional[Union[Path, str]] = None,
config_path: Optional[Union[Path, str]] = None,
weight_file: Optional[str] = None,
config_filename: Optional[str] = "config",
config_file_extension: Optional[str] = "yaml",
Expand All @@ -122,24 +122,24 @@ def get_configurable_parameters(

Args:
model_name: Optional[str]: (Default value = None)
model_config_path: Optional[Union[Path, str]]: (Default value = None)
config_path: Optional[Union[Path, str]]: (Default value = None)
weight_file: Path to the weight file
config_filename: Optional[str]: (Default value = "config")
config_file_extension: Optional[str]: (Default value = "yaml")

Returns:
Union[DictConfig, ListConfig]: Configurable parameters in DictConfig object.
"""
if model_name is None and model_config_path is None:
if model_name is None and config_path is None:
raise ValueError(
"Both model_name and model config path cannot be None! "
"Please provide a model name or path to a config file!"
)

if model_config_path is None:
model_config_path = Path(f"anomalib/models/{model_name}/{config_filename}.{config_file_extension}")
if config_path is None:
config_path = Path(f"anomalib/models/{model_name}/{config_filename}.{config_file_extension}")

config = OmegaConf.load(model_config_path)
config = OmegaConf.load(config_path)

# Dataset Configs
if "format" not in config.dataset.keys():
Expand Down
8 changes: 5 additions & 3 deletions anomalib/deploy/inferencers/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# See the License for the specific language governing permissions
# and limitations under the License.

import importlib
from importlib.util import find_spec
from pathlib import Path
from typing import Dict, Optional, Tuple, Union

Expand All @@ -26,8 +26,10 @@

from .base import Inferencer

if importlib.util.find_spec("openvino") is not None:
from openvino.inference_engine import IECore # pylint: disable=no-name-in-module
if find_spec("openvino") is not None:
from openvino.inference_engine import ( # type: ignore # pylint: disable=no-name-in-module
IECore,
)


class OpenVINOInferencer(Inferencer):
Expand Down
5 changes: 4 additions & 1 deletion docs/source/guides/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ file is to be provided:

::

python tools/train.py --model_config_path <path/to/model/config.yaml>
python tools/train.py --config <path/to/model/config.yaml>

Note that `--model_config_path` will be deprecated in `v0.2.8` and removed
in `v0.2.9`.

Alternatively, a model name could also be provided as an argument, where
the scripts automatically finds the corresponding config file.
Expand Down
8 changes: 3 additions & 5 deletions tests/helpers/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
def get_test_configurable_parameters(
dataset_path: Optional[str] = None,
model_name: Optional[str] = None,
model_config_path: Optional[Union[Path, str]] = None,
config_path: Optional[Union[Path, str]] = None,
weight_file: Optional[str] = None,
config_filename: Optional[str] = "config",
config_file_extension: Optional[str] = "yaml",
Expand All @@ -21,7 +21,7 @@ def get_test_configurable_parameters(
Args:
datset_path: Optional[Path]: Path to dataset
model_name: Optional[str]: (Default value = None)
model_config_path: Optional[Union[Path, str]]: (Default value = None)
config_path: Optional[Union[Path, str]]: (Default value = None)
weight_file: Path to the weight file
config_filename: Optional[str]: (Default value = "config")
config_file_extension: Optional[str]: (Default value = "yaml")
Expand All @@ -30,9 +30,7 @@ def get_test_configurable_parameters(
Union[DictConfig, ListConfig]: Configurable parameters in DictConfig object.
"""

config = get_configurable_parameters(
model_name, model_config_path, weight_file, config_filename, config_file_extension
)
config = get_configurable_parameters(model_name, config_path, weight_file, config_filename, config_file_extension)

# Update path to match the dataset path in the test image/runner
config.dataset.path = get_dataset_path() if dataset_path is None else dataset_path
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def run_train_test(config):

@TestDataset(num_train=200, num_test=30, path=get_dataset_path(), seed=42)
def test_normalizer(path=get_dataset_path(), category="shapes"):
config = get_configurable_parameters(model_config_path="anomalib/models/padim/config.yaml")
config = get_configurable_parameters(config_path="anomalib/models/padim/config.yaml")
config.dataset.path = path
config.dataset.category = category
config.model.threshold.adaptive = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_openvino_model_callback():
"""Tests if an optimized model is created."""

config = get_test_configurable_parameters(
model_config_path="tests/pre_merge/utils/callbacks/openvino_callback/dummy_config.yml"
config_path="tests/pre_merge/utils/callbacks/openvino_callback/dummy_config.yml"
)

with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down
2 changes: 1 addition & 1 deletion tests/pre_merge/utils/metrics/test_adaptive_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_non_adaptive_threshold():
Test if the non-adaptive threshold gets used in the F1 score computation when
adaptive thresholding is disabled and no normalization is used.
"""
config = get_test_configurable_parameters(model_config_path="anomalib/models/padim/config.yaml")
config = get_test_configurable_parameters(config_path="anomalib/models/padim/config.yaml")

config.model.normalization_method = "none"
config.model.threshold.adaptive = False
Expand Down
2 changes: 1 addition & 1 deletion tools/benchmarking/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Note: To collect memory read/write numbers, run the script with sudo privileges.
```
sudo -E ./train.sh # Train STFPM on MVTec AD leather

sudo -E ./train.sh --model_config_path <path/to/model/config.yaml>
sudo -E ./train.sh --config <path/to/model/config.yaml>

sudo -E ./train.sh --model stfpm
```
Expand Down
11 changes: 7 additions & 4 deletions tools/hpo/wandb_sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@

from argparse import ArgumentParser
from pathlib import Path
from typing import Union

import pytorch_lightning as pl
from omegaconf import DictConfig, OmegaConf
from omegaconf import DictConfig, ListConfig, OmegaConf
from pytorch_lightning import seed_everything
from pytorch_lightning.loggers import WandbLogger
from utils import flatten_hpo_params
Expand All @@ -38,12 +39,14 @@ class WandbSweep:
sweep_config (DictConfig): Sweep configuration.
"""

def __init__(self, config: DictConfig, sweep_config: DictConfig) -> None:
def __init__(self, config: Union[DictConfig, ListConfig], sweep_config: Union[DictConfig, ListConfig]) -> None:
self.config = config
self.sweep_config = sweep_config
self.observation_budget = sweep_config.observation_budget
if "observation_budget" in self.sweep_config.keys():
self.sweep_config.pop("observation_budget")
# this instance check is to silence mypy.
if isinstance(self.sweep_config, DictConfig):
self.sweep_config.pop("observation_budget")

def run(self):
"""Run the sweep."""
Expand Down Expand Up @@ -86,7 +89,7 @@ def get_args():

if __name__ == "__main__":
args = get_args()
model_config = get_configurable_parameters(model_name=args.model, model_config_path=args.model_config)
model_config = get_configurable_parameters(model_name=args.model, config_path=args.model_config)
hpo_config = OmegaConf.load(args.sweep_config)

if model_config.project.seed != 0:
Expand Down
4 changes: 2 additions & 2 deletions tools/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def get_args() -> Namespace:
Namespace: List of arguments.
"""
parser = ArgumentParser()
parser.add_argument("--model_config_path", type=Path, required=True, help="Path to a model config file")
parser.add_argument("--config", type=Path, required=True, help="Path to a model config file")
parser.add_argument("--weight_path", type=Path, required=True, help="Path to a model weights")
parser.add_argument("--image_path", type=Path, required=True, help="Path to an image to infer.")
parser.add_argument("--save_path", type=Path, required=False, help="Path to save the output image.")
Expand Down Expand Up @@ -75,7 +75,7 @@ def stream() -> None:
# This config file is also used for training and contains all the relevant
# information regarding the data, model, train and inference details.
args = get_args()
config = get_configurable_parameters(model_config_path=args.model_config_path)
config = get_configurable_parameters(config_path=args.config)

# Get the inferencer. We use .ckpt extension for Torch models and (onnx, bin)
# for the openvino models.
Expand Down
18 changes: 14 additions & 4 deletions tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions
# and limitations under the License.

import warnings
from argparse import ArgumentParser, Namespace

from pytorch_lightning import Trainer
Expand All @@ -32,11 +33,21 @@ def get_args() -> Namespace:
"""
parser = ArgumentParser()
parser.add_argument("--model", type=str, default="stfpm", help="Name of the algorithm to train/test")
# --model_config_path will be deprecated in 0.2.8 and removed in 0.2.9
parser.add_argument("--model_config_path", type=str, required=False, help="Path to a model config file")
parser.add_argument("--config", type=str, required=False, help="Path to a model config file")
parser.add_argument("--weight_file", type=str, default="weights/model.ckpt")
parser.add_argument("--openvino", type=bool, default=False)

return parser.parse_args()
args = parser.parse_args()
if args.model_config_path is not None:
warnings.warn(
message="--model_config_path will be deprecated in v0.2.8 and removed in v0.2.9. Use --config instead.",
category=DeprecationWarning,
stacklevel=2,
)
args.config = args.model_config_path

return args


def test():
Expand All @@ -47,9 +58,8 @@ def test():
args = get_args()
config = get_configurable_parameters(
model_name=args.model,
model_config_path=args.model_config_path,
config_path=args.config,
weight_file=args.weight_file,
openvino=args.openvino,
)

datamodule = get_datamodule(config)
Expand Down
16 changes: 14 additions & 2 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# See the License for the specific language governing permissions
# and limitations under the License.

import warnings
from argparse import ArgumentParser, Namespace

from pytorch_lightning import Trainer, seed_everything
Expand All @@ -38,15 +39,26 @@ def get_args() -> Namespace:
"""
parser = ArgumentParser()
parser.add_argument("--model", type=str, default="padim", help="Name of the algorithm to train/test")
# --model_config_path will be deprecated in 0.2.8 and removed in 0.2.9
parser.add_argument("--model_config_path", type=str, required=False, help="Path to a model config file")
parser.add_argument("--config", type=str, required=False, help="Path to a model config file")

return parser.parse_args()
args = parser.parse_args()
if args.model_config_path is not None:
warnings.warn(
message="--model_config_path will be deprecated in v0.2.8 and removed in v0.2.9. Use --config instead.",
category=DeprecationWarning,
stacklevel=2,
)
args.config = args.model_config_path

return args


def train():
"""Train an anomaly classification or segmentation model based on a provided configuration file."""
args = get_args()
config = get_configurable_parameters(model_name=args.model, model_config_path=args.model_config_path)
config = get_configurable_parameters(model_name=args.model, config_path=args.config)

if config.project.seed != 0:
seed_everything(config.project.seed)
Expand Down