diff --git a/libs/postgres_backend/src/lib.rs b/libs/postgres_backend/src/lib.rs index c79ee4e0533a..7c7c6535b338 100644 --- a/libs/postgres_backend/src/lib.rs +++ b/libs/postgres_backend/src/lib.rs @@ -663,11 +663,17 @@ impl PostgresBackend { assert!(self.state < ProtoState::Authentication); let have_tls = self.tls_config.is_some(); match msg { - FeStartupPacket::SslRequest => { + FeStartupPacket::SslRequest { direct } => { debug!("SSL requested"); - self.write_message(&BeMessage::EncryptionResponse(have_tls)) - .await?; + if !direct { + self.write_message(&BeMessage::EncryptionResponse(have_tls)) + .await?; + } else if !have_tls { + return Err(QueryError::Other(anyhow::anyhow!( + "direct SSL negotiation but no TLS support" + ))); + } if have_tls { self.start_tls().await?; diff --git a/libs/pq_proto/src/framed.rs b/libs/pq_proto/src/framed.rs index 6e97b8c2a02c..ccbb90e3842e 100644 --- a/libs/pq_proto/src/framed.rs +++ b/libs/pq_proto/src/framed.rs @@ -44,9 +44,9 @@ impl ConnectionError { /// Wraps async io `stream`, providing messages to write/flush + read Postgres /// messages. pub struct Framed { - stream: S, - read_buf: BytesMut, - write_buf: BytesMut, + pub stream: S, + pub read_buf: BytesMut, + pub write_buf: BytesMut, } impl Framed { diff --git a/libs/pq_proto/src/lib.rs b/libs/pq_proto/src/lib.rs index cee374201763..a01191bd5de3 100644 --- a/libs/pq_proto/src/lib.rs +++ b/libs/pq_proto/src/lib.rs @@ -39,14 +39,39 @@ pub enum FeMessage { PasswordMessage(Bytes), } +#[derive(Clone, Copy, PartialEq, PartialOrd)] +pub struct ProtocolVersion(u32); + +impl ProtocolVersion { + pub const fn new(major: u16, minor: u16) -> Self { + Self((major as u32) << 16 | minor as u32) + } + pub const fn minor(self) -> u16 { + self.0 as u16 + } + pub const fn major(self) -> u16 { + (self.0 >> 16) as u16 + } +} + +impl fmt::Debug for ProtocolVersion { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_list() + .entry(&self.major()) + .entry(&self.minor()) + .finish() + } +} + #[derive(Debug)] pub enum FeStartupPacket { CancelRequest(CancelKeyData), - SslRequest, + SslRequest { + direct: bool, + }, GssEncRequest, StartupMessage { - major_version: u32, - minor_version: u32, + version: ProtocolVersion, params: StartupMessageParams, }, } @@ -301,11 +326,23 @@ impl FeStartupPacket { /// different from [`FeMessage::parse`] because startup messages don't have /// message type byte; otherwise, its comments apply. pub fn parse(buf: &mut BytesMut) -> Result, ProtocolError> { + /// const MAX_STARTUP_PACKET_LENGTH: usize = 10000; - const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234; - const CANCEL_REQUEST_CODE: u32 = 5678; - const NEGOTIATE_SSL_CODE: u32 = 5679; - const NEGOTIATE_GSS_CODE: u32 = 5680; + const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234; + /// + const CANCEL_REQUEST_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5678); + /// + const NEGOTIATE_SSL_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5679); + /// + const NEGOTIATE_GSS_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5680); + + // + // First byte indicates standard SSL handshake message + // (It can't be a Postgres startup length because in network byte order + // that would be a startup packet hundreds of megabytes long) + if buf.first() == Some(&0x16) { + return Ok(Some(FeStartupPacket::SslRequest { direct: true })); + } // need at least 4 bytes with packet len if buf.len() < 4 { @@ -338,12 +375,10 @@ impl FeStartupPacket { let mut msg = buf.split_to(len).freeze(); msg.advance(4); // consume len - let request_code = msg.get_u32(); - let req_hi = request_code >> 16; - let req_lo = request_code & ((1 << 16) - 1); + let request_code = ProtocolVersion(msg.get_u32()); // StartupMessage, CancelRequest, SSLRequest etc are differentiated by request code. - let message = match (req_hi, req_lo) { - (RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => { + let message = match request_code { + CANCEL_REQUEST_CODE => { if msg.remaining() != 8 { return Err(ProtocolError::BadMessage( "CancelRequest message is malformed, backend PID / secret key missing" @@ -355,21 +390,22 @@ impl FeStartupPacket { cancel_key: msg.get_i32(), }) } - (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => { + NEGOTIATE_SSL_CODE => { // Requested upgrade to SSL (aka TLS) - FeStartupPacket::SslRequest + FeStartupPacket::SslRequest { direct: false } } - (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => { + NEGOTIATE_GSS_CODE => { // Requested upgrade to GSSAPI FeStartupPacket::GssEncRequest } - (RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => { + version if version.major() == RESERVED_INVALID_MAJOR_VERSION => { return Err(ProtocolError::Protocol(format!( - "Unrecognized request code {unrecognized_code}" + "Unrecognized request code {}", + version.minor() ))); } // TODO bail if protocol major_version is not 3? - (major_version, minor_version) => { + version => { // StartupMessage let s = str::from_utf8(&msg).map_err(|_e| { @@ -382,8 +418,7 @@ impl FeStartupPacket { })?; FeStartupPacket::StartupMessage { - major_version, - minor_version, + version, params: StartupMessageParams { params: msg.slice_ref(s.as_bytes()), }, @@ -522,6 +557,10 @@ pub enum BeMessage<'a> { RowDescription(&'a [RowDescriptor<'a>]), XLogData(XLogDataBody<'a>), NoticeResponse(&'a str), + NegotiateProtocolVersion { + version: ProtocolVersion, + options: &'a [&'a str], + }, KeepAlive(WalSndKeepAlive), } @@ -945,6 +984,18 @@ impl<'a> BeMessage<'a> { buf.put_u8(u8::from(req.request_reply)); }); } + + BeMessage::NegotiateProtocolVersion { version, options } => { + buf.put_u8(b'v'); + write_body(buf, |buf| { + buf.put_u32(version.0); + buf.put_u32(options.len() as u32); + for option in options.iter() { + write_cstr(option, buf)?; + } + Ok(()) + })? + } } Ok(()) } diff --git a/proxy/src/bin/pg_sni_router.rs b/proxy/src/bin/pg_sni_router.rs index 44e880838e07..d7a3eb9a4d18 100644 --- a/proxy/src/bin/pg_sni_router.rs +++ b/proxy/src/bin/pg_sni_router.rs @@ -216,10 +216,11 @@ async fn ssl_handshake( use pq_proto::FeStartupPacket::*; match msg { - SslRequest => { + SslRequest { direct: false } => { stream .write_message(&pq_proto::BeMessage::EncryptionResponse(true)) .await?; + // Upgrade raw stream into a secure TLS-backed stream. // NOTE: We've consumed `tls`; this fact will be used later. diff --git a/proxy/src/config.rs b/proxy/src/config.rs index af5511d7ec24..650491976053 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -75,6 +75,9 @@ impl TlsConfig { } } +/// +pub const PG_ALPN_PROTOCOL: &[u8] = b"postgresql"; + /// Configure TLS for the main endpoint. pub fn configure_tls( key_path: &str, @@ -111,16 +114,17 @@ pub fn configure_tls( let cert_resolver = Arc::new(cert_resolver); // allow TLS 1.2 to be compatible with older client libraries - let config = rustls::ServerConfig::builder_with_protocol_versions(&[ + let mut config = rustls::ServerConfig::builder_with_protocol_versions(&[ &rustls::version::TLS13, &rustls::version::TLS12, ]) .with_no_client_auth() - .with_cert_resolver(cert_resolver.clone()) - .into(); + .with_cert_resolver(cert_resolver.clone()); + + config.alpn_protocols = vec![PG_ALPN_PROTOCOL.to_vec()]; Ok(TlsConfig { - config, + config: Arc::new(config), common_names, cert_resolver, }) diff --git a/proxy/src/proxy/handshake.rs b/proxy/src/proxy/handshake.rs index dd935cc24528..d488aea9275b 100644 --- a/proxy/src/proxy/handshake.rs +++ b/proxy/src/proxy/handshake.rs @@ -1,11 +1,17 @@ -use pq_proto::{BeMessage as Be, CancelKeyData, FeStartupPacket, StartupMessageParams}; +use bytes::Buf; +use pq_proto::{ + framed::Framed, BeMessage as Be, CancelKeyData, FeStartupPacket, ProtocolVersion, + StartupMessageParams, +}; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; -use tracing::info; +use tracing::{info, warn}; use crate::{ - config::TlsConfig, + auth::endpoint_sni, + config::{TlsConfig, PG_ALPN_PROTOCOL}, error::ReportableError, + metrics::Metrics, proxy::ERR_INSECURE_CONNECTION, stream::{PqStream, Stream, StreamUpgradeError}, }; @@ -68,6 +74,9 @@ pub async fn handshake( // Client may try upgrading to each protocol only once let (mut tried_ssl, mut tried_gss) = (false, false); + const PG_PROTOCOL_EARLIEST: ProtocolVersion = ProtocolVersion::new(3, 0); + const PG_PROTOCOL_LATEST: ProtocolVersion = ProtocolVersion::new(3, 0); + let mut stream = PqStream::new(Stream::from_raw(stream)); loop { let msg = stream.read_startup_packet().await?; @@ -75,40 +84,96 @@ pub async fn handshake( use FeStartupPacket::*; match msg { - SslRequest => match stream.get_ref() { + SslRequest { direct } => match stream.get_ref() { Stream::Raw { .. } if !tried_ssl => { tried_ssl = true; // We can't perform TLS handshake without a config - let enc = tls.is_some(); - stream.write_message(&Be::EncryptionResponse(enc)).await?; + let have_tls = tls.is_some(); + if !direct { + stream + .write_message(&Be::EncryptionResponse(have_tls)) + .await?; + } else if !have_tls { + return Err(HandshakeError::ProtocolViolation); + } + if let Some(tls) = tls.take() { // Upgrade raw stream into a secure TLS-backed stream. // NOTE: We've consumed `tls`; this fact will be used later. - let (raw, read_buf) = stream.into_inner(); - // TODO: Normally, client doesn't send any data before - // server says TLS handshake is ok and read_buf is empy. - // However, you could imagine pipelining of postgres - // SSLRequest + TLS ClientHello in one hunk similar to - // pipelining in our node js driver. We should probably - // support that by chaining read_buf with the stream. + let Framed { + stream: raw, + read_buf, + write_buf, + } = stream.framed; + + let Stream::Raw { raw } = raw else { + return Err(HandshakeError::StreamUpgradeError( + StreamUpgradeError::AlreadyTls, + )); + }; + + let mut read_buf = read_buf.reader(); + let mut res = Ok(()); + let accept = tokio_rustls::TlsAcceptor::from(tls.to_server_config()) + .accept_with(raw, |session| { + // push the early data to the tls session + while !read_buf.get_ref().is_empty() { + match session.read_tls(&mut read_buf) { + Ok(_) => {} + Err(e) => { + res = Err(e); + break; + } + } + } + }); + + res?; + + let read_buf = read_buf.into_inner(); if !read_buf.is_empty() { return Err(HandshakeError::EarlyData); } - let tls_stream = raw - .upgrade(tls.to_server_config(), record_handshake_error) - .await?; + + let tls_stream = accept.await.inspect_err(|_| { + if record_handshake_error { + Metrics::get().proxy.tls_handshake_failures.inc() + } + })?; + + let conn_info = tls_stream.get_ref().1; + + // check the ALPN, if exists, as required. + match conn_info.alpn_protocol() { + None | Some(PG_ALPN_PROTOCOL) => {} + Some(other) => { + // try parse ep for better error + let ep = conn_info.server_name().and_then(|sni| { + endpoint_sni(sni, &tls.common_names).ok().flatten() + }); + let alpn = String::from_utf8_lossy(other); + warn!(?ep, %alpn, "unexpected ALPN"); + return Err(HandshakeError::ProtocolViolation); + } + } let (_, tls_server_end_point) = tls .cert_resolver - .resolve(tls_stream.get_ref().1.server_name()) + .resolve(conn_info.server_name()) .ok_or(HandshakeError::MissingCertificate)?; - stream = PqStream::new(Stream::Tls { - tls: Box::new(tls_stream), - tls_server_end_point, - }); + stream = PqStream { + framed: Framed { + stream: Stream::Tls { + tls: Box::new(tls_stream), + tls_server_end_point, + }, + read_buf, + write_buf, + }, + }; } } _ => return Err(HandshakeError::ProtocolViolation), @@ -122,7 +187,9 @@ pub async fn handshake( } _ => return Err(HandshakeError::ProtocolViolation), }, - StartupMessage { params, .. } => { + StartupMessage { params, version } + if PG_PROTOCOL_EARLIEST <= version && version <= PG_PROTOCOL_LATEST => + { // Check that the config has been consumed during upgrade // OR we didn't provide it at all (for dev purposes). if tls.is_some() { @@ -131,9 +198,48 @@ pub async fn handshake( .await?; } - info!(session_type = "normal", "successful handshake"); + info!(?version, session_type = "normal", "successful handshake"); break Ok(HandshakeData::Startup(stream, params)); } + // downgrade protocol version + StartupMessage { params, version } + if version.major() == 3 && version > PG_PROTOCOL_LATEST => + { + warn!(?version, "unsupported minor version"); + + // no protocol extensions are supported. + // + let mut unsupported = vec![]; + for (k, _) in params.iter() { + if k.starts_with("_pq_.") { + unsupported.push(k); + } + } + + // TODO: remove unsupported options so we don't send them to compute. + + stream + .write_message(&Be::NegotiateProtocolVersion { + version: PG_PROTOCOL_LATEST, + options: &unsupported, + }) + .await?; + + info!( + ?version, + session_type = "normal", + "successful handshake; unsupported minor version requested" + ); + break Ok(HandshakeData::Startup(stream, params)); + } + StartupMessage { version, .. } => { + warn!( + ?version, + session_type = "normal", + "unsuccessful handshake; unsupported version" + ); + return Err(HandshakeError::ProtocolViolation); + } CancelRequest(cancel_key_data) => { info!(session_type = "cancellation", "successful handshake"); break Ok(HandshakeData::Cancel(cancel_key_data));