aboutsummaryrefslogtreecommitdiff
path: root/pkg/hooks/hooks.go
blob: dbcd7b773e5eee923485bf5846f31ea57e94053a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
package hooks

import (
	"encoding/json"
	"fmt"
	"io/ioutil"
	"os"
	"path/filepath"
	"regexp"
	"strings"
	"syscall"

	spec "github.com/opencontainers/runtime-spec/specs-go"
	"github.com/opencontainers/runtime-tools/generate"
	"github.com/pkg/errors"
	"github.com/sirupsen/logrus"
)

const (
	// DefaultHooksDir Default directory containing hooks config files
	DefaultHooksDir = "/usr/share/containers/oci/hooks.d"
	// OverrideHooksDir Directory where admin can override the default configuration
	OverrideHooksDir = "/etc/containers/oci/hooks.d"
)

// HookParams is the structure returned from read the hooks configuration
type HookParams struct {
	Hook          string   `json:"hook"`
	Stage         []string `json:"stage"`
	Cmds          []string `json:"cmd"`
	Annotations   []string `json:"annotation"`
	HasBindMounts bool     `json:"hasbindmounts"`
	Arguments     []string `json:"arguments"`
}

// readHook reads hooks json files, verifies it and returns the json config
func readHook(hookPath string) (HookParams, error) {
	var hook HookParams
	raw, err := ioutil.ReadFile(hookPath)
	if err != nil {
		return hook, errors.Wrapf(err, "error Reading hook %q", hookPath)
	}
	if err := json.Unmarshal(raw, &hook); err != nil {
		return hook, errors.Wrapf(err, "error Unmarshalling JSON for %q", hookPath)
	}
	if _, err := os.Stat(hook.Hook); err != nil {
		return hook, errors.Wrapf(err, "unable to stat hook %q in hook config %q", hook.Hook, hookPath)
	}
	validStage := map[string]bool{"prestart": true, "poststart": true, "poststop": true}
	for _, cmd := range hook.Cmds {
		if _, err = regexp.Compile(cmd); err != nil {
			return hook, errors.Wrapf(err, "invalid cmd regular expression %q defined in hook config %q", cmd, hookPath)
		}
	}
	for _, cmd := range hook.Annotations {
		if _, err = regexp.Compile(cmd); err != nil {
			return hook, errors.Wrapf(err, "invalid cmd regular expression %q defined in hook config %q", cmd, hookPath)
		}
	}
	for _, stage := range hook.Stage {
		if !validStage[stage] {
			return hook, errors.Wrapf(err, "unknown stage %q defined in hook config %q", stage, hookPath)
		}
	}
	return hook, nil
}

// readHooks reads hooks json files in directory to setup OCI Hooks
// adding hooks to the passedin hooks map.
func readHooks(hooksPath string, hooks map[string]HookParams) error {
	if _, err := os.Stat(hooksPath); err != nil {
		if os.IsNotExist(err) {
			logrus.Warnf("hooks path: %q does not exist", hooksPath)
			return nil
		}
		return errors.Wrapf(err, "unable to stat hooks path %q", hooksPath)
	}

	files, err := ioutil.ReadDir(hooksPath)
	if err != nil {
		return err
	}

	for _, file := range files {
		if !strings.HasSuffix(file.Name(), ".json") {
			continue
		}
		hook, err := readHook(filepath.Join(hooksPath, file.Name()))
		if err != nil {
			return err
		}
		for key, h := range hooks {
			// hook.Hook can only be defined in one hook file, unless it has the
			// same name in the override path.
			if hook.Hook == h.Hook && key != file.Name() {
				return errors.Wrapf(syscall.EINVAL, "duplicate path,  hook %q from %q already defined in %q", hook.Hook, hooksPath, key)
			}
		}
		hooks[file.Name()] = hook
	}
	return nil
}

// SetupHooks takes a hookspath and reads all of the hooks in that directory.
// returning a map of the configured hooks
func SetupHooks(hooksPath string) (map[string]HookParams, error) {
	hooksMap := make(map[string]HookParams)
	if err := readHooks(hooksPath, hooksMap); err != nil {
		return nil, err
	}
	if hooksPath == DefaultHooksDir {
		if err := readHooks(OverrideHooksDir, hooksMap); err != nil {
			return nil, err
		}
	}

	return hooksMap, nil
}

// AddOCIHook generates OCI specification using the included hook
func AddOCIHook(g *generate.Generator, hook HookParams) error {
	for _, stage := range hook.Stage {
		h := spec.Hook{
			Path: hook.Hook,
			Args: append([]string{hook.Hook}, hook.Arguments...),
			Env:  []string{fmt.Sprintf("stage=%s", stage)},
		}
		logrus.Debugf("AddOCIHook", h)
		switch stage {
		case "prestart":
			g.AddPreStartHook(h)

		case "poststart":
			g.AddPostStartHook(h)

		case "poststop":
			g.AddPostStopHook(h)
		}
	}
	return nil
}