From 6dee5452a71ce0fd5a5b6a0bfdaee990dfa8b4db Mon Sep 17 00:00:00 2001 From: Matt Briggs Date: Thu, 6 May 2021 12:59:21 -0700 Subject: [PATCH] pubsys: limit threads during validate-repo In a previous quick fix we spawned a thread for every target during pubsys validate-repo. Now we limit the number of threads with a rayon thread pool. --- tools/Cargo.lock | 87 ++++++++++++++++++++++ tools/pubsys/Cargo.toml | 2 + tools/pubsys/src/repo/validate_repo/mod.rs | 53 +++++++------ 3 files changed, 121 insertions(+), 21 deletions(-) diff --git a/tools/Cargo.lock b/tools/Cargo.lock index 7edc7392860..1726aab5d8a 100644 --- a/tools/Cargo.lock +++ b/tools/Cargo.lock @@ -301,6 +301,51 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crossbeam-channel" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06ed27e177f16d65f0f0c22a213e17c696ace5dd64b14258b52f9417ccb52db4" +dependencies = [ + "cfg-if", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94af6efb46fef72616855b036a624cf27ba656ffc9be1b9a3c931cfc7749a9a9" +dependencies = [ + "cfg-if", + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52fb27eab85b17fbb9f6fd667089e07d6a2eb8743d02639ee7f6a7a7729c9c94" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "lazy_static", + "memoffset", + "scopeguard", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4feb231f0d4d6af81aed15928e58ecf5816aa62a2393e2c82f46973e92a9a278" +dependencies = [ + "autocfg", + "cfg-if", + "lazy_static", +] + [[package]] name = "crypto-mac" version = "0.10.0" @@ -386,6 +431,12 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ee2626afccd7561a06cf1367e2950c4718ea04565e20fb5029b6c7d8ad09abcf" +[[package]] +name = "either" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e78d4f1cc4ae33bbfc157ed5d5a5ef3bc29227303d595861deb238fcec4e9457" + [[package]] name = "encode_unicode" version = "0.3.6" @@ -794,6 +845,15 @@ version = "2.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ee1c47aaa256ecabcaea351eae4a9b01ef39ed810004e298d2511ed284b1525" +[[package]] +name = "memoffset" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f83fb6581e8ed1f85fd45c116db8405483899489e38406156c25eb743554361d" +dependencies = [ + "autocfg", +] + [[package]] name = "mime" version = "0.3.16" @@ -1082,8 +1142,10 @@ dependencies = [ "indicatif", "lazy_static", "log", + "num_cpus", "parse-datetime", "pubsys-config", + "rayon", "reqwest", "rusoto_core", "rusoto_credential", @@ -1189,6 +1251,31 @@ dependencies = [ "rand_core", ] +[[package]] +name = "rayon" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b0d8e0819fadc20c74ea8373106ead0600e3a67ef1fe8da56e39b9ae7275674" +dependencies = [ + "autocfg", + "crossbeam-deque", + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ab346ac5921dc62ffa9f89b7a773907511cdfa5490c572ae9be1be33e8afa4a" +dependencies = [ + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-utils", + "lazy_static", + "num_cpus", +] + [[package]] name = "redox_syscall" version = "0.2.7" diff --git a/tools/pubsys/Cargo.toml b/tools/pubsys/Cargo.toml index 19482f42038..e435b2cc325 100644 --- a/tools/pubsys/Cargo.toml +++ b/tools/pubsys/Cargo.toml @@ -16,7 +16,9 @@ futures = "0.3.5" indicatif = "0.15.0" lazy_static = "1.4" log = "0.4" +num_cpus = "1" parse-datetime = { path = "../../sources/parse-datetime" } +rayon = "1" # Need to bring in reqwest with a TLS feature so tough can support TLS repos. reqwest = { version = "0.11.1", default-features = false, features = ["rustls-tls", "blocking"] } rusoto_core = { version = "0.46.0", default-features = false, features = ["rustls"] } diff --git a/tools/pubsys/src/repo/validate_repo/mod.rs b/tools/pubsys/src/repo/validate_repo/mod.rs index 26ea168d845..46aff5a7a50 100644 --- a/tools/pubsys/src/repo/validate_repo/mod.rs +++ b/tools/pubsys/src/repo/validate_repo/mod.rs @@ -6,10 +6,11 @@ use crate::Args; use log::{info, trace}; use pubsys_config::InfraConfig; use snafu::{OptionExt, ResultExt}; +use std::cmp::min; use std::fs::File; use std::io; use std::path::PathBuf; -use std::thread::spawn; +use std::sync::mpsc; use structopt::StructOpt; use tough::{Repository, RepositoryLoader}; use url::Url; @@ -38,13 +39,25 @@ pub(crate) struct ValidateRepoArgs { validate_targets: bool, } -/// Retrieves listed targets and attempts to download them for validation purposes +/// If we are on a machine with a large number of cores, then we limit the number of simultaneous +/// downloads to this arbitrarily chosen maximum. +const MAX_DOWNLOAD_THREADS: usize = 16; + +/// Retrieves listed targets and attempts to download them for validation purposes. We use a Rayon +/// thread pool instead of tokio for async execution because `reqwest::blocking` creates a tokio +/// runtime (and multiple tokio runtimes are not supported). fn retrieve_targets(repo: &Repository) -> Result<(), Error> { let targets = &repo.targets().signed.targets; + let thread_pool = rayon::ThreadPoolBuilder::new() + .num_threads(min(num_cpus::get(), MAX_DOWNLOAD_THREADS)) + .build() + .context(error::ThreadPool)?; + + // create the channels through which our download results will be passed + let (tx, rx) = mpsc::channel(); - let mut tasks = Vec::new(); for target in targets.keys().cloned() { - let target = target.to_string(); + let tx = tx.clone(); let mut reader = repo .read_target(&target) .with_context(|| repo_error::ReadTarget { @@ -54,24 +67,22 @@ fn retrieve_targets(repo: &Repository) -> Result<(), Error> { target: target.to_string(), })?; info!("Downloading target: {}", target); - // TODO - limit threads https://github.com/bottlerocket-os/bottlerocket/issues/1522 - tasks.push(spawn(move || { - // tough's `Read` implementation validates the target as it's being downloaded - io::copy(&mut reader, &mut io::sink()).context(error::TargetDownload { - target: target.to_string(), + thread_pool.spawn(move || { + tx.send({ + // tough's `Read` implementation validates the target as it's being downloaded + io::copy(&mut reader, &mut io::sink()).context(error::TargetDownload { + target: target.to_string(), + }) }) - })); + // inability to send on this channel is unrecoverable + .unwrap(); + }); } + // close all senders + drop(tx); - // ensure that we join all threads before checking the results - let mut results = Vec::new(); - for task in tasks { - let result = task.join().map_err(|e| error::Error::Join { - // the join function is returning an error type that does not implement error or display - inner: format!("{:?}", e), - })?; - results.push(result); - } + // block and await all downloads + let results: Vec> = rx.into_iter().collect(); // check all results and return the first error we see for result in results { @@ -164,8 +175,8 @@ mod error { #[snafu(display("Missing target: {}", target))] TargetMissing { target: String }, - #[snafu(display("Failed to join thread: {}", inner))] - Join { inner: String }, + #[snafu(display("Unable to create thread pool: {}", source))] + ThreadPool { source: rayon::ThreadPoolBuildError }, } } pub(crate) use error::Error;