Skip to content

Commit

Permalink
QAT Support for new Framework with QuantizationModifier Testing (#1763)
Browse files Browse the repository at this point in the history
* filling out quantization modifer for training

* unit tests for quantization modifier oneshot

* pytorch tests

* deleting debug scripts

* add pytorch flag

* fix post quant calib

* move e2e example

* file path issue fix

* fix imports

* quality
  • Loading branch information
Satrat committed Oct 16, 2023
1 parent cb4e02b commit f889bb8
Show file tree
Hide file tree
Showing 9 changed files with 580 additions and 13 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ ifneq ($(findstring onnx,$(TARGETS)),onnx)
PYTEST_ARGS := $(PYTEST_ARGS) --ignore tests/sparseml/onnx
endif
ifneq ($(findstring pytorch,$(TARGETS)),pytorch)
PYTEST_ARGS := $(PYTEST_ARGS) --ignore tests/sparseml/pytorch
PYTEST_ARGS := $(PYTEST_ARGS) --ignore tests/sparseml/pytorch --ignore tests/sparseml/modifiers
endif
ifneq ($(findstring pytorch_models,$(TARGETS)),pytorch_models)
PYTEST_ARGS := $(PYTEST_ARGS) --ignore tests/sparseml/pytorch/models
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
test_stage:
quantization_modifiers:
QuantizationModifier:
start: eval(start_quant_epoch)
scheme:
input_activations:
num_bits: 8
symmetric: False
weights:
num_bits: 4
symmetric: True
strategy: "channel"
ignore: ['classifier']
pruning_modifiers:
MagnitudePruningModifier:
init_sparsity: 0.0
final_sparsity: 0.5
start: eval(warm_up_epochs)
end: eval(warm_up_epochs + pruning_epochs)
update_frequency: 0.5
targets:
- features.0.0.weight
- features.1.conv.0.0.weight
- features.1.conv.1.weight
- features.2.conv.0.0.weight
- features.2.conv.1.0.weight
- features.2.conv.2.weight
- features.3.conv.0.0.weight
- features.3.conv.1.0.weight
- features.3.conv.2.weight
- features.4.conv.0.0.weight
- features.4.conv.1.0.weight
- features.4.conv.2.weight
- features.5.conv.0.0.weight
- features.5.conv.1.0.weight
- features.5.conv.2.weight
- features.6.conv.0.0.weight
- features.6.conv.1.0.weight
- features.6.conv.2.weight
- features.7.conv.0.0.weight
- features.7.conv.1.0.weight
- features.7.conv.2.weight
- features.8.conv.0.0.weight
- features.8.conv.1.0.weight
- features.8.conv.2.weight
- features.9.conv.0.0.weight
- features.9.conv.1.0.weight
- features.9.conv.2.weight
- features.10.conv.0.0.weight
- features.10.conv.1.0.weight
- features.10.conv.2.weight
- features.11.conv.0.0.weight
- features.11.conv.1.0.weight
- features.11.conv.2.weight
- features.12.conv.0.0.weight
- features.12.conv.1.0.weight
- features.12.conv.2.weight
- features.13.conv.0.0.weight
- features.13.conv.1.0.weight
- features.13.conv.2.weight
- features.14.conv.0.0.weight
- features.14.conv.1.0.weight
- features.14.conv.2.weight
- features.15.conv.0.0.weight
- features.15.conv.1.0.weight
- features.15.conv.2.weight
- features.16.conv.0.0.weight
- features.16.conv.1.0.weight
- features.16.conv.2.weight
- features.17.conv.0.0.weight
- features.17.conv.1.0.weight
- features.17.conv.2.weight
- features.18.0.weight
- classifier.1.weight
leave_enabled: True
155 changes: 155 additions & 0 deletions integrations/torchvision/modifiers_refactor_example/e2e_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


def main():
import os

import datasets
import torch
import torchvision
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import transforms

import sparseml.core.session as sml
from sparseml.core.event import EventType
from sparseml.core.framework import Framework
from sparseml.pytorch.utils import (
ModuleExporter,
get_prunable_layers,
tensor_sparsity,
)

NUM_LABELS = 3
BATCH_SIZE = 32
NUM_EPOCHS = 12
recipe = "e2e_recipe.yaml"
device = "cuda:0"

# set up SparseML session
sml.create_session()
session = sml.active_session()

# download model
model = torchvision.models.mobilenet_v2(
weights=torchvision.models.MobileNet_V2_Weights.DEFAULT
)
model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, NUM_LABELS)
model.to(device)

# download data
beans_dataset = datasets.load_dataset("beans")
train_folder, _ = os.path.split(beans_dataset["train"][0]["image_file_path"])
train_path, _ = os.path.split(train_folder)
val_folder, _ = os.path.split(beans_dataset["validation"][0]["image_file_path"])
val_path, _ = os.path.split(train_folder)

# dataloaders
imagenet_transform = transforms.Compose(
[
transforms.Resize(
size=256,
interpolation=transforms.InterpolationMode.BILINEAR,
max_size=None,
antialias=None,
),
transforms.CenterCrop(size=(224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)

train_dataset = torchvision.datasets.ImageFolder(
root=train_path, transform=imagenet_transform
)
train_loader = DataLoader(
train_dataset, BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=16
)

val_dataset = torchvision.datasets.ImageFolder(
root=val_path, transform=imagenet_transform
)
val_loader = DataLoader(
val_dataset, BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=16
)

# loss and optimizer
criterion = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=8e-3)

# initialize session
recipe_args = {"warm_up_epochs": 5, "start_quant_epoch": 3, "pruning_epochs": 5}
_ = session.initialize(
framework=Framework.pytorch,
recipe=recipe,
recipe_args=recipe_args,
model=model,
teacher_model=None,
optimizer=optimizer,
train_data=train_loader,
val_data=val_loader,
start=0.0,
steps_per_epoch=len(train_loader),
)

# loop through batches
for epoch in range(NUM_EPOCHS):
running_loss = 0.0
total_correct = 0
total_predictions = 0
for step, (inputs, labels) in enumerate(session.state.data.train):
inputs = inputs.to(device)
labels = labels.to(device)
session.state.optimizer.optimizer.zero_grad()
session.event(event_type=EventType.BATCH_START, batch_data=(input, labels))

outputs = session.state.model.model(inputs)
loss = criterion(outputs, labels)
loss.backward()
session.event(event_type=EventType.LOSS_CALCULATED, loss=loss)

session.event(event_type=EventType.OPTIM_PRE_STEP)
session.state.optimizer.optimizer.step()
session.event(event_type=EventType.OPTIM_POST_STEP)

running_loss += loss.item()

predictions = outputs.argmax(dim=1)
total_correct += torch.sum(predictions == labels).item()
total_predictions += inputs.size(0)

session.event(event_type=EventType.BATCH_END)

loss = running_loss / (step + 1.0)
accuracy = total_correct / total_predictions
print("Epoch: {} Loss: {} Accuracy: {}".format(epoch + 1, loss, accuracy))

# finalize session
session.finalize()

# view sparsities
for (name, layer) in get_prunable_layers(session.state.model.model):
print(f"{name}.weight: {tensor_sparsity(layer.weight).item():.4f}")

# save sparsified model
save_dir = "e2e_experiment"
exporter = ModuleExporter(model, output_dir=save_dir)
exporter.export_pytorch(name="mobilenet_v2-sparse-beans.pth")
exporter.export_onnx(torch.randn(1, 3, 224, 224), name="sparse-model.onnx")


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion src/sparseml/core/recipe/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def simplify_combine_recipes(
)
combined.version = simplified.version
combined.stages.extend(simplified.stages)
combined.args.combine(simplified.args)
combined.args.update(simplified.args)

return combined

Expand Down
53 changes: 52 additions & 1 deletion src/sparseml/modifiers/quantization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from typing import Any, Dict, List, Optional

from sparseml.core import Modifier, State
from sparseml.core import Event, Modifier, State
from sparseml.modifiers.quantization.utils.quantization_scheme import (
QuantizationScheme,
QuantizationSchemeLoadable,
Expand Down Expand Up @@ -104,6 +104,57 @@ def __init__(self, **kwargs):
)
if self.model_fuse_fn_kwargs is None:
self.model_fuse_fn_kwargs = {}
if self.ignore is None:
self.ignore = []

def calculate_freeze_bn_stats_epoch(self) -> float:
"""
Get the epoch at which we want to stop updating batch normalization stats
:return: freeze_bn_stats_epoch if set, else -1
"""
return (
self.freeze_bn_stats_epoch if self.freeze_bn_stats_epoch is not None else -1
)

def check_should_freeze_bn_stats(self, event: Event) -> bool:
"""
Given the current index, determine if we should freeze batch normalization stats
:param event: Event to get index from
:return: True if stats should be frozen, False otherwise
"""
freeze_epoch = self.calculate_freeze_bn_stats_epoch()
if freeze_epoch == -1:
return False
if event.current_index >= freeze_epoch:
return True
return False

def calculate_disable_observer_epoch(self) -> float:
"""
Get the epoch at which we want to disable to quantization observer
:return epoch to disable at, or -1 if it is not set
"""
return (
self.disable_quantization_observer_epoch
if self.disable_quantization_observer_epoch is not None
else -1
)

def check_should_disable_observer(self, event: Event) -> bool:
"""
Given the current index, determine if we should disable the observer
:param event: Event to get index from
:return: True if observer should be disabled, False otherwise
"""
disable_epoch = self.calculate_disable_observer_epoch()
if disable_epoch == -1:
return False
if event.current_index >= disable_epoch:
return True
return False

def on_initialize_structure(self, state: State, **kwargs):
pass # nothing needed for this modifier
Expand Down
36 changes: 26 additions & 10 deletions src/sparseml/modifiers/quantization/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
import torch
from torch.nn import Module

from sparseml.core import Event, State
from sparseml.core import Event, EventType, State
from sparseml.modifiers.quantization.base import QuantizationModifier
from sparseml.modifiers.quantization.utils.helpers import (
configure_module_bn_wrappers,
freeze_bn_stats,
fuse_module_conv_bn_relus,
)
from sparseml.modifiers.quantization.utils.quantize import (
Expand All @@ -40,6 +41,8 @@ class QuantizationModifierPyTorch(QuantizationModifier):
calibration_dataloader_: Any = None
calibration_function_: Any = None
qat_enabled_: bool = False
quantization_observer_disabled_: bool = False
bn_stats_frozen_: bool = False

def on_initialize(self, state: State, **kwargs) -> bool:
raise_if_torch_quantization_not_available()
Expand All @@ -51,10 +54,10 @@ def on_initialize(self, state: State, **kwargs) -> bool:

self.calibration_dataloader_ = state.data.calib
module = state.model.model
device = state.hardware.device
state.model.model.to(device)
module = state.model.model
self._enable_module_qat(module)

if self.calculate_start() == -1: # one-shot
self._enable_module_qat(module)
self._disable_quantization_observer(module)

return True

Expand All @@ -63,21 +66,34 @@ def on_finalize(self, state: State, **kwargs) -> bool:
state.model.model.to(state.hardware.device)
state.model.model.apply(torch.quantization.enable_observer)
self._calibrate_if_possible(state.model.model)
state.model.model.apply(torch.quantization.disable_observer)
self._disable_quantization_observer(state.model.model)
return True

def on_start(self, state: State, event: Event, **kwargs):
pass
if not self.qat_enabled_:
self._enable_module_qat(state.model.model)

def on_update(self, state: State, event: Event, **kwargs):
pass
if event.type_ == EventType.BATCH_START:
if self.check_should_freeze_bn_stats(event):
self._freeze_bn_stats(state.model.model)
if self.check_should_disable_observer(event):
self._disable_quantization_observer(state.model.model)

def on_end(self, state: State, event: Event, **kwargs):
pass
self._disable_quantization_observer(state.model.model)

def on_event(self, state: State, event: Event, **kwargs):
pass

def _freeze_bn_stats(self, model: Module):
model.apply(freeze_bn_stats)
self.bn_stats_frozen_ = True

def _disable_quantization_observer(self, model: Module):
model.apply(torch.quantization.disable_observer)
self.quantization_observer_disabled_ = True

def _enable_module_qat(self, module: Module):
# fuse conv-bn-relu blocks prior to quantization emulation
self._fuse(module)
Expand Down Expand Up @@ -164,4 +180,4 @@ def _calibrate(self, module: Module):
if module_training:
module.train()
else:
module.apply(torch.quantization.disable_observer)
self._disable_quantization_observer(module)
Loading

0 comments on commit f889bb8

Please sign in to comment.