diff --git a/client_test.go b/client_test.go index f2b775a..8d109d0 100644 --- a/client_test.go +++ b/client_test.go @@ -5,7 +5,6 @@ import ( "encoding/xml" "errors" "fmt" - "net" "testing" "time" @@ -157,10 +156,10 @@ func TestClient_RFC3921Session(t *testing.T) { func TestClient_SendIQ(t *testing.T) { done := make(chan struct{}) // Handler for Mock server - h := func(t *testing.T, c net.Conn) { - handlerClientConnectSuccess(t, c) - discardPresence(t, c) - respondToIQ(t, c) + h := func(t *testing.T, sc *ServerConn) { + handlerClientConnectSuccess(t, sc) + discardPresence(t, sc) + respondToIQ(t, sc) done <- struct{}{} } client, mock := mockClientConnection(t, h, testClientIqPort) @@ -199,10 +198,10 @@ func TestClient_SendIQ(t *testing.T) { func TestClient_SendIQFail(t *testing.T) { done := make(chan struct{}) // Handler for Mock server - h := func(t *testing.T, c net.Conn) { - handlerClientConnectSuccess(t, c) - discardPresence(t, c) - respondToIQ(t, c) + h := func(t *testing.T, sc *ServerConn) { + handlerClientConnectSuccess(t, sc) + discardPresence(t, sc) + respondToIQ(t, sc) done <- struct{}{} } client, mock := mockClientConnection(t, h, testClientIqFailPort) @@ -244,10 +243,10 @@ func TestClient_SendIQFail(t *testing.T) { func TestClient_SendRaw(t *testing.T) { done := make(chan struct{}) // Handler for Mock server - h := func(t *testing.T, c net.Conn) { - handlerClientConnectSuccess(t, c) - discardPresence(t, c) - respondToIQ(t, c) + h := func(t *testing.T, sc *ServerConn) { + handlerClientConnectSuccess(t, sc) + discardPresence(t, sc) + respondToIQ(t, sc) done <- struct{}{} } type testCase struct { @@ -365,48 +364,44 @@ func TestClient_DisconnectStreamManager(t *testing.T) { // Basic XMPP Server Mock Handlers. // Test connection with a basic straightforward workflow -func handlerClientConnectSuccess(t *testing.T, c net.Conn) { - decoder := xml.NewDecoder(c) - checkClientOpenStream(t, c, decoder) - - sendStreamFeatures(t, c, decoder) // Send initial features - readAuth(t, decoder) - fmt.Fprintln(c, "") - - checkClientOpenStream(t, c, decoder) // Reset stream - sendBindFeature(t, c, decoder) // Send post auth features - bind(t, c, decoder) +func handlerClientConnectSuccess(t *testing.T, sc *ServerConn) { + checkClientOpenStream(t, sc) + sendStreamFeatures(t, sc) // Send initial features + readAuth(t, sc.decoder) + fmt.Fprintln(sc.connection, "") + + checkClientOpenStream(t, sc) // Reset stream + sendBindFeature(t, sc) // Send post auth features + bind(t, sc) } // We expect client will abort on TLS -func handlerAbortTLS(t *testing.T, c net.Conn) { - decoder := xml.NewDecoder(c) - checkClientOpenStream(t, c, decoder) - sendStreamFeatures(t, c, decoder) // Send initial features +func handlerAbortTLS(t *testing.T, sc *ServerConn) { + checkClientOpenStream(t, sc) + sendStreamFeatures(t, sc) // Send initial features } // Test connection with mandatory session (RFC-3921) -func handlerClientConnectWithSession(t *testing.T, c net.Conn) { - decoder := xml.NewDecoder(c) - checkClientOpenStream(t, c, decoder) - - sendStreamFeatures(t, c, decoder) // Send initial features - readAuth(t, decoder) - fmt.Fprintln(c, "") - - checkClientOpenStream(t, c, decoder) // Reset stream - sendRFC3921Feature(t, c, decoder) // Send post auth features - bind(t, c, decoder) - session(t, c, decoder) +func handlerClientConnectWithSession(t *testing.T, sc *ServerConn) { + checkClientOpenStream(t, sc) + + sendStreamFeatures(t, sc) // Send initial features + readAuth(t, sc.decoder) + fmt.Fprintln(sc.connection, "") + + checkClientOpenStream(t, sc) // Reset stream + sendRFC3921Feature(t, sc) // Send post auth features + bind(t, sc) + session(t, sc) } -func checkClientOpenStream(t *testing.T, c net.Conn, decoder *xml.Decoder) { - c.SetDeadline(time.Now().Add(defaultTimeout)) - defer c.SetDeadline(time.Time{}) +func checkClientOpenStream(t *testing.T, sc *ServerConn) { + sc.connection.SetDeadline(time.Now().Add(defaultTimeout)) + defer sc.connection.SetDeadline(time.Time{}) for { // TODO clean up. That for loop is not elegant and I prefer bounded recursion. var token xml.Token - token, err := decoder.Token() + token, err := sc.decoder.Token() if err != nil { t.Errorf("cannot read next token: %s", err) } @@ -418,7 +413,7 @@ func checkClientOpenStream(t *testing.T, c net.Conn, decoder *xml.Decoder) { err = errors.New("xmpp: expected but got <" + elem.Name.Local + "> in " + elem.Name.Space) return } - if _, err := fmt.Fprintf(c, serverStreamOpen, "localhost", "streamid1", stanza.NSClient, stanza.NSStream); err != nil { + if _, err := fmt.Fprintf(sc.connection, serverStreamOpen, "localhost", "streamid1", stanza.NSClient, stanza.NSStream); err != nil { t.Errorf("cannot write server stream open: %s", err) } return @@ -426,8 +421,8 @@ func checkClientOpenStream(t *testing.T, c net.Conn, decoder *xml.Decoder) { } } -func mockClientConnection(t *testing.T, serverHandler func(*testing.T, net.Conn), port int) (*Client, ServerMock) { - mock := ServerMock{} +func mockClientConnection(t *testing.T, serverHandler func(*testing.T, *ServerConn), port int) (*Client, *ServerMock) { + mock := &ServerMock{} testServerAddress := fmt.Sprintf("%s:%d", testClientDomain, port) mock.Start(t, testServerAddress, serverHandler) diff --git a/component_test.go b/component_test.go index 48963a5..f4d1a07 100644 --- a/component_test.go +++ b/component_test.go @@ -7,7 +7,6 @@ import ( "fmt" "github.com/google/uuid" "gosrc.io/xmpp/stanza" - "net" "strings" "testing" "time" @@ -36,7 +35,7 @@ func TestHandshake(t *testing.T) { } // Tests connection process with a handshake exchange -// Tests multiple session IDs. All connections should generate a unique stream ID +// Tests multiple session IDs. All serverConnections should generate a unique stream ID func TestGenerateHandshakeId(t *testing.T) { // Using this array with a channel to make a queue of values to test // These are stream IDs that will be used to test the connection process, mixing them with the "secret" to generate @@ -56,11 +55,11 @@ func TestGenerateHandshakeId(t *testing.T) { // Performs a Component connection with a handshake. It expects to have an ID sent its way through the "uchan" // channel of this file. Otherwise it will hang for ever. - h := func(t *testing.T, c net.Conn) { - decoder := xml.NewDecoder(c) - checkOpenStreamHandshakeID(t, c, decoder, <-uchan) - readHandshakeComponent(t, decoder) - fmt.Fprintln(c, "") // That's all the server needs to return (see xep-0114) + h := func(t *testing.T, sc *ServerConn) { + + checkOpenStreamHandshakeID(t, sc, <-uchan) + readHandshakeComponent(t, sc.decoder) + fmt.Fprintln(sc.connection, "") // That's all the server needs to return (see xep-0114) return } @@ -122,8 +121,8 @@ func TestDecoder(t *testing.T) { // Tests sending an IQ to the server, and getting the response func TestSendIq(t *testing.T) { done := make(chan struct{}) - h := func(t *testing.T, c net.Conn) { - handlerForComponentIQSend(t, c) + h := func(t *testing.T, sc *ServerConn) { + handlerForComponentIQSend(t, sc) done <- struct{}{} } @@ -164,8 +163,8 @@ func TestSendIq(t *testing.T) { // Checking that error handling is done properly client side when an invalid IQ is sent and the server responds in kind. func TestSendIqFail(t *testing.T) { done := make(chan struct{}) - h := func(t *testing.T, c net.Conn) { - handlerForComponentIQSend(t, c) + h := func(t *testing.T, sc *ServerConn) { + handlerForComponentIQSend(t, sc) done <- struct{}{} } //Connecting to a mock server, initialized with given port and handler function @@ -213,27 +212,30 @@ func TestSendIqFail(t *testing.T) { func TestSendRaw(t *testing.T) { done := make(chan struct{}) // Handler for the mock server - h := func(t *testing.T, c net.Conn) { + h := func(t *testing.T, sc *ServerConn) { // Completes the connection by exchanging handshakes - handlerForComponentHandshakeDefaultID(t, c) - receiveIq(c, xml.NewDecoder(c)) + handlerForComponentHandshakeDefaultID(t, sc) + respondToIQ(t, sc) done <- struct{}{} } type testCase struct { req string shouldErr bool + port int } testRequests := make(map[string]testCase) // Sending a correct IQ of type get. Not supposed to err testRequests["Correct IQ"] = testCase{ req: ``, shouldErr: false, + port: testSendRawPort + 100, } // Sending an IQ with a missing ID. Should err testRequests["IQ with missing ID"] = testCase{ req: ``, shouldErr: true, + port: testSendRawPort + 200, } // A handler for the component. @@ -247,7 +249,7 @@ func TestSendRaw(t *testing.T) { for name, tcase := range testRequests { t.Run(name, func(st *testing.T) { //Connecting to a mock server, initialized with given port and handler function - c, m := mockComponentConnection(t, testSendRawPort, h) + c, m := mockComponentConnection(t, tcase.port, h) c.ErrorHandler = errHandler // Sending raw xml from test case err := c.SendRaw(tcase.req) @@ -328,10 +330,10 @@ func TestStreamManagerDisconnect(t *testing.T) { // Init mock server and connection // Creating a mock server and connecting a Component to it. Initialized with given port and handler function // The Component and mock are both returned -func mockComponentConnection(t *testing.T, port int, handler func(t *testing.T, c net.Conn)) (*Component, *ServerMock) { +func mockComponentConnection(t *testing.T, port int, handler func(t *testing.T, sc *ServerConn)) (*Component, *ServerMock) { // Init mock server testComponentAddress := fmt.Sprintf("%s:%d", testComponentDomain, port) - mock := ServerMock{} + mock := &ServerMock{} mock.Start(t, testComponentAddress, handler) //================================== @@ -345,7 +347,9 @@ func mockComponentConnection(t *testing.T, port int, handler func(t *testing.T, t.Errorf("%+v", err) } - return c, &mock + // Now that the Component is connected, let's set the xml.Decoder for the server + + return c, mock } func makeBasicComponent(name string, mockServerAddr string, t *testing.T) *Component { @@ -380,19 +384,19 @@ func componentDefaultErrorHandler(err error) { // Sends IQ response to Component request. // No parsing of the request here. We just check that it's valid, and send the default response. -func handlerForComponentIQSend(t *testing.T, c net.Conn) { +func handlerForComponentIQSend(t *testing.T, sc *ServerConn) { // Completes the connection by exchanging handshakes - handlerForComponentHandshakeDefaultID(t, c) - respondToIQ(t, c) + handlerForComponentHandshakeDefaultID(t, sc) + respondToIQ(t, sc) } // Used for ID and handshake related tests -func checkOpenStreamHandshakeID(t *testing.T, c net.Conn, decoder *xml.Decoder, streamID string) { - c.SetDeadline(time.Now().Add(defaultTimeout)) - defer c.SetDeadline(time.Time{}) +func checkOpenStreamHandshakeID(t *testing.T, sc *ServerConn, streamID string) { + sc.connection.SetDeadline(time.Now().Add(defaultTimeout)) + defer sc.connection.SetDeadline(time.Time{}) for { // TODO clean up. That for loop is not elegant and I prefer bounded recursion. - token, err := decoder.Token() + token, err := sc.decoder.Token() if err != nil { t.Errorf("cannot read next token: %s", err) } @@ -404,7 +408,7 @@ func checkOpenStreamHandshakeID(t *testing.T, c net.Conn, decoder *xml.Decoder, err = errors.New("xmpp: expected but got <" + elem.Name.Local + "> in " + elem.Name.Space) return } - if _, err := fmt.Fprintf(c, serverStreamOpen, "localhost", streamID, stanza.NSComponent, stanza.NSStream); err != nil { + if _, err := fmt.Fprintf(sc.connection, serverStreamOpen, "localhost", streamID, stanza.NSComponent, stanza.NSStream); err != nil { t.Errorf("cannot write server stream open: %s", err) } return @@ -412,16 +416,15 @@ func checkOpenStreamHandshakeID(t *testing.T, c net.Conn, decoder *xml.Decoder, } } -func checkOpenStreamHandshakeDefaultID(t *testing.T, c net.Conn, decoder *xml.Decoder) { - checkOpenStreamHandshakeID(t, c, decoder, defaultStreamID) +func checkOpenStreamHandshakeDefaultID(t *testing.T, sc *ServerConn) { + checkOpenStreamHandshakeID(t, sc, defaultStreamID) } // Performs a Component connection with a handshake. It uses a default ID defined in this file as a constant. // This handler is supposed to fail by sending a "message" stanza instead of a stanza to finalize the handshake. -func handlerComponentFailedHandshakeDefaultID(t *testing.T, c net.Conn) { - decoder := xml.NewDecoder(c) - checkOpenStreamHandshakeDefaultID(t, c, decoder) - readHandshakeComponent(t, decoder) +func handlerComponentFailedHandshakeDefaultID(t *testing.T, sc *ServerConn) { + checkOpenStreamHandshakeDefaultID(t, sc) + readHandshakeComponent(t, sc.decoder) // Send a message, instead of a "" tag, to fail the handshake process dans disconnect the client. me := stanza.Message{ @@ -429,7 +432,7 @@ func handlerComponentFailedHandshakeDefaultID(t *testing.T, c net.Conn) { Body: "Fail my handshake.", } s, _ := xml.Marshal(me) - fmt.Fprintln(c, string(s)) + fmt.Fprintln(sc.connection, string(s)) return } @@ -454,10 +457,9 @@ func readHandshakeComponent(t *testing.T, decoder *xml.Decoder) { // Performs a Component connection with a handshake. It uses a default ID defined in this file as a constant. // Used in the mock server as a Handler -func handlerForComponentHandshakeDefaultID(t *testing.T, c net.Conn) { - decoder := xml.NewDecoder(c) - checkOpenStreamHandshakeDefaultID(t, c, decoder) - readHandshakeComponent(t, decoder) - fmt.Fprintln(c, "") // That's all the server needs to return (see xep-0114) +func handlerForComponentHandshakeDefaultID(t *testing.T, sc *ServerConn) { + checkOpenStreamHandshakeDefaultID(t, sc) + readHandshakeComponent(t, sc.decoder) + fmt.Fprintln(sc.connection, "") // That's all the server needs to return (see xep-0114) return } diff --git a/doc.go b/doc.go index 40f4f6a..f29bbf6 100644 --- a/doc.go +++ b/doc.go @@ -29,7 +29,7 @@ Components XMPP components can typically be used to extends the features of an XMPP server, in a portable way, using component protocol over persistent TCP -connections. +serverConnections. Component protocol is defined in XEP-114 (https://xmpp.org/extensions/xep-0114.html). diff --git a/session.go b/session.go index 22d76b2..6b9c75a 100644 --- a/session.go +++ b/session.go @@ -119,7 +119,7 @@ func (s *Session) startTlsIfSupported(o Config) { return } - // If we do not allow cleartext connections, make it explicit that server do not support starttls + // If we do not allow cleartext serverConnections, make it explicit that server do not support starttls if !o.Insecure { s.err = errors.New("XMPP server does not advertise support for starttls") } diff --git a/tcp_server_mock.go b/tcp_server_mock.go index 1084cbd..c8f5d97 100644 --- a/tcp_server_mock.go +++ b/tcp_server_mock.go @@ -41,16 +41,21 @@ const ( // ClientHandler is passed by the test client to provide custom behaviour to // the TCP server mock. This allows customizing the server behaviour to allow // testing clients under various scenarii. -type ClientHandler func(t *testing.T, conn net.Conn) +type ClientHandler func(t *testing.T, serverConn *ServerConn) // ServerMock is a simple TCP server that can be use to mock basic server // behaviour to test clients. type ServerMock struct { - t *testing.T - handler ClientHandler - listener net.Listener - connections []net.Conn - done chan struct{} + t *testing.T + handler ClientHandler + listener net.Listener + serverConnections []*ServerConn + done chan struct{} +} + +type ServerConn struct { + connection net.Conn + decoder *xml.Decoder } // Start launches the mock TCP server, listening to an actual address / port. @@ -68,9 +73,9 @@ func (mock *ServerMock) Stop() { if mock.listener != nil { mock.listener.Close() } - // Close all existing connections - for _, c := range mock.connections { - c.Close() + // Close all existing serverConnections + for _, c := range mock.serverConnections { + c.connection.Close() } } @@ -90,13 +95,14 @@ func (mock *ServerMock) init(addr string) error { return nil } -// loop accepts connections and creates a go routine per connection. +// loop accepts serverConnections and creates a go routine per connection. // The go routine is running the client handler, that is used to provide the // real TCP server behaviour. func (mock *ServerMock) loop() { listener := mock.listener for { conn, err := listener.Accept() + serverConn := &ServerConn{conn, xml.NewDecoder(conn)} if err != nil { select { case <-mock.done: @@ -106,9 +112,10 @@ func (mock *ServerMock) loop() { } return } - mock.connections = append(mock.connections, conn) + mock.serverConnections = append(mock.serverConnections, serverConn) + // TODO Create and pass a context to cancel the handler if they are still around = avoid possible leak on complex handlers - go mock.handler(mock.t, conn) + go mock.handler(mock.t, serverConn) } } @@ -116,27 +123,20 @@ func (mock *ServerMock) loop() { // A few functions commonly used for tests. Trying to avoid duplicates in client and component test files. //====================================================================================================================== -func respondToIQ(t *testing.T, c net.Conn) { - recvBuf := make([]byte, 1024) - var iqR stanza.IQ - _, err := c.Read(recvBuf[:]) // recv data - +func respondToIQ(t *testing.T, sc *ServerConn) { + // Decoder to parse the request + iqReq, err := receiveIq(sc) if err != nil { - if netErr, ok := err.(net.Error); ok && netErr.Timeout() { - t.Errorf("read timeout: %s", err) - } else { - t.Errorf("read error: %s", err) - } + t.Fatalf("failed to receive IQ : %s", err.Error()) } - xml.Unmarshal(recvBuf, &iqR) - if !iqR.IsValid() { - mockIQError(c) + if !iqReq.IsValid() { + mockIQError(sc.connection) return } // Crafting response - iqResp := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeResult, From: iqR.To, To: iqR.From, Id: iqR.Id, Lang: "en"}) + iqResp := stanza.NewIQ(stanza.Attrs{Type: stanza.IQTypeResult, From: iqReq.To, To: iqReq.From, Id: iqReq.Id, Lang: "en"}) disco := iqResp.DiscoInfo() disco.AddFeatures("vcard-temp", `http://jabber.org/protocol/address`) @@ -146,7 +146,7 @@ func respondToIQ(t *testing.T, c net.Conn) { // Sending response to the Component mResp, err := xml.Marshal(iqResp) - _, err = fmt.Fprintln(c, string(mResp)) + _, err = fmt.Fprintln(sc.connection, string(mResp)) if err != nil { t.Errorf("Could not send response stanza : %s", err) } @@ -155,13 +155,13 @@ func respondToIQ(t *testing.T, c net.Conn) { // When a presence stanza is automatically sent (right now it's the case in the client), we may want to discard it // and test further stanzas. -func discardPresence(t *testing.T, c net.Conn) { - c.SetDeadline(time.Now().Add(defaultTimeout)) - defer c.SetDeadline(time.Time{}) +func discardPresence(t *testing.T, sc *ServerConn) { + sc.connection.SetDeadline(time.Now().Add(defaultTimeout)) + defer sc.connection.SetDeadline(time.Time{}) var presenceStz stanza.Presence recvBuf := make([]byte, len(InitialPresence)) - _, err := c.Read(recvBuf[:]) // recv data + _, err := sc.connection.Read(recvBuf[:]) // recv data if err != nil { if netErr, ok := err.(net.Error); ok && netErr.Timeout() { @@ -178,11 +178,11 @@ func discardPresence(t *testing.T, c net.Conn) { } // Reads next request coming from the Component. Expecting it to be an IQ request -func receiveIq(c net.Conn, decoder *xml.Decoder) (*stanza.IQ, error) { - c.SetDeadline(time.Now().Add(defaultTimeout)) - defer c.SetDeadline(time.Time{}) +func receiveIq(sc *ServerConn) (*stanza.IQ, error) { + sc.connection.SetDeadline(time.Now().Add(defaultTimeout)) + defer sc.connection.SetDeadline(time.Time{}) var iqStz stanza.IQ - err := decoder.Decode(&iqStz) + err := sc.decoder.Decode(&iqStz) if err != nil { return nil, err } @@ -202,14 +202,14 @@ func mockIQError(c net.Conn) { fmt.Fprintln(c, ``) } -func sendStreamFeatures(t *testing.T, c net.Conn, _ *xml.Decoder) { +func sendStreamFeatures(t *testing.T, sc *ServerConn) { // This is a basic server, supporting only 1 stream feature: SASL Plain Auth features := ` PLAIN ` - if _, err := fmt.Fprintln(c, features); err != nil { + if _, err := fmt.Fprintln(sc.connection, features); err != nil { t.Errorf("cannot send stream feature: %s", err) } } @@ -237,29 +237,29 @@ func readAuth(t *testing.T, decoder *xml.Decoder) string { return "" } -func sendBindFeature(t *testing.T, c net.Conn, _ *xml.Decoder) { +func sendBindFeature(t *testing.T, sc *ServerConn) { // This is a basic server, supporting only 1 stream feature after auth: resource binding features := ` ` - if _, err := fmt.Fprintln(c, features); err != nil { + if _, err := fmt.Fprintln(sc.connection, features); err != nil { t.Errorf("cannot send stream feature: %s", err) } } -func sendRFC3921Feature(t *testing.T, c net.Conn, _ *xml.Decoder) { +func sendRFC3921Feature(t *testing.T, sc *ServerConn) { // This is a basic server, supporting only 2 features after auth: resource & session binding features := ` ` - if _, err := fmt.Fprintln(c, features); err != nil { + if _, err := fmt.Fprintln(sc.connection, features); err != nil { t.Errorf("cannot send stream feature: %s", err) } } -func bind(t *testing.T, c net.Conn, decoder *xml.Decoder) { - se, err := stanza.NextStart(decoder) +func bind(t *testing.T, sc *ServerConn) { + se, err := stanza.NextStart(sc.decoder) if err != nil { t.Errorf("cannot read bind: %s", err) return @@ -267,7 +267,7 @@ func bind(t *testing.T, c net.Conn, decoder *xml.Decoder) { iq := &stanza.IQ{} // Decode element into pointer storage - if err = decoder.DecodeElement(&iq, &se); err != nil { + if err = sc.decoder.DecodeElement(&iq, &se); err != nil { t.Errorf("cannot decode bind iq: %s", err) return } @@ -280,12 +280,12 @@ func bind(t *testing.T, c net.Conn, decoder *xml.Decoder) { %s ` - fmt.Fprintf(c, result, iq.Id, "test@localhost/test") // TODO use real JID + fmt.Fprintf(sc.connection, result, iq.Id, "test@localhost/test") // TODO use real JID } } -func session(t *testing.T, c net.Conn, decoder *xml.Decoder) { - se, err := stanza.NextStart(decoder) +func session(t *testing.T, sc *ServerConn) { + se, err := stanza.NextStart(sc.decoder) if err != nil { t.Errorf("cannot read session: %s", err) return @@ -293,7 +293,7 @@ func session(t *testing.T, c net.Conn, decoder *xml.Decoder) { iq := &stanza.IQ{} // Decode element into pointer storage - if err = decoder.DecodeElement(&iq, &se); err != nil { + if err = sc.decoder.DecodeElement(&iq, &se); err != nil { t.Errorf("cannot decode session iq: %s", err) return } @@ -301,6 +301,6 @@ func session(t *testing.T, c net.Conn, decoder *xml.Decoder) { switch iq.Payload.(type) { case *stanza.StreamSession: result := `` - fmt.Fprintf(c, result, iq.Id) + fmt.Fprintf(sc.connection, result, iq.Id) } }