Skip to content

Commit

Permalink
add protocol negotiation
Browse files Browse the repository at this point in the history
  • Loading branch information
conradludgate committed Jul 9, 2024
1 parent cf0e3f9 commit b2db42b
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 21 deletions.
77 changes: 59 additions & 18 deletions libs/pq_proto/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,37 @@ pub enum FeMessage {
PasswordMessage(Bytes),
}

#[derive(Clone, Copy, PartialEq)]
pub struct ProtocolVersion(u32);

impl ProtocolVersion {
pub const fn new(major: u16, minor: u16) -> Self {
Self((major as u32) << 16 | minor as u32)
}
pub const fn minor(self) -> u16 {
self.0 as u16
}
pub const fn major(self) -> u16 {
(self.0 >> 16) as u16
}
}

impl fmt::Debug for ProtocolVersion {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_list()
.entry(&self.major())
.entry(&self.minor())
.finish()
}
}

#[derive(Debug)]
pub enum FeStartupPacket {
CancelRequest(CancelKeyData),
SslRequest,
GssEncRequest,
StartupMessage {
major_version: u32,
minor_version: u32,
version: ProtocolVersion,
params: StartupMessageParams,
},
}
Expand Down Expand Up @@ -301,11 +324,15 @@ impl FeStartupPacket {
/// different from [`FeMessage::parse`] because startup messages don't have
/// message type byte; otherwise, its comments apply.
pub fn parse(buf: &mut BytesMut) -> Result<Option<FeStartupPacket>, ProtocolError> {
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L118>
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234;
const CANCEL_REQUEST_CODE: u32 = 5678;
const NEGOTIATE_SSL_CODE: u32 = 5679;
const NEGOTIATE_GSS_CODE: u32 = 5680;
const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234;
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L132>
const CANCEL_REQUEST_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5678);
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L166>
const NEGOTIATE_SSL_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5679);
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L167>
const NEGOTIATE_GSS_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5680);

// <https://github.com/postgres/postgres/blob/04bcf9e19a4261fe9c7df37c777592c2e10c32a7/src/backend/tcop/backend_startup.c#L378-L382>
// First byte indicates standard SSL handshake message
Expand Down Expand Up @@ -346,12 +373,10 @@ impl FeStartupPacket {
let mut msg = buf.split_to(len).freeze();
msg.advance(4); // consume len

let request_code = msg.get_u32();
let req_hi = request_code >> 16;
let req_lo = request_code & ((1 << 16) - 1);
let request_code = ProtocolVersion(msg.get_u32());
// StartupMessage, CancelRequest, SSLRequest etc are differentiated by request code.
let message = match (req_hi, req_lo) {
(RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
let message = match request_code {
CANCEL_REQUEST_CODE => {
if msg.remaining() != 8 {
return Err(ProtocolError::BadMessage(
"CancelRequest message is malformed, backend PID / secret key missing"
Expand All @@ -363,21 +388,22 @@ impl FeStartupPacket {
cancel_key: msg.get_i32(),
})
}
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
NEGOTIATE_SSL_CODE => {
// Requested upgrade to SSL (aka TLS)
FeStartupPacket::SslRequest
}
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => {
NEGOTIATE_GSS_CODE => {
// Requested upgrade to GSSAPI
FeStartupPacket::GssEncRequest
}
(RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
version if version.major() == RESERVED_INVALID_MAJOR_VERSION => {
return Err(ProtocolError::Protocol(format!(
"Unrecognized request code {unrecognized_code}"
"Unrecognized request code {}",
version.minor()
)));
}
// TODO bail if protocol major_version is not 3?
(major_version, minor_version) => {
version => {
// StartupMessage

let s = str::from_utf8(&msg).map_err(|_e| {
Expand All @@ -390,8 +416,7 @@ impl FeStartupPacket {
})?;

FeStartupPacket::StartupMessage {
major_version,
minor_version,
version,
params: StartupMessageParams {
params: msg.slice_ref(s.as_bytes()),
},
Expand Down Expand Up @@ -530,6 +555,10 @@ pub enum BeMessage<'a> {
RowDescription(&'a [RowDescriptor<'a>]),
XLogData(XLogDataBody<'a>),
NoticeResponse(&'a str),
NegotiateProtocolVersion {
version: ProtocolVersion,
options: &'a [&'a str],
},
KeepAlive(WalSndKeepAlive),
}

Expand Down Expand Up @@ -953,6 +982,18 @@ impl<'a> BeMessage<'a> {
buf.put_u8(u8::from(req.request_reply));
});
}

BeMessage::NegotiateProtocolVersion { version, options } => {
buf.put_u8(b'v');
write_body(buf, |buf| {
buf.put_u32(version.0);
buf.put_u32(options.len() as u32);
for option in options.iter() {
write_cstr(option, buf)?;
}
Ok(())
})?
}
}
Ok(())
}
Expand Down
48 changes: 45 additions & 3 deletions proxy/src/proxy/handshake.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use bytes::Buf;
use pq_proto::{
framed::Framed, BeMessage as Be, CancelKeyData, FeStartupPacket, StartupMessageParams,
framed::Framed, BeMessage as Be, CancelKeyData, FeStartupPacket, ProtocolVersion,
StartupMessageParams,
};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
Expand Down Expand Up @@ -73,6 +74,8 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
// Client may try upgrading to each protocol only once
let (mut tried_ssl, mut tried_gss) = (false, false);

const SUPPORTED_VERSION: ProtocolVersion = ProtocolVersion::new(3, 0);

let mut stream = PqStream::new(Stream::from_raw(stream));
loop {
let msg = stream.read_startup_packet().await?;
Expand Down Expand Up @@ -176,7 +179,10 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
}
_ => return Err(HandshakeError::ProtocolViolation),
},
StartupMessage { params, .. } => {
StartupMessage {
params,
version: version @ SUPPORTED_VERSION,
} => {
// Check that the config has been consumed during upgrade
// OR we didn't provide it at all (for dev purposes).
if tls.is_some() {
Expand All @@ -185,9 +191,45 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
.await?;
}

info!(session_type = "normal", "successful handshake");
info!(?version, session_type = "normal", "successful handshake");
break Ok(HandshakeData::Startup(stream, params));
}
StartupMessage { params, version } if version.major() == 3 => {
warn!(?version, "unsupported minor version");

// no protocol extensions are supported.
// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/backend/tcop/backend_startup.c#L744-L753>
let mut unsupported = vec![];
for (k, _) in params.iter() {
if k.starts_with("_pq_.") {
unsupported.push(k);
}
}

// TODO: remove unsupported options so we don't send them to compute.

stream
.write_message(&Be::NegotiateProtocolVersion {
version: SUPPORTED_VERSION,
options: &unsupported,
})
.await?;

info!(
?version,
session_type = "normal",
"successful handshake; unsupported minor version requested"
);
break Ok(HandshakeData::Startup(stream, params));
}
StartupMessage { version, .. } => {
warn!(
?version,
session_type = "normal",
"unsuccessful handshake; unsupported major version"
);
return Err(HandshakeError::ProtocolViolation);
}
CancelRequest(cancel_key_data) => {
info!(session_type = "cancellation", "successful handshake");
break Ok(HandshakeData::Cancel(cancel_key_data));
Expand Down

0 comments on commit b2db42b

Please sign in to comment.