Skip to content

Commit

Permalink
fix: escape special characters for RediSearch (#761)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghaoz committed Aug 19, 2023
1 parent 14e97e6 commit 634d24d
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 24 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:
--health-retries 5
redis:
image: redis/redis-stack
image: redis/redis-stack:6.2.6-v9
ports:
- 6379

Expand Down
50 changes: 28 additions & 22 deletions storage/cache/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,9 @@ func (r *Redis) SearchDocuments(ctx context.Context, collection, subset string,
return nil, nil
}
var builder strings.Builder
builder.WriteString(fmt.Sprintf("@collection:{ %s } @is_hidden:[0 0]", collection))
builder.WriteString(fmt.Sprintf("@collection:{ %s } @is_hidden:[0 0]", escape(collection)))
if subset != "" {
builder.WriteString(fmt.Sprintf(" @subset:{ %s }", subset))
builder.WriteString(fmt.Sprintf(" @subset:{ %s }", escape(subset)))
}
for _, q := range query {
builder.WriteString(fmt.Sprintf(" @categories:{ %s }", encdodeCategory(q)))
Expand Down Expand Up @@ -287,8 +287,8 @@ func (r *Redis) UpdateDocuments(ctx context.Context, collections []string, id st
return nil
}
var builder strings.Builder
builder.WriteString(fmt.Sprintf("@collection:{ %s }", strings.Join(collections, " | ")))
builder.WriteString(fmt.Sprintf(" @id:{ %s }", id))
builder.WriteString(fmt.Sprintf("@collection:{ %s }", escape(strings.Join(collections, " | "))))
builder.WriteString(fmt.Sprintf(" @id:{ %s }", escape(id)))
for {
// search documents
result, err := r.client.Do(ctx, "FT.SEARCH", r.DocumentTable(), builder.String(), "SORTBY", "score", "DESC", "LIMIT", 0, 10000).Result()
Expand Down Expand Up @@ -335,12 +335,12 @@ func (r *Redis) DeleteDocuments(ctx context.Context, collections []string, condi
return errors.Trace(err)
}
var builder strings.Builder
builder.WriteString(fmt.Sprintf("@collection:{ %s }", strings.Join(collections, " | ")))
builder.WriteString(fmt.Sprintf("@collection:{ %s }", escape(strings.Join(collections, " | "))))
if condition.Subset != nil {
builder.WriteString(fmt.Sprintf(" @subset:{ %s }", *condition.Subset))
builder.WriteString(fmt.Sprintf(" @subset:{ %s }", escape(*condition.Subset)))
}
if condition.Id != nil {
builder.WriteString(fmt.Sprintf(" @id:{ %s }", *condition.Id))
builder.WriteString(fmt.Sprintf(" @id:{ %s }", escape(*condition.Id)))
}
if condition.Before != nil {
builder.WriteString(fmt.Sprintf(" @timestamp:[-inf,%d]", condition.Before.UnixMicro()))
Expand Down Expand Up @@ -375,21 +375,21 @@ func (r *Redis) DeleteDocuments(ctx context.Context, collections []string, condi
func parseSearchDocumentsResult(result any) (count int64, keys []string, documents []Document, err error) {
rows, ok := result.([]any)
if !ok {
return 0, nil, nil, errors.New("invalid FT.SEARCH result")
return 0, nil, nil, errors.Errorf("invalid FT.SEARCH result: %#v", result)
}
count, ok = rows[0].(int64)
if !ok {
return 0, nil, nil, errors.New("invalid FT.SEARCH result")
return 0, nil, nil, errors.Errorf("invalid FT.SEARCH result: %#v", rows[0])
}
for i := 1; i < len(rows); i += 2 {
key, ok := rows[i].(string)
if !ok {
return 0, nil, nil, errors.New("invalid FT.SEARCH result")
return 0, nil, nil, errors.Errorf("invalid FT.SEARCH result: %#v", rows[i])
}
keys = append(keys, key)
row, ok := rows[i+1].([]any)
if !ok {
return 0, nil, nil, errors.New("invalid FT.SEARCH result")
return 0, nil, nil, errors.Errorf("invalid FT.SEARCH result: %#v", rows[i+1])
}
fields := make(map[string]any)
for j := 0; j < len(row); j += 2 {
Expand All @@ -398,27 +398,27 @@ func parseSearchDocumentsResult(result any) (count int64, keys []string, documen
var document Document
document.Id, ok = fields["id"].(string)
if !ok {
return 0, nil, nil, errors.New("invalid FT.SEARCH result")
return 0, nil, nil, errors.Errorf("invalid FT.SEARCH result: %#v", fields["id"])
}
score, ok := fields["score"].(string)
if !ok {
return 0, nil, nil, errors.New("invalid FT.SEARCH result")
return 0, nil, nil, errors.Errorf("invalid FT.SEARCH result: %#v", fields["score"])
}
document.Score, err = strconv.ParseFloat(score, 64)
if err != nil {
return 0, nil, nil, errors.Trace(err)
}
categories, ok := fields["categories"].(string)
if !ok {
return 0, nil, nil, errors.New("invalid FT.SEARCH result")
return 0, nil, nil, errors.Errorf("invalid FT.SEARCH result: %#v", fields["categories"])
}
document.Categories, err = decodeCategories(categories)
if err != nil {
return 0, nil, nil, errors.Trace(err)
}
timestamp, ok := fields["timestamp"].(string)
if !ok {
return 0, nil, nil, errors.New("invalid FT.SEARCH result")
return 0, nil, nil, errors.Errorf("invalid FT.SEARCH result: %#v", fields["timestamp"])
}
timestampMicros, err := strconv.ParseInt(timestamp, 10, 64)
if err != nil {
Expand Down Expand Up @@ -463,21 +463,21 @@ func (r *Redis) GetTimeSeriesPoints(ctx context.Context, name string, begin, end
func parseGetTimeSeriesPointsResult(result any) (count int64, keys []string, points []TimeSeriesPoint, err error) {
rows, ok := result.([]any)
if !ok {
return 0, nil, nil, errors.New("invalid FT.SEARCH result")
return 0, nil, nil, errors.Errorf("invalid FT.SEARCH result: %#v", result)
}
count, ok = rows[0].(int64)
if !ok {
return 0, nil, nil, errors.New("invalid FT.SEARCH result")
return 0, nil, nil, errors.Errorf("invalid FT.SEARCH result: %#v", rows[0])
}
for i := 1; i < len(rows); i += 2 {
key, ok := rows[i].(string)
if !ok {
return 0, nil, nil, errors.New("invalid FT.SEARCH result")
return 0, nil, nil, errors.Errorf("invalid FT.SEARCH result: %#v", rows[i])
}
keys = append(keys, key)
row, ok := rows[i+1].([]any)
if !ok {
return 0, nil, nil, errors.New("invalid FT.SEARCH result")
return 0, nil, nil, errors.Errorf("invalid FT.SEARCH result: %#v", rows[i+1])
}
fields := make(map[string]any)
for j := 0; j < len(row); j += 2 {
Expand All @@ -486,19 +486,19 @@ func parseGetTimeSeriesPointsResult(result any) (count int64, keys []string, poi
var point TimeSeriesPoint
point.Name, ok = fields["name"].(string)
if !ok {
return 0, nil, nil, errors.New("invalid FT.SEARCH result")
return 0, nil, nil, errors.Errorf("invalid FT.SEARCH result: %#v", fields["name"])
}
value, ok := fields["value"].(string)
if !ok {
return 0, nil, nil, errors.New("invalid FT.SEARCH result")
return 0, nil, nil, errors.Errorf("invalid FT.SEARCH result: %#v", fields["value"])
}
point.Value, err = strconv.ParseFloat(value, 64)
if err != nil {
return 0, nil, nil, errors.Trace(err)
}
timestamp, ok := fields["timestamp"].(string)
if !ok {
return 0, nil, nil, errors.New("invalid FT.SEARCH result")
return 0, nil, nil, errors.Errorf("invalid FT.SEARCH result: %#v", fields["timestamp"])
}
timestampMicros, err := strconv.ParseInt(timestamp, 10, 64)
if err != nil {
Expand Down Expand Up @@ -544,3 +544,9 @@ func decodeCategories(s string) ([]string, error) {
}
return categories, nil
}

// escape -:.
func escape(s string) string {
r := strings.NewReplacer("-", "\\-", ":", "\\:", ".", "\\.")
return r.Replace(s)
}
41 changes: 41 additions & 0 deletions storage/cache/redis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@ package cache

import (
"context"
"fmt"
"math"
"os"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
"github.com/zhenghaoz/gorse/base/log"
"google.golang.org/protobuf/proto"
)

var (
Expand Down Expand Up @@ -55,6 +59,43 @@ func (suite *RedisTestSuite) SetupSuite() {
suite.NoError(err)
}

func (suite *RedisTestSuite) TestEscapeCharacters() {
ts := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC)
ctx := context.Background()
for _, c := range []string{"-", ":", "."} {
suite.Run(c, func() {
collection := fmt.Sprintf("a%s1", c)
subset := fmt.Sprintf("b%s2", c)
id := fmt.Sprintf("c%s3", c)
err := suite.AddDocuments(ctx, collection, subset, []Document{{
Id: id,
Score: math.MaxFloat64,
Categories: []string{"a", "b"},
Timestamp: ts,
}})
suite.NoError(err)
documents, err := suite.SearchDocuments(ctx, collection, subset, []string{"b"}, 0, -1)
suite.NoError(err)
suite.Equal([]Document{{Id: id, Score: math.MaxFloat64, Categories: []string{"a", "b"}, Timestamp: ts}}, documents)

err = suite.UpdateDocuments(ctx, []string{collection}, id, DocumentPatch{Score: proto.Float64(1)})
suite.NoError(err)
documents, err = suite.SearchDocuments(ctx, collection, subset, []string{"b"}, 0, -1)
suite.NoError(err)
suite.Equal([]Document{{Id: id, Score: 1, Categories: []string{"a", "b"}, Timestamp: ts}}, documents)

err = suite.DeleteDocuments(ctx, []string{collection}, DocumentCondition{
Subset: proto.String(subset),
Id: proto.String(id),
})
suite.NoError(err)
documents, err = suite.SearchDocuments(ctx, collection, subset, []string{"b"}, 0, -1)
suite.NoError(err)
suite.Empty(documents)
})
}
}

func TestRedis(t *testing.T) {
suite.Run(t, new(RedisTestSuite))
}
Expand Down
2 changes: 1 addition & 1 deletion storage/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
version: "3"
services:
redis:
image: redis/redis-stack
image: redis/redis-stack:6.2.6-v9
ports:
- 6379:6379

Expand Down

0 comments on commit 634d24d

Please sign in to comment.