Skip to content

Commit

Permalink
linter
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderDokuchaev committed Mar 14, 2022
1 parent 79edc56 commit 1e2267a
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 11 deletions.
4 changes: 2 additions & 2 deletions anomalib/integration/nncf/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@ def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: Optio
"""
if self.nncf_ctrl:
return
# pylint: disable=attr-defined

init_loader = InitLoader(trainer.datamodule.train_dataloader()) # type: ignore
nncf_config = register_default_init_args(self.nncf_config, init_loader)

self.nncf_ctrl, pl_module.model = wrap_nncf_model(
model=pl_module.model, config=nncf_config, dataloader=trainer.datamodule.train_dataloader()
model=pl_module.model, config=nncf_config, dataloader=trainer.datamodule.train_dataloader() # type: ignore
)

def on_train_batch_start(
Expand Down
14 changes: 6 additions & 8 deletions anomalib/integration/nncf/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,19 @@
# See the License for the specific language governing permissions
# and limitations under the License.

import logging
from typing import Any, Dict, Iterator, Tuple

import torch.nn as nn
from nncf import NNCFConfig
from nncf.torch import create_compressed_model, load_state, register_default_init_args
from nncf.torch.compression_method_api import PTCompressionAlgorithmController
from nncf.torch.initialization import PTInitializingDataLoader
from nncf.torch.nncf_network import NNCFNetwork
from ote_anomalib.logging import get_logger
from torch import nn
from torch.utils.data.dataloader import DataLoader

logger = get_logger(__name__)
logger = logging.getLogger(name="NNCF compression")
logger.setLevel(logging.DEBUG)


class InitLoader(PTInitializingDataLoader):
Expand Down Expand Up @@ -68,8 +69,7 @@ def get_target(self, _):
def wrap_nncf_model(
model: nn.Module, config: Dict, dataloader: DataLoader = None, init_state_dict: Dict = None
) -> Tuple[NNCFNetwork, PTCompressionAlgorithmController]:
"""
Wrap model by NNCF.
"""Wrap model by NNCF.
:param model: Anomalib model.
:param config: NNCF config.
Expand Down Expand Up @@ -107,7 +107,5 @@ def wrap_nncf_model(


def is_state_nncf(state: Dict) -> bool:
"""
The function to check if sate was the result of training of NNCF-compressed model.
"""
"""The function to check if sate is the result of NNCF-compressed model."""
return bool(state.get("meta", {}).get("nncf_enable_compression", False))
1 change: 1 addition & 0 deletions anomalib/integration/nncf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def merge_dicts_and_lists_b_into_a(a, b):

def _merge_dicts_and_lists_b_into_a(a, b, cur_key=None):
"""The function is inspired by mmcf.Config._merge_a_into_b.
* works with usual dicts and lists and derived types
* supports merging of lists (by concatenating the lists)
* makes recursive merging for dict + dict case
Expand Down
2 changes: 1 addition & 1 deletion anomalib/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

# TODO(AlexanderDokuchaev): Workaround of wrapping by NNCF.
# Can't not wrap `spatial_softmax2d` if use import_module.
from anomalib.models.padim.model import PadimLightning
from anomalib.models.padim.model import PadimLightning # noqa: F401


def get_model(config: Union[DictConfig, ListConfig]) -> AnomalyModule:
Expand Down

0 comments on commit 1e2267a

Please sign in to comment.