-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #107 from libp2p/feat/better-dialsync
Improve swarm dial sync code
- Loading branch information
Showing
7 changed files
with
328 additions
and
127 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
package swarm | ||
|
||
import ( | ||
"context" | ||
"sync" | ||
|
||
peer "github.com/ipfs/go-libp2p-peer" | ||
) | ||
|
||
type DialFunc func(context.Context, peer.ID) (*Conn, error) | ||
|
||
func NewDialSync(dfn DialFunc) *DialSync { | ||
return &DialSync{ | ||
dials: make(map[peer.ID]*activeDial), | ||
dialFunc: dfn, | ||
} | ||
} | ||
|
||
type DialSync struct { | ||
dials map[peer.ID]*activeDial | ||
dialsLk sync.Mutex | ||
dialFunc DialFunc | ||
} | ||
|
||
type activeDial struct { | ||
id peer.ID | ||
refCnt int | ||
refCntLk sync.Mutex | ||
cancel func() | ||
|
||
err error | ||
conn *Conn | ||
waitch chan struct{} | ||
|
||
ds *DialSync | ||
} | ||
|
||
func (dr *activeDial) wait(ctx context.Context) (*Conn, error) { | ||
defer dr.decref() | ||
select { | ||
case <-dr.waitch: | ||
return dr.conn, dr.err | ||
case <-ctx.Done(): | ||
return nil, ctx.Err() | ||
} | ||
} | ||
|
||
func (ad *activeDial) incref() { | ||
ad.refCntLk.Lock() | ||
defer ad.refCntLk.Unlock() | ||
ad.refCnt++ | ||
} | ||
|
||
func (ad *activeDial) decref() { | ||
ad.refCntLk.Lock() | ||
defer ad.refCntLk.Unlock() | ||
ad.refCnt-- | ||
if ad.refCnt <= 0 { | ||
ad.cancel() | ||
ad.ds.dialsLk.Lock() | ||
delete(ad.ds.dials, ad.id) | ||
ad.ds.dialsLk.Unlock() | ||
} | ||
} | ||
|
||
func (ds *DialSync) DialLock(ctx context.Context, p peer.ID) (*Conn, error) { | ||
ds.dialsLk.Lock() | ||
|
||
actd, ok := ds.dials[p] | ||
if !ok { | ||
ctx, cancel := context.WithCancel(context.Background()) | ||
actd = &activeDial{ | ||
id: p, | ||
cancel: cancel, | ||
waitch: make(chan struct{}), | ||
ds: ds, | ||
} | ||
ds.dials[p] = actd | ||
|
||
go func(ctx context.Context, p peer.ID, ad *activeDial) { | ||
ad.conn, ad.err = ds.dialFunc(ctx, p) | ||
close(ad.waitch) | ||
ad.cancel() | ||
ad.waitch = nil // to ensure nobody tries reusing this | ||
}(ctx, p, actd) | ||
} | ||
|
||
actd.incref() | ||
ds.dialsLk.Unlock() | ||
|
||
return actd.wait(ctx) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
package swarm | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"sync" | ||
"testing" | ||
"time" | ||
|
||
peer "github.com/ipfs/go-libp2p-peer" | ||
) | ||
|
||
func getMockDialFunc() (DialFunc, func(), context.Context, <-chan struct{}) { | ||
dfcalls := make(chan struct{}, 512) // buffer it large enough that we won't care | ||
dialctx, cancel := context.WithCancel(context.Background()) | ||
ch := make(chan struct{}) | ||
f := func(ctx context.Context, p peer.ID) (*Conn, error) { | ||
dfcalls <- struct{}{} | ||
defer cancel() | ||
select { | ||
case <-ch: | ||
return new(Conn), nil | ||
case <-ctx.Done(): | ||
return nil, ctx.Err() | ||
} | ||
} | ||
|
||
o := new(sync.Once) | ||
|
||
return f, func() { o.Do(func() { close(ch) }) }, dialctx, dfcalls | ||
} | ||
|
||
func TestBasicDialSync(t *testing.T) { | ||
df, done, _, callsch := getMockDialFunc() | ||
|
||
dsync := NewDialSync(df) | ||
|
||
p := peer.ID("testpeer") | ||
|
||
ctx := context.Background() | ||
|
||
finished := make(chan struct{}) | ||
go func() { | ||
_, err := dsync.DialLock(ctx, p) | ||
if err != nil { | ||
t.Error(err) | ||
} | ||
finished <- struct{}{} | ||
}() | ||
|
||
go func() { | ||
_, err := dsync.DialLock(ctx, p) | ||
if err != nil { | ||
t.Error(err) | ||
} | ||
finished <- struct{}{} | ||
}() | ||
|
||
// short sleep just to make sure we've moved around in the scheduler | ||
time.Sleep(time.Millisecond * 20) | ||
done() | ||
|
||
<-finished | ||
<-finished | ||
|
||
if len(callsch) > 1 { | ||
t.Fatal("should only have called dial func once!") | ||
} | ||
} | ||
|
||
func TestDialSyncCancel(t *testing.T) { | ||
df, done, _, dcall := getMockDialFunc() | ||
|
||
dsync := NewDialSync(df) | ||
|
||
p := peer.ID("testpeer") | ||
|
||
ctx1, cancel1 := context.WithCancel(context.Background()) | ||
|
||
finished := make(chan struct{}) | ||
go func() { | ||
_, err := dsync.DialLock(ctx1, p) | ||
if err != ctx1.Err() { | ||
t.Error("should have gotten context error") | ||
} | ||
finished <- struct{}{} | ||
}() | ||
|
||
// make sure the above makes it through the wait code first | ||
select { | ||
case <-dcall: | ||
case <-time.After(time.Second): | ||
t.Fatal("timed out waiting for dial to start") | ||
} | ||
|
||
// Add a second dialwait in so two actors are waiting on the same dial | ||
go func() { | ||
_, err := dsync.DialLock(context.Background(), p) | ||
if err != nil { | ||
t.Error(err) | ||
} | ||
finished <- struct{}{} | ||
}() | ||
|
||
time.Sleep(time.Millisecond * 20) | ||
|
||
// cancel the first dialwait, it should not affect the second at all | ||
cancel1() | ||
select { | ||
case <-finished: | ||
case <-time.After(time.Second): | ||
t.Fatal("timed out waiting for wait to exit") | ||
} | ||
|
||
// short sleep just to make sure we've moved around in the scheduler | ||
time.Sleep(time.Millisecond * 20) | ||
done() | ||
|
||
<-finished | ||
} | ||
|
||
func TestDialSyncAllCancel(t *testing.T) { | ||
df, done, dctx, _ := getMockDialFunc() | ||
|
||
dsync := NewDialSync(df) | ||
|
||
p := peer.ID("testpeer") | ||
|
||
ctx1, cancel1 := context.WithCancel(context.Background()) | ||
|
||
finished := make(chan struct{}) | ||
go func() { | ||
_, err := dsync.DialLock(ctx1, p) | ||
if err != ctx1.Err() { | ||
t.Error("should have gotten context error") | ||
} | ||
finished <- struct{}{} | ||
}() | ||
|
||
// Add a second dialwait in so two actors are waiting on the same dial | ||
go func() { | ||
_, err := dsync.DialLock(ctx1, p) | ||
if err != ctx1.Err() { | ||
t.Error("should have gotten context error") | ||
} | ||
finished <- struct{}{} | ||
}() | ||
|
||
cancel1() | ||
for i := 0; i < 2; i++ { | ||
select { | ||
case <-finished: | ||
case <-time.After(time.Second): | ||
t.Fatal("timed out waiting for wait to exit") | ||
} | ||
} | ||
|
||
// the dial should have exited now | ||
select { | ||
case <-dctx.Done(): | ||
case <-time.After(time.Second): | ||
t.Fatal("timed out waiting for dial to return") | ||
} | ||
|
||
// should be able to successfully dial that peer again | ||
done() | ||
_, err := dsync.DialLock(context.Background(), p) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
} | ||
|
||
func TestFailFirst(t *testing.T) { | ||
var count int | ||
f := func(ctx context.Context, p peer.ID) (*Conn, error) { | ||
if count > 0 { | ||
return new(Conn), nil | ||
} | ||
count++ | ||
return nil, fmt.Errorf("gophers ate the modem") | ||
} | ||
|
||
ds := NewDialSync(f) | ||
|
||
p := peer.ID("testing") | ||
|
||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) | ||
defer cancel() | ||
|
||
_, err := ds.DialLock(ctx, p) | ||
if err == nil { | ||
t.Fatal("expected gophers to have eaten the modem") | ||
} | ||
|
||
c, err := ds.DialLock(ctx, p) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
if c == nil { | ||
t.Fatal("should have gotten a 'real' conn back") | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.