summaryrefslogtreecommitdiff
path: root/vendor/google.golang.org/grpc/server.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/google.golang.org/grpc/server.go')
-rw-r--r--vendor/google.golang.org/grpc/server.go77
1 files changed, 47 insertions, 30 deletions
diff --git a/vendor/google.golang.org/grpc/server.go b/vendor/google.golang.org/grpc/server.go
index 0a151dee4..0251f48da 100644
--- a/vendor/google.golang.org/grpc/server.go
+++ b/vendor/google.golang.org/grpc/server.go
@@ -844,10 +844,16 @@ func (s *Server) handleRawConn(lisAddr string, rawConn net.Conn) {
// ErrConnDispatched means that the connection was dispatched away from
// gRPC; those connections should be left open.
if err != credentials.ErrConnDispatched {
- s.mu.Lock()
- s.errorf("ServerHandshake(%q) failed: %v", rawConn.RemoteAddr(), err)
- s.mu.Unlock()
- channelz.Warningf(logger, s.channelzID, "grpc: Server.Serve failed to complete security handshake from %q: %v", rawConn.RemoteAddr(), err)
+ // In deployments where a gRPC server runs behind a cloud load
+ // balancer which performs regular TCP level health checks, the
+ // connection is closed immediately by the latter. Skipping the
+ // error here will help reduce log clutter.
+ if err != io.EOF {
+ s.mu.Lock()
+ s.errorf("ServerHandshake(%q) failed: %v", rawConn.RemoteAddr(), err)
+ s.mu.Unlock()
+ channelz.Warningf(logger, s.channelzID, "grpc: Server.Serve failed to complete security handshake from %q: %v", rawConn.RemoteAddr(), err)
+ }
rawConn.Close()
}
rawConn.SetDeadline(time.Time{})
@@ -857,6 +863,7 @@ func (s *Server) handleRawConn(lisAddr string, rawConn net.Conn) {
// Finish handshaking (HTTP2)
st := s.newHTTP2Transport(conn, authInfo)
if st == nil {
+ conn.Close()
return
}
@@ -897,7 +904,7 @@ func (s *Server) newHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) tr
MaxHeaderListSize: s.opts.maxHeaderListSize,
HeaderTableSize: s.opts.headerTableSize,
}
- st, err := transport.NewServerTransport("http2", c, config)
+ st, err := transport.NewServerTransport(c, config)
if err != nil {
s.mu.Lock()
s.errorf("NewServerTransport(%q) failed: %v", c.RemoteAddr(), err)
@@ -1109,22 +1116,24 @@ func chainUnaryServerInterceptors(s *Server) {
} else if len(interceptors) == 1 {
chainedInt = interceptors[0]
} else {
- chainedInt = func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (interface{}, error) {
- return interceptors[0](ctx, req, info, getChainUnaryHandler(interceptors, 0, info, handler))
- }
+ chainedInt = chainUnaryInterceptors(interceptors)
}
s.opts.unaryInt = chainedInt
}
-// getChainUnaryHandler recursively generate the chained UnaryHandler
-func getChainUnaryHandler(interceptors []UnaryServerInterceptor, curr int, info *UnaryServerInfo, finalHandler UnaryHandler) UnaryHandler {
- if curr == len(interceptors)-1 {
- return finalHandler
- }
-
- return func(ctx context.Context, req interface{}) (interface{}, error) {
- return interceptors[curr+1](ctx, req, info, getChainUnaryHandler(interceptors, curr+1, info, finalHandler))
+func chainUnaryInterceptors(interceptors []UnaryServerInterceptor) UnaryServerInterceptor {
+ return func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (interface{}, error) {
+ var i int
+ var next UnaryHandler
+ next = func(ctx context.Context, req interface{}) (interface{}, error) {
+ if i == len(interceptors)-1 {
+ return interceptors[i](ctx, req, info, handler)
+ }
+ i++
+ return interceptors[i-1](ctx, req, info, next)
+ }
+ return next(ctx, req)
}
}
@@ -1138,7 +1147,9 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
if sh != nil {
beginTime := time.Now()
statsBegin = &stats.Begin{
- BeginTime: beginTime,
+ BeginTime: beginTime,
+ IsClientStream: false,
+ IsServerStream: false,
}
sh.HandleRPC(stream.Context(), statsBegin)
}
@@ -1390,22 +1401,24 @@ func chainStreamServerInterceptors(s *Server) {
} else if len(interceptors) == 1 {
chainedInt = interceptors[0]
} else {
- chainedInt = func(srv interface{}, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error {
- return interceptors[0](srv, ss, info, getChainStreamHandler(interceptors, 0, info, handler))
- }
+ chainedInt = chainStreamInterceptors(interceptors)
}
s.opts.streamInt = chainedInt
}
-// getChainStreamHandler recursively generate the chained StreamHandler
-func getChainStreamHandler(interceptors []StreamServerInterceptor, curr int, info *StreamServerInfo, finalHandler StreamHandler) StreamHandler {
- if curr == len(interceptors)-1 {
- return finalHandler
- }
-
- return func(srv interface{}, ss ServerStream) error {
- return interceptors[curr+1](srv, ss, info, getChainStreamHandler(interceptors, curr+1, info, finalHandler))
+func chainStreamInterceptors(interceptors []StreamServerInterceptor) StreamServerInterceptor {
+ return func(srv interface{}, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error {
+ var i int
+ var next StreamHandler
+ next = func(srv interface{}, ss ServerStream) error {
+ if i == len(interceptors)-1 {
+ return interceptors[i](srv, ss, info, handler)
+ }
+ i++
+ return interceptors[i-1](srv, ss, info, next)
+ }
+ return next(srv, ss)
}
}
@@ -1418,7 +1431,9 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
if sh != nil {
beginTime := time.Now()
statsBegin = &stats.Begin{
- BeginTime: beginTime,
+ BeginTime: beginTime,
+ IsClientStream: sd.ClientStreams,
+ IsServerStream: sd.ServerStreams,
}
sh.HandleRPC(stream.Context(), statsBegin)
}
@@ -1521,6 +1536,8 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
}
}
+ ss.ctx = newContextWithRPCInfo(ss.ctx, false, ss.codec, ss.cp, ss.comp)
+
if trInfo != nil {
trInfo.tr.LazyLog(&trInfo.firstLine, false)
}
@@ -1588,7 +1605,7 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str
trInfo.tr.SetError()
}
errDesc := fmt.Sprintf("malformed method name: %q", stream.Method())
- if err := t.WriteStatus(stream, status.New(codes.ResourceExhausted, errDesc)); err != nil {
+ if err := t.WriteStatus(stream, status.New(codes.Unimplemented, errDesc)); err != nil {
if trInfo != nil {
trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
trInfo.tr.SetError()