Skip to content

Commit

Permalink
Add retry attempts to refresh
Browse files Browse the repository at this point in the history
Signed-off-by: Artem Glazychev <artem.glazychev@xored.com>
  • Loading branch information
glazychev-art committed Feb 5, 2024
1 parent 90f0c79 commit 40ba527
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 40 deletions.
33 changes: 24 additions & 9 deletions pkg/networkservice/chains/nsmgr/heal_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Copyright (c) 2020-2022 Doc.ai and/or its affiliates.
//
// Copyright (c) 2023 Cisco and/or its affiliates.
// Copyright (c) 2023-2024 Cisco and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
Expand Down Expand Up @@ -859,9 +859,8 @@ func TestNSMGR_RefreshFailed_ControlPlaneBroken(t *testing.T) {
),
)

requestCtx, requestCalcel := context.WithTimeout(ctx, time.Second)
requestCtx = clock.WithClock(requestCtx, clk)
defer requestCalcel()
requestCtx, requestCancel := context.WithTimeout(ctx, time.Second)
defer requestCancel()

// allow the first Request
syncCh <- struct{}{}
Expand All @@ -871,21 +870,37 @@ func TestNSMGR_RefreshFailed_ControlPlaneBroken(t *testing.T) {

// refresh interval in this test is expected to be 3 minutes and a few milliseconds
clk.Add(time.Second * 190)
// start goroutine that will update mock clock every 50 ms. It is needed for retry refresh
go func() {
tickerDuration := time.Millisecond * 50
tickCh := time.Tick(tickerDuration)
for {
select {
case <-ctx.Done():
return
case <-tickCh:
clk.Add(tickerDuration)
}
}
}()

// kill the forwarder during the healing Request (it is stopped by syncCh). Then continue - the healing process will fail.
for _, forwarder := range domain.Nodes[0].Forwarders {
// kill the forwarder during the refresh (it is stopped by syncCh). Then continue - the refresh will fail.
for idx := range domain.Nodes[0].Forwarders {
forwarder := domain.Nodes[0].Forwarders[idx]
forwarder.Cancel()
break
// wait until the forwarder dies
require.Eventually(t, func() bool {
return sandbox.CheckURLFree(forwarder.URL)
}, timeout, tick)
}
syncCh <- struct{}{}
close(syncCh)

// create a new forwarder and allow the healing Request
forwarderReg := &registry.NetworkServiceEndpoint{
Name: sandbox.UniqueName("forwarder-2"),
NetworkServiceNames: []string{"forwarder"},
}
domain.Nodes[0].NewForwarder(ctx, forwarderReg, sandbox.GenerateTestToken)
syncCh <- struct{}{}

// wait till Request reached NSE
require.Eventually(t, func() bool {
Expand Down
27 changes: 14 additions & 13 deletions pkg/networkservice/chains/nsmgr/single_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Copyright (c) 2020-2022 Doc.ai and/or its affiliates.
//
// Copyright (c) 2023 Cisco and/or its affiliates.
// Copyright (c) 2023-2024 Cisco and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
Expand Down Expand Up @@ -48,7 +48,6 @@ import (
"github.com/networkservicemesh/sdk/pkg/networkservice/common/authorize"
"github.com/networkservicemesh/sdk/pkg/networkservice/common/begin"
"github.com/networkservicemesh/sdk/pkg/networkservice/common/excludedprefixes"
"github.com/networkservicemesh/sdk/pkg/networkservice/common/heal"
"github.com/networkservicemesh/sdk/pkg/networkservice/ipam/point2pointipam"
"github.com/networkservicemesh/sdk/pkg/networkservice/utils/checks/checkcontext"
"github.com/networkservicemesh/sdk/pkg/networkservice/utils/checks/checkrequest"
Expand Down Expand Up @@ -625,7 +624,7 @@ func Test_RestartDuringRefresh(t *testing.T) {
require.NoError(t, err)

var countServer count.Server
var countClint count.Client
var countClientBack count.ClientBackward
var m sync.Once
var clientFactory begin.EventFactory
var destroyFwd atomic.Bool
Expand All @@ -636,16 +635,21 @@ func Test_RestartDuringRefresh(t *testing.T) {
NetworkServiceNames: []string{"ns"},
}, sandbox.GenerateTestToken, &countServer, checkrequest.NewServer(t, func(t *testing.T, nsr *networkservice.NetworkServiceRequest) {
if destroyFwd.Load() {
e.AsyncExec(func() {
for _, fwd := range domain.Nodes[0].Forwarders {
fwd.Cancel()
<-e.AsyncExec(func() {
for idx := range domain.Nodes[0].Forwarders {
forwarder := domain.Nodes[0].Forwarders[idx]
forwarder.Cancel()
// wait until the forwarder dies
require.Eventually(t, func() bool {
return sandbox.CheckURLFree(forwarder.URL)
}, timeout, tick)
}
})
}
}))

var nsc = domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken, client.WithAdditionalFunctionality(
&countClint,
&countClientBack,
checkcontext.NewClient(t, func(t *testing.T, ctx context.Context) {
m.Do(func() {
clientFactory = begin.FromContext(ctx)
Expand All @@ -660,7 +664,6 @@ func Test_RestartDuringRefresh(t *testing.T) {
})
}
}),
heal.NewClient(ctx),
))

_, err = nsc.Request(ctx, &networkservice.NetworkServiceRequest{
Expand All @@ -673,16 +676,14 @@ func Test_RestartDuringRefresh(t *testing.T) {
<-clientFactory.Request()
require.Equal(t, 2, countServer.Requests())
require.Never(t, func() bool { return countServer.Requests() > 2 }, time.Second/2, time.Second/20)
destroyFwd.Store(true)
for i := 0; i < 15; i++ {
var cs = countServer.Requests()
destroyFwd.Store(true)
err = <-clientFactory.Request()
require.Error(t, err)
var cc = countClientBack.Requests()
destroyFwd.Store(false)
var cc = countClint.Requests()
require.Eventually(t, func() bool { return cs < countServer.Requests() }, time.Second*2, time.Second/20)
require.Eventually(t, func() bool { return cc < countClint.Requests() }, time.Second*2, time.Second/20)
// Heal must be successful eventually
require.Eventually(t, func() bool { return cc < countClientBack.Requests() }, time.Second*2, time.Second/20)
}
}

Expand Down
11 changes: 5 additions & 6 deletions pkg/networkservice/common/refresh/client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Copyright (c) 2020 Cisco Systems, Inc.
// Copyright (c) 2020-2024 Cisco Systems, Inc.
//
// Copyright (c) 2020-2022 Doc.ai and/or its affiliates.
// Copyright (c) 2020-2024 Doc.ai and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
Expand Down Expand Up @@ -68,18 +68,17 @@ func (t *refreshClient) Request(ctx context.Context, request *networkservice.Net
store(ctx, metadata.IsClient(t), cancel)

eventFactory := begin.FromContext(ctx)
clockTime := clock.FromContext(ctx)
// Create the afterCh *outside* the go routine. This must be done to avoid picking up a later 'now'
// from mockClock in testing
afterTicker := clockTime.Ticker(refreshAfter)
afterCh := clock.FromContext(ctx).After(refreshAfter)
go func() {
defer afterTicker.Stop()
for {
select {
case <-cancelCtx.Done():
return
case <-afterTicker.C():
case <-afterCh:
if err := <-eventFactory.Request(begin.CancelContext(cancelCtx)); err != nil {
afterCh = clock.FromContext(ctx).After(time.Millisecond * 200)
logger.Warnf("refresh failed: %s", err.Error())
continue
}
Expand Down
57 changes: 45 additions & 12 deletions pkg/networkservice/common/refresh/client_test.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// Copyright (c) 2020-2021 Doc.ai and/or its affiliates.
//
// Copyright (c) 2024 Cisco and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -48,6 +50,7 @@ import (

const (
expireTimeout = 15 * time.Minute
retryTimeout = 200 * time.Millisecond
testWait = 100 * time.Millisecond
testTick = testWait / 100

Expand All @@ -69,19 +72,19 @@ func testTokenFuncWithTimeout(clockTime clock.Clock, timeout time.Duration) toke
}
}

type captureTickerDuration struct {
type captureAfterDuration struct {
*clockmock.Mock

tickerDuration time.Duration
afterDuration time.Duration
}

func (m *captureTickerDuration) Ticker(d time.Duration) clock.Ticker {
m.tickerDuration = d
return m.Mock.Ticker(d)
func (m *captureAfterDuration) After(d time.Duration) <-chan time.Time {
m.afterDuration = d
return m.Mock.After(d)
}

func (m *captureTickerDuration) Reset(t time.Time) {
m.tickerDuration = 0
func (m *captureAfterDuration) Reset(t time.Time) {
m.afterDuration = 0
m.Set(t)
}

Expand Down Expand Up @@ -355,7 +358,7 @@ func TestRefreshClient_CalculatesShortestTokenTimeout(t *testing.T) {

timeNow := time.Date(2009, 11, 10, 23, 0, 0, 0, time.Local)

clockMock := captureTickerDuration{
clockMock := captureAfterDuration{
Mock: clockmock.New(ctx),
}

Expand Down Expand Up @@ -389,14 +392,14 @@ func TestRefreshClient_CalculatesShortestTokenTimeout(t *testing.T) {
})
require.NoError(t, err)

require.Less(t, clockMock.tickerDuration, testDataElement.ExpectedRefreshTimeout+timeoutDelta)
require.Greater(t, clockMock.tickerDuration, testDataElement.ExpectedRefreshTimeout-timeoutDelta)
require.Less(t, clockMock.afterDuration, testDataElement.ExpectedRefreshTimeout+timeoutDelta)
require.Greater(t, clockMock.afterDuration, testDataElement.ExpectedRefreshTimeout-timeoutDelta)
}

require.Equal(t, countClient.Requests(), len(testData))
}

func TestRefreshClient_RefreshOnRefreshFailure(t *testing.T) {
func TestRefreshClient_RetryOnRefreshFailure(t *testing.T) {
t.Cleanup(func() { goleak.VerifyNone(t) })

ctx, cancel := context.WithCancel(context.Background())
Expand All @@ -422,7 +425,37 @@ func TestRefreshClient_RefreshOnRefreshFailure(t *testing.T) {

require.Eventually(t, cloneClient.validator(2), testWait, testTick)

clockMock.Add(expireTimeout)
clockMock.Add(retryTimeout)

require.Eventually(t, cloneClient.validator(3), testWait, testTick)
}

func TestRefreshClient_NoRetryOnRefreshSuccess(t *testing.T) {
t.Cleanup(func() { goleak.VerifyNone(t) })

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

clockMock := clockmock.New(ctx)

cloneClient := &countClient{
t: t,
}
client := testClient(ctx, testTokenFunc(clockMock),
clockMock,
cloneClient,
)

_, err := client.Request(ctx, &networkservice.NetworkServiceRequest{
Connection: new(networkservice.Connection),
})
require.NoError(t, err)

clockMock.Add(expireTimeout)

require.Eventually(t, cloneClient.validator(2), testWait, testTick)

clockMock.Add(retryTimeout)

require.Never(t, cloneClient.validator(3), testWait, testTick)
}
72 changes: 72 additions & 0 deletions pkg/networkservice/utils/count/client_backward.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright (c) 2024 Cisco and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at:
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package count

import (
"context"
"sync/atomic"

"google.golang.org/grpc"

"github.com/golang/protobuf/ptypes/empty"
"github.com/networkservicemesh/api/pkg/api/networkservice"

"github.com/networkservicemesh/sdk/pkg/networkservice/core/next"
)

// ClientBackward checks the connection on the way back
type ClientBackward struct {
Client
}

// Request performs request and increments requests count
func (c *ClientBackward) Request(ctx context.Context, request *networkservice.NetworkServiceRequest, opts ...grpc.CallOption) (*networkservice.Connection, error) {
conn, err := next.Client(ctx).Request(ctx, request, opts...)
if err != nil {
return conn, err
}

c.mu.Lock()
defer c.mu.Unlock()

atomic.AddInt32(&c.totalRequests, 1)
if c.requests == nil {
c.requests = make(map[string]int32)
}
c.requests[request.GetConnection().GetId()]++

return conn, err
}

// Close performs close and increments closes count
func (c *ClientBackward) Close(ctx context.Context, connection *networkservice.Connection, opts ...grpc.CallOption) (*empty.Empty, error) {
r, err := next.Client(ctx).Close(ctx, connection, opts...)
if err != nil {
return r, err
}

c.mu.Lock()
defer c.mu.Unlock()

atomic.AddInt32(&c.totalCloses, 1)
if c.closes == nil {
c.closes = make(map[string]int32)
}
c.closes[connection.GetId()]++

return r, err
}

0 comments on commit 40ba527

Please sign in to comment.