Skip to content

Commit

Permalink
Return an error if there is no DNS address.
Browse files Browse the repository at this point in the history
Signed-off-by: Artem Glazychev <artem.glazychev@xored.com>
  • Loading branch information
glazychev-art committed Feb 20, 2023
1 parent ff8f08b commit 01c987a
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 94 deletions.
76 changes: 50 additions & 26 deletions pkg/networkservice/chains/nsmgr/vl3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,9 @@ import (
"go.uber.org/goleak"

"github.com/networkservicemesh/sdk/pkg/networkservice/chains/client"
"github.com/networkservicemesh/sdk/pkg/networkservice/common/upstreamrefresh"
"github.com/networkservicemesh/sdk/pkg/networkservice/connectioncontext/dnscontext/vl3dns"
"github.com/networkservicemesh/sdk/pkg/networkservice/connectioncontext/ipcontext/vl3"
"github.com/networkservicemesh/sdk/pkg/networkservice/utils/checks/checkconnection"
"github.com/networkservicemesh/sdk/pkg/tools/clock"
"github.com/networkservicemesh/sdk/pkg/tools/dnsutils"
"github.com/networkservicemesh/sdk/pkg/tools/dnsutils/memory"
"github.com/networkservicemesh/sdk/pkg/tools/sandbox"
Expand Down Expand Up @@ -224,7 +223,7 @@ func Test_vl3NSE_ConnectsTo_vl3NSE(t *testing.T) {
require.Error(t, err)
}

func Test_NSC_GetsVl3DnsAddressAfterRefresh(t *testing.T) {
func Test_NSC_GetsVl3DnsAddressDelay(t *testing.T) {
t.Cleanup(func() { goleak.VerifyNone(t) })

ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
Expand Down Expand Up @@ -259,34 +258,59 @@ func Test_NSC_GetsVl3DnsAddressAfterRefresh(t *testing.T) {
vl3dns.WithDomainSchemes("{{ index .Labels \"podName\" }}.{{ .NetworkService }}."),
vl3dns.WithDNSPort(40053)))

refresh := false
refreshCompletedCh := make(chan struct{}, 1)
nsc := domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken,
client.WithAdditionalFunctionality(
upstreamrefresh.NewClient(ctx),
checkconnection.NewClient(t, func(t *testing.T, conn *networkservice.Connection) {
if !refresh {
refresh = true
require.Len(t, conn.GetContext().GetDnsContext().GetConfigs(), 0)
} else {
require.Len(t, conn.GetContext().GetDnsContext().GetConfigs(), 1)
require.Len(t, conn.GetContext().GetDnsContext().GetConfigs()[0].DnsServerIps, 1)
require.Equal(t, conn.GetContext().GetDnsContext().GetConfigs()[0].DnsServerIps[0], "127.0.0.1")
refreshCompletedCh <- struct{}{}
}
}),
))
nsc := domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken)

req := defaultRequest(nsReg.Name)
req.Connection.Labels["podName"] = nscName
go func() {
// Add a delay
<-clock.FromContext(ctx).After(time.Millisecond * 200)
dnsServerIPCh <- net.ParseIP("10.0.0.0")
}()
_, err = nsc.Request(ctx, req)
require.NoError(t, err)
}

dnsServerIPCh <- net.ParseIP("127.0.0.1")
func Test_vl3NSE_ConnectsTo_Itself(t *testing.T) {
t.Cleanup(func() { goleak.VerifyNone(t) })

select {
case <-ctx.Done():
case <-refreshCompletedCh:
}
require.NoError(t, ctx.Err())
ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
defer cancel()

domain := sandbox.NewBuilder(ctx, t).
SetNodesCount(1).
SetNSMgrProxySupplier(nil).
SetRegistryProxySupplier(nil).
Build()

nsRegistryClient := domain.NewNSRegistryClient(ctx, sandbox.GenerateTestToken)

nsReg, err := nsRegistryClient.Register(ctx, defaultRegistryService("vl3"))
require.NoError(t, err)

nseReg := defaultRegistryEndpoint(nsReg.Name)

var serverPrefixCh = make(chan *ipam.PrefixResponse, 1)
defer close(serverPrefixCh)

serverPrefixCh <- &ipam.PrefixResponse{Prefix: "10.0.0.1/24"}
dnsServerIPCh := make(chan net.IP, 1)

_ = domain.Nodes[0].NewEndpoint(
ctx,
nseReg,
sandbox.GenerateTestToken,
vl3.NewServer(ctx, serverPrefixCh),
vl3dns.NewServer(ctx,
dnsServerIPCh,
vl3dns.WithDomainSchemes("{{ index .Labels \"podName\" }}.{{ .NetworkService }}."),
vl3dns.WithDNSPort(40053)))

// Connection to itself. This allows us to assign a dns address to ourselves.
nsc := domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken, client.WithName(nseReg.Name))
req := defaultRequest(nsReg.Name)
req.Connection.Labels["podName"] = nscName

_, err = nsc.Request(ctx, req)
require.NoError(t, err)
}
12 changes: 0 additions & 12 deletions pkg/networkservice/common/monitor/monitor_connection_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,21 +105,9 @@ func (m *monitorConnectionServer) Send(event *networkservice.ConnectionEvent) (_
return nil
}

func (m *monitorConnectionServer) GetConnections() map[string]*networkservice.Connection {
connections := make(map[string]*networkservice.Connection)

<-m.executor.AsyncExec(func() {
for k, v := range m.connections {
connections[k] = v
}
})
return connections
}

// EventConsumer - interface for monitor events sending
type EventConsumer interface {
Send(event *networkservice.ConnectionEvent) (err error)
GetConnections() map[string]*networkservice.Connection
}

var _ EventConsumer = &monitorConnectionServer{}
84 changes: 28 additions & 56 deletions pkg/networkservice/connectioncontext/dnscontext/vl3dns/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"fmt"
"net"
"strings"
"sync"
"sync/atomic"
"text/template"

Expand All @@ -31,7 +30,6 @@ import (
"github.com/networkservicemesh/api/pkg/api/networkservice"
"github.com/pkg/errors"

"github.com/networkservicemesh/sdk/pkg/networkservice/common/monitor"
"github.com/networkservicemesh/sdk/pkg/networkservice/core/next"
"github.com/networkservicemesh/sdk/pkg/networkservice/utils/metadata"
"github.com/networkservicemesh/sdk/pkg/tools/dnsutils"
Expand All @@ -42,11 +40,9 @@ import (
"github.com/networkservicemesh/sdk/pkg/tools/dnsutils/noloop"
"github.com/networkservicemesh/sdk/pkg/tools/dnsutils/norecursion"
"github.com/networkservicemesh/sdk/pkg/tools/ippool"
"github.com/networkservicemesh/sdk/pkg/tools/log"
)

type vl3DNSServer struct {
chainCtx context.Context
dnsServerRecords genericsync.Map[string, []net.IP]
dnsConfigs *genericsync.Map[string, []*networkservice.DNSConfig]
domainSchemeTemplates []*template.Template
Expand All @@ -55,8 +51,6 @@ type vl3DNSServer struct {
listenAndServeDNS func(ctx context.Context, handler dnsutils.Handler, listenOn string)
dnsServerIP atomic.Value
dnsServerIPCh <-chan net.IP
monitorEventConsumer monitor.EventConsumer
once sync.Once
}

type clientDNSNameKey struct{}
Expand All @@ -68,7 +62,6 @@ type clientDNSNameKey struct{}
// opts configure vl3dns networkservice instance with specific behavior.
func NewServer(chainCtx context.Context, dnsServerIPCh <-chan net.IP, opts ...Option) networkservice.NetworkServiceServer {
var result = &vl3DNSServer{
chainCtx: chainCtx,
dnsPort: 53,
listenAndServeDNS: dnsutils.ListenAndServe,
dnsConfigs: new(genericsync.Map[string, []*networkservice.DNSConfig]),
Expand All @@ -91,27 +84,37 @@ func NewServer(chainCtx context.Context, dnsServerIPCh <-chan net.IP, opts ...Op

result.listenAndServeDNS(chainCtx, result.dnsServer, fmt.Sprintf(":%v", result.dnsPort))

if len(dnsServerIPCh) > 0 {
result.dnsServerIP.Store(<-dnsServerIPCh)
}
go func() {
for {
select {
case <-chainCtx.Done():
return
case addr, ok := <-dnsServerIPCh:
if !ok {
return
}
result.dnsServerIP.Store(addr)
}
}
}()

return result
}

func (n *vl3DNSServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) {
n.once.Do(func() {
// We assume here that the monitorEventConsumer is the same for all connections.
// We need the context of any request to pull it out.
go n.checkServerAddressUpdates(ctx)
})

if request.GetConnection().GetContext().GetDnsContext() == nil {
request.Connection.Context.DnsContext = new(networkservice.DNSContext)
}

var clientsConfigs = request.GetConnection().GetContext().GetDnsContext().GetConfigs()

dnsServerIPStr, added := n.addDNSContext(request.GetConnection())
var recordNames, err = n.buildSrcDNSRecords(request.GetConnection())
dnsServerIPStr, err := n.addDNSContext(request.GetConnection())
if err != nil {
return nil, err
}

var recordNames []string
recordNames, err = n.buildSrcDNSRecords(request.GetConnection())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -141,7 +144,7 @@ func (n *vl3DNSServer) Request(ctx context.Context, request *networkservice.Netw
var lastPrefix = srcRoutes[len(srcRoutes)-1].Prefix
for _, config := range clientsConfigs {
for _, serverIP := range config.DnsServerIps {
if added && dnsServerIPStr == serverIP {
if dnsServerIPStr == serverIP {
continue
}
if withinPrefix(serverIP, lastPrefix) {
Expand All @@ -168,7 +171,7 @@ func (n *vl3DNSServer) Close(ctx context.Context, conn *networkservice.Connectio
return next.Server(ctx).Close(ctx, conn)
}

func (n *vl3DNSServer) addDNSContext(c *networkservice.Connection) (added string, ok bool) {
func (n *vl3DNSServer) addDNSContext(c *networkservice.Connection) (added string, err error) {
if ip := n.dnsServerIP.Load(); ip != nil {
dnsServerIP := ip.(net.IP)
var dnsContext = c.GetContext().GetDnsContext()
Expand All @@ -178,9 +181,12 @@ func (n *vl3DNSServer) addDNSContext(c *networkservice.Connection) (added string
if !dnsutils.ContainsDNSConfig(dnsContext.Configs, configToAdd) {
dnsContext.Configs = append(dnsContext.Configs, configToAdd)
}
return dnsServerIP.String(), true
return dnsServerIP.String(), nil
} else if c.GetPath().GetPathSegments()[0].Name == c.GetCurrentPathSegment().Name {
// If it calls itself - this is not an error, but a request to allocate a dns address
return "", nil
}
return "", false
return "", errors.New("DNS address is initializing")
}

func (n *vl3DNSServer) buildSrcDNSRecords(c *networkservice.Connection) ([]string, error) {
Expand All @@ -195,40 +201,6 @@ func (n *vl3DNSServer) buildSrcDNSRecords(c *networkservice.Connection) ([]strin
return result, nil
}

func (n *vl3DNSServer) checkServerAddressUpdates(ctx context.Context) {
n.monitorEventConsumer, _ = monitor.LoadEventConsumer(ctx, metadata.IsClient(n))
for {
select {
case <-n.chainCtx.Done():
return
case addr, ok := <-n.dnsServerIPCh:
if !ok {
return
}

n.updateServerAddress(addr)
}
}
}

func (n *vl3DNSServer) updateServerAddress(address net.IP) {
n.dnsServerIP.Store(address)

if n.monitorEventConsumer != nil {
conns := n.monitorEventConsumer.GetConnections()
for _, c := range conns {
c.State = networkservice.State_REFRESH_REQUESTED
}
_ = n.monitorEventConsumer.Send(&networkservice.ConnectionEvent{
Type: networkservice.ConnectionEventType_UPDATE,
Connections: conns,
})
} else {
log.FromContext(n.chainCtx).WithField("vl3DNSServer", "updateServerAddress").
Debug("eventConsumer is not presented")
}
}

func compareStringSlices(a, b []string) bool {
if len(a) != len(b) {
return false
Expand Down

0 comments on commit 01c987a

Please sign in to comment.