Skip to content

Commit

Permalink
Update for nncf_task
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderDokuchaev committed Mar 14, 2022
1 parent cebc3a4 commit 5995978
Show file tree
Hide file tree
Showing 8 changed files with 390 additions and 2 deletions.
13 changes: 13 additions & 0 deletions anomalib/integration/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (C) 2021 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
13 changes: 13 additions & 0 deletions anomalib/integration/nncf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (C) 2021 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
82 changes: 82 additions & 0 deletions anomalib/integration/nncf/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Calbacks for NNCF optimization"""

# Copyright (C) 2022 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

from typing import Any, Dict, Optional

import pytorch_lightning as pl
from pytorch_lightning import Callback

from nncf import NNCFConfig
from nncf.torch import register_default_init_args

from anomalib.integration.nncf.compression import wrap_nncf_model
from anomalib.integration.nncf.utils import InitLoader

class NNCFCallback(Callback):
"""Callback for NNCF compression.
Assumes that the pl module contains a 'model' attribute, which is
the PyTorch module that must be compressed.
Args:
config (Dict): NNCF Configuration
"""

def __init__(self, nncf_config: Dict):
self.nncf_config = NNCFConfig(nncf_config)
self.nncf_ctrl = None

# pylint: disable=unused-argument
def setup(self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
stage: Optional[str] = None) -> None:
"""Call when fit or test begins.
Takes the pytorch model and wraps it using the compression controller
so that it is ready for nncf fine-tuning.
"""
if self.nncf_ctrl:
return
init_loader = InitLoader(trainer.datamodule.train_dataloader()) # type: ignore
nncf_config = register_default_init_args(
self.nncf_config, init_loader
)

self.nncf_ctrl, pl_module.model = wrap_nncf_model(model=pl_module.model,
config=nncf_config,
dataloader=trainer.datamodule.train_dataloader())

def on_train_batch_start(
self,
trainer: pl.Trainer,
_pl_module: pl.LightningModule,
_batch: Any,
_batch_idx: int,
_unused: Optional[int] = 0,
) -> None:
"""Call when the train batch begins.
Prepare compression method to continue training the model in the next step.
"""
self.nncf_ctrl.scheduler.step()

def on_train_epoch_start(self, _trainer: pl.Trainer, _pl_module: pl.LightningModule) -> None:
"""Call when the train epoch starts.
Prepare compression method to continue training the model in the next epoch.
"""
self.nncf_ctrl.scheduler.epoch_step()
120 changes: 120 additions & 0 deletions anomalib/integration/nncf/compression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""NNCF functions"""

# Copyright (C) 2022 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

from typing import Any, Dict, Iterator, Tuple

import torch.nn as nn
from torch.utils.data.dataloader import DataLoader
from nncf import NNCFConfig
from nncf.torch import create_compressed_model, register_default_init_args
from nncf.torch import load_state
from nncf.torch.compression_method_api import PTCompressionAlgorithmController
from nncf.torch.initialization import PTInitializingDataLoader
from nncf.torch.nncf_network import NNCFNetwork
from ote_anomalib.logging import get_logger

logger = get_logger(__name__)


class InitLoader(PTInitializingDataLoader):
"""Initializing data loader for NNCF to be used with unsupervised training algorithms."""

def __init__(self, data_loader: DataLoader):
super().__init__(data_loader)
self._data_loader_iter: Iterator

def __iter__(self):
"""Create iterator for dataloader."""
self._data_loader_iter = iter(self._data_loader)
return self

def __next__(self) -> Any:
"""Return next item from dataloader iterator."""
loaded_item = next(self._data_loader_iter)
return loaded_item["image"]

def get_inputs(self, dataloader_output) -> Tuple[Tuple, Dict]:
"""Get input to model.
Returns:
(dataloader_output,), {}: Tuple[Tuple, Dict]: The current model call to be made during
the initialization process
"""
return (dataloader_output,), {}

def get_target(self, _):
"""Return structure for ground truth in loss criterion based on dataloader output.
This implementation does not do anything and is a placeholder.
Returns:
None
"""
return None


def wrap_nncf_model(
model: nn.Module,
config: Dict,
dataloader: DataLoader = None,
init_state_dict: Dict = None
) -> Tuple[NNCFNetwork, PTCompressionAlgorithmController]:
"""
Wrapping model by NNCF
:param model: Anomalib model.
:param config: NNCF config.
:param dataloader: Dataloader for initialization of NNCF model.
:param init_state_dict: Opti
:return: compression controller, compressed model
"""
nncf_config = NNCFConfig.from_dict(config)

if not dataloader and not init_state_dict:
logger.warning('Either dataloader or NNCF pre-trained '
'model checkpoint should be set. Without this, '
'quantizers will not be initialized')

compression_state = None
resuming_state_dict = None
if init_state_dict:
resuming_state_dict = init_state_dict.get("model")
compression_state = init_state_dict.get("compression_state")

if dataloader:
init_loader = InitLoader(dataloader) # type: ignore
nncf_config = register_default_init_args(
nncf_config, init_loader
)

nncf_ctrl, nncf_model = create_compressed_model(model=model,
config=nncf_config,
dump_graphs=False,
compression_state=compression_state)

if resuming_state_dict:
load_state(nncf_model, resuming_state_dict, is_resume=True)

return nncf_ctrl, nncf_model


def is_state_nncf(state: Dict) -> None:
"""
The function uses metadata stored in a dict_state to check if the
checkpoint was the result of trainning of NNCF-compressed model.
See the function get_nncf_metadata above.
"""
return bool(state.get('meta',{}).get('nncf_enable_compression', False))
155 changes: 155 additions & 0 deletions anomalib/integration/nncf/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
"""Utils for NNCf optimization"""

# Copyright (C) 2022 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

from typing import Any, Dict, Iterator, Tuple, List

from copy import copy
from torch.utils.data.dataloader import DataLoader
from nncf.torch.initialization import PTInitializingDataLoader


class InitLoader(PTInitializingDataLoader):
"""Initializing data loader for NNCF to be used with unsupervised training algorithms."""

def __init__(self, data_loader: DataLoader):
super().__init__(data_loader)
self._data_loader_iter: Iterator

def __iter__(self):
"""Create iterator for dataloader."""
self._data_loader_iter = iter(self._data_loader)
return self

def __next__(self) -> Any:
"""Return next item from dataloader iterator."""
loaded_item = next(self._data_loader_iter)
return loaded_item["image"]

def get_inputs(self, dataloader_output) -> Tuple[Tuple, Dict]:
"""Get input to model.
Returns:
(dataloader_output,), {}: Tuple[Tuple, Dict]: The current model call to be made during
the initialization process
"""
return (dataloader_output,), {}

def get_target(self, _):
"""Return structure for ground truth in loss criterion based on dataloader output.
This implementation does not do anything and is a placeholder.
Returns:
None
"""
return None


def compose_nncf_config(nncf_config: Dict, enabled_options: List[str]) -> Dict:
"""
Compose NNCf config by selected options.
:param nncf_config:
:param enabled_options:
:return: config
"""
optimisation_parts = nncf_config

if "order_of_parts" in optimisation_parts:
# The result of applying the changes from optimisation parts
# may depend on the order of applying the changes
# (e.g. if for nncf_quantization it is sufficient to have `total_epochs=2`,
# but for sparsity it is required `total_epochs=50`)
# So, user can define `order_of_parts` in the optimisation_config
# to specify the order of applying the parts.
order_of_parts = optimisation_parts["order_of_parts"]
assert isinstance(order_of_parts, list), 'The field "order_of_parts" in optimisation config should be a list'

for part in enabled_options:
assert part in order_of_parts, (
f"The part {part} is selected, " "but it is absent in order_of_parts={order_of_parts}"
)

optimisation_parts_to_choose = [part for part in order_of_parts if part in enabled_options]

assert "base" in optimisation_parts, 'Error: the optimisation config does not contain the "base" part'
nncf_config_part = optimisation_parts["base"]

for part in optimisation_parts_to_choose:
assert part in optimisation_parts, f'Error: the optimisation config does not contain the part "{part}"'
optimisation_part_dict = optimisation_parts[part]
try:
nncf_config_part = merge_dicts_and_lists_b_into_a(nncf_config_part, optimisation_part_dict)
except AssertionError as cur_error:
err_descr = (
f"Error during merging the parts of nncf configs:\n"
f"the current part={part}, "
f"the order of merging parts into base is {optimisation_parts_to_choose}.\n"
f"The error is:\n{cur_error}"
)
raise RuntimeError(err_descr) from None

return nncf_config_part

# pylint: disable=invalid-name,missing-function-docstring
def merge_dicts_and_lists_b_into_a(a, b):
return _merge_dicts_and_lists_b_into_a(a, b, "")


def _merge_dicts_and_lists_b_into_a(a, b, cur_key=None):
"""The function is inspired by mmcf.Config._merge_a_into_b,
but it
* works with usual dicts and lists and derived types
* supports merging of lists (by concatenating the lists)
* makes recursive merging for dict + dict case
* overwrites when merging scalar into scalar
Note that we merge b into a (whereas Config makes merge a into b),
since otherwise the order of list merging is counter-intuitive.
"""

def _err_str(_a, _b, _key):
if _key is None:
_key_str = "of whole structures"
else:
_key_str = f"during merging for key=`{_key}`"
return (
f"Error in merging parts of config: different types {_key_str},"
f" type(a) = {type(_a)},"
f" type(b) = {type(_b)}"
)

assert isinstance(a, (dict, list)), f"Can merge only dicts and lists, whereas type(a)={type(a)}"
assert isinstance(b, (dict, list)), _err_str(a, b, cur_key)
assert isinstance(a, list) == isinstance(b, list), _err_str(a, b, cur_key)
if isinstance(a, list):
# the main diff w.r.t. mmcf.Config -- merging of lists
return a + b

a = copy(a)
for k in b.keys():
if k not in a:
a[k] = copy(b[k])
continue
new_cur_key = cur_key + "." + k if cur_key else k
if isinstance(a[k], (dict, list)):
a[k] = _merge_dicts_and_lists_b_into_a(a[k], b[k], new_cur_key)
continue

assert not isinstance(b[k], (dict, list)), _err_str(a[k], b[k], new_cur_key)

# suppose here that a[k] and b[k] are scalars, just overwrite
a[k] = b[k]
return a
4 changes: 4 additions & 0 deletions anomalib/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@

from anomalib.models.components import AnomalyModule

# TODO(AlexanderDokuchaev): Workaround of wrapping by NNCF.
# Can't not wrap `spatial_softmax2d` if use import_module.
from anomalib.models.padim.model import PadimLightning


def get_model(config: Union[DictConfig, ListConfig]) -> AnomalyModule:
"""Load model from the configuration file.
Expand Down
Loading

0 comments on commit 5995978

Please sign in to comment.