Refactor / clean up registry

This commit is contained in:
Mickael Remond 2019-06-04 17:04:25 +02:00 committed by Mickaël Rémond
parent 836e723273
commit 7a4364be95
13 changed files with 205 additions and 143 deletions

View file

@ -105,14 +105,12 @@ type auth struct {
} }
type BindBind struct { type BindBind struct {
IQPayload
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-bind bind"` XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-bind bind"`
Resource string `xml:"resource,omitempty"` Resource string `xml:"resource,omitempty"`
Jid string `xml:"jid,omitempty"` Jid string `xml:"jid,omitempty"`
} }
func (*BindBind) IsIQPayload() {
}
// Session is obsolete in RFC 6121. // Session is obsolete in RFC 6121.
// Added for compliance with RFC 3121. // Added for compliance with RFC 3121.
// Remove when ejabberd purely conforms to RFC 6121. // Remove when ejabberd purely conforms to RFC 6121.

View file

@ -29,7 +29,6 @@ func main() {
func runCheck(address, domain string) { func runCheck(address, domain string) {
client, err := xmpp.NewChecker(address, domain) client, err := xmpp.NewChecker(address, domain)
// client, err := xmpp.NewChecker("mickael.m.in-app.io:5222", "mickael.m.in-app.io")
if err != nil { if err != nil {
log.Fatal("Error: ", err) log.Fatal("Error: ", err)

View file

@ -12,8 +12,6 @@ import (
"github.com/processone/mpg123" "github.com/processone/mpg123"
"github.com/processone/soundcloud" "github.com/processone/soundcloud"
"gosrc.io/xmpp" "gosrc.io/xmpp"
"gosrc.io/xmpp/iot"
"gosrc.io/xmpp/pep"
) )
// Get the actual song Stream URL from SoundCloud website song URL and play it with mpg123 player. // Get the actual song Stream URL from SoundCloud website song URL and play it with mpg123 player.
@ -65,7 +63,7 @@ func processMessage(client *xmpp.Client, p *mpg123.Player, packet *xmpp.Message)
func processIq(client *xmpp.Client, p *mpg123.Player, packet *xmpp.IQ) { func processIq(client *xmpp.Client, p *mpg123.Player, packet *xmpp.IQ) {
switch payload := packet.Payload[0].(type) { switch payload := packet.Payload[0].(type) {
// We support IOT Control IQ // We support IOT Control IQ
case *iot.ControlSet: case *xmpp.ControlSet:
var url string var url string
for _, element := range payload.Fields { for _, element := range payload.Fields {
if element.XMLName.Local == "string" && element.Name == "url" { if element.XMLName.Local == "string" && element.Name == "url" {
@ -75,7 +73,7 @@ func processIq(client *xmpp.Client, p *mpg123.Player, packet *xmpp.IQ) {
} }
playSCURL(p, url) playSCURL(p, url)
setResponse := new(iot.ControlSetResponse) setResponse := new(xmpp.ControlSetResponse)
reply := xmpp.IQ{PacketAttrs: xmpp.PacketAttrs{To: packet.From, Type: "result", Id: packet.Id}, Payload: []xmpp.IQPayload{setResponse}} reply := xmpp.IQ{PacketAttrs: xmpp.PacketAttrs{To: packet.From, Type: "result", Id: packet.Id}, Payload: []xmpp.IQPayload{setResponse}}
_ = client.Send(reply) _ = client.Send(reply)
// TODO add Soundclound artist / title retrieval // TODO add Soundclound artist / title retrieval
@ -86,8 +84,11 @@ func processIq(client *xmpp.Client, p *mpg123.Player, packet *xmpp.IQ) {
} }
func sendUserTune(client *xmpp.Client, artist string, title string) { func sendUserTune(client *xmpp.Client, artist string, title string) {
tune := pep.Tune{Artist: artist, Title: title} tune := xmpp.Tune{Artist: artist, Title: title}
_ = client.SendRaw(tune.XMPPFormat()) iq := xmpp.NewIQ("set", "", "", "usertune-1", "en")
payload := xmpp.PubSub{Publish: xmpp.Publish{Node: "http://jabber.org/protocol/tune", Item: xmpp.Item{Tune: tune}}}
iq.AddPayload(&payload)
_ = client.Send(iq)
} }
func playSCURL(p *mpg123.Player, rawURL string) { func playSCURL(p *mpg123.Player, rawURL string) {

View file

@ -3,8 +3,6 @@ package xmpp // import "gosrc.io/xmpp"
import ( import (
"encoding/xml" "encoding/xml"
"testing" "testing"
"gosrc.io/xmpp/iot"
) )
func TestControlSet(t *testing.T) { func TestControlSet(t *testing.T) {
@ -22,7 +20,7 @@ func TestControlSet(t *testing.T) {
t.Errorf("Unmarshal(%s) returned error", data) t.Errorf("Unmarshal(%s) returned error", data)
} }
if cs, ok := parsedIQ.Payload[0].(*iot.ControlSet); !ok { if cs, ok := parsedIQ.Payload[0].(*ControlSet); !ok {
t.Errorf("Paylod is not an iot control set: %v", cs) t.Errorf("Paylod is not an iot control set: %v", cs)
} }
} }

View file

@ -1,15 +1,15 @@
package iot // import "gosrc.io/xmpp/iot" package xmpp // import "gosrc.io/xmpp/iot"
import "encoding/xml" import (
"encoding/xml"
)
type ControlSet struct { type ControlSet struct {
IQPayload
XMLName xml.Name `xml:"urn:xmpp:iot:control set"` XMLName xml.Name `xml:"urn:xmpp:iot:control set"`
Fields []ControlField `xml:",any"` Fields []ControlField `xml:",any"`
} }
func (*ControlSet) IsIQPayload() {
}
type ControlGetForm struct { type ControlGetForm struct {
XMLName xml.Name `xml:"urn:xmpp:iot:control getForm"` XMLName xml.Name `xml:"urn:xmpp:iot:control getForm"`
} }
@ -21,8 +21,6 @@ type ControlField struct {
} }
type ControlSetResponse struct { type ControlSetResponse struct {
IQPayload
XMLName xml.Name `xml:"urn:xmpp:iot:control setResponse"` XMLName xml.Name `xml:"urn:xmpp:iot:control setResponse"`
} }
func (*ControlSetResponse) IsIQPayload() {
}

47
iq.go
View file

@ -2,10 +2,7 @@ package xmpp // import "gosrc.io/xmpp"
import ( import (
"encoding/xml" "encoding/xml"
"reflect"
"strconv" "strconv"
"gosrc.io/xmpp/iot"
) )
/* /*
@ -19,6 +16,7 @@ TODO support ability to put Raw payload inside IQ
// presence or iq stanza. // presence or iq stanza.
// It is intended to be added in the payload of the erroneous stanza. // It is intended to be added in the payload of the erroneous stanza.
type Err struct { type Err struct {
IQPayload
XMLName xml.Name `xml:"error"` XMLName xml.Name `xml:"error"`
Code int `xml:"code,attr,omitempty"` Code int `xml:"code,attr,omitempty"`
Type string `xml:"type,attr,omitempty"` Type string `xml:"type,attr,omitempty"`
@ -26,8 +24,6 @@ type Err struct {
Text string `xml:"urn:ietf:params:xml:ns:xmpp-stanzas text,omitempty"` Text string `xml:"urn:ietf:params:xml:ns:xmpp-stanzas text,omitempty"`
} }
func (*Err) IsIQPayload() {}
// UnmarshalXML implements custom parsing for IQs // UnmarshalXML implements custom parsing for IQs
func (x *Err) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { func (x *Err) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
x.XMLName = start.Name x.XMLName = start.Name
@ -208,22 +204,16 @@ func (iq *IQ) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
case xml.StartElement: case xml.StartElement:
level++ level++
if level <= 1 { if level <= 1 {
var elt interface{} if iqExt := typeRegistry.GetIQExtension(tt.Name); iqExt != nil {
payloadType := tt.Name.Space + " " + tt.Name.Local // Decode payload extension
if payloadType := iqTypeRegistry[payloadType]; payloadType != nil { err = d.DecodeElement(iqExt, &tt)
val := reflect.New(payloadType)
elt = val.Interface()
} else {
// TODO: Fix me. We do nothing of that element here.
elt = new(Node)
}
if iqPl, ok := elt.(IQPayload); ok {
err = d.DecodeElement(elt, &tt)
if err != nil { if err != nil {
return err return err
} }
iq.Payload = append(iq.Payload, iqPl) iq.Payload = append(iq.Payload, iqExt)
} else {
// TODO: Fix me. We do nothing of that element here.
// elt = new(Node)
} }
} }
@ -239,13 +229,12 @@ func (iq *IQ) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
// ============================================================================ // ============================================================================
// Generic IQ Payload // Generic IQ Payload
type IQPayload interface { type IQPayload interface{}
IsIQPayload()
}
// Node is a generic structure to represent XML data. It is used to parse // Node is a generic structure to represent XML data. It is used to parse
// unreferenced or custom stanza payload. // unreferenced or custom stanza payload.
type Node struct { type Node struct {
IQPayload
XMLName xml.Name XMLName xml.Name
Attrs []xml.Attr `xml:"-"` Attrs []xml.Attr `xml:"-"`
Content string `xml:",innerxml"` Content string `xml:",innerxml"`
@ -284,8 +273,6 @@ func (n Node) MarshalXML(e *xml.Encoder, start xml.StartElement) (err error) {
return e.EncodeToken(xml.EndElement{Name: start.Name}) return e.EncodeToken(xml.EndElement{Name: start.Name})
} }
func (*Node) IsIQPayload() {}
// ============================================================================ // ============================================================================
// Disco // Disco
@ -295,14 +282,13 @@ const (
) )
type DiscoInfo struct { type DiscoInfo struct {
IQPayload
XMLName xml.Name `xml:"http://jabber.org/protocol/disco#info query"` XMLName xml.Name `xml:"http://jabber.org/protocol/disco#info query"`
Node string `xml:"node,attr,omitempty"` Node string `xml:"node,attr,omitempty"`
Identity Identity `xml:"identity"` Identity Identity `xml:"identity"`
Features []Feature `xml:"feature"` Features []Feature `xml:"feature"`
} }
func (*DiscoInfo) IsIQPayload() {}
type Identity struct { type Identity struct {
XMLName xml.Name `xml:"identity,omitempty"` XMLName xml.Name `xml:"identity,omitempty"`
Name string `xml:"name,attr,omitempty"` Name string `xml:"name,attr,omitempty"`
@ -318,13 +304,12 @@ type Feature struct {
// ============================================================================ // ============================================================================
type DiscoItems struct { type DiscoItems struct {
IQPayload
XMLName xml.Name `xml:"http://jabber.org/protocol/disco#items query"` XMLName xml.Name `xml:"http://jabber.org/protocol/disco#items query"`
Node string `xml:"node,attr,omitempty"` Node string `xml:"node,attr,omitempty"`
Items []DiscoItem `xml:"item"` Items []DiscoItem `xml:"item"`
} }
func (*DiscoItems) IsIQPayload() {}
type DiscoItem struct { type DiscoItem struct {
XMLName xml.Name `xml:"item"` XMLName xml.Name `xml:"item"`
Name string `xml:"name,attr,omitempty"` Name string `xml:"name,attr,omitempty"`
@ -333,8 +318,8 @@ type DiscoItem struct {
} }
func init() { func init() {
iqTypeRegistry["http://jabber.org/protocol/disco#info query"] = reflect.TypeOf(DiscoInfo{}) typeRegistry.MapExtension(PKTIQ, xml.Name{"http://jabber.org/protocol/disco#info", "query"}, DiscoInfo{})
iqTypeRegistry["http://jabber.org/protocol/disco#items query"] = reflect.TypeOf(DiscoItems{}) typeRegistry.MapExtension(PKTIQ, xml.Name{"http://jabber.org/protocol/disco#items", "query"}, DiscoItems{})
iqTypeRegistry["urn:ietf:params:xml:ns:xmpp-bind bind"] = reflect.TypeOf(BindBind{}) typeRegistry.MapExtension(PKTIQ, xml.Name{"urn:ietf:params:xml:ns:xmpp-bind", "bind"}, BindBind{})
iqTypeRegistry["urn:xmpp:iot:control set"] = reflect.TypeOf(iot.ControlSet{}) typeRegistry.MapExtension(PKTIQ, xml.Name{"urn:xmpp:iot:control", "set"}, ControlSet{})
} }

View file

@ -86,8 +86,7 @@ func (msg *Message) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
switch tt := t.(type) { switch tt := t.(type) {
case xml.StartElement: case xml.StartElement:
elementType := tt.Name.Space if msgExt := typeRegistry.GetMsgExtension(tt.Name); msgExt != nil {
if msgExt := typeRegistry.getmsgType(elementType); msgExt != nil {
// Decode message extension // Decode message extension
err = d.DecodeElement(msgExt, &tt) err = d.DecodeElement(msgExt, &tt)
if err != nil { if err != nil {

View file

@ -47,37 +47,3 @@ func TestDecodeError(t *testing.T) {
t.Errorf("incorrect error type: %s", parsedMessage.Error.Type) t.Errorf("incorrect error type: %s", parsedMessage.Error.Type)
} }
} }
func TestDecodeXEP0184(t *testing.T) {
str := `<message
from='northumberland@shakespeare.lit/westminster'
id='richard2-4.1.247'
to='kingrichard@royalty.england.lit/throne'>
<body>My lord, dispatch; read o'er these articles.</body>
<request xmlns='urn:xmpp:receipts'/>
</message>`
parsedMessage := xmpp.Message{}
if err := xml.Unmarshal([]byte(str), &parsedMessage); err != nil {
t.Errorf("message receipt unmarshall error: %v", err)
return
}
if parsedMessage.Body != "My lord, dispatch; read o'er these articles." {
t.Errorf("Unexpected body: '%s'", parsedMessage.Body)
}
if len(parsedMessage.Extensions) < 1 {
t.Errorf("no extension found on parsed message")
return
}
switch ext := parsedMessage.Extensions[0].(type) {
case *xmpp.Receipt:
if ext.XMLName.Local != "request" {
t.Errorf("unexpected extension: %s:%s", ext.XMLName.Space, ext.XMLName.Local)
}
default:
t.Errorf("could not find receipt extension")
}
}

View file

@ -7,17 +7,19 @@ Support for:
- XEP-0184 - Message Delivery Receipts: https://xmpp.org/extensions/xep-0184.html - XEP-0184 - Message Delivery Receipts: https://xmpp.org/extensions/xep-0184.html
*/ */
const ( // Used on outgoing message, to tell the recipient that you are requesting a message receipt / ack.
NSReceipts = "urn:xmpp:receipts" type ReceiptRequest struct {
)
// XEP-0184 message receipt markers
type Receipt struct {
MsgExtension MsgExtension
XMLName xml.Name XMLName xml.Name `xml:"urn:xmpp:receipts request"`
Id string }
type ReceiptReceived struct {
MsgExtension
XMLName xml.Name `xml:"urn:xmpp:receipts received"`
ID string
} }
func init() { func init() {
typeRegistry.RegisterMsgExt(NSReceipts, Receipt{}) typeRegistry.MapExtension(PKTMessage, xml.Name{"urn:xmpp:receipts", "request"}, ReceiptRequest{})
typeRegistry.MapExtension(PKTMessage, xml.Name{"urn:xmpp:receipts", "received"}, ReceiptReceived{})
} }

42
msg_receipts_test.go Normal file
View file

@ -0,0 +1,42 @@
package xmpp_test
import (
"encoding/xml"
"testing"
"gosrc.io/xmpp"
)
func TestDecodeRequest(t *testing.T) {
str := `<message
from='northumberland@shakespeare.lit/westminster'
id='richard2-4.1.247'
to='kingrichard@royalty.england.lit/throne'>
<body>My lord, dispatch; read o'er these articles.</body>
<request xmlns='urn:xmpp:receipts'/>
</message>`
parsedMessage := xmpp.Message{}
if err := xml.Unmarshal([]byte(str), &parsedMessage); err != nil {
t.Errorf("message receipt unmarshall error: %v", err)
return
}
if parsedMessage.Body != "My lord, dispatch; read o'er these articles." {
t.Errorf("Unexpected body: '%s'", parsedMessage.Body)
}
if len(parsedMessage.Extensions) < 1 {
t.Errorf("no extension found on parsed message")
return
}
switch ext := parsedMessage.Extensions[0].(type) {
case *xmpp.ReceiptRequest:
if ext.XMLName.Local != "request" {
t.Errorf("unexpected extension: %s:%s", ext.XMLName.Space, ext.XMLName.Local)
}
default:
t.Errorf("could not find receipts extension")
}
}

View file

@ -1,29 +1,21 @@
package pep // import "gosrc.io/xmpp/pep" package xmpp // import "gosrc.io/xmpp/pep"
import ( import (
"encoding/xml" "encoding/xml"
"gosrc.io/xmpp"
) )
type iq struct { type PubSub struct {
XMLName xml.Name `xml:"jabber:client iq"`
C pubSub // c for "contains"
xmpp.PacketAttrs // Rename h for "header" ?
}
type pubSub struct {
XMLName xml.Name `xml:"http://jabber.org/protocol/pubsub pubsub"` XMLName xml.Name `xml:"http://jabber.org/protocol/pubsub pubsub"`
Publish publish Publish Publish
} }
type publish struct { type Publish struct {
XMLName xml.Name `xml:"publish"` XMLName xml.Name `xml:"publish"`
Node string `xml:"node,attr"` Node string `xml:"node,attr"`
Item item Item Item
} }
type item struct { type Item struct {
XMLName xml.Name `xml:"item"` XMLName xml.Name `xml:"item"`
Tune Tune Tune Tune
} }
@ -67,11 +59,6 @@ type Tune struct {
} }
*/ */
func (t *Tune) XMPPFormat() (s string) {
packet, _ := xml.Marshal(iq{PacketAttrs: xmpp.PacketAttrs{Id: "tunes", Type: "set"}, C: pubSub{Publish: publish{Node: "http://jabber.org/protocol/tune", Item: item{Tune: *t}}}})
return string(packet)
}
/* /*
func (*Tune) XMPPFormat() string { func (*Tune) XMPPFormat() string {
return fmt.Sprintf( return fmt.Sprintf(

View file

@ -1,6 +1,7 @@
package xmpp package xmpp
import ( import (
"encoding/xml"
"reflect" "reflect"
"sync" "sync"
) )
@ -11,43 +12,76 @@ type MsgExtension interface{}
// TODO: Move to the client init process to remove the dependency on a global variable. // TODO: Move to the client init process to remove the dependency on a global variable.
// That should make it possible to be able to share the decoder. // That should make it possible to be able to share the decoder.
// TODO: Ensure that a client can add its own custom namespace to the registry (or overload existing ones). // TODO: Ensure that a client can add its own custom namespace to the registry (or overload existing ones).
type packetType uint8
const (
PKTPresence packetType = iota
PKTMessage
PKTIQ
)
var typeRegistry = newRegistry() var typeRegistry = newRegistry()
type namespace = string // We store different registries per packet type and namespace.
type registryKey struct {
type registry struct { packetType packetType
// Key is namespace of message extension namespace string
msgTypes map[namespace]reflect.Type
msgTypesLock *sync.RWMutex
iqTypes map[namespace]reflect.Type
} }
func newRegistry() registry { type registryForNamespace map[string]reflect.Type
return registry{
msgTypes: make(map[namespace]reflect.Type), type registry struct {
// We store different registries per packet type and namespace.
msgTypes map[registryKey]registryForNamespace
// Handle concurrent access
msgTypesLock *sync.RWMutex
}
func newRegistry() *registry {
return &registry{
msgTypes: make(map[registryKey]registryForNamespace),
msgTypesLock: &sync.RWMutex{}, msgTypesLock: &sync.RWMutex{},
iqTypes: make(map[namespace]reflect.Type),
} }
} }
// Mutexes are not needed when adding a Message or IQ extension in init function. // MapExtension stores extension type for packet payload.
// However, forcing the use of the mutex protect the data structure against unexpected use // The match is done per packetType (iq, message, or presence) and XML tag name.
// of the registry by developers using the library. // You can use the alias "*" as local XML name to be able to match all unknown tag name for that
func (r registry) RegisterMsgExt(namespace string, extension MsgExtension) { // packet type and namespace.
func (r *registry) MapExtension(pktType packetType, name xml.Name, extension MsgExtension) {
key := registryKey{pktType, name.Space}
r.msgTypesLock.RLock()
store := r.msgTypes[key]
r.msgTypesLock.RUnlock()
r.msgTypesLock.Lock() r.msgTypesLock.Lock()
defer r.msgTypesLock.Unlock() defer r.msgTypesLock.Unlock()
r.msgTypes[namespace] = reflect.TypeOf(extension) if store == nil {
store = make(map[string]reflect.Type)
}
store[name.Local] = reflect.TypeOf(extension)
r.msgTypes[key] = store
} }
func (r registry) getMsgExtType(namespace string) reflect.Type { // GetExtensionType returns extension type for packet payload, based on packet type and tag name.
func (r *registry) GetExtensionType(pktType packetType, name xml.Name) reflect.Type {
key := registryKey{pktType, name.Space}
r.msgTypesLock.RLock() r.msgTypesLock.RLock()
defer r.msgTypesLock.RUnlock() defer r.msgTypesLock.RUnlock()
return r.msgTypes[namespace] store := r.msgTypes[key]
result := store[name.Local]
if result == nil && name.Local != "*" {
return store["*"]
}
return result
} }
func (r registry) getmsgType(namespace string) MsgExtension { // GetMsgExtension returns an instance of MsgExtension, by matching packet type and XML
if extensionType := r.getMsgExtType(namespace); extensionType != nil { // tag name against the registry.
func (r *registry) GetMsgExtension(name xml.Name) MsgExtension {
if extensionType := r.GetExtensionType(PKTMessage, name); extensionType != nil {
val := reflect.New(extensionType) val := reflect.New(extensionType)
elt := val.Interface() elt := val.Interface()
if msgExt, ok := elt.(MsgExtension); ok { if msgExt, ok := elt.(MsgExtension); ok {
@ -57,9 +91,15 @@ func (r registry) getmsgType(namespace string) MsgExtension {
return nil return nil
} }
// Registry to support message extensions // GetIQExtension returns an instance of IQPayload, by matching packet type and XML
//var msgTypeRegistry = make(map[string]reflect.Type) // tag name against the registry.
func (r *registry) GetIQExtension(name xml.Name) IQPayload {
// Registry to instantiate the right IQ payload element if extensionType := r.GetExtensionType(PKTIQ, name); extensionType != nil {
// Key is namespace and key of the payload val := reflect.New(extensionType)
var iqTypeRegistry = make(map[string]reflect.Type) elt := val.Interface()
if iqExt, ok := elt.(IQPayload); ok {
return iqExt
}
}
return nil
}

47
registry_test.go Normal file
View file

@ -0,0 +1,47 @@
package xmpp // import "gosrc.io/xmpp"
import (
"encoding/xml"
"reflect"
"testing"
)
func TestRegistry_RegisterMsgExt(t *testing.T) {
// Setup registry
typeRegistry := newRegistry()
// Register an element
name := xml.Name{Space: "urn:xmpp:receipts", Local: "request"}
typeRegistry.MapExtension(PKTMessage, name, ReceiptRequest{})
// Match that element
receipt := typeRegistry.GetMsgExtension(name)
if receipt == nil {
t.Error("cannot read element type from registry")
return
}
switch r := receipt.(type) {
case *ReceiptRequest:
default:
t.Errorf("Registry did not return expected type ReceiptRequest: %v", reflect.TypeOf(r))
}
}
func BenchmarkRegistryGet(b *testing.B) {
// Setup registry
typeRegistry := newRegistry()
// Register an element
name := xml.Name{Space: "urn:xmpp:receipts", Local: "request"}
typeRegistry.MapExtension(PKTMessage, name, ReceiptRequest{})
for i := 0; i < b.N; i++ {
// Match that element
receipt := typeRegistry.GetExtensionType(PKTMessage, name)
if receipt == nil {
b.Error("cannot read element type from registry")
return
}
}
}