diff --git a/client.go b/client.go index cace5fd..e9161e9 100644 --- a/client.go +++ b/client.go @@ -90,7 +90,7 @@ type Client struct { // Session gather data that can be accessed by users of this library Session *Session // TCP level connection / can be replaced by a TLS session after starttls - conn net.Conn + transport Transport // Router is used to dispatch packets router *Router // Track and broadcast connection state @@ -139,6 +139,7 @@ func NewClient(config Config, r *Router) (c *Client, err error) { c = new(Client) c.config = config c.router = r + c.transport = &XMPPTransport{} if c.config.ConnectTimeout == 0 { c.config.ConnectTimeout = 15 // 15 second as default @@ -159,21 +160,21 @@ func (c *Client) Connect() error { func (c *Client) Resume(state SMState) error { var err error - c.conn, err = net.DialTimeout("tcp", c.config.Address, time.Duration(c.config.ConnectTimeout)*time.Second) + err = c.transport.Connect(c.config.Address, c.config) if err != nil { return err } c.updateState(StateConnected) // Client is ok, we now open XMPP session - if c.conn, c.Session, err = NewSession(c.conn, c.config, state); err != nil { + if c.Session, err = NewSession(c.transport, c.config, state); err != nil { return err } c.updateState(StateSessionEstablished) // Start the keepalive go routine keepaliveQuit := make(chan struct{}) - go keepalive(c.conn, keepaliveQuit) + go keepalive(c.transport, keepaliveQuit) // Start the receiver go routine state = c.Session.SMState go c.recv(state, keepaliveQuit) @@ -190,7 +191,7 @@ func (c *Client) Resume(state SMState) error { func (c *Client) Disconnect() { _ = c.SendRaw("") // TODO: Add a way to wait for stream close acknowledgement from the server for clean disconnect - conn := c.conn + conn := c.transport if conn != nil { _ = conn.Close() } @@ -202,7 +203,7 @@ func (c *Client) SetHandler(handler EventHandler) { // Send marshals XMPP stanza and sends it to the server. func (c *Client) Send(packet stanza.Packet) error { - conn := c.conn + conn := c.transport if conn == nil { return errors.New("client is not connected") } @@ -220,7 +221,7 @@ func (c *Client) Send(packet stanza.Packet) error { // disconnect the client. It is up to the user of this method to // carefully craft the XML content to produce valid XMPP. func (c *Client) SendRaw(packet string) error { - conn := c.conn + conn := c.transport if conn == nil { return errors.New("client is not connected") } @@ -272,16 +273,16 @@ func (c *Client) recv(state SMState, keepaliveQuit chan<- struct{}) (err error) // Loop: send whitespace keepalive to server // This is use to keep the connection open, but also to detect connection loss // and trigger proper client connection shutdown. -func keepalive(conn net.Conn, quit <-chan struct{}) { +func keepalive(transport Transport, quit <-chan struct{}) { // TODO: Make keepalive interval configurable ticker := time.NewTicker(30 * time.Second) for { select { case <-ticker.C: - if n, err := fmt.Fprintf(conn, "\n"); err != nil || n != 1 { - // When keep alive fails, we force close the connection. In all cases, the recv will also fail. + if n, err := fmt.Fprintf(transport, "\n"); err != nil || n != 1 { + // When keep alive fails, we force close the transportection. In all cases, the recv will also fail. ticker.Stop() - _ = conn.Close() + _ = transport.Close() return } case <-quit: diff --git a/config.go b/config.go index 84a7d29..8ce304f 100644 --- a/config.go +++ b/config.go @@ -14,8 +14,8 @@ type Config struct { StreamLogger *os.File // Used for debugging Lang string // TODO: should default to 'en' ConnectTimeout int // Client timeout in seconds. Default to 15 - // tls.Config must not be modified after having been passed to NewClient. The - // Client connect method may override the tls.Config.ServerName if it was not set. + // tls.Config must not be modified after having been passed to NewClient. Any + // changes made after connecting are ignored. TLSConfig *tls.Config // Insecure can be set to true to allow to open a session without TLS. If TLS // is supported on the server, we will still try to use it. diff --git a/session.go b/session.go index 1cad2b4..c7a1454 100644 --- a/session.go +++ b/session.go @@ -1,12 +1,10 @@ package xmpp import ( - "crypto/tls" "encoding/xml" "errors" "fmt" "io" - "net" "gosrc.io/xmpp/stanza" ) @@ -30,35 +28,33 @@ type Session struct { err error } -func NewSession(conn net.Conn, o Config, state SMState) (net.Conn, *Session, error) { +func NewSession(transport Transport, o Config, state SMState) (*Session, error) { s := new(Session) s.SMState = state - s.init(conn, o) + s.init(transport, o) - // starttls - var tlsConn net.Conn - tlsConn = s.startTlsIfSupported(conn, o.parsedJid.Domain, o) + s.startTlsIfSupported(transport, o.parsedJid.Domain, o) if s.err != nil { - return nil, nil, NewConnError(s.err, true) + return nil, NewConnError(s.err, true) } if !s.TlsEnabled && !o.Insecure { err := fmt.Errorf("failed to negotiate TLS session : %s", s.err) - return nil, nil, NewConnError(err, true) + return nil, NewConnError(err, true) } if s.TlsEnabled { - s.reset(conn, tlsConn, o) + s.reset(transport, o) } // auth s.auth(o) - s.reset(tlsConn, tlsConn, o) + s.reset(transport, o) // attempt resumption if s.resume(o) { - return tlsConn, s, s.err + return s, s.err } // otherwise, bind resource and 'start' XMPP session @@ -68,7 +64,7 @@ func NewSession(conn net.Conn, o Config, state SMState) (net.Conn, *Session, err // Enable stream management if supported s.EnableStreamManagement(o) - return tlsConn, s, s.err + return s, s.err } func (s *Session) PacketId() string { @@ -76,24 +72,22 @@ func (s *Session) PacketId() string { return fmt.Sprintf("%x", s.lastPacketId) } -func (s *Session) init(conn net.Conn, o Config) { - s.setStreamLogger(nil, conn, o) +func (s *Session) init(transport Transport, o Config) { + s.setStreamLogger(transport, o) s.Features = s.open(o.parsedJid.Domain) } -func (s *Session) reset(conn net.Conn, newConn net.Conn, o Config) { +func (s *Session) reset(transport Transport, o Config) { if s.err != nil { return } - s.setStreamLogger(conn, newConn, o) + s.setStreamLogger(transport, o) s.Features = s.open(o.parsedJid.Domain) } -func (s *Session) setStreamLogger(conn net.Conn, newConn net.Conn, o Config) { - if newConn != conn { - s.streamLogger = newStreamLogger(newConn, o.StreamLogger) - } +func (s *Session) setStreamLogger(transport Transport, o Config) { + s.streamLogger = newStreamLogger(transport, o.StreamLogger) s.decoder = xml.NewDecoder(s.streamLogger) s.decoder.CharsetReader = o.CharsetReader } @@ -117,9 +111,16 @@ func (s *Session) open(domain string) (f stanza.StreamFeatures) { return } -func (s *Session) startTlsIfSupported(conn net.Conn, domain string, o Config) net.Conn { +func (s *Session) startTlsIfSupported(transport Transport, domain string, o Config) { if s.err != nil { - return conn + return + } + + if !transport.DoesStartTLS() { + if !o.Insecure { + s.err = errors.New("Transport does not support starttls") + } + return } if _, ok := s.Features.DoesStartTLS(); ok { @@ -128,39 +129,21 @@ func (s *Session) startTlsIfSupported(conn net.Conn, domain string, o Config) ne var k stanza.TLSProceed if s.err = s.decoder.DecodeElement(&k, nil); s.err != nil { s.err = errors.New("expecting starttls proceed: " + s.err.Error()) - return conn - } - - if o.TLSConfig == nil { - o.TLSConfig = &tls.Config{} + return } - if o.TLSConfig.ServerName == "" { - o.TLSConfig.ServerName = domain - } - tlsConn := tls.Client(conn, o.TLSConfig) - // We convert existing connection to TLS - if s.err = tlsConn.Handshake(); s.err != nil { - return tlsConn - } - - if !o.TLSConfig.InsecureSkipVerify { - s.err = tlsConn.VerifyHostname(domain) - } + s.err = transport.StartTLS(domain, o) if s.err == nil { s.TlsEnabled = true } - return tlsConn + return } // If we do not allow cleartext connections, make it explicit that server do not support starttls if !o.Insecure { s.err = errors.New("XMPP server does not advertise support for starttls") } - - // starttls is not supported => we do not upgrade the connection: - return conn } func (s *Session) auth(o Config) { diff --git a/transport.go b/transport.go new file mode 100644 index 0000000..0ac7195 --- /dev/null +++ b/transport.go @@ -0,0 +1,74 @@ +package xmpp + +import ( + "crypto/tls" + "net" + "time" +) + +type Transport interface { + Connect(address string, c Config) error + DoesStartTLS() bool + StartTLS(domain string, c Config) error + + Read(p []byte) (n int, err error) + Write(p []byte) (n int, err error) + Close() error +} + +// XMPPTransport implements the XMPP native TCP transport +type XMPPTransport struct { + TLSConfig *tls.Config + // TCP level connection / can be replaced by a TLS session after starttls + conn net.Conn +} + +func (t *XMPPTransport) Connect(address string, c Config) error { + var err error + + t.conn, err = net.DialTimeout("tcp", address, time.Duration(c.ConnectTimeout)*time.Second) + return err +} + +func (t XMPPTransport) DoesStartTLS() bool { + return true +} + +func (t *XMPPTransport) StartTLS(domain string, c Config) error { + if t.TLSConfig == nil { + if c.TLSConfig != nil { + t.TLSConfig = c.TLSConfig + } else { + t.TLSConfig = &tls.Config{} + } + } + + if t.TLSConfig.ServerName == "" { + t.TLSConfig.ServerName = domain + } + tlsConn := tls.Client(t.conn, t.TLSConfig) + // We convert existing connection to TLS + if err := tlsConn.Handshake(); err != nil { + return err + } + + if !t.TLSConfig.InsecureSkipVerify { + if err := tlsConn.VerifyHostname(domain); err != nil { + return err + } + } + + return nil +} + +func (t XMPPTransport) Read(p []byte) (n int, err error) { + return t.conn.Read(p) +} + +func (t XMPPTransport) Write(p []byte) (n int, err error) { + return t.conn.Write(p) +} + +func (t XMPPTransport) Close() error { + return t.conn.Close() +}