Skip to content

Commit

Permalink
linter
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderDokuchaev committed Mar 14, 2022
1 parent 79edc56 commit 646dc61
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 10 deletions.
4 changes: 2 additions & 2 deletions anomalib/integration/nncf/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: Optio
"""
if self.nncf_ctrl:
return
# pylint: disable=attr-defined
init_loader = InitLoader(trainer.datamodule.train_dataloader()) # type: ignore

init_loader = InitLoader(trainer.datamodule.train_dataloader())
nncf_config = register_default_init_args(self.nncf_config, init_loader)

self.nncf_ctrl, pl_module.model = wrap_nncf_model(
Expand Down
12 changes: 5 additions & 7 deletions anomalib/integration/nncf/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions
# and limitations under the License.

import logging
from typing import Any, Dict, Iterator, Tuple

import torch.nn as nn
Expand All @@ -22,10 +23,10 @@
from nncf.torch.compression_method_api import PTCompressionAlgorithmController
from nncf.torch.initialization import PTInitializingDataLoader
from nncf.torch.nncf_network import NNCFNetwork
from ote_anomalib.logging import get_logger
from torch.utils.data.dataloader import DataLoader

logger = get_logger(__name__)
logger = logging.getLogger(name="NNCF compression")
logger.setLevel(logging.DEBUG)


class InitLoader(PTInitializingDataLoader):
Expand Down Expand Up @@ -68,8 +69,7 @@ def get_target(self, _):
def wrap_nncf_model(
model: nn.Module, config: Dict, dataloader: DataLoader = None, init_state_dict: Dict = None
) -> Tuple[NNCFNetwork, PTCompressionAlgorithmController]:
"""
Wrap model by NNCF.
"""Wrap model by NNCF.
:param model: Anomalib model.
:param config: NNCF config.
Expand Down Expand Up @@ -107,7 +107,5 @@ def wrap_nncf_model(


def is_state_nncf(state: Dict) -> bool:
"""
The function to check if sate was the result of training of NNCF-compressed model.
"""
"""The function to check if sate is the result of NNCF-compressed model."""
return bool(state.get("meta", {}).get("nncf_enable_compression", False))
1 change: 1 addition & 0 deletions anomalib/integration/nncf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def merge_dicts_and_lists_b_into_a(a, b):

def _merge_dicts_and_lists_b_into_a(a, b, cur_key=None):
"""The function is inspired by mmcf.Config._merge_a_into_b.
* works with usual dicts and lists and derived types
* supports merging of lists (by concatenating the lists)
* makes recursive merging for dict + dict case
Expand Down
2 changes: 1 addition & 1 deletion anomalib/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

# TODO(AlexanderDokuchaev): Workaround of wrapping by NNCF.
# Can't not wrap `spatial_softmax2d` if use import_module.
from anomalib.models.padim.model import PadimLightning
from anomalib.models.padim.model import PadimLightning # noqa: F401


def get_model(config: Union[DictConfig, ListConfig]) -> AnomalyModule:
Expand Down

0 comments on commit 646dc61

Please sign in to comment.