Skip to content

Commit

Permalink
Revert "🏷 Rename --model_config_path to config (#246)" (#247)
Browse files Browse the repository at this point in the history
This reverts commit ee6a112.
  • Loading branch information
samet-akcay committed Apr 21, 2022
1 parent ee6a112 commit ca18913
Show file tree
Hide file tree
Showing 13 changed files with 38 additions and 69 deletions.
17 changes: 7 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,15 @@ file, [`config.yaml`](https://gitlab-icv.inn.intel.com/algo_rnd_team/anomaly/-/b
category, the config file is to be provided:

```bash
python tools/train.py --config <path/to/model/config.yaml>
python tools/train.py --model_config_path <path/to/model/config.yaml>
```

For example, to train [PADIM](anomalib/models/padim) you can use

```bash
python tools/train.py --config anomalib/models/padim/config.yaml
python tools/train.py --model_config_path anomalib/models/padim/config.yaml
```

Note that `--model_config_path` will be deprecated in `v0.2.8` and removed
in `v0.2.9`.

Alternatively, a model name could also be provided as an argument, where the scripts automatically finds the corresponding config file.

```bash
Expand Down Expand Up @@ -141,7 +138,7 @@ The following command can be used to run inference from the command line:

```bash
python tools/inference.py \
--config <path/to/model/config.yaml> \
--model_config_path <path/to/model/config.yaml> \
--weight_path <path/to/weight/file> \
--image_path <path/to/image>
```
Expand All @@ -150,7 +147,7 @@ As a quick example:

```bash
python tools/inference.py \
--config anomalib/models/padim/config.yaml \
--model_config_path anomalib/models/padim/config.yaml \
--weight_path results/padim/mvtec/bottle/weights/model.ckpt \
--image_path datasets/MVTec/bottle/test/broken_large/000.png
```
Expand All @@ -167,7 +164,7 @@ Example OpenVINO Inference:

```bash
python tools/inference.py \
--config \
--model_config_path \
anomalib/models/padim/config.yaml \
--weight_path \
results/padim/mvtec/bottle/compressed/compressed_model.xml \
Expand Down Expand Up @@ -203,7 +200,7 @@ Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License
| DFM | ResNet-18 | 0.894 | 0.864 | 0.558 | 0.945 | 0.984 | 0.946 | 0.994 | 0.913 | 0.871 | 0.979 | 0.941 | 0.838 | 0.761 | 0.95 | 0.911 | 0.949 |
| DFKDE | Wide ResNet-50 | 0.774 | 0.708 | 0.422 | 0.905 | 0.959 | 0.903 | 0.936 | 0.746 | 0.853 | 0.736 | 0.687 | 0.749 | 0.574 | 0.697 | 0.843 | 0.892 |
| DFKDE | ResNet-18 | 0.762 | 0.646 | 0.577 | 0.669 | 0.965 | 0.863 | 0.951 | 0.751 | 0.698 | 0.806 | 0.729 | 0.607 | 0.694 | 0.767 | 0.839 | 0.866 |
| GANomaly | | 0.421 | 0.203 | 0.404 | 0.413 | 0.408 | 0.744 | 0.251 | 0.457 | 0.682 | 0.537 | 0.270 | 0.472 | 0.231 | 0.372 | 0.440 | 0.434 |
| GANomaly | | 0.421 | 0.203 | 0.404 | 0.413 | 0.408 | 0.744 | 0.251 | 0.457 | 0.682 | 0.537 | 0.270 | 0.472 | 0.231 | 0.372 | 0.440 | 0.434 |

### Pixel-Level AUC

Expand Down Expand Up @@ -232,7 +229,7 @@ Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License
| DFM | ResNet-18 | 0.919 | 0.895 | 0.844 | 0.926 | 0.971 | 0.948 | 0.977 | 0.874 | 0.935 | 0.957 | 0.958 | 0.921 | 0.874 | 0.933 | 0.833 | 0.943 |
| DFKDE | Wide ResNet-50 | 0.875 | 0.907 | 0.844 | 0.905 | 0.945 | 0.914 | 0.946 | 0.790 | 0.914 | 0.817 | 0.894 | 0.922 | 0.855 | 0.845 | 0.722 | 0.910 |
| DFKDE | ResNet-18 | 0.872 | 0.864 | 0.844 | 0.854 | 0.960 | 0.898 | 0.942 | 0.793 | 0.908 | 0.827 | 0.894 | 0.916 | 0.859 | 0.853 | 0.756 | 0.916 |
| GANomaly | | 0.834 | 0.864 | 0.844 | 0.852 | 0.836 | 0.863 | 0.863 | 0.760 | 0.905 | 0.777 | 0.894 | 0.916 | 0.853 | 0.833 | 0.571 | 0.881 |
| GANomaly | | 0.834 | 0.864 | 0.844 | 0.852 | 0.836 | 0.863 | 0.863 | 0.760 | 0.905 | 0.777 | 0.894 | 0.916 | 0.853 | 0.833 | 0.571 | 0.881 |

## Reference
If you use this library and love it, use this to cite it 🤗
Expand Down
12 changes: 6 additions & 6 deletions anomalib/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def update_multi_gpu_training_config(config: Union[DictConfig, ListConfig]) -> U

def get_configurable_parameters(
model_name: Optional[str] = None,
config_path: Optional[Union[Path, str]] = None,
model_config_path: Optional[Union[Path, str]] = None,
weight_file: Optional[str] = None,
config_filename: Optional[str] = "config",
config_file_extension: Optional[str] = "yaml",
Expand All @@ -122,24 +122,24 @@ def get_configurable_parameters(
Args:
model_name: Optional[str]: (Default value = None)
config_path: Optional[Union[Path, str]]: (Default value = None)
model_config_path: Optional[Union[Path, str]]: (Default value = None)
weight_file: Path to the weight file
config_filename: Optional[str]: (Default value = "config")
config_file_extension: Optional[str]: (Default value = "yaml")
Returns:
Union[DictConfig, ListConfig]: Configurable parameters in DictConfig object.
"""
if model_name is None and config_path is None:
if model_name is None and model_config_path is None:
raise ValueError(
"Both model_name and model config path cannot be None! "
"Please provide a model name or path to a config file!"
)

if config_path is None:
config_path = Path(f"anomalib/models/{model_name}/{config_filename}.{config_file_extension}")
if model_config_path is None:
model_config_path = Path(f"anomalib/models/{model_name}/{config_filename}.{config_file_extension}")

config = OmegaConf.load(config_path)
config = OmegaConf.load(model_config_path)

# Dataset Configs
if "format" not in config.dataset.keys():
Expand Down
8 changes: 3 additions & 5 deletions anomalib/deploy/inferencers/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# See the License for the specific language governing permissions
# and limitations under the License.

from importlib.util import find_spec
import importlib
from pathlib import Path
from typing import Dict, Optional, Tuple, Union

Expand All @@ -26,10 +26,8 @@

from .base import Inferencer

if find_spec("openvino") is not None:
from openvino.inference_engine import ( # type: ignore # pylint: disable=no-name-in-module
IECore,
)
if importlib.util.find_spec("openvino") is not None:
from openvino.inference_engine import IECore # pylint: disable=no-name-in-module


class OpenVINOInferencer(Inferencer):
Expand Down
5 changes: 1 addition & 4 deletions docs/source/guides/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,7 @@ file is to be provided:

::

python tools/train.py --config <path/to/model/config.yaml>

Note that `--model_config_path` will be deprecated in `v0.2.8` and removed
in `v0.2.9`.
python tools/train.py --model_config_path <path/to/model/config.yaml>

Alternatively, a model name could also be provided as an argument, where
the scripts automatically finds the corresponding config file.
Expand Down
8 changes: 5 additions & 3 deletions tests/helpers/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
def get_test_configurable_parameters(
dataset_path: Optional[str] = None,
model_name: Optional[str] = None,
config_path: Optional[Union[Path, str]] = None,
model_config_path: Optional[Union[Path, str]] = None,
weight_file: Optional[str] = None,
config_filename: Optional[str] = "config",
config_file_extension: Optional[str] = "yaml",
Expand All @@ -21,7 +21,7 @@ def get_test_configurable_parameters(
Args:
datset_path: Optional[Path]: Path to dataset
model_name: Optional[str]: (Default value = None)
config_path: Optional[Union[Path, str]]: (Default value = None)
model_config_path: Optional[Union[Path, str]]: (Default value = None)
weight_file: Path to the weight file
config_filename: Optional[str]: (Default value = "config")
config_file_extension: Optional[str]: (Default value = "yaml")
Expand All @@ -30,7 +30,9 @@ def get_test_configurable_parameters(
Union[DictConfig, ListConfig]: Configurable parameters in DictConfig object.
"""

config = get_configurable_parameters(model_name, config_path, weight_file, config_filename, config_file_extension)
config = get_configurable_parameters(
model_name, model_config_path, weight_file, config_filename, config_file_extension
)

# Update path to match the dataset path in the test image/runner
config.dataset.path = get_dataset_path() if dataset_path is None else dataset_path
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def run_train_test(config):

@TestDataset(num_train=200, num_test=30, path=get_dataset_path(), seed=42)
def test_normalizer(path=get_dataset_path(), category="shapes"):
config = get_configurable_parameters(config_path="anomalib/models/padim/config.yaml")
config = get_configurable_parameters(model_config_path="anomalib/models/padim/config.yaml")
config.dataset.path = path
config.dataset.category = category
config.model.threshold.adaptive = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_openvino_model_callback():
"""Tests if an optimized model is created."""

config = get_test_configurable_parameters(
config_path="tests/pre_merge/utils/callbacks/openvino_callback/dummy_config.yml"
model_config_path="tests/pre_merge/utils/callbacks/openvino_callback/dummy_config.yml"
)

with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down
2 changes: 1 addition & 1 deletion tests/pre_merge/utils/metrics/test_adaptive_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_non_adaptive_threshold():
Test if the non-adaptive threshold gets used in the F1 score computation when
adaptive thresholding is disabled and no normalization is used.
"""
config = get_test_configurable_parameters(config_path="anomalib/models/padim/config.yaml")
config = get_test_configurable_parameters(model_config_path="anomalib/models/padim/config.yaml")

config.model.normalization_method = "none"
config.model.threshold.adaptive = False
Expand Down
2 changes: 1 addition & 1 deletion tools/benchmarking/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Note: To collect memory read/write numbers, run the script with sudo privileges.
```
sudo -E ./train.sh # Train STFPM on MVTec AD leather
sudo -E ./train.sh --config <path/to/model/config.yaml>
sudo -E ./train.sh --model_config_path <path/to/model/config.yaml>
sudo -E ./train.sh --model stfpm
```
Expand Down
11 changes: 4 additions & 7 deletions tools/hpo/wandb_sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@

from argparse import ArgumentParser
from pathlib import Path
from typing import Union

import pytorch_lightning as pl
from omegaconf import DictConfig, ListConfig, OmegaConf
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import seed_everything
from pytorch_lightning.loggers import WandbLogger
from utils import flatten_hpo_params
Expand All @@ -39,14 +38,12 @@ class WandbSweep:
sweep_config (DictConfig): Sweep configuration.
"""

def __init__(self, config: Union[DictConfig, ListConfig], sweep_config: Union[DictConfig, ListConfig]) -> None:
def __init__(self, config: DictConfig, sweep_config: DictConfig) -> None:
self.config = config
self.sweep_config = sweep_config
self.observation_budget = sweep_config.observation_budget
if "observation_budget" in self.sweep_config.keys():
# this instance check is to silence mypy.
if isinstance(self.sweep_config, DictConfig):
self.sweep_config.pop("observation_budget")
self.sweep_config.pop("observation_budget")

def run(self):
"""Run the sweep."""
Expand Down Expand Up @@ -89,7 +86,7 @@ def get_args():

if __name__ == "__main__":
args = get_args()
model_config = get_configurable_parameters(model_name=args.model, config_path=args.model_config)
model_config = get_configurable_parameters(model_name=args.model, model_config_path=args.model_config)
hpo_config = OmegaConf.load(args.sweep_config)

if model_config.project.seed != 0:
Expand Down
4 changes: 2 additions & 2 deletions tools/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def get_args() -> Namespace:
Namespace: List of arguments.
"""
parser = ArgumentParser()
parser.add_argument("--config", type=Path, required=True, help="Path to a model config file")
parser.add_argument("--model_config_path", type=Path, required=True, help="Path to a model config file")
parser.add_argument("--weight_path", type=Path, required=True, help="Path to a model weights")
parser.add_argument("--image_path", type=Path, required=True, help="Path to an image to infer.")
parser.add_argument("--save_path", type=Path, required=False, help="Path to save the output image.")
Expand Down Expand Up @@ -75,7 +75,7 @@ def stream() -> None:
# This config file is also used for training and contains all the relevant
# information regarding the data, model, train and inference details.
args = get_args()
config = get_configurable_parameters(config_path=args.config)
config = get_configurable_parameters(model_config_path=args.model_config_path)

# Get the inferencer. We use .ckpt extension for Torch models and (onnx, bin)
# for the openvino models.
Expand Down
18 changes: 4 additions & 14 deletions tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# See the License for the specific language governing permissions
# and limitations under the License.

import warnings
from argparse import ArgumentParser, Namespace

from pytorch_lightning import Trainer
Expand All @@ -33,21 +32,11 @@ def get_args() -> Namespace:
"""
parser = ArgumentParser()
parser.add_argument("--model", type=str, default="stfpm", help="Name of the algorithm to train/test")
# --model_config_path will be deprecated in 0.2.8 and removed in 0.2.9
parser.add_argument("--model_config_path", type=str, required=False, help="Path to a model config file")
parser.add_argument("--config", type=str, required=False, help="Path to a model config file")
parser.add_argument("--weight_file", type=str, default="weights/model.ckpt")
parser.add_argument("--openvino", type=bool, default=False)

args = parser.parse_args()
if args.model_config_path is not None:
warnings.warn(
message="--model_config_path will be deprecated in v0.2.8 and removed in v0.2.9. Use --config instead.",
category=DeprecationWarning,
stacklevel=2,
)
args.config = args.model_config_path

return args
return parser.parse_args()


def test():
Expand All @@ -58,8 +47,9 @@ def test():
args = get_args()
config = get_configurable_parameters(
model_name=args.model,
config_path=args.config,
model_config_path=args.model_config_path,
weight_file=args.weight_file,
openvino=args.openvino,
)

datamodule = get_datamodule(config)
Expand Down
16 changes: 2 additions & 14 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
# See the License for the specific language governing permissions
# and limitations under the License.

import warnings
from argparse import ArgumentParser, Namespace

from pytorch_lightning import Trainer, seed_everything
Expand All @@ -39,26 +38,15 @@ def get_args() -> Namespace:
"""
parser = ArgumentParser()
parser.add_argument("--model", type=str, default="padim", help="Name of the algorithm to train/test")
# --model_config_path will be deprecated in 0.2.8 and removed in 0.2.9
parser.add_argument("--model_config_path", type=str, required=False, help="Path to a model config file")
parser.add_argument("--config", type=str, required=False, help="Path to a model config file")

args = parser.parse_args()
if args.model_config_path is not None:
warnings.warn(
message="--model_config_path will be deprecated in v0.2.8 and removed in v0.2.9. Use --config instead.",
category=DeprecationWarning,
stacklevel=2,
)
args.config = args.model_config_path

return args
return parser.parse_args()


def train():
"""Train an anomaly classification or segmentation model based on a provided configuration file."""
args = get_args()
config = get_configurable_parameters(model_name=args.model, config_path=args.config)
config = get_configurable_parameters(model_name=args.model, model_config_path=args.model_config_path)

if config.project.seed != 0:
seed_everything(config.project.seed)
Expand Down

0 comments on commit ca18913

Please sign in to comment.