diff --git a/src/libraries/System.IO/tests/BufferedStream/BufferedStreamTests.cs b/src/libraries/System.IO/tests/BufferedStream/BufferedStreamTests.cs index 8395a64a98de9..0826875520205 100644 --- a/src/libraries/System.IO/tests/BufferedStream/BufferedStreamTests.cs +++ b/src/libraries/System.IO/tests/BufferedStream/BufferedStreamTests.cs @@ -128,6 +128,94 @@ public async Task ShouldNotFlushUnderlyingStreamIfReadOnly(bool underlyingCanSee Assert.Equal(0, wrapper.TimesCalled(nameof(wrapper.FlushAsync))); } + [Theory] + [MemberData(nameof(SetPosMethods))] + public void SetPositionInsideBufferRange_Read_WillNotReadUnderlyingStreamAgain(int sharedBufSize, Action setPos) + { + var trackingStream = new CallTrackingStream(new MemoryStream()); + var bufferedStream = new BufferedStream(trackingStream, sharedBufSize); + bufferedStream.Write(Enumerable.Range(0, sharedBufSize * 2).Select(i => (byte)i).ToArray(), 0, sharedBufSize * 2); + setPos(bufferedStream, 0); + + var readBuf = new byte[sharedBufSize - 1]; + + // First half part verification + byte[] expectedReadBuf = Enumerable.Range(0, sharedBufSize - 1).Select(i => (byte)i).ToArray(); + + // Call Read() to fill shared read buffer + int readBytes = bufferedStream.Read(readBuf, 0, readBuf.Length); + Assert.Equal(readBuf.Length, readBytes); + Assert.Equal(sharedBufSize - 1, bufferedStream.Position); + Assert.Equal(expectedReadBuf, readBuf); + Assert.Equal(1, trackingStream.TimesCalled(nameof(trackingStream.Read))); + + // Set position inside range of shared read buffer + for (int pos = 0; pos < sharedBufSize - 1; pos++) + { + setPos(bufferedStream, pos); + + readBytes = bufferedStream.Read(readBuf, pos, readBuf.Length - pos); + Assert.Equal(readBuf.Length - pos, readBytes); + Assert.Equal(sharedBufSize - 1, bufferedStream.Position); + Assert.Equal(expectedReadBuf, readBuf); + // Should not trigger underlying stream's Read() + Assert.Equal(1, trackingStream.TimesCalled(nameof(trackingStream.Read))); + } + + Assert.Equal(sharedBufSize - 1, bufferedStream.ReadByte()); + Assert.Equal(sharedBufSize, bufferedStream.Position); + // Should not trigger underlying stream's Read() + Assert.Equal(1, trackingStream.TimesCalled(nameof(trackingStream.Read))); + + // Second half part verification + expectedReadBuf = Enumerable.Range(sharedBufSize, sharedBufSize - 1).Select(i => (byte)i).ToArray(); + // Call Read() to fill shared read buffer + readBytes = bufferedStream.Read(readBuf, 0, readBuf.Length); + Assert.Equal(readBuf.Length, readBytes); + Assert.Equal(sharedBufSize * 2 - 1, bufferedStream.Position); + Assert.Equal(expectedReadBuf, readBuf); + Assert.Equal(2, trackingStream.TimesCalled(nameof(trackingStream.Read))); + + // Set position inside range of shared read buffer + for (int pos = 0; pos < sharedBufSize - 1; pos++) + { + setPos(bufferedStream, sharedBufSize + pos); + + readBytes = bufferedStream.Read(readBuf, pos, readBuf.Length - pos); + Assert.Equal(readBuf.Length - pos, readBytes); + Assert.Equal(sharedBufSize * 2 - 1, bufferedStream.Position); + Assert.Equal(expectedReadBuf, readBuf); + // Should not trigger underlying stream's Read() + Assert.Equal(2, trackingStream.TimesCalled(nameof(trackingStream.Read))); + } + + Assert.Equal(sharedBufSize * 2 - 1, bufferedStream.ReadByte()); + Assert.Equal(sharedBufSize * 2, bufferedStream.Position); + // Should not trigger underlying stream's Read() + Assert.Equal(2, trackingStream.TimesCalled(nameof(trackingStream.Read))); + } + + public static IEnumerable SetPosMethods + { + get + { + var setByPosition = (Action)((stream, pos) => stream.Position = pos); + var seekFromBegin = (Action)((stream, pos) => stream.Seek(pos, SeekOrigin.Begin)); + var seekFromCurrent = (Action)((stream, pos) => stream.Seek(pos - stream.Position, SeekOrigin.Current)); + var seekFromEnd = (Action)((stream, pos) => stream.Seek(pos - stream.Length, SeekOrigin.End)); + + yield return new object[] { 3, setByPosition }; + yield return new object[] { 3, seekFromBegin }; + yield return new object[] { 3, seekFromCurrent }; + yield return new object[] { 3, seekFromEnd }; + + yield return new object[] { 10, setByPosition }; + yield return new object[] { 10, seekFromBegin }; + yield return new object[] { 10, seekFromCurrent }; + yield return new object[] { 10, seekFromEnd }; + } + } + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] public async Task ConcurrentOperationsAreSerialized() { diff --git a/src/libraries/System.Private.CoreLib/src/System/IO/BufferedStream.cs b/src/libraries/System.Private.CoreLib/src/System/IO/BufferedStream.cs index 1b1f5ad44a46a..bcdbc292ab415 100644 --- a/src/libraries/System.Private.CoreLib/src/System/IO/BufferedStream.cs +++ b/src/libraries/System.Private.CoreLib/src/System/IO/BufferedStream.cs @@ -205,15 +205,7 @@ public override long Position if (value < 0) ThrowHelper.ThrowArgumentOutOfRangeException(ExceptionArgument.value, ExceptionResource.ArgumentOutOfRange_NeedNonNegNum); - EnsureNotClosed(); - EnsureCanSeek(); - - if (_writePos > 0) - FlushWrite(); - - _readPos = 0; - _readLen = 0; - _stream!.Seek(value, SeekOrigin.Begin); + Seek(value, SeekOrigin.Begin); } }