diff --git a/client.go b/client.go index 67dca883..331e5c1f 100644 --- a/client.go +++ b/client.go @@ -574,6 +574,8 @@ func (c *Client) Start() (addr net.Addr, err error) { c.config.TLSConfig = &tls.Config{ Certificates: []tls.Certificate{cert}, + ClientAuth: tls.RequireAndVerifyClientCert, + MinVersion: tls.VersionTLS12, ServerName: "localhost", } } @@ -774,7 +776,7 @@ func (c *Client) Start() (addr net.Addr, err error) { } // loadServerCert is used by AutoMTLS to read an x.509 cert returned by the -// server, and load it as the RootCA for the client TLSConfig. +// server, and load it as the RootCA and ClientCA for the client TLSConfig. func (c *Client) loadServerCert(cert string) error { certPool := x509.NewCertPool() @@ -791,6 +793,7 @@ func (c *Client) loadServerCert(cert string) error { certPool.AddCert(x509Cert) c.config.TLSConfig.RootCAs = certPool + c.config.TLSConfig.ClientCAs = certPool return nil } diff --git a/server.go b/server.go index 7a58cc39..e1349991 100644 --- a/server.go +++ b/server.go @@ -304,13 +304,13 @@ func Serve(opts *ServeConfig) { certPEM, keyPEM, err := generateCert() if err != nil { - logger.Error("failed to generate client certificate", "error", err) + logger.Error("failed to generate server certificate", "error", err) panic(err) } cert, err := tls.X509KeyPair(certPEM, keyPEM) if err != nil { - logger.Error("failed to parse client certificate", "error", err) + logger.Error("failed to parse server certificate", "error", err) panic(err) } @@ -319,6 +319,8 @@ func Serve(opts *ServeConfig) { ClientAuth: tls.RequireAndVerifyClientCert, ClientCAs: clientCertPool, MinVersion: tls.VersionTLS12, + RootCAs: clientCertPool, + ServerName: "localhost", } // We send back the raw leaf cert data for the client rather than the diff --git a/server_test.go b/server_test.go index 5fec4502..1aaa0b8d 100644 --- a/server_test.go +++ b/server_test.go @@ -82,6 +82,75 @@ func TestServer_testMode(t *testing.T) { t.Logf("HELLO") } +func TestServer_testMode_AutoMTLS(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + closeCh := make(chan struct{}) + go Serve(&ServeConfig{ + HandshakeConfig: testVersionedHandshake, + VersionedPlugins: map[int]PluginSet{ + 2: testGRPCPluginMap, + }, + GRPCServer: DefaultGRPCServer, + Logger: hclog.NewNullLogger(), + Test: &ServeTestConfig{ + Context: ctx, + ReattachConfigCh: nil, + CloseCh: closeCh, + }, + }) + + // Connect! + process := helperProcess("test-mtls") + c := NewClient(&ClientConfig{ + Cmd: process, + HandshakeConfig: testVersionedHandshake, + VersionedPlugins: map[int]PluginSet{ + 2: testGRPCPluginMap, + }, + AllowedProtocols: []Protocol{ProtocolGRPC}, + AutoMTLS: true, + }) + client, err := c.Client() + if err != nil { + t.Fatalf("err: %s", err) + } + + // Pinging should work + if err := client.Ping(); err != nil { + t.Fatalf("should not err: %s", err) + } + + // Grab the impl + raw, err := client.Dispense("test") + if err != nil { + t.Fatalf("err should be nil, got %s", err) + } + + tester, ok := raw.(testInterface) + if !ok { + t.Fatalf("bad: %#v", raw) + } + + n := tester.Double(3) + if n != 6 { + t.Fatal("invalid response", n) + } + + // ensure we can make use of bidirectional communication with AutoMTLS + // enabled + err = tester.Bidirectional() + if err != nil { + t.Fatal("invalid response", err) + } + + c.Kill() + // Canceling should cause an exit + cancel() + <-closeCh +} + func TestRmListener_impl(t *testing.T) { var _ net.Listener = new(rmListener) } @@ -145,7 +214,6 @@ func TestProtocolSelection_no_server(t *testing.T) { if protocol != ProtocolNetRPC { t.Fatalf("bad protocol %s", protocol) } - } func TestServer_testStdLogger(t *testing.T) {