@ -6,7 +6,6 @@ import (
"errors"
"fmt"
"io"
"net"
"strings"
"time"
@ -14,6 +13,8 @@ import (
"nhooyr.io/websocket"
)
const maxPacketSize = 32768
const pingTimeout = time . Duration ( 5 ) * time . Second
var ServerDoesNotSupportXmppOverWebsocket = errors . New ( "The websocket server does not support the xmpp subprotocol" )
@ -22,17 +23,23 @@ type WebsocketTransport struct {
Config TransportConfiguration
decoder * xml . Decoder
wsConn * websocket . Conn
netConn net . Conn
queue chan [ ] byte
logFile io . Writer
closeCtx context . Context
closeFunc context . CancelFunc
}
func ( t * WebsocketTransport ) Connect ( ) ( string , error ) {
ctx := context . Background ( )
t . queue = make ( chan [ ] byte , 256 )
t . closeCtx , t . closeFunc = context . WithCancel ( context . Background ( ) )
var ctx context . Context
ctx = context . Background ( )
if t . Config . ConnectTimeout > 0 {
var cancel context . CancelFunc
ctx , cancel = context . WithTimeout ( ctx, time . Duration ( t . Config . ConnectTimeout ) * time . Second )
defer cancel ( )
var cancel Connect context . CancelFunc
ctx , cancel Connect = context . WithTimeout ( t. closeC tx, time . Duration ( t . Config . ConnectTimeout ) * time . Second )
defer cancel Connect ( )
}
wsConn , response , err := websocket . Dial ( ctx , t . Config . Address , & websocket . DialOptions {
@ -42,28 +49,30 @@ func (t *WebsocketTransport) Connect() (string, error) {
return "" , NewConnError ( err , true )
}
if response . Header . Get ( "Sec-WebSocket-Protocol" ) != "xmpp" {
_ = wsConn . Close ( websocket . StatusBadGateway , "Could not negotiate XMPP subprotocol" )
t. cleanup ( websocket . StatusBadGateway )
return "" , NewConnError ( ServerDoesNotSupportXmppOverWebsocket , true )
}
wsConn . SetReadLimit ( maxPacketSize )
t . wsConn = wsConn
t . netConn = websocket . NetConn ( ctx , t . wsConn , websocket . MessageText )
t . startReader ( )
handshake := fmt . Sprintf ( "<open xmlns=\"urn:ietf:params:xml:ns:xmpp-framing\" to=\"%s\" version=\"1.0\" />" , t . Config . Domain )
handshake := fmt . Sprintf ( `<open xmlns="urn:ietf:params:xml:ns:xmpp-framing" to="%s" version="1.0" /> ` , t . Config . Domain )
if _ , err = t . Write ( [ ] byte ( handshake ) ) ; err != nil {
_ = wsConn . Close ( websocket . StatusBadGateway , "XMPP handshake error" )
t. cleanup ( websocket . StatusBadGateway )
return "" , NewConnError ( err , false )
}
handshakeResponse := make ( [ ] byte , 2048 )
if _ , err = t . Read ( handshakeResponse ) ; err != nil {
_ = wsConn . Close ( websocket . StatusBadGateway , "XMPP handshake error" )
t . cleanup ( websocket . StatusBadGateway )
return "" , NewConnError ( err , false )
}
var openResponse = stanza . WebsocketOpen { }
if err = xml . Unmarshal ( handshakeResponse , & openResponse ) ; err != nil {
_ = wsConn . Close ( websocket . StatusBadGateway , "XMPP handshake error" )
t. cleanup ( websocket . StatusBadGateway )
return "" , NewConnError ( err , false )
}
@ -73,6 +82,32 @@ func (t *WebsocketTransport) Connect() (string, error) {
return openResponse . Id , nil
}
// startReader runs a go function that keeps reading from the websocket. This
// is required to allow Ping() to work: Ping requires a Reader to be running
// to process incoming control frames.
func ( t WebsocketTransport ) startReader ( ) {
go func ( ) {
buffer := make ( [ ] byte , maxPacketSize )
for {
_ , reader , err := t . wsConn . Reader ( t . closeCtx )
if err != nil {
return
}
n , err := reader . Read ( buffer )
if err != nil && err != io . EOF {
return
}
if n > 0 {
// We need to make a copy, otherwise we will overwrite the slice content
// on the next iteration of the for loop.
tmp := make ( [ ] byte , len ( buffer ) )
copy ( tmp , buffer )
t . queue <- tmp
}
}
} ( )
}
func ( t WebsocketTransport ) StartTLS ( ) error {
return TLSNotSupported
}
@ -90,31 +125,52 @@ func (t WebsocketTransport) IsSecure() bool {
}
func ( t WebsocketTransport ) Ping ( ) error {
ctx , cancel := context . WithTimeout ( context. Background ( ) , pingTimeout )
ctx , cancel := context . WithTimeout ( t. closeCtx , pingTimeout )
defer cancel ( )
return t . wsConn . Ping ( ctx )
}
func ( t * WebsocketTransport ) Read ( p [ ] byte ) ( n int , err error ) {
n , err = t . netConn . Read ( p )
if t . logFile != nil && n > 0 {
_ , _ = fmt . Fprintf ( t . logFile , "RECV:\n%s\n\n" , p )
func ( t * WebsocketTransport ) Read ( p [ ] byte ) ( int , error ) {
select {
case <- t . closeCtx . Done ( ) :
return 0 , t . closeCtx . Err ( )
case data := <- t . queue :
if t . logFile != nil && len ( data ) > 0 {
_ , _ = fmt . Fprintf ( t . logFile , "RECV:\n%s\n\n" , data )
}
copy ( p , data )
return len ( data ) , nil
}
return
}
func ( t WebsocketTransport ) Write ( p [ ] byte ) ( n int , err error ) {
func ( t WebsocketTransport ) Write ( p [ ] byte ) ( int , error ) {
if t . logFile != nil {
_ , _ = fmt . Fprintf ( t . logFile , "SEND:\n%s\n\n" , p )
}
return t . netConn . Write ( p )
return len ( p ) , t . wsConn . Write ( t . closeCtx , websocket . MessageText , p )
}
func ( t WebsocketTransport ) Close ( ) error {
t . Write ( [ ] byte ( "<close xmlns=\"urn:ietf:params:xml:ns:xmpp-framing\" />" ) )
return t . netConn. Close ( )
return t . wsConn. Close ( websocket . StatusGoingAway , "Done" )
}
func ( t * WebsocketTransport ) LogTraffic ( logFile io . Writer ) {
t . logFile = logFile
}
func ( t * WebsocketTransport ) cleanup ( code websocket . StatusCode ) {
if t . queue != nil {
close ( t . queue )
t . queue = nil
}
if t . wsConn != nil {
t . wsConn . Close ( websocket . StatusGoingAway , "Done" )
t . wsConn = nil
}
if t . closeFunc != nil {
t . closeFunc ( )
t . closeFunc = nil
t . closeCtx = nil
}
}