Skip to content

Commit

Permalink
Merge pull request #82 from adrianchiris/fix-kubelet-restart
Browse files Browse the repository at this point in the history
Fix kubelet restart
  • Loading branch information
e0ne authored Aug 15, 2023
2 parents fc182dd + ecbdda3 commit b3d2b2a
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 87 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ GOLANGCI_LINT = $(BINDIR)/golangci-lint
# we keep it fixed to avoid it from unexpectedly failing on the project
# in case of a version bump
GOLANGCI_LINT_VER = v1.51.2
TIMEOUT = 15
TIMEOUT = 20
Q = $(if $(filter 1,$V),,@)

.PHONY: all
Expand Down
5 changes: 4 additions & 1 deletion cmd/k8s-rdma-shared-dp/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@ func main() {

// add version flag
versionOpt := false
var configFilePath string
flag.BoolVar(&versionOpt, "version", false, "Show application version")
flag.BoolVar(&versionOpt, "v", false, "Show application version")
flag.StringVar(
&configFilePath, "config-file", resources.DefaultConfigFilePath, "path to device plugin config file")
flag.Parse()
if versionOpt {
fmt.Printf("%s\n", printVersionString())
Expand All @@ -36,7 +39,7 @@ func main() {

log.Println("Starting K8s RDMA Shared Device Plugin version=", version)

rm := resources.NewResourceManager()
rm := resources.NewResourceManager(configFilePath)

log.Println("resource manager reading configs")
if err := rm.ReadConfig(); err != nil {
Expand Down
6 changes: 3 additions & 3 deletions pkg/resources/resources_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (

const (
// General constants
configFilePath = "/k8s-rdma-shared-dev-plugin/config.json"
DefaultConfigFilePath = "/k8s-rdma-shared-dev-plugin/config.json"
kubeEndPoint = "kubelet.sock"
socketSuffix = "sock"
rdmaHcaResourcePrefix = "rdma"
Expand Down Expand Up @@ -54,15 +54,15 @@ type resourceManager struct {
PeriodicUpdateInterval time.Duration
}

func NewResourceManager() types.ResourceManager {
func NewResourceManager(configFile string) types.ResourceManager {
watcherMode := detectPluginWatchMode(activeSockDir)
if watcherMode {
fmt.Println("Using Kubelet Plugin Registry Mode")
} else {
fmt.Println("Using Deprecated Devie Plugin Registry Path")
}
return &resourceManager{
configFile: configFilePath,
configFile: configFile,
defaultResourcePrefix: rdmaHcaResourcePrefix,
socketSuffix: socketSuffix,
watchMode: watcherMode,
Expand Down
4 changes: 2 additions & 2 deletions pkg/resources/resources_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ var _ = Describe("ResourcesManger", func() {
activeSockDir = activeSockDirBackUP
}()

obj := NewResourceManager()
obj := NewResourceManager(DefaultConfigFilePath)
rm := obj.(*resourceManager)
Expect(rm.watchMode).To(Equal(true))
})
Expand All @@ -58,7 +58,7 @@ var _ = Describe("ResourcesManger", func() {
activeSockDir = activeSockDirBackUP
}()

obj := NewResourceManager()
obj := NewResourceManager(DefaultConfigFilePath)
rm := obj.(*resourceManager)
Expect(rm.watchMode).To(Equal(false))
})
Expand Down
39 changes: 10 additions & 29 deletions pkg/resources/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ type resourceServer struct {
watchMode bool
socketName string
socketPath string
stop chan interface{}
stopWatcher chan bool
updateResource chan bool
health chan *pluginapi.Device
Expand Down Expand Up @@ -84,26 +83,17 @@ func (rsc *resourcesServerPort) Register(client pluginapi.RegistrationClient, re
func (rsc *resourcesServerPort) Dial(unixSocketPath string, timeout time.Duration) (*grpc.ClientConn, error) {
var c *grpc.ClientConn
var err error
connChannel := make(chan interface{})

ctx, timeoutCancel := context.WithTimeout(context.TODO(), timeout)
ctx, timeoutCancel := context.WithTimeout(context.Background(), timeout)
defer timeoutCancel()
go func() {
c, err = grpc.DialContext(ctx, unixSocketPath, grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
return net.Dial("unix", addr)
}),
)
connChannel <- "done"
}()

select {
case <-ctx.Done():
return nil, fmt.Errorf("timout while trying to connect %s", unixSocketPath)
c, err = grpc.DialContext(
ctx, "unix://"+unixSocketPath, grpc.WithBlock(), grpc.WithTransportCredentials(insecure.NewCredentials()))

case <-connChannel:
return c, err
if err != nil {
return nil, fmt.Errorf("failed to connect %s, %w", unixSocketPath, err)
}

return c, nil
}

// newResourceServer returns an initialized server
Expand Down Expand Up @@ -148,7 +138,6 @@ func newResourceServer(config *types.UserConfig, devices []types.PciNetDevice, w
watchMode: watcherMode,
devs: devs,
deviceSpec: deviceSpec,
stop: make(chan interface{}),
stopWatcher: make(chan bool),
updateResource: make(chan bool, 1),
health: make(chan *pluginapi.Device),
Expand Down Expand Up @@ -207,12 +196,11 @@ func (rs *resourceServer) Stop() error {
return nil
}

// Send terminate signal to ListAndWatch()
rs.stop <- true
if !rs.watchMode {
rs.stopWatcher <- true
}

// Note: stopping RPC server will cancel any outstanding ListAndWatch() calls
rs.rsConnector.Stop()
rs.rsConnector.DeleteServer()

Expand All @@ -229,9 +217,6 @@ func (rs *resourceServer) Restart() error {
rs.rsConnector.Stop()
rs.rsConnector.DeleteServer()

// Send terminate signal to ListAndWatch()
rs.stop <- true

return rs.Start()
}

Expand Down Expand Up @@ -282,6 +267,7 @@ func (rs *resourceServer) register() error {

// ListAndWatch lists devices and update that list according to the health status
func (rs *resourceServer) ListAndWatch(e *pluginapi.Empty, s pluginapi.DevicePlugin_ListAndWatchServer) error {
log.Printf("ListAndWatch called by kubelet for: %s", rs.resourceName)
resp := new(pluginapi.ListAndWatchResponse)

// Send initial list of devices
Expand All @@ -294,8 +280,6 @@ func (rs *resourceServer) ListAndWatch(e *pluginapi.Empty, s pluginapi.DevicePlu
case <-s.Context().Done():
log.Printf("ListAndWatch stream close: %v", s.Context().Err())
return nil
case <-rs.stop:
return nil
case d := <-rs.health:
// FIXME: there is no way to recover from the Unhealthy state.
d.Health = pluginapi.Unhealthy
Expand Down Expand Up @@ -418,12 +402,11 @@ func (rs *resourceServer) UpdateDevices(devices []types.PciNetDevice) {
}

rs.deviceSpec = deviceSpec
needUpdate = true

// In case no RDMA resource report 0 resources
if len(rs.deviceSpec) == 0 {
rs.devs = []*pluginapi.Device{}
needUpdate = true

return
}

Expand All @@ -440,8 +423,6 @@ func (rs *resourceServer) UpdateDevices(devices []types.PciNetDevice) {
}
rs.devs = devs
}

needUpdate = true
}

func (rs *resourceServer) GetPreferredAllocation(
Expand Down
67 changes: 16 additions & 51 deletions pkg/resources/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"path"
"sync"
"time"

"github.com/Mellanox/k8s-rdma-shared-dev-plugin/pkg/types"
Expand Down Expand Up @@ -215,14 +216,8 @@ var _ = Describe("resourceServer tests", func() {
rs := resourceServer{
rsConnector: rsc,
watchMode: true,
stop: make(chan interface{}),
}

go func() {
stop := <-rs.stop
Expect(stop).To(BeTrue())
}()

err := rs.Stop()
Expect(err).ToNot(HaveOccurred())
rsc.AssertExpectations(testCallsAssertionReporter)
Expand All @@ -239,13 +234,10 @@ var _ = Describe("resourceServer tests", func() {
rsConnector: rsc,
watchMode: false,
stopWatcher: stopWatcher,
stop: make(chan interface{}),
}
// Dummy listener to stopWatcher to not block the test and fail
go func() {
stop := <-rs.stop
Expect(stop).To(BeTrue())
stop = <-rs.stopWatcher
stop := <-rs.stopWatcher
Expect(stop).To(BeTrue())
}()

Expand All @@ -272,14 +264,8 @@ var _ = Describe("resourceServer tests", func() {
rs := resourceServer{
watchMode: true,
rsConnector: rsc,
stop: make(chan interface{}),
}

go func() {
stop := <-rs.stop
Expect(stop).To(BeTrue())
}()

err := rs.Restart()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal("failed in restart"))
Expand All @@ -288,7 +274,6 @@ var _ = Describe("resourceServer tests", func() {
It("Failed to restart server with no grpc server", func() {
rs := resourceServer{
watchMode: true,
stop: make(chan interface{}),
}

err := rs.Restart()
Expand All @@ -308,14 +293,8 @@ var _ = Describe("resourceServer tests", func() {
rs := resourceServer{
watchMode: true,
rsConnector: rsc,
stop: make(chan interface{}),
}

go func() {
stop := <-rs.stop
Expect(stop).To(BeTrue())
}()

err := rs.Restart()
Expect(err).To(HaveOccurred())
rsc.AssertExpectations(testCallsAssertionReporter)
Expand Down Expand Up @@ -344,7 +323,6 @@ var _ = Describe("resourceServer tests", func() {
socketName: fakeSocketName,
socketPath: fakeSocketPath,
stopWatcher: make(chan bool),
stop: make(chan interface{}),
}
go func() {
rs.stopWatcher <- true
Expand All @@ -368,12 +346,10 @@ var _ = Describe("resourceServer tests", func() {
rsConnector: rsc,
socketName: fakeSocketName,
socketPath: "fake deleted",
stop: make(chan interface{}),
stopWatcher: make(chan bool),
}
go func() {
stop := <-rs.stop
Expect(stop).To(BeTrue())
time.Sleep(50 * time.Millisecond)
rs.stopWatcher <- true
}()
rs.Watch()
Expand All @@ -393,18 +369,19 @@ var _ = Describe("resourceServer tests", func() {

rs := obj.(*resourceServer)

rs.stop = make(chan interface{})
rs.health = make(chan *pluginapi.Device)
// Dummy sender
ctx, cancel := context.WithCancel(context.Background())
s := &devPluginListAndWatchServerMock{}
s.SetContext(ctx)

// report unhealthy devices then cancel context
go func() {
rs.health <- rs.devs[5]
// Make sure that health call before the stop
time.Sleep(1 * time.Millisecond)
rs.stop <- "stop"
cancel()
}()

s := &devPluginListAndWatchServerMock{}
s.SetContext(context.Background())
err = rs.ListAndWatch(nil, s)
Expect(err).ToNot(HaveOccurred())
Expect(s.devices).To(Equal(rs.devs))
Expand Down Expand Up @@ -599,17 +576,11 @@ var _ = Describe("resourceServer tests", func() {
err = rs.Start()
Expect(err).NotTo(HaveOccurred())

go func() {
stop := <-rs.stop
Expect(stop).To(BeTrue())
}()
err = rs.Restart()
Expect(err).NotTo(HaveOccurred())

go func() {
stop := <-rs.stop
Expect(stop).To(BeTrue())
stop = <-rs.stopWatcher
stop := <-rs.stopWatcher
Expect(stop).To(BeTrue())
}()

Expand All @@ -635,21 +606,12 @@ var _ = Describe("resourceServer tests", func() {
err = registrationServer.registerPlugin()
Expect(err).NotTo(HaveOccurred())

go func() {
stop := <-rs.stop
Expect(stop).To(BeTrue())
}()
err = rs.Restart()
Expect(err).NotTo(HaveOccurred())

err = registrationServer.registerPlugin()
Expect(err).NotTo(HaveOccurred())

go func() {
stop := <-rs.stop
Expect(stop).To(BeTrue())
}()

err = rs.Stop()
Expect(err).NotTo(HaveOccurred())
})
Expand All @@ -673,13 +635,16 @@ var _ = Describe("resourceServer tests", func() {
err = rs.Start()
Expect(err).NotTo(HaveOccurred())
// run socket watcher in background as in real-life
go rs.Watch()
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
stop := <-rs.stop
Expect(stop).To(BeTrue())
defer wg.Done()
rs.Watch()
}()

err = rs.Stop()
Expect(err).NotTo(HaveOccurred())
wg.Wait()
})
})
})
Expand Down

0 comments on commit b3d2b2a

Please sign in to comment.