diff --git a/client.go b/client.go index 686519a..a7e6c7d 100644 --- a/client.go +++ b/client.go @@ -1,6 +1,7 @@ package xmpp import ( + "context" "encoding/xml" "errors" "fmt" @@ -82,6 +83,8 @@ func (em EventManager) streamError(error, desc string) { // Client // ============================================================================ +var ErrCanOnlySendGetOrSetIq = errors.New("SendIQ can only send get and set IQ stanzas") + // Client is the main structure used to connect as a client on an XMPP // server. type Client struct { @@ -221,6 +224,25 @@ func (c *Client) Send(packet stanza.Packet) error { return c.sendWithWriter(c.transport, data) } +// SendIQ sends an IQ set or get stanza to the server. If a result is received +// the provided handler function will automatically be called. +// +// The provided context should have a timeout to prevent the client from waiting +// forever for an IQ result. For example: +// +// ctx, _ := context.WithTimeout(context.Background(), 30 * time.Second) +// result := <- client.SendIQ(ctx, iq) +// +func (c *Client) SendIQ(ctx context.Context, iq stanza.IQ) (chan stanza.IQ, error) { + if iq.Attrs.Type != "set" && iq.Attrs.Type != "get" { + return nil, ErrCanOnlySendGetOrSetIq + } + if err := c.Send(iq); err != nil { + return nil, err + } + return c.router.NewIQResultRoute(ctx, iq.Attrs.Id), nil +} + // SendRaw sends an XMPP stanza as a string to the server. // It can be invalid XML or XMPP content. In that case, the server will // disconnect the client. It is up to the user of this method to @@ -271,7 +293,10 @@ func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}) (err error) state.Inbound++ } - c.router.route(c, val) + // Do normal route processing in a go-routine so we can immediately + // start receiving other stanzas. This also allows route handlers to + // send and receive more stanzas. + go c.router.route(c, val) } } diff --git a/component.go b/component.go index 18be3bc..d459c00 100644 --- a/component.go +++ b/component.go @@ -1,6 +1,7 @@ package xmpp import ( + "context" "crypto/sha1" "encoding/hex" "encoding/xml" @@ -158,6 +159,25 @@ func (c *Component) Send(packet stanza.Packet) error { return nil } +// SendIQ sends an IQ set or get stanza to the server. If a result is received +// the provided handler function will automatically be called. +// +// The provided context should have a timeout to prevent the client from waiting +// forever for an IQ result. For example: +// +// ctx, _ := context.WithTimeout(context.Background(), 30 * time.Second) +// result := <- client.SendIQ(ctx, iq) +// +func (c *Component) SendIQ(ctx context.Context, iq stanza.IQ) (chan stanza.IQ, error) { + if iq.Attrs.Type != "set" && iq.Attrs.Type != "get" { + return nil, ErrCanOnlySendGetOrSetIq + } + if err := c.Send(iq); err != nil { + return nil, err + } + return c.router.NewIQResultRoute(ctx, iq.Attrs.Id), nil +} + // SendRaw sends an XMPP stanza as a string to the server. // It can be invalid XML or XMPP content. In that case, the server will // disconnect the component. It is up to the user of this method to diff --git a/go.mod b/go.mod index 3298cfa..f31fe40 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.13 require ( github.com/google/go-cmp v0.3.1 + github.com/google/uuid v1.1.1 golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 nhooyr.io/websocket v1.6.5 ) diff --git a/go.sum b/go.sum index cf05d14..ae38d07 100644 --- a/go.sum +++ b/go.sum @@ -21,9 +21,12 @@ github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1 h1:Xye71clBPdm5HgqGwUkwhbynsUJZhDbS20FvLhQ2izg= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20190908185732-236ed259b199/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= +github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= +github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= github.com/knq/sysutil v0.0.0-20181215143952-f05b59f0f307/go.mod h1:BjPj+aVjl9FW/cCGiF3nGh5v+9Gd3VCgBQbod/GlMaQ= diff --git a/router.go b/router.go index 21f0035..23a134e 100644 --- a/router.go +++ b/router.go @@ -1,8 +1,10 @@ package xmpp import ( + "context" "encoding/xml" "strings" + "sync" "gosrc.io/xmpp/stanza" ) @@ -25,16 +27,35 @@ TODO: Automatically reply to IQ that do not match any route, to comply to XMPP s type Router struct { // Routes to be matched, in order. routes []*Route + + IQResultRoutes map[string]*IQResultRoute + IQResultRouteLock sync.RWMutex } // NewRouter returns a new router instance. func NewRouter() *Router { - return &Router{} + return &Router{ + IQResultRoutes: make(map[string]*IQResultRoute), + } } // route is called by the XMPP client to dispatch stanza received using the set up routes. // It is also used by test, but is not supposed to be used directly by users of the library. func (r *Router) route(s Sender, p stanza.Packet) { + iq, isIq := p.(stanza.IQ) + if isIq { + r.IQResultRouteLock.RLock() + route, ok := r.IQResultRoutes[iq.Id] + r.IQResultRouteLock.RUnlock() + if ok { + r.IQResultRouteLock.Lock() + delete(r.IQResultRoutes, iq.Id) + r.IQResultRouteLock.Unlock() + route.result <- iq + close(route.result) + return + } + } var match RouteMatch if r.Match(p, &match) { @@ -42,11 +63,10 @@ func (r *Router) route(s Sender, p stanza.Packet) { match.Handler.HandlePacket(s, p) return } + // If there is no match and we receive an iq set or get, we need to send a reply - if iq, ok := p.(stanza.IQ); ok { - if iq.Type == stanza.IQTypeGet || iq.Type == stanza.IQTypeSet { - iqNotImplemented(s, iq) - } + if isIq && (iq.Type == stanza.IQTypeGet || iq.Type == stanza.IQTypeSet) { + iqNotImplemented(s, iq) } } @@ -68,6 +88,27 @@ func (r *Router) NewRoute() *Route { return route } +// NewIQResultRoute register a route that will catch an IQ result stanza with +// the given Id. The route will only match ones, after which it will automatically +// be unregistered +func (r *Router) NewIQResultRoute(ctx context.Context, id string) chan stanza.IQ { + route := NewIQResultRoute(ctx) + r.IQResultRouteLock.Lock() + r.IQResultRoutes[id] = route + r.IQResultRouteLock.Unlock() + + // Start a go function to make sure the route is unregistered when the context + // is done. + go func() { + <-route.context.Done() + r.IQResultRouteLock.Lock() + delete(r.IQResultRoutes, id) + r.IQResultRouteLock.Unlock() + }() + + return route.result +} + func (r *Router) Match(p stanza.Packet, match *RouteMatch) bool { for _, route := range r.routes { if route.Match(p, match) { @@ -89,8 +130,44 @@ func (r *Router) HandleFunc(name string, f func(s Sender, p stanza.Packet)) *Rou return r.NewRoute().Packet(name).HandlerFunc(f) } +// ============================================================================ + +// TimeoutHandlerFunc is a function type for handling IQ result timeouts. +type TimeoutHandlerFunc func(err error) + +// IQResultRoute is a temporary route to match IQ result stanzas +type IQResultRoute struct { + context context.Context + result chan stanza.IQ +} + +// NewIQResultRoute creates a new IQResultRoute instance +func NewIQResultRoute(ctx context.Context) *IQResultRoute { + return &IQResultRoute{ + context: ctx, + result: make(chan stanza.IQ), + } +} + +// ============================================================================ +// IQ result handler + +// IQResultHandler is a utility interface for IQ result handlers +type IQResultHandler interface { + HandleIQ(ctx context.Context, s Sender, iq stanza.IQ) +} + +// IQResultHandlerFunc is an adapter to allow using functions as IQ result handlers. +type IQResultHandlerFunc func(ctx context.Context, s Sender, iq stanza.IQ) + +// HandleIQ is a proxy function to implement IQResultHandler using a function. +func (f IQResultHandlerFunc) HandleIQ(ctx context.Context, s Sender, iq stanza.IQ) { + f(ctx, s, iq) +} + // ============================================================================ // Route + type Handler interface { HandlePacket(s Sender, p stanza.Packet) } diff --git a/router_test.go b/router_test.go index 98a4697..f9725ba 100644 --- a/router_test.go +++ b/router_test.go @@ -2,8 +2,10 @@ package xmpp import ( "bytes" + "context" "encoding/xml" "testing" + "time" "gosrc.io/xmpp/stanza" ) @@ -11,6 +13,47 @@ import ( // ============================================================================ // Test route & matchers +func TestIQResultRoutes(t *testing.T) { + t.Parallel() + router := NewRouter() + conn := NewSenderMock() + + if router.IQResultRoutes == nil { + t.Fatal("NewRouter does not initialize isResultRoutes") + } + + // Check if the IQ handler was called + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100) + defer cancel() + iq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeResult, Id: "1234"}) + res := router.NewIQResultRoute(ctx, "1234") + go router.route(conn, iq) + select { + case <-ctx.Done(): + t.Fatal("IQ result was not matched") + case <-res: + // Success + } + + // The match must only happen once, so the id should no longer be in IQResultRoutes + if _, ok := router.IQResultRoutes[iq.Attrs.Id]; ok { + t.Fatal("IQ ID was not removed from the route map") + } + + // Check other IQ does not matcah + ctx, cancel = context.WithTimeout(context.Background(), time.Millisecond*100) + defer cancel() + iq.Attrs.Id = "4321" + res = router.NewIQResultRoute(ctx, "1234") + go router.route(conn, iq) + select { + case <-ctx.Done(): + // Success + case <-res: + t.Fatal("IQ result with wrong ID was matched") + } +} + func TestNameMatcher(t *testing.T) { router := NewRouter() router.HandleFunc("message", func(s Sender, p stanza.Packet) { @@ -211,7 +254,8 @@ func TestCatchallMatcher(t *testing.T) { // ============================================================================ // SenderMock -var successFlag = "matched" +const successFlag = "matched" +const cancelledFlag = "cancelled" type SenderMock struct { buffer *bytes.Buffer diff --git a/stanza/iq.go b/stanza/iq.go index f1663bc..923cf28 100644 --- a/stanza/iq.go +++ b/stanza/iq.go @@ -2,6 +2,8 @@ package stanza import ( "encoding/xml" + + "github.com/google/uuid" ) /* @@ -31,8 +33,12 @@ type IQPayload interface { } func NewIQ(a Attrs) IQ { - // TODO generate IQ ID if not set // TODO ensure that type is set, as it is required + if a.Id == "" { + if id, err := uuid.NewRandom(); err == nil { + a.Id = id.String() + } + } return IQ{ XMLName: xml.Name{Local: "iq"}, Attrs: a, diff --git a/stanza/iq_test.go b/stanza/iq_test.go index 04a868a..93f7ebb 100644 --- a/stanza/iq_test.go +++ b/stanza/iq_test.go @@ -34,6 +34,24 @@ func TestUnmarshalIqs(t *testing.T) { } } +func TestGenerateIqId(t *testing.T) { + t.Parallel() + iq := stanza.NewIQ(stanza.Attrs{Id: "1"}) + if iq.Id != "1" { + t.Errorf("NewIQ replaced id with %s", iq.Id) + } + + iq = stanza.NewIQ(stanza.Attrs{}) + if iq.Id != "1" { + t.Error("NewIQ did not generate an Id") + } + + otherIq := stanza.NewIQ(stanza.Attrs{}) + if iq.Id == otherIq.Id { + t.Errorf("NewIQ generated two identical ids: %s", iq.Id) + } +} + func TestGenerateIq(t *testing.T) { iq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeResult, From: "admin@localhost", To: "test@localhost", Id: "1"}) payload := stanza.DiscoInfo{ diff --git a/stream_manager.go b/stream_manager.go index c21d79a..1011f6e 100644 --- a/stream_manager.go +++ b/stream_manager.go @@ -1,6 +1,7 @@ package xmpp import ( + "context" "errors" "sync" "time" @@ -26,6 +27,7 @@ type StreamClient interface { Connect() error Resume(state SMState) error Send(packet stanza.Packet) error + SendIQ(ctx context.Context, iq stanza.IQ) (chan stanza.IQ, error) SendRaw(packet string) error Disconnect() SetHandler(handler EventHandler) @@ -35,6 +37,7 @@ type StreamClient interface { // It is mostly use in callback to pass a limited subset of the stream client interface type Sender interface { Send(packet stanza.Packet) error + SendIQ(ctx context.Context, iq stanza.IQ) (chan stanza.IQ, error) SendRaw(packet string) error }