Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QAT Support for new Framework with QuantizationModifier Testing #1763

Merged
merged 15 commits into from
Oct 16, 2023
Merged
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ ifneq ($(findstring onnx,$(TARGETS)),onnx)
PYTEST_ARGS := $(PYTEST_ARGS) --ignore tests/sparseml/onnx
endif
ifneq ($(findstring pytorch,$(TARGETS)),pytorch)
PYTEST_ARGS := $(PYTEST_ARGS) --ignore tests/sparseml/pytorch
PYTEST_ARGS := $(PYTEST_ARGS) --ignore tests/sparseml/pytorch --ignore tests/sparseml/modifiers
endif
ifneq ($(findstring pytorch_models,$(TARGETS)),pytorch_models)
PYTEST_ARGS := $(PYTEST_ARGS) --ignore tests/sparseml/pytorch/models
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
test_stage:
quantization_modifiers:
QuantizationModifier:
start: eval(start_quant_epoch)
scheme:
input_activations:
num_bits: 8
symmetric: False
weights:
num_bits: 4
symmetric: True
strategy: "channel"
ignore: ['classifier']
pruning_modifiers:
MagnitudePruningModifier:
init_sparsity: 0.0
final_sparsity: 0.5
start: eval(warm_up_epochs)
end: eval(warm_up_epochs + pruning_epochs)
update_frequency: 0.5
targets:
- features.0.0.weight
- features.1.conv.0.0.weight
- features.1.conv.1.weight
- features.2.conv.0.0.weight
- features.2.conv.1.0.weight
- features.2.conv.2.weight
- features.3.conv.0.0.weight
- features.3.conv.1.0.weight
- features.3.conv.2.weight
- features.4.conv.0.0.weight
- features.4.conv.1.0.weight
- features.4.conv.2.weight
- features.5.conv.0.0.weight
- features.5.conv.1.0.weight
- features.5.conv.2.weight
- features.6.conv.0.0.weight
- features.6.conv.1.0.weight
- features.6.conv.2.weight
- features.7.conv.0.0.weight
- features.7.conv.1.0.weight
- features.7.conv.2.weight
- features.8.conv.0.0.weight
- features.8.conv.1.0.weight
- features.8.conv.2.weight
- features.9.conv.0.0.weight
- features.9.conv.1.0.weight
- features.9.conv.2.weight
- features.10.conv.0.0.weight
- features.10.conv.1.0.weight
- features.10.conv.2.weight
- features.11.conv.0.0.weight
- features.11.conv.1.0.weight
- features.11.conv.2.weight
- features.12.conv.0.0.weight
- features.12.conv.1.0.weight
- features.12.conv.2.weight
- features.13.conv.0.0.weight
- features.13.conv.1.0.weight
- features.13.conv.2.weight
- features.14.conv.0.0.weight
- features.14.conv.1.0.weight
- features.14.conv.2.weight
- features.15.conv.0.0.weight
- features.15.conv.1.0.weight
- features.15.conv.2.weight
- features.16.conv.0.0.weight
- features.16.conv.1.0.weight
- features.16.conv.2.weight
- features.17.conv.0.0.weight
- features.17.conv.1.0.weight
- features.17.conv.2.weight
- features.18.0.weight
- classifier.1.weight
leave_enabled: True
155 changes: 155 additions & 0 deletions integrations/torchvision/modifiers_refactor_example/e2e_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


Satrat marked this conversation as resolved.
Show resolved Hide resolved
def main():
import os

import datasets
import torch
import torchvision
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import transforms

import sparseml.core.session as sml
from sparseml.core.event import EventType
from sparseml.core.framework import Framework
from sparseml.pytorch.utils import (
ModuleExporter,
get_prunable_layers,
tensor_sparsity,
)

NUM_LABELS = 3
BATCH_SIZE = 32
NUM_EPOCHS = 12
recipe = "e2e_recipe.yaml"
device = "cuda:0"

# set up SparseML session
sml.create_session()
session = sml.active_session()

# download model
model = torchvision.models.mobilenet_v2(
weights=torchvision.models.MobileNet_V2_Weights.DEFAULT
)
model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, NUM_LABELS)
model.to(device)

# download data
beans_dataset = datasets.load_dataset("beans")
train_folder, _ = os.path.split(beans_dataset["train"][0]["image_file_path"])
train_path, _ = os.path.split(train_folder)
val_folder, _ = os.path.split(beans_dataset["validation"][0]["image_file_path"])
val_path, _ = os.path.split(train_folder)

# dataloaders
imagenet_transform = transforms.Compose(
[
transforms.Resize(
size=256,
interpolation=transforms.InterpolationMode.BILINEAR,
max_size=None,
antialias=None,
),
transforms.CenterCrop(size=(224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)

train_dataset = torchvision.datasets.ImageFolder(
root=train_path, transform=imagenet_transform
)
train_loader = DataLoader(
train_dataset, BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=16
)

val_dataset = torchvision.datasets.ImageFolder(
root=val_path, transform=imagenet_transform
)
val_loader = DataLoader(
val_dataset, BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=16
)

# loss and optimizer
criterion = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=8e-3)

# initialize session
recipe_args = {"warm_up_epochs": 5, "start_quant_epoch": 3, "pruning_epochs": 5}
_ = session.initialize(
framework=Framework.pytorch,
recipe=recipe,
recipe_args=recipe_args,
model=model,
teacher_model=None,
optimizer=optimizer,
train_data=train_loader,
val_data=val_loader,
start=0.0,
steps_per_epoch=len(train_loader),
)

# loop through batches
for epoch in range(NUM_EPOCHS):
running_loss = 0.0
total_correct = 0
total_predictions = 0
for step, (inputs, labels) in enumerate(session.state.data.train):
inputs = inputs.to(device)
labels = labels.to(device)
session.state.optimizer.optimizer.zero_grad()
session.event(event_type=EventType.BATCH_START, batch_data=(input, labels))

outputs = session.state.model.model(inputs)
loss = criterion(outputs, labels)
loss.backward()
session.event(event_type=EventType.LOSS_CALCULATED, loss=loss)

session.event(event_type=EventType.OPTIM_PRE_STEP)
session.state.optimizer.optimizer.step()
session.event(event_type=EventType.OPTIM_POST_STEP)

running_loss += loss.item()

predictions = outputs.argmax(dim=1)
total_correct += torch.sum(predictions == labels).item()
total_predictions += inputs.size(0)

session.event(event_type=EventType.BATCH_END)

loss = running_loss / (step + 1.0)
accuracy = total_correct / total_predictions
print("Epoch: {} Loss: {} Accuracy: {}".format(epoch + 1, loss, accuracy))

# finalize session
session.finalize()

# view sparsities
for (name, layer) in get_prunable_layers(session.state.model.model):
print(f"{name}.weight: {tensor_sparsity(layer.weight).item():.4f}")

# save sparsified model
save_dir = "e2e_experiment"
exporter = ModuleExporter(model, output_dir=save_dir)
exporter.export_pytorch(name="mobilenet_v2-sparse-beans.pth")
exporter.export_onnx(torch.randn(1, 3, 224, 224), name="sparse-model.onnx")


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion src/sparseml/core/recipe/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def simplify_combine_recipes(
)
combined.version = simplified.version
combined.stages.extend(simplified.stages)
combined.args.combine(simplified.args)
combined.args.update(simplified.args)

return combined

Expand Down
53 changes: 52 additions & 1 deletion src/sparseml/modifiers/quantization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from typing import Any, Dict, List, Optional

from sparseml.core import Modifier, State
from sparseml.core import Event, Modifier, State
from sparseml.modifiers.quantization.utils.quantization_scheme import (
QuantizationScheme,
QuantizationSchemeLoadable,
Expand Down Expand Up @@ -104,6 +104,57 @@ def __init__(self, **kwargs):
)
if self.model_fuse_fn_kwargs is None:
self.model_fuse_fn_kwargs = {}
if self.ignore is None:
self.ignore = []

def calculate_freeze_bn_stats_epoch(self) -> float:
"""
Get the epoch at which we want to stop updating batch normalization stats

:return: freeze_bn_stats_epoch if set, else -1
"""
return (
self.freeze_bn_stats_epoch if self.freeze_bn_stats_epoch is not None else -1
)

def check_should_freeze_bn_stats(self, event: Event) -> bool:
"""
Given the current index, determine if we should freeze batch normalization stats

:param event: Event to get index from
:return: True if stats should be frozen, False otherwise
"""
freeze_epoch = self.calculate_freeze_bn_stats_epoch()
if freeze_epoch == -1:
return False
if event.current_index >= freeze_epoch:
return True
return False

def calculate_disable_observer_epoch(self) -> float:
"""
Get the epoch at which we want to disable to quantization observer
:return epoch to disable at, or -1 if it is not set
"""
return (
self.disable_quantization_observer_epoch
if self.disable_quantization_observer_epoch is not None
else -1
)

def check_should_disable_observer(self, event: Event) -> bool:
"""
Given the current index, determine if we should disable the observer

:param event: Event to get index from
:return: True if observer should be disabled, False otherwise
"""
disable_epoch = self.calculate_disable_observer_epoch()
if disable_epoch == -1:
return False
if event.current_index >= disable_epoch:
return True
return False

def on_initialize_structure(self, state: State, **kwargs):
pass # nothing needed for this modifier
Expand Down
36 changes: 26 additions & 10 deletions src/sparseml/modifiers/quantization/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
import torch
from torch.nn import Module

from sparseml.core import Event, State
from sparseml.core import Event, EventType, State
from sparseml.modifiers.quantization.base import QuantizationModifier
from sparseml.modifiers.quantization.utils.helpers import (
configure_module_bn_wrappers,
freeze_bn_stats,
fuse_module_conv_bn_relus,
)
from sparseml.modifiers.quantization.utils.quantize import (
Expand All @@ -40,6 +41,8 @@ class QuantizationModifierPyTorch(QuantizationModifier):
calibration_dataloader_: Any = None
calibration_function_: Any = None
qat_enabled_: bool = False
quantization_observer_disabled_: bool = False
bn_stats_frozen_: bool = False

def on_initialize(self, state: State, **kwargs) -> bool:
raise_if_torch_quantization_not_available()
Expand All @@ -51,10 +54,10 @@ def on_initialize(self, state: State, **kwargs) -> bool:

self.calibration_dataloader_ = state.data.calib
module = state.model.model
device = state.hardware.device
state.model.model.to(device)
module = state.model.model
self._enable_module_qat(module)

if self.calculate_start() == -1: # one-shot
self._enable_module_qat(module)
self._disable_quantization_observer(module)

return True

Expand All @@ -63,21 +66,34 @@ def on_finalize(self, state: State, **kwargs) -> bool:
state.model.model.to(state.hardware.device)
state.model.model.apply(torch.quantization.enable_observer)
self._calibrate_if_possible(state.model.model)
state.model.model.apply(torch.quantization.disable_observer)
self._disable_quantization_observer(state.model.model)
return True

def on_start(self, state: State, event: Event, **kwargs):
pass
if not self.qat_enabled_:
self._enable_module_qat(state.model.model)

def on_update(self, state: State, event: Event, **kwargs):
pass
if event.type_ == EventType.BATCH_START:
if self.check_should_freeze_bn_stats(event):
self._freeze_bn_stats(state.model.model)
if self.check_should_disable_observer(event):
self._disable_quantization_observer(state.model.model)

def on_end(self, state: State, event: Event, **kwargs):
pass
self._disable_quantization_observer(state.model.model)

def on_event(self, state: State, event: Event, **kwargs):
pass

def _freeze_bn_stats(self, model: Module):
model.apply(freeze_bn_stats)
self.bn_stats_frozen_ = True

def _disable_quantization_observer(self, model: Module):
model.apply(torch.quantization.disable_observer)
self.quantization_observer_disabled_ = True

def _enable_module_qat(self, module: Module):
# fuse conv-bn-relu blocks prior to quantization emulation
self._fuse(module)
Expand Down Expand Up @@ -164,4 +180,4 @@ def _calibrate(self, module: Module):
if module_training:
module.train()
else:
module.apply(torch.quantization.disable_observer)
self._disable_quantization_observer(module)
Loading
Loading