Skip to content

Commit

Permalink
Update for nncf_task
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderDokuchaev committed Mar 14, 2022
1 parent cebc3a4 commit 79edc56
Show file tree
Hide file tree
Showing 8 changed files with 385 additions and 2 deletions.
15 changes: 15 additions & 0 deletions anomalib/integration/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Integration."""

# Copyright (C) 2021 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
15 changes: 15 additions & 0 deletions anomalib/integration/nncf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Integration NNCF."""

# Copyright (C) 2021 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
80 changes: 80 additions & 0 deletions anomalib/integration/nncf/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""Callbacks for NNCF optimization."""

# Copyright (C) 2022 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

from typing import Any, Dict, Optional

import pytorch_lightning as pl
from nncf import NNCFConfig
from nncf.torch import register_default_init_args
from pytorch_lightning import Callback

from anomalib.integration.nncf.compression import wrap_nncf_model
from anomalib.integration.nncf.utils import InitLoader


class NNCFCallback(Callback):
"""Callback for NNCF compression.
Assumes that the pl module contains a 'model' attribute, which is
the PyTorch module that must be compressed.
Args:
config (Dict): NNCF Configuration
"""

def __init__(self, nncf_config: Dict):
self.nncf_config = NNCFConfig(nncf_config)
self.nncf_ctrl = None

# pylint: disable=unused-argument
def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: Optional[str] = None) -> None:
"""Call when fit or test begins.
Takes the pytorch model and wraps it using the compression controller
so that it is ready for nncf fine-tuning.
"""
if self.nncf_ctrl:
return
# pylint: disable=attr-defined
init_loader = InitLoader(trainer.datamodule.train_dataloader()) # type: ignore
nncf_config = register_default_init_args(self.nncf_config, init_loader)

self.nncf_ctrl, pl_module.model = wrap_nncf_model(
model=pl_module.model, config=nncf_config, dataloader=trainer.datamodule.train_dataloader()
)

def on_train_batch_start(
self,
trainer: pl.Trainer,
_pl_module: pl.LightningModule,
_batch: Any,
_batch_idx: int,
_unused: Optional[int] = 0,
) -> None:
"""Call when the train batch begins.
Prepare compression method to continue training the model in the next step.
"""
if self.nncf_ctrl:
self.nncf_ctrl.scheduler.step()

def on_train_epoch_start(self, _trainer: pl.Trainer, _pl_module: pl.LightningModule) -> None:
"""Call when the train epoch starts.
Prepare compression method to continue training the model in the next epoch.
"""
if self.nncf_ctrl:
self.nncf_ctrl.scheduler.epoch_step()
113 changes: 113 additions & 0 deletions anomalib/integration/nncf/compression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""NNCF functions."""

# Copyright (C) 2022 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

from typing import Any, Dict, Iterator, Tuple

import torch.nn as nn
from nncf import NNCFConfig
from nncf.torch import create_compressed_model, load_state, register_default_init_args
from nncf.torch.compression_method_api import PTCompressionAlgorithmController
from nncf.torch.initialization import PTInitializingDataLoader
from nncf.torch.nncf_network import NNCFNetwork
from ote_anomalib.logging import get_logger
from torch.utils.data.dataloader import DataLoader

logger = get_logger(__name__)


class InitLoader(PTInitializingDataLoader):
"""Initializing data loader for NNCF to be used with unsupervised training algorithms."""

def __init__(self, data_loader: DataLoader):
super().__init__(data_loader)
self._data_loader_iter: Iterator

def __iter__(self):
"""Create iterator for dataloader."""
self._data_loader_iter = iter(self._data_loader)
return self

def __next__(self) -> Any:
"""Return next item from dataloader iterator."""
loaded_item = next(self._data_loader_iter)
return loaded_item["image"]

def get_inputs(self, dataloader_output) -> Tuple[Tuple, Dict]:
"""Get input to model.
Returns:
(dataloader_output,), {}: Tuple[Tuple, Dict]: The current model call to be made during
the initialization process
"""
return (dataloader_output,), {}

def get_target(self, _):
"""Return structure for ground truth in loss criterion based on dataloader output.
This implementation does not do anything and is a placeholder.
Returns:
None
"""
return None


def wrap_nncf_model(
model: nn.Module, config: Dict, dataloader: DataLoader = None, init_state_dict: Dict = None
) -> Tuple[NNCFNetwork, PTCompressionAlgorithmController]:
"""
Wrap model by NNCF.
:param model: Anomalib model.
:param config: NNCF config.
:param dataloader: Dataloader for initialization of NNCF model.
:param init_state_dict: Opti
:return: compression controller, compressed model
"""
nncf_config = NNCFConfig.from_dict(config)

if not dataloader and not init_state_dict:
logger.warning(
"Either dataloader or NNCF pre-trained "
"model checkpoint should be set. Without this, "
"quantizers will not be initialized"
)

compression_state = None
resuming_state_dict = None
if init_state_dict:
resuming_state_dict = init_state_dict.get("model")
compression_state = init_state_dict.get("compression_state")

if dataloader:
init_loader = InitLoader(dataloader) # type: ignore
nncf_config = register_default_init_args(nncf_config, init_loader)

nncf_ctrl, nncf_model = create_compressed_model(
model=model, config=nncf_config, dump_graphs=False, compression_state=compression_state
)

if resuming_state_dict:
load_state(nncf_model, resuming_state_dict, is_resume=True)

return nncf_ctrl, nncf_model


def is_state_nncf(state: Dict) -> bool:
"""
The function to check if sate was the result of training of NNCF-compressed model.
"""
return bool(state.get("meta", {}).get("nncf_enable_compression", False))
155 changes: 155 additions & 0 deletions anomalib/integration/nncf/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
"""Utils for NNCf optimization."""

# Copyright (C) 2022 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

from copy import copy
from typing import Any, Dict, Iterator, List, Tuple

from nncf.torch.initialization import PTInitializingDataLoader
from torch.utils.data.dataloader import DataLoader


class InitLoader(PTInitializingDataLoader):
"""Initializing data loader for NNCF to be used with unsupervised training algorithms."""

def __init__(self, data_loader: DataLoader):
super().__init__(data_loader)
self._data_loader_iter: Iterator

def __iter__(self):
"""Create iterator for dataloader."""
self._data_loader_iter = iter(self._data_loader)
return self

def __next__(self) -> Any:
"""Return next item from dataloader iterator."""
loaded_item = next(self._data_loader_iter)
return loaded_item["image"]

def get_inputs(self, dataloader_output) -> Tuple[Tuple, Dict]:
"""Get input to model.
Returns:
(dataloader_output,), {}: Tuple[Tuple, Dict]: The current model call to be made during
the initialization process
"""
return (dataloader_output,), {}

def get_target(self, _):
"""Return structure for ground truth in loss criterion based on dataloader output.
This implementation does not do anything and is a placeholder.
Returns:
None
"""
return None


def compose_nncf_config(nncf_config: Dict, enabled_options: List[str]) -> Dict:
"""Compose NNCf config by selected options.
:param nncf_config:
:param enabled_options:
:return: config
"""
optimisation_parts = nncf_config

if "order_of_parts" in optimisation_parts:
# The result of applying the changes from optimisation parts
# may depend on the order of applying the changes
# (e.g. if for nncf_quantization it is sufficient to have `total_epochs=2`,
# but for sparsity it is required `total_epochs=50`)
# So, user can define `order_of_parts` in the optimisation_config
# to specify the order of applying the parts.
order_of_parts = optimisation_parts["order_of_parts"]
assert isinstance(order_of_parts, list), 'The field "order_of_parts" in optimisation config should be a list'

for part in enabled_options:
assert part in order_of_parts, (
f"The part {part} is selected, " "but it is absent in order_of_parts={order_of_parts}"
)

optimisation_parts_to_choose = [part for part in order_of_parts if part in enabled_options]

assert "base" in optimisation_parts, 'Error: the optimisation config does not contain the "base" part'
nncf_config_part = optimisation_parts["base"]

for part in optimisation_parts_to_choose:
assert part in optimisation_parts, f'Error: the optimisation config does not contain the part "{part}"'
optimisation_part_dict = optimisation_parts[part]
try:
nncf_config_part = merge_dicts_and_lists_b_into_a(nncf_config_part, optimisation_part_dict)
except AssertionError as cur_error:
err_descr = (
f"Error during merging the parts of nncf configs:\n"
f"the current part={part}, "
f"the order of merging parts into base is {optimisation_parts_to_choose}.\n"
f"The error is:\n{cur_error}"
)
raise RuntimeError(err_descr) from None

return nncf_config_part


# pylint: disable=invalid-name,missing-function-docstring
def merge_dicts_and_lists_b_into_a(a, b):
"""The fucntion to merge dict configs."""
return _merge_dicts_and_lists_b_into_a(a, b, "")


def _merge_dicts_and_lists_b_into_a(a, b, cur_key=None):
"""The function is inspired by mmcf.Config._merge_a_into_b.
* works with usual dicts and lists and derived types
* supports merging of lists (by concatenating the lists)
* makes recursive merging for dict + dict case
* overwrites when merging scalar into scalar
Note that we merge b into a (whereas Config makes merge a into b),
since otherwise the order of list merging is counter-intuitive.
"""

def _err_str(_a, _b, _key):
if _key is None:
_key_str = "of whole structures"
else:
_key_str = f"during merging for key=`{_key}`"
return (
f"Error in merging parts of config: different types {_key_str},"
f" type(a) = {type(_a)},"
f" type(b) = {type(_b)}"
)

assert isinstance(a, (dict, list)), f"Can merge only dicts and lists, whereas type(a)={type(a)}"
assert isinstance(b, (dict, list)), _err_str(a, b, cur_key)
assert isinstance(a, list) == isinstance(b, list), _err_str(a, b, cur_key)
if isinstance(a, list):
# the main diff w.r.t. mmcf.Config -- merging of lists
return a + b

a = copy(a)
for k in b.keys():
if k not in a:
a[k] = copy(b[k])
continue
new_cur_key = cur_key + "." + k if cur_key else k
if isinstance(a[k], (dict, list)):
a[k] = _merge_dicts_and_lists_b_into_a(a[k], b[k], new_cur_key)
continue

assert not isinstance(b[k], (dict, list)), _err_str(a[k], b[k], new_cur_key)

# suppose here that a[k] and b[k] are scalars, just overwrite
a[k] = b[k]
return a
4 changes: 4 additions & 0 deletions anomalib/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@

from anomalib.models.components import AnomalyModule

# TODO(AlexanderDokuchaev): Workaround of wrapping by NNCF.
# Can't not wrap `spatial_softmax2d` if use import_module.
from anomalib.models.padim.model import PadimLightning


def get_model(config: Union[DictConfig, ListConfig]) -> AnomalyModule:
"""Load model from the configuration file.
Expand Down
Loading

0 comments on commit 79edc56

Please sign in to comment.