From 83b4d543b4e4efa94a38e06a2efc425b1fa90fc2 Mon Sep 17 00:00:00 2001 From: Samet Akcay Date: Thu, 7 Apr 2022 12:51:22 +0100 Subject: [PATCH] Remove `freia` as dependency and include it in `anomalib/models/components` (#174) * Removed FrEIA from the dependencies * Added soft permutation as an option to the config * Update pre-commit * Added modules * Updated tox.ini to exclude freia packages. * Relative imports * Ignore pylint in sequence_inn * Addressed pydocstyle * replaced FrEIA with anomalib freia modules * Add ignores to freia * Use nncf version inistead * Added freia to third party programs. * Bump up pre-commit black * Address black version and dependency issues * Added github link to the init * Added readme file to freia * removed linting ignores from the modules * Address Dicks comments --- .pre-commit-config.yaml | 8 +- anomalib/models/cflow/backbone.py | 10 +- anomalib/models/components/freia/README.md | 5 + anomalib/models/components/freia/__init__.py | 11 + .../components/freia/framework/__init__.py | 9 + .../freia/framework/sequence_inn.py | 120 ++++++++ .../components/freia/modules/__init__.py | 10 + .../freia/modules/all_in_one_block.py | 289 ++++++++++++++++++ .../models/components/freia/modules/base.py | 112 +++++++ requirements/base.txt | 1 - requirements/openvino.txt | 2 +- third-party-programs.txt | 4 + tox.ini | 9 +- 13 files changed, 576 insertions(+), 14 deletions(-) create mode 100644 anomalib/models/components/freia/README.md create mode 100644 anomalib/models/components/freia/__init__.py create mode 100644 anomalib/models/components/freia/framework/__init__.py create mode 100644 anomalib/models/components/freia/framework/sequence_inn.py create mode 100644 anomalib/models/components/freia/modules/__init__.py create mode 100644 anomalib/models/components/freia/modules/all_in_one_block.py create mode 100644 anomalib/models/components/freia/modules/base.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index db8840d270..3178f24da9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -37,7 +37,7 @@ repos: hooks: - id: flake8 args: [--config=tox.ini] - exclude: "tests/" + exclude: "tests|anomalib/models/components/freia" # python linting - repo: local @@ -47,7 +47,7 @@ repos: entry: pylint --score=no --rcfile=tox.ini language: system types: [python] - exclude: "tests|docs" + exclude: "tests|docs|anomalib/models/components/freia" # python static type checking - repo: https://github.com/pre-commit/mirrors-mypy @@ -56,7 +56,7 @@ repos: - id: mypy args: [--config-file=tox.ini] additional_dependencies: [types-PyYAML] - exclude: "tests/" + exclude: "tests|anomalib/models/components/freia" - repo: https://github.com/PyCQA/pydocstyle rev: 6.1.1 @@ -67,4 +67,4 @@ repos: language: python types: [python] args: [--config=tox.ini] - exclude: "tests|docs" + exclude: "tests|docs|anomalib/models/components/freia" diff --git a/anomalib/models/cflow/backbone.py b/anomalib/models/cflow/backbone.py index d0b23c24ee..f4c0f13c6e 100644 --- a/anomalib/models/cflow/backbone.py +++ b/anomalib/models/cflow/backbone.py @@ -16,12 +16,12 @@ import math -import FrEIA.framework as Ff -import FrEIA.modules as Fm import torch -from FrEIA.framework.sequence_inn import SequenceINN from torch import nn +from anomalib.models.components.freia.framework import SequenceINN +from anomalib.models.components.freia.modules import AllInOneBlock + def positional_encoding_2d(condition_vector: int, height: int, width: int) -> torch.Tensor: """Creates embedding to store relative position of the feature vector using sine and cosine functions. @@ -90,11 +90,11 @@ def cflow_head( Returns: SequenceINN: decoder network block """ - coder = Ff.SequenceINN(n_features) + coder = SequenceINN(n_features) print("CNF coder:", n_features) for _ in range(coupling_blocks): coder.append( - Fm.AllInOneBlock, + AllInOneBlock, cond=0, cond_shape=(condition_vector,), subnet_constructor=subnet_fc, diff --git a/anomalib/models/components/freia/README.md b/anomalib/models/components/freia/README.md new file mode 100644 index 0000000000..b9ef90a20b --- /dev/null +++ b/anomalib/models/components/freia/README.md @@ -0,0 +1,5 @@ +## FrEIA +This sub-package contains freia packages to use within flow-based algorithms such as Cflow. + +## Description +[FrEIA](https://github.com/VLL-HD/FrEIA) package is currently not available in pypi to install via pip. The only way to install it is `pip install git+https://github.com/VLL-HD/FrEIA.git`. PyPI, however, does not support installing packages from git links. Due to this limitation, anomalib cannot be updated on PyPI. To avoid this, `anomalib` contains some of the [FrEIA](https://github.com/VLL-HD/FrEIA) modules to facilitate CFlow training/inference. diff --git a/anomalib/models/components/freia/__init__.py b/anomalib/models/components/freia/__init__.py new file mode 100644 index 0000000000..b52144c7a7 --- /dev/null +++ b/anomalib/models/components/freia/__init__.py @@ -0,0 +1,11 @@ +"""Framework for Easily Invertible Architectures. + +Module to construct invertible networks with pytorch, based on a graph +structure of operations. + +Link to the original repo: https://github.com/VLL-HD/FrEIA +""" + +# Copyright (c) 2018-2022 Lynton Ardizzone, Visual Learning Lab Heidelberg. +# SPDX-License-Identifier: MIT +# diff --git a/anomalib/models/components/freia/framework/__init__.py b/anomalib/models/components/freia/framework/__init__.py new file mode 100644 index 0000000000..226ceebd51 --- /dev/null +++ b/anomalib/models/components/freia/framework/__init__.py @@ -0,0 +1,9 @@ +"""Framework.""" + +# Copyright (c) 2018-2022 Lynton Ardizzone, Visual Learning Lab Heidelberg. +# SPDX-License-Identifier: MIT +# + +from .sequence_inn import SequenceINN + +__all__ = ["SequenceINN"] diff --git a/anomalib/models/components/freia/framework/sequence_inn.py b/anomalib/models/components/freia/framework/sequence_inn.py new file mode 100644 index 0000000000..a5c05d8291 --- /dev/null +++ b/anomalib/models/components/freia/framework/sequence_inn.py @@ -0,0 +1,120 @@ +"""Sequence INN.""" + +# Copyright (c) 2018-2022 Lynton Ardizzone, Visual Learning Lab Heidelberg. +# SPDX-License-Identifier: MIT +# + +# pylint: disable=invalid-name +# flake8: noqa +# pylint: skip-file +# type: ignore +# pydocstyle: noqa + +from typing import Iterable, List, Tuple + +import torch +from torch import Tensor, nn + +from anomalib.models.components.freia.modules.base import InvertibleModule + + +class SequenceINN(InvertibleModule): + """Simpler than FrEIA.framework.GraphINN. + + Only supports a sequential series of modules (no splitting, merging, + branching off). + Has an append() method, to add new blocks in a more simple way than the + computation-graph based approach of GraphINN. For example: + .. code-block:: python + inn = SequenceINN(channels, dims_H, dims_W) + for i in range(n_blocks): + inn.append(FrEIA.modules.AllInOneBlock, clamp=2.0, permute_soft=True) + inn.append(FrEIA.modules.HaarDownsampling) + # and so on + """ + + def __init__(self, *dims: int, force_tuple_output=False): + super().__init__([dims]) + + self.shapes = [tuple(dims)] + self.conditions = [] + self.module_list = nn.ModuleList() + + self.force_tuple_output = force_tuple_output + + def append(self, module_class, cond=None, cond_shape=None, **kwargs): + """Append a reversible block from FrEIA.modules to the network. + + Args: + module_class: Class from FrEIA.modules. + cond (int): index of which condition to use (conditions will be passed as list to forward()). + Conditioning nodes are not needed for SequenceINN. + cond_shape (tuple[int]): the shape of the condition tensor. + **kwargs: Further keyword arguments that are passed to the constructor of module_class (see example). + """ + + dims_in = [self.shapes[-1]] + self.conditions.append(cond) + + if cond is not None: + kwargs["dims_c"] = [cond_shape] + + module = module_class(dims_in, **kwargs) + self.module_list.append(module) + ouput_dims = module.output_dims(dims_in) + assert len(ouput_dims) == 1, "Module has more than one output" + self.shapes.append(ouput_dims[0]) + + def __getitem__(self, item): + """Get item.""" + return self.module_list.__getitem__(item) + + def __len__(self): + """Get length.""" + return self.module_list.__len__() + + def __iter__(self): + """Iter.""" + return self.module_list.__iter__() + + def output_dims(self, input_dims: List[Tuple[int]]) -> List[Tuple[int]]: + """Output Dims.""" + if not self.force_tuple_output: + raise ValueError( + "You can only call output_dims on a SequentialINN " "when setting force_tuple_output=True." + ) + return input_dims + + def forward( + self, x_or_z: Tensor, c: Iterable[Tensor] = None, rev: bool = False, jac: bool = True + ) -> Tuple[Tensor, Tensor]: + """Execute the sequential INN in forward or inverse (rev=True) direction. + + Args: + x_or_z: input tensor (in contrast to GraphINN, a list of + tensors is not supported, as SequenceINN only has + one input). + c: list of conditions. + rev: whether to compute the network forward or reversed. + jac: whether to compute the log jacobian + Returns: + z_or_x (Tensor): network output. + jac (Tensor): log-jacobian-determinant. + """ + + iterator = range(len(self.module_list)) + log_det_jac = 0 + + if rev: + iterator = reversed(iterator) + + if torch.is_tensor(x_or_z): + x_or_z = (x_or_z,) + for i in iterator: + if self.conditions[i] is None: + x_or_z, j = self.module_list[i](x_or_z, jac=jac, rev=rev) + else: + x_or_z, j = self.module_list[i](x_or_z, c=[c[self.conditions[i]]], jac=jac, rev=rev) + log_det_jac = j + log_det_jac + + return x_or_z if self.force_tuple_output else x_or_z[0], log_det_jac diff --git a/anomalib/models/components/freia/modules/__init__.py b/anomalib/models/components/freia/modules/__init__.py new file mode 100644 index 0000000000..4060ed6bfe --- /dev/null +++ b/anomalib/models/components/freia/modules/__init__.py @@ -0,0 +1,10 @@ +"""Modules.""" + +# Copyright (c) 2018-2022 Lynton Ardizzone, Visual Learning Lab Heidelberg. +# SPDX-License-Identifier: MIT +# + +from .all_in_one_block import AllInOneBlock +from .base import InvertibleModule + +__all__ = ["AllInOneBlock", "InvertibleModule"] diff --git a/anomalib/models/components/freia/modules/all_in_one_block.py b/anomalib/models/components/freia/modules/all_in_one_block.py new file mode 100644 index 0000000000..cc35c1c3f6 --- /dev/null +++ b/anomalib/models/components/freia/modules/all_in_one_block.py @@ -0,0 +1,289 @@ +"""All in One Block Module.""" + +# Copyright (c) 2018-2022 Lynton Ardizzone, Visual Learning Lab Heidelberg. +# SPDX-License-Identifier: MIT +# + +# flake8: noqa +# pylint: skip-file +# type: ignore +# pydocstyle: noqa + +import warnings +from typing import Callable + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from scipy.stats import special_ortho_group + +from anomalib.models.components.freia.modules.base import InvertibleModule + + +class AllInOneBlock(InvertibleModule): + r"""Module combining the most common operations in a normalizing flow or similar model. + + It combines affine coupling, permutation, and global affine transformation + ('ActNorm'). It can also be used as GIN coupling block, perform learned + householder permutations, and use an inverted pre-permutation. The affine + transformation includes a soft clamping mechanism, first used in Real-NVP. + The block as a whole performs the following computation: + .. math:: + y = V\\,R \\; \\Psi(s_\\mathrm{global}) \\odot \\mathrm{Coupling}\\Big(R^{-1} V^{-1} x\\Big)+ t_\\mathrm{global} + - The inverse pre-permutation of x (i.e. :math:`R^{-1} V^{-1}`) is optional (see + ``reverse_permutation`` below). + - The learned householder reflection matrix + :math:`V` is also optional all together (see ``learned_householder_permutation`` + below). + - For the coupling, the input is split into :math:`x_1, x_2` along + the channel dimension. Then the output of the coupling operation is the + two halves :math:`u = \\mathrm{concat}(u_1, u_2)`. + .. math:: + u_1 &= x_1 \\odot \\exp \\Big( \\alpha \\; \\mathrm{tanh}\\big( s(x_2) \\big)\\Big) + t(x_2) \\\\ + u_2 &= x_2 + Because :math:`\\mathrm{tanh}(s) \\in [-1, 1]`, this clamping mechanism prevents + exploding values in the exponential. The hyperparameter :math:`\\alpha` can be adjusted. + """ + + def __init__( + self, + dims_in, + dims_c=[], + subnet_constructor: Callable = None, + affine_clamping: float = 2.0, + gin_block: bool = False, + global_affine_init: float = 1.0, + global_affine_type: str = "SOFTPLUS", + permute_soft: bool = False, + learned_householder_permutation: int = 0, + reverse_permutation: bool = False, + ): + r"""Initialize. + + Args: + dims_in (_type_): dims_in + dims_c (list, optional): dims_c. Defaults to []. + subnet_constructor (Callable, optional): class or callable ``f``, called as ``f(channels_in, channels_out)`` and + should return a torch.nn.Module. Predicts coupling coefficients :math:`s, t`. Defaults to None. + affine_clamping (float, optional): clamp the output of the multiplicative coefficients before + exponentiation to +/- ``affine_clamping`` (see :math:`\\alpha` above). Defaults to 2.0. + gin_block (bool, optional): Turn the block into a GIN block from Sorrenson et al, 2019. + Makes it so that the coupling operations as a whole is volume preserving. Defaults to False. + global_affine_init (float, optional): Initial value for the global affine scaling :math:`s_\mathrm{global}`.. Defaults to 1.0. + global_affine_type (str, optional): ``'SIGMOID'``, ``'SOFTPLUS'``, or ``'EXP'``. Defines the activation to be used + on the beta for the global affine scaling (:math:`\\Psi` above).. Defaults to "SOFTPLUS". + permute_soft (bool, optional): bool, whether to sample the permutation matrix :math:`R` from :math:`SO(N)`, + or to use hard permutations instead. Note, ``permute_soft=True`` is very slow + when working with >512 dimensions. Defaults to False. + learned_householder_permutation (int, optional): Int, if >0, turn on the matrix :math:`V` above, that represents + multiple learned householder reflections. Slow if large number. + Dubious whether it actually helps network performance. Defaults to 0. + reverse_permutation (bool, optional): Reverse the permutation before the block, as introduced by Putzky + et al, 2019. Turns on the :math:`R^{-1} V^{-1}` pre-multiplication above. Defaults to False. + + Raises: + ValueError: _description_ + ValueError: _description_ + ValueError: _description_ + """ + + super().__init__(dims_in, dims_c) + + channels = dims_in[0][0] + # rank of the tensors means 1d, 2d, 3d tensor etc. + self.input_rank = len(dims_in[0]) - 1 + # tuple containing all dims except for batch-dim (used at various points) + self.sum_dims = tuple(range(1, 2 + self.input_rank)) + + if len(dims_c) == 0: + self.conditional = False + self.condition_channels = 0 + else: + assert tuple(dims_c[0][1:]) == tuple( + dims_in[0][1:] + ), f"Dimensions of input and condition don't agree: {dims_c} vs {dims_in}." + self.conditional = True + self.condition_channels = sum(dc[0] for dc in dims_c) + + split_len1 = channels - channels // 2 + split_len2 = channels // 2 + self.splits = [split_len1, split_len2] + + try: + self.permute_function = {0: F.linear, 1: F.conv1d, 2: F.conv2d, 3: F.conv3d}[self.input_rank] + except KeyError: + raise ValueError(f"Data is {1 + self.input_rank}D. Must be 1D-4D.") + + self.in_channels = channels + self.clamp = affine_clamping + self.GIN = gin_block + self.reverse_pre_permute = reverse_permutation + self.householder = learned_householder_permutation + + if permute_soft and channels > 512: + warnings.warn( + ( + "Soft permutation will take a very long time to initialize " + f"with {channels} feature channels. Consider using hard permutation instead." + ) + ) + + # global_scale is used as the initial value for the global affine scale + # (pre-activation). It is computed such that + # global_scale_activation(global_scale) = global_affine_init + # the 'magic numbers' (specifically for sigmoid) scale the activation to + # a sensible range. + if global_affine_type == "SIGMOID": + global_scale = 2.0 - np.log(10.0 / global_affine_init - 1.0) + self.global_scale_activation = lambda a: 10 * torch.sigmoid(a - 2.0) + elif global_affine_type == "SOFTPLUS": + global_scale = 2.0 * np.log(np.exp(0.5 * 10.0 * global_affine_init) - 1) + self.softplus = nn.Softplus(beta=0.5) + self.global_scale_activation = lambda a: 0.1 * self.softplus(a) + elif global_affine_type == "EXP": + global_scale = np.log(global_affine_init) + self.global_scale_activation = lambda a: torch.exp(a) + else: + raise ValueError('Global affine activation must be "SIGMOID", "SOFTPLUS" or "EXP"') + + self.global_scale = nn.Parameter( + torch.ones(1, self.in_channels, *([1] * self.input_rank)) * float(global_scale) + ) + self.global_offset = nn.Parameter(torch.zeros(1, self.in_channels, *([1] * self.input_rank))) + + if permute_soft: + w = special_ortho_group.rvs(channels) + else: + w = np.zeros((channels, channels)) + for i, j in enumerate(np.random.permutation(channels)): + w[i, j] = 1.0 + + if self.householder: + # instead of just the permutation matrix w, the learned housholder + # permutation keeps track of reflection vectors vk, in addition to a + # random initial permutation w_0. + self.vk_householder = nn.Parameter(0.2 * torch.randn(self.householder, channels), requires_grad=True) + self.w_perm = None + self.w_perm_inv = None + self.w_0 = nn.Parameter(torch.FloatTensor(w), requires_grad=False) + else: + self.w_perm = nn.Parameter( + torch.FloatTensor(w).view(channels, channels, *([1] * self.input_rank)), requires_grad=False + ) + self.w_perm_inv = nn.Parameter( + torch.FloatTensor(w.T).view(channels, channels, *([1] * self.input_rank)), requires_grad=False + ) + + if subnet_constructor is None: + raise ValueError("Please supply a callable subnet_constructor" "function or object (see docstring)") + self.subnet = subnet_constructor(self.splits[0] + self.condition_channels, 2 * self.splits[1]) + self.last_jac = None + + def _construct_householder_permutation(self): + """Compute a permutation matrix. + + Compute a permutation matrix from the reflection vectors that are + learned internally as nn.Parameters. + """ + w = self.w_0 + for vk in self.vk_householder: + w = torch.mm(w, torch.eye(self.in_channels).to(w.device) - 2 * torch.ger(vk, vk) / torch.dot(vk, vk)) + + for i in range(self.input_rank): + w = w.unsqueeze(-1) + return w + + def _permute(self, x, rev=False): + """Perform permutation. + + Performs the permutation and scaling after the coupling operation. + Returns transformed outputs and the LogJacDet of the scaling operation. + """ + if self.GIN: + scale = 1.0 + perm_log_jac = 0.0 + else: + scale = self.global_scale_activation(self.global_scale) + perm_log_jac = torch.sum(torch.log(scale)) + + if rev: + return ((self.permute_function(x, self.w_perm_inv) - self.global_offset) / scale, perm_log_jac) + else: + return (self.permute_function(x * scale + self.global_offset, self.w_perm), perm_log_jac) + + def _pre_permute(self, x, rev=False): + """Permute before the coupling block, only used if reverse_permutation is set.""" + if rev: + return self.permute_function(x, self.w_perm) + else: + return self.permute_function(x, self.w_perm_inv) + + def _affine(self, x, a, rev=False): + """Perform affine coupling operation. + + Given the passive half, and the pre-activation outputs of the + coupling subnetwork, perform the affine coupling operation. + Returns both the transformed inputs and the LogJacDet. + """ + + # the entire coupling coefficient tensor is scaled down by a + # factor of ten for stability and easier initialization. + a *= 0.1 + ch = x.shape[1] + + sub_jac = self.clamp * torch.tanh(a[:, :ch]) + if self.GIN: + sub_jac -= torch.mean(sub_jac, dim=self.sum_dims, keepdim=True) + + if not rev: + return (x * torch.exp(sub_jac) + a[:, ch:], torch.sum(sub_jac, dim=self.sum_dims)) + else: + return ((x - a[:, ch:]) * torch.exp(-sub_jac), -torch.sum(sub_jac, dim=self.sum_dims)) + + def forward(self, x, c=[], rev=False, jac=True): + """See base class docstring.""" + if self.householder: + self.w_perm = self._construct_householder_permutation() + if rev or self.reverse_pre_permute: + self.w_perm_inv = self.w_perm.transpose(0, 1).contiguous() + + if rev: + x, global_scaling_jac = self._permute(x[0], rev=True) + x = (x,) + elif self.reverse_pre_permute: + x = (self._pre_permute(x[0], rev=False),) + + x1, x2 = torch.split(x[0], self.splits, dim=1) + + if self.conditional: + x1c = torch.cat([x1, *c], 1) + else: + x1c = x1 + + if not rev: + a1 = self.subnet(x1c) + x2, j2 = self._affine(x2, a1) + else: + a1 = self.subnet(x1c) + x2, j2 = self._affine(x2, a1, rev=True) + + log_jac_det = j2 + x_out = torch.cat((x1, x2), 1) + + if not rev: + x_out, global_scaling_jac = self._permute(x_out, rev=False) + elif self.reverse_pre_permute: + x_out = self._pre_permute(x_out, rev=True) + + # add the global scaling Jacobian to the total. + # trick to get the total number of non-channel dimensions: + # number of elements of the first channel of the first batch member + n_pixels = x_out[0, :1].numel() + log_jac_det += (-1) ** rev * n_pixels * global_scaling_jac + + return (x_out,), log_jac_det + + def output_dims(self, input_dims): + """Output Dims.""" + return input_dims diff --git a/anomalib/models/components/freia/modules/base.py b/anomalib/models/components/freia/modules/base.py new file mode 100644 index 0000000000..0a67d31541 --- /dev/null +++ b/anomalib/models/components/freia/modules/base.py @@ -0,0 +1,112 @@ +"""Base Module.""" + +# Copyright (c) 2018-2022 Lynton Ardizzone, Visual Learning Lab Heidelberg. +# SPDX-License-Identifier: MIT +# + +# flake8: noqa +# pylint: skip-file +# type: ignore +# pydocstyle: noqa + +from typing import Iterable, List, Tuple + +import torch.nn as nn +from torch import Tensor + + +class InvertibleModule(nn.Module): + r"""Base class for all invertible modules in FrEIA. + + Given ``module``, an instance of some InvertibleModule. + This ``module`` shall be invertible in its input dimensions, + so that the input can be recovered by applying the module + in backwards mode (``rev=True``), not to be confused with + ``pytorch.backward()`` which computes the gradient of an operation:: + x = torch.randn(BATCH_SIZE, DIM_COUNT) + c = torch.randn(BATCH_SIZE, CONDITION_DIM) + # Forward mode + z, jac = module([x], [c], jac=True) + # Backward mode + x_rev, jac_rev = module(z, [c], rev=True) + The ``module`` returns :math:`\\log \\det J = \\log \\left| \\det \\frac{\\partial f}{\\partial x} \\right|` + of the operation in forward mode, and + :math:`-\\log | \\det J | = \\log \\left| \\det \\frac{\\partial f^{-1}}{\\partial z} \\right| = -\\log \\left| \\det \\frac{\\partial f}{\\partial x} \\right|` + in backward mode (``rev=True``). + Then, ``torch.allclose(x, x_rev) == True`` and ``torch.allclose(jac, -jac_rev) == True``. + """ + + def __init__(self, dims_in: Iterable[Tuple[int]], dims_c: Iterable[Tuple[int]] = None): + """Initialize. + + Args: + dims_in: list of tuples specifying the shape of the inputs to this + operator: ``dims_in = [shape_x_0, shape_x_1, ...]`` + dims_c: list of tuples specifying the shape of the conditions to + this operator. + """ + super().__init__() + if dims_c is None: + dims_c = [] + self.dims_in = list(dims_in) + self.dims_c = list(dims_c) + + def forward( + self, x_or_z: Iterable[Tensor], c: Iterable[Tensor] = None, rev: bool = False, jac: bool = True + ) -> Tuple[Tuple[Tensor], Tensor]: + r"""Forward/Backward Pass. + + Perform a forward (default, ``rev=False``) or backward pass (``rev=True``) through this module/operator. + + **Note to implementers:** + - Subclasses MUST return a Jacobian when ``jac=True``, but CAN return a + valid Jacobian when ``jac=False`` (not punished). The latter is only recommended + if the computation of the Jacobian is trivial. + - Subclasses MUST follow the convention that the returned Jacobian be + consistent with the evaluation direction. Let's make this more precise: + Let :math:`f` be the function that the subclass represents. Then: + .. math:: + J &= \\log \\det \\frac{\\partial f}{\\partial x} \\\\ + -J &= \\log \\det \\frac{\\partial f^{-1}}{\\partial z}. + Any subclass MUST return :math:`J` for forward evaluation (``rev=False``), + and :math:`-J` for backward evaluation (``rev=True``). + + Args: + x_or_z: input data (array-like of one or more tensors) + c: conditioning data (array-like of none or more tensors) + rev: perform backward pass + jac: return Jacobian associated to the direction + """ + raise NotImplementedError(f"{self.__class__.__name__} does not provide forward(...) method") + + def log_jacobian(self, *args, **kwargs): + """This method is deprecated, and does nothing except raise a warning.""" + raise DeprecationWarning( + "module.log_jacobian(...) is deprecated. " + "module.forward(..., jac=True) returns a " + "tuple (out, jacobian) now." + ) + + def output_dims(self, input_dims: List[Tuple[int]]) -> List[Tuple[int]]: + """Use for shape inference during construction of the graph. + + MUST be implemented for each subclass of ``InvertibleModule``. + + Args: + input_dims: A list with one entry for each input to the module. + Even if the module only has one input, must be a list with one + entry. Each entry is a tuple giving the shape of that input, + excluding the batch dimension. For example for a module with one + input, which receives a 32x32 pixel RGB image, ``input_dims`` would + be ``[(3, 32, 32)]`` + + Returns: + A list structured in the same way as ``input_dims``. Each entry + represents one output of the module, and the entry is a tuple giving + the shape of that output. For example if the module splits the image + into a right and a left half, the return value should be + ``[(3, 16, 32), (3, 16, 32)]``. It is up to the implementor of the + subclass to ensure that the total number of elements in all inputs + and all outputs is consistent. + """ + raise NotImplementedError(f"{self.__class__.__name__} does not provide output_dims(...)") diff --git a/requirements/base.txt b/requirements/base.txt index 8276af5a5e..63eed41f88 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -2,7 +2,6 @@ albumentations==1.1.0 attrdict==2.0.1 defusedxml==0.7.1 einops==0.3.2 -FrEIA @ git+https://github.com/VLL-HD/FrEIA.git kornia==0.5.6 lxml==4.6.5 matplotlib==3.4.3 diff --git a/requirements/openvino.txt b/requirements/openvino.txt index 4c48b3428b..e172447abe 100644 --- a/requirements/openvino.txt +++ b/requirements/openvino.txt @@ -1,6 +1,6 @@ defusedxml==0.7.1 requests==2.26.0 networkx~=2.5 -nncf@ git+https://github.com/openvinotoolkit/nncf@37a830a412e60ec2fd2d84d7f00e2524e5f62777#egg=nncf +nncf==2.1.0 onnx==1.10.1 openvino-dev==2021.4.2 diff --git a/third-party-programs.txt b/third-party-programs.txt index 4d2b37ec25..23b6dce10f 100644 --- a/third-party-programs.txt +++ b/third-party-programs.txt @@ -18,3 +18,7 @@ terms are listed below. 1. Encoder, Decoder, Discriminator, Generator Copyright (c) 2018-2022 Samet Akcay, Durham University, UK SPDX-License-Identifier: MIT + +2. SequenceINN, InvertibleModule, AllInOneBlock + Copyright (c) 2018-2022 Lynton Ardizzone, Visual Learning Lab Heidelberg. + SPDX-License-Identifier: MIT diff --git a/tox.ini b/tox.ini index dd79c0ed1d..88b916fa7a 100644 --- a/tox.ini +++ b/tox.ini @@ -27,7 +27,7 @@ basepython = python3 deps = flake8 mccabe -commands = flake8 anomalib +commands = flake8 anomalib --exclude=anomalib/models/components/freia [testenv:pylint] @@ -36,7 +36,7 @@ basepython = python3 deps = pylint -r{toxinidir}/requirements/base.txt -commands = pylint anomalib --rcfile=tox.ini +commands = pylint anomalib --rcfile=tox.ini --ignore=anomalib/models/components/freia/ [testenv:mypy] basepython = python3 @@ -118,7 +118,7 @@ disable = duplicate-code, generated-members = numpy.*, torch.* good-names = e, i, id -ignore = tests,docs +ignore = tests,docs,anomalib/models/components/freia max-line-length = 120 max-parents = 15 @@ -128,6 +128,9 @@ min-similarity-lines = 5 [mypy] ignore_missing_imports = True show_error_codes = True +exclude = anomalib/models/components/freia/ +[mypy-anomalib.models.components.freia.*] +follow_imports = skip [mypy-torch.*] follow_imports = skip follow_imports_for_stubs = True