diff --git a/client.go b/client.go index e9161e9..fd5763a 100644 --- a/client.go +++ b/client.go @@ -88,8 +88,7 @@ type Client struct { // Store user defined options and states config Config // Session gather data that can be accessed by users of this library - Session *Session - // TCP level connection / can be replaced by a TLS session after starttls + Session *Session transport Transport // Router is used to dispatch packets router *Router @@ -139,12 +138,13 @@ func NewClient(config Config, r *Router) (c *Client, err error) { c = new(Client) c.config = config c.router = r - c.transport = &XMPPTransport{} if c.config.ConnectTimeout == 0 { c.config.ConnectTimeout = 15 // 15 second as default } + c.transport = &XMPPTransport{Config: config.TransportConfiguration} + return } @@ -160,7 +160,7 @@ func (c *Client) Connect() error { func (c *Client) Resume(state SMState) error { var err error - err = c.transport.Connect(c.config.Address, c.config) + err = c.transport.Connect(c.config.Address) if err != nil { return err } diff --git a/config.go b/config.go index 8ce304f..c4ff97f 100644 --- a/config.go +++ b/config.go @@ -1,12 +1,15 @@ package xmpp import ( - "crypto/tls" "io" "os" ) type Config struct { + // TransportConfiguration must not be modified after having been passed to NewClient. Any + // changes made after connecting are ignored. + TransportConfiguration + Address string Jid string parsedJid *Jid // For easier manipulation @@ -14,9 +17,6 @@ type Config struct { StreamLogger *os.File // Used for debugging Lang string // TODO: should default to 'en' ConnectTimeout int // Client timeout in seconds. Default to 15 - // tls.Config must not be modified after having been passed to NewClient. Any - // changes made after connecting are ignored. - TLSConfig *tls.Config // Insecure can be set to true to allow to open a session without TLS. If TLS // is supported on the server, we will still try to use it. Insecure bool diff --git a/session.go b/session.go index c7a1454..191c2c3 100644 --- a/session.go +++ b/session.go @@ -132,7 +132,7 @@ func (s *Session) startTlsIfSupported(transport Transport, domain string, o Conf return } - s.err = transport.StartTLS(domain, o) + s.err = transport.StartTLS(domain) if s.err == nil { s.TlsEnabled = true diff --git a/transport.go b/transport.go index 0ac7195..296cec1 100644 --- a/transport.go +++ b/transport.go @@ -6,10 +6,17 @@ import ( "time" ) +type TransportConfiguration struct { + ConnectTimeout int // Client timeout in seconds. Default to 15 + // tls.Config must not be modified after having been passed to NewClient. Any + // changes made after connecting are ignored. + TLSConfig *tls.Config +} + type Transport interface { - Connect(address string, c Config) error + Connect(address string) error DoesStartTLS() bool - StartTLS(domain string, c Config) error + StartTLS(domain string) error Read(p []byte) (n int, err error) Write(p []byte) (n int, err error) @@ -18,15 +25,16 @@ type Transport interface { // XMPPTransport implements the XMPP native TCP transport type XMPPTransport struct { + Config TransportConfiguration TLSConfig *tls.Config // TCP level connection / can be replaced by a TLS session after starttls conn net.Conn } -func (t *XMPPTransport) Connect(address string, c Config) error { +func (t *XMPPTransport) Connect(address string) error { var err error - t.conn, err = net.DialTimeout("tcp", address, time.Duration(c.ConnectTimeout)*time.Second) + t.conn, err = net.DialTimeout("tcp", address, time.Duration(t.Config.ConnectTimeout)*time.Second) return err } @@ -34,17 +42,13 @@ func (t XMPPTransport) DoesStartTLS() bool { return true } -func (t *XMPPTransport) StartTLS(domain string, c Config) error { - if t.TLSConfig == nil { - if c.TLSConfig != nil { - t.TLSConfig = c.TLSConfig - } else { - t.TLSConfig = &tls.Config{} - } +func (t *XMPPTransport) StartTLS(domain string) error { + if t.Config.TLSConfig == nil { + t.Config.TLSConfig = &tls.Config{} } - if t.TLSConfig.ServerName == "" { - t.TLSConfig.ServerName = domain + if t.Config.TLSConfig.ServerName == "" { + t.Config.TLSConfig.ServerName = domain } tlsConn := tls.Client(t.conn, t.TLSConfig) // We convert existing connection to TLS