Skip to content

Commit

Permalink
Support HPU on ASR/TTS (#290)
Browse files Browse the repository at this point in the history
* hpu asr support

* add hpu support for tts

* enhance device condition

* fix lint and coverage

* remove env variables

* remove cpu_pool

---------

Co-authored-by: Haihao Shen <haihao.shen@intel.com>
  • Loading branch information
Spycsh and hshen14 committed Sep 13, 2023
1 parent daff796 commit fb619e5
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import soundfile as sf
import numpy as np
import contextlib
import intel_extension_for_pytorch as ipex

from .utils.english_normalizer import EnglishNormalizer

Expand All @@ -40,19 +39,18 @@ def __init__(self, output_audio_path="./response.wav", voice="default", stream_m
asset_path="/intel-extension-for-transformers/intel_extension_for_transformers/neural_chat/assets"):
"""Make sure your export LD_PRELOAD=<path to libiomp5.so and libtcmalloc> beforehand."""
# default setting
self.original_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
self.processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
self.device = device
self.original_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(self.device)
self.processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
self.voice = voice
self.output_audio_path = output_audio_path
self.stream_mode = stream_mode
self.spk_model_name = "speechbrain/spkrec-xvect-voxceleb"
self.speaker_model = EncoderClassifier.from_hparams(
source=self.spk_model_name,
run_opts={"device": self.device},
savedir=os.path.join("/tmp", self.spk_model_name)
)
self.vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
run_opts={"device": "cpu"},
savedir=os.path.join("/tmp", self.spk_model_name))
self.vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(self.device)
self.vocoder.eval()
script_dir = os.path.dirname(os.path.abspath(__file__))
if os.path.exists(os.path.join(script_dir, '../../../assets/speaker_embeddings/spk_embed_default.pt')):
Expand All @@ -77,15 +75,6 @@ def __init__(self, output_audio_path="./response.wav", voice="default", stream_m
elif os.path.exists(os.path.join(asset_path, 'speaker_embeddings/spk_embed_male.pt')):
self.male_speaker_embeddings = torch.load(os.path.join(asset_path, 'speaker_embeddings/spk_embed_male.pt'))

self.cpu_pool = None
if not torch.cuda.is_available():
# ipex IOMP hardware resources
if 'LD_PRELOAD' in os.environ and 'libiomp' in os.environ['LD_PRELOAD']:
import intel_extension_for_pytorch as ipex
self.cpu_pool = ipex.cpu.runtime.CPUPool([i for i in range(24)])
else:
print("Warning! You have not preloaded iomp beforehand and that may lead to performance issue")

self.normalizer = EnglishNormalizer()

def create_speaker_embedding(self, driven_audio_path):
Expand All @@ -97,10 +86,10 @@ def create_speaker_embedding(self, driven_audio_path):
[driven_audio_path]}).cast_column("audio", Audio(sampling_rate=16000))
waveform = audio_dataset[0]["audio"]['array']
with torch.no_grad():
speaker_embeddings = self.speaker_model.encode_batch(torch.tensor(waveform))
speaker_embeddings = self.speaker_model.encode_batch(torch.tensor(waveform).to("cpu"))
speaker_embeddings = torch.nn.functional.normalize(speaker_embeddings, dim=2) # [1,1,512]
speaker_embeddings = speaker_embeddings[0] # [1,512]
return speaker_embeddings.cpu()
return speaker_embeddings.to(self.device)

def _lookup_voice_embedding(self, voice,
asset_path="/intel-extension-for-transformers/intel_extension_for_transformers/neural_chat/assets"):
Expand Down Expand Up @@ -179,8 +168,8 @@ def text2speech(self, text, output_audio_path, voice="default", do_batch_tts=Fal
for text_in in texts:
inputs = self.processor(text=text_in, return_tensors="pt")
with torch.no_grad():
with ipex.cpu.runtime.pin(self.cpu_pool) if self.cpu_pool else contextlib.nullcontext():
spectrogram = model.generate_speech(inputs["input_ids"], speaker_embeddings)
spectrogram = model.generate_speech(
inputs["input_ids"].to(self.device), speaker_embeddings.to(self.device))
speech = self.vocoder(spectrogram)
all_speech = np.concatenate([all_speech, speech.cpu().numpy()])
all_speech = np.concatenate([all_speech, np.array([0 for i in range(8000)])]) # pad after each end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,25 @@
class TestASR(unittest.TestCase):
@classmethod
def setUpClass(self):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.asr = AudioSpeechRecognition("openai/whisper-small", device=device)
if not torch.cuda.is_available():
try:
import habana_frameworks.torch.hpu as hthpu
self.is_hpu_available = True
except ImportError:
self.is_hpu_available = False
try:
import intel_extension_for_pytorch as intel_ipex
self.is_ipex_available = True
except ImportError:
self.is_ipex_available = False
if self.is_hpu_available:
self.device = "hpu"
else:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.asr = AudioSpeechRecognition("openai/whisper-small", device=self.device)
if self.device == "cpu" and self.is_ipex_available:
self.asr_bf16 = AudioSpeechRecognition("openai/whisper-small", bf16=True)
else:
self.asr_bf16 = None

def test_audio2text(self):
audio_path = "/intel-extension-for-transformers/intel_extension_for_transformers/neural_chat/assets/audio/welcome.wav"
Expand All @@ -37,7 +52,7 @@ def test_audio2text(self):
self.assertEqual(text.lower(), "Welcome to Neural Chat".lower())

def test_audio2text_bf16(self):
if torch.cuda.is_available():
if self.asr_bf16 is None:
return
audio_path = "/intel-extension-for-transformers/intel_extension_for_transformers/neural_chat/assets/audio/welcome.wav"
audio_path = "/intel-extension-for-transformers/intel_extension_for_transformers/neural_chat/assets/audio/welcome.wav"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,23 @@
class TestTTS(unittest.TestCase):
@classmethod
def setUpClass(self):
self.tts = TextToSpeech(device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
self.asr = AudioSpeechRecognition("openai/whisper-small")
try:
import habana_frameworks.torch.hpu as hthpu
self.is_hpu_available = True
except ImportError:
self.is_hpu_available = False
try:
import intel_extension_for_pytorch as ipex
self.is_ipex_available = True
except ImportError:
self.is_ipex_available = False
if self.is_hpu_available:
self.device = "hpu"
else:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tts = TextToSpeech(device=self.device)
self.asr = AudioSpeechRecognition("openai/whisper-small", device=self.device)
shutil.rmtree('./tmp_audio', ignore_errors=True)
os.mkdir('./tmp_audio')

@classmethod
Expand Down

0 comments on commit fb619e5

Please sign in to comment.