diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index 7dc304d7ac7f..d7a8edca7947 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -1,6 +1,8 @@ //! User credentials used in authentication. -use crate::{auth::password_hack::parse_endpoint_param, error::UserFacingError}; +use crate::{ + auth::password_hack::parse_endpoint_param, error::UserFacingError, proxy::neon_options, +}; use itertools::Itertools; use pq_proto::StartupMessageParams; use std::collections::HashSet; @@ -38,6 +40,8 @@ pub struct ClientCredentials<'a> { pub user: &'a str, // TODO: this is a severe misnomer! We should think of a new name ASAP. pub project: Option, + + pub cache_key: String, } impl ClientCredentials<'_> { @@ -53,6 +57,7 @@ impl<'a> ClientCredentials<'a> { ClientCredentials { user: "", project: None, + cache_key: "".to_string(), } } @@ -120,7 +125,17 @@ impl<'a> ClientCredentials<'a> { info!(user, project = project.as_deref(), "credentials"); - Ok(Self { user, project }) + let cache_key = format!( + "{}{}", + project.as_deref().unwrap_or(""), + neon_options(params).unwrap_or("".to_string()) + ); + + Ok(Self { + user, + project, + cache_key, + }) } } @@ -176,6 +191,7 @@ mod tests { let creds = ClientCredentials::parse(&options, sni, common_names)?; assert_eq!(creds.user, "john_doe"); assert_eq!(creds.project.as_deref(), Some("foo")); + assert_eq!(creds.cache_key, "foo"); Ok(()) } @@ -303,4 +319,23 @@ mod tests { _ => panic!("bad error: {err:?}"), } } + + #[test] + fn parse_neon_options() -> anyhow::Result<()> { + let options = StartupMessageParams::new([ + ("user", "john_doe"), + ("options", "neon_lsn:0/2 neon_endpoint_type:read_write"), + ]); + + let sni = Some("project.localhost"); + let common_names = Some(["localhost".into()].into()); + let creds = ClientCredentials::parse(&options, sni, common_names)?; + assert_eq!(creds.project.as_deref(), Some("project")); + assert_eq!( + creds.cache_key, + "projectneon_endpoint_type:read_write neon_lsn:0/2" + ); + + Ok(()) + } } diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index e96b79ed924a..53eb0e3a76a8 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -3,6 +3,7 @@ use crate::{ cancellation::CancelClosure, console::errors::WakeComputeError, error::{io_error, UserFacingError}, + proxy::is_neon_param, }; use futures::{FutureExt, TryFutureExt}; use itertools::Itertools; @@ -278,7 +279,7 @@ fn filtered_options(params: &StartupMessageParams) -> Option { #[allow(unstable_name_collisions)] let options: String = params .options_raw()? - .filter(|opt| parse_endpoint_param(opt).is_none()) + .filter(|opt| parse_endpoint_param(opt).is_none() && !is_neon_param(opt)) .intersperse(" ") // TODO: use impl from std once it's stabilized .collect(); @@ -313,5 +314,11 @@ mod tests { let params = StartupMessageParams::new([("options", "project = foo")]); assert_eq!(filtered_options(¶ms).as_deref(), Some("project = foo")); + + let params = StartupMessageParams::new([( + "options", + "project = foo neon_endpoint_type:read_write neon_lsn:0/2", + )]); + assert_eq!(filtered_options(¶ms).as_deref(), Some("project = foo")); } } diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index 32c34670923f..c7cfc88c759e 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -178,6 +178,7 @@ pub struct ConsoleReqExtra<'a> { pub session_id: uuid::Uuid, /// Name of client application, if set. pub application_name: Option<&'a str>, + pub options: Option<&'a str>, } /// Auth secret which is managed by the cloud. diff --git a/proxy/src/console/provider/neon.rs b/proxy/src/console/provider/neon.rs index 927fea0a134a..6229840c466c 100644 --- a/proxy/src/console/provider/neon.rs +++ b/proxy/src/console/provider/neon.rs @@ -99,6 +99,7 @@ impl Api { .query(&[ ("application_name", extra.application_name), ("project", Some(project)), + ("options", extra.options), ]) .build()?; @@ -151,7 +152,7 @@ impl super::Api for Api { extra: &ConsoleReqExtra<'_>, creds: &ClientCredentials, ) -> Result { - let key = creds.project().expect("impossible"); + let key: &str = &creds.cache_key; // Every time we do a wakeup http request, the compute node will stay up // for some time (highly depends on the console's scale-to-zero policy); diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 884aae165103..54c3503c9305 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -15,10 +15,12 @@ use crate::{ use anyhow::{bail, Context}; use async_trait::async_trait; use futures::TryFutureExt; +use itertools::Itertools; use metrics::{exponential_buckets, register_int_counter_vec, IntCounterVec}; -use once_cell::sync::Lazy; +use once_cell::sync::{Lazy, OnceCell}; use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams}; use prometheus::{register_histogram_vec, HistogramVec}; +use regex::Regex; use std::{error::Error, io, ops::ControlFlow, sync::Arc, time::Instant}; use tokio::{ io::{AsyncRead, AsyncWrite, AsyncWriteExt}, @@ -881,9 +883,12 @@ impl Client<'_, S> { allow_self_signed_compute, } = self; + let console_options = neon_options(params); + let extra = console::ConsoleReqExtra { session_id, // aka this connection's id application_name: params.get("application_name"), + options: console_options.as_deref(), }; let mut latency_timer = LatencyTimer::new(mode.protocol_label()); @@ -945,3 +950,27 @@ impl Client<'_, S> { proxy_pass(stream, node.stream, &aux).await } } + +pub fn neon_options(params: &StartupMessageParams) -> Option { + #[allow(unstable_name_collisions)] + let options: String = params + .options_raw()? + .filter(|opt| is_neon_param(opt)) + .sorted() // we sort it to use as cache key + .intersperse(" ") // TODO: use impl from std once it's stabilized + .collect(); + + // Don't even bother with empty options. + if options.is_empty() { + return None; + } + + Some(options) +} + +pub fn is_neon_param(bytes: &str) -> bool { + static RE: OnceCell = OnceCell::new(); + RE.get_or_init(|| Regex::new(r"^neon_\w+:").unwrap()); + + RE.get().unwrap().is_match(bytes) +} diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index 142c32fb84da..3ae4df46ef83 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -440,6 +440,7 @@ fn helper_create_connect_info( let extra = console::ConsoleReqExtra { session_id: uuid::Uuid::new_v4(), application_name: Some("TEST"), + options: None, }; let creds = auth::BackendType::Test(mechanism); (cache, extra, creds) diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index c5bfc325682c..d09554a922b8 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -22,7 +22,10 @@ use tokio_postgres::{AsyncMessage, ReadyForQueryStatus}; use crate::{ auth, console, - proxy::{LatencyTimer, NUM_DB_CONNECTIONS_CLOSED_COUNTER, NUM_DB_CONNECTIONS_OPENED_COUNTER}, + proxy::{ + neon_options, LatencyTimer, NUM_DB_CONNECTIONS_CLOSED_COUNTER, + NUM_DB_CONNECTIONS_OPENED_COUNTER, + }, usage_metrics::{Ids, MetricCounter, USAGE_METRICS}, }; use crate::{compute, config}; @@ -41,6 +44,7 @@ pub struct ConnInfo { pub dbname: String, pub hostname: String, pub password: String, + pub options: Option, } impl ConnInfo { @@ -401,26 +405,25 @@ async fn connect_to_compute( let tls = config.tls_config.as_ref(); let common_names = tls.and_then(|tls| tls.common_names.clone()); - let credential_params = StartupMessageParams::new([ + let params = StartupMessageParams::new([ ("user", &conn_info.username), ("database", &conn_info.dbname), ("application_name", APP_NAME), + ("options", conn_info.options.as_deref().unwrap_or("")), ]); let creds = config .auth_backend .as_ref() - .map(|_| { - auth::ClientCredentials::parse( - &credential_params, - Some(&conn_info.hostname), - common_names, - ) - }) + .map(|_| auth::ClientCredentials::parse(¶ms, Some(&conn_info.hostname), common_names)) .transpose()?; + + let console_options = neon_options(¶ms); + let extra = console::ConsoleReqExtra { session_id: uuid::Uuid::new_v4(), application_name: Some(APP_NAME), + options: console_options.as_deref(), }; let node_info = creds diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 8f7cf7fbaf10..16736ac00de6 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -174,11 +174,23 @@ fn get_conn_info( } } + let pairs = connection_url.query_pairs(); + + let mut options = Option::None; + + for (key, value) in pairs { + if key == "options" { + options = Some(value.to_string()); + break; + } + } + Ok(ConnInfo { username: username.to_owned(), dbname: dbname.to_owned(), hostname: hostname.to_owned(), password: password.to_owned(), + options, }) }