From dfee424718b6ac2333f2d8e04e0d617d3a9adf2c Mon Sep 17 00:00:00 2001 From: tro16 Date: Mon, 20 May 2024 17:59:06 +0200 Subject: [PATCH] fixing max 30sec audio bug --- amondin/speech2text.py | 50 ++++++++++-------------------------------- 1 file changed, 11 insertions(+), 39 deletions(-) diff --git a/amondin/speech2text.py b/amondin/speech2text.py index a75d899..b2bb358 100644 --- a/amondin/speech2text.py +++ b/amondin/speech2text.py @@ -3,7 +3,7 @@ """ import torch -from transformers import WhisperProcessor, WhisperForConditionalGeneration, AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline +from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline def speech2text( @@ -16,44 +16,12 @@ def speech2text( Translate audio to text :param device: Device to run the model on [cpu, cuda or cuda:x] :param audio: dictionary containing audio as numpy array of shape (n) and the sampling rate - :param model: + :param model_name: :param language: :return: """ - torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 - """ - # load model from huggingface - processor = WhisperProcessor.from_pretrained(model) - model = WhisperForConditionalGeneration.from_pretrained( - model, - torch_dtype=torch_dtype, - ).to(device) - - # specify task and language - forced_decoder_ids = processor.get_decoder_prompt_ids( - language=language, - task="transcribe" - ) - - # create input - input_features = processor( - audio["raw"], - sampling_rate=audio["sampling_rate"], - return_tensors="pt", - ).input_features.to(torch_dtype).to(device) - - # run inference - predicted_ids = model.generate( - input_features, - forced_decoder_ids=forced_decoder_ids, - ) - # convert output to text - result = processor.batch_decode( - predicted_ids, - skip_special_tokens=True - ) - """ + torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 model = AutoModelForSpeechSeq2Seq.from_pretrained( model_name, @@ -77,9 +45,13 @@ def speech2text( device=device, ) - result = pipe(audio, generate_kwargs={"task": "transcribe", "language": language}) - print(result) - return result["text"] + result = pipe( + audio, + generate_kwargs={ + "task": "transcribe", + "language": language + } + ) # return sting in list - return result[0] + return result["text"]