diff options
Diffstat (limited to 'vendor/github.com/godbus/dbus/v5/conn.go')
-rw-r--r-- | vendor/github.com/godbus/dbus/v5/conn.go | 159 |
1 files changed, 103 insertions, 56 deletions
diff --git a/vendor/github.com/godbus/dbus/v5/conn.go b/vendor/github.com/godbus/dbus/v5/conn.go index b55bc99c8..29fe018ad 100644 --- a/vendor/github.com/godbus/dbus/v5/conn.go +++ b/vendor/github.com/godbus/dbus/v5/conn.go @@ -45,6 +45,7 @@ type Conn struct { serialGen SerialGenerator inInt Interceptor outInt Interceptor + auth []Auth names *nameTracker calls *callTracker @@ -59,7 +60,8 @@ type Conn struct { func SessionBus() (conn *Conn, err error) { sessionBusLck.Lock() defer sessionBusLck.Unlock() - if sessionBus != nil { + if sessionBus != nil && + sessionBus.Connected() { return sessionBus, nil } defer func() { @@ -67,19 +69,7 @@ func SessionBus() (conn *Conn, err error) { sessionBus = conn } }() - conn, err = SessionBusPrivate() - if err != nil { - return - } - if err = conn.Auth(nil); err != nil { - conn.Close() - conn = nil - return - } - if err = conn.Hello(); err != nil { - conn.Close() - conn = nil - } + conn, err = ConnectSessionBus() return } @@ -116,7 +106,8 @@ func SessionBusPrivateHandler(handler Handler, signalHandler SignalHandler) (*Co func SystemBus() (conn *Conn, err error) { systemBusLck.Lock() defer systemBusLck.Unlock() - if systemBus != nil { + if systemBus != nil && + systemBus.Connected() { return systemBus, nil } defer func() { @@ -124,20 +115,42 @@ func SystemBus() (conn *Conn, err error) { systemBus = conn } }() - conn, err = SystemBusPrivate() + conn, err = ConnectSystemBus() + return +} + +// ConnectSessionBus connects to the session bus. +func ConnectSessionBus(opts ...ConnOption) (*Conn, error) { + address, err := getSessionBusAddress() if err != nil { - return + return nil, err } - if err = conn.Auth(nil); err != nil { - conn.Close() - conn = nil - return + return Connect(address, opts...) +} + +// ConnectSystemBus connects to the system bus. +func ConnectSystemBus(opts ...ConnOption) (*Conn, error) { + return Connect(getSystemBusPlatformAddress(), opts...) +} + +// Connect connects to the given address. +// +// Returned connection is ready to use and doesn't require calling +// Auth and Hello methods to make it usable. +func Connect(address string, opts ...ConnOption) (*Conn, error) { + conn, err := Dial(address, opts...) + if err != nil { + return nil, err + } + if err = conn.Auth(conn.auth); err != nil { + _ = conn.Close() + return nil, err } if err = conn.Hello(); err != nil { - conn.Close() - conn = nil + _ = conn.Close() + return nil, err } - return + return conn, nil } // SystemBusPrivate returns a new private connection to the system bus. @@ -197,6 +210,14 @@ func WithSerialGenerator(gen SerialGenerator) ConnOption { } } +// WithAuth sets authentication methods for the auth conversation. +func WithAuth(methods ...Auth) ConnOption { + return func(conn *Conn) error { + conn.auth = methods + return nil + } +} + // Interceptor intercepts incoming and outgoing messages. type Interceptor func(msg *Message) @@ -309,6 +330,11 @@ func (conn *Conn) Context() context.Context { return conn.ctx } +// Connected returns whether conn is connected +func (conn *Conn) Connected() bool { + return conn.ctx.Err() == nil +} + // Eavesdrop causes conn to send all incoming messages to the given channel // without further processing. Method replies, errors and signals will not be // sent to the appropriate channels and method calls will not be handled. If nil @@ -342,8 +368,9 @@ func (conn *Conn) Hello() error { } // inWorker runs in an own goroutine, reading incoming messages from the -// transport and dispatching them appropiately. +// transport and dispatching them appropriately. func (conn *Conn) inWorker() { + sequenceGen := newSequenceGenerator() for { msg, err := conn.ReadMessage() if err != nil { @@ -352,7 +379,7 @@ func (conn *Conn) inWorker() { // anything but to shut down all stuff and returns errors to all // pending replies. conn.Close() - conn.calls.finalizeAllWithError(err) + conn.calls.finalizeAllWithError(sequenceGen, err) return } // invalid messages are ignored @@ -381,13 +408,14 @@ func (conn *Conn) inWorker() { if conn.inInt != nil { conn.inInt(msg) } + sequence := sequenceGen.next() switch msg.Type { case TypeError: - conn.serialGen.RetireSerial(conn.calls.handleDBusError(msg)) + conn.serialGen.RetireSerial(conn.calls.handleDBusError(sequence, msg)) case TypeMethodReply: - conn.serialGen.RetireSerial(conn.calls.handleReply(msg)) + conn.serialGen.RetireSerial(conn.calls.handleReply(sequence, msg)) case TypeSignal: - conn.handleSignal(msg) + conn.handleSignal(sequence, msg) case TypeMethodCall: go conn.handleCall(msg) } @@ -395,7 +423,7 @@ func (conn *Conn) inWorker() { } } -func (conn *Conn) handleSignal(msg *Message) { +func (conn *Conn) handleSignal(sequence Sequence, msg *Message) { iface := msg.Headers[FieldInterface].value.(string) member := msg.Headers[FieldMember].value.(string) // as per http://dbus.freedesktop.org/doc/dbus-specification.html , @@ -421,10 +449,11 @@ func (conn *Conn) handleSignal(msg *Message) { } } signal := &Signal{ - Sender: sender, - Path: msg.Headers[FieldPath].value.(ObjectPath), - Name: iface + "." + member, - Body: msg.Body, + Sender: sender, + Path: msg.Headers[FieldPath].value.(ObjectPath), + Name: iface + "." + member, + Body: msg.Body, + Sequence: sequence, } conn.signalHandler.DeliverSignal(iface, member, signal) } @@ -442,6 +471,9 @@ func (conn *Conn) Object(dest string, path ObjectPath) BusObject { } func (conn *Conn) sendMessageAndIfClosed(msg *Message, ifClosed func()) { + if msg.serial == 0 { + msg.serial = conn.getSerial() + } if conn.outInt != nil { conn.outInt(msg) } @@ -473,16 +505,16 @@ func (conn *Conn) send(ctx context.Context, msg *Message, ch chan *Call) *Call { if ctx == nil { panic("nil context") } + if ch == nil { + ch = make(chan *Call, 1) + } else if cap(ch) == 0 { + panic("dbus: unbuffered channel passed to (*Conn).Send") + } var call *Call ctx, canceler := context.WithCancel(ctx) msg.serial = conn.getSerial() if msg.Type == TypeMethodCall && msg.Flags&FlagNoReplyExpected == 0 { - if ch == nil { - ch = make(chan *Call, 5) - } else if cap(ch) == 0 { - panic("dbus: unbuffered channel passed to (*Conn).Send") - } call = new(Call) call.Destination, _ = msg.Headers[FieldDestination].value.(string) call.Path, _ = msg.Headers[FieldPath].value.(ObjectPath) @@ -504,7 +536,8 @@ func (conn *Conn) send(ctx context.Context, msg *Message, ch chan *Call) *Call { }) } else { canceler() - call = &Call{Err: nil} + call = &Call{Err: nil, Done: ch} + ch <- call conn.sendMessageAndIfClosed(msg, func() { call = &Call{Err: ErrClosed} }) @@ -529,7 +562,6 @@ func (conn *Conn) sendError(err error, dest string, serial uint32) { } msg := new(Message) msg.Type = TypeError - msg.serial = conn.getSerial() msg.Headers = make(map[HeaderField]Variant) if dest != "" { msg.Headers[FieldDestination] = MakeVariant(dest) @@ -548,7 +580,6 @@ func (conn *Conn) sendError(err error, dest string, serial uint32) { func (conn *Conn) sendReply(dest string, serial uint32, values ...interface{}) { msg := new(Message) msg.Type = TypeMethodReply - msg.serial = conn.getSerial() msg.Headers = make(map[HeaderField]Variant) if dest != "" { msg.Headers[FieldDestination] = MakeVariant(dest) @@ -564,8 +595,14 @@ func (conn *Conn) sendReply(dest string, serial uint32, values ...interface{}) { // AddMatchSignal registers the given match rule to receive broadcast // signals based on their contents. func (conn *Conn) AddMatchSignal(options ...MatchOption) error { + return conn.AddMatchSignalContext(context.Background(), options...) +} + +// AddMatchSignalContext acts like AddMatchSignal but takes a context. +func (conn *Conn) AddMatchSignalContext(ctx context.Context, options ...MatchOption) error { options = append([]MatchOption{withMatchType("signal")}, options...) - return conn.busObj.Call( + return conn.busObj.CallWithContext( + ctx, "org.freedesktop.DBus.AddMatch", 0, formatMatchOptions(options), ).Store() @@ -573,8 +610,14 @@ func (conn *Conn) AddMatchSignal(options ...MatchOption) error { // RemoveMatchSignal removes the first rule that matches previously registered with AddMatchSignal. func (conn *Conn) RemoveMatchSignal(options ...MatchOption) error { + return conn.RemoveMatchSignalContext(context.Background(), options...) +} + +// RemoveMatchSignalContext acts like RemoveMatchSignal but takes a context. +func (conn *Conn) RemoveMatchSignalContext(ctx context.Context, options ...MatchOption) error { options = append([]MatchOption{withMatchType("signal")}, options...) - return conn.busObj.Call( + return conn.busObj.CallWithContext( + ctx, "org.freedesktop.DBus.RemoveMatch", 0, formatMatchOptions(options), ).Store() @@ -639,10 +682,11 @@ func (e Error) Error() string { // Signal represents a D-Bus message of type Signal. The name member is given in // "interface.member" notation, e.g. org.freedesktop.D-Bus.NameLost. type Signal struct { - Sender string - Path ObjectPath - Name string - Body []interface{} + Sender string + Path ObjectPath + Name string + Body []interface{} + Sequence Sequence } // transport is a D-Bus transport. @@ -825,25 +869,25 @@ func (tracker *callTracker) track(sn uint32, call *Call) { tracker.lck.Unlock() } -func (tracker *callTracker) handleReply(msg *Message) uint32 { +func (tracker *callTracker) handleReply(sequence Sequence, msg *Message) uint32 { serial := msg.Headers[FieldReplySerial].value.(uint32) tracker.lck.RLock() _, ok := tracker.calls[serial] tracker.lck.RUnlock() if ok { - tracker.finalizeWithBody(serial, msg.Body) + tracker.finalizeWithBody(serial, sequence, msg.Body) } return serial } -func (tracker *callTracker) handleDBusError(msg *Message) uint32 { +func (tracker *callTracker) handleDBusError(sequence Sequence, msg *Message) uint32 { serial := msg.Headers[FieldReplySerial].value.(uint32) tracker.lck.RLock() _, ok := tracker.calls[serial] tracker.lck.RUnlock() if ok { name, _ := msg.Headers[FieldErrorName].value.(string) - tracker.finalizeWithError(serial, Error{name, msg.Body}) + tracker.finalizeWithError(serial, sequence, Error{name, msg.Body}) } return serial } @@ -856,7 +900,7 @@ func (tracker *callTracker) handleSendError(msg *Message, err error) { _, ok := tracker.calls[msg.serial] tracker.lck.RUnlock() if ok { - tracker.finalizeWithError(msg.serial, err) + tracker.finalizeWithError(msg.serial, NoSequence, err) } } @@ -871,7 +915,7 @@ func (tracker *callTracker) finalize(sn uint32) { } } -func (tracker *callTracker) finalizeWithBody(sn uint32, body []interface{}) { +func (tracker *callTracker) finalizeWithBody(sn uint32, sequence Sequence, body []interface{}) { tracker.lck.Lock() c, ok := tracker.calls[sn] if ok { @@ -880,11 +924,12 @@ func (tracker *callTracker) finalizeWithBody(sn uint32, body []interface{}) { tracker.lck.Unlock() if ok { c.Body = body + c.ResponseSequence = sequence c.done() } } -func (tracker *callTracker) finalizeWithError(sn uint32, err error) { +func (tracker *callTracker) finalizeWithError(sn uint32, sequence Sequence, err error) { tracker.lck.Lock() c, ok := tracker.calls[sn] if ok { @@ -893,11 +938,12 @@ func (tracker *callTracker) finalizeWithError(sn uint32, err error) { tracker.lck.Unlock() if ok { c.Err = err + c.ResponseSequence = sequence c.done() } } -func (tracker *callTracker) finalizeAllWithError(err error) { +func (tracker *callTracker) finalizeAllWithError(sequenceGen *sequenceGenerator, err error) { tracker.lck.Lock() closedCalls := make([]*Call, 0, len(tracker.calls)) for sn := range tracker.calls { @@ -907,6 +953,7 @@ func (tracker *callTracker) finalizeAllWithError(err error) { tracker.lck.Unlock() for _, call := range closedCalls { call.Err = err + call.ResponseSequence = sequenceGen.next() call.done() } } |