-
Notifications
You must be signed in to change notification settings - Fork 1
/
accuracy.go
77 lines (65 loc) · 1.64 KB
/
accuracy.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
package mlmetrics
import (
"sync"
)
// Accuracy is a basic classification metric. It measures how often the
// classifier makes the correct prediction. It is the ratio between the
// weight of correct predictions and the total weight of predictions.
type Accuracy struct {
observed float64
correct float64
mu sync.RWMutex
}
// NewAccuracy inits a new metric.
func NewAccuracy() *Accuracy {
return &Accuracy{}
}
// Reset resets state.
func (m *Accuracy) Reset() {
m.mu.Lock()
m.observed = 0
m.correct = 0
m.mu.Unlock()
}
// Observe records an observation of the actual vs the predicted category.
func (m *Accuracy) Observe(actual, predicted int) {
m.ObserveWeight(actual, predicted, 1.0)
}
// ObserveWeight records an observation of the actual vs the predicted category with a given weight.
func (m *Accuracy) ObserveWeight(actual, predicted int, weight float64) {
if !isValidCategory(actual) || !isValidCategory(predicted) || !isValidWeight(weight) {
return
}
equal := predicted == actual
m.mu.Lock()
m.observed += weight
if equal {
m.correct += weight
}
m.mu.Unlock()
}
// TotalWeight returns the total weight observed.
func (m *Accuracy) TotalWeight() float64 {
m.mu.RLock()
observed := m.observed
m.mu.RUnlock()
return observed
}
// CorrectWeight returns the weight of correct observations.
func (m *Accuracy) CorrectWeight() float64 {
m.mu.RLock()
correct := m.correct
m.mu.RUnlock()
return correct
}
// Rate returns the rate of correct predictions.
func (m *Accuracy) Rate() float64 {
m.mu.RLock()
observed := m.observed
correct := m.correct
m.mu.RUnlock()
if observed == 0 {
return 0
}
return correct / observed
}