Skip to content

WIP: Add firmware-search-path option to nvcdi package #317

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/nvcdi/common-nvml.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func (l *nvmllib) newCommonNVMLDiscoverer() (discover.Discover, error) {
l.logger.Warningf("failed to create discoverer for graphics mounts: %v", err)
}

driverFiles, err := NewDriverDiscoverer(l.logger, l.driver, l.nvidiaCTKPath, l.nvmllib)
driverFiles, err := NewDriverDiscoverer(l.logger, l.driver, l.nvidiaCTKPath, l.nvmllib, l.firmwareSearchPaths...)
if err != nil {
return nil, fmt.Errorf("failed to create discoverer for driver files: %v", err)
}
Expand Down
21 changes: 13 additions & 8 deletions pkg/nvcdi/driver-nvml.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,18 @@ import (
"strings"

"github.com/NVIDIA/go-nvlib/pkg/nvml"
"golang.org/x/sys/unix"

"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/cuda"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
"golang.org/x/sys/unix"
)

// NewDriverDiscoverer creates a discoverer for the libraries and binaries associated with a driver installation.
// The supplied NVML Library is used to query the expected driver version.
func NewDriverDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCTKPath string, nvmllib nvml.Interface) (discover.Discover, error) {
func NewDriverDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCTKPath string, nvmllib nvml.Interface, fwSearchPaths ...string) (discover.Discover, error) {
if r := nvmllib.Init(); r != nvml.SUCCESS {
return nil, fmt.Errorf("failed to initialize NVML: %v", r)
}
Expand All @@ -48,10 +49,10 @@ func NewDriverDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCTK
return nil, fmt.Errorf("failed to determine driver version: %v", r)
}

return newDriverVersionDiscoverer(logger, driver, nvidiaCTKPath, version)
return newDriverVersionDiscoverer(logger, driver, nvidiaCTKPath, version, fwSearchPaths...)
}

func newDriverVersionDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCTKPath string, version string) (discover.Discover, error) {
func newDriverVersionDiscoverer(logger logger.Interface, driver *root.Driver, nvidiaCTKPath string, version string, fwSearchPaths ...string) (discover.Discover, error) {
libraries, err := NewDriverLibraryDiscoverer(logger, driver, nvidiaCTKPath, version)
if err != nil {
return nil, fmt.Errorf("failed to create discoverer for driver libraries: %v", err)
Expand All @@ -62,7 +63,7 @@ func newDriverVersionDiscoverer(logger logger.Interface, driver *root.Driver, nv
return nil, fmt.Errorf("failed to create discoverer for IPC sockets: %v", err)
}

firmwares, err := NewDriverFirmwareDiscoverer(logger, driver.Root, version)
firmwares, err := NewDriverFirmwareDiscoverer(logger, driver.Root, version, fwSearchPaths...)
if err != nil {
return nil, fmt.Errorf("failed to create discoverer for GSP firmware: %v", err)
}
Expand Down Expand Up @@ -114,7 +115,11 @@ func getUTSRelease() (string, error) {
return unix.ByteSliceToString(utsname.Release[:]), nil
}

func getFirmwareSearchPaths(logger logger.Interface) ([]string, error) {
func getFirmwareSearchPaths(logger logger.Interface, fwSearchPaths ...string) ([]string, error) {
Copy link
Preview

Copilot AI May 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] While the fallback behavior is clear from the implementation, adding a brief note in the function comment about how custom firmware search paths override defaults would improve clarity.

Copilot uses AI. Check for mistakes.

if len(fwSearchPaths) > 0 {
logger.Debugf("using custom firmware search paths configured with the library: %v", fwSearchPaths)
return fwSearchPaths, nil
}

var firmwarePaths []string
if p := getCustomFirmwareClassPath(logger); p != "" {
Expand Down Expand Up @@ -149,8 +154,8 @@ func getCustomFirmwareClassPath(logger logger.Interface) string {
}

// NewDriverFirmwareDiscoverer creates a discoverer for GSP firmware associated with the specified driver version.
func NewDriverFirmwareDiscoverer(logger logger.Interface, driverRoot string, version string) (discover.Discover, error) {
gspFirmwareSearchPaths, err := getFirmwareSearchPaths(logger)
func NewDriverFirmwareDiscoverer(logger logger.Interface, driverRoot string, version string, fwSearchPaths ...string) (discover.Discover, error) {
gspFirmwareSearchPaths, err := getFirmwareSearchPaths(logger, fwSearchPaths...)
if err != nil {
return nil, fmt.Errorf("failed to get firmware search paths: %v", err)
}
Expand Down
20 changes: 11 additions & 9 deletions pkg/nvcdi/lib.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
"github.com/NVIDIA/go-nvlib/pkg/nvlib/info"
"github.com/NVIDIA/go-nvlib/pkg/nvml"

"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
"github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv"
Expand All @@ -39,15 +40,16 @@ type wrapper struct {
}

type nvcdilib struct {
logger logger.Interface
nvmllib nvml.Interface
mode string
devicelib device.Interface
deviceNamer DeviceNamer
driverRoot string
devRoot string
nvidiaCTKPath string
librarySearchPaths []string
logger logger.Interface
nvmllib nvml.Interface
mode string
devicelib device.Interface
deviceNamer DeviceNamer
driverRoot string
devRoot string
nvidiaCTKPath string
librarySearchPaths []string
firmwareSearchPaths []string

csvFiles []string
csvIgnorePatterns []string
Expand Down
7 changes: 4 additions & 3 deletions pkg/nvcdi/management.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ import (
"strings"

"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
"tags.cncf.io/container-device-interface/pkg/cdi"
"tags.cncf.io/container-device-interface/specs-go"

"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/cuda"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
"tags.cncf.io/container-device-interface/pkg/cdi"
"tags.cncf.io/container-device-interface/specs-go"
)

type managementlib nvcdilib
Expand Down Expand Up @@ -65,7 +66,7 @@ func (m *managementlib) GetCommonEdits() (*cdi.ContainerEdits, error) {
return nil, fmt.Errorf("failed to get CUDA version: %v", err)
}

driver, err := newDriverVersionDiscoverer(m.logger, m.driver, m.nvidiaCTKPath, version)
driver, err := newDriverVersionDiscoverer(m.logger, m.driver, m.nvidiaCTKPath, version, m.firmwareSearchPaths...)
if err != nil {
return nil, fmt.Errorf("failed to create driver library discoverer: %v", err)
}
Expand Down
9 changes: 9 additions & 0 deletions pkg/nvcdi/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package nvcdi
import (
"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
"github.com/NVIDIA/go-nvlib/pkg/nvml"

"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform"
)
Expand Down Expand Up @@ -125,3 +126,11 @@ func WithLibrarySearchPaths(paths []string) Option {
o.librarySearchPaths = paths
}
}

// WithFirmwareSearchPaths sets the firmware search paths.
// This is currently only used for NVML- and Management-Mode.
func WithFirmwareSearchPaths(paths []string) Option {
Copy link
Preview

Copilot AI May 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Augment the function comment to include expected input formats and examples for firmware search paths to help clarify usage for future maintainers.

Copilot uses AI. Check for mistakes.

return func(o *nvcdilib) {
o.firmwareSearchPaths = paths
}
}
27 changes: 18 additions & 9 deletions tools/container/toolkit/toolkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,16 @@ import (
"path/filepath"
"strings"

"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/system/nvdevices"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform"
toml "github.com/pelletier/go-toml"
log "github.com/sirupsen/logrus"
"github.com/urfave/cli/v2"
"tags.cncf.io/container-device-interface/pkg/cdi"
"tags.cncf.io/container-device-interface/pkg/parser"

"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/system/nvdevices"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform"
)

const (
Expand Down Expand Up @@ -63,11 +64,12 @@ type options struct {
ContainerCLIDebug string
toolkitRoot string

cdiEnabled bool
cdiOutputDir string
cdiKind string
cdiVendor string
cdiClass string
cdiEnabled bool
cdiOutputDir string
cdiKind string
cdiVendor string
cdiClass string
cdiFirmwareSearchPaths cli.StringSlice

acceptNVIDIAVisibleDevicesWhenUnprivileged bool
acceptNVIDIAVisibleDevicesAsVolumeMounts bool
Expand Down Expand Up @@ -216,6 +218,12 @@ func main() {
Destination: &opts.cdiKind,
EnvVars: []string{"CDI_KIND"},
},
&cli.StringSliceFlag{
Name: "cdi-firmware-search-paths",
Usage: "specify custom firmware search paths to be used during generation of a CDI specification",
Destination: &opts.cdiFirmwareSearchPaths,
EnvVars: []string{"CDI_FIRMWARE_SEARCH_PATHS"},
},
&cli.BoolFlag{
Name: "ignore-errors",
Usage: "ignore errors when installing the NVIDIA Container toolkit. This is used for testing purposes only.",
Expand Down Expand Up @@ -701,6 +709,7 @@ func generateCDISpec(opts *options, nvidiaCTKPath string) error {
nvcdi.WithMode(nvcdi.ModeManagement),
nvcdi.WithDriverRoot(opts.DriverRootCtrPath),
nvcdi.WithNVIDIACTKPath(nvidiaCTKPath),
nvcdi.WithFirmwareSearchPaths(opts.cdiFirmwareSearchPaths.Value()),
nvcdi.WithVendor(opts.cdiVendor),
nvcdi.WithClass(opts.cdiClass),
)
Expand Down