aboutsummaryrefslogtreecommitdiff
path: root/vendor/github.com/gogo/protobuf/proto/message_set.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/gogo/protobuf/proto/message_set.go')
-rw-r--r--vendor/github.com/gogo/protobuf/proto/message_set.go81
1 files changed, 42 insertions, 39 deletions
diff --git a/vendor/github.com/gogo/protobuf/proto/message_set.go b/vendor/github.com/gogo/protobuf/proto/message_set.go
index fd982decd..3b6ca41d5 100644
--- a/vendor/github.com/gogo/protobuf/proto/message_set.go
+++ b/vendor/github.com/gogo/protobuf/proto/message_set.go
@@ -42,6 +42,7 @@ import (
"fmt"
"reflect"
"sort"
+ "sync"
)
// errNoMessageTypeID occurs when a protocol buffer does not have a message type ID.
@@ -94,10 +95,7 @@ func (ms *messageSet) find(pb Message) *_MessageSet_Item {
}
func (ms *messageSet) Has(pb Message) bool {
- if ms.find(pb) != nil {
- return true
- }
- return false
+ return ms.find(pb) != nil
}
func (ms *messageSet) Unmarshal(pb Message) error {
@@ -150,46 +148,42 @@ func skipVarint(buf []byte) []byte {
// MarshalMessageSet encodes the extension map represented by m in the message set wire format.
// It is called by generated Marshal methods on protocol buffer messages with the message_set_wire_format option.
func MarshalMessageSet(exts interface{}) ([]byte, error) {
- var m map[int32]Extension
+ return marshalMessageSet(exts, false)
+}
+
+// marshaMessageSet implements above function, with the opt to turn on / off deterministic during Marshal.
+func marshalMessageSet(exts interface{}, deterministic bool) ([]byte, error) {
switch exts := exts.(type) {
case *XXX_InternalExtensions:
- if err := encodeExtensions(exts); err != nil {
- return nil, err
- }
- m, _ = exts.extensionsRead()
+ var u marshalInfo
+ siz := u.sizeMessageSet(exts)
+ b := make([]byte, 0, siz)
+ return u.appendMessageSet(b, exts, deterministic)
+
case map[int32]Extension:
- if err := encodeExtensionsMap(exts); err != nil {
- return nil, err
+ // This is an old-style extension map.
+ // Wrap it in a new-style XXX_InternalExtensions.
+ ie := XXX_InternalExtensions{
+ p: &struct {
+ mu sync.Mutex
+ extensionMap map[int32]Extension
+ }{
+ extensionMap: exts,
+ },
}
- m = exts
+
+ var u marshalInfo
+ siz := u.sizeMessageSet(&ie)
+ b := make([]byte, 0, siz)
+ return u.appendMessageSet(b, &ie, deterministic)
+
default:
return nil, errors.New("proto: not an extension map")
}
-
- // Sort extension IDs to provide a deterministic encoding.
- // See also enc_map in encode.go.
- ids := make([]int, 0, len(m))
- for id := range m {
- ids = append(ids, int(id))
- }
- sort.Ints(ids)
-
- ms := &messageSet{Item: make([]*_MessageSet_Item, 0, len(m))}
- for _, id := range ids {
- e := m[int32(id)]
- // Remove the wire type and field number varint, as well as the length varint.
- msg := skipVarint(skipVarint(e.enc))
-
- ms.Item = append(ms.Item, &_MessageSet_Item{
- TypeId: Int32(int32(id)),
- Message: msg,
- })
- }
- return Marshal(ms)
}
// UnmarshalMessageSet decodes the extension map encoded in buf in the message set wire format.
-// It is called by generated Unmarshal methods on protocol buffer messages with the message_set_wire_format option.
+// It is called by Unmarshal methods on protocol buffer messages with the message_set_wire_format option.
func UnmarshalMessageSet(buf []byte, exts interface{}) error {
var m map[int32]Extension
switch exts := exts.(type) {
@@ -235,7 +229,15 @@ func MarshalMessageSetJSON(exts interface{}) ([]byte, error) {
var m map[int32]Extension
switch exts := exts.(type) {
case *XXX_InternalExtensions:
- m, _ = exts.extensionsRead()
+ var mu sync.Locker
+ m, mu = exts.extensionsRead()
+ if m != nil {
+ // Keep the extensions map locked until we're done marshaling to prevent
+ // races between marshaling and unmarshaling the lazily-{en,de}coded
+ // values.
+ mu.Lock()
+ defer mu.Unlock()
+ }
case map[int32]Extension:
m = exts
default:
@@ -253,15 +255,16 @@ func MarshalMessageSetJSON(exts interface{}) ([]byte, error) {
for i, id := range ids {
ext := m[id]
- if i > 0 {
- b.WriteByte(',')
- }
-
msd, ok := messageSetMap[id]
if !ok {
// Unknown type; we can't render it, so skip it.
continue
}
+
+ if i > 0 && b.Len() > 1 {
+ b.WriteByte(',')
+ }
+
fmt.Fprintf(&b, `"[%s]":`, msd.name)
x := ext.value