From ffadd331dd05313f9adf6a90046deb463ec98172 Mon Sep 17 00:00:00 2001 From: Wichert Akkerman Date: Mon, 21 Oct 2019 08:57:44 +0200 Subject: [PATCH] Add a go function to always read websockets Websocket need to have a Reader running at all times in order to allow Ping to work (because a Reader is the only thing that will correctly handle control frames). To faciliate this a go function is introduced that will always read from the websocket until it is cancelled. Read data is passed to the transport via a channel. --- websocket_transport.go | 98 +++++++++++++++++++++++++++++++++--------- 1 file changed, 77 insertions(+), 21 deletions(-) diff --git a/websocket_transport.go b/websocket_transport.go index 391c012..bd8a87b 100644 --- a/websocket_transport.go +++ b/websocket_transport.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "io" - "net" "strings" "time" @@ -14,6 +13,8 @@ import ( "nhooyr.io/websocket" ) +const maxPacketSize = 32768 + const pingTimeout = time.Duration(5) * time.Second var ServerDoesNotSupportXmppOverWebsocket = errors.New("The websocket server does not support the xmpp subprotocol") @@ -22,17 +23,23 @@ type WebsocketTransport struct { Config TransportConfiguration decoder *xml.Decoder wsConn *websocket.Conn - netConn net.Conn + queue chan []byte logFile io.Writer + + closeCtx context.Context + closeFunc context.CancelFunc } func (t *WebsocketTransport) Connect() (string, error) { - ctx := context.Background() + t.queue = make(chan []byte, 256) + t.closeCtx, t.closeFunc = context.WithCancel(context.Background()) + var ctx context.Context + ctx = context.Background() if t.Config.ConnectTimeout > 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, time.Duration(t.Config.ConnectTimeout)*time.Second) - defer cancel() + var cancelConnect context.CancelFunc + ctx, cancelConnect = context.WithTimeout(t.closeCtx, time.Duration(t.Config.ConnectTimeout)*time.Second) + defer cancelConnect() } wsConn, response, err := websocket.Dial(ctx, t.Config.Address, &websocket.DialOptions{ @@ -42,28 +49,30 @@ func (t *WebsocketTransport) Connect() (string, error) { return "", NewConnError(err, true) } if response.Header.Get("Sec-WebSocket-Protocol") != "xmpp" { - _ = wsConn.Close(websocket.StatusBadGateway, "Could not negotiate XMPP subprotocol") + t.cleanup(websocket.StatusBadGateway) return "", NewConnError(ServerDoesNotSupportXmppOverWebsocket, true) } + wsConn.SetReadLimit(maxPacketSize) t.wsConn = wsConn - t.netConn = websocket.NetConn(ctx, t.wsConn, websocket.MessageText) + t.startReader() - handshake := fmt.Sprintf("", t.Config.Domain) + handshake := fmt.Sprintf(``, t.Config.Domain) if _, err = t.Write([]byte(handshake)); err != nil { - _ = wsConn.Close(websocket.StatusBadGateway, "XMPP handshake error") + t.cleanup(websocket.StatusBadGateway) return "", NewConnError(err, false) } handshakeResponse := make([]byte, 2048) if _, err = t.Read(handshakeResponse); err != nil { - _ = wsConn.Close(websocket.StatusBadGateway, "XMPP handshake error") + t.cleanup(websocket.StatusBadGateway) + return "", NewConnError(err, false) } var openResponse = stanza.WebsocketOpen{} if err = xml.Unmarshal(handshakeResponse, &openResponse); err != nil { - _ = wsConn.Close(websocket.StatusBadGateway, "XMPP handshake error") + t.cleanup(websocket.StatusBadGateway) return "", NewConnError(err, false) } @@ -73,6 +82,32 @@ func (t *WebsocketTransport) Connect() (string, error) { return openResponse.Id, nil } +// startReader runs a go function that keeps reading from the websocket. This +// is required to allow Ping() to work: Ping requires a Reader to be running +// to process incoming control frames. +func (t WebsocketTransport) startReader() { + go func() { + buffer := make([]byte, maxPacketSize) + for { + _, reader, err := t.wsConn.Reader(t.closeCtx) + if err != nil { + return + } + n, err := reader.Read(buffer) + if err != nil && err != io.EOF { + return + } + if n > 0 { + // We need to make a copy, otherwise we will overwrite the slice content + // on the next iteration of the for loop. + tmp := make([]byte, len(buffer)) + copy(tmp, buffer) + t.queue <- tmp + } + } + }() +} + func (t WebsocketTransport) StartTLS() error { return TLSNotSupported } @@ -90,31 +125,52 @@ func (t WebsocketTransport) IsSecure() bool { } func (t WebsocketTransport) Ping() error { - ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) + ctx, cancel := context.WithTimeout(t.closeCtx, pingTimeout) defer cancel() return t.wsConn.Ping(ctx) } -func (t *WebsocketTransport) Read(p []byte) (n int, err error) { - n, err = t.netConn.Read(p) - if t.logFile != nil && n > 0 { - _, _ = fmt.Fprintf(t.logFile, "RECV:\n%s\n\n", p) +func (t *WebsocketTransport) Read(p []byte) (int, error) { + select { + case <-t.closeCtx.Done(): + return 0, t.closeCtx.Err() + case data := <-t.queue: + if t.logFile != nil && len(data) > 0 { + _, _ = fmt.Fprintf(t.logFile, "RECV:\n%s\n\n", data) + } + copy(p, data) + return len(data), nil } - return } -func (t WebsocketTransport) Write(p []byte) (n int, err error) { +func (t WebsocketTransport) Write(p []byte) (int, error) { if t.logFile != nil { _, _ = fmt.Fprintf(t.logFile, "SEND:\n%s\n\n", p) } - return t.netConn.Write(p) + return len(p), t.wsConn.Write(t.closeCtx, websocket.MessageText, p) } func (t WebsocketTransport) Close() error { t.Write([]byte("")) - return t.netConn.Close() + return t.wsConn.Close(websocket.StatusGoingAway, "Done") } func (t *WebsocketTransport) LogTraffic(logFile io.Writer) { t.logFile = logFile } + +func (t *WebsocketTransport) cleanup(code websocket.StatusCode) { + if t.queue != nil { + close(t.queue) + t.queue = nil + } + if t.wsConn != nil { + t.wsConn.Close(websocket.StatusGoingAway, "Done") + t.wsConn = nil + } + if t.closeFunc != nil { + t.closeFunc() + t.closeFunc = nil + t.closeCtx = nil + } +}