Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return an error if there is no DNS address. #1424

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 48 additions & 26 deletions pkg/networkservice/chains/nsmgr/vl3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,9 @@ import (
"go.uber.org/goleak"

"github.com/networkservicemesh/sdk/pkg/networkservice/chains/client"
"github.com/networkservicemesh/sdk/pkg/networkservice/common/upstreamrefresh"
"github.com/networkservicemesh/sdk/pkg/networkservice/connectioncontext/dnscontext/vl3dns"
"github.com/networkservicemesh/sdk/pkg/networkservice/connectioncontext/ipcontext/vl3"
"github.com/networkservicemesh/sdk/pkg/networkservice/utils/checks/checkconnection"
"github.com/networkservicemesh/sdk/pkg/tools/clock"
"github.com/networkservicemesh/sdk/pkg/tools/dnsutils"
"github.com/networkservicemesh/sdk/pkg/tools/dnsutils/memory"
"github.com/networkservicemesh/sdk/pkg/tools/sandbox"
Expand Down Expand Up @@ -224,7 +223,7 @@ func Test_vl3NSE_ConnectsTo_vl3NSE(t *testing.T) {
require.Error(t, err)
}

func Test_NSC_GetsVl3DnsAddressAfterRefresh(t *testing.T) {
func Test_NSC_GetsVl3DnsAddressDelay(t *testing.T) {
t.Cleanup(func() { goleak.VerifyNone(t) })

ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
Expand Down Expand Up @@ -259,34 +258,57 @@ func Test_NSC_GetsVl3DnsAddressAfterRefresh(t *testing.T) {
vl3dns.WithDomainSchemes("{{ index .Labels \"podName\" }}.{{ .NetworkService }}."),
vl3dns.WithDNSPort(40053)))

refresh := false
refreshCompletedCh := make(chan struct{}, 1)
nsc := domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken,
client.WithAdditionalFunctionality(
upstreamrefresh.NewClient(ctx),
checkconnection.NewClient(t, func(t *testing.T, conn *networkservice.Connection) {
if !refresh {
refresh = true
require.Len(t, conn.GetContext().GetDnsContext().GetConfigs(), 0)
} else {
require.Len(t, conn.GetContext().GetDnsContext().GetConfigs(), 1)
require.Len(t, conn.GetContext().GetDnsContext().GetConfigs()[0].DnsServerIps, 1)
require.Equal(t, conn.GetContext().GetDnsContext().GetConfigs()[0].DnsServerIps[0], "127.0.0.1")
refreshCompletedCh <- struct{}{}
}
}),
))
nsc := domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken)

req := defaultRequest(nsReg.Name)
req.Connection.Labels["podName"] = nscName
go func() {
// Add a delay
<-clock.FromContext(ctx).After(time.Millisecond * 200)
dnsServerIPCh <- net.ParseIP("10.0.0.0")
}()
_, err = nsc.Request(ctx, req)
require.NoError(t, err)
}

dnsServerIPCh <- net.ParseIP("127.0.0.1")
func Test_vl3NSE_ConnectsTo_Itself(t *testing.T) {
t.Cleanup(func() { goleak.VerifyNone(t) })

select {
case <-ctx.Done():
case <-refreshCompletedCh:
}
require.NoError(t, ctx.Err())
ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
defer cancel()

domain := sandbox.NewBuilder(ctx, t).
SetNodesCount(1).
SetNSMgrProxySupplier(nil).
SetRegistryProxySupplier(nil).
Build()

nsRegistryClient := domain.NewNSRegistryClient(ctx, sandbox.GenerateTestToken)

nsReg, err := nsRegistryClient.Register(ctx, defaultRegistryService("vl3"))
require.NoError(t, err)

nseReg := defaultRegistryEndpoint(nsReg.Name)

var serverPrefixCh = make(chan *ipam.PrefixResponse, 1)
defer close(serverPrefixCh)

serverPrefixCh <- &ipam.PrefixResponse{Prefix: "10.0.0.1/24"}
dnsServerIPCh := make(chan net.IP, 1)

_ = domain.Nodes[0].NewEndpoint(
ctx,
nseReg,
sandbox.GenerateTestToken,
vl3.NewServer(ctx, serverPrefixCh),
vl3dns.NewServer(ctx,
dnsServerIPCh,
vl3dns.WithDNSPort(40053)))

// Connection to itself. This allows us to assign a dns address to ourselves.
nsc := domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken, client.WithName(nseReg.Name))
req := defaultRequest(nsReg.Name)

_, err = nsc.Request(ctx, req)
require.NoError(t, err)
}
12 changes: 0 additions & 12 deletions pkg/networkservice/common/monitor/monitor_connection_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,21 +105,9 @@ func (m *monitorConnectionServer) Send(event *networkservice.ConnectionEvent) (_
return nil
}

func (m *monitorConnectionServer) GetConnections() map[string]*networkservice.Connection {
connections := make(map[string]*networkservice.Connection)

<-m.executor.AsyncExec(func() {
for k, v := range m.connections {
connections[k] = v
}
})
return connections
}

// EventConsumer - interface for monitor events sending
type EventConsumer interface {
Send(event *networkservice.ConnectionEvent) (err error)
GetConnections() map[string]*networkservice.Connection
}

var _ EventConsumer = &monitorConnectionServer{}
84 changes: 28 additions & 56 deletions pkg/networkservice/connectioncontext/dnscontext/vl3dns/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"fmt"
"net"
"strings"
"sync"
"sync/atomic"
"text/template"

Expand All @@ -31,7 +30,6 @@ import (
"github.com/networkservicemesh/api/pkg/api/networkservice"
"github.com/pkg/errors"

"github.com/networkservicemesh/sdk/pkg/networkservice/common/monitor"
"github.com/networkservicemesh/sdk/pkg/networkservice/core/next"
"github.com/networkservicemesh/sdk/pkg/networkservice/utils/metadata"
"github.com/networkservicemesh/sdk/pkg/tools/dnsutils"
Expand All @@ -42,11 +40,9 @@ import (
"github.com/networkservicemesh/sdk/pkg/tools/dnsutils/noloop"
"github.com/networkservicemesh/sdk/pkg/tools/dnsutils/norecursion"
"github.com/networkservicemesh/sdk/pkg/tools/ippool"
"github.com/networkservicemesh/sdk/pkg/tools/log"
)

type vl3DNSServer struct {
chainCtx context.Context
dnsServerRecords genericsync.Map[string, []net.IP]
dnsConfigs *genericsync.Map[string, []*networkservice.DNSConfig]
domainSchemeTemplates []*template.Template
Expand All @@ -55,8 +51,6 @@ type vl3DNSServer struct {
listenAndServeDNS func(ctx context.Context, handler dnsutils.Handler, listenOn string)
dnsServerIP atomic.Value
dnsServerIPCh <-chan net.IP
monitorEventConsumer monitor.EventConsumer
once sync.Once
}

type clientDNSNameKey struct{}
Expand All @@ -68,7 +62,6 @@ type clientDNSNameKey struct{}
// opts configure vl3dns networkservice instance with specific behavior.
func NewServer(chainCtx context.Context, dnsServerIPCh <-chan net.IP, opts ...Option) networkservice.NetworkServiceServer {
var result = &vl3DNSServer{
chainCtx: chainCtx,
dnsPort: 53,
listenAndServeDNS: dnsutils.ListenAndServe,
dnsConfigs: new(genericsync.Map[string, []*networkservice.DNSConfig]),
Expand All @@ -91,27 +84,37 @@ func NewServer(chainCtx context.Context, dnsServerIPCh <-chan net.IP, opts ...Op

result.listenAndServeDNS(chainCtx, result.dnsServer, fmt.Sprintf(":%v", result.dnsPort))

if len(dnsServerIPCh) > 0 {
result.dnsServerIP.Store(<-dnsServerIPCh)
}
go func() {
for {
select {
case <-chainCtx.Done():
return
case addr, ok := <-dnsServerIPCh:
if !ok {
return
}
result.dnsServerIP.Store(addr)
}
}
}()

return result
}

func (n *vl3DNSServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) {
n.once.Do(func() {
// We assume here that the monitorEventConsumer is the same for all connections.
// We need the context of any request to pull it out.
go n.checkServerAddressUpdates(ctx)
})

if request.GetConnection().GetContext().GetDnsContext() == nil {
request.Connection.Context.DnsContext = new(networkservice.DNSContext)
}

var clientsConfigs = request.GetConnection().GetContext().GetDnsContext().GetConfigs()

dnsServerIPStr, added := n.addDNSContext(request.GetConnection())
var recordNames, err = n.buildSrcDNSRecords(request.GetConnection())
dnsServerIPStr, err := n.addDNSContext(request.GetConnection())
if err != nil {
return nil, err
}

var recordNames []string
recordNames, err = n.buildSrcDNSRecords(request.GetConnection())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -141,7 +144,7 @@ func (n *vl3DNSServer) Request(ctx context.Context, request *networkservice.Netw
var lastPrefix = srcRoutes[len(srcRoutes)-1].Prefix
for _, config := range clientsConfigs {
for _, serverIP := range config.DnsServerIps {
if added && dnsServerIPStr == serverIP {
if dnsServerIPStr == serverIP {
continue
}
if withinPrefix(serverIP, lastPrefix) {
Expand All @@ -168,7 +171,7 @@ func (n *vl3DNSServer) Close(ctx context.Context, conn *networkservice.Connectio
return next.Server(ctx).Close(ctx, conn)
}

func (n *vl3DNSServer) addDNSContext(c *networkservice.Connection) (added string, ok bool) {
func (n *vl3DNSServer) addDNSContext(c *networkservice.Connection) (added string, err error) {
if ip := n.dnsServerIP.Load(); ip != nil {
dnsServerIP := ip.(net.IP)
var dnsContext = c.GetContext().GetDnsContext()
Expand All @@ -178,9 +181,12 @@ func (n *vl3DNSServer) addDNSContext(c *networkservice.Connection) (added string
if !dnsutils.ContainsDNSConfig(dnsContext.Configs, configToAdd) {
dnsContext.Configs = append(dnsContext.Configs, configToAdd)
}
return dnsServerIP.String(), true
return dnsServerIP.String(), nil
} else if c.GetPath().GetPathSegments()[0].Name == c.GetCurrentPathSegment().Name {
// If it calls itself - this is not an error, but a request to allocate a dns address
return "", nil
}
return "", false
return "", errors.New("DNS address is initializing")
}

func (n *vl3DNSServer) buildSrcDNSRecords(c *networkservice.Connection) ([]string, error) {
Expand All @@ -195,40 +201,6 @@ func (n *vl3DNSServer) buildSrcDNSRecords(c *networkservice.Connection) ([]strin
return result, nil
}

func (n *vl3DNSServer) checkServerAddressUpdates(ctx context.Context) {
n.monitorEventConsumer, _ = monitor.LoadEventConsumer(ctx, metadata.IsClient(n))
for {
select {
case <-n.chainCtx.Done():
return
case addr, ok := <-n.dnsServerIPCh:
if !ok {
return
}

n.updateServerAddress(addr)
}
}
}

func (n *vl3DNSServer) updateServerAddress(address net.IP) {
n.dnsServerIP.Store(address)

if n.monitorEventConsumer != nil {
conns := n.monitorEventConsumer.GetConnections()
for _, c := range conns {
c.State = networkservice.State_REFRESH_REQUESTED
}
_ = n.monitorEventConsumer.Send(&networkservice.ConnectionEvent{
Type: networkservice.ConnectionEventType_UPDATE,
Connections: conns,
})
} else {
log.FromContext(n.chainCtx).WithField("vl3DNSServer", "updateServerAddress").
Debug("eventConsumer is not presented")
}
}

func compareStringSlices(a, b []string) bool {
if len(a) != len(b) {
return false
Expand Down
4 changes: 2 additions & 2 deletions pkg/tools/dnsutils/dnsutils.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2022 Cisco and/or its affiliates.
// Copyright (c) 2022-2023 Cisco and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
Expand Down Expand Up @@ -37,7 +37,7 @@ func ListenAndServe(ctx context.Context, handler Handler, listenOn string) {

for _, network := range networks {
var server = &dns.Server{Addr: listenOn, Net: network, Handler: dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
var timeoutCtx, cancel = context.WithTimeout(context.Background(), time.Second*5)
var timeoutCtx, cancel = context.WithTimeout(ctx, time.Second*5)
defer cancel()

handler.ServeDNS(timeoutCtx, w, m)
Expand Down