diff options
Diffstat (limited to 'vendor/google.golang.org/grpc/internal/transport')
7 files changed, 396 insertions, 71 deletions
diff --git a/vendor/google.golang.org/grpc/internal/transport/controlbuf.go b/vendor/google.golang.org/grpc/internal/transport/controlbuf.go index 40ef23923..f63a01376 100644 --- a/vendor/google.golang.org/grpc/internal/transport/controlbuf.go +++ b/vendor/google.golang.org/grpc/internal/transport/controlbuf.go @@ -20,13 +20,17 @@ package transport import ( "bytes" + "errors" "fmt" "runtime" + "strconv" "sync" "sync/atomic" "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" + "google.golang.org/grpc/internal/grpcutil" + "google.golang.org/grpc/status" ) var updateHeaderTblSize = func(e *hpack.Encoder, v uint32) { @@ -128,6 +132,14 @@ type cleanupStream struct { func (c *cleanupStream) isTransportResponseFrame() bool { return c.rst } // Results in a RST_STREAM +type earlyAbortStream struct { + streamID uint32 + contentSubtype string + status *status.Status +} + +func (*earlyAbortStream) isTransportResponseFrame() bool { return false } + type dataFrame struct { streamID uint32 endStream bool @@ -749,6 +761,24 @@ func (l *loopyWriter) cleanupStreamHandler(c *cleanupStream) error { return nil } +func (l *loopyWriter) earlyAbortStreamHandler(eas *earlyAbortStream) error { + if l.side == clientSide { + return errors.New("earlyAbortStream not handled on client") + } + + headerFields := []hpack.HeaderField{ + {Name: ":status", Value: "200"}, + {Name: "content-type", Value: grpcutil.ContentType(eas.contentSubtype)}, + {Name: "grpc-status", Value: strconv.Itoa(int(eas.status.Code()))}, + {Name: "grpc-message", Value: encodeGrpcMessage(eas.status.Message())}, + } + + if err := l.writeHeader(eas.streamID, true, headerFields, nil); err != nil { + return err + } + return nil +} + func (l *loopyWriter) incomingGoAwayHandler(*incomingGoAway) error { if l.side == clientSide { l.draining = true @@ -787,6 +817,8 @@ func (l *loopyWriter) handle(i interface{}) error { return l.registerStreamHandler(i) case *cleanupStream: return l.cleanupStreamHandler(i) + case *earlyAbortStream: + return l.earlyAbortStreamHandler(i) case *incomingGoAway: return l.incomingGoAwayHandler(i) case *dataFrame: diff --git a/vendor/google.golang.org/grpc/internal/transport/http2_client.go b/vendor/google.golang.org/grpc/internal/transport/http2_client.go index e73b77a15..48c5e52ed 100644 --- a/vendor/google.golang.org/grpc/internal/transport/http2_client.go +++ b/vendor/google.golang.org/grpc/internal/transport/http2_client.go @@ -32,13 +32,14 @@ import ( "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" - "google.golang.org/grpc/internal/grpcutil" - "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" - "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/channelz" + icredentials "google.golang.org/grpc/internal/credentials" + "google.golang.org/grpc/internal/grpcutil" + imetadata "google.golang.org/grpc/internal/metadata" "google.golang.org/grpc/internal/syscall" + "google.golang.org/grpc/internal/transport/networktype" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" @@ -59,7 +60,7 @@ type http2Client struct { cancel context.CancelFunc ctxDone <-chan struct{} // Cache the ctx.Done() chan. userAgent string - md interface{} + md metadata.MD conn net.Conn // underlying communication channel loopy *loopyWriter remoteAddr net.Addr @@ -114,6 +115,9 @@ type http2Client struct { // goAwayReason records the http2.ErrCode and debug data received with the // GoAway frame. goAwayReason GoAwayReason + // goAwayDebugMessage contains a detailed human readable string about a + // GoAway frame, useful for error messages. + goAwayDebugMessage string // A condition variable used to signal when the keepalive goroutine should // go dormant. The condition for dormancy is based on the number of active // streams and the `PermitWithoutStream` keepalive client parameter. And @@ -137,11 +141,27 @@ type http2Client struct { connectionID uint64 } -func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr string) (net.Conn, error) { +func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr resolver.Address, useProxy bool, grpcUA string) (net.Conn, error) { + address := addr.Addr + networkType, ok := networktype.Get(addr) if fn != nil { - return fn(ctx, addr) + if networkType == "unix" && !strings.HasPrefix(address, "\x00") { + // For backward compatibility, if the user dialed "unix:///path", + // the passthrough resolver would be used and the user's custom + // dialer would see "unix:///path". Since the unix resolver is used + // and the address is now "/path", prepend "unix://" so the user's + // custom dialer sees the same address. + return fn(ctx, "unix://"+address) + } + return fn(ctx, address) } - return (&net.Dialer{}).DialContext(ctx, "tcp", addr) + if !ok { + networkType, address = parseDialTarget(address) + } + if networkType == "tcp" && useProxy { + return proxyDial(ctx, address, grpcUA) + } + return (&net.Dialer{}).DialContext(ctx, networkType, address) } func isTemporary(err error) bool { @@ -172,7 +192,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts } }() - conn, err := dial(connectCtx, opts.Dialer, addr.Addr) + conn, err := dial(connectCtx, opts.Dialer, addr, opts.UseProxy, opts.UserAgent) if err != nil { if opts.FailOnNonTempDialError { return nil, connectionErrorf(isTemporary(err), err, "transport: error while dialing: %v", err) @@ -220,12 +240,23 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts // Attributes field of resolver.Address, which is shoved into connectCtx // and passed to the credential handshaker. This makes it possible for // address specific arbitrary data to reach the credential handshaker. - contextWithHandshakeInfo := internal.NewClientHandshakeInfoContext.(func(context.Context, credentials.ClientHandshakeInfo) context.Context) - connectCtx = contextWithHandshakeInfo(connectCtx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes}) + connectCtx = icredentials.NewClientHandshakeInfoContext(connectCtx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes}) conn, authInfo, err = transportCreds.ClientHandshake(connectCtx, addr.ServerName, conn) if err != nil { return nil, connectionErrorf(isTemporary(err), err, "transport: authentication handshake failed: %v", err) } + for _, cd := range perRPCCreds { + if cd.RequireTransportSecurity() { + if ci, ok := authInfo.(interface { + GetCommonAuthInfo() credentials.CommonAuthInfo + }); ok { + secLevel := ci.GetCommonAuthInfo().SecurityLevel + if secLevel != credentials.InvalidSecurityLevel && secLevel < credentials.PrivacyAndIntegrity { + return nil, connectionErrorf(true, nil, "transport: cannot send secure credentials on an insecure connection") + } + } + } + } isSecure = true if transportCreds.Info().SecurityProtocol == "tls" { scheme = "https" @@ -248,7 +279,6 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts ctxDone: ctx.Done(), // Cache Done chan. cancel: cancel, userAgent: opts.UserAgent, - md: addr.Metadata, conn: conn, remoteAddr: conn.RemoteAddr(), localAddr: conn.LocalAddr(), @@ -276,6 +306,12 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts keepaliveEnabled: keepaliveEnabled, bufferPool: newBufferPool(), } + + if md, ok := addr.Metadata.(*metadata.MD); ok { + t.md = *md + } else if md := imetadata.Get(addr); md != nil { + t.md = md + } t.controlBuf = newControlBuffer(t.ctxDone) if opts.InitialWindowSize >= defaultWindowSize { t.initialWindowSize = opts.InitialWindowSize @@ -312,12 +348,14 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts // Send connection preface to server. n, err := t.conn.Write(clientPreface) if err != nil { - t.Close() - return nil, connectionErrorf(true, err, "transport: failed to write client preface: %v", err) + err = connectionErrorf(true, err, "transport: failed to write client preface: %v", err) + t.Close(err) + return nil, err } if n != len(clientPreface) { - t.Close() - return nil, connectionErrorf(true, err, "transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface)) + err = connectionErrorf(true, nil, "transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface)) + t.Close(err) + return nil, err } var ss []http2.Setting @@ -335,14 +373,16 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts } err = t.framer.fr.WriteSettings(ss...) if err != nil { - t.Close() - return nil, connectionErrorf(true, err, "transport: failed to write initial settings frame: %v", err) + err = connectionErrorf(true, err, "transport: failed to write initial settings frame: %v", err) + t.Close(err) + return nil, err } // Adjust the connection flow control window if needed. if delta := uint32(icwz - defaultWindowSize); delta > 0 { if err := t.framer.fr.WriteWindowUpdate(0, delta); err != nil { - t.Close() - return nil, connectionErrorf(true, err, "transport: failed to write window update: %v", err) + err = connectionErrorf(true, err, "transport: failed to write window update: %v", err) + t.Close(err) + return nil, err } } @@ -379,6 +419,7 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream { buf: newRecvBuffer(), headerChan: make(chan struct{}), contentSubtype: callHdr.ContentSubtype, + doneFunc: callHdr.DoneFunc, } s.wq = newWriteQuota(defaultWriteQuota, s.done) s.requestRead = func(n int) { @@ -418,7 +459,7 @@ func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr) Method: callHdr.Method, AuthInfo: t.authInfo, } - ctxWithRequestInfo := internal.NewRequestInfoContext.(func(context.Context, credentials.RequestInfo) context.Context)(ctx, ri) + ctxWithRequestInfo := icredentials.NewRequestInfoContext(ctx, ri) authData, err := t.getTrAuthData(ctxWithRequestInfo, aud) if err != nil { return nil, err @@ -481,25 +522,23 @@ func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr) for _, vv := range added { for i, v := range vv { if i%2 == 0 { - k = v + k = strings.ToLower(v) continue } // HTTP doesn't allow you to set pseudoheaders after non pseudoheaders were set. if isReservedHeader(k) { continue } - headerFields = append(headerFields, hpack.HeaderField{Name: strings.ToLower(k), Value: encodeMetadataHeader(k, v)}) + headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) } } } - if md, ok := t.md.(*metadata.MD); ok { - for k, vv := range *md { - if isReservedHeader(k) { - continue - } - for _, v := range vv { - headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) - } + for k, vv := range t.md { + if isReservedHeader(k) { + continue + } + for _, v := range vv { + headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) } } return headerFields, nil @@ -549,8 +588,11 @@ func (t *http2Client) getCallAuthData(ctx context.Context, audience string, call // Note: if these credentials are provided both via dial options and call // options, then both sets of credentials will be applied. if callCreds := callHdr.Creds; callCreds != nil { - if !t.isSecure && callCreds.RequireTransportSecurity() { - return nil, status.Error(codes.Unauthenticated, "transport: cannot send secure credentials on an insecure connection") + if callCreds.RequireTransportSecurity() { + ri, _ := credentials.RequestInfoFromContext(ctx) + if !t.isSecure || credentials.CheckSecurityLevel(ri.AuthInfo, credentials.PrivacyAndIntegrity) != nil { + return nil, status.Error(codes.Unauthenticated, "transport: cannot send secure credentials on an insecure connection") + } } data, err := callCreds.GetRequestMetadata(ctx, audience) if err != nil { @@ -796,6 +838,9 @@ func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2. t.controlBuf.executeAndPut(addBackStreamQuota, cleanup) // This will unblock write. close(s.done) + if s.doneFunc != nil { + s.doneFunc() + } } // Close kicks off the shutdown process of the transport. This should be called @@ -805,12 +850,12 @@ func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2. // This method blocks until the addrConn that initiated this transport is // re-connected. This happens because t.onClose() begins reconnect logic at the // addrConn level and blocks until the addrConn is successfully connected. -func (t *http2Client) Close() error { +func (t *http2Client) Close(err error) { t.mu.Lock() // Make sure we only Close once. if t.state == closing { t.mu.Unlock() - return nil + return } // Call t.onClose before setting the state to closing to prevent the client // from attempting to create new streams ASAP. @@ -826,13 +871,19 @@ func (t *http2Client) Close() error { t.mu.Unlock() t.controlBuf.finish() t.cancel() - err := t.conn.Close() + t.conn.Close() if channelz.IsOn() { channelz.RemoveEntry(t.channelzID) } + // Append info about previous goaways if there were any, since this may be important + // for understanding the root cause for this connection to be closed. + _, goAwayDebugMessage := t.GetGoAwayReason() + if len(goAwayDebugMessage) > 0 { + err = fmt.Errorf("closing transport due to: %v, received prior goaway: %v", err, goAwayDebugMessage) + } // Notify all active streams. for _, s := range streams { - t.closeStream(s, ErrConnClosing, false, http2.ErrCodeNo, status.New(codes.Unavailable, ErrConnClosing.Desc), nil, false) + t.closeStream(s, err, false, http2.ErrCodeNo, status.New(codes.Unavailable, err.Error()), nil, false) } if t.statsHandler != nil { connEnd := &stats.ConnEnd{ @@ -840,7 +891,6 @@ func (t *http2Client) Close() error { } t.statsHandler.HandleConn(t.ctx, connEnd) } - return err } // GracefulClose sets the state to draining, which prevents new streams from @@ -859,7 +909,7 @@ func (t *http2Client) GracefulClose() { active := len(t.activeStreams) t.mu.Unlock() if active == 0 { - t.Close() + t.Close(ErrConnClosing) return } t.controlBuf.put(&incomingGoAway{}) @@ -1105,9 +1155,9 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { } } id := f.LastStreamID - if id > 0 && id%2 != 1 { + if id > 0 && id%2 == 0 { t.mu.Unlock() - t.Close() + t.Close(connectionErrorf(true, nil, "received goaway with non-zero even-numbered numbered stream id: %v", id)) return } // A client can receive multiple GoAways from the server (see @@ -1125,7 +1175,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { // If there are multiple GoAways the first one should always have an ID greater than the following ones. if id > t.prevGoAwayID { t.mu.Unlock() - t.Close() + t.Close(connectionErrorf(true, nil, "received goaway with stream id: %v, which exceeds stream id of previous goaway: %v", id, t.prevGoAwayID)) return } default: @@ -1155,7 +1205,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { active := len(t.activeStreams) t.mu.Unlock() if active == 0 { - t.Close() + t.Close(connectionErrorf(true, nil, "received goaway and there are no active streams")) } } @@ -1171,12 +1221,13 @@ func (t *http2Client) setGoAwayReason(f *http2.GoAwayFrame) { t.goAwayReason = GoAwayTooManyPings } } + t.goAwayDebugMessage = fmt.Sprintf("code: %s, debug data: %v", f.ErrCode, string(f.DebugData())) } -func (t *http2Client) GetGoAwayReason() GoAwayReason { +func (t *http2Client) GetGoAwayReason() (GoAwayReason, string) { t.mu.Lock() defer t.mu.Unlock() - return t.goAwayReason + return t.goAwayReason, t.goAwayDebugMessage } func (t *http2Client) handleWindowUpdate(f *http2.WindowUpdateFrame) { @@ -1273,7 +1324,8 @@ func (t *http2Client) reader() { // Check the validity of server preface. frame, err := t.framer.fr.ReadFrame() if err != nil { - t.Close() // this kicks off resetTransport, so must be last before return + err = connectionErrorf(true, err, "error reading server preface: %v", err) + t.Close(err) // this kicks off resetTransport, so must be last before return return } t.conn.SetReadDeadline(time.Time{}) // reset deadline once we get the settings frame (we didn't time out, yay!) @@ -1282,7 +1334,8 @@ func (t *http2Client) reader() { } sf, ok := frame.(*http2.SettingsFrame) if !ok { - t.Close() // this kicks off resetTransport, so must be last before return + // this kicks off resetTransport, so must be last before return + t.Close(connectionErrorf(true, nil, "initial http2 frame from server is not a settings frame: %T", frame)) return } t.onPrefaceReceipt() @@ -1318,7 +1371,7 @@ func (t *http2Client) reader() { continue } else { // Transport error. - t.Close() + t.Close(connectionErrorf(true, err, "error reading from server: %v", err)) return } } @@ -1377,7 +1430,7 @@ func (t *http2Client) keepalive() { continue } if outstandingPing && timeoutLeft <= 0 { - t.Close() + t.Close(connectionErrorf(true, nil, "keepalive ping failed to receive ACK within timeout")) return } t.mu.Lock() diff --git a/vendor/google.golang.org/grpc/internal/transport/http2_server.go b/vendor/google.golang.org/grpc/internal/transport/http2_server.go index 0cf1cc320..11be5599c 100644 --- a/vendor/google.golang.org/grpc/internal/transport/http2_server.go +++ b/vendor/google.golang.org/grpc/internal/transport/http2_server.go @@ -26,6 +26,7 @@ import ( "io" "math" "net" + "net/http" "strconv" "sync" "sync/atomic" @@ -355,26 +356,6 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( if state.data.statsTrace != nil { s.ctx = stats.SetIncomingTrace(s.ctx, state.data.statsTrace) } - if t.inTapHandle != nil { - var err error - info := &tap.Info{ - FullMethodName: state.data.method, - } - s.ctx, err = t.inTapHandle(s.ctx, info) - if err != nil { - if logger.V(logLevel) { - logger.Warningf("transport: http2Server.operateHeaders got an error from InTapHandle: %v", err) - } - t.controlBuf.put(&cleanupStream{ - streamID: s.id, - rst: true, - rstCode: http2.ErrCodeRefusedStream, - onWrite: func() {}, - }) - s.cancel() - return false - } - } t.mu.Lock() if t.state != reachable { t.mu.Unlock() @@ -402,6 +383,39 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( return true } t.maxStreamID = streamID + if state.data.httpMethod != http.MethodPost { + t.mu.Unlock() + if logger.V(logLevel) { + logger.Warningf("transport: http2Server.operateHeaders parsed a :method field: %v which should be POST", state.data.httpMethod) + } + t.controlBuf.put(&cleanupStream{ + streamID: streamID, + rst: true, + rstCode: http2.ErrCodeProtocol, + onWrite: func() {}, + }) + s.cancel() + return false + } + if t.inTapHandle != nil { + var err error + if s.ctx, err = t.inTapHandle(s.ctx, &tap.Info{FullMethodName: state.data.method}); err != nil { + t.mu.Unlock() + if logger.V(logLevel) { + logger.Infof("transport: http2Server.operateHeaders got an error from InTapHandle: %v", err) + } + stat, ok := status.FromError(err) + if !ok { + stat = status.New(codes.PermissionDenied, err.Error()) + } + t.controlBuf.put(&earlyAbortStream{ + streamID: s.id, + contentSubtype: s.contentSubtype, + status: stat, + }) + return false + } + } t.activeStreams[streamID] = s if len(t.activeStreams) == 1 { t.idle = time.Time{} diff --git a/vendor/google.golang.org/grpc/internal/transport/http_util.go b/vendor/google.golang.org/grpc/internal/transport/http_util.go index 4d15afbf7..c7dee140c 100644 --- a/vendor/google.golang.org/grpc/internal/transport/http_util.go +++ b/vendor/google.golang.org/grpc/internal/transport/http_util.go @@ -27,6 +27,7 @@ import ( "math" "net" "net/http" + "net/url" "strconv" "strings" "time" @@ -110,6 +111,7 @@ type parsedHeaderData struct { timeoutSet bool timeout time.Duration method string + httpMethod string // key-value metadata map from the peer. mdata map[string][]string statsTags []byte @@ -362,6 +364,8 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) { } d.data.statsTrace = v d.addMetadata(f.Name, string(v)) + case ":method": + d.data.httpMethod = f.Value default: if isReservedHeader(f.Name) && !isWhitelistedHeader(f.Name) { break @@ -598,3 +602,31 @@ func newFramer(conn net.Conn, writeBufferSize, readBufferSize int, maxHeaderList f.fr.ReadMetaHeaders = hpack.NewDecoder(http2InitHeaderTableSize, nil) return f } + +// parseDialTarget returns the network and address to pass to dialer. +func parseDialTarget(target string) (string, string) { + net := "tcp" + m1 := strings.Index(target, ":") + m2 := strings.Index(target, ":/") + // handle unix:addr which will fail with url.Parse + if m1 >= 0 && m2 < 0 { + if n := target[0:m1]; n == "unix" { + return n, target[m1+1:] + } + } + if m2 >= 0 { + t, err := url.Parse(target) + if err != nil { + return net, target + } + scheme := t.Scheme + addr := t.Path + if scheme == "unix" { + if addr == "" { + addr = t.Host + } + return scheme, addr + } + } + return net, target +} diff --git a/vendor/google.golang.org/grpc/internal/transport/networktype/networktype.go b/vendor/google.golang.org/grpc/internal/transport/networktype/networktype.go new file mode 100644 index 000000000..96967428b --- /dev/null +++ b/vendor/google.golang.org/grpc/internal/transport/networktype/networktype.go @@ -0,0 +1,46 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package networktype declares the network type to be used in the default +// dailer. Attribute of a resolver.Address. +package networktype + +import ( + "google.golang.org/grpc/resolver" +) + +// keyType is the key to use for storing State in Attributes. +type keyType string + +const key = keyType("grpc.internal.transport.networktype") + +// Set returns a copy of the provided address with attributes containing networkType. +func Set(address resolver.Address, networkType string) resolver.Address { + address.Attributes = address.Attributes.WithValues(key, networkType) + return address +} + +// Get returns the network type in the resolver.Address and true, or "", false +// if not present. +func Get(address resolver.Address) (string, bool) { + v := address.Attributes.Value(key) + if v == nil { + return "", false + } + return v.(string), true +} diff --git a/vendor/google.golang.org/grpc/internal/transport/proxy.go b/vendor/google.golang.org/grpc/internal/transport/proxy.go new file mode 100644 index 000000000..a662bf39a --- /dev/null +++ b/vendor/google.golang.org/grpc/internal/transport/proxy.go @@ -0,0 +1,142 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package transport + +import ( + "bufio" + "context" + "encoding/base64" + "fmt" + "io" + "net" + "net/http" + "net/http/httputil" + "net/url" +) + +const proxyAuthHeaderKey = "Proxy-Authorization" + +var ( + // The following variable will be overwritten in the tests. + httpProxyFromEnvironment = http.ProxyFromEnvironment +) + +func mapAddress(ctx context.Context, address string) (*url.URL, error) { + req := &http.Request{ + URL: &url.URL{ + Scheme: "https", + Host: address, + }, + } + url, err := httpProxyFromEnvironment(req) + if err != nil { + return nil, err + } + return url, nil +} + +// To read a response from a net.Conn, http.ReadResponse() takes a bufio.Reader. +// It's possible that this reader reads more than what's need for the response and stores +// those bytes in the buffer. +// bufConn wraps the original net.Conn and the bufio.Reader to make sure we don't lose the +// bytes in the buffer. +type bufConn struct { + net.Conn + r io.Reader +} + +func (c *bufConn) Read(b []byte) (int, error) { + return c.r.Read(b) +} + +func basicAuth(username, password string) string { + auth := username + ":" + password + return base64.StdEncoding.EncodeToString([]byte(auth)) +} + +func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, backendAddr string, proxyURL *url.URL, grpcUA string) (_ net.Conn, err error) { + defer func() { + if err != nil { + conn.Close() + } + }() + + req := &http.Request{ + Method: http.MethodConnect, + URL: &url.URL{Host: backendAddr}, + Header: map[string][]string{"User-Agent": {grpcUA}}, + } + if t := proxyURL.User; t != nil { + u := t.Username() + p, _ := t.Password() + req.Header.Add(proxyAuthHeaderKey, "Basic "+basicAuth(u, p)) + } + + if err := sendHTTPRequest(ctx, req, conn); err != nil { + return nil, fmt.Errorf("failed to write the HTTP request: %v", err) + } + + r := bufio.NewReader(conn) + resp, err := http.ReadResponse(r, req) + if err != nil { + return nil, fmt.Errorf("reading server HTTP response: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + dump, err := httputil.DumpResponse(resp, true) + if err != nil { + return nil, fmt.Errorf("failed to do connect handshake, status code: %s", resp.Status) + } + return nil, fmt.Errorf("failed to do connect handshake, response: %q", dump) + } + + return &bufConn{Conn: conn, r: r}, nil +} + +// proxyDial dials, connecting to a proxy first if necessary. Checks if a proxy +// is necessary, dials, does the HTTP CONNECT handshake, and returns the +// connection. +func proxyDial(ctx context.Context, addr string, grpcUA string) (conn net.Conn, err error) { + newAddr := addr + proxyURL, err := mapAddress(ctx, addr) + if err != nil { + return nil, err + } + if proxyURL != nil { + newAddr = proxyURL.Host + } + + conn, err = (&net.Dialer{}).DialContext(ctx, "tcp", newAddr) + if err != nil { + return + } + if proxyURL != nil { + // proxy is disabled if proxyURL is nil. + conn, err = doHTTPConnectHandshake(ctx, conn, addr, proxyURL, grpcUA) + } + return +} + +func sendHTTPRequest(ctx context.Context, req *http.Request, conn net.Conn) error { + req = req.WithContext(ctx) + if err := req.Write(conn); err != nil { + return fmt.Errorf("failed to write the HTTP request: %v", err) + } + return nil +} diff --git a/vendor/google.golang.org/grpc/internal/transport/transport.go b/vendor/google.golang.org/grpc/internal/transport/transport.go index b74030a96..6cc1031fd 100644 --- a/vendor/google.golang.org/grpc/internal/transport/transport.go +++ b/vendor/google.golang.org/grpc/internal/transport/transport.go @@ -241,6 +241,7 @@ type Stream struct { ctx context.Context // the associated context of the stream cancel context.CancelFunc // always nil for client side Stream done chan struct{} // closed at the end of stream to unblock writers. On the client side. + doneFunc func() // invoked at the end of stream on client side. ctxDone <-chan struct{} // same as done chan but for server side. Cache of ctx.Done() (for performance) method string // the associated RPC method of the stream recvCompress string @@ -569,6 +570,8 @@ type ConnectOptions struct { ChannelzParentID int64 // MaxHeaderListSize sets the max (uncompressed) size of header list that is prepared to be received. MaxHeaderListSize *uint32 + // UseProxy specifies if a proxy should be used. + UseProxy bool } // NewClientTransport establishes the transport with the required ConnectOptions @@ -609,6 +612,8 @@ type CallHdr struct { ContentSubtype string PreviousAttempts int // value of grpc-previous-rpc-attempts header to set + + DoneFunc func() // called when the stream is finished } // ClientTransport is the common interface for all gRPC client-side transport @@ -617,7 +622,7 @@ type ClientTransport interface { // Close tears down this transport. Once it returns, the transport // should not be accessed any more. The caller must make sure this // is called only once. - Close() error + Close(err error) // GracefulClose starts to tear down the transport: the transport will stop // accepting new RPCs and NewStream will return error. Once all streams are @@ -651,8 +656,9 @@ type ClientTransport interface { // HTTP/2). GoAway() <-chan struct{} - // GetGoAwayReason returns the reason why GoAway frame was received. - GetGoAwayReason() GoAwayReason + // GetGoAwayReason returns the reason why GoAway frame was received, along + // with a human readable string with debug info. + GetGoAwayReason() (GoAwayReason, string) // RemoteAddr returns the remote network address. RemoteAddr() net.Addr |