Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pause/resume to protocol #36

Merged
merged 1 commit into from
Jan 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions back_pressure.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package remotedialer

import (
"sync"
)

type backPressure struct {
cond sync.Cond
c *connection
paused bool
}

func newBackPressure(c *connection) *backPressure {
return &backPressure{
cond: sync.Cond{
L: &sync.Mutex{},
},
c: c,
paused: false,
}
}

func (b *backPressure) OnPause() {
b.cond.L.Lock()
defer b.cond.L.Unlock()

b.paused = true
b.cond.Broadcast()
}

func (b *backPressure) OnResume() {
b.cond.L.Lock()
defer b.cond.L.Unlock()

b.paused = false
b.cond.Broadcast()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need an advanced sync lecture for this... 😲

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how do you mean?

}

func (b *backPressure) Pause() error {
b.cond.L.Lock()
defer b.cond.L.Unlock()
if b.paused {
return nil
}
if _, err := b.c.Pause(); err != nil {
return err
}
b.paused = true
return nil
}

func (b *backPressure) Resume() error {
b.cond.L.Lock()
defer b.cond.L.Unlock()
if !b.paused {
return nil
}
if _, err := b.c.Resume(); err != nil {
return err
}
b.paused = false
return nil
}

func (b *backPressure) Wait() {
b.cond.L.Lock()
defer b.cond.L.Unlock()

for b.paused {
b.cond.Wait()
}
}
120 changes: 120 additions & 0 deletions buffer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package remotedialer

import (
"context"
"io/ioutil"
"net"
"net/http"
"testing"

"github.com/stretchr/testify/assert"
)

func TestExceedBuffer(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

producerAddress, err := newTestProducer(ctx)
if err != nil {
t.Fatal(err)
}

serverAddress, server, err := newTestServer(ctx)
if err != nil {
t.Fatal(err)
}

if err := newTestClient(ctx, "ws://"+serverAddress); err != nil {
t.Fatal(err)
}

client := http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, proto, address string) (net.Conn, error) {
return server.Dialer("client")(ctx, proto, address)
},
},
}

producerURL := "http://" + producerAddress

resp, err := client.Get(producerURL)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()

resp2, err := client.Get(producerURL)
if err != nil {
t.Fatal(err)
}
defer resp2.Body.Close()

resp2Body, err := ioutil.ReadAll(resp2.Body)
if err != nil {
t.Fatal(err)
}

respBody, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}

assert.Equal(t, 4096*4096, len(resp2Body))
assert.Equal(t, 4096*4096, len(respBody))
}

func newTestServer(ctx context.Context) (string, *Server, error) {
auth := func(req *http.Request) (clientKey string, authed bool, err error) {
return "client", true, nil
}

server := New(auth, DefaultErrorWriter)
address, err := newServer(ctx, server)
return address, server, err
}

func newTestClient(ctx context.Context, url string) error {
result := make(chan error, 2)
go func() {
err := ConnectToProxy(ctx, url, nil, func(proto, address string) bool {
return true
}, nil, func(ctx context.Context, session *Session) error {
result <- nil
return nil
})
result <- err
}()
return <-result
}

func newServer(ctx context.Context, handler http.Handler) (string, error) {
server := http.Server{
BaseContext: func(_ net.Listener) context.Context {
return ctx
},
Handler: handler,
}
listener, err := net.Listen("tcp", "localhost:0")
if err != nil {
return "", err
}
go func() {
<-ctx.Done()
listener.Close()
server.Shutdown(context.Background())
}()
go server.Serve(listener)
return listener.Addr().String(), nil
}

func newTestProducer(ctx context.Context) (string, error) {
buffer := make([]byte, 4096)
return newServer(ctx, http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
for i := 0; i < 4096; i++ {
if _, err := resp.Write(buffer); err != nil {
panic(err)
}
}
}))
}
29 changes: 28 additions & 1 deletion connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
type connection struct {
err error
writeDeadline time.Time
backPressure *backPressure
buffer *readBuffer
addr addr
session *Session
Expand All @@ -19,14 +20,15 @@ type connection struct {

func newConnection(connID int64, session *Session, proto, address string) *connection {
c := &connection{
buffer: newReadBuffer(),
addr: addr{
proto: proto,
address: address,
},
connID: connID,
session: session,
}
c.backPressure = newBackPressure(c)
c.buffer = newReadBuffer(c.backPressure)
metrics.IncSMTotalAddConnectionsForWS(session.clientKey, proto, address)
return c
}
Expand Down Expand Up @@ -69,11 +71,36 @@ func (c *connection) Write(b []byte) (int, error) {
if c.err != nil {
return 0, io.ErrClosedPipe
}
c.backPressure.Wait()
msg := newMessage(c.connID, b)
metrics.AddSMTotalTransmitBytesOnWS(c.session.clientKey, float64(len(msg.Bytes())))
return c.session.writeMessage(c.writeDeadline, msg)
}

func (c *connection) OnPause() {
c.backPressure.OnPause()
}

func (c *connection) OnResume() {
c.backPressure.OnResume()
}

func (c *connection) Pause() (int, error) {
if c.err != nil {
return 0, io.ErrClosedPipe
}
msg := newPause(c.connID)
return c.session.writeMessage(c.writeDeadline, msg)
}

func (c *connection) Resume() (int, error) {
if c.err != nil {
return 0, io.ErrClosedPipe
}
msg := newResume(c.connID)
return c.session.writeMessage(c.writeDeadline, msg)
}

func (c *connection) writeErr(err error) {
if err != nil {
msg := newErrorMessage(c.connID, err)
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ require (
github.com/pkg/errors v0.8.1
github.com/prometheus/client_golang v1.4.0
github.com/sirupsen/logrus v1.4.2
github.com/stretchr/testify v1.4.0
)
6 changes: 0 additions & 6 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuy
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973 h1:xJ4a3vCFaGF/jqvzLMYoU8P317H5OQ+Via4RmuPwCS0=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
Expand All @@ -19,7 +18,6 @@ github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.1 h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
Expand Down Expand Up @@ -58,7 +56,6 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
github.com/prometheus/client_golang v1.4.0 h1:YVIb/fVcOTMSqtqZWSKnHpSLBxu8DKgxq8z6RuBZwqI=
github.com/prometheus/client_golang v1.4.0/go.mod h1:e9GMxYsXl05ICDXkRhurwBS4Q3OK1iX/F2sw+iXX5zU=
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910 h1:idejC8f05m9MGOsuEi1ATq9shN03HrxNkD/luQvxCv8=
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/client_model v0.2.0 h1:uq5h0d+GuxiXLJLNABMgp2qUWDPiLvgCzz2dUR+/W/M=
Expand All @@ -75,9 +72,7 @@ github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
Expand All @@ -91,7 +86,6 @@ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190422165155-953cdadca894 h1:Cz4ceDQGXuKRnVBDTS23GTn/pU5OE2C0WrNTOYK1Uuc=
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200122134326-e047566fdf82 h1:ywK/j/KkyTHcdyYSZNXGjMwgmDSfjglYZ3vStQ/gSCU=
golang.org/x/sys v0.0.0-20200122134326-e047566fdf82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
Expand Down
22 changes: 22 additions & 0 deletions message.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ const (
Error
AddClient
RemoveClient
Pause
Resume
)

var (
Expand Down Expand Up @@ -59,6 +61,22 @@ func newMessage(connID int64, bytes []byte) *message {
}
}

func newPause(connID int64) *message {
return &message{
id: nextid(),
connID: connID,
messageType: Pause,
}
}

func newResume(connID int64) *message {
return &message{
id: nextid(),
connID: connID,
messageType: Resume,
}
}

func newConnect(connID int64, proto, address string) *message {
return &message{
id: nextid(),
Expand Down Expand Up @@ -213,6 +231,10 @@ func (m *message) String() string {
return fmt.Sprintf("%d ADDCLIENT [%s]", m.id, m.address)
case RemoveClient:
return fmt.Sprintf("%d REMOVECLIENT [%s]", m.id, m.address)
case Pause:
return fmt.Sprintf("%d PAUSE [%s]", m.id, m.address)
case Resume:
return fmt.Sprintf("%d RESUME [%s]", m.id, m.address)
}
return fmt.Sprintf("%d UNKNOWN[%d]: %d", m.id, m.connID, m.messageType)
}
Loading