diff --git a/pkg/nvcdi/common-nvml.go b/pkg/nvcdi/common-nvml.go index 4c634a72d..e21d92b06 100644 --- a/pkg/nvcdi/common-nvml.go +++ b/pkg/nvcdi/common-nvml.go @@ -41,7 +41,7 @@ func (l *nvmllib) newCommonNVMLDiscoverer() (discover.Discover, error) { l.logger.Warningf("failed to create discoverer for graphics mounts: %v", err) } - driverFiles, err := NewDriverDiscoverer(l.logger, l.driver, l.nvidiaCTKPath, l.nvmllib) + driverFiles, err := NewDriverDiscoverer(l.logger, l.driver, l.nvidiaCTKPath, l.nvmllib, l.firmwareSearchPaths...) if err != nil { return nil, fmt.Errorf("failed to create discoverer for driver files: %v", err) } diff --git a/pkg/nvcdi/driver-nvml.go b/pkg/nvcdi/driver-nvml.go index 28bd0704a..e61858788 100644 --- a/pkg/nvcdi/driver-nvml.go +++ b/pkg/nvcdi/driver-nvml.go @@ -23,17 +23,18 @@ import ( "strings" "github.com/NVIDIA/go-nvlib/pkg/nvml" + "golang.org/x/sys/unix" + "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/cuda" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root" - "golang.org/x/sys/unix" ) // NewDriverDiscoverer creates a discoverer for the libraries and binaries associated with a driver installation. // The supplied NVML Library is used to query the expected driver version. -func NewDriverDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCTKPath string, nvmllib nvml.Interface) (discover.Discover, error) { +func NewDriverDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCTKPath string, nvmllib nvml.Interface, fwSearchPaths ...string) (discover.Discover, error) { if r := nvmllib.Init(); r != nvml.SUCCESS { return nil, fmt.Errorf("failed to initialize NVML: %v", r) } @@ -48,10 +49,10 @@ func NewDriverDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCTK return nil, fmt.Errorf("failed to determine driver version: %v", r) } - return newDriverVersionDiscoverer(logger, driver, nvidiaCTKPath, version) + return newDriverVersionDiscoverer(logger, driver, nvidiaCTKPath, version, fwSearchPaths...) } -func newDriverVersionDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCTKPath string, version string) (discover.Discover, error) { +func newDriverVersionDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCTKPath string, version string, fwSearchPaths ...string) (discover.Discover, error) { libraries, err := NewDriverLibraryDiscoverer(logger, driver, nvidiaCTKPath, version) if err != nil { return nil, fmt.Errorf("failed to create discoverer for driver libraries: %v", err) @@ -62,7 +63,7 @@ func newDriverVersionDiscoverer(logger logger.Interface, driver *root.Driver, nv return nil, fmt.Errorf("failed to create discoverer for IPC sockets: %v", err) } - firmwares, err := NewDriverFirmwareDiscoverer(logger, driver.Root, version) + firmwares, err := NewDriverFirmwareDiscoverer(logger, driver.Root, version, fwSearchPaths...) if err != nil { return nil, fmt.Errorf("failed to create discoverer for GSP firmware: %v", err) } @@ -114,7 +115,11 @@ func getUTSRelease() (string, error) { return unix.ByteSliceToString(utsname.Release[:]), nil } -func getFirmwareSearchPaths(logger logger.Interface) ([]string, error) { +func getFirmwareSearchPaths(logger logger.Interface, fwSearchPaths ...string) ([]string, error) { + if len(fwSearchPaths) > 0 { + logger.Debugf("using custom firmware search paths configured with the library: %v", fwSearchPaths) + return fwSearchPaths, nil + } var firmwarePaths []string if p := getCustomFirmwareClassPath(logger); p != "" { @@ -149,8 +154,8 @@ func getCustomFirmwareClassPath(logger logger.Interface) string { } // NewDriverFirmwareDiscoverer creates a discoverer for GSP firmware associated with the specified driver version. -func NewDriverFirmwareDiscoverer(logger logger.Interface, driverRoot string, version string) (discover.Discover, error) { - gspFirmwareSearchPaths, err := getFirmwareSearchPaths(logger) +func NewDriverFirmwareDiscoverer(logger logger.Interface, driverRoot string, version string, fwSearchPaths ...string) (discover.Discover, error) { + gspFirmwareSearchPaths, err := getFirmwareSearchPaths(logger, fwSearchPaths...) if err != nil { return nil, fmt.Errorf("failed to get firmware search paths: %v", err) } diff --git a/pkg/nvcdi/lib.go b/pkg/nvcdi/lib.go index 3839697cd..842154b8f 100644 --- a/pkg/nvcdi/lib.go +++ b/pkg/nvcdi/lib.go @@ -22,6 +22,7 @@ import ( "github.com/NVIDIA/go-nvlib/pkg/nvlib/device" "github.com/NVIDIA/go-nvlib/pkg/nvlib/info" "github.com/NVIDIA/go-nvlib/pkg/nvml" + "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root" "github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv" @@ -39,15 +40,16 @@ type wrapper struct { } type nvcdilib struct { - logger logger.Interface - nvmllib nvml.Interface - mode string - devicelib device.Interface - deviceNamer DeviceNamer - driverRoot string - devRoot string - nvidiaCTKPath string - librarySearchPaths []string + logger logger.Interface + nvmllib nvml.Interface + mode string + devicelib device.Interface + deviceNamer DeviceNamer + driverRoot string + devRoot string + nvidiaCTKPath string + librarySearchPaths []string + firmwareSearchPaths []string csvFiles []string csvIgnorePatterns []string diff --git a/pkg/nvcdi/management.go b/pkg/nvcdi/management.go index 460a48739..284cf7788 100644 --- a/pkg/nvcdi/management.go +++ b/pkg/nvcdi/management.go @@ -22,12 +22,13 @@ import ( "strings" "github.com/NVIDIA/go-nvlib/pkg/nvlib/device" + "tags.cncf.io/container-device-interface/pkg/cdi" + "tags.cncf.io/container-device-interface/specs-go" + "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" "github.com/NVIDIA/nvidia-container-toolkit/internal/edits" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/cuda" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec" - "tags.cncf.io/container-device-interface/pkg/cdi" - "tags.cncf.io/container-device-interface/specs-go" ) type managementlib nvcdilib @@ -65,7 +66,7 @@ func (m *managementlib) GetCommonEdits() (*cdi.ContainerEdits, error) { return nil, fmt.Errorf("failed to get CUDA version: %v", err) } - driver, err := newDriverVersionDiscoverer(m.logger, m.driver, m.nvidiaCTKPath, version) + driver, err := newDriverVersionDiscoverer(m.logger, m.driver, m.nvidiaCTKPath, version, m.firmwareSearchPaths...) if err != nil { return nil, fmt.Errorf("failed to create driver library discoverer: %v", err) } diff --git a/pkg/nvcdi/options.go b/pkg/nvcdi/options.go index 86bb877de..a4a288385 100644 --- a/pkg/nvcdi/options.go +++ b/pkg/nvcdi/options.go @@ -19,6 +19,7 @@ package nvcdi import ( "github.com/NVIDIA/go-nvlib/pkg/nvlib/device" "github.com/NVIDIA/go-nvlib/pkg/nvml" + "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform" ) @@ -125,3 +126,11 @@ func WithLibrarySearchPaths(paths []string) Option { o.librarySearchPaths = paths } } + +// WithFirmwareSearchPaths sets the firmware search paths. +// This is currently only used for NVML- and Management-Mode. +func WithFirmwareSearchPaths(paths []string) Option { + return func(o *nvcdilib) { + o.firmwareSearchPaths = paths + } +} diff --git a/tools/container/toolkit/toolkit.go b/tools/container/toolkit/toolkit.go index d57ba9c1f..03446ce57 100644 --- a/tools/container/toolkit/toolkit.go +++ b/tools/container/toolkit/toolkit.go @@ -23,15 +23,16 @@ import ( "path/filepath" "strings" - "github.com/NVIDIA/nvidia-container-toolkit/internal/config" - "github.com/NVIDIA/nvidia-container-toolkit/internal/system/nvdevices" - "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi" - "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform" toml "github.com/pelletier/go-toml" log "github.com/sirupsen/logrus" "github.com/urfave/cli/v2" "tags.cncf.io/container-device-interface/pkg/cdi" "tags.cncf.io/container-device-interface/pkg/parser" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/config" + "github.com/NVIDIA/nvidia-container-toolkit/internal/system/nvdevices" + "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi" + "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform" ) const ( @@ -63,11 +64,12 @@ type options struct { ContainerCLIDebug string toolkitRoot string - cdiEnabled bool - cdiOutputDir string - cdiKind string - cdiVendor string - cdiClass string + cdiEnabled bool + cdiOutputDir string + cdiKind string + cdiVendor string + cdiClass string + cdiFirmwareSearchPaths cli.StringSlice acceptNVIDIAVisibleDevicesWhenUnprivileged bool acceptNVIDIAVisibleDevicesAsVolumeMounts bool @@ -216,6 +218,12 @@ func main() { Destination: &opts.cdiKind, EnvVars: []string{"CDI_KIND"}, }, + &cli.StringSliceFlag{ + Name: "cdi-firmware-search-paths", + Usage: "specify custom firmware search paths to be used during generation of a CDI specification", + Destination: &opts.cdiFirmwareSearchPaths, + EnvVars: []string{"CDI_FIRMWARE_SEARCH_PATHS"}, + }, &cli.BoolFlag{ Name: "ignore-errors", Usage: "ignore errors when installing the NVIDIA Container toolkit. This is used for testing purposes only.", @@ -701,6 +709,7 @@ func generateCDISpec(opts *options, nvidiaCTKPath string) error { nvcdi.WithMode(nvcdi.ModeManagement), nvcdi.WithDriverRoot(opts.DriverRootCtrPath), nvcdi.WithNVIDIACTKPath(nvidiaCTKPath), + nvcdi.WithFirmwareSearchPaths(opts.cdiFirmwareSearchPaths.Value()), nvcdi.WithVendor(opts.cdiVendor), nvcdi.WithClass(opts.cdiClass), )