Skip to content

Commit 4579a60

Browse files
authored
pass a single context throughout the device-plugin method call stack (#1284)
This change follows the Go best practices. With a single ctx reference, we allow for the proper propagation of cancellations and graceful terminations across all goroutines of the device-plugin application. Signed-off-by: Tariq Ibrahim <[email protected]>
1 parent 0ad1fdb commit 4579a60

File tree

4 files changed

+12
-8
lines changed

4 files changed

+12
-8
lines changed

cmd/nvidia-device-plugin/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ func startPlugins(c *cli.Context, o *options) ([]plugin.Interface, bool, error)
353353

354354
// Get the set of plugins.
355355
klog.Info("Retrieving plugins.")
356-
plugins, err := GetPlugins(infolib, nvmllib, devicelib, config)
356+
plugins, err := GetPlugins(c.Context, infolib, nvmllib, devicelib, config)
357357
if err != nil {
358358
return nil, false, fmt.Errorf("error getting plugins: %v", err)
359359
}

cmd/nvidia-device-plugin/plugin-manager.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package main
1818

1919
import (
20+
"context"
2021
"fmt"
2122

2223
"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
@@ -30,7 +31,7 @@ import (
3031
)
3132

3233
// GetPlugins returns a set of plugins for the specified configuration.
33-
func GetPlugins(infolib info.Interface, nvmllib nvml.Interface, devicelib device.Interface, config *spec.Config) ([]plugin.Interface, error) {
34+
func GetPlugins(ctx context.Context, infolib info.Interface, nvmllib nvml.Interface, devicelib device.Interface, config *spec.Config) ([]plugin.Interface, error) {
3435
// TODO: We could consider passing this as an argument since it should already be used to construct nvmllib.
3536
driverRoot := root(*config.Flags.Plugin.ContainerDriverRoot)
3637

@@ -61,7 +62,7 @@ func GetPlugins(infolib info.Interface, nvmllib nvml.Interface, devicelib device
6162
return nil, fmt.Errorf("unable to create cdi handler: %v", err)
6263
}
6364

64-
plugins, err := plugin.New(infolib, nvmllib, devicelib,
65+
plugins, err := plugin.New(ctx, infolib, nvmllib, devicelib,
6566
plugin.WithCDIHandler(cdiHandler),
6667
plugin.WithConfig(config),
6768
plugin.WithDeviceListStrategies(deviceListStrategies),

internal/plugin/factory.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package plugin
1818

1919
import (
20+
"context"
2021
"fmt"
2122

2223
"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
@@ -46,7 +47,7 @@ type options struct {
4647
}
4748

4849
// New a new set of plugins with the supplied options.
49-
func New(infolib info.Interface, nvmllib nvml.Interface, devicelib device.Interface, opts ...Option) ([]Interface, error) {
50+
func New(ctx context.Context, infolib info.Interface, nvmllib nvml.Interface, devicelib device.Interface, opts ...Option) ([]Interface, error) {
5051
o := &options{
5152
infolib: infolib,
5253
nvmllib: nvmllib,
@@ -72,7 +73,7 @@ func New(infolib info.Interface, nvmllib nvml.Interface, devicelib device.Interf
7273

7374
var plugins []Interface
7475
for _, resourceManager := range resourceManagers {
75-
plugin, err := o.devicePluginForResource(resourceManager)
76+
plugin, err := o.devicePluginForResource(ctx, resourceManager)
7677
if err != nil {
7778
return nil, fmt.Errorf("failed to create plugin: %w", err)
7879
}

internal/plugin/server.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ const (
5050

5151
// nvidiaDevicePlugin implements the Kubernetes device plugin API
5252
type nvidiaDevicePlugin struct {
53+
ctx context.Context
5354
rm rm.ResourceManager
5455
config *spec.Config
5556
deviceListStrategies spec.DeviceListStrategies
@@ -68,13 +69,14 @@ type nvidiaDevicePlugin struct {
6869
}
6970

7071
// devicePluginForResource creates a device plugin for the specified resource.
71-
func (o *options) devicePluginForResource(resourceManager rm.ResourceManager) (Interface, error) {
72+
func (o *options) devicePluginForResource(ctx context.Context, resourceManager rm.ResourceManager) (Interface, error) {
7273
mpsOptions, err := o.getMPSOptions(resourceManager)
7374
if err != nil {
7475
return nil, err
7576
}
7677

7778
plugin := nvidiaDevicePlugin{
79+
ctx: ctx,
7880
rm: resourceManager,
7981
config: o.config,
8082
deviceListStrategies: o.deviceListStrategies,
@@ -245,7 +247,7 @@ func (plugin *nvidiaDevicePlugin) Register(kubeletSocket string) error {
245247
},
246248
}
247249

248-
_, err = client.Register(context.Background(), reqt)
250+
_, err = client.Register(plugin.ctx, reqt)
249251
if err != nil {
250252
return err
251253
}
@@ -432,7 +434,7 @@ func (plugin *nvidiaDevicePlugin) PreStartContainer(context.Context, *pluginapi.
432434

433435
// dial establishes the gRPC communication with the registered device plugin.
434436
func (plugin *nvidiaDevicePlugin) dial(unixSocketPath string, timeout time.Duration) (*grpc.ClientConn, error) {
435-
ctx, cancel := context.WithTimeout(context.Background(), timeout)
437+
ctx, cancel := context.WithTimeout(plugin.ctx, timeout)
436438
defer cancel()
437439
//nolint:staticcheck // TODO: Switch to grpc.NewClient
438440
c, err := grpc.DialContext(ctx, unixSocketPath,

0 commit comments

Comments
 (0)