diff --git a/lock/etcd_test.go b/lock/etcd_test.go index 6ff6611..6ae2323 100644 --- a/lock/etcd_test.go +++ b/lock/etcd_test.go @@ -16,132 +16,171 @@ package lock import ( "errors" - "reflect" "testing" - "go.etcd.io/etcd/client" + pb "go.etcd.io/etcd/api/v3/mvccpb" + client "go.etcd.io/etcd/client/v3" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/net/context" ) +// testEtcdClient is the struct used to mock the `etcd` +// client type testEtcdClient struct { - err error - resp *client.Response + err error + putResp *client.PutResponse + getResp *client.GetResponse } -func (t *testEtcdClient) Get(ctx context.Context, key string, opts *client.GetOptions) (*client.Response, error) { - return t.resp, t.err -} +func (t *testEtcdClient) Put(ctx context.Context, key, val string, opts ...client.OpOption) (*client.PutResponse, error) { + return t.putResp, t.err -func (t *testEtcdClient) Set(ctx context.Context, key, value string, opts *client.SetOptions) (*client.Response, error) { - return t.resp, t.err } -func (t *testEtcdClient) Create(ctx context.Context, key, value string) (*client.Response, error) { - return t.resp, t.err +func (t *testEtcdClient) Get(ctx context.Context, key string, opts ...client.OpOption) (*client.GetResponse, error) { + return t.getResp, t.err + } func TestEtcdLockClientInit(t *testing.T) { - for i, tt := range []struct { - ee error - want bool - group string - keypath string - }{ - {nil, false, "", SemaphorePrefix}, - {client.Error{Code: client.ErrorCodeNodeExist}, false, "", SemaphorePrefix}, - {client.Error{Code: client.ErrorCodeKeyNotFound}, true, "", SemaphorePrefix}, - {errors.New("some random error"), true, "", SemaphorePrefix}, - {client.Error{Code: client.ErrorCodeKeyNotFound}, true, "database", "coreos.com/updateengine/rebootlock/groups/database/semaphore"}, - {nil, false, "prod/database", "coreos.com/updateengine/rebootlock/groups/prod%2Fdatabase/semaphore"}, - } { - elc, got := NewEtcdLockClient(&testEtcdClient{err: tt.ee}, tt.group) - if (got != nil) != tt.want { - t.Errorf("case %d: unexpected error state initializing Client: got %v", i, got) - continue - } - - if got != nil { - continue - } - - if elc.keypath != tt.keypath { - t.Errorf("case %d: unexpected etcd key path: got %v want %v", i, elc.keypath, tt.keypath) - } - } -} + t.Run("Success", func(t *testing.T) { + elc, err := NewEtcdLockClient(&testEtcdClient{ + err: nil, + getResp: &client.GetResponse{Count: 0}, + putResp: &client.PutResponse{}, + }, + "", + ) + require.Nil(t, err) -func makeResponse(idx int, val string) *client.Response { - return &client.Response{ - Node: &client.Node{ - Value: val, - ModifiedIndex: uint64(idx), + assert.NotNil(t, elc) + }) + t.Run("SuccessWithExistingValue", func(t *testing.T) { + elc, err := NewEtcdLockClient(&testEtcdClient{ + err: nil, + getResp: &client.GetResponse{ + Count: 1, + Kvs: []*pb.KeyValue{ + &pb.KeyValue{ + Key: []byte(SemaphorePrefix), + Value: []byte(`{"semaphore": 1, "max": 2, "holders": ["foo", "bar"]}`), + }, + }, + }, + putResp: &client.PutResponse{}, }, - } + "", + ) + require.Nil(t, err) + + assert.NotNil(t, elc) + }) + t.Run("Error", func(t *testing.T) { + elc, err := NewEtcdLockClient(&testEtcdClient{ + err: errors.New("connection refused"), + getResp: &client.GetResponse{Count: 0}, + putResp: &client.PutResponse{}, + }, + "", + ) + require.Nil(t, elc) + + assert.Equal(t, "unable to init etcd lock client: unable to get semaphore: connection refused", err.Error()) + }) } func TestEtcdLockClientGet(t *testing.T) { - for i, tt := range []struct { - ee error - er *client.Response - ws *Semaphore - we bool - }{ - // errors returned from etcd - {errors.New("some error"), nil, nil, true}, - {client.Error{Code: client.ErrorCodeKeyNotFound}, nil, nil, true}, - // bad JSON should cause errors - {nil, makeResponse(0, "asdf"), nil, true}, - {nil, makeResponse(0, `{"semaphore:`), nil, true}, - // successful calls - {nil, makeResponse(10, `{"semaphore": 1}`), &Semaphore{Index: 10, Semaphore: 1}, false}, - {nil, makeResponse(1024, `{"semaphore": 1, "max": 2, "holders": ["foo", "bar"]}`), &Semaphore{Index: 1024, Semaphore: 1, Max: 2, Holders: []string{"foo", "bar"}}, false}, - // index should be set from etcd, not json! - {nil, makeResponse(1234, `{"semaphore": 89, "index": 4567}`), &Semaphore{Index: 1234, Semaphore: 89}, false}, - } { - elc := &EtcdLockClient{ - keyapi: &testEtcdClient{ - err: tt.ee, - resp: tt.er, + t.Run("Success", func(t *testing.T) { + elc, err := NewEtcdLockClient(&testEtcdClient{ + err: nil, + putResp: &client.PutResponse{}, + getResp: &client.GetResponse{ + Count: 1, + Kvs: []*pb.KeyValue{ + &pb.KeyValue{ + Key: []byte(SemaphorePrefix), + // index should be set from etcd, not json (backported from legacy test) + Value: []byte(`{"index": 12, "semaphore": 1, "max": 2, "holders": ["foo", "bar"]}`), + Version: 1234, + }, + }, + }, + }, + "", + ) + require.Nil(t, err) + + res, err := elc.Get() + require.Nil(t, err) + + assert.Equal(t, 1, res.Semaphore) + assert.Equal(t, uint64(1234), res.Index) + }) + t.Run("SuccessNotFound", func(t *testing.T) { + elc, err := NewEtcdLockClient(&testEtcdClient{ + err: nil, + getResp: &client.GetResponse{Count: 0}, + putResp: &client.PutResponse{}, + }, + "", + ) + require.Nil(t, err) + + res, err := elc.Get() + require.Nil(t, res) + assert.ErrorIs(t, err, ErrNotFound) + }) + t.Run("ErrorWithMalformedJSON", func(t *testing.T) { + elc, err := NewEtcdLockClient(&testEtcdClient{ + err: nil, + getResp: &client.GetResponse{ + Count: 1, + Kvs: []*pb.KeyValue{ + &pb.KeyValue{ + Key: []byte(SemaphorePrefix), + // notice the missing `,` in the array + Value: []byte(`{"semaphore": 1, "max": 2, "holders": ["foo" "bar"]}`), + }, + }, }, - } - gs, ge := elc.Get() - if tt.we { - if ge == nil { - t.Fatalf("case %d: expected error but got nil!", i) - } - } else { - if ge != nil { - t.Fatalf("case %d: unexpected error: %v", i, ge) - } - } - if !reflect.DeepEqual(gs, tt.ws) { - t.Fatalf("case %d: bad semaphore: got %#v, want %#v", i, gs, tt.ws) - } - } + putResp: &client.PutResponse{}, + }, + "", + ) + require.Nil(t, elc) + + assert.Equal(t, "unable to init etcd lock client: unable to get semaphore: invalid character '\"' after array element", err.Error()) + }) } func TestEtcdLockClientSet(t *testing.T) { - for i, tt := range []struct { - sem *Semaphore - ee error // error returned from etcd - want bool // do we expect Set to return an error - }{ - // nil semaphore cannot be set - {nil, nil, true}, - // empty semaphore is OK - {&Semaphore{}, nil, false}, - {&Semaphore{Index: uint64(1234)}, nil, false}, - // all errors returned from etcd should propagate - {&Semaphore{}, client.Error{Code: client.ErrorCodeNodeExist}, true}, - {&Semaphore{}, client.Error{Code: client.ErrorCodeKeyNotFound}, true}, - {&Semaphore{}, errors.New("some random error"), true}, - } { - elc := &EtcdLockClient{ - keyapi: &testEtcdClient{err: tt.ee}, - } - got := elc.Set(tt.sem) - if (got != nil) != tt.want { - t.Errorf("case %d: unexpected error state calling Set: got %v", i, got) - } - } + t.Run("Success", func(t *testing.T) { + elc, err := NewEtcdLockClient(&testEtcdClient{ + err: nil, + getResp: &client.GetResponse{Count: 0}, + putResp: &client.PutResponse{}, + }, + "", + ) + require.Nil(t, err) + + err = elc.Set(&Semaphore{}) + assert.Nil(t, err) + + }) + t.Run("ErrorNilSemaphore", func(t *testing.T) { + elc, err := NewEtcdLockClient(&testEtcdClient{ + err: nil, + getResp: &client.GetResponse{Count: 0}, + putResp: &client.PutResponse{}, + }, + "", + ) + require.Nil(t, err) + + err = elc.Set(nil) + assert.Equal(t, "cannot set nil semaphore", err.Error()) + }) }