diff --git a/lib/backend/firestore/firestorebk.go b/lib/backend/firestore/firestorebk.go index bac9b056ab251..e246a996d546e 100644 --- a/lib/backend/firestore/firestorebk.go +++ b/lib/backend/firestore/firestorebk.go @@ -127,13 +127,13 @@ type Backend struct { } type record struct { - Key backend.Key `firestore:"key,omitempty"` - Timestamp int64 `firestore:"timestamp,omitempty"` - Expires int64 `firestore:"expires,omitempty"` - ID int64 `firestore:"id,omitempty"` - Value []byte `firestore:"value,omitempty"` - RevisionV2 string `firestore:"revision,omitempty"` - RevisionV1 string `firestore:"-"` + Key []byte `firestore:"key,omitempty"` + Timestamp int64 `firestore:"timestamp,omitempty"` + Expires int64 `firestore:"expires,omitempty"` + ID int64 `firestore:"id,omitempty"` + Value []byte `firestore:"value,omitempty"` + RevisionV2 string `firestore:"revision,omitempty"` + RevisionV1 string `firestore:"-"` } func (r *record) updates() []firestore.Update { @@ -163,9 +163,21 @@ type legacyRecord struct { Value string `firestore:"value,omitempty"` } +// brokenRecord is an incorrect version of record used to marshal backend.Items. +// The Key type was inadvertently changed from a []byte to backend.Key which +// causes problems reading existing data prior to the conversion. +type brokenRecord struct { + Key backend.Key `firestore:"key,omitempty"` + Timestamp int64 `firestore:"timestamp,omitempty"` + Expires int64 `firestore:"expires,omitempty"` + Value []byte `firestore:"value,omitempty"` + ID int64 `firestore:"id,omitempty"` + RevisionV2 string `firestore:"revision,omitempty"` +} + func newRecord(from backend.Item, clock clockwork.Clock) record { r := record{ - Key: from.Key, + Key: []byte(from.Key.String()), Value: from.Value, Timestamp: clock.Now().UTC().Unix(), ID: id(clock.Now()), @@ -184,23 +196,48 @@ func newRecord(from backend.Item, clock clockwork.Clock) record { } func newRecordFromDoc(doc *firestore.DocumentSnapshot) (*record, error) { + k, err := doc.DataAt(keyDocProperty) + if err != nil { + return nil, trace.Wrap(err) + } + var r record - if err := doc.DataTo(&r); err != nil { - // If unmarshal failed, try using the old format of records, where - // Value was a string. This document could've been written by an older - // version of our code. - var rl legacyRecord - if doc.DataTo(&rl) != nil { + switch k.(type) { + case []any: + // If the key is a slice of any, then the key was mistakenly persisted + // as a backend.Key directly. + var br brokenRecord + if doc.DataTo(&br) != nil { return nil, ConvertGRPCError(err) } + r = record{ - Key: backend.Key(rl.Key), - Value: []byte(rl.Value), - Timestamp: rl.Timestamp, - Expires: rl.Expires, - ID: rl.ID, + Key: br.Key, + Value: br.Value, + Timestamp: br.Timestamp, + Expires: br.Expires, + RevisionV2: br.RevisionV2, + ID: br.ID, + } + default: + if err := doc.DataTo(&r); err != nil { + // If unmarshal failed, try using the old format of records, where + // Value was a string. This document could've been written by an older + // version of our code. + var rl legacyRecord + if doc.DataTo(&rl) != nil { + return nil, ConvertGRPCError(err) + } + r = record{ + Key: backend.Key(rl.Key), + Value: []byte(rl.Value), + Timestamp: rl.Timestamp, + Expires: rl.Expires, + ID: rl.ID, + } } } + if r.RevisionV2 == "" { r.RevisionV1 = toRevisionV1(doc.UpdateTime) } @@ -218,7 +255,7 @@ func (r *record) isExpired(now time.Time) bool { func (r *record) backendItem() backend.Item { bi := backend.Item{ - Key: r.Key, + Key: backend.Key(r.Key), Value: r.Value, ID: r.ID, } @@ -444,23 +481,31 @@ func (b *Backend) getRangeDocs(ctx context.Context, startKey, endKey backend.Key limit = backend.DefaultRangeLimit } docs, err := b.svc.Collection(b.CollectionName). - Where(keyDocProperty, ">=", startKey). - Where(keyDocProperty, "<=", endKey). + Where(keyDocProperty, ">=", []byte(startKey.String())). + Where(keyDocProperty, "<=", []byte(endKey.String())). Limit(limit). Documents(ctx).GetAll() if err != nil { return nil, trace.Wrap(err) } legacyDocs, err := b.svc.Collection(b.CollectionName). - Where(keyDocProperty, ">=", string(startKey)). - Where(keyDocProperty, "<=", string(endKey)). + Where(keyDocProperty, ">=", startKey.String()). + Where(keyDocProperty, "<=", endKey.String()). + Limit(limit). + Documents(ctx).GetAll() + if err != nil { + return nil, trace.Wrap(err) + } + brokenDocs, err := b.svc.Collection(b.CollectionName). + Where(keyDocProperty, ">=", startKey). + Where(keyDocProperty, "<=", endKey). Limit(limit). Documents(ctx).GetAll() if err != nil { return nil, trace.Wrap(err) } - allDocs := append(docs, legacyDocs...) + allDocs := append(append(docs, legacyDocs...), brokenDocs...) if len(allDocs) >= backend.DefaultRangeLimit { b.Warnf("Range query hit backend limit. (this is a bug!) startKey=%q,limit=%d", startKey, backend.DefaultRangeLimit) } diff --git a/lib/backend/firestore/firestorebk_test.go b/lib/backend/firestore/firestorebk_test.go index 98430ae02c278..e3bf31cfa1517 100644 --- a/lib/backend/firestore/firestorebk_test.go +++ b/lib/backend/firestore/firestorebk_test.go @@ -37,6 +37,7 @@ import ( "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/api/option" "google.golang.org/genproto/googleapis/rpc/code" @@ -215,6 +216,101 @@ func TestReadLegacyRecord(t *testing.T) { require.Equal(t, item.Expires, got.Expires) } +func TestReadBrokenRecord(t *testing.T) { + cfg := firestoreParams() + ensureTestsEnabled(t) + ensureEmulatorRunning(t, cfg) + + uut := newBackend(t, cfg) + + ctx := context.Background() + + prefix := test.MakePrefix() + + // Create a valid record with the correct key type. + item := backend.Item{ + Key: prefix("valid-record"), + Value: []byte("llamas"), + } + _, err := uut.Put(ctx, item) + require.NoError(t, err) + + // Create a legacy record with a string key type. + lr := legacyRecord{ + Key: prefix("legacy-record").String(), + Value: "sheep", + } + _, err = uut.svc.Collection(uut.CollectionName).Doc(uut.keyToDocumentID(backend.Key(lr.Key))).Set(ctx, lr) + require.NoError(t, err) + + // Create a broken record with a backend.Key key type. + brokenItem := backend.Item{ + Key: prefix("broken-record"), + Value: []byte("foo"), + Expires: uut.clock.Now().Add(time.Minute).Round(time.Second).UTC(), + } + + // Write using broken record format, emulating data written by an older + // version of this backend. + br := brokenRecord{ + Key: brokenItem.Key, + Value: brokenItem.Value, + Expires: brokenItem.Expires.UTC().Unix(), + Timestamp: uut.clock.Now().UTC().Unix(), + } + _, err = uut.svc.Collection(uut.CollectionName).Doc(uut.keyToDocumentID(brokenItem.Key)).Set(ctx, br) + require.NoError(t, err) + + // Read the data back and make sure it matches the original item. + got, err := uut.Get(ctx, brokenItem.Key) + require.NoError(t, err) + require.Equal(t, brokenItem.Key, got.Key) + require.Equal(t, brokenItem.Value, got.Value) + require.Equal(t, brokenItem.Expires, got.Expires) + + // Read the data back using a range query too. + gotRange, err := uut.GetRange(ctx, brokenItem.Key, brokenItem.Key, 1) + require.NoError(t, err) + require.Len(t, gotRange.Items, 1) + + got = &gotRange.Items[0] + require.Equal(t, brokenItem.Key, got.Key) + require.Equal(t, brokenItem.Value, got.Value) + require.Equal(t, brokenItem.Expires, got.Expires) + + // Retrieve the entire key range to validate that there are no duplicate records + results, err := uut.GetRange(ctx, prefix(""), backend.RangeEnd(prefix("")), 5) + require.NoError(t, err) + + require.Len(t, results.Items, 3) + + for _, result := range results.Items { + switch r := result.Key.String(); r { + case item.Key.String(): + assert.Equal(t, item.Value, result.Value) + case br.Key.String(): + assert.Equal(t, br.Value, result.Value) + case lr.Key: + assert.Equal(t, lr.Value, string(result.Value)) + default: + t.Errorf("GetRange returned unexpected item key %s", r) + } + } + + // Update the value and ensure that it's set to the correct key value + item.Value = []byte("llama") + _, err = uut.Update(ctx, item) + require.NoError(t, err) + + doc, err := uut.svc.Collection(uut.CollectionName).Doc(uut.keyToDocumentID(item.Key)).Get(ctx) + require.NoError(t, err) + + var r record + require.NoError(t, doc.DataTo(&r)) + require.Equal(t, []byte(item.Key.String()), r.Key) + require.Equal(t, item.Value, r.Value) +} + type mockFirestoreServer struct { // Embed for forward compatibility. // Tests will keep working if more methods are added