Skip to content

Commit

Permalink
Forward various connection params to compute nodes. (#2336)
Browse files Browse the repository at this point in the history
Previously, proxy didn't forward auxiliary `options` parameter
and other ones to the client's compute node, e.g.

```
$ psql "user=john host=localhost dbname=postgres options='-cgeqo=off'"
postgres=# show geqo;
┌──────┐
│ geqo │
├──────┤
│ on   │
└──────┘
(1 row)
```

With this patch we now forward `options`, `application_name` and `replication`.

Further reading: https://www.postgresql.org/docs/current/libpq-connect.html

Fixes #1287.
  • Loading branch information
funbringer committed Aug 30, 2022
1 parent 60408db commit 96a50e9
Show file tree
Hide file tree
Showing 13 changed files with 272 additions and 128 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

159 changes: 122 additions & 37 deletions libs/utils/src/pq_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@ use anyhow::{bail, ensure, Context, Result};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use postgres_protocol::PG_EPOCH;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::future::Future;
use std::io::{self, Cursor};
use std::str;
use std::time::{Duration, SystemTime};
use std::{
borrow::Cow,
collections::HashMap,
future::Future,
io::{self, Cursor},
str,
time::{Duration, SystemTime},
};
use tokio::io::AsyncReadExt;
use tracing::{trace, warn};

Expand Down Expand Up @@ -53,7 +56,67 @@ pub enum FeStartupPacket {
},
}

pub type StartupMessageParams = HashMap<String, String>;
#[derive(Debug)]
pub struct StartupMessageParams {
params: HashMap<String, String>,
}

impl StartupMessageParams {
/// Get parameter's value by its name.
pub fn get(&self, name: &str) -> Option<&str> {
self.params.get(name).map(|s| s.as_str())
}

/// Split command-line options according to PostgreSQL's logic,
/// taking into account all escape sequences but leaving them as-is.
/// [`None`] means that there's no `options` in [`Self`].
pub fn options_raw(&self) -> Option<impl Iterator<Item = &str>> {
// See `postgres: pg_split_opts`.
let mut last_was_escape = false;
let iter = self
.get("options")?
.split(move |c: char| {
// We split by non-escaped whitespace symbols.
let should_split = c.is_ascii_whitespace() && !last_was_escape;
last_was_escape = c == '\\' && !last_was_escape;
should_split
})
.filter(|s| !s.is_empty());

Some(iter)
}

/// Split command-line options according to PostgreSQL's logic,
/// applying all escape sequences (using owned strings as needed).
/// [`None`] means that there's no `options` in [`Self`].
pub fn options_escaped(&self) -> Option<impl Iterator<Item = Cow<'_, str>>> {
// See `postgres: pg_split_opts`.
let iter = self.options_raw()?.map(|s| {
let mut preserve_next_escape = false;
let escape = |c| {
// We should remove '\\' unless it's preceded by '\\'.
let should_remove = c == '\\' && !preserve_next_escape;
preserve_next_escape = should_remove;
should_remove
};

match s.contains('\\') {
true => Cow::Owned(s.replace(escape, "")),
false => Cow::Borrowed(s),
}
});

Some(iter)
}

// This function is mostly useful in tests.
#[doc(hidden)]
pub fn new<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> Self {
Self {
params: pairs.map(|(k, v)| (k.to_owned(), v.to_owned())).into(),
}
}
}

#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
pub struct CancelKeyData {
Expand Down Expand Up @@ -237,9 +300,9 @@ impl FeStartupPacket {
stream.read_exact(params_bytes.as_mut()).await?;

// Parse params depending on request code
let most_sig_16_bits = request_code >> 16;
let least_sig_16_bits = request_code & ((1 << 16) - 1);
let message = match (most_sig_16_bits, least_sig_16_bits) {
let req_hi = request_code >> 16;
let req_lo = request_code & ((1 << 16) - 1);
let message = match (req_hi, req_lo) {
(RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
ensure!(params_len == 8, "expected 8 bytes for CancelRequest params");
let mut cursor = Cursor::new(params_bytes);
Expand All @@ -248,49 +311,44 @@ impl FeStartupPacket {
cancel_key: cursor.read_i32().await?,
})
}
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => FeStartupPacket::SslRequest,
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
// Requested upgrade to SSL (aka TLS)
FeStartupPacket::SslRequest
}
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => {
// Requested upgrade to GSSAPI
FeStartupPacket::GssEncRequest
}
(RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
bail!("Unrecognized request code {}", unrecognized_code)
}
// TODO bail if protocol major_version is not 3?
(major_version, minor_version) => {
// TODO bail if protocol major_version is not 3?
// Parse null-terminated (String) pairs of param name / param value
let params_str = str::from_utf8(&params_bytes).unwrap();
let mut params_tokens = params_str.split('\0');
let mut params: HashMap<String, String> = HashMap::new();
while let Some(name) = params_tokens.next() {
let value = params_tokens
// Parse pairs of null-terminated strings (key, value).
// See `postgres: ProcessStartupPacket, build_startup_packet`.
let mut tokens = str::from_utf8(&params_bytes)
.context("StartupMessage params: invalid utf-8")?
.strip_suffix('\0') // drop packet's own null terminator
.context("StartupMessage params: missing null terminator")?
.split_terminator('\0');

let mut params = HashMap::new();
while let Some(name) = tokens.next() {
let value = tokens
.next()
.context("expected even number of params in StartupMessage")?;
if name == "options" {
// parsing options arguments "...&options=<var0>%3D<val0>+<var1>=<var1>..."
// '%3D' is '=' and '+' is ' '

// Note: we allow users that don't have SNI capabilities,
// to pass a special keyword argument 'project'
// to be used to determine the cluster name by the proxy.

//TODO: write unit test for this and refactor in its own function.
for cmdopt in value.split(' ') {
let nameval: Vec<&str> = cmdopt.split('=').collect();
if nameval.len() == 2 {
params.insert(nameval[0].to_string(), nameval[1].to_string());
}
}
} else {
params.insert(name.to_string(), value.to_string());
}
.context("StartupMessage params: key without value")?;

params.insert(name.to_owned(), value.to_owned());
}

FeStartupPacket::StartupMessage {
major_version,
minor_version,
params,
params: StartupMessageParams { params },
}
}
};

Ok(Some(FeMessage::StartupPacket(message)))
})
}
Expand Down Expand Up @@ -967,6 +1025,33 @@ mod tests {
assert_eq!(zf, zf_parsed);
}

#[test]
fn test_startup_message_params_options_escaped() {
fn split_options(params: &StartupMessageParams) -> Vec<Cow<'_, str>> {
params
.options_escaped()
.expect("options are None")
.collect()
}

let make_params = |options| StartupMessageParams::new([("options", options)]);

let params = StartupMessageParams::new([]);
assert!(matches!(params.options_escaped(), None));

let params = make_params("");
assert!(split_options(&params).is_empty());

let params = make_params("foo");
assert_eq!(split_options(&params), ["foo"]);

let params = make_params(" foo bar ");
assert_eq!(split_options(&params), ["foo", "bar"]);

let params = make_params("foo\\ bar \\ \\\\ baz\\ lol");
assert_eq!(split_options(&params), ["foo bar", " \\", "baz ", "lol"]);
}

// Make sure that `read` is sync/async callable
async fn _assert(stream: &mut (impl tokio::io::AsyncRead + Unpin)) {
let _ = FeMessage::read(&mut [].as_ref());
Expand Down
1 change: 1 addition & 0 deletions proxy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ hashbrown = "0.12"
hex = "0.4.3"
hmac = "0.12.1"
hyper = "0.14"
itertools = "0.10.3"
once_cell = "1.13.0"
md5 = "0.7.0"
parking_lot = "0.12"
Expand Down
4 changes: 2 additions & 2 deletions proxy/src/auth/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ impl<T, E> BackendType<Result<T, E>> {
}
}

impl BackendType<ClientCredentials> {
impl BackendType<ClientCredentials<'_>> {
/// Authenticate the client via the requested backend, possibly using credentials.
pub async fn authenticate(
mut self,
Expand All @@ -149,7 +149,7 @@ impl BackendType<ClientCredentials> {

// Finally we may finish the initialization of `creds`.
// TODO: add missing type safety to ClientCredentials.
creds.project = Some(payload.project);
creds.project = Some(payload.project.into());

let mut config = match &self {
Console(creds) => {
Expand Down
8 changes: 4 additions & 4 deletions proxy/src/auth/backend/console.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ pub enum AuthInfo {
#[must_use]
pub(super) struct Api<'a> {
endpoint: &'a ApiUrl,
creds: &'a ClientCredentials,
creds: &'a ClientCredentials<'a>,
}

impl<'a> Api<'a> {
Expand All @@ -143,7 +143,7 @@ impl<'a> Api<'a> {
url.path_segments_mut().push("proxy_get_role_secret");
url.query_pairs_mut()
.append_pair("project", self.creds.project().expect("impossible"))
.append_pair("role", &self.creds.user);
.append_pair("role", self.creds.user);

// TODO: use a proper logger
println!("cplane request: {url}");
Expand Down Expand Up @@ -187,8 +187,8 @@ impl<'a> Api<'a> {
config
.host(host)
.port(port)
.dbname(&self.creds.dbname)
.user(&self.creds.user);
.dbname(self.creds.dbname)
.user(self.creds.user);

Ok(config)
}
Expand Down
12 changes: 6 additions & 6 deletions proxy/src/auth/backend/legacy_console.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,23 +56,23 @@ enum ProxyAuthResponse {
NotReady { ready: bool }, // TODO: get rid of `ready`
}

impl ClientCredentials {
impl ClientCredentials<'_> {
fn is_existing_user(&self) -> bool {
self.user.ends_with("@zenith")
}
}

async fn authenticate_proxy_client(
auth_endpoint: &reqwest::Url,
creds: &ClientCredentials,
creds: &ClientCredentials<'_>,
md5_response: &str,
salt: &[u8; 4],
psql_session_id: &str,
) -> Result<DatabaseInfo, LegacyAuthError> {
let mut url = auth_endpoint.clone();
url.query_pairs_mut()
.append_pair("login", &creds.user)
.append_pair("database", &creds.dbname)
.append_pair("login", creds.user)
.append_pair("database", creds.dbname)
.append_pair("md5response", md5_response)
.append_pair("salt", &hex::encode(salt))
.append_pair("psql_session_id", psql_session_id);
Expand Down Expand Up @@ -103,7 +103,7 @@ async fn authenticate_proxy_client(
async fn handle_existing_user(
auth_endpoint: &reqwest::Url,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
creds: &ClientCredentials,
creds: &ClientCredentials<'_>,
) -> auth::Result<compute::NodeInfo> {
let psql_session_id = super::link::new_psql_session_id();
let md5_salt = rand::random();
Expand Down Expand Up @@ -136,7 +136,7 @@ async fn handle_existing_user(
pub async fn handle_user(
auth_endpoint: &reqwest::Url,
auth_link_uri: &reqwest::Url,
creds: &ClientCredentials,
creds: &ClientCredentials<'_>,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
) -> auth::Result<compute::NodeInfo> {
if creds.is_existing_user() {
Expand Down
6 changes: 3 additions & 3 deletions proxy/src/auth/backend/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
#[must_use]
pub(super) struct Api<'a> {
endpoint: &'a ApiUrl,
creds: &'a ClientCredentials,
creds: &'a ClientCredentials<'a>,
}

// Helps eliminate graceless `.map_err` calls without introducing another ctor.
Expand Down Expand Up @@ -87,8 +87,8 @@ impl<'a> Api<'a> {
config
.host(self.endpoint.host_str().unwrap_or("localhost"))
.port(self.endpoint.port().unwrap_or(5432))
.dbname(&self.creds.dbname)
.user(&self.creds.user);
.dbname(self.creds.dbname)
.user(self.creds.user);

Ok(config)
}
Expand Down
Loading

0 comments on commit 96a50e9

Please sign in to comment.