Skip to content

Commit

Permalink
fix(storage-scrubber): use default AWS authentication
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Chi Z <chi@neon.tech>
  • Loading branch information
skyzh committed Jul 5, 2024
1 parent e3dd5ae commit ddebdb6
Show file tree
Hide file tree
Showing 8 changed files with 20 additions and 71 deletions.
2 changes: 1 addition & 1 deletion storage_scrubber/src/find_large_objects.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ pub async fn find_large_objects(
min_size: u64,
ignore_deltas: bool,
) -> anyhow::Result<LargeObjectListing> {
let (s3_client, target) = init_remote(bucket_config.clone(), NodeKind::Pageserver)?;
let (s3_client, target) = init_remote(bucket_config.clone(), NodeKind::Pageserver).await?;
let mut tenants = std::pin::pin!(stream_tenants(&s3_client, &target));
let mut objects = Vec::new();
let mut tenant_ctr = 0u64;
Expand Down
4 changes: 2 additions & 2 deletions storage_scrubber/src/garbage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ async fn find_garbage_inner(
node_kind: NodeKind,
) -> anyhow::Result<GarbageList> {
// Construct clients for S3 and for Console API
let (s3_client, target) = init_remote(bucket_config.clone(), node_kind)?;
let (s3_client, target) = init_remote(bucket_config.clone(), node_kind).await?;
let cloud_admin_api_client = Arc::new(CloudAdminApiClient::new(console_config));

// Build a set of console-known tenants, for quickly eliminating known-active tenants without having
Expand Down Expand Up @@ -432,7 +432,7 @@ pub async fn purge_garbage(
);

let (s3_client, target) =
init_remote(garbage_list.bucket_config.clone(), garbage_list.node_kind)?;
init_remote(garbage_list.bucket_config.clone(), garbage_list.node_kind).await?;

// Sanity checks on the incoming list
if garbage_list.active_tenant_count == 0 {
Expand Down
70 changes: 9 additions & 61 deletions storage_scrubber/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,8 @@ use std::sync::Arc;
use std::time::Duration;

use anyhow::Context;
use aws_config::environment::EnvironmentVariableCredentialsProvider;
use aws_config::imds::credentials::ImdsCredentialsProvider;
use aws_config::meta::credentials::CredentialsProviderChain;
use aws_config::profile::ProfileFileCredentialsProvider;
use aws_config::retry::RetryConfig;
use aws_config::sso::SsoCredentialsProvider;
use aws_config::BehaviorVersion;
use aws_sdk_s3::config::{AsyncSleep, Region, SharedAsyncSleep};
use aws_sdk_s3::{Client, Config};
use aws_smithy_async::rt::sleep::TokioSleep;
use aws_sdk_s3::config::Region;
use aws_sdk_s3::Client;

use camino::{Utf8Path, Utf8PathBuf};
use clap::ValueEnum;
Expand Down Expand Up @@ -277,65 +269,21 @@ pub fn init_logging(file_name: &str) -> Option<WorkerGuard> {
}
}

pub fn init_s3_client(bucket_region: Region) -> Client {
let credentials_provider = {
// uses "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"
let chain = CredentialsProviderChain::first_try(
"env",
EnvironmentVariableCredentialsProvider::new(),
)
// uses "AWS_PROFILE" / `aws sso login --profile <profile>`
.or_else(
"profile-sso",
ProfileFileCredentialsProvider::builder().build(),
);

// Use SSO if we were given an account ID
match std::env::var("SSO_ACCOUNT_ID").ok() {
Some(sso_account) => chain.or_else(
"sso",
SsoCredentialsProvider::builder()
.account_id(sso_account)
.role_name("PowerUserAccess")
.start_url("https://neondb.awsapps.com/start")
.region(bucket_region.clone())
.build(),
),
None => chain,
}
.or_else(
// Finally try IMDS
"imds",
ImdsCredentialsProvider::builder().build(),
)
};

let sleep_impl: Arc<dyn AsyncSleep> = Arc::new(TokioSleep::new());

let mut builder = Config::builder()
.behavior_version(
#[allow(deprecated)] /* TODO: https://github.com/neondatabase/neon/issues/7665 */
BehaviorVersion::v2023_11_09(),
)
pub async fn init_s3_client(bucket_region: Region) -> Client {
let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
.region(bucket_region)
.retry_config(RetryConfig::adaptive().with_max_attempts(3))
.sleep_impl(SharedAsyncSleep::from(sleep_impl))
.credentials_provider(credentials_provider);

if let Ok(endpoint) = env::var("AWS_ENDPOINT_URL") {
builder = builder.endpoint_url(endpoint)
}

Client::from_conf(builder.build())
.load()
.await;
Client::new(&config)
}

fn init_remote(
async fn init_remote(
bucket_config: BucketConfig,
node_kind: NodeKind,
) -> anyhow::Result<(Arc<Client>, RootTarget)> {
let bucket_region = Region::new(bucket_config.region);
let delimiter = "/".to_string();
let s3_client = Arc::new(init_s3_client(bucket_region));
let s3_client = Arc::new(init_s3_client(bucket_region).await);

let s3_root = match node_kind {
NodeKind::Pageserver => RootTarget::Pageserver(S3Target {
Expand Down
2 changes: 1 addition & 1 deletion storage_scrubber/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ async fn main() -> anyhow::Result<()> {
concurrency,
} => {
let downloader =
SnapshotDownloader::new(bucket_config, tenant_id, output_path, concurrency)?;
SnapshotDownloader::new(bucket_config, tenant_id, output_path, concurrency).await?;
downloader.download().await
}
Command::PageserverPhysicalGc {
Expand Down
2 changes: 1 addition & 1 deletion storage_scrubber/src/pageserver_physical_gc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ pub async fn pageserver_physical_gc(
min_age: Duration,
mode: GcMode,
) -> anyhow::Result<GcSummary> {
let (s3_client, target) = init_remote(bucket_config.clone(), NodeKind::Pageserver)?;
let (s3_client, target) = init_remote(bucket_config.clone(), NodeKind::Pageserver).await?;

let tenants = if tenant_ids.is_empty() {
futures::future::Either::Left(stream_tenants(&s3_client, &target))
Expand Down
2 changes: 1 addition & 1 deletion storage_scrubber/src/scan_pageserver_metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ pub async fn scan_metadata(
bucket_config: BucketConfig,
tenant_ids: Vec<TenantShardId>,
) -> anyhow::Result<MetadataSummary> {
let (s3_client, target) = init_remote(bucket_config, NodeKind::Pageserver)?;
let (s3_client, target) = init_remote(bucket_config, NodeKind::Pageserver).await?;

let tenants = if tenant_ids.is_empty() {
futures::future::Either::Left(stream_tenants(&s3_client, &target))
Expand Down
2 changes: 1 addition & 1 deletion storage_scrubber/src/scan_safekeeper_metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ pub async fn scan_safekeeper_metadata(
let timelines = client.query(&query, &[]).await?;
info!("loaded {} timelines", timelines.len());

let (s3_client, target) = init_remote(bucket_config, NodeKind::Safekeeper)?;
let (s3_client, target) = init_remote(bucket_config, NodeKind::Safekeeper).await?;
let console_config = ConsoleConfig::from_env()?;
let cloud_admin_api_client = CloudAdminApiClient::new(console_config);

Expand Down
7 changes: 4 additions & 3 deletions storage_scrubber/src/tenant_snapshot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ pub struct SnapshotDownloader {
}

impl SnapshotDownloader {
pub fn new(
pub async fn new(
bucket_config: BucketConfig,
tenant_id: TenantId,
output_path: Utf8PathBuf,
concurrency: usize,
) -> anyhow::Result<Self> {
let (s3_client, s3_root) = init_remote(bucket_config.clone(), NodeKind::Pageserver)?;
let (s3_client, s3_root) = init_remote(bucket_config.clone(), NodeKind::Pageserver).await?;
Ok(Self {
s3_client,
s3_root,
Expand Down Expand Up @@ -215,7 +215,8 @@ impl SnapshotDownloader {
}

pub async fn download(&self) -> anyhow::Result<()> {
let (s3_client, target) = init_remote(self.bucket_config.clone(), NodeKind::Pageserver)?;
let (s3_client, target) =
init_remote(self.bucket_config.clone(), NodeKind::Pageserver).await?;

// Generate a stream of TenantShardId
let shards = stream_tenant_shards(&s3_client, &target, self.tenant_id).await?;
Expand Down

0 comments on commit ddebdb6

Please sign in to comment.