Skip to content

Commit

Permalink
refactor: implement context based progress tracker (#741)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghaoz authored Jul 29, 2023
1 parent 17f3f73 commit e3e44b1
Show file tree
Hide file tree
Showing 41 changed files with 1,038 additions and 1,442 deletions.
2 changes: 1 addition & 1 deletion base/parallel/parallel.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func DynamicParallel(nJobs int, jobsAlloc *task.JobsAllocator, worker func(worke
// consumer
for {
exit := atomic.NewBool(true)
numJobs := jobsAlloc.AvailableJobs(nil)
numJobs := jobsAlloc.AvailableJobs()
var wg sync.WaitGroup
wg.Add(numJobs)
errs := make([]error, nJobs)
Expand Down
174 changes: 174 additions & 0 deletions base/progress/progress.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
// Copyright 2023 gorse Project Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package progress

import (
"context"
"sort"
"sync"
"time"

"github.com/google/uuid"
"modernc.org/mathutil"
)

type spanKeyType string

var spanKeyName = spanKeyType(uuid.New().String())

type Status string

const (
StatusPending Status = "Pending"
StatusComplete Status = "Complete"
StatusRunning Status = "Running"
StatusSuspended Status = "Suspended"
StatusFailed Status = "Failed"
)

type Tracer struct {
name string
spans sync.Map
}

func NewTracer(name string) *Tracer {
return &Tracer{name: name}
}

// Start creates a root span.
func (t *Tracer) Start(ctx context.Context, name string, total int) (context.Context, *Span) {
span := &Span{
name: name,
status: StatusRunning,
total: total,
start: time.Now(),
}
t.spans.Store(name, span)
return context.WithValue(ctx, spanKeyName, span), span
}

func (t *Tracer) List() []Progress {
var progress []Progress
t.spans.Range(func(key, value interface{}) bool {
span := value.(*Span)
progress = append(progress, span.Progress())
return true
})
// sort by start time
sort.Slice(progress, func(i, j int) bool {
return progress[i].StartTime.Before(progress[j].StartTime)
})
return progress
}

type Span struct {
name string
status Status
total int
count int
err string
start time.Time
finish time.Time
children sync.Map
}

func (s *Span) Add(n int) {
s.count = mathutil.Min(s.count+n, s.total)
}

func (s *Span) End() {
s.status = StatusComplete
s.count = s.total
s.finish = time.Now()
}

func (s *Span) Error(err error) {
s.err = err.Error()
}

func (s *Span) Count() int {
return s.count
}

func (s *Span) Progress() Progress {
// find running children
var children []Progress
s.children.Range(func(key, value interface{}) bool {
child := value.(*Span)
progress := child.Progress()
if progress.Status == StatusRunning {
children = append(children, progress)
}
return true
})
// leaf node
if len(children) == 0 {
return Progress{
Name: s.name,
Status: s.status,
Error: s.err,
Count: s.count,
Total: s.total,
StartTime: s.start,
FinishTime: s.finish,
}
}
// non-leaf node
childTotal := children[0].Total
parentTotal := s.total * childTotal
parentCount := s.count * childTotal
for _, child := range children {
parentCount += childTotal * child.Count / child.Total
}
return Progress{
Name: s.name,
Status: s.status,
Error: s.err,
Count: parentCount,
Total: parentTotal,
StartTime: s.start,
FinishTime: s.finish,
}
}

func Start(ctx context.Context, name string, total int) (context.Context, *Span) {
childSpan := &Span{
name: name,
status: StatusRunning,
total: total,
count: 0,
start: time.Now(),
}
if ctx == nil {
return nil, childSpan
}
span, ok := (ctx).Value(spanKeyName).(*Span)
if !ok {
return nil, childSpan
}
span.children.Store(name, childSpan)
return context.WithValue(ctx, spanKeyName, childSpan), childSpan
}

type Progress struct {
Tracer string
Name string
Status Status
Error string
Count int
Total int
StartTime time.Time
FinishTime time.Time
}
34 changes: 34 additions & 0 deletions base/progress/progress_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright 2023 gorse Project Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package progress

import (
"testing"

"github.com/stretchr/testify/suite"
)

type ProgressTestSuite struct {
suite.Suite
tracer Tracer
}

func (suite *ProgressTestSuite) SetupTest() {
suite.tracer = Tracer{}
}

func TestProgressTestSuite(t *testing.T) {
suite.Run(t, new(ProgressTestSuite))
}
4 changes: 3 additions & 1 deletion base/search/bruteforce.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
package search

import (
"context"

"github.com/zhenghaoz/gorse/base/heap"
)

Expand All @@ -26,7 +28,7 @@ type Bruteforce struct {
}

// Build a vector index on data.
func (b *Bruteforce) Build() {}
func (b *Bruteforce) Build(_ context.Context) {}

// NewBruteforce creates a Bruteforce vector index.
func NewBruteforce(vectors []Vector) *Bruteforce {
Expand Down
23 changes: 12 additions & 11 deletions base/search/hnsw.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package search

import (
"context"
"math/rand"
"runtime"
"sync"
Expand All @@ -26,7 +27,7 @@ import (
"github.com/zhenghaoz/gorse/base/heap"
"github.com/zhenghaoz/gorse/base/log"
"github.com/zhenghaoz/gorse/base/parallel"
"github.com/zhenghaoz/gorse/base/task"
"github.com/zhenghaoz/gorse/base/progress"
"go.uber.org/zap"
"modernc.org/mathutil"
)
Expand All @@ -49,7 +50,6 @@ type HNSW struct {
maxConnection0 int
efConstruction int
numJobs int
task *task.SubTask
}

// HNSWConfig is the configuration function for HNSW.
Expand Down Expand Up @@ -123,24 +123,26 @@ func (h *HNSW) knnSearch(q Vector, k, ef int) *heap.PriorityQueue {
}

// Build a vector index on data.
func (h *HNSW) Build() {
func (h *HNSW) Build(ctx context.Context) {
completed := make(chan struct{}, h.numJobs)
go func() {
defer base.CheckPanic()
completedCount, previousCount := 0, 0
ticker := time.NewTicker(10 * time.Second)
_, span := progress.Start(ctx, "HNSW.Build", len(h.vectors))
for {
select {
case _, ok := <-completed:
if !ok {
span.End()
return
}
completedCount++
case <-ticker.C:
throughput := completedCount - previousCount
previousCount = completedCount
h.task.Add(throughput * len(h.vectors))
if throughput > 0 {
span.Add(throughput)
log.Logger().Info("building index",
zap.Int("n_indexed_vectors", completedCount),
zap.Int("n_vectors", len(h.vectors)),
Expand Down Expand Up @@ -320,7 +322,7 @@ func NewHNSWBuilder(data []Vector, k, numJobs int) *HNSWBuilder {
rng: base.NewRandomGenerator(0),
numJobs: numJobs,
}
b.bruteForce.Build()
b.bruteForce.Build(context.Background())
return b
}

Expand Down Expand Up @@ -361,20 +363,19 @@ func (b *HNSWBuilder) evaluate(idx *HNSW, prune0 bool) float32 {
return result / count
}

func (b *HNSWBuilder) Build(recall float32, trials int, prune0 bool, t *task.Task) (idx *HNSW, score float32) {
buildTask := t.SubTask(EstimateHNSWBuilderComplexity(len(b.data), trials))
defer buildTask.Finish()
func (b *HNSWBuilder) Build(ctx context.Context, recall float32, trials int, prune0 bool) (idx *HNSW, score float32) {
ef := 1 << int(math32.Ceil(math32.Log2(float32(b.k))))
newCtx, span := progress.Start(ctx, "HNSWBuilder.Build", trials)
defer span.End()
for i := 0; i < trials; i++ {
start := time.Now()
idx = NewHNSW(b.data,
SetEFConstruction(ef),
SetHNSWNumJobs(b.numJobs))
idx.task = buildTask
idx.Build()
idx.Build(newCtx)
buildTime := time.Since(start)
score = b.evaluate(idx, prune0)
idx.task.Add(b.testSize * len(b.data))
span.Add(1)
log.Logger().Info("try to build vector index",
zap.String("index_type", "HNSW"),
zap.Int("ef_construction", ef),
Expand Down
8 changes: 5 additions & 3 deletions base/search/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,16 @@
package search

import (
"context"
"fmt"
"reflect"
"sort"

"github.com/chewxy/math32"
"github.com/zhenghaoz/gorse/base/floats"
"github.com/zhenghaoz/gorse/base/log"
"go.uber.org/zap"
"modernc.org/sortutil"
"reflect"
"sort"
)

type Vector interface {
Expand Down Expand Up @@ -182,7 +184,7 @@ func (v *DictionaryCentroidVector) Distance(vector Vector) float32 {
}

type VectorIndex interface {
Build()
Build(ctx context.Context)
Search(q Vector, n int, prune0 bool) ([]int32, []float32)
MultiSearch(q Vector, terms []string, n int, prune0 bool) (map[string][]int32, map[string][]float32)
}
Loading

0 comments on commit e3e44b1

Please sign in to comment.