From 4e5a26183ecb4f9a0f85c8f8dbe7982885435436 Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Tue, 12 Dec 2023 13:04:53 +0000 Subject: [PATCH] ssh: close net.Conn on all NewServerConn errors This PR ensures that the net.Conn passed to ssh.NewServerConn is closed on all error return paths, not just after a failed handshake. This matches the behavior of ssh.NewClientConn. Change-Id: Id8a51d10ae8d575cbbe26f2ef6b37de7cca840ec GitHub-Last-Rev: 81bb2e58a881a9a85935740bda06b034b32a8ce3 GitHub-Pull-Request: golang/crypto#279 Reviewed-on: https://go-review.googlesource.com/c/crypto/+/549095 Run-TryBot: Nicola Murino Auto-Submit: Nicola Murino Reviewed-by: Roland Shoemaker Reviewed-by: Nicola Murino LUCI-TryBot-Result: Go LUCI Reviewed-by: Michael Pratt TryBot-Result: Gopher Robot --- ssh/server.go | 2 ++ ssh/server_test.go | 73 ++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 66 insertions(+), 9 deletions(-) diff --git a/ssh/server.go b/ssh/server.go index 7f0c236a9a..c2dfe3268c 100644 --- a/ssh/server.go +++ b/ssh/server.go @@ -213,6 +213,7 @@ func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewCha } else { for _, algo := range fullConf.PublicKeyAuthAlgorithms { if !contains(supportedPubKeyAuthAlgos, algo) { + c.Close() return nil, nil, nil, fmt.Errorf("ssh: unsupported public key authentication algorithm %s", algo) } } @@ -220,6 +221,7 @@ func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewCha // Check if the config contains any unsupported key exchanges for _, kex := range fullConf.KeyExchanges { if _, ok := serverForbiddenKexAlgos[kex]; ok { + c.Close() return nil, nil, nil, fmt.Errorf("ssh: unsupported key exchange %s for server", kex) } } diff --git a/ssh/server_test.go b/ssh/server_test.go index 2145dce06f..a9b2bce0c1 100644 --- a/ssh/server_test.go +++ b/ssh/server_test.go @@ -5,7 +5,11 @@ package ssh import ( + "io" + "net" + "sync/atomic" "testing" + "time" ) func TestClientAuthRestrictedPublicKeyAlgos(t *testing.T) { @@ -59,27 +63,78 @@ func TestClientAuthRestrictedPublicKeyAlgos(t *testing.T) { } func TestNewServerConnValidationErrors(t *testing.T) { - c1, c2, err := netPipe() - if err != nil { - t.Fatalf("netPipe: %v", err) - } - defer c1.Close() - defer c2.Close() - serverConf := &ServerConfig{ PublicKeyAuthAlgorithms: []string{CertAlgoRSAv01}, } - _, _, _, err = NewServerConn(c1, serverConf) + c := &markerConn{} + _, _, _, err := NewServerConn(c, serverConf) if err == nil { t.Fatal("NewServerConn with invalid public key auth algorithms succeeded") } + if !c.isClosed() { + t.Fatal("NewServerConn with invalid public key auth algorithms left connection open") + } + if c.isUsed() { + t.Fatal("NewServerConn with invalid public key auth algorithms used connection") + } + serverConf = &ServerConfig{ Config: Config{ KeyExchanges: []string{kexAlgoDHGEXSHA256}, }, } - _, _, _, err = NewServerConn(c1, serverConf) + c = &markerConn{} + _, _, _, err = NewServerConn(c, serverConf) if err == nil { t.Fatal("NewServerConn with unsupported key exchange succeeded") } + if !c.isClosed() { + t.Fatal("NewServerConn with unsupported key exchange left connection open") + } + if c.isUsed() { + t.Fatal("NewServerConn with unsupported key exchange used connection") + } +} + +type markerConn struct { + closed uint32 + used uint32 } + +func (c *markerConn) isClosed() bool { + return atomic.LoadUint32(&c.closed) != 0 +} + +func (c *markerConn) isUsed() bool { + return atomic.LoadUint32(&c.used) != 0 +} + +func (c *markerConn) Close() error { + atomic.StoreUint32(&c.closed, 1) + return nil +} + +func (c *markerConn) Read(b []byte) (n int, err error) { + atomic.StoreUint32(&c.used, 1) + if atomic.LoadUint32(&c.closed) != 0 { + return 0, net.ErrClosed + } else { + return 0, io.EOF + } +} + +func (c *markerConn) Write(b []byte) (n int, err error) { + atomic.StoreUint32(&c.used, 1) + if atomic.LoadUint32(&c.closed) != 0 { + return 0, net.ErrClosed + } else { + return 0, io.ErrClosedPipe + } +} + +func (*markerConn) LocalAddr() net.Addr { return nil } +func (*markerConn) RemoteAddr() net.Addr { return nil } + +func (*markerConn) SetDeadline(t time.Time) error { return nil } +func (*markerConn) SetReadDeadline(t time.Time) error { return nil } +func (*markerConn) SetWriteDeadline(t time.Time) error { return nil }