From dc37a357a34396482f91a9921a3c470b17b17c9e Mon Sep 17 00:00:00 2001 From: David Anyatonwu Date: Sat, 14 Sep 2024 17:18:22 +0100 Subject: [PATCH 1/5] feat(audio): implement MKL-accelerated speech-to-text for Mac Signed-off-by: David Anyatonwu --- .github/workflows/benchmark.yml | 35 +++++-------- screenpipe-audio/Cargo.toml | 7 +-- screenpipe-audio/benches/stt_benchmark.rs | 60 +++++++---------------- screenpipe-audio/src/stt.rs | 18 ++++++- 4 files changed, 49 insertions(+), 71 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 0bf15d9d..160475ee 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -72,59 +72,46 @@ jobs: stt_benchmark: name: Run STT benchmark - runs-on: ubuntu-latest + runs-on: macos-latest steps: - uses: actions/checkout@v3 - uses: dtolnay/rust-toolchain@stable - name: Install dependencies run: | - sudo apt-get update - sudo apt-get install -y ffmpeg tesseract-ocr libtesseract-dev libavformat-dev libavfilter-dev libavdevice-dev ffmpeg libasound2-dev libgtk-3-dev libsoup-3.0-dev libjavascriptcoregtk-4.1-dev libwebkit2gtk-4.1-dev + brew install cmake openblas lapack - - name: Run STT benchmarks + - name: Run STT benchmarks (MKL) run: | - cargo bench --bench stt_benchmark -- --output-format bencher | tee -a stt_output.txt + cargo bench --bench stt_benchmark --features mkl -- --output-format bencher | tee -a stt_output_mkl.txt - name: Upload STT benchmark artifact uses: actions/upload-artifact@v3 with: - name: stt-benchmark-data - path: stt_output.txt + name: stt-benchmark-data-macos + path: stt_output_mkl.txt analyze_benchmarks: - needs: - [ - apple_ocr_benchmark, - tesseract_ocr_benchmark, - windows_ocr_benchmark, - stt_benchmark, - ] + needs: [stt_benchmark] runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Download benchmark data - uses: actions/download-artifact@v3 - with: - name: ocr-benchmark-data - path: ./cache/ocr - - name: Download STT benchmark data uses: actions/download-artifact@v3 with: - name: stt-benchmark-data + name: stt-benchmark-data-macos path: ./cache/stt - name: List contents of cache directory run: ls -R ./cache - - name: Analyze OCR benchmarks + - name: Analyze STT benchmarks uses: benchmark-action/github-action-benchmark@v1 with: - name: OCR Benchmarks + name: STT Benchmarks tool: "cargo" - output-file-path: ./cache/ocr/ocr_output.txt + output-file-path: ./cache/stt/stt_output_mkl.txt github-token: ${{ secrets.GH_PAGES_TOKEN }} auto-push: true alert-threshold: "200%" diff --git a/screenpipe-audio/Cargo.toml b/screenpipe-audio/Cargo.toml index 06e0ba9b..1c759ebe 100644 --- a/screenpipe-audio/Cargo.toml +++ b/screenpipe-audio/Cargo.toml @@ -31,9 +31,9 @@ chrono = { version = "0.4.31", features = ["serde"] } # Local Embeddings + STT # TODO: feature metal, cuda, etc. see https://github.com/huggingface/candle/blob/main/candle-core/Cargo.toml -candle = { workspace = true } -candle-nn = { workspace = true } -candle-transformers = { workspace = true } +candle = { workspace = true, features = ["mkl"] } +candle-nn = { workspace = true, features = ["mkl"] } +candle-transformers = { workspace = true, features = ["mkl"] } vad-rs = "0.1.3" tokenizers = { workspace = true } anyhow = "1.0.86" @@ -80,6 +80,7 @@ criterion = { workspace = true } memory-stats = "1.0" [features] +default = ["mkl"] metal = ["candle/metal", "candle-nn/metal", "candle-transformers/metal"] cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"] mkl = ["candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"] diff --git a/screenpipe-audio/benches/stt_benchmark.rs b/screenpipe-audio/benches/stt_benchmark.rs index 38da8044..66b57f59 100644 --- a/screenpipe-audio/benches/stt_benchmark.rs +++ b/screenpipe-audio/benches/stt_benchmark.rs @@ -1,12 +1,12 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use memory_stats::memory_stats; -use screenpipe_audio::vad_engine::SileroVad; use screenpipe_audio::{ - create_whisper_channel, stt, AudioTranscriptionEngine, VadEngineEnum, WhisperModel, + stt, AudioInput, AudioTranscriptionEngine, WhisperModel, vad_engine::SileroVad }; -use std::path::PathBuf; use std::sync::Arc; use std::time::Duration; +use std::path::PathBuf; +use std::fs::File; +use std::io::Read; fn criterion_benchmark(c: &mut Criterion) { let audio_transcription_engine = Arc::new(AudioTranscriptionEngine::WhisperTiny); @@ -14,59 +14,35 @@ fn criterion_benchmark(c: &mut Criterion) { let test_file_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")) .join("test_data") .join("selah.mp4"); + let mut audio_data = Vec::new(); + File::open(&test_file_path).unwrap().read_to_end(&mut audio_data).unwrap(); let mut group = c.benchmark_group("whisper_benchmarks"); group.sample_size(10); group.measurement_time(Duration::from_secs(60)); - group.bench_function("create_whisper_channel", |b| { - b.iter(|| { - let _ = create_whisper_channel( - black_box(audio_transcription_engine.clone()), - black_box(VadEngineEnum::Silero), - None, - ); - }) - }); - - group.bench_function("stt", |b| { + group.bench_function("stt_mkl", |b| { b.iter(|| { let mut vad_engine = Box::new(SileroVad::new().unwrap()); + let audio_input = AudioInput { + data: audio_data.clone().into_iter().map(|x| x as f32).collect(), + sample_rate: 16000, + channels: 1, + device: "test".to_string(), + }; let _ = stt( - black_box(test_file_path.to_string_lossy().as_ref()), + black_box(&audio_input), black_box(&whisper_model), black_box(audio_transcription_engine.clone()), - &mut *vad_engine, - None, + black_box(&mut *vad_engine), + black_box(None), + black_box(&PathBuf::from("test_output")), ); }) }); - group.bench_function("memory_usage_stt", |b| { - b.iter_custom(|iters| { - let mut total_duration = Duration::new(0, 0); - for _ in 0..iters { - let start = std::time::Instant::now(); - let before = memory_stats().unwrap().physical_mem; - let mut vad_engine = Box::new(SileroVad::new().unwrap()); - let _ = stt( - test_file_path.to_string_lossy().as_ref(), - &whisper_model, - audio_transcription_engine.clone(), - &mut *vad_engine, - None, - ); - let after = memory_stats().unwrap().physical_mem; - let duration = start.elapsed(); - total_duration += duration; - println!("Memory used: {} bytes", after - before); - } - total_duration - }) - }); - group.finish(); } criterion_group!(benches, criterion_benchmark); -criterion_main!(benches); +criterion_main!(benches); \ No newline at end of file diff --git a/screenpipe-audio/src/stt.rs b/screenpipe-audio/src/stt.rs index 87d5f0ee..587fe739 100644 --- a/screenpipe-audio/src/stt.rs +++ b/screenpipe-audio/src/stt.rs @@ -40,8 +40,8 @@ pub struct WhisperModel { impl WhisperModel { pub fn new(engine: Arc) -> Result { debug!("Initializing WhisperModel"); - let device = Device::new_metal(0).unwrap_or(Device::new_cuda(0).unwrap_or(Device::Cpu)); - info!("device = {:?}", device); + let device = Self::get_optimal_device()?; + info!("Using device: {:?}", device); debug!("Fetching model files"); let (config_filename, tokenizer_filename, weights_filename) = { @@ -86,6 +86,20 @@ impl WhisperModel { device, }) } + + fn get_optimal_device() -> Result { + #[cfg(feature = "mkl")] + { + info!("Using MKL-accelerated CPU"); + Ok(Device::Cpu) + } + #[cfg(not(feature = "mkl"))] + { + info!("Using standard CPU"); + Ok(Device::Cpu) + } + } + } #[derive(Debug, Clone)] From f873fba1ef83787ab284b8973a38f2afb8ebfdda Mon Sep 17 00:00:00 2001 From: David Anyatonwu Date: Mon, 16 Sep 2024 12:23:22 +0100 Subject: [PATCH 2/5] feat(build): add MKL support for Windows and Linux builds - Update Cargo.toml to include MKL feature - Modify release-app.yml to use MKL feature for Windows and Linux - Keep Metal feature for macOS builds - Set RUSTFLAGS for optimized builds on Windows and Linux Signed-off-by: David Anyatonwu --- .github/workflows/benchmark.yml | 56 ++++++++++------------ .github/workflows/release-app.yml | 19 ++++---- screenpipe-audio/Cargo.toml | 11 +++-- screenpipe-audio/benches/stt_benchmark.rs | 58 ++++++++++++++++------- screenpipe-audio/src/stt.rs | 13 ++--- 5 files changed, 88 insertions(+), 69 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 160475ee..821cdf6a 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -72,63 +72,57 @@ jobs: stt_benchmark: name: Run STT benchmark - runs-on: macos-latest + runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - uses: dtolnay/rust-toolchain@stable - name: Install dependencies run: | - brew install cmake openblas lapack - - - name: Run STT benchmarks (MKL) + sudo apt-get update + sudo apt-get install -y ffmpeg tesseract-ocr libtesseract-dev libavformat-dev libavfilter-dev libavdevice-dev ffmpeg libasound2-dev libgtk-3-dev libsoup-3.0-dev libjavascriptcoregtk-4.1-dev libwebkit2gtk-4.1-dev + - name: Run STT benchmarks run: | - cargo bench --bench stt_benchmark --features mkl -- --output-format bencher | tee -a stt_output_mkl.txt - + cargo bench --bench stt_benchmark -- --output-format bencher | tee -a stt_output.txt - name: Upload STT benchmark artifact uses: actions/upload-artifact@v3 with: - name: stt-benchmark-data-macos - path: stt_output_mkl.txt + name: stt-benchmark-data + path: stt_output.txt analyze_benchmarks: - needs: [stt_benchmark] + needs: + [ + apple_ocr_benchmark, + tesseract_ocr_benchmark, + windows_ocr_benchmark, + stt_benchmark, + ] runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 + - name: Download benchmark data + uses: actions/download-artifact@v3 + with: + name: ocr-benchmark-data + path: ./cache/ocr + - name: Download STT benchmark data uses: actions/download-artifact@v3 with: - name: stt-benchmark-data-macos + name: stt-benchmark-data path: ./cache/stt - name: List contents of cache directory run: ls -R ./cache - - name: Analyze STT benchmarks + - name: Analyze OCR benchmarks uses: benchmark-action/github-action-benchmark@v1 with: - name: STT Benchmarks + name: OCR Benchmarks tool: "cargo" - output-file-path: ./cache/stt/stt_output_mkl.txt + output-file-path: ./cache/ocr/ocr_output.txt github-token: ${{ secrets.GH_PAGES_TOKEN }} auto-push: true - alert-threshold: "200%" - comment-on-alert: true - fail-on-alert: true - alert-comment-cc-users: "@louis030195" - - # todo broken - # - name: Analyze STT benchmarks - # uses: benchmark-action/github-action-benchmark@v1 - # with: - # name: STT Benchmarks - # tool: "cargo" - # output-file-path: ./cache/stt/stt_output.txt - # github-token: ${{ secrets.GH_PAGES_TOKEN }} - # auto-push: true - # alert-threshold: "200%" - # comment-on-alert: true - # fail-on-alert: true - # alert-comment-cc-users: "@louis030195" + alert-threshold: "200%" \ No newline at end of file diff --git a/.github/workflows/release-app.yml b/.github/workflows/release-app.yml index 147f48a4..b727ed28 100644 --- a/.github/workflows/release-app.yml +++ b/.github/workflows/release-app.yml @@ -33,18 +33,19 @@ jobs: fail-fast: false matrix: include: - - platform: "macos-latest" # for Arm based macs (M1 and above). + - platform: "macos-latest" args: "--target aarch64-apple-darwin --features metal" target: aarch64-apple-darwin - - platform: "macos-latest" # for Intel based macs. + - platform: "macos-latest" args: "--target x86_64-apple-darwin --features metal" target: x86_64-apple-darwin - - platform: "ubuntu-22.04" # Ubuntu x86_64 - args: "" # TODO CUDA, mkl - - platform: "windows-latest" # Windows x86_64 - args: "--target x86_64-pc-windows-msvc" # TODO CUDA, mkl? --features "openblas" - pre-build-args: "" # --openblas - # windows arm: https://github.com/ahqsoftwares/tauri-ahq-store/blob/2fbc2103c222662b3c6ee0cd71fcde664824f0ef/.github/workflows/publish.yml#L136 + - platform: "ubuntu-22.04" + args: "--features mkl" + target: x86_64-unknown-linux-gnu + - platform: "windows-latest" + args: "--target x86_64-pc-windows-msvc --features mkl" + target: x86_64-pc-windows-msvc + pre-build-args: "" runs-on: ${{ matrix.platform }} steps: @@ -150,6 +151,8 @@ jobs: export PKG_CONFIG_PATH="/usr/local/opt/ffmpeg/lib/pkgconfig:$PKG_CONFIG_PATH" export PKG_CONFIG_ALLOW_CROSS=1 export RUSTFLAGS="-C link-arg=-Wl,-rpath,@executable_path/../Frameworks -C link-arg=-Wl,-rpath,@loader_path/../Frameworks -C link-arg=-Wl,-install_name,@rpath/libscreenpipe.dylib" + elif [[ "${{ matrix.platform }}" == "ubuntu-22.04" || "${{ matrix.platform }}" == "windows-latest" ]]; then + export RUSTFLAGS="-C target-cpu=native" fi cargo build --release ${{ matrix.args }} ls -R target diff --git a/screenpipe-audio/Cargo.toml b/screenpipe-audio/Cargo.toml index 1c759ebe..017d9daa 100644 --- a/screenpipe-audio/Cargo.toml +++ b/screenpipe-audio/Cargo.toml @@ -31,9 +31,9 @@ chrono = { version = "0.4.31", features = ["serde"] } # Local Embeddings + STT # TODO: feature metal, cuda, etc. see https://github.com/huggingface/candle/blob/main/candle-core/Cargo.toml -candle = { workspace = true, features = ["mkl"] } -candle-nn = { workspace = true, features = ["mkl"] } -candle-transformers = { workspace = true, features = ["mkl"] } +candle = { workspace = true } +candle-nn = { workspace = true } +candle-transformers = { workspace = true } vad-rs = "0.1.3" tokenizers = { workspace = true } anyhow = "1.0.86" @@ -80,11 +80,12 @@ criterion = { workspace = true } memory-stats = "1.0" [features] -default = ["mkl"] +default = ["metal"] metal = ["candle/metal", "candle-nn/metal", "candle-transformers/metal"] -cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"] mkl = ["candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"] + + [[bin]] name = "screenpipe-audio" path = "src/bin/screenpipe-audio.rs" diff --git a/screenpipe-audio/benches/stt_benchmark.rs b/screenpipe-audio/benches/stt_benchmark.rs index 66b57f59..6df4fea4 100644 --- a/screenpipe-audio/benches/stt_benchmark.rs +++ b/screenpipe-audio/benches/stt_benchmark.rs @@ -1,12 +1,12 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use memory_stats::memory_stats; +use screenpipe_audio::vad_engine::SileroVad; use screenpipe_audio::{ - stt, AudioInput, AudioTranscriptionEngine, WhisperModel, vad_engine::SileroVad + create_whisper_channel, stt, AudioTranscriptionEngine, VadEngineEnum, WhisperModel, }; +use std::path::PathBuf; use std::sync::Arc; use std::time::Duration; -use std::path::PathBuf; -use std::fs::File; -use std::io::Read; fn criterion_benchmark(c: &mut Criterion) { let audio_transcription_engine = Arc::new(AudioTranscriptionEngine::WhisperTiny); @@ -14,33 +14,57 @@ fn criterion_benchmark(c: &mut Criterion) { let test_file_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")) .join("test_data") .join("selah.mp4"); - let mut audio_data = Vec::new(); - File::open(&test_file_path).unwrap().read_to_end(&mut audio_data).unwrap(); let mut group = c.benchmark_group("whisper_benchmarks"); group.sample_size(10); group.measurement_time(Duration::from_secs(60)); - group.bench_function("stt_mkl", |b| { + group.bench_function("create_whisper_channel", |b| { + b.iter(|| { + let _ = create_whisper_channel( + black_box(audio_transcription_engine.clone()), + black_box(VadEngineEnum::Silero), + None, + ); + }) + }); + + group.bench_function("stt", |b| { b.iter(|| { let mut vad_engine = Box::new(SileroVad::new().unwrap()); - let audio_input = AudioInput { - data: audio_data.clone().into_iter().map(|x| x as f32).collect(), - sample_rate: 16000, - channels: 1, - device: "test".to_string(), - }; let _ = stt( - black_box(&audio_input), + black_box(test_file_path.to_string_lossy().as_ref()), black_box(&whisper_model), black_box(audio_transcription_engine.clone()), - black_box(&mut *vad_engine), - black_box(None), - black_box(&PathBuf::from("test_output")), + &mut *vad_engine, + None, ); }) }); + group.bench_function("memory_usage_stt", |b| { + b.iter_custom(|iters| { + let mut total_duration = Duration::new(0, 0); + for _ in 0..iters { + let start = std::time::Instant::now(); + let before = memory_stats().unwrap().physical_mem; + let mut vad_engine = Box::new(SileroVad::new().unwrap()); + let _ = stt( + test_file_path.to_string_lossy().as_ref(), + &whisper_model, + audio_transcription_engine.clone(), + &mut *vad_engine, + None, + ); + let after = memory_stats().unwrap().physical_mem; + let duration = start.elapsed(); + total_duration += duration; + println!("Memory used: {} bytes", after - before); + } + total_duration + }) + }); + group.finish(); } diff --git a/screenpipe-audio/src/stt.rs b/screenpipe-audio/src/stt.rs index 587fe739..814430e9 100644 --- a/screenpipe-audio/src/stt.rs +++ b/screenpipe-audio/src/stt.rs @@ -88,14 +88,11 @@ impl WhisperModel { } fn get_optimal_device() -> Result { - #[cfg(feature = "mkl")] - { - info!("Using MKL-accelerated CPU"); - Ok(Device::Cpu) - } - #[cfg(not(feature = "mkl"))] - { - info!("Using standard CPU"); + if let Ok(device) = Device::new_metal(0) { + info!("Using Metal GPU"); + Ok(device) + } else { + info!("Metal not available, falling back to CPU"); Ok(Device::Cpu) } } From 3ecb0fb2cdbba903d2b2767b451d4d410bcb6531 Mon Sep 17 00:00:00 2001 From: David Anyatonwu Date: Mon, 16 Sep 2024 12:33:04 +0100 Subject: [PATCH 3/5] revert benchmark to what it was initially Signed-off-by: David Anyatonwu --- .github/workflows/benchmark.yml | 21 ++++++++++++++++++++- .github/workflows/release-app.yml | 11 ++++++----- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 821cdf6a..01d85399 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -81,9 +81,11 @@ jobs: run: | sudo apt-get update sudo apt-get install -y ffmpeg tesseract-ocr libtesseract-dev libavformat-dev libavfilter-dev libavdevice-dev ffmpeg libasound2-dev libgtk-3-dev libsoup-3.0-dev libjavascriptcoregtk-4.1-dev libwebkit2gtk-4.1-dev + - name: Run STT benchmarks run: | cargo bench --bench stt_benchmark -- --output-format bencher | tee -a stt_output.txt + - name: Upload STT benchmark artifact uses: actions/upload-artifact@v3 with: @@ -125,4 +127,21 @@ jobs: output-file-path: ./cache/ocr/ocr_output.txt github-token: ${{ secrets.GH_PAGES_TOKEN }} auto-push: true - alert-threshold: "200%" \ No newline at end of file + alert-threshold: "200%" + comment-on-alert: true + fail-on-alert: true + alert-comment-cc-users: "@louis030195" + + # todo broken + # - name: Analyze STT benchmarks + # uses: benchmark-action/github-action-benchmark@v1 + # with: + # name: STT Benchmarks + # tool: "cargo" + # output-file-path: ./cache/stt/stt_output.txt + # github-token: ${{ secrets.GH_PAGES_TOKEN }} + # auto-push: true + # alert-threshold: "200%" + # comment-on-alert: true + # fail-on-alert: true + # alert-comment-cc-users: "@louis030195" \ No newline at end of file diff --git a/.github/workflows/release-app.yml b/.github/workflows/release-app.yml index b727ed28..a879147a 100644 --- a/.github/workflows/release-app.yml +++ b/.github/workflows/release-app.yml @@ -33,19 +33,20 @@ jobs: fail-fast: false matrix: include: - - platform: "macos-latest" + - platform: "macos-latest" # for Arm based macs (M1 and above). args: "--target aarch64-apple-darwin --features metal" target: aarch64-apple-darwin - - platform: "macos-latest" + - platform: "macos-latest" # for Intel based macs. args: "--target x86_64-apple-darwin --features metal" target: x86_64-apple-darwin - - platform: "ubuntu-22.04" + - platform: "ubuntu-22.04" # Ubuntu x86_64 args: "--features mkl" target: x86_64-unknown-linux-gnu - - platform: "windows-latest" + - platform: "windows-latest" # Windows x86_64 args: "--target x86_64-pc-windows-msvc --features mkl" target: x86_64-pc-windows-msvc - pre-build-args: "" + pre-build-args: "" # --openblas + # windows arm: https://github.com/ahqsoftwares/tauri-ahq-store/blob/2fbc2103c222662b3c6ee0cd71fcde664824f0ef/.github/workflows/publish.yml#L136 runs-on: ${{ matrix.platform }} steps: From 7c46091fdc796f9b4c9ff3b100f3b4213c6c04cf Mon Sep 17 00:00:00 2001 From: David Anyatonwu Date: Tue, 17 Sep 2024 07:51:33 +0100 Subject: [PATCH 4/5] implemented suggested comments Signed-off-by: David Anyatonwu --- .github/workflows/release-app.yml | 7 ++++++- screenpipe-audio/Cargo.toml | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/.github/workflows/release-app.yml b/.github/workflows/release-app.yml index a879147a..a8ce65d4 100644 --- a/.github/workflows/release-app.yml +++ b/.github/workflows/release-app.yml @@ -149,10 +149,15 @@ jobs: shell: bash run: | if [[ "${{ matrix.platform }}" == "macos-latest" ]]; then + # These settings are specific to macOS builds export PKG_CONFIG_PATH="/usr/local/opt/ffmpeg/lib/pkgconfig:$PKG_CONFIG_PATH" export PKG_CONFIG_ALLOW_CROSS=1 export RUSTFLAGS="-C link-arg=-Wl,-rpath,@executable_path/../Frameworks -C link-arg=-Wl,-rpath,@loader_path/../Frameworks -C link-arg=-Wl,-install_name,@rpath/libscreenpipe.dylib" - elif [[ "${{ matrix.platform }}" == "ubuntu-22.04" || "${{ matrix.platform }}" == "windows-latest" ]]; then + elif [[ "${{ matrix.platform }}" == "ubuntu-22.04" ]]; then + # Linux-specific settings (if any) + export RUSTFLAGS="-C target-cpu=native" + elif [[ "${{ matrix.platform }}" == "windows-latest" ]]; then + # Windows-specific settings (if any) export RUSTFLAGS="-C target-cpu=native" fi cargo build --release ${{ matrix.args }} diff --git a/screenpipe-audio/Cargo.toml b/screenpipe-audio/Cargo.toml index 017d9daa..ee79e403 100644 --- a/screenpipe-audio/Cargo.toml +++ b/screenpipe-audio/Cargo.toml @@ -80,8 +80,8 @@ criterion = { workspace = true } memory-stats = "1.0" [features] -default = ["metal"] metal = ["candle/metal", "candle-nn/metal", "candle-transformers/metal"] +cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"] mkl = ["candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"] From d8b8621b646fca564460d07e7fee4d083f1a5d93 Mon Sep 17 00:00:00 2001 From: David Anyatonwu Date: Tue, 17 Sep 2024 19:47:29 +0100 Subject: [PATCH 5/5] reintroduced cuda support Signed-off-by: David Anyatonwu --- screenpipe-audio/src/stt.rs | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/screenpipe-audio/src/stt.rs b/screenpipe-audio/src/stt.rs index 814430e9..3d2bb79b 100644 --- a/screenpipe-audio/src/stt.rs +++ b/screenpipe-audio/src/stt.rs @@ -88,13 +88,24 @@ impl WhisperModel { } fn get_optimal_device() -> Result { - if let Ok(device) = Device::new_metal(0) { - info!("Using Metal GPU"); - Ok(device) - } else { - info!("Metal not available, falling back to CPU"); - Ok(Device::Cpu) + #[cfg(target_os = "macos")] + { + if let Ok(device) = Device::new_metal(0) { + info!("Using Metal GPU"); + return Ok(device); + } + } + + #[cfg(not(target_os = "macos"))] + { + if let Ok(device) = Device::new_cuda(0) { + info!("Using CUDA GPU"); + return Ok(device); + } } + + info!("GPU not available, falling back to CPU"); + Ok(Device::Cpu) } }