Skip to content

Commit

Permalink
add authentication rate limiting (#6865)
Browse files Browse the repository at this point in the history
## Problem

neondatabase/cloud#9642

## Summary of changes

1. Make `EndpointRateLimiter` generic, renamed as `BucketRateLimiter`
2. Add support for claiming multiple tokens at once
3. Add `AuthRateLimiter` alias.
4. Check `(Endpoint, IP)` pair during authentication, weighted by how
many hashes proxy would be doing.

TODO: handle ipv6 subnets. will do this in a separate PR.
  • Loading branch information
conradludgate committed Mar 26, 2024
1 parent b3b7ce4 commit 12512f3
Show file tree
Hide file tree
Showing 13 changed files with 241 additions and 99 deletions.
2 changes: 1 addition & 1 deletion libs/metrics/src/hll.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ macro_rules! register_hll {
}};

($N:literal, $NAME:expr, $HELP:expr $(,)?) => {{
$crate::register_hll!($N, $crate::opts!($NAME, $HELP), $LABELS_NAMES)
$crate::register_hll!($N, $crate::opts!($NAME, $HELP))
}};
}

Expand Down
90 changes: 76 additions & 14 deletions proxy/src/auth/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ use crate::console::errors::GetAuthInfoError;
use crate::console::provider::{CachedRoleSecret, ConsoleBackend};
use crate::console::{AuthSecret, NodeInfo};
use crate::context::RequestMonitoring;
use crate::intern::EndpointIdInt;
use crate::metrics::{AUTH_RATE_LIMIT_HITS, ENDPOINTS_AUTH_RATE_LIMITED};
use crate::proxy::connect_compute::ComputeConnectBackend;
use crate::proxy::NeonOptions;
use crate::stream::Stream;
Expand All @@ -28,7 +30,7 @@ use crate::{
use crate::{scram, EndpointCacheKey, EndpointId, RoleName};
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
use tracing::{info, warn};

/// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality
pub enum MaybeOwned<'a, T> {
Expand Down Expand Up @@ -174,6 +176,52 @@ impl TryFrom<ComputeUserInfoMaybeEndpoint> for ComputeUserInfo {
}
}

impl AuthenticationConfig {
pub fn check_rate_limit(
&self,

ctx: &mut RequestMonitoring,
secret: AuthSecret,
endpoint: &EndpointId,
is_cleartext: bool,
) -> auth::Result<AuthSecret> {
// we have validated the endpoint exists, so let's intern it.
let endpoint_int = EndpointIdInt::from(endpoint);

// only count the full hash count if password hack or websocket flow.
// in other words, if proxy needs to run the hashing
let password_weight = if is_cleartext {
match &secret {
#[cfg(any(test, feature = "testing"))]
AuthSecret::Md5(_) => 1,
AuthSecret::Scram(s) => s.iterations + 1,
}
} else {
// validating scram takes just 1 hmac_sha_256 operation.
1
};

let limit_not_exceeded = self
.rate_limiter
.check((endpoint_int, ctx.peer_addr), password_weight);

if !limit_not_exceeded {
warn!(
enabled = self.rate_limiter_enabled,
"rate limiting authentication"
);
AUTH_RATE_LIMIT_HITS.inc();
ENDPOINTS_AUTH_RATE_LIMITED.measure(endpoint);

if self.rate_limiter_enabled {
return Err(auth::AuthError::too_many_connections());
}
}

Ok(secret)
}
}

/// True to its name, this function encapsulates our current auth trade-offs.
/// Here, we choose the appropriate auth flow based on circumstances.
///
Expand Down Expand Up @@ -214,14 +262,24 @@ async fn auth_quirks(
Some(secret) => secret,
None => api.get_role_secret(ctx, &info).await?,
};
let (cached_entry, secret) = cached_secret.take_value();

let secret = match secret {
Some(secret) => config.check_rate_limit(
ctx,
secret,
&info.endpoint,
unauthenticated_password.is_some() || allow_cleartext,
)?,
None => {
// If we don't have an authentication secret, we mock one to
// prevent malicious probing (possible due to missing protocol steps).
// This mocked secret will never lead to successful authentication.
info!("authentication info not found, mocking it");
AuthSecret::Scram(scram::ServerSecret::mock(rand::random()))
}
};

let secret = cached_secret.value.clone().unwrap_or_else(|| {
// If we don't have an authentication secret, we mock one to
// prevent malicious probing (possible due to missing protocol steps).
// This mocked secret will never lead to successful authentication.
info!("authentication info not found, mocking it");
AuthSecret::Scram(scram::ServerSecret::mock(&info.user, rand::random()))
});
match authenticate_with_secret(
ctx,
secret,
Expand All @@ -237,7 +295,7 @@ async fn auth_quirks(
Err(e) => {
if e.is_auth_failed() {
// The password could have been changed, so we invalidate the cache.
cached_secret.invalidate();
cached_entry.invalidate();
}
Err(e)
}
Expand Down Expand Up @@ -415,6 +473,7 @@ mod tests {

use bytes::BytesMut;
use fallible_iterator::FallibleIterator;
use once_cell::sync::Lazy;
use postgres_protocol::{
authentication::sasl::{ChannelBinding, ScramSha256},
message::{backend::Message as PgMessage, frontend},
Expand All @@ -432,6 +491,7 @@ mod tests {
},
context::RequestMonitoring,
proxy::NeonOptions,
rate_limiter::{AuthRateLimiter, RateBucketInfo},
scram::ServerSecret,
stream::{PqStream, Stream},
};
Expand Down Expand Up @@ -473,9 +533,11 @@ mod tests {
}
}

static CONFIG: &AuthenticationConfig = &AuthenticationConfig {
static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
scram_protocol_timeout: std::time::Duration::from_secs(5),
};
rate_limiter_enabled: true,
rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),
});

async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage {
loop {
Expand Down Expand Up @@ -544,7 +606,7 @@ mod tests {
}
});

let _creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, false, CONFIG)
let _creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, false, &CONFIG)
.await
.unwrap();

Expand Down Expand Up @@ -584,7 +646,7 @@ mod tests {
client.write_all(&write).await.unwrap();
});

let _creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, true, CONFIG)
let _creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, true, &CONFIG)
.await
.unwrap();

Expand Down Expand Up @@ -624,7 +686,7 @@ mod tests {
client.write_all(&write).await.unwrap();
});

let creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, true, CONFIG)
let creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, true, &CONFIG)
.await
.unwrap();

Expand Down
13 changes: 11 additions & 2 deletions proxy/src/bin/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use proxy::console;
use proxy::context::parquet::ParquetUploadArgs;
use proxy::http;
use proxy::metrics::NUM_CANCELLATION_REQUESTS_SOURCE_FROM_CLIENT;
use proxy::rate_limiter::AuthRateLimiter;
use proxy::rate_limiter::EndpointRateLimiter;
use proxy::rate_limiter::RateBucketInfo;
use proxy::rate_limiter::RateLimiterConfig;
Expand Down Expand Up @@ -141,10 +142,16 @@ struct ProxyCliArgs {
///
/// Provided in the form '<Requests Per Second>@<Bucket Duration Size>'.
/// Can be given multiple times for different bucket sizes.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)]
endpoint_rps_limit: Vec<RateBucketInfo>,
/// Whether the auth rate limiter actually takes effect (for testing)
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
auth_rate_limit_enabled: bool,
/// Authentication rate limiter max number of hashes per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_AUTH_SET)]
auth_rate_limit: Vec<RateBucketInfo>,
/// Redis rate limiter max number of requests per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)]
redis_rps_limit: Vec<RateBucketInfo>,
/// Initial limit for dynamic rate limiter. Makes sense only if `rate_limit_algorithm` is *not* `None`.
#[clap(long, default_value_t = 100)]
Expand Down Expand Up @@ -510,6 +517,8 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
};
let authentication_config = AuthenticationConfig {
scram_protocol_timeout: args.scram_protocol_timeout,
rate_limiter_enabled: args.auth_rate_limit_enabled,
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),
};

let mut endpoint_rps_limit = args.endpoint_rps_limit.clone();
Expand Down
10 changes: 10 additions & 0 deletions proxy/src/cache/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,16 @@ impl<C: Cache, V> Cached<C, V> {
Self { token: None, value }
}

pub fn take_value(self) -> (Cached<C, ()>, V) {
(
Cached {
token: self.token,
value: (),
},
self.value,
)
}

/// Drop this entry from a cache if it's still there.
pub fn invalidate(self) -> V {
if let Some((cache, info)) = &self.token {
Expand Down
30 changes: 6 additions & 24 deletions proxy/src/cache/project_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,10 +373,7 @@ mod tests {
let endpoint_id = "endpoint".into();
let user1: RoleName = "user1".into();
let user2: RoleName = "user2".into();
let secret1 = Some(AuthSecret::Scram(ServerSecret::mock(
user1.as_str(),
[1; 32],
)));
let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
let secret2 = None;
let allowed_ips = Arc::new(vec![
"127.0.0.1".parse().unwrap(),
Expand All @@ -395,10 +392,7 @@ mod tests {

// Shouldn't add more than 2 roles.
let user3: RoleName = "user3".into();
let secret3 = Some(AuthSecret::Scram(ServerSecret::mock(
user3.as_str(),
[3; 32],
)));
let secret3 = Some(AuthSecret::Scram(ServerSecret::mock([3; 32])));
cache.insert_role_secret(&project_id, &endpoint_id, &user3, secret3.clone());
assert!(cache.get_role_secret(&endpoint_id, &user3).is_none());

Expand Down Expand Up @@ -431,14 +425,8 @@ mod tests {
let endpoint_id = "endpoint".into();
let user1: RoleName = "user1".into();
let user2: RoleName = "user2".into();
let secret1 = Some(AuthSecret::Scram(ServerSecret::mock(
user1.as_str(),
[1; 32],
)));
let secret2 = Some(AuthSecret::Scram(ServerSecret::mock(
user2.as_str(),
[2; 32],
)));
let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
let secret2 = Some(AuthSecret::Scram(ServerSecret::mock([2; 32])));
let allowed_ips = Arc::new(vec![
"127.0.0.1".parse().unwrap(),
"127.0.0.2".parse().unwrap(),
Expand Down Expand Up @@ -486,14 +474,8 @@ mod tests {
let endpoint_id = "endpoint".into();
let user1: RoleName = "user1".into();
let user2: RoleName = "user2".into();
let secret1 = Some(AuthSecret::Scram(ServerSecret::mock(
user1.as_str(),
[1; 32],
)));
let secret2 = Some(AuthSecret::Scram(ServerSecret::mock(
user2.as_str(),
[2; 32],
)));
let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
let secret2 = Some(AuthSecret::Scram(ServerSecret::mock([2; 32])));
let allowed_ips = Arc::new(vec![
"127.0.0.1".parse().unwrap(),
"127.0.0.2".parse().unwrap(),
Expand Down
8 changes: 7 additions & 1 deletion proxy/src/config.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
use crate::{auth, rate_limiter::RateBucketInfo, serverless::GlobalConnPoolOptions};
use crate::{
auth,
rate_limiter::{AuthRateLimiter, RateBucketInfo},
serverless::GlobalConnPoolOptions,
};
use anyhow::{bail, ensure, Context, Ok};
use itertools::Itertools;
use rustls::{
Expand Down Expand Up @@ -50,6 +54,8 @@ pub struct HttpConfig {

pub struct AuthenticationConfig {
pub scram_protocol_timeout: tokio::time::Duration,
pub rate_limiter_enabled: bool,
pub rate_limiter: AuthRateLimiter,
}

impl TlsConfig {
Expand Down
22 changes: 21 additions & 1 deletion proxy/src/metrics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ use ::metrics::{
register_int_gauge_vec, Histogram, HistogramVec, HyperLogLogVec, IntCounterPairVec,
IntCounterVec, IntGauge, IntGaugeVec,
};
use metrics::{register_int_counter, register_int_counter_pair, IntCounter, IntCounterPair};
use metrics::{
register_hll, register_int_counter, register_int_counter_pair, HyperLogLog, IntCounter,
IntCounterPair,
};

use once_cell::sync::Lazy;
use tokio::time::{self, Instant};
Expand Down Expand Up @@ -358,3 +361,20 @@ pub static TLS_HANDSHAKE_FAILURES: Lazy<IntCounter> = Lazy::new(|| {
)
.unwrap()
});

pub static ENDPOINTS_AUTH_RATE_LIMITED: Lazy<HyperLogLog<32>> = Lazy::new(|| {
register_hll!(
32,
"proxy_endpoints_auth_rate_limits",
"Number of endpoints affected by authentication rate limits",
)
.unwrap()
});

pub static AUTH_RATE_LIMIT_HITS: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!(
"proxy_requests_auth_rate_limits_total",
"Number of connection requests affected by authentication rate limits",
)
.unwrap()
});
2 changes: 1 addition & 1 deletion proxy/src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(

// check rate limit
if let Some(ep) = user_info.get_endpoint() {
if !endpoint_rate_limiter.check(ep) {
if !endpoint_rate_limiter.check(ep, 1) {
return stream
.throw_error(auth::AuthError::too_many_connections())
.await?;
Expand Down
10 changes: 3 additions & 7 deletions proxy/src/proxy/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ impl Scram {
Ok(Scram(secret))
}

fn mock(user: &str) -> Self {
Scram(scram::ServerSecret::mock(user, rand::random()))
fn mock() -> Self {
Scram(scram::ServerSecret::mock(rand::random()))
}
}

Expand Down Expand Up @@ -330,11 +330,7 @@ async fn scram_auth_mock() -> anyhow::Result<()> {

let (client_config, server_config) =
generate_tls_config("generic-project-name.localhost", "localhost")?;
let proxy = tokio::spawn(dummy_proxy(
client,
Some(server_config),
Scram::mock("user"),
));
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), Scram::mock()));

use rand::{distributions::Alphanumeric, Rng};
let password: String = rand::thread_rng()
Expand Down
2 changes: 1 addition & 1 deletion proxy/src/rate_limiter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ mod limiter;
pub use aimd::Aimd;
pub use limit_algorithm::{AimdConfig, Fixed, RateLimitAlgorithm, RateLimiterConfig};
pub use limiter::Limiter;
pub use limiter::{EndpointRateLimiter, RateBucketInfo, RedisRateLimiter};
pub use limiter::{AuthRateLimiter, EndpointRateLimiter, RateBucketInfo, RedisRateLimiter};
Loading

1 comment on commit 12512f3

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2810 tests run: 2657 passed, 0 failed, 153 skipped (full report)


Flaky tests (1)

Postgres 15

  • test_deletion_queue_recovery[no-validate-lose]: debug

Code coverage* (full report)

  • functions: 28.2% (6306 of 22367 functions)
  • lines: 47.0% (44287 of 94291 lines)

* collected from Rust tests only


The comment gets automatically updated with the latest test results
12512f3 at 2024-03-26T20:38:19.060Z :recycle:

Please sign in to comment.