Skip to content

Commit

Permalink
Scale batch size for QAT by constant 4 (#186)
Browse files Browse the repository at this point in the history
* Scale batch size for QAT by constant 4

* Add back logging message
  • Loading branch information
KSGulin committed Mar 2, 2023
1 parent a1ebe6a commit 2c7058d
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 34 deletions.
42 changes: 13 additions & 29 deletions utils/neuralmagic/sparsification_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from sparseml.pytorch.optim import ScheduledModifierManager
from sparseml.pytorch.utils import SparsificationGroupLogger

from utils.autobatch import check_train_batch_size
from utils.loggers import Loggers
from utils.loss import ComputeLoss
from utils.neuralmagic.quantization import update_model_bottlenecks
Expand Down Expand Up @@ -392,37 +391,22 @@ def rescale_gradient_accumulation(
maintaining the original effective batch size
"""

# Calculate maximum batch size that will fit in memory
batch_size_max = max(
check_train_batch_size(self.model, image_size, False) // QAT_BATCH_SCALE, 1
)

if batch_size > batch_size_max:
new_batch_size = batch_size_max

# effective batch size to maintain
effective_batch_size = batch_size * accumulate

# Roughly calculate batch size by rounding. This can result in an effective
# batch size that is 1-to-few off from the original
new_accumulate = max(round(effective_batch_size / new_batch_size), 1)
new_batch_size = max(round(effective_batch_size / new_accumulate), 1)

self.log_console(
f"Batch size rescaled to {new_batch_size} with {new_accumulate} gradient "
effective_batch_size = batch_size * accumulate
batch_size = max(batch_size // QAT_BATCH_SCALE, 1)
accumulate = effective_batch_size // batch_size

self.log_console(
f"Batch size rescaled to {batch_size} with {accumulate} gradient "
"accumulation steps for QAT"
)

if new_accumulate * new_batch_size != effective_batch_size:
self.log_console(
"New effective batch size doesn't match previous effective batch size. "
f"Previous effective batch size: {effective_batch_size}. "
f"New effective batch size: {new_batch_size * new_accumulate}",
level="warning",
)

batch_size = new_batch_size
accumulate = new_accumulate
if accumulate * batch_size != effective_batch_size:
self.log_console(
"New effective batch size doesn't match previous effective batch size. "
f"Previous effective batch size: {effective_batch_size}. "
f"New effective batch size: {batch_size * accumulate}",
level="warning",
)

return batch_size, accumulate

Expand Down
12 changes: 7 additions & 5 deletions utils/neuralmagic/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import os
import sys
import yaml
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

import numpy
import torch
import yaml
from sparseml.pytorch.optim import ScheduledModifierManager
from sparseml.pytorch.utils import ModuleExporter, download_framework_model_by_recipe_type
from sparsezoo import Model
Expand Down Expand Up @@ -223,7 +223,7 @@ def export_sample_inputs_outputs(
number_export_samples=100,
image_size: int = 640,
onnx_path: Union[str, Path, None] = None,
default_hyp:str = 'data/hyps/hyp.scratch-low.yaml',
default_hyp: str = "data/hyps/hyp.scratch-low.yaml",
):
"""
Export sample model input and output for testing with the DeepSparse Engine
Expand All @@ -244,9 +244,11 @@ def export_sample_inputs_outputs(
if model.hyp is None:
FILE = Path(__file__).resolve()
ROOT = FILE.parents[2] # YOLOv5 root directory
HYPS_DIR = ROOT/default_hyp
nm_log_console(f"The model hyper-parameters are not set, using defaults from {HYPS_DIR}.")
with open(HYPS_DIR, errors='ignore') as f:
HYPS_DIR = ROOT / default_hyp
nm_log_console(
f"The model hyper-parameters are not set, using defaults from {HYPS_DIR}."
)
with open(HYPS_DIR, errors="ignore") as f:
model.hyp = yaml.safe_load(f)

# Create dataloader
Expand Down

0 comments on commit 2c7058d

Please sign in to comment.