diff --git a/cert_checker.go b/cert_checker.go index c85664b..fcee7b1 100644 --- a/cert_checker.go +++ b/cert_checker.go @@ -51,7 +51,7 @@ func (c *ServerCheck) Check() error { decoder := xml.NewDecoder(tcpconn) // Send stream open tag - if _, err = fmt.Fprintf(tcpconn, xmppStreamOpen, c.domain, stanza.NSClient, stanza.NSStream); err != nil { + if _, err = fmt.Fprintf(tcpconn, clientStreamOpen, c.domain); err != nil { return err } diff --git a/client.go b/client.go index 3d3c868..686519a 100644 --- a/client.go +++ b/client.go @@ -144,7 +144,7 @@ func NewClient(config Config, r *Router) (c *Client, err error) { if config.TransportConfiguration.Domain == "" { config.TransportConfiguration.Domain = config.parsedJid.Domain } - c.transport = NewTransport(config.TransportConfiguration) + c.transport = NewClientTransport(config.TransportConfiguration) if config.StreamLogger != nil { c.transport.LogTraffic(config.StreamLogger) diff --git a/component.go b/component.go index 137bc05..18be3bc 100644 --- a/component.go +++ b/component.go @@ -11,8 +11,6 @@ import ( "gosrc.io/xmpp/stanza" ) -const componentStreamOpen = "" - type ComponentOptions struct { TransportConfiguration @@ -71,7 +69,11 @@ func (c *Component) Resume(sm SMState) error { if c.ComponentOptions.TransportConfiguration.Domain == "" { c.ComponentOptions.TransportConfiguration.Domain = c.ComponentOptions.Domain } - c.transport = NewTransport(c.ComponentOptions.TransportConfiguration) + c.transport, err = NewComponentTransport(c.ComponentOptions.TransportConfiguration) + if err != nil { + c.updateState(StateStreamError) + return err + } if streamId, err = c.transport.Connect(); err != nil { c.updateState(StateStreamError) diff --git a/session.go b/session.go index 692203e..ab7191e 100644 --- a/session.go +++ b/session.go @@ -108,7 +108,7 @@ func (s *Session) startTlsIfSupported(o Config) { return } - s.err = s.transport.StartTLS() + s.StreamId, s.err = s.transport.StartTLS() if s.err == nil { s.TlsEnabled = true diff --git a/transport.go b/transport.go index daf66a1..2e44381 100644 --- a/transport.go +++ b/transport.go @@ -4,11 +4,13 @@ import ( "crypto/tls" "encoding/xml" "errors" + "fmt" "io" "strings" ) -var TLSNotSupported = errors.New("Transport does not support StartTLS") +var ErrTransportProtocolNotSupported = errors.New("Transport protocol not supported") +var ErrTLSNotSupported = errors.New("Transport does not support StartTLS") type TransportConfiguration struct { // Address is the XMPP Host and port to connect to. Host is of @@ -25,7 +27,7 @@ type TransportConfiguration struct { type Transport interface { Connect() (string, error) DoesStartTLS() bool - StartTLS() error + StartTLS() (string, error) LogTraffic(logFile io.Writer) @@ -38,16 +40,34 @@ type Transport interface { Close() error } -// NewTransport creates a new Transport instance. +// NewClientTransport creates a new Transport instance for clients. // The type of transport is determined by the address in the configuration: // - if the address is a URL with the `ws` or `wss` scheme WebsocketTransport is used // - in all other cases a XMPPTransport is used // For XMPPTransport it is mandatory for the address to have a port specified. -func NewTransport(config TransportConfiguration) Transport { +func NewClientTransport(config TransportConfiguration) Transport { if strings.HasPrefix(config.Address, "ws:") || strings.HasPrefix(config.Address, "wss:") { return &WebsocketTransport{Config: config} } config.Address = ensurePort(config.Address, 5222) - return &XMPPTransport{Config: config} + return &XMPPTransport{ + Config: config, + openStatement: clientStreamOpen, + } +} + +// NewComponentTransport creates a new Transport instance for components. +// Only XMPP transports are allowed. If you try to use any other protocol an error +// will be returned. +func NewComponentTransport(config TransportConfiguration) (Transport, error) { + if strings.HasPrefix(config.Address, "ws:") || strings.HasPrefix(config.Address, "wss:") { + return nil, fmt.Errorf("Components only support XMPP transport: %w", ErrTransportProtocolNotSupported) + } + + config.Address = ensurePort(config.Address, 5222) + return &XMPPTransport{ + Config: config, + openStatement: componentStreamOpen, + }, nil } diff --git a/websocket_transport.go b/websocket_transport.go index 45b54a4..ad422c1 100644 --- a/websocket_transport.go +++ b/websocket_transport.go @@ -108,8 +108,8 @@ func (t WebsocketTransport) startReader() { }() } -func (t WebsocketTransport) StartTLS() error { - return TLSNotSupported +func (t WebsocketTransport) StartTLS() (string, error) { + return "", ErrTLSNotSupported } func (t WebsocketTransport) DoesStartTLS() bool { diff --git a/xmpp_transport.go b/xmpp_transport.go index 87b4a93..edcabdf 100644 --- a/xmpp_transport.go +++ b/xmpp_transport.go @@ -14,16 +14,19 @@ import ( // XMPPTransport implements the XMPP native TCP transport type XMPPTransport struct { - Config TransportConfiguration - TLSConfig *tls.Config - decoder *xml.Decoder - conn net.Conn - readWriter io.ReadWriter - logFile io.Writer - isSecure bool + openStatement string + Config TransportConfiguration + TLSConfig *tls.Config + decoder *xml.Decoder + conn net.Conn + readWriter io.ReadWriter + logFile io.Writer + isSecure bool } -const xmppStreamOpen = "" +var componentStreamOpen = fmt.Sprintf("", stanza.NSComponent, stanza.NSStream) + +var clientStreamOpen = fmt.Sprintf("", stanza.NSClient, stanza.NSStream) func (t *XMPPTransport) Connect() (string, error) { var err error @@ -34,8 +37,11 @@ func (t *XMPPTransport) Connect() (string, error) { } t.readWriter = newStreamLogger(t.conn, t.logFile) + return t.startStream() +} - if _, err = fmt.Fprintf(t.readWriter, xmppStreamOpen, t.Config.Domain, stanza.NSClient, stanza.NSStream); err != nil { +func (t *XMPPTransport) startStream() (string, error) { + if _, err := fmt.Fprintf(t.readWriter, t.openStatement, t.Config.Domain); err != nil { t.conn.Close() return "", NewConnError(err, true) } @@ -62,7 +68,7 @@ func (t XMPPTransport) IsSecure() bool { return t.isSecure } -func (t *XMPPTransport) StartTLS() error { +func (t *XMPPTransport) StartTLS() (string, error) { if t.Config.TLSConfig == nil { t.TLSConfig = &tls.Config{} } else { @@ -75,7 +81,7 @@ func (t *XMPPTransport) StartTLS() error { tlsConn := tls.Client(t.conn, t.TLSConfig) // We convert existing connection to TLS if err := tlsConn.Handshake(); err != nil { - return err + return "", err } t.conn = tlsConn @@ -85,12 +91,13 @@ func (t *XMPPTransport) StartTLS() error { if !t.TLSConfig.InsecureSkipVerify { if err := tlsConn.VerifyHostname(t.Config.Domain); err != nil { - return err + return "", err } } t.isSecure = true - return nil + + return t.startStream() } func (t XMPPTransport) Ping() error {