Skip to content

Commit

Permalink
fix: str support for model_size param in efficient ad
Browse files Browse the repository at this point in the history
Signed-off-by: Lukas Hennies <lukas.hennies@anticipate.ml>
  • Loading branch information
Gornoka committed Jun 27, 2024
1 parent a29fccb commit 3f85fe1
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions src/anomalib/models/image/efficient_ad/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,19 +61,21 @@ class EfficientAd(AnomalyModule):
"""

def __init__(
self,
imagenet_dir: Path | str = "./datasets/imagenette",
teacher_out_channels: int = 384,
model_size: EfficientAdModelSize = EfficientAdModelSize.S,
lr: float = 0.0001,
weight_decay: float = 0.00001,
padding: bool = False,
pad_maps: bool = True,
self,
imagenet_dir: Path | str = "./datasets/imagenette",
teacher_out_channels: int = 384,
model_size: EfficientAdModelSize | str = EfficientAdModelSize.S,
lr: float = 0.0001,
weight_decay: float = 0.00001,
padding: bool = False,
pad_maps: bool = True,
) -> None:
super().__init__()

self.imagenet_dir = Path(imagenet_dir)
self.model_size = model_size
if not isinstance(model_size, EfficientAdModelSize):
model_size = EfficientAdModelSize(model_size)
self.model_size: EfficientAdModelSize = model_size
self.model: EfficientAdModel = EfficientAdModel(
teacher_out_channels=teacher_out_channels,
model_size=model_size,
Expand All @@ -91,7 +93,7 @@ def prepare_pretrained_model(self) -> None:
download_and_extract(pretrained_models_dir, WEIGHTS_DOWNLOAD_INFO)
model_size_str = self.model_size.value if isinstance(self.model_size, EfficientAdModelSize) else self.model_size
teacher_path = (
pretrained_models_dir / "efficientad_pretrained_weights" / f"pretrained_teacher_{model_size_str}.pth"
pretrained_models_dir / "efficientad_pretrained_weights" / f"pretrained_teacher_{model_size_str}.pth"
)
logger.info(f"Load pretrained teacher model from {teacher_path}")
self.model.teacher.load_state_dict(torch.load(teacher_path, map_location=torch.device(self.device)))
Expand Down Expand Up @@ -145,15 +147,15 @@ def teacher_channel_mean_std(self, dataloader: DataLoader) -> dict[str, torch.Te

n += y[:, 0].numel()
chanel_sum += torch.sum(y, dim=[0, 2, 3])
chanel_sum_sqr += torch.sum(y**2, dim=[0, 2, 3])
chanel_sum_sqr += torch.sum(y ** 2, dim=[0, 2, 3])

if n is None:
msg = "The value of 'n' cannot be None."
raise ValueError(msg)

channel_mean = chanel_sum / n

channel_std = (torch.sqrt((chanel_sum_sqr / n) - (channel_mean**2))).float()[None, :, None, None]
channel_std = (torch.sqrt((chanel_sum_sqr / n) - (channel_mean ** 2))).float()[None, :, None, None]
channel_mean = channel_mean.float()[None, :, None, None]

return {"mean": channel_mean, "std": channel_std}
Expand Down

0 comments on commit 3f85fe1

Please sign in to comment.