Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for providing a custom Connection ID generator via Config #3452

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,8 @@ type client struct {
logger utils.Logger
}

var (
// make it possible to mock connection ID generation in the tests
generateConnectionID = protocol.GenerateConnectionID
generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
)
// make it possible to mock connection ID for initial generation in the tests
var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial

// DialAddr establishes a new QUIC connection to a server.
// It uses a new UDP connection and closes this connection when the QUIC connection is closed.
Expand Down Expand Up @@ -193,7 +190,7 @@ func dialContext(
return nil, err
}
config = populateClientConfig(config, createdPacketConn)
packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength, config.StatelessResetKey, config.Tracer)
packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDGenerator.ConnectionIDLen(), config.StatelessResetKey, config.Tracer)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -256,7 +253,7 @@ func newClient(
}
}

srcConnID, err := generateConnectionID(config.ConnectionIDLength)
srcConnID, err := config.ConnectionIDGenerator.GenerateConnectionID()
if err != nil {
return nil, err
}
Expand Down
21 changes: 14 additions & 7 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,22 +88,16 @@ var _ = Describe("Client", func() {
})

Context("Dialing", func() {
var origGenerateConnectionID func(int) (protocol.ConnectionID, error)
var origGenerateConnectionIDForInitial func() (protocol.ConnectionID, error)

BeforeEach(func() {
origGenerateConnectionID = generateConnectionID
origGenerateConnectionIDForInitial = generateConnectionIDForInitial
generateConnectionID = func(int) (protocol.ConnectionID, error) {
return connID, nil
}
generateConnectionIDForInitial = func() (protocol.ConnectionID, error) {
return connID, nil
}
})

AfterEach(func() {
generateConnectionID = origGenerateConnectionID
generateConnectionIDForInitial = origGenerateConnectionIDForInitial
})

Expand Down Expand Up @@ -524,7 +518,7 @@ var _ = Describe("Client", func() {
manager.EXPECT().Add(connID, gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)

config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}, ConnectionIDGenerator: &mockedConnIDGenerator{ConnID: connID}}
c := make(chan struct{})
var cconn sendConn
var version protocol.VersionNumber
Expand Down Expand Up @@ -602,10 +596,23 @@ var _ = Describe("Client", func() {
return conn
}

config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.VersionTLS}, ConnectionIDGenerator: &mockedConnIDGenerator{ConnID: connID}}
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
_, err := DialAddr("localhost:7890", tlsConf, config)
Expect(err).ToNot(HaveOccurred())
Expect(counter).To(Equal(2))
})
})
})

type mockedConnIDGenerator struct {
ConnID protocol.ConnectionID
}

func (m *mockedConnIDGenerator) GenerateConnectionID() ([]byte, error) {
return m.ConnID, nil
}

func (m *mockedConnIDGenerator) ConnectionIDLen() int {
return m.ConnID.Len()
}
26 changes: 17 additions & 9 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,7 @@ func validateConfig(config *Config) error {
// populateServerConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil
func populateServerConfig(config *Config) *Config {
config = populateConfig(config)
if config.ConnectionIDLength == 0 {
config.ConnectionIDLength = protocol.DefaultConnectionIDLength
}
config = populateConfig(config, protocol.DefaultConnectionIDLength)
if config.MaxTokenAge == 0 {
config.MaxTokenAge = protocol.TokenValidity
}
Expand All @@ -54,21 +51,27 @@ func populateServerConfig(config *Config) *Config {
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil
func populateClientConfig(config *Config, createdPacketConn bool) *Config {
config = populateConfig(config)
if config.ConnectionIDLength == 0 && !createdPacketConn {
config.ConnectionIDLength = protocol.DefaultConnectionIDLength
defaultConnIDLen := protocol.DefaultConnectionIDLength
if createdPacketConn {
defaultConnIDLen = 0
}

config = populateConfig(config, defaultConnIDLen)
return config
}

func populateConfig(config *Config) *Config {
func populateConfig(config *Config, defaultConnIDLen int) *Config {
if config == nil {
config = &Config{}
}
versions := config.Versions
if len(versions) == 0 {
versions = protocol.SupportedVersions
}
conIDLen := config.ConnectionIDLength
if config.ConnectionIDLength == 0 {
conIDLen = defaultConnIDLen
}
handshakeIdleTimeout := protocol.DefaultHandshakeIdleTimeout
if config.HandshakeIdleTimeout != 0 {
handshakeIdleTimeout = config.HandshakeIdleTimeout
Expand Down Expand Up @@ -105,6 +108,10 @@ func populateConfig(config *Config) *Config {
} else if maxIncomingUniStreams < 0 {
maxIncomingUniStreams = 0
}
connIDGenerator := config.ConnectionIDGenerator
if connIDGenerator == nil {
connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: conIDLen}
}

return &Config{
Versions: versions,
Expand All @@ -121,7 +128,8 @@ func populateConfig(config *Config) *Config {
AllowConnectionWindowIncrease: config.AllowConnectionWindowIncrease,
MaxIncomingStreams: maxIncomingStreams,
MaxIncomingUniStreams: maxIncomingUniStreams,
ConnectionIDLength: config.ConnectionIDLength,
ConnectionIDLength: conIDLen,
ConnectionIDGenerator: connIDGenerator,
StatelessResetKey: config.StatelessResetKey,
TokenStore: config.TokenStore,
EnableDatagrams: config.EnableDatagrams,
Expand Down
8 changes: 5 additions & 3 deletions config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ var _ = Describe("Config", func() {
f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3}))
case "ConnectionIDLength":
f.Set(reflect.ValueOf(8))
case "ConnectionIDGenerator":
f.Set(reflect.ValueOf(&protocol.DefaultConnectionIDGenerator{ConnLen: protocol.DefaultConnectionIDLength}))
case "HandshakeIdleTimeout":
f.Set(reflect.ValueOf(time.Second))
case "MaxIdleTimeout":
Expand Down Expand Up @@ -140,18 +142,18 @@ var _ = Describe("Config", func() {
var calledAddrValidation bool
c1 := &Config{}
c1.RequireAddressValidation = func(net.Addr) bool { calledAddrValidation = true; return true }
c2 := populateConfig(c1)
c2 := populateConfig(c1, protocol.DefaultConnectionIDLength)
c2.RequireAddressValidation(&net.UDPAddr{})
Expect(calledAddrValidation).To(BeTrue())
})

It("copies non-function fields", func() {
c := configWithNonZeroNonFunctionFields()
Expect(populateConfig(c)).To(Equal(c))
Expect(populateConfig(c, protocol.DefaultConnectionIDLength)).To(Equal(c))
})

It("populates empty fields with default values", func() {
c := populateConfig(&Config{})
c := populateConfig(&Config{}, protocol.DefaultConnectionIDLength)
Expect(c.Versions).To(Equal(protocol.SupportedVersions))
Expect(c.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout))
Expect(c.InitialStreamReceiveWindow).To(BeEquivalentTo(protocol.DefaultInitialMaxStreamData))
Expand Down
9 changes: 5 additions & 4 deletions conn_id_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
)

type connIDGenerator struct {
connIDLen int
generator ConnectionIDGenerator
highestSeq uint64

activeSrcConnIDs map[uint64]protocol.ConnectionID
Expand All @@ -35,10 +35,11 @@ func newConnIDGenerator(
retireConnectionID func(protocol.ConnectionID),
replaceWithClosed func(protocol.ConnectionID, packetHandler),
queueControlFrame func(wire.Frame),
generator ConnectionIDGenerator,
version protocol.VersionNumber,
) *connIDGenerator {
m := &connIDGenerator{
connIDLen: initialConnectionID.Len(),
generator: generator,
activeSrcConnIDs: make(map[uint64]protocol.ConnectionID),
addConnectionID: addConnectionID,
getStatelessResetToken: getStatelessResetToken,
Expand All @@ -54,7 +55,7 @@ func newConnIDGenerator(
}

func (m *connIDGenerator) SetMaxActiveConnIDs(limit uint64) error {
if m.connIDLen == 0 {
if m.generator.ConnectionIDLen() == 0 {
return nil
}
// The active_connection_id_limit transport parameter is the number of
Expand Down Expand Up @@ -99,7 +100,7 @@ func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.Connect
}

func (m *connIDGenerator) issueNewConnID() error {
connID, err := protocol.GenerateConnectionID(m.connIDLen)
connID, err := m.generator.GenerateConnectionID()
if err != nil {
return err
}
Expand Down
1 change: 1 addition & 0 deletions conn_id_generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ var _ = Describe("Connection ID Generator", func() {
func(c protocol.ConnectionID) { retiredConnIDs = append(retiredConnIDs, c) },
func(c protocol.ConnectionID, h packetHandler) { replacedWithClosed[string(c)] = h },
func(f wire.Frame) { queuedFrames = append(queuedFrames, f) },
&protocol.DefaultConnectionIDGenerator{ConnLen: initialConnID.Len()},
protocol.VersionDraft29,
)
})
Expand Down
2 changes: 2 additions & 0 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ var newConnection = func(
runner.Retire,
runner.ReplaceWithClosed,
s.queueControlFrame,
s.config.ConnectionIDGenerator,
s.version,
)
s.preSetup()
Expand Down Expand Up @@ -407,6 +408,7 @@ var newClientConnection = func(
runner.Retire,
runner.ReplaceWithClosed,
s.queueControlFrame,
s.config.ConnectionIDGenerator,
s.version,
)
s.preSetup()
Expand Down
37 changes: 35 additions & 2 deletions integrationtests/self/conn_id_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ package self_test

import (
"context"
"crypto/rand"
"fmt"
"io"
"math/rand"
mrand "math/rand"
"net"

"github.com/lucas-clemente/quic-go"
Expand All @@ -14,9 +15,26 @@ import (
. "github.com/onsi/gomega"
)

type connIDGenerator struct {
length int
}

func (c *connIDGenerator) GenerateConnectionID() ([]byte, error) {
b := make([]byte, c.length)
_, err := rand.Read(b)
if err != nil {
fmt.Fprintf(GinkgoWriter, "generating conn ID failed: %s", err)
}
return b, nil
}

func (c *connIDGenerator) ConnectionIDLen() int {
return c.length
}

var _ = Describe("Connection ID lengths tests", func() {
randomConnIDLen := func() int {
return 4 + int(rand.Int31n(15))
return 4 + int(mrand.Int31n(15))
}

runServer := func(conf *quic.Config) quic.Listener {
Expand Down Expand Up @@ -87,4 +105,19 @@ var _ = Describe("Connection ID lengths tests", func() {
defer ln.Close()
runClient(ln.Addr(), clientConf)
})

It("downloads a file when both client and server use a custom connection ID generator", func() {
serverConf := getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{protocol.VersionTLS},
ConnectionIDGenerator: &connIDGenerator{length: randomConnIDLen()},
})
clientConf := getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{protocol.VersionTLS},
ConnectionIDGenerator: &connIDGenerator{length: randomConnIDLen()},
})

ln := runServer(serverConf)
defer ln.Close()
runClient(ln.Addr(), clientConf)
})
})
23 changes: 23 additions & 0 deletions interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,24 @@ type EarlyConnection interface {
NextConnection() Connection
}

// A ConnectionIDGenerator is an interface that allows clients to implement their own format
// for the Connection IDs that servers/clients use as SrcConnectionID in QUIC packets.
//
// Connection IDs generated by an implementation should always produce IDs of constant size.
type ConnectionIDGenerator interface {
// GenerateConnectionID generates a new ConnectionID.
joliveirinha marked this conversation as resolved.
Show resolved Hide resolved
// Generated ConnectionIDs should be unique and observers should not be able to correlate two ConnectionIDs.
GenerateConnectionID() ([]byte, error)
joliveirinha marked this conversation as resolved.
Show resolved Hide resolved

// ConnectionIDLen tells what is the length of the ConnectionIDs generated by the implementation of
// this interface.
joliveirinha marked this conversation as resolved.
Show resolved Hide resolved
// Effectively, this means that implementations of ConnectionIDGenerator must always return constant-size
// connection IDs. Valid lengths are between 0 and 20 and calls to GenerateConnectionID.
// 0-length ConnectionsIDs can be used when an endpoint (server or client) does not require multiplexing connections
// in the presence of a connection migration environment.
ConnectionIDLen() int
}

// Config contains all configuration data needed for a QUIC server or client.
type Config struct {
// The QUIC versions that can be negotiated.
Expand All @@ -213,6 +231,11 @@ type Config struct {
// If used for a server, or dialing on a packet conn, a 4 byte connection ID will be used.
// When dialing on a packet conn, the ConnectionIDLength value must be the same for every Dial call.
ConnectionIDLength int
// An optional ConnectionIDGenerator to be used for ConnectionIDs generated during the lifecycle of a QUIC connection.
// The goal is to give some control on how connection IDs, which can be useful in some scenarios, in particular for servers.
// By default, if not provided, random connection IDs with the length given by ConnectionIDLength is used.
// Otherwise, if one is provided, then ConnectionIDLength is ignored.
ConnectionIDGenerator ConnectionIDGenerator
// HandshakeIdleTimeout is the idle timeout before completion of the handshake.
// Specifically, if we don't receive any packet from the peer within this time, the connection attempt is aborted.
// If this value is zero, the timeout is set to 5 seconds.
Expand Down
12 changes: 12 additions & 0 deletions internal/protocol/connection_id.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,15 @@ func (c ConnectionID) String() string {
}
return fmt.Sprintf("%x", c.Bytes())
}

type DefaultConnectionIDGenerator struct {
ConnLen int
}

func (d *DefaultConnectionIDGenerator) GenerateConnectionID() ([]byte, error) {
return GenerateConnectionID(d.ConnLen)
}

func (d *DefaultConnectionIDGenerator) ConnectionIDLen() int {
return d.ConnLen
}
Loading