diff --git a/client.go b/client.go index 1c5654f6dca..9f91d1e41cc 100644 --- a/client.go +++ b/client.go @@ -22,9 +22,10 @@ type client struct { tlsConf *tls.Config config *Config - connIDGenerator ConnectionIDGenerator - srcConnID protocol.ConnectionID - destConnID protocol.ConnectionID + connIDGenerator ConnectionIDGenerator + maxUDPPayloadSize protocol.ByteCount + srcConnID protocol.ConnectionID + destConnID protocol.ConnectionID initialPacketNumber protocol.PacketNumber hasNegotiatedVersion bool @@ -142,8 +143,9 @@ func dial( config *Config, onClose func(), use0RTT bool, + maxUDPPayloadSize protocol.ByteCount, ) (quicConn, error) { - c, err := newClient(conn, connIDGenerator, config, tlsConf, onClose, use0RTT) + c, err := newClient(conn, connIDGenerator, config, tlsConf, onClose, use0RTT, maxUDPPayloadSize) if err != nil { return nil, err } @@ -162,7 +164,7 @@ func dial( return c.conn, nil } -func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config *Config, tlsConf *tls.Config, onClose func(), use0RTT bool) (*client, error) { +func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config *Config, tlsConf *tls.Config, onClose func(), use0RTT bool, maxUDPPayloadSize protocol.ByteCount) (*client, error) { srcConnID, err := connIDGenerator.GenerateConnectionID() if err != nil { return nil, err @@ -172,17 +174,18 @@ func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config return nil, err } c := &client{ - connIDGenerator: connIDGenerator, - srcConnID: srcConnID, - destConnID: destConnID, - sendConn: sendConn, - use0RTT: use0RTT, - onClose: onClose, - tlsConf: tlsConf, - config: config, - version: config.Versions[0], - handshakeChan: make(chan struct{}), - logger: utils.DefaultLogger.WithPrefix("client"), + connIDGenerator: connIDGenerator, + srcConnID: srcConnID, + destConnID: destConnID, + sendConn: sendConn, + use0RTT: use0RTT, + onClose: onClose, + tlsConf: tlsConf, + config: config, + maxUDPPayloadSize: maxUDPPayloadSize, + version: config.Versions[0], + handshakeChan: make(chan struct{}), + logger: utils.DefaultLogger.WithPrefix("client"), } return c, nil } @@ -205,6 +208,7 @@ func (c *client) dial(ctx context.Context) error { c.tracer, c.logger, c.version, + c.maxUDPPayloadSize, ) c.packetHandlers.Add(c.srcConnID, c.conn) diff --git a/client_test.go b/client_test.go index 98d6e2a84b5..ff4254869e3 100644 --- a/client_test.go +++ b/client_test.go @@ -47,6 +47,7 @@ var _ = Describe("Client", func() { tracer *logging.ConnectionTracer, logger utils.Logger, v protocol.Version, + maxUDPPayloadSize protocol.ByteCount, ) quicConn ) @@ -126,6 +127,7 @@ var _ = Describe("Client", func() { _ *logging.ConnectionTracer, _ utils.Logger, _ protocol.Version, + _ protocol.ByteCount, ) quicConn { Expect(enable0RTT).To(BeFalse()) conn := NewMockQUICConn(mockCtrl) @@ -135,7 +137,7 @@ var _ = Describe("Client", func() { conn.EXPECT().HandshakeComplete().Return(c) return conn } - cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, false) + cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, false, protocol.DefaultMaxUDPPayloadSize) Expect(err).ToNot(HaveOccurred()) cl.packetHandlers = manager Expect(cl).ToNot(BeNil()) @@ -163,6 +165,7 @@ var _ = Describe("Client", func() { _ *logging.ConnectionTracer, _ utils.Logger, _ protocol.Version, + _ protocol.ByteCount, ) quicConn { Expect(enable0RTT).To(BeTrue()) conn := NewMockQUICConn(mockCtrl) @@ -172,7 +175,7 @@ var _ = Describe("Client", func() { return conn } - cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, true) + cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, true, protocol.DefaultMaxUDPPayloadSize) Expect(err).ToNot(HaveOccurred()) cl.packetHandlers = manager Expect(cl).ToNot(BeNil()) @@ -200,6 +203,7 @@ var _ = Describe("Client", func() { _ *logging.ConnectionTracer, _ utils.Logger, _ protocol.Version, + _ protocol.ByteCount, ) quicConn { conn := NewMockQUICConn(mockCtrl) conn.EXPECT().run().Return(testErr) @@ -208,7 +212,7 @@ var _ = Describe("Client", func() { return conn } var closed bool - cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, func() { closed = true }, true) + cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, func() { closed = true }, true, protocol.DefaultMaxUDPPayloadSize) Expect(err).ToNot(HaveOccurred()) cl.packetHandlers = manager Expect(cl).ToNot(BeNil()) @@ -285,6 +289,7 @@ var _ = Describe("Client", func() { _ *logging.ConnectionTracer, _ utils.Logger, versionP protocol.Version, + _ protocol.ByteCount, ) quicConn { version = versionP conf = configP @@ -328,6 +333,7 @@ var _ = Describe("Client", func() { _ *logging.ConnectionTracer, _ utils.Logger, versionP protocol.Version, + _ protocol.ByteCount, ) quicConn { conn := NewMockQUICConn(mockCtrl) conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) diff --git a/config.go b/config.go index bdc3a1b62ae..d42bdc1c5f4 100644 --- a/config.go +++ b/config.go @@ -45,12 +45,6 @@ func validateConfig(config *Config) error { if config.InitialPacketSize > protocol.MaxPacketBufferSize { config.InitialPacketSize = protocol.MaxPacketBufferSize } - if config.MaxUDPPayloadSize > 0 && config.MaxUDPPayloadSize < protocol.MinInitialPacketSize { - config.MaxUDPPayloadSize = protocol.MinInitialPacketSize - } - if config.MaxUDPPayloadSize > protocol.MaxPacketBufferSize { - config.MaxUDPPayloadSize = protocol.MaxPacketBufferSize - } // check that all QUIC versions are actually supported for _, v := range config.Versions { if !protocol.IsValidVersion(v) { @@ -110,10 +104,6 @@ func populateConfig(config *Config) *Config { if initialPacketSize == 0 { initialPacketSize = protocol.InitialPacketSize } - maxUDPPayloadSize := config.MaxUDPPayloadSize - if maxUDPPayloadSize == 0 { - maxUDPPayloadSize = protocol.DefaultMaxUDPPayloadSize - } return &Config{ GetConfigForClient: config.GetConfigForClient, @@ -131,7 +121,6 @@ func populateConfig(config *Config) *Config { TokenStore: config.TokenStore, EnableDatagrams: config.EnableDatagrams, InitialPacketSize: initialPacketSize, - MaxUDPPayloadSize: maxUDPPayloadSize, DisablePathMTUDiscovery: config.DisablePathMTUDiscovery, Allow0RTT: config.Allow0RTT, Tracer: config.Tracer, diff --git a/config_test.go b/config_test.go index 56d778d5d42..6008a527361 100644 --- a/config_test.go +++ b/config_test.go @@ -68,24 +68,6 @@ var _ = Describe("Config", func() { Expect(validateConfig(conf)).To(Succeed()) Expect(conf.InitialPacketSize).To(BeZero()) }) - - It("increases too small UDP payload sizes", func() { - conf := &Config{MaxUDPPayloadSize: 10} - Expect(validateConfig(conf)).To(Succeed()) - Expect(conf.MaxUDPPayloadSize).To(BeEquivalentTo(1200)) - }) - - It("clips too large UDP payload sizes", func() { - conf := &Config{MaxUDPPayloadSize: protocol.MaxPacketBufferSize + 1} - Expect(validateConfig(conf)).To(Succeed()) - Expect(conf.MaxUDPPayloadSize).To(BeEquivalentTo(protocol.MaxPacketBufferSize)) - }) - - It("doesn't modify the MaxUDPPayloadSize if it is unset", func() { - conf := &Config{MaxUDPPayloadSize: 0} - Expect(validateConfig(conf)).To(Succeed()) - Expect(conf.MaxUDPPayloadSize).To(BeZero()) - }) }) configWithNonZeroNonFunctionFields := func() *Config { @@ -137,8 +119,6 @@ var _ = Describe("Config", func() { f.Set(reflect.ValueOf(true)) case "InitialPacketSize": f.Set(reflect.ValueOf(uint16(1350))) - case "MaxUDPPayloadSize": - f.Set(reflect.ValueOf(uint16(1400))) case "DisablePathMTUDiscovery": f.Set(reflect.ValueOf(true)) case "Allow0RTT": diff --git a/connection.go b/connection.go index 57afcb5a2b2..712159a9fe0 100644 --- a/connection.go +++ b/connection.go @@ -239,6 +239,7 @@ var newConnection = func( tracer *logging.ConnectionTracer, logger utils.Logger, v protocol.Version, + maxUDPPayloadSize protocol.ByteCount, ) quicConn { s := &connection{ ctx: ctx, @@ -298,7 +299,7 @@ var newConnection = func( MaxUniStreamNum: protocol.StreamNum(s.config.MaxIncomingUniStreams), MaxAckDelay: protocol.MaxAckDelayInclGranularity, AckDelayExponent: protocol.AckDelayExponent, - MaxUDPPayloadSize: protocol.ByteCount(s.config.MaxUDPPayloadSize), + MaxUDPPayloadSize: maxUDPPayloadSize, DisableActiveMigration: true, StatelessResetToken: &statelessResetToken, OriginalDestinationConnectionID: origDestConnID, @@ -354,6 +355,7 @@ var newClientConnection = func( tracer *logging.ConnectionTracer, logger utils.Logger, v protocol.Version, + maxUDPPayloadSize protocol.ByteCount, ) quicConn { s := &connection{ conn: conn, @@ -408,7 +410,7 @@ var newClientConnection = func( MaxBidiStreamNum: protocol.StreamNum(s.config.MaxIncomingStreams), MaxUniStreamNum: protocol.StreamNum(s.config.MaxIncomingUniStreams), MaxAckDelay: protocol.MaxAckDelayInclGranularity, - MaxUDPPayloadSize: protocol.ByteCount(s.config.MaxUDPPayloadSize), + MaxUDPPayloadSize: maxUDPPayloadSize, AckDelayExponent: protocol.AckDelayExponent, DisableActiveMigration: true, // For interoperability with quic-go versions before May 2023, this value must be set to a value diff --git a/connection_test.go b/connection_test.go index 88a9ec85dca..caf6b73cd76 100644 --- a/connection_test.go +++ b/connection_test.go @@ -135,6 +135,7 @@ var _ = Describe("Connection", func() { tr, utils.DefaultLogger, protocol.Version1, + protocol.DefaultMaxUDPPayloadSize, ).(*connection) streamManager = NewMockStreamManager(mockCtrl) conn.streamsMap = streamManager @@ -2589,6 +2590,7 @@ var _ = Describe("Client Connection", func() { tr, utils.DefaultLogger, protocol.Version1, + protocol.DefaultMaxUDPPayloadSize, ).(*connection) packer = NewMockPacker(mockCtrl) conn.packer = packer diff --git a/interface.go b/interface.go index e362d0a98fc..a3e670f0a0b 100644 --- a/interface.go +++ b/interface.go @@ -333,10 +333,6 @@ type Config struct { // If set too high, the path might not support packets that large, leading to a timeout of the QUIC handshake. // Values below 1200 are invalid. InitialPacketSize uint16 - // MaxUDPPayloadSize configures the max_udp_payload_size transport parameter. This is the the limit on much data this - // endpoint is willing to receive. - // This is an experimental config option, it might be removed once PMTU can account for the path changing to lower values - MaxUDPPayloadSize uint16 // DisablePathMTUDiscovery disables Path MTU Discovery (RFC 8899). // This allows the sending of QUIC packets that fully utilize the available MTU of the path. // Path MTU discovery is only available on systems that allow setting of the Don't Fragment (DF) bit. diff --git a/logging/types.go b/logging/types.go index 0d79b0a90a4..387f788eb93 100644 --- a/logging/types.go +++ b/logging/types.go @@ -56,6 +56,8 @@ const ( PacketDropUnexpectedVersion // PacketDropDuplicate is used when a duplicate packet is received PacketDropDuplicate + // PacketTooBig is used when packet size > max UDP payload + PacketTooBig ) // TimerType is the type of the loss detection timer diff --git a/server.go b/server.go index a55bdd43b21..f9e7f82fc2c 100644 --- a/server.go +++ b/server.go @@ -59,8 +59,9 @@ type baseServer struct { disableVersionNegotiation bool acceptEarlyConns bool - tlsConf *tls.Config - config *Config + tlsConf *tls.Config + config *Config + maxUDPPayloadSize protocol.ByteCount conn rawConn @@ -98,6 +99,7 @@ type baseServer struct { *logging.ConnectionTracer, utils.Logger, protocol.Version, + protocol.ByteCount, /* max UDP payload size */ ) quicConn closeMx sync.Mutex @@ -244,12 +246,14 @@ func newServer( verifySourceAddress func(net.Addr) bool, disableVersionNegotiation bool, acceptEarly bool, + maxUDPPayloadSize protocol.ByteCount, ) *baseServer { s := &baseServer{ conn: conn, connContext: connContext, tlsConf: tlsConf, config: config, + maxUDPPayloadSize: maxUDPPayloadSize, tokenGenerator: handshake.NewTokenGenerator(tokenGeneratorKey), maxTokenAge: maxTokenAge, verifySourceAddress: verifySourceAddress, @@ -689,6 +693,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error tracer, s.logger, hdr.Version, + s.maxUDPPayloadSize, ) conn.handlePacket(p) // Adding the connection will fail if the client's chosen Destination Connection ID is already in use. diff --git a/server_test.go b/server_test.go index 2bede3c08d0..35c5dd1ac66 100644 --- a/server_test.go +++ b/server_test.go @@ -305,6 +305,7 @@ var _ = Describe("Server", func() { _ *logging.ConnectionTracer, _ utils.Logger, _ protocol.Version, + _ protocol.ByteCount, ) quicConn { Expect(origDestConnID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}))) Expect(*retrySrcConnID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}))) @@ -508,6 +509,7 @@ var _ = Describe("Server", func() { _ *logging.ConnectionTracer, _ utils.Logger, _ protocol.Version, + _ protocol.ByteCount, ) quicConn { Expect(origDestConnID).To(Equal(hdr.DestConnectionID)) Expect(retrySrcConnID).To(BeNil()) @@ -577,6 +579,7 @@ var _ = Describe("Server", func() { _ *logging.ConnectionTracer, _ utils.Logger, _ protocol.Version, + _ protocol.ByteCount, ) quicConn { <-acceptConn counter.Add(1) @@ -633,6 +636,7 @@ var _ = Describe("Server", func() { _ *logging.ConnectionTracer, _ utils.Logger, _ protocol.Version, + _ protocol.ByteCount, ) quicConn { conn := NewMockQUICConn(mockCtrl) conn.EXPECT().handlePacket(gomock.Any()) @@ -683,6 +687,7 @@ var _ = Describe("Server", func() { _ *logging.ConnectionTracer, _ utils.Logger, _ protocol.Version, + _ protocol.ByteCount, ) quicConn { conn := <-connChan conn.EXPECT().handlePacket(gomock.Any()) @@ -745,6 +750,7 @@ var _ = Describe("Server", func() { _ *logging.ConnectionTracer, _ utils.Logger, _ protocol.Version, + _ protocol.ByteCount, ) quicConn { c := NewMockQUICConn(mockCtrl) c.EXPECT().handlePacket(gomock.Any()) @@ -979,6 +985,7 @@ var _ = Describe("Server", func() { _ *logging.ConnectionTracer, _ utils.Logger, _ protocol.Version, + _ protocol.ByteCount, ) quicConn { conn := NewMockQUICConn(mockCtrl) conn.EXPECT().handlePacket(gomock.Any()) @@ -1047,6 +1054,7 @@ var _ = Describe("Server", func() { _ *logging.ConnectionTracer, _ utils.Logger, _ protocol.Version, + _ protocol.ByteCount, ) quicConn { Expect(conf.MaxIncomingStreams).To(BeEquivalentTo(1234)) conn.EXPECT().handlePacket(gomock.Any()) @@ -1118,6 +1126,7 @@ var _ = Describe("Server", func() { _ *logging.ConnectionTracer, _ utils.Logger, _ protocol.Version, + _ protocol.ByteCount, ) quicConn { conn.EXPECT().handlePacket(gomock.Any()) conn.EXPECT().HandshakeComplete().Return(handshakeChan) @@ -1189,6 +1198,7 @@ var _ = Describe("Server", func() { _ *logging.ConnectionTracer, _ utils.Logger, _ protocol.Version, + _ protocol.ByteCount, ) quicConn { conn.EXPECT().handlePacket(gomock.Any()) conn.EXPECT().run() @@ -1231,6 +1241,7 @@ var _ = Describe("Server", func() { _ *logging.ConnectionTracer, _ utils.Logger, _ protocol.Version, + _ protocol.ByteCount, ) quicConn { ready := make(chan struct{}) close(ready) @@ -1291,6 +1302,7 @@ var _ = Describe("Server", func() { _ *logging.ConnectionTracer, _ utils.Logger, _ protocol.Version, + _ protocol.ByteCount, ) quicConn { conn.EXPECT().handlePacket(p) conn.EXPECT().run() @@ -1414,6 +1426,7 @@ var _ = Describe("Server", func() { _ *logging.ConnectionTracer, _ utils.Logger, _ protocol.Version, + _ protocol.ByteCount, ) quicConn { conn := NewMockQUICConn(mockCtrl) var calls []any diff --git a/transport.go b/transport.go index 059f30f5beb..f74cce14fe6 100644 --- a/transport.go +++ b/transport.go @@ -104,6 +104,11 @@ type Transport struct { // Tracer.Close is called when the transport is closed. Tracer *logging.Tracer + // MaxUDPPayloadSize configures the max_udp_payload_size transport parameter. This is the the limit on much data this + // endpoint is willing to receive. + // This is an experimental config option, it might be removed once PMTU can account for the path changing to lower values + MaxUDPPayloadSize uint16 + handlerMap packetHandlerManager mutex sync.Mutex @@ -189,6 +194,7 @@ func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bo t.VerifySourceAddress, t.DisableVersionNegotiationPackets, allow0RTT, + protocol.ByteCount(t.MaxUDPPayloadSize), ) t.server = s return s, nil @@ -218,7 +224,7 @@ func (t *Transport) dial(ctx context.Context, addr net.Addr, host string, tlsCon } tlsConf = tlsConf.Clone() setTLSConfigServerName(tlsConf, addr, host) - return dial(ctx, newSendConn(t.conn, addr, packetInfo{}, utils.DefaultLogger), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, use0RTT) + return dial(ctx, newSendConn(t.conn, addr, packetInfo{}, utils.DefaultLogger), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, use0RTT, protocol.ByteCount(t.MaxUDPPayloadSize)) } func (t *Transport) init(allowZeroLengthConnIDs bool) error { @@ -263,6 +269,10 @@ func (t *Transport) init(allowZeroLengthConnIDs bool) error { t.connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: t.connIDLen} } + if t.MaxUDPPayloadSize == 0 { + t.MaxUDPPayloadSize = protocol.DefaultMaxUDPPayloadSize + } + getMultiplexer().AddConn(t.Conn) go t.listen(conn) go t.runSendQueue() @@ -398,6 +408,15 @@ func (t *Transport) handlePacket(p receivedPacket) { t.handleNonQUICPacket(p) return } + if len(p.data) > int(t.MaxUDPPayloadSize) { + // Peer didn't respect our max_udp_payload_size transport parameter. Drop the packet + t.logger.Debugf("received packet size %s great than the max UDP payload %s", len(p.data), t.MaxUDPPayloadSize) + if t.Tracer != nil && t.Tracer.DroppedPacket != nil { + t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketTooBig) + } + p.buffer.MaybeRelease() + return + } connID, err := wire.ParseConnectionID(p.data, t.connIDLen) if err != nil { t.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err)