aboutsummaryrefslogtreecommitdiff
path: root/vendor/github.com/container-orchestrated-devices/container-device-interface/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/container-orchestrated-devices/container-device-interface/pkg')
-rw-r--r--vendor/github.com/container-orchestrated-devices/container-device-interface/pkg/devices.go180
1 files changed, 180 insertions, 0 deletions
diff --git a/vendor/github.com/container-orchestrated-devices/container-device-interface/pkg/devices.go b/vendor/github.com/container-orchestrated-devices/container-device-interface/pkg/devices.go
new file mode 100644
index 000000000..e66fd36c0
--- /dev/null
+++ b/vendor/github.com/container-orchestrated-devices/container-device-interface/pkg/devices.go
@@ -0,0 +1,180 @@
+package pkg
+
+import (
+ "encoding/json"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "strings"
+
+ cdispec "github.com/container-orchestrated-devices/container-device-interface/specs-go"
+ spec "github.com/opencontainers/runtime-spec/specs-go"
+)
+
+const (
+ root = "/etc/cdi"
+)
+
+func collectCDISpecs() (map[string]*cdispec.Spec, error) {
+ var files []string
+ vendor := make(map[string]*cdispec.Spec)
+
+ err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
+ if info == nil || info.IsDir() {
+ return nil
+ }
+
+ if filepath.Ext(path) != ".json" {
+ return nil
+ }
+
+ files = append(files, path)
+ return nil
+ })
+
+ if err != nil {
+ return nil, err
+ }
+
+ for _, path := range files {
+ spec, err := loadCDIFile(path)
+ if err != nil {
+ continue
+ }
+
+ if _, ok := vendor[spec.Kind]; ok {
+ continue
+ }
+
+ vendor[spec.Kind] = spec
+ }
+
+ return vendor, nil
+}
+
+// TODO: Validate (e.g: duplicate device names)
+func loadCDIFile(path string) (*cdispec.Spec, error) {
+ file, err := ioutil.ReadFile(path)
+ if err != nil {
+ return nil, err
+ }
+
+ var spec *cdispec.Spec
+ err = json.Unmarshal([]byte(file), &spec)
+ if err != nil {
+ return nil, err
+ }
+
+ return spec, nil
+}
+
+/*
+* Pattern "vendor.com/device=myDevice" with the vendor being optional
+ */
+func extractVendor(dev string) (string, string) {
+ if strings.IndexByte(dev, '=') == -1 {
+ return "", dev
+ }
+
+ split := strings.SplitN(dev, "=", 2)
+ return split[0], split[1]
+}
+
+// GetCDIForDevice returns the CDI specification that matches the device name the user provided.
+func GetCDIForDevice(dev string, specs map[string]*cdispec.Spec) (*cdispec.Spec, error) {
+ vendor, device := extractVendor(dev)
+
+ if vendor != "" {
+ s, ok := specs[vendor]
+ if !ok {
+ return nil, fmt.Errorf("Could not find vendor %q for device %q", vendor, device)
+ }
+
+ for _, d := range s.Devices {
+ if d.Name != device {
+ continue
+ }
+
+ return s, nil
+ }
+
+ return nil, fmt.Errorf("Could not find device %q for vendor %q", device, vendor)
+ }
+
+ var found []*cdispec.Spec
+ var vendors []string
+ for vendor, spec := range specs {
+
+ for _, d := range spec.Devices {
+ if d.Name != device {
+ continue
+ }
+
+ found = append(found, spec)
+ vendors = append(vendors, vendor)
+ }
+ }
+
+ if len(found) > 1 {
+ return nil, fmt.Errorf("%q is ambiguous and currently refers to multiple devices from different vendors: %q", dev, vendors)
+ }
+
+ if len(found) == 1 {
+ return found[0], nil
+ }
+
+ return nil, fmt.Errorf("Could not find device %q", dev)
+}
+
+// HasDevice returns true if a device is a CDI device
+// an error may be returned in cases where permissions may be required
+func HasDevice(dev string) (bool, error) {
+ specs, err := collectCDISpecs()
+ if err != nil {
+ return false, err
+ }
+
+ d, err := GetCDIForDevice(dev, specs)
+ if err != nil {
+ return false, err
+ }
+
+ return d != nil, nil
+}
+
+// UpdateOCISpecForDevices updates the given OCI spec based on the requested CDI devices
+func UpdateOCISpecForDevices(ociconfig *spec.Spec, devs []string) error {
+ specs, err := collectCDISpecs()
+ if err != nil {
+ return err
+ }
+
+ return UpdateOCISpecForDevicesWithSpec(ociconfig, devs, specs)
+}
+
+// UpdateOCISpecForDevicesWithLoggerAndSpecs is mainly used for testing
+func UpdateOCISpecForDevicesWithSpec(ociconfig *spec.Spec, devs []string, specs map[string]*cdispec.Spec) error {
+ edits := make(map[string]*cdispec.Spec)
+
+ for _, d := range devs {
+ spec, err := GetCDIForDevice(d, specs)
+ if err != nil {
+ return err
+ }
+
+ edits[spec.Kind] = spec
+ err = cdispec.ApplyOCIEditsForDevice(ociconfig, spec, d)
+ if err != nil {
+ return err
+ }
+ }
+
+ for _, spec := range edits {
+ if err := cdispec.ApplyOCIEdits(ociconfig, spec); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}