summaryrefslogtreecommitdiff
path: root/vendor/google.golang.org/protobuf/proto/checkinit.go
blob: 3e9a6a2f66c77cdcb18acb3bb56ffa36be9e5a18 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package proto

import (
	"google.golang.org/protobuf/internal/errors"
	"google.golang.org/protobuf/reflect/protoreflect"
	"google.golang.org/protobuf/runtime/protoiface"
)

// CheckInitialized returns an error if any required fields in m are not set.
func CheckInitialized(m Message) error {
	// Treat a nil message interface as an "untyped" empty message,
	// which we assume to have no required fields.
	if m == nil {
		return nil
	}

	return checkInitialized(m.ProtoReflect())
}

// CheckInitialized returns an error if any required fields in m are not set.
func checkInitialized(m protoreflect.Message) error {
	if methods := protoMethods(m); methods != nil && methods.CheckInitialized != nil {
		_, err := methods.CheckInitialized(protoiface.CheckInitializedInput{
			Message: m,
		})
		return err
	}
	return checkInitializedSlow(m)
}

func checkInitializedSlow(m protoreflect.Message) error {
	md := m.Descriptor()
	fds := md.Fields()
	for i, nums := 0, md.RequiredNumbers(); i < nums.Len(); i++ {
		fd := fds.ByNumber(nums.Get(i))
		if !m.Has(fd) {
			return errors.RequiredNotSet(string(fd.FullName()))
		}
	}
	var err error
	m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
		switch {
		case fd.IsList():
			if fd.Message() == nil {
				return true
			}
			for i, list := 0, v.List(); i < list.Len() && err == nil; i++ {
				err = checkInitialized(list.Get(i).Message())
			}
		case fd.IsMap():
			if fd.MapValue().Message() == nil {
				return true
			}
			v.Map().Range(func(key protoreflect.MapKey, v protoreflect.Value) bool {
				err = checkInitialized(v.Message())
				return err == nil
			})
		default:
			if fd.Message() == nil {
				return true
			}
			err = checkInitialized(v.Message())
		}
		return err == nil
	})
	return err
}