diff options
Diffstat (limited to 'vendor/github.com/digitalocean/go-libvirt/rpc.go')
-rw-r--r-- | vendor/github.com/digitalocean/go-libvirt/rpc.go | 606 |
1 files changed, 606 insertions, 0 deletions
diff --git a/vendor/github.com/digitalocean/go-libvirt/rpc.go b/vendor/github.com/digitalocean/go-libvirt/rpc.go new file mode 100644 index 000000000..8181e0991 --- /dev/null +++ b/vendor/github.com/digitalocean/go-libvirt/rpc.go @@ -0,0 +1,606 @@ +// Copyright 2018 The go-libvirt Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package libvirt + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "reflect" + "strings" + "sync/atomic" + "unsafe" + + "github.com/digitalocean/go-libvirt/internal/constants" + "github.com/digitalocean/go-libvirt/internal/event" + xdr "github.com/digitalocean/go-libvirt/internal/go-xdr/xdr2" +) + +// ErrUnsupported is returned if a procedure is not supported by libvirt +var ErrUnsupported = errors.New("unsupported procedure requested") + +// request and response types +const ( + // Call is used when making calls to the remote server. + Call = iota + + // Reply indicates a server reply. + Reply + + // Message is an asynchronous notification. + Message + + // Stream represents a stream data packet. + Stream + + // CallWithFDs is used by a client to indicate the request has + // arguments with file descriptors. + CallWithFDs + + // ReplyWithFDs is used by a server to indicate the request has + // arguments with file descriptors. + ReplyWithFDs +) + +// request and response statuses +const ( + // StatusOK is always set for method calls or events. + // For replies it indicates successful completion of the method. + // For streams it indicates confirmation of the end of file on the stream. + StatusOK = iota + + // StatusError for replies indicates that the method call failed + // and error information is being returned. For streams this indicates + // that not all data was sent and the stream has aborted. + StatusError + + // StatusContinue is only used for streams. + // This indicates that further data packets will be following. + StatusContinue +) + +// header is a libvirt rpc packet header +type header struct { + // Program identifier + Program uint32 + + // Program version + Version uint32 + + // Remote procedure identifier + Procedure uint32 + + // Call type, e.g., Reply + Type uint32 + + // Call serial number + Serial int32 + + // Request status, e.g., StatusOK + Status uint32 +} + +// packet represents a RPC request or response. +type packet struct { + // Size of packet, in bytes, including length. + // Len + Header + Payload + Len uint32 + Header header +} + +// Global packet instance, for use with unsafe.Sizeof() +var _p packet + +// internal rpc response +type response struct { + Payload []byte + Status uint32 +} + +// libvirt error response +type libvirtError struct { + Code uint32 + DomainID uint32 + Padding uint8 + Message string + Level uint32 +} + +func (e libvirtError) Error() string { + return e.Message +} + +// checkError is used to check whether an error is a libvirtError, and if it is, +// whether its error code matches the one passed in. It will return false if +// these conditions are not met. +func checkError(err error, expectedError errorNumber) bool { + e, ok := err.(libvirtError) + if ok { + return e.Code == uint32(expectedError) + } + return false +} + +// IsNotFound detects libvirt's ERR_NO_DOMAIN. +func IsNotFound(err error) bool { + return checkError(err, errNoDomain) +} + +// listen processes incoming data and routes +// responses to their respective callback handler. +func (l *Libvirt) listen() { + for { + // response packet length + length, err := pktlen(l.r) + if err != nil { + // When the underlying connection EOFs or is closed, stop + // this goroutine + if err == io.EOF || strings.Contains(err.Error(), "use of closed network connection") { + return + } + + // invalid packet + continue + } + + // response header + h, err := extractHeader(l.r) + if err != nil { + // invalid packet + continue + } + + // payload: packet length minus what was previously read + size := int(length) - int(unsafe.Sizeof(_p)) + buf := make([]byte, size) + _, err = io.ReadFull(l.r, buf) + if err != nil { + // invalid packet + continue + } + + // route response to caller + l.route(h, buf) + } +} + +// callback sends RPC responses to respective callers. +func (l *Libvirt) callback(id int32, res response) { + l.cmux.Lock() + defer l.cmux.Unlock() + + c, ok := l.callbacks[id] + if !ok { + return + } + + c <- res +} + +// route sends incoming packets to their listeners. +func (l *Libvirt) route(h *header, buf []byte) { + // route events to their respective listener + var event event.Event + + switch { + case h.Program == constants.QEMUProgram && h.Procedure == constants.QEMUProcDomainMonitorEvent: + event = &DomainEvent{} + case h.Program == constants.Program && h.Procedure == constants.ProcDomainEventCallbackLifecycle: + event = &DomainEventCallbackLifecycleMsg{} + } + + if event != nil { + err := eventDecoder(buf, event) + if err != nil { // event was malformed, drop. + return + } + + l.stream(event) + return + } + + // send response to caller + l.callback(h.Serial, response{Payload: buf, Status: h.Status}) +} + +// serial provides atomic access to the next sequential request serial number. +func (l *Libvirt) serial() int32 { + return atomic.AddInt32(&l.s, 1) +} + +// stream decodes and relays domain events to their respective listener. +func (l *Libvirt) stream(e event.Event) { + l.emux.RLock() + defer l.emux.RUnlock() + + q, ok := l.events[e.GetCallbackID()] + if !ok { + return + } + + q.Push(e) +} + +// addStream configures the routing for an event stream. +func (l *Libvirt) addStream(s *event.Stream) { + l.emux.Lock() + defer l.emux.Unlock() + + l.events[s.CallbackID] = s +} + +// removeStream notifies the libvirt server to stop sending events for the +// provided callback ID. Upon successful de-registration the callback handler +// is destroyed. Subsequent calls to removeStream are idempotent and return +// nil. +// TODO: Fix this comment +func (l *Libvirt) removeStream(id int32) error { + l.emux.Lock() + defer l.emux.Unlock() + + // if the event is already removed, just return nil + _, ok := l.events[id] + if ok { + delete(l.events, id) + } + + return nil +} + +// register configures a method response callback +func (l *Libvirt) register(id int32, c chan response) { + l.cmux.Lock() + defer l.cmux.Unlock() + + l.callbacks[id] = c +} + +// deregister destroys a method response callback. It is the responsibility of +// the caller to manage locking (l.cmux) during this call. +func (l *Libvirt) deregister(id int32) { + _, ok := l.callbacks[id] + if !ok { + return + } + + close(l.callbacks[id]) + delete(l.callbacks, id) +} + +// deregisterAll closes all waiting callback channels. This is used to clean up +// if the connection to libvirt is lost. Callers waiting for responses will +// return an error when the response channel is closed, rather than just +// hanging. +func (l *Libvirt) deregisterAll() { + l.cmux.Lock() + defer l.cmux.Unlock() + + for id := range l.callbacks { + l.deregister(id) + } +} + +// request performs a libvirt RPC request. +// returns response returned by server. +// if response is not OK, decodes error from it and returns it. +func (l *Libvirt) request(proc uint32, program uint32, payload []byte) (response, error) { + return l.requestStream(proc, program, payload, nil, nil) +} + +// requestStream performs a libvirt RPC request. The `out` and `in` parameters +// are optional, and should be nil when RPC endpoints don't return a stream. +func (l *Libvirt) requestStream(proc uint32, program uint32, payload []byte, + out io.Reader, in io.Writer) (response, error) { + serial := l.serial() + c := make(chan response) + + l.register(serial, c) + defer func() { + l.cmux.Lock() + defer l.cmux.Unlock() + + l.deregister(serial) + }() + + err := l.sendPacket(serial, proc, program, payload, Call, StatusOK) + if err != nil { + return response{}, err + } + + resp, err := l.getResponse(c) + if err != nil { + return resp, err + } + + if out != nil { + abort := make(chan bool) + outErr := make(chan error) + go func() { + outErr <- l.sendStream(serial, proc, program, out, abort) + }() + + // Even without incoming stream server sends confirmation once all data is received + resp, err = l.processIncomingStream(c, in) + if err != nil { + abort <- true + return resp, err + } + + err = <-outErr + if err != nil { + return response{}, err + } + } + + switch in { + case nil: + return resp, nil + default: + return l.processIncomingStream(c, in) + } +} + +// processIncomingStream is called once we've successfully sent a request to +// libvirt. It writes the responses back to the stream passed by the caller +// until libvirt sends a packet with statusOK or an error. +func (l *Libvirt) processIncomingStream(c chan response, inStream io.Writer) (response, error) { + for { + resp, err := l.getResponse(c) + if err != nil { + return resp, err + } + + // StatusOK indicates end of stream + if resp.Status == StatusOK { + return resp, nil + } + + // FIXME: this smells. + // StatusError is handled in getResponse, so this must be StatusContinue + // StatusContinue is only valid here for stream packets + // libvirtd breaks protocol and returns StatusContinue with an + // empty response Payload when the stream finishes + if len(resp.Payload) == 0 { + return resp, nil + } + if inStream != nil { + _, err = inStream.Write(resp.Payload) + if err != nil { + return response{}, err + } + } + } +} + +func (l *Libvirt) sendStream(serial int32, proc uint32, program uint32, stream io.Reader, abort chan bool) error { + // Keep total packet length under 4 MiB to follow possible limitation in libvirt server code + buf := make([]byte, 4*MiB-unsafe.Sizeof(_p)) + for { + select { + case <-abort: + return l.sendPacket(serial, proc, program, nil, Stream, StatusError) + default: + } + n, err := stream.Read(buf) + if n > 0 { + err2 := l.sendPacket(serial, proc, program, buf[:n], Stream, StatusContinue) + if err2 != nil { + return err2 + } + } + if err != nil { + if err == io.EOF { + return l.sendPacket(serial, proc, program, nil, Stream, StatusOK) + } + // keep original error + err2 := l.sendPacket(serial, proc, program, nil, Stream, StatusError) + if err2 != nil { + return err2 + } + return err + } + } +} + +func (l *Libvirt) sendPacket(serial int32, proc uint32, program uint32, payload []byte, typ uint32, status uint32) error { + + p := packet{ + Header: header{ + Program: program, + Version: constants.ProtocolVersion, + Procedure: proc, + Type: typ, + Serial: serial, + Status: status, + }, + } + + size := int(unsafe.Sizeof(p.Len)) + int(unsafe.Sizeof(p.Header)) + if payload != nil { + size += len(payload) + } + p.Len = uint32(size) + + // write header + l.mu.Lock() + defer l.mu.Unlock() + err := binary.Write(l.w, binary.BigEndian, p) + if err != nil { + return err + } + + // write payload + if payload != nil { + err = binary.Write(l.w, binary.BigEndian, payload) + if err != nil { + return err + } + } + + return l.w.Flush() +} + +func (l *Libvirt) getResponse(c chan response) (response, error) { + resp := <-c + if resp.Status == StatusError { + return resp, decodeError(resp.Payload) + } + + return resp, nil +} + +// encode XDR encodes the provided data. +func encode(data interface{}) ([]byte, error) { + var buf bytes.Buffer + _, err := xdr.Marshal(&buf, data) + + return buf.Bytes(), err +} + +// decodeError extracts an error message from the provider buffer. +func decodeError(buf []byte) error { + var e libvirtError + + dec := xdr.NewDecoder(bytes.NewReader(buf)) + _, err := dec.Decode(&e) + if err != nil { + return err + } + + if strings.Contains(e.Message, "unknown procedure") { + return ErrUnsupported + } + // if libvirt returns ERR_OK, ignore the error + if checkError(e, errOk) { + return nil + } + + return e +} + +// eventDecoder decodes an event from a xdr buffer. +func eventDecoder(buf []byte, e interface{}) error { + dec := xdr.NewDecoder(bytes.NewReader(buf)) + _, err := dec.Decode(e) + return err +} + +// pktlen returns the length of an incoming RPC packet. Read errors will +// result in a returned response length of 0 and a non-nil error. +func pktlen(r io.Reader) (uint32, error) { + buf := make([]byte, unsafe.Sizeof(_p.Len)) + + // extract the packet's length from the header + _, err := io.ReadFull(r, buf) + if err != nil { + return 0, err + } + + return binary.BigEndian.Uint32(buf), nil +} + +// extractHeader returns the decoded header from an incoming response. +func extractHeader(r io.Reader) (*header, error) { + buf := make([]byte, unsafe.Sizeof(_p.Header)) + + // extract the packet's header from r + _, err := io.ReadFull(r, buf) + if err != nil { + return nil, err + } + + return &header{ + Program: binary.BigEndian.Uint32(buf[0:4]), + Version: binary.BigEndian.Uint32(buf[4:8]), + Procedure: binary.BigEndian.Uint32(buf[8:12]), + Type: binary.BigEndian.Uint32(buf[12:16]), + Serial: int32(binary.BigEndian.Uint32(buf[16:20])), + Status: binary.BigEndian.Uint32(buf[20:24]), + }, nil +} + +type typedParamDecoder struct{} + +// Decode decodes a TypedParam. These are part of the libvirt spec, and not xdr +// proper. TypedParams contain a name, which is called Field for some reason, +// and a Value, which itself has a "discriminant" - an integer enum encoding the +// actual type, and a value, the length of which varies based on the actual +// type. +func (tpd typedParamDecoder) Decode(d *xdr.Decoder, v reflect.Value) (int, error) { + // Get the name of the typed param first + name, n, err := d.DecodeString() + if err != nil { + return n, err + } + val, n2, err := tpd.decodeTypedParamValue(d) + n += n2 + if err != nil { + return n, err + } + tp := &TypedParam{Field: name, Value: *val} + v.Set(reflect.ValueOf(*tp)) + + return n, nil +} + +// decodeTypedParamValue decodes the Value part of a TypedParam. +func (typedParamDecoder) decodeTypedParamValue(d *xdr.Decoder) (*TypedParamValue, int, error) { + // All TypedParamValues begin with a uint32 discriminant that tells us what + // type they are. + discriminant, n, err := d.DecodeUint() + if err != nil { + return nil, n, err + } + var n2 int + var tpv *TypedParamValue + switch discriminant { + case 1: + var val int32 + n2, err = d.Decode(&val) + tpv = &TypedParamValue{D: discriminant, I: val} + case 2: + var val uint32 + n2, err = d.Decode(&val) + tpv = &TypedParamValue{D: discriminant, I: val} + case 3: + var val int64 + n2, err = d.Decode(&val) + tpv = &TypedParamValue{D: discriminant, I: val} + case 4: + var val uint64 + n2, err = d.Decode(&val) + tpv = &TypedParamValue{D: discriminant, I: val} + case 5: + var val float64 + n2, err = d.Decode(&val) + tpv = &TypedParamValue{D: discriminant, I: val} + case 6: + var val int32 + n2, err = d.Decode(&val) + tpv = &TypedParamValue{D: discriminant, I: val} + case 7: + var val string + n2, err = d.Decode(&val) + tpv = &TypedParamValue{D: discriminant, I: val} + + default: + err = fmt.Errorf("invalid parameter type %v", discriminant) + } + n += n2 + + return tpv, n, err +} |