diff --git a/component.go b/component.go index 490b8ac..4ee6809 100644 --- a/component.go +++ b/component.go @@ -70,6 +70,9 @@ func (m *Component) UnmarshalJSON(data []byte) error { } func (c *Component) Equal(o Multiaddr) bool { + if o == nil { + return false + } return bytes.Equal(c.bytes, o.Bytes()) } diff --git a/multiaddr.go b/multiaddr.go index 4b3a360..5e60780 100644 --- a/multiaddr.go +++ b/multiaddr.go @@ -48,6 +48,9 @@ func NewMultiaddrBytes(b []byte) (a Multiaddr, err error) { // Equal tests whether two multiaddrs are equal func (m *multiaddr) Equal(m2 Multiaddr) bool { + if m2 == nil { + return false + } return bytes.Equal(m.bytes, m2.Bytes()) } @@ -139,6 +142,10 @@ func (m *multiaddr) Protocols() []Protocol { // Encapsulate wraps a given Multiaddr, returning the resulting joined Multiaddr func (m *multiaddr) Encapsulate(o Multiaddr) Multiaddr { + if o == nil { + return m + } + mb := m.bytes ob := o.Bytes() diff --git a/multiaddr_test.go b/multiaddr_test.go index 6d160b4..67a91f0 100644 --- a/multiaddr_test.go +++ b/multiaddr_test.go @@ -242,6 +242,28 @@ func TestEqual(t *testing.T) { } } +// TestNilInterface makes sure funcs that accept a multiaddr interface don't +// panic if it's passed a nil interface. +func TestNilInterface(t *testing.T) { + m1 := newMultiaddr(t, "/ip4/127.0.0.1/udp/1234") + var m2 Multiaddr + m1.Equal(m2) + m1.Encapsulate(m2) + m1.Decapsulate(m2) + + // Test components + c, _ := SplitFirst(m1) + c.Equal(m2) + c.Encapsulate(m2) + c.Decapsulate(m2) + + // Util funcs + _ = Split(m2) + _, _ = SplitFirst(m2) + _, _ = SplitLast(m2) + ForEach(m2, func(c Component) bool { return true }) +} + func TestStringToBytes(t *testing.T) { testString := func(s string, h string) { diff --git a/util.go b/util.go index b0ac7ee..8757401 100644 --- a/util.go +++ b/util.go @@ -36,6 +36,9 @@ func Join(ms ...Multiaddr) Multiaddr { for _, mb := range ms { bidx += copy(b[bidx:], mb.Bytes()) } + if length == 0 { + return nil + } return &multiaddr{bytes: b} } @@ -59,6 +62,9 @@ func StringCast(s string) Multiaddr { // SplitFirst returns the first component and the rest of the multiaddr. func SplitFirst(m Multiaddr) (*Component, Multiaddr) { + if m == nil { + return nil, nil + } // Shortcut if we already have a component if c, ok := m.(*Component); ok { return c, nil @@ -80,6 +86,10 @@ func SplitFirst(m Multiaddr) (*Component, Multiaddr) { // SplitLast returns the rest of the multiaddr and the last component. func SplitLast(m Multiaddr) (Multiaddr, *Component) { + if m == nil { + return nil, nil + } + // Shortcut if we already have a component if c, ok := m.(*Component); ok { return nil, c @@ -117,6 +127,9 @@ func SplitLast(m Multiaddr) (Multiaddr, *Component) { // component on which the callback first returns will be included in the // *second* multiaddr. func SplitFunc(m Multiaddr, cb func(Component) bool) (Multiaddr, Multiaddr) { + if m == nil { + return nil, nil + } // Shortcut if we already have a component if c, ok := m.(*Component); ok { if cb(*c) { @@ -159,6 +172,9 @@ func SplitFunc(m Multiaddr, cb func(Component) bool) (Multiaddr, Multiaddr) { // This function iterates over components *by value* to avoid allocating. // Return true to continue iteration, false to stop. func ForEach(m Multiaddr, cb func(c Component) bool) { + if m == nil { + return + } // Shortcut if we already have a component if c, ok := m.(*Component); ok { cb(*c)