Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature Extractor Refactor #451

Merged
merged 16 commits into from
Aug 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions anomalib/models/cflow/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import einops
import torch
import torchvision
from torch import nn

from anomalib.models.cflow.anomaly_map import AnomalyMapGenerator
Expand All @@ -44,13 +43,13 @@ def __init__(
):
super().__init__()

self.backbone = getattr(torchvision.models, backbone)
self.backbone = backbone
self.fiber_batch_size = fiber_batch_size
self.condition_vector: int = condition_vector
self.dec_arch = decoder
self.pool_layers = layers

self.encoder = FeatureExtractor(backbone=self.backbone(pretrained=pre_trained), layers=self.pool_layers)
self.encoder = FeatureExtractor(backbone=self.backbone, layers=self.pool_layers, pre_trained=pre_trained)
self.pool_dims = self.encoder.out_dims
self.decoders = nn.ModuleList(
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
# See the License for the specific language governing permissions
# and limitations under the License.

from typing import Callable, Dict, Iterable
import warnings
from typing import Dict, List

import timm
import torch
from torch import Tensor, nn

Expand All @@ -32,10 +34,9 @@ class FeatureExtractor(nn.Module):

Example:
>>> import torch
>>> import torchvision
>>> from anomalib.core.model.feature_extractor import FeatureExtractor

>>> model = FeatureExtractor(model=torchvision.models.resnet18(), layers=['layer1', 'layer2', 'layer3'])
>>> model = FeatureExtractor(model="resnet18", layers=['layer1', 'layer2', 'layer3'])
>>> input = torch.rand((32, 3, 256, 256))
>>> features = model(input)

Expand All @@ -45,42 +46,46 @@ class FeatureExtractor(nn.Module):
[torch.Size([32, 64, 64, 64]), torch.Size([32, 128, 32, 32]), torch.Size([32, 256, 16, 16])]
"""

def __init__(self, backbone: nn.Module, layers: Iterable[str]):
def __init__(self, backbone: str, layers: List[str], pre_trained: bool = True):
super().__init__()
self.backbone = backbone
self.layers = layers
self.idx = self._map_layer_to_idx()
self.feature_extractor = timm.create_model(
backbone,
pretrained=pre_trained,
features_only=True,
exportable=True,
out_indices=self.idx,
)
self.out_dims = self.feature_extractor.feature_info.channels()
self._features = {layer: torch.empty(0) for layer in self.layers}
self.out_dims = []

for layer_id in layers:
layer = dict([*self.backbone.named_modules()])[layer_id]
layer.register_forward_hook(self.get_features(layer_id))
# get output dimension of features if available
layer_modules = [*layer.modules()]
for idx in reversed(range(len(layer_modules))):
if hasattr(layer_modules[idx], "out_channels"):
self.out_dims.append(layer_modules[idx].out_channels)
break

def get_features(self, layer_id: str) -> Callable:
"""Get layer features.
def _map_layer_to_idx(self, offset: int = 3) -> List[int]:
"""Maps set of layer names to indices of model.

Args:
layer_id (str): Layer ID
offset (int) `timm` ignores the first few layers when indexing please update offset based on need

Returns:
Layer features
Feature map extracted from the CNN
"""

def hook(_, __, output):
"""Hook to extract features via a forward-pass.

Args:
output: Feature map collected after the forward-pass.
"""
self._features[layer_id] = output

return hook
idx = []
features = timm.create_model(
self.backbone,
pretrained=False,
features_only=False,
exportable=True,
)
for i in self.layers:
try:
idx.append(list(dict(features.named_children()).keys()).index(i) - offset)
except ValueError:
warnings.warn(f"Layer {i} not found in model {self.backbone}")
# Remove unfound key from layer dict
self.layers.remove(i)

return idx

def forward(self, input_tensor: Tensor) -> Dict[str, Tensor]:
"""Forward-pass input tensor into the CNN.
Expand All @@ -91,6 +96,5 @@ def forward(self, input_tensor: Tensor) -> Dict[str, Tensor]:
Returns:
Feature map extracted from the CNN
"""
self._features = {layer: torch.empty(0) for layer in self.layers}
_ = self.backbone(input_tensor)
return self._features
features = dict(zip(self.layers, self.feature_extractor(input_tensor)))
return features
2 changes: 1 addition & 1 deletion anomalib/models/dfkde/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ model:
threshold_offset: 12
normalization_method: min_max # options: [null, min_max, cdf]
layers:
- avgpool
- layer4

metrics:
image:
Expand Down
10 changes: 7 additions & 3 deletions anomalib/models/dfkde/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import List, Optional, Tuple

import torch
import torchvision
import torch.nn.functional as F
from torch import Tensor, nn

from anomalib.models.components import PCA, FeatureExtractor, GaussianKDE
Expand Down Expand Up @@ -59,8 +59,8 @@ def __init__(
self.threshold_steepness = threshold_steepness
self.threshold_offset = threshold_offset

_backbone = getattr(torchvision.models, backbone)
self.feature_extractor = FeatureExtractor(backbone=_backbone(pretrained=pre_trained), layers=layers).eval()
_backbone = backbone
self.feature_extractor = FeatureExtractor(backbone=_backbone, pre_trained=pre_trained, layers=layers).eval()

self.pca_model = PCA(n_components=self.n_components)
self.kde_model = GaussianKDE()
Expand All @@ -79,6 +79,10 @@ def get_features(self, batch: Tensor) -> Tensor:
"""
self.feature_extractor.eval()
layer_outputs = self.feature_extractor(batch)
for layer in layer_outputs:
batch_size = len(layer_outputs[layer])
layer_outputs[layer] = F.adaptive_avg_pool2d(input=layer_outputs[layer], output_size=(1, 1))
layer_outputs[layer] = layer_outputs[layer].view(batch_size, -1)
layer_outputs = torch.cat(list(layer_outputs.values())).detach()
return layer_outputs

Expand Down
7 changes: 4 additions & 3 deletions anomalib/models/dfm/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import torch
import torch.nn.functional as F
import torchvision
from torch import Tensor, nn

from anomalib.models.components import PCA, DynamicBufferModule, FeatureExtractor
Expand Down Expand Up @@ -103,13 +102,15 @@ def __init__(
score_type: str = "fre",
):
super().__init__()
self.backbone = getattr(torchvision.models, backbone)
self.backbone = backbone
self.pooling_kernel_size = pooling_kernel_size
self.n_components = n_comps
self.pca_model = PCA(n_components=self.n_components)
self.gaussian_model = SingleClassGaussian()
self.score_type = score_type
self.feature_extractor = FeatureExtractor(backbone=self.backbone(pretrained=pre_trained), layers=[layer]).eval()
self.feature_extractor = FeatureExtractor(
backbone=self.backbone, pre_trained=pre_trained, layers=[layer]
).eval()

def fit(self, dataset: Tensor) -> None:
"""Fit a pca transformation and a Gaussian model to dataset.
Expand Down
6 changes: 2 additions & 4 deletions anomalib/models/padim/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import torch
import torch.nn.functional as F
import torchvision
from torch import Tensor, nn

from anomalib.models.components import FeatureExtractor, MultiVariateGaussian
Expand Down Expand Up @@ -52,9 +51,9 @@ def __init__(
super().__init__()
self.tiler: Optional[Tiler] = None

self.backbone = getattr(torchvision.models, backbone)
self.backbone = backbone
self.layers = layers
self.feature_extractor = FeatureExtractor(backbone=self.backbone(pretrained=pre_trained), layers=self.layers)
self.feature_extractor = FeatureExtractor(backbone=self.backbone, layers=layers, pre_trained=pre_trained)
self.dims = DIMS[backbone]
# pylint: disable=not-callable
# Since idx is randomly selected, save it with model to get same results
Expand Down Expand Up @@ -109,7 +108,6 @@ def forward(self, input_tensor: Tensor) -> Tensor:
output = self.anomaly_map_generator(
embedding=embeddings, mean=self.gaussian.mean, inv_covariance=self.gaussian.inv_covariance
)

return output

def generate_embedding(self, features: Dict[str, Tensor]) -> Tensor:
Expand Down
5 changes: 2 additions & 3 deletions anomalib/models/patchcore/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import torch
import torch.nn.functional as F
import torchvision
from torch import Tensor, nn

from anomalib.models.components import (
Expand All @@ -44,12 +43,12 @@ def __init__(
super().__init__()
self.tiler: Optional[Tiler] = None

self.backbone = getattr(torchvision.models, backbone)
self.backbone = backbone
self.layers = layers
self.input_size = input_size
self.num_neighbors = num_neighbors

self.feature_extractor = FeatureExtractor(backbone=self.backbone(pretrained=pre_trained), layers=self.layers)
self.feature_extractor = FeatureExtractor(backbone=self.backbone, pre_trained=pre_trained, layers=self.layers)
self.feature_pooler = torch.nn.AvgPool2d(3, 1, 1)
self.anomaly_map_generator = AnomalyMapGenerator(input_size=input_size)

Expand Down
6 changes: 2 additions & 4 deletions anomalib/models/reverse_distillation/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from typing import List, Optional, Tuple, Union

import torchvision
from torch import Tensor, nn

from anomalib.models.components import FeatureExtractor
Expand Down Expand Up @@ -50,9 +49,8 @@ def __init__(
super().__init__()
self.tiler: Optional[Tiler] = None

encoder_backbone = getattr(torchvision.models, backbone)
# TODO replace with TIMM feature extractor
self.encoder = FeatureExtractor(backbone=encoder_backbone(pretrained=pre_trained), layers=layers)
encoder_backbone = backbone
self.encoder = FeatureExtractor(backbone=encoder_backbone, pre_trained=pre_trained, layers=layers)
self.bottleneck = get_bottleneck_layer(backbone)
self.decoder = get_decoder(backbone)

Expand Down
7 changes: 3 additions & 4 deletions anomalib/models/stfpm/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from typing import Dict, List, Optional, Tuple

import torchvision
from torch import Tensor, nn

from anomalib.models.components import FeatureExtractor
Expand All @@ -42,9 +41,9 @@ def __init__(
super().__init__()
self.tiler: Optional[Tiler] = None

self.backbone = getattr(torchvision.models, backbone)
self.teacher_model = FeatureExtractor(backbone=self.backbone(pretrained=True), layers=layers)
self.student_model = FeatureExtractor(backbone=self.backbone(pretrained=False), layers=layers)
self.backbone = backbone
self.teacher_model = FeatureExtractor(backbone=self.backbone, pre_trained=True, layers=layers)
self.student_model = FeatureExtractor(backbone=self.backbone, pre_trained=False, layers=layers)

# teacher model is fixed
for parameters in self.teacher_model.parameters():
Expand Down
2 changes: 1 addition & 1 deletion configs/model/dfkde.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ model:
backbone: resnet18
pre_trained: true
layers:
- avgpool
- layer4
max_training_points: 40000
pre_processing: scale
n_components: 16
Expand Down
35 changes: 35 additions & 0 deletions tests/pre_merge/models/test_feature_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import pytest
import torch

from anomalib.models.components.feature_extractors import FeatureExtractor


class TestFeatureExtractor:
@pytest.mark.parametrize(
"backbone",
["resnet18", "wide_resnet50_2"],
)
@pytest.mark.parametrize(
"pretrained",
[True, False],
)
def test_feature_extraction(self, backbone, pretrained):
layers = ["layer1", "layer2", "layer3"]
model = FeatureExtractor(backbone=backbone, layers=layers, pre_trained=pretrained)
test_input = torch.rand((32, 3, 256, 256))
features = model(test_input)

if backbone == "resnet18":
assert features["layer1"].shape == torch.Size((32, 64, 64, 64))
assert features["layer2"].shape == torch.Size((32, 128, 32, 32))
assert features["layer3"].shape == torch.Size((32, 256, 16, 16))
assert model.out_dims == [64, 128, 256]
assert model.idx == [1, 2, 3]
elif backbone == "wide_resnet50_2":
assert features["layer1"].shape == torch.Size((32, 256, 64, 64))
assert features["layer2"].shape == torch.Size((32, 512, 32, 32))
assert features["layer3"].shape == torch.Size((32, 1024, 16, 16))
assert model.out_dims == [256, 512, 1024]
assert model.idx == [1, 2, 3]
else:
pass