Skip to content

Commit

Permalink
fixing max 30sec audio bug
Browse files Browse the repository at this point in the history
  • Loading branch information
tim-roethig-db committed May 20, 2024
1 parent beb8f04 commit dfee424
Showing 1 changed file with 11 additions and 39 deletions.
50 changes: 11 additions & 39 deletions amondin/speech2text.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration, AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline


def speech2text(
Expand All @@ -16,44 +16,12 @@ def speech2text(
Translate audio to text
:param device: Device to run the model on [cpu, cuda or cuda:x]
:param audio: dictionary containing audio as numpy array of shape (n) and the sampling rate
:param model:
:param model_name:
:param language:
:return:
"""
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
"""
# load model from huggingface
processor = WhisperProcessor.from_pretrained(model)
model = WhisperForConditionalGeneration.from_pretrained(
model,
torch_dtype=torch_dtype,
).to(device)
# specify task and language
forced_decoder_ids = processor.get_decoder_prompt_ids(
language=language,
task="transcribe"
)
# create input
input_features = processor(
audio["raw"],
sampling_rate=audio["sampling_rate"],
return_tensors="pt",
).input_features.to(torch_dtype).to(device)
# run inference
predicted_ids = model.generate(
input_features,
forced_decoder_ids=forced_decoder_ids,
)

# convert output to text
result = processor.batch_decode(
predicted_ids,
skip_special_tokens=True
)
"""
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_name,
Expand All @@ -77,9 +45,13 @@ def speech2text(
device=device,
)

result = pipe(audio, generate_kwargs={"task": "transcribe", "language": language})
print(result)
return result["text"]
result = pipe(
audio,
generate_kwargs={
"task": "transcribe",
"language": language
}
)

# return sting in list
return result[0]
return result["text"]

0 comments on commit dfee424

Please sign in to comment.