// untested sections: 6

package gstruct

import (
	"errors"
	"fmt"
	"reflect"
	"runtime/debug"
	"strings"

	"github.com/onsi/gomega/format"
	errorsutil "github.com/onsi/gomega/gstruct/errors"
	"github.com/onsi/gomega/types"
)

//MatchAllFields succeeds if every field of a struct matches the field matcher associated with
//it, and every element matcher is matched.
//    actual := struct{
//      A int
//      B []bool
//      C string
//    }{
//      A: 5,
//      B: []bool{true, false},
//      C: "foo",
//    }
//
//    Expect(actual).To(MatchAllFields(Fields{
//      "A": Equal(5),
//      "B": ConsistOf(true, false),
//      "C": Equal("foo"),
//    }))
func MatchAllFields(fields Fields) types.GomegaMatcher {
	return &FieldsMatcher{
		Fields: fields,
	}
}

//MatchFields succeeds if each element of a struct matches the field matcher associated with
//it. It can ignore extra fields and/or missing fields.
//    actual := struct{
//      A int
//      B []bool
//      C string
//    }{
//      A: 5,
//      B: []bool{true, false},
//      C: "foo",
//    }
//
//    Expect(actual).To(MatchFields(IgnoreExtras, Fields{
//      "A": Equal(5),
//      "B": ConsistOf(true, false),
//    }))
//    Expect(actual).To(MatchFields(IgnoreMissing, Fields{
//      "A": Equal(5),
//      "B": ConsistOf(true, false),
//      "C": Equal("foo"),
//      "D": Equal("extra"),
//    }))
func MatchFields(options Options, fields Fields) types.GomegaMatcher {
	return &FieldsMatcher{
		Fields:        fields,
		IgnoreExtras:  options&IgnoreExtras != 0,
		IgnoreMissing: options&IgnoreMissing != 0,
	}
}

type FieldsMatcher struct {
	// Matchers for each field.
	Fields Fields

	// Whether to ignore extra elements or consider it an error.
	IgnoreExtras bool
	// Whether to ignore missing elements or consider it an error.
	IgnoreMissing bool

	// State.
	failures []error
}

// Field name to matcher.
type Fields map[string]types.GomegaMatcher

func (m *FieldsMatcher) Match(actual interface{}) (success bool, err error) {
	if reflect.TypeOf(actual).Kind() != reflect.Struct {
		return false, fmt.Errorf("%v is type %T, expected struct", actual, actual)
	}

	m.failures = m.matchFields(actual)
	if len(m.failures) > 0 {
		return false, nil
	}
	return true, nil
}

func (m *FieldsMatcher) matchFields(actual interface{}) (errs []error) {
	val := reflect.ValueOf(actual)
	typ := val.Type()
	fields := map[string]bool{}
	for i := 0; i < val.NumField(); i++ {
		fieldName := typ.Field(i).Name
		fields[fieldName] = true

		err := func() (err error) {
			// This test relies heavily on reflect, which tends to panic.
			// Recover here to provide more useful error messages in that case.
			defer func() {
				if r := recover(); r != nil {
					err = fmt.Errorf("panic checking %+v: %v\n%s", actual, r, debug.Stack())
				}
			}()

			matcher, expected := m.Fields[fieldName]
			if !expected {
				if !m.IgnoreExtras {
					return fmt.Errorf("unexpected field %s: %+v", fieldName, actual)
				}
				return nil
			}

			field := val.Field(i).Interface()

			match, err := matcher.Match(field)
			if err != nil {
				return err
			} else if !match {
				if nesting, ok := matcher.(errorsutil.NestingMatcher); ok {
					return errorsutil.AggregateError(nesting.Failures())
				}
				return errors.New(matcher.FailureMessage(field))
			}
			return nil
		}()
		if err != nil {
			errs = append(errs, errorsutil.Nest("."+fieldName, err))
		}
	}

	for field := range m.Fields {
		if !fields[field] && !m.IgnoreMissing {
			errs = append(errs, fmt.Errorf("missing expected field %s", field))
		}
	}

	return errs
}

func (m *FieldsMatcher) FailureMessage(actual interface{}) (message string) {
	failures := make([]string, len(m.failures))
	for i := range m.failures {
		failures[i] = m.failures[i].Error()
	}
	return format.Message(reflect.TypeOf(actual).Name(),
		fmt.Sprintf("to match fields: {\n%v\n}\n", strings.Join(failures, "\n")))
}

func (m *FieldsMatcher) NegatedFailureMessage(actual interface{}) (message string) {
	return format.Message(actual, "not to match fields")
}

func (m *FieldsMatcher) Failures() []error {
	return m.failures
}