diff --git a/cmd/flowlogs-pipeline/main.go b/cmd/flowlogs-pipeline/main.go index a2fa5246e..f24e13589 100644 --- a/cmd/flowlogs-pipeline/main.go +++ b/cmd/flowlogs-pipeline/main.go @@ -199,7 +199,7 @@ func run() { } // Start health report server - operational.NewHealthServer(&opts, mainPipeline.IsAlive, mainPipeline.IsReady) + healthServer := operational.NewHealthServer(&opts, mainPipeline.IsAlive, mainPipeline.IsReady) // Starts the flows pipeline mainPipeline.Run() @@ -207,6 +207,7 @@ func run() { if promServer != nil { _ = promServer.Shutdown(context.Background()) } + _ = healthServer.Shutdown(context.Background()) // Give all threads a chance to exit and then exit the process time.Sleep(time.Second) diff --git a/pkg/operational/health.go b/pkg/operational/health.go index a252ebc46..c4510a695 100644 --- a/pkg/operational/health.go +++ b/pkg/operational/health.go @@ -24,35 +24,28 @@ import ( "github.com/heptiolabs/healthcheck" "github.com/netobserv/flowlogs-pipeline/pkg/config" + "github.com/netobserv/flowlogs-pipeline/pkg/server" log "github.com/sirupsen/logrus" ) -type Server struct { - handler healthcheck.Handler - Address string -} - -func (hs *Server) Serve() { - for { - err := http.ListenAndServe(hs.Address, hs.handler) - log.Errorf("http.ListenAndServe error %v", err) - time.Sleep(60 * time.Second) - } -} - -func NewHealthServer(opts *config.Options, isAlive healthcheck.Check, isReady healthcheck.Check) *Server { - +func NewHealthServer(opts *config.Options, isAlive healthcheck.Check, isReady healthcheck.Check) *http.Server { handler := healthcheck.NewHandler() address := net.JoinHostPort(opts.Health.Address, opts.Health.Port) handler.AddLivenessCheck("PipelineCheck", isAlive) handler.AddReadinessCheck("PipelineCheck", isReady) - server := &Server{ - handler: handler, - Address: address, - } - - go server.Serve() + server := server.Default(&http.Server{ + Handler: handler, + Addr: address, + }) + + go func() { + for { + err := server.ListenAndServe() + log.Errorf("http.ListenAndServe error %v", err) + time.Sleep(60 * time.Second) + } + }() return server } diff --git a/pkg/pipeline/health_test.go b/pkg/pipeline/health_test.go index 54bd6a649..af16c42ba 100644 --- a/pkg/pipeline/health_test.go +++ b/pkg/pipeline/health_test.go @@ -58,7 +58,7 @@ func TestNewHealthServer(t *testing.T) { expectedAddr := fmt.Sprintf("%s:%s", opts.Health.Address, opts.Health.Port) server := operational.NewHealthServer(&opts, tt.args.pipeline.IsAlive, tt.args.pipeline.IsReady) require.NotNil(t, server) - require.Equal(t, expectedAddr, server.Address) + require.Equal(t, expectedAddr, server.Addr) client := &http.Client{} diff --git a/pkg/prometheus/prom_server.go b/pkg/prometheus/prom_server.go index d40000bc0..19a226583 100644 --- a/pkg/prometheus/prom_server.go +++ b/pkg/prometheus/prom_server.go @@ -24,6 +24,7 @@ import ( "github.com/netobserv/flowlogs-pipeline/pkg/api" "github.com/netobserv/flowlogs-pipeline/pkg/config" + "github.com/netobserv/flowlogs-pipeline/pkg/server" prom "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/sirupsen/logrus" @@ -63,7 +64,7 @@ func StartServerAsync(conn *api.PromConnectionInfo, registry *prom.Registry) *ht addr := fmt.Sprintf("%s:%v", conn.Address, port) plog.Infof("StartServerAsync: addr = %s", addr) - httpServer := http.Server{ + httpServer := &http.Server{ Addr: addr, // TLS clients must use TLS 1.2 or higher TLSConfig: &tls.Config{ @@ -79,6 +80,7 @@ func StartServerAsync(conn *api.PromConnectionInfo, registry *prom.Registry) *ht mux.Handle("/metrics", promhttp.HandlerFor(registry, promhttp.HandlerOpts{})) } httpServer.Handler = mux + httpServer = server.Default(httpServer) go func() { var err error @@ -92,5 +94,5 @@ func StartServerAsync(conn *api.PromConnectionInfo, registry *prom.Registry) *ht } }() - return &httpServer + return httpServer } diff --git a/pkg/prometheus/prom_server_test.go b/pkg/prometheus/prom_server_test.go new file mode 100644 index 000000000..e413e7f85 --- /dev/null +++ b/pkg/prometheus/prom_server_test.go @@ -0,0 +1,81 @@ +package prometheus + +import ( + "context" + "fmt" + "io" + "net/http" + "strings" + "testing" + "time" + + "github.com/netobserv/flowlogs-pipeline/pkg/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStartPromServer(t *testing.T) { + srv := InitializePrometheus(&config.MetricsSettings{}) + + serverURL := "http://0.0.0.0:9090" + t.Logf("Started test http server: %v", serverURL) + + httpClient := &http.Client{} + + // wait for our test http server to come up + checkHTTPReady(httpClient, serverURL) + + r, err := http.NewRequest("GET", serverURL+"/metrics", nil) + require.NoError(t, err) + + resp, err := httpClient.Do(r) + require.NoError(t, err) + defer resp.Body.Close() + + bodyBytes, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + bodyString := string(bodyBytes) + require.Equal(t, http.StatusOK, resp.StatusCode) + require.Contains(t, bodyString, "go_gc_duration_seconds") + + _ = srv.Shutdown(context.Background()) +} + +func TestStartPromServer_HeadersLimit(t *testing.T) { + srv := InitializePrometheus(&config.MetricsSettings{}) + + serverURL := "http://0.0.0.0:9090" + t.Logf("Started test http server: %v", serverURL) + + httpClient := &http.Client{} + + // wait for our test http server to come up + checkHTTPReady(httpClient, serverURL) + + r, err := http.NewRequest("GET", serverURL+"/metrics", nil) + require.NoError(t, err) + + // Set many headers + oneKBString := strings.Repeat(".", 1024) + for i := 0; i < 1025; i++ { + r.Header.Set(fmt.Sprintf("test-header-%d", i), oneKBString) + } + + resp, err := httpClient.Do(r) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusRequestHeaderFieldsTooLarge, resp.StatusCode) + + _ = srv.Shutdown(context.Background()) +} + +func checkHTTPReady(httpClient *http.Client, url string) { + for i := 0; i < 60; i++ { + if r, err := httpClient.Get(url); err == nil { + r.Body.Close() + break + } + time.Sleep(time.Second) + } +} diff --git a/pkg/server/common.go b/pkg/server/common.go new file mode 100644 index 000000000..92fdb7fb5 --- /dev/null +++ b/pkg/server/common.go @@ -0,0 +1,46 @@ +package server + +import ( + "crypto/tls" + "net/http" + "time" + + "github.com/sirupsen/logrus" +) + +var slog = logrus.WithField("module", "server") + +func Default(srv *http.Server) *http.Server { + // defaults taken from https://bruinsslot.jp/post/go-secure-webserver/ can be overriden by caller + if srv.Handler != nil { + // No more than 2MB body + srv.Handler = http.MaxBytesHandler(srv.Handler, 2<<20) + } else { + slog.Warnf("Handler not yet set on server while securing defaults. Make sure a MaxByte middleware is used.") + } + if srv.ReadTimeout == 0 { + srv.ReadTimeout = 10 * time.Second + } + if srv.ReadHeaderTimeout == 0 { + srv.ReadHeaderTimeout = 5 * time.Second + } + if srv.WriteTimeout == 0 { + srv.WriteTimeout = 10 * time.Second + } + if srv.IdleTimeout == 0 { + srv.IdleTimeout = 120 * time.Second + } + if srv.MaxHeaderBytes == 0 { + srv.MaxHeaderBytes = 1 << 20 // 1MB + } + if srv.TLSConfig == nil { + srv.TLSConfig = &tls.Config{} + } + if srv.TLSConfig.MinVersion == 0 { + srv.TLSConfig.MinVersion = tls.VersionTLS13 + } + // Disable http/2 + srv.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler), 0) + + return srv +}