Skip to content

Commit

Permalink
Boolean aggregates.
Browse files Browse the repository at this point in the history
  • Loading branch information
ncruces committed Jun 6, 2024
1 parent 8fd878a commit dbf764a
Show file tree
Hide file tree
Showing 7 changed files with 163 additions and 40 deletions.
6 changes: 2 additions & 4 deletions ext/stats/TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,8 @@ https://sqlite.org/windowfunctions.html#builtins

## Boolean aggregates

- [ ] `ALL(boolean)`
- [ ] `ANY(boolean)`
- [ ] `EVERY(boolean)`
- [ ] `SOME(boolean)`
- [X] `EVERY(boolean)`
- [X] `SOME(boolean)`

## Additional aggregates

Expand Down
46 changes: 46 additions & 0 deletions ext/stats/boolean.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package stats

import "github.com/ncruces/go-sqlite3"

const (
every = iota
some
)

func newBoolean(kind int) func() sqlite3.AggregateFunction {
return func() sqlite3.AggregateFunction { return &boolean{kind: kind} }
}

type boolean struct {
count int
total int
kind int
}

func (b *boolean) Value(ctx sqlite3.Context) {
if b.kind == every {
ctx.ResultBool(b.count == b.total)
} else {
ctx.ResultBool(b.count > 0)
}
}

func (b *boolean) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
if arg[0].Type() == sqlite3.NULL {
return
}
if arg[0].Bool() {
b.count++
}
b.total++
}

func (b *boolean) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
if arg[0].Type() == sqlite3.NULL {
return
}
if arg[0].Bool() {
b.count--
}
b.total--
}
74 changes: 74 additions & 0 deletions ext/stats/boolean_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package stats_test

import (
"testing"

"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/ext/stats"
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
)

func TestRegister_boolean(t *testing.T) {
t.Parallel()

db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()

stats.Register(db)

err = db.Exec(`CREATE TABLE data (x)`)
if err != nil {
t.Fatal(err)
}

err = db.Exec(`INSERT INTO data (x) VALUES (4), (7.0), (13), (NULL), (16), (3.14)`)
if err != nil {
t.Fatal(err)
}

stmt, _, err := db.Prepare(`
SELECT
every(x > 0),
every(x > 10),
some(x > 10),
some(x > 20)
FROM data`)
if err != nil {
t.Fatal(err)
}
if stmt.Step() {
if got := stmt.ColumnBool(0); got != true {
t.Errorf("got %v, want true", got)
}
if got := stmt.ColumnBool(1); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnBool(2); got != true {
t.Errorf("got %v, want true", got)
}
if got := stmt.ColumnBool(3); got != false {
t.Errorf("got %v, want false", got)
}
}
stmt.Close()

stmt, _, err = db.Prepare(`SELECT every(x > 10) OVER (ROWS 1 PRECEDING) FROM data`)
if err != nil {
t.Fatal(err)
}

want := [...]bool{false, false, false, true, true, false}
for i := 0; stmt.Step(); i++ {
if got := stmt.ColumnBool(0); got != want[i] {
t.Errorf("got %v, want %v", got, want[i])
}
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
t.Errorf("got %v, want INTEGER", got)
}
}
stmt.Close()
}
11 changes: 11 additions & 0 deletions ext/stats/stats.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
// - quantile_disc: discrete quantile
// - quantile_cont: continuous quantile
// - median: median value
// - every: boolean and
// - some: boolean or
//
// These join the [Built-in Aggregate Functions]:
// - count: count rows/values
Expand All @@ -29,9 +31,16 @@
// - min: minimum value
// - max: maximum value
//
// And the [Built-in Window Functions]:
// - rank: rank of the current row with gaps
// - dense_rank: rank of the current row without gaps
// - percent_rank: relative rank of the row
// - cume_dist: cumulative distribution
//
// See: [ANSI SQL Aggregate Functions], [DuckDB Aggregate Functions]
//
// [Built-in Aggregate Functions]: https://sqlite.org/lang_aggfunc.html
// [Built-in Window Functions]: https://sqlite.org/windowfunctions.html#builtins
// [ANSI SQL Aggregate Functions]: https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html
// [DuckDB Aggregate Functions]: https://duckdb.org/docs/sql/aggregates.html
package stats
Expand Down Expand Up @@ -61,6 +70,8 @@ func Register(db *sqlite3.Conn) {
db.CreateWindowFunction("median", 1, flags, newQuantile(median))
db.CreateWindowFunction("quantile_cont", 2, flags, newQuantile(quant_cont))
db.CreateWindowFunction("quantile_disc", 2, flags, newQuantile(quant_disc))
db.CreateWindowFunction("every", 1, flags, newBoolean(every))
db.CreateWindowFunction("some", 1, flags, newBoolean(some))
}

const (
Expand Down
58 changes: 26 additions & 32 deletions ext/stats/stats_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ func TestRegister_variance(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer stmt.Close()

if stmt.Step() {
if got := stmt.ColumnFloat(0); got != 40 {
t.Errorf("got %v, want 40", got)
Expand All @@ -62,24 +60,23 @@ func TestRegister_variance(t *testing.T) {
t.Errorf("got %v, want √22.5", got)
}
}
stmt.Close()

{
stmt, _, err := db.Prepare(`SELECT var_samp(x) OVER (ROWS 1 PRECEDING) FROM data`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
stmt, _, err = db.Prepare(`SELECT var_samp(x) OVER (ROWS 1 PRECEDING) FROM data`)
if err != nil {
t.Fatal(err)
}

want := [...]float64{0, 4.5, 18, 0, 0}
for i := 0; stmt.Step(); i++ {
if got := stmt.ColumnFloat(0); got != want[i] {
t.Errorf("got %v, want %v", got, want[i])
}
if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) {
t.Errorf("got %v, want %v", got, want[i])
}
want := [...]float64{0, 4.5, 18, 0, 0}
for i := 0; stmt.Step(); i++ {
if got := stmt.ColumnFloat(0); got != want[i] {
t.Errorf("got %v, want %v", got, want[i])
}
if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) {
t.Errorf("got %v, want %v", got, want[i])
}
}
stmt.Close()
}

func TestRegister_covariance(t *testing.T) {
Expand Down Expand Up @@ -113,8 +110,6 @@ func TestRegister_covariance(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer stmt.Close()

if stmt.Step() {
if got := stmt.ColumnFloat(0); got != 0.9881049293224639 {
t.Errorf("got %v, want 0.9881049293224639", got)
Expand Down Expand Up @@ -159,24 +154,23 @@ func TestRegister_covariance(t *testing.T) {
t.Errorf("got %v, want 5", got)
}
}
stmt.Close()

{
stmt, _, err := db.Prepare(`SELECT covar_samp(y, x) OVER (ROWS 1 PRECEDING) FROM data`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
stmt, _, err = db.Prepare(`SELECT covar_samp(y, x) OVER (ROWS 1 PRECEDING) FROM data`)
if err != nil {
t.Fatal(err)
}

want := [...]float64{0, 10, 30, 75, 22.5}
for i := 0; stmt.Step(); i++ {
if got := stmt.ColumnFloat(0); got != want[i] {
t.Errorf("got %v, want %v", got, want[i])
}
if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) {
t.Errorf("got %v, want %v", got, want[i])
}
want := [...]float64{0, 10, 30, 75, 22.5}
for i := 0; stmt.Step(); i++ {
if got := stmt.ColumnFloat(0); got != want[i] {
t.Errorf("got %v, want %v", got, want[i])
}
if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) {
t.Errorf("got %v, want %v", got, want[i])
}
}
stmt.Close()
}

func Benchmark_average(b *testing.B) {
Expand Down
4 changes: 2 additions & 2 deletions stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -441,12 +441,12 @@ func (s *Stmt) ColumnOriginName(col int) string {
// ColumnBool returns the value of the result column as a bool.
// The leftmost column of the result set has the index 0.
// SQLite does not have a separate boolean storage class.
// Instead, boolean values are retrieved as integers,
// Instead, boolean values are retrieved as numbers,
// with 0 converted to false and any other value to true.
//
// https://sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnBool(col int) bool {
return s.ColumnInt64(col) != 0
return s.ColumnFloat(col) != 0
}

// ColumnInt returns the value of the result column as an int.
Expand Down
4 changes: 2 additions & 2 deletions value.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,12 @@ func (v Value) NumericType() Datatype {

// Bool returns the value as a bool.
// SQLite does not have a separate boolean storage class.
// Instead, boolean values are retrieved as integers,
// Instead, boolean values are retrieved as numbers,
// with 0 converted to false and any other value to true.
//
// https://sqlite.org/c3ref/value_blob.html
func (v Value) Bool() bool {
return v.Int64() != 0
return v.Float() != 0
}

// Int returns the value as an int.
Expand Down

0 comments on commit dbf764a

Please sign in to comment.