diff --git a/workflows/index-generation/ncbi-compress/src/lib.rs b/workflows/index-generation/ncbi-compress/src/lib.rs new file mode 100644 index 000000000..f62192aaf --- /dev/null +++ b/workflows/index-generation/ncbi-compress/src/lib.rs @@ -0,0 +1,2 @@ +pub mod ncbi_compress; +pub mod util; diff --git a/workflows/index-generation/ncbi-compress/src/main.rs b/workflows/index-generation/ncbi-compress/src/main.rs index 8b60191a7..d6583678b 100644 --- a/workflows/index-generation/ncbi-compress/src/main.rs +++ b/workflows/index-generation/ncbi-compress/src/main.rs @@ -1,392 +1,12 @@ use std::io::Write; -use std::ops::AddAssign; -use std::path::Path; -use std::{borrow::BorrowMut, fs}; -use bio::io::fasta; use chrono::Local; use clap::Parser; use env_logger::Builder; use log::LevelFilter; -use rayon::prelude::*; -use sourmash::encodings::HashFunctions; -use sourmash::errors::SourmashError; -use sourmash::signature::SigsTrait; -use sourmash::sketch::minhash::KmerMinHash; -use tempdir::TempDir; -use trie_rs::{Trie, TrieBuilder}; -/// A trie that stores u64 values -struct TrieStore { - trie: Trie, -} - -impl TrieStore { - pub fn get(&self, key: &str) -> Option { - self.trie - .predictive_search(key.as_bytes()) - .first() - .map(|bytes| { - let mut bytes = bytes.to_vec(); - let value_bytes = bytes.split_off(bytes.len() - 8); - u64::from_be_bytes(value_bytes.try_into().unwrap()) - }) - } -} - -struct TrieStoreBuilder { - builder: TrieBuilder, -} - -impl TrieStoreBuilder { - pub fn new() -> Self { - TrieStoreBuilder { - builder: TrieBuilder::new(), - } - } - - pub fn push(&mut self, key: &str, value: u64) { - let mut key = key.as_bytes().to_vec(); - key.extend_from_slice(&value.to_be_bytes()); - self.builder.push(key); - } - - pub fn build(self) -> TrieStore { - let trie = self.builder.build(); - TrieStore { trie } - } -} - -fn containment(needle: &KmerMinHash, haystack: &KmerMinHash) -> Result { - let (intersect_size, _) = needle.intersection_size(haystack)?; - Ok(intersect_size as f64 / needle.mins().len() as f64) -} - -struct MinHashTreeNode { - own: KmerMinHash, - children_aggregate: KmerMinHash, -} - -struct MinHashTree { - branching_factor: usize, - nodes: Vec, -} - -impl MinHashTree { - fn new(branching_factor: usize) -> Self { - MinHashTree { - branching_factor, - nodes: Vec::new(), - } - } - - fn parent_idx(&self, node: usize) -> Option { - if node == 0 { - None - } else { - Some((node - 1) / self.branching_factor) - } - } - - fn child_idxes(&self, node: usize) -> Vec { - let first = self.branching_factor * node + 0 + 1; - let last = self.branching_factor * node + self.branching_factor + 1; - (first..last.min(self.nodes.len() - 1)).collect() - } - - fn merge_to_parent( - &mut self, - parent_idx: usize, - child_idx: usize, - ) -> Result<(), SourmashError> { - let (left, right) = self.nodes.split_at_mut(child_idx); - left[parent_idx] - .children_aggregate - .merge(&right[0].children_aggregate) - } - - pub fn insert(&mut self, hash: KmerMinHash) -> Result<(), SourmashError> { - let node = MinHashTreeNode { - own: hash.clone(), - children_aggregate: hash.clone(), - }; - let mut current_idx = self.nodes.len(); - self.nodes.push(node); - - while let Some(parent_idx) = self.parent_idx(current_idx) { - // no need to aggregate to the root, it would just contain everything thus providing no information - if parent_idx == 0 { - break; - } - - self.merge_to_parent(parent_idx, current_idx)?; - current_idx = parent_idx; - } - Ok(()) - } - - pub fn contains( - &self, - hash: &KmerMinHash, - similarity_threshold: f64, - ) -> Result { - if self.nodes.is_empty() { - return Ok(false); - } - - let mut to_visit = vec![0]; - while !to_visit.is_empty() { - let found = to_visit.par_iter().any(|node_idx| { - let node = self.nodes.get(*node_idx).unwrap(); - containment(hash, &node.own).unwrap() >= similarity_threshold - }); - - if found { - return Ok(true); - } - - to_visit = to_visit - .par_iter() - .flat_map(|node_idx| { - let node = self.nodes.get(*node_idx).unwrap(); - if containment(hash, &node.children_aggregate).unwrap() >= similarity_threshold - { - self.child_idxes(*node_idx) - } else { - vec![] - } - }) - .collect(); - } - Ok(false) - } -} - -fn remove_accesion_version(accession: &str) -> &str { - accession.splitn(2, |c| c == '.').next().unwrap() -} - -fn split_accessions_by_taxid + std::fmt::Debug, Q: AsRef + std::fmt::Debug>( - input_fasta_path: P, - mapping_file_path: Vec, - taxids_to_drop: &Vec, -) -> TempDir { - log::info!("Creating accession to taxid mapping"); - let taxid_dir = TempDir::new("accessions_by_taxid").unwrap(); - let reader = fasta::Reader::from_file(&input_fasta_path).unwrap(); - let mut builder = TrieBuilder::new(); - reader.records().enumerate().for_each(|(i, result)| { - let record = result.unwrap(); - let accession_id = record.id().split_whitespace().next().unwrap(); - let accession_no_version = remove_accesion_version(accession_id); - builder.push(accession_no_version); - if i % 10_000 == 0 { - log::info!(" Processed {} accessions", i); - } - }); - log::info!(" Started building accession trie"); - let accessions_trie = builder.build(); - log::info!(" Finished building accession trie"); - - let mut builder = TrieStoreBuilder::new(); - mapping_file_path.iter().for_each(|mapping_file_path| { - log::info!(" Processing mapping file {:?}", mapping_file_path); - let reader = csv::ReaderBuilder::new() - .delimiter(b'\t') - .from_path(mapping_file_path) - .unwrap(); - let mut added = 0; - reader.into_records().enumerate().for_each(|(i, result)| { - if i % 10_000 == 0 { - log::info!(" Processed {} mappings, added {}", i, added); - } - - let record = result.unwrap(); - let accession = &record[0]; - let accession_no_version = remove_accesion_version(accession); - - // Only output mappings if the accession is in the source files - if !accessions_trie.exact_match(accession_no_version) { - return; - } - - // If using the prot.accession2taxid.FULL file - let taxid = if record.len() < 3 { - // The taxid will be at index 1 - record[1].parse::().unwrap() - } else { - // Otherwise there is a versionless accession ID at index 0 and the taxid is at index 2 - record[2].parse::().unwrap() - }; - - if !taxids_to_drop.contains(&taxid) { - added += 1; - builder.push(accession_no_version, taxid); - } - }); - log::info!( - " Finished Processing mapping file {:?}, added {} mappings", - mapping_file_path, - added - ); - }); - log::info!(" Started building accession to taxid trie"); - let accession_to_taxid = builder.build(); - log::info!("Finished building accession to taxid trie"); - - log::info!("Splitting accessions by taxid"); - let reader = fasta::Reader::from_file(&input_fasta_path).unwrap(); - for (i, record) in reader.records().enumerate() { - if i % 10_000 == 0 { - log::info!(" Split {} accessions", i); - } - let record = record.unwrap(); - let accession_id = record.id().split_whitespace().next().unwrap(); - let acccession_no_version = remove_accesion_version(accession_id); - let taxid = if let Some(taxid) = accession_to_taxid.get(acccession_no_version) { - taxid - } else { - continue; - }; - - let file_path = taxid_dir.path().join(format!("{}.fasta", taxid)); - let file = fs::OpenOptions::new() - .create(true) - .append(true) - .open(file_path) - .unwrap(); - let mut writer = fasta::Writer::new(file); - writer.write_record(&record).unwrap(); - } - log::info!("Finished splitting accessions by taxid"); - - taxid_dir -} - -fn fasta_compress_taxid + std::fmt::Debug>( - input_fasta_path: P, - writer: &mut fasta::Writer, - scaled: u64, - k: u32, - seed: u64, - similarity_threshold: f64, - chunk_size: usize, - branch_factor: usize, - accession_count: &mut u64, - unique_accession_count: &mut u64, -) { - let reader = fasta::Reader::from_file(&input_fasta_path).unwrap(); - let mut tree = MinHashTree::new(branch_factor); - - let mut records_iter = reader.records(); - let mut chunk = records_iter - .borrow_mut() - .take(chunk_size) - .collect::>(); - while chunk.len() > 0 { - let chunk_signatures = chunk - .par_iter() - .filter_map(|r| { - let record = r.as_ref().unwrap(); - let mut hash = - KmerMinHash::new(scaled, k, HashFunctions::murmur64_DNA, seed, false, 0); - hash.add_sequence(record.seq(), true).unwrap(); - // Run an initial similarity check here against the full tree, this is slow so we can parallelize it - if tree.contains(&hash, similarity_threshold).unwrap() { - None - } else { - Some((record, hash)) - } - }) - .collect::>(); - - let mut tmp = Vec::with_capacity(chunk_signatures.len() / 2); - for (record, hash) in chunk_signatures { - accession_count.add_assign(1); - // Perform a faster similarity check over just this chunk because we may have similarities within a chunk - let similar = tmp - .par_iter() - .any(|other| containment(&hash, other).unwrap() >= similarity_threshold); - - if !similar { - unique_accession_count.add_assign(1); - tmp.push(hash); - writer.write_record(record).unwrap(); - - if *unique_accession_count % 10_000 == 0 { - log::info!( - "Processed {} accessions, {} unique", - accession_count, - unique_accession_count - ); - } - } - } - for hash in tmp { - tree.insert(hash).unwrap(); - } - chunk = records_iter - .borrow_mut() - .take(chunk_size) - .collect::>(); - } -} - -fn fasta_compress + std::fmt::Debug>( - input_fasta_path: P, - accession_mapping_files: Vec

, - output_fasta_path: P, - taxids_to_drop: Vec, - scaled: u64, - k: u32, - seed: u64, - similarity_threshold: f64, - chunk_size: usize, - branch_factor: usize, -) { - log::info!("Splitting accessions by taxid"); - let taxid_dir = - split_accessions_by_taxid(&input_fasta_path, accession_mapping_files, &taxids_to_drop); - log::info!("Finished splitting accessions by taxid"); - let mut writer = fasta::Writer::to_file(output_fasta_path).unwrap(); - - log::info!("Starting compression by taxid"); - let mut accession_count = 0; - let mut unique_accession_count = 0; - for (i, entry) in fs::read_dir(taxid_dir.path()).unwrap().enumerate() { - let entry = entry.unwrap(); - let path = entry.path(); - let input_fasta_path = path.to_str().unwrap(); - fasta_compress_taxid( - input_fasta_path, - &mut writer, - scaled, - k, - seed, - similarity_threshold, - chunk_size, - branch_factor, - &mut accession_count, - &mut unique_accession_count, - ); - - if i % 10_000 == 0 { - log::info!( - " Compressed {} taxids, {} accessions, {} uniqe accessions", - i, - accession_count, - unique_accession_count - ); - } - } - - taxid_dir.close().unwrap(); - log::info!( - "Finished compression by taxid, {} accessions, {} uniqe accessions", - accession_count, - unique_accession_count - ); -} +use ncbi_compress::ncbi_compress::ncbi_compress::fasta_compress; +// pub mod ncbi_compress; #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] @@ -440,7 +60,8 @@ struct Args { branch_factor: usize, } -fn main() { + +pub fn main() { Builder::new() .format(|buf, record| { writeln!( diff --git a/workflows/index-generation/ncbi-compress/src/ncbi_compress.rs b/workflows/index-generation/ncbi-compress/src/ncbi_compress.rs new file mode 100644 index 000000000..059bf204f --- /dev/null +++ b/workflows/index-generation/ncbi-compress/src/ncbi_compress.rs @@ -0,0 +1,387 @@ +pub mod ncbi_compress { + use std::ops::AddAssign; + use std::path::Path; + use std::{borrow::BorrowMut, fs}; + + use bio::io::fasta; + use rayon::prelude::*; + use sourmash::encodings::HashFunctions; + use sourmash::errors::SourmashError; + use sourmash::signature::SigsTrait; + use sourmash::sketch::minhash::KmerMinHash; + use tempdir::TempDir; + use trie_rs::{Trie, TrieBuilder}; + + /// A trie that stores u64 values + struct TrieStore { + trie: Trie, + } + + impl TrieStore { + pub fn get(&self, key: &str) -> Option { + self.trie + .predictive_search(key.as_bytes()) + .first() + .map(|bytes| { + let mut bytes = bytes.to_vec(); + let value_bytes = bytes.split_off(bytes.len() - 8); + u64::from_be_bytes(value_bytes.try_into().unwrap()) + }) + } + } + + struct TrieStoreBuilder { + builder: TrieBuilder, + } + + impl TrieStoreBuilder { + pub fn new() -> Self { + TrieStoreBuilder { + builder: TrieBuilder::new(), + } + } + + pub fn push(&mut self, key: &str, value: u64) { + let mut key = key.as_bytes().to_vec(); + key.extend_from_slice(&value.to_be_bytes()); + self.builder.push(key); + } + + pub fn build(self) -> TrieStore { + let trie = self.builder.build(); + TrieStore { trie } + } + } + + struct MinHashTreeNode { + own: KmerMinHash, + children_aggregate: KmerMinHash, + } + + struct MinHashTree { + branching_factor: usize, + nodes: Vec, + } + + impl MinHashTree { + fn new(branching_factor: usize) -> Self { + MinHashTree { + branching_factor, + nodes: Vec::new(), + } + } + + fn parent_idx(&self, node: usize) -> Option { + if node == 0 { + None + } else { + Some((node - 1) / self.branching_factor) + } + } + + fn child_idxes(&self, node: usize) -> Vec { + let first = self.branching_factor * node + 0 + 1; + let last = self.branching_factor * node + self.branching_factor + 1; + (first..last.min(self.nodes.len() - 1)).collect() + } + + fn merge_to_parent( + &mut self, + parent_idx: usize, + child_idx: usize, + ) -> Result<(), SourmashError> { + let (left, right) = self.nodes.split_at_mut(child_idx); + left[parent_idx] + .children_aggregate + .merge(&right[0].children_aggregate) + } + + pub fn insert(&mut self, hash: KmerMinHash) -> Result<(), SourmashError> { + let node = MinHashTreeNode { + own: hash.clone(), + children_aggregate: hash.clone(), + }; + let mut current_idx = self.nodes.len(); + self.nodes.push(node); + + while let Some(parent_idx) = self.parent_idx(current_idx) { + // no need to aggregate to the root, it would just contain everything thus providing no information + if parent_idx == 0 { + break; + } + + self.merge_to_parent(parent_idx, current_idx)?; + current_idx = parent_idx; + } + Ok(()) + } + + pub fn contains( + &self, + hash: &KmerMinHash, + similarity_threshold: f64, + ) -> Result { + if self.nodes.is_empty() { + return Ok(false); + } + + let mut to_visit = vec![0]; + while !to_visit.is_empty() { + let found = to_visit.par_iter().any(|node_idx| { + let node = self.nodes.get(*node_idx).unwrap(); + containment(hash, &node.own).unwrap() >= similarity_threshold + }); + + if found { + return Ok(true); + } + + to_visit = to_visit + .par_iter() + .flat_map(|node_idx| { + let node = self.nodes.get(*node_idx).unwrap(); + if containment(hash, &node.children_aggregate).unwrap() >= similarity_threshold + { + self.child_idxes(*node_idx) + } else { + vec![] + } + }) + .collect(); + } + Ok(false) + } + } + + fn containment(needle: &KmerMinHash, haystack: &KmerMinHash) -> Result { + let (intersect_size, _) = needle.intersection_size(haystack)?; + Ok(intersect_size as f64 / needle.mins().len() as f64) + } + + fn remove_accession_version(accession: &str) -> &str { + accession.splitn(2, |c| c == '.').next().unwrap() + } + + pub fn split_accessions_by_taxid + std::fmt::Debug, Q: AsRef + std::fmt::Debug>( + input_fasta_path: P, + mapping_file_path: Vec, + taxids_to_drop: &Vec, + ) -> TempDir { + log::info!("Creating accession to taxid mapping"); + let taxid_dir = TempDir::new("accessions_by_taxid").unwrap(); + let reader = fasta::Reader::from_file(&input_fasta_path).unwrap(); + let mut builder = TrieBuilder::new(); + reader.records().enumerate().for_each(|(i, result)| { + let record = result.unwrap(); + let accession_id = record.id().split_whitespace().next().unwrap(); + let accession_no_version = remove_accession_version(accession_id); + builder.push(accession_no_version); + if i % 10_000 == 0 { + log::info!(" Processed {} accessions", i); + } + }); + log::info!(" Started building accession trie"); + let accessions_trie = builder.build(); + log::info!(" Finished building accession trie"); + + let mut builder = TrieStoreBuilder::new(); + mapping_file_path.iter().for_each(|mapping_file_path| { + log::info!(" Processing mapping file {:?}", mapping_file_path); + let reader = csv::ReaderBuilder::new() + .delimiter(b'\t') + .from_path(mapping_file_path) + .unwrap(); + let mut added = 0; + reader.into_records().enumerate().for_each(|(i, result)| { + if i % 10_000 == 0 { + log::info!(" Processed {} mappings, added {}", i, added); + } + + let record = result.unwrap(); + let accession = &record[0]; + let accession_no_version = remove_accession_version(accession); + + // Only output mappings if the accession is in the source files + if !accessions_trie.exact_match(accession_no_version) { + return; + } + + // If using the prot.accession2taxid.FULL file + let taxid = if record.len() < 3 { + // The taxid will be at index 1 + record[1].parse::().unwrap() + } else { + // Otherwise there is a versionless accession ID at index 0 and the taxid is at index 2 + record[2].parse::().unwrap() + }; + + if !taxids_to_drop.contains(&taxid) { + added += 1; + builder.push(accession_no_version, taxid); + } + }); + log::info!( + " Finished Processing mapping file {:?}, added {} mappings", + mapping_file_path, + added + ); + }); + log::info!(" Started building accession to taxid trie"); + let accession_to_taxid = builder.build(); + log::info!("Finished building accession to taxid trie"); + + log::info!("Splitting accessions by taxid"); + let reader = fasta::Reader::from_file(&input_fasta_path).unwrap(); + for (i, record) in reader.records().enumerate() { + if i % 10_000 == 0 { + log::info!(" Split {} accessions", i); + } + let record = record.unwrap(); + let accession_id = record.id().split_whitespace().next().unwrap(); + let acccession_no_version = remove_accession_version(accession_id); + let taxid = if let Some(taxid) = accession_to_taxid.get(acccession_no_version) { + taxid + } else { + continue; + }; + println!("taxid: {}", taxid); + + let file_path = taxid_dir.path().join(format!("{}.fasta", taxid)); + let file = fs::OpenOptions::new() + .create(true) + .append(true) + .open(file_path) + .unwrap(); + let mut writer = fasta::Writer::new(file); + writer.write_record(&record).unwrap(); + } + log::info!("Finished splitting accessions by taxid"); + + taxid_dir + } + + pub fn fasta_compress_taxid + std::fmt::Debug>( + input_fasta_path: P, + writer: &mut fasta::Writer, + scaled: u64, + k: u32, + seed: u64, + similarity_threshold: f64, + chunk_size: usize, + branch_factor: usize, + accession_count: &mut u64, + unique_accession_count: &mut u64, + ) { + let reader = fasta::Reader::from_file(&input_fasta_path).unwrap(); + let mut tree = MinHashTree::new(branch_factor); + + let mut records_iter = reader.records(); + let mut chunk = records_iter + .borrow_mut() + .take(chunk_size) + .collect::>(); + while chunk.len() > 0 { + let chunk_signatures = chunk + .par_iter() + .filter_map(|r| { + let record = r.as_ref().unwrap(); + let mut hash = + KmerMinHash::new(scaled, k, HashFunctions::murmur64_DNA, seed, false, 0); + hash.add_sequence(record.seq(), true).unwrap(); + // Run an initial similarity check here against the full tree, this is slow so we can parallelize it + if tree.contains(&hash, similarity_threshold).unwrap() { + None + } else { + Some((record, hash)) + } + }) + .collect::>(); + + let mut tmp = Vec::with_capacity(chunk_signatures.len() / 2); + for (record, hash) in chunk_signatures { + accession_count.add_assign(1); + // Perform a faster similarity check over just this chunk because we may have similarities within a chunk + let similar = tmp + .par_iter() + .any(|other| containment(&hash, other).unwrap() >= similarity_threshold); + + if !similar { + unique_accession_count.add_assign(1); + tmp.push(hash); + writer.write_record(record).unwrap(); + + if *unique_accession_count % 10_000 == 0 { + log::info!( + "Processed {} accessions, {} unique", + accession_count, + unique_accession_count + ); + } + } + } + for hash in tmp { + tree.insert(hash).unwrap(); + } + chunk = records_iter + .borrow_mut() + .take(chunk_size) + .collect::>(); + } + } + + pub fn fasta_compress + std::fmt::Debug>( + input_fasta_path: P, + accession_mapping_files: Vec

, + output_fasta_path: P, + taxids_to_drop: Vec, + scaled: u64, + k: u32, + seed: u64, + similarity_threshold: f64, + chunk_size: usize, + branch_factor: usize, + ) { + log::info!("Splitting accessions by taxid"); + let taxid_dir = + split_accessions_by_taxid(&input_fasta_path, accession_mapping_files, &taxids_to_drop); + log::info!("Finished splitting accessions by taxid"); + let mut writer = fasta::Writer::to_file(output_fasta_path).unwrap(); + + log::info!("Starting compression by taxid"); + let mut accession_count = 0; + let mut unique_accession_count = 0; + for (i, entry) in fs::read_dir(taxid_dir.path()).unwrap().enumerate() { + let entry = entry.unwrap(); + let path = entry.path(); + let input_fasta_path = path.to_str().unwrap(); + fasta_compress_taxid( + input_fasta_path, + &mut writer, + scaled, + k, + seed, + similarity_threshold, + chunk_size, + branch_factor, + &mut accession_count, + &mut unique_accession_count, + ); + + if i % 10_000 == 0 { + log::info!( + " Compressed {} taxids, {} accessions, {} uniqe accessions", + i, + accession_count, + unique_accession_count + ); + } + } + + taxid_dir.close().unwrap(); + log::info!( + "Finished compression by taxid, {} accessions, {} uniqe accessions", + accession_count, + unique_accession_count + ); + } +} diff --git a/workflows/index-generation/ncbi-compress/src/util.rs b/workflows/index-generation/ncbi-compress/src/util.rs new file mode 100644 index 000000000..03f094679 --- /dev/null +++ b/workflows/index-generation/ncbi-compress/src/util.rs @@ -0,0 +1,39 @@ +pub mod util { + use std::fs; + use std::io::Read; + use std::io::Write; + + pub fn are_files_equal(file_path1: &str, file_path2: &str) -> bool { + if let Ok(contents1) = fs::read(file_path1) { + println!("contents1"); + if let Ok(contents2) = fs::read(file_path2) { + println!("contents2"); + return contents1 == contents2; + } + } + false + } + + pub fn write_to_file(filename: &str, content: &str) -> std::io::Result<()> { + println!("starting write to file in write_to_file"); + let mut file = fs::File::create(filename)?; + println!("{}", format!("Wrote to file: {}", filename)); + file.write_all(content.as_bytes())?; + Ok(()) + } + + pub fn read_contents(input_fasta_path: &str) -> String { + let mut file_content = String::new(); + let mut file = std::fs::File::open(&input_fasta_path).expect("Failed to open file"); + file.read_to_string(&mut file_content) + .expect("Failed to read from file"); + return file_content + } + + pub fn read_and_write_to_file(input_fasta_path: &str, output_fasta_path: &str) -> () { + let mut file_content = read_contents(input_fasta_path); + println!("{}", format!("writing to file: {}", output_fasta_path)); + let _ = write_to_file(&output_fasta_path, &file_content); + } +} + diff --git a/workflows/index-generation/ncbi-compress/tests/integration_test.rs b/workflows/index-generation/ncbi-compress/tests/integration_test.rs new file mode 100644 index 000000000..4c83f9eca --- /dev/null +++ b/workflows/index-generation/ncbi-compress/tests/integration_test.rs @@ -0,0 +1,37 @@ +use std::path::Path; + +use ::ncbi_compress::ncbi_compress::ncbi_compress::fasta_compress; +use ::ncbi_compress::util::util; + +#[test] +fn test_fasta_compress() { + assert!(true); + let pathogens = ["chkv", "streptococcus", "rhinovirus"]; + + for pathogen in pathogens .iter() { + + let mapping_files = vec![ + Path::new("tests/test_data/accession2taxid/nucl_gb.accession2taxid.subset"), + Path::new("tests/test_data/accession2taxid/nucl_wgs.accession2taxid.subset"), + Path::new("tests/test_data/accession2taxid/pdb.accession2taxid.subset"), + ]; + + let sorted_seq = format!("tests/test_data/simulated_seqs/all_simulated_seqs_{}_sorted.fasta", pathogen); + let expected_compressed = format!("tests/test_data/expected_compression_results/nt_compressed_0.6_{}.fa", pathogen); + + fasta_compress( + Path::new(&sorted_seq), + mapping_files, + Path::new("tests/test_data/test_output.fasta"), + vec![9606], + 1000, + 31, + 42, + 0.6, + 1000, + 1000 + ); + assert!(util::are_files_equal("tests/test_data/test_output.fasta", &expected_compressed)); + } + +} diff --git a/workflows/index-generation/ncbi-compress/tests/tests.rs b/workflows/index-generation/ncbi-compress/tests/tests.rs new file mode 100644 index 000000000..f60082853 --- /dev/null +++ b/workflows/index-generation/ncbi-compress/tests/tests.rs @@ -0,0 +1,66 @@ +use std::path::Path; +use std::fs; + +use bio::io::fasta; +use tempfile::NamedTempFile; + +use ::ncbi_compress::ncbi_compress::ncbi_compress; +use ::ncbi_compress::util::util; + +#[test] +fn test_split_accessions_by_taxis() { + let pathogens = ["chkv", "streptococcus", "rhinovirus"]; + + for pathogen in pathogens .iter() { + let mapping_files = vec![ + Path::new("tests/test_data/accession2taxid/nucl_gb.accession2taxid.subset"), + Path::new("tests/test_data/accession2taxid/nucl_wgs.accession2taxid.subset"), + Path::new("tests/test_data/accession2taxid/pdb.accession2taxid.subset"), + ]; + let sorted_seq = format!("tests/test_data/simulated_seqs/all_simulated_seqs_{}_sorted.fasta", pathogen); + let taxids_to_drop: Vec = vec![9606]; + let taxid_dir = ncbi_compress::split_accessions_by_taxid( + sorted_seq, + mapping_files, + &taxids_to_drop, + ); + for (i, entry) in fs::read_dir(taxid_dir.path()).unwrap().enumerate() { + let entry = entry.unwrap(); + let path = entry.path(); + let input_fasta_path = path.to_str().unwrap(); + let expected = format!("tests/test_data/expected_split_accessions_by_taxid/test_file_{}_{}.txt", i, pathogen); + assert!(util::are_files_equal(input_fasta_path, &expected)) + } + } +} + +// #[test] +// fn test_fasta_compress_taxid() { +// let input_fasta_path = Path::new("tests/test_data/simulated_seqs/all_simulated_seqs_chkv_sorted.fasta"); +// let expected_fasta_path = "tests/test_data/expected_fasta_compress_taxid_chkv.fasta"; +// let mut temp_file = NamedTempFile::new().unwrap(); +// let temp_file_path = temp_file.path(); +// let mut writer = fasta::Writer::to_file(temp_file_path).unwrap(); +// let scaled = 1000; +// let k = 31; +// let seed = 42; +// let similarity_threshold = 0.6; +// let chunk_size = 1000; +// let branch_factor = 1000; +// let mut accession_count = 0; +// let mut unique_accession_count = 0; +// ncbi_compress::fasta_compress_taxid( +// input_fasta_path, +// &mut writer, +// scaled, +// k, +// seed, +// similarity_threshold, +// chunk_size, +// branch_factor, +// &mut accession_count, +// &mut unique_accession_count, +// ); + +// assert!(util::are_files_equal(temp_file_path.file_name().unwrap().to_str().unwrap(), &expected_fasta_path)); +// }