diff --git a/config/config.go b/config/config.go index adbcf1d122..81cfff293d 100644 --- a/config/config.go +++ b/config/config.go @@ -13,8 +13,8 @@ import ( "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peerstore" "github.com/libp2p/go-libp2p/core/pnet" + "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/core/routing" - "github.com/libp2p/go-libp2p/core/sec" "github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/p2p/host/autonat" "github.com/libp2p/go-libp2p/p2p/host/autorelay" @@ -28,13 +28,12 @@ import ( relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" "github.com/libp2p/go-libp2p/p2p/protocol/holepunch" - logging "github.com/ipfs/go-log/v2" ma "github.com/multiformats/go-multiaddr" madns "github.com/multiformats/go-multiaddr-dns" + "go.uber.org/fx" + "go.uber.org/fx/fxevent" ) -var log = logging.Logger("p2p-config") - // AddrsFactory is a function that takes a set of multiaddrs we're listening on and // returns the set of multiaddrs we should advertise to the network. type AddrsFactory = bhost.AddrsFactory @@ -71,9 +70,9 @@ type Config struct { PeerKey crypto.PrivKey - Transports []TptC - Muxers []MsMuxC - SecurityTransports []MsSecC + Transports []fx.Option + Muxers []Muxer + SecurityTransports []fx.Option Insecure bool PSK pnet.PSK @@ -168,51 +167,65 @@ func (cfg *Config) addTransports(h host.Host) error { // Should probably skip this if no transports. return fmt.Errorf("swarm does not support transports") } - var secure sec.SecureMuxer + + muxers := make([]protocol.ID, 0, len(cfg.Muxers)) + for _, m := range cfg.Muxers { + muxers = append(muxers, m.ID) + } + + var security []fx.Option if cfg.Insecure { - secure = makeInsecureTransport(h.ID(), cfg.PeerKey) + security = append(security, fx.Provide(makeInsecureTransport)) } else { - var err error - secure, err = makeSecurityMuxer(h, cfg.SecurityTransports, cfg.Muxers) - if err != nil { - return err - } - } - muxer, err := makeMuxer(h, cfg.Muxers) - if err != nil { - return err - } - var opts []tptu.Option - if len(cfg.PSK) > 0 { - opts = append(opts, tptu.WithPSK(cfg.PSK)) - } - if cfg.ConnectionGater != nil { - opts = append(opts, tptu.WithConnectionGater(cfg.ConnectionGater)) - } - if cfg.ResourceManager != nil { - opts = append(opts, tptu.WithResourceManager(cfg.ResourceManager)) - } - upgrader, err := tptu.New(secure, muxer, opts...) - if err != nil { - return err + security = cfg.SecurityTransports } - tpts, err := makeTransports(h, upgrader, cfg.ConnectionGater, cfg.PSK, cfg.ResourceManager, cfg.MultiaddrResolver, cfg.Transports) + muxer, err := makeMuxer(cfg.Muxers) if err != nil { return err } - for _, t := range tpts { - if err := swrm.AddTransport(t); err != nil { - return err - } - } + fxopts := []fx.Option{ + fx.WithLogger(func() fxevent.Logger { return getFXLogger() }), + fx.Provide(tptu.New), + fx.Provide(func() network.Multiplexer { return muxer }), + fx.Provide(fx.Annotate( + makeSecurityMuxer, + fx.ParamTags(`group:"security"`), + )), + fx.Supply(muxers), + fx.Provide(func() host.Host { return h }), + fx.Provide(func() crypto.PrivKey { return h.Peerstore().PrivKey(h.ID()) }), + fx.Provide(func() connmgr.ConnectionGater { return cfg.ConnectionGater }), + fx.Provide(func() pnet.PSK { return cfg.PSK }), + fx.Provide(func() network.ResourceManager { return cfg.ResourceManager }), + fx.Provide(func() *madns.Resolver { return cfg.MultiaddrResolver }), + } + fxopts = append(fxopts, cfg.Transports...) + if !cfg.Insecure { + fxopts = append(fxopts, security...) + } + + fxopts = append(fxopts, fx.Invoke( + fx.Annotate( + func(tpts []transport.Transport) error { + for _, t := range tpts { + if err := swrm.AddTransport(t); err != nil { + return err + } + } + return nil + }, + fx.ParamTags(`group:"transport"`), + )), + ) if cfg.Relay { - if err := circuitv2.AddTransport(h, upgrader); err != nil { - h.Close() - return err - } + fxopts = append(fxopts, fx.Invoke(circuitv2.AddTransport)) + } + app := fx.New(fxopts...) + if err := app.Err(); err != nil { + h.Close() + return err } - return nil } diff --git a/config/constructor_types.go b/config/constructor_types.go deleted file mode 100644 index 9c14df2e2c..0000000000 --- a/config/constructor_types.go +++ /dev/null @@ -1,91 +0,0 @@ -package config - -import ( - "fmt" - "reflect" - - "github.com/libp2p/go-libp2p/core/connmgr" - "github.com/libp2p/go-libp2p/core/crypto" - "github.com/libp2p/go-libp2p/core/host" - "github.com/libp2p/go-libp2p/core/network" - "github.com/libp2p/go-libp2p/core/peer" - "github.com/libp2p/go-libp2p/core/peerstore" - "github.com/libp2p/go-libp2p/core/pnet" - "github.com/libp2p/go-libp2p/core/protocol" - "github.com/libp2p/go-libp2p/core/sec" - "github.com/libp2p/go-libp2p/core/transport" - - madns "github.com/multiformats/go-multiaddr-dns" -) - -var ( - // interfaces - hostType = reflect.TypeOf((*host.Host)(nil)).Elem() - networkType = reflect.TypeOf((*network.Network)(nil)).Elem() - transportType = reflect.TypeOf((*transport.Transport)(nil)).Elem() - muxType = reflect.TypeOf((*network.Multiplexer)(nil)).Elem() - securityType = reflect.TypeOf((*sec.SecureTransport)(nil)).Elem() - privKeyType = reflect.TypeOf((*crypto.PrivKey)(nil)).Elem() - pubKeyType = reflect.TypeOf((*crypto.PubKey)(nil)).Elem() - pstoreType = reflect.TypeOf((*peerstore.Peerstore)(nil)).Elem() - connGaterType = reflect.TypeOf((*connmgr.ConnectionGater)(nil)).Elem() - upgraderType = reflect.TypeOf((*transport.Upgrader)(nil)).Elem() - rcmgrType = reflect.TypeOf((*network.ResourceManager)(nil)).Elem() - - // concrete types - peerIDType = reflect.TypeOf((peer.ID)("")) - pskType = reflect.TypeOf((pnet.PSK)(nil)) - resolverType = reflect.TypeOf((*madns.Resolver)(nil)) - muxersType = reflect.TypeOf(([]protocol.ID)(nil)) -) - -var argTypes = map[reflect.Type]constructor{ - upgraderType: func(_ host.Host, u transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []protocol.ID) interface{} { - return u - }, - hostType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []protocol.ID) interface{} { - return h - }, - networkType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []protocol.ID) interface{} { - return h.Network() - }, - pskType: func(_ host.Host, _ transport.Upgrader, psk pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []protocol.ID) interface{} { - return psk - }, - connGaterType: func(_ host.Host, _ transport.Upgrader, _ pnet.PSK, cg connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []protocol.ID) interface{} { - return cg - }, - peerIDType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []protocol.ID) interface{} { - return h.ID() - }, - privKeyType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []protocol.ID) interface{} { - return h.Peerstore().PrivKey(h.ID()) - }, - pubKeyType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []protocol.ID) interface{} { - return h.Peerstore().PubKey(h.ID()) - }, - pstoreType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []protocol.ID) interface{} { - return h.Peerstore() - }, - rcmgrType: func(_ host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, rcmgr network.ResourceManager, _ *madns.Resolver, _ []protocol.ID) interface{} { - return rcmgr - }, - resolverType: func(_ host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, r *madns.Resolver, _ []protocol.ID) interface{} { - return r - }, - muxersType: func(_ host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, muxers []protocol.ID) interface{} { - return muxers - }, -} - -func newArgTypeSet(types ...reflect.Type) map[reflect.Type]constructor { - result := make(map[reflect.Type]constructor, len(types)) - for _, ty := range types { - c, ok := argTypes[ty] - if !ok { - panic(fmt.Sprintf("missing constructor for type %s", ty)) - } - result[ty] = c - } - return result -} diff --git a/config/log.go b/config/log.go new file mode 100644 index 0000000000..3b74c38c7d --- /dev/null +++ b/config/log.go @@ -0,0 +1,28 @@ +package config + +import ( + "strings" + "sync" + + logging "github.com/ipfs/go-log/v2" + "go.uber.org/fx/fxevent" +) + +var log = logging.Logger("p2p-config") + +var ( + fxLogger fxevent.Logger + logInitOnce sync.Once +) + +type fxLogWriter struct{} + +func (l *fxLogWriter) Write(b []byte) (int, error) { + log.Debug(strings.TrimSuffix(string(b), "\n")) + return len(b), nil +} + +func getFXLogger() fxevent.Logger { + logInitOnce.Do(func() { fxLogger = &fxevent.ConsoleLogger{W: &fxLogWriter{}} }) + return fxLogger +} diff --git a/config/muxer.go b/config/muxer.go index e7b64c1345..448db65e84 100644 --- a/config/muxer.go +++ b/config/muxer.go @@ -3,61 +3,27 @@ package config import ( "fmt" - "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/protocol" msmux "github.com/libp2p/go-libp2p/p2p/muxer/muxer-multistream" ) -// MuxC is a stream multiplex transport constructor. -type MuxC func(h host.Host) (network.Multiplexer, error) - -// MsMuxC is a tuple containing a multiplex transport constructor and a protocol -// ID. -type MsMuxC struct { - MuxC - ID string +type Muxer struct { + ID protocol.ID + Multiplexer network.Multiplexer } -var muxArgTypes = newArgTypeSet(hostType, networkType, peerIDType, pstoreType) - -// MuxerConstructor creates a multiplex constructor from the passed parameter -// using reflection. -func MuxerConstructor(m interface{}) (MuxC, error) { - // Already constructed? - if t, ok := m.(network.Multiplexer); ok { - return func(_ host.Host) (network.Multiplexer, error) { - return t, nil - }, nil - } - - ctor, err := makeConstructor(m, muxType, muxArgTypes) - if err != nil { - return nil, err - } - return func(h host.Host) (network.Multiplexer, error) { - t, err := ctor(h, nil, nil, nil, nil, nil, nil) - if err != nil { - return nil, err - } - return t.(network.Multiplexer), nil - }, nil -} - -func makeMuxer(h host.Host, tpts []MsMuxC) (network.Multiplexer, error) { +func makeMuxer(muxers []Muxer) (network.Multiplexer, error) { muxMuxer := msmux.NewBlankTransport() - transportSet := make(map[string]struct{}, len(tpts)) - for _, tptC := range tpts { - if _, ok := transportSet[tptC.ID]; ok { - return nil, fmt.Errorf("duplicate muxer transport: %s", tptC.ID) + transportSet := make(map[protocol.ID]struct{}, len(muxers)) + for _, m := range muxers { + if _, ok := transportSet[m.ID]; ok { + return nil, fmt.Errorf("duplicate muxer transport: %s", m.ID) } - transportSet[tptC.ID] = struct{}{} + transportSet[m.ID] = struct{}{} } - for _, tptC := range tpts { - tpt, err := tptC.MuxC(h) - if err != nil { - return nil, err - } - muxMuxer.AddTransport(tptC.ID, tpt) + for _, m := range muxers { + muxMuxer.AddTransport(string(m.ID), m.Multiplexer) } return muxMuxer, nil } diff --git a/config/muxer_test.go b/config/muxer_test.go deleted file mode 100644 index 772b43294c..0000000000 --- a/config/muxer_test.go +++ /dev/null @@ -1,102 +0,0 @@ -package config - -import ( - "testing" - - "github.com/libp2p/go-libp2p/core/host" - "github.com/libp2p/go-libp2p/core/network" - "github.com/libp2p/go-libp2p/core/peer" - bhost "github.com/libp2p/go-libp2p/p2p/host/basic" - "github.com/libp2p/go-libp2p/p2p/muxer/yamux" - swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" -) - -func TestMuxerSimple(t *testing.T) { - // single - _, err := MuxerConstructor(func(_ peer.ID) network.Multiplexer { return nil }) - if err != nil { - t.Fatal(err) - } -} - -func TestMuxerByValue(t *testing.T) { - _, err := MuxerConstructor(yamux.DefaultTransport) - if err != nil { - t.Fatal(err) - } -} -func TestMuxerDuplicate(t *testing.T) { - _, err := MuxerConstructor(func(_ peer.ID, _ peer.ID) network.Multiplexer { return nil }) - if err != nil { - t.Fatal(err) - } -} - -func TestMuxerError(t *testing.T) { - _, err := MuxerConstructor(func() (network.Multiplexer, error) { return nil, nil }) - if err != nil { - t.Fatal(err) - } -} - -func TestMuxerBadTypes(t *testing.T) { - for i, f := range []interface{}{ - func() error { return nil }, - func() string { return "" }, - func() {}, - func(string) network.Multiplexer { return nil }, - func(string) (network.Multiplexer, error) { return nil, nil }, - nil, - "testing", - } { - - if _, err := MuxerConstructor(f); err == nil { - t.Fatalf("constructor %d with type %T should have failed", i, f) - } - } -} - -func TestCatchDuplicateTransportsMuxer(t *testing.T) { - h, err := bhost.NewHost(swarmt.GenSwarm(t), nil) - if err != nil { - t.Fatal(err) - } - yamuxMuxer, err := MuxerConstructor(yamux.DefaultTransport) - if err != nil { - t.Fatal(err) - } - - var tests = map[string]struct { - h host.Host - transports []MsMuxC - expectedError string - }{ - "no duplicate transports": { - h: h, - transports: []MsMuxC{{yamuxMuxer, "yamux"}}, - expectedError: "", - }, - "duplicate transports": { - h: h, - transports: []MsMuxC{ - {yamuxMuxer, "yamux"}, - {yamuxMuxer, "yamux"}, - }, - expectedError: "duplicate muxer transport: yamux", - }, - } - for testName, test := range tests { - t.Run(testName, func(t *testing.T) { - _, err = makeMuxer(test.h, test.transports) - if err != nil { - if err.Error() != test.expectedError { - t.Errorf( - "\nexpected: [%v]\nactual: [%v]\n", - test.expectedError, - err, - ) - } - } - }) - } -} diff --git a/config/reflection_magic.go b/config/reflection_magic.go deleted file mode 100644 index 0189872abb..0000000000 --- a/config/reflection_magic.go +++ /dev/null @@ -1,176 +0,0 @@ -package config - -import ( - "errors" - "fmt" - "reflect" - "runtime" - - "github.com/libp2p/go-libp2p/core/connmgr" - "github.com/libp2p/go-libp2p/core/host" - "github.com/libp2p/go-libp2p/core/network" - "github.com/libp2p/go-libp2p/core/pnet" - "github.com/libp2p/go-libp2p/core/protocol" - "github.com/libp2p/go-libp2p/core/transport" - - madns "github.com/multiformats/go-multiaddr-dns" -) - -var errorType = reflect.TypeOf((*error)(nil)).Elem() - -// checks if a function returns either the specified type or the specified type -// and an error. -func checkReturnType(fnType, tptType reflect.Type) error { - switch fnType.NumOut() { - case 2: - if fnType.Out(1) != errorType { - return fmt.Errorf("expected (optional) second return value from transport constructor to be an error") - } - - fallthrough - case 1: - if !fnType.Out(0).Implements(tptType) { - return fmt.Errorf("transport constructor returns %s which doesn't implement %s", fnType.Out(0), tptType) - } - default: - return fmt.Errorf("expected transport constructor to return a transport and, optionally, an error") - } - return nil -} - -// Handles return values with optional errors. That is, return values of the -// form `(something, error)` or just `something`. -// -// Panics if the return value isn't of the correct form. -func handleReturnValue(out []reflect.Value) (interface{}, error) { - switch len(out) { - case 2: - err := out[1] - if err != (reflect.Value{}) && !err.IsNil() { - return nil, err.Interface().(error) - } - fallthrough - case 1: - tpt := out[0] - - // Check for nil value and nil error. - if tpt == (reflect.Value{}) { - return nil, fmt.Errorf("unspecified error") - } - switch tpt.Kind() { - case reflect.Ptr, reflect.Interface, reflect.Func: - if tpt.IsNil() { - return nil, fmt.Errorf("unspecified error") - } - } - - return tpt.Interface(), nil - default: - panic("expected 1 or 2 return values from transport constructor") - } -} - -// calls the transport constructor and annotates the error with the name of the constructor. -func callConstructor(c reflect.Value, args []reflect.Value) (interface{}, error) { - val, err := handleReturnValue(c.Call(args)) - if err != nil { - name := runtime.FuncForPC(c.Pointer()).Name() - if name != "" { - // makes debugging easier - return nil, fmt.Errorf("transport constructor %s failed: %s", name, err) - } - } - return val, err -} - -type constructor func(host.Host, transport.Upgrader, pnet.PSK, connmgr.ConnectionGater, network.ResourceManager, *madns.Resolver, []protocol.ID) interface{} - -func makeArgumentConstructors(fnType reflect.Type, argTypes map[reflect.Type]constructor) ([]constructor, error) { - params := fnType.NumIn() - if fnType.IsVariadic() { - params-- - } - out := make([]constructor, params) - for i := range out { - argType := fnType.In(i) - c, ok := argTypes[argType] - if !ok { - return nil, fmt.Errorf("argument %d has an unexpected type %s", i, argType.Name()) - } - out[i] = c - } - return out, nil -} - -func getConstructorOpts(t reflect.Type, opts ...interface{}) ([]reflect.Value, error) { - if !t.IsVariadic() { - if len(opts) > 0 { - return nil, errors.New("constructor doesn't accept any options") - } - return nil, nil - } - if len(opts) == 0 { - return nil, nil - } - // variadic parameters always go last - wantType := t.In(t.NumIn() - 1).Elem() - values := make([]reflect.Value, 0, len(opts)) - for _, opt := range opts { - val := reflect.ValueOf(opt) - if opt == nil { - return nil, errors.New("expected a transport option, got nil") - } - if val.Type() != wantType { - return nil, fmt.Errorf("expected option of type %s, got %s", wantType, reflect.TypeOf(opt)) - } - values = append(values, val.Convert(wantType)) - } - return values, nil -} - -// makes a transport constructor. -func makeConstructor( - tpt interface{}, - tptType reflect.Type, - argTypes map[reflect.Type]constructor, - opts ...interface{}, -) (func(host.Host, transport.Upgrader, pnet.PSK, connmgr.ConnectionGater, network.ResourceManager, *madns.Resolver, []protocol.ID) (interface{}, error), error) { - v := reflect.ValueOf(tpt) - // avoid panicing on nil/zero value. - if v == (reflect.Value{}) { - return nil, fmt.Errorf("expected a transport or transport constructor, got a %T", tpt) - } - t := v.Type() - if t.Kind() != reflect.Func { - return nil, fmt.Errorf("expected a transport or transport constructor, got a %T", tpt) - } - - if err := checkReturnType(t, tptType); err != nil { - return nil, err - } - - argConstructors, err := makeArgumentConstructors(t, argTypes) - if err != nil { - return nil, err - } - optValues, err := getConstructorOpts(t, opts...) - if err != nil { - return nil, err - } - - return func(h host.Host, u transport.Upgrader, psk pnet.PSK, cg connmgr.ConnectionGater, rcmgr network.ResourceManager, resolver *madns.Resolver, muxers []protocol.ID) (interface{}, error) { - arguments := make([]reflect.Value, 0, len(argConstructors)+len(opts)) - for i, makeArg := range argConstructors { - if arg := makeArg(h, u, psk, cg, rcmgr, resolver, muxers); arg != nil { - arguments = append(arguments, reflect.ValueOf(arg)) - } else { - // ValueOf an un-typed nil yields a zero reflect - // value. However, we _want_ the zero value of - // the _type_. - arguments = append(arguments, reflect.Zero(t.In(i))) - } - } - arguments = append(arguments, optValues...) - return callConstructor(v, arguments) - }, nil -} diff --git a/config/reflection_magic_test.go b/config/reflection_magic_test.go deleted file mode 100644 index ce9f3986f3..0000000000 --- a/config/reflection_magic_test.go +++ /dev/null @@ -1,142 +0,0 @@ -package config - -import ( - "errors" - "reflect" - "strings" - "testing" -) - -func TestHandleReturnValue(t *testing.T) { - // one value - v, err := handleReturnValue([]reflect.Value{reflect.ValueOf(1)}) - if v.(int) != 1 { - t.Fatal("expected value") - } - if err != nil { - t.Fatal(err) - } - - // Nil value - v, err = handleReturnValue([]reflect.Value{reflect.ValueOf(nil)}) - if v != nil { - t.Fatal("expected no value") - } - if err == nil { - t.Fatal("expected an error") - } - - // Nil value, nil err - v, err = handleReturnValue([]reflect.Value{reflect.ValueOf(nil), reflect.ValueOf(nil)}) - if v != nil { - t.Fatal("expected no value") - } - if err == nil { - t.Fatal("expected an error") - } - - // two values - v, err = handleReturnValue([]reflect.Value{reflect.ValueOf(1), reflect.ValueOf(nil)}) - if v, ok := v.(int); !ok || v != 1 { - t.Fatalf("expected value of 1, got %v", v) - } - if err != nil { - t.Fatal("expected no error") - } - - // an error - myError := errors.New("my error") - _, err = handleReturnValue([]reflect.Value{reflect.ValueOf(1), reflect.ValueOf(myError)}) - if err != myError { - t.Fatal(err) - } - - for _, vals := range [][]reflect.Value{ - {reflect.ValueOf(1), reflect.ValueOf("not an error")}, - {}, - {reflect.ValueOf(1), reflect.ValueOf(myError), reflect.ValueOf(myError)}, - } { - func() { - defer func() { recover() }() - handleReturnValue(vals) - t.Fatal("expected a panic") - }() - } -} - -type foo interface { - foo() foo -} - -var fooType = reflect.TypeOf((*foo)(nil)).Elem() - -func TestCheckReturnType(t *testing.T) { - for i, fn := range []interface{}{ - func() { panic("") }, - func() error { panic("") }, - func() (error, error) { panic("") }, - func() (foo, error, error) { panic("") }, - func() (foo, foo) { panic("") }, - } { - if checkReturnType(reflect.TypeOf(fn), fooType) == nil { - t.Errorf("expected falure for case %d (type %T)", i, fn) - } - } - - for i, fn := range []interface{}{ - func() foo { panic("") }, - func() (foo, error) { panic("") }, - } { - if err := checkReturnType(reflect.TypeOf(fn), fooType); err != nil { - t.Errorf("expected success for case %d (type %T), got: %s", i, fn, err) - } - } -} - -func constructFoo() foo { - return nil -} - -type fooImpl struct{} - -func (f *fooImpl) foo() foo { return nil } - -func TestCallConstructor(t *testing.T) { - _, err := callConstructor(reflect.ValueOf(constructFoo), nil) - if err == nil { - t.Fatal("expected constructor to fail") - } - - if !strings.Contains(err.Error(), "constructFoo") { - t.Errorf("expected error to contain the constructor name: %s", err) - } - - v, err := callConstructor(reflect.ValueOf(func() foo { return &fooImpl{} }), nil) - if err != nil { - t.Fatal(err) - } - if _, ok := v.(*fooImpl); !ok { - t.Fatal("expected a fooImpl") - } - - v, err = callConstructor(reflect.ValueOf(func() *fooImpl { return new(fooImpl) }), nil) - if err != nil { - t.Fatal(err) - } - if _, ok := v.(*fooImpl); !ok { - t.Fatal("expected a fooImpl") - } - - _, err = callConstructor(reflect.ValueOf(func() (*fooImpl, error) { return nil, nil }), nil) - if err == nil { - t.Fatal("expected error") - } - - v, err = callConstructor(reflect.ValueOf(func() (*fooImpl, error) { return new(fooImpl), nil }), nil) - if err != nil { - t.Fatal(err) - } - if _, ok := v.(*fooImpl); !ok { - t.Fatal("expected a fooImpl") - } -} diff --git a/config/security.go b/config/security.go index 6f15b63462..1a88575f5d 100644 --- a/config/security.go +++ b/config/security.go @@ -1,84 +1,23 @@ package config import ( - "fmt" - "github.com/libp2p/go-libp2p/core/crypto" - "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/peer" - "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/core/sec" "github.com/libp2p/go-libp2p/core/sec/insecure" csms "github.com/libp2p/go-libp2p/p2p/net/conn-security-multistream" ) -// SecC is a security transport constructor. -type SecC func(h host.Host, muxers []protocol.ID) (sec.SecureTransport, error) - -// MsSecC is a tuple containing a security transport constructor and a protocol -// ID. -type MsSecC struct { - SecC - ID string -} - -var securityArgTypes = newArgTypeSet( - hostType, networkType, peerIDType, - privKeyType, pubKeyType, pstoreType, - muxersType, -) - -// SecurityConstructor creates a security constructor from the passed parameter -// using reflection. -func SecurityConstructor(security interface{}) (SecC, error) { - // Already constructed? - if t, ok := security.(sec.SecureTransport); ok { - return func(_ host.Host, _ []protocol.ID) (sec.SecureTransport, error) { - return t, nil - }, nil - } - - ctor, err := makeConstructor(security, securityType, securityArgTypes) - if err != nil { - return nil, err - } - return func(h host.Host, muxers []protocol.ID) (sec.SecureTransport, error) { - t, err := ctor(h, nil, nil, nil, nil, nil, muxers) - if err != nil { - return nil, err - } - return t.(sec.SecureTransport), nil - }, nil -} - func makeInsecureTransport(id peer.ID, privKey crypto.PrivKey) sec.SecureMuxer { secMuxer := new(csms.SSMuxer) - secMuxer.AddTransport(insecure.ID, insecure.NewWithIdentity(id, privKey)) + secMuxer.AddTransport(insecure.ID, insecure.NewWithIdentity(insecure.ID, id, privKey)) return secMuxer } -func makeSecurityMuxer(h host.Host, tpts []MsSecC, muxers []MsMuxC) (sec.SecureMuxer, error) { +func makeSecurityMuxer(tpts []sec.SecureTransport) sec.SecureMuxer { secMuxer := new(csms.SSMuxer) - transportSet := make(map[string]struct{}, len(tpts)) - for _, tptC := range tpts { - if _, ok := transportSet[tptC.ID]; ok { - return nil, fmt.Errorf("duplicate security transport: %s", tptC.ID) - } - transportSet[tptC.ID] = struct{}{} + for _, tpt := range tpts { + secMuxer.AddTransport(string(tpt.ID()), tpt) } - muxIds := make([]protocol.ID, 0, len(muxers)) - for _, muxc := range muxers { - muxIds = append(muxIds, protocol.ID(muxc.ID)) - } - for _, tptC := range tpts { - tpt, err := tptC.SecC(h, muxIds) - if err != nil { - return nil, err - } - if _, ok := tpt.(*insecure.Transport); ok { - return nil, fmt.Errorf("cannot construct libp2p with an insecure transport, set the Insecure config option instead") - } - secMuxer.AddTransport(tptC.ID, tpt) - } - return secMuxer, nil + return secMuxer } diff --git a/config/transport.go b/config/transport.go deleted file mode 100644 index 6105e77f13..0000000000 --- a/config/transport.go +++ /dev/null @@ -1,71 +0,0 @@ -package config - -import ( - "github.com/libp2p/go-libp2p/core/connmgr" - "github.com/libp2p/go-libp2p/core/host" - "github.com/libp2p/go-libp2p/core/network" - "github.com/libp2p/go-libp2p/core/pnet" - "github.com/libp2p/go-libp2p/core/transport" - - madns "github.com/multiformats/go-multiaddr-dns" -) - -// TptC is the type for libp2p transport constructors. You probably won't ever -// implement this function interface directly. Instead, pass your transport -// constructor to TransportConstructor. -type TptC func(host.Host, transport.Upgrader, pnet.PSK, connmgr.ConnectionGater, network.ResourceManager, *madns.Resolver) (transport.Transport, error) - -var transportArgTypes = argTypes - -// TransportConstructor uses reflection to turn a function that constructs a -// transport into a TptC. -// -// You can pass either a constructed transport (something that implements -// `transport.Transport`) or a function that takes any of: -// -// * The local peer ID. -// * A transport connection upgrader. -// * A private key. -// * A public key. -// * A Host. -// * A Network. -// * A Peerstore. -// * An address filter. -// * A security transport. -// * A stream multiplexer transport. -// * A private network protection key. -// * A connection gater. -// -// And returns a type implementing transport.Transport and, optionally, an error -// (as the second argument). -func TransportConstructor(tpt interface{}, opts ...interface{}) (TptC, error) { - // Already constructed? - if t, ok := tpt.(transport.Transport); ok { - return func(_ host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver) (transport.Transport, error) { - return t, nil - }, nil - } - ctor, err := makeConstructor(tpt, transportType, transportArgTypes, opts...) - if err != nil { - return nil, err - } - return func(h host.Host, u transport.Upgrader, psk pnet.PSK, cg connmgr.ConnectionGater, rcmgr network.ResourceManager, resolver *madns.Resolver) (transport.Transport, error) { - t, err := ctor(h, u, psk, cg, rcmgr, resolver, nil) - if err != nil { - return nil, err - } - return t.(transport.Transport), nil - }, nil -} - -func makeTransports(h host.Host, u transport.Upgrader, cg connmgr.ConnectionGater, psk pnet.PSK, rcmgr network.ResourceManager, resolver *madns.Resolver, tpts []TptC) ([]transport.Transport, error) { - transports := make([]transport.Transport, len(tpts)) - for i, tC := range tpts { - t, err := tC(h, u, psk, cg, rcmgr, resolver) - if err != nil { - return nil, err - } - transports[i] = t - } - return transports, nil -} diff --git a/config/transport_test.go b/config/transport_test.go deleted file mode 100644 index 5684521e98..0000000000 --- a/config/transport_test.go +++ /dev/null @@ -1,42 +0,0 @@ -package config - -import ( - "testing" - - "github.com/libp2p/go-libp2p/core/peer" - "github.com/libp2p/go-libp2p/core/transport" - "github.com/libp2p/go-libp2p/p2p/transport/tcp" - - "github.com/stretchr/testify/require" -) - -func TestTransportVariadicOptions(t *testing.T) { - _, err := TransportConstructor(func(_ peer.ID, _ ...int) transport.Transport { return nil }) - require.NoError(t, err) -} - -func TestConstructorWithoutOptsCalledWithOpts(t *testing.T) { - _, err := TransportConstructor(func(_ transport.Upgrader) transport.Transport { - return nil - }, 42) - require.EqualError(t, err, "constructor doesn't accept any options") -} - -func TestConstructorWithOptsTypeMismatch(t *testing.T) { - _, err := TransportConstructor(func(_ transport.Upgrader, opts ...int) transport.Transport { - return nil - }, 42, "foo") - require.EqualError(t, err, "expected option of type int, got string") -} - -func TestConstructorWithOpts(t *testing.T) { - var options []int - c, err := TransportConstructor(func(_ transport.Upgrader, opts ...int) (transport.Transport, error) { - options = opts - return tcp.NewTCPTransport(nil, nil) - }, 42, 1337) - require.NoError(t, err) - _, err = c(nil, nil, nil, nil, nil, nil) - require.NoError(t, err) - require.Equal(t, []int{42, 1337}, options) -} diff --git a/core/sec/insecure/insecure.go b/core/sec/insecure/insecure.go index 2d94f43804..d2487a3b0a 100644 --- a/core/sec/insecure/insecure.go +++ b/core/sec/insecure/insecure.go @@ -12,6 +12,7 @@ import ( ci "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/core/sec" pb "github.com/libp2p/go-libp2p/core/sec/insecure/pb" @@ -28,18 +29,22 @@ const ID = "/plaintext/2.0.0" // peer presents as their ID and public key. // No authentication of the remote identity is performed. type Transport struct { - id peer.ID - key ci.PrivKey + id peer.ID + key ci.PrivKey + protocolID protocol.ID } +var _ sec.SecureTransport = &Transport{} + // NewWithIdentity constructs a new insecure transport. The provided private key // is stored and returned from LocalPrivateKey to satisfy the // SecureTransport interface, and the public key is sent to // remote peers. No security is provided. -func NewWithIdentity(id peer.ID, key ci.PrivKey) *Transport { +func NewWithIdentity(protocolID protocol.ID, id peer.ID, key ci.PrivKey) *Transport { return &Transport{ - id: id, - key: key, + protocolID: protocolID, + id: id, + key: key, } } @@ -108,6 +113,10 @@ func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p pee return conn, nil } +func (t *Transport) ID() protocol.ID { + return t.protocolID +} + // Conn is the connection type returned by the insecure transport. type Conn struct { net.Conn diff --git a/core/sec/insecure/insecure_test.go b/core/sec/insecure/insecure_test.go index a3ce8314f4..8663f1d975 100644 --- a/core/sec/insecure/insecure_test.go +++ b/core/sec/insecure/insecure_test.go @@ -61,7 +61,7 @@ func newTestTransport(t *testing.T, typ, bits int) *Transport { require.NoError(t, err) id, err := peer.IDFromPublicKey(pub) require.NoError(t, err) - return NewWithIdentity(id, priv) + return NewWithIdentity("/test/1.0.0", id, priv) } // Create a new pair of connected TCP sockets. diff --git a/core/sec/security.go b/core/sec/security.go index c192a56a91..8b733b5d06 100644 --- a/core/sec/security.go +++ b/core/sec/security.go @@ -7,6 +7,7 @@ import ( "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/protocol" ) // SecureConn is an authenticated, encrypted connection. @@ -24,6 +25,9 @@ type SecureTransport interface { // SecureOutbound secures an outbound connection. SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (SecureConn, error) + + // ID is the protocol ID of the security protocol. + ID() protocol.ID } // A SecureMuxer is a wrapper around SecureTransport which can select security protocols diff --git a/error_util.go b/error_util.go deleted file mode 100644 index 86827f4eac..0000000000 --- a/error_util.go +++ /dev/null @@ -1,17 +0,0 @@ -package libp2p - -import ( - "fmt" - "runtime" -) - -func traceError(err error, skip int) error { - if err == nil { - return nil - } - _, file, line, ok := runtime.Caller(skip + 1) - if !ok { - return err - } - return fmt.Errorf("%s:%d: %s", file, line, err) -} diff --git a/go.mod b/go.mod index 3c6128f8d9..371298868d 100644 --- a/go.mod +++ b/go.mod @@ -51,6 +51,7 @@ require ( github.com/raulk/go-watchdog v1.3.0 github.com/stretchr/testify v1.8.0 go.opencensus.io v0.23.0 + go.uber.org/fx v1.18.2 go.uber.org/goleak v1.1.12 golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 @@ -105,6 +106,7 @@ require ( github.com/spaolacci/murmur3 v1.1.0 // indirect github.com/syndtr/goleveldb v1.0.0 // indirect go.uber.org/atomic v1.10.0 // indirect + go.uber.org/dig v1.15.0 // indirect go.uber.org/multierr v1.8.0 // indirect go.uber.org/zap v1.23.0 // indirect golang.org/x/exp v0.0.0-20220916125017-b168a2c6b86b // indirect diff --git a/go.sum b/go.sum index c9161da572..268c140d05 100644 --- a/go.sum +++ b/go.sum @@ -532,6 +532,10 @@ go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= +go.uber.org/dig v1.15.0 h1:vq3YWr8zRj1eFGC7Gvf907hE0eRjPTZ1d3xHadD6liE= +go.uber.org/dig v1.15.0/go.mod h1:pKHs0wMynzL6brANhB2hLMro+zalv1osARTviTcqHLM= +go.uber.org/fx v1.18.2 h1:bUNI6oShr+OVFQeU8cDNbnN7VFsu+SsjHzUF51V/GAU= +go.uber.org/fx v1.18.2/go.mod h1:g0V1KMQ66zIRk8bLu3Ea5Jt2w/cHlOIp4wdRsgh0JaY= go.uber.org/goleak v1.1.11-0.20210813005559-691160354723/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= go.uber.org/goleak v1.1.12 h1:gZAh5/EyT/HQwlpkCy6wTpqfH9H8Lz8zbm3dZh+OyzA= go.uber.org/goleak v1.1.12/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= diff --git a/libp2p_test.go b/libp2p_test.go index 094fdefbc4..9a469bd1a5 100644 --- a/libp2p_test.go +++ b/libp2p_test.go @@ -12,8 +12,13 @@ import ( "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/transport" + "github.com/libp2p/go-libp2p/p2p/net/swarm" + "github.com/libp2p/go-libp2p/p2p/security/noise" + tls "github.com/libp2p/go-libp2p/p2p/security/tls" + quic "github.com/libp2p/go-libp2p/p2p/transport/quic" "github.com/libp2p/go-libp2p/p2p/transport/tcp" + ma "github.com/multiformats/go-multiaddr" "github.com/stretchr/testify/require" ) @@ -165,3 +170,122 @@ func TestChainOptions(t *testing.T) { } } } + +func TestTransportConstructorTCP(t *testing.T) { + h, err := New( + Transport(tcp.NewTCPTransport), + DisableRelay(), + ) + require.NoError(t, err) + defer h.Close() + require.NoError(t, h.Network().Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0"))) + err = h.Network().Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic")) + require.Error(t, err) + require.Contains(t, err.Error(), swarm.ErrNoTransport.Error()) +} + +func TestTransportConstructorQUIC(t *testing.T) { + h, err := New( + Transport(quic.NewTransport, quic.DisableReuseport()), + DisableRelay(), + ) + require.NoError(t, err) + defer h.Close() + require.NoError(t, h.Network().Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic"))) + err = h.Network().Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0")) + require.Error(t, err) + require.Contains(t, err.Error(), swarm.ErrNoTransport.Error()) +} + +type mockTransport struct{} + +func (m mockTransport) Dial(context.Context, ma.Multiaddr, peer.ID) (transport.CapableConn, error) { + panic("implement me") +} + +func (m mockTransport) CanDial(ma.Multiaddr) bool { panic("implement me") } +func (m mockTransport) Listen(ma.Multiaddr) (transport.Listener, error) { panic("implement me") } +func (m mockTransport) Protocols() []int { return []int{1337} } +func (m mockTransport) Proxy() bool { panic("implement me") } + +var _ transport.Transport = &mockTransport{} + +func TestTransportConstructorWithoutOpts(t *testing.T) { + t.Run("successful", func(t *testing.T) { + var called bool + constructor := func() transport.Transport { + called = true + return &mockTransport{} + } + + h, err := New( + Transport(constructor), + DisableRelay(), + ) + require.NoError(t, err) + require.True(t, called, "expected constructor to be called") + defer h.Close() + }) + + t.Run("with options", func(t *testing.T) { + var called bool + constructor := func() transport.Transport { + called = true + return &mockTransport{} + } + + _, err := New( + Transport(constructor, tcp.DisableReuseport()), + DisableRelay(), + ) + require.EqualError(t, err, "transport constructor doesn't take any options") + require.False(t, called, "didn't expected constructor to be called") + }) +} + +func TestTransportConstructorWithWrongOpts(t *testing.T) { + _, err := New( + Transport(quic.NewTransport, tcp.DisableReuseport()), + DisableRelay(), + ) + require.EqualError(t, err, "transport option of type tcp.Option not assignable to libp2pquic.Option") +} + +func TestSecurityConstructor(t *testing.T) { + h, err := New( + Transport(tcp.NewTCPTransport), + Security("/noisy", noise.New), + Security("/tls", tls.New), + DefaultListenAddrs, + DisableRelay(), + ) + require.NoError(t, err) + defer h.Close() + + h1, err := New( + NoListenAddrs, + Transport(tcp.NewTCPTransport), + Security("/noise", noise.New), // different name + DisableRelay(), + ) + require.NoError(t, err) + defer h1.Close() + + h2, err := New( + NoListenAddrs, + Transport(tcp.NewTCPTransport), + Security("/noisy", noise.New), + DisableRelay(), + ) + require.NoError(t, err) + defer h2.Close() + + ai := peer.AddrInfo{ + ID: h.ID(), + Addrs: h.Addrs(), + } + err = h1.Connect(context.Background(), ai) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to negotiate security protocol") + require.NoError(t, h2.Connect(context.Background(), ai)) +} diff --git a/options.go b/options.go index 5d54fe751d..440427aecb 100644 --- a/options.go +++ b/options.go @@ -4,8 +4,11 @@ package libp2p // those are in defaults.go). import ( + "crypto/rand" + "encoding/binary" "errors" "fmt" + "reflect" "time" "github.com/libp2p/go-libp2p/config" @@ -16,6 +19,9 @@ import ( "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peerstore" "github.com/libp2p/go-libp2p/core/pnet" + "github.com/libp2p/go-libp2p/core/protocol" + "github.com/libp2p/go-libp2p/core/sec" + "github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/p2p/host/autorelay" bhost "github.com/libp2p/go-libp2p/p2p/host/basic" relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" @@ -23,6 +29,7 @@ import ( ma "github.com/multiformats/go-multiaddr" madns "github.com/multiformats/go-multiaddr-dns" + "go.uber.org/fx" ) // ListenAddrStrings configures libp2p to listen on the given (unparsed) @@ -61,17 +68,27 @@ func ListenAddrs(addrs ...ma.Multiaddr) Option { // * Host // * Network // * Peerstore -func Security(name string, tpt interface{}) Option { - stpt, err := config.SecurityConstructor(tpt) - err = traceError(err, 1) +func Security(name string, constructor interface{}) Option { return func(cfg *Config) error { - if err != nil { - return err - } if cfg.Insecure { return fmt.Errorf("cannot use security transports with an insecure libp2p configuration") } - cfg.SecurityTransports = append(cfg.SecurityTransports, config.MsSecC{SecC: stpt, ID: name}) + fxName := fmt.Sprintf(`name:"%s"`, name) + // provide the name of the security transport + cfg.SecurityTransports = append(cfg.SecurityTransports, + fx.Provide(fx.Annotate( + func() protocol.ID { return protocol.ID(name) }, + fx.ResultTags(fxName), + )), + ) + cfg.SecurityTransports = append(cfg.SecurityTransports, + fx.Provide(fx.Annotate( + constructor, + fx.ParamTags(fxName), + fx.As(new(sec.SecureTransport)), + fx.ResultTags(`group:"security"`), + )), + ) return nil } } @@ -86,25 +103,11 @@ var NoSecurity Option = func(cfg *Config) error { return nil } -// Muxer configures libp2p to use the given stream multiplexer (or stream -// multiplexer constructor). -// -// Name is the protocol name. -// -// The transport can be a constructed mux.Transport or a function taking any -// subset of this libp2p node's: -// * Peer ID -// * Host -// * Network -// * Peerstore -func Muxer(name string, tpt interface{}) Option { - mtpt, err := config.MuxerConstructor(tpt) - err = traceError(err, 1) +// Muxer configures libp2p to use the given stream multiplexer. +// name is the protocol name. +func Muxer(name string, muxer network.Multiplexer) Option { return func(cfg *Config) error { - if err != nil { - return err - } - cfg.Muxers = append(cfg.Muxers, config.MsMuxC{MuxC: mtpt, ID: name}) + cfg.Muxers = append(cfg.Muxers, config.Muxer{Multiplexer: muxer, ID: protocol.ID(name)}) return nil } } @@ -124,14 +127,55 @@ func Muxer(name string, tpt interface{}) Option { // * Public Key // * Address filter (filter.Filter) // * Peerstore -func Transport(tpt interface{}, opts ...interface{}) Option { - tptc, err := config.TransportConstructor(tpt, opts...) - err = traceError(err, 1) +func Transport(constructor interface{}, opts ...interface{}) Option { return func(cfg *Config) error { - if err != nil { - return err + // generate a random identifier, so that fx can associate the constructor with its options + b := make([]byte, 8) + rand.Read(b) + id := binary.BigEndian.Uint64(b) + + tag := fmt.Sprintf(`group:"transportopt_%d"`, id) + + typ := reflect.ValueOf(constructor).Type() + numParams := typ.NumIn() + isVariadic := typ.IsVariadic() + + if !isVariadic && len(opts) > 0 { + return errors.New("transport constructor doesn't take any options") + } + if isVariadic && numParams >= 1 { + paramType := typ.In(numParams - 1).Elem() + for _, opt := range opts { + if typ := reflect.TypeOf(opt); !typ.AssignableTo(paramType) { + return fmt.Errorf("transport option of type %s not assignable to %s", typ, paramType) + } + } + } + + var params []string + if isVariadic && len(opts) > 0 { + // If there are transport options, apply the tag. + // Since options are variadic, they have to be the last argument of the constructor. + params = make([]string, numParams) + params[len(params)-1] = tag + } + + cfg.Transports = append(cfg.Transports, fx.Provide( + fx.Annotate( + constructor, + fx.ParamTags(params...), + fx.As(new(transport.Transport)), + fx.ResultTags(`group:"transport"`), + ), + )) + for _, opt := range opts { + cfg.Transports = append(cfg.Transports, fx.Supply( + fx.Annotate( + opt, + fx.ResultTags(tag), + ), + )) } - cfg.Transports = append(cfg.Transports, tptc) return nil } } @@ -412,7 +456,7 @@ var NoListenAddrs = func(cfg *Config) error { // This will both clear any configured transports (specified in prior libp2p // options) and prevent libp2p from applying the default transports. var NoTransports = func(cfg *Config) error { - cfg.Transports = []config.TptC{} + cfg.Transports = []fx.Option{} return nil } diff --git a/p2p/net/conn-security-multistream/ssms_test.go b/p2p/net/conn-security-multistream/ssms_test.go index 5aa5db352d..3ccf4a7f26 100644 --- a/p2p/net/conn-security-multistream/ssms_test.go +++ b/p2p/net/conn-security-multistream/ssms_test.go @@ -44,8 +44,8 @@ func TestCommonProto(t *testing.T) { var at, bt SSMuxer - atInsecure := insecure.NewWithIdentity(idA, privA) - btInsecure := insecure.NewWithIdentity(idB, privB) + atInsecure := insecure.NewWithIdentity(insecure.ID, idA, privA) + btInsecure := insecure.NewWithIdentity(insecure.ID, idB, privB) at.AddTransport("/plaintext/1.0.0", atInsecure) bt.AddTransport("/plaintext/1.1.0", btInsecure) bt.AddTransport("/plaintext/1.0.0", btInsecure) @@ -88,8 +88,8 @@ func TestNoCommonProto(t *testing.T) { privB, idB := newPeer(t) var at, bt SSMuxer - atInsecure := insecure.NewWithIdentity(idA, privA) - btInsecure := insecure.NewWithIdentity(idB, privB) + atInsecure := insecure.NewWithIdentity(insecure.ID, idA, privA) + btInsecure := insecure.NewWithIdentity(insecure.ID, idB, privB) at.AddTransport("/plaintext/1.0.0", atInsecure) bt.AddTransport("/plaintext/1.1.0", btInsecure) diff --git a/p2p/net/swarm/dial_worker_test.go b/p2p/net/swarm/dial_worker_test.go index bf7a5ca3ad..728c81da32 100644 --- a/p2p/net/swarm/dial_worker_test.go +++ b/p2p/net/swarm/dial_worker_test.go @@ -77,11 +77,11 @@ func makeUpgrader(t *testing.T, n *Swarm) transport.Upgrader { id := n.LocalPeer() pk := n.Peerstore().PrivKey(id) secMuxer := new(csms.SSMuxer) - secMuxer.AddTransport(insecure.ID, insecure.NewWithIdentity(id, pk)) + secMuxer.AddTransport(insecure.ID, insecure.NewWithIdentity(insecure.ID, id, pk)) stMuxer := msmux.NewBlankTransport() stMuxer.AddTransport("/yamux/1.0.0", yamux.DefaultTransport) - u, err := tptu.New(secMuxer, stMuxer) + u, err := tptu.New(secMuxer, stMuxer, nil, nil, nil) require.NoError(t, err) return u } diff --git a/p2p/net/swarm/testing/testing.go b/p2p/net/swarm/testing/testing.go index 604d1d0a47..a28b488efb 100644 --- a/p2p/net/swarm/testing/testing.go +++ b/p2p/net/swarm/testing/testing.go @@ -99,15 +99,15 @@ func OptPeerPrivateKey(sk crypto.PrivKey) Option { } // GenUpgrader creates a new connection upgrader for use with this swarm. -func GenUpgrader(t *testing.T, n *swarm.Swarm, opts ...tptu.Option) transport.Upgrader { +func GenUpgrader(t *testing.T, n *swarm.Swarm, connGater connmgr.ConnectionGater, opts ...tptu.Option) transport.Upgrader { id := n.LocalPeer() pk := n.Peerstore().PrivKey(id) secMuxer := new(csms.SSMuxer) - secMuxer.AddTransport(insecure.ID, insecure.NewWithIdentity(id, pk)) + secMuxer.AddTransport(insecure.ID, insecure.NewWithIdentity(insecure.ID, id, pk)) stMuxer := msmux.NewBlankTransport() stMuxer.AddTransport("/yamux/1.0.0", yamux.DefaultTransport) - u, err := tptu.New(secMuxer, stMuxer, opts...) + u, err := tptu.New(secMuxer, stMuxer, nil, nil, connGater, opts...) require.NoError(t, err) return u } @@ -145,7 +145,7 @@ func GenSwarm(t *testing.T, opts ...Option) *swarm.Swarm { s, err := swarm.NewSwarm(id, ps, swarmOpts...) require.NoError(t, err) - upgrader := GenUpgrader(t, s, tptu.WithConnectionGater(cfg.connectionGater)) + upgrader := GenUpgrader(t, s, cfg.connectionGater) if !cfg.disableTCP { var tcpOpts []tcp.Option diff --git a/p2p/net/swarm/testing/testing_test.go b/p2p/net/swarm/testing/testing_test.go index ef62570224..d4a43dfb59 100644 --- a/p2p/net/swarm/testing/testing_test.go +++ b/p2p/net/swarm/testing/testing_test.go @@ -9,5 +9,5 @@ import ( func TestGenSwarm(t *testing.T) { swarm := GenSwarm(t) require.NoError(t, swarm.Close()) - GenUpgrader(t, swarm) + GenUpgrader(t, swarm, nil) } diff --git a/p2p/net/upgrader/listener_test.go b/p2p/net/upgrader/listener_test.go index fcb7bf9c07..67fb292b16 100644 --- a/p2p/net/upgrader/listener_test.go +++ b/p2p/net/upgrader/listener_test.go @@ -100,7 +100,7 @@ func TestConnectionsClosedIfNotAccepted(t *testing.T) { timeout = 500 * time.Millisecond } - id, u := createUpgrader(t, upgrader.WithAcceptTimeout(timeout)) + id, u := createUpgraderWithOpts(t, upgrader.WithAcceptTimeout(timeout)) ln := createListener(t, u) defer ln.Close() @@ -134,7 +134,7 @@ func TestConnectionsClosedIfNotAccepted(t *testing.T) { func TestFailedUpgradeOnListen(t *testing.T) { require := require.New(t) - id, u := createUpgraderWithMuxer(t, &errorMuxer{}) + id, u := createUpgraderWithMuxer(t, &errorMuxer{}, nil, nil) ln := createListener(t, u) errCh := make(chan error) @@ -225,7 +225,7 @@ func TestConcurrentAccept(t *testing.T) { var num = 3 * upgrader.AcceptQueueLength blockingMuxer := newBlockingMuxer() - id, u := createUpgraderWithMuxer(t, blockingMuxer) + id, u := createUpgraderWithMuxer(t, blockingMuxer, nil, nil) ln := createListener(t, u) defer ln.Close() @@ -309,7 +309,7 @@ func TestListenerConnectionGater(t *testing.T) { require := require.New(t) testGater := &testGater{} - id, u := createUpgrader(t, upgrader.WithConnectionGater(testGater)) + id, u := createUpgraderWithConnGater(t, testGater) ln := createListener(t, u) defer ln.Close() @@ -354,7 +354,7 @@ func TestListenerResourceManagement(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() rcmgr := mocknetwork.NewMockResourceManager(ctrl) - id, upgrader := createUpgrader(t, upgrader.WithResourceManager(rcmgr)) + id, upgrader := createUpgraderWithResourceManager(t, rcmgr) ln := createListener(t, upgrader) defer ln.Close() @@ -380,7 +380,7 @@ func TestListenerResourceManagementDenied(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() rcmgr := mocknetwork.NewMockResourceManager(ctrl) - id, upgrader := createUpgrader(t, upgrader.WithResourceManager(rcmgr)) + id, upgrader := createUpgraderWithResourceManager(t, rcmgr) ln := createListener(t, upgrader) rcmgr.EXPECT().OpenConnection(network.DirInbound, true, gomock.Not(ln.Multiaddr())).Return(nil, errors.New("nope")) diff --git a/p2p/net/upgrader/upgrader.go b/p2p/net/upgrader/upgrader.go index ea19d2f2d6..f189720681 100644 --- a/p2p/net/upgrader/upgrader.go +++ b/p2p/net/upgrader/upgrader.go @@ -30,13 +30,6 @@ const defaultAcceptTimeout = 15 * time.Second type Option func(*upgrader) error -func WithPSK(psk ipnet.PSK) Option { - return func(u *upgrader) error { - u.psk = psk - return nil - } -} - func WithAcceptTimeout(t time.Duration) Option { return func(u *upgrader) error { u.acceptTimeout = t @@ -44,20 +37,6 @@ func WithAcceptTimeout(t time.Duration) Option { } } -func WithConnectionGater(g connmgr.ConnectionGater) Option { - return func(u *upgrader) error { - u.connGater = g - return nil - } -} - -func WithResourceManager(m network.ResourceManager) Option { - return func(u *upgrader) error { - u.rcmgr = m - return nil - } -} - // Upgrader is a multistream upgrader that can upgrade an underlying connection // to a full transport connection (secure and multiplexed). type upgrader struct { @@ -78,11 +57,14 @@ type upgrader struct { var _ transport.Upgrader = &upgrader{} -func New(secureMuxer sec.SecureMuxer, muxer network.Multiplexer, opts ...Option) (transport.Upgrader, error) { +func New(secureMuxer sec.SecureMuxer, muxer network.Multiplexer, psk ipnet.PSK, rcmgr network.ResourceManager, connGater connmgr.ConnectionGater, opts ...Option) (transport.Upgrader, error) { u := &upgrader{ secure: secureMuxer, muxer: muxer, acceptTimeout: defaultAcceptTimeout, + rcmgr: rcmgr, + connGater: connGater, + psk: psk, } for _, opt := range opts { if err := opt(u); err != nil { diff --git a/p2p/net/upgrader/upgrader_test.go b/p2p/net/upgrader/upgrader_test.go index 16188bd5bb..d39d360223 100644 --- a/p2p/net/upgrader/upgrader_test.go +++ b/p2p/net/upgrader/upgrader_test.go @@ -6,6 +6,7 @@ import ( "net" "testing" + "github.com/libp2p/go-libp2p/core/connmgr" "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/network" mocknetwork "github.com/libp2p/go-libp2p/core/network/mocks" @@ -22,16 +23,28 @@ import ( "github.com/stretchr/testify/require" ) -func createUpgrader(t *testing.T, opts ...upgrader.Option) (peer.ID, transport.Upgrader) { - return createUpgraderWithMuxer(t, &negotiatingMuxer{}, opts...) +func createUpgrader(t *testing.T) (peer.ID, transport.Upgrader) { + return createUpgraderWithMuxer(t, &negotiatingMuxer{}, nil, nil) } -func createUpgraderWithMuxer(t *testing.T, muxer network.Multiplexer, opts ...upgrader.Option) (peer.ID, transport.Upgrader) { +func createUpgraderWithConnGater(t *testing.T, connGater connmgr.ConnectionGater) (peer.ID, transport.Upgrader) { + return createUpgraderWithMuxer(t, &negotiatingMuxer{}, nil, connGater) +} + +func createUpgraderWithResourceManager(t *testing.T, rcmgr network.ResourceManager) (peer.ID, transport.Upgrader) { + return createUpgraderWithMuxer(t, &negotiatingMuxer{}, rcmgr, nil) +} + +func createUpgraderWithOpts(t *testing.T, opts ...upgrader.Option) (peer.ID, transport.Upgrader) { + return createUpgraderWithMuxer(t, &negotiatingMuxer{}, nil, nil, opts...) +} + +func createUpgraderWithMuxer(t *testing.T, muxer network.Multiplexer, rcmgr network.ResourceManager, connGater connmgr.ConnectionGater, opts ...upgrader.Option) (peer.ID, transport.Upgrader) { priv, _, err := test.RandTestKeyPair(crypto.Ed25519, 256) require.NoError(t, err) id, err := peer.IDFromPrivateKey(priv) require.NoError(t, err) - u, err := upgrader.New(&MuxAdapter{tpt: insecure.NewWithIdentity(id, priv)}, muxer, opts...) + u, err := upgrader.New(&MuxAdapter{tpt: insecure.NewWithIdentity(insecure.ID, id, priv)}, muxer, nil, rcmgr, connGater, opts...) require.NoError(t, err) return id, u } @@ -54,7 +67,7 @@ func (m *negotiatingMuxer) NewConn(c net.Conn, isServer bool, scope network.Peer return yamux.DefaultTransport.NewConn(c, isServer, scope) } -// blockingMuxer blocks the muxer negotiation until the contain chan is closed +// blockingMuxer blocks the muxer negotiation until the contained chan is closed type blockingMuxer struct { unblock chan struct{} } @@ -120,7 +133,7 @@ func TestOutboundConnectionGating(t *testing.T) { defer ln.Close() testGater := &testGater{} - _, dialUpgrader := createUpgrader(t, upgrader.WithConnectionGater(testGater)) + _, dialUpgrader := createUpgraderWithConnGater(t, testGater) conn, err := dial(t, dialUpgrader, ln.Multiaddr(), id, &network.NullScope{}) require.NoError(err) require.NotNil(conn) @@ -164,7 +177,7 @@ func TestOutboundResourceManagement(t *testing.T) { }) t.Run("failed negotiation", func(t *testing.T) { - id, upgrader := createUpgraderWithMuxer(t, &errorMuxer{}) + id, upgrader := createUpgraderWithMuxer(t, &errorMuxer{}, nil, nil) ln := createListener(t, upgrader) defer ln.Close() diff --git a/p2p/protocol/circuitv2/relay/relay_test.go b/p2p/protocol/circuitv2/relay/relay_test.go index 33a850167d..fc797c42c3 100644 --- a/p2p/protocol/circuitv2/relay/relay_test.go +++ b/p2p/protocol/circuitv2/relay/relay_test.go @@ -53,7 +53,7 @@ func getNetHosts(t *testing.T, ctx context.Context, n int) (hosts []host.Host, u t.Fatal(err) } - upgrader := swarmt.GenUpgrader(t, netw) + upgrader := swarmt.GenUpgrader(t, netw, nil) upgraders = append(upgraders, upgrader) tpt, err := tcp.NewTCPTransport(upgrader, nil) diff --git a/p2p/protocol/internal/circuitv1-deprecated/relay_test.go b/p2p/protocol/internal/circuitv1-deprecated/relay_test.go index 9b76cffcdc..30a80f4f95 100644 --- a/p2p/protocol/internal/circuitv1-deprecated/relay_test.go +++ b/p2p/protocol/internal/circuitv1-deprecated/relay_test.go @@ -44,7 +44,7 @@ func getNetHosts(t *testing.T, n int) []host.Host { } func newTestRelay(t *testing.T, host host.Host, opts ...RelayOpt) *Relay { - r, err := NewRelay(host, swarmt.GenUpgrader(t, host.Network().(*swarm.Swarm)), opts...) + r, err := NewRelay(host, swarmt.GenUpgrader(t, host.Network().(*swarm.Swarm), nil), opts...) if err != nil { t.Fatal(err) } diff --git a/p2p/protocol/internal/circuitv1-deprecated/transport_test.go b/p2p/protocol/internal/circuitv1-deprecated/transport_test.go index 02b3716155..c1cb1a9c4b 100644 --- a/p2p/protocol/internal/circuitv1-deprecated/transport_test.go +++ b/p2p/protocol/internal/circuitv1-deprecated/transport_test.go @@ -27,17 +27,17 @@ var msg = []byte("relay works!") func testSetupRelay(t *testing.T) []host.Host { hosts := getNetHosts(t, 3) - err := AddRelayTransport(hosts[0], swarmt.GenUpgrader(t, hosts[0].Network().(*swarm.Swarm))) + err := AddRelayTransport(hosts[0], swarmt.GenUpgrader(t, hosts[0].Network().(*swarm.Swarm), nil)) if err != nil { t.Fatal(err) } - err = AddRelayTransport(hosts[1], swarmt.GenUpgrader(t, hosts[1].Network().(*swarm.Swarm)), OptHop) + err = AddRelayTransport(hosts[1], swarmt.GenUpgrader(t, hosts[1].Network().(*swarm.Swarm), nil), OptHop) if err != nil { t.Fatal(err) } - err = AddRelayTransport(hosts[2], swarmt.GenUpgrader(t, hosts[2].Network().(*swarm.Swarm))) + err = AddRelayTransport(hosts[2], swarmt.GenUpgrader(t, hosts[2].Network().(*swarm.Swarm), nil)) if err != nil { t.Fatal(err) } diff --git a/p2p/security/noise/benchmark_test.go b/p2p/security/noise/benchmark_test.go index 836275b954..d59a1cb979 100644 --- a/p2p/security/noise/benchmark_test.go +++ b/p2p/security/noise/benchmark_test.go @@ -39,7 +39,7 @@ func makeTransport(b *testing.B) *Transport { if err != nil { b.Fatal(err) } - tpt, err := New(priv, nil) + tpt, err := New(ID, priv, nil) if err != nil { b.Fatalf("error constructing transport: %v", err) } diff --git a/p2p/security/noise/session_transport.go b/p2p/security/noise/session_transport.go index c42271bd91..0f26f3fa8e 100644 --- a/p2p/security/noise/session_transport.go +++ b/p2p/security/noise/session_transport.go @@ -6,6 +6,7 @@ import ( "github.com/libp2p/go-libp2p/core/canonicallog" "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/core/sec" "github.com/libp2p/go-libp2p/p2p/security/noise/pb" @@ -71,6 +72,8 @@ type SessionTransport struct { prologue []byte disablePeerIDCheck bool + protocolID protocol.ID + initiatorEarlyDataHandler, responderEarlyDataHandler EarlyDataHandler } @@ -92,3 +95,7 @@ func (i *SessionTransport) SecureInbound(ctx context.Context, insecure net.Conn, func (i *SessionTransport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { return newSecureSession(i.t, ctx, insecure, p, i.prologue, i.initiatorEarlyDataHandler, i.responderEarlyDataHandler, true, !i.disablePeerIDCheck) } + +func (i *SessionTransport) ID() protocol.ID { + return i.protocolID +} diff --git a/p2p/security/noise/transport.go b/p2p/security/noise/transport.go index 6878929505..e436c46aac 100644 --- a/p2p/security/noise/transport.go +++ b/p2p/security/noise/transport.go @@ -15,24 +15,21 @@ import ( ) // ID is the protocol ID for noise -const ( - ID = "/noise" - maxProtoNum = 100 -) - -var _ sec.SecureTransport = &Transport{} +const ID = "/noise" +const maxProtoNum = 100 -// Transport implements the interface sec.SecureTransport -// https://godoc.org/github.com/libp2p/go-libp2p/core/sec#SecureConn type Transport struct { + protocolID protocol.ID localID peer.ID privateKey crypto.PrivKey muxers []string } +var _ sec.SecureTransport = &Transport{} + // New creates a new Noise transport using the given private key as its // libp2p identity key. -func New(privkey crypto.PrivKey, muxers []protocol.ID) (*Transport, error) { +func New(id protocol.ID, privkey crypto.PrivKey, muxers []protocol.ID) (*Transport, error) { localID, err := peer.IDFromPrivateKey(privkey) if err != nil { return nil, err @@ -44,6 +41,7 @@ func New(privkey crypto.PrivKey, muxers []protocol.ID) (*Transport, error) { } return &Transport{ + protocolID: id, localID: localID, privateKey: privkey, muxers: smuxers, @@ -75,7 +73,7 @@ func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p pee } func (t *Transport) WithSessionOptions(opts ...SessionOption) (*SessionTransport, error) { - st := &SessionTransport{t: t} + st := &SessionTransport{t: t, protocolID: t.protocolID} for _, opt := range opts { if err := opt(st); err != nil { return nil, err @@ -84,6 +82,10 @@ func (t *Transport) WithSessionOptions(opts ...SessionOption) (*SessionTransport return st, nil } +func (t *Transport) ID() protocol.ID { + return t.protocolID +} + func matchMuxers(initiatorMuxers, responderMuxers []string) string { for _, muxer := range responderMuxers { for _, initMuxer := range initiatorMuxers { diff --git a/p2p/security/tls/cmd/tlsdiag/client.go b/p2p/security/tls/cmd/tlsdiag/client.go index 2292bfe0e9..3868afebb0 100644 --- a/p2p/security/tls/cmd/tlsdiag/client.go +++ b/p2p/security/tls/cmd/tlsdiag/client.go @@ -34,7 +34,7 @@ func StartClient() error { return err } fmt.Printf(" Peer ID: %s\n", id.Pretty()) - tp, err := libp2ptls.New(priv, nil) + tp, err := libp2ptls.New(libp2ptls.ID, priv, nil) if err != nil { return err } diff --git a/p2p/security/tls/cmd/tlsdiag/server.go b/p2p/security/tls/cmd/tlsdiag/server.go index 05e4be3f16..76c45a155e 100644 --- a/p2p/security/tls/cmd/tlsdiag/server.go +++ b/p2p/security/tls/cmd/tlsdiag/server.go @@ -27,7 +27,7 @@ func StartServer() error { return err } fmt.Printf(" Peer ID: %s\n", id.Pretty()) - tp, err := libp2ptls.New(priv, nil) + tp, err := libp2ptls.New(libp2ptls.ID, priv, nil) if err != nil { return err } diff --git a/p2p/security/tls/transport.go b/p2p/security/tls/transport.go index 695f648465..754a335139 100644 --- a/p2p/security/tls/transport.go +++ b/p2p/security/tls/transport.go @@ -26,21 +26,25 @@ const ID = "/tls/1.0.0" type Transport struct { identity *Identity - localPeer peer.ID - privKey ci.PrivKey - muxers []protocol.ID + localPeer peer.ID + privKey ci.PrivKey + muxers []protocol.ID + protocolID protocol.ID } +var _ sec.SecureTransport = &Transport{} + // New creates a TLS encrypted transport -func New(key ci.PrivKey, muxers []protocol.ID) (*Transport, error) { - id, err := peer.IDFromPrivateKey(key) +func New(id protocol.ID, key ci.PrivKey, muxers []protocol.ID) (*Transport, error) { + localPeer, err := peer.IDFromPrivateKey(key) if err != nil { return nil, err } t := &Transport{ - localPeer: id, - privKey: key, - muxers: muxers, + protocolID: id, + localPeer: localPeer, + privKey: key, + muxers: muxers, } identity, err := NewIdentity(key) @@ -51,8 +55,6 @@ func New(key ci.PrivKey, muxers []protocol.ID) (*Transport, error) { return t, nil } -var _ sec.SecureTransport = &Transport{} - // SecureInbound runs the TLS handshake as a server. // If p is empty, connections from any peer are accepted. func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { @@ -148,3 +150,7 @@ func (t *Transport) setupConn(tlsConn *tls.Conn, remotePubKey ci.PubKey) (sec.Se connectionState: network.ConnectionState{NextProto: nextProto}, }, nil } + +func (t *Transport) ID() protocol.ID { + return t.protocolID +} diff --git a/p2p/security/tls/transport_test.go b/p2p/security/tls/transport_test.go index 59fa8bdae8..8c9fa7ceda 100644 --- a/p2p/security/tls/transport_test.go +++ b/p2p/security/tls/transport_test.go @@ -126,9 +126,9 @@ func TestHandshakeSucceeds(t *testing.T) { } // Use standard transports with default TLS configuration - clientTransport, err := New(clientKey, nil) + clientTransport, err := New(ID, clientKey, nil) require.NoError(t, err) - serverTransport, err := New(serverKey, nil) + serverTransport, err := New(ID, serverKey, nil) require.NoError(t, err) t.Run("standard TLS with extension not critical", func(t *testing.T) { @@ -240,9 +240,9 @@ func TestHandshakeWithNextProtoSucceeds(t *testing.T) { // Iterate through the NextProto combinations. for _, test := range tests { - clientTransport, err := New(clientKey, test.clientProtos) + clientTransport, err := New(ID, clientKey, test.clientProtos) require.NoError(t, err) - serverTransport, err := New(serverKey, test.serverProtos) + serverTransport, err := New(ID, serverKey, test.serverProtos) require.NoError(t, err) t.Run("TLS handshake with ALPN extension", func(t *testing.T) { @@ -268,9 +268,9 @@ func TestHandshakeConnectionCancellations(t *testing.T) { _, clientKey := createPeer(t) serverID, serverKey := createPeer(t) - clientTransport, err := New(clientKey, nil) + clientTransport, err := New(ID, clientKey, nil) require.NoError(t, err) - serverTransport, err := New(serverKey, nil) + serverTransport, err := New(ID, serverKey, nil) require.NoError(t, err) t.Run("cancel outgoing connection", func(t *testing.T) { @@ -320,9 +320,9 @@ func TestPeerIDMismatch(t *testing.T) { _, clientKey := createPeer(t) serverID, serverKey := createPeer(t) - serverTransport, err := New(serverKey, nil) + serverTransport, err := New(ID, serverKey, nil) require.NoError(t, err) - clientTransport, err := New(clientKey, nil) + clientTransport, err := New(ID, clientKey, nil) require.NoError(t, err) t.Run("for outgoing connections", func(t *testing.T) { @@ -597,9 +597,9 @@ func TestInvalidCerts(t *testing.T) { tr := transforms[i] t.Run(fmt.Sprintf("client offending: %s", tr.name), func(t *testing.T) { - serverTransport, err := New(serverKey, nil) + serverTransport, err := New(ID, serverKey, nil) require.NoError(t, err) - clientTransport, err := New(clientKey, nil) + clientTransport, err := New(ID, clientKey, nil) require.NoError(t, err) tr.apply(clientTransport.identity) @@ -640,10 +640,10 @@ func TestInvalidCerts(t *testing.T) { }) t.Run(fmt.Sprintf("server offending: %s", tr.name), func(t *testing.T) { - serverTransport, err := New(serverKey, nil) + serverTransport, err := New(ID, serverKey, nil) require.NoError(t, err) tr.apply(serverTransport.identity) - clientTransport, err := New(clientKey, nil) + clientTransport, err := New(ID, clientKey, nil) require.NoError(t, err) clientInsecureConn, serverInsecureConn := connect(t) diff --git a/p2p/transport/tcp/tcp_test.go b/p2p/transport/tcp/tcp_test.go index 3d49a253aa..eec1657dd7 100644 --- a/p2p/transport/tcp/tcp_test.go +++ b/p2p/transport/tcp/tcp_test.go @@ -27,11 +27,11 @@ func TestTcpTransport(t *testing.T) { peerA, ia := makeInsecureMuxer(t) _, ib := makeInsecureMuxer(t) - ua, err := tptu.New(ia, yamux.DefaultTransport) + ua, err := tptu.New(ia, yamux.DefaultTransport, nil, nil, nil) require.NoError(t, err) ta, err := NewTCPTransport(ua, nil) require.NoError(t, err) - ub, err := tptu.New(ib, yamux.DefaultTransport) + ub, err := tptu.New(ib, yamux.DefaultTransport, nil, nil, nil) require.NoError(t, err) tb, err := NewTCPTransport(ub, nil) require.NoError(t, err) @@ -48,11 +48,11 @@ func TestTcpTransportWithMetrics(t *testing.T) { peerA, ia := makeInsecureMuxer(t) _, ib := makeInsecureMuxer(t) - ua, err := tptu.New(ia, yamux.DefaultTransport) + ua, err := tptu.New(ia, yamux.DefaultTransport, nil, nil, nil) require.NoError(t, err) ta, err := NewTCPTransport(ua, nil, WithMetrics()) require.NoError(t, err) - ub, err := tptu.New(ib, yamux.DefaultTransport) + ub, err := tptu.New(ib, yamux.DefaultTransport, nil, nil, nil) require.NoError(t, err) tb, err := NewTCPTransport(ub, nil, WithMetrics()) require.NoError(t, err) @@ -68,7 +68,7 @@ func TestResourceManager(t *testing.T) { peerA, ia := makeInsecureMuxer(t) _, ib := makeInsecureMuxer(t) - ua, err := tptu.New(ia, yamux.DefaultTransport) + ua, err := tptu.New(ia, yamux.DefaultTransport, nil, nil, nil) require.NoError(t, err) ta, err := NewTCPTransport(ua, nil) require.NoError(t, err) @@ -76,7 +76,7 @@ func TestResourceManager(t *testing.T) { require.NoError(t, err) defer ln.Close() - ub, err := tptu.New(ib, yamux.DefaultTransport) + ub, err := tptu.New(ib, yamux.DefaultTransport, nil, nil, nil) require.NoError(t, err) rcmgr := mocknetwork.NewMockResourceManager(ctrl) tb, err := NewTCPTransport(ub, rcmgr) @@ -153,6 +153,6 @@ func makeInsecureMuxer(t *testing.T) (peer.ID, sec.SecureMuxer) { id, err := peer.IDFromPrivateKey(priv) require.NoError(t, err) var secMuxer csms.SSMuxer - secMuxer.AddTransport(insecure.ID, insecure.NewWithIdentity(id, priv)) + secMuxer.AddTransport(insecure.ID, insecure.NewWithIdentity(insecure.ID, id, priv)) return id, &secMuxer } diff --git a/p2p/transport/websocket/websocket_test.go b/p2p/transport/websocket/websocket_test.go index 645dc82452..714fe89a68 100644 --- a/p2p/transport/websocket/websocket_test.go +++ b/p2p/transport/websocket/websocket_test.go @@ -39,7 +39,7 @@ import ( func newUpgrader(t *testing.T) (peer.ID, transport.Upgrader) { t.Helper() id, m := newInsecureMuxer(t) - u, err := tptu.New(m, yamux.DefaultTransport) + u, err := tptu.New(m, yamux.DefaultTransport, nil, nil, nil) if err != nil { t.Fatal(err) } @@ -49,7 +49,7 @@ func newUpgrader(t *testing.T) (peer.ID, transport.Upgrader) { func newSecureUpgrader(t *testing.T) (peer.ID, transport.Upgrader) { t.Helper() id, m := newSecureMuxer(t) - u, err := tptu.New(m, yamux.DefaultTransport) + u, err := tptu.New(m, yamux.DefaultTransport, nil, nil, nil) if err != nil { t.Fatal(err) } @@ -67,7 +67,7 @@ func newInsecureMuxer(t *testing.T) (peer.ID, sec.SecureMuxer) { t.Fatal(err) } var secMuxer csms.SSMuxer - secMuxer.AddTransport(insecure.ID, insecure.NewWithIdentity(id, priv)) + secMuxer.AddTransport(insecure.ID, insecure.NewWithIdentity(insecure.ID, id, priv)) return id, &secMuxer } @@ -82,7 +82,7 @@ func newSecureMuxer(t *testing.T) (peer.ID, sec.SecureMuxer) { t.Fatal(err) } var secMuxer csms.SSMuxer - noiseTpt, err := noise.New(priv, nil) + noiseTpt, err := noise.New(noise.ID, priv, nil) require.NoError(t, err) secMuxer.AddTransport(noise.ID, noiseTpt) return id, &secMuxer diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index 3d902b4edc..bc1b5e1439 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -110,7 +110,7 @@ func New(key ic.PrivKey, gater connmgr.ConnectionGater, rcmgr network.ResourceMa return nil, err } } - n, err := noise.New(key, nil) + n, err := noise.New(noise.ID, key, nil) if err != nil { return nil, err }