Skip to content

Commit

Permalink
Shared memory API.
Browse files Browse the repository at this point in the history
  • Loading branch information
ncruces committed Apr 21, 2024
1 parent 07241d0 commit 62b79d2
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 35 deletions.
8 changes: 0 additions & 8 deletions internal/util/unwrap.go

This file was deleted.

33 changes: 20 additions & 13 deletions vfs/adiantum/hbsh.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,19 @@ func (h *hbshVFS) OpenParams(name string, flags vfs.OpenFlag, params url.Values)
// Encrypt everything except super journals.
if flags&vfs.OPEN_SUPER_JOURNAL == 0 {
var key []byte
if t, ok := params["key"]; ok {
if name == "" {
key = h.hbsh.KDF("") // Temporary files get a random key.
} else if t, ok := params["key"]; ok {
key = []byte(t[0])
} else if t, ok := params["hexkey"]; ok {
key, _ = hex.DecodeString(t[0])
} else if t, ok := params["textkey"]; ok {
key = h.hbsh.KDF(t[0])
} else if name == "" {
key = h.hbsh.KDF("")
}

if hbsh = h.hbsh.HBSH(key); hbsh == nil {
return nil, flags, sqlite3.NOTADB
// Can't open without a valid key.
return nil, flags, sqlite3.CANTOPEN
}
}

Expand All @@ -51,6 +52,7 @@ func (h *hbshVFS) OpenParams(name string, flags vfs.OpenFlag, params url.Values)
file, flags, err = h.Open(name, flags)
}
if err != nil || hbsh == nil || flags&vfs.OPEN_MEMORY != 0 {
// Error, or no encryption (super journals, memory files).
return file, flags, err
}
return &hbshFile{File: file, hbsh: hbsh}, flags, err
Expand All @@ -72,8 +74,8 @@ func (h *hbshFile) ReadAt(p []byte, off int64) (n int, err error) {
min := (off) &^ (blockSize - 1) // round down
max := (off + int64(len(p)) + blockSize - 1) &^ (blockSize - 1) // round up

// Read one block at a time.
for ; min < max; min += blockSize {
// Read full block.
m, err := h.File.ReadAt(h.block[:], min)
if m != blockSize {
return n, err
Expand All @@ -98,20 +100,24 @@ func (h *hbshFile) WriteAt(p []byte, off int64) (n int, err error) {
min := (off) &^ (blockSize - 1) // round down
max := (off + int64(len(p)) + blockSize - 1) &^ (blockSize - 1) // round up

// Write one block at a time.
for ; min < max; min += blockSize {
binary.LittleEndian.PutUint64(h.tweak[:], uint64(min))
data := h.block[:]

if off > min || len(p[n:]) < blockSize {
// Read full block.
// Partial block write: read-update-write.
m, err := h.File.ReadAt(h.block[:], min)
if m != blockSize {
if err != io.EOF {
return n, err
}
// Writing past the EOF.
// A partially written block is corrupt,
// and also considered to be past the EOF.
// Writing past the EOF:
// We're either appending an entirely new block,
// or the final block was only partially written.
// A partially written block can't be decripted,
// and is as good as corrupt.
// Either way, zero pad the file to the next block size.
clear(data)
}

Expand All @@ -124,7 +130,6 @@ func (h *hbshFile) WriteAt(p []byte, off int64) (n int, err error) {
t := copy(data, p[n:])
h.hbsh.Encrypt(h.block[:], h.tweak[:])

// Write full block.
m, err := h.File.WriteAt(h.block[:], min)
if m != blockSize {
return n, err
Expand Down Expand Up @@ -155,9 +160,11 @@ func (h *hbshFile) DeviceCharacteristics() vfs.DeviceCharacteristic {
vfs.IOCAP_BATCH_ATOMIC)
}

// This is needed for shared memory.
func (h *hbshFile) Unwrap() vfs.File {
return h.File
func (h *hbshFile) SharedMemory() vfs.SharedMemory {
if shm, ok := h.File.(vfs.FileSharedMemory); ok {
return shm.SharedMemory()
}
return nil
}

// Wrap optional methods.
Expand Down
22 changes: 21 additions & 1 deletion vfs/api.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
// Package vfs wraps the C SQLite VFS API.
package vfs

import "net/url"
import (
"context"
"net/url"

"github.com/tetratelabs/wazero/api"
)

// A VFS defines the interface between the SQLite core and the underlying operating system.
//
Expand Down Expand Up @@ -129,3 +134,18 @@ type FileBatchAtomicWrite interface {
CommitAtomicWrite() error
RollbackAtomicWrite() error
}

// FileSharedMemory extends File to possibly implement shared memory.
// It's OK for SharedMemory to return nil.
type FileSharedMemory interface {
File
SharedMemory() SharedMemory
}

// SharedMemory is a shared memory implementation.
// This cannot be externally implemented.
type SharedMemory interface {
shmMap(context.Context, api.Module, int32, int32, bool) (uint32, error)
shmLock(int32, int32, _ShmFlag) error
shmUnmap(bool)
}
8 changes: 0 additions & 8 deletions vfs/file.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package vfs

import (
"context"
"errors"
"io"
"io/fs"
Expand All @@ -12,7 +11,6 @@ import (
"syscall"

"github.com/ncruces/go-sqlite3/util/osutil"
"github.com/tetratelabs/wazero/api"
)

type vfsOS struct{}
Expand Down Expand Up @@ -215,9 +213,3 @@ func (f *vfsFile) PowersafeOverwrite() bool { return f.psow }
func (f *vfsFile) PersistentWAL() bool { return f.keepWAL }
func (f *vfsFile) SetPowersafeOverwrite(psow bool) { f.psow = psow }
func (f *vfsFile) SetPersistentWAL(keepWAL bool) { f.keepWAL = keepWAL }

type fileShm interface {
shmMap(context.Context, api.Module, int32, int32, bool) (uint32, error)
shmLock(int32, int32, _ShmFlag) error
shmUnmap(bool)
}
2 changes: 2 additions & 0 deletions vfs/shm.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ const (
_SHM_DMS = _SHM_BASE + _SHM_NLOCK
)

func (f *vfsFile) SharedMemory() SharedMemory { return f }

func (f *vfsFile) shmMap(ctx context.Context, mod api.Module, id, size int32, extend bool) (uint32, error) {
// Ensure size is a multiple of the OS page size.
if int(size)&(unix.Getpagesize()-1) != 0 {
Expand Down
12 changes: 7 additions & 5 deletions vfs/vfs.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,10 @@ func vfsOpen(ctx context.Context, mod api.Module, pVfs, zPath, pFile uint32, fla
util.WriteUint32(mod, pOutFlags, uint32(flags))
}
if pOutVFS != 0 && util.CanMapFiles(ctx) {
if _, ok := util.Unwrap(file).(fileShm); ok {
util.WriteUint32(mod, pOutVFS, 1)
if f, ok := file.(FileSharedMemory); ok {
if f.SharedMemory() != nil {
util.WriteUint32(mod, pOutVFS, 1)
}
}
}
vfsFileRegister(ctx, mod, pFile, file)
Expand Down Expand Up @@ -366,7 +368,7 @@ func vfsShmBarrier(ctx context.Context, mod api.Module, pFile uint32) {
}

func vfsShmMap(ctx context.Context, mod api.Module, pFile uint32, iRegion, szRegion int32, bExtend, pp uint32) _ErrorCode {
file := util.Unwrap(vfsFileGet(ctx, mod, pFile)).(fileShm)
file := vfsFileGet(ctx, mod, pFile).(FileSharedMemory).SharedMemory()
p, err := file.shmMap(ctx, mod, iRegion, szRegion, bExtend != 0)
if err != nil {
return vfsErrorCode(err, _IOERR_SHMMAP)
Expand All @@ -376,13 +378,13 @@ func vfsShmMap(ctx context.Context, mod api.Module, pFile uint32, iRegion, szReg
}

func vfsShmLock(ctx context.Context, mod api.Module, pFile uint32, offset, n int32, flags _ShmFlag) _ErrorCode {
file := util.Unwrap(vfsFileGet(ctx, mod, pFile)).(fileShm)
file := vfsFileGet(ctx, mod, pFile).(FileSharedMemory).SharedMemory()
err := file.shmLock(offset, n, flags)
return vfsErrorCode(err, _IOERR_SHMLOCK)
}

func vfsShmUnmap(ctx context.Context, mod api.Module, pFile, bDelete uint32) _ErrorCode {
file := util.Unwrap(vfsFileGet(ctx, mod, pFile)).(fileShm)
file := vfsFileGet(ctx, mod, pFile).(FileSharedMemory).SharedMemory()
file.shmUnmap(bDelete != 0)
return _OK
}
Expand Down

0 comments on commit 62b79d2

Please sign in to comment.