diff --git a/libs/metrics/src/hll.rs b/libs/metrics/src/hll.rs index 46a623b0e243..dfb4461ce952 100644 --- a/libs/metrics/src/hll.rs +++ b/libs/metrics/src/hll.rs @@ -40,7 +40,7 @@ macro_rules! register_hll { }}; ($N:literal, $NAME:expr, $HELP:expr $(,)?) => {{ - $crate::register_hll!($N, $crate::opts!($NAME, $HELP), $LABELS_NAMES) + $crate::register_hll!($N, $crate::opts!($NAME, $HELP)) }}; } diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index 04fe83d8ebc7..e421798067b9 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -12,6 +12,8 @@ use crate::console::errors::GetAuthInfoError; use crate::console::provider::{CachedRoleSecret, ConsoleBackend}; use crate::console::{AuthSecret, NodeInfo}; use crate::context::RequestMonitoring; +use crate::intern::EndpointIdInt; +use crate::metrics::{AUTH_RATE_LIMIT_HITS, ENDPOINTS_AUTH_RATE_LIMITED}; use crate::proxy::connect_compute::ComputeConnectBackend; use crate::proxy::NeonOptions; use crate::stream::Stream; @@ -28,7 +30,7 @@ use crate::{ use crate::{scram, EndpointCacheKey, EndpointId, RoleName}; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; -use tracing::info; +use tracing::{info, warn}; /// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality pub enum MaybeOwned<'a, T> { @@ -174,6 +176,52 @@ impl TryFrom for ComputeUserInfo { } } +impl AuthenticationConfig { + pub fn check_rate_limit( + &self, + + ctx: &mut RequestMonitoring, + secret: AuthSecret, + endpoint: &EndpointId, + is_cleartext: bool, + ) -> auth::Result { + // we have validated the endpoint exists, so let's intern it. + let endpoint_int = EndpointIdInt::from(endpoint); + + // only count the full hash count if password hack or websocket flow. + // in other words, if proxy needs to run the hashing + let password_weight = if is_cleartext { + match &secret { + #[cfg(any(test, feature = "testing"))] + AuthSecret::Md5(_) => 1, + AuthSecret::Scram(s) => s.iterations + 1, + } + } else { + // validating scram takes just 1 hmac_sha_256 operation. + 1 + }; + + let limit_not_exceeded = self + .rate_limiter + .check((endpoint_int, ctx.peer_addr), password_weight); + + if !limit_not_exceeded { + warn!( + enabled = self.rate_limiter_enabled, + "rate limiting authentication" + ); + AUTH_RATE_LIMIT_HITS.inc(); + ENDPOINTS_AUTH_RATE_LIMITED.measure(endpoint); + + if self.rate_limiter_enabled { + return Err(auth::AuthError::too_many_connections()); + } + } + + Ok(secret) + } +} + /// True to its name, this function encapsulates our current auth trade-offs. /// Here, we choose the appropriate auth flow based on circumstances. /// @@ -214,14 +262,24 @@ async fn auth_quirks( Some(secret) => secret, None => api.get_role_secret(ctx, &info).await?, }; + let (cached_entry, secret) = cached_secret.take_value(); + + let secret = match secret { + Some(secret) => config.check_rate_limit( + ctx, + secret, + &info.endpoint, + unauthenticated_password.is_some() || allow_cleartext, + )?, + None => { + // If we don't have an authentication secret, we mock one to + // prevent malicious probing (possible due to missing protocol steps). + // This mocked secret will never lead to successful authentication. + info!("authentication info not found, mocking it"); + AuthSecret::Scram(scram::ServerSecret::mock(rand::random())) + } + }; - let secret = cached_secret.value.clone().unwrap_or_else(|| { - // If we don't have an authentication secret, we mock one to - // prevent malicious probing (possible due to missing protocol steps). - // This mocked secret will never lead to successful authentication. - info!("authentication info not found, mocking it"); - AuthSecret::Scram(scram::ServerSecret::mock(&info.user, rand::random())) - }); match authenticate_with_secret( ctx, secret, @@ -237,7 +295,7 @@ async fn auth_quirks( Err(e) => { if e.is_auth_failed() { // The password could have been changed, so we invalidate the cache. - cached_secret.invalidate(); + cached_entry.invalidate(); } Err(e) } @@ -415,6 +473,7 @@ mod tests { use bytes::BytesMut; use fallible_iterator::FallibleIterator; + use once_cell::sync::Lazy; use postgres_protocol::{ authentication::sasl::{ChannelBinding, ScramSha256}, message::{backend::Message as PgMessage, frontend}, @@ -432,6 +491,7 @@ mod tests { }, context::RequestMonitoring, proxy::NeonOptions, + rate_limiter::{AuthRateLimiter, RateBucketInfo}, scram::ServerSecret, stream::{PqStream, Stream}, }; @@ -473,9 +533,11 @@ mod tests { } } - static CONFIG: &AuthenticationConfig = &AuthenticationConfig { + static CONFIG: Lazy = Lazy::new(|| AuthenticationConfig { scram_protocol_timeout: std::time::Duration::from_secs(5), - }; + rate_limiter_enabled: true, + rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET), + }); async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage { loop { @@ -544,7 +606,7 @@ mod tests { } }); - let _creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, false, CONFIG) + let _creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, false, &CONFIG) .await .unwrap(); @@ -584,7 +646,7 @@ mod tests { client.write_all(&write).await.unwrap(); }); - let _creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, true, CONFIG) + let _creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, true, &CONFIG) .await .unwrap(); @@ -624,7 +686,7 @@ mod tests { client.write_all(&write).await.unwrap(); }); - let creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, true, CONFIG) + let creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, true, &CONFIG) .await .unwrap(); diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index d38439c2a0d6..88b847f5f106 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -18,6 +18,7 @@ use proxy::console; use proxy::context::parquet::ParquetUploadArgs; use proxy::http; use proxy::metrics::NUM_CANCELLATION_REQUESTS_SOURCE_FROM_CLIENT; +use proxy::rate_limiter::AuthRateLimiter; use proxy::rate_limiter::EndpointRateLimiter; use proxy::rate_limiter::RateBucketInfo; use proxy::rate_limiter::RateLimiterConfig; @@ -141,10 +142,16 @@ struct ProxyCliArgs { /// /// Provided in the form '@'. /// Can be given multiple times for different bucket sizes. - #[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)] + #[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)] endpoint_rps_limit: Vec, + /// Whether the auth rate limiter actually takes effect (for testing) + #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)] + auth_rate_limit_enabled: bool, + /// Authentication rate limiter max number of hashes per second. + #[clap(long, default_values_t = RateBucketInfo::DEFAULT_AUTH_SET)] + auth_rate_limit: Vec, /// Redis rate limiter max number of requests per second. - #[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)] + #[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)] redis_rps_limit: Vec, /// Initial limit for dynamic rate limiter. Makes sense only if `rate_limit_algorithm` is *not* `None`. #[clap(long, default_value_t = 100)] @@ -510,6 +517,8 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { }; let authentication_config = AuthenticationConfig { scram_protocol_timeout: args.scram_protocol_timeout, + rate_limiter_enabled: args.auth_rate_limit_enabled, + rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()), }; let mut endpoint_rps_limit = args.endpoint_rps_limit.clone(); diff --git a/proxy/src/cache/common.rs b/proxy/src/cache/common.rs index 2af6a70e9072..bc1c37512bce 100644 --- a/proxy/src/cache/common.rs +++ b/proxy/src/cache/common.rs @@ -43,6 +43,16 @@ impl Cached { Self { token: None, value } } + pub fn take_value(self) -> (Cached, V) { + ( + Cached { + token: self.token, + value: (), + }, + self.value, + ) + } + /// Drop this entry from a cache if it's still there. pub fn invalidate(self) -> V { if let Some((cache, info)) = &self.token { diff --git a/proxy/src/cache/project_info.rs b/proxy/src/cache/project_info.rs index 6e3eb8c1b028..5a3660520bf1 100644 --- a/proxy/src/cache/project_info.rs +++ b/proxy/src/cache/project_info.rs @@ -373,10 +373,7 @@ mod tests { let endpoint_id = "endpoint".into(); let user1: RoleName = "user1".into(); let user2: RoleName = "user2".into(); - let secret1 = Some(AuthSecret::Scram(ServerSecret::mock( - user1.as_str(), - [1; 32], - ))); + let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32]))); let secret2 = None; let allowed_ips = Arc::new(vec![ "127.0.0.1".parse().unwrap(), @@ -395,10 +392,7 @@ mod tests { // Shouldn't add more than 2 roles. let user3: RoleName = "user3".into(); - let secret3 = Some(AuthSecret::Scram(ServerSecret::mock( - user3.as_str(), - [3; 32], - ))); + let secret3 = Some(AuthSecret::Scram(ServerSecret::mock([3; 32]))); cache.insert_role_secret(&project_id, &endpoint_id, &user3, secret3.clone()); assert!(cache.get_role_secret(&endpoint_id, &user3).is_none()); @@ -431,14 +425,8 @@ mod tests { let endpoint_id = "endpoint".into(); let user1: RoleName = "user1".into(); let user2: RoleName = "user2".into(); - let secret1 = Some(AuthSecret::Scram(ServerSecret::mock( - user1.as_str(), - [1; 32], - ))); - let secret2 = Some(AuthSecret::Scram(ServerSecret::mock( - user2.as_str(), - [2; 32], - ))); + let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32]))); + let secret2 = Some(AuthSecret::Scram(ServerSecret::mock([2; 32]))); let allowed_ips = Arc::new(vec![ "127.0.0.1".parse().unwrap(), "127.0.0.2".parse().unwrap(), @@ -486,14 +474,8 @@ mod tests { let endpoint_id = "endpoint".into(); let user1: RoleName = "user1".into(); let user2: RoleName = "user2".into(); - let secret1 = Some(AuthSecret::Scram(ServerSecret::mock( - user1.as_str(), - [1; 32], - ))); - let secret2 = Some(AuthSecret::Scram(ServerSecret::mock( - user2.as_str(), - [2; 32], - ))); + let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32]))); + let secret2 = Some(AuthSecret::Scram(ServerSecret::mock([2; 32]))); let allowed_ips = Arc::new(vec![ "127.0.0.1".parse().unwrap(), "127.0.0.2".parse().unwrap(), diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 45f8d7614439..361c3ef519c4 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -1,4 +1,8 @@ -use crate::{auth, rate_limiter::RateBucketInfo, serverless::GlobalConnPoolOptions}; +use crate::{ + auth, + rate_limiter::{AuthRateLimiter, RateBucketInfo}, + serverless::GlobalConnPoolOptions, +}; use anyhow::{bail, ensure, Context, Ok}; use itertools::Itertools; use rustls::{ @@ -50,6 +54,8 @@ pub struct HttpConfig { pub struct AuthenticationConfig { pub scram_protocol_timeout: tokio::time::Duration, + pub rate_limiter_enabled: bool, + pub rate_limiter: AuthRateLimiter, } impl TlsConfig { diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index eed45e421b8f..4172dc19daf9 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -4,7 +4,10 @@ use ::metrics::{ register_int_gauge_vec, Histogram, HistogramVec, HyperLogLogVec, IntCounterPairVec, IntCounterVec, IntGauge, IntGaugeVec, }; -use metrics::{register_int_counter, register_int_counter_pair, IntCounter, IntCounterPair}; +use metrics::{ + register_hll, register_int_counter, register_int_counter_pair, HyperLogLog, IntCounter, + IntCounterPair, +}; use once_cell::sync::Lazy; use tokio::time::{self, Instant}; @@ -358,3 +361,20 @@ pub static TLS_HANDSHAKE_FAILURES: Lazy = Lazy::new(|| { ) .unwrap() }); + +pub static ENDPOINTS_AUTH_RATE_LIMITED: Lazy> = Lazy::new(|| { + register_hll!( + 32, + "proxy_endpoints_auth_rate_limits", + "Number of endpoints affected by authentication rate limits", + ) + .unwrap() +}); + +pub static AUTH_RATE_LIMIT_HITS: Lazy = Lazy::new(|| { + register_int_counter!( + "proxy_requests_auth_rate_limits_total", + "Number of connection requests affected by authentication rate limits", + ) + .unwrap() +}); diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 843bfc08cfa3..6051c0a81242 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -280,7 +280,7 @@ pub async fn handle_client( // check rate limit if let Some(ep) = user_info.get_endpoint() { - if !endpoint_rate_limiter.check(ep) { + if !endpoint_rate_limiter.check(ep, 1) { return stream .throw_error(auth::AuthError::too_many_connections()) .await?; diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index 9c3be7361291..a4051447c1b0 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -142,8 +142,8 @@ impl Scram { Ok(Scram(secret)) } - fn mock(user: &str) -> Self { - Scram(scram::ServerSecret::mock(user, rand::random())) + fn mock() -> Self { + Scram(scram::ServerSecret::mock(rand::random())) } } @@ -330,11 +330,7 @@ async fn scram_auth_mock() -> anyhow::Result<()> { let (client_config, server_config) = generate_tls_config("generic-project-name.localhost", "localhost")?; - let proxy = tokio::spawn(dummy_proxy( - client, - Some(server_config), - Scram::mock("user"), - )); + let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), Scram::mock())); use rand::{distributions::Alphanumeric, Rng}; let password: String = rand::thread_rng() diff --git a/proxy/src/rate_limiter.rs b/proxy/src/rate_limiter.rs index f0da4ead230c..13dffffca01e 100644 --- a/proxy/src/rate_limiter.rs +++ b/proxy/src/rate_limiter.rs @@ -4,4 +4,4 @@ mod limiter; pub use aimd::Aimd; pub use limit_algorithm::{AimdConfig, Fixed, RateLimitAlgorithm, RateLimiterConfig}; pub use limiter::Limiter; -pub use limiter::{EndpointRateLimiter, RateBucketInfo, RedisRateLimiter}; +pub use limiter::{AuthRateLimiter, EndpointRateLimiter, RateBucketInfo, RedisRateLimiter}; diff --git a/proxy/src/rate_limiter/limiter.rs b/proxy/src/rate_limiter/limiter.rs index 3181060e2f93..f590896dd9f4 100644 --- a/proxy/src/rate_limiter/limiter.rs +++ b/proxy/src/rate_limiter/limiter.rs @@ -1,6 +1,8 @@ use std::{ + borrow::Cow, collections::hash_map::RandomState, - hash::BuildHasher, + hash::{BuildHasher, Hash}, + net::IpAddr, sync::{ atomic::{AtomicUsize, Ordering}, Arc, Mutex, @@ -15,7 +17,7 @@ use tokio::sync::{Mutex as AsyncMutex, Semaphore, SemaphorePermit}; use tokio::time::{timeout, Duration, Instant}; use tracing::info; -use crate::EndpointId; +use crate::{intern::EndpointIdInt, EndpointId}; use super::{ limit_algorithm::{LimitAlgorithm, Sample}, @@ -49,11 +51,11 @@ impl RedisRateLimiter { .data .iter_mut() .zip(self.info) - .all(|(bucket, info)| bucket.should_allow_request(info, now)); + .all(|(bucket, info)| bucket.should_allow_request(info, now, 1)); if should_allow_request { // only increment the bucket counts if the request will actually be accepted - self.data.iter_mut().for_each(RateBucket::inc); + self.data.iter_mut().for_each(|b| b.inc(1)); } should_allow_request @@ -71,9 +73,14 @@ impl RedisRateLimiter { // saw SNI, before doing TLS handshake. User-side error messages in that case // does not look very nice (`SSL SYSCALL error: Undefined error: 0`), so for now // I went with a more expensive way that yields user-friendlier error messages. -pub struct EndpointRateLimiter { - map: DashMap, Hasher>, - info: &'static [RateBucketInfo], +pub type EndpointRateLimiter = BucketRateLimiter; + +// This can't be just per IP because that would limit some PaaS that share IP addresses +pub type AuthRateLimiter = BucketRateLimiter<(EndpointIdInt, IpAddr), StdRng, RandomState>; + +pub struct BucketRateLimiter { + map: DashMap, Hasher>, + info: Cow<'static, [RateBucketInfo]>, access_count: AtomicUsize, rand: Mutex, } @@ -85,9 +92,9 @@ struct RateBucket { } impl RateBucket { - fn should_allow_request(&mut self, info: &RateBucketInfo, now: Instant) -> bool { + fn should_allow_request(&mut self, info: &RateBucketInfo, now: Instant, n: u32) -> bool { if now - self.start < info.interval { - self.count < info.max_rpi + self.count + n <= info.max_rpi } else { // bucket expired, reset self.count = 0; @@ -97,8 +104,8 @@ impl RateBucket { } } - fn inc(&mut self) { - self.count += 1; + fn inc(&mut self, n: u32) { + self.count += n; } } @@ -111,7 +118,7 @@ pub struct RateBucketInfo { impl std::fmt::Display for RateBucketInfo { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let rps = self.max_rpi * 1000 / self.interval.as_millis() as u32; + let rps = (self.max_rpi as u64) * 1000 / self.interval.as_millis() as u64; write!(f, "{rps}@{}", humantime::format_duration(self.interval)) } } @@ -136,12 +143,25 @@ impl std::str::FromStr for RateBucketInfo { } impl RateBucketInfo { - pub const DEFAULT_SET: [Self; 3] = [ + pub const DEFAULT_ENDPOINT_SET: [Self; 3] = [ Self::new(300, Duration::from_secs(1)), Self::new(200, Duration::from_secs(60)), Self::new(100, Duration::from_secs(600)), ]; + /// All of these are per endpoint-ip pair. + /// Context: 4096 rounds of pbkdf2 take about 1ms of cpu time to execute (1 milli-cpu-second or 1mcpus). + /// + /// First bucket: 300mcpus total per endpoint-ip pair + /// * 1228800 requests per second with 1 hash rounds. (endpoint rate limiter will catch this first) + /// * 300 requests per second with 4096 hash rounds. + /// * 2 requests per second with 600000 hash rounds. + pub const DEFAULT_AUTH_SET: [Self; 3] = [ + Self::new(300 * 4096, Duration::from_secs(1)), + Self::new(200 * 4096, Duration::from_secs(60)), + Self::new(100 * 4096, Duration::from_secs(600)), + ]; + pub fn validate(info: &mut [Self]) -> anyhow::Result<()> { info.sort_unstable_by_key(|info| info.interval); let invalid = info @@ -150,7 +170,7 @@ impl RateBucketInfo { .find(|(a, b)| a.max_rpi > b.max_rpi); if let Some((a, b)) = invalid { bail!( - "invalid endpoint RPS limits. {b} allows fewer requests per bucket than {a} ({} vs {})", + "invalid bucket RPS limits. {b} allows fewer requests per bucket than {a} ({} vs {})", b.max_rpi, a.max_rpi, ); @@ -162,19 +182,24 @@ impl RateBucketInfo { pub const fn new(max_rps: u32, interval: Duration) -> Self { Self { interval, - max_rpi: max_rps * interval.as_millis() as u32 / 1000, + max_rpi: ((max_rps as u64) * (interval.as_millis() as u64) / 1000) as u32, } } } -impl EndpointRateLimiter { - pub fn new(info: &'static [RateBucketInfo]) -> Self { +impl BucketRateLimiter { + pub fn new(info: impl Into>) -> Self { Self::new_with_rand_and_hasher(info, StdRng::from_entropy(), RandomState::new()) } } -impl EndpointRateLimiter { - fn new_with_rand_and_hasher(info: &'static [RateBucketInfo], rand: R, hasher: S) -> Self { +impl BucketRateLimiter { + fn new_with_rand_and_hasher( + info: impl Into>, + rand: R, + hasher: S, + ) -> Self { + let info = info.into(); info!(buckets = ?info, "endpoint rate limiter"); Self { info, @@ -185,7 +210,7 @@ impl EndpointRateLimiter { } /// Check that number of connections to the endpoint is below `max_rps` rps. - pub fn check(&self, endpoint: EndpointId) -> bool { + pub fn check(&self, key: K, n: u32) -> bool { // do a partial GC every 2k requests. This cleans up ~ 1/64th of the map. // worst case memory usage is about: // = 2 * 2048 * 64 * (48B + 72B) @@ -195,7 +220,7 @@ impl EndpointRateLimiter { } let now = Instant::now(); - let mut entry = self.map.entry(endpoint).or_insert_with(|| { + let mut entry = self.map.entry(key).or_insert_with(|| { vec![ RateBucket { start: now, @@ -207,12 +232,12 @@ impl EndpointRateLimiter { let should_allow_request = entry .iter_mut() - .zip(self.info) - .all(|(bucket, info)| bucket.should_allow_request(info, now)); + .zip(&*self.info) + .all(|(bucket, info)| bucket.should_allow_request(info, now, n)); if should_allow_request { // only increment the bucket counts if the request will actually be accepted - entry.iter_mut().for_each(RateBucket::inc); + entry.iter_mut().for_each(|b| b.inc(n)); } should_allow_request @@ -223,7 +248,7 @@ impl EndpointRateLimiter { /// But that way deletion does not aquire mutex on each entry access. pub fn do_gc(&self) { info!( - "cleaning up endpoint rate limiter, current size = {}", + "cleaning up bucket rate limiter, current size = {}", self.map.len() ); let n = self.map.shards().len(); @@ -534,7 +559,7 @@ mod tests { use rustc_hash::FxHasher; use tokio::time; - use super::{EndpointRateLimiter, Limiter, Outcome}; + use super::{BucketRateLimiter, EndpointRateLimiter, Limiter, Outcome}; use crate::{ rate_limiter::{RateBucketInfo, RateLimitAlgorithm}, EndpointId, @@ -672,12 +697,12 @@ mod tests { #[test] fn default_rate_buckets() { - let mut defaults = RateBucketInfo::DEFAULT_SET; + let mut defaults = RateBucketInfo::DEFAULT_ENDPOINT_SET; RateBucketInfo::validate(&mut defaults[..]).unwrap(); } #[test] - #[should_panic = "invalid endpoint RPS limits. 10@10s allows fewer requests per bucket than 300@1s (100 vs 300)"] + #[should_panic = "invalid bucket RPS limits. 10@10s allows fewer requests per bucket than 300@1s (100 vs 300)"] fn rate_buckets_validate() { let mut rates: Vec = ["300@1s", "10@10s"] .into_iter() @@ -693,42 +718,42 @@ mod tests { .map(|s| s.parse().unwrap()) .collect(); RateBucketInfo::validate(&mut rates).unwrap(); - let limiter = EndpointRateLimiter::new(Vec::leak(rates)); + let limiter = EndpointRateLimiter::new(rates); let endpoint = EndpointId::from("ep-my-endpoint-1234"); time::pause(); for _ in 0..100 { - assert!(limiter.check(endpoint.clone())); + assert!(limiter.check(endpoint.clone(), 1)); } // more connections fail - assert!(!limiter.check(endpoint.clone())); + assert!(!limiter.check(endpoint.clone(), 1)); // fail even after 500ms as it's in the same bucket time::advance(time::Duration::from_millis(500)).await; - assert!(!limiter.check(endpoint.clone())); + assert!(!limiter.check(endpoint.clone(), 1)); // after a full 1s, 100 requests are allowed again time::advance(time::Duration::from_millis(500)).await; for _ in 1..6 { - for _ in 0..100 { - assert!(limiter.check(endpoint.clone())); + for _ in 0..50 { + assert!(limiter.check(endpoint.clone(), 2)); } time::advance(time::Duration::from_millis(1000)).await; } // more connections after 600 will exceed the 20rps@30s limit - assert!(!limiter.check(endpoint.clone())); + assert!(!limiter.check(endpoint.clone(), 1)); // will still fail before the 30 second limit time::advance(time::Duration::from_millis(30_000 - 6_000 - 1)).await; - assert!(!limiter.check(endpoint.clone())); + assert!(!limiter.check(endpoint.clone(), 1)); // after the full 30 seconds, 100 requests are allowed again time::advance(time::Duration::from_millis(1)).await; for _ in 0..100 { - assert!(limiter.check(endpoint.clone())); + assert!(limiter.check(endpoint.clone(), 1)); } } @@ -738,14 +763,41 @@ mod tests { let rand = rand::rngs::StdRng::from_seed([1; 32]); let hasher = BuildHasherDefault::::default(); - let limiter = EndpointRateLimiter::new_with_rand_and_hasher( - &RateBucketInfo::DEFAULT_SET, + let limiter = BucketRateLimiter::new_with_rand_and_hasher( + &RateBucketInfo::DEFAULT_ENDPOINT_SET, rand, hasher, ); for i in 0..1_000_000 { - limiter.check(format!("{i}").into()); + limiter.check(i, 1); } assert!(limiter.map.len() < 150_000); } + + #[test] + fn test_default_auth_set() { + // these values used to exceed u32::MAX + assert_eq!( + RateBucketInfo::DEFAULT_AUTH_SET, + [ + RateBucketInfo { + interval: Duration::from_secs(1), + max_rpi: 300 * 4096, + }, + RateBucketInfo { + interval: Duration::from_secs(60), + max_rpi: 200 * 4096 * 60, + }, + RateBucketInfo { + interval: Duration::from_secs(600), + max_rpi: 100 * 4096 * 600, + } + ] + ); + + for x in RateBucketInfo::DEFAULT_AUTH_SET { + let y = x.to_string().parse().unwrap(); + assert_eq!(x, y); + } + } } diff --git a/proxy/src/scram/secret.rs b/proxy/src/scram/secret.rs index f3414cb8ecc8..44c4f9e44aec 100644 --- a/proxy/src/scram/secret.rs +++ b/proxy/src/scram/secret.rs @@ -50,13 +50,13 @@ impl ServerSecret { /// To avoid revealing information to an attacker, we use a /// mocked server secret even if the user doesn't exist. /// See `auth-scram.c : mock_scram_secret` for details. - pub fn mock(user: &str, nonce: [u8; 32]) -> Self { - // Refer to `auth-scram.c : scram_mock_salt`. - let mocked_salt = super::sha256([user.as_bytes(), &nonce]); - + pub fn mock(nonce: [u8; 32]) -> Self { Self { - iterations: 4096, - salt_base64: base64::encode(mocked_salt), + // this doesn't reveal much information as we're going to use + // iteration count 1 for our generated passwords going forward. + // PG16 users can set iteration count=1 already today. + iterations: 1, + salt_base64: base64::encode(nonce), stored_key: ScramKey::default(), server_key: ScramKey::default(), doomed: true, diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 72b55c45f013..f10779d7ba0d 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -42,7 +42,12 @@ impl PoolingBackend { }; let secret = match cached_secret.value.clone() { - Some(secret) => secret, + Some(secret) => self.config.authentication_config.check_rate_limit( + ctx, + secret, + &user_info.endpoint, + true, + )?, None => { // If we don't have an authentication secret, for the http flow we can just return an error. info!("authentication info not found");