Skip to content

Commit

Permalink
Merge pull request #318 from Wordcab/317-increase-timeout-for-remote-…
Browse files Browse the repository at this point in the history
…diarization-auth

Increasing timeout for remote diarization auth
  • Loading branch information
aleksandr-smechov authored Jun 29, 2024
2 parents 0bfaeb3 + c147d82 commit 61be1ab
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 15 deletions.
4 changes: 2 additions & 2 deletions .env
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ DEBUG=True
# Then in your Dockerfile, copy the converted models to the /app/src/wordcab_transcribe/whisper_models folder.
# Example for WHISPER_MODEL: COPY cloned_wordcab_transcribe_repo/src/wordcab_transcribe/whisper_models/large-v3 /app/src/wordcab_transcribe/whisper_models/large-v3
# Example for ALIGN_MODEL: COPY cloned_wordcab_transcribe_repo/src/wordcab_transcribe/whisper_models/tiny /app/src/wordcab_transcribe/whisper_models/tiny
WHISPER_MODEL="large-v3"
WHISPER_MODEL="medium"
# You can specify one of two engines, "faster-whisper" or "tensorrt-llm". At the moment, "faster-whisper" is more
# stable, adjustable, and accurate, while "tensorrt-llm" is faster but less accurate and adjustable.
WHISPER_ENGINE="tensorrt-llm"
WHISPER_ENGINE="faster-whisper-batched"
# This helps adjust some build during the conversion of the Whisper model to TensorRT. If you change this, be sure to
# it in pre_requirements.txt. The only available options are "0.9.0.dev2024032600" and "0.11.0.dev2024052100".
# Note that version "0.11.0.dev2024052100" is not compatible with T4 or V100 GPUs.
Expand Down
2 changes: 2 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ RUN curl -L ${RELEASE_URL} | tar -zx -C /tmp \

RUN python -m pip install pip --upgrade

COPY faster-whisper /app/faster-whisper

COPY pre_requirements.txt .
COPY requirements.txt .

Expand Down
3 changes: 2 additions & 1 deletion pre_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ tensorrt_llm==0.9.0.dev2024032600
Cython==3.0.10
youtokentome @ git+https://github.com/gburlet/YouTokenToMe.git@dependencies
deepmultilingualpunctuation==1.0.1
pyannote.audio==3.2.0
pyannote.audio==3.2.0
ipython==8.24.0
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
aiohttp==3.9.3
aiofiles==23.2.1
boto3
faster-whisper @ https://github.com/SYSTRAN/faster-whisper/archive/refs/heads/master.tar.gz
-e /app/faster-whisper
ffmpeg-python==0.2.0
transformers==4.38.2
librosa==0.10.1
Expand Down
4 changes: 2 additions & 2 deletions src/wordcab_transcribe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ def whisper_model_compatibility_check(cls, value: str): # noqa: B902, N805
@field_validator("whisper_engine")
def whisper_engine_compatibility_check(cls, value: str): # noqa: B902, N805
"""Check that the whisper engine is compatible."""
if value.lower() not in ["faster-whisper", "tensorrt-llm"]:
if value.lower() not in ["faster-whisper", "faster-whisper-batched", "tensorrt-llm"]:
raise ValueError(
"The whisper engine must be one of `faster-whisper` or `tensorrt-llm`."
"The whisper engine must be one of `faster-whisper`, `faster-whisper-batched`, or `tensorrt-llm`."
)

return value
Expand Down
1 change: 0 additions & 1 deletion src/wordcab_transcribe/router/v1/audio_file_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ async def inference_with_audio( # noqa: C901
)

background_tasks.add_task(delete_file, filepath=filename)

task = asyncio.create_task(
asr.process_input(
filepath=filepath,
Expand Down
22 changes: 16 additions & 6 deletions src/wordcab_transcribe/services/asr_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ async def process_input( # noqa: C901
filepath (Union[str, List[str]]):
Path to the audio file or list of paths to the audio files to process.
batch_size (Union[int, None]):
The batch size to use for the transcription. For tensorrt-llm whisper engine only.
The batch size to use for the transcription. For tensorrt-llm and faster-whisper-batch engines only.
offset_start (Union[float, None]):
The start time of the audio file to process.
offset_end (Union[float, None]):
Expand Down Expand Up @@ -611,11 +611,20 @@ async def process_transcription(self, task: ASRTask, debug_mode: bool) -> None:
if isinstance(task.transcription.execution, LocalExecution):
out = await time_and_tell_async(
lambda: self.local_services.transcription(
task.audio,
audio=task.audio,
model_index=task.transcription.execution.index,
suppress_blank=False,
word_timestamps=True,
**task.transcription.options.model_dump(),
source_lang=task.transcription.options.source_lang,
batch_size=task.batch_size,
num_beams=task.transcription.options.num_beams,
suppress_blank=False, # TODO: Add this to the options
vocab=task.transcription.options.vocab,
word_timestamps=task.word_timestamps,
internal_vad=task.transcription.options.internal_vad,
repetition_penalty=task.transcription.options.repetition_penalty,
compression_ratio_threshold=task.transcription.options.compression_ratio_threshold,
log_prob_threshold=task.transcription.options.log_prob_threshold,
no_speech_threshold=task.transcription.options.no_speech_threshold,
condition_on_previous_text=task.transcription.options.condition_on_previous_text,
),
func_name="transcription",
debug_mode=debug_mode,
Expand Down Expand Up @@ -880,7 +889,8 @@ async def remote_diarization(
if not settings.debug:
headers = {"Content-Type": "application/x-www-form-urlencoded"}
auth_url = f"{url}/api/v1/auth"
async with aiohttp.ClientSession() as session:
diarization_timeout = aiohttp.ClientTimeout(total=10)
async with AsyncLocationTrustedRedirectSession(timeout=diarization_timeout) as session:
async with session.post(
url=auth_url,
data={"username": settings.username, "password": settings.password},
Expand Down
29 changes: 27 additions & 2 deletions src/wordcab_transcribe/services/transcribe_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
from typing import Iterable, List, NamedTuple, Optional, Union

import torch
from faster_whisper import WhisperModel
from loguru import logger
from tensorshare import Backend, TensorShare
from faster_whisper import WhisperModel, BatchedInferencePipeline

from wordcab_transcribe.config import settings
from wordcab_transcribe.engines.tensorrt_llm.model import WhisperModelTRT
Expand Down Expand Up @@ -87,6 +87,16 @@ def __init__(
device_index=device_index,
compute_type=self.compute_type,
)
elif self.model_engine == "faster-whisper-batched":
logger.info("Using faster-whisper-batched model engine.")
self.model = BatchedInferencePipeline(
model=WhisperModel(
self.model_path,
device=self.device,
device_index=device_index,
compute_type=self.compute_type,
)
)
elif self.model_engine == "tensorrt-llm":
logger.info("Using tensorrt-llm model engine.")
if "v3" in self.model_path:
Expand Down Expand Up @@ -126,7 +136,7 @@ def __call__(
],
source_lang: str,
model_index: int,
batch_size: int = 1,
batch_size: int,
num_beams: int = 1,
suppress_blank: bool = False,
vocab: Union[List[str], None] = None,
Expand Down Expand Up @@ -220,6 +230,21 @@ def __call__(
"window_size_samples": 512,
},
)
elif self.model_engine == "faster-whisper-batched":
print("Batch size: ", batch_size)
segments, _ = self.model.transcribe(
audio,
language=source_lang,
hotwords=prompt,
beam_size=num_beams,
repetition_penalty=repetition_penalty,
compression_ratio_threshold=compression_ratio_threshold,
log_prob_threshold=log_prob_threshold,
no_speech_threshold=no_speech_threshold,
suppress_blank=suppress_blank,
word_timestamps=word_timestamps,
batch_size=batch_size,
)
elif self.model_engine == "tensorrt-llm":
segments = self.model.transcribe(
audio_data=[audio],
Expand Down

0 comments on commit 61be1ab

Please sign in to comment.