diff options
Diffstat (limited to 'cmd/podman/shared/parse')
-rw-r--r-- | cmd/podman/shared/parse/parse.go | 28 | ||||
-rw-r--r-- | cmd/podman/shared/parse/parse_test.go | 53 |
2 files changed, 81 insertions, 0 deletions
diff --git a/cmd/podman/shared/parse/parse.go b/cmd/podman/shared/parse/parse.go index 3a75ff7a8..79449029d 100644 --- a/cmd/podman/shared/parse/parse.go +++ b/cmd/podman/shared/parse/parse.go @@ -79,6 +79,34 @@ func ValidateDomain(val string) (string, error) { return "", fmt.Errorf("%s is not a valid domain", val) } +// GetAllLabels retrieves all labels given a potential label file and a number +// of labels provided from the command line. +func GetAllLabels(labelFile, inputLabels []string) (map[string]string, error) { + labels := make(map[string]string) + for _, file := range labelFile { + // Use of parseEnvFile still seems safe, as it's missing the + // extra parsing logic of parseEnv. + // There's an argument that we SHOULD be doing that parsing for + // all environment variables, even those sourced from files, but + // that would require a substantial rework. + if err := parseEnvFile(labels, file); err != nil { + return nil, err + } + } + for _, label := range inputLabels { + split := strings.SplitN(label, "=", 2) + if split[0] == "" { + return nil, errors.Errorf("invalid label format: %q", label) + } + value := "" + if len(split) > 1 { + value = split[1] + } + labels[split[0]] = value + } + return labels, nil +} + // reads a file of line terminated key=value pairs, and overrides any keys // present in the file with additional pairs specified in the override parameter // for env-file and labels-file flags diff --git a/cmd/podman/shared/parse/parse_test.go b/cmd/podman/shared/parse/parse_test.go index 1359076a0..a6ddc2be9 100644 --- a/cmd/podman/shared/parse/parse_test.go +++ b/cmd/podman/shared/parse/parse_test.go @@ -4,9 +4,33 @@ package parse import ( + "io/ioutil" + "os" "testing" + + "github.com/stretchr/testify/assert" +) + +var ( + Var1 = []string{"ONE=1", "TWO=2"} ) +func createTmpFile(content []byte) (string, error) { + tmpfile, err := ioutil.TempFile(os.TempDir(), "unittest") + if err != nil { + return "", err + } + + if _, err := tmpfile.Write(content); err != nil { + return "", err + + } + if err := tmpfile.Close(); err != nil { + return "", err + } + return tmpfile.Name(), nil +} + func TestValidateExtraHost(t *testing.T) { type args struct { val string @@ -97,3 +121,32 @@ func TestValidateFileName(t *testing.T) { }) } } + +func TestGetAllLabels(t *testing.T) { + fileLabels := []string{} + labels, _ := GetAllLabels(fileLabels, Var1) + assert.Equal(t, len(labels), 2) +} + +func TestGetAllLabelsBadKeyValue(t *testing.T) { + inLabels := []string{"=badValue", "="} + fileLabels := []string{} + _, err := GetAllLabels(fileLabels, inLabels) + assert.Error(t, err, assert.AnError) +} + +func TestGetAllLabelsBadLabelFile(t *testing.T) { + fileLabels := []string{"/foobar5001/be"} + _, err := GetAllLabels(fileLabels, Var1) + assert.Error(t, err, assert.AnError) +} + +func TestGetAllLabelsFile(t *testing.T) { + content := []byte("THREE=3") + tFile, err := createTmpFile(content) + defer os.Remove(tFile) + assert.NoError(t, err) + fileLabels := []string{tFile} + result, _ := GetAllLabels(fileLabels, Var1) + assert.Equal(t, len(result), 3) +} |