summaryrefslogtreecommitdiff
path: root/vendor/google.golang.org/grpc/server.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/google.golang.org/grpc/server.go')
-rw-r--r--vendor/google.golang.org/grpc/server.go105
1 files changed, 52 insertions, 53 deletions
diff --git a/vendor/google.golang.org/grpc/server.go b/vendor/google.golang.org/grpc/server.go
index 0a151dee4..557f29559 100644
--- a/vendor/google.golang.org/grpc/server.go
+++ b/vendor/google.golang.org/grpc/server.go
@@ -710,13 +710,6 @@ func (s *Server) GetServiceInfo() map[string]ServiceInfo {
// the server being stopped.
var ErrServerStopped = errors.New("grpc: the server has been stopped")
-func (s *Server) useTransportAuthenticator(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
- if s.opts.creds == nil {
- return rawConn, nil, nil
- }
- return s.opts.creds.ServerHandshake(rawConn)
-}
-
type listenSocket struct {
net.Listener
channelzID int64
@@ -839,28 +832,14 @@ func (s *Server) handleRawConn(lisAddr string, rawConn net.Conn) {
return
}
rawConn.SetDeadline(time.Now().Add(s.opts.connectionTimeout))
- conn, authInfo, err := s.useTransportAuthenticator(rawConn)
- if err != nil {
- // ErrConnDispatched means that the connection was dispatched away from
- // gRPC; those connections should be left open.
- if err != credentials.ErrConnDispatched {
- s.mu.Lock()
- s.errorf("ServerHandshake(%q) failed: %v", rawConn.RemoteAddr(), err)
- s.mu.Unlock()
- channelz.Warningf(logger, s.channelzID, "grpc: Server.Serve failed to complete security handshake from %q: %v", rawConn.RemoteAddr(), err)
- rawConn.Close()
- }
- rawConn.SetDeadline(time.Time{})
- return
- }
// Finish handshaking (HTTP2)
- st := s.newHTTP2Transport(conn, authInfo)
+ st := s.newHTTP2Transport(rawConn)
+ rawConn.SetDeadline(time.Time{})
if st == nil {
return
}
- rawConn.SetDeadline(time.Time{})
if !s.addConn(lisAddr, st) {
return
}
@@ -881,10 +860,11 @@ func (s *Server) drainServerTransports(addr string) {
// newHTTP2Transport sets up a http/2 transport (using the
// gRPC http2 server transport in transport/http2_server.go).
-func (s *Server) newHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) transport.ServerTransport {
+func (s *Server) newHTTP2Transport(c net.Conn) transport.ServerTransport {
config := &transport.ServerConfig{
MaxStreams: s.opts.maxConcurrentStreams,
- AuthInfo: authInfo,
+ ConnectionTimeout: s.opts.connectionTimeout,
+ Credentials: s.opts.creds,
InTapHandle: s.opts.inTapHandle,
StatsHandler: s.opts.statsHandler,
KeepaliveParams: s.opts.keepaliveParams,
@@ -897,13 +877,22 @@ func (s *Server) newHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) tr
MaxHeaderListSize: s.opts.maxHeaderListSize,
HeaderTableSize: s.opts.headerTableSize,
}
- st, err := transport.NewServerTransport("http2", c, config)
+ st, err := transport.NewServerTransport(c, config)
if err != nil {
s.mu.Lock()
s.errorf("NewServerTransport(%q) failed: %v", c.RemoteAddr(), err)
s.mu.Unlock()
- c.Close()
- channelz.Warning(logger, s.channelzID, "grpc: Server.Serve failed to create ServerTransport: ", err)
+ // ErrConnDispatched means that the connection was dispatched away from
+ // gRPC; those connections should be left open.
+ if err != credentials.ErrConnDispatched {
+ c.Close()
+ }
+ // Don't log on ErrConnDispatched and io.EOF to prevent log spam.
+ if err != credentials.ErrConnDispatched {
+ if err != io.EOF {
+ channelz.Warning(logger, s.channelzID, "grpc: Server.Serve failed to create ServerTransport: ", err)
+ }
+ }
return nil
}
@@ -1109,22 +1098,24 @@ func chainUnaryServerInterceptors(s *Server) {
} else if len(interceptors) == 1 {
chainedInt = interceptors[0]
} else {
- chainedInt = func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (interface{}, error) {
- return interceptors[0](ctx, req, info, getChainUnaryHandler(interceptors, 0, info, handler))
- }
+ chainedInt = chainUnaryInterceptors(interceptors)
}
s.opts.unaryInt = chainedInt
}
-// getChainUnaryHandler recursively generate the chained UnaryHandler
-func getChainUnaryHandler(interceptors []UnaryServerInterceptor, curr int, info *UnaryServerInfo, finalHandler UnaryHandler) UnaryHandler {
- if curr == len(interceptors)-1 {
- return finalHandler
- }
-
- return func(ctx context.Context, req interface{}) (interface{}, error) {
- return interceptors[curr+1](ctx, req, info, getChainUnaryHandler(interceptors, curr+1, info, finalHandler))
+func chainUnaryInterceptors(interceptors []UnaryServerInterceptor) UnaryServerInterceptor {
+ return func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (interface{}, error) {
+ var i int
+ var next UnaryHandler
+ next = func(ctx context.Context, req interface{}) (interface{}, error) {
+ if i == len(interceptors)-1 {
+ return interceptors[i](ctx, req, info, handler)
+ }
+ i++
+ return interceptors[i-1](ctx, req, info, next)
+ }
+ return next(ctx, req)
}
}
@@ -1138,7 +1129,9 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
if sh != nil {
beginTime := time.Now()
statsBegin = &stats.Begin{
- BeginTime: beginTime,
+ BeginTime: beginTime,
+ IsClientStream: false,
+ IsServerStream: false,
}
sh.HandleRPC(stream.Context(), statsBegin)
}
@@ -1390,22 +1383,24 @@ func chainStreamServerInterceptors(s *Server) {
} else if len(interceptors) == 1 {
chainedInt = interceptors[0]
} else {
- chainedInt = func(srv interface{}, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error {
- return interceptors[0](srv, ss, info, getChainStreamHandler(interceptors, 0, info, handler))
- }
+ chainedInt = chainStreamInterceptors(interceptors)
}
s.opts.streamInt = chainedInt
}
-// getChainStreamHandler recursively generate the chained StreamHandler
-func getChainStreamHandler(interceptors []StreamServerInterceptor, curr int, info *StreamServerInfo, finalHandler StreamHandler) StreamHandler {
- if curr == len(interceptors)-1 {
- return finalHandler
- }
-
- return func(srv interface{}, ss ServerStream) error {
- return interceptors[curr+1](srv, ss, info, getChainStreamHandler(interceptors, curr+1, info, finalHandler))
+func chainStreamInterceptors(interceptors []StreamServerInterceptor) StreamServerInterceptor {
+ return func(srv interface{}, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error {
+ var i int
+ var next StreamHandler
+ next = func(srv interface{}, ss ServerStream) error {
+ if i == len(interceptors)-1 {
+ return interceptors[i](srv, ss, info, handler)
+ }
+ i++
+ return interceptors[i-1](srv, ss, info, next)
+ }
+ return next(srv, ss)
}
}
@@ -1418,7 +1413,9 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
if sh != nil {
beginTime := time.Now()
statsBegin = &stats.Begin{
- BeginTime: beginTime,
+ BeginTime: beginTime,
+ IsClientStream: sd.ClientStreams,
+ IsServerStream: sd.ServerStreams,
}
sh.HandleRPC(stream.Context(), statsBegin)
}
@@ -1521,6 +1518,8 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
}
}
+ ss.ctx = newContextWithRPCInfo(ss.ctx, false, ss.codec, ss.cp, ss.comp)
+
if trInfo != nil {
trInfo.tr.LazyLog(&trInfo.firstLine, false)
}
@@ -1588,7 +1587,7 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str
trInfo.tr.SetError()
}
errDesc := fmt.Sprintf("malformed method name: %q", stream.Method())
- if err := t.WriteStatus(stream, status.New(codes.ResourceExhausted, errDesc)); err != nil {
+ if err := t.WriteStatus(stream, status.New(codes.Unimplemented, errDesc)); err != nil {
if trInfo != nil {
trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
trInfo.tr.SetError()