Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server.go: "/" for windows #571

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"io"
"io/fs"
"io/ioutil"
"os"
"path/filepath"
Expand All @@ -21,6 +22,18 @@ const (
SftpServerWorkerCount = 8
)

type file interface {
Stat() (os.FileInfo, error)
ReadAt(b []byte, off int64) (int, error)
WriteAt(b []byte, off int64) (int, error)
Readdir(int) ([]os.FileInfo, error)
Name() string
Truncate(int64) error
Chmod(mode fs.FileMode) error
Chown(uid, gid int) error
Close() error
}

// Server is an SSH File Transfer Protocol (sftp) server.
// This is intended to provide the sftp subsystem to an ssh server daemon.
// This implementation currently supports most of sftp server protocol version 3,
Expand All @@ -30,13 +43,14 @@ type Server struct {
debugStream io.Writer
readOnly bool
pktMgr *packetManager
openFiles map[string]*os.File
openFiles map[string]file
openFilesLock sync.RWMutex
handleCount int
workDir string
winRoot bool
}

func (svr *Server) nextHandle(f *os.File) string {
func (svr *Server) nextHandle(f file) string {
svr.openFilesLock.Lock()
defer svr.openFilesLock.Unlock()
svr.handleCount++
Expand All @@ -56,7 +70,7 @@ func (svr *Server) closeHandle(handle string) error {
return EBADF
}

func (svr *Server) getHandle(handle string) (*os.File, bool) {
func (svr *Server) getHandle(handle string) (file, bool) {
svr.openFilesLock.RLock()
defer svr.openFilesLock.RUnlock()
f, ok := svr.openFiles[handle]
Expand Down Expand Up @@ -85,7 +99,7 @@ func NewServer(rwc io.ReadWriteCloser, options ...ServerOption) (*Server, error)
serverConn: svrConn,
debugStream: ioutil.Discard,
pktMgr: newPktMgr(svrConn),
openFiles: make(map[string]*os.File),
openFiles: make(map[string]file),
}

for _, o := range options {
Expand Down Expand Up @@ -116,6 +130,14 @@ func ReadOnly() ServerOption {
}
}

// configures a Server to serve a virtual '/' for windows that lists all drives
powellnorma marked this conversation as resolved.
Show resolved Hide resolved
func WindowsRootEnumeratesDrives() ServerOption {
return func(s *Server) error {
s.winRoot = true
return nil
}
}

// WithAllocator enable the allocator.
// After processing a packet we keep in memory the allocated slices
// and we reuse them for new packets.
Expand Down Expand Up @@ -462,7 +484,7 @@ func (p *sshFxpOpenPacket) respond(svr *Server) responsePacket {
osFlags |= os.O_EXCL
}

f, err := os.OpenFile(svr.toLocalPath(p.Path), osFlags, 0o644)
f, err := svr.openfile(svr.toLocalPath(p.Path), osFlags, 0o644)
if err != nil {
return statusFromError(p.ID, err)
}
Expand Down
13 changes: 13 additions & 0 deletions server_posix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
//go:build !windows
// +build !windows

package sftp

import (
"io/fs"
"os"
)

func (s *Server) openfile(path string, flag int, mode fs.FileMode) (file, error) {
return os.OpenFile(path, flag, mode)
}
122 changes: 121 additions & 1 deletion server_windows.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
package sftp

import (
"fmt"
"io"
"io/fs"
"os"
"path"
"path/filepath"

"golang.org/x/sys/windows"
)

func (s *Server) toLocalPath(p string) string {
Expand All @@ -12,7 +18,11 @@ func (s *Server) toLocalPath(p string) string {

lp := filepath.FromSlash(p)

if path.IsAbs(p) {
if path.IsAbs(p) { // starts with '/'
if len(p) == 1 && s.winRoot {
return `\\.\` // for openfile
}

tmp := lp
for len(tmp) > 0 && tmp[0] == '\\' {
tmp = tmp[1:]
Expand All @@ -33,7 +43,117 @@ func (s *Server) toLocalPath(p string) string {
// e.g. "/C:" to "C:\\"
return tmp
}

if s.winRoot {
// Make it so that "/Windows" is not found, and "/c:/Windows" has to be used
return `\\.\` + tmp
}
}

return lp
}

func bitsToDrives(bitmap uint32) []string {
var drive rune = 'a'
var drives []string

for bitmap != 0 {
if bitmap&1 == 1 {
drives = append(drives, string(drive)+":")
}
drive++
bitmap >>= 1
}

return drives
}

func getDrives() ([]string, error) {
mask, err := windows.GetLogicalDrives()
if err != nil {
return nil, fmt.Errorf("GetLogicalDrives: %w", err)
}
return bitsToDrives(mask), nil
}

type driveInfo struct {
fs.FileInfo
name string
}

func (i *driveInfo) Name() string {
return i.name // since the Name() returned from a os.Stat("C:\\") is "\\"
}

type winRoot struct {
drives []string
}

func newWinRoot() (*winRoot, error) {
drives, err := getDrives()
if err != nil {
return nil, err
}
return &winRoot{
drives: drives,
}, nil
}

func (f *winRoot) Readdir(n int) ([]os.FileInfo, error) {
drives := f.drives
if n > 0 && len(drives) > n {
drives = drives[:n]
}
f.drives = f.drives[len(drives):]
if len(drives) == 0 {
return nil, io.EOF
}

var infos []os.FileInfo
for _, drive := range drives {
fi, err := os.Stat(drive + `\`)
if err != nil {
return nil, err
}

di := &driveInfo{
FileInfo: fi,
name: drive,
}
infos = append(infos, di)
}

return infos, nil
}

func (f *winRoot) Stat() (os.FileInfo, error) {
return nil, os.ErrPermission
}
func (f *winRoot) ReadAt(b []byte, off int64) (int, error) {
return 0, os.ErrPermission
}
func (f *winRoot) WriteAt(b []byte, off int64) (int, error) {
return 0, os.ErrPermission
}
func (f *winRoot) Name() string {
return "/"
}
func (f *winRoot) Truncate(int64) error {
return os.ErrPermission
}
func (f *winRoot) Chmod(mode fs.FileMode) error {
return os.ErrPermission
}
func (f *winRoot) Chown(uid, gid int) error {
return os.ErrPermission
}
func (f *winRoot) Close() error {
return nil
}

func (s *Server) openfile(path string, flag int, mode fs.FileMode) (file, error) {
if path == `\\.\` && s.winRoot {
return newWinRoot()
}
return os.OpenFile(path, flag, mode)
}