diff --git a/echo.go b/echo.go index c43998016..381604180 100644 --- a/echo.go +++ b/echo.go @@ -92,6 +92,7 @@ type ( Renderer Renderer Logger Logger IPExtractor IPExtractor + ListenerNetwork string } // Route contains a handler and information for matching against requests. @@ -281,6 +282,7 @@ var ( ErrInvalidRedirectCode = errors.New("invalid redirect status code") ErrCookieNotFound = errors.New("cookie not found") ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte") + ErrInvalidListenerNetwork = errors.New("invalid listener network") ) // Error handlers @@ -302,9 +304,10 @@ func New() (e *Echo) { AutoTLSManager: autocert.Manager{ Prompt: autocert.AcceptTOS, }, - Logger: log.New("echo"), - colorer: color.New(), - maxParam: new(int), + Logger: log.New("echo"), + colorer: color.New(), + maxParam: new(int), + ListenerNetwork: "tcp", } e.Server.Handler = e e.TLSServer.Handler = e @@ -714,7 +717,7 @@ func (e *Echo) StartServer(s *http.Server) (err error) { if s.TLSConfig == nil { if e.Listener == nil { - e.Listener, err = newListener(s.Addr) + e.Listener, err = newListener(s.Addr, e.ListenerNetwork) if err != nil { return err } @@ -725,7 +728,7 @@ func (e *Echo) StartServer(s *http.Server) (err error) { return s.Serve(e.Listener) } if e.TLSListener == nil { - l, err := newListener(s.Addr) + l, err := newListener(s.Addr, e.ListenerNetwork) if err != nil { return err } @@ -754,7 +757,7 @@ func (e *Echo) StartH2CServer(address string, h2s *http2.Server) (err error) { } if e.Listener == nil { - e.Listener, err = newListener(s.Addr) + e.Listener, err = newListener(s.Addr, e.ListenerNetwork) if err != nil { return err } @@ -875,8 +878,11 @@ func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { return } -func newListener(address string) (*tcpKeepAliveListener, error) { - l, err := net.Listen("tcp", address) +func newListener(address, network string) (*tcpKeepAliveListener, error) { + if network != "tcp" && network != "tcp4" && network != "tcp6" { + return nil, ErrInvalidListenerNetwork + } + l, err := net.Listen(network, address) if err != nil { return nil, err } diff --git a/echo_test.go b/echo_test.go index 71dc1ac07..bf3ecfe91 100644 --- a/echo_test.go +++ b/echo_test.go @@ -4,6 +4,7 @@ import ( "bytes" stdContext "context" "errors" + "fmt" "io/ioutil" "net/http" "net/http/httptest" @@ -658,6 +659,69 @@ func TestEchoShutdown(t *testing.T) { assert.Equal(t, err.Error(), "http: Server closed") } +var listenerNetworkTests = []struct { + test string + network string + address string +}{ + {"tcp ipv4 address", "tcp", "127.0.0.1:1323"}, + {"tcp ipv6 address", "tcp", "[::1]:1323"}, + {"tcp4 ipv4 address", "tcp4", "127.0.0.1:1323"}, + {"tcp6 ipv6 address", "tcp6", "[::1]:1323"}, +} + +func TestEchoListenerNetwork(t *testing.T) { + for _, tt := range listenerNetworkTests { + t.Run(tt.test, func(t *testing.T) { + e := New() + e.ListenerNetwork = tt.network + + // HandlerFunc + e.GET("/ok", func(c Context) error { + return c.String(http.StatusOK, "OK") + }) + + errCh := make(chan error) + + go func() { + errCh <- e.Start(tt.address) + }() + + time.Sleep(200 * time.Millisecond) + + if resp, err := http.Get(fmt.Sprintf("http://%s/ok", tt.address)); err == nil { + defer resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + + if body, err := ioutil.ReadAll(resp.Body); err == nil { + assert.Equal(t, "OK", string(body)) + } else { + assert.Fail(t, err.Error()) + } + + } else { + assert.Fail(t, err.Error()) + } + + if err := e.Close(); err != nil { + t.Fatal(err) + } + }) + } +} + +func TestEchoListenerNetworkInvalid(t *testing.T) { + e := New() + e.ListenerNetwork = "unix" + + // HandlerFunc + e.GET("/ok", func(c Context) error { + return c.String(http.StatusOK, "OK") + }) + + assert.Equal(t, ErrInvalidListenerNetwork, e.Start(":1323")) +} + func TestEchoReverse(t *testing.T) { assert := assert.New(t)