diff --git a/internal/quic/atomic_bits.go b/internal/quic/atomic_bits.go new file mode 100644 index 0000000000..e1e2594d15 --- /dev/null +++ b/internal/quic/atomic_bits.go @@ -0,0 +1,33 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import "sync/atomic" + +// atomicBits is an atomic uint32 that supports setting individual bits. +type atomicBits[T ~uint32] struct { + bits atomic.Uint32 +} + +// set sets the bits in mask to the corresponding bits in v. +// It returns the new value. +func (a *atomicBits[T]) set(v, mask T) T { + if v&^mask != 0 { + panic("BUG: bits in v are not in mask") + } + for { + o := a.bits.Load() + n := (o &^ uint32(mask)) | uint32(v) + if a.bits.CompareAndSwap(o, n) { + return T(n) + } + } +} + +func (a *atomicBits[T]) load() T { + return T(a.bits.Load()) +} diff --git a/internal/quic/conn_streams.go b/internal/quic/conn_streams.go index dd35e34cf6..0ede284e23 100644 --- a/internal/quic/conn_streams.go +++ b/internal/quic/conn_streams.go @@ -185,24 +185,46 @@ func (c *Conn) appendStreamFrames(w *packetWriter, pnum packetNumber, pto bool) for { s := c.streams.sendHead const pto = false - if !s.appendInFrames(w, pnum, pto) { - return false + + state := s.state.load() + if state&streamInSend != 0 { + s.ingate.lock() + ok := s.appendInFramesLocked(w, pnum, pto) + state = s.inUnlockNoQueue() + if !ok { + return false + } } - avail := w.avail() - if !s.appendOutFrames(w, pnum, pto) { - // We've sent some data for this stream, but it still has more to send. - // If the stream got a reasonable chance to put data in a packet, - // advance sendHead to the next stream in line, to avoid starvation. - // We'll come back to this stream after going through the others. - // - // If the packet was already mostly out of space, leave sendHead alone - // and come back to this stream again on the next packet. - if avail > 512 { - c.streams.sendHead = s.next - c.streams.sendTail = s + + if state&streamOutSend != 0 { + avail := w.avail() + s.outgate.lock() + ok := s.appendOutFramesLocked(w, pnum, pto) + state = s.outUnlockNoQueue() + if !ok { + // We've sent some data for this stream, but it still has more to send. + // If the stream got a reasonable chance to put data in a packet, + // advance sendHead to the next stream in line, to avoid starvation. + // We'll come back to this stream after going through the others. + // + // If the packet was already mostly out of space, leave sendHead alone + // and come back to this stream again on the next packet. + if avail > 512 { + c.streams.sendHead = s.next + c.streams.sendTail = s + } + return false } - return false } + + if state == streamInDone|streamOutDone { + // Stream is finished, remove it from the conn. + s.state.set(streamConnRemoved, streamConnRemoved) + delete(c.streams.streams, s.id) + + // TODO: Provide the peer with additional stream quota (MAX_STREAMS). + } + next := s.next s.next = nil if (next == s) != (s == c.streams.sendTail) { @@ -231,10 +253,16 @@ func (c *Conn) appendStreamFramesPTO(w *packetWriter, pnum packetNumber) bool { defer c.streams.sendMu.Unlock() for _, s := range c.streams.streams { const pto = true - if !s.appendInFrames(w, pnum, pto) { + s.ingate.lock() + inOK := s.appendInFramesLocked(w, pnum, pto) + s.inUnlockNoQueue() + if !inOK { return false } - if !s.appendOutFrames(w, pnum, pto) { + s.outgate.lock() + outOK := s.appendOutFramesLocked(w, pnum, pto) + s.outUnlockNoQueue() + if !outOK { return false } } diff --git a/internal/quic/conn_streams_test.go b/internal/quic/conn_streams_test.go index 877dbb94fc..9bbc994b11 100644 --- a/internal/quic/conn_streams_test.go +++ b/internal/quic/conn_streams_test.go @@ -8,6 +8,8 @@ package quic import ( "context" + "fmt" + "io" "testing" ) @@ -253,3 +255,90 @@ func TestStreamsWriteQueueFairness(t *testing.T) { } } } + +func TestStreamsShutdown(t *testing.T) { + // These tests verify that a stream is removed from the Conn's map of live streams + // after it is fully shut down. + // + // Each case consists of a setup step, after which one stream should exist, + // and a shutdown step, after which no streams should remain in the Conn. + for _, test := range []struct { + name string + side streamSide + styp streamType + setup func(*testing.T, *testConn, *Stream) + shutdown func(*testing.T, *testConn, *Stream) + }{{ + name: "closed", + side: localStream, + styp: uniStream, + setup: func(t *testing.T, tc *testConn, s *Stream) { + s.CloseContext(canceledContext()) + }, + shutdown: func(t *testing.T, tc *testConn, s *Stream) { + tc.writeAckForAll() + }, + }, { + name: "local close", + side: localStream, + styp: bidiStream, + setup: func(t *testing.T, tc *testConn, s *Stream) { + tc.writeFrames(packetType1RTT, debugFrameResetStream{ + id: s.id, + }) + s.CloseContext(canceledContext()) + }, + shutdown: func(t *testing.T, tc *testConn, s *Stream) { + tc.writeAckForAll() + }, + }, { + name: "remote reset", + side: localStream, + styp: bidiStream, + setup: func(t *testing.T, tc *testConn, s *Stream) { + s.CloseContext(canceledContext()) + tc.wantIdle("all frames after CloseContext are ignored") + tc.writeAckForAll() + }, + shutdown: func(t *testing.T, tc *testConn, s *Stream) { + tc.writeFrames(packetType1RTT, debugFrameResetStream{ + id: s.id, + }) + }, + }, { + name: "local close", + side: remoteStream, + styp: uniStream, + setup: func(t *testing.T, tc *testConn, s *Stream) { + ctx := canceledContext() + tc.writeFrames(packetType1RTT, debugFrameStream{ + id: s.id, + fin: true, + }) + if n, err := s.ReadContext(ctx, make([]byte, 16)); n != 0 || err != io.EOF { + t.Errorf("ReadContext() = %v, %v; want 0, io.EOF", n, err) + } + }, + shutdown: func(t *testing.T, tc *testConn, s *Stream) { + s.CloseRead() + }, + }} { + name := fmt.Sprintf("%v/%v/%v", test.side, test.styp, test.name) + t.Run(name, func(t *testing.T) { + tc, s := newTestConnAndStream(t, serverSide, test.side, test.styp, + permissiveTransportParameters) + tc.ignoreFrame(frameTypeStreamBase) + tc.ignoreFrame(frameTypeStopSending) + test.setup(t, tc, s) + tc.wantIdle("conn should be idle after setup") + if got, want := len(tc.conn.streams.streams), 1; got != want { + t.Fatalf("after setup: %v streams in Conn's map; want %v", got, want) + } + test.shutdown(t, tc, s) + tc.wantIdle("conn should be idle after shutdown") + if got, want := len(tc.conn.streams.streams), 0; got != want { + t.Fatalf("after shutdown: %v streams in Conn's map; want %v", got, want) + } + }) + } +} diff --git a/internal/quic/conn_test.go b/internal/quic/conn_test.go index d8c44558dc..ea720d5754 100644 --- a/internal/quic/conn_test.go +++ b/internal/quic/conn_test.go @@ -394,6 +394,7 @@ func (tc *testConn) writeFrames(ptype packetType, frames ...debugFrame) { // writeAckForAll sends the Conn a datagram containing an ack for all packets up to the // last one received. func (tc *testConn) writeAckForAll() { + tc.t.Helper() if tc.lastPacket == nil { return } @@ -405,6 +406,7 @@ func (tc *testConn) writeAckForAll() { // writeAckForLatest sends the Conn a datagram containing an ack for the // most recent packet received. func (tc *testConn) writeAckForLatest() { + tc.t.Helper() if tc.lastPacket == nil { return } diff --git a/internal/quic/stream.go b/internal/quic/stream.go index 1033cbb401..2dbf4461ba 100644 --- a/internal/quic/stream.go +++ b/internal/quic/stream.go @@ -49,9 +49,38 @@ type Stream struct { outresetcode uint64 // reset code to send in RESET_STREAM outdone chan struct{} // closed when all data sent + // Atomic stream state bits. + // + // These bits provide a fast way to coordinate between the + // send and receive sides of the stream, and the conn's loop. + // + // streamIn* bits must be set with ingate held. + // streamOut* bits must be set with outgate held. + // streamConn* bits are set by the conn's loop. + state atomicBits[streamState] + prev, next *Stream // guarded by streamsState.sendMu } +type streamState uint32 + +const ( + // streamInSend and streamOutSend are set when there are + // frames to send for the inbound or outbound sides of the stream. + // For example, MAX_STREAM_DATA or STREAM_DATA_BLOCKED. + streamInSend = streamState(1 << iota) + streamOutSend + + // streamInDone and streamOutDone are set when the inbound or outbound + // sides of the stream are finished. When both are set, the stream + // can be removed from the Conn and forgotten. + streamInDone + streamOutDone + + // streamConnRemoved is set when the stream has been removed from the conn. + streamConnRemoved +) + // newStream returns a new stream. // // The stream's ingate and outgate are locked. @@ -289,15 +318,34 @@ func (s *Stream) CloseWrite() { // that the stream was terminated abruptly. // Any blocked writes will be unblocked and return errors. // -// Reset sends the application protocol error code to the peer. +// Reset sends the application protocol error code, which must be +// less than 2^62, to the peer. // It does not wait for the peer to acknowledge receipt of the error. // Use CloseContext to wait for the peer's acknowledgement. +// +// Reset does not affect reads. +// Use CloseRead to abort reads on the stream. func (s *Stream) Reset(code uint64) { + const userClosed = true + s.resetInternal(code, userClosed) +} + +func (s *Stream) resetInternal(code uint64, userClosed bool) { s.outgate.lock() defer s.outUnlock() + if s.IsReadOnly() { + return + } + if userClosed { + // Mark that the user closed the stream. + s.outclosed.set() + } if s.outreset.isSet() { return } + if code > maxVarint { + code = maxVarint + } // We could check here to see if the stream is closed and the // peer has acked all the data and the FIN, but sending an // extra RESET_STREAM in this case is harmless. @@ -310,44 +358,67 @@ func (s *Stream) Reset(code uint64) { // inUnlock unlocks s.ingate. // It sets the gate condition if reads from s will not block. -// If s has receive-related frames to write, it notifies the Conn. +// If s has receive-related frames to write or if both directions +// are done and the stream should be removed, it notifies the Conn. func (s *Stream) inUnlock() { - if s.inUnlockNoQueue() { + state := s.inUnlockNoQueue() + if state&streamInSend != 0 || state == streamInDone|streamOutDone { s.conn.queueStreamForSend(s) } } // inUnlockNoQueue is inUnlock, // but reports whether s has frames to write rather than notifying the Conn. -func (s *Stream) inUnlockNoQueue() (shouldSend bool) { +func (s *Stream) inUnlockNoQueue() streamState { canRead := s.inset.contains(s.in.start) || // data available to read s.insize == s.in.start || // at EOF s.inresetcode != -1 || // reset by peer s.inclosed.isSet() // closed locally defer s.ingate.unlock(canRead) - return s.insendmax.shouldSend() || // STREAM_MAX_DATA - s.inclosed.shouldSend() // STOP_SENDING + var state streamState + switch { + case s.IsWriteOnly(): + state = streamInDone + case s.inresetcode != -1: // reset by peer + fallthrough + case s.in.start == s.insize: // all data received and read + // We don't increase MAX_STREAMS until the user calls ReadClose or Close, + // so the receive side is not finished until inclosed is set. + if s.inclosed.isSet() { + state = streamInDone + } + case s.insendmax.shouldSend(): // STREAM_MAX_DATA + state = streamInSend + case s.inclosed.shouldSend(): // STOP_SENDING + state = streamInSend + } + const mask = streamInDone | streamInSend + return s.state.set(state, mask) } // outUnlock unlocks s.outgate. // It sets the gate condition if writes to s will not block. -// If s has send-related frames to write, it notifies the Conn. +// If s has send-related frames to write or if both directions +// are done and the stream should be removed, it notifies the Conn. func (s *Stream) outUnlock() { - if s.outUnlockNoQueue() { + state := s.outUnlockNoQueue() + if state&streamOutSend != 0 || state == streamInDone|streamOutDone { s.conn.queueStreamForSend(s) } } // outUnlockNoQueue is outUnlock, // but reports whether s has frames to write rather than notifying the Conn. -func (s *Stream) outUnlockNoQueue() (shouldSend bool) { +func (s *Stream) outUnlockNoQueue() streamState { isDone := s.outclosed.isReceived() && s.outacked.isrange(0, s.out.end) || // all data acked s.outreset.isSet() // reset locally if isDone { select { case <-s.outdone: default: - close(s.outdone) + if !s.IsReadOnly() { + close(s.outdone) + } } } lim := min(s.out.start+s.outmaxbuf, s.outwin) @@ -355,14 +426,32 @@ func (s *Stream) outUnlockNoQueue() (shouldSend bool) { s.outclosed.isSet() || // closed locally s.outreset.isSet() // reset locally defer s.outgate.unlock(canWrite) - if s.outreset.isSet() { - // If the stream is reset locally, the only frame we'll send is RESET_STREAM. - return s.outreset.shouldSend() - } - return len(s.outunsent) > 0 || // STREAM frame with data - s.outclosed.shouldSend() || // STREAM frame with FIN bit - s.outopened.shouldSend() || // STREAM frame with no data - s.outblocked.shouldSend() // STREAM_DATA_BLOCKED + var state streamState + switch { + case s.IsReadOnly(): + state = streamOutDone + case s.outclosed.isReceived() && s.outacked.isrange(0, s.out.end): // all data sent and acked + fallthrough + case s.outreset.isReceived(): // RESET_STREAM sent and acked + // We don't increase MAX_STREAMS until the user calls WriteClose or Close, + // so the send side is not finished until outclosed is set. + if s.outclosed.isSet() { + state = streamOutDone + } + case s.outreset.shouldSend(): // RESET_STREAM + state = streamOutSend + case s.outreset.isSet(): // RESET_STREAM sent but not acknowledged + case len(s.outunsent) > 0: // STREAM frame with data + state = streamOutSend + case s.outclosed.shouldSend(): // STREAM frame with FIN bit + state = streamOutSend + case s.outopened.shouldSend(): // STREAM frame with no data + state = streamOutSend + case s.outblocked.shouldSend(): // STREAM_DATA_BLOCKED + state = streamOutSend + } + const mask = streamOutDone | streamOutSend + return s.state.set(state, mask) } // handleData handles data received in a STREAM frame. @@ -431,7 +520,8 @@ func (s *Stream) checkStreamBounds(end int64, fin bool) error { func (s *Stream) handleStopSending(code uint64) error { // Peer requests that we reset this stream. // https://www.rfc-editor.org/rfc/rfc9000#section-3.5-4 - s.Reset(code) + const userReset = false + s.resetInternal(code, userReset) return nil } @@ -504,14 +594,12 @@ func (s *Stream) ackOrLossData(pnum packetNumber, start, end int64, fin bool, fa } } -// appendInFrames appends STOP_SENDING and MAX_STREAM_DATA frames +// appendInFramesLocked appends STOP_SENDING and MAX_STREAM_DATA frames // to the current packet. // // It returns true if no more frames need appending, // false if not everything fit in the current packet. -func (s *Stream) appendInFrames(w *packetWriter, pnum packetNumber, pto bool) bool { - s.ingate.lock() - defer s.inUnlockNoQueue() +func (s *Stream) appendInFramesLocked(w *packetWriter, pnum packetNumber, pto bool) bool { if s.inclosed.shouldSendPTO(pto) { // We don't currently have an API for setting the error code. // Just send zero. @@ -534,14 +622,12 @@ func (s *Stream) appendInFrames(w *packetWriter, pnum packetNumber, pto bool) bo return true } -// appendOutFrames appends RESET_STREAM, STREAM_DATA_BLOCKED, and STREAM frames +// appendOutFramesLocked appends RESET_STREAM, STREAM_DATA_BLOCKED, and STREAM frames // to the current packet. // // It returns true if no more frames need appending, // false if not everything fit in the current packet. -func (s *Stream) appendOutFrames(w *packetWriter, pnum packetNumber, pto bool) bool { - s.outgate.lock() - defer s.outUnlockNoQueue() +func (s *Stream) appendOutFramesLocked(w *packetWriter, pnum packetNumber, pto bool) bool { if s.outreset.isSet() { // RESET_STREAM if s.outreset.shouldSendPTO(pto) { diff --git a/internal/quic/stream_test.go b/internal/quic/stream_test.go index 79377c6a4a..e22e0432ef 100644 --- a/internal/quic/stream_test.go +++ b/internal/quic/stream_test.go @@ -1111,6 +1111,24 @@ func TestStreamPeerResetFollowedByData(t *testing.T) { }) } +func TestStreamResetInvalidCode(t *testing.T) { + tc, s := newTestConnAndLocalStream(t, serverSide, uniStream) + s.Reset(1 << 62) + tc.wantFrame("reset with invalid code sends a RESET_STREAM anyway", + packetType1RTT, debugFrameResetStream{ + id: s.id, + // The code we send here isn't specified, + // so this could really be any value. + code: (1 << 62) - 1, + }) +} + +func TestStreamResetReceiveOnly(t *testing.T) { + tc, s := newTestConnAndRemoteStream(t, serverSide, uniStream) + s.Reset(0) + tc.wantIdle("resetting a receive-only stream has no effect") +} + func TestStreamPeerStopSendingForActiveStream(t *testing.T) { // "An endpoint that receives a STOP_SENDING frame MUST send a RESET_STREAM frame if // the stream is in the "Ready" or "Send" state." @@ -1145,6 +1163,21 @@ func TestStreamPeerStopSendingForActiveStream(t *testing.T) { }) } +type streamSide string + +const ( + localStream = streamSide("local") + remoteStream = streamSide("remote") +) + +func newTestConnAndStream(t *testing.T, side connSide, sside streamSide, styp streamType, opts ...any) (*testConn, *Stream) { + if sside == localStream { + return newTestConnAndLocalStream(t, side, styp, opts...) + } else { + return newTestConnAndRemoteStream(t, side, styp, opts...) + } +} + func newTestConnAndLocalStream(t *testing.T, side connSide, styp streamType, opts ...any) (*testConn, *Stream) { t.Helper() ctx := canceledContext()