Correctly open new streams after StartTLS and auth

This commit is contained in:
Wichert Akkerman 2019-10-25 15:55:27 +02:00 committed by Mickaël Rémond
parent 33446ad0ba
commit 6da1962962
6 changed files with 52 additions and 43 deletions

View file

@ -73,9 +73,10 @@ func (s *Session) init(o Config) {
}
func (s *Session) reset(o Config) {
if s.err != nil {
if s.StreamId, s.err = s.transport.StartStream(); s.err != nil {
return
}
s.Features = s.open(o.parsedJid.Domain)
}
@ -108,7 +109,7 @@ func (s *Session) startTlsIfSupported(o Config) {
return
}
s.StreamId, s.err = s.transport.StartTLS()
s.err = s.transport.StartTLS()
if s.err == nil {
s.TlsEnabled = true

View file

@ -6,6 +6,7 @@ const (
NSSASL = "urn:ietf:params:xml:ns:xmpp-sasl"
NSBind = "urn:ietf:params:xml:ns:xmpp-bind"
NSSession = "urn:ietf:params:xml:ns:xmpp-session"
NSFraming = "urn:ietf:params:xml:ns:xmpp-framing"
NSClient = "jabber:client"
NSComponent = "jabber:component:accept"
)

View file

@ -24,8 +24,10 @@ func InitStream(p *xml.Decoder) (sessionID string, err error) {
switch elem := t.(type) {
case xml.StartElement:
if elem.Name.Space != NSStream || elem.Name.Local != "stream" {
err = errors.New("xmpp: expected <stream> but got <" + elem.Name.Local + "> in " + elem.Name.Space)
isStreamOpen := elem.Name.Space == NSStream && elem.Name.Local == "stream"
isFrameOpen := elem.Name.Space == NSFraming && elem.Name.Local == "open"
if !isStreamOpen && !isFrameOpen {
err = errors.New("xmpp: expected <stream> or <open> but got <" + elem.Name.Local + "> in " + elem.Name.Space)
return sessionID, err
}

View file

@ -27,10 +27,11 @@ type TransportConfiguration struct {
type Transport interface {
Connect() (string, error)
DoesStartTLS() bool
StartTLS() (string, error)
StartTLS() error
LogTraffic(logFile io.Writer)
StartStream() (string, error)
GetDecoder() *xml.Decoder
IsSecure() bool

View file

@ -57,29 +57,24 @@ func (t *WebsocketTransport) Connect() (string, error) {
t.wsConn = wsConn
t.startReader()
handshake := fmt.Sprintf(`<open xmlns="urn:ietf:params:xml:ns:xmpp-framing" to="%s" version="1.0" />`, t.Config.Domain)
if _, err = t.Write([]byte(handshake)); err != nil {
t.cleanup(websocket.StatusBadGateway)
return "", NewConnError(err, false)
}
handshakeResponse := make([]byte, 2048)
if _, err = t.Read(handshakeResponse); err != nil {
t.cleanup(websocket.StatusBadGateway)
return "", NewConnError(err, false)
}
var openResponse = stanza.WebsocketOpen{}
if err = xml.Unmarshal(handshakeResponse, &openResponse); err != nil {
t.cleanup(websocket.StatusBadGateway)
return "", NewConnError(err, false)
}
t.decoder = xml.NewDecoder(t)
t.decoder.CharsetReader = t.Config.CharsetReader
return openResponse.Id, nil
return t.StartStream()
}
func (t WebsocketTransport) StartStream() (string, error) {
if _, err := fmt.Fprintf(t, `<open xmlns="urn:ietf:params:xml:ns:xmpp-framing" to="%s" version="1.0" />`, t.Config.Domain); err != nil {
t.cleanup(websocket.StatusBadGateway)
return "", NewConnError(err, true)
}
sessionID, err := stanza.InitStream(t.GetDecoder())
if err != nil {
t.Close()
return "", NewConnError(err, false)
}
return sessionID, nil
}
// startReader runs a go function that keeps reading from the websocket. This
@ -108,14 +103,18 @@ func (t WebsocketTransport) startReader() {
}()
}
func (t WebsocketTransport) StartTLS() (string, error) {
return "", ErrTLSNotSupported
func (t WebsocketTransport) StartTLS() error {
return ErrTLSNotSupported
}
func (t WebsocketTransport) DoesStartTLS() bool {
return false
}
func (t WebsocketTransport) GetDomain() string {
return t.Config.Domain
}
func (t WebsocketTransport) GetDecoder() *xml.Decoder {
return t.decoder
}
@ -152,20 +151,21 @@ func (t WebsocketTransport) Write(p []byte) (int, error) {
func (t WebsocketTransport) Close() error {
t.Write([]byte("<close xmlns=\"urn:ietf:params:xml:ns:xmpp-framing\" />"))
return t.wsConn.Close(websocket.StatusGoingAway, "Done")
return t.cleanup(websocket.StatusGoingAway)
}
func (t *WebsocketTransport) LogTraffic(logFile io.Writer) {
t.logFile = logFile
}
func (t *WebsocketTransport) cleanup(code websocket.StatusCode) {
func (t *WebsocketTransport) cleanup(code websocket.StatusCode) error {
var err error
if t.queue != nil {
close(t.queue)
t.queue = nil
}
if t.wsConn != nil {
t.wsConn.Close(websocket.StatusGoingAway, "Done")
err = t.wsConn.Close(websocket.StatusGoingAway, "Done")
t.wsConn = nil
}
if t.closeFunc != nil {
@ -173,4 +173,5 @@ func (t *WebsocketTransport) cleanup(code websocket.StatusCode) {
t.closeFunc = nil
t.closeCtx = nil
}
return err
}

View file

@ -37,20 +37,20 @@ func (t *XMPPTransport) Connect() (string, error) {
}
t.readWriter = newStreamLogger(t.conn, t.logFile)
return t.startStream()
t.decoder = xml.NewDecoder(t.readWriter)
t.decoder.CharsetReader = t.Config.CharsetReader
return t.StartStream()
}
func (t *XMPPTransport) startStream() (string, error) {
if _, err := fmt.Fprintf(t.readWriter, t.openStatement, t.Config.Domain); err != nil {
t.conn.Close()
func (t XMPPTransport) StartStream() (string, error) {
if _, err := fmt.Fprintf(t, t.openStatement, t.Config.Domain); err != nil {
t.Close()
return "", NewConnError(err, true)
}
t.decoder = xml.NewDecoder(t.readWriter)
t.decoder.CharsetReader = t.Config.CharsetReader
sessionID, err := stanza.InitStream(t.decoder)
sessionID, err := stanza.InitStream(t.GetDecoder())
if err != nil {
t.conn.Close()
t.Close()
return "", NewConnError(err, false)
}
return sessionID, nil
@ -60,6 +60,10 @@ func (t XMPPTransport) DoesStartTLS() bool {
return true
}
func (t XMPPTransport) GetDomain() string {
return t.Config.Domain
}
func (t XMPPTransport) GetDecoder() *xml.Decoder {
return t.decoder
}
@ -68,7 +72,7 @@ func (t XMPPTransport) IsSecure() bool {
return t.isSecure
}
func (t *XMPPTransport) StartTLS() (string, error) {
func (t *XMPPTransport) StartTLS() error {
if t.Config.TLSConfig == nil {
t.TLSConfig = &tls.Config{}
} else {
@ -81,7 +85,7 @@ func (t *XMPPTransport) StartTLS() (string, error) {
tlsConn := tls.Client(t.conn, t.TLSConfig)
// We convert existing connection to TLS
if err := tlsConn.Handshake(); err != nil {
return "", err
return err
}
t.conn = tlsConn
@ -91,13 +95,12 @@ func (t *XMPPTransport) StartTLS() (string, error) {
if !t.TLSConfig.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(t.Config.Domain); err != nil {
return "", err
return err
}
}
t.isSecure = true
return t.startStream()
return nil
}
func (t XMPPTransport) Ping() error {