Skip to content

Commit

Permalink
chore: update type hints for efficient_ad init
Browse files Browse the repository at this point in the history
also pre-commit
  • Loading branch information
Gornoka committed Jun 27, 2024
1 parent 0a075ff commit 9a22446
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions src/anomalib/models/image/efficient_ad/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,14 @@ class EfficientAd(AnomalyModule):
"""

def __init__(
self,
imagenet_dir: Path | str = "./datasets/imagenette",
teacher_out_channels: int = 384,
model_size: EfficientAdModelSize | str = EfficientAdModelSize.S,
lr: float = 0.0001,
weight_decay: float = 0.00001,
padding: bool = False,
pad_maps: bool = True,
self,
imagenet_dir: Path | str = "./datasets/imagenette",
teacher_out_channels: int = 384,
model_size: EfficientAdModelSize | str = EfficientAdModelSize.S,
lr: float = 0.0001,
weight_decay: float = 0.00001,
padding: bool = False,
pad_maps: bool = True,
) -> None:
super().__init__()

Expand All @@ -82,9 +82,9 @@ def __init__(
padding=padding,
pad_maps=pad_maps,
)
self.batch_size = 1 # imagenet dataloader batch_size is 1 according to the paper
self.lr = lr
self.weight_decay = weight_decay
self.batch_size: int = 1 # imagenet dataloader batch_size is 1 according to the paper
self.lr: float = lr
self.weight_decay: float = weight_decay

def prepare_pretrained_model(self) -> None:
"""Prepare the pretrained teacher model."""
Expand All @@ -93,7 +93,7 @@ def prepare_pretrained_model(self) -> None:
download_and_extract(pretrained_models_dir, WEIGHTS_DOWNLOAD_INFO)
model_size_str = self.model_size.value if isinstance(self.model_size, EfficientAdModelSize) else self.model_size
teacher_path = (
pretrained_models_dir / "efficientad_pretrained_weights" / f"pretrained_teacher_{model_size_str}.pth"
pretrained_models_dir / "efficientad_pretrained_weights" / f"pretrained_teacher_{model_size_str}.pth"
)
logger.info(f"Load pretrained teacher model from {teacher_path}")
self.model.teacher.load_state_dict(torch.load(teacher_path, map_location=torch.device(self.device)))
Expand Down Expand Up @@ -147,15 +147,15 @@ def teacher_channel_mean_std(self, dataloader: DataLoader) -> dict[str, torch.Te

n += y[:, 0].numel()
chanel_sum += torch.sum(y, dim=[0, 2, 3])
chanel_sum_sqr += torch.sum(y ** 2, dim=[0, 2, 3])
chanel_sum_sqr += torch.sum(y**2, dim=[0, 2, 3])

if n is None:
msg = "The value of 'n' cannot be None."
raise ValueError(msg)

channel_mean = chanel_sum / n

channel_std = (torch.sqrt((chanel_sum_sqr / n) - (channel_mean ** 2))).float()[None, :, None, None]
channel_std = (torch.sqrt((chanel_sum_sqr / n) - (channel_mean**2))).float()[None, :, None, None]
channel_mean = channel_mean.float()[None, :, None, None]

return {"mean": channel_mean, "std": channel_std}
Expand Down

0 comments on commit 9a22446

Please sign in to comment.