Skip to content

Commit

Permalink
[ENH] Introduce stream abstraction and enable concurrency test for bl…
Browse files Browse the repository at this point in the history
…ockfile (#2454)

## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
	 - The tokio filesystem API will trigger panic if not invoked in the tokio runtime. However, the shuttle concurrency test has its runtime and can not properly run if tokio filesystem APIs are called. 
	 - To fix the above issue, this PR introduces the stream abstraction to support sync and async local filesystem APIs. 
 - New functionality
	 - ...

## Test plan
*How are these changes tested?*

- [ ] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust

## Documentation Changes
*Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs repository](https://github.com/chroma-core/docs)?*
  • Loading branch information
Ishiihara committed Jul 16, 2024
1 parent 4644217 commit 2e5bff4
Show file tree
Hide file tree
Showing 9 changed files with 412 additions and 129 deletions.
142 changes: 68 additions & 74 deletions rust/worker/src/blockstore/arrow/concurrency_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,88 +2,82 @@
mod tests {
use crate::{
blockstore::arrow::{config::TEST_MAX_BLOCK_SIZE_BYTES, provider::ArrowBlockfileProvider},
storage::{local::LocalStorage, Storage},
cache::{
cache::Cache,
config::{CacheConfig, UnboundedCacheConfig},
},
storage::{sync_local::SyncLocalStorage, Storage},
};
use rand::Rng;
use shuttle::{future, thread};

#[test]
fn test_blockfile_shuttle() {
// shuttle::check_random(
// || {
// let tmp_dir = tempfile::tempdir().unwrap();
// let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap()));
// let blockfile_provider = ArrowBlockfileProvider::new(storage);
// let writer = blockfile_provider.create::<&str, u32>().unwrap();
// let id = writer.id();
// // Generate N datapoints and then have T threads write them to the blockfile
// let range_min = 10;
// let range_max = 10000;
// let n = shuttle::rand::thread_rng().gen_range(range_min..range_max);
// // Make the max threads the number of cores * 2
// let max_threads = num_cpus::get() * 2;
// let t = shuttle::rand::thread_rng().gen_range(2..max_threads);
// let mut join_handles = Vec::with_capacity(t);
// for i in 0..t {
// let range_start = i * n / t;
// let range_end = (i + 1) * n / t;
// let writer = writer.clone();
// let handle = thread::spawn(move || {
// println!("Thread {} writing keys {} to {}", i, range_start, range_end);
// for j in range_start..range_end {
// let key_string = format!("key{}", j);
// future::block_on(async {
// writer
// .set::<&str, u32>("", key_string.as_str(), j as u32)
// .await
// .unwrap_or_else(|e| {
// println!(
// "Expect key to be set successfully, but got error: {:?}",
// e
// )
// });
// });
// }
// });
// join_handles.push(handle);
// }
shuttle::check_random(
|| {
let tmp_dir = tempfile::tempdir().unwrap();
let storage =
Storage::SyncLocal(SyncLocalStorage::new(tmp_dir.path().to_str().unwrap()));
let block_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let sparse_index_cache =
Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let blockfile_provider = ArrowBlockfileProvider::new(
storage,
TEST_MAX_BLOCK_SIZE_BYTES,
block_cache,
sparse_index_cache,
);
let writer = blockfile_provider.create::<&str, u32>().unwrap();
let id = writer.id();
// Generate N datapoints and then have T threads write them to the blockfile
let range_min = 10;
let range_max = 10000;
let n = shuttle::rand::thread_rng().gen_range(range_min..range_max);
// Make the max threads the number of cores * 2
let max_threads = num_cpus::get() * 2;
let t = shuttle::rand::thread_rng().gen_range(2..max_threads);
let mut join_handles = Vec::with_capacity(t);
for i in 0..t {
let range_start = i * n / t;
let range_end = (i + 1) * n / t;
let writer = writer.clone();
let handle = thread::spawn(move || {
for j in range_start..range_end {
let key_string = format!("key{}", j);
future::block_on(async {
writer
.set::<&str, u32>("", key_string.as_str(), j as u32)
.await
.unwrap();
});
}
});
join_handles.push(handle);
}

// for handle in join_handles {
// handle.join().unwrap();
// }
for handle in join_handles {
handle.join().unwrap();
}

// // commit the writer
// future::block_on(async {
// let flusher = writer.commit::<&str, u32>().unwrap();
// flusher.flush::<&str, u32>().await.unwrap();
// });
// commit the writer
future::block_on(async {
let flusher = writer.commit::<&str, u32>().unwrap();
flusher.flush::<&str, u32>().await.unwrap();
});

// let reader = future::block_on(async {
// blockfile_provider.open::<&str, u32>(&id).await.unwrap()
// });
// // Read the data back
// for i in 0..n {
// let key_string = format!("key{}", i);
// println!("Reading key {}", key_string);
// future::block_on(async {
// match reader.get("", key_string.as_str()).await {
// Ok(value) => {
// // value.expect("Expect key to exist and there to be no error");
// assert_eq!(value, i as u32);
// }
// Err(e) => {
// println!(
// "Expect key to exist and there to be no error, but got error: {:?}",
// e
// )
// }
// }
// });
// // let value = value.expect("Expect key to exist and there to be no error");
// // assert_eq!(value, i as u32);
// }
// },
// 100,
// );
let reader = future::block_on(async {
blockfile_provider.open::<&str, u32>(&id).await.unwrap()
});
// Read the data back
for i in 0..n {
let key_string = format!("key{}", i);
let value =
future::block_on(async { reader.get("", key_string.as_str()).await });
let value = value.expect("Expect key to exist and there to be no error");
assert_eq!(value, i as u32);
}
},
100,
);
}
}
75 changes: 44 additions & 31 deletions rust/worker/src/blockstore/arrow/provider.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::{
block::{self, delta::BlockDelta, Block},
block::{delta::BlockDelta, Block},
blockfile::{ArrowBlockfileReader, ArrowBlockfileWriter},
config::ArrowBlockfileProviderConfig,
sparse_index::SparseIndex,
Expand All @@ -15,12 +15,12 @@ use crate::{
},
config::Configurable,
errors::{ChromaError, ErrorCodes},
storage::{config::StorageConfig, Storage},
storage::Storage,
};
use async_trait::async_trait;
use core::panic;
use futures::StreamExt;
use thiserror::Error;
use tokio::io::AsyncReadExt;
use tracing::{Instrument, Span};
use uuid::Uuid;

Expand Down Expand Up @@ -213,28 +213,37 @@ impl BlockManager {
None => {
async {
let key = format!("block/{}", id);
let bytes = self.storage.get(&key).instrument(
let stream = self.storage.get(&key).instrument(
tracing::trace_span!(parent: Span::current(), "BlockManager storage get"),
).await;
let mut buf: Vec<u8> = Vec::new();
match bytes {
match stream {
Ok(mut bytes) => {
let res = bytes.read_to_end(&mut buf).instrument(
tracing::trace_span!(parent: Span::current(), "BlockManager read bytes to end"),
let read_block_span = tracing::trace_span!(parent: Span::current(), "BlockManager read bytes to end");
let buf = read_block_span.in_scope(|| async {
let mut buf: Vec<u8> = Vec::new();
while let Some(res) = bytes.next().await {
match res {
Ok(chunk) => {
buf.extend(chunk);
}
Err(e) => {
tracing::error!("Error reading block from storage: {}", e);
return None;
}
}
}
Some(buf)
}
).await;
tracing::info!("Read {:?} bytes from s3", buf.len());
match res {
Ok(_) => {}
Err(e) => {
// TODO: Return an error to callsite instead of None.
tracing::error!(
"Error reading block {:?} from s3 {:?}",
key,
e
);
let buf = match buf {
Some(buf) => {
buf
}
None => {
return None;
}
}
};
tracing::info!("Read {:?} bytes from s3", buf.len());
let deserialization_span = tracing::trace_span!(parent: Span::current(), "BlockManager deserialize block");
let block = deserialization_span.in_scope(|| Block::from_bytes(&buf, *id));
match block {
Expand All @@ -252,10 +261,9 @@ impl BlockManager {
None
}
}
}
},
Err(e) => {
// TODO: Return an error to callsite instead of None.
tracing::error!("Error reading block {:?} from s3 {:?}", key, e);
tracing::error!("Error reading block from storage: {}", e);
None
}
}
Expand Down Expand Up @@ -330,17 +338,22 @@ impl SparseIndexManager {
tracing::info!("Cache miss - fetching sparse index from storage");
let key = format!("sparse_index/{}", id);
tracing::debug!("Reading sparse index from storage with key: {}", key);
let bytes = self.storage.get(&key).await;
let stream = self.storage.get(&key).await;
let mut buf: Vec<u8> = Vec::new();
match bytes {
match stream {
Ok(mut bytes) => {
let res = bytes.read_to_end(&mut buf).await;
match res {
Ok(_) => {}
Err(e) => {
// TODO: return error
tracing::error!("Error reading sparse index from storage: {}", e);
return None;
while let Some(res) = bytes.next().await {
match res {
Ok(chunk) => {
buf.extend(chunk);
}
Err(e) => {
tracing::error!(
"Error reading sparse index from storage: {}",
e
);
return None;
}
}
}
let block = Block::from_bytes(&buf, *id);
Expand Down
58 changes: 47 additions & 11 deletions rust/worker/src/index/hnsw_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@ use super::{
};
use crate::errors::ErrorCodes;
use crate::index::types::PersistentIndex;
use crate::storage::stream::ByteStreamItem;
use crate::{errors::ChromaError, storage::Storage, types::Segment};
use futures::stream;
use futures::stream::StreamExt;
use parking_lot::RwLock;
use std::fmt::Debug;
use std::path::Path;
use std::{collections::HashMap, path::PathBuf, sync::Arc};
use thiserror::Error;
use tokio::io::AsyncWriteExt;
use tracing::{instrument, Instrument, Span};
use uuid::Uuid;

Expand Down Expand Up @@ -130,8 +134,9 @@ impl HnswIndexProvider {
// Fetch the files from storage and put them in the index storage path
for file in FILES.iter() {
let key = self.format_key(source_id, file);
let res = self.storage.get(&key).await;
let mut reader = match res {
tracing::info!("Loading hnsw index file: {}", key);
let stream = self.storage.get(&key).await;
let reader = match stream {
Ok(reader) => reader,
Err(e) => {
tracing::error!("Failed to load hnsw index file from storage: {}", e);
Expand All @@ -142,27 +147,58 @@ impl HnswIndexProvider {
let file_path = index_storage_path.join(file);
// For now, we never evict from the cache, so if the index is being loaded, the file does not exist
let file_handle = tokio::fs::File::create(&file_path).await;
let mut file_handle = match file_handle {
let file_handle = match file_handle {
Ok(file) => file,
Err(e) => {
tracing::error!("Failed to create file: {}", e);
return Err(Box::new(HnswIndexProviderFileError::IOError(e)));
}
};
let copy_res = tokio::io::copy(&mut reader, &mut file_handle)
.instrument(tracing::info_span!(parent: Span::current(), "hnsw provider file read", file = file))
.await;
match copy_res {
Ok(bytes_read) => {
tracing::info!("Copied {} bytes to file {:?}", bytes_read, file_path);
let total_bytes_written = self.copy_stream_to_local_file(reader, file_handle).await?;
tracing::info!(
"Copied {} bytes from storage key: {} to file: {}",
total_bytes_written,
key,
file_path.to_str().unwrap()
);
// bytes is an AsyncBufRead, so we fil and consume it to a file
tracing::info!("Loaded hnsw index file: {}", file);
}
Ok(())
}

async fn copy_stream_to_local_file(
&self,
stream: Box<dyn stream::Stream<Item = ByteStreamItem> + Unpin + Send>,
file_handle: tokio::fs::File,
) -> Result<u64, Box<HnswIndexProviderFileError>> {
let mut total_bytes_written = 0;
let mut file_handle = file_handle;
let mut stream = stream;
while let Some(res) = stream.next().await {
let chunk = match res {
Ok(chunk) => chunk,
Err(e) => {
return Err(Box::new(HnswIndexProviderFileError::StorageGetError(e)));
}
};

let res = file_handle.write_all(&chunk).await;
match res {
Ok(_) => {
total_bytes_written += chunk.len() as u64;
}
Err(e) => {
tracing::error!("Failed to copy file: {}", e);
return Err(Box::new(HnswIndexProviderFileError::IOError(e)));
}
}
}
Ok(())
match file_handle.flush().await {
Ok(_) => Ok(total_bytes_written),
Err(e) => {
return Err(Box::new(HnswIndexProviderFileError::IOError(e)));
}
}
}

pub(crate) async fn open(
Expand Down
2 changes: 2 additions & 0 deletions rust/worker/src/storage/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ pub(crate) enum StorageConfig {
S3(S3StorageConfig),
#[serde(alias = "local")]
Local(LocalStorageConfig),
#[serde(alias = "sync_local")]
SyncLocal(LocalStorageConfig),
}

#[derive(Deserialize, PartialEq, Debug)]
Expand Down
Loading

0 comments on commit 2e5bff4

Please sign in to comment.