diff options
Diffstat (limited to 'utils')
-rw-r--r-- | utils/testdata/cgroup.empty | 0 | ||||
-rw-r--r-- | utils/testdata/cgroup.other | 1 | ||||
-rw-r--r-- | utils/testdata/cgroup.root | 1 | ||||
-rw-r--r-- | utils/utils.go | 2 | ||||
-rw-r--r-- | utils/utils_supported.go | 14 | ||||
-rw-r--r-- | utils/utils_test.go | 26 | ||||
-rw-r--r-- | utils/utils_windows.go | 4 |
7 files changed, 42 insertions, 6 deletions
diff --git a/utils/testdata/cgroup.empty b/utils/testdata/cgroup.empty new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/utils/testdata/cgroup.empty diff --git a/utils/testdata/cgroup.other b/utils/testdata/cgroup.other new file mode 100644 index 000000000..239a7cded --- /dev/null +++ b/utils/testdata/cgroup.other @@ -0,0 +1 @@ +0::/other diff --git a/utils/testdata/cgroup.root b/utils/testdata/cgroup.root new file mode 100644 index 000000000..1e027b2a3 --- /dev/null +++ b/utils/testdata/cgroup.root @@ -0,0 +1 @@ +0::/ diff --git a/utils/utils.go b/utils/utils.go index d0e3dbb46..fd66ac2ed 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -192,7 +192,7 @@ func moveProcessPIDFileToScope(pidPath, slice, scope string) error { } func moveProcessToScope(pid int, slice, scope string) error { - err := RunUnderSystemdScope(int(pid), slice, scope) + err := RunUnderSystemdScope(pid, slice, scope) // If the PID is not valid anymore, do not return an error. if dbusErr, ok := err.(dbus.Error); ok { if dbusErr.Name == "org.freedesktop.DBus.Error.UnixProcessIdUnknown" { diff --git a/utils/utils_supported.go b/utils/utils_supported.go index 493ea61ce..c2dcc4631 100644 --- a/utils/utils_supported.go +++ b/utils/utils_supported.go @@ -64,7 +64,7 @@ func RunUnderSystemdScope(pid int, slice string, unitName string) error { return nil } -func getCgroupProcess(procFile string) (string, error) { +func getCgroupProcess(procFile string, allowRoot bool) (string, error) { f, err := os.Open(procFile) if err != nil { return "", err @@ -72,7 +72,7 @@ func getCgroupProcess(procFile string) (string, error) { defer f.Close() scanner := bufio.NewScanner(f) - cgroup := "/" + cgroup := "" for scanner.Scan() { line := scanner.Text() parts := strings.SplitN(line, ":", 3) @@ -87,7 +87,7 @@ func getCgroupProcess(procFile string) (string, error) { cgroup = parts[2] } } - if cgroup == "/" { + if len(cgroup) == 0 || (!allowRoot && cgroup == "/") { return "", errors.Errorf("could not find cgroup mount in %q", procFile) } return cgroup, nil @@ -95,12 +95,16 @@ func getCgroupProcess(procFile string) (string, error) { // GetOwnCgroup returns the cgroup for the current process. func GetOwnCgroup() (string, error) { - return getCgroupProcess("/proc/self/cgroup") + return getCgroupProcess("/proc/self/cgroup", true) +} + +func GetOwnCgroupDisallowRoot() (string, error) { + return getCgroupProcess("/proc/self/cgroup", false) } // GetCgroupProcess returns the cgroup for the specified process process. func GetCgroupProcess(pid int) (string, error) { - return getCgroupProcess(fmt.Sprintf("/proc/%d/cgroup", pid)) + return getCgroupProcess(fmt.Sprintf("/proc/%d/cgroup", pid), true) } // MoveUnderCgroupSubtree moves the PID under a cgroup subtree. diff --git a/utils/utils_test.go b/utils/utils_test.go new file mode 100644 index 000000000..f34dbdd7e --- /dev/null +++ b/utils/utils_test.go @@ -0,0 +1,26 @@ +//go:build linux || darwin +// +build linux darwin + +package utils + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCgroupProcess(t *testing.T) { + val, err := getCgroupProcess("testdata/cgroup.root", true) + assert.Nil(t, err) + assert.Equal(t, "/", val) + + _, err = getCgroupProcess("testdata/cgroup.root", false) + assert.NotNil(t, err) + + val, err = getCgroupProcess("testdata/cgroup.other", true) + assert.Nil(t, err) + assert.Equal(t, "/other", val) + + _, err = getCgroupProcess("testdata/cgroup.empty", true) + assert.NotNil(t, err) +} diff --git a/utils/utils_windows.go b/utils/utils_windows.go index 2c159ab06..1d017f5ae 100644 --- a/utils/utils_windows.go +++ b/utils/utils_windows.go @@ -17,6 +17,10 @@ func GetOwnCgroup() (string, error) { return "", errors.New("not implemented for windows") } +func GetOwnCgroupDisallowRoot() (string, error) { + return "", errors.New("not implemented for windows") +} + func GetCgroupProcess(pid int) (string, error) { return "", errors.New("not implemented for windows") } |