summaryrefslogtreecommitdiff
path: root/vendor/github.com/onsi/gomega/gstruct/keys.go
blob: 56aed4bab723541b7eb686b9bc94703839b25921 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
// untested sections: 6

package gstruct

import (
	"errors"
	"fmt"
	"reflect"
	"runtime/debug"
	"strings"

	"github.com/onsi/gomega/format"
	errorsutil "github.com/onsi/gomega/gstruct/errors"
	"github.com/onsi/gomega/types"
)

func MatchAllKeys(keys Keys) types.GomegaMatcher {
	return &KeysMatcher{
		Keys: keys,
	}
}

func MatchKeys(options Options, keys Keys) types.GomegaMatcher {
	return &KeysMatcher{
		Keys:          keys,
		IgnoreExtras:  options&IgnoreExtras != 0,
		IgnoreMissing: options&IgnoreMissing != 0,
	}
}

type KeysMatcher struct {
	// Matchers for each key.
	Keys Keys

	// Whether to ignore extra keys or consider it an error.
	IgnoreExtras bool
	// Whether to ignore missing keys or consider it an error.
	IgnoreMissing bool

	// State.
	failures []error
}

type Keys map[interface{}]types.GomegaMatcher

func (m *KeysMatcher) Match(actual interface{}) (success bool, err error) {
	if reflect.TypeOf(actual).Kind() != reflect.Map {
		return false, fmt.Errorf("%v is type %T, expected map", actual, actual)
	}

	m.failures = m.matchKeys(actual)
	if len(m.failures) > 0 {
		return false, nil
	}
	return true, nil
}

func (m *KeysMatcher) matchKeys(actual interface{}) (errs []error) {
	actualValue := reflect.ValueOf(actual)
	keys := map[interface{}]bool{}
	for _, keyValue := range actualValue.MapKeys() {
		key := keyValue.Interface()
		keys[key] = true

		err := func() (err error) {
			// This test relies heavily on reflect, which tends to panic.
			// Recover here to provide more useful error messages in that case.
			defer func() {
				if r := recover(); r != nil {
					err = fmt.Errorf("panic checking %+v: %v\n%s", actual, r, debug.Stack())
				}
			}()

			matcher, ok := m.Keys[key]
			if !ok {
				if !m.IgnoreExtras {
					return fmt.Errorf("unexpected key %s: %+v", key, actual)
				}
				return nil
			}

			valInterface := actualValue.MapIndex(keyValue).Interface()

			match, err := matcher.Match(valInterface)
			if err != nil {
				return err
			}

			if !match {
				if nesting, ok := matcher.(errorsutil.NestingMatcher); ok {
					return errorsutil.AggregateError(nesting.Failures())
				}
				return errors.New(matcher.FailureMessage(valInterface))
			}
			return nil
		}()
		if err != nil {
			errs = append(errs, errorsutil.Nest(fmt.Sprintf(".%#v", key), err))
		}
	}

	for key := range m.Keys {
		if !keys[key] && !m.IgnoreMissing {
			errs = append(errs, fmt.Errorf("missing expected key %s", key))
		}
	}

	return errs
}

func (m *KeysMatcher) FailureMessage(actual interface{}) (message string) {
	failures := make([]string, len(m.failures))
	for i := range m.failures {
		failures[i] = m.failures[i].Error()
	}
	return format.Message(reflect.TypeOf(actual).Name(),
		fmt.Sprintf("to match keys: {\n%v\n}\n", strings.Join(failures, "\n")))
}

func (m *KeysMatcher) NegatedFailureMessage(actual interface{}) (message string) {
	return format.Message(actual, "not to match keys")
}

func (m *KeysMatcher) Failures() []error {
	return m.failures
}