Skip to content

Commit

Permalink
Merge branch 'main' into prod_smooth_quant
Browse files Browse the repository at this point in the history
  • Loading branch information
Satrat committed Oct 19, 2023
2 parents b03ab84 + cae298c commit 8244bc5
Show file tree
Hide file tree
Showing 28 changed files with 1,753 additions and 1,033 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
test_stage:
quantization_modifiers:
QuantizationModifier:
start: eval(start_quant_epoch)
scheme:
input_activations:
num_bits: 8
symmetric: False
weights:
num_bits: 4
symmetric: True
strategy: "channel"
ignore: ['classifier']
pruning_modifiers:
MagnitudePruningModifier:
init_sparsity: 0.0
final_sparsity: 0.5
start: eval(warm_up_epochs)
end: eval(warm_up_epochs + pruning_epochs)
update_frequency: 0.5
targets:
- features.0.0.weight
- features.1.conv.0.0.weight
- features.1.conv.1.weight
- features.2.conv.0.0.weight
- features.2.conv.1.0.weight
- features.2.conv.2.weight
- features.3.conv.0.0.weight
- features.3.conv.1.0.weight
- features.3.conv.2.weight
- features.4.conv.0.0.weight
- features.4.conv.1.0.weight
- features.4.conv.2.weight
- features.5.conv.0.0.weight
- features.5.conv.1.0.weight
- features.5.conv.2.weight
- features.6.conv.0.0.weight
- features.6.conv.1.0.weight
- features.6.conv.2.weight
- features.7.conv.0.0.weight
- features.7.conv.1.0.weight
- features.7.conv.2.weight
- features.8.conv.0.0.weight
- features.8.conv.1.0.weight
- features.8.conv.2.weight
- features.9.conv.0.0.weight
- features.9.conv.1.0.weight
- features.9.conv.2.weight
- features.10.conv.0.0.weight
- features.10.conv.1.0.weight
- features.10.conv.2.weight
- features.11.conv.0.0.weight
- features.11.conv.1.0.weight
- features.11.conv.2.weight
- features.12.conv.0.0.weight
- features.12.conv.1.0.weight
- features.12.conv.2.weight
- features.13.conv.0.0.weight
- features.13.conv.1.0.weight
- features.13.conv.2.weight
- features.14.conv.0.0.weight
- features.14.conv.1.0.weight
- features.14.conv.2.weight
- features.15.conv.0.0.weight
- features.15.conv.1.0.weight
- features.15.conv.2.weight
- features.16.conv.0.0.weight
- features.16.conv.1.0.weight
- features.16.conv.2.weight
- features.17.conv.0.0.weight
- features.17.conv.1.0.weight
- features.17.conv.2.weight
- features.18.0.weight
- classifier.1.weight
leave_enabled: True
155 changes: 155 additions & 0 deletions integrations/torchvision/modifiers_refactor_example/e2e_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


def main():
import os

import datasets
import torch
import torchvision
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import transforms

import sparseml.core.session as sml
from sparseml.core.event import EventType
from sparseml.core.framework import Framework
from sparseml.pytorch.utils import (
ModuleExporter,
get_prunable_layers,
tensor_sparsity,
)

NUM_LABELS = 3
BATCH_SIZE = 32
NUM_EPOCHS = 12
recipe = "e2e_recipe.yaml"
device = "cuda:0"

# set up SparseML session
sml.create_session()
session = sml.active_session()

# download model
model = torchvision.models.mobilenet_v2(
weights=torchvision.models.MobileNet_V2_Weights.DEFAULT
)
model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, NUM_LABELS)
model.to(device)

# download data
beans_dataset = datasets.load_dataset("beans")
train_folder, _ = os.path.split(beans_dataset["train"][0]["image_file_path"])
train_path, _ = os.path.split(train_folder)
val_folder, _ = os.path.split(beans_dataset["validation"][0]["image_file_path"])
val_path, _ = os.path.split(train_folder)

# dataloaders
imagenet_transform = transforms.Compose(
[
transforms.Resize(
size=256,
interpolation=transforms.InterpolationMode.BILINEAR,
max_size=None,
antialias=None,
),
transforms.CenterCrop(size=(224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)

train_dataset = torchvision.datasets.ImageFolder(
root=train_path, transform=imagenet_transform
)
train_loader = DataLoader(
train_dataset, BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=16
)

val_dataset = torchvision.datasets.ImageFolder(
root=val_path, transform=imagenet_transform
)
val_loader = DataLoader(
val_dataset, BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=16
)

# loss and optimizer
criterion = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=8e-3)

# initialize session
recipe_args = {"warm_up_epochs": 5, "start_quant_epoch": 3, "pruning_epochs": 5}
_ = session.initialize(
framework=Framework.pytorch,
recipe=recipe,
recipe_args=recipe_args,
model=model,
teacher_model=None,
optimizer=optimizer,
train_data=train_loader,
val_data=val_loader,
start=0.0,
steps_per_epoch=len(train_loader),
)

# loop through batches
for epoch in range(NUM_EPOCHS):
running_loss = 0.0
total_correct = 0
total_predictions = 0
for step, (inputs, labels) in enumerate(session.state.data.train):
inputs = inputs.to(device)
labels = labels.to(device)
session.state.optimizer.optimizer.zero_grad()
session.event(event_type=EventType.BATCH_START, batch_data=(input, labels))

outputs = session.state.model.model(inputs)
loss = criterion(outputs, labels)
loss.backward()
session.event(event_type=EventType.LOSS_CALCULATED, loss=loss)

session.event(event_type=EventType.OPTIM_PRE_STEP)
session.state.optimizer.optimizer.step()
session.event(event_type=EventType.OPTIM_POST_STEP)

running_loss += loss.item()

predictions = outputs.argmax(dim=1)
total_correct += torch.sum(predictions == labels).item()
total_predictions += inputs.size(0)

session.event(event_type=EventType.BATCH_END)

loss = running_loss / (step + 1.0)
accuracy = total_correct / total_predictions
print("Epoch: {} Loss: {} Accuracy: {}".format(epoch + 1, loss, accuracy))

# finalize session
session.finalize()

# view sparsities
for (name, layer) in get_prunable_layers(session.state.model.model):
print(f"{name}.weight: {tensor_sparsity(layer.weight).item():.4f}")

# save sparsified model
save_dir = "e2e_experiment"
exporter = ModuleExporter(model, output_dir=save_dir)
exporter.export_pytorch(name="mobilenet_v2-sparse-beans.pth")
exporter.export_onnx(torch.randn(1, 3, 224, 224), name="sparse-model.onnx")


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion research/information_retrieval/doc2query/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ traitlets==5.0.5
transformers==4.7.0
typer==0.3.2
typing-extensions==3.10.0.0
urllib3==1.26.17
urllib3==1.26.18
wandb==0.10.32
wasabi==0.8.2
wcwidth==0.2.5
Expand Down
1 change: 1 addition & 0 deletions src/sparseml/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .framework import *
from .framework_object import *
from .lifecycle import *
from .logger import *
from .model import *
from .modifier import *
from .optimizer import *
Expand Down
Loading

0 comments on commit 8244bc5

Please sign in to comment.