diff options
Diffstat (limited to 'pkg/varlinkapi/virtwriter/virtwriter.go')
-rw-r--r-- | pkg/varlinkapi/virtwriter/virtwriter.go | 141 |
1 files changed, 87 insertions, 54 deletions
diff --git a/pkg/varlinkapi/virtwriter/virtwriter.go b/pkg/varlinkapi/virtwriter/virtwriter.go index 3adaf6e17..27ecd1f52 100644 --- a/pkg/varlinkapi/virtwriter/virtwriter.go +++ b/pkg/varlinkapi/virtwriter/virtwriter.go @@ -4,10 +4,9 @@ import ( "bufio" "encoding/binary" "encoding/json" - "errors" "io" - "os" + "github.com/pkg/errors" "k8s.io/client-go/tools/remotecommand" ) @@ -90,66 +89,100 @@ func (v VirtWriteCloser) Write(input []byte) (int, error) { } // Reader decodes the content that comes over the wire and directs it to the proper destination. -func Reader(r *bufio.Reader, output, errput *os.File, input *io.PipeWriter, resize chan remotecommand.TerminalSize) error { - var saveb []byte - var eom int +func Reader(r *bufio.Reader, output, errput, input io.Writer, resize chan remotecommand.TerminalSize, execEcChan chan int) error { + var messageSize int64 + headerBytes := make([]byte, 8) + + if r == nil { + return errors.Errorf("Reader must not be nil") + } + for { - readb := make([]byte, 32*1024) - n, err := r.Read(readb) - // TODO, later may be worth checking in len of the read is 0 + n, err := io.ReadFull(r, headerBytes) if err != nil { - return err + return errors.Wrapf(err, "Virtual Read failed, %d", n) } - b := append(saveb, readb[0:n]...) - // no sense in reading less than the header len - for len(b) > 7 { - eom = int(binary.BigEndian.Uint32(b[4:8])) + 8 - // The message and header are togther - if len(b) >= eom { - out := append([]byte{}, b[8:eom]...) - - switch IntToSocketDest(int(b[0])) { - case ToStdout: - n, err := output.Write(out) - if err != nil { - return err - } - if n < len(out) { - return errors.New("short write error occurred on stdout") - } - case ToStderr: - n, err := errput.Write(out) - if err != nil { - return err - } - if n < len(out) { - return errors.New("short write error occurred on stderr") - } - case ToStdin: - n, err := input.Write(out) + if n < 8 { + return errors.New("short read and no full header read") + } + + messageSize = int64(binary.BigEndian.Uint32(headerBytes[4:8])) + + switch IntToSocketDest(int(headerBytes[0])) { + case ToStdout: + if output != nil { + _, err := io.CopyN(output, r, messageSize) + if err != nil { + return err + } + } + case ToStderr: + if errput != nil { + _, err := io.CopyN(errput, r, messageSize) + if err != nil { + return err + } + } + case ToStdin: + if input != nil { + _, err := io.CopyN(input, r, messageSize) + if err != nil { + return err + } + } + case TerminalResize: + if resize != nil { + out := make([]byte, messageSize) + if messageSize > 0 { + _, err = io.ReadFull(r, out) + if err != nil { return err } - if n < len(out) { - return errors.New("short write error occurred on stdin") - } - case TerminalResize: - // Resize events come over in bytes, need to be reserialized - resizeEvent := remotecommand.TerminalSize{} - if err := json.Unmarshal(out, &resizeEvent); err != nil { - return err - } - resize <- resizeEvent - case Quit: - return nil } - b = b[eom:] - } else { - // We do not have the header and full message, need to slurp again - saveb = b - break + // Resize events come over in bytes, need to be reserialized + resizeEvent := remotecommand.TerminalSize{} + if err := json.Unmarshal(out, &resizeEvent); err != nil { + return err + } + resize <- resizeEvent + } + case Quit: + out := make([]byte, messageSize) + if messageSize > 0 { + _, err = io.ReadFull(r, out) + + if err != nil { + return err + } + } + if execEcChan != nil { + ecInt := binary.BigEndian.Uint32(out) + execEcChan <- int(ecInt) } + return nil + + default: + // Something really went wrong + return errors.New("unknown multiplex destination") } } - return nil +} + +// HangUp sends message to peer to close connection +func HangUp(writer *bufio.Writer, ec uint32) (err error) { + n := 0 + msg := make([]byte, 4) + + binary.BigEndian.PutUint32(msg, ec) + + writeQuit := NewVirtWriteCloser(writer, Quit) + if n, err = writeQuit.Write(msg); err != nil { + return + } + + if n != len(msg) { + return errors.Errorf("Failed to send complete %s message", string(msg)) + } + return } |