Tests for Component and code style fixes (#129)

* Tests for Component and code style fixes
160-regression
remicorniere 4 years ago committed by Jérôme Sautret
parent 7d89353156
commit 1822089db6

@ -58,7 +58,7 @@ func handleMessage(_ xmpp.Sender, p stanza.Packet) {
func discoInfo(c xmpp.Sender, p stanza.Packet, opts xmpp.ComponentOptions) { func discoInfo(c xmpp.Sender, p stanza.Packet, opts xmpp.ComponentOptions) {
// Type conversion & sanity checks // Type conversion & sanity checks
iq, ok := p.(stanza.IQ) iq, ok := p.(stanza.IQ)
if !ok || iq.Type != "get" { if !ok || iq.Type != stanza.IQTypeGet {
return return
} }
@ -73,7 +73,7 @@ func discoInfo(c xmpp.Sender, p stanza.Packet, opts xmpp.ComponentOptions) {
func discoItems(c xmpp.Sender, p stanza.Packet) { func discoItems(c xmpp.Sender, p stanza.Packet) {
// Type conversion & sanity checks // Type conversion & sanity checks
iq, ok := p.(stanza.IQ) iq, ok := p.(stanza.IQ)
if !ok || iq.Type != "get" { if !ok || iq.Type != stanza.IQTypeGet {
return return
} }

@ -106,7 +106,7 @@ func handleIQ(s xmpp.Sender, p stanza.Packet, player *mpg123.Player) {
func sendUserTune(s xmpp.Sender, artist string, title string) { func sendUserTune(s xmpp.Sender, artist string, title string) {
tune := stanza.Tune{Artist: artist, Title: title} tune := stanza.Tune{Artist: artist, Title: title}
iq := stanza.NewIQ(stanza.Attrs{Type: "set", Id: "usertune-1", Lang: "en"}) iq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeSet, Id: "usertune-1", Lang: "en"})
payload := stanza.PubSub{Publish: &stanza.Publish{Node: "http://jabber.org/protocol/tune", Item: stanza.Item{Tune: &tune}}} payload := stanza.PubSub{Publish: &stanza.Publish{Node: "http://jabber.org/protocol/tune", Item: stanza.Item{Tune: &tune}}}
iq.Payload = &payload iq.Payload = &payload
_ = s.Send(iq) _ = s.Send(iq)

@ -60,7 +60,10 @@ func authPlain(socket io.ReadWriter, decoder *xml.Decoder, mech string, user str
raw := "\x00" + user + "\x00" + secret raw := "\x00" + user + "\x00" + secret
enc := make([]byte, base64.StdEncoding.EncodedLen(len(raw))) enc := make([]byte, base64.StdEncoding.EncodedLen(len(raw)))
base64.StdEncoding.Encode(enc, []byte(raw)) base64.StdEncoding.Encode(enc, []byte(raw))
fmt.Fprintf(socket, "<auth xmlns='%s' mechanism='%s'>%s</auth>", stanza.NSSASL, mech, enc) _, err := fmt.Fprintf(socket, "<auth xmlns='%s' mechanism='%s'>%s</auth>", stanza.NSSASL, mech, enc)
if err != nil {
return err
}
// Next message should be either success or failure. // Next message should be either success or failure.
val, err := stanza.NextPacket(decoder) val, err := stanza.NextPacket(decoder)

@ -79,7 +79,10 @@ func (c *ServerCheck) Check() error {
} }
if _, ok := f.DoesStartTLS(); ok { if _, ok := f.DoesStartTLS(); ok {
fmt.Fprintf(tcpconn, "<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>") _, err = fmt.Fprintf(tcpconn, "<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>")
if err != nil {
return err
}
var k stanza.TLSProceed var k stanza.TLSProceed
if err = decoder.DecodeElement(&k, nil); err != nil { if err = decoder.DecodeElement(&k, nil); err != nil {

@ -50,7 +50,7 @@ type SMState struct {
// EventHandler is use to pass events about state of the connection to // EventHandler is use to pass events about state of the connection to
// client implementation. // client implementation.
type EventHandler func(Event) type EventHandler func(Event) error
type EventManager struct { type EventManager struct {
// Store current state // Store current state
@ -188,13 +188,16 @@ func (c *Client) Resume(state SMState) error {
go keepalive(c.transport, keepaliveQuit) go keepalive(c.transport, keepaliveQuit)
// Start the receiver go routine // Start the receiver go routine
state = c.Session.SMState state = c.Session.SMState
go c.recv(state, keepaliveQuit) // Leaving this channel here for later. Not used atm. We should return this instead of an error because right
// now the returned error is lost in limbo.
errChan := make(chan error)
go c.recv(state, keepaliveQuit, errChan)
// We're connected and can now receive and send messages. // We're connected and can now receive and send messages.
//fmt.Fprintf(client.conn, "<presence xml:lang='en'><show>%s</show><status>%s</status></presence>", "chat", "Online") //fmt.Fprintf(client.conn, "<presence xml:lang='en'><show>%s</show><status>%s</status></presence>", "chat", "Online")
// TODO: Do we always want to send initial presence automatically ? // TODO: Do we always want to send initial presence automatically ?
// Do we need an option to avoid that or do we rely on client to send the presence itself ? // Do we need an option to avoid that or do we rely on client to send the presence itself ?
fmt.Fprintf(c.transport, "<presence/>") _, err = fmt.Fprintf(c.transport, "<presence/>")
return err return err
} }
@ -235,7 +238,7 @@ func (c *Client) Send(packet stanza.Packet) error {
// result := <- client.SendIQ(ctx, iq) // result := <- client.SendIQ(ctx, iq)
// //
func (c *Client) SendIQ(ctx context.Context, iq stanza.IQ) (chan stanza.IQ, error) { func (c *Client) SendIQ(ctx context.Context, iq stanza.IQ) (chan stanza.IQ, error) {
if iq.Attrs.Type != "set" && iq.Attrs.Type != "get" { if iq.Attrs.Type != stanza.IQTypeSet && iq.Attrs.Type != stanza.IQTypeGet {
return nil, ErrCanOnlySendGetOrSetIq return nil, ErrCanOnlySendGetOrSetIq
} }
if err := c.Send(iq); err != nil { if err := c.Send(iq); err != nil {
@ -267,13 +270,14 @@ func (c *Client) sendWithWriter(writer io.Writer, packet []byte) error {
// Go routines // Go routines
// Loop: Receive data from server // Loop: Receive data from server
func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}) (err error) { func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}, errChan chan<- error) {
for { for {
val, err := stanza.NextPacket(c.transport.GetDecoder()) val, err := stanza.NextPacket(c.transport.GetDecoder())
if err != nil { if err != nil {
errChan <- err
close(keepaliveQuit) close(keepaliveQuit)
c.disconnected(state) c.disconnected(state)
return err return
} }
// Handle stream errors // Handle stream errors
@ -282,18 +286,22 @@ func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}) (err error)
c.router.route(c, val) c.router.route(c, val)
close(keepaliveQuit) close(keepaliveQuit)
c.streamError(packet.Error.Local, packet.Text) c.streamError(packet.Error.Local, packet.Text)
return errors.New("stream error: " + packet.Error.Local) errChan <- errors.New("stream error: " + packet.Error.Local)
return
// Process Stream management nonzas // Process Stream management nonzas
case stanza.SMRequest: case stanza.SMRequest:
answer := stanza.SMAnswer{XMLName: xml.Name{ answer := stanza.SMAnswer{XMLName: xml.Name{
Space: stanza.NSStreamManagement, Space: stanza.NSStreamManagement,
Local: "a", Local: "a",
}, H: state.Inbound} }, H: state.Inbound}
c.Send(answer) err = c.Send(answer)
if err != nil {
errChan <- err
return
}
default: default:
state.Inbound++ state.Inbound++
} }
// Do normal route processing in a go-routine so we can immediately // Do normal route processing in a go-routine so we can immediately
// start receiving other stanzas. This also allows route handlers to // start receiving other stanzas. This also allows route handlers to
// send and receive more stanzas. // send and receive more stanzas.

@ -72,11 +72,13 @@ func (c *Component) Resume(sm SMState) error {
c.transport, err = NewComponentTransport(c.ComponentOptions.TransportConfiguration) c.transport, err = NewComponentTransport(c.ComponentOptions.TransportConfiguration)
if err != nil { if err != nil {
c.updateState(StatePermanentError) c.updateState(StatePermanentError)
return NewConnError(err, true) return NewConnError(err, true)
} }
if streamId, err = c.transport.Connect(); err != nil { if streamId, err = c.transport.Connect(); err != nil {
c.updateState(StatePermanentError) c.updateState(StatePermanentError)
return NewConnError(err, true) return NewConnError(err, true)
} }
c.updateState(StateConnected) c.updateState(StateConnected)
@ -84,6 +86,7 @@ func (c *Component) Resume(sm SMState) error {
// Authentication // Authentication
if _, err := fmt.Fprintf(c.transport, "<handshake>%s</handshake>", c.handshake(streamId)); err != nil { if _, err := fmt.Fprintf(c.transport, "<handshake>%s</handshake>", c.handshake(streamId)); err != nil {
c.updateState(StateStreamError) c.updateState(StateStreamError)
return NewConnError(errors.New("cannot send handshake "+err.Error()), false) return NewConnError(errors.New("cannot send handshake "+err.Error()), false)
} }
@ -101,12 +104,16 @@ func (c *Component) Resume(sm SMState) error {
case stanza.Handshake: case stanza.Handshake:
// Start the receiver go routine // Start the receiver go routine
c.updateState(StateSessionEstablished) c.updateState(StateSessionEstablished)
go c.recv() // Leaving this channel here for later. Not used atm. We should return this instead of an error because right
return nil // now the returned error is lost in limbo.
errChan := make(chan error)
go c.recv(errChan) // Sends to errChan
return err // Should be empty at this point
default: default:
c.updateState(StatePermanentError) c.updateState(StatePermanentError)
return NewConnError(errors.New("expecting handshake result, got "+v.Name()), true) return NewConnError(errors.New("expecting handshake result, got "+v.Name()), true)
} }
return err
} }
func (c *Component) Disconnect() { func (c *Component) Disconnect() {
@ -121,20 +128,22 @@ func (c *Component) SetHandler(handler EventHandler) {
} }
// Receiver Go routine receiver // Receiver Go routine receiver
func (c *Component) recv() (err error) { func (c *Component) recv(errChan chan<- error) {
for { for {
val, err := stanza.NextPacket(c.transport.GetDecoder()) val, err := stanza.NextPacket(c.transport.GetDecoder())
if err != nil { if err != nil {
c.updateState(StateDisconnected) c.updateState(StateDisconnected)
return err errChan <- err
return
} }
// Handle stream errors // Handle stream errors
switch p := val.(type) { switch p := val.(type) {
case stanza.StreamError: case stanza.StreamError:
c.router.route(c, val) c.router.route(c, val)
c.streamError(p.Error.Local, p.Text) c.streamError(p.Error.Local, p.Text)
return errors.New("stream error: " + p.Error.Local) errChan <- errors.New("stream error: " + p.Error.Local)
return
} }
c.router.route(c, val) c.router.route(c, val)
} }
@ -168,7 +177,7 @@ func (c *Component) Send(packet stanza.Packet) error {
// result := <- client.SendIQ(ctx, iq) // result := <- client.SendIQ(ctx, iq)
// //
func (c *Component) SendIQ(ctx context.Context, iq stanza.IQ) (chan stanza.IQ, error) { func (c *Component) SendIQ(ctx context.Context, iq stanza.IQ) (chan stanza.IQ, error) {
if iq.Attrs.Type != "set" && iq.Attrs.Type != "get" { if iq.Attrs.Type != stanza.IQTypeSet && iq.Attrs.Type != stanza.IQTypeGet {
return nil, ErrCanOnlySendGetOrSetIq return nil, ErrCanOnlySendGetOrSetIq
} }
if err := c.Send(iq); err != nil { if err := c.Send(iq); err != nil {

@ -1,12 +1,34 @@
package xmpp package xmpp
import ( import (
"context"
"encoding/xml"
"errors"
"fmt" "fmt"
"gosrc.io/xmpp/stanza"
"net"
"strings"
"testing" "testing"
"time"
) )
const testComponentDomain = "localhost" // Tests are ran in parallel, so each test creating a server must use a different port so we do not get any
const testComponentPort = "15222" // conflict. Using iota for this should do the trick.
const (
testComponentDomain = "localhost"
defaultServerName = "testServer"
defaultStreamID = "91bd0bba-012f-4d92-bb17-5fc41e6fe545"
defaultComponentName = "Test Component"
// Default port is not standard XMPP port to avoid interfering
// with local running XMPP server
testHandshakePort = iota + 15222
testDecoderPort
testSendIqPort
testSendRawPort
testDisconnectPort
testSManDisconnectPort
)
func TestHandshake(t *testing.T) { func TestHandshake(t *testing.T) {
opts := ComponentOptions{ opts := ComponentOptions{
@ -24,8 +46,73 @@ func TestHandshake(t *testing.T) {
} }
} }
// Tests connection process with a handshake exchange
// Tests multiple session IDs. All connections should generate a unique stream ID
func TestGenerateHandshake(t *testing.T) { func TestGenerateHandshake(t *testing.T) {
// TODO // Using this array with a channel to make a queue of values to test
// These are stream IDs that will be used to test the connection process, mixing them with the "secret" to generate
// some handshake value
var uuidsArray = [5]string{
"cc9b3249-9582-4780-825f-4311b42f9b0e",
"bba8be3c-d98e-4e26-b9bb-9ed34578a503",
"dae72822-80e8-496b-b763-ab685f53a188",
"a45d6c06-de49-4bb0-935b-1a2201b71028",
"7dc6924f-0eca-4237-9898-18654b8d891e",
}
// Channel to pass stream IDs as a queue
var uchan = make(chan string, len(uuidsArray))
// Populate test channel
for _, elt := range uuidsArray {
uchan <- elt
}
// Performs a Component connection with a handshake. It expects to have an ID sent its way through the "uchan"
// channel of this file. Otherwise it will hang for ever.
h := func(t *testing.T, c net.Conn) {
decoder := xml.NewDecoder(c)
checkOpenStreamHandshakeID(t, c, decoder, <-uchan)
readHandshakeComponent(t, decoder)
fmt.Fprintln(c, "<handshake/>") // That's all the server needs to return (see xep-0114)
return
}
// Init mock server
testComponentAddess := fmt.Sprintf("%s:%d", testComponentDomain, testHandshakePort)
mock := ServerMock{}
mock.Start(t, testComponentAddess, h)
// Init component
opts := ComponentOptions{
TransportConfiguration: TransportConfiguration{
Address: testComponentAddess,
Domain: "localhost",
},
Domain: testComponentDomain,
Secret: "mypass",
Name: "Test Component",
Category: "gateway",
Type: "service",
}
router := NewRouter()
c, err := NewComponent(opts, router)
if err != nil {
t.Errorf("%+v", err)
}
c.transport, err = NewComponentTransport(c.ComponentOptions.TransportConfiguration)
if err != nil {
t.Errorf("%+v", err)
}
// Try connecting, and storing the resulting streamID in a map.
m := make(map[string]bool)
for _, _ = range uuidsArray {
streamId, _ := c.transport.Connect()
m[c.handshake(streamId)] = true
}
if len(uuidsArray) != len(m) {
t.Errorf("Handshake does not produce a unique id. Expected: %d unique ids, got: %d", len(uuidsArray), len(m))
}
} }
// Test that NewStreamManager can accept a Component. // Test that NewStreamManager can accept a Component.
@ -39,18 +126,318 @@ func TestStreamManager(t *testing.T) {
// The decoder is expected to be built after a valid connection // The decoder is expected to be built after a valid connection
// Based on the xmpp_component example. // Based on the xmpp_component example.
func TestDecoder(t *testing.T) { func TestDecoder(t *testing.T) {
testComponentAddess := fmt.Sprintf("%s:%s", testComponentDomain, testComponentPort) c, _ := mockConnection(t, testDecoderPort, handlerForComponentHandshakeDefaultID)
if c.transport.GetDecoder() == nil {
t.Errorf("Failed to initialize decoder. Decoder is nil.")
}
}
// Tests sending an IQ to the server, and getting the response
func TestSendIq(t *testing.T) {
//Connecting to a mock server, initialized with given port and handler function
c, m := mockConnection(t, testSendIqPort, handlerForComponentIQSend)
ctx, _ := context.WithTimeout(context.Background(), 30*time.Second)
iqReq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "test1@localhost/mremond-mbp", To: defaultServerName, Id: defaultStreamID, Lang: "en"})
disco := iqReq.DiscoInfo()
iqReq.Payload = disco
var res chan stanza.IQ
res, _ = c.SendIQ(ctx, iqReq)
select {
case <-res:
case <-time.After(100 * time.Millisecond):
t.Errorf("Failed to receive response, to sent IQ, from mock server")
}
m.Stop()
}
// Tests sending raw xml to the mock server.
// TODO : check the server response client side ?
// Right now, the server response is not checked and an err is passed in a channel if the test is supposed to err.
// In this test, we use IQs
func TestSendRaw(t *testing.T) {
// Error channel for the handler
errChan := make(chan error)
// Handler for the mock server
h := func(t *testing.T, c net.Conn) {
// Completes the connection by exchanging handshakes
handlerForComponentHandshakeDefaultID(t, c)
receiveRawIq(t, c, errChan)
return
}
type testCase struct {
req string
shouldErr bool
}
testRequests := make(map[string]testCase)
// Sending a correct IQ of type get. Not supposed to err
testRequests["Correct IQ"] = testCase{
req: `<iq type="get" id="91bd0bba-012f-4d92-bb17-5fc41e6fe545" from="test1@localhost/mremond-mbp" to="testServer" lang="en"><query xmlns="http://jabber.org/protocol/disco#info"></query></iq>`,
shouldErr: false,
}
// Sending an IQ with a missing ID. Should err
testRequests["IQ with missing ID"] = testCase{
req: `<iq type="get" from="test1@localhost/mremond-mbp" to="testServer" lang="en"><query xmlns="http://jabber.org/protocol/disco#info"></query></iq>`,
shouldErr: true,
}
// Tests for all the IQs
for name, tcase := range testRequests {
t.Run(name, func(st *testing.T) {
//Connecting to a mock server, initialized with given port and handler function
c, m := mockConnection(t, testSendRawPort, h)
// Sending raw xml from test case
err := c.SendRaw(tcase.req)
if err != nil {
t.Errorf("Error sending Raw string")
}
// Just wait a little so the message has time to arrive
select {
case <-time.After(100 * time.Millisecond):
case err = <-errChan:
if err == nil && tcase.shouldErr {
t.Errorf("Failed to get closing stream err")
}
}
c.transport.Close()
m.Stop()
})
}
}
// Tests the Disconnect method for Components
func TestDisconnect(t *testing.T) {
c, m := mockConnection(t, testDisconnectPort, handlerForComponentHandshakeDefaultID)
err := c.transport.Ping()
if err != nil {
t.Errorf("Could not ping but not disconnected yet")
}
c.Disconnect()
err = c.transport.Ping()
if err == nil {
t.Errorf("Did not disconnect properly")
}
m.Stop()
}
// Tests that a streamManager successfully disconnects when a handshake fails between the component and the server.
func TestStreamManagerDisconnect(t *testing.T) {
// Init mock server
testComponentAddress := fmt.Sprintf("%s:%d", testComponentDomain, testSManDisconnectPort)
mock := ServerMock{} mock := ServerMock{}
mock.Start(t, testComponentAddess, handlerConnectSuccess) // Handler fails the handshake, which is currently the only option to disconnect completely when using a streamManager
// a failed handshake being a permanent error, except for a "conflict"
mock.Start(t, testComponentAddress, handlerComponentFailedHandshakeDefaultID)
//==================================
// Create Component to connect to it
c := makeBasicComponent(defaultComponentName, testComponentAddress, t)
//========================================
// Connect the new Component to the server
cm := NewStreamManager(c, nil)
errChan := make(chan error)
runSMan := func(errChan chan error) {
errChan <- cm.Run()
}
go runSMan(errChan)
select {
case <-errChan:
case <-time.After(100 * time.Millisecond):
t.Errorf("The component and server seem to still be connected while they should not.")
}
mock.Stop()
}
//=============================================================================
// Basic XMPP Server Mock Handlers.
// Performs a Component connection with a handshake. It uses a default ID defined in this file as a constant.
// Used in the mock server as a Handler
func handlerForComponentHandshakeDefaultID(t *testing.T, c net.Conn) {
decoder := xml.NewDecoder(c)
checkOpenStreamHandshakeDefaultID(t, c, decoder)
readHandshakeComponent(t, decoder)
fmt.Fprintln(c, "<handshake/>") // That's all the server needs to return (see xep-0114)
return
}
// Performs a Component connection with a handshake. It uses a default ID defined in this file as a constant.
// This handler is supposed to fail by sending a "message" stanza instead of a <handshake/> stanza to finalize the handshake.
func handlerComponentFailedHandshakeDefaultID(t *testing.T, c net.Conn) {
decoder := xml.NewDecoder(c)
checkOpenStreamHandshakeDefaultID(t, c, decoder)
readHandshakeComponent(t, decoder)
// Send a message, instead of a "<handshake/>" tag, to fail the handshake process dans disconnect the client.
me := stanza.Message{
Attrs: stanza.Attrs{Type: stanza.MessageTypeChat, From: defaultServerName, To: defaultComponentName, Lang: "en"},
Body: "Fail my handshake.",
}
s, _ := xml.Marshal(me)
fmt.Fprintln(c, string(s))
return
}
// Reads from the connection with the Component. Expects a handshake request, and returns the <handshake/> tag.
func readHandshakeComponent(t *testing.T, decoder *xml.Decoder) {
se, err := stanza.NextStart(decoder)
if err != nil {
t.Errorf("cannot read auth: %s", err)
return
}
nv := &stanza.Handshake{}
// Decode element into pointer storage
if err = decoder.DecodeElement(nv, &se); err != nil {
t.Errorf("cannot decode handshake: %s", err)
return
}
if len(strings.TrimSpace(nv.Value)) == 0 {
t.Errorf("did not receive handshake ID")
}
}
func checkOpenStreamHandshakeDefaultID(t *testing.T, c net.Conn, decoder *xml.Decoder) {
checkOpenStreamHandshakeID(t, c, decoder, defaultStreamID)
}
// Used for ID and handshake related tests
func checkOpenStreamHandshakeID(t *testing.T, c net.Conn, decoder *xml.Decoder, streamID string) {
c.SetDeadline(time.Now().Add(defaultTimeout))
defer c.SetDeadline(time.Time{})
for { // TODO clean up. That for loop is not elegant and I prefer bounded recursion.
token, err := decoder.Token()
if err != nil {
t.Errorf("cannot read next token: %s", err)
}
switch elem := token.(type) {
// Wait for first startElement
case xml.StartElement:
if elem.Name.Space != stanza.NSStream || elem.Name.Local != "stream" {
err = errors.New("xmpp: expected <stream> but got <" + elem.Name.Local + "> in " + elem.Name.Space)
return
}
if _, err := fmt.Fprintf(c, serverStreamOpen, "localhost", streamID, stanza.NSComponent, stanza.NSStream); err != nil {
t.Errorf("cannot write server stream open: %s", err)
}
return
}
}
}
//=============================================================================
// Sends IQ response to Component request.
// No parsing of the request here. We just check that it's valid, and send the default response.
func handlerForComponentIQSend(t *testing.T, c net.Conn) {
// Completes the connection by exchanging handshakes
handlerForComponentHandshakeDefaultID(t, c)
// Decoder to parse the request
decoder := xml.NewDecoder(c)
iqReq, err := receiveIq(t, c, decoder)
if err != nil {
t.Errorf("Error receiving the IQ stanza : %v", err)
} else if !iqReq.IsValid() {
t.Errorf("server received an IQ stanza : %v", iqReq)
}
// Crafting response
iqResp := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeResult, From: iqReq.To, To: iqReq.From, Id: iqReq.Id, Lang: "en"})
disco := iqResp.DiscoInfo()
disco.AddFeatures("vcard-temp",
`http://jabber.org/protocol/address`)
disco.AddIdentity("Multicast", "service", "multicast")
iqResp.Payload = disco
// Sending response to the Component
mResp, err := xml.Marshal(iqResp)
_, err = fmt.Fprintln(c, string(mResp))
if err != nil {
t.Errorf("Could not send response stanza : %s", err)
}
return
}
// Reads next request coming from the Component. Expecting it to be an IQ request
func receiveIq(t *testing.T, c net.Conn, decoder *xml.Decoder) (stanza.IQ, error) {
c.SetDeadline(time.Now().Add(defaultTimeout))
defer c.SetDeadline(time.Time{})
var iqStz stanza.IQ
err := decoder.Decode(&iqStz)
if err != nil {
t.Errorf("cannot read the received IQ stanza: %s", err)
}
if !iqStz.IsValid() {
t.Errorf("received IQ stanza is invalid : %s", err)
}
return iqStz, nil
}
func receiveRawIq(t *testing.T, c net.Conn, errChan chan error) {
c.SetDeadline(time.Now().Add(defaultTimeout))
defer c.SetDeadline(time.Time{})
decoder := xml.NewDecoder(c)
var iq stanza.IQ
err := decoder.Decode(&iq)
if err != nil || !iq.IsValid() {
s := stanza.StreamError{
XMLName: xml.Name{Local: "stream:error"},
Error: xml.Name{Local: "xml-not-well-formed"},
Text: `XML was not well-formed`,
}
raw, _ := xml.Marshal(s)
fmt.Fprintln(c, string(raw))
fmt.Fprintln(c, `</stream:stream>`) // TODO : check this client side
errChan <- fmt.Errorf("invalid xml")
return
}
errChan <- nil
return
}
//===============================
// Init mock server and connection
// Creating a mock server and connecting a Component to it. Initialized with given port and handler function
// The Component and mock are both returned
func mockConnection(t *testing.T, port int, handler func(t *testing.T, c net.Conn)) (*Component, *ServerMock) {
// Init mock server
testComponentAddress := fmt.Sprintf("%s:%d", testComponentDomain, port)
mock := ServerMock{}
mock.Start(t, testComponentAddress, handler)
//==================================
// Create Component to connect to it
c := makeBasicComponent(defaultComponentName, testComponentAddress, t)
//========================================
// Connect the new Component to the server
err := c.Connect()
if err != nil {
t.Errorf("%+v", err)
}
return c, &mock
}
func makeBasicComponent(name string, mockServerAddr string, t *testing.T) *Component {
opts := ComponentOptions{ opts := ComponentOptions{
TransportConfiguration: TransportConfiguration{ TransportConfiguration: TransportConfiguration{
Address: testComponentAddess, Address: mockServerAddr,
Domain: "localhost", Domain: "localhost",
}, },
Domain: testComponentDomain, Domain: testComponentDomain,
Secret: "mypass", Secret: "mypass",
Name: "Test Component", Name: name,
Category: "gateway", Category: "gateway",
Type: "service", Type: "service",
} }
@ -63,12 +450,5 @@ func TestDecoder(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("%+v", err) t.Errorf("%+v", err)
} }
_, err = c.transport.Connect() return c
if err != nil {
t.Errorf("%+v", err)
}
if c.transport.GetDecoder() == nil {
t.Errorf("Failed to initialize decoder. Decoder is nil.")
}
} }

@ -23,7 +23,7 @@ func ensurePort(addr string, port int) string {
// This is IPV4 without port // This is IPV4 without port
return addr + ":" + strconv.Itoa(port) return addr + ":" + strconv.Itoa(port)
case 1: case 1:
// This is IPV$ with port // This is IPV6 with port
return addr return addr
default: default:
// This is IPV6 without port, as you need to use bracket with port in IPV6 // This is IPV6 without port, as you need to use bracket with port in IPV6

@ -1,12 +1,10 @@
package xmpp package xmpp
import ( import (
"strings"
"testing" "testing"
) )
type params struct {
}
func TestParseAddr(t *testing.T) { func TestParseAddr(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@ -33,3 +31,36 @@ func TestParseAddr(t *testing.T) {
}) })
} }
} }
func TestEnsurePort(t *testing.T) {
testAddresses := []string{
"1ca3:6c07:ee3a:89ca:e065:9a70:71d:daad",
"1ca3:6c07:ee3a:89ca:e065:9a70:71d:daad:5252",
"[::1]",
"127.0.0.1:5555",
"127.0.0.1",
"[::1]:5555",
}
for _, oldAddr := range testAddresses {
t.Run(oldAddr, func(st *testing.T) {
newAddr := ensurePort(oldAddr, 5222)
if len(newAddr) < len(oldAddr) {
st.Errorf("incorrect Result: transformed address is shorter than input : %v (old) > %v (new)", newAddr, oldAddr)
}
// If IPv6, the new address needs brackets to specify a port, like so : [2001:db8:85a3:0:0:8a2e:370:7334]:5222
if strings.Count(newAddr, "[") < strings.Count(oldAddr, "[") ||
strings.Count(newAddr, "]") < strings.Count(oldAddr, "]") {
st.Errorf("incorrect Result. Transformed address seems to not have correct brakets : %v => %v", oldAddr, newAddr)
}
// Check if we messed up the colons, or didn't properly add a port
if strings.Count(newAddr, ":") < strings.Count(oldAddr, ":") {
st.Errorf("incorrect Result: transformed address doesn't seem to have a port %v (=> %v, no port ?)", oldAddr, newAddr)
}
})
}
}

@ -146,7 +146,7 @@ func TestTypeMatcher(t *testing.T) {
// We do not match on other types // We do not match on other types
conn = NewSenderMock() conn = NewSenderMock()
iqVersion := stanza.NewIQ(stanza.Attrs{Type: "get", From: "service.localhost", To: "test@localhost", Id: "1"}) iqVersion := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "service.localhost", To: "test@localhost", Id: "1"})
iqVersion.Payload = &stanza.DiscoInfo{ iqVersion.Payload = &stanza.DiscoInfo{
XMLName: xml.Name{ XMLName: xml.Name{
Space: "jabber:iq:version", Space: "jabber:iq:version",
@ -163,27 +163,27 @@ func TestCompositeMatcher(t *testing.T) {
router := NewRouter() router := NewRouter()
router.NewRoute(). router.NewRoute().
IQNamespaces("jabber:iq:version"). IQNamespaces("jabber:iq:version").
StanzaType("get"). StanzaType(string(stanza.IQTypeGet)).
HandlerFunc(func(s Sender, p stanza.Packet) { HandlerFunc(func(s Sender, p stanza.Packet) {
_ = s.SendRaw(successFlag) _ = s.SendRaw(successFlag)
}) })
// Data set // Data set
getVersionIq := stanza.NewIQ(stanza.Attrs{Type: "get", From: "service.localhost", To: "test@localhost", Id: "1"}) getVersionIq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "service.localhost", To: "test@localhost", Id: "1"})
getVersionIq.Payload = &stanza.Version{ getVersionIq.Payload = &stanza.Version{
XMLName: xml.Name{ XMLName: xml.Name{
Space: "jabber:iq:version", Space: "jabber:iq:version",
Local: "query", Local: "query",
}} }}
setVersionIq := stanza.NewIQ(stanza.Attrs{Type: "set", From: "service.localhost", To: "test@localhost", Id: "1"}) setVersionIq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeSet, From: "service.localhost", To: "test@localhost", Id: "1"})
setVersionIq.Payload = &stanza.Version{ setVersionIq.Payload = &stanza.Version{
XMLName: xml.Name{ XMLName: xml.Name{
Space: "jabber:iq:version", Space: "jabber:iq:version",
Local: "query", Local: "query",
}} }}
GetDiscoIq := stanza.NewIQ(stanza.Attrs{Type: "get", From: "service.localhost", To: "test@localhost", Id: "1"}) GetDiscoIq := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "service.localhost", To: "test@localhost", Id: "1"})
GetDiscoIq.Payload = &stanza.DiscoInfo{ GetDiscoIq.Payload = &stanza.DiscoInfo{
XMLName: xml.Name{ XMLName: xml.Name{
Space: "http://jabber.org/protocol/disco#info", Space: "http://jabber.org/protocol/disco#info",
@ -238,7 +238,7 @@ func TestCatchallMatcher(t *testing.T) {
} }
conn = NewSenderMock() conn = NewSenderMock()
iqVersion := stanza.NewIQ(stanza.Attrs{Type: "get", From: "service.localhost", To: "test@localhost", Id: "1"}) iqVersion := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeGet, From: "service.localhost", To: "test@localhost", Id: "1"})
iqVersion.Payload = &stanza.DiscoInfo{ iqVersion.Payload = &stanza.DiscoInfo{
XMLName: xml.Name{ XMLName: xml.Name{
Space: "jabber:iq:version", Space: "jabber:iq:version",

@ -12,7 +12,7 @@ import (
type Handshake struct { type Handshake struct {
XMLName xml.Name `xml:"jabber:component:accept handshake"` XMLName xml.Name `xml:"jabber:component:accept handshake"`
// TODO Add handshake value with test for proper serialization // TODO Add handshake value with test for proper serialization
// Value string `xml:",innerxml"` Value string `xml:",innerxml"`
} }
func (Handshake) Name() string { func (Handshake) Name() string {

@ -54,7 +54,7 @@ func (x *Err) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
textName := xml.Name{Space: "urn:ietf:params:xml:ns:xmpp-stanzas", Local: "text"} textName := xml.Name{Space: "urn:ietf:params:xml:ns:xmpp-stanzas", Local: "text"}
if elt.XMLName == textName { if elt.XMLName == textName {
x.Text = string(elt.Content) x.Text = elt.Content
} else if elt.XMLName.Space == "urn:ietf:params:xml:ns:xmpp-stanzas" { } else if elt.XMLName.Space == "urn:ietf:params:xml:ns:xmpp-stanzas" {
x.Reason = elt.XMLName.Local x.Reason = elt.XMLName.Local
} }
@ -94,16 +94,32 @@ func (x Err) MarshalXML(e *xml.Encoder, start xml.StartElement) (err error) {
// Reason // Reason
if x.Reason != "" { if x.Reason != "" {
reason := xml.Name{Space: "urn:ietf:params:xml:ns:xmpp-stanzas", Local: x.Reason} reason := xml.Name{Space: "urn:ietf:params:xml:ns:xmpp-stanzas", Local: x.Reason}
e.EncodeToken(xml.StartElement{Name: reason}) err = e.EncodeToken(xml.StartElement{Name: reason})
e.EncodeToken(xml.EndElement{Name: reason}) if err != nil {
return err
}
err = e.EncodeToken(xml.EndElement{Name: reason})
if err != nil {
return err
}
} }
// Text // Text
if x.Text != "" { if x.Text != "" {
text := xml.Name{Space: "urn:ietf:params:xml:ns:xmpp-stanzas", Local: "text"} text := xml.Name{Space: "urn:ietf:params:xml:ns:xmpp-stanzas", Local: "text"}
e.EncodeToken(xml.StartElement{Name: text}) err = e.EncodeToken(xml.StartElement{Name: text})
e.EncodeToken(xml.CharData(x.Text)) if err != nil {
e.EncodeToken(xml.EndElement{Name: text}) return err
}
err = e.EncodeToken(xml.CharData(x.Text))
if err != nil {
return err
}
err = e.EncodeToken(xml.EndElement{Name: text})
if err != nil {
return err
}
} }
return e.EncodeToken(xml.EndElement{Name: start.Name}) return e.EncodeToken(xml.EndElement{Name: start.Name})

@ -2,6 +2,7 @@ package stanza
import ( import (
"encoding/xml" "encoding/xml"
"strings"
"github.com/google/uuid" "github.com/google/uuid"
) )
@ -23,7 +24,7 @@ type IQ struct { // Info/Query
// child element, which specifies the semantics of the particular // child element, which specifies the semantics of the particular
// request." // request."
Payload IQPayload `xml:",omitempty"` Payload IQPayload `xml:",omitempty"`
Error Err `xml:"error,omitempty"` Error *Err `xml:"error,omitempty"`
// Any is used to decode unknown payload as a generic structure // Any is used to decode unknown payload as a generic structure
Any *Node `xml:",any"` Any *Node `xml:",any"`
} }
@ -52,7 +53,7 @@ func (iq IQ) MakeError(xerror Err) IQ {
iq.Type = "error" iq.Type = "error"
iq.From = to iq.From = to
iq.To = from iq.To = from
iq.Error = xerror iq.Error = &xerror
return iq return iq
} }
@ -106,7 +107,7 @@ func (iq *IQ) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
if err != nil { if err != nil {
return err return err
} }
iq.Error = xmppError iq.Error = &xmppError
continue continue
} }
if iqExt := TypeRegistry.GetIQExtension(tt.Name); iqExt != nil { if iqExt := TypeRegistry.GetIQExtension(tt.Name); iqExt != nil {
@ -132,3 +133,39 @@ func (iq *IQ) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
} }
} }
} }
// Following RFC-3920 for IQs
func (iq *IQ) IsValid() bool {
// ID is required
if len(strings.TrimSpace(iq.Id)) == 0 {
return false
}
// Type is required
if iq.Type.IsEmpty() {
return false
}
// Type get and set must contain one and only one child element that specifies the semantics
if iq.Type == IQTypeGet || iq.Type == IQTypeSet {
if iq.Payload == nil && iq.Any == nil {
return false
}
}
// A result must include zero or one child element
if iq.Type == IQTypeResult {
if iq.Payload != nil && iq.Any != nil {
return false
}
}
//Error type must contain an "error" child element
if iq.Type == IQTypeError {
if iq.Error == nil {
return false
}
}
return true
}

@ -187,3 +187,38 @@ func TestUnknownPayload(t *testing.T) {
t.Errorf("could not extract namespace: '%s'", parsedIQ.Any.XMLName.Space) t.Errorf("could not extract namespace: '%s'", parsedIQ.Any.XMLName.Space)
} }
} }
func TestIsValid(t *testing.T) {
type testCase struct {
iq string
shouldErr bool
}
testIQs := make(map[string]testCase)
testIQs["Valid IQ"] = testCase{
`<iq type="get" to="service.localhost" id="1" >
<query xmlns="unknown:ns"/>
</iq>`,
false,
}
testIQs["Invalid IQ"] = testCase{
`<iq type="get" to="service.localhost">
<query xmlns="unknown:ns"/>
</iq>`,
true,
}
for name, tcase := range testIQs {
t.Run(name, func(st *testing.T) {
parsedIQ := stanza.IQ{}
err := xml.Unmarshal([]byte(tcase.iq), &parsedIQ)
if err != nil {
t.Errorf("Unmarshal error: %#v (%s)", err, tcase.iq)
return
}
if !parsedIQ.IsValid() && !tcase.shouldErr {
t.Errorf("failed iq validation for : %s", tcase.iq)
}
})
}
}

@ -46,9 +46,18 @@ func (n Node) MarshalXML(e *xml.Encoder, start xml.StartElement) (err error) {
start.Name = n.XMLName start.Name = n.XMLName
err = e.EncodeToken(start) err = e.EncodeToken(start)
e.EncodeElement(n.Nodes, xml.StartElement{Name: n.XMLName}) if err != nil {
return err
}
err = e.EncodeElement(n.Nodes, xml.StartElement{Name: n.XMLName})
if err != nil {
return err
}
if n.Content != "" { if n.Content != "" {
e.EncodeToken(xml.CharData(n.Content)) err = e.EncodeToken(xml.CharData(n.Content))
if err != nil {
return err
}
} }
return e.EncodeToken(xml.EndElement{Name: start.Name}) return e.EncodeToken(xml.EndElement{Name: start.Name})
} }

@ -1,5 +1,7 @@
package stanza package stanza
import "strings"
type StanzaType string type StanzaType string
// RFC 6120: part of A.5 Client Namespace and A.6 Server Namespace // RFC 6120: part of A.5 Client Namespace and A.6 Server Namespace
@ -23,3 +25,7 @@ const (
PresenceTypeUnsubscribe StanzaType = "unsubscribe" PresenceTypeUnsubscribe StanzaType = "unsubscribe"
PresenceTypeUnsubscribed StanzaType = "unsubscribed" PresenceTypeUnsubscribed StanzaType = "unsubscribed"
) )
func (s StanzaType) IsEmpty() bool {
return len(strings.TrimSpace(string(s))) == 0
}

@ -107,6 +107,6 @@ func (s *StreamSession) IsOptional() bool {
// Registry init // Registry init
func init() { func init() {
TypeRegistry.MapExtension(PKTIQ, xml.Name{"urn:ietf:params:xml:ns:xmpp-bind", "bind"}, Bind{}) TypeRegistry.MapExtension(PKTIQ, xml.Name{Space: "urn:ietf:params:xml:ns:xmpp-bind", Local: "bind"}, Bind{})
TypeRegistry.MapExtension(PKTIQ, xml.Name{"urn:ietf:params:xml:ns:xmpp-session", "session"}, StreamSession{}) TypeRegistry.MapExtension(PKTIQ, xml.Name{Space: "urn:ietf:params:xml:ns:xmpp-session", Local: "session"}, StreamSession{})
} }

@ -8,7 +8,7 @@ import "encoding/xml"
type Stream struct { type Stream struct {
XMLName xml.Name `xml:"http://etherx.jabber.org/streams stream"` XMLName xml.Name `xml:"http://etherx.jabber.org/streams stream"`
From string `xml:"from,attr"` From string `xml:"from,attr"`
To string `xml:"to,attr"` To string `xml:"to,attr"`
Id string `xml:"id,attr"` Id string `xml:"id,attr"`
Version string `xml:"version,attr"` Version string `xml:"version,attr"`
} }

@ -15,7 +15,7 @@ type StreamFeatures struct {
// Server capabilities hash // Server capabilities hash
Caps Caps Caps Caps
// Stream features // Stream features
StartTLS tlsStartTLS StartTLS TlsStartTLS
Mechanisms saslMechanisms Mechanisms saslMechanisms
Bind Bind Bind Bind
StreamManagement streamManagement StreamManagement streamManagement
@ -60,13 +60,13 @@ type Caps struct {
// StartTLS feature // StartTLS feature
// Reference: RFC 6120 - https://tools.ietf.org/html/rfc6120#section-5.4 // Reference: RFC 6120 - https://tools.ietf.org/html/rfc6120#section-5.4
type tlsStartTLS struct { type TlsStartTLS struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls starttls"` XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls starttls"`
Required bool Required bool
} }
// UnmarshalXML implements custom parsing startTLS required flag // UnmarshalXML implements custom parsing startTLS required flag
func (stls *tlsStartTLS) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { func (stls *TlsStartTLS) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
stls.XMLName = start.Name stls.XMLName = start.Name
// Check subelements to extract required field as boolean // Check subelements to extract required field as boolean
@ -98,7 +98,7 @@ func (stls *tlsStartTLS) UnmarshalXML(d *xml.Decoder, start xml.StartElement) er
} }
} }
func (sf *StreamFeatures) DoesStartTLS() (feature tlsStartTLS, isSupported bool) { func (sf *StreamFeatures) DoesStartTLS() (feature TlsStartTLS, isSupported bool) {
if sf.StartTLS.XMLName.Space+" "+sf.StartTLS.XMLName.Local == nsTLS+" starttls" { if sf.StartTLS.XMLName.Space+" "+sf.StartTLS.XMLName.Local == nsTLS+" starttls" {
return sf.StartTLS, true return sf.StartTLS, true
} }

@ -74,7 +74,7 @@ func (sm *StreamManager) Run() error {
return errors.New("missing stream client") return errors.New("missing stream client")
} }
handler := func(e Event) { handler := func(e Event) error {
switch e.State { switch e.State {
case StateConnected: case StateConnected:
sm.Metrics.setConnectTime() sm.Metrics.setConnectTime()
@ -82,17 +82,18 @@ func (sm *StreamManager) Run() error {
sm.Metrics.setLoginTime() sm.Metrics.setLoginTime()
case StateDisconnected: case StateDisconnected:
// Reconnect on disconnection // Reconnect on disconnection
sm.resume(e.SMState) return sm.resume(e.SMState)
case StateStreamError: case StateStreamError:
sm.client.Disconnect() sm.client.Disconnect()
// Only try reconnecting if we have not been kicked by another session to avoid connection loop. // Only try reconnecting if we have not been kicked by another session to avoid connection loop.
// TODO: Make this conflict exception a permanent error // TODO: Make this conflict exception a permanent error
if e.StreamError != "conflict" { if e.StreamError != "conflict" {
sm.connect() return sm.connect()
} }
case StatePermanentError: case StatePermanentError:
// Do not attempt to reconnect // Do not attempt to reconnect
} }
return nil
} }
sm.client.SetHandler(handler) sm.client.SetHandler(handler)

@ -5,7 +5,7 @@ export GO111MODULE=on
echo "" > coverage.txt echo "" > coverage.txt
for d in $(go list ./... | grep -v vendor); do for d in $(go list ./... | grep -v vendor); do
go test -race -coverprofile=profile.out -covermode=atomic ${d} go test -race -coverprofile=profile.out -covermode=atomic "${d}"
if [ -f profile.out ]; then if [ -f profile.out ]; then
cat profile.out >> coverage.txt cat profile.out >> coverage.txt
rm profile.out rm profile.out

Loading…
Cancel
Save