Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatically load extensions. #115

Merged
merged 7 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ func newConn(filename string, flags OpenFlag) (conn *Conn, err error) {
c.arena = c.newArena(1024)
c.ctx = context.WithValue(c.ctx, connKey{}, c)
c.handle, err = c.openDB(filename, flags)
if err == nil {
err = initExtensions(c)
}
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions ext/array/array.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ import (
// The argument must be bound to a Go slice or array of
// ints, floats, bools, strings or byte slices,
// using [sqlite3.BindPointer] or [sqlite3.Pointer].
func Register(db *sqlite3.Conn) {
sqlite3.CreateModule(db, "array", nil,
func Register(db *sqlite3.Conn) error {
return sqlite3.CreateModule(db, "array", nil,
func(db *sqlite3.Conn, _, _, _ string, _ ...string) (array, error) {
err := db.DeclareVTab(`CREATE TABLE x(value, array HIDDEN)`)
return array{}, err
Expand Down
19 changes: 8 additions & 11 deletions ext/array/array_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@ import (
)

func Example_driver() {
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
array.Register(c)
return nil
})
db, err := driver.Open(":memory:", array.Register)
if err != nil {
log.Fatal(err)
}
Expand Down Expand Up @@ -53,14 +50,14 @@ func Example_driver() {
}

func Example() {
sqlite3.AutoExtension(array.Register)

db, err := sqlite3.Open(":memory:")
if err != nil {
log.Fatal(err)
}
defer db.Close()

array.Register(db)

stmt, _, err := db.Prepare(`
SELECT name
FROM pragma_function_list
Expand Down Expand Up @@ -91,10 +88,7 @@ func Example() {
func Test_cursor_Column(t *testing.T) {
t.Parallel()

db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
array.Register(c)
return nil
})
db, err := driver.Open(":memory:", array.Register)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -139,7 +133,10 @@ func Test_array_errors(t *testing.T) {
}
defer db.Close()

array.Register(db)
err = array.Register(db)
if err != nil {
t.Fatal(err)
}

err = db.Exec(`SELECT * FROM array()`)
if err == nil {
Expand Down
9 changes: 5 additions & 4 deletions ext/blobio/blob.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ import (
// along with the [sqlite3.Blob] handle.
//
// https://sqlite.org/c3ref/blob.html
func Register(db *sqlite3.Conn) {
db.CreateFunction("readblob", 6, 0, readblob)
db.CreateFunction("writeblob", 6, 0, writeblob)
db.CreateFunction("openblob", -1, 0, openblob)
func Register(db *sqlite3.Conn) error {
return errors.Join(
db.CreateFunction("readblob", 6, 0, readblob),
db.CreateFunction("writeblob", 6, 0, writeblob),
db.CreateFunction("openblob", -1, 0, openblob))
}

// OpenCallback is the type for the openblob callback.
Expand Down
16 changes: 6 additions & 10 deletions ext/blobio/blob_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@ import (

func Example() {
// Open the database, registering the extension.
db, err := driver.Open("file:/test.db?vfs=memdb", func(conn *sqlite3.Conn) error {
blobio.Register(conn)
return nil
})
db, err := driver.Open("file:/test.db?vfs=memdb", blobio.Register)

if err != nil {
log.Fatal(err)
Expand Down Expand Up @@ -60,6 +57,11 @@ func Example() {
// Hello BLOB!
}

func init() {
sqlite3.AutoExtension(blobio.Register)
sqlite3.AutoExtension(array.Register)
}

func Test_readblob(t *testing.T) {
t.Parallel()

Expand All @@ -69,9 +71,6 @@ func Test_readblob(t *testing.T) {
}
defer db.Close()

blobio.Register(db)
array.Register(db)

err = db.Exec(`SELECT readblob()`)
if err == nil {
t.Fatal("want error")
Expand Down Expand Up @@ -129,9 +128,6 @@ func Test_openblob(t *testing.T) {
}
defer db.Close()

blobio.Register(db)
array.Register(db)

err = db.Exec(`SELECT openblob()`)
if err == nil {
t.Fatal("want error")
Expand Down
4 changes: 2 additions & 2 deletions ext/bloom/bloom.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ import (
// Register registers the bloom_filter virtual table:
//
// CREATE VIRTUAL TABLE foo USING bloom_filter(nElements, falseProb, kHashes)
func Register(db *sqlite3.Conn) {
sqlite3.CreateModule(db, "bloom_filter", create, connect)
func Register(db *sqlite3.Conn) error {
return sqlite3.CreateModule(db, "bloom_filter", create, connect)
}

type bloom struct {
Expand Down
8 changes: 4 additions & 4 deletions ext/bloom/bloom_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ import (
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
)

func init() {
sqlite3.AutoExtension(bloom.Register)
}

func TestRegister(t *testing.T) {
t.Parallel()

Expand All @@ -21,8 +25,6 @@ func TestRegister(t *testing.T) {
}
defer db.Close()

bloom.Register(db)

err = db.Exec(`
CREATE VIRTUAL TABLE sports_cars USING bloom_filter(20);
INSERT INTO sports_cars VALUES ('ferrari'), ('lamborghini'), ('alfa romeo')
Expand Down Expand Up @@ -90,8 +92,6 @@ func Test_compatible(t *testing.T) {
}
defer db.Close()

bloom.Register(db)

query, _, err := db.Prepare(`SELECT COUNT(*) FROM plants(?)`)
if err != nil {
t.Fatal(err)
Expand Down
8 changes: 4 additions & 4 deletions ext/csv/csv.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ import (

// Register registers the CSV virtual table.
// If a filename is specified, [os.Open] is used to open the file.
func Register(db *sqlite3.Conn) {
RegisterFS(db, osutil.FS{})
func Register(db *sqlite3.Conn) error {
return RegisterFS(db, osutil.FS{})
}

// RegisterFS registers the CSV virtual table.
// If a filename is specified, fsys is used to open the file.
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) {
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error {
declare := func(db *sqlite3.Conn, _, _, _ string, arg ...string) (_ *table, err error) {
var (
filename string
Expand Down Expand Up @@ -118,7 +118,7 @@ func RegisterFS(db *sqlite3.Conn, fsys fs.FS) {
return table, nil
}

sqlite3.CreateModule(db, "csv", declare, declare)
return sqlite3.CreateModule(db, "csv", declare, declare)
}

type table struct {
Expand Down
15 changes: 8 additions & 7 deletions ext/csv/csv_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ func Example() {
}
defer db.Close()

csv.Register(db)
err = csv.Register(db)
if err != nil {
log.Fatal(err)
}

err = db.Exec(`
CREATE VIRTUAL TABLE eurofxref USING csv(
Expand Down Expand Up @@ -51,6 +54,10 @@ func Example() {
// On Twosday, 1€ = $1.1342
}

func init() {
sqlite3.AutoExtension(csv.Register)
}

func TestRegister(t *testing.T) {
t.Parallel()

Expand All @@ -60,8 +67,6 @@ func TestRegister(t *testing.T) {
}
defer db.Close()

csv.Register(db)

const data = `
# Comment
"Rob" "Pike" rob
Expand Down Expand Up @@ -124,8 +129,6 @@ func TestAffinity(t *testing.T) {
}
defer db.Close()

csv.Register(db)

const data = "01\n0.10\ne"
err = db.Exec(`
CREATE VIRTUAL TABLE temp.nums USING csv(
Expand Down Expand Up @@ -168,8 +171,6 @@ func TestRegister_errors(t *testing.T) {
}
defer db.Close()

csv.Register(db)

err = db.Exec(`CREATE VIRTUAL TABLE temp.users USING csv()`)
if err == nil {
t.Fatal("want error")
Expand Down
24 changes: 13 additions & 11 deletions ext/fileio/fileio.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,26 @@ import (

// Register registers SQL functions readfile, writefile, lsmode,
// and the table-valued function fsdir.
func Register(db *sqlite3.Conn) {
RegisterFS(db, nil)
func Register(db *sqlite3.Conn) error {
return RegisterFS(db, nil)
}

// Register registers SQL functions readfile, lsmode,
// and the table-valued function fsdir;
// fsys will be used to read files and list directories.
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) {
db.CreateFunction("lsmode", 1, sqlite3.DETERMINISTIC, lsmode)
db.CreateFunction("readfile", 1, sqlite3.DIRECTONLY, readfile(fsys))
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error {
var err error
if fsys == nil {
db.CreateFunction("writefile", -1, sqlite3.DIRECTONLY, writefile)
err = db.CreateFunction("writefile", -1, sqlite3.DIRECTONLY, writefile)
}
sqlite3.CreateModule(db, "fsdir", nil, func(db *sqlite3.Conn, _, _, _ string, _ ...string) (fsdir, error) {
err := db.DeclareVTab(`CREATE TABLE x(name,mode,mtime TIMESTAMP,data,path HIDDEN,dir HIDDEN)`)
db.VTabConfig(sqlite3.VTAB_DIRECTONLY)
return fsdir{fsys}, err
})
return errors.Join(err,
db.CreateFunction("readfile", 1, sqlite3.DIRECTONLY, readfile(fsys)),
db.CreateFunction("lsmode", 1, sqlite3.DETERMINISTIC, lsmode),
sqlite3.CreateModule(db, "fsdir", nil, func(db *sqlite3.Conn, _, _, _ string, _ ...string) (fsdir, error) {
err := db.DeclareVTab(`CREATE TABLE x(name,mode,mtime TIMESTAMP,data,path HIDDEN,dir HIDDEN)`)
db.VTabConfig(sqlite3.VTAB_DIRECTONLY)
return fsdir{fsys}, err
}))
}

func lsmode(ctx sqlite3.Context, arg ...sqlite3.Value) {
Expand Down
5 changes: 1 addition & 4 deletions ext/fileio/fileio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@ import (
func Test_lsmode(t *testing.T) {
t.Parallel()

db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
fileio.Register(c)
return nil
})
db, err := driver.Open(":memory:", fileio.Register)
if err != nil {
t.Fatal(err)
}
Expand Down
5 changes: 4 additions & 1 deletion ext/fileio/fsdir_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ func Test_fsdir_errors(t *testing.T) {
}
defer db.Close()

fileio.Register(db)
err = fileio.Register(db)
if err != nil {
t.Fatal(err)
}

err = db.Exec(`SELECT name FROM fsdir()`)
if err == nil {
Expand Down
6 changes: 1 addition & 5 deletions ext/fileio/write_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"testing"
"time"

"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
Expand All @@ -16,10 +15,7 @@ import (
func Test_writefile(t *testing.T) {
t.Parallel()

db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
Register(c)
return nil
})
db, err := driver.Open(":memory:", Register)
if err != nil {
t.Fatal(err)
}
Expand Down
Loading
Loading