Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add retry attempts to refresh #1583

Merged
merged 2 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions pkg/networkservice/chains/nsmgr/heal_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Copyright (c) 2020-2022 Doc.ai and/or its affiliates.
//
// Copyright (c) 2023 Cisco and/or its affiliates.
// Copyright (c) 2023-2024 Cisco and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
Expand Down Expand Up @@ -859,9 +859,8 @@ func TestNSMGR_RefreshFailed_ControlPlaneBroken(t *testing.T) {
),
)

requestCtx, requestCalcel := context.WithTimeout(ctx, time.Second)
requestCtx = clock.WithClock(requestCtx, clk)
defer requestCalcel()
requestCtx, requestCancel := context.WithTimeout(ctx, time.Second)
defer requestCancel()

// allow the first Request
syncCh <- struct{}{}
Expand All @@ -871,21 +870,37 @@ func TestNSMGR_RefreshFailed_ControlPlaneBroken(t *testing.T) {

// refresh interval in this test is expected to be 3 minutes and a few milliseconds
clk.Add(time.Second * 190)
// start goroutine that will update mock clock every 50 ms. It is needed for retry refresh
go func() {
tickerDuration := time.Millisecond * 50
tickCh := time.Tick(tickerDuration)
for {
select {
case <-ctx.Done():
return
case <-tickCh:
clk.Add(tickerDuration)
}
}
}()

// kill the forwarder during the healing Request (it is stopped by syncCh). Then continue - the healing process will fail.
for _, forwarder := range domain.Nodes[0].Forwarders {
// kill the forwarder during the refresh (it is stopped by syncCh). Then continue - the refresh will fail.
for idx := range domain.Nodes[0].Forwarders {
forwarder := domain.Nodes[0].Forwarders[idx]
forwarder.Cancel()
break
// wait until the forwarder dies
require.Eventually(t, func() bool {
return sandbox.CheckURLFree(forwarder.URL)
}, timeout, tick)
}
syncCh <- struct{}{}
close(syncCh)

// create a new forwarder and allow the healing Request
forwarderReg := &registry.NetworkServiceEndpoint{
Name: sandbox.UniqueName("forwarder-2"),
NetworkServiceNames: []string{"forwarder"},
}
domain.Nodes[0].NewForwarder(ctx, forwarderReg, sandbox.GenerateTestToken)
syncCh <- struct{}{}

// wait till Request reached NSE
require.Eventually(t, func() bool {
Expand Down
27 changes: 14 additions & 13 deletions pkg/networkservice/chains/nsmgr/single_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Copyright (c) 2020-2022 Doc.ai and/or its affiliates.
//
// Copyright (c) 2023 Cisco and/or its affiliates.
// Copyright (c) 2023-2024 Cisco and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
Expand Down Expand Up @@ -48,7 +48,6 @@ import (
"github.com/networkservicemesh/sdk/pkg/networkservice/common/authorize"
"github.com/networkservicemesh/sdk/pkg/networkservice/common/begin"
"github.com/networkservicemesh/sdk/pkg/networkservice/common/excludedprefixes"
"github.com/networkservicemesh/sdk/pkg/networkservice/common/heal"
"github.com/networkservicemesh/sdk/pkg/networkservice/ipam/point2pointipam"
"github.com/networkservicemesh/sdk/pkg/networkservice/utils/checks/checkcontext"
"github.com/networkservicemesh/sdk/pkg/networkservice/utils/checks/checkrequest"
Expand Down Expand Up @@ -625,7 +624,7 @@ func Test_RestartDuringRefresh(t *testing.T) {
require.NoError(t, err)

var countServer count.Server
var countClint count.Client
var countClientBack count.ClientBackward
var m sync.Once
var clientFactory begin.EventFactory
var destroyFwd atomic.Bool
Expand All @@ -636,16 +635,21 @@ func Test_RestartDuringRefresh(t *testing.T) {
NetworkServiceNames: []string{"ns"},
}, sandbox.GenerateTestToken, &countServer, checkrequest.NewServer(t, func(t *testing.T, nsr *networkservice.NetworkServiceRequest) {
if destroyFwd.Load() {
e.AsyncExec(func() {
for _, fwd := range domain.Nodes[0].Forwarders {
fwd.Cancel()
<-e.AsyncExec(func() {
for idx := range domain.Nodes[0].Forwarders {
forwarder := domain.Nodes[0].Forwarders[idx]
forwarder.Cancel()
// wait until the forwarder dies
require.Eventually(t, func() bool {
return sandbox.CheckURLFree(forwarder.URL)
}, timeout, tick)
}
})
}
}))

var nsc = domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken, client.WithAdditionalFunctionality(
&countClint,
&countClientBack,
checkcontext.NewClient(t, func(t *testing.T, ctx context.Context) {
m.Do(func() {
clientFactory = begin.FromContext(ctx)
Expand All @@ -660,7 +664,6 @@ func Test_RestartDuringRefresh(t *testing.T) {
})
}
}),
heal.NewClient(ctx),
))

_, err = nsc.Request(ctx, &networkservice.NetworkServiceRequest{
Expand All @@ -673,16 +676,14 @@ func Test_RestartDuringRefresh(t *testing.T) {
<-clientFactory.Request()
require.Equal(t, 2, countServer.Requests())
require.Never(t, func() bool { return countServer.Requests() > 2 }, time.Second/2, time.Second/20)
destroyFwd.Store(true)
for i := 0; i < 15; i++ {
var cs = countServer.Requests()
destroyFwd.Store(true)
err = <-clientFactory.Request()
require.Error(t, err)
var cc = countClientBack.Requests()
destroyFwd.Store(false)
var cc = countClint.Requests()
require.Eventually(t, func() bool { return cs < countServer.Requests() }, time.Second*2, time.Second/20)
require.Eventually(t, func() bool { return cc < countClint.Requests() }, time.Second*2, time.Second/20)
// Heal must be successful eventually
require.Eventually(t, func() bool { return cc < countClientBack.Requests() }, time.Second*2, time.Second/20)
}
}

Expand Down
4 changes: 3 additions & 1 deletion pkg/networkservice/common/connect/server_test.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// Copyright (c) 2020-2022 Doc.ai and/or its affiliates.
//
// Copyright (c) 2024 Cisco and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -197,7 +199,7 @@ func TestConnectServer_RequestParallel(t *testing.T) {
connect.NewServer(
next.NewNetworkServiceClient(
dial.NewClient(context.Background(),
dial.WithDialTimeout(time.Second),
dial.WithDialTimeout(time.Second*2),
dial.WithDialOptions(grpc.WithTransportCredentials(insecure.NewCredentials())),
),
serverClient,
Expand Down
11 changes: 5 additions & 6 deletions pkg/networkservice/common/refresh/client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Copyright (c) 2020 Cisco Systems, Inc.
// Copyright (c) 2020-2024 Cisco Systems, Inc.
//
// Copyright (c) 2020-2022 Doc.ai and/or its affiliates.
// Copyright (c) 2020-2024 Doc.ai and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
Expand Down Expand Up @@ -68,18 +68,17 @@ func (t *refreshClient) Request(ctx context.Context, request *networkservice.Net
store(ctx, metadata.IsClient(t), cancel)

eventFactory := begin.FromContext(ctx)
clockTime := clock.FromContext(ctx)
// Create the afterCh *outside* the go routine. This must be done to avoid picking up a later 'now'
// from mockClock in testing
afterTicker := clockTime.Ticker(refreshAfter)
afterCh := clock.FromContext(ctx).After(refreshAfter)
go func() {
defer afterTicker.Stop()
for {
select {
case <-cancelCtx.Done():
return
case <-afterTicker.C():
case <-afterCh:
if err := <-eventFactory.Request(begin.CancelContext(cancelCtx)); err != nil {
afterCh = clock.FromContext(ctx).After(time.Millisecond * 200)
logger.Warnf("refresh failed: %s", err.Error())
continue
}
Expand Down
57 changes: 45 additions & 12 deletions pkg/networkservice/common/refresh/client_test.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// Copyright (c) 2020-2021 Doc.ai and/or its affiliates.
//
// Copyright (c) 2024 Cisco and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -48,6 +50,7 @@ import (

const (
expireTimeout = 15 * time.Minute
retryTimeout = 200 * time.Millisecond
testWait = 100 * time.Millisecond
testTick = testWait / 100

Expand All @@ -69,19 +72,19 @@ func testTokenFuncWithTimeout(clockTime clock.Clock, timeout time.Duration) toke
}
}

type captureTickerDuration struct {
type captureAfterDuration struct {
*clockmock.Mock

tickerDuration time.Duration
afterDuration time.Duration
}

func (m *captureTickerDuration) Ticker(d time.Duration) clock.Ticker {
m.tickerDuration = d
return m.Mock.Ticker(d)
func (m *captureAfterDuration) After(d time.Duration) <-chan time.Time {
m.afterDuration = d
return m.Mock.After(d)
}

func (m *captureTickerDuration) Reset(t time.Time) {
m.tickerDuration = 0
func (m *captureAfterDuration) Reset(t time.Time) {
m.afterDuration = 0
m.Set(t)
}

Expand Down Expand Up @@ -355,7 +358,7 @@ func TestRefreshClient_CalculatesShortestTokenTimeout(t *testing.T) {

timeNow := time.Date(2009, 11, 10, 23, 0, 0, 0, time.Local)

clockMock := captureTickerDuration{
clockMock := captureAfterDuration{
Mock: clockmock.New(ctx),
}

Expand Down Expand Up @@ -389,14 +392,14 @@ func TestRefreshClient_CalculatesShortestTokenTimeout(t *testing.T) {
})
require.NoError(t, err)

require.Less(t, clockMock.tickerDuration, testDataElement.ExpectedRefreshTimeout+timeoutDelta)
require.Greater(t, clockMock.tickerDuration, testDataElement.ExpectedRefreshTimeout-timeoutDelta)
require.Less(t, clockMock.afterDuration, testDataElement.ExpectedRefreshTimeout+timeoutDelta)
require.Greater(t, clockMock.afterDuration, testDataElement.ExpectedRefreshTimeout-timeoutDelta)
}

require.Equal(t, countClient.Requests(), len(testData))
}

func TestRefreshClient_RefreshOnRefreshFailure(t *testing.T) {
func TestRefreshClient_RetryOnRefreshFailure(t *testing.T) {
t.Cleanup(func() { goleak.VerifyNone(t) })

ctx, cancel := context.WithCancel(context.Background())
Expand All @@ -422,7 +425,37 @@ func TestRefreshClient_RefreshOnRefreshFailure(t *testing.T) {

require.Eventually(t, cloneClient.validator(2), testWait, testTick)

clockMock.Add(expireTimeout)
clockMock.Add(retryTimeout)

require.Eventually(t, cloneClient.validator(3), testWait, testTick)
}

func TestRefreshClient_NoRetryOnRefreshSuccess(t *testing.T) {
t.Cleanup(func() { goleak.VerifyNone(t) })

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

clockMock := clockmock.New(ctx)

cloneClient := &countClient{
t: t,
}
client := testClient(ctx, testTokenFunc(clockMock),
clockMock,
cloneClient,
)

_, err := client.Request(ctx, &networkservice.NetworkServiceRequest{
Connection: new(networkservice.Connection),
})
require.NoError(t, err)

clockMock.Add(expireTimeout)

require.Eventually(t, cloneClient.validator(2), testWait, testTick)

clockMock.Add(retryTimeout)

require.Never(t, cloneClient.validator(3), testWait, testTick)
}
72 changes: 72 additions & 0 deletions pkg/networkservice/utils/count/client_backward.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright (c) 2024 Cisco and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at:
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package count

import (
"context"
"sync/atomic"

"google.golang.org/grpc"

"github.com/golang/protobuf/ptypes/empty"
"github.com/networkservicemesh/api/pkg/api/networkservice"

"github.com/networkservicemesh/sdk/pkg/networkservice/core/next"
)

// ClientBackward checks the connection on the way back
type ClientBackward struct {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of creating a new chain element, I'd suggest updating the count chain element to support the backward counters

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Client
}

// Request performs request and increments requests count
func (c *ClientBackward) Request(ctx context.Context, request *networkservice.NetworkServiceRequest, opts ...grpc.CallOption) (*networkservice.Connection, error) {
conn, err := next.Client(ctx).Request(ctx, request, opts...)
if err != nil {
return conn, err
}

c.mu.Lock()
defer c.mu.Unlock()

atomic.AddInt32(&c.totalRequests, 1)
if c.requests == nil {
c.requests = make(map[string]int32)
}
c.requests[request.GetConnection().GetId()]++

return conn, err
}

// Close performs close and increments closes count
func (c *ClientBackward) Close(ctx context.Context, connection *networkservice.Connection, opts ...grpc.CallOption) (*empty.Empty, error) {
r, err := next.Client(ctx).Close(ctx, connection, opts...)
if err != nil {
return r, err
}

c.mu.Lock()
defer c.mu.Unlock()

atomic.AddInt32(&c.totalCloses, 1)
if c.closes == nil {
c.closes = make(map[string]int32)
}
c.closes[connection.GetId()]++

return r, err
}
Loading