Skip to content

Commit

Permalink
perf: support parallel feedback loading (#753)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghaoz committed Aug 10, 2023
1 parent 2b29104 commit 14e97e6
Show file tree
Hide file tree
Showing 10 changed files with 208 additions and 74 deletions.
19 changes: 19 additions & 0 deletions base/parallel/parallel.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,22 @@ func BatchParallel(nJobs, nWorkers, batchSize int, worker func(workerId, beginJo
}
return nil
}

// Split a slice into n slices and keep the order of elements.
func Split[T any](a []T, n int) [][]T {
if n > len(a) {
n = len(a)
}
minChunkSize := len(a) / n
maxChunkNum := len(a) % n
chunks := make([][]T, n)
for i, j := 0, 0; i < n; i++ {
chunkSize := minChunkSize
if i < maxChunkNum {
chunkSize++
}
chunks[i] = a[j : j+chunkSize]
j += chunkSize
}
return chunks
}
10 changes: 10 additions & 0 deletions base/parallel/parallel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,13 @@ func TestBatchParallelFail(t *testing.T) {
})
assert.Error(t, err)
}

func TestSplit(t *testing.T) {
a := []int{1, 2, 3, 4, 5, 6}
b := Split(a, 3)
assert.Equal(t, [][]int{{1, 2}, {3, 4}, {5, 6}}, b)

a = []int{1, 2, 3, 4, 5, 6, 7}
b = Split(a, 3)
assert.Equal(t, [][]int{{1, 2, 3}, {4, 5}, {6, 7}}, b)
}
2 changes: 1 addition & 1 deletion master/rest.go
Original file line number Diff line number Diff line change
Expand Up @@ -1287,7 +1287,7 @@ func (m *Master) importExportFeedback(response http.ResponseWriter, request *htt
return
}
// write rows
feedbackChan, errChan := m.DataClient.GetFeedbackStream(ctx, batchSize, nil, m.Config.Now())
feedbackChan, errChan := m.DataClient.GetFeedbackStream(ctx, batchSize, data.WithEndTime(*m.Config.Now()))
for feedback := range feedbackChan {
for _, v := range feedback {
if _, err = response.Write([]byte(fmt.Sprintf("%s,%s,%s,%v\r\n",
Expand Down
124 changes: 75 additions & 49 deletions master/tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -1400,14 +1400,15 @@ func (m *Master) LoadDataFromDatabase(ctx context.Context, database data.Databas

startLoadTime := time.Now()
// setup time limit
var itemTimeLimit, feedbackTimeLimit *time.Time
var feedbackTimeLimit data.ScanOption
var itemTimeLimit *time.Time
if itemTTL > 0 {
temp := time.Now().AddDate(0, 0, -int(itemTTL))
itemTimeLimit = &temp
}
if positiveFeedbackTTL > 0 {
temp := time.Now().AddDate(0, 0, -int(positiveFeedbackTTL))
feedbackTimeLimit = &temp
feedbackTimeLimit = data.WithBeginTime(temp)
}
timeWindowLimit := time.Time{}
if m.Config.Recommend.Popular.PopularWindow > 0 {
Expand Down Expand Up @@ -1544,44 +1545,59 @@ func (m *Master) LoadDataFromDatabase(ctx context.Context, database data.Databas
positiveSet[i] = mapset.NewSet[int32]()
}

// split user groups
users := rankingDataset.UserIndex.GetNames()
sort.Strings(users)
userGroups := parallel.Split(users, m.Config.Master.NumJobs)

// STEP 3: pull positive feedback
var mu sync.Mutex
var feedbackCount float64
var posFeedbackCount int
start = time.Now()
err = parallel.Parallel(int(rankingDataset.UserIndex.Len()), m.Config.Master.NumJobs, func(_, userIndex int) error {
// convert user index to id
userId := rankingDataset.UserIndex.ToName(int32(userIndex))
// load positive feedback from database
feedback, err := database.GetUserFeedback(ctx, userId, feedbackTimeLimit, posFeedbackTypes...)
if err != nil {
return errors.Trace(err)
}

for _, f := range feedback {
// convert item id to index
itemIndex := rankingDataset.ItemIndex.ToNumber(f.ItemId)
if itemIndex == base.NotId {
continue
}
positiveSet[userIndex].Add(itemIndex)
// add feedback to ranking dataset
mu.Lock()
feedbackCount++
rankingDataset.AddRawFeedback(int32(userIndex), itemIndex)
// insert feedback to popularity counter
if f.Timestamp.After(timeWindowLimit) && !rankingDataset.HiddenItems[itemIndex] {
popularCount[itemIndex]++
err = parallel.Parallel(len(userGroups), m.Config.Master.NumJobs, func(_, userIndex int) error {
feedbackChan, errChan := database.GetFeedbackStream(newCtx, batchSize,
data.WithBeginUserId(userGroups[userIndex][0]),
data.WithEndUserId(userGroups[userIndex][len(userGroups[userIndex])-1]),
feedbackTimeLimit,
data.WithEndTime(*m.Config.Now()),
data.WithFeedbackTypes(posFeedbackTypes...))
for feedback := range feedbackChan {
for _, f := range feedback {
// convert user and item id to index
userIndex := rankingDataset.UserIndex.ToNumber(f.UserId)
if userIndex == base.NotId {
continue
}
itemIndex := rankingDataset.ItemIndex.ToNumber(f.ItemId)
if itemIndex == base.NotId {
continue
}
// insert feedback to positive set
positiveSet[userIndex].Add(itemIndex)

mu.Lock()
posFeedbackCount++
// insert feedback to ranking dataset
rankingDataset.AddFeedback(f.UserId, f.ItemId, false)
// insert feedback to popularity counter
if f.Timestamp.After(timeWindowLimit) && !rankingDataset.HiddenItems[itemIndex] {
popularCount[itemIndex]++
}
// insert feedback to evaluator
evaluator.Positive(f.FeedbackType, userIndex, itemIndex, f.Timestamp)
mu.Unlock()
}
evaluator.Positive(f.FeedbackType, int32(userIndex), itemIndex, f.Timestamp)
mu.Unlock()
}
if err = <-errChan; err != nil {
return errors.Trace(err)
}
return nil
})
if err != nil {
return nil, nil, nil, nil, errors.Trace(err)
}
log.Logger().Debug("pulled positive feedback from database",
zap.Int("n_positive_feedback", rankingDataset.Count()),
zap.Int("n_positive_feedback", posFeedbackCount),
zap.Duration("used_time", time.Since(start)))
LoadDatasetStepSecondsVec.WithLabelValues("load_positive_feedback").Set(time.Since(start).Seconds())
span.Add(1)
Expand All @@ -1594,34 +1610,44 @@ func (m *Master) LoadDataFromDatabase(ctx context.Context, database data.Databas

// STEP 4: pull negative feedback
start = time.Now()
err = parallel.Parallel(int(rankingDataset.UserIndex.Len()), m.Config.Master.NumJobs, func(_, userIndex int) error {
// convert user index to id
userId := rankingDataset.UserIndex.ToName(int32(userIndex))
// load negative feedback from database
feedback, err := database.GetUserFeedback(ctx, userId, feedbackTimeLimit, readTypes...)
if err != nil {
return errors.Trace(err)
}
for _, f := range feedback {
itemIndex := rankingDataset.ItemIndex.ToNumber(f.ItemId)
if itemIndex == base.NotId {
continue
}
if !positiveSet[userIndex].Contains(itemIndex) {
negativeSet[userIndex].Add(itemIndex)
var negativeFeedbackCount float64
err = parallel.Parallel(len(userGroups), m.Config.Master.NumJobs, func(_, userIndex int) error {
feedbackChan, errChan := database.GetFeedbackStream(newCtx, batchSize,
data.WithBeginUserId(userGroups[userIndex][0]),
data.WithEndUserId(userGroups[userIndex][len(userGroups[userIndex])-1]),
feedbackTimeLimit,
data.WithEndTime(*m.Config.Now()),
data.WithFeedbackTypes(readTypes...))
for feedback := range feedbackChan {
for _, f := range feedback {
userIndex := rankingDataset.UserIndex.ToNumber(f.UserId)
if userIndex == base.NotId {
continue
}
itemIndex := rankingDataset.ItemIndex.ToNumber(f.ItemId)
if itemIndex == base.NotId {
continue
}
if !positiveSet[userIndex].Contains(itemIndex) {
negativeSet[userIndex].Add(itemIndex)
}

mu.Lock()
negativeFeedbackCount++
evaluator.Read(userIndex, itemIndex, f.Timestamp)
mu.Unlock()
}
mu.Lock()
feedbackCount++
evaluator.Read(int32(userIndex), itemIndex, f.Timestamp)
mu.Unlock()
}
if err = <-errChan; err != nil {
return errors.Trace(err)
}
return nil
})
if err != nil {
return nil, nil, nil, nil, errors.Trace(err)
}
FeedbacksTotal.Set(feedbackCount)
log.Logger().Debug("pulled negative feedback from database",
zap.Int("n_negative_feedback", int(negativeFeedbackCount)),
zap.Duration("used_time", time.Since(start)))
LoadDatasetStepSecondsVec.WithLabelValues("load_negative_feedback").Set(time.Since(start).Seconds())
span.Add(1)
Expand Down
57 changes: 56 additions & 1 deletion storage/data/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,61 @@ func (sorter feedbackSorter) Swap(i, j int) {
sorter[i], sorter[j] = sorter[j], sorter[i]
}

type ScanOptions struct {
BeginUserId *string
EndUserId *string
BeginTime *time.Time
EndTime *time.Time
FeedbackTypes []string
}

type ScanOption func(options *ScanOptions)

// WithBeginUserId sets the begin user id. The begin user id is included in the result.
func WithBeginUserId(userId string) ScanOption {
return func(options *ScanOptions) {
options.BeginUserId = &userId
}
}

// WithEndUserId sets the end user id. The end user id is included in the result.
func WithEndUserId(userId string) ScanOption {
return func(options *ScanOptions) {
options.EndUserId = &userId
}
}

// WithBeginTime sets the begin time. The begin time is included in the result.
func WithBeginTime(t time.Time) ScanOption {
return func(options *ScanOptions) {
options.BeginTime = &t
}
}

// WithEndTime sets the end time. The end time is included in the result.
func WithEndTime(t time.Time) ScanOption {
return func(options *ScanOptions) {
options.EndTime = &t
}
}

// WithFeedbackTypes sets the feedback types.
func WithFeedbackTypes(feedbackTypes ...string) ScanOption {
return func(options *ScanOptions) {
options.FeedbackTypes = feedbackTypes
}
}

func NewScanOptions(opts ...ScanOption) ScanOptions {
options := ScanOptions{}
for _, opt := range opts {
if opt != nil {
opt(&options)
}
}
return options
}

type Database interface {
Init() error
Ping() error
Expand All @@ -188,7 +243,7 @@ type Database interface {
GetFeedback(ctx context.Context, cursor string, n int, beginTime, endTime *time.Time, feedbackTypes ...string) (string, []Feedback, error)
GetUserStream(ctx context.Context, batchSize int) (chan []User, chan error)
GetItemStream(ctx context.Context, batchSize int, timeLimit *time.Time) (chan []Item, chan error)
GetFeedbackStream(ctx context.Context, batchSize int, beginTime, endTime *time.Time, feedbackTypes ...string) (chan []Feedback, chan error)
GetFeedbackStream(ctx context.Context, batchSize int, options ...ScanOption) (chan []Feedback, chan error)
}

// Open a connection to a database.
Expand Down
12 changes: 7 additions & 5 deletions storage/data/database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ func (suite *baseTestSuite) getFeedback(ctx context.Context, batchSize int, begi
}
}

func (suite *baseTestSuite) getFeedbackStream(ctx context.Context, batchSize int, beginTime, endTime *time.Time, feedbackTypes ...string) []Feedback {
func (suite *baseTestSuite) getFeedbackStream(ctx context.Context, batchSize int, scanOptions ...ScanOption) []Feedback {
var feedbacks []Feedback
feedbackChan, errChan := suite.Database.GetFeedbackStream(ctx, batchSize, beginTime, endTime, feedbackTypes...)
feedbackChan, errChan := suite.Database.GetFeedbackStream(ctx, batchSize, scanOptions...)
for batchFeedback := range feedbackChan {
feedbacks = append(feedbacks, batchFeedback...)
}
Expand Down Expand Up @@ -250,12 +250,14 @@ func (suite *baseTestSuite) TestFeedback() {
ret = suite.getFeedback(ctx, 2, lo.ToPtr(timestamp.Add(time.Second)), lo.ToPtr(time.Now()))
suite.Empty(ret)
// Get feedback stream
feedbackFromStream := suite.getFeedbackStream(ctx, 3, nil, lo.ToPtr(time.Now()), positiveFeedbackType)
feedbackFromStream := suite.getFeedbackStream(ctx, 3, WithEndTime(time.Now()), WithFeedbackTypes(positiveFeedbackType))
suite.ElementsMatch(feedback, feedbackFromStream)
feedbackFromStream = suite.getFeedbackStream(ctx, 3, nil, lo.ToPtr(time.Now()))
feedbackFromStream = suite.getFeedbackStream(ctx, 3, WithEndTime(time.Now()))
suite.Equal(len(feedback)+2, len(feedbackFromStream))
feedbackFromStream = suite.getFeedbackStream(ctx, 3, lo.ToPtr(timestamp.Add(time.Second)), lo.ToPtr(time.Now()))
feedbackFromStream = suite.getFeedbackStream(ctx, 3, WithBeginTime(timestamp.Add(time.Second)), WithEndTime(time.Now()))
suite.Empty(feedbackFromStream)
feedbackFromStream = suite.getFeedbackStream(ctx, 3, WithBeginUserId("1"), WithEndUserId("3"), WithEndTime(time.Now()), WithFeedbackTypes(positiveFeedbackType))
suite.Equal(feedback[1:4], feedbackFromStream)
// Get items
items := suite.getItems(ctx, 3)
suite.Equal(5, len(items))
Expand Down
33 changes: 24 additions & 9 deletions storage/data/mongodb.go
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,8 @@ func (db *MongoDB) GetFeedback(ctx context.Context, cursor string, n int, beginT
}

// GetFeedbackStream reads feedback from MongoDB by stream.
func (db *MongoDB) GetFeedbackStream(ctx context.Context, batchSize int, beginTime, endTime *time.Time, feedbackTypes ...string) (chan []Feedback, chan error) {
func (db *MongoDB) GetFeedbackStream(ctx context.Context, batchSize int, scanOptions ...ScanOption) (chan []Feedback, chan error) {
scan := NewScanOptions(scanOptions...)
feedbackChan := make(chan []Feedback, bufSize)
errChan := make(chan error, 1)
go func() {
Expand All @@ -695,18 +696,32 @@ func (db *MongoDB) GetFeedbackStream(ctx context.Context, batchSize int, beginTi
opt := options.Find()
filter := make(bson.M)
// pass feedback type to filter
if len(feedbackTypes) > 0 {
filter["feedbackkey.feedbacktype"] = bson.M{"$in": feedbackTypes}
if len(scan.FeedbackTypes) > 0 {
filter["feedbackkey.feedbacktype"] = bson.M{"$in": scan.FeedbackTypes}
}
// pass time limit to filter
timestampConditions := bson.M{}
if beginTime != nil {
timestampConditions["$gt"] = *beginTime
if scan.BeginTime != nil || scan.EndTime != nil {
timestampConditions := bson.M{}
if scan.BeginTime != nil {
timestampConditions["$gt"] = *scan.BeginTime
}
if scan.EndTime != nil {
timestampConditions["$lte"] = *scan.EndTime
}
filter["timestamp"] = timestampConditions
}
if endTime != nil {
timestampConditions["$lte"] = *endTime
// pass user id to filter
if scan.BeginUserId != nil || scan.EndUserId != nil {
userIdConditions := bson.M{}
if scan.BeginUserId != nil {
userIdConditions["$gte"] = *scan.BeginUserId
}
if scan.EndUserId != nil {
userIdConditions["$lte"] = *scan.EndUserId
}
filter["feedbackkey.userid"] = userIdConditions
}
filter["timestamp"] = timestampConditions

r, err := c.Find(ctx, filter, opt)
if err != nil {
errChan <- errors.Trace(err)
Expand Down
2 changes: 1 addition & 1 deletion storage/data/no_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ func (NoDatabase) GetFeedback(_ context.Context, _ string, _ int, _, _ *time.Tim
}

// GetFeedbackStream method of NoDatabase returns ErrNoDatabase.
func (NoDatabase) GetFeedbackStream(_ context.Context, _ int, _, _ *time.Time, _ ...string) (chan []Feedback, chan error) {
func (NoDatabase) GetFeedbackStream(_ context.Context, _ int, _ ...ScanOption) (chan []Feedback, chan error) {
feedbackChan := make(chan []Feedback, bufSize)
errChan := make(chan error, 1)
go func() {
Expand Down
2 changes: 1 addition & 1 deletion storage/data/no_database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,6 @@ func TestNoDatabase(t *testing.T) {
assert.ErrorIs(t, err, ErrNoDatabase)
_, err = database.DeleteUserItemFeedback(ctx, "", "")
assert.ErrorIs(t, err, ErrNoDatabase)
_, c = database.GetFeedbackStream(ctx, 0, nil, lo.ToPtr(time.Now()))
_, c = database.GetFeedbackStream(ctx, 0)
assert.ErrorIs(t, <-c, ErrNoDatabase)
}
Loading

0 comments on commit 14e97e6

Please sign in to comment.