From c7bdae06064845ccaa1ea0fcdc4277cdb3271f2e Mon Sep 17 00:00:00 2001 From: Quinn Klassen Date: Wed, 8 May 2024 16:36:42 -0700 Subject: [PATCH 1/2] Require update handler to have a context --- internal/internal_update.go | 15 ++++++++++++--- internal/internal_update_test.go | 10 ++++++---- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/internal/internal_update.go b/internal/internal_update.go index ab1b16b99..d06696db3 100644 --- a/internal/internal_update.go +++ b/internal/internal_update.go @@ -379,7 +379,7 @@ func (up *updateProtocol) HasCompleted() bool { // // 1. is a function // 2. has exactly one return parameter -// 3. the one return prarmeter is of type `error` +// 3. the one return parameter is of type `error` func validateValidatorFn(fn interface{}) error { fnType := reflect.TypeOf(fn) if fnType.Kind() != reflect.Func { @@ -405,13 +405,22 @@ func validateValidatorFn(fn interface{}) error { // validateUpdateHandlerFn validates that the supplied interface // // 1. is a function -// 2. has one or two return parameters, the last of which is of type `error` -// 3. if there are two return parameters, the first is a serializable type +// 2. has at least one parameter, the first of which is of type `workflow.Context` +// 3. has one or two return parameters, the last of which is of type `error` +// 4. if there are two return parameters, the first is a serializable type func validateUpdateHandlerFn(fn interface{}) error { fnType := reflect.TypeOf(fn) if fnType.Kind() != reflect.Func { return fmt.Errorf("handler must be function but was %s", fnType.Kind()) } + if fnType.NumIn() == 0 { + return errors.New("first parameter of handler must be a workflow.Context") + } else if !isWorkflowContext(fnType.In(0)) { + return fmt.Errorf( + "first parameter of handler must be a workflow.Context but found %v", + fnType.In(0).Kind(), + ) + } switch fnType.NumOut() { case 1: if !isError(fnType.Out(0)) { diff --git a/internal/internal_update_test.go b/internal/internal_update_test.go index e7759251c..473b50cb8 100644 --- a/internal/internal_update_test.go +++ b/internal/internal_update_test.go @@ -151,10 +151,12 @@ func TestUpdateHandlerFnValidation(t *testing.T) { {require.Error, func() int { return 0 }}, {require.Error, func(Context, int) (int, int, error) { return 0, 0, nil }}, {require.Error, func(int) (chan int, error) { return nil, nil }}, - {require.NoError, func() error { return nil }}, + {require.Error, func() error { return nil }}, + {require.Error, func(int, int, string) error { return nil }}, + {require.Error, func(int) error { return nil }}, + {require.NoError, func(Context, int) error { return nil }}, + {require.NoError, func(Context, int, int, string) error { return nil }}, {require.NoError, func(Context) error { return nil }}, - {require.NoError, func(int) error { return nil }}, - {require.NoError, func(int, int, string) error { return nil }}, {require.NoError, func(Context, int, int, string) error { return nil }}, } { t.Run(reflect.TypeOf(tc.fn).String(), func(t *testing.T) { @@ -219,7 +221,7 @@ func TestDefaultUpdateHandler(t *testing.T) { t, ctx, "unused_handler", - func() error { panic("should not be called") }, + func(ctx Context) error { panic("should not be called") }, UpdateHandlerOptions{}, ) }, From 9b8a67baa19533910b50452cd472c29244eba495 Mon Sep 17 00:00:00 2001 From: Quinn Klassen Date: Wed, 8 May 2024 19:36:16 -0700 Subject: [PATCH 2/2] Fix unit test --- internal/internal_update_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/internal_update_test.go b/internal/internal_update_test.go index 473b50cb8..29dc62650 100644 --- a/internal/internal_update_test.go +++ b/internal/internal_update_test.go @@ -103,7 +103,7 @@ func TestUpdateHandlerPanicHandling(t *testing.T) { interceptor, ctx, err := newWorkflowContext(env, nil) require.NoError(t, err) - panicFunc := func() error { panic("intentional") } + panicFunc := func(ctx Context) error { panic("intentional") } dispatcher, _ := newDispatcher( ctx, interceptor, @@ -123,7 +123,7 @@ func TestUpdateHandlerPanicHandling(t *testing.T) { interceptor, ctx, err := newWorkflowContext(env, nil) require.NoError(t, err) - panicFunc := func() error { panic("intentional") } + panicFunc := func(ctx Context) error { panic("intentional") } dispatcher, _ := newDispatcher( ctx, interceptor,