Skip to content

Commit

Permalink
Merge branch 'main' into quant-modifier-ux
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli committed May 13, 2024
2 parents a55f50c + fcf3c77 commit 5b88ad3
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 29 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/build-wheel-and-container.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ jobs:
needs: build-wheel-and-push
uses: ./.github/workflows/test-wheel-push-to-internal.yml
with:
build-label: aws-avx2-64G
build-label: ubuntu-20.04
whl: ${{ needs.build-wheel-and-push.outputs.wheel }}
python: '3.10'
secrets: inherit
Expand All @@ -70,7 +70,7 @@ jobs:
needs: [set-outputs, test-wheel-and-push-internal]
uses: ./.github/workflows/build-container.yml
with:
build-label: aws-avx2-64G
build-label: k8s-eng-gpu-64G-v100-32G
dev: ${{ needs.set-outputs.outputs.dev }}
release: ${{ needs.set-outputs.outputs.release }}
name: ${{ github.event.number }}
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/test-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ jobs:
- name: "🔬 Running onnx tests"
run: make test TARGETS=onnx
pytorch-tests:
runs-on: aws-avx2-64G
runs-on: k8s-eng-gpu-64G-v100-32G
env:
SPARSEZOO_TEST_MODE: "true"
CLEARML_WEB_HOST: ${{ secrets.CLEARML_WEB_HOST }}
Expand Down Expand Up @@ -169,7 +169,7 @@ jobs:
- name: "🔬 Running pytorch tests"
run: make test TARGETS=pytorch
compat-pytorch-1_9-pytorch-tests:
runs-on: aws-avx2-64G
runs-on: k8s-eng-gpu-64G-v100-32G
env:
SPARSEZOO_TEST_MODE: "true"
CLEARML_WEB_HOST: ${{ secrets.CLEARML_WEB_HOST }}
Expand Down Expand Up @@ -222,7 +222,7 @@ jobs:
- name: "🔬 Running onnx tests"
run: make test TARGETS=onnx
transformers-tests:
runs-on: aws-avx2-64G
runs-on: k8s-eng-gpu-64G-v100-32G
env:
SPARSEZOO_TEST_MODE: "true"
CLEARML_WEB_HOST: ${{ secrets.CLEARML_WEB_HOST }}
Expand Down
9 changes: 0 additions & 9 deletions .github/workflows/test-wheel-push-to-internal.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,6 @@ jobs:
- name: Fetch name of whl
run: |
echo "FILENAME=$(echo dist_s3/*.whl)" >> $GITHUB_ENV
- name: Push to internal pypi
uses: neuralmagic/nm-actions/actions/nm-upload-whl@main
with:
server: ${{ secrets.NM_PRIVATE_PYPI_LOCATION }}
username: ${{ secrets.NM_PRIVATE_PYPI_USER }}
password: ${{ secrets.NM_PRIVATE_PYPI_AUTH }}
whl: ./$FILENAME
port: 8080
- name: Install whl
run: |
Expand Down
43 changes: 28 additions & 15 deletions src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import logging
import math
from copy import copy

import torch
import torch.nn as nn
Expand Down Expand Up @@ -181,30 +182,42 @@ def fasterprune(
fake_quantize,
)

if quant_scheme.weights.strategy == QuantizationStrategy.TENSOR:
strategy = quant_scheme.weights.strategy

if strategy == QuantizationStrategy.TENSOR:
q = fake_quantize(
q,
scale,
zero_point,
self.layer.quantization_scheme.weights,
)
else:
while scale.ndim < 2:
scale = scale.unsqueeze(scale.ndim)
zero_point = zero_point.unsqueeze(zero_point.ndim)

while q.ndim < 2:
q = q.unsqueeze(q.ndim)

elif strategy == QuantizationStrategy.CHANNEL:
# TODO: for channelwise why isn't this just a 1d tensor?
q = fake_quantize(
q,
scale[:, i],
zero_point[:, i],
self.layer.quantization_scheme.weights,
scale[:, 0],
zero_point[:, 0],
quant_scheme.weights,
)
else: # strategy == QuantizationStrategy.CHANNEL
# TODO: for grouped quantization its always 3d but the last
# dim is always 1. Can we just make it 2d instead and avoid?
scale = scale[:, :, 0]
zero_point = zero_point[:, :, 0]

# get the group index for the current column
input_dim_group = i // quant_scheme.weights.group_size

# Since we're only applying quantization to a slice, this
# ends up being a channelwise application
altered_qargs = copy(quant_scheme.weights)
altered_qargs.strategy = QuantizationStrategy.CHANNEL
q = fake_quantize(
q,
scale[:, input_dim_group],
zero_point[:, input_dim_group],
altered_qargs,
)

while q.ndim > 1:
q = q.squeeze()

Q1[:, i] = q
Losses1[:, i] = (w - q) ** 2 / d**2
Expand Down

0 comments on commit 5b88ad3

Please sign in to comment.