diff --git a/client.go b/client.go index 4d49381..5ae17ce 100644 --- a/client.go +++ b/client.go @@ -14,6 +14,8 @@ import ( "github.com/quic-go/quic-go/quicvarint" ) +var errNoWebTransport = errors.New("server didn't enable WebTransport") + type Dialer struct { // If not set, reasonable defaults will be used. // In order for WebTransport to function, this implementation will: @@ -110,7 +112,25 @@ func (d *Dialer) Dial(ctx context.Context, urlStr string, reqHdr http.Header) (* } req = req.WithContext(ctx) - rsp, err := d.RoundTripper.RoundTripOpt(req, http3.RoundTripOpt{DontCloseRequestStream: true}) + rsp, err := d.RoundTripper.RoundTripOpt(req, http3.RoundTripOpt{ + DontCloseRequestStream: true, + CheckSettings: func(settings http3.Settings) error { + if !settings.EnableExtendedConnect { + return errors.New("server didn't enable Extended CONNECT") + } + if !settings.EnableDatagram { + return errors.New("server didn't enable HTTP/3 datagram support") + } + if settings.Other == nil { + return errNoWebTransport + } + s, ok := settings.Other[settingsEnableWebtransport] + if !ok || s != 1 { + return errNoWebTransport + } + return nil + }, + }) if err != nil { return nil, nil, err } diff --git a/client_test.go b/client_test.go index 87f0a6f..4ea1e86 100644 --- a/client_test.go +++ b/client_test.go @@ -47,18 +47,53 @@ func (c *requestStreamDelayingConn) OpenStreamSync(ctx context.Context) (quic.St return str, nil } +const ( + // Extended CONNECT, RFC 9220 + settingExtendedConnect = 0x8 + // HTTP Datagrams, RFC 9297 + settingDatagram = 0x33 + // WebTransport + settingsEnableWebtransport = 0x2b603742 +) + +// appendSettingsFrame serializes an HTTP/3 SETTINGS frame +// It reimplements the function in the http3 package, in a slightly simplified way. +func appendSettingsFrame(b []byte, values map[uint64]uint64) []byte { + b = quicvarint.Append(b, 0x4) + var l uint64 + for k, val := range values { + l += uint64(quicvarint.Len(k)) + uint64(quicvarint.Len(val)) + } + b = quicvarint.Append(b, l) + for id, val := range values { + b = quicvarint.Append(b, id) + b = quicvarint.Append(b, val) + } + return b +} + func TestClientInvalidResponseHandling(t *testing.T) { tlsConf := tlsConf.Clone() tlsConf.NextProtos = []string{"h3"} - s, err := quic.ListenAddr("localhost:0", tlsConf, nil) + s, err := quic.ListenAddr("localhost:0", tlsConf, &quic.Config{EnableDatagrams: true}) require.NoError(t, err) errChan := make(chan error) go func() { conn, err := s.Accept(context.Background()) require.NoError(t, err) + // send the SETTINGS frame + settingsStr, err := conn.OpenUniStream() + require.NoError(t, err) + _, err = settingsStr.Write(appendSettingsFrame([]byte{0} /* stream type */, map[uint64]uint64{ + settingDatagram: 1, + settingExtendedConnect: 1, + settingsEnableWebtransport: 1, + })) + require.NoError(t, err) + str, err := conn.AcceptStream(context.Background()) require.NoError(t, err) - // write a HTTP3 data frame. This will cause an error, since a HEADERS frame is expected + // write an HTTP/3 data frame. This will cause an error, since a HEADERS frame is expected var b []byte b = quicvarint.Append(b, 0x0) b = quicvarint.Append(b, 1337) @@ -79,11 +114,79 @@ func TestClientInvalidResponseHandling(t *testing.T) { } _, _, err = d.Dial(context.Background(), fmt.Sprintf("https://localhost:%d", s.Addr().(*net.UDPAddr).Port), nil) require.Error(t, err) - sErr := <-errChan + var sErr error + select { + case sErr = <-errChan: + case <-time.After(5 * time.Second): + t.Fatal("timeout") + } require.Error(t, sErr) var appErr *quic.ApplicationError require.True(t, errors.As(sErr, &appErr)) - require.Equal(t, quic.ApplicationErrorCode(0x105), appErr.ErrorCode) // H3_FRAME_UNEXPECTED + require.Equal(t, http3.ErrCodeFrameUnexpected, http3.ErrCode(appErr.ErrorCode)) +} + +func TestClientInvalidSettingsHandling(t *testing.T) { + for _, tc := range []struct { + name string + settings map[uint64]uint64 + errorStr string + }{ + { + name: "Extended CONNECT disabled", + settings: map[uint64]uint64{ + settingDatagram: 1, + settingExtendedConnect: 0, + settingsEnableWebtransport: 1, + }, + errorStr: "server didn't enable Extended CONNECT", + }, + { + name: "HTTP/3 DATAGRAMs disabled", + settings: map[uint64]uint64{ + settingDatagram: 0, + settingExtendedConnect: 1, + settingsEnableWebtransport: 1, + }, + errorStr: "server didn't enable HTTP/3 datagram support", + }, + { + name: "WebTransport disabled", + settings: map[uint64]uint64{ + settingDatagram: 1, + settingExtendedConnect: 1, + settingsEnableWebtransport: 0, + }, + errorStr: "server didn't enable WebTransport", + }, + } { + tc := tc + t.Run(tc.name, func(t *testing.T) { + tlsConf := tlsConf.Clone() + tlsConf.NextProtos = []string{"h3"} + s, err := quic.ListenAddr("localhost:0", tlsConf, &quic.Config{EnableDatagrams: true}) + require.NoError(t, err) + go func() { + conn, err := s.Accept(context.Background()) + require.NoError(t, err) + // send the SETTINGS frame + settingsStr, err := conn.OpenUniStream() + require.NoError(t, err) + _, err = settingsStr.Write(appendSettingsFrame([]byte{0} /* stream type */, tc.settings)) + require.NoError(t, err) + }() + + d := webtransport.Dialer{ + RoundTripper: &http3.RoundTripper{ + TLSClientConfig: &tls.Config{RootCAs: certPool}, + }, + } + _, _, err = d.Dial(context.Background(), fmt.Sprintf("https://localhost:%d", s.Addr().(*net.UDPAddr).Port), nil) + require.Error(t, err) + require.ErrorContains(t, err, tc.errorStr) + }) + + } } func TestClientReorderedUpgrade(t *testing.T) {