diff --git a/README.md b/README.md index e513ed5a..ee647d37 100644 --- a/README.md +++ b/README.md @@ -377,7 +377,9 @@ Finally, you can run the following to cleanup your environment and delete the ## Anatomy of a DRA resource driver -TBD +For usage and configuration options, prefer: +- CLI help: run `./dra-example-kubeletplugin --help` for flags and examples +- Helm values: consult `deployments/helm/dra-example-driver/values.yaml` for configurable settings and inline docs ## Code Organization diff --git a/cmd/dra-example-kubeletplugin/discovery.go b/cmd/dra-example-kubeletplugin/discovery.go index 0c45431f..2ef68de8 100644 --- a/cmd/dra-example-kubeletplugin/discovery.go +++ b/cmd/dra-example-kubeletplugin/discovery.go @@ -20,36 +20,53 @@ import ( "fmt" "math/rand" "os" + "strconv" + "strings" resourceapi "k8s.io/api/resource/v1" "k8s.io/apimachinery/pkg/api/resource" "k8s.io/utils/ptr" + semver "github.com/Masterminds/semver/v3" "github.com/google/uuid" ) -func enumerateAllPossibleDevices(numGPUs int) (AllocatableDevices, error) { +func enumerateAllPossibleDevices(numGPUs int, deviceAttributes []string) (AllocatableDevices, error) { seed := os.Getenv("NODE_NAME") uuids := generateUUIDs(seed, numGPUs) + // Parse additional device attributes from the flag + additionalAttributes, err := parseDeviceAttributes(deviceAttributes) + if err != nil { + return nil, fmt.Errorf("error parsing device attributes: %w", err) + } + alldevices := make(AllocatableDevices) for i, uuid := range uuids { - device := resourceapi.Device{ - Name: fmt.Sprintf("gpu-%d", i), - Attributes: map[resourceapi.QualifiedName]resourceapi.DeviceAttribute{ - "index": { - IntValue: ptr.To(int64(i)), - }, - "uuid": { - StringValue: ptr.To(uuid), - }, - "model": { - StringValue: ptr.To("LATEST-GPU-MODEL"), - }, - "driverVersion": { - VersionValue: ptr.To("1.0.0"), - }, + // Start with default attributes + attributes := map[resourceapi.QualifiedName]resourceapi.DeviceAttribute{ + "index": { + IntValue: ptr.To(int64(i)), + }, + "uuid": { + StringValue: ptr.To(uuid), + }, + "model": { + StringValue: ptr.To("LATEST-GPU-MODEL"), }, + "driverVersion": { + VersionValue: ptr.To("1.0.0"), + }, + } + + // Add additional attributes from the flag + for key, value := range additionalAttributes { + attributes[key] = value + } + + device := resourceapi.Device{ + Name: fmt.Sprintf("gpu-%d", i), + Attributes: attributes, Capacity: map[resourceapi.QualifiedName]resourceapi.DeviceCapacity{ "memory": { Value: resource.MustParse("80Gi"), @@ -61,6 +78,102 @@ func enumerateAllPossibleDevices(numGPUs int) (AllocatableDevices, error) { return alldevices, nil } +// parseDeviceAttributes parses a comma-separated string of key=value pairs +// and returns a map of device attributes with automatic type detection. +// Supported value types: +// - int: integer values (e.g., "count=5") +// - bool: boolean values (e.g., "enabled=true", "disabled=false") +// - version: semantic version values (e.g., "driver_version=1.2.3") +// - string: any other value (e.g., "productName=NVIDIA GeForce RTX 5090", "architecture=Blackwell") +func parseDeviceAttributes(deviceAttributes []string) (map[resourceapi.QualifiedName]resourceapi.DeviceAttribute, error) { + attributes := make(map[resourceapi.QualifiedName]resourceapi.DeviceAttribute) + + if len(deviceAttributes) == 0 { + return attributes, nil + } + + for _, pair := range deviceAttributes { + pair = strings.TrimSpace(pair) + if pair == "" { + continue + } + + parts := strings.SplitN(pair, "=", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("invalid device attribute format: %s (expected key=value)", pair) + } + + key := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + + if key == "" { + return nil, fmt.Errorf("device attribute key cannot be empty") + } + + // Detect value type and create appropriate DeviceAttribute + attr, err := createDeviceAttribute(value) + if err != nil { + return nil, fmt.Errorf("invalid value for attribute %s: %w", key, err) + } + + attributes[resourceapi.QualifiedName(key)] = attr + } + + return attributes, nil +} + +// createDeviceAttribute creates a DeviceAttribute with the appropriate value type +// based on the input string. It tries to detect the type in this order: +// 1. bool (true/false) +// 2. int (integer) +// 3. version (semantic version pattern) +// 4. string (default) +func createDeviceAttribute(value string) (resourceapi.DeviceAttribute, error) { + // Check for boolean values + if value == "true" { + return resourceapi.DeviceAttribute{ + BoolValue: ptr.To(true), + }, nil + } + if value == "false" { + return resourceapi.DeviceAttribute{ + BoolValue: ptr.To(false), + }, nil + } + + // Check for integer values + if intVal, err := strconv.ParseInt(value, 10, 64); err == nil { + return resourceapi.DeviceAttribute{ + IntValue: ptr.To(intVal), + }, nil + } + + // Check for semantic version pattern (basic check for x.y.z format) + if isSemanticVersion(value) { + return resourceapi.DeviceAttribute{ + VersionValue: ptr.To(value), + }, nil + } + + // Default to string value + // Validate string length (max 64 characters as per API spec) + if len(value) > 64 { + return resourceapi.DeviceAttribute{}, fmt.Errorf("string value too long (max 64 characters): %s", value) + } + + return resourceapi.DeviceAttribute{ + StringValue: ptr.To(value), + }, nil +} + +// isSemanticVersion checks whether the string is a valid semantic version per https://semver.org/. +// It accepts versions like 1.2.3, 1.0.0-beta.1, and allows build metadata like +exp.sha. +func isSemanticVersion(value string) bool { + // Enforce strict SemVer (MAJOR.MINOR.PATCH) per semver.org + _, err := semver.StrictNewVersion(value) + return err == nil +} + func generateUUIDs(seed string, count int) []string { rand := rand.New(rand.NewSource(hash(seed))) diff --git a/cmd/dra-example-kubeletplugin/discovery_test.go b/cmd/dra-example-kubeletplugin/discovery_test.go new file mode 100644 index 00000000..28160c7b --- /dev/null +++ b/cmd/dra-example-kubeletplugin/discovery_test.go @@ -0,0 +1,91 @@ +package main + +import ( + "testing" + + resourceapi "k8s.io/api/resource/v1" +) + +func TestParseDeviceAttributes(t *testing.T) { + tests := []struct { + name string + input []string + wantErr bool + wantVals map[resourceapi.QualifiedName]func(resourceapi.DeviceAttribute) bool + }{ + { + name: "empty", + input: []string{}, + wantErr: false, + wantVals: map[resourceapi.QualifiedName]func(resourceapi.DeviceAttribute) bool{}, + }, + { + name: "invalid format (missing '=')", + input: []string{"invalid"}, + wantErr: true, + }, + { + name: "typed values", + input: []string{"boolTrue=true", "boolFalse=false", "count=42", "driverVersion=1.2.3", "pre=1.0.0-beta.1", "name=LATEST-GPU-MODEL"}, + wantVals: map[resourceapi.QualifiedName]func(resourceapi.DeviceAttribute) bool{ + "boolTrue": func(a resourceapi.DeviceAttribute) bool { return a.BoolValue != nil && *a.BoolValue }, + "boolFalse": func(a resourceapi.DeviceAttribute) bool { return a.BoolValue != nil && !*a.BoolValue }, + "count": func(a resourceapi.DeviceAttribute) bool { return a.IntValue != nil && *a.IntValue == 42 }, + "driverVersion": func(a resourceapi.DeviceAttribute) bool { return a.VersionValue != nil && *a.VersionValue == "1.2.3" }, + "pre": func(a resourceapi.DeviceAttribute) bool { + return a.VersionValue != nil && *a.VersionValue == "1.0.0-beta.1" + }, + "name": func(a resourceapi.DeviceAttribute) bool { + return a.StringValue != nil && *a.StringValue == "LATEST-GPU-MODEL" + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseDeviceAttributes(tt.input) + if tt.wantErr { + if err == nil { + t.Fatalf("expected error, got nil") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + for k, check := range tt.wantVals { + a, ok := got[k] + if !ok { + t.Fatalf("missing key %q", k) + } + if !check(a) { + t.Fatalf("value check failed for %q: %+v", k, a) + } + } + }) + } +} + +func TestIsSemanticVersion(t *testing.T) { + cases := []struct { + name string + v string + ok bool + }{ + {"basic", "1.2.3", true}, + {"prerelease", "1.0.0-beta", true}, + {"prerelease+build", "1.0.0-beta+build", true}, + {"zeros", "0.0.1", true}, + {"missing patch", "1.2", false}, + {"too many parts", "1.2.3.4", false}, + {"invalid char", "1.0.x", false}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + if got := isSemanticVersion(c.v); got != c.ok { + t.Fatalf("isSemanticVersion(%q)=%v, want %v", c.v, got, c.ok) + } + }) + } +} diff --git a/cmd/dra-example-kubeletplugin/main.go b/cmd/dra-example-kubeletplugin/main.go index e0cfaf9c..c096bf0f 100644 --- a/cmd/dra-example-kubeletplugin/main.go +++ b/cmd/dra-example-kubeletplugin/main.go @@ -46,6 +46,7 @@ type Flags struct { nodeName string cdiRoot string numDevices int + deviceAttributes []string kubeletRegistrarDirectoryPath string kubeletPluginsDirectoryPath string healthcheckPort int @@ -94,6 +95,12 @@ func newApp() *cli.App { Destination: &flags.numDevices, EnvVars: []string{"NUM_DEVICES"}, }, + &cli.StringSliceFlag{ + Name: "device-attributes", + Usage: "Additional device attributes as repeated key=value pairs. May be specified multiple times. Examples: --device-attributes productName=NVIDIA GeForce RTX 5090 --device-attributes architecture=Blackwell. Note: when using DEVICE_ATTRIBUTES env var, provide key=value entries separated by commas; values containing commas are not supported via env and should be passed using repeated --device-attributes flags.", + Value: cli.NewStringSlice(), + EnvVars: []string{"DEVICE_ATTRIBUTES"}, + }, &cli.StringFlag{ Name: "kubelet-registrar-directory-path", Usage: "Absolute path to the directory where kubelet stores plugin registrations.", @@ -138,6 +145,8 @@ func newApp() *cli.App { return fmt.Errorf("create client: %v", err) } + flags.deviceAttributes = c.StringSlice("device-attributes") + config := &Config{ flags: flags, coreclient: clientSets.Core, diff --git a/cmd/dra-example-kubeletplugin/state.go b/cmd/dra-example-kubeletplugin/state.go index 7cffe388..97e24961 100644 --- a/cmd/dra-example-kubeletplugin/state.go +++ b/cmd/dra-example-kubeletplugin/state.go @@ -64,7 +64,7 @@ type DeviceState struct { } func NewDeviceState(config *Config) (*DeviceState, error) { - allocatable, err := enumerateAllPossibleDevices(config.flags.numDevices) + allocatable, err := enumerateAllPossibleDevices(config.flags.numDevices, config.flags.deviceAttributes) if err != nil { return nil, fmt.Errorf("error enumerating all possible devices: %v", err) } diff --git a/deployments/helm/dra-example-driver/templates/kubeletplugin.yaml b/deployments/helm/dra-example-driver/templates/kubeletplugin.yaml index bcf44ffd..cee24700 100644 --- a/deployments/helm/dra-example-driver/templates/kubeletplugin.yaml +++ b/deployments/helm/dra-example-driver/templates/kubeletplugin.yaml @@ -76,6 +76,9 @@ spec: # Simulated number of devices the example driver will pretend to have. - name: NUM_DEVICES value: {{ .Values.kubeletPlugin.numDevices | quote }} + # Additional device attributes to be added to resource slices. + - name: DEVICE_ATTRIBUTES + value: {{ .Values.kubeletPlugin.deviceAttributes | quote }} {{- if .Values.kubeletPlugin.containers.plugin.healthcheckPort }} - name: HEALTHCHECK_PORT value: {{ .Values.kubeletPlugin.containers.plugin.healthcheckPort | quote }} diff --git a/deployments/helm/dra-example-driver/values.yaml b/deployments/helm/dra-example-driver/values.yaml index e58c9d33..c3fd1740 100644 --- a/deployments/helm/dra-example-driver/values.yaml +++ b/deployments/helm/dra-example-driver/values.yaml @@ -46,6 +46,13 @@ controller: kubeletPlugin: numDevices: 8 + # Additional device attributes to be added to resource slices. + # When setting via env var (DEVICE_ATTRIBUTES), provide a comma-separated list of key=value entries: + # DEVICE_ATTRIBUTES: "productName=NVIDIA GeForce RTX 5090,architecture=Blackwell" + # Values containing commas are not supported via env var. Prefer repeated CLI flags, e.g.: + # --device-attributes productName=NVIDIA GeForce RTX 5090 \ + # --device-attributes architecture=Blackwell + deviceAttributes: "" priorityClassName: "system-node-critical" updateStrategy: type: RollingUpdate diff --git a/go.mod b/go.mod index 5dc4428e..659d3eee 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.24.0 toolchain go1.24.2 require ( + github.com/Masterminds/semver/v3 v3.2.1 github.com/google/uuid v1.6.0 github.com/spf13/pflag v1.0.6 github.com/stretchr/testify v1.10.0 diff --git a/go.sum b/go.sum index 28825073..4215c9bd 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0rYXWg0= +github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=