Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove freia as dependency and include it in anomalib/models/components #174

Merged
merged 20 commits into from
Apr 7, 2022
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ repos:

# python code formatting
- repo: https://github.com/psf/black
rev: 20.8b1
rev: 22.3.0
hooks:
- id: black
args: [--line-length, "120"]
Expand All @@ -37,7 +37,7 @@ repos:
hooks:
- id: flake8
args: [--config=tox.ini]
exclude: "tests/"
exclude: "tests|anomalib/models/components/freia"

# python linting
- repo: local
Expand All @@ -47,7 +47,7 @@ repos:
entry: pylint --score=no --rcfile=tox.ini
language: system
types: [python]
exclude: "tests|docs"
exclude: "tests|docs|anomalib/models/components/freia"

# python static type checking
- repo: https://github.com/pre-commit/mirrors-mypy
Expand All @@ -56,7 +56,7 @@ repos:
- id: mypy
args: [--config-file=tox.ini]
additional_dependencies: [types-PyYAML]
exclude: "tests/"
exclude: "tests|anomalib/models/components/freia"

- repo: https://github.com/PyCQA/pydocstyle
rev: 6.1.1
Expand All @@ -67,4 +67,4 @@ repos:
language: python
types: [python]
args: [--config=tox.ini]
exclude: "tests|docs"
exclude: "tests|docs|anomalib/models/components/freia"
16 changes: 9 additions & 7 deletions anomalib/models/cflow/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@

import math

import FrEIA.framework as Ff
import FrEIA.modules as Fm
import torch
from FrEIA.framework.sequence_inn import SequenceINN
from torch import nn

from anomalib.models.components.freia.framework import SequenceINN
from anomalib.models.components.freia.modules import AllInOneBlock


def positional_encoding_2d(condition_vector: int, height: int, width: int) -> torch.Tensor:
"""Creates embedding to store relative position of the feature vector using sine and cosine functions.
Expand Down Expand Up @@ -73,7 +73,9 @@ def subnet_fc(dims_in: int, dims_out: int):
return nn.Sequential(nn.Linear(dims_in, 2 * dims_in), nn.ReLU(), nn.Linear(2 * dims_in, dims_out))


def cflow_head(condition_vector: int, coupling_blocks: int, clamp_alpha: float, n_features: int) -> SequenceINN:
def cflow_head(
condition_vector: int, coupling_blocks: int, clamp_alpha: float, n_features: int, permute_soft: bool = False
) -> SequenceINN:
"""Create invertible decoder network.

Args:
Expand All @@ -85,16 +87,16 @@ def cflow_head(condition_vector: int, coupling_blocks: int, clamp_alpha: float,
Returns:
SequenceINN: decoder network block
"""
coder = Ff.SequenceINN(n_features)
coder = SequenceINN(n_features)
print("CNF coder:", n_features)
for _ in range(coupling_blocks):
coder.append(
Fm.AllInOneBlock,
AllInOneBlock,
cond=0,
cond_shape=(condition_vector,),
subnet_constructor=subnet_fc,
affine_clamping=clamp_alpha,
global_affine_type="SOFTPLUS",
permute_soft=True,
permute_soft=permute_soft,
)
return coder
1 change: 1 addition & 0 deletions anomalib/models/cflow/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ model:
condition_vector: 128
coupling_blocks: 8
clamp_alpha: 1.9
soft_permutation: false
lr: 0.0001
early_stopping:
patience: 2
Expand Down
11 changes: 9 additions & 2 deletions anomalib/models/cflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def get_logp(dim_feature_vector: int, p_u: torch.Tensor, logdet_j: torch.Tensor)
torch.Tensor: Log probability
"""
ln_sqrt_2pi = -np.log(np.sqrt(2 * np.pi)) # ln(sqrt(2*pi))
logp = dim_feature_vector * ln_sqrt_2pi - 0.5 * torch.sum(p_u ** 2, 1) + logdet_j
logp = dim_feature_vector * ln_sqrt_2pi - 0.5 * torch.sum(p_u**2, 1) + logdet_j
return logp


Expand Down Expand Up @@ -139,9 +139,16 @@ def __init__(self, hparams: Union[DictConfig, ListConfig]):

self.encoder = FeatureExtractor(backbone=self.backbone(pretrained=True), layers=self.pool_layers)
self.pool_dims = self.encoder.out_dims
permute_soft = hparams.model.soft_permutation
self.decoders = nn.ModuleList(
[
cflow_head(self.condition_vector, hparams.model.coupling_blocks, hparams.model.clamp_alpha, pool_dim)
cflow_head( # type: ignore # pylint:disable=too-many-function-args
self.condition_vector,
hparams.model.coupling_blocks,
hparams.model.clamp_alpha,
pool_dim,
permute_soft,
)
for pool_dim in self.pool_dims
]
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def johnson_lindenstrauss_min_dim(self, n_samples: int, eps: float = 0.1):
eps (float, optional): Minimum distortion rate. Defaults to 0.1.
"""

denominator = (eps ** 2 / 2) - (eps ** 3 / 3)
denominator = (eps**2 / 2) - (eps**3 / 3)
return (4 * np.log(n_samples) / denominator).astype(np.int64)

def fit(self, embedding: Tensor) -> "SparseRandomProjection":
Expand Down
5 changes: 5 additions & 0 deletions anomalib/models/components/freia/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
## FrEIA
This sub-package contains freia packages to use within flow-based algorithms such as Cflow.

## Description
[FrEIA](https://github.com/VLL-HD/FrEIA) package is currently not available in pypi to install via pip. The only way to install it is `pip install git+https://github.com/VLL-HD/FrEIA.git`. PyPI, however, does not support installing packages from git links. Due to this limitation, anomalib cannot be updated on PyPI. To avoid this, `anomalib` contains some of the [FrEIA](https://github.com/VLL-HD/FrEIA) modules to facilitate CFlow training/inference.
11 changes: 11 additions & 0 deletions anomalib/models/components/freia/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Framework for Easily Invertible Architectures.

Module to construct invertible networks with pytorch, based on a graph
structure of operations.
LeonidBeynenson marked this conversation as resolved.
Show resolved Hide resolved

Link to the original repo: https://github.com/VLL-HD/FrEIA
"""

# Copyright (c) 2018-2022 Lynton Ardizzone, Visual Learning Lab Heidelberg.
# SPDX-License-Identifier: MIT
#
9 changes: 9 additions & 0 deletions anomalib/models/components/freia/framework/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""Framework."""

# Copyright (c) 2018-2022 Lynton Ardizzone, Visual Learning Lab Heidelberg.
# SPDX-License-Identifier: MIT
#

from .sequence_inn import SequenceINN

__all__ = ["SequenceINN"]
120 changes: 120 additions & 0 deletions anomalib/models/components/freia/framework/sequence_inn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""Sequence INN."""

# Copyright (c) 2018-2022 Lynton Ardizzone, Visual Learning Lab Heidelberg.
# SPDX-License-Identifier: MIT
#

# pylint: disable=invalid-name
# flake8: noqa
# pylint: skip-file
# type: ignore
# pydocstyle: noqa

from typing import Iterable, List, Tuple

import torch
from torch import Tensor, nn

from anomalib.models.components.freia.modules.base import InvertibleModule


class SequenceINN(InvertibleModule):
"""Simpler than FrEIA.framework.GraphINN.

Only supports a sequential series of modules (no splitting, merging,
branching off).
Has an append() method, to add new blocks in a more simple way than the
computation-graph based approach of GraphINN. For example:
.. code-block:: python
inn = SequenceINN(channels, dims_H, dims_W)
for i in range(n_blocks):
inn.append(FrEIA.modules.AllInOneBlock, clamp=2.0, permute_soft=True)
inn.append(FrEIA.modules.HaarDownsampling)
# and so on
"""

def __init__(self, *dims: int, force_tuple_output=False):
super().__init__([dims])

self.shapes = [tuple(dims)]
self.conditions = []
self.module_list = nn.ModuleList()

self.force_tuple_output = force_tuple_output

def append(self, module_class, cond=None, cond_shape=None, **kwargs):
"""Append a reversible block from FrEIA.modules to the network.

Args:
module_class: Class from FrEIA.modules.
cond (int): index of which condition to use (conditions will be passed as list to forward()).
Conditioning nodes are not needed for SequenceINN.
cond_shape (tuple[int]): the shape of the condition tensor.
**kwargs: Further keyword arguments that are passed to the constructor of module_class (see example).
"""

dims_in = [self.shapes[-1]]
self.conditions.append(cond)

if cond is not None:
kwargs["dims_c"] = [cond_shape]

module = module_class(dims_in, **kwargs)
self.module_list.append(module)
ouput_dims = module.output_dims(dims_in)
assert len(ouput_dims) == 1, "Module has more than one output"
self.shapes.append(ouput_dims[0])

def __getitem__(self, item):
"""Get item."""
return self.module_list.__getitem__(item)

def __len__(self):
"""Get length."""
return self.module_list.__len__()

def __iter__(self):
"""Iter."""
return self.module_list.__iter__()

def output_dims(self, input_dims: List[Tuple[int]]) -> List[Tuple[int]]:
"""Output Dims."""
if not self.force_tuple_output:
raise ValueError(
"You can only call output_dims on a SequentialINN " "when setting force_tuple_output=True."
)
return input_dims

def forward(
self, x_or_z: Tensor, c: Iterable[Tensor] = None, rev: bool = False, jac: bool = True
) -> Tuple[Tensor, Tensor]:
"""Execute the sequential INN in forward or inverse (rev=True) direction.

Args:
x_or_z: input tensor (in contrast to GraphINN, a list of
tensors is not supported, as SequenceINN only has
one input).
c: list of conditions.
rev: whether to compute the network forward or reversed.
jac: whether to compute the log jacobian
Returns:
z_or_x (Tensor): network output.
jac (Tensor): log-jacobian-determinant.
"""

iterator = range(len(self.module_list))
log_det_jac = 0

if rev:
iterator = reversed(iterator)

if torch.is_tensor(x_or_z):
x_or_z = (x_or_z,)
for i in iterator:
if self.conditions[i] is None:
x_or_z, j = self.module_list[i](x_or_z, jac=jac, rev=rev)
else:
x_or_z, j = self.module_list[i](x_or_z, c=[c[self.conditions[i]]], jac=jac, rev=rev)
log_det_jac = j + log_det_jac

return x_or_z if self.force_tuple_output else x_or_z[0], log_det_jac
10 changes: 10 additions & 0 deletions anomalib/models/components/freia/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""Modules."""

# Copyright (c) 2018-2022 Lynton Ardizzone, Visual Learning Lab Heidelberg.
# SPDX-License-Identifier: MIT
#

from .all_in_one_block import AllInOneBlock
from .base import InvertibleModule

__all__ = ["AllInOneBlock", "InvertibleModule"]
Loading