Skip to content

Commit

Permalink
proxy: update tokio-postgres to allow arbitrary config params (#8076)
Browse files Browse the repository at this point in the history
## Problem

Fixes #1287

## Summary of changes

tokio-postgres now supports arbitrary server params through the
`param(key, value)` method. Some keys are special so we explicitly
filter them out.
  • Loading branch information
conradludgate committed Jun 24, 2024
1 parent 75747cd commit 78d9059
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 92 deletions.
8 changes: 4 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

50 changes: 23 additions & 27 deletions libs/postgres_connection/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,20 +144,7 @@ impl PgConnectionConfig {
// implement and this function is hardly a bottleneck. The function is only called around
// establishing a new connection.
#[allow(unstable_name_collisions)]
config.options(
&self
.options
.iter()
.map(|s| {
if s.contains(['\\', ' ']) {
Cow::Owned(s.replace('\\', "\\\\").replace(' ', "\\ "))
} else {
Cow::Borrowed(s.as_str())
}
})
.intersperse(Cow::Borrowed(" ")) // TODO: use impl from std once it's stabilized
.collect::<String>(),
);
config.options(&encode_options(&self.options));
}
config
}
Expand All @@ -178,6 +165,21 @@ impl PgConnectionConfig {
}
}

#[allow(unstable_name_collisions)]
fn encode_options(options: &[String]) -> String {
options
.iter()
.map(|s| {
if s.contains(['\\', ' ']) {
Cow::Owned(s.replace('\\', "\\\\").replace(' ', "\\ "))
} else {
Cow::Borrowed(s.as_str())
}
})
.intersperse(Cow::Borrowed(" ")) // TODO: use impl from std once it's stabilized
.collect::<String>()
}

impl fmt::Display for PgConnectionConfig {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
// The password is intentionally hidden and not part of this display string.
Expand Down Expand Up @@ -206,7 +208,7 @@ impl fmt::Debug for PgConnectionConfig {

#[cfg(test)]
mod tests_pg_connection_config {
use crate::PgConnectionConfig;
use crate::{encode_options, PgConnectionConfig};
use once_cell::sync::Lazy;
use url::Host;

Expand Down Expand Up @@ -255,18 +257,12 @@ mod tests_pg_connection_config {

#[test]
fn test_with_options() {
let cfg = PgConnectionConfig::new_host_port(STUB_HOST.clone(), 123).extend_options([
"hello",
"world",
"with space",
"and \\ backslashes",
let options = encode_options(&[
"hello".to_owned(),
"world".to_owned(),
"with space".to_owned(),
"and \\ backslashes".to_owned(),
]);
assert_eq!(cfg.host(), &*STUB_HOST);
assert_eq!(cfg.port(), 123);
assert_eq!(cfg.raw_address(), "stub.host.example:123");
assert_eq!(
cfg.to_tokio_postgres_config().get_options(),
Some("hello world with\\ space and\\ \\\\\\ backslashes")
);
assert_eq!(options, "hello world with\\ space and\\ \\\\\\ backslashes");
}
}
129 changes: 68 additions & 61 deletions proxy/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,8 @@ impl ConnCfg {

/// Reuse password or auth keys from the other config.
pub fn reuse_password(&mut self, other: Self) {
if let Some(password) = other.get_password() {
self.password(password);
}

if let Some(keys) = other.get_auth_keys() {
self.auth_keys(keys);
if let Some(password) = other.get_auth() {
self.auth(password);
}
}

Expand All @@ -124,48 +120,64 @@ impl ConnCfg {

/// Apply startup message params to the connection config.
pub fn set_startup_params(&mut self, params: &StartupMessageParams) {
// Only set `user` if it's not present in the config.
// Link auth flow takes username from the console's response.
if let (None, Some(user)) = (self.get_user(), params.get("user")) {
self.user(user);
}

// Only set `dbname` if it's not present in the config.
// Link auth flow takes dbname from the console's response.
if let (None, Some(dbname)) = (self.get_dbname(), params.get("database")) {
self.dbname(dbname);
}

// Don't add `options` if they were only used for specifying a project.
// Connection pools don't support `options`, because they affect backend startup.
if let Some(options) = filtered_options(params) {
self.options(&options);
}

if let Some(app_name) = params.get("application_name") {
self.application_name(app_name);
}

// TODO: This is especially ugly...
if let Some(replication) = params.get("replication") {
use tokio_postgres::config::ReplicationMode;
match replication {
"true" | "on" | "yes" | "1" => {
self.replication_mode(ReplicationMode::Physical);
let mut client_encoding = false;
for (k, v) in params.iter() {
match k {
"user" => {
// Only set `user` if it's not present in the config.
// Link auth flow takes username from the console's response.
if self.get_user().is_none() {
self.user(v);
}
}
"database" => {
self.replication_mode(ReplicationMode::Logical);
// Only set `dbname` if it's not present in the config.
// Link auth flow takes dbname from the console's response.
if self.get_dbname().is_none() {
self.dbname(v);
}
}
"options" => {
// Don't add `options` if they were only used for specifying a project.
// Connection pools don't support `options`, because they affect backend startup.
if let Some(options) = filtered_options(v) {
self.options(&options);
}
}

// the special ones in tokio-postgres that we don't want being set by the user
"dbname" => {}
"password" => {}
"sslmode" => {}
"host" => {}
"port" => {}
"connect_timeout" => {}
"keepalives" => {}
"keepalives_idle" => {}
"keepalives_interval" => {}
"keepalives_retries" => {}
"target_session_attrs" => {}
"channel_binding" => {}
"max_backend_message_size" => {}

"client_encoding" => {
client_encoding = true;
// only error should be from bad null bytes,
// but we've already checked for those.
_ = self.param("client_encoding", v);
}

_ => {
// only error should be from bad null bytes,
// but we've already checked for those.
_ = self.param(k, v);
}
_other => {}
}
}

// TODO: extend the list of the forwarded startup parameters.
// Currently, tokio-postgres doesn't allow us to pass
// arbitrary parameters, but the ones above are a good start.
//
// This and the reverse params problem can be better addressed
// in a bespoke connection machinery (a new library for that sake).
if !client_encoding {
// for compatibility since we removed it from tokio-postgres
self.param("client_encoding", "UTF8").unwrap();
}
}
}

Expand Down Expand Up @@ -338,10 +350,9 @@ impl ConnCfg {
}

/// Retrieve `options` from a startup message, dropping all proxy-secific flags.
fn filtered_options(params: &StartupMessageParams) -> Option<String> {
fn filtered_options(options: &str) -> Option<String> {
#[allow(unstable_name_collisions)]
let options: String = params
.options_raw()?
let options: String = StartupMessageParams::parse_options_raw(options)
.filter(|opt| parse_endpoint_param(opt).is_none() && neon_option(opt).is_none())
.intersperse(" ") // TODO: use impl from std once it's stabilized
.collect();
Expand Down Expand Up @@ -413,27 +424,23 @@ mod tests {
#[test]
fn test_filtered_options() {
// Empty options is unlikely to be useful anyway.
let params = StartupMessageParams::new([("options", "")]);
assert_eq!(filtered_options(&params), None);
assert_eq!(filtered_options(""), None);

// It's likely that clients will only use options to specify endpoint/project.
let params = StartupMessageParams::new([("options", "project=foo")]);
assert_eq!(filtered_options(&params), None);
let params = "project=foo";
assert_eq!(filtered_options(params), None);

// Same, because unescaped whitespaces are no-op.
let params = StartupMessageParams::new([("options", " project=foo ")]);
assert_eq!(filtered_options(&params).as_deref(), None);
let params = " project=foo ";
assert_eq!(filtered_options(params), None);

let params = StartupMessageParams::new([("options", r"\ project=foo \ ")]);
assert_eq!(filtered_options(&params).as_deref(), Some(r"\ \ "));
let params = r"\ project=foo \ ";
assert_eq!(filtered_options(params).as_deref(), Some(r"\ \ "));

let params = StartupMessageParams::new([("options", "project = foo")]);
assert_eq!(filtered_options(&params).as_deref(), Some("project = foo"));
let params = "project = foo";
assert_eq!(filtered_options(params).as_deref(), Some("project = foo"));

let params = StartupMessageParams::new([(
"options",
"project = foo neon_endpoint_type:read_write neon_lsn:0/2",
)]);
assert_eq!(filtered_options(&params).as_deref(), Some("project = foo"));
let params = "project = foo neon_endpoint_type:read_write neon_lsn:0/2";
assert_eq!(filtered_options(params).as_deref(), Some("project = foo"));
}
}
4 changes: 4 additions & 0 deletions proxy/src/serverless/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,10 @@ impl ConnectMechanism for TokioMechanism {
.dbname(&self.conn_info.dbname)
.connect_timeout(timeout);

config
.param("client_encoding", "UTF8")
.expect("client encoding UTF8 is always valid");

let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Compute);
let res = config.connect(tokio_postgres::NoTls).await;
drop(pause);
Expand Down
1 change: 1 addition & 0 deletions proxy/src/serverless/sql_over_http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ fn get_conn_info(
options = Some(NeonOptions::parse_options_raw(&value));
}
}
ctx.set_db_options(params.freeze());

let user_info = ComputeUserInfo {
endpoint,
Expand Down
19 changes: 19 additions & 0 deletions test_runner/regress/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,25 @@ def test_proxy_select_1(static_proxy: NeonProxy):
assert out[0][0] == 42


def test_proxy_server_params(static_proxy: NeonProxy):
"""
Test that server params are passing through to postgres
"""

out = static_proxy.safe_psql(
"select to_json('0 seconds'::interval)", options="-c intervalstyle=iso_8601"
)
assert out[0][0] == "PT0S"
out = static_proxy.safe_psql(
"select to_json('0 seconds'::interval)", options="-c intervalstyle=sql_standard"
)
assert out[0][0] == "0"
out = static_proxy.safe_psql(
"select to_json('0 seconds'::interval)", options="-c intervalstyle=postgres"
)
assert out[0][0] == "00:00:00"


def test_password_hack(static_proxy: NeonProxy):
"""
Check the PasswordHack auth flow: an alternative to SCRAM auth for
Expand Down

1 comment on commit 78d9059

@github-actions
Copy link

@github-actions github-actions bot commented on 78d9059 Jun 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3035 tests run: 2909 passed, 0 failed, 126 skipped (full report)


Flaky tests (2)

Postgres 16

  • test_isolation[None]: debug

Postgres 14

  • test_storage_controller_many_tenants[github-actions-selfhosted]: release

Code coverage* (full report)

  • functions: 32.4% (6869 of 21174 functions)
  • lines: 49.9% (53391 of 107090 lines)

* collected from Rust tests only


The comment gets automatically updated with the latest test results
78d9059 at 2024-06-24T12:40:25.452Z :recycle:

Please sign in to comment.