diff --git a/router.go b/router.go index 21f0035..bf6d5e6 100644 --- a/router.go +++ b/router.go @@ -1,6 +1,7 @@ package xmpp import ( + "context" "encoding/xml" "strings" @@ -25,16 +26,26 @@ TODO: Automatically reply to IQ that do not match any route, to comply to XMPP s type Router struct { // Routes to be matched, in order. routes []*Route + + iqResultRoutes map[string]*IqResultRoute } // NewRouter returns a new router instance. func NewRouter() *Router { - return &Router{} + return &Router{ + iqResultRoutes: make(map[string]*IqResultRoute), + } } // route is called by the XMPP client to dispatch stanza received using the set up routes. // It is also used by test, but is not supposed to be used directly by users of the library. func (r *Router) route(s Sender, p stanza.Packet) { + iq, isIq := p.(stanza.IQ) + if isIq { + if route, ok := r.iqResultRoutes[iq.Id]; ok { + route.handler.HandlePacket(s, p) + } + } var match RouteMatch if r.Match(p, &match) { @@ -42,11 +53,10 @@ func (r *Router) route(s Sender, p stanza.Packet) { match.Handler.HandlePacket(s, p) return } + // If there is no match and we receive an iq set or get, we need to send a reply - if iq, ok := p.(stanza.IQ); ok { - if iq.Type == stanza.IQTypeGet || iq.Type == stanza.IQTypeSet { - iqNotImplemented(s, iq) - } + if isIq && (iq.Type == stanza.IQTypeGet || iq.Type == stanza.IQTypeSet) { + iqNotImplemented(s, iq) } } @@ -68,6 +78,28 @@ func (r *Router) NewRoute() *Route { return route } +// NewIqResultRoute register a route that will catch an IQ result stanza with +// the given Id. The route will only match ones, after which it will automatically +// be unregistered +func (r *Router) NewIqResultRoute(ctx context.Context, id string) *IqResultRoute { + route := &IqResultRoute{ + context: ctx, + matched: make(chan struct{}), + } + r.iqResultRoutes[id] = route + go func() { + select { + case <-route.context.Done(): + if route.timeoutHandler != nil { + route.timeoutHandler(route.context.Err()) + } + case <-route.matched: + } + delete(r.iqResultRoutes, id) + }() + return route +} + func (r *Router) Match(p stanza.Packet, match *RouteMatch) bool { for _, route := range r.routes { if route.Match(p, match) { @@ -89,6 +121,40 @@ func (r *Router) HandleFunc(name string, f func(s Sender, p stanza.Packet)) *Rou return r.NewRoute().Packet(name).HandlerFunc(f) } +// HandleIqResult register a temporary route +func (r *Router) HandleIqResult(id string, handler Handler) *IqResultRoute { + return r.NewIqResultRoute(context.Background(), id).Handler(handler) +} + +func (r *Router) HandleFuncIqResult(id string, f func(s Sender, p stanza.Packet)) *IqResultRoute { + return r.NewIqResultRoute(context.Background(), id).HandlerFunc(f) +} + +// ============================================================================ +// IqResultRoute +type TimeoutHandlerFunc func(err error) + +type IqResultRoute struct { + context context.Context + matched chan struct{} + handler Handler + timeoutHandler TimeoutHandlerFunc +} + +func (r *IqResultRoute) Handler(handler Handler) *IqResultRoute { + r.handler = handler + return r +} + +func (r *IqResultRoute) HandlerFunc(f HandlerFunc) *IqResultRoute { + return r.Handler(f) +} + +func (r *IqResultRoute) TimeoutHandlerFunc(f TimeoutHandlerFunc) *IqResultRoute { + r.timeoutHandler = f + return r +} + // ============================================================================ // Route type Handler interface {