diff --git a/pkg/transport/listener.go b/pkg/transport/listener.go index 7371947ef767..60a2f7a3345d 100644 --- a/pkg/transport/listener.go +++ b/pkg/transport/listener.go @@ -100,7 +100,7 @@ func (info TLSInfo) Empty() bool { return info.CertFile == "" && info.KeyFile == "" } -func SelfCert(dirpath string, hosts []string) (info TLSInfo, err error) { +func SelfCert(dirpath string, hosts []string, additionalUsages ...x509.ExtKeyUsage) (info TLSInfo, err error) { if err = os.MkdirAll(dirpath, 0700); err != nil { return } @@ -129,7 +129,7 @@ func SelfCert(dirpath string, hosts []string) (info TLSInfo, err error) { NotAfter: time.Now().Add(365 * (24 * time.Hour)), KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + ExtKeyUsage: append([]x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, additionalUsages...), BasicConstraintsValid: true, } diff --git a/pkg/transport/listener_test.go b/pkg/transport/listener_test.go index 6cc44a118f9f..7ef4b5a9e031 100644 --- a/pkg/transport/listener_test.go +++ b/pkg/transport/listener_test.go @@ -16,20 +16,26 @@ package transport import ( "crypto/tls" + "crypto/x509" "errors" "io/ioutil" + "net" "net/http" "os" "testing" "time" ) -func createSelfCert() (*TLSInfo, func(), error) { +func createSelfCert(hosts ...string) (*TLSInfo, func(), error) { + return createSelfCertEx("127.0.0.1") +} + +func createSelfCertEx(host string, additionalUsages ...x509.ExtKeyUsage) (*TLSInfo, func(), error) { d, terr := ioutil.TempDir("", "etcd-test-tls-") if terr != nil { return nil, nil, terr } - info, err := SelfCert(d, []string{"127.0.0.1"}) + info, err := SelfCert(d, []string{host + ":0"}, additionalUsages...) if err != nil { return nil, nil, err } @@ -70,10 +76,108 @@ func testNewListenerTLSInfoAccept(t *testing.T, tlsInfo TLSInfo) { } defer conn.Close() if _, ok := conn.(*tls.Conn); !ok { - t.Errorf("failed to accept *tls.Conn") + t.Error("failed to accept *tls.Conn") } } +// TestNewListenerTLSInfoSkipClientSANVerify tests that if client IP address mismatches +// with specified address in its certificate the connection is still accepted +// if the flag SkipClientSANVerify is set (i.e. checkSAN() is disabled for the client side) +func TestNewListenerTLSInfoSkipClientSANVerify(t *testing.T) { + tests := []struct { + skipClientSANVerify bool + goodClientHost bool + acceptExpected bool + }{ + {false, true, true}, + {false, false, false}, + {true, true, true}, + {true, false, true}, + } + for _, test := range tests { + testNewListenerTLSInfoClientCheck(t, test.skipClientSANVerify, test.goodClientHost, test.acceptExpected) + } +} + +func testNewListenerTLSInfoClientCheck(t *testing.T, skipClientSANVerify, goodClientHost, acceptExpected bool) { + tlsInfo, del, err := createSelfCert() + if err != nil { + t.Fatalf("unable to create cert: %v", err) + } + defer del() + + host := "127.0.0.222" + if goodClientHost { + host = "127.0.0.1" + } + clientTLSInfo, del2, err := createSelfCertEx(host, x509.ExtKeyUsageClientAuth) + if err != nil { + t.Fatalf("unable to create cert: %v", err) + } + defer del2() + + tlsInfo.SkipClientSANVerify = skipClientSANVerify + tlsInfo.CAFile = clientTLSInfo.CertFile + + rootCAs := x509.NewCertPool() + loaded, err := ioutil.ReadFile(tlsInfo.CertFile) + if err != nil { + t.Fatalf("unexpected missing certfile: %v", err) + } + rootCAs.AppendCertsFromPEM(loaded) + + clientCert, err := tls.LoadX509KeyPair(clientTLSInfo.CertFile, clientTLSInfo.KeyFile) + if err != nil { + t.Fatalf("unable to create peer cert: %v", err) + } + + tlsConfig := &tls.Config{} + tlsConfig.InsecureSkipVerify = false + tlsConfig.Certificates = []tls.Certificate{clientCert} + tlsConfig.RootCAs = rootCAs + + ln, err := NewListener("127.0.0.1:0", "https", tlsInfo) + if err != nil { + t.Fatalf("unexpected NewListener error: %v", err) + } + defer ln.Close() + + tr := &http.Transport{TLSClientConfig: tlsConfig} + cli := &http.Client{Transport: tr} + chClientErr := make(chan error) + go func() { + _, err := cli.Get("https://" + ln.Addr().String()) + chClientErr <- err + }() + + chAcceptErr := make(chan error) + chAcceptConn := make(chan net.Conn) + go func() { + conn, err := ln.Accept() + if err != nil { + chAcceptErr <- err + } else { + chAcceptConn <- conn + } + }() + + select { + case <-chClientErr: + if acceptExpected { + t.Errorf("accepted for good client address: skipClientSANVerify=%t, goodClientHost=%t", skipClientSANVerify, goodClientHost) + } + case acceptErr := <-chAcceptErr: + t.Fatalf("unexpected Accept error: %v", acceptErr) + case conn := <-chAcceptConn: + defer conn.Close() + if _, ok := conn.(*tls.Conn); !ok { + t.Errorf("failed to accept *tls.Conn") + } + if !acceptExpected { + t.Errorf("accepted for bad client address: skipClientSANVerify=%t, goodClientHost=%t", skipClientSANVerify, goodClientHost) + } + } +} func TestNewListenerTLSEmptyInfo(t *testing.T) { _, err := NewListener("127.0.0.1:0", "https", nil) if err == nil {