diff --git a/p2p/host/eventbus/basic.go b/p2p/host/eventbus/basic.go index 8cb030d4f2..ca102b5637 100644 --- a/p2p/host/eventbus/basic.go +++ b/p2p/host/eventbus/basic.go @@ -21,21 +21,21 @@ type basicBus struct { var _ event.Bus = (*basicBus)(nil) -type Emitter struct { +type emitter struct { n *node typ reflect.Type closed int32 dropper func(reflect.Type) } -func (e *Emitter) Emit(evt interface{}) { +func (e *emitter) Emit(evt interface{}) { if atomic.LoadInt32(&e.closed) != 0 { panic("emitter is closed") } e.n.emit(evt) } -func (e *Emitter) Close() error { +func (e *emitter) Close() error { if !atomic.CompareAndSwapInt32(&e.closed, 0, 1) { panic("closed an emitter more than once") } @@ -93,16 +93,42 @@ func (b *basicBus) tryDropNode(typ reflect.Type) { b.lk.Unlock() } +type sub struct { + ch chan interface{} + nodes []*node + dropper func(reflect.Type) +} + +func (s *sub) Out() <-chan interface{} { + return s.ch +} + +func (s *sub) Close() error { + close(s.ch) + for _, n := range s.nodes { + n.lk.Lock() + for i := 0; i < len(n.sinks); i++ { + if n.sinks[i] == s.ch { + n.sinks[i], n.sinks[len(n.sinks)-1] = n.sinks[len(n.sinks)-1], nil + n.sinks = n.sinks[:len(n.sinks)-1] + break + } + } + tryDrop := len(n.sinks) == 0 && atomic.LoadInt32(&n.nEmitters) == 0 + n.lk.Unlock() + if tryDrop { + s.dropper(n.typ) + } + } + return nil +} + +var _ event.Subscription = (*sub)(nil) + // Subscribe creates new subscription. Failing to drain the channel will cause // publishers to get blocked. CancelFunc is guaranteed to return after last send // to the channel -// -// Example: -// ch := make(chan EventT, 10) -// defer close(ch) -// cancel, err := eventbus.Subscribe(ch) -// defer cancel() -func (b *basicBus) Subscribe(typedChan interface{}, opts ...event.SubscriptionOpt) (c event.CancelFunc, err error) { +func (b *basicBus) Subscribe(evtTypes interface{}, opts ...event.SubscriptionOpt) (_ event.Subscription, err error) { var settings subSettings for _, opt := range opts { if err := opt(&settings); err != nil { @@ -110,50 +136,40 @@ func (b *basicBus) Subscribe(typedChan interface{}, opts ...event.SubscriptionOp } } - refCh := reflect.ValueOf(typedChan) - typ := refCh.Type() - if typ.Kind() != reflect.Chan { - return nil, errors.New("expected a channel") + types, ok := evtTypes.([]interface{}) + if !ok { + types = []interface{}{evtTypes} } - if typ.ChanDir()&reflect.SendDir == 0 { - return nil, errors.New("channel doesn't allow send") + + out := &sub{ + ch: make(chan interface{}, settings.buffer), + nodes: make([]*node, len(types)), + + dropper: b.tryDropNode, } - if settings.forcedType != nil { - if settings.forcedType.Elem().AssignableTo(typ) { - return nil, fmt.Errorf("forced type %s cannot be sent to chan %s", settings.forcedType, typ) + for i, etyp := range types { + typ := reflect.TypeOf(etyp) + + if typ.Kind() != reflect.Ptr { + return nil, errors.New("subscribe called with non-pointer type") } - typ = settings.forcedType - } - - err = b.withNode(typ.Elem(), func(n *node) { - n.sinks = append(n.sinks, refCh) - c = func() { - n.lk.Lock() - for i := 0; i < len(n.sinks); i++ { - if n.sinks[i] == refCh { - n.sinks[i], n.sinks[len(n.sinks)-1] = n.sinks[len(n.sinks)-1], reflect.Value{} - n.sinks = n.sinks[:len(n.sinks)-1] - break + + err = b.withNode(typ.Elem(), func(n *node) { + n.sinks = append(n.sinks, out.ch) + out.nodes[i] = n + }, func(n *node) { + if n.keepLast { + l := n.last.Load() + if l == nil { + return } + out.ch <- l } - tryDrop := len(n.sinks) == 0 && atomic.LoadInt32(&n.nEmitters) == 0 - n.lk.Unlock() - if tryDrop { - b.tryDropNode(typ.Elem()) - } - } - }, func(n *node) { - if n.keepLast { - lastVal, ok := n.last.Load().(reflect.Value) - if !ok { - return - } + }) + } - refCh.Send(lastVal) - } - }) - return + return out, nil } // Emitter creates new emitter @@ -183,7 +199,7 @@ func (b *basicBus) Emitter(evtType interface{}, opts ...event.EmitterOpt) (e eve err = b.withNode(typ, func(n *node) { atomic.AddInt32(&n.nEmitters, 1) n.keepLast = n.keepLast || settings.makeStateful - e = &Emitter{n: n, typ: typ, dropper: b.tryDropNode} + e = &emitter{n: n, typ: typ, dropper: b.tryDropNode} }, func(_ *node) {}) return } @@ -203,7 +219,7 @@ type node struct { keepLast bool last atomic.Value - sinks []reflect.Value + sinks []chan interface{} } func newNode(typ reflect.Type) *node { @@ -220,11 +236,11 @@ func (n *node) emit(event interface{}) { n.lk.RLock() if n.keepLast { - n.last.Store(eval) + n.last.Store(event) } for _, ch := range n.sinks { - ch.Send(eval) + ch <- event } n.lk.RUnlock() } diff --git a/p2p/host/eventbus/basic_test.go b/p2p/host/eventbus/basic_test.go index 8d2bfae820..fc23e61290 100644 --- a/p2p/host/eventbus/basic_test.go +++ b/p2p/host/eventbus/basic_test.go @@ -27,15 +27,14 @@ func (EventA) String() string { func TestEmit(t *testing.T) { bus := NewBus() - events := make(chan EventA) - cancel, err := bus.Subscribe(events) + sub, err := bus.Subscribe(new(EventA)) if err != nil { t.Fatal(err) } go func() { - defer cancel() - <-events + defer sub.Close() + <-sub.Out() }() em, err := bus.Emitter(new(EventA)) @@ -49,8 +48,7 @@ func TestEmit(t *testing.T) { func TestSub(t *testing.T) { bus := NewBus() - events := make(chan EventB) - cancel, err := bus.Subscribe(events) + sub, err := bus.Subscribe(new(EventB)) if err != nil { t.Fatal(err) } @@ -61,8 +59,8 @@ func TestSub(t *testing.T) { wait.Add(1) go func() { - defer cancel() - event = <-events + defer sub.Close() + event = (<-sub.Out()).(EventB) wait.Done() }() @@ -131,9 +129,9 @@ func TestClosingRaces(t *testing.T) { lk.RLock() defer lk.RUnlock() - cancel, _ := b.Subscribe(make(chan EventA)) + sub, _ := b.Subscribe(new(EventA)) time.Sleep(10 * time.Millisecond) - cancel() + sub.Close() wg.Done() }() @@ -174,15 +172,14 @@ func TestSubMany(t *testing.T) { for i := 0; i < n; i++ { go func() { - events := make(chan EventB) - cancel, err := bus.Subscribe(events) + sub, err := bus.Subscribe(new(EventB)) if err != nil { panic(err) } - defer cancel() + defer sub.Close() ready.Done() - atomic.AddInt32(&r, int32(<-events)) + atomic.AddInt32(&r, int32((<-sub.Out()).(EventB))) wait.Done() }() } @@ -205,8 +202,7 @@ func TestSubMany(t *testing.T) { func TestSubType(t *testing.T) { bus := NewBus() - events := make(chan fmt.Stringer) - cancel, err := bus.Subscribe(events, ForceSubType(new(EventA))) + sub, err := bus.Subscribe([]interface{}{new(EventA), new(EventB)}) if err != nil { t.Fatal(err) } @@ -217,8 +213,8 @@ func TestSubType(t *testing.T) { wait.Add(1) go func() { - defer cancel() - event = <-events + defer sub.Close() + event = (<-sub.Out()).(EventA) wait.Done() }() @@ -244,15 +240,14 @@ func TestNonStateful(t *testing.T) { } defer em.Close() - eventsA := make(chan EventB, 1) - cancelS, err := bus.Subscribe(eventsA) + sub1, err := bus.Subscribe(new(EventB), BufSize(1)) if err != nil { t.Fatal(err) } - defer cancelS() + defer sub1.Close() select { - case <-eventsA: + case <-sub1.Out(): t.Fatal("didn't expect to get an event") default: } @@ -260,23 +255,22 @@ func TestNonStateful(t *testing.T) { em.Emit(EventB(1)) select { - case e := <-eventsA: - if e != 1 { + case e := <-sub1.Out(): + if e.(EventB) != 1 { t.Fatal("got wrong event") } default: t.Fatal("expected to get an event") } - eventsB := make(chan EventB, 1) - cancelS2, err := bus.Subscribe(eventsB) + sub2, err := bus.Subscribe(new(EventB), BufSize(1)) if err != nil { t.Fatal(err) } - defer cancelS2() + defer sub2.Close() select { - case <-eventsA: + case <-sub2.Out(): t.Fatal("didn't expect to get an event") default: } @@ -292,14 +286,13 @@ func TestStateful(t *testing.T) { em.Emit(EventB(2)) - eventsA := make(chan EventB, 1) - cancelS, err := bus.Subscribe(eventsA) + sub, err := bus.Subscribe(new(EventB), BufSize(1)) if err != nil { t.Fatal(err) } - defer cancelS() + defer sub.Close() - if <-eventsA != 2 { + if (<-sub.Out()).(EventB) != 2 { t.Fatal("got wrong event") } } @@ -320,16 +313,19 @@ func testMany(t testing.TB, subs, emits, msgs int, stateful bool) { for i := 0; i < subs; i++ { go func() { - events := make(chan EventB) - cancel, err := bus.Subscribe(events) + sub, err := bus.Subscribe(new(EventB)) if err != nil { panic(err) } - defer cancel() + defer sub.Close() ready.Done() for i := 0; i < emits*msgs; i++ { - atomic.AddInt64(&r, int64(<-events)) + e, ok := <-sub.Out() + if !ok { + panic("wat") + } + atomic.AddInt64(&r, int64(e.(EventB))) } wait.Done() }() diff --git a/p2p/host/eventbus/opts.go b/p2p/host/eventbus/opts.go index 38265557b3..50c673685b 100644 --- a/p2p/host/eventbus/opts.go +++ b/p2p/host/eventbus/opts.go @@ -1,40 +1,12 @@ package eventbus -import ( - "errors" - "reflect" - - "github.com/libp2p/go-libp2p-core/event" -) - type subSettings struct { - forcedType reflect.Type + buffer int } -// ForceSubType is a Subscribe option which overrides the type to which -// the subscription will be done. Note that the evtType must be assignable -// to channel type. -// -// This also allows for subscribing to multiple eventbus channels with one -// Go channel to get better ordering guarantees. -// -// Example: -// type Event struct{} -// func (Event) String() string { -// return "event" -// } -// -// eventCh := make(chan fmt.Stringer) // interface { String() string } -// cancel, err := eventbus.Subscribe(eventCh, event.ForceSubType(new(Event))) -// [...] -func ForceSubType(evtType interface{}) event.SubscriptionOpt { - return func(settings interface{}) error { - s := settings.(*subSettings) - typ := reflect.TypeOf(evtType) - if typ.Kind() != reflect.Ptr { - return errors.New("ForceSubType called with non-pointer type") - } - s.forcedType = typ +func BufSize(n int) func(interface{}) error { + return func(s interface{}) error { + s.(*subSettings).buffer = n return nil } }