diff --git a/vfs/shm_bsd.go b/vfs/shm_bsd.go index 94e5274c..ed6b8970 100644 --- a/vfs/shm_bsd.go +++ b/vfs/shm_bsd.go @@ -20,11 +20,7 @@ import ( // [EXCLUSIVE locking mode]: https://sqlite.org/pragma.html#pragma_locking_mode const SupportsSharedMemory = true -const ( - _SHM_NLOCK = 8 - _SHM_BASE = 120 - _SHM_DMS = _SHM_BASE + _SHM_NLOCK -) +const _SHM_NLOCK = 8 func (f *vfsFile) SharedMemory() SharedMemory { return f.shm } @@ -100,22 +96,19 @@ func (s *vfsShm) Close() error { return err } -func (s *vfsShm) shmOpen() error { +func (s *vfsShm) shmOpen() (err error) { if s.vfsShmFile != nil { return nil } - var flag int - if s.readOnly { - flag = unix.O_RDONLY - } else { - flag = unix.O_RDWR - } + // Open file read-write, as it will be shared. f, err := os.OpenFile(s.path, - flag|unix.O_CREAT|unix.O_NOFOLLOW, 0666) + unix.O_RDWR|unix.O_CREAT|unix.O_NOFOLLOW, 0666) if err != nil { return _CANTOPEN } + // Close if file if it's not nil. + defer func() { f.Close() }() fi, err := f.Stat() if err != nil { @@ -125,19 +118,34 @@ func (s *vfsShm) shmOpen() error { vfsShmFilesMtx.Lock() defer vfsShmFilesMtx.Unlock() + // Find a shared file, increase the reference count. for _, g := range vfsShmFiles { if g != nil && os.SameFile(fi, g.info) { - f.Close() g.refs++ s.vfsShmFile = g return nil } } + + // Lock and truncate the file, if not readonly. + if s.readOnly { + err = _READONLY_CANTINIT + } else { + if rc := osWriteLock(f, 0, 0, 0); rc != _OK { + return rc + } + if err := f.Truncate(0); err != nil { + return _IOERR_SHMOPEN + } + } + + // Add the new shared file. s.vfsShmFile = &vfsShmFile{ File: f, info: fi, refs: 1, } + f = nil add := true for i, g := range vfsShmFiles { if g == nil { @@ -148,17 +156,7 @@ func (s *vfsShm) shmOpen() error { if add { vfsShmFiles = append(vfsShmFiles, s.vfsShmFile) } - - if s.readOnly { - return _READONLY_CANTINIT - } - if rc := osWriteLock(f, _SHM_DMS, 1, 0); rc != _OK { - return rc - } - if err := f.Truncate(0); err != nil { - return _IOERR_SHMOPEN - } - return nil + return err } func (s *vfsShm) shmMap(ctx context.Context, mod api.Module, id, size int32, extend bool) (uint32, error) {