Skip to content

Commit

Permalink
fix: efficient ad model_size str fixes (#2159)
Browse files Browse the repository at this point in the history
* fix: model str conversion

Signed-off-by: Lukas Hennies <lukas.hennies@anticipate.ml>

* fix: str support for model_size param in efficient ad

Signed-off-by: Lukas Hennies <lukas.hennies@anticipate.ml>

* chore: update type hints for efficient_ad init

also pre-commit
Signed-off-by: Lukas Hennies <lukas.hennies@anticipate.ml>

---------

Signed-off-by: Lukas Hennies <lukas.hennies@anticipate.ml>
Co-authored-by: Samet Akcay <samet.akcay@intel.com>
  • Loading branch information
Gornoka and samet-akcay committed Jul 10, 2024
1 parent b646d1a commit 9582b1d
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions src/anomalib/models/image/efficient_ad/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
self,
imagenet_dir: Path | str = "./datasets/imagenette",
teacher_out_channels: int = 384,
model_size: EfficientAdModelSize = EfficientAdModelSize.S,
model_size: EfficientAdModelSize | str = EfficientAdModelSize.S,
lr: float = 0.0001,
weight_decay: float = 0.00001,
padding: bool = False,
Expand All @@ -72,24 +72,27 @@ def __init__(
super().__init__()

self.imagenet_dir = Path(imagenet_dir)
self.model_size = model_size
if not isinstance(model_size, EfficientAdModelSize):
model_size = EfficientAdModelSize(model_size)
self.model_size: EfficientAdModelSize = model_size
self.model: EfficientAdModel = EfficientAdModel(
teacher_out_channels=teacher_out_channels,
model_size=model_size,
padding=padding,
pad_maps=pad_maps,
)
self.batch_size = 1 # imagenet dataloader batch_size is 1 according to the paper
self.lr = lr
self.weight_decay = weight_decay
self.batch_size: int = 1 # imagenet dataloader batch_size is 1 according to the paper
self.lr: float = lr
self.weight_decay: float = weight_decay

def prepare_pretrained_model(self) -> None:
"""Prepare the pretrained teacher model."""
pretrained_models_dir = Path("./pre_trained/")
if not (pretrained_models_dir / "efficientad_pretrained_weights").is_dir():
download_and_extract(pretrained_models_dir, WEIGHTS_DOWNLOAD_INFO)
model_size_str = self.model_size.value if isinstance(self.model_size, EfficientAdModelSize) else self.model_size
teacher_path = (
pretrained_models_dir / "efficientad_pretrained_weights" / f"pretrained_teacher_{self.model_size.value}.pth"
pretrained_models_dir / "efficientad_pretrained_weights" / f"pretrained_teacher_{model_size_str}.pth"
)
logger.info(f"Load pretrained teacher model from {teacher_path}")
self.model.teacher.load_state_dict(torch.load(teacher_path, map_location=torch.device(self.device)))
Expand Down

0 comments on commit 9582b1d

Please sign in to comment.