Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transport/http: Add fix for HTTP request stream not seekable #357

Merged
merged 1 commit into from
Mar 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions transport/http/checksum_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ func (m *contentMD5Checksum) HandleBuild(
stream := req.GetStream()
// compute checksum if payload is explicit
if stream != nil {
if !req.IsStreamSeekable() {
return out, metadata, fmt.Errorf(
"unseekable stream is not supported for computing md5 checksum")
}

v, err := computeMD5Checksum(stream)
if err != nil {
return out, metadata, fmt.Errorf("error computing md5 checksum, %w", err)
Expand Down
3 changes: 2 additions & 1 deletion transport/http/checksum_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func TestChecksumMiddleware(t *testing.T) {
"nil body": {},
"unseekable payload": {
payload: bytes.NewBuffer([]byte(`xyz`)),
expectError: "error rewinding request stream",
expectError: "unseekable stream is not supported",
},
}

Expand All @@ -61,6 +61,7 @@ func TestChecksumMiddleware(t *testing.T) {
if e, a := c.expectError, err.Error(); !strings.Contains(a, e) {
t.Fatalf("expect error to contain %q, got %v", e, a)
}
return
} else if err != nil {
t.Fatalf("expect no error, got %v", err)
}
Expand Down
40 changes: 29 additions & 11 deletions transport/http/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,23 @@ func (r *Request) Clone() *Request {
// to the request and ok set. If the length cannot be determined, an error will
// be returned.
func (r *Request) StreamLength() (size int64, ok bool, err error) {
if r.stream == nil {
return streamLength(r.stream, r.isStreamSeekable, r.streamStartPos)
}

func streamLength(stream io.Reader, seekable bool, startPos int64) (size int64, ok bool, err error) {
if stream == nil {
return 0, true, nil
}

if l, ok := r.stream.(interface{ Len() int }); ok {
if l, ok := stream.(interface{ Len() int }); ok {
return int64(l.Len()), true, nil
}

if !r.isStreamSeekable {
if !seekable {
return 0, false, nil
}

s := r.stream.(io.Seeker)
s := stream.(io.Seeker)
endOffset, err := s.Seek(0, io.SeekEnd)
if err != nil {
return 0, false, err
Expand All @@ -69,12 +73,12 @@ func (r *Request) StreamLength() (size int64, ok bool, err error) {
// file, and wants to skip the first N bytes uploading the rest. The
// application would move the file's offset N bytes, then hand it off to
// the SDK to send the remaining. The SDK should respect that initial offset.
_, err = s.Seek(r.streamStartPos, io.SeekStart)
_, err = s.Seek(startPos, io.SeekStart)
if err != nil {
return 0, false, err
}

return endOffset - r.streamStartPos, true, nil
return endOffset - startPos, true, nil
}

// RewindStream will rewind the io.Reader to the relative start position if it
Expand Down Expand Up @@ -103,27 +107,41 @@ func (r *Request) IsStreamSeekable() bool {
return r.isStreamSeekable
}

// SetStream returns a clone of the request with the stream set to the provided reader.
// May return an error if the provided reader is seekable but returns an error.
// SetStream returns a clone of the request with the stream set to the provided
// reader. May return an error if the provided reader is seekable but returns
// an error.
func (r *Request) SetStream(reader io.Reader) (rc *Request, err error) {
rc = r.Clone()

if reader == http.NoBody {
reader = nil
}

var isStreamSeekable bool
var streamStartPos int64
switch v := reader.(type) {
case io.Seeker:
n, err := v.Seek(0, io.SeekCurrent)
if err != nil {
return r, err
}
rc.isStreamSeekable = true
rc.streamStartPos = n
isStreamSeekable = true
streamStartPos = n
default:
rc.isStreamSeekable = false
// If the stream length can be determined, and is determined to be empty,
// use a nil stream to prevent confusion between empty vs not-empty
// streams.
length, ok, err := streamLength(reader, false, 0)
if err != nil {
return nil, err
} else if ok && length == 0 {
reader = nil
}
}

rc.stream = reader
rc.isStreamSeekable = isStreamSeekable
rc.streamStartPos = streamStartPos

return rc, err
}
Expand Down
10 changes: 7 additions & 3 deletions transport/http/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@ func TestRequestRewindable(t *testing.T) {
"rewindable": {
Stream: bytes.NewReader([]byte{}),
},
"not rewindable": {
Stream: bytes.NewBuffer([]byte{}),
"empty not rewindable": {
Stream: bytes.NewBuffer([]byte{}),
// ExpectErr: "stream is not seekable",
},
"not empty not rewindable": {
Stream: bytes.NewBuffer([]byte("abc123")),
ExpectErr: "stream is not seekable",
},
"nil stream": {},
Expand Down Expand Up @@ -121,7 +125,7 @@ func TestRequestSetStream(t *testing.T) {
},
"empty unseekable stream": {
reader: bytes.NewBuffer([]byte{}),
expectNilStream: false,
expectNilStream: true,
expectNilBody: true,
},
"empty seekable stream": {
Expand Down