Correctly open new streams after StartTLS and auth

disco_info_form
Wichert Akkerman 5 years ago committed by Mickaël Rémond
parent 33446ad0ba
commit 6da1962962

@ -73,9 +73,10 @@ func (s *Session) init(o Config) {
} }
func (s *Session) reset(o Config) { func (s *Session) reset(o Config) {
if s.err != nil { if s.StreamId, s.err = s.transport.StartStream(); s.err != nil {
return return
} }
s.Features = s.open(o.parsedJid.Domain) s.Features = s.open(o.parsedJid.Domain)
} }
@ -108,7 +109,7 @@ func (s *Session) startTlsIfSupported(o Config) {
return return
} }
s.StreamId, s.err = s.transport.StartTLS() s.err = s.transport.StartTLS()
if s.err == nil { if s.err == nil {
s.TlsEnabled = true s.TlsEnabled = true

@ -6,6 +6,7 @@ const (
NSSASL = "urn:ietf:params:xml:ns:xmpp-sasl" NSSASL = "urn:ietf:params:xml:ns:xmpp-sasl"
NSBind = "urn:ietf:params:xml:ns:xmpp-bind" NSBind = "urn:ietf:params:xml:ns:xmpp-bind"
NSSession = "urn:ietf:params:xml:ns:xmpp-session" NSSession = "urn:ietf:params:xml:ns:xmpp-session"
NSFraming = "urn:ietf:params:xml:ns:xmpp-framing"
NSClient = "jabber:client" NSClient = "jabber:client"
NSComponent = "jabber:component:accept" NSComponent = "jabber:component:accept"
) )

@ -24,8 +24,10 @@ func InitStream(p *xml.Decoder) (sessionID string, err error) {
switch elem := t.(type) { switch elem := t.(type) {
case xml.StartElement: case xml.StartElement:
if elem.Name.Space != NSStream || elem.Name.Local != "stream" { isStreamOpen := elem.Name.Space == NSStream && elem.Name.Local == "stream"
err = errors.New("xmpp: expected <stream> but got <" + elem.Name.Local + "> in " + elem.Name.Space) isFrameOpen := elem.Name.Space == NSFraming && elem.Name.Local == "open"
if !isStreamOpen && !isFrameOpen {
err = errors.New("xmpp: expected <stream> or <open> but got <" + elem.Name.Local + "> in " + elem.Name.Space)
return sessionID, err return sessionID, err
} }

@ -27,10 +27,11 @@ type TransportConfiguration struct {
type Transport interface { type Transport interface {
Connect() (string, error) Connect() (string, error)
DoesStartTLS() bool DoesStartTLS() bool
StartTLS() (string, error) StartTLS() error
LogTraffic(logFile io.Writer) LogTraffic(logFile io.Writer)
StartStream() (string, error)
GetDecoder() *xml.Decoder GetDecoder() *xml.Decoder
IsSecure() bool IsSecure() bool

@ -57,29 +57,24 @@ func (t *WebsocketTransport) Connect() (string, error) {
t.wsConn = wsConn t.wsConn = wsConn
t.startReader() t.startReader()
handshake := fmt.Sprintf(`<open xmlns="urn:ietf:params:xml:ns:xmpp-framing" to="%s" version="1.0" />`, t.Config.Domain) t.decoder = xml.NewDecoder(t)
if _, err = t.Write([]byte(handshake)); err != nil { t.decoder.CharsetReader = t.Config.CharsetReader
t.cleanup(websocket.StatusBadGateway)
return "", NewConnError(err, false)
}
handshakeResponse := make([]byte, 2048) return t.StartStream()
if _, err = t.Read(handshakeResponse); err != nil { }
t.cleanup(websocket.StatusBadGateway)
return "", NewConnError(err, false) func (t WebsocketTransport) StartStream() (string, error) {
if _, err := fmt.Fprintf(t, `<open xmlns="urn:ietf:params:xml:ns:xmpp-framing" to="%s" version="1.0" />`, t.Config.Domain); err != nil {
t.cleanup(websocket.StatusBadGateway)
return "", NewConnError(err, true)
} }
var openResponse = stanza.WebsocketOpen{} sessionID, err := stanza.InitStream(t.GetDecoder())
if err = xml.Unmarshal(handshakeResponse, &openResponse); err != nil { if err != nil {
t.cleanup(websocket.StatusBadGateway) t.Close()
return "", NewConnError(err, false) return "", NewConnError(err, false)
} }
return sessionID, nil
t.decoder = xml.NewDecoder(t)
t.decoder.CharsetReader = t.Config.CharsetReader
return openResponse.Id, nil
} }
// startReader runs a go function that keeps reading from the websocket. This // startReader runs a go function that keeps reading from the websocket. This
@ -108,14 +103,18 @@ func (t WebsocketTransport) startReader() {
}() }()
} }
func (t WebsocketTransport) StartTLS() (string, error) { func (t WebsocketTransport) StartTLS() error {
return "", ErrTLSNotSupported return ErrTLSNotSupported
} }
func (t WebsocketTransport) DoesStartTLS() bool { func (t WebsocketTransport) DoesStartTLS() bool {
return false return false
} }
func (t WebsocketTransport) GetDomain() string {
return t.Config.Domain
}
func (t WebsocketTransport) GetDecoder() *xml.Decoder { func (t WebsocketTransport) GetDecoder() *xml.Decoder {
return t.decoder return t.decoder
} }
@ -152,20 +151,21 @@ func (t WebsocketTransport) Write(p []byte) (int, error) {
func (t WebsocketTransport) Close() error { func (t WebsocketTransport) Close() error {
t.Write([]byte("<close xmlns=\"urn:ietf:params:xml:ns:xmpp-framing\" />")) t.Write([]byte("<close xmlns=\"urn:ietf:params:xml:ns:xmpp-framing\" />"))
return t.wsConn.Close(websocket.StatusGoingAway, "Done") return t.cleanup(websocket.StatusGoingAway)
} }
func (t *WebsocketTransport) LogTraffic(logFile io.Writer) { func (t *WebsocketTransport) LogTraffic(logFile io.Writer) {
t.logFile = logFile t.logFile = logFile
} }
func (t *WebsocketTransport) cleanup(code websocket.StatusCode) { func (t *WebsocketTransport) cleanup(code websocket.StatusCode) error {
var err error
if t.queue != nil { if t.queue != nil {
close(t.queue) close(t.queue)
t.queue = nil t.queue = nil
} }
if t.wsConn != nil { if t.wsConn != nil {
t.wsConn.Close(websocket.StatusGoingAway, "Done") err = t.wsConn.Close(websocket.StatusGoingAway, "Done")
t.wsConn = nil t.wsConn = nil
} }
if t.closeFunc != nil { if t.closeFunc != nil {
@ -173,4 +173,5 @@ func (t *WebsocketTransport) cleanup(code websocket.StatusCode) {
t.closeFunc = nil t.closeFunc = nil
t.closeCtx = nil t.closeCtx = nil
} }
return err
} }

@ -37,20 +37,20 @@ func (t *XMPPTransport) Connect() (string, error) {
} }
t.readWriter = newStreamLogger(t.conn, t.logFile) t.readWriter = newStreamLogger(t.conn, t.logFile)
return t.startStream() t.decoder = xml.NewDecoder(t.readWriter)
t.decoder.CharsetReader = t.Config.CharsetReader
return t.StartStream()
} }
func (t *XMPPTransport) startStream() (string, error) { func (t XMPPTransport) StartStream() (string, error) {
if _, err := fmt.Fprintf(t.readWriter, t.openStatement, t.Config.Domain); err != nil { if _, err := fmt.Fprintf(t, t.openStatement, t.Config.Domain); err != nil {
t.conn.Close() t.Close()
return "", NewConnError(err, true) return "", NewConnError(err, true)
} }
t.decoder = xml.NewDecoder(t.readWriter) sessionID, err := stanza.InitStream(t.GetDecoder())
t.decoder.CharsetReader = t.Config.CharsetReader
sessionID, err := stanza.InitStream(t.decoder)
if err != nil { if err != nil {
t.conn.Close() t.Close()
return "", NewConnError(err, false) return "", NewConnError(err, false)
} }
return sessionID, nil return sessionID, nil
@ -60,6 +60,10 @@ func (t XMPPTransport) DoesStartTLS() bool {
return true return true
} }
func (t XMPPTransport) GetDomain() string {
return t.Config.Domain
}
func (t XMPPTransport) GetDecoder() *xml.Decoder { func (t XMPPTransport) GetDecoder() *xml.Decoder {
return t.decoder return t.decoder
} }
@ -68,7 +72,7 @@ func (t XMPPTransport) IsSecure() bool {
return t.isSecure return t.isSecure
} }
func (t *XMPPTransport) StartTLS() (string, error) { func (t *XMPPTransport) StartTLS() error {
if t.Config.TLSConfig == nil { if t.Config.TLSConfig == nil {
t.TLSConfig = &tls.Config{} t.TLSConfig = &tls.Config{}
} else { } else {
@ -81,7 +85,7 @@ func (t *XMPPTransport) StartTLS() (string, error) {
tlsConn := tls.Client(t.conn, t.TLSConfig) tlsConn := tls.Client(t.conn, t.TLSConfig)
// We convert existing connection to TLS // We convert existing connection to TLS
if err := tlsConn.Handshake(); err != nil { if err := tlsConn.Handshake(); err != nil {
return "", err return err
} }
t.conn = tlsConn t.conn = tlsConn
@ -91,13 +95,12 @@ func (t *XMPPTransport) StartTLS() (string, error) {
if !t.TLSConfig.InsecureSkipVerify { if !t.TLSConfig.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(t.Config.Domain); err != nil { if err := tlsConn.VerifyHostname(t.Config.Domain); err != nil {
return "", err return err
} }
} }
t.isSecure = true t.isSecure = true
return nil
return t.startStream()
} }
func (t XMPPTransport) Ping() error { func (t XMPPTransport) Ping() error {

Loading…
Cancel
Save