summaryrefslogtreecommitdiff
path: root/vendor/github.com/gogo/protobuf/proto/extensions_gogo.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/gogo/protobuf/proto/extensions_gogo.go')
-rw-r--r--vendor/github.com/gogo/protobuf/proto/extensions_gogo.go162
1 files changed, 118 insertions, 44 deletions
diff --git a/vendor/github.com/gogo/protobuf/proto/extensions_gogo.go b/vendor/github.com/gogo/protobuf/proto/extensions_gogo.go
index ea6478f00..53ebd8cca 100644
--- a/vendor/github.com/gogo/protobuf/proto/extensions_gogo.go
+++ b/vendor/github.com/gogo/protobuf/proto/extensions_gogo.go
@@ -32,12 +32,36 @@ import (
"bytes"
"errors"
"fmt"
+ "io"
"reflect"
"sort"
"strings"
"sync"
)
+type extensionsBytes interface {
+ Message
+ ExtensionRangeArray() []ExtensionRange
+ GetExtensions() *[]byte
+}
+
+type slowExtensionAdapter struct {
+ extensionsBytes
+}
+
+func (s slowExtensionAdapter) extensionsWrite() map[int32]Extension {
+ panic("Please report a bug to github.com/gogo/protobuf if you see this message: Writing extensions is not supported for extensions stored in a byte slice field.")
+}
+
+func (s slowExtensionAdapter) extensionsRead() (map[int32]Extension, sync.Locker) {
+ b := s.GetExtensions()
+ m, err := BytesToExtensionsMap(*b)
+ if err != nil {
+ panic(err)
+ }
+ return m, notLocker{}
+}
+
func GetBoolExtension(pb Message, extension *ExtensionDesc, ifnotset bool) bool {
if reflect.ValueOf(pb).IsNil() {
return ifnotset
@@ -56,19 +80,28 @@ func GetBoolExtension(pb Message, extension *ExtensionDesc, ifnotset bool) bool
}
func (this *Extension) Equal(that *Extension) bool {
+ if err := this.Encode(); err != nil {
+ return false
+ }
+ if err := that.Encode(); err != nil {
+ return false
+ }
return bytes.Equal(this.enc, that.enc)
}
func (this *Extension) Compare(that *Extension) int {
+ if err := this.Encode(); err != nil {
+ return 1
+ }
+ if err := that.Encode(); err != nil {
+ return -1
+ }
return bytes.Compare(this.enc, that.enc)
}
func SizeOfInternalExtension(m extendableProto) (n int) {
- return SizeOfExtensionMap(m.extensionsWrite())
-}
-
-func SizeOfExtensionMap(m map[int32]Extension) (n int) {
- return extensionsMapSize(m)
+ info := getMarshalInfo(reflect.TypeOf(m))
+ return info.sizeV1Extensions(m.extensionsWrite())
}
type sortableMapElem struct {
@@ -122,28 +155,26 @@ func EncodeInternalExtension(m extendableProto, data []byte) (n int, err error)
}
func EncodeExtensionMap(m map[int32]Extension, data []byte) (n int, err error) {
- if err := encodeExtensionsMap(m); err != nil {
- return 0, err
- }
- keys := make([]int, 0, len(m))
- for k := range m {
- keys = append(keys, int(k))
- }
- sort.Ints(keys)
- for _, k := range keys {
- n += copy(data[n:], m[int32(k)].enc)
+ o := 0
+ for _, e := range m {
+ if err := e.Encode(); err != nil {
+ return 0, err
+ }
+ n := copy(data[o:], e.enc)
+ if n != len(e.enc) {
+ return 0, io.ErrShortBuffer
+ }
+ o += n
}
- return n, nil
+ return o, nil
}
func GetRawExtension(m map[int32]Extension, id int32) ([]byte, error) {
- if m[id].value == nil || m[id].desc == nil {
- return m[id].enc, nil
- }
- if err := encodeExtensionsMap(m); err != nil {
+ e := m[id]
+ if err := e.Encode(); err != nil {
return nil, err
}
- return m[id].enc, nil
+ return e.enc, nil
}
func size(buf []byte, wire int) (int, error) {
@@ -218,35 +249,58 @@ func AppendExtension(e Message, tag int32, buf []byte) {
}
}
-func encodeExtension(e *Extension) error {
- if e.value == nil || e.desc == nil {
- // Extension is only in its encoded form.
- return nil
+func encodeExtension(extension *ExtensionDesc, value interface{}) ([]byte, error) {
+ u := getMarshalInfo(reflect.TypeOf(extension.ExtendedType))
+ ei := u.getExtElemInfo(extension)
+ v := value
+ p := toAddrPointer(&v, ei.isptr)
+ siz := ei.sizer(p, SizeVarint(ei.wiretag))
+ buf := make([]byte, 0, siz)
+ return ei.marshaler(buf, p, ei.wiretag, false)
+}
+
+func decodeExtensionFromBytes(extension *ExtensionDesc, buf []byte) (interface{}, error) {
+ o := 0
+ for o < len(buf) {
+ tag, n := DecodeVarint((buf)[o:])
+ fieldNum := int32(tag >> 3)
+ wireType := int(tag & 0x7)
+ if o+n > len(buf) {
+ return nil, fmt.Errorf("unable to decode extension")
+ }
+ l, err := size((buf)[o+n:], wireType)
+ if err != nil {
+ return nil, err
+ }
+ if int32(fieldNum) == extension.Field {
+ if o+n+l > len(buf) {
+ return nil, fmt.Errorf("unable to decode extension")
+ }
+ v, err := decodeExtension((buf)[o:o+n+l], extension)
+ if err != nil {
+ return nil, err
+ }
+ return v, nil
+ }
+ o += n + l
}
- // We don't skip extensions that have an encoded form set,
- // because the extension value may have been mutated after
- // the last time this function was called.
-
- et := reflect.TypeOf(e.desc.ExtensionType)
- props := extensionProperties(e.desc)
-
- p := NewBuffer(nil)
- // If e.value has type T, the encoder expects a *struct{ X T }.
- // Pass a *T with a zero field and hope it all works out.
- x := reflect.New(et)
- x.Elem().Set(reflect.ValueOf(e.value))
- if err := props.enc(p, props, toStructPointer(x)); err != nil {
- return err
+ return defaultExtensionValue(extension)
+}
+
+func (this *Extension) Encode() error {
+ if this.enc == nil {
+ var err error
+ this.enc, err = encodeExtension(this.desc, this.value)
+ if err != nil {
+ return err
+ }
}
- e.enc = p.buf
return nil
}
func (this Extension) GoString() string {
- if this.enc == nil {
- if err := encodeExtension(&this); err != nil {
- panic(err)
- }
+ if err := this.Encode(); err != nil {
+ return fmt.Sprintf("error encoding extension: %v", err)
}
return fmt.Sprintf("proto.NewExtension(%#v)", this.enc)
}
@@ -292,3 +346,23 @@ func GetUnsafeExtensionsMap(extendable Message) map[int32]Extension {
pb := extendable.(extendableProto)
return pb.extensionsWrite()
}
+
+func deleteExtension(pb extensionsBytes, theFieldNum int32, offset int) int {
+ ext := pb.GetExtensions()
+ for offset < len(*ext) {
+ tag, n1 := DecodeVarint((*ext)[offset:])
+ fieldNum := int32(tag >> 3)
+ wireType := int(tag & 0x7)
+ n2, err := size((*ext)[offset+n1:], wireType)
+ if err != nil {
+ panic(err)
+ }
+ newOffset := offset + n1 + n2
+ if fieldNum == theFieldNum {
+ *ext = append((*ext)[:offset], (*ext)[newOffset:]...)
+ return offset
+ }
+ offset = newOffset
+ }
+ return -1
+}