diff options
Diffstat (limited to 'vendor/github.com/onsi/gomega/internal/async_assertion.go')
-rw-r--r-- | vendor/github.com/onsi/gomega/internal/async_assertion.go | 220 |
1 files changed, 220 insertions, 0 deletions
diff --git a/vendor/github.com/onsi/gomega/internal/async_assertion.go b/vendor/github.com/onsi/gomega/internal/async_assertion.go new file mode 100644 index 000000000..ae20c14b8 --- /dev/null +++ b/vendor/github.com/onsi/gomega/internal/async_assertion.go @@ -0,0 +1,220 @@ +package internal + +import ( + "errors" + "fmt" + "reflect" + "runtime" + "time" + + "github.com/onsi/gomega/types" +) + +type AsyncAssertionType uint + +const ( + AsyncAssertionTypeEventually AsyncAssertionType = iota + AsyncAssertionTypeConsistently +) + +type AsyncAssertion struct { + asyncType AsyncAssertionType + + actualIsFunc bool + actualValue interface{} + actualFunc func() ([]reflect.Value, error) + + timeoutInterval time.Duration + pollingInterval time.Duration + offset int + g *Gomega +} + +func NewAsyncAssertion(asyncType AsyncAssertionType, actualInput interface{}, g *Gomega, timeoutInterval time.Duration, pollingInterval time.Duration, offset int) *AsyncAssertion { + out := &AsyncAssertion{ + asyncType: asyncType, + timeoutInterval: timeoutInterval, + pollingInterval: pollingInterval, + offset: offset, + g: g, + } + + switch actualType := reflect.TypeOf(actualInput); { + case actualType.Kind() != reflect.Func: + out.actualValue = actualInput + case actualType.NumIn() == 0 && actualType.NumOut() > 0: + out.actualIsFunc = true + out.actualFunc = func() ([]reflect.Value, error) { + return reflect.ValueOf(actualInput).Call([]reflect.Value{}), nil + } + case actualType.NumIn() == 1 && actualType.In(0).Implements(reflect.TypeOf((*types.Gomega)(nil)).Elem()): + out.actualIsFunc = true + out.actualFunc = func() (values []reflect.Value, err error) { + var assertionFailure error + assertionCapturingGomega := NewGomega(g.DurationBundle).ConfigureWithFailHandler(func(message string, callerSkip ...int) { + skip := 0 + if len(callerSkip) > 0 { + skip = callerSkip[0] + } + _, file, line, _ := runtime.Caller(skip + 1) + assertionFailure = fmt.Errorf("Assertion in callback at %s:%d failed:\n%s", file, line, message) + panic("stop execution") + }) + + defer func() { + if actualType.NumOut() == 0 { + if assertionFailure == nil { + values = []reflect.Value{reflect.Zero(reflect.TypeOf((*error)(nil)).Elem())} + } else { + values = []reflect.Value{reflect.ValueOf(assertionFailure)} + } + } else { + err = assertionFailure + } + if e := recover(); e != nil && assertionFailure == nil { + panic(e) + } + }() + + values = reflect.ValueOf(actualInput).Call([]reflect.Value{reflect.ValueOf(assertionCapturingGomega)}) + return + } + default: + msg := fmt.Sprintf("The function passed to Gomega's async assertions should either take no arguments and return values, or take a single Gomega interface that it can use to make assertions within the body of the function. When taking a Gomega interface the function can optionally return values or return nothing. The function you passed takes %d arguments and returns %d values.", actualType.NumIn(), actualType.NumOut()) + g.Fail(msg, offset+4) + } + + return out +} + +func (assertion *AsyncAssertion) Should(matcher types.GomegaMatcher, optionalDescription ...interface{}) bool { + assertion.g.THelper() + return assertion.match(matcher, true, optionalDescription...) +} + +func (assertion *AsyncAssertion) ShouldNot(matcher types.GomegaMatcher, optionalDescription ...interface{}) bool { + assertion.g.THelper() + return assertion.match(matcher, false, optionalDescription...) +} + +func (assertion *AsyncAssertion) buildDescription(optionalDescription ...interface{}) string { + switch len(optionalDescription) { + case 0: + return "" + case 1: + if describe, ok := optionalDescription[0].(func() string); ok { + return describe() + "\n" + } + } + return fmt.Sprintf(optionalDescription[0].(string), optionalDescription[1:]...) + "\n" +} + +func (assertion *AsyncAssertion) pollActual() (interface{}, error) { + if !assertion.actualIsFunc { + return assertion.actualValue, nil + } + + values, err := assertion.actualFunc() + if err != nil { + return nil, err + } + extras := []interface{}{} + for _, value := range values[1:] { + extras = append(extras, value.Interface()) + } + success, message := vetExtras(extras) + if !success { + return nil, errors.New(message) + } + + return values[0].Interface(), nil +} + +func (assertion *AsyncAssertion) matcherMayChange(matcher types.GomegaMatcher, value interface{}) bool { + if assertion.actualIsFunc { + return true + } + return types.MatchMayChangeInTheFuture(matcher, value) +} + +func (assertion *AsyncAssertion) match(matcher types.GomegaMatcher, desiredMatch bool, optionalDescription ...interface{}) bool { + timer := time.Now() + timeout := time.After(assertion.timeoutInterval) + + var matches bool + var err error + mayChange := true + value, err := assertion.pollActual() + if err == nil { + mayChange = assertion.matcherMayChange(matcher, value) + matches, err = matcher.Match(value) + } + + assertion.g.THelper() + + fail := func(preamble string) { + errMsg := "" + message := "" + if err != nil { + errMsg = "Error: " + err.Error() + } else { + if desiredMatch { + message = matcher.FailureMessage(value) + } else { + message = matcher.NegatedFailureMessage(value) + } + } + assertion.g.THelper() + description := assertion.buildDescription(optionalDescription...) + assertion.g.Fail(fmt.Sprintf("%s after %.3fs.\n%s%s%s", preamble, time.Since(timer).Seconds(), description, message, errMsg), 3+assertion.offset) + } + + if assertion.asyncType == AsyncAssertionTypeEventually { + for { + if err == nil && matches == desiredMatch { + return true + } + + if !mayChange { + fail("No future change is possible. Bailing out early") + return false + } + + select { + case <-time.After(assertion.pollingInterval): + value, err = assertion.pollActual() + if err == nil { + mayChange = assertion.matcherMayChange(matcher, value) + matches, err = matcher.Match(value) + } + case <-timeout: + fail("Timed out") + return false + } + } + } else if assertion.asyncType == AsyncAssertionTypeConsistently { + for { + if !(err == nil && matches == desiredMatch) { + fail("Failed") + return false + } + + if !mayChange { + return true + } + + select { + case <-time.After(assertion.pollingInterval): + value, err = assertion.pollActual() + if err == nil { + mayChange = assertion.matcherMayChange(matcher, value) + matches, err = matcher.Match(value) + } + case <-timeout: + return true + } + } + } + + return false +} |