Skip to content

Commit

Permalink
Merge pull request #79 from kadirnar/update-test
Browse files Browse the repository at this point in the history
💬 Add new parameters for hqq optimization method
  • Loading branch information
kadirnar authored May 4, 2024
2 parents 317bdc9 + e5ceca7 commit e27f4dd
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 8 deletions.
Empty file modified scripts/runpod.sh
100644 → 100755
Empty file.
7 changes: 6 additions & 1 deletion whisperplus/pipelines/whisper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import logging

import torch
from hqq.core.quantize import HQQBackend, HQQLinear
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline

HQQLinear.set_backend(HQQBackend.PYTORCH) # Pytorch backend
HQQLinear.set_backend(HQQBackend.PYTORCH_COMPILE) # Compiled Pytorch via dynamo
HQQLinear.set_backend(HQQBackend.ATEN) # C++ Aten/CUDA backend (set automatically by default if available)

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


Expand All @@ -25,7 +30,7 @@ def load_model(self, model_id: str = "distil-whisper/distil-large-v3", quant_con
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation="flash_attention_2",
torch_dtype=torch.float16,
torch_dtype=torch.bfloat16,
device_map='auto',
max_memory={0: "24GiB"})
logging.info("Model loaded successfully.")
Expand Down
20 changes: 13 additions & 7 deletions whisperplus/test.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
import time

import torch
from hqq.utils.patching import prepare_for_inference
from pipelines.whisper import SpeechToTextPipeline
from transformers import BitsAndBytesConfig, HqqConfig
from utils.download_utils import download_and_convert_to_mp3

url = "https://www.youtube.com/watch?v=di3rHkEZuUw"
url = "https://www.youtube.com/watch?v=BpN4hEAvDBg"
audio_path = download_and_convert_to_mp3(url)

hqq_config = HqqConfig(
nbits=1, group_size=64, quant_zero=False, quant_scale=False, axis=0) # axis=0 is used by default
nbits=1,
group_size=64,
quant_zero=False,
quant_scale=False,
axis=0,
offload_meta=False,
) # axis=0 is used by default

bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
Expand All @@ -18,14 +23,14 @@
bnb_4bit_use_double_quant=True,
)
model = SpeechToTextPipeline(
model_id="distil-whisper/distil-large-v3", quant_config=bnb_config) # or bnb_config
model_id="distil-whisper/distil-large-v3", quant_config=hqq_config) # or bnb_config

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

start_event.record()
transcript = model(
audio_path="testv0.mp3",
audio_path=audio_path,
chunk_length_s=30,
stride_length_s=5,
max_new_tokens=128,
Expand All @@ -36,4 +41,5 @@

torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
print(f"Execution time: {elapsed_time_ms}ms")
seconds = elapsed_time_ms / 1000
print(f"Elapsed time: {seconds} seconds")

0 comments on commit e27f4dd

Please sign in to comment.