diff options
Diffstat (limited to 'vendor/google.golang.org/grpc/internal/transport')
9 files changed, 5975 insertions, 0 deletions
diff --git a/vendor/google.golang.org/grpc/internal/transport/bdp_estimator.go b/vendor/google.golang.org/grpc/internal/transport/bdp_estimator.go new file mode 100644 index 000000000..070680edb --- /dev/null +++ b/vendor/google.golang.org/grpc/internal/transport/bdp_estimator.go @@ -0,0 +1,141 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package transport + +import ( + "sync" + "time" +) + +const ( + // bdpLimit is the maximum value the flow control windows will be increased + // to. TCP typically limits this to 4MB, but some systems go up to 16MB. + // Since this is only a limit, it is safe to make it optimistic. + bdpLimit = (1 << 20) * 16 + // alpha is a constant factor used to keep a moving average + // of RTTs. + alpha = 0.9 + // If the current bdp sample is greater than or equal to + // our beta * our estimated bdp and the current bandwidth + // sample is the maximum bandwidth observed so far, we + // increase our bbp estimate by a factor of gamma. + beta = 0.66 + // To put our bdp to be smaller than or equal to twice the real BDP, + // we should multiply our current sample with 4/3, however to round things out + // we use 2 as the multiplication factor. + gamma = 2 +) + +// Adding arbitrary data to ping so that its ack can be identified. +// Easter-egg: what does the ping message say? +var bdpPing = &ping{data: [8]byte{2, 4, 16, 16, 9, 14, 7, 7}} + +type bdpEstimator struct { + // sentAt is the time when the ping was sent. + sentAt time.Time + + mu sync.Mutex + // bdp is the current bdp estimate. + bdp uint32 + // sample is the number of bytes received in one measurement cycle. + sample uint32 + // bwMax is the maximum bandwidth noted so far (bytes/sec). + bwMax float64 + // bool to keep track of the beginning of a new measurement cycle. + isSent bool + // Callback to update the window sizes. + updateFlowControl func(n uint32) + // sampleCount is the number of samples taken so far. + sampleCount uint64 + // round trip time (seconds) + rtt float64 +} + +// timesnap registers the time bdp ping was sent out so that +// network rtt can be calculated when its ack is received. +// It is called (by controller) when the bdpPing is +// being written on the wire. +func (b *bdpEstimator) timesnap(d [8]byte) { + if bdpPing.data != d { + return + } + b.sentAt = time.Now() +} + +// add adds bytes to the current sample for calculating bdp. +// It returns true only if a ping must be sent. This can be used +// by the caller (handleData) to make decision about batching +// a window update with it. +func (b *bdpEstimator) add(n uint32) bool { + b.mu.Lock() + defer b.mu.Unlock() + if b.bdp == bdpLimit { + return false + } + if !b.isSent { + b.isSent = true + b.sample = n + b.sentAt = time.Time{} + b.sampleCount++ + return true + } + b.sample += n + return false +} + +// calculate is called when an ack for a bdp ping is received. +// Here we calculate the current bdp and bandwidth sample and +// decide if the flow control windows should go up. +func (b *bdpEstimator) calculate(d [8]byte) { + // Check if the ping acked for was the bdp ping. + if bdpPing.data != d { + return + } + b.mu.Lock() + rttSample := time.Since(b.sentAt).Seconds() + if b.sampleCount < 10 { + // Bootstrap rtt with an average of first 10 rtt samples. + b.rtt += (rttSample - b.rtt) / float64(b.sampleCount) + } else { + // Heed to the recent past more. + b.rtt += (rttSample - b.rtt) * float64(alpha) + } + b.isSent = false + // The number of bytes accumulated so far in the sample is smaller + // than or equal to 1.5 times the real BDP on a saturated connection. + bwCurrent := float64(b.sample) / (b.rtt * float64(1.5)) + if bwCurrent > b.bwMax { + b.bwMax = bwCurrent + } + // If the current sample (which is smaller than or equal to the 1.5 times the real BDP) is + // greater than or equal to 2/3rd our perceived bdp AND this is the maximum bandwidth seen so far, we + // should update our perception of the network BDP. + if float64(b.sample) >= beta*float64(b.bdp) && bwCurrent == b.bwMax && b.bdp != bdpLimit { + sampleFloat := float64(b.sample) + b.bdp = uint32(gamma * sampleFloat) + if b.bdp > bdpLimit { + b.bdp = bdpLimit + } + bdp := b.bdp + b.mu.Unlock() + b.updateFlowControl(bdp) + return + } + b.mu.Unlock() +} diff --git a/vendor/google.golang.org/grpc/internal/transport/controlbuf.go b/vendor/google.golang.org/grpc/internal/transport/controlbuf.go new file mode 100644 index 000000000..40ef23923 --- /dev/null +++ b/vendor/google.golang.org/grpc/internal/transport/controlbuf.go @@ -0,0 +1,942 @@ +/* + * + * Copyright 2014 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package transport + +import ( + "bytes" + "fmt" + "runtime" + "sync" + "sync/atomic" + + "golang.org/x/net/http2" + "golang.org/x/net/http2/hpack" +) + +var updateHeaderTblSize = func(e *hpack.Encoder, v uint32) { + e.SetMaxDynamicTableSizeLimit(v) +} + +type itemNode struct { + it interface{} + next *itemNode +} + +type itemList struct { + head *itemNode + tail *itemNode +} + +func (il *itemList) enqueue(i interface{}) { + n := &itemNode{it: i} + if il.tail == nil { + il.head, il.tail = n, n + return + } + il.tail.next = n + il.tail = n +} + +// peek returns the first item in the list without removing it from the +// list. +func (il *itemList) peek() interface{} { + return il.head.it +} + +func (il *itemList) dequeue() interface{} { + if il.head == nil { + return nil + } + i := il.head.it + il.head = il.head.next + if il.head == nil { + il.tail = nil + } + return i +} + +func (il *itemList) dequeueAll() *itemNode { + h := il.head + il.head, il.tail = nil, nil + return h +} + +func (il *itemList) isEmpty() bool { + return il.head == nil +} + +// The following defines various control items which could flow through +// the control buffer of transport. They represent different aspects of +// control tasks, e.g., flow control, settings, streaming resetting, etc. + +// maxQueuedTransportResponseFrames is the most queued "transport response" +// frames we will buffer before preventing new reads from occurring on the +// transport. These are control frames sent in response to client requests, +// such as RST_STREAM due to bad headers or settings acks. +const maxQueuedTransportResponseFrames = 50 + +type cbItem interface { + isTransportResponseFrame() bool +} + +// registerStream is used to register an incoming stream with loopy writer. +type registerStream struct { + streamID uint32 + wq *writeQuota +} + +func (*registerStream) isTransportResponseFrame() bool { return false } + +// headerFrame is also used to register stream on the client-side. +type headerFrame struct { + streamID uint32 + hf []hpack.HeaderField + endStream bool // Valid on server side. + initStream func(uint32) error // Used only on the client side. + onWrite func() + wq *writeQuota // write quota for the stream created. + cleanup *cleanupStream // Valid on the server side. + onOrphaned func(error) // Valid on client-side +} + +func (h *headerFrame) isTransportResponseFrame() bool { + return h.cleanup != nil && h.cleanup.rst // Results in a RST_STREAM +} + +type cleanupStream struct { + streamID uint32 + rst bool + rstCode http2.ErrCode + onWrite func() +} + +func (c *cleanupStream) isTransportResponseFrame() bool { return c.rst } // Results in a RST_STREAM + +type dataFrame struct { + streamID uint32 + endStream bool + h []byte + d []byte + // onEachWrite is called every time + // a part of d is written out. + onEachWrite func() +} + +func (*dataFrame) isTransportResponseFrame() bool { return false } + +type incomingWindowUpdate struct { + streamID uint32 + increment uint32 +} + +func (*incomingWindowUpdate) isTransportResponseFrame() bool { return false } + +type outgoingWindowUpdate struct { + streamID uint32 + increment uint32 +} + +func (*outgoingWindowUpdate) isTransportResponseFrame() bool { + return false // window updates are throttled by thresholds +} + +type incomingSettings struct { + ss []http2.Setting +} + +func (*incomingSettings) isTransportResponseFrame() bool { return true } // Results in a settings ACK + +type outgoingSettings struct { + ss []http2.Setting +} + +func (*outgoingSettings) isTransportResponseFrame() bool { return false } + +type incomingGoAway struct { +} + +func (*incomingGoAway) isTransportResponseFrame() bool { return false } + +type goAway struct { + code http2.ErrCode + debugData []byte + headsUp bool + closeConn bool +} + +func (*goAway) isTransportResponseFrame() bool { return false } + +type ping struct { + ack bool + data [8]byte +} + +func (*ping) isTransportResponseFrame() bool { return true } + +type outFlowControlSizeRequest struct { + resp chan uint32 +} + +func (*outFlowControlSizeRequest) isTransportResponseFrame() bool { return false } + +type outStreamState int + +const ( + active outStreamState = iota + empty + waitingOnStreamQuota +) + +type outStream struct { + id uint32 + state outStreamState + itl *itemList + bytesOutStanding int + wq *writeQuota + + next *outStream + prev *outStream +} + +func (s *outStream) deleteSelf() { + if s.prev != nil { + s.prev.next = s.next + } + if s.next != nil { + s.next.prev = s.prev + } + s.next, s.prev = nil, nil +} + +type outStreamList struct { + // Following are sentinel objects that mark the + // beginning and end of the list. They do not + // contain any item lists. All valid objects are + // inserted in between them. + // This is needed so that an outStream object can + // deleteSelf() in O(1) time without knowing which + // list it belongs to. + head *outStream + tail *outStream +} + +func newOutStreamList() *outStreamList { + head, tail := new(outStream), new(outStream) + head.next = tail + tail.prev = head + return &outStreamList{ + head: head, + tail: tail, + } +} + +func (l *outStreamList) enqueue(s *outStream) { + e := l.tail.prev + e.next = s + s.prev = e + s.next = l.tail + l.tail.prev = s +} + +// remove from the beginning of the list. +func (l *outStreamList) dequeue() *outStream { + b := l.head.next + if b == l.tail { + return nil + } + b.deleteSelf() + return b +} + +// controlBuffer is a way to pass information to loopy. +// Information is passed as specific struct types called control frames. +// A control frame not only represents data, messages or headers to be sent out +// but can also be used to instruct loopy to update its internal state. +// It shouldn't be confused with an HTTP2 frame, although some of the control frames +// like dataFrame and headerFrame do go out on wire as HTTP2 frames. +type controlBuffer struct { + ch chan struct{} + done <-chan struct{} + mu sync.Mutex + consumerWaiting bool + list *itemList + err error + + // transportResponseFrames counts the number of queued items that represent + // the response of an action initiated by the peer. trfChan is created + // when transportResponseFrames >= maxQueuedTransportResponseFrames and is + // closed and nilled when transportResponseFrames drops below the + // threshold. Both fields are protected by mu. + transportResponseFrames int + trfChan atomic.Value // *chan struct{} +} + +func newControlBuffer(done <-chan struct{}) *controlBuffer { + return &controlBuffer{ + ch: make(chan struct{}, 1), + list: &itemList{}, + done: done, + } +} + +// throttle blocks if there are too many incomingSettings/cleanupStreams in the +// controlbuf. +func (c *controlBuffer) throttle() { + ch, _ := c.trfChan.Load().(*chan struct{}) + if ch != nil { + select { + case <-*ch: + case <-c.done: + } + } +} + +func (c *controlBuffer) put(it cbItem) error { + _, err := c.executeAndPut(nil, it) + return err +} + +func (c *controlBuffer) executeAndPut(f func(it interface{}) bool, it cbItem) (bool, error) { + var wakeUp bool + c.mu.Lock() + if c.err != nil { + c.mu.Unlock() + return false, c.err + } + if f != nil { + if !f(it) { // f wasn't successful + c.mu.Unlock() + return false, nil + } + } + if c.consumerWaiting { + wakeUp = true + c.consumerWaiting = false + } + c.list.enqueue(it) + if it.isTransportResponseFrame() { + c.transportResponseFrames++ + if c.transportResponseFrames == maxQueuedTransportResponseFrames { + // We are adding the frame that puts us over the threshold; create + // a throttling channel. + ch := make(chan struct{}) + c.trfChan.Store(&ch) + } + } + c.mu.Unlock() + if wakeUp { + select { + case c.ch <- struct{}{}: + default: + } + } + return true, nil +} + +// Note argument f should never be nil. +func (c *controlBuffer) execute(f func(it interface{}) bool, it interface{}) (bool, error) { + c.mu.Lock() + if c.err != nil { + c.mu.Unlock() + return false, c.err + } + if !f(it) { // f wasn't successful + c.mu.Unlock() + return false, nil + } + c.mu.Unlock() + return true, nil +} + +func (c *controlBuffer) get(block bool) (interface{}, error) { + for { + c.mu.Lock() + if c.err != nil { + c.mu.Unlock() + return nil, c.err + } + if !c.list.isEmpty() { + h := c.list.dequeue().(cbItem) + if h.isTransportResponseFrame() { + if c.transportResponseFrames == maxQueuedTransportResponseFrames { + // We are removing the frame that put us over the + // threshold; close and clear the throttling channel. + ch := c.trfChan.Load().(*chan struct{}) + close(*ch) + c.trfChan.Store((*chan struct{})(nil)) + } + c.transportResponseFrames-- + } + c.mu.Unlock() + return h, nil + } + if !block { + c.mu.Unlock() + return nil, nil + } + c.consumerWaiting = true + c.mu.Unlock() + select { + case <-c.ch: + case <-c.done: + c.finish() + return nil, ErrConnClosing + } + } +} + +func (c *controlBuffer) finish() { + c.mu.Lock() + if c.err != nil { + c.mu.Unlock() + return + } + c.err = ErrConnClosing + // There may be headers for streams in the control buffer. + // These streams need to be cleaned out since the transport + // is still not aware of these yet. + for head := c.list.dequeueAll(); head != nil; head = head.next { + hdr, ok := head.it.(*headerFrame) + if !ok { + continue + } + if hdr.onOrphaned != nil { // It will be nil on the server-side. + hdr.onOrphaned(ErrConnClosing) + } + } + c.mu.Unlock() +} + +type side int + +const ( + clientSide side = iota + serverSide +) + +// Loopy receives frames from the control buffer. +// Each frame is handled individually; most of the work done by loopy goes +// into handling data frames. Loopy maintains a queue of active streams, and each +// stream maintains a queue of data frames; as loopy receives data frames +// it gets added to the queue of the relevant stream. +// Loopy goes over this list of active streams by processing one node every iteration, +// thereby closely resemebling to a round-robin scheduling over all streams. While +// processing a stream, loopy writes out data bytes from this stream capped by the min +// of http2MaxFrameLen, connection-level flow control and stream-level flow control. +type loopyWriter struct { + side side + cbuf *controlBuffer + sendQuota uint32 + oiws uint32 // outbound initial window size. + // estdStreams is map of all established streams that are not cleaned-up yet. + // On client-side, this is all streams whose headers were sent out. + // On server-side, this is all streams whose headers were received. + estdStreams map[uint32]*outStream // Established streams. + // activeStreams is a linked-list of all streams that have data to send and some + // stream-level flow control quota. + // Each of these streams internally have a list of data items(and perhaps trailers + // on the server-side) to be sent out. + activeStreams *outStreamList + framer *framer + hBuf *bytes.Buffer // The buffer for HPACK encoding. + hEnc *hpack.Encoder // HPACK encoder. + bdpEst *bdpEstimator + draining bool + + // Side-specific handlers + ssGoAwayHandler func(*goAway) (bool, error) +} + +func newLoopyWriter(s side, fr *framer, cbuf *controlBuffer, bdpEst *bdpEstimator) *loopyWriter { + var buf bytes.Buffer + l := &loopyWriter{ + side: s, + cbuf: cbuf, + sendQuota: defaultWindowSize, + oiws: defaultWindowSize, + estdStreams: make(map[uint32]*outStream), + activeStreams: newOutStreamList(), + framer: fr, + hBuf: &buf, + hEnc: hpack.NewEncoder(&buf), + bdpEst: bdpEst, + } + return l +} + +const minBatchSize = 1000 + +// run should be run in a separate goroutine. +// It reads control frames from controlBuf and processes them by: +// 1. Updating loopy's internal state, or/and +// 2. Writing out HTTP2 frames on the wire. +// +// Loopy keeps all active streams with data to send in a linked-list. +// All streams in the activeStreams linked-list must have both: +// 1. Data to send, and +// 2. Stream level flow control quota available. +// +// In each iteration of run loop, other than processing the incoming control +// frame, loopy calls processData, which processes one node from the activeStreams linked-list. +// This results in writing of HTTP2 frames into an underlying write buffer. +// When there's no more control frames to read from controlBuf, loopy flushes the write buffer. +// As an optimization, to increase the batch size for each flush, loopy yields the processor, once +// if the batch size is too low to give stream goroutines a chance to fill it up. +func (l *loopyWriter) run() (err error) { + defer func() { + if err == ErrConnClosing { + // Don't log ErrConnClosing as error since it happens + // 1. When the connection is closed by some other known issue. + // 2. User closed the connection. + // 3. A graceful close of connection. + if logger.V(logLevel) { + logger.Infof("transport: loopyWriter.run returning. %v", err) + } + err = nil + } + }() + for { + it, err := l.cbuf.get(true) + if err != nil { + return err + } + if err = l.handle(it); err != nil { + return err + } + if _, err = l.processData(); err != nil { + return err + } + gosched := true + hasdata: + for { + it, err := l.cbuf.get(false) + if err != nil { + return err + } + if it != nil { + if err = l.handle(it); err != nil { + return err + } + if _, err = l.processData(); err != nil { + return err + } + continue hasdata + } + isEmpty, err := l.processData() + if err != nil { + return err + } + if !isEmpty { + continue hasdata + } + if gosched { + gosched = false + if l.framer.writer.offset < minBatchSize { + runtime.Gosched() + continue hasdata + } + } + l.framer.writer.Flush() + break hasdata + + } + } +} + +func (l *loopyWriter) outgoingWindowUpdateHandler(w *outgoingWindowUpdate) error { + return l.framer.fr.WriteWindowUpdate(w.streamID, w.increment) +} + +func (l *loopyWriter) incomingWindowUpdateHandler(w *incomingWindowUpdate) error { + // Otherwise update the quota. + if w.streamID == 0 { + l.sendQuota += w.increment + return nil + } + // Find the stream and update it. + if str, ok := l.estdStreams[w.streamID]; ok { + str.bytesOutStanding -= int(w.increment) + if strQuota := int(l.oiws) - str.bytesOutStanding; strQuota > 0 && str.state == waitingOnStreamQuota { + str.state = active + l.activeStreams.enqueue(str) + return nil + } + } + return nil +} + +func (l *loopyWriter) outgoingSettingsHandler(s *outgoingSettings) error { + return l.framer.fr.WriteSettings(s.ss...) +} + +func (l *loopyWriter) incomingSettingsHandler(s *incomingSettings) error { + if err := l.applySettings(s.ss); err != nil { + return err + } + return l.framer.fr.WriteSettingsAck() +} + +func (l *loopyWriter) registerStreamHandler(h *registerStream) error { + str := &outStream{ + id: h.streamID, + state: empty, + itl: &itemList{}, + wq: h.wq, + } + l.estdStreams[h.streamID] = str + return nil +} + +func (l *loopyWriter) headerHandler(h *headerFrame) error { + if l.side == serverSide { + str, ok := l.estdStreams[h.streamID] + if !ok { + if logger.V(logLevel) { + logger.Warningf("transport: loopy doesn't recognize the stream: %d", h.streamID) + } + return nil + } + // Case 1.A: Server is responding back with headers. + if !h.endStream { + return l.writeHeader(h.streamID, h.endStream, h.hf, h.onWrite) + } + // else: Case 1.B: Server wants to close stream. + + if str.state != empty { // either active or waiting on stream quota. + // add it str's list of items. + str.itl.enqueue(h) + return nil + } + if err := l.writeHeader(h.streamID, h.endStream, h.hf, h.onWrite); err != nil { + return err + } + return l.cleanupStreamHandler(h.cleanup) + } + // Case 2: Client wants to originate stream. + str := &outStream{ + id: h.streamID, + state: empty, + itl: &itemList{}, + wq: h.wq, + } + str.itl.enqueue(h) + return l.originateStream(str) +} + +func (l *loopyWriter) originateStream(str *outStream) error { + hdr := str.itl.dequeue().(*headerFrame) + if err := hdr.initStream(str.id); err != nil { + if err == ErrConnClosing { + return err + } + // Other errors(errStreamDrain) need not close transport. + return nil + } + if err := l.writeHeader(str.id, hdr.endStream, hdr.hf, hdr.onWrite); err != nil { + return err + } + l.estdStreams[str.id] = str + return nil +} + +func (l *loopyWriter) writeHeader(streamID uint32, endStream bool, hf []hpack.HeaderField, onWrite func()) error { + if onWrite != nil { + onWrite() + } + l.hBuf.Reset() + for _, f := range hf { + if err := l.hEnc.WriteField(f); err != nil { + if logger.V(logLevel) { + logger.Warningf("transport: loopyWriter.writeHeader encountered error while encoding headers: %v", err) + } + } + } + var ( + err error + endHeaders, first bool + ) + first = true + for !endHeaders { + size := l.hBuf.Len() + if size > http2MaxFrameLen { + size = http2MaxFrameLen + } else { + endHeaders = true + } + if first { + first = false + err = l.framer.fr.WriteHeaders(http2.HeadersFrameParam{ + StreamID: streamID, + BlockFragment: l.hBuf.Next(size), + EndStream: endStream, + EndHeaders: endHeaders, + }) + } else { + err = l.framer.fr.WriteContinuation( + streamID, + endHeaders, + l.hBuf.Next(size), + ) + } + if err != nil { + return err + } + } + return nil +} + +func (l *loopyWriter) preprocessData(df *dataFrame) error { + str, ok := l.estdStreams[df.streamID] + if !ok { + return nil + } + // If we got data for a stream it means that + // stream was originated and the headers were sent out. + str.itl.enqueue(df) + if str.state == empty { + str.state = active + l.activeStreams.enqueue(str) + } + return nil +} + +func (l *loopyWriter) pingHandler(p *ping) error { + if !p.ack { + l.bdpEst.timesnap(p.data) + } + return l.framer.fr.WritePing(p.ack, p.data) + +} + +func (l *loopyWriter) outFlowControlSizeRequestHandler(o *outFlowControlSizeRequest) error { + o.resp <- l.sendQuota + return nil +} + +func (l *loopyWriter) cleanupStreamHandler(c *cleanupStream) error { + c.onWrite() + if str, ok := l.estdStreams[c.streamID]; ok { + // On the server side it could be a trailers-only response or + // a RST_STREAM before stream initialization thus the stream might + // not be established yet. + delete(l.estdStreams, c.streamID) + str.deleteSelf() + } + if c.rst { // If RST_STREAM needs to be sent. + if err := l.framer.fr.WriteRSTStream(c.streamID, c.rstCode); err != nil { + return err + } + } + if l.side == clientSide && l.draining && len(l.estdStreams) == 0 { + return ErrConnClosing + } + return nil +} + +func (l *loopyWriter) incomingGoAwayHandler(*incomingGoAway) error { + if l.side == clientSide { + l.draining = true + if len(l.estdStreams) == 0 { + return ErrConnClosing + } + } + return nil +} + +func (l *loopyWriter) goAwayHandler(g *goAway) error { + // Handling of outgoing GoAway is very specific to side. + if l.ssGoAwayHandler != nil { + draining, err := l.ssGoAwayHandler(g) + if err != nil { + return err + } + l.draining = draining + } + return nil +} + +func (l *loopyWriter) handle(i interface{}) error { + switch i := i.(type) { + case *incomingWindowUpdate: + return l.incomingWindowUpdateHandler(i) + case *outgoingWindowUpdate: + return l.outgoingWindowUpdateHandler(i) + case *incomingSettings: + return l.incomingSettingsHandler(i) + case *outgoingSettings: + return l.outgoingSettingsHandler(i) + case *headerFrame: + return l.headerHandler(i) + case *registerStream: + return l.registerStreamHandler(i) + case *cleanupStream: + return l.cleanupStreamHandler(i) + case *incomingGoAway: + return l.incomingGoAwayHandler(i) + case *dataFrame: + return l.preprocessData(i) + case *ping: + return l.pingHandler(i) + case *goAway: + return l.goAwayHandler(i) + case *outFlowControlSizeRequest: + return l.outFlowControlSizeRequestHandler(i) + default: + return fmt.Errorf("transport: unknown control message type %T", i) + } +} + +func (l *loopyWriter) applySettings(ss []http2.Setting) error { + for _, s := range ss { + switch s.ID { + case http2.SettingInitialWindowSize: + o := l.oiws + l.oiws = s.Val + if o < l.oiws { + // If the new limit is greater make all depleted streams active. + for _, stream := range l.estdStreams { + if stream.state == waitingOnStreamQuota { + stream.state = active + l.activeStreams.enqueue(stream) + } + } + } + case http2.SettingHeaderTableSize: + updateHeaderTblSize(l.hEnc, s.Val) + } + } + return nil +} + +// processData removes the first stream from active streams, writes out at most 16KB +// of its data and then puts it at the end of activeStreams if there's still more data +// to be sent and stream has some stream-level flow control. +func (l *loopyWriter) processData() (bool, error) { + if l.sendQuota == 0 { + return true, nil + } + str := l.activeStreams.dequeue() // Remove the first stream. + if str == nil { + return true, nil + } + dataItem := str.itl.peek().(*dataFrame) // Peek at the first data item this stream. + // A data item is represented by a dataFrame, since it later translates into + // multiple HTTP2 data frames. + // Every dataFrame has two buffers; h that keeps grpc-message header and d that is acutal data. + // As an optimization to keep wire traffic low, data from d is copied to h to make as big as the + // maximum possilbe HTTP2 frame size. + + if len(dataItem.h) == 0 && len(dataItem.d) == 0 { // Empty data frame + // Client sends out empty data frame with endStream = true + if err := l.framer.fr.WriteData(dataItem.streamID, dataItem.endStream, nil); err != nil { + return false, err + } + str.itl.dequeue() // remove the empty data item from stream + if str.itl.isEmpty() { + str.state = empty + } else if trailer, ok := str.itl.peek().(*headerFrame); ok { // the next item is trailers. + if err := l.writeHeader(trailer.streamID, trailer.endStream, trailer.hf, trailer.onWrite); err != nil { + return false, err + } + if err := l.cleanupStreamHandler(trailer.cleanup); err != nil { + return false, nil + } + } else { + l.activeStreams.enqueue(str) + } + return false, nil + } + var ( + buf []byte + ) + // Figure out the maximum size we can send + maxSize := http2MaxFrameLen + if strQuota := int(l.oiws) - str.bytesOutStanding; strQuota <= 0 { // stream-level flow control. + str.state = waitingOnStreamQuota + return false, nil + } else if maxSize > strQuota { + maxSize = strQuota + } + if maxSize > int(l.sendQuota) { // connection-level flow control. + maxSize = int(l.sendQuota) + } + // Compute how much of the header and data we can send within quota and max frame length + hSize := min(maxSize, len(dataItem.h)) + dSize := min(maxSize-hSize, len(dataItem.d)) + if hSize != 0 { + if dSize == 0 { + buf = dataItem.h + } else { + // We can add some data to grpc message header to distribute bytes more equally across frames. + // Copy on the stack to avoid generating garbage + var localBuf [http2MaxFrameLen]byte + copy(localBuf[:hSize], dataItem.h) + copy(localBuf[hSize:], dataItem.d[:dSize]) + buf = localBuf[:hSize+dSize] + } + } else { + buf = dataItem.d + } + + size := hSize + dSize + + // Now that outgoing flow controls are checked we can replenish str's write quota + str.wq.replenish(size) + var endStream bool + // If this is the last data message on this stream and all of it can be written in this iteration. + if dataItem.endStream && len(dataItem.h)+len(dataItem.d) <= size { + endStream = true + } + if dataItem.onEachWrite != nil { + dataItem.onEachWrite() + } + if err := l.framer.fr.WriteData(dataItem.streamID, endStream, buf[:size]); err != nil { + return false, err + } + str.bytesOutStanding += size + l.sendQuota -= uint32(size) + dataItem.h = dataItem.h[hSize:] + dataItem.d = dataItem.d[dSize:] + + if len(dataItem.h) == 0 && len(dataItem.d) == 0 { // All the data from that message was written out. + str.itl.dequeue() + } + if str.itl.isEmpty() { + str.state = empty + } else if trailer, ok := str.itl.peek().(*headerFrame); ok { // The next item is trailers. + if err := l.writeHeader(trailer.streamID, trailer.endStream, trailer.hf, trailer.onWrite); err != nil { + return false, err + } + if err := l.cleanupStreamHandler(trailer.cleanup); err != nil { + return false, err + } + } else if int(l.oiws)-str.bytesOutStanding <= 0 { // Ran out of stream quota. + str.state = waitingOnStreamQuota + } else { // Otherwise add it back to the list of active streams. + l.activeStreams.enqueue(str) + } + return false, nil +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/vendor/google.golang.org/grpc/internal/transport/defaults.go b/vendor/google.golang.org/grpc/internal/transport/defaults.go new file mode 100644 index 000000000..9fa306b2e --- /dev/null +++ b/vendor/google.golang.org/grpc/internal/transport/defaults.go @@ -0,0 +1,49 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package transport + +import ( + "math" + "time" +) + +const ( + // The default value of flow control window size in HTTP2 spec. + defaultWindowSize = 65535 + // The initial window size for flow control. + initialWindowSize = defaultWindowSize // for an RPC + infinity = time.Duration(math.MaxInt64) + defaultClientKeepaliveTime = infinity + defaultClientKeepaliveTimeout = 20 * time.Second + defaultMaxStreamsClient = 100 + defaultMaxConnectionIdle = infinity + defaultMaxConnectionAge = infinity + defaultMaxConnectionAgeGrace = infinity + defaultServerKeepaliveTime = 2 * time.Hour + defaultServerKeepaliveTimeout = 20 * time.Second + defaultKeepalivePolicyMinTime = 5 * time.Minute + // max window limit set by HTTP2 Specs. + maxWindowSize = math.MaxInt32 + // defaultWriteQuota is the default value for number of data + // bytes that each stream can schedule before some of it being + // flushed out. + defaultWriteQuota = 64 * 1024 + defaultClientMaxHeaderListSize = uint32(16 << 20) + defaultServerMaxHeaderListSize = uint32(16 << 20) +) diff --git a/vendor/google.golang.org/grpc/internal/transport/flowcontrol.go b/vendor/google.golang.org/grpc/internal/transport/flowcontrol.go new file mode 100644 index 000000000..f262edd8e --- /dev/null +++ b/vendor/google.golang.org/grpc/internal/transport/flowcontrol.go @@ -0,0 +1,217 @@ +/* + * + * Copyright 2014 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package transport + +import ( + "fmt" + "math" + "sync" + "sync/atomic" +) + +// writeQuota is a soft limit on the amount of data a stream can +// schedule before some of it is written out. +type writeQuota struct { + quota int32 + // get waits on read from when quota goes less than or equal to zero. + // replenish writes on it when quota goes positive again. + ch chan struct{} + // done is triggered in error case. + done <-chan struct{} + // replenish is called by loopyWriter to give quota back to. + // It is implemented as a field so that it can be updated + // by tests. + replenish func(n int) +} + +func newWriteQuota(sz int32, done <-chan struct{}) *writeQuota { + w := &writeQuota{ + quota: sz, + ch: make(chan struct{}, 1), + done: done, + } + w.replenish = w.realReplenish + return w +} + +func (w *writeQuota) get(sz int32) error { + for { + if atomic.LoadInt32(&w.quota) > 0 { + atomic.AddInt32(&w.quota, -sz) + return nil + } + select { + case <-w.ch: + continue + case <-w.done: + return errStreamDone + } + } +} + +func (w *writeQuota) realReplenish(n int) { + sz := int32(n) + a := atomic.AddInt32(&w.quota, sz) + b := a - sz + if b <= 0 && a > 0 { + select { + case w.ch <- struct{}{}: + default: + } + } +} + +type trInFlow struct { + limit uint32 + unacked uint32 + effectiveWindowSize uint32 +} + +func (f *trInFlow) newLimit(n uint32) uint32 { + d := n - f.limit + f.limit = n + f.updateEffectiveWindowSize() + return d +} + +func (f *trInFlow) onData(n uint32) uint32 { + f.unacked += n + if f.unacked >= f.limit/4 { + w := f.unacked + f.unacked = 0 + f.updateEffectiveWindowSize() + return w + } + f.updateEffectiveWindowSize() + return 0 +} + +func (f *trInFlow) reset() uint32 { + w := f.unacked + f.unacked = 0 + f.updateEffectiveWindowSize() + return w +} + +func (f *trInFlow) updateEffectiveWindowSize() { + atomic.StoreUint32(&f.effectiveWindowSize, f.limit-f.unacked) +} + +func (f *trInFlow) getSize() uint32 { + return atomic.LoadUint32(&f.effectiveWindowSize) +} + +// TODO(mmukhi): Simplify this code. +// inFlow deals with inbound flow control +type inFlow struct { + mu sync.Mutex + // The inbound flow control limit for pending data. + limit uint32 + // pendingData is the overall data which have been received but not been + // consumed by applications. + pendingData uint32 + // The amount of data the application has consumed but grpc has not sent + // window update for them. Used to reduce window update frequency. + pendingUpdate uint32 + // delta is the extra window update given by receiver when an application + // is reading data bigger in size than the inFlow limit. + delta uint32 +} + +// newLimit updates the inflow window to a new value n. +// It assumes that n is always greater than the old limit. +func (f *inFlow) newLimit(n uint32) uint32 { + f.mu.Lock() + d := n - f.limit + f.limit = n + f.mu.Unlock() + return d +} + +func (f *inFlow) maybeAdjust(n uint32) uint32 { + if n > uint32(math.MaxInt32) { + n = uint32(math.MaxInt32) + } + f.mu.Lock() + defer f.mu.Unlock() + // estSenderQuota is the receiver's view of the maximum number of bytes the sender + // can send without a window update. + estSenderQuota := int32(f.limit - (f.pendingData + f.pendingUpdate)) + // estUntransmittedData is the maximum number of bytes the sends might not have put + // on the wire yet. A value of 0 or less means that we have already received all or + // more bytes than the application is requesting to read. + estUntransmittedData := int32(n - f.pendingData) // Casting into int32 since it could be negative. + // This implies that unless we send a window update, the sender won't be able to send all the bytes + // for this message. Therefore we must send an update over the limit since there's an active read + // request from the application. + if estUntransmittedData > estSenderQuota { + // Sender's window shouldn't go more than 2^31 - 1 as specified in the HTTP spec. + if f.limit+n > maxWindowSize { + f.delta = maxWindowSize - f.limit + } else { + // Send a window update for the whole message and not just the difference between + // estUntransmittedData and estSenderQuota. This will be helpful in case the message + // is padded; We will fallback on the current available window(at least a 1/4th of the limit). + f.delta = n + } + return f.delta + } + return 0 +} + +// onData is invoked when some data frame is received. It updates pendingData. +func (f *inFlow) onData(n uint32) error { + f.mu.Lock() + f.pendingData += n + if f.pendingData+f.pendingUpdate > f.limit+f.delta { + limit := f.limit + rcvd := f.pendingData + f.pendingUpdate + f.mu.Unlock() + return fmt.Errorf("received %d-bytes data exceeding the limit %d bytes", rcvd, limit) + } + f.mu.Unlock() + return nil +} + +// onRead is invoked when the application reads the data. It returns the window size +// to be sent to the peer. +func (f *inFlow) onRead(n uint32) uint32 { + f.mu.Lock() + if f.pendingData == 0 { + f.mu.Unlock() + return 0 + } + f.pendingData -= n + if n > f.delta { + n -= f.delta + f.delta = 0 + } else { + f.delta -= n + n = 0 + } + f.pendingUpdate += n + if f.pendingUpdate >= f.limit/4 { + wu := f.pendingUpdate + f.pendingUpdate = 0 + f.mu.Unlock() + return wu + } + f.mu.Unlock() + return 0 +} diff --git a/vendor/google.golang.org/grpc/internal/transport/handler_server.go b/vendor/google.golang.org/grpc/internal/transport/handler_server.go new file mode 100644 index 000000000..05d3871e6 --- /dev/null +++ b/vendor/google.golang.org/grpc/internal/transport/handler_server.go @@ -0,0 +1,463 @@ +/* + * + * Copyright 2016 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// This file is the implementation of a gRPC server using HTTP/2 which +// uses the standard Go http2 Server implementation (via the +// http.Handler interface), rather than speaking low-level HTTP/2 +// frames itself. It is the implementation of *grpc.Server.ServeHTTP. + +package transport + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + "strings" + "sync" + "time" + + "github.com/golang/protobuf/proto" + "golang.org/x/net/http2" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal/grpcutil" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/peer" + "google.golang.org/grpc/stats" + "google.golang.org/grpc/status" +) + +// NewServerHandlerTransport returns a ServerTransport handling gRPC +// from inside an http.Handler. It requires that the http Server +// supports HTTP/2. +func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats stats.Handler) (ServerTransport, error) { + if r.ProtoMajor != 2 { + return nil, errors.New("gRPC requires HTTP/2") + } + if r.Method != "POST" { + return nil, errors.New("invalid gRPC request method") + } + contentType := r.Header.Get("Content-Type") + // TODO: do we assume contentType is lowercase? we did before + contentSubtype, validContentType := grpcutil.ContentSubtype(contentType) + if !validContentType { + return nil, errors.New("invalid gRPC request content-type") + } + if _, ok := w.(http.Flusher); !ok { + return nil, errors.New("gRPC requires a ResponseWriter supporting http.Flusher") + } + + st := &serverHandlerTransport{ + rw: w, + req: r, + closedCh: make(chan struct{}), + writes: make(chan func()), + contentType: contentType, + contentSubtype: contentSubtype, + stats: stats, + } + + if v := r.Header.Get("grpc-timeout"); v != "" { + to, err := decodeTimeout(v) + if err != nil { + return nil, status.Errorf(codes.Internal, "malformed time-out: %v", err) + } + st.timeoutSet = true + st.timeout = to + } + + metakv := []string{"content-type", contentType} + if r.Host != "" { + metakv = append(metakv, ":authority", r.Host) + } + for k, vv := range r.Header { + k = strings.ToLower(k) + if isReservedHeader(k) && !isWhitelistedHeader(k) { + continue + } + for _, v := range vv { + v, err := decodeMetadataHeader(k, v) + if err != nil { + return nil, status.Errorf(codes.Internal, "malformed binary metadata: %v", err) + } + metakv = append(metakv, k, v) + } + } + st.headerMD = metadata.Pairs(metakv...) + + return st, nil +} + +// serverHandlerTransport is an implementation of ServerTransport +// which replies to exactly one gRPC request (exactly one HTTP request), +// using the net/http.Handler interface. This http.Handler is guaranteed +// at this point to be speaking over HTTP/2, so it's able to speak valid +// gRPC. +type serverHandlerTransport struct { + rw http.ResponseWriter + req *http.Request + timeoutSet bool + timeout time.Duration + + headerMD metadata.MD + + closeOnce sync.Once + closedCh chan struct{} // closed on Close + + // writes is a channel of code to run serialized in the + // ServeHTTP (HandleStreams) goroutine. The channel is closed + // when WriteStatus is called. + writes chan func() + + // block concurrent WriteStatus calls + // e.g. grpc/(*serverStream).SendMsg/RecvMsg + writeStatusMu sync.Mutex + + // we just mirror the request content-type + contentType string + // we store both contentType and contentSubtype so we don't keep recreating them + // TODO make sure this is consistent across handler_server and http2_server + contentSubtype string + + stats stats.Handler +} + +func (ht *serverHandlerTransport) Close() error { + ht.closeOnce.Do(ht.closeCloseChanOnce) + return nil +} + +func (ht *serverHandlerTransport) closeCloseChanOnce() { close(ht.closedCh) } + +func (ht *serverHandlerTransport) RemoteAddr() net.Addr { return strAddr(ht.req.RemoteAddr) } + +// strAddr is a net.Addr backed by either a TCP "ip:port" string, or +// the empty string if unknown. +type strAddr string + +func (a strAddr) Network() string { + if a != "" { + // Per the documentation on net/http.Request.RemoteAddr, if this is + // set, it's set to the IP:port of the peer (hence, TCP): + // https://golang.org/pkg/net/http/#Request + // + // If we want to support Unix sockets later, we can + // add our own grpc-specific convention within the + // grpc codebase to set RemoteAddr to a different + // format, or probably better: we can attach it to the + // context and use that from serverHandlerTransport.RemoteAddr. + return "tcp" + } + return "" +} + +func (a strAddr) String() string { return string(a) } + +// do runs fn in the ServeHTTP goroutine. +func (ht *serverHandlerTransport) do(fn func()) error { + select { + case <-ht.closedCh: + return ErrConnClosing + case ht.writes <- fn: + return nil + } +} + +func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) error { + ht.writeStatusMu.Lock() + defer ht.writeStatusMu.Unlock() + + headersWritten := s.updateHeaderSent() + err := ht.do(func() { + if !headersWritten { + ht.writePendingHeaders(s) + } + + // And flush, in case no header or body has been sent yet. + // This forces a separation of headers and trailers if this is the + // first call (for example, in end2end tests's TestNoService). + ht.rw.(http.Flusher).Flush() + + h := ht.rw.Header() + h.Set("Grpc-Status", fmt.Sprintf("%d", st.Code())) + if m := st.Message(); m != "" { + h.Set("Grpc-Message", encodeGrpcMessage(m)) + } + + if p := st.Proto(); p != nil && len(p.Details) > 0 { + stBytes, err := proto.Marshal(p) + if err != nil { + // TODO: return error instead, when callers are able to handle it. + panic(err) + } + + h.Set("Grpc-Status-Details-Bin", encodeBinHeader(stBytes)) + } + + if md := s.Trailer(); len(md) > 0 { + for k, vv := range md { + // Clients don't tolerate reading restricted headers after some non restricted ones were sent. + if isReservedHeader(k) { + continue + } + for _, v := range vv { + // http2 ResponseWriter mechanism to send undeclared Trailers after + // the headers have possibly been written. + h.Add(http2.TrailerPrefix+k, encodeMetadataHeader(k, v)) + } + } + } + }) + + if err == nil { // transport has not been closed + if ht.stats != nil { + // Note: The trailer fields are compressed with hpack after this call returns. + // No WireLength field is set here. + ht.stats.HandleRPC(s.Context(), &stats.OutTrailer{ + Trailer: s.trailer.Copy(), + }) + } + } + ht.Close() + return err +} + +// writePendingHeaders sets common and custom headers on the first +// write call (Write, WriteHeader, or WriteStatus) +func (ht *serverHandlerTransport) writePendingHeaders(s *Stream) { + ht.writeCommonHeaders(s) + ht.writeCustomHeaders(s) +} + +// writeCommonHeaders sets common headers on the first write +// call (Write, WriteHeader, or WriteStatus). +func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) { + h := ht.rw.Header() + h["Date"] = nil // suppress Date to make tests happy; TODO: restore + h.Set("Content-Type", ht.contentType) + + // Predeclare trailers we'll set later in WriteStatus (after the body). + // This is a SHOULD in the HTTP RFC, and the way you add (known) + // Trailers per the net/http.ResponseWriter contract. + // See https://golang.org/pkg/net/http/#ResponseWriter + // and https://golang.org/pkg/net/http/#example_ResponseWriter_trailers + h.Add("Trailer", "Grpc-Status") + h.Add("Trailer", "Grpc-Message") + h.Add("Trailer", "Grpc-Status-Details-Bin") + + if s.sendCompress != "" { + h.Set("Grpc-Encoding", s.sendCompress) + } +} + +// writeCustomHeaders sets custom headers set on the stream via SetHeader +// on the first write call (Write, WriteHeader, or WriteStatus). +func (ht *serverHandlerTransport) writeCustomHeaders(s *Stream) { + h := ht.rw.Header() + + s.hdrMu.Lock() + for k, vv := range s.header { + if isReservedHeader(k) { + continue + } + for _, v := range vv { + h.Add(k, encodeMetadataHeader(k, v)) + } + } + + s.hdrMu.Unlock() +} + +func (ht *serverHandlerTransport) Write(s *Stream, hdr []byte, data []byte, opts *Options) error { + headersWritten := s.updateHeaderSent() + return ht.do(func() { + if !headersWritten { + ht.writePendingHeaders(s) + } + ht.rw.Write(hdr) + ht.rw.Write(data) + ht.rw.(http.Flusher).Flush() + }) +} + +func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error { + if err := s.SetHeader(md); err != nil { + return err + } + + headersWritten := s.updateHeaderSent() + err := ht.do(func() { + if !headersWritten { + ht.writePendingHeaders(s) + } + + ht.rw.WriteHeader(200) + ht.rw.(http.Flusher).Flush() + }) + + if err == nil { + if ht.stats != nil { + // Note: The header fields are compressed with hpack after this call returns. + // No WireLength field is set here. + ht.stats.HandleRPC(s.Context(), &stats.OutHeader{ + Header: md.Copy(), + Compression: s.sendCompress, + }) + } + } + return err +} + +func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), traceCtx func(context.Context, string) context.Context) { + // With this transport type there will be exactly 1 stream: this HTTP request. + + ctx := ht.req.Context() + var cancel context.CancelFunc + if ht.timeoutSet { + ctx, cancel = context.WithTimeout(ctx, ht.timeout) + } else { + ctx, cancel = context.WithCancel(ctx) + } + + // requestOver is closed when the status has been written via WriteStatus. + requestOver := make(chan struct{}) + go func() { + select { + case <-requestOver: + case <-ht.closedCh: + case <-ht.req.Context().Done(): + } + cancel() + ht.Close() + }() + + req := ht.req + + s := &Stream{ + id: 0, // irrelevant + requestRead: func(int) {}, + cancel: cancel, + buf: newRecvBuffer(), + st: ht, + method: req.URL.Path, + recvCompress: req.Header.Get("grpc-encoding"), + contentSubtype: ht.contentSubtype, + } + pr := &peer.Peer{ + Addr: ht.RemoteAddr(), + } + if req.TLS != nil { + pr.AuthInfo = credentials.TLSInfo{State: *req.TLS, CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity}} + } + ctx = metadata.NewIncomingContext(ctx, ht.headerMD) + s.ctx = peer.NewContext(ctx, pr) + if ht.stats != nil { + s.ctx = ht.stats.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method}) + inHeader := &stats.InHeader{ + FullMethod: s.method, + RemoteAddr: ht.RemoteAddr(), + Compression: s.recvCompress, + } + ht.stats.HandleRPC(s.ctx, inHeader) + } + s.trReader = &transportReader{ + reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf, freeBuffer: func(*bytes.Buffer) {}}, + windowHandler: func(int) {}, + } + + // readerDone is closed when the Body.Read-ing goroutine exits. + readerDone := make(chan struct{}) + go func() { + defer close(readerDone) + + // TODO: minimize garbage, optimize recvBuffer code/ownership + const readSize = 8196 + for buf := make([]byte, readSize); ; { + n, err := req.Body.Read(buf) + if n > 0 { + s.buf.put(recvMsg{buffer: bytes.NewBuffer(buf[:n:n])}) + buf = buf[n:] + } + if err != nil { + s.buf.put(recvMsg{err: mapRecvMsgError(err)}) + return + } + if len(buf) == 0 { + buf = make([]byte, readSize) + } + } + }() + + // startStream is provided by the *grpc.Server's serveStreams. + // It starts a goroutine serving s and exits immediately. + // The goroutine that is started is the one that then calls + // into ht, calling WriteHeader, Write, WriteStatus, Close, etc. + startStream(s) + + ht.runStream() + close(requestOver) + + // Wait for reading goroutine to finish. + req.Body.Close() + <-readerDone +} + +func (ht *serverHandlerTransport) runStream() { + for { + select { + case fn := <-ht.writes: + fn() + case <-ht.closedCh: + return + } + } +} + +func (ht *serverHandlerTransport) IncrMsgSent() {} + +func (ht *serverHandlerTransport) IncrMsgRecv() {} + +func (ht *serverHandlerTransport) Drain() { + panic("Drain() is not implemented") +} + +// mapRecvMsgError returns the non-nil err into the appropriate +// error value as expected by callers of *grpc.parser.recvMsg. +// In particular, in can only be: +// * io.EOF +// * io.ErrUnexpectedEOF +// * of type transport.ConnectionError +// * an error from the status package +func mapRecvMsgError(err error) error { + if err == io.EOF || err == io.ErrUnexpectedEOF { + return err + } + if se, ok := err.(http2.StreamError); ok { + if code, ok := http2ErrConvTab[se.Code]; ok { + return status.Error(code, se.Error()) + } + } + if strings.Contains(err.Error(), "body closed by handler") { + return status.Error(codes.Canceled, err.Error()) + } + return connectionErrorf(true, err, err.Error()) +} diff --git a/vendor/google.golang.org/grpc/internal/transport/http2_client.go b/vendor/google.golang.org/grpc/internal/transport/http2_client.go new file mode 100644 index 000000000..e73b77a15 --- /dev/null +++ b/vendor/google.golang.org/grpc/internal/transport/http2_client.go @@ -0,0 +1,1491 @@ +/* + * + * Copyright 2014 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package transport + +import ( + "context" + "fmt" + "io" + "math" + "net" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "golang.org/x/net/http2" + "golang.org/x/net/http2/hpack" + "google.golang.org/grpc/internal/grpcutil" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/channelz" + "google.golang.org/grpc/internal/syscall" + "google.golang.org/grpc/keepalive" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/peer" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/stats" + "google.golang.org/grpc/status" +) + +// clientConnectionCounter counts the number of connections a client has +// initiated (equal to the number of http2Clients created). Must be accessed +// atomically. +var clientConnectionCounter uint64 + +// http2Client implements the ClientTransport interface with HTTP2. +type http2Client struct { + lastRead int64 // Keep this field 64-bit aligned. Accessed atomically. + ctx context.Context + cancel context.CancelFunc + ctxDone <-chan struct{} // Cache the ctx.Done() chan. + userAgent string + md interface{} + conn net.Conn // underlying communication channel + loopy *loopyWriter + remoteAddr net.Addr + localAddr net.Addr + authInfo credentials.AuthInfo // auth info about the connection + + readerDone chan struct{} // sync point to enable testing. + writerDone chan struct{} // sync point to enable testing. + // goAway is closed to notify the upper layer (i.e., addrConn.transportMonitor) + // that the server sent GoAway on this transport. + goAway chan struct{} + + framer *framer + // controlBuf delivers all the control related tasks (e.g., window + // updates, reset streams, and various settings) to the controller. + controlBuf *controlBuffer + fc *trInFlow + // The scheme used: https if TLS is on, http otherwise. + scheme string + + isSecure bool + + perRPCCreds []credentials.PerRPCCredentials + + kp keepalive.ClientParameters + keepaliveEnabled bool + + statsHandler stats.Handler + + initialWindowSize int32 + + // configured by peer through SETTINGS_MAX_HEADER_LIST_SIZE + maxSendHeaderListSize *uint32 + + bdpEst *bdpEstimator + // onPrefaceReceipt is a callback that client transport calls upon + // receiving server preface to signal that a succefull HTTP2 + // connection was established. + onPrefaceReceipt func() + + maxConcurrentStreams uint32 + streamQuota int64 + streamsQuotaAvailable chan struct{} + waitingStreams uint32 + nextID uint32 + + mu sync.Mutex // guard the following variables + state transportState + activeStreams map[uint32]*Stream + // prevGoAway ID records the Last-Stream-ID in the previous GOAway frame. + prevGoAwayID uint32 + // goAwayReason records the http2.ErrCode and debug data received with the + // GoAway frame. + goAwayReason GoAwayReason + // A condition variable used to signal when the keepalive goroutine should + // go dormant. The condition for dormancy is based on the number of active + // streams and the `PermitWithoutStream` keepalive client parameter. And + // since the number of active streams is guarded by the above mutex, we use + // the same for this condition variable as well. + kpDormancyCond *sync.Cond + // A boolean to track whether the keepalive goroutine is dormant or not. + // This is checked before attempting to signal the above condition + // variable. + kpDormant bool + + // Fields below are for channelz metric collection. + channelzID int64 // channelz unique identification number + czData *channelzData + + onGoAway func(GoAwayReason) + onClose func() + + bufferPool *bufferPool + + connectionID uint64 +} + +func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr string) (net.Conn, error) { + if fn != nil { + return fn(ctx, addr) + } + return (&net.Dialer{}).DialContext(ctx, "tcp", addr) +} + +func isTemporary(err error) bool { + switch err := err.(type) { + case interface { + Temporary() bool + }: + return err.Temporary() + case interface { + Timeout() bool + }: + // Timeouts may be resolved upon retry, and are thus treated as + // temporary. + return err.Timeout() + } + return true +} + +// newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2 +// and starts to receive messages on it. Non-nil error returns if construction +// fails. +func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts ConnectOptions, onPrefaceReceipt func(), onGoAway func(GoAwayReason), onClose func()) (_ *http2Client, err error) { + scheme := "http" + ctx, cancel := context.WithCancel(ctx) + defer func() { + if err != nil { + cancel() + } + }() + + conn, err := dial(connectCtx, opts.Dialer, addr.Addr) + if err != nil { + if opts.FailOnNonTempDialError { + return nil, connectionErrorf(isTemporary(err), err, "transport: error while dialing: %v", err) + } + return nil, connectionErrorf(true, err, "transport: Error while dialing %v", err) + } + // Any further errors will close the underlying connection + defer func(conn net.Conn) { + if err != nil { + conn.Close() + } + }(conn) + kp := opts.KeepaliveParams + // Validate keepalive parameters. + if kp.Time == 0 { + kp.Time = defaultClientKeepaliveTime + } + if kp.Timeout == 0 { + kp.Timeout = defaultClientKeepaliveTimeout + } + keepaliveEnabled := false + if kp.Time != infinity { + if err = syscall.SetTCPUserTimeout(conn, kp.Timeout); err != nil { + return nil, connectionErrorf(false, err, "transport: failed to set TCP_USER_TIMEOUT: %v", err) + } + keepaliveEnabled = true + } + var ( + isSecure bool + authInfo credentials.AuthInfo + ) + transportCreds := opts.TransportCredentials + perRPCCreds := opts.PerRPCCredentials + + if b := opts.CredsBundle; b != nil { + if t := b.TransportCredentials(); t != nil { + transportCreds = t + } + if t := b.PerRPCCredentials(); t != nil { + perRPCCreds = append(perRPCCreds, t) + } + } + if transportCreds != nil { + // gRPC, resolver, balancer etc. can specify arbitrary data in the + // Attributes field of resolver.Address, which is shoved into connectCtx + // and passed to the credential handshaker. This makes it possible for + // address specific arbitrary data to reach the credential handshaker. + contextWithHandshakeInfo := internal.NewClientHandshakeInfoContext.(func(context.Context, credentials.ClientHandshakeInfo) context.Context) + connectCtx = contextWithHandshakeInfo(connectCtx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes}) + conn, authInfo, err = transportCreds.ClientHandshake(connectCtx, addr.ServerName, conn) + if err != nil { + return nil, connectionErrorf(isTemporary(err), err, "transport: authentication handshake failed: %v", err) + } + isSecure = true + if transportCreds.Info().SecurityProtocol == "tls" { + scheme = "https" + } + } + dynamicWindow := true + icwz := int32(initialWindowSize) + if opts.InitialConnWindowSize >= defaultWindowSize { + icwz = opts.InitialConnWindowSize + dynamicWindow = false + } + writeBufSize := opts.WriteBufferSize + readBufSize := opts.ReadBufferSize + maxHeaderListSize := defaultClientMaxHeaderListSize + if opts.MaxHeaderListSize != nil { + maxHeaderListSize = *opts.MaxHeaderListSize + } + t := &http2Client{ + ctx: ctx, + ctxDone: ctx.Done(), // Cache Done chan. + cancel: cancel, + userAgent: opts.UserAgent, + md: addr.Metadata, + conn: conn, + remoteAddr: conn.RemoteAddr(), + localAddr: conn.LocalAddr(), + authInfo: authInfo, + readerDone: make(chan struct{}), + writerDone: make(chan struct{}), + goAway: make(chan struct{}), + framer: newFramer(conn, writeBufSize, readBufSize, maxHeaderListSize), + fc: &trInFlow{limit: uint32(icwz)}, + scheme: scheme, + activeStreams: make(map[uint32]*Stream), + isSecure: isSecure, + perRPCCreds: perRPCCreds, + kp: kp, + statsHandler: opts.StatsHandler, + initialWindowSize: initialWindowSize, + onPrefaceReceipt: onPrefaceReceipt, + nextID: 1, + maxConcurrentStreams: defaultMaxStreamsClient, + streamQuota: defaultMaxStreamsClient, + streamsQuotaAvailable: make(chan struct{}, 1), + czData: new(channelzData), + onGoAway: onGoAway, + onClose: onClose, + keepaliveEnabled: keepaliveEnabled, + bufferPool: newBufferPool(), + } + t.controlBuf = newControlBuffer(t.ctxDone) + if opts.InitialWindowSize >= defaultWindowSize { + t.initialWindowSize = opts.InitialWindowSize + dynamicWindow = false + } + if dynamicWindow { + t.bdpEst = &bdpEstimator{ + bdp: initialWindowSize, + updateFlowControl: t.updateFlowControl, + } + } + if t.statsHandler != nil { + t.ctx = t.statsHandler.TagConn(t.ctx, &stats.ConnTagInfo{ + RemoteAddr: t.remoteAddr, + LocalAddr: t.localAddr, + }) + connBegin := &stats.ConnBegin{ + Client: true, + } + t.statsHandler.HandleConn(t.ctx, connBegin) + } + if channelz.IsOn() { + t.channelzID = channelz.RegisterNormalSocket(t, opts.ChannelzParentID, fmt.Sprintf("%s -> %s", t.localAddr, t.remoteAddr)) + } + if t.keepaliveEnabled { + t.kpDormancyCond = sync.NewCond(&t.mu) + go t.keepalive() + } + // Start the reader goroutine for incoming message. Each transport has + // a dedicated goroutine which reads HTTP2 frame from network. Then it + // dispatches the frame to the corresponding stream entity. + go t.reader() + + // Send connection preface to server. + n, err := t.conn.Write(clientPreface) + if err != nil { + t.Close() + return nil, connectionErrorf(true, err, "transport: failed to write client preface: %v", err) + } + if n != len(clientPreface) { + t.Close() + return nil, connectionErrorf(true, err, "transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface)) + } + var ss []http2.Setting + + if t.initialWindowSize != defaultWindowSize { + ss = append(ss, http2.Setting{ + ID: http2.SettingInitialWindowSize, + Val: uint32(t.initialWindowSize), + }) + } + if opts.MaxHeaderListSize != nil { + ss = append(ss, http2.Setting{ + ID: http2.SettingMaxHeaderListSize, + Val: *opts.MaxHeaderListSize, + }) + } + err = t.framer.fr.WriteSettings(ss...) + if err != nil { + t.Close() + return nil, connectionErrorf(true, err, "transport: failed to write initial settings frame: %v", err) + } + // Adjust the connection flow control window if needed. + if delta := uint32(icwz - defaultWindowSize); delta > 0 { + if err := t.framer.fr.WriteWindowUpdate(0, delta); err != nil { + t.Close() + return nil, connectionErrorf(true, err, "transport: failed to write window update: %v", err) + } + } + + t.connectionID = atomic.AddUint64(&clientConnectionCounter, 1) + + if err := t.framer.writer.Flush(); err != nil { + return nil, err + } + go func() { + t.loopy = newLoopyWriter(clientSide, t.framer, t.controlBuf, t.bdpEst) + err := t.loopy.run() + if err != nil { + if logger.V(logLevel) { + logger.Errorf("transport: loopyWriter.run returning. Err: %v", err) + } + } + // If it's a connection error, let reader goroutine handle it + // since there might be data in the buffers. + if _, ok := err.(net.Error); !ok { + t.conn.Close() + } + close(t.writerDone) + }() + return t, nil +} + +func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream { + // TODO(zhaoq): Handle uint32 overflow of Stream.id. + s := &Stream{ + ct: t, + done: make(chan struct{}), + method: callHdr.Method, + sendCompress: callHdr.SendCompress, + buf: newRecvBuffer(), + headerChan: make(chan struct{}), + contentSubtype: callHdr.ContentSubtype, + } + s.wq = newWriteQuota(defaultWriteQuota, s.done) + s.requestRead = func(n int) { + t.adjustWindow(s, uint32(n)) + } + // The client side stream context should have exactly the same life cycle with the user provided context. + // That means, s.ctx should be read-only. And s.ctx is done iff ctx is done. + // So we use the original context here instead of creating a copy. + s.ctx = ctx + s.trReader = &transportReader{ + reader: &recvBufferReader{ + ctx: s.ctx, + ctxDone: s.ctx.Done(), + recv: s.buf, + closeStream: func(err error) { + t.CloseStream(s, err) + }, + freeBuffer: t.bufferPool.put, + }, + windowHandler: func(n int) { + t.updateWindow(s, uint32(n)) + }, + } + return s +} + +func (t *http2Client) getPeer() *peer.Peer { + return &peer.Peer{ + Addr: t.remoteAddr, + AuthInfo: t.authInfo, + } +} + +func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr) ([]hpack.HeaderField, error) { + aud := t.createAudience(callHdr) + ri := credentials.RequestInfo{ + Method: callHdr.Method, + AuthInfo: t.authInfo, + } + ctxWithRequestInfo := internal.NewRequestInfoContext.(func(context.Context, credentials.RequestInfo) context.Context)(ctx, ri) + authData, err := t.getTrAuthData(ctxWithRequestInfo, aud) + if err != nil { + return nil, err + } + callAuthData, err := t.getCallAuthData(ctxWithRequestInfo, aud, callHdr) + if err != nil { + return nil, err + } + // TODO(mmukhi): Benchmark if the performance gets better if count the metadata and other header fields + // first and create a slice of that exact size. + // Make the slice of certain predictable size to reduce allocations made by append. + hfLen := 7 // :method, :scheme, :path, :authority, content-type, user-agent, te + hfLen += len(authData) + len(callAuthData) + headerFields := make([]hpack.HeaderField, 0, hfLen) + headerFields = append(headerFields, hpack.HeaderField{Name: ":method", Value: "POST"}) + headerFields = append(headerFields, hpack.HeaderField{Name: ":scheme", Value: t.scheme}) + headerFields = append(headerFields, hpack.HeaderField{Name: ":path", Value: callHdr.Method}) + headerFields = append(headerFields, hpack.HeaderField{Name: ":authority", Value: callHdr.Host}) + headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: grpcutil.ContentType(callHdr.ContentSubtype)}) + headerFields = append(headerFields, hpack.HeaderField{Name: "user-agent", Value: t.userAgent}) + headerFields = append(headerFields, hpack.HeaderField{Name: "te", Value: "trailers"}) + if callHdr.PreviousAttempts > 0 { + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-previous-rpc-attempts", Value: strconv.Itoa(callHdr.PreviousAttempts)}) + } + + if callHdr.SendCompress != "" { + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress}) + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-accept-encoding", Value: callHdr.SendCompress}) + } + if dl, ok := ctx.Deadline(); ok { + // Send out timeout regardless its value. The server can detect timeout context by itself. + // TODO(mmukhi): Perhaps this field should be updated when actually writing out to the wire. + timeout := time.Until(dl) + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-timeout", Value: grpcutil.EncodeDuration(timeout)}) + } + for k, v := range authData { + headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) + } + for k, v := range callAuthData { + headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) + } + if b := stats.OutgoingTags(ctx); b != nil { + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-tags-bin", Value: encodeBinHeader(b)}) + } + if b := stats.OutgoingTrace(ctx); b != nil { + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-trace-bin", Value: encodeBinHeader(b)}) + } + + if md, added, ok := metadata.FromOutgoingContextRaw(ctx); ok { + var k string + for k, vv := range md { + // HTTP doesn't allow you to set pseudoheaders after non pseudoheaders were set. + if isReservedHeader(k) { + continue + } + for _, v := range vv { + headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) + } + } + for _, vv := range added { + for i, v := range vv { + if i%2 == 0 { + k = v + continue + } + // HTTP doesn't allow you to set pseudoheaders after non pseudoheaders were set. + if isReservedHeader(k) { + continue + } + headerFields = append(headerFields, hpack.HeaderField{Name: strings.ToLower(k), Value: encodeMetadataHeader(k, v)}) + } + } + } + if md, ok := t.md.(*metadata.MD); ok { + for k, vv := range *md { + if isReservedHeader(k) { + continue + } + for _, v := range vv { + headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) + } + } + } + return headerFields, nil +} + +func (t *http2Client) createAudience(callHdr *CallHdr) string { + // Create an audience string only if needed. + if len(t.perRPCCreds) == 0 && callHdr.Creds == nil { + return "" + } + // Construct URI required to get auth request metadata. + // Omit port if it is the default one. + host := strings.TrimSuffix(callHdr.Host, ":443") + pos := strings.LastIndex(callHdr.Method, "/") + if pos == -1 { + pos = len(callHdr.Method) + } + return "https://" + host + callHdr.Method[:pos] +} + +func (t *http2Client) getTrAuthData(ctx context.Context, audience string) (map[string]string, error) { + if len(t.perRPCCreds) == 0 { + return nil, nil + } + authData := map[string]string{} + for _, c := range t.perRPCCreds { + data, err := c.GetRequestMetadata(ctx, audience) + if err != nil { + if _, ok := status.FromError(err); ok { + return nil, err + } + + return nil, status.Errorf(codes.Unauthenticated, "transport: %v", err) + } + for k, v := range data { + // Capital header names are illegal in HTTP/2. + k = strings.ToLower(k) + authData[k] = v + } + } + return authData, nil +} + +func (t *http2Client) getCallAuthData(ctx context.Context, audience string, callHdr *CallHdr) (map[string]string, error) { + var callAuthData map[string]string + // Check if credentials.PerRPCCredentials were provided via call options. + // Note: if these credentials are provided both via dial options and call + // options, then both sets of credentials will be applied. + if callCreds := callHdr.Creds; callCreds != nil { + if !t.isSecure && callCreds.RequireTransportSecurity() { + return nil, status.Error(codes.Unauthenticated, "transport: cannot send secure credentials on an insecure connection") + } + data, err := callCreds.GetRequestMetadata(ctx, audience) + if err != nil { + return nil, status.Errorf(codes.Internal, "transport: %v", err) + } + callAuthData = make(map[string]string, len(data)) + for k, v := range data { + // Capital header names are illegal in HTTP/2 + k = strings.ToLower(k) + callAuthData[k] = v + } + } + return callAuthData, nil +} + +// PerformedIOError wraps an error to indicate IO may have been performed +// before the error occurred. +type PerformedIOError struct { + Err error +} + +// Error implements error. +func (p PerformedIOError) Error() string { + return p.Err.Error() +} + +// NewStream creates a stream and registers it into the transport as "active" +// streams. +func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Stream, err error) { + ctx = peer.NewContext(ctx, t.getPeer()) + headerFields, err := t.createHeaderFields(ctx, callHdr) + if err != nil { + // We may have performed I/O in the per-RPC creds callback, so do not + // allow transparent retry. + return nil, PerformedIOError{err} + } + s := t.newStream(ctx, callHdr) + cleanup := func(err error) { + if s.swapState(streamDone) == streamDone { + // If it was already done, return. + return + } + // The stream was unprocessed by the server. + atomic.StoreUint32(&s.unprocessed, 1) + s.write(recvMsg{err: err}) + close(s.done) + // If headerChan isn't closed, then close it. + if atomic.CompareAndSwapUint32(&s.headerChanClosed, 0, 1) { + close(s.headerChan) + } + } + hdr := &headerFrame{ + hf: headerFields, + endStream: false, + initStream: func(id uint32) error { + t.mu.Lock() + if state := t.state; state != reachable { + t.mu.Unlock() + // Do a quick cleanup. + err := error(errStreamDrain) + if state == closing { + err = ErrConnClosing + } + cleanup(err) + return err + } + t.activeStreams[id] = s + if channelz.IsOn() { + atomic.AddInt64(&t.czData.streamsStarted, 1) + atomic.StoreInt64(&t.czData.lastStreamCreatedTime, time.Now().UnixNano()) + } + // If the keepalive goroutine has gone dormant, wake it up. + if t.kpDormant { + t.kpDormancyCond.Signal() + } + t.mu.Unlock() + return nil + }, + onOrphaned: cleanup, + wq: s.wq, + } + firstTry := true + var ch chan struct{} + checkForStreamQuota := func(it interface{}) bool { + if t.streamQuota <= 0 { // Can go negative if server decreases it. + if firstTry { + t.waitingStreams++ + } + ch = t.streamsQuotaAvailable + return false + } + if !firstTry { + t.waitingStreams-- + } + t.streamQuota-- + h := it.(*headerFrame) + h.streamID = t.nextID + t.nextID += 2 + s.id = h.streamID + s.fc = &inFlow{limit: uint32(t.initialWindowSize)} + if t.streamQuota > 0 && t.waitingStreams > 0 { + select { + case t.streamsQuotaAvailable <- struct{}{}: + default: + } + } + return true + } + var hdrListSizeErr error + checkForHeaderListSize := func(it interface{}) bool { + if t.maxSendHeaderListSize == nil { + return true + } + hdrFrame := it.(*headerFrame) + var sz int64 + for _, f := range hdrFrame.hf { + if sz += int64(f.Size()); sz > int64(*t.maxSendHeaderListSize) { + hdrListSizeErr = status.Errorf(codes.Internal, "header list size to send violates the maximum size (%d bytes) set by server", *t.maxSendHeaderListSize) + return false + } + } + return true + } + for { + success, err := t.controlBuf.executeAndPut(func(it interface{}) bool { + if !checkForStreamQuota(it) { + return false + } + if !checkForHeaderListSize(it) { + return false + } + return true + }, hdr) + if err != nil { + return nil, err + } + if success { + break + } + if hdrListSizeErr != nil { + return nil, hdrListSizeErr + } + firstTry = false + select { + case <-ch: + case <-s.ctx.Done(): + return nil, ContextErr(s.ctx.Err()) + case <-t.goAway: + return nil, errStreamDrain + case <-t.ctx.Done(): + return nil, ErrConnClosing + } + } + if t.statsHandler != nil { + header, ok := metadata.FromOutgoingContext(ctx) + if ok { + header.Set("user-agent", t.userAgent) + } else { + header = metadata.Pairs("user-agent", t.userAgent) + } + // Note: The header fields are compressed with hpack after this call returns. + // No WireLength field is set here. + outHeader := &stats.OutHeader{ + Client: true, + FullMethod: callHdr.Method, + RemoteAddr: t.remoteAddr, + LocalAddr: t.localAddr, + Compression: callHdr.SendCompress, + Header: header, + } + t.statsHandler.HandleRPC(s.ctx, outHeader) + } + return s, nil +} + +// CloseStream clears the footprint of a stream when the stream is not needed any more. +// This must not be executed in reader's goroutine. +func (t *http2Client) CloseStream(s *Stream, err error) { + var ( + rst bool + rstCode http2.ErrCode + ) + if err != nil { + rst = true + rstCode = http2.ErrCodeCancel + } + t.closeStream(s, err, rst, rstCode, status.Convert(err), nil, false) +} + +func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2.ErrCode, st *status.Status, mdata map[string][]string, eosReceived bool) { + // Set stream status to done. + if s.swapState(streamDone) == streamDone { + // If it was already done, return. If multiple closeStream calls + // happen simultaneously, wait for the first to finish. + <-s.done + return + } + // status and trailers can be updated here without any synchronization because the stream goroutine will + // only read it after it sees an io.EOF error from read or write and we'll write those errors + // only after updating this. + s.status = st + if len(mdata) > 0 { + s.trailer = mdata + } + if err != nil { + // This will unblock reads eventually. + s.write(recvMsg{err: err}) + } + // If headerChan isn't closed, then close it. + if atomic.CompareAndSwapUint32(&s.headerChanClosed, 0, 1) { + s.noHeaders = true + close(s.headerChan) + } + cleanup := &cleanupStream{ + streamID: s.id, + onWrite: func() { + t.mu.Lock() + if t.activeStreams != nil { + delete(t.activeStreams, s.id) + } + t.mu.Unlock() + if channelz.IsOn() { + if eosReceived { + atomic.AddInt64(&t.czData.streamsSucceeded, 1) + } else { + atomic.AddInt64(&t.czData.streamsFailed, 1) + } + } + }, + rst: rst, + rstCode: rstCode, + } + addBackStreamQuota := func(interface{}) bool { + t.streamQuota++ + if t.streamQuota > 0 && t.waitingStreams > 0 { + select { + case t.streamsQuotaAvailable <- struct{}{}: + default: + } + } + return true + } + t.controlBuf.executeAndPut(addBackStreamQuota, cleanup) + // This will unblock write. + close(s.done) +} + +// Close kicks off the shutdown process of the transport. This should be called +// only once on a transport. Once it is called, the transport should not be +// accessed any more. +// +// This method blocks until the addrConn that initiated this transport is +// re-connected. This happens because t.onClose() begins reconnect logic at the +// addrConn level and blocks until the addrConn is successfully connected. +func (t *http2Client) Close() error { + t.mu.Lock() + // Make sure we only Close once. + if t.state == closing { + t.mu.Unlock() + return nil + } + // Call t.onClose before setting the state to closing to prevent the client + // from attempting to create new streams ASAP. + t.onClose() + t.state = closing + streams := t.activeStreams + t.activeStreams = nil + if t.kpDormant { + // If the keepalive goroutine is blocked on this condition variable, we + // should unblock it so that the goroutine eventually exits. + t.kpDormancyCond.Signal() + } + t.mu.Unlock() + t.controlBuf.finish() + t.cancel() + err := t.conn.Close() + if channelz.IsOn() { + channelz.RemoveEntry(t.channelzID) + } + // Notify all active streams. + for _, s := range streams { + t.closeStream(s, ErrConnClosing, false, http2.ErrCodeNo, status.New(codes.Unavailable, ErrConnClosing.Desc), nil, false) + } + if t.statsHandler != nil { + connEnd := &stats.ConnEnd{ + Client: true, + } + t.statsHandler.HandleConn(t.ctx, connEnd) + } + return err +} + +// GracefulClose sets the state to draining, which prevents new streams from +// being created and causes the transport to be closed when the last active +// stream is closed. If there are no active streams, the transport is closed +// immediately. This does nothing if the transport is already draining or +// closing. +func (t *http2Client) GracefulClose() { + t.mu.Lock() + // Make sure we move to draining only from active. + if t.state == draining || t.state == closing { + t.mu.Unlock() + return + } + t.state = draining + active := len(t.activeStreams) + t.mu.Unlock() + if active == 0 { + t.Close() + return + } + t.controlBuf.put(&incomingGoAway{}) +} + +// Write formats the data into HTTP2 data frame(s) and sends it out. The caller +// should proceed only if Write returns nil. +func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) error { + if opts.Last { + // If it's the last message, update stream state. + if !s.compareAndSwapState(streamActive, streamWriteDone) { + return errStreamDone + } + } else if s.getState() != streamActive { + return errStreamDone + } + df := &dataFrame{ + streamID: s.id, + endStream: opts.Last, + h: hdr, + d: data, + } + if hdr != nil || data != nil { // If it's not an empty data frame, check quota. + if err := s.wq.get(int32(len(hdr) + len(data))); err != nil { + return err + } + } + return t.controlBuf.put(df) +} + +func (t *http2Client) getStream(f http2.Frame) *Stream { + t.mu.Lock() + s := t.activeStreams[f.Header().StreamID] + t.mu.Unlock() + return s +} + +// adjustWindow sends out extra window update over the initial window size +// of stream if the application is requesting data larger in size than +// the window. +func (t *http2Client) adjustWindow(s *Stream, n uint32) { + if w := s.fc.maybeAdjust(n); w > 0 { + t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, increment: w}) + } +} + +// updateWindow adjusts the inbound quota for the stream. +// Window updates will be sent out when the cumulative quota +// exceeds the corresponding threshold. +func (t *http2Client) updateWindow(s *Stream, n uint32) { + if w := s.fc.onRead(n); w > 0 { + t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, increment: w}) + } +} + +// updateFlowControl updates the incoming flow control windows +// for the transport and the stream based on the current bdp +// estimation. +func (t *http2Client) updateFlowControl(n uint32) { + t.mu.Lock() + for _, s := range t.activeStreams { + s.fc.newLimit(n) + } + t.mu.Unlock() + updateIWS := func(interface{}) bool { + t.initialWindowSize = int32(n) + return true + } + t.controlBuf.executeAndPut(updateIWS, &outgoingWindowUpdate{streamID: 0, increment: t.fc.newLimit(n)}) + t.controlBuf.put(&outgoingSettings{ + ss: []http2.Setting{ + { + ID: http2.SettingInitialWindowSize, + Val: n, + }, + }, + }) +} + +func (t *http2Client) handleData(f *http2.DataFrame) { + size := f.Header().Length + var sendBDPPing bool + if t.bdpEst != nil { + sendBDPPing = t.bdpEst.add(size) + } + // Decouple connection's flow control from application's read. + // An update on connection's flow control should not depend on + // whether user application has read the data or not. Such a + // restriction is already imposed on the stream's flow control, + // and therefore the sender will be blocked anyways. + // Decoupling the connection flow control will prevent other + // active(fast) streams from starving in presence of slow or + // inactive streams. + // + if w := t.fc.onData(size); w > 0 { + t.controlBuf.put(&outgoingWindowUpdate{ + streamID: 0, + increment: w, + }) + } + if sendBDPPing { + // Avoid excessive ping detection (e.g. in an L7 proxy) + // by sending a window update prior to the BDP ping. + + if w := t.fc.reset(); w > 0 { + t.controlBuf.put(&outgoingWindowUpdate{ + streamID: 0, + increment: w, + }) + } + + t.controlBuf.put(bdpPing) + } + // Select the right stream to dispatch. + s := t.getStream(f) + if s == nil { + return + } + if size > 0 { + if err := s.fc.onData(size); err != nil { + t.closeStream(s, io.EOF, true, http2.ErrCodeFlowControl, status.New(codes.Internal, err.Error()), nil, false) + return + } + if f.Header().Flags.Has(http2.FlagDataPadded) { + if w := s.fc.onRead(size - uint32(len(f.Data()))); w > 0 { + t.controlBuf.put(&outgoingWindowUpdate{s.id, w}) + } + } + // TODO(bradfitz, zhaoq): A copy is required here because there is no + // guarantee f.Data() is consumed before the arrival of next frame. + // Can this copy be eliminated? + if len(f.Data()) > 0 { + buffer := t.bufferPool.get() + buffer.Reset() + buffer.Write(f.Data()) + s.write(recvMsg{buffer: buffer}) + } + } + // The server has closed the stream without sending trailers. Record that + // the read direction is closed, and set the status appropriately. + if f.FrameHeader.Flags.Has(http2.FlagDataEndStream) { + t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.New(codes.Internal, "server closed the stream without sending trailers"), nil, true) + } +} + +func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) { + s := t.getStream(f) + if s == nil { + return + } + if f.ErrCode == http2.ErrCodeRefusedStream { + // The stream was unprocessed by the server. + atomic.StoreUint32(&s.unprocessed, 1) + } + statusCode, ok := http2ErrConvTab[f.ErrCode] + if !ok { + if logger.V(logLevel) { + logger.Warningf("transport: http2Client.handleRSTStream found no mapped gRPC status for the received http2 error %v", f.ErrCode) + } + statusCode = codes.Unknown + } + if statusCode == codes.Canceled { + if d, ok := s.ctx.Deadline(); ok && !d.After(time.Now()) { + // Our deadline was already exceeded, and that was likely the cause + // of this cancelation. Alter the status code accordingly. + statusCode = codes.DeadlineExceeded + } + } + t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.Newf(statusCode, "stream terminated by RST_STREAM with error code: %v", f.ErrCode), nil, false) +} + +func (t *http2Client) handleSettings(f *http2.SettingsFrame, isFirst bool) { + if f.IsAck() { + return + } + var maxStreams *uint32 + var ss []http2.Setting + var updateFuncs []func() + f.ForeachSetting(func(s http2.Setting) error { + switch s.ID { + case http2.SettingMaxConcurrentStreams: + maxStreams = new(uint32) + *maxStreams = s.Val + case http2.SettingMaxHeaderListSize: + updateFuncs = append(updateFuncs, func() { + t.maxSendHeaderListSize = new(uint32) + *t.maxSendHeaderListSize = s.Val + }) + default: + ss = append(ss, s) + } + return nil + }) + if isFirst && maxStreams == nil { + maxStreams = new(uint32) + *maxStreams = math.MaxUint32 + } + sf := &incomingSettings{ + ss: ss, + } + if maxStreams != nil { + updateStreamQuota := func() { + delta := int64(*maxStreams) - int64(t.maxConcurrentStreams) + t.maxConcurrentStreams = *maxStreams + t.streamQuota += delta + if delta > 0 && t.waitingStreams > 0 { + close(t.streamsQuotaAvailable) // wake all of them up. + t.streamsQuotaAvailable = make(chan struct{}, 1) + } + } + updateFuncs = append(updateFuncs, updateStreamQuota) + } + t.controlBuf.executeAndPut(func(interface{}) bool { + for _, f := range updateFuncs { + f() + } + return true + }, sf) +} + +func (t *http2Client) handlePing(f *http2.PingFrame) { + if f.IsAck() { + // Maybe it's a BDP ping. + if t.bdpEst != nil { + t.bdpEst.calculate(f.Data) + } + return + } + pingAck := &ping{ack: true} + copy(pingAck.data[:], f.Data[:]) + t.controlBuf.put(pingAck) +} + +func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { + t.mu.Lock() + if t.state == closing { + t.mu.Unlock() + return + } + if f.ErrCode == http2.ErrCodeEnhanceYourCalm { + if logger.V(logLevel) { + logger.Infof("Client received GoAway with http2.ErrCodeEnhanceYourCalm.") + } + } + id := f.LastStreamID + if id > 0 && id%2 != 1 { + t.mu.Unlock() + t.Close() + return + } + // A client can receive multiple GoAways from the server (see + // https://github.com/grpc/grpc-go/issues/1387). The idea is that the first + // GoAway will be sent with an ID of MaxInt32 and the second GoAway will be + // sent after an RTT delay with the ID of the last stream the server will + // process. + // + // Therefore, when we get the first GoAway we don't necessarily close any + // streams. While in case of second GoAway we close all streams created after + // the GoAwayId. This way streams that were in-flight while the GoAway from + // server was being sent don't get killed. + select { + case <-t.goAway: // t.goAway has been closed (i.e.,multiple GoAways). + // If there are multiple GoAways the first one should always have an ID greater than the following ones. + if id > t.prevGoAwayID { + t.mu.Unlock() + t.Close() + return + } + default: + t.setGoAwayReason(f) + close(t.goAway) + t.controlBuf.put(&incomingGoAway{}) + // Notify the clientconn about the GOAWAY before we set the state to + // draining, to allow the client to stop attempting to create streams + // before disallowing new streams on this connection. + t.onGoAway(t.goAwayReason) + t.state = draining + } + // All streams with IDs greater than the GoAwayId + // and smaller than the previous GoAway ID should be killed. + upperLimit := t.prevGoAwayID + if upperLimit == 0 { // This is the first GoAway Frame. + upperLimit = math.MaxUint32 // Kill all streams after the GoAway ID. + } + for streamID, stream := range t.activeStreams { + if streamID > id && streamID <= upperLimit { + // The stream was unprocessed by the server. + atomic.StoreUint32(&stream.unprocessed, 1) + t.closeStream(stream, errStreamDrain, false, http2.ErrCodeNo, statusGoAway, nil, false) + } + } + t.prevGoAwayID = id + active := len(t.activeStreams) + t.mu.Unlock() + if active == 0 { + t.Close() + } +} + +// setGoAwayReason sets the value of t.goAwayReason based +// on the GoAway frame received. +// It expects a lock on transport's mutext to be held by +// the caller. +func (t *http2Client) setGoAwayReason(f *http2.GoAwayFrame) { + t.goAwayReason = GoAwayNoReason + switch f.ErrCode { + case http2.ErrCodeEnhanceYourCalm: + if string(f.DebugData()) == "too_many_pings" { + t.goAwayReason = GoAwayTooManyPings + } + } +} + +func (t *http2Client) GetGoAwayReason() GoAwayReason { + t.mu.Lock() + defer t.mu.Unlock() + return t.goAwayReason +} + +func (t *http2Client) handleWindowUpdate(f *http2.WindowUpdateFrame) { + t.controlBuf.put(&incomingWindowUpdate{ + streamID: f.Header().StreamID, + increment: f.Increment, + }) +} + +// operateHeaders takes action on the decoded headers. +func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { + s := t.getStream(frame) + if s == nil { + return + } + endStream := frame.StreamEnded() + atomic.StoreUint32(&s.bytesReceived, 1) + initialHeader := atomic.LoadUint32(&s.headerChanClosed) == 0 + + if !initialHeader && !endStream { + // As specified by gRPC over HTTP2, a HEADERS frame (and associated CONTINUATION frames) can only appear at the start or end of a stream. Therefore, second HEADERS frame must have EOS bit set. + st := status.New(codes.Internal, "a HEADERS frame cannot appear in the middle of a stream") + t.closeStream(s, st.Err(), true, http2.ErrCodeProtocol, st, nil, false) + return + } + + state := &decodeState{} + // Initialize isGRPC value to be !initialHeader, since if a gRPC Response-Headers has already been received, then it means that the peer is speaking gRPC and we are in gRPC mode. + state.data.isGRPC = !initialHeader + if h2code, err := state.decodeHeader(frame); err != nil { + t.closeStream(s, err, true, h2code, status.Convert(err), nil, endStream) + return + } + + isHeader := false + defer func() { + if t.statsHandler != nil { + if isHeader { + inHeader := &stats.InHeader{ + Client: true, + WireLength: int(frame.Header().Length), + Header: s.header.Copy(), + Compression: s.recvCompress, + } + t.statsHandler.HandleRPC(s.ctx, inHeader) + } else { + inTrailer := &stats.InTrailer{ + Client: true, + WireLength: int(frame.Header().Length), + Trailer: s.trailer.Copy(), + } + t.statsHandler.HandleRPC(s.ctx, inTrailer) + } + } + }() + + // If headerChan hasn't been closed yet + if atomic.CompareAndSwapUint32(&s.headerChanClosed, 0, 1) { + s.headerValid = true + if !endStream { + // HEADERS frame block carries a Response-Headers. + isHeader = true + // These values can be set without any synchronization because + // stream goroutine will read it only after seeing a closed + // headerChan which we'll close after setting this. + s.recvCompress = state.data.encoding + if len(state.data.mdata) > 0 { + s.header = state.data.mdata + } + } else { + // HEADERS frame block carries a Trailers-Only. + s.noHeaders = true + } + close(s.headerChan) + } + + if !endStream { + return + } + + // if client received END_STREAM from server while stream was still active, send RST_STREAM + rst := s.getState() == streamActive + t.closeStream(s, io.EOF, rst, http2.ErrCodeNo, state.status(), state.data.mdata, true) +} + +// reader runs as a separate goroutine in charge of reading data from network +// connection. +// +// TODO(zhaoq): currently one reader per transport. Investigate whether this is +// optimal. +// TODO(zhaoq): Check the validity of the incoming frame sequence. +func (t *http2Client) reader() { + defer close(t.readerDone) + // Check the validity of server preface. + frame, err := t.framer.fr.ReadFrame() + if err != nil { + t.Close() // this kicks off resetTransport, so must be last before return + return + } + t.conn.SetReadDeadline(time.Time{}) // reset deadline once we get the settings frame (we didn't time out, yay!) + if t.keepaliveEnabled { + atomic.StoreInt64(&t.lastRead, time.Now().UnixNano()) + } + sf, ok := frame.(*http2.SettingsFrame) + if !ok { + t.Close() // this kicks off resetTransport, so must be last before return + return + } + t.onPrefaceReceipt() + t.handleSettings(sf, true) + + // loop to keep reading incoming messages on this transport. + for { + t.controlBuf.throttle() + frame, err := t.framer.fr.ReadFrame() + if t.keepaliveEnabled { + atomic.StoreInt64(&t.lastRead, time.Now().UnixNano()) + } + if err != nil { + // Abort an active stream if the http2.Framer returns a + // http2.StreamError. This can happen only if the server's response + // is malformed http2. + if se, ok := err.(http2.StreamError); ok { + t.mu.Lock() + s := t.activeStreams[se.StreamID] + t.mu.Unlock() + if s != nil { + // use error detail to provide better err message + code := http2ErrConvTab[se.Code] + errorDetail := t.framer.fr.ErrorDetail() + var msg string + if errorDetail != nil { + msg = errorDetail.Error() + } else { + msg = "received invalid frame" + } + t.closeStream(s, status.Error(code, msg), true, http2.ErrCodeProtocol, status.New(code, msg), nil, false) + } + continue + } else { + // Transport error. + t.Close() + return + } + } + switch frame := frame.(type) { + case *http2.MetaHeadersFrame: + t.operateHeaders(frame) + case *http2.DataFrame: + t.handleData(frame) + case *http2.RSTStreamFrame: + t.handleRSTStream(frame) + case *http2.SettingsFrame: + t.handleSettings(frame, false) + case *http2.PingFrame: + t.handlePing(frame) + case *http2.GoAwayFrame: + t.handleGoAway(frame) + case *http2.WindowUpdateFrame: + t.handleWindowUpdate(frame) + default: + if logger.V(logLevel) { + logger.Errorf("transport: http2Client.reader got unhandled frame type %v.", frame) + } + } + } +} + +func minTime(a, b time.Duration) time.Duration { + if a < b { + return a + } + return b +} + +// keepalive running in a separate goroutune makes sure the connection is alive by sending pings. +func (t *http2Client) keepalive() { + p := &ping{data: [8]byte{}} + // True iff a ping has been sent, and no data has been received since then. + outstandingPing := false + // Amount of time remaining before which we should receive an ACK for the + // last sent ping. + timeoutLeft := time.Duration(0) + // Records the last value of t.lastRead before we go block on the timer. + // This is required to check for read activity since then. + prevNano := time.Now().UnixNano() + timer := time.NewTimer(t.kp.Time) + for { + select { + case <-timer.C: + lastRead := atomic.LoadInt64(&t.lastRead) + if lastRead > prevNano { + // There has been read activity since the last time we were here. + outstandingPing = false + // Next timer should fire at kp.Time seconds from lastRead time. + timer.Reset(time.Duration(lastRead) + t.kp.Time - time.Duration(time.Now().UnixNano())) + prevNano = lastRead + continue + } + if outstandingPing && timeoutLeft <= 0 { + t.Close() + return + } + t.mu.Lock() + if t.state == closing { + // If the transport is closing, we should exit from the + // keepalive goroutine here. If not, we could have a race + // between the call to Signal() from Close() and the call to + // Wait() here, whereby the keepalive goroutine ends up + // blocking on the condition variable which will never be + // signalled again. + t.mu.Unlock() + return + } + if len(t.activeStreams) < 1 && !t.kp.PermitWithoutStream { + // If a ping was sent out previously (because there were active + // streams at that point) which wasn't acked and its timeout + // hadn't fired, but we got here and are about to go dormant, + // we should make sure that we unconditionally send a ping once + // we awaken. + outstandingPing = false + t.kpDormant = true + t.kpDormancyCond.Wait() + } + t.kpDormant = false + t.mu.Unlock() + + // We get here either because we were dormant and a new stream was + // created which unblocked the Wait() call, or because the + // keepalive timer expired. In both cases, we need to send a ping. + if !outstandingPing { + if channelz.IsOn() { + atomic.AddInt64(&t.czData.kpCount, 1) + } + t.controlBuf.put(p) + timeoutLeft = t.kp.Timeout + outstandingPing = true + } + // The amount of time to sleep here is the minimum of kp.Time and + // timeoutLeft. This will ensure that we wait only for kp.Time + // before sending out the next ping (for cases where the ping is + // acked). + sleepDuration := minTime(t.kp.Time, timeoutLeft) + timeoutLeft -= sleepDuration + timer.Reset(sleepDuration) + case <-t.ctx.Done(): + if !timer.Stop() { + <-timer.C + } + return + } + } +} + +func (t *http2Client) Error() <-chan struct{} { + return t.ctx.Done() +} + +func (t *http2Client) GoAway() <-chan struct{} { + return t.goAway +} + +func (t *http2Client) ChannelzMetric() *channelz.SocketInternalMetric { + s := channelz.SocketInternalMetric{ + StreamsStarted: atomic.LoadInt64(&t.czData.streamsStarted), + StreamsSucceeded: atomic.LoadInt64(&t.czData.streamsSucceeded), + StreamsFailed: atomic.LoadInt64(&t.czData.streamsFailed), + MessagesSent: atomic.LoadInt64(&t.czData.msgSent), + MessagesReceived: atomic.LoadInt64(&t.czData.msgRecv), + KeepAlivesSent: atomic.LoadInt64(&t.czData.kpCount), + LastLocalStreamCreatedTimestamp: time.Unix(0, atomic.LoadInt64(&t.czData.lastStreamCreatedTime)), + LastMessageSentTimestamp: time.Unix(0, atomic.LoadInt64(&t.czData.lastMsgSentTime)), + LastMessageReceivedTimestamp: time.Unix(0, atomic.LoadInt64(&t.czData.lastMsgRecvTime)), + LocalFlowControlWindow: int64(t.fc.getSize()), + SocketOptions: channelz.GetSocketOption(t.conn), + LocalAddr: t.localAddr, + RemoteAddr: t.remoteAddr, + // RemoteName : + } + if au, ok := t.authInfo.(credentials.ChannelzSecurityInfo); ok { + s.Security = au.GetSecurityValue() + } + s.RemoteFlowControlWindow = t.getOutFlowWindow() + return &s +} + +func (t *http2Client) RemoteAddr() net.Addr { return t.remoteAddr } + +func (t *http2Client) IncrMsgSent() { + atomic.AddInt64(&t.czData.msgSent, 1) + atomic.StoreInt64(&t.czData.lastMsgSentTime, time.Now().UnixNano()) +} + +func (t *http2Client) IncrMsgRecv() { + atomic.AddInt64(&t.czData.msgRecv, 1) + atomic.StoreInt64(&t.czData.lastMsgRecvTime, time.Now().UnixNano()) +} + +func (t *http2Client) getOutFlowWindow() int64 { + resp := make(chan uint32, 1) + timer := time.NewTimer(time.Second) + defer timer.Stop() + t.controlBuf.put(&outFlowControlSizeRequest{resp}) + select { + case sz := <-resp: + return int64(sz) + case <-t.ctxDone: + return -1 + case <-timer.C: + return -2 + } +} diff --git a/vendor/google.golang.org/grpc/internal/transport/http2_server.go b/vendor/google.golang.org/grpc/internal/transport/http2_server.go new file mode 100644 index 000000000..0cf1cc320 --- /dev/null +++ b/vendor/google.golang.org/grpc/internal/transport/http2_server.go @@ -0,0 +1,1268 @@ +/* + * + * Copyright 2014 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package transport + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "math" + "net" + "strconv" + "sync" + "sync/atomic" + "time" + + "github.com/golang/protobuf/proto" + "golang.org/x/net/http2" + "golang.org/x/net/http2/hpack" + "google.golang.org/grpc/internal/grpcutil" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal/channelz" + "google.golang.org/grpc/internal/grpcrand" + "google.golang.org/grpc/keepalive" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/peer" + "google.golang.org/grpc/stats" + "google.golang.org/grpc/status" + "google.golang.org/grpc/tap" +) + +var ( + // ErrIllegalHeaderWrite indicates that setting header is illegal because of + // the stream's state. + ErrIllegalHeaderWrite = errors.New("transport: the stream is done or WriteHeader was already called") + // ErrHeaderListSizeLimitViolation indicates that the header list size is larger + // than the limit set by peer. + ErrHeaderListSizeLimitViolation = errors.New("transport: trying to send header list size larger than the limit set by peer") +) + +// serverConnectionCounter counts the number of connections a server has seen +// (equal to the number of http2Servers created). Must be accessed atomically. +var serverConnectionCounter uint64 + +// http2Server implements the ServerTransport interface with HTTP2. +type http2Server struct { + lastRead int64 // Keep this field 64-bit aligned. Accessed atomically. + ctx context.Context + done chan struct{} + conn net.Conn + loopy *loopyWriter + readerDone chan struct{} // sync point to enable testing. + writerDone chan struct{} // sync point to enable testing. + remoteAddr net.Addr + localAddr net.Addr + maxStreamID uint32 // max stream ID ever seen + authInfo credentials.AuthInfo // auth info about the connection + inTapHandle tap.ServerInHandle + framer *framer + // The max number of concurrent streams. + maxStreams uint32 + // controlBuf delivers all the control related tasks (e.g., window + // updates, reset streams, and various settings) to the controller. + controlBuf *controlBuffer + fc *trInFlow + stats stats.Handler + // Keepalive and max-age parameters for the server. + kp keepalive.ServerParameters + // Keepalive enforcement policy. + kep keepalive.EnforcementPolicy + // The time instance last ping was received. + lastPingAt time.Time + // Number of times the client has violated keepalive ping policy so far. + pingStrikes uint8 + // Flag to signify that number of ping strikes should be reset to 0. + // This is set whenever data or header frames are sent. + // 1 means yes. + resetPingStrikes uint32 // Accessed atomically. + initialWindowSize int32 + bdpEst *bdpEstimator + maxSendHeaderListSize *uint32 + + mu sync.Mutex // guard the following + + // drainChan is initialized when drain(...) is called the first time. + // After which the server writes out the first GoAway(with ID 2^31-1) frame. + // Then an independent goroutine will be launched to later send the second GoAway. + // During this time we don't want to write another first GoAway(with ID 2^31 -1) frame. + // Thus call to drain(...) will be a no-op if drainChan is already initialized since draining is + // already underway. + drainChan chan struct{} + state transportState + activeStreams map[uint32]*Stream + // idle is the time instant when the connection went idle. + // This is either the beginning of the connection or when the number of + // RPCs go down to 0. + // When the connection is busy, this value is set to 0. + idle time.Time + + // Fields below are for channelz metric collection. + channelzID int64 // channelz unique identification number + czData *channelzData + bufferPool *bufferPool + + connectionID uint64 +} + +// newHTTP2Server constructs a ServerTransport based on HTTP2. ConnectionError is +// returned if something goes wrong. +func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err error) { + writeBufSize := config.WriteBufferSize + readBufSize := config.ReadBufferSize + maxHeaderListSize := defaultServerMaxHeaderListSize + if config.MaxHeaderListSize != nil { + maxHeaderListSize = *config.MaxHeaderListSize + } + framer := newFramer(conn, writeBufSize, readBufSize, maxHeaderListSize) + // Send initial settings as connection preface to client. + isettings := []http2.Setting{{ + ID: http2.SettingMaxFrameSize, + Val: http2MaxFrameLen, + }} + // TODO(zhaoq): Have a better way to signal "no limit" because 0 is + // permitted in the HTTP2 spec. + maxStreams := config.MaxStreams + if maxStreams == 0 { + maxStreams = math.MaxUint32 + } else { + isettings = append(isettings, http2.Setting{ + ID: http2.SettingMaxConcurrentStreams, + Val: maxStreams, + }) + } + dynamicWindow := true + iwz := int32(initialWindowSize) + if config.InitialWindowSize >= defaultWindowSize { + iwz = config.InitialWindowSize + dynamicWindow = false + } + icwz := int32(initialWindowSize) + if config.InitialConnWindowSize >= defaultWindowSize { + icwz = config.InitialConnWindowSize + dynamicWindow = false + } + if iwz != defaultWindowSize { + isettings = append(isettings, http2.Setting{ + ID: http2.SettingInitialWindowSize, + Val: uint32(iwz)}) + } + if config.MaxHeaderListSize != nil { + isettings = append(isettings, http2.Setting{ + ID: http2.SettingMaxHeaderListSize, + Val: *config.MaxHeaderListSize, + }) + } + if config.HeaderTableSize != nil { + isettings = append(isettings, http2.Setting{ + ID: http2.SettingHeaderTableSize, + Val: *config.HeaderTableSize, + }) + } + if err := framer.fr.WriteSettings(isettings...); err != nil { + return nil, connectionErrorf(false, err, "transport: %v", err) + } + // Adjust the connection flow control window if needed. + if delta := uint32(icwz - defaultWindowSize); delta > 0 { + if err := framer.fr.WriteWindowUpdate(0, delta); err != nil { + return nil, connectionErrorf(false, err, "transport: %v", err) + } + } + kp := config.KeepaliveParams + if kp.MaxConnectionIdle == 0 { + kp.MaxConnectionIdle = defaultMaxConnectionIdle + } + if kp.MaxConnectionAge == 0 { + kp.MaxConnectionAge = defaultMaxConnectionAge + } + // Add a jitter to MaxConnectionAge. + kp.MaxConnectionAge += getJitter(kp.MaxConnectionAge) + if kp.MaxConnectionAgeGrace == 0 { + kp.MaxConnectionAgeGrace = defaultMaxConnectionAgeGrace + } + if kp.Time == 0 { + kp.Time = defaultServerKeepaliveTime + } + if kp.Timeout == 0 { + kp.Timeout = defaultServerKeepaliveTimeout + } + kep := config.KeepalivePolicy + if kep.MinTime == 0 { + kep.MinTime = defaultKeepalivePolicyMinTime + } + done := make(chan struct{}) + t := &http2Server{ + ctx: context.Background(), + done: done, + conn: conn, + remoteAddr: conn.RemoteAddr(), + localAddr: conn.LocalAddr(), + authInfo: config.AuthInfo, + framer: framer, + readerDone: make(chan struct{}), + writerDone: make(chan struct{}), + maxStreams: maxStreams, + inTapHandle: config.InTapHandle, + fc: &trInFlow{limit: uint32(icwz)}, + state: reachable, + activeStreams: make(map[uint32]*Stream), + stats: config.StatsHandler, + kp: kp, + idle: time.Now(), + kep: kep, + initialWindowSize: iwz, + czData: new(channelzData), + bufferPool: newBufferPool(), + } + t.controlBuf = newControlBuffer(t.done) + if dynamicWindow { + t.bdpEst = &bdpEstimator{ + bdp: initialWindowSize, + updateFlowControl: t.updateFlowControl, + } + } + if t.stats != nil { + t.ctx = t.stats.TagConn(t.ctx, &stats.ConnTagInfo{ + RemoteAddr: t.remoteAddr, + LocalAddr: t.localAddr, + }) + connBegin := &stats.ConnBegin{} + t.stats.HandleConn(t.ctx, connBegin) + } + if channelz.IsOn() { + t.channelzID = channelz.RegisterNormalSocket(t, config.ChannelzParentID, fmt.Sprintf("%s -> %s", t.remoteAddr, t.localAddr)) + } + + t.connectionID = atomic.AddUint64(&serverConnectionCounter, 1) + + t.framer.writer.Flush() + + defer func() { + if err != nil { + t.Close() + } + }() + + // Check the validity of client preface. + preface := make([]byte, len(clientPreface)) + if _, err := io.ReadFull(t.conn, preface); err != nil { + return nil, connectionErrorf(false, err, "transport: http2Server.HandleStreams failed to receive the preface from client: %v", err) + } + if !bytes.Equal(preface, clientPreface) { + return nil, connectionErrorf(false, nil, "transport: http2Server.HandleStreams received bogus greeting from client: %q", preface) + } + + frame, err := t.framer.fr.ReadFrame() + if err == io.EOF || err == io.ErrUnexpectedEOF { + return nil, err + } + if err != nil { + return nil, connectionErrorf(false, err, "transport: http2Server.HandleStreams failed to read initial settings frame: %v", err) + } + atomic.StoreInt64(&t.lastRead, time.Now().UnixNano()) + sf, ok := frame.(*http2.SettingsFrame) + if !ok { + return nil, connectionErrorf(false, nil, "transport: http2Server.HandleStreams saw invalid preface type %T from client", frame) + } + t.handleSettings(sf) + + go func() { + t.loopy = newLoopyWriter(serverSide, t.framer, t.controlBuf, t.bdpEst) + t.loopy.ssGoAwayHandler = t.outgoingGoAwayHandler + if err := t.loopy.run(); err != nil { + if logger.V(logLevel) { + logger.Errorf("transport: loopyWriter.run returning. Err: %v", err) + } + } + t.conn.Close() + close(t.writerDone) + }() + go t.keepalive() + return t, nil +} + +// operateHeader takes action on the decoded headers. +func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream), traceCtx func(context.Context, string) context.Context) (fatal bool) { + streamID := frame.Header().StreamID + state := &decodeState{ + serverSide: true, + } + if h2code, err := state.decodeHeader(frame); err != nil { + if _, ok := status.FromError(err); ok { + t.controlBuf.put(&cleanupStream{ + streamID: streamID, + rst: true, + rstCode: h2code, + onWrite: func() {}, + }) + } + return false + } + + buf := newRecvBuffer() + s := &Stream{ + id: streamID, + st: t, + buf: buf, + fc: &inFlow{limit: uint32(t.initialWindowSize)}, + recvCompress: state.data.encoding, + method: state.data.method, + contentSubtype: state.data.contentSubtype, + } + if frame.StreamEnded() { + // s is just created by the caller. No lock needed. + s.state = streamReadDone + } + if state.data.timeoutSet { + s.ctx, s.cancel = context.WithTimeout(t.ctx, state.data.timeout) + } else { + s.ctx, s.cancel = context.WithCancel(t.ctx) + } + pr := &peer.Peer{ + Addr: t.remoteAddr, + } + // Attach Auth info if there is any. + if t.authInfo != nil { + pr.AuthInfo = t.authInfo + } + s.ctx = peer.NewContext(s.ctx, pr) + // Attach the received metadata to the context. + if len(state.data.mdata) > 0 { + s.ctx = metadata.NewIncomingContext(s.ctx, state.data.mdata) + } + if state.data.statsTags != nil { + s.ctx = stats.SetIncomingTags(s.ctx, state.data.statsTags) + } + if state.data.statsTrace != nil { + s.ctx = stats.SetIncomingTrace(s.ctx, state.data.statsTrace) + } + if t.inTapHandle != nil { + var err error + info := &tap.Info{ + FullMethodName: state.data.method, + } + s.ctx, err = t.inTapHandle(s.ctx, info) + if err != nil { + if logger.V(logLevel) { + logger.Warningf("transport: http2Server.operateHeaders got an error from InTapHandle: %v", err) + } + t.controlBuf.put(&cleanupStream{ + streamID: s.id, + rst: true, + rstCode: http2.ErrCodeRefusedStream, + onWrite: func() {}, + }) + s.cancel() + return false + } + } + t.mu.Lock() + if t.state != reachable { + t.mu.Unlock() + s.cancel() + return false + } + if uint32(len(t.activeStreams)) >= t.maxStreams { + t.mu.Unlock() + t.controlBuf.put(&cleanupStream{ + streamID: streamID, + rst: true, + rstCode: http2.ErrCodeRefusedStream, + onWrite: func() {}, + }) + s.cancel() + return false + } + if streamID%2 != 1 || streamID <= t.maxStreamID { + t.mu.Unlock() + // illegal gRPC stream id. + if logger.V(logLevel) { + logger.Errorf("transport: http2Server.HandleStreams received an illegal stream id: %v", streamID) + } + s.cancel() + return true + } + t.maxStreamID = streamID + t.activeStreams[streamID] = s + if len(t.activeStreams) == 1 { + t.idle = time.Time{} + } + t.mu.Unlock() + if channelz.IsOn() { + atomic.AddInt64(&t.czData.streamsStarted, 1) + atomic.StoreInt64(&t.czData.lastStreamCreatedTime, time.Now().UnixNano()) + } + s.requestRead = func(n int) { + t.adjustWindow(s, uint32(n)) + } + s.ctx = traceCtx(s.ctx, s.method) + if t.stats != nil { + s.ctx = t.stats.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method}) + inHeader := &stats.InHeader{ + FullMethod: s.method, + RemoteAddr: t.remoteAddr, + LocalAddr: t.localAddr, + Compression: s.recvCompress, + WireLength: int(frame.Header().Length), + Header: metadata.MD(state.data.mdata).Copy(), + } + t.stats.HandleRPC(s.ctx, inHeader) + } + s.ctxDone = s.ctx.Done() + s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone) + s.trReader = &transportReader{ + reader: &recvBufferReader{ + ctx: s.ctx, + ctxDone: s.ctxDone, + recv: s.buf, + freeBuffer: t.bufferPool.put, + }, + windowHandler: func(n int) { + t.updateWindow(s, uint32(n)) + }, + } + // Register the stream with loopy. + t.controlBuf.put(®isterStream{ + streamID: s.id, + wq: s.wq, + }) + handle(s) + return false +} + +// HandleStreams receives incoming streams using the given handler. This is +// typically run in a separate goroutine. +// traceCtx attaches trace to ctx and returns the new context. +func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context.Context, string) context.Context) { + defer close(t.readerDone) + for { + t.controlBuf.throttle() + frame, err := t.framer.fr.ReadFrame() + atomic.StoreInt64(&t.lastRead, time.Now().UnixNano()) + if err != nil { + if se, ok := err.(http2.StreamError); ok { + if logger.V(logLevel) { + logger.Warningf("transport: http2Server.HandleStreams encountered http2.StreamError: %v", se) + } + t.mu.Lock() + s := t.activeStreams[se.StreamID] + t.mu.Unlock() + if s != nil { + t.closeStream(s, true, se.Code, false) + } else { + t.controlBuf.put(&cleanupStream{ + streamID: se.StreamID, + rst: true, + rstCode: se.Code, + onWrite: func() {}, + }) + } + continue + } + if err == io.EOF || err == io.ErrUnexpectedEOF { + t.Close() + return + } + if logger.V(logLevel) { + logger.Warningf("transport: http2Server.HandleStreams failed to read frame: %v", err) + } + t.Close() + return + } + switch frame := frame.(type) { + case *http2.MetaHeadersFrame: + if t.operateHeaders(frame, handle, traceCtx) { + t.Close() + break + } + case *http2.DataFrame: + t.handleData(frame) + case *http2.RSTStreamFrame: + t.handleRSTStream(frame) + case *http2.SettingsFrame: + t.handleSettings(frame) + case *http2.PingFrame: + t.handlePing(frame) + case *http2.WindowUpdateFrame: + t.handleWindowUpdate(frame) + case *http2.GoAwayFrame: + // TODO: Handle GoAway from the client appropriately. + default: + if logger.V(logLevel) { + logger.Errorf("transport: http2Server.HandleStreams found unhandled frame type %v.", frame) + } + } + } +} + +func (t *http2Server) getStream(f http2.Frame) (*Stream, bool) { + t.mu.Lock() + defer t.mu.Unlock() + if t.activeStreams == nil { + // The transport is closing. + return nil, false + } + s, ok := t.activeStreams[f.Header().StreamID] + if !ok { + // The stream is already done. + return nil, false + } + return s, true +} + +// adjustWindow sends out extra window update over the initial window size +// of stream if the application is requesting data larger in size than +// the window. +func (t *http2Server) adjustWindow(s *Stream, n uint32) { + if w := s.fc.maybeAdjust(n); w > 0 { + t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, increment: w}) + } + +} + +// updateWindow adjusts the inbound quota for the stream and the transport. +// Window updates will deliver to the controller for sending when +// the cumulative quota exceeds the corresponding threshold. +func (t *http2Server) updateWindow(s *Stream, n uint32) { + if w := s.fc.onRead(n); w > 0 { + t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, + increment: w, + }) + } +} + +// updateFlowControl updates the incoming flow control windows +// for the transport and the stream based on the current bdp +// estimation. +func (t *http2Server) updateFlowControl(n uint32) { + t.mu.Lock() + for _, s := range t.activeStreams { + s.fc.newLimit(n) + } + t.initialWindowSize = int32(n) + t.mu.Unlock() + t.controlBuf.put(&outgoingWindowUpdate{ + streamID: 0, + increment: t.fc.newLimit(n), + }) + t.controlBuf.put(&outgoingSettings{ + ss: []http2.Setting{ + { + ID: http2.SettingInitialWindowSize, + Val: n, + }, + }, + }) + +} + +func (t *http2Server) handleData(f *http2.DataFrame) { + size := f.Header().Length + var sendBDPPing bool + if t.bdpEst != nil { + sendBDPPing = t.bdpEst.add(size) + } + // Decouple connection's flow control from application's read. + // An update on connection's flow control should not depend on + // whether user application has read the data or not. Such a + // restriction is already imposed on the stream's flow control, + // and therefore the sender will be blocked anyways. + // Decoupling the connection flow control will prevent other + // active(fast) streams from starving in presence of slow or + // inactive streams. + if w := t.fc.onData(size); w > 0 { + t.controlBuf.put(&outgoingWindowUpdate{ + streamID: 0, + increment: w, + }) + } + if sendBDPPing { + // Avoid excessive ping detection (e.g. in an L7 proxy) + // by sending a window update prior to the BDP ping. + if w := t.fc.reset(); w > 0 { + t.controlBuf.put(&outgoingWindowUpdate{ + streamID: 0, + increment: w, + }) + } + t.controlBuf.put(bdpPing) + } + // Select the right stream to dispatch. + s, ok := t.getStream(f) + if !ok { + return + } + if s.getState() == streamReadDone { + t.closeStream(s, true, http2.ErrCodeStreamClosed, false) + return + } + if size > 0 { + if err := s.fc.onData(size); err != nil { + t.closeStream(s, true, http2.ErrCodeFlowControl, false) + return + } + if f.Header().Flags.Has(http2.FlagDataPadded) { + if w := s.fc.onRead(size - uint32(len(f.Data()))); w > 0 { + t.controlBuf.put(&outgoingWindowUpdate{s.id, w}) + } + } + // TODO(bradfitz, zhaoq): A copy is required here because there is no + // guarantee f.Data() is consumed before the arrival of next frame. + // Can this copy be eliminated? + if len(f.Data()) > 0 { + buffer := t.bufferPool.get() + buffer.Reset() + buffer.Write(f.Data()) + s.write(recvMsg{buffer: buffer}) + } + } + if f.Header().Flags.Has(http2.FlagDataEndStream) { + // Received the end of stream from the client. + s.compareAndSwapState(streamActive, streamReadDone) + s.write(recvMsg{err: io.EOF}) + } +} + +func (t *http2Server) handleRSTStream(f *http2.RSTStreamFrame) { + // If the stream is not deleted from the transport's active streams map, then do a regular close stream. + if s, ok := t.getStream(f); ok { + t.closeStream(s, false, 0, false) + return + } + // If the stream is already deleted from the active streams map, then put a cleanupStream item into controlbuf to delete the stream from loopy writer's established streams map. + t.controlBuf.put(&cleanupStream{ + streamID: f.Header().StreamID, + rst: false, + rstCode: 0, + onWrite: func() {}, + }) +} + +func (t *http2Server) handleSettings(f *http2.SettingsFrame) { + if f.IsAck() { + return + } + var ss []http2.Setting + var updateFuncs []func() + f.ForeachSetting(func(s http2.Setting) error { + switch s.ID { + case http2.SettingMaxHeaderListSize: + updateFuncs = append(updateFuncs, func() { + t.maxSendHeaderListSize = new(uint32) + *t.maxSendHeaderListSize = s.Val + }) + default: + ss = append(ss, s) + } + return nil + }) + t.controlBuf.executeAndPut(func(interface{}) bool { + for _, f := range updateFuncs { + f() + } + return true + }, &incomingSettings{ + ss: ss, + }) +} + +const ( + maxPingStrikes = 2 + defaultPingTimeout = 2 * time.Hour +) + +func (t *http2Server) handlePing(f *http2.PingFrame) { + if f.IsAck() { + if f.Data == goAwayPing.data && t.drainChan != nil { + close(t.drainChan) + return + } + // Maybe it's a BDP ping. + if t.bdpEst != nil { + t.bdpEst.calculate(f.Data) + } + return + } + pingAck := &ping{ack: true} + copy(pingAck.data[:], f.Data[:]) + t.controlBuf.put(pingAck) + + now := time.Now() + defer func() { + t.lastPingAt = now + }() + // A reset ping strikes means that we don't need to check for policy + // violation for this ping and the pingStrikes counter should be set + // to 0. + if atomic.CompareAndSwapUint32(&t.resetPingStrikes, 1, 0) { + t.pingStrikes = 0 + return + } + t.mu.Lock() + ns := len(t.activeStreams) + t.mu.Unlock() + if ns < 1 && !t.kep.PermitWithoutStream { + // Keepalive shouldn't be active thus, this new ping should + // have come after at least defaultPingTimeout. + if t.lastPingAt.Add(defaultPingTimeout).After(now) { + t.pingStrikes++ + } + } else { + // Check if keepalive policy is respected. + if t.lastPingAt.Add(t.kep.MinTime).After(now) { + t.pingStrikes++ + } + } + + if t.pingStrikes > maxPingStrikes { + // Send goaway and close the connection. + if logger.V(logLevel) { + logger.Errorf("transport: Got too many pings from the client, closing the connection.") + } + t.controlBuf.put(&goAway{code: http2.ErrCodeEnhanceYourCalm, debugData: []byte("too_many_pings"), closeConn: true}) + } +} + +func (t *http2Server) handleWindowUpdate(f *http2.WindowUpdateFrame) { + t.controlBuf.put(&incomingWindowUpdate{ + streamID: f.Header().StreamID, + increment: f.Increment, + }) +} + +func appendHeaderFieldsFromMD(headerFields []hpack.HeaderField, md metadata.MD) []hpack.HeaderField { + for k, vv := range md { + if isReservedHeader(k) { + // Clients don't tolerate reading restricted headers after some non restricted ones were sent. + continue + } + for _, v := range vv { + headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) + } + } + return headerFields +} + +func (t *http2Server) checkForHeaderListSize(it interface{}) bool { + if t.maxSendHeaderListSize == nil { + return true + } + hdrFrame := it.(*headerFrame) + var sz int64 + for _, f := range hdrFrame.hf { + if sz += int64(f.Size()); sz > int64(*t.maxSendHeaderListSize) { + if logger.V(logLevel) { + logger.Errorf("header list size to send violates the maximum size (%d bytes) set by client", *t.maxSendHeaderListSize) + } + return false + } + } + return true +} + +// WriteHeader sends the header metadata md back to the client. +func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { + if s.updateHeaderSent() || s.getState() == streamDone { + return ErrIllegalHeaderWrite + } + s.hdrMu.Lock() + if md.Len() > 0 { + if s.header.Len() > 0 { + s.header = metadata.Join(s.header, md) + } else { + s.header = md + } + } + if err := t.writeHeaderLocked(s); err != nil { + s.hdrMu.Unlock() + return err + } + s.hdrMu.Unlock() + return nil +} + +func (t *http2Server) setResetPingStrikes() { + atomic.StoreUint32(&t.resetPingStrikes, 1) +} + +func (t *http2Server) writeHeaderLocked(s *Stream) error { + // TODO(mmukhi): Benchmark if the performance gets better if count the metadata and other header fields + // first and create a slice of that exact size. + headerFields := make([]hpack.HeaderField, 0, 2) // at least :status, content-type will be there if none else. + headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"}) + headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: grpcutil.ContentType(s.contentSubtype)}) + if s.sendCompress != "" { + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress}) + } + headerFields = appendHeaderFieldsFromMD(headerFields, s.header) + success, err := t.controlBuf.executeAndPut(t.checkForHeaderListSize, &headerFrame{ + streamID: s.id, + hf: headerFields, + endStream: false, + onWrite: t.setResetPingStrikes, + }) + if !success { + if err != nil { + return err + } + t.closeStream(s, true, http2.ErrCodeInternal, false) + return ErrHeaderListSizeLimitViolation + } + if t.stats != nil { + // Note: Headers are compressed with hpack after this call returns. + // No WireLength field is set here. + outHeader := &stats.OutHeader{ + Header: s.header.Copy(), + Compression: s.sendCompress, + } + t.stats.HandleRPC(s.Context(), outHeader) + } + return nil +} + +// WriteStatus sends stream status to the client and terminates the stream. +// There is no further I/O operations being able to perform on this stream. +// TODO(zhaoq): Now it indicates the end of entire stream. Revisit if early +// OK is adopted. +func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { + if s.getState() == streamDone { + return nil + } + s.hdrMu.Lock() + // TODO(mmukhi): Benchmark if the performance gets better if count the metadata and other header fields + // first and create a slice of that exact size. + headerFields := make([]hpack.HeaderField, 0, 2) // grpc-status and grpc-message will be there if none else. + if !s.updateHeaderSent() { // No headers have been sent. + if len(s.header) > 0 { // Send a separate header frame. + if err := t.writeHeaderLocked(s); err != nil { + s.hdrMu.Unlock() + return err + } + } else { // Send a trailer only response. + headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"}) + headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: grpcutil.ContentType(s.contentSubtype)}) + } + } + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status", Value: strconv.Itoa(int(st.Code()))}) + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(st.Message())}) + + if p := st.Proto(); p != nil && len(p.Details) > 0 { + stBytes, err := proto.Marshal(p) + if err != nil { + // TODO: return error instead, when callers are able to handle it. + logger.Errorf("transport: failed to marshal rpc status: %v, error: %v", p, err) + } else { + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status-details-bin", Value: encodeBinHeader(stBytes)}) + } + } + + // Attach the trailer metadata. + headerFields = appendHeaderFieldsFromMD(headerFields, s.trailer) + trailingHeader := &headerFrame{ + streamID: s.id, + hf: headerFields, + endStream: true, + onWrite: t.setResetPingStrikes, + } + s.hdrMu.Unlock() + success, err := t.controlBuf.execute(t.checkForHeaderListSize, trailingHeader) + if !success { + if err != nil { + return err + } + t.closeStream(s, true, http2.ErrCodeInternal, false) + return ErrHeaderListSizeLimitViolation + } + // Send a RST_STREAM after the trailers if the client has not already half-closed. + rst := s.getState() == streamActive + t.finishStream(s, rst, http2.ErrCodeNo, trailingHeader, true) + if t.stats != nil { + // Note: The trailer fields are compressed with hpack after this call returns. + // No WireLength field is set here. + t.stats.HandleRPC(s.Context(), &stats.OutTrailer{ + Trailer: s.trailer.Copy(), + }) + } + return nil +} + +// Write converts the data into HTTP2 data frame and sends it out. Non-nil error +// is returns if it fails (e.g., framing error, transport error). +func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) error { + if !s.isHeaderSent() { // Headers haven't been written yet. + if err := t.WriteHeader(s, nil); err != nil { + if _, ok := err.(ConnectionError); ok { + return err + } + // TODO(mmukhi, dfawley): Make sure this is the right code to return. + return status.Errorf(codes.Internal, "transport: %v", err) + } + } else { + // Writing headers checks for this condition. + if s.getState() == streamDone { + // TODO(mmukhi, dfawley): Should the server write also return io.EOF? + s.cancel() + select { + case <-t.done: + return ErrConnClosing + default: + } + return ContextErr(s.ctx.Err()) + } + } + df := &dataFrame{ + streamID: s.id, + h: hdr, + d: data, + onEachWrite: t.setResetPingStrikes, + } + if err := s.wq.get(int32(len(hdr) + len(data))); err != nil { + select { + case <-t.done: + return ErrConnClosing + default: + } + return ContextErr(s.ctx.Err()) + } + return t.controlBuf.put(df) +} + +// keepalive running in a separate goroutine does the following: +// 1. Gracefully closes an idle connection after a duration of keepalive.MaxConnectionIdle. +// 2. Gracefully closes any connection after a duration of keepalive.MaxConnectionAge. +// 3. Forcibly closes a connection after an additive period of keepalive.MaxConnectionAgeGrace over keepalive.MaxConnectionAge. +// 4. Makes sure a connection is alive by sending pings with a frequency of keepalive.Time and closes a non-responsive connection +// after an additional duration of keepalive.Timeout. +func (t *http2Server) keepalive() { + p := &ping{} + // True iff a ping has been sent, and no data has been received since then. + outstandingPing := false + // Amount of time remaining before which we should receive an ACK for the + // last sent ping. + kpTimeoutLeft := time.Duration(0) + // Records the last value of t.lastRead before we go block on the timer. + // This is required to check for read activity since then. + prevNano := time.Now().UnixNano() + // Initialize the different timers to their default values. + idleTimer := time.NewTimer(t.kp.MaxConnectionIdle) + ageTimer := time.NewTimer(t.kp.MaxConnectionAge) + kpTimer := time.NewTimer(t.kp.Time) + defer func() { + // We need to drain the underlying channel in these timers after a call + // to Stop(), only if we are interested in resetting them. Clearly we + // are not interested in resetting them here. + idleTimer.Stop() + ageTimer.Stop() + kpTimer.Stop() + }() + + for { + select { + case <-idleTimer.C: + t.mu.Lock() + idle := t.idle + if idle.IsZero() { // The connection is non-idle. + t.mu.Unlock() + idleTimer.Reset(t.kp.MaxConnectionIdle) + continue + } + val := t.kp.MaxConnectionIdle - time.Since(idle) + t.mu.Unlock() + if val <= 0 { + // The connection has been idle for a duration of keepalive.MaxConnectionIdle or more. + // Gracefully close the connection. + t.drain(http2.ErrCodeNo, []byte{}) + return + } + idleTimer.Reset(val) + case <-ageTimer.C: + t.drain(http2.ErrCodeNo, []byte{}) + ageTimer.Reset(t.kp.MaxConnectionAgeGrace) + select { + case <-ageTimer.C: + // Close the connection after grace period. + if logger.V(logLevel) { + logger.Infof("transport: closing server transport due to maximum connection age.") + } + t.Close() + case <-t.done: + } + return + case <-kpTimer.C: + lastRead := atomic.LoadInt64(&t.lastRead) + if lastRead > prevNano { + // There has been read activity since the last time we were + // here. Setup the timer to fire at kp.Time seconds from + // lastRead time and continue. + outstandingPing = false + kpTimer.Reset(time.Duration(lastRead) + t.kp.Time - time.Duration(time.Now().UnixNano())) + prevNano = lastRead + continue + } + if outstandingPing && kpTimeoutLeft <= 0 { + if logger.V(logLevel) { + logger.Infof("transport: closing server transport due to idleness.") + } + t.Close() + return + } + if !outstandingPing { + if channelz.IsOn() { + atomic.AddInt64(&t.czData.kpCount, 1) + } + t.controlBuf.put(p) + kpTimeoutLeft = t.kp.Timeout + outstandingPing = true + } + // The amount of time to sleep here is the minimum of kp.Time and + // timeoutLeft. This will ensure that we wait only for kp.Time + // before sending out the next ping (for cases where the ping is + // acked). + sleepDuration := minTime(t.kp.Time, kpTimeoutLeft) + kpTimeoutLeft -= sleepDuration + kpTimer.Reset(sleepDuration) + case <-t.done: + return + } + } +} + +// Close starts shutting down the http2Server transport. +// TODO(zhaoq): Now the destruction is not blocked on any pending streams. This +// could cause some resource issue. Revisit this later. +func (t *http2Server) Close() error { + t.mu.Lock() + if t.state == closing { + t.mu.Unlock() + return errors.New("transport: Close() was already called") + } + t.state = closing + streams := t.activeStreams + t.activeStreams = nil + t.mu.Unlock() + t.controlBuf.finish() + close(t.done) + err := t.conn.Close() + if channelz.IsOn() { + channelz.RemoveEntry(t.channelzID) + } + // Cancel all active streams. + for _, s := range streams { + s.cancel() + } + if t.stats != nil { + connEnd := &stats.ConnEnd{} + t.stats.HandleConn(t.ctx, connEnd) + } + return err +} + +// deleteStream deletes the stream s from transport's active streams. +func (t *http2Server) deleteStream(s *Stream, eosReceived bool) { + // In case stream sending and receiving are invoked in separate + // goroutines (e.g., bi-directional streaming), cancel needs to be + // called to interrupt the potential blocking on other goroutines. + s.cancel() + + t.mu.Lock() + if _, ok := t.activeStreams[s.id]; ok { + delete(t.activeStreams, s.id) + if len(t.activeStreams) == 0 { + t.idle = time.Now() + } + } + t.mu.Unlock() + + if channelz.IsOn() { + if eosReceived { + atomic.AddInt64(&t.czData.streamsSucceeded, 1) + } else { + atomic.AddInt64(&t.czData.streamsFailed, 1) + } + } +} + +// finishStream closes the stream and puts the trailing headerFrame into controlbuf. +func (t *http2Server) finishStream(s *Stream, rst bool, rstCode http2.ErrCode, hdr *headerFrame, eosReceived bool) { + oldState := s.swapState(streamDone) + if oldState == streamDone { + // If the stream was already done, return. + return + } + + hdr.cleanup = &cleanupStream{ + streamID: s.id, + rst: rst, + rstCode: rstCode, + onWrite: func() { + t.deleteStream(s, eosReceived) + }, + } + t.controlBuf.put(hdr) +} + +// closeStream clears the footprint of a stream when the stream is not needed any more. +func (t *http2Server) closeStream(s *Stream, rst bool, rstCode http2.ErrCode, eosReceived bool) { + s.swapState(streamDone) + t.deleteStream(s, eosReceived) + + t.controlBuf.put(&cleanupStream{ + streamID: s.id, + rst: rst, + rstCode: rstCode, + onWrite: func() {}, + }) +} + +func (t *http2Server) RemoteAddr() net.Addr { + return t.remoteAddr +} + +func (t *http2Server) Drain() { + t.drain(http2.ErrCodeNo, []byte{}) +} + +func (t *http2Server) drain(code http2.ErrCode, debugData []byte) { + t.mu.Lock() + defer t.mu.Unlock() + if t.drainChan != nil { + return + } + t.drainChan = make(chan struct{}) + t.controlBuf.put(&goAway{code: code, debugData: debugData, headsUp: true}) +} + +var goAwayPing = &ping{data: [8]byte{1, 6, 1, 8, 0, 3, 3, 9}} + +// Handles outgoing GoAway and returns true if loopy needs to put itself +// in draining mode. +func (t *http2Server) outgoingGoAwayHandler(g *goAway) (bool, error) { + t.mu.Lock() + if t.state == closing { // TODO(mmukhi): This seems unnecessary. + t.mu.Unlock() + // The transport is closing. + return false, ErrConnClosing + } + sid := t.maxStreamID + if !g.headsUp { + // Stop accepting more streams now. + t.state = draining + if len(t.activeStreams) == 0 { + g.closeConn = true + } + t.mu.Unlock() + if err := t.framer.fr.WriteGoAway(sid, g.code, g.debugData); err != nil { + return false, err + } + if g.closeConn { + // Abruptly close the connection following the GoAway (via + // loopywriter). But flush out what's inside the buffer first. + t.framer.writer.Flush() + return false, fmt.Errorf("transport: Connection closing") + } + return true, nil + } + t.mu.Unlock() + // For a graceful close, send out a GoAway with stream ID of MaxUInt32, + // Follow that with a ping and wait for the ack to come back or a timer + // to expire. During this time accept new streams since they might have + // originated before the GoAway reaches the client. + // After getting the ack or timer expiration send out another GoAway this + // time with an ID of the max stream server intends to process. + if err := t.framer.fr.WriteGoAway(math.MaxUint32, http2.ErrCodeNo, []byte{}); err != nil { + return false, err + } + if err := t.framer.fr.WritePing(false, goAwayPing.data); err != nil { + return false, err + } + go func() { + timer := time.NewTimer(time.Minute) + defer timer.Stop() + select { + case <-t.drainChan: + case <-timer.C: + case <-t.done: + return + } + t.controlBuf.put(&goAway{code: g.code, debugData: g.debugData}) + }() + return false, nil +} + +func (t *http2Server) ChannelzMetric() *channelz.SocketInternalMetric { + s := channelz.SocketInternalMetric{ + StreamsStarted: atomic.LoadInt64(&t.czData.streamsStarted), + StreamsSucceeded: atomic.LoadInt64(&t.czData.streamsSucceeded), + StreamsFailed: atomic.LoadInt64(&t.czData.streamsFailed), + MessagesSent: atomic.LoadInt64(&t.czData.msgSent), + MessagesReceived: atomic.LoadInt64(&t.czData.msgRecv), + KeepAlivesSent: atomic.LoadInt64(&t.czData.kpCount), + LastRemoteStreamCreatedTimestamp: time.Unix(0, atomic.LoadInt64(&t.czData.lastStreamCreatedTime)), + LastMessageSentTimestamp: time.Unix(0, atomic.LoadInt64(&t.czData.lastMsgSentTime)), + LastMessageReceivedTimestamp: time.Unix(0, atomic.LoadInt64(&t.czData.lastMsgRecvTime)), + LocalFlowControlWindow: int64(t.fc.getSize()), + SocketOptions: channelz.GetSocketOption(t.conn), + LocalAddr: t.localAddr, + RemoteAddr: t.remoteAddr, + // RemoteName : + } + if au, ok := t.authInfo.(credentials.ChannelzSecurityInfo); ok { + s.Security = au.GetSecurityValue() + } + s.RemoteFlowControlWindow = t.getOutFlowWindow() + return &s +} + +func (t *http2Server) IncrMsgSent() { + atomic.AddInt64(&t.czData.msgSent, 1) + atomic.StoreInt64(&t.czData.lastMsgSentTime, time.Now().UnixNano()) +} + +func (t *http2Server) IncrMsgRecv() { + atomic.AddInt64(&t.czData.msgRecv, 1) + atomic.StoreInt64(&t.czData.lastMsgRecvTime, time.Now().UnixNano()) +} + +func (t *http2Server) getOutFlowWindow() int64 { + resp := make(chan uint32, 1) + timer := time.NewTimer(time.Second) + defer timer.Stop() + t.controlBuf.put(&outFlowControlSizeRequest{resp}) + select { + case sz := <-resp: + return int64(sz) + case <-t.done: + return -1 + case <-timer.C: + return -2 + } +} + +func getJitter(v time.Duration) time.Duration { + if v == infinity { + return 0 + } + // Generate a jitter between +/- 10% of the value. + r := int64(v / 10) + j := grpcrand.Int63n(2*r) - r + return time.Duration(j) +} diff --git a/vendor/google.golang.org/grpc/internal/transport/http_util.go b/vendor/google.golang.org/grpc/internal/transport/http_util.go new file mode 100644 index 000000000..4d15afbf7 --- /dev/null +++ b/vendor/google.golang.org/grpc/internal/transport/http_util.go @@ -0,0 +1,600 @@ +/* + * + * Copyright 2014 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package transport + +import ( + "bufio" + "bytes" + "encoding/base64" + "fmt" + "io" + "math" + "net" + "net/http" + "strconv" + "strings" + "time" + "unicode/utf8" + + "github.com/golang/protobuf/proto" + "golang.org/x/net/http2" + "golang.org/x/net/http2/hpack" + spb "google.golang.org/genproto/googleapis/rpc/status" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/internal/grpcutil" + "google.golang.org/grpc/status" +) + +const ( + // http2MaxFrameLen specifies the max length of a HTTP2 frame. + http2MaxFrameLen = 16384 // 16KB frame + // http://http2.github.io/http2-spec/#SettingValues + http2InitHeaderTableSize = 4096 + // baseContentType is the base content-type for gRPC. This is a valid + // content-type on it's own, but can also include a content-subtype such as + // "proto" as a suffix after "+" or ";". See + // https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests + // for more details. + +) + +var ( + clientPreface = []byte(http2.ClientPreface) + http2ErrConvTab = map[http2.ErrCode]codes.Code{ + http2.ErrCodeNo: codes.Internal, + http2.ErrCodeProtocol: codes.Internal, + http2.ErrCodeInternal: codes.Internal, + http2.ErrCodeFlowControl: codes.ResourceExhausted, + http2.ErrCodeSettingsTimeout: codes.Internal, + http2.ErrCodeStreamClosed: codes.Internal, + http2.ErrCodeFrameSize: codes.Internal, + http2.ErrCodeRefusedStream: codes.Unavailable, + http2.ErrCodeCancel: codes.Canceled, + http2.ErrCodeCompression: codes.Internal, + http2.ErrCodeConnect: codes.Internal, + http2.ErrCodeEnhanceYourCalm: codes.ResourceExhausted, + http2.ErrCodeInadequateSecurity: codes.PermissionDenied, + http2.ErrCodeHTTP11Required: codes.Internal, + } + // HTTPStatusConvTab is the HTTP status code to gRPC error code conversion table. + HTTPStatusConvTab = map[int]codes.Code{ + // 400 Bad Request - INTERNAL. + http.StatusBadRequest: codes.Internal, + // 401 Unauthorized - UNAUTHENTICATED. + http.StatusUnauthorized: codes.Unauthenticated, + // 403 Forbidden - PERMISSION_DENIED. + http.StatusForbidden: codes.PermissionDenied, + // 404 Not Found - UNIMPLEMENTED. + http.StatusNotFound: codes.Unimplemented, + // 429 Too Many Requests - UNAVAILABLE. + http.StatusTooManyRequests: codes.Unavailable, + // 502 Bad Gateway - UNAVAILABLE. + http.StatusBadGateway: codes.Unavailable, + // 503 Service Unavailable - UNAVAILABLE. + http.StatusServiceUnavailable: codes.Unavailable, + // 504 Gateway timeout - UNAVAILABLE. + http.StatusGatewayTimeout: codes.Unavailable, + } + logger = grpclog.Component("transport") +) + +type parsedHeaderData struct { + encoding string + // statusGen caches the stream status received from the trailer the server + // sent. Client side only. Do not access directly. After all trailers are + // parsed, use the status method to retrieve the status. + statusGen *status.Status + // rawStatusCode and rawStatusMsg are set from the raw trailer fields and are not + // intended for direct access outside of parsing. + rawStatusCode *int + rawStatusMsg string + httpStatus *int + // Server side only fields. + timeoutSet bool + timeout time.Duration + method string + // key-value metadata map from the peer. + mdata map[string][]string + statsTags []byte + statsTrace []byte + contentSubtype string + + // isGRPC field indicates whether the peer is speaking gRPC (otherwise HTTP). + // + // We are in gRPC mode (peer speaking gRPC) if: + // * We are client side and have already received a HEADER frame that indicates gRPC peer. + // * The header contains valid a content-type, i.e. a string starts with "application/grpc" + // And we should handle error specific to gRPC. + // + // Otherwise (i.e. a content-type string starts without "application/grpc", or does not exist), we + // are in HTTP fallback mode, and should handle error specific to HTTP. + isGRPC bool + grpcErr error + httpErr error + contentTypeErr string +} + +// decodeState configures decoding criteria and records the decoded data. +type decodeState struct { + // whether decoding on server side or not + serverSide bool + + // Records the states during HPACK decoding. It will be filled with info parsed from HTTP HEADERS + // frame once decodeHeader function has been invoked and returned. + data parsedHeaderData +} + +// isReservedHeader checks whether hdr belongs to HTTP2 headers +// reserved by gRPC protocol. Any other headers are classified as the +// user-specified metadata. +func isReservedHeader(hdr string) bool { + if hdr != "" && hdr[0] == ':' { + return true + } + switch hdr { + case "content-type", + "user-agent", + "grpc-message-type", + "grpc-encoding", + "grpc-message", + "grpc-status", + "grpc-timeout", + "grpc-status-details-bin", + // Intentionally exclude grpc-previous-rpc-attempts and + // grpc-retry-pushback-ms, which are "reserved", but their API + // intentionally works via metadata. + "te": + return true + default: + return false + } +} + +// isWhitelistedHeader checks whether hdr should be propagated into metadata +// visible to users, even though it is classified as "reserved", above. +func isWhitelistedHeader(hdr string) bool { + switch hdr { + case ":authority", "user-agent": + return true + default: + return false + } +} + +func (d *decodeState) status() *status.Status { + if d.data.statusGen == nil { + // No status-details were provided; generate status using code/msg. + d.data.statusGen = status.New(codes.Code(int32(*(d.data.rawStatusCode))), d.data.rawStatusMsg) + } + return d.data.statusGen +} + +const binHdrSuffix = "-bin" + +func encodeBinHeader(v []byte) string { + return base64.RawStdEncoding.EncodeToString(v) +} + +func decodeBinHeader(v string) ([]byte, error) { + if len(v)%4 == 0 { + // Input was padded, or padding was not necessary. + return base64.StdEncoding.DecodeString(v) + } + return base64.RawStdEncoding.DecodeString(v) +} + +func encodeMetadataHeader(k, v string) string { + if strings.HasSuffix(k, binHdrSuffix) { + return encodeBinHeader(([]byte)(v)) + } + return v +} + +func decodeMetadataHeader(k, v string) (string, error) { + if strings.HasSuffix(k, binHdrSuffix) { + b, err := decodeBinHeader(v) + return string(b), err + } + return v, nil +} + +func (d *decodeState) decodeHeader(frame *http2.MetaHeadersFrame) (http2.ErrCode, error) { + // frame.Truncated is set to true when framer detects that the current header + // list size hits MaxHeaderListSize limit. + if frame.Truncated { + return http2.ErrCodeFrameSize, status.Error(codes.Internal, "peer header list size exceeded limit") + } + + for _, hf := range frame.Fields { + d.processHeaderField(hf) + } + + if d.data.isGRPC { + if d.data.grpcErr != nil { + return http2.ErrCodeProtocol, d.data.grpcErr + } + if d.serverSide { + return http2.ErrCodeNo, nil + } + if d.data.rawStatusCode == nil && d.data.statusGen == nil { + // gRPC status doesn't exist. + // Set rawStatusCode to be unknown and return nil error. + // So that, if the stream has ended this Unknown status + // will be propagated to the user. + // Otherwise, it will be ignored. In which case, status from + // a later trailer, that has StreamEnded flag set, is propagated. + code := int(codes.Unknown) + d.data.rawStatusCode = &code + } + return http2.ErrCodeNo, nil + } + + // HTTP fallback mode + if d.data.httpErr != nil { + return http2.ErrCodeProtocol, d.data.httpErr + } + + var ( + code = codes.Internal // when header does not include HTTP status, return INTERNAL + ok bool + ) + + if d.data.httpStatus != nil { + code, ok = HTTPStatusConvTab[*(d.data.httpStatus)] + if !ok { + code = codes.Unknown + } + } + + return http2.ErrCodeProtocol, status.Error(code, d.constructHTTPErrMsg()) +} + +// constructErrMsg constructs error message to be returned in HTTP fallback mode. +// Format: HTTP status code and its corresponding message + content-type error message. +func (d *decodeState) constructHTTPErrMsg() string { + var errMsgs []string + + if d.data.httpStatus == nil { + errMsgs = append(errMsgs, "malformed header: missing HTTP status") + } else { + errMsgs = append(errMsgs, fmt.Sprintf("%s: HTTP status code %d", http.StatusText(*(d.data.httpStatus)), *d.data.httpStatus)) + } + + if d.data.contentTypeErr == "" { + errMsgs = append(errMsgs, "transport: missing content-type field") + } else { + errMsgs = append(errMsgs, d.data.contentTypeErr) + } + + return strings.Join(errMsgs, "; ") +} + +func (d *decodeState) addMetadata(k, v string) { + if d.data.mdata == nil { + d.data.mdata = make(map[string][]string) + } + d.data.mdata[k] = append(d.data.mdata[k], v) +} + +func (d *decodeState) processHeaderField(f hpack.HeaderField) { + switch f.Name { + case "content-type": + contentSubtype, validContentType := grpcutil.ContentSubtype(f.Value) + if !validContentType { + d.data.contentTypeErr = fmt.Sprintf("transport: received the unexpected content-type %q", f.Value) + return + } + d.data.contentSubtype = contentSubtype + // TODO: do we want to propagate the whole content-type in the metadata, + // or come up with a way to just propagate the content-subtype if it was set? + // ie {"content-type": "application/grpc+proto"} or {"content-subtype": "proto"} + // in the metadata? + d.addMetadata(f.Name, f.Value) + d.data.isGRPC = true + case "grpc-encoding": + d.data.encoding = f.Value + case "grpc-status": + code, err := strconv.Atoi(f.Value) + if err != nil { + d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-status: %v", err) + return + } + d.data.rawStatusCode = &code + case "grpc-message": + d.data.rawStatusMsg = decodeGrpcMessage(f.Value) + case "grpc-status-details-bin": + v, err := decodeBinHeader(f.Value) + if err != nil { + d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err) + return + } + s := &spb.Status{} + if err := proto.Unmarshal(v, s); err != nil { + d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err) + return + } + d.data.statusGen = status.FromProto(s) + case "grpc-timeout": + d.data.timeoutSet = true + var err error + if d.data.timeout, err = decodeTimeout(f.Value); err != nil { + d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed time-out: %v", err) + } + case ":path": + d.data.method = f.Value + case ":status": + code, err := strconv.Atoi(f.Value) + if err != nil { + d.data.httpErr = status.Errorf(codes.Internal, "transport: malformed http-status: %v", err) + return + } + d.data.httpStatus = &code + case "grpc-tags-bin": + v, err := decodeBinHeader(f.Value) + if err != nil { + d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-tags-bin: %v", err) + return + } + d.data.statsTags = v + d.addMetadata(f.Name, string(v)) + case "grpc-trace-bin": + v, err := decodeBinHeader(f.Value) + if err != nil { + d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-trace-bin: %v", err) + return + } + d.data.statsTrace = v + d.addMetadata(f.Name, string(v)) + default: + if isReservedHeader(f.Name) && !isWhitelistedHeader(f.Name) { + break + } + v, err := decodeMetadataHeader(f.Name, f.Value) + if err != nil { + if logger.V(logLevel) { + logger.Errorf("Failed to decode metadata header (%q, %q): %v", f.Name, f.Value, err) + } + return + } + d.addMetadata(f.Name, v) + } +} + +type timeoutUnit uint8 + +const ( + hour timeoutUnit = 'H' + minute timeoutUnit = 'M' + second timeoutUnit = 'S' + millisecond timeoutUnit = 'm' + microsecond timeoutUnit = 'u' + nanosecond timeoutUnit = 'n' +) + +func timeoutUnitToDuration(u timeoutUnit) (d time.Duration, ok bool) { + switch u { + case hour: + return time.Hour, true + case minute: + return time.Minute, true + case second: + return time.Second, true + case millisecond: + return time.Millisecond, true + case microsecond: + return time.Microsecond, true + case nanosecond: + return time.Nanosecond, true + default: + } + return +} + +func decodeTimeout(s string) (time.Duration, error) { + size := len(s) + if size < 2 { + return 0, fmt.Errorf("transport: timeout string is too short: %q", s) + } + if size > 9 { + // Spec allows for 8 digits plus the unit. + return 0, fmt.Errorf("transport: timeout string is too long: %q", s) + } + unit := timeoutUnit(s[size-1]) + d, ok := timeoutUnitToDuration(unit) + if !ok { + return 0, fmt.Errorf("transport: timeout unit is not recognized: %q", s) + } + t, err := strconv.ParseInt(s[:size-1], 10, 64) + if err != nil { + return 0, err + } + const maxHours = math.MaxInt64 / int64(time.Hour) + if d == time.Hour && t > maxHours { + // This timeout would overflow math.MaxInt64; clamp it. + return time.Duration(math.MaxInt64), nil + } + return d * time.Duration(t), nil +} + +const ( + spaceByte = ' ' + tildeByte = '~' + percentByte = '%' +) + +// encodeGrpcMessage is used to encode status code in header field +// "grpc-message". It does percent encoding and also replaces invalid utf-8 +// characters with Unicode replacement character. +// +// It checks to see if each individual byte in msg is an allowable byte, and +// then either percent encoding or passing it through. When percent encoding, +// the byte is converted into hexadecimal notation with a '%' prepended. +func encodeGrpcMessage(msg string) string { + if msg == "" { + return "" + } + lenMsg := len(msg) + for i := 0; i < lenMsg; i++ { + c := msg[i] + if !(c >= spaceByte && c <= tildeByte && c != percentByte) { + return encodeGrpcMessageUnchecked(msg) + } + } + return msg +} + +func encodeGrpcMessageUnchecked(msg string) string { + var buf bytes.Buffer + for len(msg) > 0 { + r, size := utf8.DecodeRuneInString(msg) + for _, b := range []byte(string(r)) { + if size > 1 { + // If size > 1, r is not ascii. Always do percent encoding. + buf.WriteString(fmt.Sprintf("%%%02X", b)) + continue + } + + // The for loop is necessary even if size == 1. r could be + // utf8.RuneError. + // + // fmt.Sprintf("%%%02X", utf8.RuneError) gives "%FFFD". + if b >= spaceByte && b <= tildeByte && b != percentByte { + buf.WriteByte(b) + } else { + buf.WriteString(fmt.Sprintf("%%%02X", b)) + } + } + msg = msg[size:] + } + return buf.String() +} + +// decodeGrpcMessage decodes the msg encoded by encodeGrpcMessage. +func decodeGrpcMessage(msg string) string { + if msg == "" { + return "" + } + lenMsg := len(msg) + for i := 0; i < lenMsg; i++ { + if msg[i] == percentByte && i+2 < lenMsg { + return decodeGrpcMessageUnchecked(msg) + } + } + return msg +} + +func decodeGrpcMessageUnchecked(msg string) string { + var buf bytes.Buffer + lenMsg := len(msg) + for i := 0; i < lenMsg; i++ { + c := msg[i] + if c == percentByte && i+2 < lenMsg { + parsed, err := strconv.ParseUint(msg[i+1:i+3], 16, 8) + if err != nil { + buf.WriteByte(c) + } else { + buf.WriteByte(byte(parsed)) + i += 2 + } + } else { + buf.WriteByte(c) + } + } + return buf.String() +} + +type bufWriter struct { + buf []byte + offset int + batchSize int + conn net.Conn + err error + + onFlush func() +} + +func newBufWriter(conn net.Conn, batchSize int) *bufWriter { + return &bufWriter{ + buf: make([]byte, batchSize*2), + batchSize: batchSize, + conn: conn, + } +} + +func (w *bufWriter) Write(b []byte) (n int, err error) { + if w.err != nil { + return 0, w.err + } + if w.batchSize == 0 { // Buffer has been disabled. + return w.conn.Write(b) + } + for len(b) > 0 { + nn := copy(w.buf[w.offset:], b) + b = b[nn:] + w.offset += nn + n += nn + if w.offset >= w.batchSize { + err = w.Flush() + } + } + return n, err +} + +func (w *bufWriter) Flush() error { + if w.err != nil { + return w.err + } + if w.offset == 0 { + return nil + } + if w.onFlush != nil { + w.onFlush() + } + _, w.err = w.conn.Write(w.buf[:w.offset]) + w.offset = 0 + return w.err +} + +type framer struct { + writer *bufWriter + fr *http2.Framer +} + +func newFramer(conn net.Conn, writeBufferSize, readBufferSize int, maxHeaderListSize uint32) *framer { + if writeBufferSize < 0 { + writeBufferSize = 0 + } + var r io.Reader = conn + if readBufferSize > 0 { + r = bufio.NewReaderSize(r, readBufferSize) + } + w := newBufWriter(conn, writeBufferSize) + f := &framer{ + writer: w, + fr: http2.NewFramer(w, r), + } + f.fr.SetMaxReadFrameSize(http2MaxFrameLen) + // Opt-in to Frame reuse API on framer to reduce garbage. + // Frames aren't safe to read from after a subsequent call to ReadFrame. + f.fr.SetReuseFrames() + f.fr.MaxHeaderListSize = maxHeaderListSize + f.fr.ReadMetaHeaders = hpack.NewDecoder(http2InitHeaderTableSize, nil) + return f +} diff --git a/vendor/google.golang.org/grpc/internal/transport/transport.go b/vendor/google.golang.org/grpc/internal/transport/transport.go new file mode 100644 index 000000000..b74030a96 --- /dev/null +++ b/vendor/google.golang.org/grpc/internal/transport/transport.go @@ -0,0 +1,804 @@ +/* + * + * Copyright 2014 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package transport defines and implements message oriented communication +// channel to complete various transactions (e.g., an RPC). It is meant for +// grpc-internal usage and is not intended to be imported directly by users. +package transport + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net" + "sync" + "sync/atomic" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/keepalive" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/stats" + "google.golang.org/grpc/status" + "google.golang.org/grpc/tap" +) + +const logLevel = 2 + +type bufferPool struct { + pool sync.Pool +} + +func newBufferPool() *bufferPool { + return &bufferPool{ + pool: sync.Pool{ + New: func() interface{} { + return new(bytes.Buffer) + }, + }, + } +} + +func (p *bufferPool) get() *bytes.Buffer { + return p.pool.Get().(*bytes.Buffer) +} + +func (p *bufferPool) put(b *bytes.Buffer) { + p.pool.Put(b) +} + +// recvMsg represents the received msg from the transport. All transport +// protocol specific info has been removed. +type recvMsg struct { + buffer *bytes.Buffer + // nil: received some data + // io.EOF: stream is completed. data is nil. + // other non-nil error: transport failure. data is nil. + err error +} + +// recvBuffer is an unbounded channel of recvMsg structs. +// +// Note: recvBuffer differs from buffer.Unbounded only in the fact that it +// holds a channel of recvMsg structs instead of objects implementing "item" +// interface. recvBuffer is written to much more often and using strict recvMsg +// structs helps avoid allocation in "recvBuffer.put" +type recvBuffer struct { + c chan recvMsg + mu sync.Mutex + backlog []recvMsg + err error +} + +func newRecvBuffer() *recvBuffer { + b := &recvBuffer{ + c: make(chan recvMsg, 1), + } + return b +} + +func (b *recvBuffer) put(r recvMsg) { + b.mu.Lock() + if b.err != nil { + b.mu.Unlock() + // An error had occurred earlier, don't accept more + // data or errors. + return + } + b.err = r.err + if len(b.backlog) == 0 { + select { + case b.c <- r: + b.mu.Unlock() + return + default: + } + } + b.backlog = append(b.backlog, r) + b.mu.Unlock() +} + +func (b *recvBuffer) load() { + b.mu.Lock() + if len(b.backlog) > 0 { + select { + case b.c <- b.backlog[0]: + b.backlog[0] = recvMsg{} + b.backlog = b.backlog[1:] + default: + } + } + b.mu.Unlock() +} + +// get returns the channel that receives a recvMsg in the buffer. +// +// Upon receipt of a recvMsg, the caller should call load to send another +// recvMsg onto the channel if there is any. +func (b *recvBuffer) get() <-chan recvMsg { + return b.c +} + +// recvBufferReader implements io.Reader interface to read the data from +// recvBuffer. +type recvBufferReader struct { + closeStream func(error) // Closes the client transport stream with the given error and nil trailer metadata. + ctx context.Context + ctxDone <-chan struct{} // cache of ctx.Done() (for performance). + recv *recvBuffer + last *bytes.Buffer // Stores the remaining data in the previous calls. + err error + freeBuffer func(*bytes.Buffer) +} + +// Read reads the next len(p) bytes from last. If last is drained, it tries to +// read additional data from recv. It blocks if there no additional data available +// in recv. If Read returns any non-nil error, it will continue to return that error. +func (r *recvBufferReader) Read(p []byte) (n int, err error) { + if r.err != nil { + return 0, r.err + } + if r.last != nil { + // Read remaining data left in last call. + copied, _ := r.last.Read(p) + if r.last.Len() == 0 { + r.freeBuffer(r.last) + r.last = nil + } + return copied, nil + } + if r.closeStream != nil { + n, r.err = r.readClient(p) + } else { + n, r.err = r.read(p) + } + return n, r.err +} + +func (r *recvBufferReader) read(p []byte) (n int, err error) { + select { + case <-r.ctxDone: + return 0, ContextErr(r.ctx.Err()) + case m := <-r.recv.get(): + return r.readAdditional(m, p) + } +} + +func (r *recvBufferReader) readClient(p []byte) (n int, err error) { + // If the context is canceled, then closes the stream with nil metadata. + // closeStream writes its error parameter to r.recv as a recvMsg. + // r.readAdditional acts on that message and returns the necessary error. + select { + case <-r.ctxDone: + // Note that this adds the ctx error to the end of recv buffer, and + // reads from the head. This will delay the error until recv buffer is + // empty, thus will delay ctx cancellation in Recv(). + // + // It's done this way to fix a race between ctx cancel and trailer. The + // race was, stream.Recv() may return ctx error if ctxDone wins the + // race, but stream.Trailer() may return a non-nil md because the stream + // was not marked as done when trailer is received. This closeStream + // call will mark stream as done, thus fix the race. + // + // TODO: delaying ctx error seems like a unnecessary side effect. What + // we really want is to mark the stream as done, and return ctx error + // faster. + r.closeStream(ContextErr(r.ctx.Err())) + m := <-r.recv.get() + return r.readAdditional(m, p) + case m := <-r.recv.get(): + return r.readAdditional(m, p) + } +} + +func (r *recvBufferReader) readAdditional(m recvMsg, p []byte) (n int, err error) { + r.recv.load() + if m.err != nil { + return 0, m.err + } + copied, _ := m.buffer.Read(p) + if m.buffer.Len() == 0 { + r.freeBuffer(m.buffer) + r.last = nil + } else { + r.last = m.buffer + } + return copied, nil +} + +type streamState uint32 + +const ( + streamActive streamState = iota + streamWriteDone // EndStream sent + streamReadDone // EndStream received + streamDone // the entire stream is finished. +) + +// Stream represents an RPC in the transport layer. +type Stream struct { + id uint32 + st ServerTransport // nil for client side Stream + ct *http2Client // nil for server side Stream + ctx context.Context // the associated context of the stream + cancel context.CancelFunc // always nil for client side Stream + done chan struct{} // closed at the end of stream to unblock writers. On the client side. + ctxDone <-chan struct{} // same as done chan but for server side. Cache of ctx.Done() (for performance) + method string // the associated RPC method of the stream + recvCompress string + sendCompress string + buf *recvBuffer + trReader io.Reader + fc *inFlow + wq *writeQuota + + // Callback to state application's intentions to read data. This + // is used to adjust flow control, if needed. + requestRead func(int) + + headerChan chan struct{} // closed to indicate the end of header metadata. + headerChanClosed uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times. + // headerValid indicates whether a valid header was received. Only + // meaningful after headerChan is closed (always call waitOnHeader() before + // reading its value). Not valid on server side. + headerValid bool + + // hdrMu protects header and trailer metadata on the server-side. + hdrMu sync.Mutex + // On client side, header keeps the received header metadata. + // + // On server side, header keeps the header set by SetHeader(). The complete + // header will merged into this after t.WriteHeader() is called. + header metadata.MD + trailer metadata.MD // the key-value map of trailer metadata. + + noHeaders bool // set if the client never received headers (set only after the stream is done). + + // On the server-side, headerSent is atomically set to 1 when the headers are sent out. + headerSent uint32 + + state streamState + + // On client-side it is the status error received from the server. + // On server-side it is unused. + status *status.Status + + bytesReceived uint32 // indicates whether any bytes have been received on this stream + unprocessed uint32 // set if the server sends a refused stream or GOAWAY including this stream + + // contentSubtype is the content-subtype for requests. + // this must be lowercase or the behavior is undefined. + contentSubtype string +} + +// isHeaderSent is only valid on the server-side. +func (s *Stream) isHeaderSent() bool { + return atomic.LoadUint32(&s.headerSent) == 1 +} + +// updateHeaderSent updates headerSent and returns true +// if it was alreay set. It is valid only on server-side. +func (s *Stream) updateHeaderSent() bool { + return atomic.SwapUint32(&s.headerSent, 1) == 1 +} + +func (s *Stream) swapState(st streamState) streamState { + return streamState(atomic.SwapUint32((*uint32)(&s.state), uint32(st))) +} + +func (s *Stream) compareAndSwapState(oldState, newState streamState) bool { + return atomic.CompareAndSwapUint32((*uint32)(&s.state), uint32(oldState), uint32(newState)) +} + +func (s *Stream) getState() streamState { + return streamState(atomic.LoadUint32((*uint32)(&s.state))) +} + +func (s *Stream) waitOnHeader() { + if s.headerChan == nil { + // On the server headerChan is always nil since a stream originates + // only after having received headers. + return + } + select { + case <-s.ctx.Done(): + // Close the stream to prevent headers/trailers from changing after + // this function returns. + s.ct.CloseStream(s, ContextErr(s.ctx.Err())) + // headerChan could possibly not be closed yet if closeStream raced + // with operateHeaders; wait until it is closed explicitly here. + <-s.headerChan + case <-s.headerChan: + } +} + +// RecvCompress returns the compression algorithm applied to the inbound +// message. It is empty string if there is no compression applied. +func (s *Stream) RecvCompress() string { + s.waitOnHeader() + return s.recvCompress +} + +// SetSendCompress sets the compression algorithm to the stream. +func (s *Stream) SetSendCompress(str string) { + s.sendCompress = str +} + +// Done returns a channel which is closed when it receives the final status +// from the server. +func (s *Stream) Done() <-chan struct{} { + return s.done +} + +// Header returns the header metadata of the stream. +// +// On client side, it acquires the key-value pairs of header metadata once it is +// available. It blocks until i) the metadata is ready or ii) there is no header +// metadata or iii) the stream is canceled/expired. +// +// On server side, it returns the out header after t.WriteHeader is called. It +// does not block and must not be called until after WriteHeader. +func (s *Stream) Header() (metadata.MD, error) { + if s.headerChan == nil { + // On server side, return the header in stream. It will be the out + // header after t.WriteHeader is called. + return s.header.Copy(), nil + } + s.waitOnHeader() + if !s.headerValid { + return nil, s.status.Err() + } + return s.header.Copy(), nil +} + +// TrailersOnly blocks until a header or trailers-only frame is received and +// then returns true if the stream was trailers-only. If the stream ends +// before headers are received, returns true, nil. Client-side only. +func (s *Stream) TrailersOnly() bool { + s.waitOnHeader() + return s.noHeaders +} + +// Trailer returns the cached trailer metedata. Note that if it is not called +// after the entire stream is done, it could return an empty MD. Client +// side only. +// It can be safely read only after stream has ended that is either read +// or write have returned io.EOF. +func (s *Stream) Trailer() metadata.MD { + c := s.trailer.Copy() + return c +} + +// ContentSubtype returns the content-subtype for a request. For example, a +// content-subtype of "proto" will result in a content-type of +// "application/grpc+proto". This will always be lowercase. See +// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for +// more details. +func (s *Stream) ContentSubtype() string { + return s.contentSubtype +} + +// Context returns the context of the stream. +func (s *Stream) Context() context.Context { + return s.ctx +} + +// Method returns the method for the stream. +func (s *Stream) Method() string { + return s.method +} + +// Status returns the status received from the server. +// Status can be read safely only after the stream has ended, +// that is, after Done() is closed. +func (s *Stream) Status() *status.Status { + return s.status +} + +// SetHeader sets the header metadata. This can be called multiple times. +// Server side only. +// This should not be called in parallel to other data writes. +func (s *Stream) SetHeader(md metadata.MD) error { + if md.Len() == 0 { + return nil + } + if s.isHeaderSent() || s.getState() == streamDone { + return ErrIllegalHeaderWrite + } + s.hdrMu.Lock() + s.header = metadata.Join(s.header, md) + s.hdrMu.Unlock() + return nil +} + +// SendHeader sends the given header metadata. The given metadata is +// combined with any metadata set by previous calls to SetHeader and +// then written to the transport stream. +func (s *Stream) SendHeader(md metadata.MD) error { + return s.st.WriteHeader(s, md) +} + +// SetTrailer sets the trailer metadata which will be sent with the RPC status +// by the server. This can be called multiple times. Server side only. +// This should not be called parallel to other data writes. +func (s *Stream) SetTrailer(md metadata.MD) error { + if md.Len() == 0 { + return nil + } + if s.getState() == streamDone { + return ErrIllegalHeaderWrite + } + s.hdrMu.Lock() + s.trailer = metadata.Join(s.trailer, md) + s.hdrMu.Unlock() + return nil +} + +func (s *Stream) write(m recvMsg) { + s.buf.put(m) +} + +// Read reads all p bytes from the wire for this stream. +func (s *Stream) Read(p []byte) (n int, err error) { + // Don't request a read if there was an error earlier + if er := s.trReader.(*transportReader).er; er != nil { + return 0, er + } + s.requestRead(len(p)) + return io.ReadFull(s.trReader, p) +} + +// tranportReader reads all the data available for this Stream from the transport and +// passes them into the decoder, which converts them into a gRPC message stream. +// The error is io.EOF when the stream is done or another non-nil error if +// the stream broke. +type transportReader struct { + reader io.Reader + // The handler to control the window update procedure for both this + // particular stream and the associated transport. + windowHandler func(int) + er error +} + +func (t *transportReader) Read(p []byte) (n int, err error) { + n, err = t.reader.Read(p) + if err != nil { + t.er = err + return + } + t.windowHandler(n) + return +} + +// BytesReceived indicates whether any bytes have been received on this stream. +func (s *Stream) BytesReceived() bool { + return atomic.LoadUint32(&s.bytesReceived) == 1 +} + +// Unprocessed indicates whether the server did not process this stream -- +// i.e. it sent a refused stream or GOAWAY including this stream ID. +func (s *Stream) Unprocessed() bool { + return atomic.LoadUint32(&s.unprocessed) == 1 +} + +// GoString is implemented by Stream so context.String() won't +// race when printing %#v. +func (s *Stream) GoString() string { + return fmt.Sprintf("<stream: %p, %v>", s, s.method) +} + +// state of transport +type transportState int + +const ( + reachable transportState = iota + closing + draining +) + +// ServerConfig consists of all the configurations to establish a server transport. +type ServerConfig struct { + MaxStreams uint32 + AuthInfo credentials.AuthInfo + InTapHandle tap.ServerInHandle + StatsHandler stats.Handler + KeepaliveParams keepalive.ServerParameters + KeepalivePolicy keepalive.EnforcementPolicy + InitialWindowSize int32 + InitialConnWindowSize int32 + WriteBufferSize int + ReadBufferSize int + ChannelzParentID int64 + MaxHeaderListSize *uint32 + HeaderTableSize *uint32 +} + +// NewServerTransport creates a ServerTransport with conn or non-nil error +// if it fails. +func NewServerTransport(protocol string, conn net.Conn, config *ServerConfig) (ServerTransport, error) { + return newHTTP2Server(conn, config) +} + +// ConnectOptions covers all relevant options for communicating with the server. +type ConnectOptions struct { + // UserAgent is the application user agent. + UserAgent string + // Dialer specifies how to dial a network address. + Dialer func(context.Context, string) (net.Conn, error) + // FailOnNonTempDialError specifies if gRPC fails on non-temporary dial errors. + FailOnNonTempDialError bool + // PerRPCCredentials stores the PerRPCCredentials required to issue RPCs. + PerRPCCredentials []credentials.PerRPCCredentials + // TransportCredentials stores the Authenticator required to setup a client + // connection. Only one of TransportCredentials and CredsBundle is non-nil. + TransportCredentials credentials.TransportCredentials + // CredsBundle is the credentials bundle to be used. Only one of + // TransportCredentials and CredsBundle is non-nil. + CredsBundle credentials.Bundle + // KeepaliveParams stores the keepalive parameters. + KeepaliveParams keepalive.ClientParameters + // StatsHandler stores the handler for stats. + StatsHandler stats.Handler + // InitialWindowSize sets the initial window size for a stream. + InitialWindowSize int32 + // InitialConnWindowSize sets the initial window size for a connection. + InitialConnWindowSize int32 + // WriteBufferSize sets the size of write buffer which in turn determines how much data can be batched before it's written on the wire. + WriteBufferSize int + // ReadBufferSize sets the size of read buffer, which in turn determines how much data can be read at most for one read syscall. + ReadBufferSize int + // ChannelzParentID sets the addrConn id which initiate the creation of this client transport. + ChannelzParentID int64 + // MaxHeaderListSize sets the max (uncompressed) size of header list that is prepared to be received. + MaxHeaderListSize *uint32 +} + +// NewClientTransport establishes the transport with the required ConnectOptions +// and returns it to the caller. +func NewClientTransport(connectCtx, ctx context.Context, addr resolver.Address, opts ConnectOptions, onPrefaceReceipt func(), onGoAway func(GoAwayReason), onClose func()) (ClientTransport, error) { + return newHTTP2Client(connectCtx, ctx, addr, opts, onPrefaceReceipt, onGoAway, onClose) +} + +// Options provides additional hints and information for message +// transmission. +type Options struct { + // Last indicates whether this write is the last piece for + // this stream. + Last bool +} + +// CallHdr carries the information of a particular RPC. +type CallHdr struct { + // Host specifies the peer's host. + Host string + + // Method specifies the operation to perform. + Method string + + // SendCompress specifies the compression algorithm applied on + // outbound message. + SendCompress string + + // Creds specifies credentials.PerRPCCredentials for a call. + Creds credentials.PerRPCCredentials + + // ContentSubtype specifies the content-subtype for a request. For example, a + // content-subtype of "proto" will result in a content-type of + // "application/grpc+proto". The value of ContentSubtype must be all + // lowercase, otherwise the behavior is undefined. See + // https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests + // for more details. + ContentSubtype string + + PreviousAttempts int // value of grpc-previous-rpc-attempts header to set +} + +// ClientTransport is the common interface for all gRPC client-side transport +// implementations. +type ClientTransport interface { + // Close tears down this transport. Once it returns, the transport + // should not be accessed any more. The caller must make sure this + // is called only once. + Close() error + + // GracefulClose starts to tear down the transport: the transport will stop + // accepting new RPCs and NewStream will return error. Once all streams are + // finished, the transport will close. + // + // It does not block. + GracefulClose() + + // Write sends the data for the given stream. A nil stream indicates + // the write is to be performed on the transport as a whole. + Write(s *Stream, hdr []byte, data []byte, opts *Options) error + + // NewStream creates a Stream for an RPC. + NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, error) + + // CloseStream clears the footprint of a stream when the stream is + // not needed any more. The err indicates the error incurred when + // CloseStream is called. Must be called when a stream is finished + // unless the associated transport is closing. + CloseStream(stream *Stream, err error) + + // Error returns a channel that is closed when some I/O error + // happens. Typically the caller should have a goroutine to monitor + // this in order to take action (e.g., close the current transport + // and create a new one) in error case. It should not return nil + // once the transport is initiated. + Error() <-chan struct{} + + // GoAway returns a channel that is closed when ClientTransport + // receives the draining signal from the server (e.g., GOAWAY frame in + // HTTP/2). + GoAway() <-chan struct{} + + // GetGoAwayReason returns the reason why GoAway frame was received. + GetGoAwayReason() GoAwayReason + + // RemoteAddr returns the remote network address. + RemoteAddr() net.Addr + + // IncrMsgSent increments the number of message sent through this transport. + IncrMsgSent() + + // IncrMsgRecv increments the number of message received through this transport. + IncrMsgRecv() +} + +// ServerTransport is the common interface for all gRPC server-side transport +// implementations. +// +// Methods may be called concurrently from multiple goroutines, but +// Write methods for a given Stream will be called serially. +type ServerTransport interface { + // HandleStreams receives incoming streams using the given handler. + HandleStreams(func(*Stream), func(context.Context, string) context.Context) + + // WriteHeader sends the header metadata for the given stream. + // WriteHeader may not be called on all streams. + WriteHeader(s *Stream, md metadata.MD) error + + // Write sends the data for the given stream. + // Write may not be called on all streams. + Write(s *Stream, hdr []byte, data []byte, opts *Options) error + + // WriteStatus sends the status of a stream to the client. WriteStatus is + // the final call made on a stream and always occurs. + WriteStatus(s *Stream, st *status.Status) error + + // Close tears down the transport. Once it is called, the transport + // should not be accessed any more. All the pending streams and their + // handlers will be terminated asynchronously. + Close() error + + // RemoteAddr returns the remote network address. + RemoteAddr() net.Addr + + // Drain notifies the client this ServerTransport stops accepting new RPCs. + Drain() + + // IncrMsgSent increments the number of message sent through this transport. + IncrMsgSent() + + // IncrMsgRecv increments the number of message received through this transport. + IncrMsgRecv() +} + +// connectionErrorf creates an ConnectionError with the specified error description. +func connectionErrorf(temp bool, e error, format string, a ...interface{}) ConnectionError { + return ConnectionError{ + Desc: fmt.Sprintf(format, a...), + temp: temp, + err: e, + } +} + +// ConnectionError is an error that results in the termination of the +// entire connection and the retry of all the active streams. +type ConnectionError struct { + Desc string + temp bool + err error +} + +func (e ConnectionError) Error() string { + return fmt.Sprintf("connection error: desc = %q", e.Desc) +} + +// Temporary indicates if this connection error is temporary or fatal. +func (e ConnectionError) Temporary() bool { + return e.temp +} + +// Origin returns the original error of this connection error. +func (e ConnectionError) Origin() error { + // Never return nil error here. + // If the original error is nil, return itself. + if e.err == nil { + return e + } + return e.err +} + +var ( + // ErrConnClosing indicates that the transport is closing. + ErrConnClosing = connectionErrorf(true, nil, "transport is closing") + // errStreamDrain indicates that the stream is rejected because the + // connection is draining. This could be caused by goaway or balancer + // removing the address. + errStreamDrain = status.Error(codes.Unavailable, "the connection is draining") + // errStreamDone is returned from write at the client side to indiacte application + // layer of an error. + errStreamDone = errors.New("the stream is done") + // StatusGoAway indicates that the server sent a GOAWAY that included this + // stream's ID in unprocessed RPCs. + statusGoAway = status.New(codes.Unavailable, "the stream is rejected because server is draining the connection") +) + +// GoAwayReason contains the reason for the GoAway frame received. +type GoAwayReason uint8 + +const ( + // GoAwayInvalid indicates that no GoAway frame is received. + GoAwayInvalid GoAwayReason = 0 + // GoAwayNoReason is the default value when GoAway frame is received. + GoAwayNoReason GoAwayReason = 1 + // GoAwayTooManyPings indicates that a GoAway frame with + // ErrCodeEnhanceYourCalm was received and that the debug data said + // "too_many_pings". + GoAwayTooManyPings GoAwayReason = 2 +) + +// channelzData is used to store channelz related data for http2Client and http2Server. +// These fields cannot be embedded in the original structs (e.g. http2Client), since to do atomic +// operation on int64 variable on 32-bit machine, user is responsible to enforce memory alignment. +// Here, by grouping those int64 fields inside a struct, we are enforcing the alignment. +type channelzData struct { + kpCount int64 + // The number of streams that have started, including already finished ones. + streamsStarted int64 + // Client side: The number of streams that have ended successfully by receiving + // EoS bit set frame from server. + // Server side: The number of streams that have ended successfully by sending + // frame with EoS bit set. + streamsSucceeded int64 + streamsFailed int64 + // lastStreamCreatedTime stores the timestamp that the last stream gets created. It is of int64 type + // instead of time.Time since it's more costly to atomically update time.Time variable than int64 + // variable. The same goes for lastMsgSentTime and lastMsgRecvTime. + lastStreamCreatedTime int64 + msgSent int64 + msgRecv int64 + lastMsgSentTime int64 + lastMsgRecvTime int64 +} + +// ContextErr converts the error from context package into a status error. +func ContextErr(err error) error { + switch err { + case context.DeadlineExceeded: + return status.Error(codes.DeadlineExceeded, err.Error()) + case context.Canceled: + return status.Error(codes.Canceled, err.Error()) + } + return status.Errorf(codes.Internal, "Unexpected error from context packet: %v", err) +} |