Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shared memory API. #81

Merged
merged 1 commit into from
Apr 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions internal/util/unwrap.go

This file was deleted.

33 changes: 20 additions & 13 deletions vfs/adiantum/hbsh.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,19 @@ func (h *hbshVFS) OpenParams(name string, flags vfs.OpenFlag, params url.Values)
// Encrypt everything except super journals.
if flags&vfs.OPEN_SUPER_JOURNAL == 0 {
var key []byte
if t, ok := params["key"]; ok {
if name == "" {
key = h.hbsh.KDF("") // Temporary files get a random key.
} else if t, ok := params["key"]; ok {
key = []byte(t[0])
} else if t, ok := params["hexkey"]; ok {
key, _ = hex.DecodeString(t[0])
} else if t, ok := params["textkey"]; ok {
key = h.hbsh.KDF(t[0])
} else if name == "" {
key = h.hbsh.KDF("")
}

if hbsh = h.hbsh.HBSH(key); hbsh == nil {
return nil, flags, sqlite3.NOTADB
// Can't open without a valid key.
return nil, flags, sqlite3.CANTOPEN
}
}

Expand All @@ -51,6 +52,7 @@ func (h *hbshVFS) OpenParams(name string, flags vfs.OpenFlag, params url.Values)
file, flags, err = h.Open(name, flags)
}
if err != nil || hbsh == nil || flags&vfs.OPEN_MEMORY != 0 {
// Error, or no encryption (super journals, memory files).
return file, flags, err
}
return &hbshFile{File: file, hbsh: hbsh}, flags, err
Expand All @@ -72,8 +74,8 @@ func (h *hbshFile) ReadAt(p []byte, off int64) (n int, err error) {
min := (off) &^ (blockSize - 1) // round down
max := (off + int64(len(p)) + blockSize - 1) &^ (blockSize - 1) // round up

// Read one block at a time.
for ; min < max; min += blockSize {
// Read full block.
m, err := h.File.ReadAt(h.block[:], min)
if m != blockSize {
return n, err
Expand All @@ -98,20 +100,24 @@ func (h *hbshFile) WriteAt(p []byte, off int64) (n int, err error) {
min := (off) &^ (blockSize - 1) // round down
max := (off + int64(len(p)) + blockSize - 1) &^ (blockSize - 1) // round up

// Write one block at a time.
for ; min < max; min += blockSize {
binary.LittleEndian.PutUint64(h.tweak[:], uint64(min))
data := h.block[:]

if off > min || len(p[n:]) < blockSize {
// Read full block.
// Partial block write: read-update-write.
m, err := h.File.ReadAt(h.block[:], min)
if m != blockSize {
if err != io.EOF {
return n, err
}
// Writing past the EOF.
// A partially written block is corrupt,
// and also considered to be past the EOF.
// Writing past the EOF:
// We're either appending an entirely new block,
// or the final block was only partially written.
// A partially written block can't be decripted,
// and is as good as corrupt.
// Either way, zero pad the file to the next block size.
clear(data)
}

Expand All @@ -124,7 +130,6 @@ func (h *hbshFile) WriteAt(p []byte, off int64) (n int, err error) {
t := copy(data, p[n:])
h.hbsh.Encrypt(h.block[:], h.tweak[:])

// Write full block.
m, err := h.File.WriteAt(h.block[:], min)
if m != blockSize {
return n, err
Expand Down Expand Up @@ -155,9 +160,11 @@ func (h *hbshFile) DeviceCharacteristics() vfs.DeviceCharacteristic {
vfs.IOCAP_BATCH_ATOMIC)
}

// This is needed for shared memory.
func (h *hbshFile) Unwrap() vfs.File {
return h.File
func (h *hbshFile) SharedMemory() vfs.SharedMemory {
if shm, ok := h.File.(vfs.FileSharedMemory); ok {
return shm.SharedMemory()
}
return nil
}

// Wrap optional methods.
Expand Down
22 changes: 21 additions & 1 deletion vfs/api.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
// Package vfs wraps the C SQLite VFS API.
package vfs

import "net/url"
import (
"context"
"net/url"

"github.com/tetratelabs/wazero/api"
)

// A VFS defines the interface between the SQLite core and the underlying operating system.
//
Expand Down Expand Up @@ -129,3 +134,18 @@ type FileBatchAtomicWrite interface {
CommitAtomicWrite() error
RollbackAtomicWrite() error
}

// FileSharedMemory extends File to possibly implement shared memory.
// It's OK for SharedMemory to return nil.
type FileSharedMemory interface {
File
SharedMemory() SharedMemory
}

// SharedMemory is a shared memory implementation.
// This cannot be externally implemented.
type SharedMemory interface {
shmMap(context.Context, api.Module, int32, int32, bool) (uint32, error)
shmLock(int32, int32, _ShmFlag) error
shmUnmap(bool)
}
8 changes: 0 additions & 8 deletions vfs/file.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package vfs

import (
"context"
"errors"
"io"
"io/fs"
Expand All @@ -12,7 +11,6 @@ import (
"syscall"

"github.com/ncruces/go-sqlite3/util/osutil"
"github.com/tetratelabs/wazero/api"
)

type vfsOS struct{}
Expand Down Expand Up @@ -215,9 +213,3 @@ func (f *vfsFile) PowersafeOverwrite() bool { return f.psow }
func (f *vfsFile) PersistentWAL() bool { return f.keepWAL }
func (f *vfsFile) SetPowersafeOverwrite(psow bool) { f.psow = psow }
func (f *vfsFile) SetPersistentWAL(keepWAL bool) { f.keepWAL = keepWAL }

type fileShm interface {
shmMap(context.Context, api.Module, int32, int32, bool) (uint32, error)
shmLock(int32, int32, _ShmFlag) error
shmUnmap(bool)
}
2 changes: 2 additions & 0 deletions vfs/shm.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ const (
_SHM_DMS = _SHM_BASE + _SHM_NLOCK
)

func (f *vfsFile) SharedMemory() SharedMemory { return f }

func (f *vfsFile) shmMap(ctx context.Context, mod api.Module, id, size int32, extend bool) (uint32, error) {
// Ensure size is a multiple of the OS page size.
if int(size)&(unix.Getpagesize()-1) != 0 {
Expand Down
12 changes: 7 additions & 5 deletions vfs/vfs.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,10 @@ func vfsOpen(ctx context.Context, mod api.Module, pVfs, zPath, pFile uint32, fla
util.WriteUint32(mod, pOutFlags, uint32(flags))
}
if pOutVFS != 0 && util.CanMapFiles(ctx) {
if _, ok := util.Unwrap(file).(fileShm); ok {
util.WriteUint32(mod, pOutVFS, 1)
if f, ok := file.(FileSharedMemory); ok {
if f.SharedMemory() != nil {
util.WriteUint32(mod, pOutVFS, 1)
}
}
}
vfsFileRegister(ctx, mod, pFile, file)
Expand Down Expand Up @@ -366,7 +368,7 @@ func vfsShmBarrier(ctx context.Context, mod api.Module, pFile uint32) {
}

func vfsShmMap(ctx context.Context, mod api.Module, pFile uint32, iRegion, szRegion int32, bExtend, pp uint32) _ErrorCode {
file := util.Unwrap(vfsFileGet(ctx, mod, pFile)).(fileShm)
file := vfsFileGet(ctx, mod, pFile).(FileSharedMemory).SharedMemory()
p, err := file.shmMap(ctx, mod, iRegion, szRegion, bExtend != 0)
if err != nil {
return vfsErrorCode(err, _IOERR_SHMMAP)
Expand All @@ -376,13 +378,13 @@ func vfsShmMap(ctx context.Context, mod api.Module, pFile uint32, iRegion, szReg
}

func vfsShmLock(ctx context.Context, mod api.Module, pFile uint32, offset, n int32, flags _ShmFlag) _ErrorCode {
file := util.Unwrap(vfsFileGet(ctx, mod, pFile)).(fileShm)
file := vfsFileGet(ctx, mod, pFile).(FileSharedMemory).SharedMemory()
err := file.shmLock(offset, n, flags)
return vfsErrorCode(err, _IOERR_SHMLOCK)
}

func vfsShmUnmap(ctx context.Context, mod api.Module, pFile, bDelete uint32) _ErrorCode {
file := util.Unwrap(vfsFileGet(ctx, mod, pFile)).(fileShm)
file := vfsFileGet(ctx, mod, pFile).(FileSharedMemory).SharedMemory()
file.shmUnmap(bDelete != 0)
return _OK
}
Expand Down