From 6b2ca0a9e6191f55a9c4623e6d2de2290fd9c84c Mon Sep 17 00:00:00 2001 From: tro16 Date: Sat, 18 May 2024 15:40:15 +0200 Subject: [PATCH] testing --- amondin/diarize_speakers.py | 4 ++-- amondin/main.py | 13 +++++++++---- amondin/speech2text.py | 4 ++-- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/amondin/diarize_speakers.py b/amondin/diarize_speakers.py index 8d6d6db..93bad7c 100644 --- a/amondin/diarize_speakers.py +++ b/amondin/diarize_speakers.py @@ -7,17 +7,17 @@ def diarize_speakers( - file_path: str, hf_token: str, num_speakers: int = None, tolerance: float = 1.0 + file_path: str, hf_token: str, device: str, num_speakers: int, tolerance: float = 1.0 ) -> list[dict]: """ Detect speakers in audio.wav file and label the segments of each speaker accordingly + :param device: Device to run the model on :param file_path: :param hf_token: HF token since the pyanote model needs authentication :param num_speakers: Set to None to self detect the number of speakers :param tolerance: :return: """ - device = "cuda:0" if torch.cuda.is_available() else "cpu" pipeline = Pipeline.from_pretrained( "pyannote/speaker-diarization-3.1", diff --git a/amondin/main.py b/amondin/main.py index 708652e..52201f5 100644 --- a/amondin/main.py +++ b/amondin/main.py @@ -9,11 +9,12 @@ def transcribe( - input_file_path: str, output_file_path: str, hf_token: str, language: str = "german", num_speakers: int = None, - s2t_model: str = "openai/whisper-tiny" + input_file_path: str, output_file_path: str, hf_token: str, device: str = "cpu", + language: str = "german", num_speakers: int = None, s2t_model: str = "openai/whisper-tiny" ): """ Transcribe a give audio.wav file. + :param device: Device to run the model on [cpu, cuda or cuda:x] :param output_file_path: :param input_file_path: :param hf_token: @@ -23,17 +24,21 @@ def transcribe( :param s2t_model: :return: """ + + print(f"Running on {device}.") + print("Diarizing speakers...") diarized_speakers = diarize_speakers( input_file_path, hf_token=hf_token, num_speakers=num_speakers, + device=device ) - print("Transcripting audio...") + print("Transcribing audio...") transcript = [] for i, speaker_section in enumerate(diarized_speakers): - print(f"Transcripting part {i+1} of {len(diarized_speakers)}") + print(f"Transcribing part {i+1} of {len(diarized_speakers)}") text = speech2text( speaker_section["audio"], model=s2t_model, diff --git a/amondin/speech2text.py b/amondin/speech2text.py index 3fb4e32..d431088 100644 --- a/amondin/speech2text.py +++ b/amondin/speech2text.py @@ -6,15 +6,15 @@ from transformers import WhisperProcessor, WhisperForConditionalGeneration -def speech2text(audio: dict, model: str = "openai/whisper-tiny", language: str = "german") -> str: +def speech2text(audio: dict, device: str, model: str = "openai/whisper-tiny", language: str = "german") -> str: """ Translate audio to text + :param device: Device to run the model on [cpu, cuda or cuda:x] :param audio: dictionary containing audio as numpy array of shape (n,) and the sampling rate :param model: :param language: :return: """ - device = "cuda:0" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 # load model from huggingface