summaryrefslogtreecommitdiff
path: root/vendor/github.com/Microsoft/go-winio/hvsock.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/Microsoft/go-winio/hvsock.go')
-rw-r--r--vendor/github.com/Microsoft/go-winio/hvsock.go305
1 files changed, 305 insertions, 0 deletions
diff --git a/vendor/github.com/Microsoft/go-winio/hvsock.go b/vendor/github.com/Microsoft/go-winio/hvsock.go
new file mode 100644
index 000000000..dbfe790ee
--- /dev/null
+++ b/vendor/github.com/Microsoft/go-winio/hvsock.go
@@ -0,0 +1,305 @@
+package winio
+
+import (
+ "fmt"
+ "io"
+ "net"
+ "os"
+ "syscall"
+ "time"
+ "unsafe"
+
+ "github.com/Microsoft/go-winio/pkg/guid"
+)
+
+//sys bind(s syscall.Handle, name unsafe.Pointer, namelen int32) (err error) [failretval==socketError] = ws2_32.bind
+
+const (
+ afHvSock = 34 // AF_HYPERV
+
+ socketError = ^uintptr(0)
+)
+
+// An HvsockAddr is an address for a AF_HYPERV socket.
+type HvsockAddr struct {
+ VMID guid.GUID
+ ServiceID guid.GUID
+}
+
+type rawHvsockAddr struct {
+ Family uint16
+ _ uint16
+ VMID guid.GUID
+ ServiceID guid.GUID
+}
+
+// Network returns the address's network name, "hvsock".
+func (addr *HvsockAddr) Network() string {
+ return "hvsock"
+}
+
+func (addr *HvsockAddr) String() string {
+ return fmt.Sprintf("%s:%s", &addr.VMID, &addr.ServiceID)
+}
+
+// VsockServiceID returns an hvsock service ID corresponding to the specified AF_VSOCK port.
+func VsockServiceID(port uint32) guid.GUID {
+ g, _ := guid.FromString("00000000-facb-11e6-bd58-64006a7986d3")
+ g.Data1 = port
+ return g
+}
+
+func (addr *HvsockAddr) raw() rawHvsockAddr {
+ return rawHvsockAddr{
+ Family: afHvSock,
+ VMID: addr.VMID,
+ ServiceID: addr.ServiceID,
+ }
+}
+
+func (addr *HvsockAddr) fromRaw(raw *rawHvsockAddr) {
+ addr.VMID = raw.VMID
+ addr.ServiceID = raw.ServiceID
+}
+
+// HvsockListener is a socket listener for the AF_HYPERV address family.
+type HvsockListener struct {
+ sock *win32File
+ addr HvsockAddr
+}
+
+// HvsockConn is a connected socket of the AF_HYPERV address family.
+type HvsockConn struct {
+ sock *win32File
+ local, remote HvsockAddr
+}
+
+func newHvSocket() (*win32File, error) {
+ fd, err := syscall.Socket(afHvSock, syscall.SOCK_STREAM, 1)
+ if err != nil {
+ return nil, os.NewSyscallError("socket", err)
+ }
+ f, err := makeWin32File(fd)
+ if err != nil {
+ syscall.Close(fd)
+ return nil, err
+ }
+ f.socket = true
+ return f, nil
+}
+
+// ListenHvsock listens for connections on the specified hvsock address.
+func ListenHvsock(addr *HvsockAddr) (_ *HvsockListener, err error) {
+ l := &HvsockListener{addr: *addr}
+ sock, err := newHvSocket()
+ if err != nil {
+ return nil, l.opErr("listen", err)
+ }
+ sa := addr.raw()
+ err = bind(sock.handle, unsafe.Pointer(&sa), int32(unsafe.Sizeof(sa)))
+ if err != nil {
+ return nil, l.opErr("listen", os.NewSyscallError("socket", err))
+ }
+ err = syscall.Listen(sock.handle, 16)
+ if err != nil {
+ return nil, l.opErr("listen", os.NewSyscallError("listen", err))
+ }
+ return &HvsockListener{sock: sock, addr: *addr}, nil
+}
+
+func (l *HvsockListener) opErr(op string, err error) error {
+ return &net.OpError{Op: op, Net: "hvsock", Addr: &l.addr, Err: err}
+}
+
+// Addr returns the listener's network address.
+func (l *HvsockListener) Addr() net.Addr {
+ return &l.addr
+}
+
+// Accept waits for the next connection and returns it.
+func (l *HvsockListener) Accept() (_ net.Conn, err error) {
+ sock, err := newHvSocket()
+ if err != nil {
+ return nil, l.opErr("accept", err)
+ }
+ defer func() {
+ if sock != nil {
+ sock.Close()
+ }
+ }()
+ c, err := l.sock.prepareIo()
+ if err != nil {
+ return nil, l.opErr("accept", err)
+ }
+ defer l.sock.wg.Done()
+
+ // AcceptEx, per documentation, requires an extra 16 bytes per address.
+ const addrlen = uint32(16 + unsafe.Sizeof(rawHvsockAddr{}))
+ var addrbuf [addrlen * 2]byte
+
+ var bytes uint32
+ err = syscall.AcceptEx(l.sock.handle, sock.handle, &addrbuf[0], 0, addrlen, addrlen, &bytes, &c.o)
+ _, err = l.sock.asyncIo(c, nil, bytes, err)
+ if err != nil {
+ return nil, l.opErr("accept", os.NewSyscallError("acceptex", err))
+ }
+ conn := &HvsockConn{
+ sock: sock,
+ }
+ conn.local.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[0])))
+ conn.remote.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[addrlen])))
+ sock = nil
+ return conn, nil
+}
+
+// Close closes the listener, causing any pending Accept calls to fail.
+func (l *HvsockListener) Close() error {
+ return l.sock.Close()
+}
+
+/* Need to finish ConnectEx handling
+func DialHvsock(ctx context.Context, addr *HvsockAddr) (*HvsockConn, error) {
+ sock, err := newHvSocket()
+ if err != nil {
+ return nil, err
+ }
+ defer func() {
+ if sock != nil {
+ sock.Close()
+ }
+ }()
+ c, err := sock.prepareIo()
+ if err != nil {
+ return nil, err
+ }
+ defer sock.wg.Done()
+ var bytes uint32
+ err = windows.ConnectEx(windows.Handle(sock.handle), sa, nil, 0, &bytes, &c.o)
+ _, err = sock.asyncIo(ctx, c, nil, bytes, err)
+ if err != nil {
+ return nil, err
+ }
+ conn := &HvsockConn{
+ sock: sock,
+ remote: *addr,
+ }
+ sock = nil
+ return conn, nil
+}
+*/
+
+func (conn *HvsockConn) opErr(op string, err error) error {
+ return &net.OpError{Op: op, Net: "hvsock", Source: &conn.local, Addr: &conn.remote, Err: err}
+}
+
+func (conn *HvsockConn) Read(b []byte) (int, error) {
+ c, err := conn.sock.prepareIo()
+ if err != nil {
+ return 0, conn.opErr("read", err)
+ }
+ defer conn.sock.wg.Done()
+ buf := syscall.WSABuf{Buf: &b[0], Len: uint32(len(b))}
+ var flags, bytes uint32
+ err = syscall.WSARecv(conn.sock.handle, &buf, 1, &bytes, &flags, &c.o, nil)
+ n, err := conn.sock.asyncIo(c, &conn.sock.readDeadline, bytes, err)
+ if err != nil {
+ if _, ok := err.(syscall.Errno); ok {
+ err = os.NewSyscallError("wsarecv", err)
+ }
+ return 0, conn.opErr("read", err)
+ } else if n == 0 {
+ err = io.EOF
+ }
+ return n, err
+}
+
+func (conn *HvsockConn) Write(b []byte) (int, error) {
+ t := 0
+ for len(b) != 0 {
+ n, err := conn.write(b)
+ if err != nil {
+ return t + n, err
+ }
+ t += n
+ b = b[n:]
+ }
+ return t, nil
+}
+
+func (conn *HvsockConn) write(b []byte) (int, error) {
+ c, err := conn.sock.prepareIo()
+ if err != nil {
+ return 0, conn.opErr("write", err)
+ }
+ defer conn.sock.wg.Done()
+ buf := syscall.WSABuf{Buf: &b[0], Len: uint32(len(b))}
+ var bytes uint32
+ err = syscall.WSASend(conn.sock.handle, &buf, 1, &bytes, 0, &c.o, nil)
+ n, err := conn.sock.asyncIo(c, &conn.sock.writeDeadline, bytes, err)
+ if err != nil {
+ if _, ok := err.(syscall.Errno); ok {
+ err = os.NewSyscallError("wsasend", err)
+ }
+ return 0, conn.opErr("write", err)
+ }
+ return n, err
+}
+
+// Close closes the socket connection, failing any pending read or write calls.
+func (conn *HvsockConn) Close() error {
+ return conn.sock.Close()
+}
+
+func (conn *HvsockConn) shutdown(how int) error {
+ err := syscall.Shutdown(conn.sock.handle, syscall.SHUT_RD)
+ if err != nil {
+ return os.NewSyscallError("shutdown", err)
+ }
+ return nil
+}
+
+// CloseRead shuts down the read end of the socket.
+func (conn *HvsockConn) CloseRead() error {
+ err := conn.shutdown(syscall.SHUT_RD)
+ if err != nil {
+ return conn.opErr("close", err)
+ }
+ return nil
+}
+
+// CloseWrite shuts down the write end of the socket, notifying the other endpoint that
+// no more data will be written.
+func (conn *HvsockConn) CloseWrite() error {
+ err := conn.shutdown(syscall.SHUT_WR)
+ if err != nil {
+ return conn.opErr("close", err)
+ }
+ return nil
+}
+
+// LocalAddr returns the local address of the connection.
+func (conn *HvsockConn) LocalAddr() net.Addr {
+ return &conn.local
+}
+
+// RemoteAddr returns the remote address of the connection.
+func (conn *HvsockConn) RemoteAddr() net.Addr {
+ return &conn.remote
+}
+
+// SetDeadline implements the net.Conn SetDeadline method.
+func (conn *HvsockConn) SetDeadline(t time.Time) error {
+ conn.SetReadDeadline(t)
+ conn.SetWriteDeadline(t)
+ return nil
+}
+
+// SetReadDeadline implements the net.Conn SetReadDeadline method.
+func (conn *HvsockConn) SetReadDeadline(t time.Time) error {
+ return conn.sock.SetReadDeadline(t)
+}
+
+// SetWriteDeadline implements the net.Conn SetWriteDeadline method.
+func (conn *HvsockConn) SetWriteDeadline(t time.Time) error {
+ return conn.sock.SetWriteDeadline(t)
+}