diff --git a/cmd/dra-example-kubeletplugin/checkpoint.go b/cmd/dra-example-kubeletplugin/checkpoint.go index 7311e760..70727e85 100644 --- a/cmd/dra-example-kubeletplugin/checkpoint.go +++ b/cmd/dra-example-kubeletplugin/checkpoint.go @@ -3,14 +3,19 @@ package main import ( "encoding/json" + "k8s.io/kubernetes/pkg/kubelet/checkpointmanager" "k8s.io/kubernetes/pkg/kubelet/checkpointmanager/checksum" ) +type PreparedClaims map[string]PreparedDevices + type Checkpoint struct { Checksum checksum.Checksum `json:"checksum"` V1 *CheckpointV1 `json:"v1,omitempty"` } +var _ checkpointmanager.Checkpoint = &Checkpoint{} + type CheckpointV1 struct { PreparedClaims PreparedClaims `json:"preparedClaims,omitempty"` } @@ -25,6 +30,32 @@ func newCheckpoint() *Checkpoint { return pc } +func (cp *Checkpoint) GetPreparedDevices(claimUID string) PreparedDevices { + if cp.V1 == nil { + return nil + } + if devices, ok := cp.V1.PreparedClaims[claimUID]; ok { + return devices + } + return nil +} + +func (cp *Checkpoint) AddPreparedDevices(claimUID string, pds PreparedDevices) { + if cp.V1 == nil { + return + } + + cp.V1.PreparedClaims[claimUID] = pds +} + +func (cp *Checkpoint) RemovePreparedDevices(claimUID string) { + if cp.V1 == nil { + return + } + + delete(cp.V1.PreparedClaims, claimUID) +} + func (cp *Checkpoint) MarshalCheckpoint() ([]byte, error) { cp.Checksum = 0 out, err := json.Marshal(*cp) diff --git a/cmd/dra-example-kubeletplugin/prepared_device.go b/cmd/dra-example-kubeletplugin/prepared_device.go new file mode 100644 index 00000000..6ca0ec90 --- /dev/null +++ b/cmd/dra-example-kubeletplugin/prepared_device.go @@ -0,0 +1,21 @@ +package main + +import ( + drapbv1 "k8s.io/kubelet/pkg/apis/dra/v1beta1" + cdiapi "tags.cncf.io/container-device-interface/pkg/cdi" +) + +type PreparedDevice struct { + drapbv1.Device + ContainerEdits *cdiapi.ContainerEdits +} + +type PreparedDevices []*PreparedDevice + +func (pds PreparedDevices) GetDevices() []*drapbv1.Device { + var devices []*drapbv1.Device + for _, pd := range pds { + devices = append(devices, &pd.Device) + } + return devices +} diff --git a/cmd/dra-example-kubeletplugin/state_test.go b/cmd/dra-example-kubeletplugin/prepared_device_test.go similarity index 100% rename from cmd/dra-example-kubeletplugin/state_test.go rename to cmd/dra-example-kubeletplugin/prepared_device_test.go diff --git a/cmd/dra-example-kubeletplugin/state.go b/cmd/dra-example-kubeletplugin/state.go index c72de7c6..76dc5f75 100644 --- a/cmd/dra-example-kubeletplugin/state.go +++ b/cmd/dra-example-kubeletplugin/state.go @@ -34,8 +34,7 @@ import ( ) type AllocatableDevices map[string]resourceapi.Device -type PreparedDevices []*PreparedDevice -type PreparedClaims map[string]PreparedDevices + type PerDeviceCDIContainerEdits map[string]*cdiapi.ContainerEdits type OpaqueDeviceConfig struct { @@ -43,19 +42,6 @@ type OpaqueDeviceConfig struct { Config runtime.Object } -type PreparedDevice struct { - drapbv1.Device - ContainerEdits *cdiapi.ContainerEdits -} - -func (pds PreparedDevices) GetDevices() []*drapbv1.Device { - var devices []*drapbv1.Device - for _, pd := range pds { - devices = append(devices, &pd.Device) - } - return devices -} - type DeviceState struct { sync.Mutex cdi *CDIHandler @@ -119,10 +105,10 @@ func (s *DeviceState) Prepare(claim *resourceapi.ResourceClaim) ([]*drapbv1.Devi if err := s.checkpointManager.GetCheckpoint(DriverPluginCheckpointFile, checkpoint); err != nil { return nil, fmt.Errorf("unable to sync from checkpoint: %v", err) } - preparedClaims := checkpoint.V1.PreparedClaims - if preparedClaims[claimUID] != nil { - return preparedClaims[claimUID].GetDevices(), nil + preparedDevices := checkpoint.GetPreparedDevices(claimUID) + if preparedDevices != nil { + return preparedDevices.GetDevices(), nil } preparedDevices, err := s.prepareDevices(claim) @@ -134,12 +120,12 @@ func (s *DeviceState) Prepare(claim *resourceapi.ResourceClaim) ([]*drapbv1.Devi return nil, fmt.Errorf("unable to create CDI spec file for claim: %v", err) } - preparedClaims[claimUID] = preparedDevices + checkpoint.AddPreparedDevices(claimUID, preparedDevices) if err := s.checkpointManager.CreateCheckpoint(DriverPluginCheckpointFile, checkpoint); err != nil { return nil, fmt.Errorf("unable to sync to checkpoint: %v", err) } - return preparedClaims[claimUID].GetDevices(), nil + return preparedDevices.GetDevices(), nil } func (s *DeviceState) Unprepare(claimUID string) error { @@ -150,13 +136,13 @@ func (s *DeviceState) Unprepare(claimUID string) error { if err := s.checkpointManager.GetCheckpoint(DriverPluginCheckpointFile, checkpoint); err != nil { return fmt.Errorf("unable to sync from checkpoint: %v", err) } - preparedClaims := checkpoint.V1.PreparedClaims - if preparedClaims[claimUID] == nil { + preparedDevices := checkpoint.GetPreparedDevices(claimUID) + if preparedDevices == nil { return nil } - if err := s.unprepareDevices(claimUID, preparedClaims[claimUID]); err != nil { + if err := s.unprepareDevices(claimUID, preparedDevices); err != nil { return fmt.Errorf("unprepare failed: %v", err) } @@ -165,7 +151,7 @@ func (s *DeviceState) Unprepare(claimUID string) error { return fmt.Errorf("unable to delete CDI spec file for claim: %v", err) } - delete(preparedClaims, claimUID) + checkpoint.RemovePreparedDevices(claimUID) if err := s.checkpointManager.CreateCheckpoint(DriverPluginCheckpointFile, checkpoint); err != nil { return fmt.Errorf("unable to sync to checkpoint: %v", err) }