Correctly open new streams after StartTLS and auth
This commit is contained in:
parent
33446ad0ba
commit
6da1962962
|
@ -73,9 +73,10 @@ func (s *Session) init(o Config) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) reset(o Config) {
|
func (s *Session) reset(o Config) {
|
||||||
if s.err != nil {
|
if s.StreamId, s.err = s.transport.StartStream(); s.err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
s.Features = s.open(o.parsedJid.Domain)
|
s.Features = s.open(o.parsedJid.Domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -108,7 +109,7 @@ func (s *Session) startTlsIfSupported(o Config) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
s.StreamId, s.err = s.transport.StartTLS()
|
s.err = s.transport.StartTLS()
|
||||||
|
|
||||||
if s.err == nil {
|
if s.err == nil {
|
||||||
s.TlsEnabled = true
|
s.TlsEnabled = true
|
||||||
|
|
|
@ -6,6 +6,7 @@ const (
|
||||||
NSSASL = "urn:ietf:params:xml:ns:xmpp-sasl"
|
NSSASL = "urn:ietf:params:xml:ns:xmpp-sasl"
|
||||||
NSBind = "urn:ietf:params:xml:ns:xmpp-bind"
|
NSBind = "urn:ietf:params:xml:ns:xmpp-bind"
|
||||||
NSSession = "urn:ietf:params:xml:ns:xmpp-session"
|
NSSession = "urn:ietf:params:xml:ns:xmpp-session"
|
||||||
|
NSFraming = "urn:ietf:params:xml:ns:xmpp-framing"
|
||||||
NSClient = "jabber:client"
|
NSClient = "jabber:client"
|
||||||
NSComponent = "jabber:component:accept"
|
NSComponent = "jabber:component:accept"
|
||||||
)
|
)
|
||||||
|
|
|
@ -24,8 +24,10 @@ func InitStream(p *xml.Decoder) (sessionID string, err error) {
|
||||||
|
|
||||||
switch elem := t.(type) {
|
switch elem := t.(type) {
|
||||||
case xml.StartElement:
|
case xml.StartElement:
|
||||||
if elem.Name.Space != NSStream || elem.Name.Local != "stream" {
|
isStreamOpen := elem.Name.Space == NSStream && elem.Name.Local == "stream"
|
||||||
err = errors.New("xmpp: expected <stream> but got <" + elem.Name.Local + "> in " + elem.Name.Space)
|
isFrameOpen := elem.Name.Space == NSFraming && elem.Name.Local == "open"
|
||||||
|
if !isStreamOpen && !isFrameOpen {
|
||||||
|
err = errors.New("xmpp: expected <stream> or <open> but got <" + elem.Name.Local + "> in " + elem.Name.Space)
|
||||||
return sessionID, err
|
return sessionID, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -27,10 +27,11 @@ type TransportConfiguration struct {
|
||||||
type Transport interface {
|
type Transport interface {
|
||||||
Connect() (string, error)
|
Connect() (string, error)
|
||||||
DoesStartTLS() bool
|
DoesStartTLS() bool
|
||||||
StartTLS() (string, error)
|
StartTLS() error
|
||||||
|
|
||||||
LogTraffic(logFile io.Writer)
|
LogTraffic(logFile io.Writer)
|
||||||
|
|
||||||
|
StartStream() (string, error)
|
||||||
GetDecoder() *xml.Decoder
|
GetDecoder() *xml.Decoder
|
||||||
IsSecure() bool
|
IsSecure() bool
|
||||||
|
|
||||||
|
|
|
@ -57,29 +57,24 @@ func (t *WebsocketTransport) Connect() (string, error) {
|
||||||
t.wsConn = wsConn
|
t.wsConn = wsConn
|
||||||
t.startReader()
|
t.startReader()
|
||||||
|
|
||||||
handshake := fmt.Sprintf(`<open xmlns="urn:ietf:params:xml:ns:xmpp-framing" to="%s" version="1.0" />`, t.Config.Domain)
|
|
||||||
if _, err = t.Write([]byte(handshake)); err != nil {
|
|
||||||
t.cleanup(websocket.StatusBadGateway)
|
|
||||||
return "", NewConnError(err, false)
|
|
||||||
}
|
|
||||||
|
|
||||||
handshakeResponse := make([]byte, 2048)
|
|
||||||
if _, err = t.Read(handshakeResponse); err != nil {
|
|
||||||
t.cleanup(websocket.StatusBadGateway)
|
|
||||||
|
|
||||||
return "", NewConnError(err, false)
|
|
||||||
}
|
|
||||||
|
|
||||||
var openResponse = stanza.WebsocketOpen{}
|
|
||||||
if err = xml.Unmarshal(handshakeResponse, &openResponse); err != nil {
|
|
||||||
t.cleanup(websocket.StatusBadGateway)
|
|
||||||
return "", NewConnError(err, false)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.decoder = xml.NewDecoder(t)
|
t.decoder = xml.NewDecoder(t)
|
||||||
t.decoder.CharsetReader = t.Config.CharsetReader
|
t.decoder.CharsetReader = t.Config.CharsetReader
|
||||||
|
|
||||||
return openResponse.Id, nil
|
return t.StartStream()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t WebsocketTransport) StartStream() (string, error) {
|
||||||
|
if _, err := fmt.Fprintf(t, `<open xmlns="urn:ietf:params:xml:ns:xmpp-framing" to="%s" version="1.0" />`, t.Config.Domain); err != nil {
|
||||||
|
t.cleanup(websocket.StatusBadGateway)
|
||||||
|
return "", NewConnError(err, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionID, err := stanza.InitStream(t.GetDecoder())
|
||||||
|
if err != nil {
|
||||||
|
t.Close()
|
||||||
|
return "", NewConnError(err, false)
|
||||||
|
}
|
||||||
|
return sessionID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// startReader runs a go function that keeps reading from the websocket. This
|
// startReader runs a go function that keeps reading from the websocket. This
|
||||||
|
@ -108,14 +103,18 @@ func (t WebsocketTransport) startReader() {
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t WebsocketTransport) StartTLS() (string, error) {
|
func (t WebsocketTransport) StartTLS() error {
|
||||||
return "", ErrTLSNotSupported
|
return ErrTLSNotSupported
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t WebsocketTransport) DoesStartTLS() bool {
|
func (t WebsocketTransport) DoesStartTLS() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t WebsocketTransport) GetDomain() string {
|
||||||
|
return t.Config.Domain
|
||||||
|
}
|
||||||
|
|
||||||
func (t WebsocketTransport) GetDecoder() *xml.Decoder {
|
func (t WebsocketTransport) GetDecoder() *xml.Decoder {
|
||||||
return t.decoder
|
return t.decoder
|
||||||
}
|
}
|
||||||
|
@ -152,20 +151,21 @@ func (t WebsocketTransport) Write(p []byte) (int, error) {
|
||||||
|
|
||||||
func (t WebsocketTransport) Close() error {
|
func (t WebsocketTransport) Close() error {
|
||||||
t.Write([]byte("<close xmlns=\"urn:ietf:params:xml:ns:xmpp-framing\" />"))
|
t.Write([]byte("<close xmlns=\"urn:ietf:params:xml:ns:xmpp-framing\" />"))
|
||||||
return t.wsConn.Close(websocket.StatusGoingAway, "Done")
|
return t.cleanup(websocket.StatusGoingAway)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *WebsocketTransport) LogTraffic(logFile io.Writer) {
|
func (t *WebsocketTransport) LogTraffic(logFile io.Writer) {
|
||||||
t.logFile = logFile
|
t.logFile = logFile
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *WebsocketTransport) cleanup(code websocket.StatusCode) {
|
func (t *WebsocketTransport) cleanup(code websocket.StatusCode) error {
|
||||||
|
var err error
|
||||||
if t.queue != nil {
|
if t.queue != nil {
|
||||||
close(t.queue)
|
close(t.queue)
|
||||||
t.queue = nil
|
t.queue = nil
|
||||||
}
|
}
|
||||||
if t.wsConn != nil {
|
if t.wsConn != nil {
|
||||||
t.wsConn.Close(websocket.StatusGoingAway, "Done")
|
err = t.wsConn.Close(websocket.StatusGoingAway, "Done")
|
||||||
t.wsConn = nil
|
t.wsConn = nil
|
||||||
}
|
}
|
||||||
if t.closeFunc != nil {
|
if t.closeFunc != nil {
|
||||||
|
@ -173,4 +173,5 @@ func (t *WebsocketTransport) cleanup(code websocket.StatusCode) {
|
||||||
t.closeFunc = nil
|
t.closeFunc = nil
|
||||||
t.closeCtx = nil
|
t.closeCtx = nil
|
||||||
}
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,20 +37,20 @@ func (t *XMPPTransport) Connect() (string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
t.readWriter = newStreamLogger(t.conn, t.logFile)
|
t.readWriter = newStreamLogger(t.conn, t.logFile)
|
||||||
return t.startStream()
|
t.decoder = xml.NewDecoder(t.readWriter)
|
||||||
|
t.decoder.CharsetReader = t.Config.CharsetReader
|
||||||
|
return t.StartStream()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *XMPPTransport) startStream() (string, error) {
|
func (t XMPPTransport) StartStream() (string, error) {
|
||||||
if _, err := fmt.Fprintf(t.readWriter, t.openStatement, t.Config.Domain); err != nil {
|
if _, err := fmt.Fprintf(t, t.openStatement, t.Config.Domain); err != nil {
|
||||||
t.conn.Close()
|
t.Close()
|
||||||
return "", NewConnError(err, true)
|
return "", NewConnError(err, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.decoder = xml.NewDecoder(t.readWriter)
|
sessionID, err := stanza.InitStream(t.GetDecoder())
|
||||||
t.decoder.CharsetReader = t.Config.CharsetReader
|
|
||||||
sessionID, err := stanza.InitStream(t.decoder)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.conn.Close()
|
t.Close()
|
||||||
return "", NewConnError(err, false)
|
return "", NewConnError(err, false)
|
||||||
}
|
}
|
||||||
return sessionID, nil
|
return sessionID, nil
|
||||||
|
@ -60,6 +60,10 @@ func (t XMPPTransport) DoesStartTLS() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t XMPPTransport) GetDomain() string {
|
||||||
|
return t.Config.Domain
|
||||||
|
}
|
||||||
|
|
||||||
func (t XMPPTransport) GetDecoder() *xml.Decoder {
|
func (t XMPPTransport) GetDecoder() *xml.Decoder {
|
||||||
return t.decoder
|
return t.decoder
|
||||||
}
|
}
|
||||||
|
@ -68,7 +72,7 @@ func (t XMPPTransport) IsSecure() bool {
|
||||||
return t.isSecure
|
return t.isSecure
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *XMPPTransport) StartTLS() (string, error) {
|
func (t *XMPPTransport) StartTLS() error {
|
||||||
if t.Config.TLSConfig == nil {
|
if t.Config.TLSConfig == nil {
|
||||||
t.TLSConfig = &tls.Config{}
|
t.TLSConfig = &tls.Config{}
|
||||||
} else {
|
} else {
|
||||||
|
@ -81,7 +85,7 @@ func (t *XMPPTransport) StartTLS() (string, error) {
|
||||||
tlsConn := tls.Client(t.conn, t.TLSConfig)
|
tlsConn := tls.Client(t.conn, t.TLSConfig)
|
||||||
// We convert existing connection to TLS
|
// We convert existing connection to TLS
|
||||||
if err := tlsConn.Handshake(); err != nil {
|
if err := tlsConn.Handshake(); err != nil {
|
||||||
return "", err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
t.conn = tlsConn
|
t.conn = tlsConn
|
||||||
|
@ -91,13 +95,12 @@ func (t *XMPPTransport) StartTLS() (string, error) {
|
||||||
|
|
||||||
if !t.TLSConfig.InsecureSkipVerify {
|
if !t.TLSConfig.InsecureSkipVerify {
|
||||||
if err := tlsConn.VerifyHostname(t.Config.Domain); err != nil {
|
if err := tlsConn.VerifyHostname(t.Config.Domain); err != nil {
|
||||||
return "", err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
t.isSecure = true
|
t.isSecure = true
|
||||||
|
return nil
|
||||||
return t.startStream()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t XMPPTransport) Ping() error {
|
func (t XMPPTransport) Ping() error {
|
||||||
|
|
Loading…
Reference in a new issue