Skip to content

Commit

Permalink
Override SDK-set headers for client name and others
Browse files Browse the repository at this point in the history
Fixes #440
  • Loading branch information
cretz committed Feb 13, 2024
1 parent 678bd0b commit 46c677a
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 0 deletions.
28 changes: 28 additions & 0 deletions temporalcli/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"go.temporal.io/sdk/converter"
"go.temporal.io/sdk/log"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)

func (c *ClientOptions) dialClient(cctx *CommandContext) (client.Client, error) {
Expand Down Expand Up @@ -53,6 +54,14 @@ func (c *ClientOptions) dialClient(cctx *CommandContext) (client.Client, error)
clientOptions.ConnectionOptions.DialOptions, grpc.WithChainUnaryInterceptor(interceptor))
}

// Fixed header overrides
clientOptions.ConnectionOptions.DialOptions = append(
clientOptions.ConnectionOptions.DialOptions, grpc.WithChainUnaryInterceptor(fixedHeaderOverrideInterceptor))

// Additional gRPC options
clientOptions.ConnectionOptions.DialOptions = append(
clientOptions.ConnectionOptions.DialOptions, cctx.Options.AdditionalClientGRPCDialOptions...)

// TLS
var err error
if clientOptions.ConnectionOptions.TLS, err = c.tlsConfig(); err != nil {
Expand Down Expand Up @@ -92,6 +101,25 @@ func (c *ClientOptions) tlsConfig() (*tls.Config, error) {
return conf, nil
}

func fixedHeaderOverrideInterceptor(
ctx context.Context,
method string, req, reply any,
cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption,
) error {
// The SDK sets some values on the outgoing metadata that we can't override
// via normal headers, so we have to replace directly on the metadata
md, _ := metadata.FromOutgoingContext(ctx)
if md == nil {
md = metadata.MD{}
}
md.Set("client-name", "temporal-cli")
md.Set("client-version", Version)
md.Set("supported-server-versions", ">=1.0.0 <2.0.0")
md.Set("caller-type", "operator")
ctx = metadata.NewOutgoingContext(ctx, md)
return invoker(ctx, method, req, reply, cc, opts...)
}

func payloadCodecInterceptor(namespace, codecEndpoint, codecAuth string) (grpc.UnaryClientInterceptor, error) {
codecEndpoint = strings.ReplaceAll(codecEndpoint, "{namespace}", namespace)

Expand Down
3 changes: 3 additions & 0 deletions temporalcli/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"go.temporal.io/api/failure/v1"
"go.temporal.io/api/temporalproto"
"go.temporal.io/server/common/headers"
"google.golang.org/grpc"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
"gopkg.in/yaml.v3"
Expand Down Expand Up @@ -67,6 +68,8 @@ type CommandOptions struct {

// Defaults to logging error then os.Exit(1)
Fail func(error)

AdditionalClientGRPCDialOptions []grpc.DialOption
}

func NewCommandContext(ctx context.Context, options CommandOptions) (*CommandContext, context.CancelFunc, error) {
Expand Down
50 changes: 50 additions & 0 deletions temporalcli/commands.workflow_exec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ package temporalcli_test

import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"net/http/httptest"
"os"
"strconv"
"sync"

"github.com/google/uuid"
"go.temporal.io/api/common/v1"
Expand All @@ -16,6 +18,8 @@ import (
"go.temporal.io/sdk/converter"
"go.temporal.io/sdk/worker"
"go.temporal.io/sdk/workflow"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"google.golang.org/protobuf/proto"
)

Expand Down Expand Up @@ -149,6 +153,52 @@ func (s *SharedServerSuite) TestWorkflow_Execute_SimpleFailure() {
jsonPath(jsonOut, "closeEvent", "workflowExecutionFailedEventAttributes", "failure", "message"))
}

func (s *SharedServerSuite) TestWorkflow_Execute_ClientHeaders() {
// Capture headers
var lastHeadersClient metadata.MD
var lastHeadersLock sync.Mutex
// Capture from client
s.CommandHarness.Options.AdditionalClientGRPCDialOptions = append(
s.CommandHarness.Options.AdditionalClientGRPCDialOptions,
grpc.WithChainUnaryInterceptor(func(
ctx context.Context,
method string, req, reply any,
cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption,
) error {
lastHeadersLock.Lock()
lastHeadersClient, _ = metadata.FromOutgoingContext(ctx)
lastHeadersLock.Unlock()
return invoker(ctx, method, req, reply, cc, opts...)
}),
)

// Capture from server
// TODO(cretz): Pending fix on server for gRPC interceptors
// var lastHeadersServer metadata.MD
// s.SetServerInterceptor(
// func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
// lastHeadersLock.Lock()
// lastHeadersServer, _ = metadata.FromIncomingContext(ctx)
// lastHeadersLock.Unlock()
// return handler(ctx, req)
// },
// )

// Exec workflow
res := s.Execute(
"workflow", "execute",
"--address", s.Address(),
"--task-queue", s.Worker.Options.TaskQueue,
"--type", "DevWorkflow",
"--workflow-id", "my-id1",
"-i", `["val1", "val2"]`,
)
s.NoError(res.Err)

// Check that the client name is there
s.Equal("temporal-cli", lastHeadersClient["client-name"][0])
}

func (s *SharedServerSuite) TestWorkflow_Execute_EnvVars() {
s.CommandHarness.Options.LookupEnv = func(key string) (string, bool) {
if key == "TEMPORAL_ADDRESS" {
Expand Down
22 changes: 22 additions & 0 deletions temporalcli/commands_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"go.temporal.io/sdk/client"
"go.temporal.io/sdk/worker"
"go.temporal.io/sdk/workflow"
"google.golang.org/grpc"
)

type CommandHarness struct {
Expand Down Expand Up @@ -221,6 +222,9 @@ type DevServer struct {

logOutput bytes.Buffer
logOutputLock sync.RWMutex

serverInterceptor grpc.UnaryServerInterceptor
serverInterceptorLock sync.RWMutex
}

type DevServerOptions struct {
Expand Down Expand Up @@ -267,6 +271,18 @@ func StartDevServer(t *testing.T, options DevServerOptions) *DevServer {
d.Options.DynamicConfigValues = map[string]any{}
}
d.Options.DynamicConfigValues["system.forceSearchAttributesCacheRefreshOnRead"] = true
d.Options.GRPCInterceptors = append(
d.Options.GRPCInterceptors,
func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
d.serverInterceptorLock.RLock()
serverInterceptor := d.serverInterceptor
d.serverInterceptorLock.RUnlock()
if serverInterceptor != nil {
return serverInterceptor(ctx, req, info, handler)
}
return handler(ctx, req)
},
)

// Start
var err error
Expand Down Expand Up @@ -336,6 +352,12 @@ func (d *DevServer) Namespace() string {
return d.Options.ClientOptions.Namespace
}

func (d *DevServer) SetServerInterceptor(serverInterceptor grpc.UnaryServerInterceptor) {
d.serverInterceptorLock.Lock()
defer d.serverInterceptorLock.Unlock()
d.serverInterceptor = serverInterceptor
}

type DevWorker struct {
Worker worker.Worker
// Has defaults populated
Expand Down
8 changes: 8 additions & 0 deletions temporalcli/devserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import (
"go.temporal.io/server/schema/sqlite"
sqliteschema "go.temporal.io/server/schema/sqlite"
"go.temporal.io/server/temporal"
"google.golang.org/grpc"
"gopkg.in/yaml.v3"
)

Expand All @@ -71,6 +72,7 @@ type StartOptions struct {
FrontendHTTPPort int
DynamicConfigValues map[string]any
LogConfig func([]byte)
GRPCInterceptors []grpc.UnaryServerInterceptor
}

type Server struct {
Expand Down Expand Up @@ -193,6 +195,12 @@ func (s *StartOptions) buildServerOptions() ([]temporal.ServerOption, error) {
}
opts = append(opts, temporal.WithDynamicConfigClient(dynConf))
}

// gRPC interceptors if set
if len(s.GRPCInterceptors) > 0 {
opts = append(opts, temporal.WithChainedFrontendGrpcInterceptors(s.GRPCInterceptors...))
}

return opts, nil
}

Expand Down

0 comments on commit 46c677a

Please sign in to comment.