Skip to content

Commit

Permalink
fix(storage/rollout): ensure tx rollback and fixed rollout type on up…
Browse files Browse the repository at this point in the history
…date
  • Loading branch information
GeorgeMac committed Jul 3, 2023
1 parent 0f269a3 commit 37a1166
Show file tree
Hide file tree
Showing 4 changed files with 515 additions and 452 deletions.
45 changes: 42 additions & 3 deletions build/testing/integration/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -558,22 +558,23 @@ func API(t *testing.T, ctx context.Context, client sdk.SDK, namespace string, au
rolloutPercentage, err := client.Flipt().CreateRollout(ctx, &flipt.CreateRolloutRequest{
NamespaceKey: namespace,
FlagKey: "boolean_disabled",
Description: "50% enabled",
Description: "50% disabled",
Rank: 2,
Rule: &flipt.CreateRolloutRequest_Percentage{
Percentage: &flipt.RolloutPercentage{
Percentage: 50,
Value: true,
Value: false,
},
},
})
require.NoError(t, err)

assert.Equal(t, namespace, rolloutPercentage.NamespaceKey)
assert.Equal(t, "boolean_disabled", rolloutPercentage.FlagKey)
assert.Equal(t, "50% disabled", rolloutPercentage.Description)
assert.Equal(t, int32(2), rolloutPercentage.Rank)
assert.Equal(t, float32(50.0), rolloutPercentage.Rule.(*flipt.Rollout_Percentage).Percentage.Percentage)
assert.Equal(t, true, rolloutPercentage.Rule.(*flipt.Rollout_Percentage).Percentage.Value)
assert.Equal(t, false, rolloutPercentage.Rule.(*flipt.Rollout_Percentage).Percentage.Value)

rollouts, err := client.Flipt().ListRollouts(ctx, &flipt.ListRolloutRequest{
NamespaceKey: namespace,
Expand All @@ -585,6 +586,44 @@ func API(t *testing.T, ctx context.Context, client sdk.SDK, namespace string, au
rolloutSegment,
rolloutPercentage,
}, rollouts.Rules, protocmp.Transform()))

updatedRollout, err := client.Flipt().UpdateRollout(ctx, &flipt.UpdateRolloutRequest{
NamespaceKey: namespace,
FlagKey: "boolean_disabled",
Id: rolloutPercentage.Id,
Description: "50% enabled",
Rule: &flipt.UpdateRolloutRequest_Percentage{
Percentage: &flipt.RolloutPercentage{
Percentage: 50,
Value: true,
},
},
})
require.NoError(t, err)

assert.Equal(t, namespace, updatedRollout.NamespaceKey)
assert.Equal(t, "boolean_disabled", updatedRollout.FlagKey)
assert.Equal(t, "50% enabled", updatedRollout.Description)
assert.Equal(t, int32(2), updatedRollout.Rank)
assert.Equal(t, float32(50.0), updatedRollout.Rule.(*flipt.Rollout_Percentage).Percentage.Percentage)
assert.Equal(t, true, updatedRollout.Rule.(*flipt.Rollout_Percentage).Percentage.Value)

t.Run("Cannot change rollout type", func(t *testing.T) {
_, err := client.Flipt().UpdateRollout(ctx, &flipt.UpdateRolloutRequest{
NamespaceKey: namespace,
FlagKey: "boolean_disabled",
Id: rolloutPercentage.Id,
Description: "50% enabled",
Rule: &flipt.UpdateRolloutRequest_Segment{
Segment: &flipt.RolloutSegment{
SegmentKey: "everyone",
Value: true,
},
},
})

require.EqualError(t, err, "rpc error: code = InvalidArgument desc = cannot change type of rollout: have \"PERCENTAGE_ROLLOUT_TYPE\" attempted \"SEGMENT_ROLLOUT_TYPE\"")
})
})

t.Run("Legacy", func(t *testing.T) {
Expand Down
67 changes: 52 additions & 15 deletions internal/storage/sql/common/rollout.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ const (
)

func (s *Store) GetRollout(ctx context.Context, namespaceKey, id string) (*flipt.Rollout, error) {
return getRollout(ctx, s.builder, namespaceKey, id)
}

func getRollout(ctx context.Context, builder sq.StatementBuilderType, namespaceKey, id string) (*flipt.Rollout, error) {
if namespaceKey == "" {
namespaceKey = storage.DefaultNamespace
}
Expand All @@ -33,7 +37,7 @@ func (s *Store) GetRollout(ctx context.Context, namespaceKey, id string) (*flipt

rollout = &flipt.Rollout{}

err = s.builder.Select("id, namespace_key, flag_key, \"type\", \"rank\", description, created_at, updated_at").
err = builder.Select("id, namespace_key, flag_key, \"type\", \"rank\", description, created_at, updated_at").
From(tableRollouts).
Where(sq.And{sq.Eq{"id": id}, sq.Eq{"namespace_key": namespaceKey}}).
QueryRowContext(ctx).
Expand Down Expand Up @@ -65,7 +69,7 @@ func (s *Store) GetRollout(ctx context.Context, namespaceKey, id string) (*flipt
Segment: &flipt.RolloutSegment{},
}

if err := s.builder.Select("segment_key, \"value\"").
if err := builder.Select("segment_key, \"value\"").
From(tableRolloutSegments).
Where(sq.And{sq.Eq{"rollout_id": rollout.Id}, sq.Eq{"namespace_key": rollout.NamespaceKey}}).
Limit(1).
Expand All @@ -82,7 +86,7 @@ func (s *Store) GetRollout(ctx context.Context, namespaceKey, id string) (*flipt
Percentage: &flipt.RolloutPercentage{},
}

if err := s.builder.Select("percentage, \"value\"").
if err := builder.Select("percentage, \"value\"").
From(tableRolloutPercentages).
Where(sq.And{sq.Eq{"rollout_id": rollout.Id}, sq.Eq{"namespace_key": rollout.NamespaceKey}}).
Limit(1).
Expand Down Expand Up @@ -400,16 +404,32 @@ func (s *Store) CreateRollout(ctx context.Context, r *flipt.CreateRolloutRequest
return rollout, tx.Commit()
}

func (s *Store) UpdateRollout(ctx context.Context, r *flipt.UpdateRolloutRequest) (*flipt.Rollout, error) {
func (s *Store) UpdateRollout(ctx context.Context, r *flipt.UpdateRolloutRequest) (_ *flipt.Rollout, err error) {
if r.NamespaceKey == "" {
r.NamespaceKey = storage.DefaultNamespace
}

if r.Id == "" {
return nil, errs.ErrInvalid("rollout ID not supplied")
}

tx, err := s.db.Begin()
if err != nil {
return nil, err
}

defer func() {
if err != nil {
_ = tx.Rollback()
}
}()

// get current state for rollout
rollout, err := getRollout(ctx, s.builder.RunWith(tx), r.NamespaceKey, r.Id)
if err != nil {
return nil, err
}

whereClause := sq.And{sq.Eq{"id": r.Id}, sq.Eq{"flag_key": r.FlagKey}, sq.Eq{"namespace_key": r.NamespaceKey}}

query := s.builder.Update(tableRollouts).
Expand All @@ -420,54 +440,71 @@ func (s *Store) UpdateRollout(ctx context.Context, r *flipt.UpdateRolloutRequest

res, err := query.ExecContext(ctx)
if err != nil {
_ = tx.Rollback()
return nil, err
}

count, err := res.RowsAffected()
if err != nil {
_ = tx.Rollback()
return nil, err
}

if count != 1 {
_ = tx.Rollback()
return nil, errs.ErrNotFoundf(`rollout "%s/%s"`, r.NamespaceKey, r.Id)
}

switch r.GetType() {
case flipt.RolloutType_SEGMENT_ROLLOUT_TYPE:
switch r.Rule.(type) {
case *flipt.UpdateRolloutRequest_Segment:
// enforce that rollout type is consistent with the DB
if err := ensureRolloutType(rollout, flipt.RolloutType_SEGMENT_ROLLOUT_TYPE); err != nil {
return nil, err
}

var segmentRule = r.GetSegment()

if _, err := s.builder.Update(tableRolloutSegments).
RunWith(tx).
Set("segment_key", segmentRule.SegmentKey).
Set("value", segmentRule.Value).
Where(sq.Eq{"rollout_id": r.Id}).ExecContext(ctx); err != nil {
_ = tx.Rollback()
return nil, err
}
case flipt.RolloutType_PERCENTAGE_ROLLOUT_TYPE:
case *flipt.UpdateRolloutRequest_Percentage:
// enforce that rollout type is consistent with the DB
if err := ensureRolloutType(rollout, flipt.RolloutType_PERCENTAGE_ROLLOUT_TYPE); err != nil {
return nil, err
}

var percentageRule = r.GetPercentage()

if _, err := s.builder.Update(tableRolloutPercentages).
RunWith(tx).
Set("percentage", percentageRule.Percentage).
Set("value", percentageRule.Value).
Where(sq.Eq{"rollout_id": r.Id}).ExecContext(ctx); err != nil {
_ = tx.Rollback()
return nil, err
}
default:
_ = tx.Rollback()
return nil, errs.InvalidFieldError("rule", "invalid rollout rule type")
}

if err := tx.Commit(); err != nil {
rollout, err = getRollout(ctx, s.builder.RunWith(tx), r.NamespaceKey, r.Id)
if err != nil {
return nil, err
}

return s.GetRollout(ctx, r.NamespaceKey, r.Id)
return rollout, tx.Commit()
}

func ensureRolloutType(rollout *flipt.Rollout, typ flipt.RolloutType) error {
if rollout.Type == typ {
return nil
}

return errs.ErrInvalidf(
"cannot change type of rollout: have %q attempted %q",
rollout.Type,
typ,
)
}

func (s *Store) DeleteRollout(ctx context.Context, r *flipt.DeleteRolloutRequest) error {
Expand Down
Loading

0 comments on commit 37a1166

Please sign in to comment.