Skip to content

Commit

Permalink
Fix request throttling per host (#192)
Browse files Browse the repository at this point in the history
* Fix host limiting

* Refactor

* Remove comment

* Rename a field

* Test host throttler

* Fix data race

* Test pool

* Test more

* Calculate rate after gettign semaphore
  • Loading branch information
raviqqe authored Nov 11, 2021
1 parent 3a52e27 commit ffe5e1e
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 16 deletions.
1 change: 1 addition & 0 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ func (c *command) runWithError(ss []string) (bool, error) {
),
args.RateLimit,
args.MaxConnections,
args.MaxConnectionsPerHost,
)

pp := newPageParser(newLinkFinder(args.ExcludedPatterns))
Expand Down
27 changes: 27 additions & 0 deletions host_throttler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package main

import "go.uber.org/ratelimit"

type hostThrottler struct {
limiter ratelimit.Limiter
connections semaphore
}

func newHostThrottler(requestPerSecond, maxConnectionsPerHost int) *hostThrottler {
l := ratelimit.NewUnlimited()

if requestPerSecond > 0 {
l = ratelimit.New(requestPerSecond)
}

return &hostThrottler{l, newSemaphore(maxConnectionsPerHost)}
}

func (t *hostThrottler) Request() {
t.connections.Request()
t.limiter.Take()
}

func (t *hostThrottler) Release() {
t.connections.Release()
}
23 changes: 23 additions & 0 deletions host_throttler_pool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package main

import "sync"

type hostThrottlerPool struct {
requestPerSecond, maxConnectionsPerHost int
hostMap sync.Map
}

func newHostThrottlerPool(requestPerSecond, maxConnectionsPerHost int) *hostThrottlerPool {
return &hostThrottlerPool{requestPerSecond, maxConnectionsPerHost, sync.Map{}}
}

func (p *hostThrottlerPool) Get(name string) *hostThrottler {
t := newHostThrottler(p.requestPerSecond, p.maxConnectionsPerHost)
x, ok := p.hostMap.LoadOrStore(name, t)

if ok {
t = x.(*hostThrottler)
}

return t
}
55 changes: 55 additions & 0 deletions host_throttler_pool_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package main

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestNewHostThrottlerPool(t *testing.T) {
newHostThrottlerPool(1, 1)
}

func TestHostThrottlerPoolGetHost(t *testing.T) {
c := make(chan struct{}, 100)
s := newHostThrottlerPool(1000000, 1)

for i := 0; i < 2; i++ {
go func() {
s.Get("foo").Request()
c <- struct{}{}
}()
}

<-c

assert.Equal(t, 0, len(c))

s.Get("foo").Release()
<-c
}

func TestHostThrottlerPoolGetHosts(t *testing.T) {
hosts := []string{"foo", "bar"}
c := make(chan struct{}, 100)
s := newHostThrottlerPool(1000000, 1)

for _, host := range hosts {
for i := 0; i < 2; i++ {
go func(host string) {
s.Get(host).Request()
c <- struct{}{}
}(host)
}
}

<-c
<-c

assert.Equal(t, 0, len(c))

for _, host := range hosts {
s.Get(host).Release()
<-c
}
}
30 changes: 30 additions & 0 deletions host_throttler_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package main

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestNewHostThrottler(t *testing.T) {
newHostThrottler(1, 1)
}

func TestHostThrottlerRequest(t *testing.T) {
c := make(chan struct{}, 100)
s := newHostThrottler(1000000, 1)

for i := 0; i < 2; i++ {
go func() {
s.Request()
c <- struct{}{}
}()
}

<-c

assert.Equal(t, 0, len(c))

s.Release()
<-c
}
29 changes: 13 additions & 16 deletions throttled_http_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,29 @@ package main

import (
"net/url"

"go.uber.org/ratelimit"
)

// TODO Throttle requests for each host.
type throttledHttpClient struct {
client httpClient
limiter ratelimit.Limiter
semaphore semaphore
client httpClient
connections semaphore
hostThrottlerPool *hostThrottlerPool
}

func newThrottledHttpClient(c httpClient, rps int, maxConnections int) httpClient {
l := ratelimit.NewUnlimited()

if rps > 0 {
l = ratelimit.New(rps)
func newThrottledHttpClient(c httpClient, requestPerSecond int, maxConnections, maxConnectionsPerHost int) httpClient {
return &throttledHttpClient{
c,
newSemaphore(maxConnections),
newHostThrottlerPool(requestPerSecond, maxConnectionsPerHost),
}

return &throttledHttpClient{c, l, newSemaphore(maxConnections)}
}

func (c *throttledHttpClient) Get(u *url.URL) (httpResponse, error) {
c.semaphore.Request()
defer c.semaphore.Release()
c.connections.Request()
defer c.connections.Release()

c.limiter.Take()
t := c.hostThrottlerPool.Get(u.Hostname())
t.Request()
defer t.Release()

return c.client.Get(u)
}

0 comments on commit ffe5e1e

Please sign in to comment.