Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add vocab feature #124

Merged
merged 12 commits into from
Jun 30, 2023
16 changes: 15 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def test_audio_request() -> None:
assert request.source_lang == "en"
assert request.timestamps == "s"
assert request.use_batch is False
assert request.vocab == []
assert request.word_timestamps is False


Expand All @@ -56,6 +57,7 @@ def test_audio_response() -> None:
source_lang="en",
timestamps="s",
use_batch=False,
vocab=["custom company", "custom product"],
word_timestamps=False,
)
assert response.utterances == []
Expand All @@ -65,6 +67,7 @@ def test_audio_response() -> None:
assert response.source_lang == "en"
assert response.timestamps == "s"
assert response.use_batch is False
assert response.vocab == ["custom company", "custom product"]
assert response.word_timestamps is False

response = AudioResponse(
Expand All @@ -78,6 +81,7 @@ def test_audio_response() -> None:
source_lang="en",
timestamps="s",
use_batch=False,
vocab=["custom company", "custom product"],
word_timestamps=True,
)
assert response.utterances == [
Expand All @@ -90,6 +94,7 @@ def test_audio_response() -> None:
assert response.source_lang == "en"
assert response.timestamps == "s"
assert response.use_batch is False
assert response.vocab == ["custom company", "custom product"]
assert response.word_timestamps is True


Expand Down Expand Up @@ -135,6 +140,7 @@ def test_base_response() -> None:
source_lang="en",
timestamps="s",
use_batch=False,
vocab=["custom company", "custom product"],
word_timestamps=False,
)
assert response.utterances == [
Expand All @@ -146,6 +152,7 @@ def test_base_response() -> None:
assert response.source_lang == "en"
assert response.timestamps == "s"
assert response.use_batch is False
assert response.vocab == ["custom company", "custom product"]
assert response.word_timestamps is False


Expand All @@ -157,7 +164,7 @@ def test_cortex_error() -> None:
assert error.message == "This is a test error"


def test_corxet_payload() -> None:
def test_cortex_payload() -> None:
"""Test the CortexPayload model."""
payload = CortexPayload(
url_type="youtube",
Expand All @@ -182,6 +189,7 @@ def test_corxet_payload() -> None:
assert payload.source_lang == "en"
assert payload.timestamps == "s"
assert payload.use_batch is False
assert payload.vocab == []
assert payload.word_timestamps is False
assert payload.job_name == "test_job"
assert payload.ping is False
Expand All @@ -199,6 +207,7 @@ def test_cortex_url_response() -> None:
source_lang="en",
timestamps="s",
use_batch=False,
vocab=["custom company", "custom product"],
word_timestamps=False,
dual_channel=False,
job_name="test_job",
Expand All @@ -213,6 +222,7 @@ def test_cortex_url_response() -> None:
assert response.source_lang == "en"
assert response.timestamps == "s"
assert response.use_batch is False
assert response.vocab == ["custom company", "custom product"]
assert response.word_timestamps is False
assert response.dual_channel is False
assert response.job_name == "test_job"
Expand All @@ -231,6 +241,7 @@ def test_cortex_youtube_response() -> None:
source_lang="en",
timestamps="s",
use_batch=False,
vocab=["custom company", "custom product"],
word_timestamps=False,
video_url="https://www.youtube.com/watch?v=dQw4w9WgXcQ",
job_name="test_job",
Expand All @@ -245,6 +256,7 @@ def test_cortex_youtube_response() -> None:
assert response.source_lang == "en"
assert response.timestamps == "s"
assert response.use_batch is False
assert response.vocab == ["custom company", "custom product"]
assert response.word_timestamps is False
assert response.video_url == "https://www.youtube.com/watch?v=dQw4w9WgXcQ"
assert response.job_name == "test_job"
Expand All @@ -263,6 +275,7 @@ def test_youtube_response() -> None:
source_lang="en",
timestamps="s",
use_batch=False,
vocab=["custom company", "custom product"],
word_timestamps=False,
video_url="https://www.youtube.com/watch?v=dQw4w9WgXcQ",
)
Expand All @@ -275,5 +288,6 @@ def test_youtube_response() -> None:
assert response.source_lang == "en"
assert response.timestamps == "s"
assert response.use_batch is False
assert response.vocab == ["custom company", "custom product"]
assert response.word_timestamps is False
assert response.video_url == "https://www.youtube.com/watch?v=dQw4w9WgXcQ"
47 changes: 47 additions & 0 deletions wordcab_transcribe/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class BaseResponse(BaseModel):
source_lang: str
timestamps: str
use_batch: bool
vocab: List[str]
word_timestamps: bool


Expand Down Expand Up @@ -59,6 +60,11 @@ class Config:
"source_lang": "en",
"timestamps": "s",
"use_batch": False,
"vocab": [
"custom company name",
"custom product name",
"custom co-worker name",
],
"word_timestamps": False,
"dual_channel": False,
}
Expand Down Expand Up @@ -94,6 +100,11 @@ class Config:
"source_lang": "en",
"timestamps": "s",
"use_batch": False,
"vocab": [
"custom company name",
"custom product name",
"custom co-worker name",
],
"word_timestamps": False,
"video_url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ",
}
Expand Down Expand Up @@ -127,6 +138,7 @@ class CortexPayload(BaseModel):
source_lang: Optional[str] = "en"
timestamps: Optional[str] = "s"
use_batch: Optional[bool] = False
vocab: Optional[List[str]] = []
word_timestamps: Optional[bool] = False
job_name: Optional[str] = None
ping: Optional[bool] = False
Expand Down Expand Up @@ -159,6 +171,11 @@ class Config:
"source_lang": "en",
"timestamps": "s",
"use_batch": False,
"vocab": [
"custom company name",
"custom product name",
"custom co-worker name",
],
"word_timestamps": False,
"job_name": "job_abc123",
"ping": False,
Expand Down Expand Up @@ -196,6 +213,11 @@ class Config:
"source_lang": "en",
"timestamps": "s",
"use_batch": False,
"vocab": [
"custom company name",
"custom product name",
"custom co-worker name",
],
"word_timestamps": False,
"dual_channel": False,
"job_name": "job_name",
Expand Down Expand Up @@ -234,6 +256,11 @@ class Config:
"source_lang": "en",
"timestamps": "s",
"use_batch": False,
"vocab": [
"custom company name",
"custom product name",
"custom co-worker name",
],
"word_timestamps": False,
"video_url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ",
"job_name": "job_name",
Expand All @@ -250,6 +277,7 @@ class BaseRequest(BaseModel):
source_lang: str = "en"
timestamps: str = "s"
use_batch: bool = False
vocab: List[str] = []
word_timestamps: bool = False

@validator("timestamps")
Expand All @@ -259,6 +287,15 @@ def validate_timestamps_values(cls, value: str) -> str: # noqa: B902, N805
raise ValueError("timestamps must be one of 'hms', 'ms', 's'.")
return value

@validator("vocab")
def validate_each_vocab_value(
cls, value: List[str] # noqa: B902, N805
) -> List[str]:
"""Validate the value of each vocab field."""
if not all(isinstance(v, str) for v in value):
raise ValueError("vocab must be a list of strings.")
return value

class Config:
"""Pydantic config class."""

Expand All @@ -269,6 +306,11 @@ class Config:
"source_lang": "en",
"timestamps": "s",
"use_batch": False,
"vocab": [
"custom company name",
"custom product name",
"custom co-worker name",
],
"word_timestamps": False,
}
}
Expand All @@ -289,6 +331,11 @@ class Config:
"source_lang": "en",
"timestamps": "s",
"use_batch": False,
"vocab": [
"custom company name",
"custom product name",
"custom co-worker name",
],
"word_timestamps": False,
"dual_channel": False,
}
Expand Down
6 changes: 5 additions & 1 deletion wordcab_transcribe/router/v1/audio_file_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Audio file endpoint for the Wordcab Transcribe API."""

import asyncio
from typing import Union
from typing import List, Union

import shortuuid
from fastapi import APIRouter, BackgroundTasks, File, Form, HTTPException, UploadFile
Expand Down Expand Up @@ -45,6 +45,7 @@ async def inference_with_audio(
source_lang: str = Form("en"), # noqa: B008
timestamps: str = Form("s"), # noqa: B008
use_batch: bool = Form(False), # noqa: B008
vocab: List[str] = Form([]), # noqa: B008
word_timestamps: bool = Form(False), # noqa: B008
file: UploadFile = File(...), # noqa: B008
) -> AudioResponse:
Expand All @@ -64,6 +65,7 @@ async def inference_with_audio(
source_lang=source_lang,
timestamps=timestamps,
use_batch=use_batch,
vocab=vocab,
word_timestamps=word_timestamps,
dual_channel=dual_channel,
)
Expand Down Expand Up @@ -91,6 +93,7 @@ async def inference_with_audio(
source_lang=data.source_lang,
timestamps_format=data.timestamps,
use_batch=data.use_batch,
vocab=data.vocab,
word_timestamps=data.word_timestamps,
)
)
Expand All @@ -113,5 +116,6 @@ async def inference_with_audio(
source_lang=data.source_lang,
timestamps=data.timestamps,
use_batch=data.use_batch,
vocab=data.vocab,
word_timestamps=data.word_timestamps,
)
2 changes: 2 additions & 0 deletions wordcab_transcribe/router/v1/audio_url_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ async def inference_with_audio_url(
source_lang=data.source_lang,
timestamps_format=data.timestamps,
use_batch=data.use_batch,
vocab=data.vocab,
word_timestamps=data.word_timestamps,
)
)
Expand All @@ -92,5 +93,6 @@ async def inference_with_audio_url(
source_lang=data.source_lang,
timestamps=data.timestamps,
use_batch=data.use_batch,
vocab=data.vocab,
word_timestamps=data.word_timestamps,
)
2 changes: 2 additions & 0 deletions wordcab_transcribe/router/v1/cortex_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ async def run_cortex(
source_lang=payload.source_lang,
timestamps=payload.timestamps,
use_batch=payload.use_batch,
vocab=payload.vocab,
word_timestamps=payload.word_timestamps,
)
utterances: AudioResponse = await inference_with_audio_url(
Expand All @@ -86,6 +87,7 @@ async def run_cortex(
source_lang=payload.source_lang,
timestamps=payload.timestamps,
use_batch=payload.use_batch,
vocab=payload.vocab,
word_timestamps=payload.word_timestamps,
)
utterances: YouTubeResponse = await inference_with_youtube(
Expand Down
5 changes: 4 additions & 1 deletion wordcab_transcribe/services/asr_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ async def process_input(
source_lang: str,
timestamps_format: str,
use_batch: bool,
vocab: List[str],
word_timestamps: bool,
) -> Union[List[dict], Exception]:
"""Process the input request and return the results.
Expand All @@ -190,6 +191,7 @@ async def process_input(
source_lang (str): Source language of the audio file.
timestamps_format (str): Timestamps format to use.
use_batch (bool): Whether to use batch processing or not.
vocab (List[str]): List of words to use for the vocabulary.
word_timestamps (bool): Whether to return word timestamps or not.

Returns:
Expand All @@ -203,8 +205,8 @@ async def process_input(
"source_lang": source_lang,
"timestamps_format": timestamps_format,
"use_batch": use_batch,
"vocab": vocab,
"word_timestamps": word_timestamps,
"post_processed": False,
"transcription_result": None,
"transcription_done": asyncio.Event(),
"diarization_result": None,
Expand Down Expand Up @@ -324,6 +326,7 @@ def process_transcription(
task["input"],
source_lang=task["source_lang"],
suppress_blank=False,
vocab=None if task["vocab"] == [] else task["vocab"],
word_timestamps=True,
vad_service=self.services["vad"] if task["dual_channel"] else None,
use_batch=task["use_batch"],
Expand Down
4 changes: 4 additions & 0 deletions wordcab_transcribe/services/transcribe_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ def __call__(
audio: Union[str, torch.Tensor, Tuple[str, str]],
source_lang: str,
suppress_blank: bool = False,
vocab: Optional[List[str]] = None,
word_timestamps: bool = True,
vad_service: Optional[VadService] = None,
use_batch: bool = True,
Expand All @@ -335,6 +336,7 @@ def __call__(
audio files.
source_lang (str): Language of the audio file.
suppress_blank (bool): Whether to suppress blank at the beginning of the sampling.
vocab (Optional[List[str]]): Vocabulary to use during generation if not None.
word_timestamps (bool): Whether to return word timestamps.
vad_service (Optional[VADService]): VADService to use for voice activity detection in the dual_channel case.
use_batch (bool): Whether to use batch inference.
Expand Down Expand Up @@ -365,9 +367,11 @@ def __call__(
self.loaded_model_lang = "multi"

if not use_batch:
prompt = ", ".join(vocab) if vocab else None
segments, _ = self.model.transcribe(
audio,
language=source_lang,
initial_prompt=prompt,
suppress_blank=False,
word_timestamps=True,
)
Expand Down