Skip to content

Commit

Permalink
Feature Extractor Refactor (#451)
Browse files Browse the repository at this point in the history
  • Loading branch information
ashishbdatta authored Aug 11, 2022
1 parent a0e040d commit 27a31cd
Show file tree
Hide file tree
Showing 11 changed files with 95 additions and 58 deletions.
5 changes: 2 additions & 3 deletions anomalib/models/cflow/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import einops
import torch
import torchvision
from torch import nn

from anomalib.models.cflow.anomaly_map import AnomalyMapGenerator
Expand All @@ -33,13 +32,13 @@ def __init__(
):
super().__init__()

self.backbone = getattr(torchvision.models, backbone)
self.backbone = backbone
self.fiber_batch_size = fiber_batch_size
self.condition_vector: int = condition_vector
self.dec_arch = decoder
self.pool_layers = layers

self.encoder = FeatureExtractor(backbone=self.backbone(pretrained=pre_trained), layers=self.pool_layers)
self.encoder = FeatureExtractor(backbone=self.backbone, layers=self.pool_layers, pre_trained=pre_trained)
self.pool_dims = self.encoder.out_dims
self.decoders = nn.ModuleList(
[
Expand Down
68 changes: 36 additions & 32 deletions anomalib/models/components/feature_extractors/feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from typing import Callable, Dict, Iterable
import warnings
from typing import Dict, List

import timm
import torch
from torch import Tensor, nn

Expand All @@ -21,10 +23,9 @@ class FeatureExtractor(nn.Module):
Example:
>>> import torch
>>> import torchvision
>>> from anomalib.core.model.feature_extractor import FeatureExtractor
>>> model = FeatureExtractor(model=torchvision.models.resnet18(), layers=['layer1', 'layer2', 'layer3'])
>>> model = FeatureExtractor(model="resnet18", layers=['layer1', 'layer2', 'layer3'])
>>> input = torch.rand((32, 3, 256, 256))
>>> features = model(input)
Expand All @@ -34,42 +35,46 @@ class FeatureExtractor(nn.Module):
[torch.Size([32, 64, 64, 64]), torch.Size([32, 128, 32, 32]), torch.Size([32, 256, 16, 16])]
"""

def __init__(self, backbone: nn.Module, layers: Iterable[str]):
def __init__(self, backbone: str, layers: List[str], pre_trained: bool = True):
super().__init__()
self.backbone = backbone
self.layers = layers
self.idx = self._map_layer_to_idx()
self.feature_extractor = timm.create_model(
backbone,
pretrained=pre_trained,
features_only=True,
exportable=True,
out_indices=self.idx,
)
self.out_dims = self.feature_extractor.feature_info.channels()
self._features = {layer: torch.empty(0) for layer in self.layers}
self.out_dims = []

for layer_id in layers:
layer = dict([*self.backbone.named_modules()])[layer_id]
layer.register_forward_hook(self.get_features(layer_id))
# get output dimension of features if available
layer_modules = [*layer.modules()]
for idx in reversed(range(len(layer_modules))):
if hasattr(layer_modules[idx], "out_channels"):
self.out_dims.append(layer_modules[idx].out_channels)
break

def get_features(self, layer_id: str) -> Callable:
"""Get layer features.
def _map_layer_to_idx(self, offset: int = 3) -> List[int]:
"""Maps set of layer names to indices of model.
Args:
layer_id (str): Layer ID
offset (int) `timm` ignores the first few layers when indexing please update offset based on need
Returns:
Layer features
Feature map extracted from the CNN
"""

def hook(_, __, output):
"""Hook to extract features via a forward-pass.
Args:
output: Feature map collected after the forward-pass.
"""
self._features[layer_id] = output

return hook
idx = []
features = timm.create_model(
self.backbone,
pretrained=False,
features_only=False,
exportable=True,
)
for i in self.layers:
try:
idx.append(list(dict(features.named_children()).keys()).index(i) - offset)
except ValueError:
warnings.warn(f"Layer {i} not found in model {self.backbone}")
# Remove unfound key from layer dict
self.layers.remove(i)

return idx

def forward(self, input_tensor: Tensor) -> Dict[str, Tensor]:
"""Forward-pass input tensor into the CNN.
Expand All @@ -80,6 +85,5 @@ def forward(self, input_tensor: Tensor) -> Dict[str, Tensor]:
Returns:
Feature map extracted from the CNN
"""
self._features = {layer: torch.empty(0) for layer in self.layers}
_ = self.backbone(input_tensor)
return self._features
features = dict(zip(self.layers, self.feature_extractor(input_tensor)))
return features
2 changes: 1 addition & 1 deletion anomalib/models/dfkde/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ model:
threshold_offset: 12
normalization_method: min_max # options: [null, min_max, cdf]
layers:
- avgpool
- layer4

metrics:
image:
Expand Down
10 changes: 7 additions & 3 deletions anomalib/models/dfkde/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import List, Optional, Tuple

import torch
import torchvision
import torch.nn.functional as F
from torch import Tensor, nn

from anomalib.models.components import PCA, FeatureExtractor, GaussianKDE
Expand Down Expand Up @@ -48,8 +48,8 @@ def __init__(
self.threshold_steepness = threshold_steepness
self.threshold_offset = threshold_offset

_backbone = getattr(torchvision.models, backbone)
self.feature_extractor = FeatureExtractor(backbone=_backbone(pretrained=pre_trained), layers=layers).eval()
_backbone = backbone
self.feature_extractor = FeatureExtractor(backbone=_backbone, pre_trained=pre_trained, layers=layers).eval()

self.pca_model = PCA(n_components=self.n_components)
self.kde_model = GaussianKDE()
Expand All @@ -68,6 +68,10 @@ def get_features(self, batch: Tensor) -> Tensor:
"""
self.feature_extractor.eval()
layer_outputs = self.feature_extractor(batch)
for layer in layer_outputs:
batch_size = len(layer_outputs[layer])
layer_outputs[layer] = F.adaptive_avg_pool2d(input=layer_outputs[layer], output_size=(1, 1))
layer_outputs[layer] = layer_outputs[layer].view(batch_size, -1)
layer_outputs = torch.cat(list(layer_outputs.values())).detach()
return layer_outputs

Expand Down
7 changes: 4 additions & 3 deletions anomalib/models/dfm/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import torch
import torch.nn.functional as F
import torchvision
from torch import Tensor, nn

from anomalib.models.components import PCA, DynamicBufferModule, FeatureExtractor
Expand Down Expand Up @@ -92,13 +91,15 @@ def __init__(
score_type: str = "fre",
):
super().__init__()
self.backbone = getattr(torchvision.models, backbone)
self.backbone = backbone
self.pooling_kernel_size = pooling_kernel_size
self.n_components = n_comps
self.pca_model = PCA(n_components=self.n_components)
self.gaussian_model = SingleClassGaussian()
self.score_type = score_type
self.feature_extractor = FeatureExtractor(backbone=self.backbone(pretrained=pre_trained), layers=[layer]).eval()
self.feature_extractor = FeatureExtractor(
backbone=self.backbone, pre_trained=pre_trained, layers=[layer]
).eval()

def fit(self, dataset: Tensor) -> None:
"""Fit a pca transformation and a Gaussian model to dataset.
Expand Down
6 changes: 2 additions & 4 deletions anomalib/models/padim/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import torch
import torch.nn.functional as F
import torchvision
from torch import Tensor, nn

from anomalib.models.components import FeatureExtractor, MultiVariateGaussian
Expand Down Expand Up @@ -41,9 +40,9 @@ def __init__(
super().__init__()
self.tiler: Optional[Tiler] = None

self.backbone = getattr(torchvision.models, backbone)
self.backbone = backbone
self.layers = layers
self.feature_extractor = FeatureExtractor(backbone=self.backbone(pretrained=pre_trained), layers=self.layers)
self.feature_extractor = FeatureExtractor(backbone=self.backbone, layers=layers, pre_trained=pre_trained)
self.dims = DIMS[backbone]
# pylint: disable=not-callable
# Since idx is randomly selected, save it with model to get same results
Expand Down Expand Up @@ -98,7 +97,6 @@ def forward(self, input_tensor: Tensor) -> Tensor:
output = self.anomaly_map_generator(
embedding=embeddings, mean=self.gaussian.mean, inv_covariance=self.gaussian.inv_covariance
)

return output

def generate_embedding(self, features: Dict[str, Tensor]) -> Tensor:
Expand Down
5 changes: 2 additions & 3 deletions anomalib/models/patchcore/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import torch
import torch.nn.functional as F
import torchvision
from torch import Tensor, nn

from anomalib.models.components import (
Expand All @@ -33,12 +32,12 @@ def __init__(
super().__init__()
self.tiler: Optional[Tiler] = None

self.backbone = getattr(torchvision.models, backbone)
self.backbone = backbone
self.layers = layers
self.input_size = input_size
self.num_neighbors = num_neighbors

self.feature_extractor = FeatureExtractor(backbone=self.backbone(pretrained=pre_trained), layers=self.layers)
self.feature_extractor = FeatureExtractor(backbone=self.backbone, pre_trained=pre_trained, layers=self.layers)
self.feature_pooler = torch.nn.AvgPool2d(3, 1, 1)
self.anomaly_map_generator = AnomalyMapGenerator(input_size=input_size)

Expand Down
6 changes: 2 additions & 4 deletions anomalib/models/reverse_distillation/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from typing import List, Optional, Tuple, Union

import torchvision
from torch import Tensor, nn

from anomalib.models.components import FeatureExtractor
Expand Down Expand Up @@ -39,9 +38,8 @@ def __init__(
super().__init__()
self.tiler: Optional[Tiler] = None

encoder_backbone = getattr(torchvision.models, backbone)
# TODO replace with TIMM feature extractor
self.encoder = FeatureExtractor(backbone=encoder_backbone(pretrained=pre_trained), layers=layers)
encoder_backbone = backbone
self.encoder = FeatureExtractor(backbone=encoder_backbone, pre_trained=pre_trained, layers=layers)
self.bottleneck = get_bottleneck_layer(backbone)
self.decoder = get_decoder(backbone)

Expand Down
7 changes: 3 additions & 4 deletions anomalib/models/stfpm/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from typing import Dict, List, Optional, Tuple

import torchvision
from torch import Tensor, nn

from anomalib.models.components import FeatureExtractor
Expand All @@ -31,9 +30,9 @@ def __init__(
super().__init__()
self.tiler: Optional[Tiler] = None

self.backbone = getattr(torchvision.models, backbone)
self.teacher_model = FeatureExtractor(backbone=self.backbone(pretrained=True), layers=layers)
self.student_model = FeatureExtractor(backbone=self.backbone(pretrained=False), layers=layers)
self.backbone = backbone
self.teacher_model = FeatureExtractor(backbone=self.backbone, pre_trained=True, layers=layers)
self.student_model = FeatureExtractor(backbone=self.backbone, pre_trained=False, layers=layers)

# teacher model is fixed
for parameters in self.teacher_model.parameters():
Expand Down
2 changes: 1 addition & 1 deletion configs/model/dfkde.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ model:
backbone: resnet18
pre_trained: true
layers:
- avgpool
- layer4
max_training_points: 40000
pre_processing: scale
n_components: 16
Expand Down
35 changes: 35 additions & 0 deletions tests/pre_merge/models/test_feature_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import pytest
import torch

from anomalib.models.components.feature_extractors import FeatureExtractor


class TestFeatureExtractor:
@pytest.mark.parametrize(
"backbone",
["resnet18", "wide_resnet50_2"],
)
@pytest.mark.parametrize(
"pretrained",
[True, False],
)
def test_feature_extraction(self, backbone, pretrained):
layers = ["layer1", "layer2", "layer3"]
model = FeatureExtractor(backbone=backbone, layers=layers, pre_trained=pretrained)
test_input = torch.rand((32, 3, 256, 256))
features = model(test_input)

if backbone == "resnet18":
assert features["layer1"].shape == torch.Size((32, 64, 64, 64))
assert features["layer2"].shape == torch.Size((32, 128, 32, 32))
assert features["layer3"].shape == torch.Size((32, 256, 16, 16))
assert model.out_dims == [64, 128, 256]
assert model.idx == [1, 2, 3]
elif backbone == "wide_resnet50_2":
assert features["layer1"].shape == torch.Size((32, 256, 64, 64))
assert features["layer2"].shape == torch.Size((32, 512, 32, 32))
assert features["layer3"].shape == torch.Size((32, 1024, 16, 16))
assert model.out_dims == [256, 512, 1024]
assert model.idx == [1, 2, 3]
else:
pass

0 comments on commit 27a31cd

Please sign in to comment.