diff --git a/session.go b/session.go index 191c2c3..ccf1993 100644 --- a/session.go +++ b/session.go @@ -39,7 +39,7 @@ func NewSession(transport Transport, o Config, state SMState) (*Session, error) return nil, NewConnError(s.err, true) } - if !s.TlsEnabled && !o.Insecure { + if !transport.IsSecure() && !o.Insecure { err := fmt.Errorf("failed to negotiate TLS session : %s", s.err) return nil, NewConnError(err, true) } diff --git a/transport.go b/transport.go index e435633..e9a35db 100644 --- a/transport.go +++ b/transport.go @@ -2,8 +2,11 @@ package xmpp import ( "crypto/tls" + "errors" ) +var TLSNotSupported = errors.New("Transport does not support StartTLS") + type TransportConfiguration struct { // Address is the XMPP Host and port to connect to. Host is of // the form 'serverhost:port' i.e "localhost:8888" @@ -19,6 +22,8 @@ type Transport interface { DoesStartTLS() bool StartTLS(domain string) error + IsSecure() bool + Read(p []byte) (n int, err error) Write(p []byte) (n int, err error) Close() error diff --git a/websocket_transport.go b/websocket_transport.go index 9363cf8..13345a0 100644 --- a/websocket_transport.go +++ b/websocket_transport.go @@ -2,7 +2,6 @@ package xmpp import ( "context" - "errors" "net" "strings" "time" @@ -23,9 +22,6 @@ func (t *WebsocketTransport) Connect() error { ctx, cancel := context.WithTimeout(t.ctx, time.Duration(t.Config.ConnectTimeout)*time.Second) defer cancel() - if !c.Insecure && strings.HasPrefix(address, "wss:") { - return errors.New("Websocket address is not secure") - } wsConn, _, err := websocket.Dial(ctx, t.Config.Address, nil) if err != nil { t.wsConn = wsConn @@ -34,10 +30,18 @@ func (t *WebsocketTransport) Connect() error { return err } +func (t WebsocketTransport) StartTLS(domain string) error { + return TLSNotSupported +} + func (t WebsocketTransport) DoesStartTLS() bool { return false } +func (t WebsocketTransport) IsSecure() bool { + return strings.HasPrefix(t.Config.Address, "wss:") +} + func (t WebsocketTransport) Read(p []byte) (n int, err error) { return t.netConn.Read(p) } diff --git a/xmpp_transport.go b/xmpp_transport.go index 6f526db..088e22b 100644 --- a/xmpp_transport.go +++ b/xmpp_transport.go @@ -11,7 +11,8 @@ type XMPPTransport struct { Config TransportConfiguration TLSConfig *tls.Config // TCP level connection / can be replaced by a TLS session after starttls - conn net.Conn + conn net.Conn + isSecure bool } func (t *XMPPTransport) Connect() error { @@ -25,6 +26,10 @@ func (t XMPPTransport) DoesStartTLS() bool { return true } +func (t XMPPTransport) IsSecure() bool { + return t.isSecure +} + func (t *XMPPTransport) StartTLS(domain string) error { if t.Config.TLSConfig == nil { t.Config.TLSConfig = &tls.Config{} @@ -45,6 +50,7 @@ func (t *XMPPTransport) StartTLS(domain string) error { } } + t.isSecure = true return nil }