From 36e153f9813d1d419dfa86bfe7b328435370ab38 Mon Sep 17 00:00:00 2001 From: Wichert Akkerman Date: Tue, 15 Oct 2019 20:56:11 +0200 Subject: [PATCH] Allow transports to define their own ping mechanism --- client.go | 4 ++-- transport.go | 1 + websocket_transport.go | 10 ++++++++++ xmpp_transport.go | 12 ++++++++++++ 4 files changed, 25 insertions(+), 2 deletions(-) diff --git a/client.go b/client.go index 7f56d0b..9c74da7 100644 --- a/client.go +++ b/client.go @@ -276,8 +276,8 @@ func keepalive(transport Transport, quit <-chan struct{}) { for { select { case <-ticker.C: - if n, err := fmt.Fprintf(transport, "\n"); err != nil || n != 1 { - // When keep alive fails, we force close the transportection. In all cases, the recv will also fail. + if err := transport.Ping(); err != nil { + // When keepalive fails, we force close the transport. In all cases, the recv will also fail. ticker.Stop() _ = transport.Close() return diff --git a/transport.go b/transport.go index 3e0ca0d..6c4b8e0 100644 --- a/transport.go +++ b/transport.go @@ -25,6 +25,7 @@ type Transport interface { IsSecure() bool + Ping() error Read(p []byte) (n int, err error) Write(p []byte) (n int, err error) Close() error diff --git a/websocket_transport.go b/websocket_transport.go index a526fc4..26ab511 100644 --- a/websocket_transport.go +++ b/websocket_transport.go @@ -9,6 +9,8 @@ import ( "nhooyr.io/websocket" ) +const pingTimeout = time.Duration(5) * time.Second + type WebsocketTransport struct { Config TransportConfiguration wsConn *websocket.Conn @@ -46,6 +48,14 @@ func (t WebsocketTransport) IsSecure() bool { return strings.HasPrefix(t.Config.Address, "wss:") } +func (t WebsocketTransport) Ping() error { + ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) + defer cancel() + // Note that we do not use wsConn.Ping(), because not all websocket servers + // (ejabberd for example) implement ping frames + return t.wsConn.Write(ctx, websocket.MessageText, []byte(" ")) +} + func (t WebsocketTransport) Read(p []byte) (n int, err error) { return t.netConn.Read(p) } diff --git a/xmpp_transport.go b/xmpp_transport.go index 2530b82..614a76d 100644 --- a/xmpp_transport.go +++ b/xmpp_transport.go @@ -2,6 +2,7 @@ package xmpp import ( "crypto/tls" + "errors" "net" "time" ) @@ -59,6 +60,17 @@ func (t *XMPPTransport) StartTLS(domain string) error { return nil } +func (t XMPPTransport) Ping() error { + n, err := t.conn.Write([]byte("\n")) + if err != nil { + return err + } + if n != 1 { + return errors.New("Could not write ping") + } + return nil +} + func (t XMPPTransport) Read(p []byte) (n int, err error) { return t.conn.Read(p) }