Skip to content

Commit

Permalink
fix: factorization machines diverged (#743)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghaoz committed Jul 29, 2023
1 parent e3e44b1 commit 59ab0b4
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 8 deletions.
23 changes: 19 additions & 4 deletions base/progress/progress.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,15 @@ func (s *Span) Add(n int) {
}

func (s *Span) End() {
s.status = StatusComplete
s.count = s.total
s.finish = time.Now()
if s.status == StatusRunning {
s.status = StatusComplete
s.count = s.total
s.finish = time.Now()
}
}

func (s *Span) Error(err error) {
func (s *Span) Fail(err error) {
s.status = StatusFailed
s.err = err.Error()
}

Expand All @@ -111,6 +114,10 @@ func (s *Span) Progress() Progress {
if progress.Status == StatusRunning {
children = append(children, progress)
}
if s.err == "" && progress.Error != "" {
s.err = progress.Error
s.status = StatusFailed
}
return true
})
// leaf node
Expand Down Expand Up @@ -162,6 +169,14 @@ func Start(ctx context.Context, name string, total int) (context.Context, *Span)
return context.WithValue(ctx, spanKeyName, childSpan), childSpan
}

func Fail(ctx context.Context, err error) {
span, ok := (ctx).Value(spanKeyName).(*Span)
if !ok {
return
}
span.Fail(err)
}

type Progress struct {
Tracer string
Name string
Expand Down
4 changes: 3 additions & 1 deletion master/tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ func (t *FindItemNeighborsTask) run(ctx context.Context, j *task.JobsAllocator)
numItems := dataset.ItemCount()
numFeedback := dataset.Count()

_, span := t.tracer.Start(ctx, "Find Item Neighbors", dataset.ItemCount())
newCtx, span := t.tracer.Start(ctx, "Find Item Neighbors", dataset.ItemCount())
defer span.End()

if numItems == 0 {
Expand Down Expand Up @@ -312,6 +312,7 @@ func (t *FindItemNeighborsTask) run(ctx context.Context, j *task.JobsAllocator)
close(completed)
if err != nil {
log.Logger().Error("failed to searching neighbors of items", zap.Error(err))
progress.Fail(newCtx, err)
FindItemNeighborsTotalSeconds.Set(0)
} else {
if err := t.CacheClient.Set(ctx, cache.Time(cache.Key(cache.GlobalMeta, cache.LastUpdateItemNeighborsTime), time.Now())); err != nil {
Expand Down Expand Up @@ -641,6 +642,7 @@ func (t *FindUserNeighborsTask) run(ctx context.Context, j *task.JobsAllocator)
close(completed)
if err != nil {
log.Logger().Error("failed to searching neighbors of users", zap.Error(err))
progress.Fail(newCtx, err)
FindUserNeighborsTotalSeconds.Set(0)
} else {
if err := t.CacheClient.Set(ctx, cache.Time(cache.Key(cache.GlobalMeta, cache.LastUpdateUserNeighborsTime), time.Now())); err != nil {
Expand Down
13 changes: 11 additions & 2 deletions model/click/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,8 @@ func (fm *FM) Fit(ctx context.Context, trainSet, testSet *Dataset, config *FitCo
cost += grad * grad / 2
case FMClassification:
grad = -target * (1 - 1/(1+math32.Exp(-target*prediction)))
cost += (1 + target) * math32.Log(1+math32.Exp(-prediction)) / 2
cost += (1 - target) * math32.Log(1+math32.Exp(prediction)) / 2
cost += (1 + target) * math32.Log1p(exp(-prediction)) / 2
cost += (1 - target) * math32.Log1p(exp(prediction)) / 2
default:
log.Logger().Fatal("unknown task", zap.String("task", string(fm.Task)))
}
Expand Down Expand Up @@ -434,6 +434,7 @@ func (fm *FM) Fit(ctx context.Context, trainSet, testSet *Dataset, config *FitCo
// check NaN
if math32.IsNaN(cost) || math32.IsNaN(score.GetValue()) {
log.Logger().Warn("model diverged", zap.Float32("lr", fm.lr))
span.Fail(errors.New("model diverged"))
break
}
snapshots.AddSnapshot(score, fm.V, fm.W, fm.B)
Expand Down Expand Up @@ -645,3 +646,11 @@ func (fm *FM) Unmarshal(r io.Reader) error {
}
return nil
}

func exp(x float32) float32 {
e := math32.Exp(x)
if math32.IsInf(e, 1) {
return math32.MaxFloat32
}
return e
}
2 changes: 1 addition & 1 deletion model/ranking/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ func (bpr *BPR) Fit(ctx context.Context, trainSet, valSet *DataSet, config *FitC
}
}
diff := bpr.InternalPredict(userIndex, posIndex) - bpr.InternalPredict(userIndex, negIndex)
cost[workerId] += math32.Log(1 + math32.Exp(-diff))
cost[workerId] += math32.Log1p(math32.Exp(-diff))
grad := math32.Exp(-diff) / (1.0 + math32.Exp(-diff))
// Pairwise update
copy(userFactor[workerId], bpr.UserFactor[userIndex])
Expand Down
3 changes: 3 additions & 0 deletions storage/cache/mongodb.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,9 @@ func (m MongoDB) Remain(ctx context.Context, name string) (int64, error) {
}

func (m MongoDB) AddDocuments(ctx context.Context, collection, subset string, documents []Document) error {
if len(documents) == 0 {
return nil
}
var models []mongo.WriteModel
for _, document := range documents {
models = append(models, mongo.NewUpdateOneModel().
Expand Down

0 comments on commit 59ab0b4

Please sign in to comment.