summaryrefslogtreecommitdiff
path: root/vendor/github.com/gogo/protobuf/proto/extensions.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/gogo/protobuf/proto/extensions.go')
-rw-r--r--vendor/github.com/gogo/protobuf/proto/extensions.go283
1 files changed, 97 insertions, 186 deletions
diff --git a/vendor/github.com/gogo/protobuf/proto/extensions.go b/vendor/github.com/gogo/protobuf/proto/extensions.go
index 0dfcb538e..44ebd457c 100644
--- a/vendor/github.com/gogo/protobuf/proto/extensions.go
+++ b/vendor/github.com/gogo/protobuf/proto/extensions.go
@@ -38,6 +38,7 @@ package proto
import (
"errors"
"fmt"
+ "io"
"reflect"
"strconv"
"sync"
@@ -69,12 +70,6 @@ type extendableProtoV1 interface {
ExtensionMap() map[int32]Extension
}
-type extensionsBytes interface {
- Message
- ExtensionRangeArray() []ExtensionRange
- GetExtensions() *[]byte
-}
-
// extensionAdapter is a wrapper around extendableProtoV1 that implements extendableProto.
type extensionAdapter struct {
extendableProtoV1
@@ -97,14 +92,31 @@ func (n notLocker) Unlock() {}
// extendable returns the extendableProto interface for the given generated proto message.
// If the proto message has the old extension format, it returns a wrapper that implements
// the extendableProto interface.
-func extendable(p interface{}) (extendableProto, bool) {
- if ep, ok := p.(extendableProto); ok {
- return ep, ok
- }
- if ep, ok := p.(extendableProtoV1); ok {
- return extensionAdapter{ep}, ok
+func extendable(p interface{}) (extendableProto, error) {
+ switch p := p.(type) {
+ case extendableProto:
+ if isNilPtr(p) {
+ return nil, fmt.Errorf("proto: nil %T is not extendable", p)
+ }
+ return p, nil
+ case extendableProtoV1:
+ if isNilPtr(p) {
+ return nil, fmt.Errorf("proto: nil %T is not extendable", p)
+ }
+ return extensionAdapter{p}, nil
+ case extensionsBytes:
+ return slowExtensionAdapter{p}, nil
}
- return nil, false
+ // Don't allocate a specific error containing %T:
+ // this is the hot path for Clone and MarshalText.
+ return nil, errNotExtendable
+}
+
+var errNotExtendable = errors.New("proto: not an extendable proto.Message")
+
+func isNilPtr(x interface{}) bool {
+ v := reflect.ValueOf(x)
+ return v.Kind() == reflect.Ptr && v.IsNil()
}
// XXX_InternalExtensions is an internal representation of proto extensions.
@@ -149,16 +161,6 @@ func (e *XXX_InternalExtensions) extensionsRead() (map[int32]Extension, sync.Loc
return e.p.extensionMap, &e.p.mu
}
-type extensionRange interface {
- Message
- ExtensionRangeArray() []ExtensionRange
-}
-
-var extendableProtoType = reflect.TypeOf((*extendableProto)(nil)).Elem()
-var extendableProtoV1Type = reflect.TypeOf((*extendableProtoV1)(nil)).Elem()
-var extendableBytesType = reflect.TypeOf((*extensionsBytes)(nil)).Elem()
-var extensionRangeType = reflect.TypeOf((*extensionRange)(nil)).Elem()
-
// ExtensionDesc represents an extension specification.
// Used in generated code from the protocol compiler.
type ExtensionDesc struct {
@@ -198,8 +200,8 @@ func SetRawExtension(base Message, id int32, b []byte) {
*ext = append(*ext, b...)
return
}
- epb, ok := extendable(base)
- if !ok {
+ epb, err := extendable(base)
+ if err != nil {
return
}
extmap := epb.extensionsWrite()
@@ -207,7 +209,7 @@ func SetRawExtension(base Message, id int32, b []byte) {
}
// isExtensionField returns true iff the given field number is in an extension range.
-func isExtensionField(pb extensionRange, field int32) bool {
+func isExtensionField(pb extendableProto, field int32) bool {
for _, er := range pb.ExtensionRangeArray() {
if er.Start <= field && field <= er.End {
return true
@@ -223,8 +225,11 @@ func checkExtensionTypes(pb extendableProto, extension *ExtensionDesc) error {
if ea, ok := pbi.(extensionAdapter); ok {
pbi = ea.extendableProtoV1
}
+ if ea, ok := pbi.(slowExtensionAdapter); ok {
+ pbi = ea.extensionsBytes
+ }
if a, b := reflect.TypeOf(pbi), reflect.TypeOf(extension.ExtendedType); a != b {
- return errors.New("proto: bad extended type; " + b.String() + " does not extend " + a.String())
+ return fmt.Errorf("proto: bad extended type; %v does not extend %v", b, a)
}
// Check the range.
if !isExtensionField(pb, extension.Field) {
@@ -269,80 +274,6 @@ func extensionProperties(ed *ExtensionDesc) *Properties {
return prop
}
-// encode encodes any unmarshaled (unencoded) extensions in e.
-func encodeExtensions(e *XXX_InternalExtensions) error {
- m, mu := e.extensionsRead()
- if m == nil {
- return nil // fast path
- }
- mu.Lock()
- defer mu.Unlock()
- return encodeExtensionsMap(m)
-}
-
-// encode encodes any unmarshaled (unencoded) extensions in e.
-func encodeExtensionsMap(m map[int32]Extension) error {
- for k, e := range m {
- if e.value == nil || e.desc == nil {
- // Extension is only in its encoded form.
- continue
- }
-
- // We don't skip extensions that have an encoded form set,
- // because the extension value may have been mutated after
- // the last time this function was called.
-
- et := reflect.TypeOf(e.desc.ExtensionType)
- props := extensionProperties(e.desc)
-
- p := NewBuffer(nil)
- // If e.value has type T, the encoder expects a *struct{ X T }.
- // Pass a *T with a zero field and hope it all works out.
- x := reflect.New(et)
- x.Elem().Set(reflect.ValueOf(e.value))
- if err := props.enc(p, props, toStructPointer(x)); err != nil {
- return err
- }
- e.enc = p.buf
- m[k] = e
- }
- return nil
-}
-
-func extensionsSize(e *XXX_InternalExtensions) (n int) {
- m, mu := e.extensionsRead()
- if m == nil {
- return 0
- }
- mu.Lock()
- defer mu.Unlock()
- return extensionsMapSize(m)
-}
-
-func extensionsMapSize(m map[int32]Extension) (n int) {
- for _, e := range m {
- if e.value == nil || e.desc == nil {
- // Extension is only in its encoded form.
- n += len(e.enc)
- continue
- }
-
- // We don't skip extensions that have an encoded form set,
- // because the extension value may have been mutated after
- // the last time this function was called.
-
- et := reflect.TypeOf(e.desc.ExtensionType)
- props := extensionProperties(e.desc)
-
- // If e.value has type T, the encoder expects a *struct{ X T }.
- // Pass a *T with a zero field and hope it all works out.
- x := reflect.New(et)
- x.Elem().Set(reflect.ValueOf(e.value))
- n += props.size(props, toStructPointer(x))
- }
- return
-}
-
// HasExtension returns whether the given extension is present in pb.
func HasExtension(pb Message, extension *ExtensionDesc) bool {
if epb, doki := pb.(extensionsBytes); doki {
@@ -366,8 +297,8 @@ func HasExtension(pb Message, extension *ExtensionDesc) bool {
return false
}
// TODO: Check types, field numbers, etc.?
- epb, ok := extendable(pb)
- if !ok {
+ epb, err := extendable(pb)
+ if err != nil {
return false
}
extmap, mu := epb.extensionsRead()
@@ -375,46 +306,26 @@ func HasExtension(pb Message, extension *ExtensionDesc) bool {
return false
}
mu.Lock()
- _, ok = extmap[extension.Field]
+ _, ok := extmap[extension.Field]
mu.Unlock()
return ok
}
-func deleteExtension(pb extensionsBytes, theFieldNum int32, offset int) int {
- ext := pb.GetExtensions()
- for offset < len(*ext) {
- tag, n1 := DecodeVarint((*ext)[offset:])
- fieldNum := int32(tag >> 3)
- wireType := int(tag & 0x7)
- n2, err := size((*ext)[offset+n1:], wireType)
- if err != nil {
- panic(err)
- }
- newOffset := offset + n1 + n2
- if fieldNum == theFieldNum {
- *ext = append((*ext)[:offset], (*ext)[newOffset:]...)
- return offset
- }
- offset = newOffset
- }
- return -1
-}
-
// ClearExtension removes the given extension from pb.
func ClearExtension(pb Message, extension *ExtensionDesc) {
clearExtension(pb, extension.Field)
}
func clearExtension(pb Message, fieldNum int32) {
- if epb, doki := pb.(extensionsBytes); doki {
+ if epb, ok := pb.(extensionsBytes); ok {
offset := 0
for offset != -1 {
offset = deleteExtension(epb, fieldNum, offset)
}
return
}
- epb, ok := extendable(pb)
- if !ok {
+ epb, err := extendable(pb)
+ if err != nil {
return
}
// TODO: Check types, field numbers, etc.?
@@ -422,39 +333,33 @@ func clearExtension(pb Message, fieldNum int32) {
delete(extmap, fieldNum)
}
-// GetExtension parses and returns the given extension of pb.
-// If the extension is not present and has no default value it returns ErrMissingExtension.
+// GetExtension retrieves a proto2 extended field from pb.
+//
+// If the descriptor is type complete (i.e., ExtensionDesc.ExtensionType is non-nil),
+// then GetExtension parses the encoded field and returns a Go value of the specified type.
+// If the field is not present, then the default value is returned (if one is specified),
+// otherwise ErrMissingExtension is reported.
+//
+// If the descriptor is not type complete (i.e., ExtensionDesc.ExtensionType is nil),
+// then GetExtension returns the raw encoded bytes of the field extension.
func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) {
if epb, doki := pb.(extensionsBytes); doki {
ext := epb.GetExtensions()
- o := 0
- for o < len(*ext) {
- tag, n := DecodeVarint((*ext)[o:])
- fieldNum := int32(tag >> 3)
- wireType := int(tag & 0x7)
- l, err := size((*ext)[o+n:], wireType)
- if err != nil {
- return nil, err
- }
- if int32(fieldNum) == extension.Field {
- v, err := decodeExtension((*ext)[o:o+n+l], extension)
- if err != nil {
- return nil, err
- }
- return v, nil
- }
- o += n + l
- }
- return defaultExtensionValue(extension)
+ return decodeExtensionFromBytes(extension, *ext)
}
- epb, ok := extendable(pb)
- if !ok {
- return nil, errors.New("proto: not an extendable proto")
- }
- if err := checkExtensionTypes(epb, extension); err != nil {
+
+ epb, err := extendable(pb)
+ if err != nil {
return nil, err
}
+ if extension.ExtendedType != nil {
+ // can only check type if this is a complete descriptor
+ if cerr := checkExtensionTypes(epb, extension); cerr != nil {
+ return nil, cerr
+ }
+ }
+
emap, mu := epb.extensionsRead()
if emap == nil {
return defaultExtensionValue(extension)
@@ -479,6 +384,11 @@ func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) {
return e.value, nil
}
+ if extension.ExtensionType == nil {
+ // incomplete descriptor
+ return e.enc, nil
+ }
+
v, err := decodeExtension(e.enc, extension)
if err != nil {
return nil, err
@@ -496,6 +406,11 @@ func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) {
// defaultExtensionValue returns the default value for extension.
// If no default for an extension is defined ErrMissingExtension is returned.
func defaultExtensionValue(extension *ExtensionDesc) (interface{}, error) {
+ if extension.ExtensionType == nil {
+ // incomplete descriptor, so no default
+ return nil, ErrMissingExtension
+ }
+
t := reflect.TypeOf(extension.ExtensionType)
props := extensionProperties(extension)
@@ -530,31 +445,28 @@ func defaultExtensionValue(extension *ExtensionDesc) (interface{}, error) {
// decodeExtension decodes an extension encoded in b.
func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) {
- o := NewBuffer(b)
-
t := reflect.TypeOf(extension.ExtensionType)
-
- props := extensionProperties(extension)
+ unmarshal := typeUnmarshaler(t, extension.Tag)
// t is a pointer to a struct, pointer to basic type or a slice.
- // Allocate a "field" to store the pointer/slice itself; the
- // pointer/slice will be stored here. We pass
- // the address of this field to props.dec.
- // This passes a zero field and a *t and lets props.dec
- // interpret it as a *struct{ x t }.
+ // Allocate space to store the pointer/slice.
value := reflect.New(t).Elem()
+ var err error
for {
- // Discard wire type and field number varint. It isn't needed.
- if _, err := o.DecodeVarint(); err != nil {
- return nil, err
+ x, n := decodeVarint(b)
+ if n == 0 {
+ return nil, io.ErrUnexpectedEOF
}
+ b = b[n:]
+ wire := int(x) & 7
- if err := props.dec(o, props, toStructPointer(value.Addr())); err != nil {
+ b, err = unmarshal(b, valToPointer(value.Addr()), wire)
+ if err != nil {
return nil, err
}
- if o.index >= len(o.buf) {
+ if len(b) == 0 {
break
}
}
@@ -564,9 +476,13 @@ func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) {
// GetExtensions returns a slice of the extensions present in pb that are also listed in es.
// The returned slice has the same length as es; missing extensions will appear as nil elements.
func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) {
+ epb, err := extendable(pb)
+ if err != nil {
+ return nil, err
+ }
extensions = make([]interface{}, len(es))
for i, e := range es {
- extensions[i], err = GetExtension(pb, e)
+ extensions[i], err = GetExtension(epb, e)
if err == ErrMissingExtension {
err = nil
}
@@ -581,9 +497,9 @@ func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, e
// For non-registered extensions, ExtensionDescs returns an incomplete descriptor containing
// just the Field field, which defines the extension's field number.
func ExtensionDescs(pb Message) ([]*ExtensionDesc, error) {
- epb, ok := extendable(pb)
- if !ok {
- return nil, fmt.Errorf("proto: %T is not an extendable proto.Message", pb)
+ epb, err := extendable(pb)
+ if err != nil {
+ return nil, err
}
registeredExtensions := RegisteredExtensions(pb)
@@ -610,23 +526,18 @@ func ExtensionDescs(pb Message) ([]*ExtensionDesc, error) {
// SetExtension sets the specified extension of pb to the specified value.
func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error {
- if epb, doki := pb.(extensionsBytes); doki {
- ClearExtension(pb, extension)
- ext := epb.GetExtensions()
- et := reflect.TypeOf(extension.ExtensionType)
- props := extensionProperties(extension)
- p := NewBuffer(nil)
- x := reflect.New(et)
- x.Elem().Set(reflect.ValueOf(value))
- if err := props.enc(p, props, toStructPointer(x)); err != nil {
+ if epb, ok := pb.(extensionsBytes); ok {
+ newb, err := encodeExtension(extension, value)
+ if err != nil {
return err
}
- *ext = append(*ext, p.buf...)
+ bb := epb.GetExtensions()
+ *bb = append(*bb, newb...)
return nil
}
- epb, ok := extendable(pb)
- if !ok {
- return errors.New("proto: not an extendable proto")
+ epb, err := extendable(pb)
+ if err != nil {
+ return err
}
if err := checkExtensionTypes(epb, extension); err != nil {
return err
@@ -656,8 +567,8 @@ func ClearAllExtensions(pb Message) {
*ext = []byte{}
return
}
- epb, ok := extendable(pb)
- if !ok {
+ epb, err := extendable(pb)
+ if err != nil {
return
}
m := epb.extensionsWrite()