summaryrefslogtreecommitdiff
path: root/vendor/github.com/vishvananda/netlink/socket_linux.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/vishvananda/netlink/socket_linux.go')
-rw-r--r--vendor/github.com/vishvananda/netlink/socket_linux.go159
1 files changed, 159 insertions, 0 deletions
diff --git a/vendor/github.com/vishvananda/netlink/socket_linux.go b/vendor/github.com/vishvananda/netlink/socket_linux.go
new file mode 100644
index 000000000..b42b84f0c
--- /dev/null
+++ b/vendor/github.com/vishvananda/netlink/socket_linux.go
@@ -0,0 +1,159 @@
+package netlink
+
+import (
+ "errors"
+ "fmt"
+ "net"
+ "syscall"
+
+ "github.com/vishvananda/netlink/nl"
+)
+
+const (
+ sizeofSocketID = 0x30
+ sizeofSocketRequest = sizeofSocketID + 0x8
+ sizeofSocket = sizeofSocketID + 0x18
+)
+
+type socketRequest struct {
+ Family uint8
+ Protocol uint8
+ Ext uint8
+ pad uint8
+ States uint32
+ ID SocketID
+}
+
+type writeBuffer struct {
+ Bytes []byte
+ pos int
+}
+
+func (b *writeBuffer) Write(c byte) {
+ b.Bytes[b.pos] = c
+ b.pos++
+}
+
+func (b *writeBuffer) Next(n int) []byte {
+ s := b.Bytes[b.pos : b.pos+n]
+ b.pos += n
+ return s
+}
+
+func (r *socketRequest) Serialize() []byte {
+ b := writeBuffer{Bytes: make([]byte, sizeofSocketRequest)}
+ b.Write(r.Family)
+ b.Write(r.Protocol)
+ b.Write(r.Ext)
+ b.Write(r.pad)
+ native.PutUint32(b.Next(4), r.States)
+ networkOrder.PutUint16(b.Next(2), r.ID.SourcePort)
+ networkOrder.PutUint16(b.Next(2), r.ID.DestinationPort)
+ copy(b.Next(4), r.ID.Source.To4())
+ b.Next(12)
+ copy(b.Next(4), r.ID.Destination.To4())
+ b.Next(12)
+ native.PutUint32(b.Next(4), r.ID.Interface)
+ native.PutUint32(b.Next(4), r.ID.Cookie[0])
+ native.PutUint32(b.Next(4), r.ID.Cookie[1])
+ return b.Bytes
+}
+
+func (r *socketRequest) Len() int { return sizeofSocketRequest }
+
+type readBuffer struct {
+ Bytes []byte
+ pos int
+}
+
+func (b *readBuffer) Read() byte {
+ c := b.Bytes[b.pos]
+ b.pos++
+ return c
+}
+
+func (b *readBuffer) Next(n int) []byte {
+ s := b.Bytes[b.pos : b.pos+n]
+ b.pos += n
+ return s
+}
+
+func (s *Socket) deserialize(b []byte) error {
+ if len(b) < sizeofSocket {
+ return fmt.Errorf("socket data short read (%d); want %d", len(b), sizeofSocket)
+ }
+ rb := readBuffer{Bytes: b}
+ s.Family = rb.Read()
+ s.State = rb.Read()
+ s.Timer = rb.Read()
+ s.Retrans = rb.Read()
+ s.ID.SourcePort = networkOrder.Uint16(rb.Next(2))
+ s.ID.DestinationPort = networkOrder.Uint16(rb.Next(2))
+ s.ID.Source = net.IPv4(rb.Read(), rb.Read(), rb.Read(), rb.Read())
+ rb.Next(12)
+ s.ID.Destination = net.IPv4(rb.Read(), rb.Read(), rb.Read(), rb.Read())
+ rb.Next(12)
+ s.ID.Interface = native.Uint32(rb.Next(4))
+ s.ID.Cookie[0] = native.Uint32(rb.Next(4))
+ s.ID.Cookie[1] = native.Uint32(rb.Next(4))
+ s.Expires = native.Uint32(rb.Next(4))
+ s.RQueue = native.Uint32(rb.Next(4))
+ s.WQueue = native.Uint32(rb.Next(4))
+ s.UID = native.Uint32(rb.Next(4))
+ s.INode = native.Uint32(rb.Next(4))
+ return nil
+}
+
+// SocketGet returns the Socket identified by its local and remote addresses.
+func SocketGet(local, remote net.Addr) (*Socket, error) {
+ localTCP, ok := local.(*net.TCPAddr)
+ if !ok {
+ return nil, ErrNotImplemented
+ }
+ remoteTCP, ok := remote.(*net.TCPAddr)
+ if !ok {
+ return nil, ErrNotImplemented
+ }
+ localIP := localTCP.IP.To4()
+ if localIP == nil {
+ return nil, ErrNotImplemented
+ }
+ remoteIP := remoteTCP.IP.To4()
+ if remoteIP == nil {
+ return nil, ErrNotImplemented
+ }
+
+ s, err := nl.Subscribe(syscall.NETLINK_INET_DIAG)
+ if err != nil {
+ return nil, err
+ }
+ defer s.Close()
+ req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, 0)
+ req.AddData(&socketRequest{
+ Family: syscall.AF_INET,
+ Protocol: syscall.IPPROTO_TCP,
+ ID: SocketID{
+ SourcePort: uint16(localTCP.Port),
+ DestinationPort: uint16(remoteTCP.Port),
+ Source: localIP,
+ Destination: remoteIP,
+ Cookie: [2]uint32{nl.TCPDIAG_NOCOOKIE, nl.TCPDIAG_NOCOOKIE},
+ },
+ })
+ s.Send(req)
+ msgs, err := s.Receive()
+ if err != nil {
+ return nil, err
+ }
+ if len(msgs) == 0 {
+ return nil, errors.New("no message nor error from netlink")
+ }
+ if len(msgs) > 2 {
+ return nil, fmt.Errorf("multiple (%d) matching sockets", len(msgs))
+ }
+ sock := &Socket{}
+ if err := sock.deserialize(msgs[0].Data); err != nil {
+ return nil, err
+ }
+ return sock, nil
+}