From a2fbc5d08aace57ae3bfa546636567c0f9297c8f Mon Sep 17 00:00:00 2001 From: Louis Beaumont Date: Sun, 1 Sep 2024 16:25:45 -0700 Subject: [PATCH] fix: major performance progress (memory leak on Candle MacOS implementation) --- Cargo.toml | 2 +- .../screenpipe-app-tauri/src-tauri/Cargo.toml | 2 +- screenpipe-audio/Cargo.toml | 1 + screenpipe-audio/src/bin/screenpipe-audio.rs | 2 +- screenpipe-audio/src/core.rs | 16 ++++- screenpipe-audio/src/stt.rs | 70 +++++++++++++++---- screenpipe-server/src/core.rs | 22 ++++-- screenpipe-server/src/server.rs | 1 - screenpipe-vision/src/core.rs | 3 +- 9 files changed, 93 insertions(+), 26 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 87f7f4a6..ed837e52 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,7 @@ resolver = "2" [workspace.package] -version = "0.1.70" +version = "0.1.71" authors = ["louis030195 "] description = "" repository = "https://github.com/mediar-ai/screenpipe" diff --git a/examples/apps/screenpipe-app-tauri/src-tauri/Cargo.toml b/examples/apps/screenpipe-app-tauri/src-tauri/Cargo.toml index 2627bf8c..ceb5179f 100644 --- a/examples/apps/screenpipe-app-tauri/src-tauri/Cargo.toml +++ b/examples/apps/screenpipe-app-tauri/src-tauri/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "screenpipe-app" -version = "0.1.77" +version = "0.1.78" description = "" authors = ["you"] license = "" diff --git a/screenpipe-audio/Cargo.toml b/screenpipe-audio/Cargo.toml index 6e8fc29e..cfb974c5 100644 --- a/screenpipe-audio/Cargo.toml +++ b/screenpipe-audio/Cargo.toml @@ -63,6 +63,7 @@ screenpipe-core = { path = "../screenpipe-core" } [target.'cfg(target_os = "macos")'.dependencies] once_cell = "1.17.1" +objc = "0.2.7" [dev-dependencies] tempfile = "3.3.0" diff --git a/screenpipe-audio/src/bin/screenpipe-audio.rs b/screenpipe-audio/src/bin/screenpipe-audio.rs index 0785a34a..c034bcd9 100644 --- a/screenpipe-audio/src/bin/screenpipe-audio.rs +++ b/screenpipe-audio/src/bin/screenpipe-audio.rs @@ -78,7 +78,7 @@ async fn main() -> Result<()> { let chunk_duration = Duration::from_secs(10); let output_path = PathBuf::from("output.mp4"); - let (whisper_sender, mut whisper_receiver) = + let (whisper_sender, mut whisper_receiver, _) = create_whisper_channel(Arc::new(AudioTranscriptionEngine::WhisperDistilLargeV3)).await?; // Spawn threads for each device let recording_threads: Vec<_> = devices diff --git a/screenpipe-audio/src/core.rs b/screenpipe-audio/src/core.rs index 1d547241..ce4fe8b0 100644 --- a/screenpipe-audio/src/core.rs +++ b/screenpipe-audio/src/core.rs @@ -276,7 +276,7 @@ pub async fn record_and_transcribe( ); // TODO: consider a lock-free ring buffer like crossbeam_queue::ArrayQueue (ask AI why) - let (tx, rx) = mpsc::channel(1000); // For audio data + let (tx, rx) = mpsc::channel(100); // For audio data let is_running_clone = Arc::clone(&is_running); let is_running_clone_2 = is_running.clone(); let is_running_clone_3 = is_running.clone(); @@ -294,7 +294,7 @@ pub async fn record_and_transcribe( } }; // Spawn a thread to handle the non-Send stream - thread::spawn(move || { + let audio_handle = thread::spawn(move || { let stream = match config.sample_format() { cpal::SampleFormat::I8 => cpal_audio_device.build_input_stream( &config.into(), @@ -351,6 +351,8 @@ pub async fn record_and_transcribe( while is_running_clone.load(Ordering::Relaxed) { std::thread::sleep(Duration::from_millis(100)); } + s.pause().ok(); + drop(s); } Err(e) => error!("Failed to build input stream: {}", e), } @@ -381,6 +383,16 @@ pub async fn record_and_transcribe( // Signal the recording thread to stop is_running.store(false, Ordering::Relaxed); // TODO: could also just kill the trhead.. + // Wait for the native thread to finish + if let Err(e) = audio_handle.join() { + error!("Error joining audio thread: {:?}", e); + } + + tokio::fs::File::open(&output_path_clone_2.to_path_buf()) + .await? + .sync_all() + .await?; + debug!("Sending audio to audio model"); if let Err(e) = whisper_sender.send(AudioInput { path: output_path_clone_2.to_str().unwrap().to_string(), diff --git a/screenpipe-audio/src/stt.rs b/screenpipe-audio/src/stt.rs index 54705506..16c25d3d 100644 --- a/screenpipe-audio/src/stt.rs +++ b/screenpipe-audio/src/stt.rs @@ -717,11 +717,14 @@ pub struct TranscriptionResult { pub timestamp: u64, pub error: Option, } +use std::sync::atomic::{AtomicBool, Ordering}; + pub async fn create_whisper_channel( audio_transcription_engine: Arc, ) -> Result<( UnboundedSender, UnboundedReceiver, + Arc, // Shutdown flag )> { let whisper_model = WhisperModel::new(audio_transcription_engine.clone())?; let (input_sender, mut input_receiver): ( @@ -733,8 +736,16 @@ pub async fn create_whisper_channel( UnboundedReceiver, ) = unbounded_channel(); + let shutdown_flag = Arc::new(AtomicBool::new(false)); + let shutdown_flag_clone = shutdown_flag.clone(); + tokio::spawn(async move { loop { + if shutdown_flag_clone.load(Ordering::Relaxed) { + info!("Whisper channel shutting down"); + break; + } + tokio::select! { Some(input) = input_receiver.recv() => { let timestamp = SystemTime::now() @@ -742,22 +753,50 @@ pub async fn create_whisper_channel( .expect("Time went backwards") .as_secs(); - let transcription_result = match stt(&input.path, &whisper_model, audio_transcription_engine.clone()) { - Ok(transcription) => TranscriptionResult { - input: input.clone(), - transcription: Some(transcription), - timestamp, - error: None, - }, - Err(e) => { - error!("STT error for input {}: {:?}", input.path, e); - TranscriptionResult { + #[cfg(target_os = "macos")] + use objc::{rc::autoreleasepool}; + + let transcription_result = if cfg!(target_os = "macos") { + #[cfg(target_os = "macos")] + { + autoreleasepool(|| { + match stt(&input.path, &whisper_model, audio_transcription_engine.clone()) { + Ok(transcription) => TranscriptionResult { + input: input.clone(), + transcription: Some(transcription), + timestamp, + error: None, + }, + Err(e) => { + error!("STT error for input {}: {:?}", input.path, e); + TranscriptionResult { + input: input.clone(), + transcription: None, + timestamp, + error: Some(e.to_string()), + } + }, + } + }) + } + } else { + match stt(&input.path, &whisper_model, audio_transcription_engine.clone()) { + Ok(transcription) => TranscriptionResult { input: input.clone(), - transcription: None, + transcription: Some(transcription), timestamp, - error: Some(e.to_string()), - } - }, + error: None, + }, + Err(e) => { + error!("STT error for input {}: {:?}", input.path, e); + TranscriptionResult { + input: input.clone(), + transcription: None, + timestamp, + error: Some(e.to_string()), + } + }, + } }; if output_sender.send(transcription_result).is_err() { @@ -767,7 +806,8 @@ pub async fn create_whisper_channel( else => break, } } + // Cleanup code here (if needed) }); - Ok((input_sender, output_receiver)) + Ok((input_sender, output_receiver, shutdown_flag)) } diff --git a/screenpipe-server/src/core.rs b/screenpipe-server/src/core.rs index 028e89ca..dbe26220 100644 --- a/screenpipe-server/src/core.rs +++ b/screenpipe-server/src/core.rs @@ -33,7 +33,7 @@ pub async fn start_continuous_recording( monitor_id: u32, use_pii_removal: bool, ) -> Result<()> { - let (whisper_sender, whisper_receiver) = if audio_disabled { + let (whisper_sender, whisper_receiver, whisper_shutdown_flag) = if audio_disabled { // Create a dummy channel if no audio devices are available, e.g. audio disabled let (input_sender, _): (UnboundedSender, UnboundedReceiver) = unbounded_channel(); @@ -41,10 +41,15 @@ pub async fn start_continuous_recording( UnboundedSender, UnboundedReceiver, ) = unbounded_channel(); - (input_sender, output_receiver) + ( + input_sender, + output_receiver, + Arc::new(AtomicBool::new(false)), + ) } else { create_whisper_channel(audio_transcription_engine.clone()).await? }; + let whisper_sender_clone = whisper_sender.clone(); let db_manager_video = Arc::clone(&db); let db_manager_audio = Arc::clone(&db); @@ -92,9 +97,10 @@ pub async fn start_continuous_recording( .await }); - let video_result = video_handle.await; - let audio_result = audio_handle.await; + // Wait for both tasks to complete + let (video_result, audio_result) = tokio::join!(video_handle, audio_handle); + // Handle any errors from the tasks if let Err(e) = video_result { error!("Video recording error: {:?}", e); } @@ -102,6 +108,14 @@ pub async fn start_continuous_recording( error!("Audio recording error: {:?}", e); } + // Shutdown the whisper channel + whisper_shutdown_flag.store(true, Ordering::Relaxed); + drop(whisper_sender_clone); // Close the sender channel + + // TODO: process any remaining audio chunks + // TODO: wait a bit for whisper to finish processing + // TODO: any additional cleanup like device controls to release + info!("Stopped recording"); Ok(()) } diff --git a/screenpipe-server/src/server.rs b/screenpipe-server/src/server.rs index 4380df90..a045b73e 100644 --- a/screenpipe-server/src/server.rs +++ b/screenpipe-server/src/server.rs @@ -539,7 +539,6 @@ impl Server { ) .with_state(app_state); - info!("Starting server on {}", self.addr); match serve(TcpListener::bind(self.addr).await?, app.into_make_service()).await { Ok(_) => { diff --git a/screenpipe-vision/src/core.rs b/screenpipe-vision/src/core.rs index f79c9099..d0316256 100644 --- a/screenpipe-vision/src/core.rs +++ b/screenpipe-vision/src/core.rs @@ -269,7 +269,8 @@ pub async fn process_ocr_task( Ok(()) } -fn parse_json_output(json_output: &str) -> Vec> { +fn parse_json_output(json_output: &str) -> Vec> { + // TODO: this function uses a TONN of memory and is not efficient and we should use binary serialization instead let parsed_output: Vec> = serde_json::from_str(json_output) .unwrap_or_else(|e| { error!("Failed to parse JSON output: {}", e);