package pkg import ( "encoding/json" "fmt" "io/ioutil" "os" "path/filepath" "strings" cdispec "github.com/container-orchestrated-devices/container-device-interface/specs-go" spec "github.com/opencontainers/runtime-spec/specs-go" ) const ( root = "/etc/cdi" ) func collectCDISpecs() (map[string]*cdispec.Spec, error) { var files []string vendor := make(map[string]*cdispec.Spec) err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error { if info == nil || info.IsDir() { return nil } if filepath.Ext(path) != ".json" { return nil } files = append(files, path) return nil }) if err != nil { return nil, err } for _, path := range files { spec, err := loadCDIFile(path) if err != nil { continue } if _, ok := vendor[spec.Kind]; ok { continue } vendor[spec.Kind] = spec } return vendor, nil } // TODO: Validate (e.g: duplicate device names) func loadCDIFile(path string) (*cdispec.Spec, error) { file, err := ioutil.ReadFile(path) if err != nil { return nil, err } var spec *cdispec.Spec err = json.Unmarshal([]byte(file), &spec) if err != nil { return nil, err } return spec, nil } /* * Pattern "vendor.com/device=myDevice" with the vendor being optional */ func extractVendor(dev string) (string, string) { if strings.IndexByte(dev, '=') == -1 { return "", dev } split := strings.SplitN(dev, "=", 2) return split[0], split[1] } // GetCDIForDevice returns the CDI specification that matches the device name the user provided. func GetCDIForDevice(dev string, specs map[string]*cdispec.Spec) (*cdispec.Spec, error) { vendor, device := extractVendor(dev) if vendor != "" { s, ok := specs[vendor] if !ok { return nil, fmt.Errorf("Could not find vendor %q for device %q", vendor, device) } for _, d := range s.Devices { if d.Name != device { continue } return s, nil } return nil, fmt.Errorf("Could not find device %q for vendor %q", device, vendor) } var found []*cdispec.Spec var vendors []string for vendor, spec := range specs { for _, d := range spec.Devices { if d.Name != device { continue } found = append(found, spec) vendors = append(vendors, vendor) } } if len(found) > 1 { return nil, fmt.Errorf("%q is ambiguous and currently refers to multiple devices from different vendors: %q", dev, vendors) } if len(found) == 1 { return found[0], nil } return nil, fmt.Errorf("Could not find device %q", dev) } // HasDevice returns true if a device is a CDI device // an error may be returned in cases where permissions may be required func HasDevice(dev string) (bool, error) { specs, err := collectCDISpecs() if err != nil { return false, err } d, err := GetCDIForDevice(dev, specs) if err != nil { return false, err } return d != nil, nil } // UpdateOCISpecForDevices updates the given OCI spec based on the requested CDI devices func UpdateOCISpecForDevices(ociconfig *spec.Spec, devs []string) error { specs, err := collectCDISpecs() if err != nil { return err } return UpdateOCISpecForDevicesWithSpec(ociconfig, devs, specs) } // UpdateOCISpecForDevicesWithLoggerAndSpecs is mainly used for testing func UpdateOCISpecForDevicesWithSpec(ociconfig *spec.Spec, devs []string, specs map[string]*cdispec.Spec) error { edits := make(map[string]*cdispec.Spec) for _, d := range devs { spec, err := GetCDIForDevice(d, specs) if err != nil { return err } edits[spec.Kind] = spec err = cdispec.ApplyOCIEditsForDevice(ociconfig, spec, d) if err != nil { return err } } for _, spec := range edits { if err := cdispec.ApplyOCIEdits(ociconfig, spec); err != nil { return err } } return nil }