Correctly open new streams after StartTLS and auth
This commit is contained in:
parent
33446ad0ba
commit
6da1962962
|
@ -73,9 +73,10 @@ func (s *Session) init(o Config) {
|
|||
}
|
||||
|
||||
func (s *Session) reset(o Config) {
|
||||
if s.err != nil {
|
||||
if s.StreamId, s.err = s.transport.StartStream(); s.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.Features = s.open(o.parsedJid.Domain)
|
||||
}
|
||||
|
||||
|
@ -108,7 +109,7 @@ func (s *Session) startTlsIfSupported(o Config) {
|
|||
return
|
||||
}
|
||||
|
||||
s.StreamId, s.err = s.transport.StartTLS()
|
||||
s.err = s.transport.StartTLS()
|
||||
|
||||
if s.err == nil {
|
||||
s.TlsEnabled = true
|
||||
|
|
|
@ -6,6 +6,7 @@ const (
|
|||
NSSASL = "urn:ietf:params:xml:ns:xmpp-sasl"
|
||||
NSBind = "urn:ietf:params:xml:ns:xmpp-bind"
|
||||
NSSession = "urn:ietf:params:xml:ns:xmpp-session"
|
||||
NSFraming = "urn:ietf:params:xml:ns:xmpp-framing"
|
||||
NSClient = "jabber:client"
|
||||
NSComponent = "jabber:component:accept"
|
||||
)
|
||||
|
|
|
@ -24,8 +24,10 @@ func InitStream(p *xml.Decoder) (sessionID string, err error) {
|
|||
|
||||
switch elem := t.(type) {
|
||||
case xml.StartElement:
|
||||
if elem.Name.Space != NSStream || elem.Name.Local != "stream" {
|
||||
err = errors.New("xmpp: expected <stream> but got <" + elem.Name.Local + "> in " + elem.Name.Space)
|
||||
isStreamOpen := elem.Name.Space == NSStream && elem.Name.Local == "stream"
|
||||
isFrameOpen := elem.Name.Space == NSFraming && elem.Name.Local == "open"
|
||||
if !isStreamOpen && !isFrameOpen {
|
||||
err = errors.New("xmpp: expected <stream> or <open> but got <" + elem.Name.Local + "> in " + elem.Name.Space)
|
||||
return sessionID, err
|
||||
}
|
||||
|
||||
|
|
|
@ -27,10 +27,11 @@ type TransportConfiguration struct {
|
|||
type Transport interface {
|
||||
Connect() (string, error)
|
||||
DoesStartTLS() bool
|
||||
StartTLS() (string, error)
|
||||
StartTLS() error
|
||||
|
||||
LogTraffic(logFile io.Writer)
|
||||
|
||||
StartStream() (string, error)
|
||||
GetDecoder() *xml.Decoder
|
||||
IsSecure() bool
|
||||
|
||||
|
|
|
@ -57,29 +57,24 @@ func (t *WebsocketTransport) Connect() (string, error) {
|
|||
t.wsConn = wsConn
|
||||
t.startReader()
|
||||
|
||||
handshake := fmt.Sprintf(`<open xmlns="urn:ietf:params:xml:ns:xmpp-framing" to="%s" version="1.0" />`, t.Config.Domain)
|
||||
if _, err = t.Write([]byte(handshake)); err != nil {
|
||||
t.cleanup(websocket.StatusBadGateway)
|
||||
return "", NewConnError(err, false)
|
||||
}
|
||||
|
||||
handshakeResponse := make([]byte, 2048)
|
||||
if _, err = t.Read(handshakeResponse); err != nil {
|
||||
t.cleanup(websocket.StatusBadGateway)
|
||||
|
||||
return "", NewConnError(err, false)
|
||||
}
|
||||
|
||||
var openResponse = stanza.WebsocketOpen{}
|
||||
if err = xml.Unmarshal(handshakeResponse, &openResponse); err != nil {
|
||||
t.cleanup(websocket.StatusBadGateway)
|
||||
return "", NewConnError(err, false)
|
||||
}
|
||||
|
||||
t.decoder = xml.NewDecoder(t)
|
||||
t.decoder.CharsetReader = t.Config.CharsetReader
|
||||
|
||||
return openResponse.Id, nil
|
||||
return t.StartStream()
|
||||
}
|
||||
|
||||
func (t WebsocketTransport) StartStream() (string, error) {
|
||||
if _, err := fmt.Fprintf(t, `<open xmlns="urn:ietf:params:xml:ns:xmpp-framing" to="%s" version="1.0" />`, t.Config.Domain); err != nil {
|
||||
t.cleanup(websocket.StatusBadGateway)
|
||||
return "", NewConnError(err, true)
|
||||
}
|
||||
|
||||
sessionID, err := stanza.InitStream(t.GetDecoder())
|
||||
if err != nil {
|
||||
t.Close()
|
||||
return "", NewConnError(err, false)
|
||||
}
|
||||
return sessionID, nil
|
||||
}
|
||||
|
||||
// startReader runs a go function that keeps reading from the websocket. This
|
||||
|
@ -108,14 +103,18 @@ func (t WebsocketTransport) startReader() {
|
|||
}()
|
||||
}
|
||||
|
||||
func (t WebsocketTransport) StartTLS() (string, error) {
|
||||
return "", ErrTLSNotSupported
|
||||
func (t WebsocketTransport) StartTLS() error {
|
||||
return ErrTLSNotSupported
|
||||
}
|
||||
|
||||
func (t WebsocketTransport) DoesStartTLS() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (t WebsocketTransport) GetDomain() string {
|
||||
return t.Config.Domain
|
||||
}
|
||||
|
||||
func (t WebsocketTransport) GetDecoder() *xml.Decoder {
|
||||
return t.decoder
|
||||
}
|
||||
|
@ -152,20 +151,21 @@ func (t WebsocketTransport) Write(p []byte) (int, error) {
|
|||
|
||||
func (t WebsocketTransport) Close() error {
|
||||
t.Write([]byte("<close xmlns=\"urn:ietf:params:xml:ns:xmpp-framing\" />"))
|
||||
return t.wsConn.Close(websocket.StatusGoingAway, "Done")
|
||||
return t.cleanup(websocket.StatusGoingAway)
|
||||
}
|
||||
|
||||
func (t *WebsocketTransport) LogTraffic(logFile io.Writer) {
|
||||
t.logFile = logFile
|
||||
}
|
||||
|
||||
func (t *WebsocketTransport) cleanup(code websocket.StatusCode) {
|
||||
func (t *WebsocketTransport) cleanup(code websocket.StatusCode) error {
|
||||
var err error
|
||||
if t.queue != nil {
|
||||
close(t.queue)
|
||||
t.queue = nil
|
||||
}
|
||||
if t.wsConn != nil {
|
||||
t.wsConn.Close(websocket.StatusGoingAway, "Done")
|
||||
err = t.wsConn.Close(websocket.StatusGoingAway, "Done")
|
||||
t.wsConn = nil
|
||||
}
|
||||
if t.closeFunc != nil {
|
||||
|
@ -173,4 +173,5 @@ func (t *WebsocketTransport) cleanup(code websocket.StatusCode) {
|
|||
t.closeFunc = nil
|
||||
t.closeCtx = nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -37,20 +37,20 @@ func (t *XMPPTransport) Connect() (string, error) {
|
|||
}
|
||||
|
||||
t.readWriter = newStreamLogger(t.conn, t.logFile)
|
||||
return t.startStream()
|
||||
t.decoder = xml.NewDecoder(t.readWriter)
|
||||
t.decoder.CharsetReader = t.Config.CharsetReader
|
||||
return t.StartStream()
|
||||
}
|
||||
|
||||
func (t *XMPPTransport) startStream() (string, error) {
|
||||
if _, err := fmt.Fprintf(t.readWriter, t.openStatement, t.Config.Domain); err != nil {
|
||||
t.conn.Close()
|
||||
func (t XMPPTransport) StartStream() (string, error) {
|
||||
if _, err := fmt.Fprintf(t, t.openStatement, t.Config.Domain); err != nil {
|
||||
t.Close()
|
||||
return "", NewConnError(err, true)
|
||||
}
|
||||
|
||||
t.decoder = xml.NewDecoder(t.readWriter)
|
||||
t.decoder.CharsetReader = t.Config.CharsetReader
|
||||
sessionID, err := stanza.InitStream(t.decoder)
|
||||
sessionID, err := stanza.InitStream(t.GetDecoder())
|
||||
if err != nil {
|
||||
t.conn.Close()
|
||||
t.Close()
|
||||
return "", NewConnError(err, false)
|
||||
}
|
||||
return sessionID, nil
|
||||
|
@ -60,6 +60,10 @@ func (t XMPPTransport) DoesStartTLS() bool {
|
|||
return true
|
||||
}
|
||||
|
||||
func (t XMPPTransport) GetDomain() string {
|
||||
return t.Config.Domain
|
||||
}
|
||||
|
||||
func (t XMPPTransport) GetDecoder() *xml.Decoder {
|
||||
return t.decoder
|
||||
}
|
||||
|
@ -68,7 +72,7 @@ func (t XMPPTransport) IsSecure() bool {
|
|||
return t.isSecure
|
||||
}
|
||||
|
||||
func (t *XMPPTransport) StartTLS() (string, error) {
|
||||
func (t *XMPPTransport) StartTLS() error {
|
||||
if t.Config.TLSConfig == nil {
|
||||
t.TLSConfig = &tls.Config{}
|
||||
} else {
|
||||
|
@ -81,7 +85,7 @@ func (t *XMPPTransport) StartTLS() (string, error) {
|
|||
tlsConn := tls.Client(t.conn, t.TLSConfig)
|
||||
// We convert existing connection to TLS
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
return "", err
|
||||
return err
|
||||
}
|
||||
|
||||
t.conn = tlsConn
|
||||
|
@ -91,13 +95,12 @@ func (t *XMPPTransport) StartTLS() (string, error) {
|
|||
|
||||
if !t.TLSConfig.InsecureSkipVerify {
|
||||
if err := tlsConn.VerifyHostname(t.Config.Domain); err != nil {
|
||||
return "", err
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
t.isSecure = true
|
||||
|
||||
return t.startStream()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t XMPPTransport) Ping() error {
|
||||
|
|
Loading…
Reference in a new issue