diff options
Diffstat (limited to 'vendor/google.golang.org/grpc/server.go')
-rw-r--r-- | vendor/google.golang.org/grpc/server.go | 54 |
1 files changed, 35 insertions, 19 deletions
diff --git a/vendor/google.golang.org/grpc/server.go b/vendor/google.golang.org/grpc/server.go index 65de84b30..b54f5bb57 100644 --- a/vendor/google.golang.org/grpc/server.go +++ b/vendor/google.golang.org/grpc/server.go @@ -73,6 +73,12 @@ func init() { internal.DrainServerTransports = func(srv *Server, addr string) { srv.drainServerTransports(addr) } + internal.AddExtraServerOptions = func(opt ...ServerOption) { + extraServerOptions = opt + } + internal.ClearExtraServerOptions = func() { + extraServerOptions = nil + } } var statusOK = status.New(codes.OK, "") @@ -150,7 +156,7 @@ type serverOptions struct { chainUnaryInts []UnaryServerInterceptor chainStreamInts []StreamServerInterceptor inTapHandle tap.ServerInHandle - statsHandler stats.Handler + statsHandlers []stats.Handler maxConcurrentStreams uint32 maxReceiveMessageSize int maxSendMessageSize int @@ -174,6 +180,7 @@ var defaultServerOptions = serverOptions{ writeBufferSize: defaultWriteBufSize, readBufferSize: defaultReadBufSize, } +var extraServerOptions []ServerOption // A ServerOption sets options such as credentials, codec and keepalive parameters, etc. type ServerOption interface { @@ -435,7 +442,7 @@ func InTapHandle(h tap.ServerInHandle) ServerOption { // StatsHandler returns a ServerOption that sets the stats handler for the server. func StatsHandler(h stats.Handler) ServerOption { return newFuncServerOption(func(o *serverOptions) { - o.statsHandler = h + o.statsHandlers = append(o.statsHandlers, h) }) } @@ -560,6 +567,9 @@ func (s *Server) stopServerWorkers() { // started to accept requests yet. func NewServer(opt ...ServerOption) *Server { opts := defaultServerOptions + for _, o := range extraServerOptions { + o.apply(&opts) + } for _, o := range opt { o.apply(&opts) } @@ -867,7 +877,7 @@ func (s *Server) newHTTP2Transport(c net.Conn) transport.ServerTransport { ConnectionTimeout: s.opts.connectionTimeout, Credentials: s.opts.creds, InTapHandle: s.opts.inTapHandle, - StatsHandler: s.opts.statsHandler, + StatsHandlers: s.opts.statsHandlers, KeepaliveParams: s.opts.keepaliveParams, KeepalivePolicy: s.opts.keepalivePolicy, InitialWindowSize: s.opts.initialWindowSize, @@ -963,7 +973,7 @@ var _ http.Handler = (*Server)(nil) // Notice: This API is EXPERIMENTAL and may be changed or removed in a // later release. func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - st, err := transport.NewServerHandlerTransport(w, r, s.opts.statsHandler) + st, err := transport.NewServerHandlerTransport(w, r, s.opts.statsHandlers) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -1076,8 +1086,10 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(payload), s.opts.maxSendMessageSize) } err = t.Write(stream, hdr, payload, opts) - if err == nil && s.opts.statsHandler != nil { - s.opts.statsHandler.HandleRPC(stream.Context(), outPayload(false, msg, data, payload, time.Now())) + if err == nil { + for _, sh := range s.opts.statsHandlers { + sh.HandleRPC(stream.Context(), outPayload(false, msg, data, payload, time.Now())) + } } return err } @@ -1124,13 +1136,13 @@ func chainUnaryInterceptors(interceptors []UnaryServerInterceptor) UnaryServerIn } func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.Stream, info *serviceInfo, md *MethodDesc, trInfo *traceInfo) (err error) { - sh := s.opts.statsHandler - if sh != nil || trInfo != nil || channelz.IsOn() { + shs := s.opts.statsHandlers + if len(shs) != 0 || trInfo != nil || channelz.IsOn() { if channelz.IsOn() { s.incrCallsStarted() } var statsBegin *stats.Begin - if sh != nil { + for _, sh := range shs { beginTime := time.Now() statsBegin = &stats.Begin{ BeginTime: beginTime, @@ -1161,7 +1173,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. trInfo.tr.Finish() } - if sh != nil { + for _, sh := range shs { end := &stats.End{ BeginTime: statsBegin.BeginTime, EndTime: time.Now(), @@ -1243,7 +1255,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } var payInfo *payloadInfo - if sh != nil || binlog != nil { + if len(shs) != 0 || binlog != nil { payInfo = &payloadInfo{} } d, err := recvAndDecompress(&parser{r: stream}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp) @@ -1260,7 +1272,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. if err := s.getCodec(stream.ContentSubtype()).Unmarshal(d, v); err != nil { return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err) } - if sh != nil { + for _, sh := range shs { sh.HandleRPC(stream.Context(), &stats.InPayload{ RecvTime: time.Now(), Payload: v, @@ -1418,16 +1430,18 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp if channelz.IsOn() { s.incrCallsStarted() } - sh := s.opts.statsHandler + shs := s.opts.statsHandlers var statsBegin *stats.Begin - if sh != nil { + if len(shs) != 0 { beginTime := time.Now() statsBegin = &stats.Begin{ BeginTime: beginTime, IsClientStream: sd.ClientStreams, IsServerStream: sd.ServerStreams, } - sh.HandleRPC(stream.Context(), statsBegin) + for _, sh := range shs { + sh.HandleRPC(stream.Context(), statsBegin) + } } ctx := NewContextWithServerTransportStream(stream.Context(), stream) ss := &serverStream{ @@ -1439,10 +1453,10 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp maxReceiveMessageSize: s.opts.maxReceiveMessageSize, maxSendMessageSize: s.opts.maxSendMessageSize, trInfo: trInfo, - statsHandler: sh, + statsHandler: shs, } - if sh != nil || trInfo != nil || channelz.IsOn() { + if len(shs) != 0 || trInfo != nil || channelz.IsOn() { // See comment in processUnaryRPC on defers. defer func() { if trInfo != nil { @@ -1456,7 +1470,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp ss.mu.Unlock() } - if sh != nil { + if len(shs) != 0 { end := &stats.End{ BeginTime: statsBegin.BeginTime, EndTime: time.Now(), @@ -1464,7 +1478,9 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp if err != nil && err != io.EOF { end.Error = toRPCErr(err) } - sh.HandleRPC(stream.Context(), end) + for _, sh := range shs { + sh.HandleRPC(stream.Context(), end) + } } if channelz.IsOn() { |