aboutsummaryrefslogtreecommitdiff
path: root/vendor/github.com/onsi/gomega/matchers/be_numerically_matcher.go
blob: f72591a1a8d5ef3a2f6ecbef15b16af4308e1775 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
// untested sections: 4

package matchers

import (
	"fmt"
	"math"

	"github.com/onsi/gomega/format"
)

type BeNumericallyMatcher struct {
	Comparator string
	CompareTo  []interface{}
}

func (matcher *BeNumericallyMatcher) FailureMessage(actual interface{}) (message string) {
	return matcher.FormatFailureMessage(actual, false)
}

func (matcher *BeNumericallyMatcher) NegatedFailureMessage(actual interface{}) (message string) {
	return matcher.FormatFailureMessage(actual, true)
}

func (matcher *BeNumericallyMatcher) FormatFailureMessage(actual interface{}, negated bool) (message string) {
	if len(matcher.CompareTo) == 1 {
		message = fmt.Sprintf("to be %s", matcher.Comparator)
	} else {
		message = fmt.Sprintf("to be within %v of %s", matcher.CompareTo[1], matcher.Comparator)
	}
	if negated {
		message = "not " + message
	}
	return format.Message(actual, message, matcher.CompareTo[0])
}

func (matcher *BeNumericallyMatcher) Match(actual interface{}) (success bool, err error) {
	if len(matcher.CompareTo) == 0 || len(matcher.CompareTo) > 2 {
		return false, fmt.Errorf("BeNumerically requires 1 or 2 CompareTo arguments.  Got:\n%s", format.Object(matcher.CompareTo, 1))
	}
	if !isNumber(actual) {
		return false, fmt.Errorf("Expected a number.  Got:\n%s", format.Object(actual, 1))
	}
	if !isNumber(matcher.CompareTo[0]) {
		return false, fmt.Errorf("Expected a number.  Got:\n%s", format.Object(matcher.CompareTo[0], 1))
	}
	if len(matcher.CompareTo) == 2 && !isNumber(matcher.CompareTo[1]) {
		return false, fmt.Errorf("Expected a number.  Got:\n%s", format.Object(matcher.CompareTo[0], 1))
	}

	switch matcher.Comparator {
	case "==", "~", ">", ">=", "<", "<=":
	default:
		return false, fmt.Errorf("Unknown comparator: %s", matcher.Comparator)
	}

	if isFloat(actual) || isFloat(matcher.CompareTo[0]) {
		var secondOperand float64 = 1e-8
		if len(matcher.CompareTo) == 2 {
			secondOperand = toFloat(matcher.CompareTo[1])
		}
		success = matcher.matchFloats(toFloat(actual), toFloat(matcher.CompareTo[0]), secondOperand)
	} else if isInteger(actual) {
		var secondOperand int64 = 0
		if len(matcher.CompareTo) == 2 {
			secondOperand = toInteger(matcher.CompareTo[1])
		}
		success = matcher.matchIntegers(toInteger(actual), toInteger(matcher.CompareTo[0]), secondOperand)
	} else if isUnsignedInteger(actual) {
		var secondOperand uint64 = 0
		if len(matcher.CompareTo) == 2 {
			secondOperand = toUnsignedInteger(matcher.CompareTo[1])
		}
		success = matcher.matchUnsignedIntegers(toUnsignedInteger(actual), toUnsignedInteger(matcher.CompareTo[0]), secondOperand)
	} else {
		return false, fmt.Errorf("Failed to compare:\n%s\n%s:\n%s", format.Object(actual, 1), matcher.Comparator, format.Object(matcher.CompareTo[0], 1))
	}

	return success, nil
}

func (matcher *BeNumericallyMatcher) matchIntegers(actual, compareTo, threshold int64) (success bool) {
	switch matcher.Comparator {
	case "==", "~":
		diff := actual - compareTo
		return -threshold <= diff && diff <= threshold
	case ">":
		return (actual > compareTo)
	case ">=":
		return (actual >= compareTo)
	case "<":
		return (actual < compareTo)
	case "<=":
		return (actual <= compareTo)
	}
	return false
}

func (matcher *BeNumericallyMatcher) matchUnsignedIntegers(actual, compareTo, threshold uint64) (success bool) {
	switch matcher.Comparator {
	case "==", "~":
		if actual < compareTo {
			actual, compareTo = compareTo, actual
		}
		return actual-compareTo <= threshold
	case ">":
		return (actual > compareTo)
	case ">=":
		return (actual >= compareTo)
	case "<":
		return (actual < compareTo)
	case "<=":
		return (actual <= compareTo)
	}
	return false
}

func (matcher *BeNumericallyMatcher) matchFloats(actual, compareTo, threshold float64) (success bool) {
	switch matcher.Comparator {
	case "~":
		return math.Abs(actual-compareTo) <= threshold
	case "==":
		return (actual == compareTo)
	case ">":
		return (actual > compareTo)
	case ">=":
		return (actual >= compareTo)
	case "<":
		return (actual < compareTo)
	case "<=":
		return (actual <= compareTo)
	}
	return false
}