diff --git a/transport.go b/transport.go index 296cec1..30bcfca 100644 --- a/transport.go +++ b/transport.go @@ -2,8 +2,6 @@ package xmpp import ( "crypto/tls" - "net" - "time" ) type TransportConfiguration struct { @@ -22,57 +20,3 @@ type Transport interface { Write(p []byte) (n int, err error) Close() error } - -// XMPPTransport implements the XMPP native TCP transport -type XMPPTransport struct { - Config TransportConfiguration - TLSConfig *tls.Config - // TCP level connection / can be replaced by a TLS session after starttls - conn net.Conn -} - -func (t *XMPPTransport) Connect(address string) error { - var err error - - t.conn, err = net.DialTimeout("tcp", address, time.Duration(t.Config.ConnectTimeout)*time.Second) - return err -} - -func (t XMPPTransport) DoesStartTLS() bool { - return true -} - -func (t *XMPPTransport) StartTLS(domain string) error { - if t.Config.TLSConfig == nil { - t.Config.TLSConfig = &tls.Config{} - } - - if t.Config.TLSConfig.ServerName == "" { - t.Config.TLSConfig.ServerName = domain - } - tlsConn := tls.Client(t.conn, t.TLSConfig) - // We convert existing connection to TLS - if err := tlsConn.Handshake(); err != nil { - return err - } - - if !t.TLSConfig.InsecureSkipVerify { - if err := tlsConn.VerifyHostname(domain); err != nil { - return err - } - } - - return nil -} - -func (t XMPPTransport) Read(p []byte) (n int, err error) { - return t.conn.Read(p) -} - -func (t XMPPTransport) Write(p []byte) (n int, err error) { - return t.conn.Write(p) -} - -func (t XMPPTransport) Close() error { - return t.conn.Close() -} diff --git a/xmpp_transport.go b/xmpp_transport.go new file mode 100644 index 0000000..fcbaa3b --- /dev/null +++ b/xmpp_transport.go @@ -0,0 +1,61 @@ +package xmpp + +import ( + "crypto/tls" + "net" + "time" +) + +// XMPPTransport implements the XMPP native TCP transport +type XMPPTransport struct { + Config TransportConfiguration + TLSConfig *tls.Config + // TCP level connection / can be replaced by a TLS session after starttls + conn net.Conn +} + +func (t *XMPPTransport) Connect(address string) error { + var err error + + t.conn, err = net.DialTimeout("tcp", address, time.Duration(t.Config.ConnectTimeout)*time.Second) + return err +} + +func (t XMPPTransport) DoesStartTLS() bool { + return true +} + +func (t *XMPPTransport) StartTLS(domain string) error { + if t.Config.TLSConfig == nil { + t.Config.TLSConfig = &tls.Config{} + } + + if t.Config.TLSConfig.ServerName == "" { + t.Config.TLSConfig.ServerName = domain + } + tlsConn := tls.Client(t.conn, t.TLSConfig) + // We convert existing connection to TLS + if err := tlsConn.Handshake(); err != nil { + return err + } + + if !t.TLSConfig.InsecureSkipVerify { + if err := tlsConn.VerifyHostname(domain); err != nil { + return err + } + } + + return nil +} + +func (t XMPPTransport) Read(p []byte) (n int, err error) { + return t.conn.Read(p) +} + +func (t XMPPTransport) Write(p []byte) (n int, err error) { + return t.conn.Write(p) +} + +func (t XMPPTransport) Close() error { + return t.conn.Close() +}