Transports need to handle open/close stanzas

XMPP and WebSocket transports require different open and close stanzas. To
handle this the responsibility handling those and creating the XML decoder is
moved to the Transport.
disco_info_form
Wichert Akkerman 5 years ago committed by Mickaël Rémond
parent 25fd476328
commit 92329b48e6

@ -13,6 +13,7 @@ func main() {
opts := xmpp.ComponentOptions{
TransportConfiguration: xmpp.TransportConfiguration{
Address: "localhost:9999",
Domain: "service.localhost",
},
Domain: "service.localhost",
Secret: "mypass",

@ -12,6 +12,7 @@ func main() {
opts := xmpp.ComponentOptions{
TransportConfiguration: xmpp.TransportConfiguration{
Address: "localhost:8888",
Domain: "service2.localhost",
},
Domain: "service2.localhost",
Secret: "mypass",

@ -16,7 +16,8 @@ import (
func main() {
config := xmpp.Config{
TransportConfiguration: xmpp.TransportConfiguration{
Address: "localhost:5222",
// Address: "localhost:5222",
Address: "ws://127.0.0.1:5280/xmpp",
},
Jid: "test@localhost",
Credential: xmpp.Password("test"),

@ -141,8 +141,15 @@ func NewClient(config Config, r *Router) (c *Client, err error) {
c.config.ConnectTimeout = 15 // 15 second as default
}
if config.TransportConfiguration.Domain == "" {
config.TransportConfiguration.Domain = config.parsedJid.Domain
}
c.transport = NewTransport(config.TransportConfiguration)
if config.StreamLogger != nil {
c.transport.LogTraffic(config.StreamLogger)
}
return
}
@ -158,7 +165,7 @@ func (c *Client) Connect() error {
func (c *Client) Resume(state SMState) error {
var err error
err = c.transport.Connect()
streamId, err := c.transport.Connect()
if err != nil {
return err
}
@ -168,6 +175,7 @@ func (c *Client) Resume(state SMState) error {
if c.Session, err = NewSession(c.transport, c.config, state); err != nil {
return err
}
c.Session.StreamId = streamId
c.updateState(StateSessionEstablished)
// Start the keepalive go routine
@ -181,13 +189,12 @@ func (c *Client) Resume(state SMState) error {
//fmt.Fprintf(client.conn, "<presence xml:lang='en'><show>%s</show><status>%s</status></presence>", "chat", "Online")
// TODO: Do we always want to send initial presence automatically ?
// Do we need an option to avoid that or do we rely on client to send the presence itself ?
fmt.Fprintf(c.Session.streamLogger, "<presence/>")
fmt.Fprintf(c.transport, "<presence/>")
return err
}
func (c *Client) Disconnect() {
_ = c.SendRaw("</stream:stream>")
// TODO: Add a way to wait for stream close acknowledgement from the server for clean disconnect
if c.transport != nil {
_ = c.transport.Close()
@ -210,7 +217,7 @@ func (c *Client) Send(packet stanza.Packet) error {
return errors.New("cannot marshal packet " + err.Error())
}
return c.sendWithWriter(c.Session.streamLogger, data)
return c.sendWithWriter(c.transport, data)
}
// SendRaw sends an XMPP stanza as a string to the server.
@ -223,7 +230,7 @@ func (c *Client) SendRaw(packet string) error {
return errors.New("client is not connected")
}
return c.sendWithWriter(c.Session.streamLogger, []byte(packet))
return c.sendWithWriter(c.transport, []byte(packet))
}
func (c *Client) sendWithWriter(writer io.Writer, packet []byte) error {
@ -238,7 +245,7 @@ func (c *Client) sendWithWriter(writer io.Writer, packet []byte) error {
// Loop: Receive data from server
func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}) (err error) {
for {
val, err := stanza.NextPacket(c.Session.decoder)
val, err := stanza.NextPacket(c.transport.GetDecoder())
if err != nil {
close(keepaliveQuit)
c.disconnected(state)

@ -67,33 +67,25 @@ func (c *Component) Connect() error {
}
func (c *Component) Resume(sm SMState) error {
var err error
c.transport = NewTransport(c.ComponentOptions.TransportConfiguration)
if err = c.transport.Connect(); err != nil {
return err
}
c.updateState(StateConnected)
// 1. Send stream open tag
if _, err := fmt.Fprintf(c.transport, componentStreamOpen, c.Domain, stanza.NSComponent, stanza.NSStream); err != nil {
c.updateState(StateStreamError)
return NewConnError(errors.New("cannot send stream open "+err.Error()), false)
var streamId string
if c.ComponentOptions.TransportConfiguration.Domain == "" {
c.ComponentOptions.TransportConfiguration.Domain = c.ComponentOptions.Domain
}
c.decoder = xml.NewDecoder(c.transport)
c.transport = NewTransport(c.ComponentOptions.TransportConfiguration)
// 2. Initialize xml decoder and extract streamID from reply
streamId, err := stanza.InitStream(c.decoder)
if err != nil {
if streamId, err = c.transport.Connect(); err != nil {
c.updateState(StateStreamError)
return NewConnError(errors.New("cannot init decoder "+err.Error()), false)
return err
}
c.updateState(StateConnected)
// 3. Authentication
// Authentication
if _, err := fmt.Fprintf(c.transport, "<handshake>%s</handshake>", c.handshake(streamId)); err != nil {
c.updateState(StateStreamError)
return NewConnError(errors.New("cannot send handshake "+err.Error()), false)
}
// 4. Check server response for authentication
// Check server response for authentication
val, err := stanza.NextPacket(c.decoder)
if err != nil {
c.updateState(StateDisconnected)
@ -116,7 +108,6 @@ func (c *Component) Resume(sm SMState) error {
}
func (c *Component) Disconnect() {
_ = c.SendRaw("</stream:stream>")
// TODO: Add a way to wait for stream close acknowledgement from the server for clean disconnect
if c.transport != nil {
_ = c.transport.Close()

@ -1,7 +1,6 @@
package xmpp
import (
"io"
"os"
)
@ -18,6 +17,5 @@ type Config struct {
ConnectTimeout int // Client timeout in seconds. Default to 15
// Insecure can be set to true to allow to open a session without TLS. If TLS
// is supported on the server, we will still try to use it.
Insecure bool
CharsetReader func(charset string, input io.Reader) (io.Reader, error) // passed to xml decoder
Insecure bool
}

@ -1,16 +1,12 @@
package xmpp
import (
"encoding/xml"
"errors"
"fmt"
"io"
"gosrc.io/xmpp/stanza"
)
const xmppStreamOpen = "<?xml version='1.0'?><stream:stream to='%s' xmlns='%s' xmlns:stream='%s' version='1.0'>"
type Session struct {
// Session info
BindJid string // Jabber ID as provided by XMPP server
@ -21,8 +17,7 @@ type Session struct {
lastPacketId int
// read / write
streamLogger io.ReadWriter
decoder *xml.Decoder
transport Transport
// error management
err error
@ -30,10 +25,11 @@ type Session struct {
func NewSession(transport Transport, o Config, state SMState) (*Session, error) {
s := new(Session)
s.transport = transport
s.SMState = state
s.init(transport, o)
s.init(o)
s.startTlsIfSupported(transport, o.parsedJid.Domain, o)
s.startTlsIfSupported(o)
if s.err != nil {
return nil, NewConnError(s.err, true)
@ -45,12 +41,12 @@ func NewSession(transport Transport, o Config, state SMState) (*Session, error)
}
if s.TlsEnabled {
s.reset(transport, o)
s.reset(o)
}
// auth
s.auth(o)
s.reset(transport, o)
s.reset(o)
// attempt resumption
if s.resume(o) {
@ -72,51 +68,31 @@ func (s *Session) PacketId() string {
return fmt.Sprintf("%x", s.lastPacketId)
}
func (s *Session) init(transport Transport, o Config) {
s.setStreamLogger(transport, o)
func (s *Session) init(o Config) {
s.Features = s.open(o.parsedJid.Domain)
}
func (s *Session) reset(transport Transport, o Config) {
func (s *Session) reset(o Config) {
if s.err != nil {
return
}
s.setStreamLogger(transport, o)
s.Features = s.open(o.parsedJid.Domain)
}
func (s *Session) setStreamLogger(transport Transport, o Config) {
s.streamLogger = newStreamLogger(transport, o.StreamLogger)
s.decoder = xml.NewDecoder(s.streamLogger)
s.decoder.CharsetReader = o.CharsetReader
}
func (s *Session) open(domain string) (f stanza.StreamFeatures) {
// Send stream open tag
if _, s.err = fmt.Fprintf(s.streamLogger, xmppStreamOpen, domain, stanza.NSClient, stanza.NSStream); s.err != nil {
return
}
// Set xml decoder and extract streamID from reply
s.StreamId, s.err = stanza.InitStream(s.decoder) // TODO refactor / rename
if s.err != nil {
return
}
// extract stream features
if s.err = s.decoder.Decode(&f); s.err != nil {
if s.err = s.transport.GetDecoder().Decode(&f); s.err != nil {
s.err = errors.New("stream open decode features: " + s.err.Error())
}
return
}
func (s *Session) startTlsIfSupported(transport Transport, domain string, o Config) {
func (s *Session) startTlsIfSupported(o Config) {
if s.err != nil {
return
}
if !transport.DoesStartTLS() {
if !s.transport.DoesStartTLS() {
if !o.Insecure {
s.err = errors.New("Transport does not support starttls")
}
@ -124,15 +100,15 @@ func (s *Session) startTlsIfSupported(transport Transport, domain string, o Conf
}
if _, ok := s.Features.DoesStartTLS(); ok {
fmt.Fprintf(s.streamLogger, "<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>")
fmt.Fprintf(s.transport, "<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>")
var k stanza.TLSProceed
if s.err = s.decoder.DecodeElement(&k, nil); s.err != nil {
if s.err = s.transport.GetDecoder().DecodeElement(&k, nil); s.err != nil {
s.err = errors.New("expecting starttls proceed: " + s.err.Error())
return
}
s.err = transport.StartTLS(domain)
s.err = s.transport.StartTLS()
if s.err == nil {
s.TlsEnabled = true
@ -151,7 +127,7 @@ func (s *Session) auth(o Config) {
return
}
s.err = authSASL(s.streamLogger, s.decoder, s.Features, o.parsedJid.Node, o.Credential)
s.err = authSASL(s.transport, s.transport.GetDecoder(), s.Features, o.parsedJid.Node, o.Credential)
}
// Attempt to resume session using stream management
@ -163,11 +139,11 @@ func (s *Session) resume(o Config) bool {
return false
}
fmt.Fprintf(s.streamLogger, "<resume xmlns='%s' h='%d' previd='%s'/>",
fmt.Fprintf(s.transport, "<resume xmlns='%s' h='%d' previd='%s'/>",
stanza.NSStreamManagement, s.SMState.Inbound, s.SMState.Id)
var packet stanza.Packet
packet, s.err = stanza.NextPacket(s.decoder)
packet, s.err = stanza.NextPacket(s.transport.GetDecoder())
if s.err == nil {
switch p := packet.(type) {
case stanza.SMResumed:
@ -194,14 +170,14 @@ func (s *Session) bind(o Config) {
// Send IQ message asking to bind to the local user name.
var resource = o.parsedJid.Resource
if resource != "" {
fmt.Fprintf(s.streamLogger, "<iq type='set' id='%s'><bind xmlns='%s'><resource>%s</resource></bind></iq>",
fmt.Fprintf(s.transport, "<iq type='set' id='%s'><bind xmlns='%s'><resource>%s</resource></bind></iq>",
s.PacketId(), stanza.NSBind, resource)
} else {
fmt.Fprintf(s.streamLogger, "<iq type='set' id='%s'><bind xmlns='%s'/></iq>", s.PacketId(), stanza.NSBind)
fmt.Fprintf(s.transport, "<iq type='set' id='%s'><bind xmlns='%s'/></iq>", s.PacketId(), stanza.NSBind)
}
var iq stanza.IQ
if s.err = s.decoder.Decode(&iq); s.err != nil {
if s.err = s.transport.GetDecoder().Decode(&iq); s.err != nil {
s.err = errors.New("error decoding iq bind result: " + s.err.Error())
return
}
@ -226,8 +202,8 @@ func (s *Session) rfc3921Session(o Config) {
var iq stanza.IQ
// We only negotiate session binding if it is mandatory, we skip it when optional.
if !s.Features.Session.IsOptional() {
fmt.Fprintf(s.streamLogger, "<iq type='set' id='%s'><session xmlns='%s'/></iq>", s.PacketId(), stanza.NSSession)
if s.err = s.decoder.Decode(&iq); s.err != nil {
fmt.Fprintf(s.transport, "<iq type='set' id='%s'><session xmlns='%s'/></iq>", s.PacketId(), stanza.NSSession)
if s.err = s.transport.GetDecoder().Decode(&iq); s.err != nil {
s.err = errors.New("expecting iq result after session open: " + s.err.Error())
return
}
@ -243,10 +219,10 @@ func (s *Session) EnableStreamManagement(o Config) {
return
}
fmt.Fprintf(s.streamLogger, "<enable xmlns='%s' resume='true'/>", stanza.NSStreamManagement)
fmt.Fprintf(s.transport, "<enable xmlns='%s' resume='true'/>", stanza.NSStreamManagement)
var packet stanza.Packet
packet, s.err = stanza.NextPacket(s.decoder)
packet, s.err = stanza.NextPacket(s.transport.GetDecoder())
if s.err == nil {
switch p := packet.(type) {
case stanza.SMEnabled:

@ -0,0 +1,13 @@
package stanza
import "encoding/xml"
// Open Packet
// Reference: WebSocket connections must start with this element
// https://tools.ietf.org/html/rfc7395#section-3.4
type WebsocketOpen struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-framing open"`
From string `xml:"from,attr"`
Id string `xml:"id,attr"`
Version string `xml:"version,attr"`
}

@ -1,167 +1,14 @@
package stanza
import (
"encoding/xml"
)
// ============================================================================
// StreamFeatures Packet
// Reference: The active stream features are published on
// https://xmpp.org/registrar/stream-features.html
// Note: That page misses draft and experimental XEP (i.e CSI, etc)
type StreamFeatures struct {
XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"`
// Server capabilities hash
Caps Caps
// Stream features
StartTLS tlsStartTLS
Mechanisms saslMechanisms
Bind Bind
StreamManagement streamManagement
// Obsolete
Session StreamSession
// ProcessOne Stream Features
P1Push p1Push
P1Rebind p1Rebind
p1Ack p1Ack
Any []xml.Name `xml:",any"`
}
func (StreamFeatures) Name() string {
return "stream:features"
}
type streamFeatureDecoder struct{}
var streamFeatures streamFeatureDecoder
func (streamFeatureDecoder) decode(p *xml.Decoder, se xml.StartElement) (StreamFeatures, error) {
var packet StreamFeatures
err := p.DecodeElement(&packet, &se)
return packet, err
}
// Capabilities
// Reference: https://xmpp.org/extensions/xep-0115.html#stream
// "A server MAY include its entity capabilities in a stream feature element so that connecting clients
// and peer servers do not need to send service discovery requests each time they connect."
// This is not a stream feature but a way to let client cache server disco info.
type Caps struct {
XMLName xml.Name `xml:"http://jabber.org/protocol/caps c"`
Hash string `xml:"hash,attr"`
Node string `xml:"node,attr"`
Ver string `xml:"ver,attr"`
Ext string `xml:"ext,attr,omitempty"`
}
// ============================================================================
// Supported Stream Features
// StartTLS feature
// Reference: RFC 6120 - https://tools.ietf.org/html/rfc6120#section-5.4
type tlsStartTLS struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls starttls"`
Required bool
}
// UnmarshalXML implements custom parsing startTLS required flag
func (stls *tlsStartTLS) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
stls.XMLName = start.Name
// Check subelements to extract required field as boolean
for {
t, err := d.Token()
if err != nil {
return err
}
switch tt := t.(type) {
case xml.StartElement:
elt := new(Node)
err = d.DecodeElement(elt, &tt)
if err != nil {
return err
}
if elt.XMLName.Local == "required" {
stls.Required = true
}
case xml.EndElement:
if tt == start.End() {
return nil
}
}
}
}
func (sf *StreamFeatures) DoesStartTLS() (feature tlsStartTLS, isSupported bool) {
if sf.StartTLS.XMLName.Space+" "+sf.StartTLS.XMLName.Local == nsTLS+" starttls" {
return sf.StartTLS, true
}
return feature, false
}
// Mechanisms
// Reference: RFC 6120 - https://tools.ietf.org/html/rfc6120#section-6.4.1
type saslMechanisms struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl mechanisms"`
Mechanism []string `xml:"mechanism"`
}
// StreamManagement
// Reference: XEP-0198 - https://xmpp.org/extensions/xep-0198.html#feature
type streamManagement struct {
XMLName xml.Name `xml:"urn:xmpp:sm:3 sm"`
}
func (sf *StreamFeatures) DoesStreamManagement() (isSupported bool) {
if sf.StreamManagement.XMLName.Space+" "+sf.StreamManagement.XMLName.Local == "urn:xmpp:sm:3 sm" {
return true
}
return false
}
// P1 extensions
// Reference: https://docs.ejabberd.im/developer/mobile/core-features/
// p1:push support
type p1Push struct {
XMLName xml.Name `xml:"p1:push push"`
}
// p1:rebind suppor
type p1Rebind struct {
XMLName xml.Name `xml:"p1:rebind rebind"`
}
// p1:ack support
type p1Ack struct {
XMLName xml.Name `xml:"p1:ack ack"`
}
// ============================================================================
// StreamError Packet
type StreamError struct {
XMLName xml.Name `xml:"http://etherx.jabber.org/streams error"`
Error xml.Name `xml:",any"`
Text string `xml:"urn:ietf:params:xml:ns:xmpp-streams text"`
}
func (StreamError) Name() string {
return "stream:error"
}
type streamErrorDecoder struct{}
var streamError streamErrorDecoder
func (streamErrorDecoder) decode(p *xml.Decoder, se xml.StartElement) (StreamError, error) {
var packet StreamError
err := p.DecodeElement(&packet, &se)
return packet, err
import "encoding/xml"
// Start of stream
// Reference: XMPP Core stream open
// https://tools.ietf.org/html/rfc6120#section-4.2
type Stream struct {
XMLName xml.Name `xml:"http://etherx.jabber.org/streams stream"`
From string `xml:"from,attr"`
To string `xml:"to,attr"`
Id string `xml:"id,attr"`
Version string `xml:"version,attr"`
}

@ -0,0 +1,167 @@
package stanza
import (
"encoding/xml"
)
// ============================================================================
// StreamFeatures Packet
// Reference: The active stream features are published on
// https://xmpp.org/registrar/stream-features.html
// Note: That page misses draft and experimental XEP (i.e CSI, etc)
type StreamFeatures struct {
XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"`
// Server capabilities hash
Caps Caps
// Stream features
StartTLS tlsStartTLS
Mechanisms saslMechanisms
Bind Bind
StreamManagement streamManagement
// Obsolete
Session StreamSession
// ProcessOne Stream Features
P1Push p1Push
P1Rebind p1Rebind
p1Ack p1Ack
Any []xml.Name `xml:",any"`
}
func (StreamFeatures) Name() string {
return "stream:features"
}
type streamFeatureDecoder struct{}
var streamFeatures streamFeatureDecoder
func (streamFeatureDecoder) decode(p *xml.Decoder, se xml.StartElement) (StreamFeatures, error) {
var packet StreamFeatures
err := p.DecodeElement(&packet, &se)
return packet, err
}
// Capabilities
// Reference: https://xmpp.org/extensions/xep-0115.html#stream
// "A server MAY include its entity capabilities in a stream feature element so that connecting clients
// and peer servers do not need to send service discovery requests each time they connect."
// This is not a stream feature but a way to let client cache server disco info.
type Caps struct {
XMLName xml.Name `xml:"http://jabber.org/protocol/caps c"`
Hash string `xml:"hash,attr"`
Node string `xml:"node,attr"`
Ver string `xml:"ver,attr"`
Ext string `xml:"ext,attr,omitempty"`
}
// ============================================================================
// Supported Stream Features
// StartTLS feature
// Reference: RFC 6120 - https://tools.ietf.org/html/rfc6120#section-5.4
type tlsStartTLS struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls starttls"`
Required bool
}
// UnmarshalXML implements custom parsing startTLS required flag
func (stls *tlsStartTLS) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
stls.XMLName = start.Name
// Check subelements to extract required field as boolean
for {
t, err := d.Token()
if err != nil {
return err
}
switch tt := t.(type) {
case xml.StartElement:
elt := new(Node)
err = d.DecodeElement(elt, &tt)
if err != nil {
return err
}
if elt.XMLName.Local == "required" {
stls.Required = true
}
case xml.EndElement:
if tt == start.End() {
return nil
}
}
}
}
func (sf *StreamFeatures) DoesStartTLS() (feature tlsStartTLS, isSupported bool) {
if sf.StartTLS.XMLName.Space+" "+sf.StartTLS.XMLName.Local == nsTLS+" starttls" {
return sf.StartTLS, true
}
return feature, false
}
// Mechanisms
// Reference: RFC 6120 - https://tools.ietf.org/html/rfc6120#section-6.4.1
type saslMechanisms struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl mechanisms"`
Mechanism []string `xml:"mechanism"`
}
// StreamManagement
// Reference: XEP-0198 - https://xmpp.org/extensions/xep-0198.html#feature
type streamManagement struct {
XMLName xml.Name `xml:"urn:xmpp:sm:3 sm"`
}
func (sf *StreamFeatures) DoesStreamManagement() (isSupported bool) {
if sf.StreamManagement.XMLName.Space+" "+sf.StreamManagement.XMLName.Local == "urn:xmpp:sm:3 sm" {
return true
}
return false
}
// P1 extensions
// Reference: https://docs.ejabberd.im/developer/mobile/core-features/
// p1:push support
type p1Push struct {
XMLName xml.Name `xml:"p1:push push"`
}
// p1:rebind suppor
type p1Rebind struct {
XMLName xml.Name `xml:"p1:rebind rebind"`
}
// p1:ack support
type p1Ack struct {
XMLName xml.Name `xml:"p1:ack ack"`
}
// ============================================================================
// StreamError Packet
type StreamError struct {
XMLName xml.Name `xml:"http://etherx.jabber.org/streams error"`
Error xml.Name `xml:",any"`
Text string `xml:"urn:ietf:params:xml:ns:xmpp-streams text"`
}
func (StreamError) Name() string {
return "stream:error"
}
type streamErrorDecoder struct{}
var streamError streamErrorDecoder
func (streamErrorDecoder) decode(p *xml.Decoder, se xml.StartElement) (StreamError, error) {
var packet StreamError
err := p.DecodeElement(&packet, &se)
return packet, err
}

@ -2,17 +2,16 @@ package xmpp
import (
"io"
"os"
)
// Mediated Read / Write on socket
// Used if logFile from Config is not nil
type streamLogger struct {
socket io.ReadWriter // Actual connection
logFile *os.File
logFile io.Writer
}
func newStreamLogger(conn io.ReadWriter, logFile *os.File) io.ReadWriter {
func newStreamLogger(conn io.ReadWriter, logFile io.Writer) io.ReadWriter {
if logFile == nil {
return conn
} else {

@ -2,7 +2,9 @@ package xmpp
import (
"crypto/tls"
"encoding/xml"
"errors"
"io"
"strings"
)
@ -12,17 +14,22 @@ type TransportConfiguration struct {
// Address is the XMPP Host and port to connect to. Host is of
// the form 'serverhost:port' i.e "localhost:8888"
Address string
Domain string
ConnectTimeout int // Client timeout in seconds. Default to 15
// tls.Config must not be modified after having been passed to NewClient. Any
// changes made after connecting are ignored.
TLSConfig *tls.Config
TLSConfig *tls.Config
CharsetReader func(charset string, input io.Reader) (io.Reader, error) // passed to xml decoder
}
type Transport interface {
Connect() error
Connect() (string, error)
DoesStartTLS() bool
StartTLS(domain string) error
StartTLS() error
LogTraffic(logFile io.Writer)
GetDecoder() *xml.Decoder
IsSecure() bool
Ping() error

@ -2,11 +2,15 @@ package xmpp
import (
"context"
"encoding/xml"
"errors"
"fmt"
"io"
"net"
"strings"
"time"
"gosrc.io/xmpp/stanza"
"nhooyr.io/websocket"
)
@ -16,35 +20,60 @@ var ServerDoesNotSupportXmppOverWebsocket = errors.New("The websocket server doe
type WebsocketTransport struct {
Config TransportConfiguration
decoder *xml.Decoder
wsConn *websocket.Conn
netConn net.Conn
ctx context.Context
logFile io.Writer
}
func (t *WebsocketTransport) Connect() error {
t.ctx = context.Background()
func (t *WebsocketTransport) Connect() (string, error) {
ctx := context.Background()
if t.Config.ConnectTimeout > 0 {
ctx, cancel := context.WithTimeout(t.ctx, time.Duration(t.Config.ConnectTimeout)*time.Second)
t.ctx = ctx
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, time.Duration(t.Config.ConnectTimeout)*time.Second)
defer cancel()
}
wsConn, response, err := websocket.Dial(t.ctx, t.Config.Address, &websocket.DialOptions{
wsConn, response, err := websocket.Dial(ctx, t.Config.Address, &websocket.DialOptions{
Subprotocols: []string{"xmpp"},
})
if err != nil {
return NewConnError(err, true)
return "", NewConnError(err, true)
}
if response.Header.Get("Sec-WebSocket-Protocol") != "xmpp" {
return ServerDoesNotSupportXmppOverWebsocket
_ = wsConn.Close(websocket.StatusBadGateway, "Could not negotiate XMPP subprotocol")
return "", NewConnError(ServerDoesNotSupportXmppOverWebsocket, true)
}
t.wsConn = wsConn
t.netConn = websocket.NetConn(t.ctx, t.wsConn, websocket.MessageText)
return nil
t.netConn = websocket.NetConn(ctx, t.wsConn, websocket.MessageText)
handshake := fmt.Sprintf("<open xmlns=\"urn:ietf:params:xml:ns:xmpp-framing\" to=\"%s\" version=\"1.0\" />", t.Config.Domain)
if _, err = t.Write([]byte(handshake)); err != nil {
_ = wsConn.Close(websocket.StatusBadGateway, "XMPP handshake error")
return "", NewConnError(err, false)
}
handshakeResponse := make([]byte, 2048)
if _, err = t.Read(handshakeResponse); err != nil {
_ = wsConn.Close(websocket.StatusBadGateway, "XMPP handshake error")
return "", NewConnError(err, false)
}
var openResponse = stanza.WebsocketOpen{}
if err = xml.Unmarshal(handshakeResponse, &openResponse); err != nil {
_ = wsConn.Close(websocket.StatusBadGateway, "XMPP handshake error")
return "", NewConnError(err, false)
}
t.decoder = xml.NewDecoder(t)
t.decoder.CharsetReader = t.Config.CharsetReader
return openResponse.Id, nil
}
func (t WebsocketTransport) StartTLS(domain string) error {
func (t WebsocketTransport) StartTLS() error {
return TLSNotSupported
}
@ -52,6 +81,10 @@ func (t WebsocketTransport) DoesStartTLS() bool {
return false
}
func (t WebsocketTransport) GetDecoder() *xml.Decoder {
return t.decoder
}
func (t WebsocketTransport) IsSecure() bool {
return strings.HasPrefix(t.Config.Address, "wss:")
}
@ -59,19 +92,29 @@ func (t WebsocketTransport) IsSecure() bool {
func (t WebsocketTransport) Ping() error {
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
defer cancel()
// Note that we do not use wsConn.Ping(), because not all websocket servers
// (ejabberd for example) implement ping frames
return t.wsConn.Write(ctx, websocket.MessageText, []byte(" "))
return t.wsConn.Ping(ctx)
}
func (t WebsocketTransport) Read(p []byte) (n int, err error) {
return t.netConn.Read(p)
func (t *WebsocketTransport) Read(p []byte) (n int, err error) {
n, err = t.netConn.Read(p)
if t.logFile != nil && n > 0 {
_, _ = fmt.Fprintf(t.logFile, "RECV:\n%s\n\n", p)
}
return
}
func (t WebsocketTransport) Write(p []byte) (n int, err error) {
if t.logFile != nil {
_, _ = fmt.Fprintf(t.logFile, "SEND:\n%s\n\n", p)
}
return t.netConn.Write(p)
}
func (t WebsocketTransport) Close() error {
t.Write([]byte("<close xmlns=\"urn:ietf:params:xml:ns:xmpp-framing\" />"))
return t.netConn.Close()
}
func (t *WebsocketTransport) LogTraffic(logFile io.Writer) {
t.logFile = logFile
}

@ -2,39 +2,65 @@ package xmpp
import (
"crypto/tls"
"encoding/xml"
"errors"
"fmt"
"io"
"net"
"time"
"gosrc.io/xmpp/stanza"
)
// XMPPTransport implements the XMPP native TCP transport
type XMPPTransport struct {
Config TransportConfiguration
TLSConfig *tls.Config
// TCP level connection / can be replaced by a TLS session after starttls
conn net.Conn
isSecure bool
Config TransportConfiguration
TLSConfig *tls.Config
decoder *xml.Decoder
conn net.Conn
readWriter io.ReadWriter
isSecure bool
}
func (t *XMPPTransport) Connect() error {
const xmppStreamOpen = "<?xml version='1.0'?><stream:stream to='%s' xmlns='%s' xmlns:stream='%s' version='1.0'>"
func (t *XMPPTransport) Connect() (string, error) {
var err error
t.conn, err = net.DialTimeout("tcp", t.Config.Address, time.Duration(t.Config.ConnectTimeout)*time.Second)
if err != nil {
return NewConnError(err, true)
return "", NewConnError(err, true)
}
return nil
if _, err = fmt.Fprintf(t.conn, xmppStreamOpen, t.Config.Domain, stanza.NSClient, stanza.NSStream); err != nil {
t.conn.Close()
return "", NewConnError(err, true)
}
t.decoder = xml.NewDecoder(t.readWriter)
t.decoder.CharsetReader = t.Config.CharsetReader
sessionId, err := stanza.InitStream(t.decoder)
if err != nil {
t.conn.Close()
return "", NewConnError(err, false)
}
t.readWriter = t.conn
return sessionId, nil
}
func (t XMPPTransport) DoesStartTLS() bool {
return true
}
func (t XMPPTransport) GetDecoder() *xml.Decoder {
return t.decoder
}
func (t XMPPTransport) IsSecure() bool {
return t.isSecure
}
func (t *XMPPTransport) StartTLS(domain string) error {
func (t *XMPPTransport) StartTLS() error {
if t.Config.TLSConfig == nil {
t.TLSConfig = &tls.Config{}
} else {
@ -42,7 +68,7 @@ func (t *XMPPTransport) StartTLS(domain string) error {
}
if t.TLSConfig.ServerName == "" {
t.TLSConfig.ServerName = domain
t.TLSConfig.ServerName = t.Config.Domain
}
tlsConn := tls.Client(t.conn, t.TLSConfig)
// We convert existing connection to TLS
@ -51,7 +77,7 @@ func (t *XMPPTransport) StartTLS(domain string) error {
}
if !t.TLSConfig.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(domain); err != nil {
if err := tlsConn.VerifyHostname(t.Config.Domain); err != nil {
return err
}
}
@ -72,13 +98,18 @@ func (t XMPPTransport) Ping() error {
}
func (t XMPPTransport) Read(p []byte) (n int, err error) {
return t.conn.Read(p)
return t.readWriter.Read(p)
}
func (t XMPPTransport) Write(p []byte) (n int, err error) {
return t.conn.Write(p)
return t.readWriter.Write(p)
}
func (t XMPPTransport) Close() error {
_, _ = t.readWriter.Write([]byte("</stream:stream>"))
return t.conn.Close()
}
func (t *XMPPTransport) LogTraffic(logFile io.Writer) {
t.readWriter = &streamLogger{t.conn, logFile}
}

Loading…
Cancel
Save