Skip to content

Commit

Permalink
Eliminate grpc dispatch proto marshaling overhead with Codec (#5)
Browse files Browse the repository at this point in the history
* changing fiber response request to byte. add new implementation of new codec

* tidy up codec and interface changes for request response

* fix import syntax

* add default codec as part of fibercodec

* patch fiber dispatcher servce method syntax

* fix test syntax for method name

* minor update on codec syntax

Co-authored-by: ningjie.lee <ningjie.lee@gojek.com>
  • Loading branch information
leonlnj and leonlnj authored Sep 19, 2022
1 parent cc45f83 commit e1bb222
Show file tree
Hide file tree
Showing 18 changed files with 145 additions and 265 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ traffic mixers from a set of composable abstract network components.

Core components of fiber are transport agnostic, however, there is
Go's `net/http`-based implementation provided in [fiber/http](http) package
and a grpc implementation using server reflection in [fiber/grpc](grpc).
and a grpc implementation in [fiber/grpc](grpc).

The grpc implementation will return a [dynamicpb message](https://pkg.go.dev/google.golang.org/protobuf/types/dynamicpb)
and it is expected that the client [marshal](https://pkg.go.dev/github.com/golang/protobuf/proto#Marshal) the message and unmarshall into the intended proto response.
The grpc implementation will use the byte payload from the request and response using a custom codec to minimize marshaling overhead.
It is expected that the client [marshal](https://pkg.go.dev/github.com/golang/protobuf/proto#Marshal) the message and unmarshall into the intended proto response.

## Usage

Expand Down
2 changes: 1 addition & 1 deletion cached_payload.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ type CachedPayload struct {
}

// Payload returns the cached []byte contents
func (b *CachedPayload) Payload() interface{} {
func (b *CachedPayload) Payload() []byte {
return b.data
}
2 changes: 1 addition & 1 deletion eager_router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ func TestEagerRouter_Dispatch(t *testing.T) {

assert.Equal(t, len(tt.expected), len(received), tt.name)
for i := 0; i < len(tt.expected); i++ {
assert.Equal(t, string(tt.expected[i].Payload().([]byte)), string(received[i].Payload().([]byte)), tt.name)
assert.Equal(t, string(tt.expected[i].Payload()), string(received[i].Payload()), tt.name)
assert.Equal(t, tt.expected[i].StatusCode(), received[i].StatusCode(), tt.name)
}
strategy.AssertExpectations(t)
Expand Down
35 changes: 12 additions & 23 deletions example/simplegrpc/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package main

import (
"context"
"fmt"
"log"

"github.com/gojek/fiber"
Expand Down Expand Up @@ -61,41 +60,31 @@ func main() {
"route-b": proxy2,
})

var req = &grpc.Request{
Message: &testproto.PredictValuesRequest{
PredictionRows: []*testproto.PredictionRow{
{
RowId: "1",
},
{
RowId: "2",
},
bytePayload, _ := proto.Marshal(&testproto.PredictValuesRequest{
PredictionRows: []*testproto.PredictionRow{
{
RowId: "1",
},
{
RowId: "2",
},
},
})
var req = &grpc.Request{
Message: bytePayload,
}

resp, ok := <-component.Dispatch(context.Background(), req).Iter()
if ok {
if resp.StatusCode() == int(codes.OK) {
log.Print(resp.Payload())

//values can be retrieved using protoReflect or marshalled into proto
payload, ok := resp.Payload().(proto.Message)
if !ok {
log.Fatalf("fail to convert response to proto")
}
payloadByte, err := proto.Marshal(payload)
if err != nil {
log.Fatalf("fail to marshal to proto")
}
responseProto := &testproto.PredictValuesResponse{}
err = proto.Unmarshal(payloadByte, responseProto)
err := proto.Unmarshal(resp.Payload(), responseProto)
if err != nil {
log.Fatalf("fail to unmarshal to proto")
}
log.Print(responseProto.String())
} else {
log.Fatalf(fmt.Sprintf("%s", resp.Payload()))
log.Fatalf(string(resp.Payload()))
}
} else {
log.Fatalf("fail to receive response queue")
Expand Down
35 changes: 12 additions & 23 deletions example/simplegrpcfromconfig/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package main

import (
"context"
"fmt"
"log"

"github.com/gojek/fiber/config"
Expand Down Expand Up @@ -32,42 +31,32 @@ func main() {
if err != nil {
log.Fatalf("\nerror: %v\n", err)
}

var req = &grpc.Request{
Message: &testproto.PredictValuesRequest{
PredictionRows: []*testproto.PredictionRow{
{
RowId: "1",
},
{
RowId: "2",
},
bytePayload, _ := proto.Marshal(&testproto.PredictValuesRequest{
PredictionRows: []*testproto.PredictionRow{
{
RowId: "1",
},
{
RowId: "2",
},
},
})
var req = &grpc.Request{
Message: bytePayload,
}

resp, ok := <-component.Dispatch(context.Background(), req).Iter()
if ok {
if resp.StatusCode() == int(codes.OK) {
log.Print(resp.Payload())

//values can be retrieved using protoReflect or marshalled into proto
payload, ok := resp.Payload().(proto.Message)
if !ok {
log.Fatalf("fail to convert response to proto")
}
payloadByte, err := proto.Marshal(payload)
if err != nil {
log.Fatalf("fail to marshal to proto")
}
responseProto := &testproto.PredictValuesResponse{}
err = proto.Unmarshal(payloadByte, responseProto)
err = proto.Unmarshal(resp.Payload(), responseProto)
if err != nil {
log.Fatalf("fail to unmarshal to proto")
}
log.Print(responseProto.String())
} else {
log.Fatalf(fmt.Sprintf("%s", resp.Payload()))
log.Fatalf(string(resp.Payload()))
}
} else {
log.Fatalf("fail to receive response queue")
Expand Down
48 changes: 48 additions & 0 deletions grpc/codec.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package grpc

import (
"io"

"google.golang.org/grpc/encoding"
)

// CodecName is the name registered for the proto compressor.
const codecName = "fiber"

// FiberCodec is a custom codec to prevent marshaling and unmarshalling
// when unnecessary, base on the inputs
type FiberCodec struct {
defaultCodec encoding.Codec
}

// Marshal will attempt to pass the request directly if it is a byte slice,
// otherwise unmarshal the request proto using the default implementation
func (fc *FiberCodec) Marshal(v interface{}) ([]byte, error) {
b, ok := v.([]byte)
if ok {
return b, nil
}
return fc.getDefaultCodec().Marshal(v)
}

// Unmarshal will attempt to write the request directly if it is a writer,
// otherwise unmarshal the request proto using the default implementation
func (fc *FiberCodec) Unmarshal(data []byte, v interface{}) error {
writer, ok := v.(io.Writer)
if ok {
_, err := writer.Write(data)
return err
}
return fc.getDefaultCodec().Unmarshal(data, v)
}

func (*FiberCodec) Name() string {
return codecName
}

func (fc *FiberCodec) getDefaultCodec() encoding.Codec {
if fc.defaultCodec == nil {
fc.defaultCodec = encoding.GetCodec("proto")
}
return fc.defaultCodec
}
140 changes: 21 additions & 119 deletions grpc/dispatcher.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package grpc

import (
"bytes"
"context"
"errors"
"fmt"
"strings"
"time"

"github.com/gojek/fiber"
Expand All @@ -13,17 +13,15 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/encoding"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protodesc"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/descriptorpb"
"google.golang.org/protobuf/types/dynamicpb"
)

func init() {
encoding.RegisterCodec(&FiberCodec{})
}

const (
TimeoutDefault = time.Second
)
Expand All @@ -36,8 +34,6 @@ type Dispatcher struct {
endpoint string
// conn is the grpc connection dialed upon creation of dispatcher
conn *grpc.ClientConn
// ResponseProto is the proto return type of the service.
responseProto proto.Message
}

type DispatcherConfig struct {
Expand All @@ -61,9 +57,20 @@ func (d *Dispatcher) Do(request fiber.Request) fiber.Response {
defer cancel()
ctx = metadata.NewOutgoingContext(ctx, grpcRequest.Metadata)

responseProto := proto.Clone(d.responseProto)
response := new(bytes.Buffer)
var responseHeader metadata.MD
err := d.conn.Invoke(ctx, d.serviceMethod, grpcRequest.Payload(), responseProto, grpc.Header(&responseHeader))

// Dispatcher will send both request and payload as bytes, with the use of codec
// to prevent marshaling. The codec content type will be sent with request and
// the server will attempt to unmarshal with the codec.
err := d.conn.Invoke(
ctx,
d.serviceMethod,
grpcRequest.Payload(),
response,
grpc.Header(&responseHeader),
grpc.CallContentSubtype(codecName),
)
if err != nil {
// if ok is false, unknown codes.Unknown and Status msg is returned in Status
responseStatus, _ := status.FromError(err)
Expand All @@ -76,7 +83,7 @@ func (d *Dispatcher) Do(request fiber.Request) fiber.Response {

return &Response{
Metadata: responseHeader,
Message: responseProto,
Message: response.Bytes(),
Status: *status.New(codes.OK, "Success"),
}
}
Expand Down Expand Up @@ -104,116 +111,11 @@ func NewDispatcher(config DispatcherConfig) (*Dispatcher, error) {
errors.New("grpc dispatcher: "+responseStatus.String()))
}

// Get reflection response from reflection server, which contain FileDescriptorProtos
reflectionResponse, err := getReflectionResponse(conn, config.Service)
if err != nil {
return nil, err
}
fileDescriptorProtoBytes := reflectionResponse.GetFileDescriptorResponse().GetFileDescriptorProto()

messageDescriptor, err := getMessageDescriptor(fileDescriptorProtoBytes, config.Service, config.Method)
if err != nil {
return nil, err
}

dispatcher := &Dispatcher{
timeout: configuredTimeout,
serviceMethod: fmt.Sprintf("%s/%s", config.Service, config.Method),
serviceMethod: fmt.Sprintf("/%s/%s", config.Service, config.Method),
endpoint: config.Endpoint,
conn: conn,
responseProto: dynamicpb.NewMessage(messageDescriptor),
}
return dispatcher, nil
}

func getReflectionResponse(conn *grpc.ClientConn, serviceName string) (*grpc_reflection_v1alpha.ServerReflectionResponse, error) {
// create a reflection client and get FileDescriptorProtos
reflectionClient := grpc_reflection_v1alpha.NewServerReflectionClient(conn)
req := &grpc_reflection_v1alpha.ServerReflectionRequest{
MessageRequest: &grpc_reflection_v1alpha.ServerReflectionRequest_FileContainingSymbol{
FileContainingSymbol: serviceName,
},
}
reflectionInfoClient, err := reflectionClient.ServerReflectionInfo(context.Background())
if err != nil {
return nil, fiberError.NewFiberError(
protocol.GRPC,
errors.New("grpc dispatcher: unable to get reflection information, ensure server reflection is enable and config are correct"))
}
if err = reflectionInfoClient.Send(req); err != nil {
return nil, fiberError.NewFiberError(protocol.GRPC, err)
}
reflectionResponse, err := reflectionInfoClient.Recv()
if err != nil {
return nil, fiberError.NewFiberError(protocol.GRPC, err)
}

return reflectionResponse, nil
}

func getMessageDescriptor(fileDescriptorProtoBytes [][]byte, serviceName string, methodName string) (protoreflect.MessageDescriptor, error) {
fileDescriptorProto, outputProtoName, err := getFileDescriptorProto(fileDescriptorProtoBytes, serviceName, methodName)
if err != nil {
return nil, err
}

messageDescriptor, err := getMessageDescriptorByName(fileDescriptorProto, outputProtoName)
if err != nil {
return nil, err
}
return messageDescriptor, nil
}

func getFileDescriptorProto(fileDescriptorProtoBytes [][]byte, serviceName string, methodName string) (*descriptorpb.FileDescriptorProto, string, error) {
var fileDescriptorProto *descriptorpb.FileDescriptorProto
var outputProtoName string

for _, fdpByte := range fileDescriptorProtoBytes {
fdp := &descriptorpb.FileDescriptorProto{}
if err := proto.Unmarshal(fdpByte, fdp); err != nil {
return nil, "", fiberError.NewFiberError(protocol.GRPC, err)
}

for _, service := range fdp.Service {
// find matching service descriptors from file descriptor
if serviceName == fmt.Sprintf("%s.%s", fdp.GetPackage(), service.GetName()) {
// find matching method from service descriptor
for _, method := range service.Method {
if method.GetName() == methodName {
outputType := method.GetOutputType()
//Get the proto name without package
outputProtoName = outputType[strings.LastIndex(outputType, ".")+1:]
fileDescriptorProto = fdp
break
}
}
}
if fileDescriptorProto != nil {
break
}
}
if fileDescriptorProto != nil {
break
}
}

if fileDescriptorProto == nil {
return nil, "", fiberError.NewFiberError(
protocol.GRPC,
errors.New("grpc dispatcher: unable to fetch file descriptors, ensure config are correct"))
}
return fileDescriptorProto, outputProtoName, nil
}

func getMessageDescriptorByName(fileDescriptorProto *descriptorpb.FileDescriptorProto, outputProtoName string) (protoreflect.MessageDescriptor, error) {
// Create a FileDescriptor from FileDescriptorProto, and get MessageDescriptor to create a dynamic message
// Note: It might be required to register new proto using protoregistry.Files.RegisterFile() at runtime
fileDescriptor, err := protodesc.NewFile(fileDescriptorProto, protoregistry.GlobalFiles)
if err != nil {
return nil, fiberError.NewFiberError(
protocol.GRPC,
errors.New("grpc dispatcher: unable to find proto in registry"))
}
messageDescriptor := fileDescriptor.Messages().ByName(protoreflect.Name(outputProtoName))
return messageDescriptor, nil
}
Loading

0 comments on commit e1bb222

Please sign in to comment.