diff --git a/swarm.go b/swarm.go index a55c8d9c..c5c3bf28 100644 --- a/swarm.go +++ b/swarm.go @@ -41,19 +41,21 @@ var ErrAddrFiltered = errors.New("address filtered") // ErrDialTimeout is returned when one a dial times out due to the global timeout var ErrDialTimeout = errors.New("dial timed out") -type Option func(*Swarm) +type Option func(*Swarm) error // WithConnectionGater sets a connection gater func WithConnectionGater(gater connmgr.ConnectionGater) Option { - return func(s *Swarm) { + return func(s *Swarm) error { s.gater = gater + return nil } } // WithMetrics sets a metrics reporter func WithMetrics(reporter metrics.Reporter) Option { - return func(s *Swarm) { + return func(s *Swarm) error { s.bwc = reporter + return nil } } @@ -114,7 +116,7 @@ type Swarm struct { } // NewSwarm constructs a Swarm. -func NewSwarm(local peer.ID, peers peerstore.Peerstore, opts ...Option) *Swarm { +func NewSwarm(local peer.ID, peers peerstore.Peerstore, opts ...Option) (*Swarm, error) { ctx, cancel := context.WithCancel(context.Background()) s := &Swarm{ local: local, @@ -129,13 +131,15 @@ func NewSwarm(local peer.ID, peers peerstore.Peerstore, opts ...Option) *Swarm { s.notifs.m = make(map[network.Notifiee]struct{}) for _, opt := range opts { - opt(s) + if err := opt(s); err != nil { + return nil, err + } } s.dsync = newDialSync(s.dialWorkerLoop) s.limiter = newDialLimiter(s.dialAddr) s.backf.init(s.ctx) - return s + return s, nil } func (s *Swarm) Close() error { diff --git a/testing/testing.go b/testing/testing.go index b0d4c218..0e4974f4 100644 --- a/testing/testing.go +++ b/testing/testing.go @@ -3,6 +3,8 @@ package testing import ( "testing" + "github.com/stretchr/testify/require" + csms "github.com/libp2p/go-conn-security-multistream" "github.com/libp2p/go-libp2p-core/connmgr" "github.com/libp2p/go-libp2p-core/control" @@ -117,7 +119,8 @@ func GenSwarm(t *testing.T, opts ...Option) *swarm.Swarm { if cfg.connectionGater != nil { swarmOpts = append(swarmOpts, swarm.WithConnectionGater(cfg.connectionGater)) } - s := swarm.NewSwarm(p.ID, ps, swarmOpts...) + s, err := swarm.NewSwarm(p.ID, ps, swarmOpts...) + require.NoError(t, err) upgrader := GenUpgrader(s) upgrader.ConnGater = cfg.connectionGater