From a08801eae680a46e74d82a78dab2c221d25e3511 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 27 Apr 2024 20:16:29 +0200 Subject: [PATCH] add support for datagrams (#142) * add support for datagrams * add a test case --- go.mod | 2 +- mock_stream_test.go | 33 +++++++++++++++++++++++-- session.go | 16 +++++++++---- session_manager.go | 2 +- session_test.go | 5 ++-- webtransport_test.go | 57 +++++++++++++++++++++++++++++++++++++++++++- 6 files changed, 104 insertions(+), 11 deletions(-) diff --git a/go.mod b/go.mod index c72833b..2a85961 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/quic-go/quic-go v0.43.0 github.com/stretchr/testify v1.8.0 go.uber.org/mock v0.4.0 + golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 ) require ( @@ -19,7 +20,6 @@ require ( github.com/quic-go/qpack v0.4.0 // indirect github.com/rogpeppe/go-internal v1.10.0 // indirect golang.org/x/crypto v0.14.0 // indirect - golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 // indirect golang.org/x/mod v0.12.0 // indirect golang.org/x/net v0.17.0 // indirect golang.org/x/sys v0.15.0 // indirect diff --git a/mock_stream_test.go b/mock_stream_test.go index 87be243..56bc0f7 100644 --- a/mock_stream_test.go +++ b/mock_stream_test.go @@ -1,9 +1,9 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/quic-go/quic-go (interfaces: Stream) +// Source: github.com/quic-go/quic-go/http3 (interfaces: Stream) // // Generated by this command: // -// mockgen -package webtransport -destination mock_stream_test.go github.com/quic-go/quic-go Stream +// mockgen -package webtransport -destination mock_stream_test.go github.com/quic-go/quic-go/http3 Stream // // Package webtransport is a generated GoMock package. @@ -108,6 +108,35 @@ func (mr *MockStreamMockRecorder) Read(arg0 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockStream)(nil).Read), arg0) } +// ReceiveDatagram mocks base method. +func (m *MockStream) ReceiveDatagram(arg0 context.Context) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReceiveDatagram", arg0) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReceiveDatagram indicates an expected call of ReceiveDatagram. +func (mr *MockStreamMockRecorder) ReceiveDatagram(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceiveDatagram", reflect.TypeOf((*MockStream)(nil).ReceiveDatagram), arg0) +} + +// SendDatagram mocks base method. +func (m *MockStream) SendDatagram(arg0 []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendDatagram", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SendDatagram indicates an expected call of SendDatagram. +func (mr *MockStreamMockRecorder) SendDatagram(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendDatagram", reflect.TypeOf((*MockStream)(nil).SendDatagram), arg0) +} + // SetDeadline mocks base method. func (m *MockStream) SetDeadline(arg0 time.Time) error { m.ctrl.T.Helper() diff --git a/session.go b/session.go index 33e2425..bd8c876 100644 --- a/session.go +++ b/session.go @@ -63,7 +63,7 @@ func (q *acceptQueue[T]) Chan() <-chan struct{} { return q.c } type Session struct { sessionID sessionID qconn http3.Connection - requestStr quic.Stream + requestStr http3.Stream streamHdr []byte uniStreamHdr []byte @@ -82,7 +82,7 @@ type Session struct { streams streamsMap } -func newSession(sessionID sessionID, qconn http3.Connection, requestStr quic.Stream) *Session { +func newSession(sessionID sessionID, qconn http3.Connection, requestStr http3.Stream) *Session { tracingID := qconn.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID) ctx, ctxCancel := context.WithCancel(context.WithValue(context.Background(), quic.ConnectionTracingKey, tracingID)) c := &Session{ @@ -390,6 +390,14 @@ func (s *Session) CloseWithError(code SessionErrorCode, msg string) error { return err } +func (s *Session) SendDatagram(b []byte) error { + return s.requestStr.SendDatagram(b) +} + +func (s *Session) ReceiveDatagram(ctx context.Context) ([]byte, error) { + return s.requestStr.ReceiveDatagram(ctx) +} + func (s *Session) closeWithError(code SessionErrorCode, msg string) (bool /* first call to close session */, error) { s.closeMx.Lock() defer s.closeMx.Unlock() @@ -413,6 +421,6 @@ func (s *Session) closeWithError(code SessionErrorCode, msg string) (bool /* fir ) } -func (c *Session) ConnectionState() quic.ConnectionState { - return c.qconn.ConnectionState() +func (s *Session) ConnectionState() quic.ConnectionState { + return s.qconn.ConnectionState() } diff --git a/session_manager.go b/session_manager.go index 5361019..7150edb 100644 --- a/session_manager.go +++ b/session_manager.go @@ -164,7 +164,7 @@ func (m *sessionManager) handleUniStream(str quic.ReceiveStream, sess *session) } // AddSession adds a new WebTransport session. -func (m *sessionManager) AddSession(qconn http3.Connection, id sessionID, requestStr quic.Stream) *Session { +func (m *sessionManager) AddSession(qconn http3.Connection, id sessionID, requestStr http3.Stream) *Session { conn := newSession(id, qconn, requestStr) connTracingID := qconn.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID) diff --git a/session_test.go b/session_test.go index 8d3b5dc..fd0287a 100644 --- a/session_test.go +++ b/session_test.go @@ -7,20 +7,21 @@ import ( "time" "github.com/quic-go/quic-go" + "github.com/quic-go/quic-go/http3" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) //go:generate sh -c "go run go.uber.org/mock/mockgen -package webtransport -destination mock_connection_test.go github.com/quic-go/quic-go/http3 Connection && cat mock_connection_test.go | sed s@qerr\\.ApplicationErrorCode@quic.ApplicationErrorCode@g > tmp.go && mv tmp.go mock_connection_test.go && goimports -w mock_connection_test.go" -//go:generate sh -c "go run go.uber.org/mock/mockgen -package webtransport -destination mock_stream_test.go github.com/quic-go/quic-go Stream && cat mock_stream_test.go | sed s@protocol\\.StreamID@quic.StreamID@g | sed s@qerr\\.StreamErrorCode@quic.StreamErrorCode@g > tmp.go && mv tmp.go mock_stream_test.go && goimports -w mock_stream_test.go" +//go:generate sh -c "go run go.uber.org/mock/mockgen -package webtransport -destination mock_stream_test.go github.com/quic-go/quic-go/http3 Stream && cat mock_stream_test.go | sed s@protocol\\.StreamID@quic.StreamID@g | sed s@qerr\\.StreamErrorCode@quic.StreamErrorCode@g > tmp.go && mv tmp.go mock_stream_test.go && goimports -w mock_stream_test.go" type mockRequestStream struct { *MockStream c chan struct{} } -func newMockRequestStream(ctrl *gomock.Controller) quic.Stream { +func newMockRequestStream(ctrl *gomock.Controller) http3.Stream { str := NewMockStream(ctrl) str.EXPECT().Close() str.EXPECT().CancelRead(gomock.Any()) diff --git a/webtransport_test.go b/webtransport_test.go index d819191..3d84ca8 100644 --- a/webtransport_test.go +++ b/webtransport_test.go @@ -2,7 +2,6 @@ package webtransport_test import ( "context" - "crypto/rand" "crypto/tls" "errors" "fmt" @@ -15,6 +14,8 @@ import ( "testing" "time" + "golang.org/x/exp/rand" + "github.com/quic-go/webtransport-go" "github.com/quic-go/quic-go" @@ -595,3 +596,57 @@ func TestWriteCloseRace(t *testing.T) { <-ready close(ch) } + +func TestDatagrams(t *testing.T) { + const num = 100 + var mx sync.Mutex + m := make(map[string]bool, num) + + var counter int + done := make(chan struct{}) + serverErrChan := make(chan error, 1) + sess, closeServer := establishSession(t, func(sess *webtransport.Session) { + defer close(done) + for { + b, err := sess.ReceiveDatagram(context.Background()) + if err != nil { + return + } + mx.Lock() + if _, ok := m[string(b)]; !ok { + serverErrChan <- errors.New("received unexpected datagram") + return + } + m[string(b)] = true + mx.Unlock() + counter++ + } + }) + defer closeServer() + + errChan := make(chan error, 1) + + for i := 0; i < num; i++ { + b := make([]byte, 800) + rand.Read(b) + mx.Lock() + m[string(b)] = false + mx.Unlock() + if err := sess.SendDatagram(b); err != nil { + break + } + } + time.Sleep(scaleDuration(10 * time.Millisecond)) + sess.CloseWithError(0, "") + select { + case err := <-serverErrChan: + t.Fatal(err) + case err := <-errChan: + t.Fatal(err) + case <-done: + t.Logf("sent: %d, received: %d", num, counter) + require.Greater(t, counter, num*4/5) + case <-time.After(5 * time.Second): + t.Fatal("timeout") + } +}