Skip to content

Commit

Permalink
Remove oxy connections limiter
Browse files Browse the repository at this point in the history
lib/limiter composes a maximum simultaneous connections limiter
as well as a rate limiter. This commit replaces the connections
limiter from oxy with built-in code.

Note: this also introduces a bug fix that may change behavior.
Prior to this change, the connection limiter kept a separate
connection account for HTTP connections than it did for connections
managed manually with acquire/release. We no longer maintain separate
counts - all connections (HTTP or not) contribute to the number
of allowed connections.
  • Loading branch information
zmb3 committed Sep 28, 2024
1 parent ed26cb1 commit 57f2bec
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 103 deletions.
4 changes: 1 addition & 3 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -372,9 +372,7 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) {
cfg.Logger = slog.With(teleport.ComponentKey, teleport.ComponentAuth)
}

limiter, err := limiter.NewConnectionsLimiter(limiter.Config{
MaxConnections: defaults.LimiterMaxConcurrentSignatures,
})
limiter, err := limiter.NewConnectionsLimiter(defaults.LimiterMaxConcurrentSignatures)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down
59 changes: 39 additions & 20 deletions lib/limiter/connlimiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,48 +19,68 @@
package limiter

import (
"context"
"log/slog"
"net/http"
"strings"
"sync"

"github.com/gravitational/oxy/connlimit"
"github.com/gravitational/oxy/utils"
"github.com/gravitational/teleport"
"github.com/gravitational/trace"
log "github.com/sirupsen/logrus"
)

// ConnectionsLimiter is a network connection limiter and tracker
// ConnectionsLimiter is a network connection limiter.
type ConnectionsLimiter struct {
*connlimit.ConnLimiter
maxConnections int64
log *slog.Logger

next http.Handler

sync.Mutex
connections map[string]int64
}

// NewConnectionsLimiter returns new connection limiter, in case if connection
// limits are not set, they won't be tracked
func NewConnectionsLimiter(config Config) (*ConnectionsLimiter, error) {
limiter := ConnectionsLimiter{
maxConnections: config.MaxConnections,
func NewConnectionsLimiter(maxConnections int64) (*ConnectionsLimiter, error) {
// TODO: remove error from return type
return &ConnectionsLimiter{
maxConnections: maxConnections,
log: slog.With(teleport.ComponentKey, "limiter"),
connections: make(map[string]int64),
}, nil
}

// Wrap wraps an HTTP handler.
func (l *ConnectionsLimiter) Wrap(h http.Handler) {
l.next = h
}

func (l *ConnectionsLimiter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if l.next == nil {
sc := http.StatusInternalServerError
http.Error(w, http.StatusText(sc), sc)
return
}

ipExtractor, err := utils.NewExtractor("client.ip")
if err != nil {
return nil, trace.Wrap(err)
// TODO: use net.SplitHostPort to be more compatible with IPv6
token, _, _ := strings.Cut(r.RemoteAddr, ":")
if token == "" {
l.log.WarnContext(context.Background(), "failed to extract source IP", "remote_addr", r.RemoteAddr)
sc := http.StatusInternalServerError
http.Error(w, http.StatusText(sc), sc)
return
}

limiter.ConnLimiter, err = connlimit.New(nil, ipExtractor, config.MaxConnections)
if err != nil {
return nil, trace.Wrap(err)
if err := l.AcquireConnection(token); err != nil {
l.log.InfoContext(context.Background(), "limiting request", "token", token, "error", err)
trace.WriteError(w, err)
return
}

return &limiter, nil
}
defer l.ReleaseConnection(token)

// WrapHandle adds connection limiter to the handle
func (l *ConnectionsLimiter) WrapHandle(h http.Handler) {
l.ConnLimiter.Wrap(h)
l.next.ServeHTTP(w, r)
}

// AcquireConnection acquires connection and bumps counter
Expand Down Expand Up @@ -97,7 +117,6 @@ func (l *ConnectionsLimiter) ReleaseConnection(token string) {

numberOfConnections, exists := l.connections[token]
if !exists {
log.Errorf("Trying to set negative number of connections")
return
}

Expand Down
71 changes: 71 additions & 0 deletions lib/limiter/connlimiter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Teleport
* Copyright (C) 2024 Gravitational, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package limiter_test

import (
"testing"

"github.com/gravitational/teleport/lib/limiter"
"github.com/stretchr/testify/require"
)

func TestConnectionsLimiter(t *testing.T) {
l, err := limiter.NewConnectionsLimiter(0)
require.NoError(t, err)

for i := 0; i < 10; i++ {
require.NoError(t, l.AcquireConnection("token1"))
}
for i := 0; i < 5; i++ {
require.NoError(t, l.AcquireConnection("token2"))
}

for i := 0; i < 10; i++ {
l.ReleaseConnection("token1")
}
for i := 0; i < 5; i++ {
l.ReleaseConnection("token2")
}

l, err = limiter.NewConnectionsLimiter(5)
require.NoError(t, err)

for i := 0; i < 5; i++ {
require.NoError(t, l.AcquireConnection("token1"))
}

for i := 0; i < 5; i++ {
require.NoError(t, l.AcquireConnection("token2"))
}
for i := 0; i < 5; i++ {
require.Error(t, l.AcquireConnection("token2"))
}

for i := 0; i < 10; i++ {
l.ReleaseConnection("token1")
require.NoError(t, l.AcquireConnection("token1"))
}

for i := 0; i < 5; i++ {
l.ReleaseConnection("token2")
}
for i := 0; i < 5; i++ {
require.NoError(t, l.AcquireConnection("token2"))
}
}
47 changes: 22 additions & 25 deletions lib/limiter/limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ package limiter

import (
"context"
"encoding/json"
"net"
"net/http"

Expand All @@ -34,8 +33,8 @@ import (

// Limiter helps limiting connections and request rates
type Limiter struct {
// ConnectionsLimiter limits simultaneous connection
*ConnectionsLimiter
// connectionLimiter limits simultaneous connection
connectionLimiter *ConnectionsLimiter
// rateLimiter limits request rate
rateLimiter *RateLimiter
}
Expand All @@ -52,21 +51,11 @@ type Config struct {
Clock timetools.TimeProvider
}

// SetEnv reads LimiterConfig from JSON string
func (l *Config) SetEnv(v string) error {
if err := json.Unmarshal([]byte(v), l); err != nil {
return trace.Wrap(err, "expected JSON encoded remote certificate")
}
return nil
}

// NewLimiter returns new rate and connection limiter
func NewLimiter(config Config) (*Limiter, error) {
if config.MaxConnections < 0 {
config.MaxConnections = 0
}
config.MaxConnections = max(config.MaxConnections, 0)

connectionsLimiter, err := NewConnectionsLimiter(config)
connectionsLimiter, err := NewConnectionsLimiter(config.MaxConnections)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -77,11 +66,15 @@ func NewLimiter(config Config) (*Limiter, error) {
}

return &Limiter{
ConnectionsLimiter: connectionsLimiter,
rateLimiter: rateLimiter,
connectionLimiter: connectionsLimiter,
rateLimiter: rateLimiter,
}, nil
}

func (l *Limiter) GetNumConnection(token string) (int64, error) {
return l.connectionLimiter.GetNumConnection(token)
}

func (l *Limiter) RegisterRequest(token string) error {
return l.rateLimiter.RegisterRequest(token, nil)
}
Expand All @@ -90,10 +83,14 @@ func (l *Limiter) RegisterRequestWithCustomRate(token string, customRate *rateli
return l.rateLimiter.RegisterRequest(token, customRate)
}

func (l *Limiter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
l.connectionLimiter.ServeHTTP(w, r)
}

// WrapHandle adds limiter to the handle
func (l *Limiter) WrapHandle(h http.Handler) {
l.rateLimiter.Wrap(h)
l.ConnLimiter.Wrap(l.rateLimiter)
l.connectionLimiter.Wrap(l.rateLimiter)
}

// RegisterRequestAndConnection register a rate and connection limiter for a given token. Close function is returned,
Expand All @@ -112,11 +109,11 @@ func (l *Limiter) RegisterRequestAndConnection(token string) (func(), error) {
}

// Apply connection limiting.
if err := l.AcquireConnection(token); err != nil {
if err := l.connectionLimiter.AcquireConnection(token); err != nil {
return func() {}, trace.LimitExceeded("exceeded connection limit for %q", token)
}

return func() { l.ReleaseConnection(token) }, nil
return func() { l.connectionLimiter.ReleaseConnection(token) }, nil
}

// UnaryServerInterceptor returns a gRPC unary interceptor which
Expand Down Expand Up @@ -148,10 +145,10 @@ func (l *Limiter) UnaryServerInterceptorWithCustomRate(customRate CustomRateFunc
if err := l.RegisterRequestWithCustomRate(clientIP, customRate(info.FullMethod)); err != nil {
return nil, trace.LimitExceeded("rate limit exceeded")
}
if err := l.ConnLimiter.Acquire(clientIP, 1); err != nil {
if err := l.connectionLimiter.AcquireConnection(clientIP); err != nil {
return nil, trace.LimitExceeded("connection limit exceeded")
}
defer l.ConnLimiter.Release(clientIP, 1)
defer l.connectionLimiter.ReleaseConnection(clientIP)
return handler(ctx, req)
}
}
Expand All @@ -171,17 +168,17 @@ func (l *Limiter) StreamServerInterceptor(srv interface{}, serverStream grpc.Ser
if err := l.RegisterRequest(clientIP); err != nil {
return trace.LimitExceeded("rate limit exceeded")
}
if err := l.ConnLimiter.Acquire(clientIP, 1); err != nil {
if err := l.connectionLimiter.AcquireConnection(clientIP); err != nil {
return trace.LimitExceeded("connection limit exceeded")
}
defer l.ConnLimiter.Release(clientIP, 1)
defer l.connectionLimiter.ReleaseConnection(clientIP)
return handler(srv, serverStream)
}

// WrapListener returns a [Listener] that wraps the provided listener
// with one that limits connections
func (l *Limiter) WrapListener(ln net.Listener) (*Listener, error) {
return NewListener(ln, l.ConnectionsLimiter)
return NewListener(ln, l.connectionLimiter)
}

type handlerWrapper interface {
Expand Down
55 changes: 1 addition & 54 deletions lib/limiter/limiter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,59 +43,6 @@ func TestMain(m *testing.M) {
os.Exit(m.Run())
}

func TestConnectionsLimiter(t *testing.T) {
limiter, err := NewLimiter(
Config{
MaxConnections: 0,
},
)
require.NoError(t, err)

for i := 0; i < 10; i++ {
require.NoError(t, limiter.AcquireConnection("token1"))
}
for i := 0; i < 5; i++ {
require.NoError(t, limiter.AcquireConnection("token2"))
}

for i := 0; i < 10; i++ {
limiter.ReleaseConnection("token1")
}
for i := 0; i < 5; i++ {
limiter.ReleaseConnection("token2")
}

limiter, err = NewLimiter(
Config{
MaxConnections: 5,
},
)
require.NoError(t, err)

for i := 0; i < 5; i++ {
require.NoError(t, limiter.AcquireConnection("token1"))
}

for i := 0; i < 5; i++ {
require.NoError(t, limiter.AcquireConnection("token2"))
}
for i := 0; i < 5; i++ {
require.Error(t, limiter.AcquireConnection("token2"))
}

for i := 0; i < 10; i++ {
limiter.ReleaseConnection("token1")
require.NoError(t, limiter.AcquireConnection("token1"))
}

for i := 0; i < 5; i++ {
limiter.ReleaseConnection("token2")
}
for i := 0; i < 5; i++ {
require.NoError(t, limiter.AcquireConnection("token2"))
}
}

func TestRateLimiter(t *testing.T) {
// TODO: this test fails
clock := &timetools.FreezedTime{
Expand Down Expand Up @@ -404,7 +351,7 @@ func TestListener(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
limiter, err := NewConnectionsLimiter(test.config)
limiter, err := NewConnectionsLimiter(test.config.MaxConnections)
require.NoError(t, err)

ln, err := NewListener(test.listener, limiter)
Expand Down
2 changes: 1 addition & 1 deletion lib/service/desktop.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ func (process *TeleportProcess) initWindowsDesktopServiceRegistered(logger *slog
return tlsCopy, nil
}

connLimiter, err := limiter.NewConnectionsLimiter(cfg.WindowsDesktop.ConnLimiter)
connLimiter, err := limiter.NewConnectionsLimiter(cfg.WindowsDesktop.ConnLimiter.MaxConnections)
if err != nil {
return trace.Wrap(err)
}
Expand Down

0 comments on commit 57f2bec

Please sign in to comment.