Skip to content

Commit

Permalink
Merge pull request #1667 from pafuent/listener_network_configurable
Browse files Browse the repository at this point in the history
Adding Echo#ListenerNetwork as configuration
  • Loading branch information
pafuent committed Dec 12, 2020
2 parents 06a9480 + 78fe222 commit 2b36b3d
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 8 deletions.
22 changes: 14 additions & 8 deletions echo.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ type (
Renderer Renderer
Logger Logger
IPExtractor IPExtractor
ListenerNetwork string
}

// Route contains a handler and information for matching against requests.
Expand Down Expand Up @@ -281,6 +282,7 @@ var (
ErrInvalidRedirectCode = errors.New("invalid redirect status code")
ErrCookieNotFound = errors.New("cookie not found")
ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte")
ErrInvalidListenerNetwork = errors.New("invalid listener network")
)

// Error handlers
Expand All @@ -302,9 +304,10 @@ func New() (e *Echo) {
AutoTLSManager: autocert.Manager{
Prompt: autocert.AcceptTOS,
},
Logger: log.New("echo"),
colorer: color.New(),
maxParam: new(int),
Logger: log.New("echo"),
colorer: color.New(),
maxParam: new(int),
ListenerNetwork: "tcp",
}
e.Server.Handler = e
e.TLSServer.Handler = e
Expand Down Expand Up @@ -714,7 +717,7 @@ func (e *Echo) StartServer(s *http.Server) (err error) {

if s.TLSConfig == nil {
if e.Listener == nil {
e.Listener, err = newListener(s.Addr)
e.Listener, err = newListener(s.Addr, e.ListenerNetwork)
if err != nil {
return err
}
Expand All @@ -725,7 +728,7 @@ func (e *Echo) StartServer(s *http.Server) (err error) {
return s.Serve(e.Listener)
}
if e.TLSListener == nil {
l, err := newListener(s.Addr)
l, err := newListener(s.Addr, e.ListenerNetwork)
if err != nil {
return err
}
Expand Down Expand Up @@ -754,7 +757,7 @@ func (e *Echo) StartH2CServer(address string, h2s *http2.Server) (err error) {
}

if e.Listener == nil {
e.Listener, err = newListener(s.Addr)
e.Listener, err = newListener(s.Addr, e.ListenerNetwork)
if err != nil {
return err
}
Expand Down Expand Up @@ -875,8 +878,11 @@ func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
return
}

func newListener(address string) (*tcpKeepAliveListener, error) {
l, err := net.Listen("tcp", address)
func newListener(address, network string) (*tcpKeepAliveListener, error) {
if network != "tcp" && network != "tcp4" && network != "tcp6" {
return nil, ErrInvalidListenerNetwork
}
l, err := net.Listen(network, address)
if err != nil {
return nil, err
}
Expand Down
64 changes: 64 additions & 0 deletions echo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
stdContext "context"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -658,6 +659,69 @@ func TestEchoShutdown(t *testing.T) {
assert.Equal(t, err.Error(), "http: Server closed")
}

var listenerNetworkTests = []struct {
test string
network string
address string
}{
{"tcp ipv4 address", "tcp", "127.0.0.1:1323"},
{"tcp ipv6 address", "tcp", "[::1]:1323"},
{"tcp4 ipv4 address", "tcp4", "127.0.0.1:1323"},
{"tcp6 ipv6 address", "tcp6", "[::1]:1323"},
}

func TestEchoListenerNetwork(t *testing.T) {
for _, tt := range listenerNetworkTests {
t.Run(tt.test, func(t *testing.T) {
e := New()
e.ListenerNetwork = tt.network

// HandlerFunc
e.GET("/ok", func(c Context) error {
return c.String(http.StatusOK, "OK")
})

errCh := make(chan error)

go func() {
errCh <- e.Start(tt.address)
}()

time.Sleep(200 * time.Millisecond)

if resp, err := http.Get(fmt.Sprintf("http://%s/ok", tt.address)); err == nil {
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)

if body, err := ioutil.ReadAll(resp.Body); err == nil {
assert.Equal(t, "OK", string(body))
} else {
assert.Fail(t, err.Error())
}

} else {
assert.Fail(t, err.Error())
}

if err := e.Close(); err != nil {
t.Fatal(err)
}
})
}
}

func TestEchoListenerNetworkInvalid(t *testing.T) {
e := New()
e.ListenerNetwork = "unix"

// HandlerFunc
e.GET("/ok", func(c Context) error {
return c.String(http.StatusOK, "OK")
})

assert.Equal(t, ErrInvalidListenerNetwork, e.Start(":1323"))
}

func TestEchoReverse(t *testing.T) {
assert := assert.New(t)

Expand Down

0 comments on commit 2b36b3d

Please sign in to comment.