Skip to content

Commit

Permalink
fix: major performance progress (memory leak on Candle MacOS implemen…
Browse files Browse the repository at this point in the history
…tation)
  • Loading branch information
louis030195 committed Sep 1, 2024
1 parent 194eb11 commit a2fbc5d
Show file tree
Hide file tree
Showing 9 changed files with 93 additions and 26 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ resolver = "2"


[workspace.package]
version = "0.1.70"
version = "0.1.71"
authors = ["louis030195 <hi@louis030195.com>"]
description = ""
repository = "https://github.com/mediar-ai/screenpipe"
Expand Down
2 changes: 1 addition & 1 deletion examples/apps/screenpipe-app-tauri/src-tauri/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "screenpipe-app"
version = "0.1.77"
version = "0.1.78"
description = ""
authors = ["you"]
license = ""
Expand Down
1 change: 1 addition & 0 deletions screenpipe-audio/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ screenpipe-core = { path = "../screenpipe-core" }

[target.'cfg(target_os = "macos")'.dependencies]
once_cell = "1.17.1"
objc = "0.2.7"

[dev-dependencies]
tempfile = "3.3.0"
Expand Down
2 changes: 1 addition & 1 deletion screenpipe-audio/src/bin/screenpipe-audio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ async fn main() -> Result<()> {

let chunk_duration = Duration::from_secs(10);
let output_path = PathBuf::from("output.mp4");
let (whisper_sender, mut whisper_receiver) =
let (whisper_sender, mut whisper_receiver, _) =
create_whisper_channel(Arc::new(AudioTranscriptionEngine::WhisperDistilLargeV3)).await?;
// Spawn threads for each device
let recording_threads: Vec<_> = devices
Expand Down
16 changes: 14 additions & 2 deletions screenpipe-audio/src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ pub async fn record_and_transcribe(
);

// TODO: consider a lock-free ring buffer like crossbeam_queue::ArrayQueue (ask AI why)
let (tx, rx) = mpsc::channel(1000); // For audio data
let (tx, rx) = mpsc::channel(100); // For audio data
let is_running_clone = Arc::clone(&is_running);
let is_running_clone_2 = is_running.clone();
let is_running_clone_3 = is_running.clone();
Expand All @@ -294,7 +294,7 @@ pub async fn record_and_transcribe(
}
};
// Spawn a thread to handle the non-Send stream
thread::spawn(move || {
let audio_handle = thread::spawn(move || {
let stream = match config.sample_format() {
cpal::SampleFormat::I8 => cpal_audio_device.build_input_stream(
&config.into(),
Expand Down Expand Up @@ -351,6 +351,8 @@ pub async fn record_and_transcribe(
while is_running_clone.load(Ordering::Relaxed) {
std::thread::sleep(Duration::from_millis(100));
}
s.pause().ok();
drop(s);
}
Err(e) => error!("Failed to build input stream: {}", e),
}
Expand Down Expand Up @@ -381,6 +383,16 @@ pub async fn record_and_transcribe(
// Signal the recording thread to stop
is_running.store(false, Ordering::Relaxed); // TODO: could also just kill the trhead..

// Wait for the native thread to finish
if let Err(e) = audio_handle.join() {
error!("Error joining audio thread: {:?}", e);
}

tokio::fs::File::open(&output_path_clone_2.to_path_buf())
.await?
.sync_all()
.await?;

debug!("Sending audio to audio model");
if let Err(e) = whisper_sender.send(AudioInput {
path: output_path_clone_2.to_str().unwrap().to_string(),
Expand Down
70 changes: 55 additions & 15 deletions screenpipe-audio/src/stt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -717,11 +717,14 @@ pub struct TranscriptionResult {
pub timestamp: u64,
pub error: Option<String>,
}
use std::sync::atomic::{AtomicBool, Ordering};

pub async fn create_whisper_channel(
audio_transcription_engine: Arc<AudioTranscriptionEngine>,
) -> Result<(
UnboundedSender<AudioInput>,
UnboundedReceiver<TranscriptionResult>,
Arc<AtomicBool>, // Shutdown flag
)> {
let whisper_model = WhisperModel::new(audio_transcription_engine.clone())?;
let (input_sender, mut input_receiver): (
Expand All @@ -733,31 +736,67 @@ pub async fn create_whisper_channel(
UnboundedReceiver<TranscriptionResult>,
) = unbounded_channel();

let shutdown_flag = Arc::new(AtomicBool::new(false));
let shutdown_flag_clone = shutdown_flag.clone();

tokio::spawn(async move {
loop {
if shutdown_flag_clone.load(Ordering::Relaxed) {
info!("Whisper channel shutting down");
break;
}

tokio::select! {
Some(input) = input_receiver.recv() => {
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Time went backwards")
.as_secs();

let transcription_result = match stt(&input.path, &whisper_model, audio_transcription_engine.clone()) {
Ok(transcription) => TranscriptionResult {
input: input.clone(),
transcription: Some(transcription),
timestamp,
error: None,
},
Err(e) => {
error!("STT error for input {}: {:?}", input.path, e);
TranscriptionResult {
#[cfg(target_os = "macos")]
use objc::{rc::autoreleasepool};

let transcription_result = if cfg!(target_os = "macos") {
#[cfg(target_os = "macos")]
{
autoreleasepool(|| {
match stt(&input.path, &whisper_model, audio_transcription_engine.clone()) {
Ok(transcription) => TranscriptionResult {
input: input.clone(),
transcription: Some(transcription),
timestamp,
error: None,
},
Err(e) => {
error!("STT error for input {}: {:?}", input.path, e);
TranscriptionResult {
input: input.clone(),
transcription: None,
timestamp,
error: Some(e.to_string()),
}
},
}
})
}
} else {
match stt(&input.path, &whisper_model, audio_transcription_engine.clone()) {
Ok(transcription) => TranscriptionResult {
input: input.clone(),
transcription: None,
transcription: Some(transcription),
timestamp,
error: Some(e.to_string()),
}
},
error: None,
},
Err(e) => {
error!("STT error for input {}: {:?}", input.path, e);
TranscriptionResult {
input: input.clone(),
transcription: None,
timestamp,
error: Some(e.to_string()),
}
},
}
};

if output_sender.send(transcription_result).is_err() {
Expand All @@ -767,7 +806,8 @@ pub async fn create_whisper_channel(
else => break,
}
}
// Cleanup code here (if needed)
});

Ok((input_sender, output_receiver))
Ok((input_sender, output_receiver, shutdown_flag))
}
22 changes: 18 additions & 4 deletions screenpipe-server/src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,23 @@ pub async fn start_continuous_recording(
monitor_id: u32,
use_pii_removal: bool,
) -> Result<()> {
let (whisper_sender, whisper_receiver) = if audio_disabled {
let (whisper_sender, whisper_receiver, whisper_shutdown_flag) = if audio_disabled {
// Create a dummy channel if no audio devices are available, e.g. audio disabled
let (input_sender, _): (UnboundedSender<AudioInput>, UnboundedReceiver<AudioInput>) =
unbounded_channel();
let (_, output_receiver): (
UnboundedSender<TranscriptionResult>,
UnboundedReceiver<TranscriptionResult>,
) = unbounded_channel();
(input_sender, output_receiver)
(
input_sender,
output_receiver,
Arc::new(AtomicBool::new(false)),
)
} else {
create_whisper_channel(audio_transcription_engine.clone()).await?
};
let whisper_sender_clone = whisper_sender.clone();
let db_manager_video = Arc::clone(&db);
let db_manager_audio = Arc::clone(&db);

Expand Down Expand Up @@ -92,16 +97,25 @@ pub async fn start_continuous_recording(
.await
});

let video_result = video_handle.await;
let audio_result = audio_handle.await;
// Wait for both tasks to complete
let (video_result, audio_result) = tokio::join!(video_handle, audio_handle);

// Handle any errors from the tasks
if let Err(e) = video_result {
error!("Video recording error: {:?}", e);
}
if let Err(e) = audio_result {
error!("Audio recording error: {:?}", e);
}

// Shutdown the whisper channel
whisper_shutdown_flag.store(true, Ordering::Relaxed);
drop(whisper_sender_clone); // Close the sender channel

// TODO: process any remaining audio chunks
// TODO: wait a bit for whisper to finish processing
// TODO: any additional cleanup like device controls to release

info!("Stopped recording");
Ok(())
}
Expand Down
1 change: 0 additions & 1 deletion screenpipe-server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,6 @@ impl Server {
)
.with_state(app_state);

info!("Starting server on {}", self.addr);

match serve(TcpListener::bind(self.addr).await?, app.into_make_service()).await {
Ok(_) => {
Expand Down
3 changes: 2 additions & 1 deletion screenpipe-vision/src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,8 @@ pub async fn process_ocr_task(
Ok(())
}

fn parse_json_output(json_output: &str) -> Vec<HashMap<String, String>> {
fn parse_json_output(json_output: &str) -> Vec<HashMap<String, String>> {
// TODO: this function uses a TONN of memory and is not efficient and we should use binary serialization instead
let parsed_output: Vec<HashMap<String, String>> = serde_json::from_str(json_output)
.unwrap_or_else(|e| {
error!("Failed to parse JSON output: {}", e);
Expand Down

0 comments on commit a2fbc5d

Please sign in to comment.