diff --git a/go/pkg/client/cas.go b/go/pkg/client/cas.go index da24af4d..bdcf5adc 100644 --- a/go/pkg/client/cas.go +++ b/go/pkg/client/cas.go @@ -51,7 +51,7 @@ func (c *Client) UploadIfMissing(ctx context.Context, data ...*chunker.Chunker) log.V(2).Infof("%d items to store", len(missing)) var batches [][]digest.Digest if c.useBatchOps { - batches = makeBatches(missing) + batches = c.makeBatches(missing) } else { log.V(2).Info("Uploading them individually") for i := range missing { @@ -134,35 +134,25 @@ func (c *Client) WriteBlob(ctx context.Context, blob []byte) (digest.Digest, err return dg, c.WriteChunked(ctx, c.ResourceNameWrite(dg.Hash, dg.Size), ch) } -const ( - // MaxBatchSz is the maximum size of a batch to upload with BatchWriteBlobs. We set it to slightly - // below 4 MB, because that is the limit of a message size in gRPC - MaxBatchSz = 4*1024*1024 - 1024 - - // MaxBatchDigests is a suggested approximate limit based on current RBE implementation. - // Above that BatchUpdateBlobs calls start to exceed a typical minute timeout. - MaxBatchDigests = 4000 -) - // BatchWriteBlobs uploads a number of blobs to the CAS. They must collectively be below the -// maximum total size for a batch upload, which is about 4 MB (see MaxBatchSz). Digests must be +// maximum total size for a batch upload, which is about 4 MB (see MaxBatchSize). Digests must be // computed in advance by the caller. In case multiple errors occur during the blob upload, the // last error will be returned. func (c *Client) BatchWriteBlobs(ctx context.Context, blobs map[digest.Digest][]byte) error { var reqs []*repb.BatchUpdateBlobsRequest_Request var sz int64 for k, b := range blobs { - sz += k.Size + sz += int64(k.Size) reqs = append(reqs, &repb.BatchUpdateBlobsRequest_Request{ Digest: k.ToProto(), Data: b, }) } - if sz > MaxBatchSz { - return fmt.Errorf("batch update of %d total bytes exceeds maximum of %d", sz, MaxBatchSz) + if sz > int64(c.MaxBatchSize) { + return fmt.Errorf("batch update of %d total bytes exceeds maximum of %d", sz, c.MaxBatchSize) } - if len(blobs) > MaxBatchDigests { - return fmt.Errorf("batch update of %d total blobs exceeds maximum of %d", len(blobs), MaxBatchDigests) + if len(blobs) > int(c.MaxBatchDigests) { + return fmt.Errorf("batch update of %d total blobs exceeds maximum of %d", len(blobs), c.MaxBatchDigests) } closure := func() error { var resp *repb.BatchUpdateBlobsResponse @@ -212,12 +202,12 @@ func (c *Client) BatchWriteBlobs(ctx context.Context, blobs map[digest.Digest][] } // BatchDownloadBlobs downloads a number of blobs from the CAS to memory. They must collectively be below the -// maximum total size for a batch read, which is about 4 MB (see MaxBatchSz). Digests must be +// maximum total size for a batch read, which is about 4 MB (see MaxBatchSize). Digests must be // computed in advance by the caller. In case multiple errors occur during the blob read, the // last error will be returned. func (c *Client) BatchDownloadBlobs(ctx context.Context, dgs []digest.Digest) (map[digest.Digest][]byte, error) { - if len(dgs) > MaxBatchDigests { - return nil, fmt.Errorf("batch read of %d total blobs exceeds maximum of %d", len(dgs), MaxBatchDigests) + if len(dgs) > int(c.MaxBatchDigests) { + return nil, fmt.Errorf("batch read of %d total blobs exceeds maximum of %d", len(dgs), c.MaxBatchDigests) } req := &repb.BatchReadBlobsRequest{InstanceName: c.InstanceName} var sz int64 @@ -227,11 +217,11 @@ func (c *Client) BatchDownloadBlobs(ctx context.Context, dgs []digest.Digest) (m foundEmpty = true continue } - sz += dg.Size + sz += int64(dg.Size) req.Digests = append(req.Digests, dg.ToProto()) } - if sz > MaxBatchSz { - return nil, fmt.Errorf("batch read of %d total bytes exceeds maximum of %d", sz, MaxBatchSz) + if sz > int64(c.MaxBatchSize) { + return nil, fmt.Errorf("batch read of %d total bytes exceeds maximum of %d", sz, c.MaxBatchSize) } res := make(map[digest.Digest][]byte) if foundEmpty { @@ -291,7 +281,7 @@ func (c *Client) BatchDownloadBlobs(ctx context.Context, dgs []digest.Digest) (m // The input list is sorted in-place; additionally, any blob bigger than the maximum will be put in // a batch of its own and the caller will need to ensure that it is uploaded with Write, not batch // operations. -func makeBatches(dgs []digest.Digest) [][]digest.Digest { +func (c *Client) makeBatches(dgs []digest.Digest) [][]digest.Digest { var batches [][]digest.Digest log.V(2).Infof("Batching %d digests", len(dgs)) sort.Slice(dgs, func(i, j int) bool { @@ -300,11 +290,19 @@ func makeBatches(dgs []digest.Digest) [][]digest.Digest { for len(dgs) > 0 { batch := []digest.Digest{dgs[len(dgs)-1]} dgs = dgs[:len(dgs)-1] - sz := batch[0].Size - for len(dgs) > 0 && len(batch) < MaxBatchDigests && dgs[0].Size <= MaxBatchSz-sz { // dg.Size+sz possibly overflows so subtract instead. - sz += dgs[0].Size + requestOverhead := marshalledFieldSize(int64(len(c.InstanceName))) + sz := requestOverhead + marshalledRequestSize(batch[0]) + var nextSize int64 + if len(dgs) > 0 { + nextSize = marshalledRequestSize(dgs[0]) + } + for len(dgs) > 0 && len(batch) < int(c.MaxBatchDigests) && nextSize <= int64(c.MaxBatchSize)-sz { // nextSize+sz possibly overflows so subtract instead. + sz += nextSize batch = append(batch, dgs[0]) dgs = dgs[1:] + if len(dgs) > 0 { + nextSize = marshalledRequestSize(dgs[0]) + } } log.V(3).Infof("Created batch of %d blobs with total size %d", len(batch), sz) batches = append(batches, batch) @@ -313,6 +311,29 @@ func makeBatches(dgs []digest.Digest) [][]digest.Digest { return batches } +func marshalledFieldSize(size int64) int64 { + return 1 + int64(proto.SizeVarint(uint64(size))) + size +} + +func marshalledRequestSize(d digest.Digest) int64 { + // An additional BatchUpdateBlobsRequest_Request includes the Digest and data fields, + // as well as the message itself. Every field has a 1-byte size tag, followed by + // the varint field size for variable-sized fields (digest hash and data). + // Note that the BatchReadBlobsResponse_Response field is similar, but includes + // and additional Status proto which can theoretically be unlimited in size. + // We do not account for it here, relying on the Client setting a large (100MB) + // limit for incoming messages. + digestSize := marshalledFieldSize(int64(len(d.Hash))) + if d.Size > 0 { + digestSize += 1 + int64(proto.SizeVarint(uint64(d.Size))) + } + reqSize := marshalledFieldSize(digestSize) + if d.Size > 0 { + reqSize += marshalledFieldSize(int64(d.Size)) + } + return marshalledFieldSize(reqSize) +} + // ReadBlob fetches a blob from the CAS into a byte slice. func (c *Client) ReadBlob(ctx context.Context, d digest.Digest) ([]byte, error) { return c.readBlob(ctx, d.Hash, d.Size, 0, 0) @@ -649,7 +670,7 @@ func (c *Client) downloadFiles(ctx context.Context, execRoot string, outputs map log.V(2).Infof("%d items to download", len(dgs)) var batches [][]digest.Digest if c.useBatchOps { - batches = makeBatches(dgs) + batches = c.makeBatches(dgs) } else { log.V(2).Info("Downloading them individually") for i := range dgs { diff --git a/go/pkg/client/cas_test.go b/go/pkg/client/cas_test.go index 56b173a4..449573f9 100644 --- a/go/pkg/client/cas_test.go +++ b/go/pkg/client/cas_test.go @@ -27,7 +27,6 @@ import ( const ( instance = "instance" - thirdBatchSz = client.MaxBatchSz / 3 ) func TestSplitEndpoints(t *testing.T) { @@ -494,6 +493,10 @@ func TestWriteBlobsBatching(t *testing.T) { defer cleanup() fake := e.Server.CAS c := e.Client.GrpcClient + c.MaxBatchSize = 500 + c.MaxBatchDigests = 4 + // Each batch request frame overhead is 13 bytes. + // A per-blob overhead is 74 bytes. tests := []struct { name string @@ -509,26 +512,26 @@ func TestWriteBlobsBatching(t *testing.T) { }, { name: "large and small blobs hitting max exactly", - sizes: []int{client.MaxBatchSz - 1, client.MaxBatchSz - 1, client.MaxBatchSz - 1, 1, 1, 1}, + sizes: []int{338, 338, 338, 1, 1, 1}, batchReqs: 3, writeReqs: 0, }, { name: "small batches of big blobs", - sizes: []int{thirdBatchSz, thirdBatchSz, thirdBatchSz, thirdBatchSz, thirdBatchSz, thirdBatchSz, thirdBatchSz}, + sizes: []int{88, 88, 88, 88, 88, 88, 88}, batchReqs: 2, writeReqs: 1, }, { name: "batch with blob that's too big", - sizes: []int{client.MaxBatchSz + 1, thirdBatchSz, thirdBatchSz, thirdBatchSz}, + sizes: []int{400, 88, 88, 88}, batchReqs: 1, writeReqs: 1, }, { - name: "many small blobs", - sizes: []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - batchReqs: 1, + name: "many small blobs hitting max digests", + sizes: []int{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + batchReqs: 4, writeReqs: 0, }, } @@ -836,6 +839,10 @@ func TestDownloadActionOutputsBatching(t *testing.T) { defer cleanup() fake := e.Server.CAS c := e.Client.GrpcClient + c.MaxBatchSize = 500 + c.MaxBatchDigests = 4 + // Each batch request frame overhead is 13 bytes. + // A per-blob overhead is 74 bytes. tests := []struct { name string @@ -849,23 +856,23 @@ func TestDownloadActionOutputsBatching(t *testing.T) { }, { name: "large and small blobs hitting max exactly", - sizes: []int{client.MaxBatchSz - 1, client.MaxBatchSz - 1, client.MaxBatchSz - 1, 1, 1, 1}, + sizes: []int{338, 338, 338, 1, 1, 1}, batchReqs: 3, }, { name: "small batches of big blobs", - sizes: []int{thirdBatchSz, thirdBatchSz, thirdBatchSz, thirdBatchSz, thirdBatchSz, thirdBatchSz, thirdBatchSz}, + sizes: []int{88, 88, 88, 88, 88, 88, 88}, batchReqs: 2, }, { name: "batch with blob that's too big", - sizes: []int{client.MaxBatchSz + 1, thirdBatchSz, thirdBatchSz, thirdBatchSz}, + sizes: []int{400, 88, 88, 88}, batchReqs: 1, }, { - name: "many small blobs", - sizes: []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - batchReqs: 1, + name: "many small blobs hitting max digests", + sizes: []int{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + batchReqs: 4, }, } diff --git a/go/pkg/client/client.go b/go/pkg/client/client.go index 28b4dcba..49500e42 100644 --- a/go/pkg/client/client.go +++ b/go/pkg/client/client.go @@ -62,7 +62,11 @@ type Client struct { Connection *grpc.ClientConn CASConnection *grpc.ClientConn // Can be different from Connection a separate CAS endpoint is provided. // ChunkMaxSize is maximum chunk size to use for CAS uploads/downloads. - ChunkMaxSize ChunkMaxSize + ChunkMaxSize ChunkMaxSize + // MaxBatchDigests is maximum amount of digests to batch in batched operations. + MaxBatchDigests MaxBatchDigests + // MaxBatchSize is maximum size in bytes of a batch request for batch operations. + MaxBatchSize MaxBatchSize useBatchOps UseBatchOps casUploaders chan bool casDownloaders chan bool @@ -70,6 +74,16 @@ type Client struct { creds credentials.PerRPCCredentials } +const ( + // DefaultMaxBatchSize is the maximum size of a batch to upload with BatchWriteBlobs. We set it to slightly + // below 4 MB, because that is the limit of a message size in gRPC + DefaultMaxBatchSize = 4*1024*1024 - 1024 + + // DefaultMaxBatchDigests is a suggested approximate limit based on current RBE implementation. + // Above that BatchUpdateBlobs calls start to exceed a typical minute timeout. + DefaultMaxBatchDigests = 4000 +) + // Close closes the underlying gRPC connection(s). func (c *Client) Close() error { err := c.Connection.Close() @@ -95,6 +109,22 @@ func (s ChunkMaxSize) Apply(c *Client) { c.ChunkMaxSize = s } +// MaxBatchDigests is maximum amount of digests to batch in batched operations. +type MaxBatchDigests int + +// Apply sets the client's maximal batch digests to s. +func (s MaxBatchDigests) Apply(c *Client) { + c.MaxBatchDigests = s +} + +// MaxBatchSize is maximum size in bytes of a batch request for batch operations. +type MaxBatchSize int64 + +// Apply sets the client's maximum batch size to s. +func (s MaxBatchSize) Apply(c *Client) { + c.MaxBatchSize = s +} + // UseBatchOps can be set to true to use batch CAS operations when uploading multiple blobs, or // false to always use individual ByteStream requests. type UseBatchOps bool @@ -288,20 +318,22 @@ func NewClient(ctx context.Context, instanceName string, params DialParams, opts return nil, err } client := &Client{ - InstanceName: instanceName, - actionCache: regrpc.NewActionCacheClient(casConn), - byteStream: bsgrpc.NewByteStreamClient(casConn), - cas: regrpc.NewContentAddressableStorageClient(casConn), - execution: regrpc.NewExecutionClient(conn), - operations: opgrpc.NewOperationsClient(conn), - rpcTimeout: time.Minute, - Connection: conn, - CASConnection: casConn, - ChunkMaxSize: chunker.DefaultChunkSize, - useBatchOps: true, - casUploaders: make(chan bool, DefaultCASConcurrency), - casDownloaders: make(chan bool, DefaultCASConcurrency), - Retrier: RetryTransient(), + InstanceName: instanceName, + actionCache: regrpc.NewActionCacheClient(casConn), + byteStream: bsgrpc.NewByteStreamClient(casConn), + cas: regrpc.NewContentAddressableStorageClient(casConn), + execution: regrpc.NewExecutionClient(conn), + operations: opgrpc.NewOperationsClient(conn), + rpcTimeout: time.Minute, + Connection: conn, + CASConnection: casConn, + ChunkMaxSize: chunker.DefaultChunkSize, + MaxBatchDigests: DefaultMaxBatchDigests, + MaxBatchSize: DefaultMaxBatchSize, + useBatchOps: true, + casUploaders: make(chan bool, DefaultCASConcurrency), + casDownloaders: make(chan bool, DefaultCASConcurrency), + Retrier: RetryTransient(), } for _, o := range opts { o.Apply(client) @@ -323,10 +355,12 @@ func (d RPCTimeout) Apply(c *Client) { // // This method is logically "protected" and is intended for use by extensions of Client. func (c *Client) RPCOpts() []grpc.CallOption { + // Set a high limit on receiving large messages from the server. + opts := []grpc.CallOption{grpc.MaxCallRecvMsgSize(100 * 1024 * 1024)} if c.creds == nil { - return nil + return opts } - return []grpc.CallOption{grpc.PerRPCCredentials(c.creds)} + return append(opts, grpc.PerRPCCredentials(c.creds)) } // CallWithTimeout executes the given function f with a context that times out after an RPC timeout. diff --git a/go/pkg/fakes/cas.go b/go/pkg/fakes/cas.go index aa233dd9..488c44b7 100644 --- a/go/pkg/fakes/cas.go +++ b/go/pkg/fakes/cas.go @@ -12,6 +12,7 @@ import ( "github.com/bazelbuild/remote-apis-sdks/go/pkg/chunker" "github.com/bazelbuild/remote-apis-sdks/go/pkg/client" "github.com/bazelbuild/remote-apis-sdks/go/pkg/digest" + "github.com/golang/protobuf/proto" "github.com/pborman/uuid" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -198,6 +199,8 @@ func (f *Writer) QueryWriteStatus(context.Context, *bspb.QueryWriteStatusRequest // CAS is a fake CAS that implements FindMissingBlobs, Read and Write, storing stored blobs // in a map. It also counts the number of requests to store received, for validating batching logic. type CAS struct { + // Maximum batch byte size to verify requests against. + BatchSize int blobs map[digest.Digest][]byte reads map[digest.Digest]int writes map[digest.Digest]int @@ -210,7 +213,7 @@ type CAS struct { // NewCAS returns a new empty fake CAS. func NewCAS() *CAS { - c := &CAS{} + c := &CAS{BatchSize: client.DefaultMaxBatchSize} c.Clear() return c } @@ -304,12 +307,10 @@ func (f *CAS) BatchUpdateBlobs(ctx context.Context, req *repb.BatchUpdateBlobsRe return nil, status.Error(codes.InvalidArgument, "test fake expected instance name \"instance\"") } - var tot int64 - for _, r := range req.Requests { - tot += r.Digest.SizeBytes - } - if tot > client.MaxBatchSz { - return nil, status.Errorf(codes.InvalidArgument, "test fake received batch update for more than the maximum of %d bytes: %d bytes", client.MaxBatchSz, tot) + reqBlob, _ := proto.Marshal(req) + size := len(reqBlob) + if size > f.BatchSize { + return nil, status.Errorf(codes.InvalidArgument, "test fake received batch update for more than the maximum of %d bytes: %d bytes", f.BatchSize, size) } var resps []*repb.BatchUpdateBlobsResponse_Response @@ -355,12 +356,10 @@ func (f *CAS) BatchReadBlobs(ctx context.Context, req *repb.BatchReadBlobsReques return nil, status.Error(codes.InvalidArgument, "test fake expected instance name \"instance\"") } - var tot int64 - for _, dg := range req.Digests { - tot += dg.SizeBytes - } - if tot > client.MaxBatchSz { - return nil, status.Errorf(codes.InvalidArgument, "test fake received batch read for more than the maximum of %d bytes: %d bytes", client.MaxBatchSz, tot) + reqBlob, _ := proto.Marshal(req) + size := len(reqBlob) + if size > f.BatchSize { + return nil, status.Errorf(codes.InvalidArgument, "test fake received batch read for more than the maximum of %d bytes: %d bytes", f.BatchSize, size) } var resps []*repb.BatchReadBlobsResponse_Response