From a0432e70afd1a5b0f8f63ecc46406132324e7bee Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Mon, 14 Nov 2022 07:44:25 -0800 Subject: [PATCH] webtransport: use deterministic TLS certificates (#1833) * Use deterministic TLS certificates for webtransport * Update test to work with buckets * Make sure to overlap and use a random offset * Fixup mistaken change in other test * Add QuickCheck tests for cert behavior * Lint fix * Add more tests * Add webtransport integration test * Use same key * Actually offset by at least clockSkew * Use seeded key for certs after reboot test * PR comments * Remove debug code * Fix calculation for cert having been valid Fixes the logic that a cert has been valid for a clockSkew by subtracting the clockSkew from the start time rather than incorporating it into the offset. The offset should be used to shift the buckets. * Update comment * Lint fix * Update TestGetCurrentBucketStartTimeIsWithinBounds to include clockSkew calculation * Rebase fixes --- p2p/test/webtransport/webtransport_test.go | 53 +++++ p2p/transport/webtransport/cert_manager.go | 57 +++-- .../webtransport/cert_manager_test.go | 82 +++++++- p2p/transport/webtransport/crypto.go | 58 ++++- p2p/transport/webtransport/crypto_test.go | 68 ++++++ p2p/transport/webtransport/transport.go | 2 +- p2p/transport/webtransport/transport_test.go | 199 ++++++++++++++++++ 7 files changed, 493 insertions(+), 26 deletions(-) create mode 100644 p2p/test/webtransport/webtransport_test.go diff --git a/p2p/test/webtransport/webtransport_test.go b/p2p/test/webtransport/webtransport_test.go new file mode 100644 index 0000000000..adc64b49a3 --- /dev/null +++ b/p2p/test/webtransport/webtransport_test.go @@ -0,0 +1,53 @@ +package webtransport_test + +import ( + "testing" + "time" + + "github.com/benbjohnson/clock" + "github.com/libp2p/go-libp2p" + ic "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/test" + libp2pwebtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport" + ma "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/require" +) + +func extractCertHashes(addr ma.Multiaddr) []string { + var certHashesStr []string + ma.ForEach(addr, func(c ma.Component) bool { + if c.Protocol().Code == ma.P_CERTHASH { + certHashesStr = append(certHashesStr, c.Value()) + } + return true + }) + return certHashesStr +} + +func TestDeterministicCertsAfterReboot(t *testing.T) { + priv, _, err := test.RandTestKeyPair(ic.Ed25519, 256) + require.NoError(t, err) + + cl := clock.NewMock() + // Move one year ahead to avoid edge cases around epoch + cl.Add(time.Hour * 24 * 365) + h, err := libp2p.New(libp2p.NoTransports, libp2p.Transport(libp2pwebtransport.New, libp2pwebtransport.WithClock(cl)), libp2p.Identity(priv)) + require.NoError(t, err) + err = h.Network().Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) + require.NoError(t, err) + + prevCerthashes := extractCertHashes(h.Addrs()[0]) + h.Close() + + h, err = libp2p.New(libp2p.NoTransports, libp2p.Transport(libp2pwebtransport.New, libp2pwebtransport.WithClock(cl)), libp2p.Identity(priv)) + require.NoError(t, err) + defer h.Close() + err = h.Network().Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) + require.NoError(t, err) + + nextCertHashes := extractCertHashes(h.Addrs()[0]) + + for i := range prevCerthashes { + require.Equal(t, prevCerthashes[i], nextCertHashes[i]) + } +} diff --git a/p2p/transport/webtransport/cert_manager.go b/p2p/transport/webtransport/cert_manager.go index da2c69f464..9d9da18835 100644 --- a/p2p/transport/webtransport/cert_manager.go +++ b/p2p/transport/webtransport/cert_manager.go @@ -4,11 +4,13 @@ import ( "context" "crypto/sha256" "crypto/tls" + "encoding/binary" "fmt" "sync" "time" "github.com/benbjohnson/clock" + ic "github.com/libp2p/go-libp2p/core/crypto" ma "github.com/multiformats/go-multiaddr" "github.com/multiformats/go-multihash" ) @@ -17,6 +19,7 @@ import ( // When we generate a certificate, the NotBefore time is set to clockSkewAllowance before the current time. // Similarly, we stop using a certificate one clockSkewAllowance before its expiry time. const clockSkewAllowance = time.Hour +const validityMinusTwoSkew = certValidity - (2 * clockSkewAllowance) type certConfig struct { tlsConf *tls.Config @@ -26,8 +29,8 @@ type certConfig struct { func (c *certConfig) Start() time.Time { return c.tlsConf.Certificates[0].Leaf.NotBefore } func (c *certConfig) End() time.Time { return c.tlsConf.Certificates[0].Leaf.NotAfter } -func newCertConfig(start, end time.Time) (*certConfig, error) { - conf, err := getTLSConf(start, end) +func newCertConfig(key ic.PrivKey, start, end time.Time) (*certConfig, error) { + conf, err := getTLSConf(key, start, end) if err != nil { return nil, err } @@ -57,32 +60,58 @@ type certManager struct { serializedCertHashes [][]byte } -func newCertManager(clock clock.Clock) (*certManager, error) { +func newCertManager(hostKey ic.PrivKey, clock clock.Clock) (*certManager, error) { m := &certManager{clock: clock} m.ctx, m.ctxCancel = context.WithCancel(context.Background()) - if err := m.init(); err != nil { + if err := m.init(hostKey); err != nil { return nil, err } - m.background() + m.background(hostKey) return m, nil } -func (m *certManager) init() error { - start := m.clock.Now().Add(-clockSkewAllowance) - var err error - m.nextConfig, err = newCertConfig(start, start.Add(certValidity)) +// getCurrentTimeBucket returns the canonical start time of the given time as +// bucketed by ranges of certValidity since unix epoch (plus an offset). This +// lets you get the same time ranges across reboots without having to persist +// state. +// ``` +// ... v--- epoch + offset +// ... |--------| |--------| ... +// ... |--------| |--------| ... +// ``` +func getCurrentBucketStartTime(now time.Time, offset time.Duration) time.Time { + currentBucket := (now.UnixMilli() - offset.Milliseconds()) / validityMinusTwoSkew.Milliseconds() + return time.UnixMilli(offset.Milliseconds() + currentBucket*validityMinusTwoSkew.Milliseconds()) +} + +func (m *certManager) init(hostKey ic.PrivKey) error { + start := m.clock.Now() + pubkeyBytes, err := hostKey.GetPublic().Raw() + if err != nil { + return err + } + + // We want to add a random offset to each start time so that not all certs + // rotate at the same time across the network. The offset represents moving + // the bucket start time some `offset` earlier. + offset := (time.Duration(binary.LittleEndian.Uint16(pubkeyBytes)) * time.Minute) % certValidity + + // We want the certificate have been valid for at least one clockSkewAllowance + start = start.Add(-clockSkewAllowance) + startTime := getCurrentBucketStartTime(start, offset) + m.nextConfig, err = newCertConfig(hostKey, startTime, startTime.Add(certValidity)) if err != nil { return err } - return m.rollConfig() + return m.rollConfig(hostKey) } -func (m *certManager) rollConfig() error { +func (m *certManager) rollConfig(hostKey ic.PrivKey) error { // We stop using the current certificate clockSkewAllowance before its expiry time. // At this point, the next certificate needs to be valid for one clockSkewAllowance. nextStart := m.nextConfig.End().Add(-2 * clockSkewAllowance) - c, err := newCertConfig(nextStart, nextStart.Add(certValidity)) + c, err := newCertConfig(hostKey, nextStart, nextStart.Add(certValidity)) if err != nil { return err } @@ -95,7 +124,7 @@ func (m *certManager) rollConfig() error { return m.cacheAddrComponent() } -func (m *certManager) background() { +func (m *certManager) background(hostKey ic.PrivKey) { d := m.currentConfig.End().Add(-clockSkewAllowance).Sub(m.clock.Now()) log.Debugw("setting timer", "duration", d.String()) t := m.clock.Timer(d) @@ -111,7 +140,7 @@ func (m *certManager) background() { return case now := <-t.C: m.mx.Lock() - if err := m.rollConfig(); err != nil { + if err := m.rollConfig(hostKey); err != nil { log.Errorw("rolling config failed", "error", err) } d := m.currentConfig.End().Add(-clockSkewAllowance).Sub(now) diff --git a/p2p/transport/webtransport/cert_manager_test.go b/p2p/transport/webtransport/cert_manager_test.go index 3f2328fbb7..4b4550bb9d 100644 --- a/p2p/transport/webtransport/cert_manager_test.go +++ b/p2p/transport/webtransport/cert_manager_test.go @@ -3,10 +3,14 @@ package libp2pwebtransport import ( "crypto/sha256" "crypto/tls" + "fmt" "testing" + "testing/quick" "time" "github.com/benbjohnson/clock" + "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/test" ma "github.com/multiformats/go-multiaddr" "github.com/multiformats/go-multibase" "github.com/multiformats/go-multihash" @@ -39,14 +43,16 @@ func certHashFromComponent(t *testing.T, comp ma.Component) []byte { func TestInitialCert(t *testing.T) { cl := clock.NewMock() cl.Add(1234567 * time.Hour) - m, err := newCertManager(cl) + priv, _, err := test.RandTestKeyPair(crypto.Ed25519, 256) + require.NoError(t, err) + m, err := newCertManager(priv, cl) require.NoError(t, err) defer m.Close() conf := m.GetConfig() require.Len(t, conf.Certificates, 1) cert := conf.Certificates[0] - require.Equal(t, cl.Now().Add(-clockSkewAllowance).UTC(), cert.Leaf.NotBefore) + require.GreaterOrEqual(t, cl.Now().Add(-clockSkewAllowance), cert.Leaf.NotBefore) require.Equal(t, cert.Leaf.NotBefore.Add(certValidity), cert.Leaf.NotAfter) addr := m.AddrComponent() components := splitMultiaddr(addr) @@ -59,7 +65,11 @@ func TestInitialCert(t *testing.T) { func TestCertRenewal(t *testing.T) { cl := clock.NewMock() - m, err := newCertManager(cl) + // Add a year to avoid edge cases around the epoch + cl.Add(time.Hour * 24 * 365) + priv, _, err := test.SeededTestKeyPair(crypto.Ed25519, 256, 0) + require.NoError(t, err) + m, err := newCertManager(priv, cl) require.NoError(t, err) defer m.Close() @@ -68,7 +78,7 @@ func TestCertRenewal(t *testing.T) { require.Len(t, first, 2) require.NotEqual(t, first[0].Value(), first[1].Value(), "the hashes should differ") // wait for a new certificate to be generated - cl.Add(certValidity - 2*clockSkewAllowance - time.Second) + cl.Set(m.currentConfig.End().Add(-(clockSkewAllowance + time.Second))) require.Never(t, func() bool { for i, c := range splitMultiaddr(m.AddrComponent()) { if c.Value() != first[i].Value() { @@ -100,3 +110,67 @@ func TestCertRenewal(t *testing.T) { // check that the 2nd certificate from the beginning was rolled over to be the 1st certificate require.Equal(t, second[1].Value(), third[0].Value()) } + +func TestDeterministicCertsAcrossReboots(t *testing.T) { + // Run this test 100 times to make sure it's deterministic + runs := 100 + for i := 0; i < runs; i++ { + t.Run(fmt.Sprintf("Run=%d", i), func(t *testing.T) { + cl := clock.NewMock() + priv, _, err := test.SeededTestKeyPair(crypto.Ed25519, 256, 0) + require.NoError(t, err) + m, err := newCertManager(priv, cl) + require.NoError(t, err) + defer m.Close() + + conf := m.GetConfig() + require.Len(t, conf.Certificates, 1) + oldCerts := m.serializedCertHashes + + m.Close() + + cl.Add(time.Hour) + // reboot + m, err = newCertManager(priv, cl) + require.NoError(t, err) + defer m.Close() + + newCerts := m.serializedCertHashes + + require.Equal(t, oldCerts, newCerts) + }) + } +} + +func TestDeterministicTimeBuckets(t *testing.T) { + cl := clock.NewMock() + cl.Add(time.Hour * 24 * 365) + startA := getCurrentBucketStartTime(cl.Now(), 0) + startB := getCurrentBucketStartTime(cl.Now().Add(time.Hour*24), 0) + require.Equal(t, startA, startB) + + // 15 Days later + startC := getCurrentBucketStartTime(cl.Now().Add(time.Hour*24*15), 0) + require.NotEqual(t, startC, startB) +} + +func TestGetCurrentBucketStartTimeIsWithinBounds(t *testing.T) { + require.NoError(t, quick.Check(func(timeSinceUnixEpoch time.Duration, offset time.Duration) bool { + if offset < 0 { + offset = -offset + } + if timeSinceUnixEpoch < 0 { + timeSinceUnixEpoch = -timeSinceUnixEpoch + } + + offset = offset % certValidity + // Bound this to 100 years + timeSinceUnixEpoch = time.Duration(timeSinceUnixEpoch % (time.Hour * 24 * 365 * 100)) + // Start a bit further in the future to avoid edge cases around epoch + timeSinceUnixEpoch += time.Hour * 24 * 365 + start := time.UnixMilli(timeSinceUnixEpoch.Milliseconds()) + + bucketStart := getCurrentBucketStartTime(start.Add(-clockSkewAllowance), offset) + return !bucketStart.After(start.Add(-clockSkewAllowance)) || bucketStart.Equal(start.Add(-clockSkewAllowance)) + }, nil)) +} diff --git a/p2p/transport/webtransport/crypto.go b/p2p/transport/webtransport/crypto.go index 0ea71323af..dc2c1f03a7 100644 --- a/p2p/transport/webtransport/crypto.go +++ b/p2p/transport/webtransport/crypto.go @@ -4,7 +4,6 @@ import ( "bytes" "crypto/ecdsa" "crypto/elliptic" - "crypto/rand" "crypto/sha256" "crypto/tls" "crypto/x509" @@ -12,14 +11,20 @@ import ( "encoding/binary" "errors" "fmt" + "io" "math/big" "time" + ic "github.com/libp2p/go-libp2p/core/crypto" + "github.com/multiformats/go-multihash" + "golang.org/x/crypto/hkdf" ) -func getTLSConf(start, end time.Time) (*tls.Config, error) { - cert, priv, err := generateCert(start, end) +const deterministicCertInfo = "determinisitic cert" + +func getTLSConf(key ic.PrivKey, start, end time.Time) (*tls.Config, error) { + cert, priv, err := generateCert(key, start, end) if err != nil { return nil, err } @@ -32,9 +37,20 @@ func getTLSConf(start, end time.Time) (*tls.Config, error) { }, nil } -func generateCert(start, end time.Time) (*x509.Certificate, *ecdsa.PrivateKey, error) { +// generateCert generates certs deterministically based on the `key` and start +// time passed in. Uses `golang.org/x/crypto/hkdf`. +func generateCert(key ic.PrivKey, start, end time.Time) (*x509.Certificate, *ecdsa.PrivateKey, error) { + keyBytes, err := key.Raw() + if err != nil { + return nil, nil, err + } + + startTimeSalt := make([]byte, 8) + binary.LittleEndian.PutUint64(startTimeSalt, uint64(start.UnixNano())) + deterministicHKDFReader := newDeterministicReader(keyBytes, startTimeSalt, deterministicCertInfo) + b := make([]byte, 8) - if _, err := rand.Read(b); err != nil { + if _, err := deterministicHKDFReader.Read(b); err != nil { return nil, nil, err } serial := int64(binary.BigEndian.Uint64(b)) @@ -51,11 +67,12 @@ func generateCert(start, end time.Time) (*x509.Certificate, *ecdsa.PrivateKey, e KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, BasicConstraintsValid: true, } - caPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + + caPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), deterministicHKDFReader) if err != nil { return nil, nil, err } - caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, certTempl, &caPrivateKey.PublicKey, caPrivateKey) + caBytes, err := x509.CreateCertificate(deterministicHKDFReader, certTempl, certTempl, caPrivateKey.Public(), caPrivateKey) if err != nil { return nil, nil, err } @@ -106,3 +123,30 @@ func verifyRawCerts(rawCerts [][]byte, certHashes []multihash.DecodedMultihash) } return nil } + +// deterministicReader is a hack. It counter-acts the Go library's attempt at +// making ECDSA signatures non-deterministic. Go adds non-determinism by +// randomly dropping a singly byte from the reader stream. This counteracts this +// by detecting when a read is a single byte and using a different reader +// instead. +type deterministicReader struct { + reader io.Reader + singleByteReader io.Reader +} + +func newDeterministicReader(seed []byte, salt []byte, info string) io.Reader { + reader := hkdf.New(sha256.New, seed, salt, []byte(info)) + singleByteReader := hkdf.New(sha256.New, seed, salt, []byte(info+" single byte")) + + return &deterministicReader{ + reader: reader, + singleByteReader: singleByteReader, + } +} + +func (r *deterministicReader) Read(p []byte) (n int, err error) { + if len(p) == 1 { + return r.singleByteReader.Read(p) + } + return r.reader.Read(p) +} diff --git a/p2p/transport/webtransport/crypto_test.go b/p2p/transport/webtransport/crypto_test.go index d6d106202a..143cd6de45 100644 --- a/p2p/transport/webtransport/crypto_test.go +++ b/p2p/transport/webtransport/crypto_test.go @@ -1,6 +1,7 @@ package libp2pwebtransport import ( + "bytes" "crypto" "crypto/ecdsa" "crypto/elliptic" @@ -10,11 +11,13 @@ import ( "crypto/x509" "crypto/x509/pkix" "fmt" + "io" "math/big" mrand "math/rand" "testing" "time" + ic "github.com/libp2p/go-libp2p/core/crypto" "github.com/multiformats/go-multihash" "github.com/stretchr/testify/require" ) @@ -130,3 +133,68 @@ func TestCertificateVerification(t *testing.T) { }) } } + +func TestDeterministicCertHashes(t *testing.T) { + // Run this test 1000 times since we want to make sure the signatures are deterministic + runs := 1000 + for i := 0; i < runs; i++ { + t.Run(fmt.Sprintf("Run=%d", i), func(t *testing.T) { + zeroSeed := [32]byte{} + priv, _, err := ic.GenerateEd25519Key(bytes.NewReader(zeroSeed[:])) + require.NoError(t, err) + cert, certPriv, err := generateCert(priv, time.Time{}, time.Time{}.Add(time.Hour*24*14)) + require.NoError(t, err) + + keyBytes, err := x509.MarshalECPrivateKey(certPriv) + require.NoError(t, err) + + cert2, certPriv2, err := generateCert(priv, time.Time{}, time.Time{}.Add(time.Hour*24*14)) + require.NoError(t, err) + + require.Equal(t, cert2.Signature, cert.Signature) + require.Equal(t, cert2.Raw, cert.Raw) + keyBytes2, err := x509.MarshalECPrivateKey(certPriv2) + require.NoError(t, err) + require.Equal(t, keyBytes, keyBytes2) + }) + } +} + +// TestDeterministicSig tests that our hack around making ECDSA signatures +// deterministic works. If this fails, this means we need to try another +// strategy to make deterministic signatures or try something else entirely. +// See deterministicReader for more context. +func TestDeterministicSig(t *testing.T) { + // Run this test 1000 times since we want to make sure the signatures are deterministic + runs := 1000 + for i := 0; i < runs; i++ { + t.Run(fmt.Sprintf("Run=%d", i), func(t *testing.T) { + zeroSeed := [32]byte{} + deterministicHKDFReader := newDeterministicReader(zeroSeed[:], nil, deterministicCertInfo) + b := [1024]byte{} + io.ReadFull(deterministicHKDFReader, b[:]) + caPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), deterministicHKDFReader) + require.NoError(t, err) + + sig, err := caPrivateKey.Sign(deterministicHKDFReader, b[:], crypto.SHA256) + require.NoError(t, err) + + deterministicHKDFReader = newDeterministicReader(zeroSeed[:], nil, deterministicCertInfo) + b2 := [1024]byte{} + io.ReadFull(deterministicHKDFReader, b2[:]) + caPrivateKey2, err := ecdsa.GenerateKey(elliptic.P256(), deterministicHKDFReader) + require.NoError(t, err) + + sig2, err := caPrivateKey2.Sign(deterministicHKDFReader, b2[:], crypto.SHA256) + require.NoError(t, err) + + keyBytes, err := x509.MarshalECPrivateKey(caPrivateKey) + require.NoError(t, err) + keyBytes2, err := x509.MarshalECPrivateKey(caPrivateKey2) + require.NoError(t, err) + + require.Equal(t, sig, sig2) + require.Equal(t, keyBytes, keyBytes2) + }) + } +} diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index bc1b5e1439..8e6ac77dbc 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -293,7 +293,7 @@ func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { } if t.staticTLSConf == nil { t.listenOnce.Do(func() { - t.certManager, t.listenOnceErr = newCertManager(t.clock) + t.certManager, t.listenOnceErr = newCertManager(t.privKey, t.clock) }) if t.listenOnceErr != nil { return nil, t.listenOnceErr diff --git a/p2p/transport/webtransport/transport_test.go b/p2p/transport/webtransport/transport_test.go index de4cf2263c..70d4ec4b3d 100644 --- a/p2p/transport/webtransport/transport_test.go +++ b/p2p/transport/webtransport/transport_test.go @@ -20,14 +20,18 @@ import ( "strings" "sync/atomic" "testing" + "testing/quick" "time" + "github.com/benbjohnson/clock" ic "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/network" mocknetwork "github.com/libp2p/go-libp2p/core/network/mocks" "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/test" tpt "github.com/libp2p/go-libp2p/core/transport" libp2pwebtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport" + "github.com/lucas-clemente/quic-go" "github.com/golang/mock/gomock" quicproxy "github.com/lucas-clemente/quic-go/integrationtests/tools/proxy" @@ -38,6 +42,9 @@ import ( "github.com/stretchr/testify/require" ) +const clockSkewAllowance = time.Hour +const certValidity = 14 * 24 * time.Hour + func newIdentity(t *testing.T) (peer.ID, ic.PrivKey) { key, _, err := ic.GenerateEd25519Key(rand.Reader) require.NoError(t, err) @@ -714,3 +721,195 @@ func TestFlowControlWindowIncrease(t *testing.T) { t.Fatal("timeout") } } + +var errTimeout = errors.New("timeout") + +func serverSendsBackValidCert(timeSinceUnixEpoch time.Duration, keySeed int64, randomClientSkew time.Duration) error { + if timeSinceUnixEpoch < 0 { + timeSinceUnixEpoch = -timeSinceUnixEpoch + } + + // Bound this to 100 years + timeSinceUnixEpoch = time.Duration(timeSinceUnixEpoch % (time.Hour * 24 * 365 * 100)) + // Start a bit further in the future to avoid edge cases around epoch + timeSinceUnixEpoch += time.Hour * 24 * 365 + start := time.UnixMilli(timeSinceUnixEpoch.Milliseconds()) + + randomClientSkew = randomClientSkew % clockSkewAllowance + + cl := clock.NewMock() + cl.Set(start) + + priv, _, err := test.SeededTestKeyPair(ic.Ed25519, 256, keySeed) + if err != nil { + return err + } + tr, err := libp2pwebtransport.New(priv, nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl)) + if err != nil { + return err + } + l, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) + if err != nil { + return err + } + defer l.Close() + + conn, err := quic.DialAddr(l.Addr().String(), &tls.Config{ + NextProtos: []string{"h3"}, + InsecureSkipVerify: true, + VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + for _, c := range rawCerts { + cert, err := x509.ParseCertificate(c) + if err != nil { + return err + } + + for _, clientSkew := range []time.Duration{randomClientSkew, -clockSkewAllowance, clockSkewAllowance} { + clientTime := cl.Now().Add(clientSkew) + if clientTime.After(cert.NotAfter) || clientTime.Before(cert.NotBefore) { + return fmt.Errorf("Times are not valid: server_now=%v client_now=%v certstart=%v certend=%v", cl.Now().UTC(), clientTime.UTC(), cert.NotBefore.UTC(), cert.NotAfter.UTC()) + } + } + + } + return nil + }, + }, &quic.Config{MaxIdleTimeout: time.Second}) + + if err != nil { + if _, ok := err.(*quic.IdleTimeoutError); ok { + return errTimeout + } + return err + } + defer conn.CloseWithError(0, "") + + return nil +} + +func TestServerSendsBackValidCert(t *testing.T) { + var maxTimeoutErrors = 10 + require.NoError(t, quick.Check(func(timeSinceUnixEpoch time.Duration, keySeed int64, randomClientSkew time.Duration) bool { + err := serverSendsBackValidCert(timeSinceUnixEpoch, keySeed, randomClientSkew) + if err == errTimeout { + maxTimeoutErrors -= 1 + if maxTimeoutErrors <= 0 { + fmt.Println("Too many timeout errors") + return false + } + // Sporadic timeout errors on macOS + return true + } else if err != nil { + fmt.Println("Err:", err) + return false + } + + return true + }, nil)) +} + +func TestServerRotatesCertCorrectly(t *testing.T) { + require.NoError(t, quick.Check(func(timeSinceUnixEpoch time.Duration, keySeed int64) bool { + if timeSinceUnixEpoch < 0 { + timeSinceUnixEpoch = -timeSinceUnixEpoch + } + + // Bound this to 100 years + timeSinceUnixEpoch = time.Duration(timeSinceUnixEpoch % (time.Hour * 24 * 365 * 100)) + // Start a bit further in the future to avoid edge cases around epoch + timeSinceUnixEpoch += time.Hour * 24 * 365 + start := time.UnixMilli(timeSinceUnixEpoch.Milliseconds()) + + cl := clock.NewMock() + cl.Set(start) + + priv, _, err := test.SeededTestKeyPair(ic.Ed25519, 256, keySeed) + if err != nil { + return false + } + tr, err := libp2pwebtransport.New(priv, nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl)) + if err != nil { + return false + } + + l, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) + if err != nil { + return false + } + certhashes := extractCertHashes(l.Multiaddr()) + l.Close() + + // These two certificates together are valid for at most certValidity - (4*clockSkewAllowance) + cl.Add(certValidity - (4 * clockSkewAllowance) - time.Second) + tr, err = libp2pwebtransport.New(priv, nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl)) + if err != nil { + return false + } + + l, err = tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) + if err != nil { + return false + } + defer l.Close() + + var found bool + ma.ForEach(l.Multiaddr(), func(c ma.Component) bool { + if c.Protocol().Code == ma.P_CERTHASH { + for _, prevCerthash := range certhashes { + if c.Value() == prevCerthash { + found = true + return false + } + } + } + return true + }) + + return found + + }, nil)) +} + +func TestServerRotatesCertCorrectlyAfterSteps(t *testing.T) { + cl := clock.NewMock() + // Move one year ahead to avoid edge cases around epoch + cl.Add(time.Hour * 24 * 365) + + priv, _, err := test.RandTestKeyPair(ic.Ed25519, 256) + require.NoError(t, err) + tr, err := libp2pwebtransport.New(priv, nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl)) + require.NoError(t, err) + + l, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) + require.NoError(t, err) + + certhashes := extractCertHashes(l.Multiaddr()) + l.Close() + + // Traverse various time boundaries and make sure we always keep a common certhash. + // e.g. certhash/A/certhash/B ... -> ... certhash/B/certhash/C ... -> ... certhash/C/certhash/D + for i := 0; i < 200; i++ { + cl.Add(24 * time.Hour) + tr, err := libp2pwebtransport.New(priv, nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl)) + require.NoError(t, err) + l, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) + require.NoError(t, err) + + var found bool + ma.ForEach(l.Multiaddr(), func(c ma.Component) bool { + if c.Protocol().Code == ma.P_CERTHASH { + for _, prevCerthash := range certhashes { + if prevCerthash == c.Value() { + found = true + return false + } + } + } + return true + }) + certhashes = extractCertHashes(l.Multiaddr()) + l.Close() + + require.True(t, found, "Failed after hour: %v", i) + } +}