diff --git a/auth.go b/auth.go index 66ac983..de38cac 100644 --- a/auth.go +++ b/auth.go @@ -8,7 +8,7 @@ import ( "io" ) -func authSASL(socket io.ReadWriter, decoder *xml.Decoder, f streamFeatures, user string, password string) (err error) { +func authSASL(socket io.ReadWriter, decoder *xml.Decoder, f StreamFeatures, user string, password string) (err error) { // TODO: Implement other type of SASL Authentication havePlain := false for _, m := range f.Mechanisms.Mechanism { diff --git a/check_cert.go b/check_cert.go new file mode 100644 index 0000000..b23b6bb --- /dev/null +++ b/check_cert.go @@ -0,0 +1,146 @@ +package xmpp + +import ( + "crypto/tls" + "encoding/xml" + "errors" + "fmt" + "net" + "strings" + "time" +) + +// TODO: Should I move this as an extension of the client? +// I should probably make the code more modular, but keep concern separated to keep it simple. +type ServerCheck struct { + address string + domain string +} + +func NewChecker(address, domain string) (*ServerCheck, error) { + client := ServerCheck{} + + var err error + var host string + if client.address, host, err = extractParams(address); err != nil { + return &client, err + } + + if domain != "" { + client.domain = domain + } else { + client.domain = host + } + + return &client, nil +} + +// Check triggers actual TCP connection, based on previously defined parameters. +func (c *ServerCheck) Check() error { + var tcpconn net.Conn + var err error + + timeout := 15 * time.Second + tcpconn, err = net.DialTimeout("tcp", c.address, timeout) + if err != nil { + return err + } + + decoder := xml.NewDecoder(tcpconn) + + // Send stream open tag + if _, err = fmt.Fprintf(tcpconn, xmppStreamOpen, c.domain, NSClient, NSStream); err != nil { + return err + } + + // Set xml decoder and extract streamID from reply (not used for now) + _, err = initDecoder(decoder) + if err != nil { + return err + } + + // extract stream features + var f StreamFeatures + packet, err := next(decoder) + if err != nil { + err = fmt.Errorf("stream open decode features: %s", err) + return err + } + + switch p := packet.(type) { + case StreamFeatures: + f = p + case StreamError: + return errors.New("open stream error: " + p.Error.Local) + default: + return errors.New("expected packet received while expecting features, got " + p.Name()) + } + + startTLSFeature := f.StartTLS.XMLName.Space + " " + f.StartTLS.XMLName.Local + if startTLSFeature == nsTLS+" starttls" { + fmt.Fprintf(tcpconn, "") + + var k tlsProceed + if err = decoder.DecodeElement(&k, nil); err != nil { + return fmt.Errorf("expecting starttls proceed: %s", err) + } + + DefaultTlsConfig.ServerName = c.domain + tlsConn := tls.Client(tcpconn, &DefaultTlsConfig) + // We convert existing connection to TLS + if err = tlsConn.Handshake(); err != nil { + return err + } + + // We check that cert matches hostname + if err = tlsConn.VerifyHostname(c.domain); err != nil { + return err + } + + if err = checkExpiration(tlsConn); err != nil { + return err + } + return nil + } + return errors.New("TLS not supported on server") +} + +// Check expiration date for the whole certificate chain and returns an error +// if the expiration date is in less than 48 hours. +func checkExpiration(tlsConn *tls.Conn) error { + checkedCerts := make(map[string]struct{}) + for _, chain := range tlsConn.ConnectionState().VerifiedChains { + for _, cert := range chain { + if _, checked := checkedCerts[string(cert.Signature)]; checked { + continue + } + checkedCerts[string(cert.Signature)] = struct{}{} + + // Check the expiration. + timeNow := time.Now() + expiresInHours := int64(cert.NotAfter.Sub(timeNow).Hours()) + // fmt.Printf("Cert '%s' expires in %d days\n", cert.Subject.CommonName, expiresInHours/24) + if expiresInHours <= 48 { + return fmt.Errorf("certificate '%s' will expire on %s", cert.Subject.CommonName, cert.NotAfter) + } + } + } + return nil +} + +func extractParams(addr string) (string, string, error) { + var err error + hostport := strings.Split(addr, ":") + if len(hostport) > 2 { + err = errors.New("too many colons in xmpp server address") + return addr, hostport[0], err + } + + // Address is composed of two parts, we are good + if len(hostport) == 2 && hostport[1] != "" { + return addr, hostport[0], err + } + + // Port was not passed, we append XMPP default port: + return strings.Join([]string{hostport[0], "5222"}, ":"), hostport[0], err +} diff --git a/cmd/xmpp-check/TODO.md b/cmd/xmpp-check/TODO.md new file mode 100644 index 0000000..682b4cb --- /dev/null +++ b/cmd/xmpp-check/TODO.md @@ -0,0 +1,3 @@ +# TODO + +- Use a config file to define the checks to perform as client on an XMPP server. \ No newline at end of file diff --git a/cmd/xmpp-check/xmpp-check.go b/cmd/xmpp-check/xmpp-check.go new file mode 100644 index 0000000..6e7ceab --- /dev/null +++ b/cmd/xmpp-check/xmpp-check.go @@ -0,0 +1,43 @@ +package main + +import ( + "log" + "os" + + "gosrc.io/xmpp" +) + +func main() { + args := os.Args[1:] + + if len(args) == 0 { + log.Fatal("usage: xmpp-check host[:port] [domain]") + } + + var address string + var domain string + if len(args) >= 1 { + address = args[0] + } + + if len(args) >= 2 { + domain = args[1] + } + + runCheck(address, domain) +} + +func runCheck(address, domain string) { + client, err := xmpp.NewChecker(address, domain) + // client, err := xmpp.NewChecker("mickael.m.in-app.io:5222", "mickael.m.in-app.io") + + if err != nil { + log.Fatal("Error: ", err) + } + + if err = client.Check(); err != nil { + log.Fatal("Failed connection check: ", err) + } + + log.Println("All checks passed") +} diff --git a/parser.go b/parser.go index 898d408..0ad24fc 100644 --- a/parser.go +++ b/parser.go @@ -90,6 +90,8 @@ func decodeStream(p *xml.Decoder, se xml.StartElement) (Packet, error) { switch se.Name.Local { case "error": return streamError.decode(p, se) + case "features": + return streamFeatures.decode(p, se) default: return nil, errors.New("unexpected XMPP packet " + se.Name.Space + " <" + se.Name.Local + "/>") diff --git a/session.go b/session.go index 8259b9a..a15fdb3 100644 --- a/session.go +++ b/session.go @@ -15,7 +15,7 @@ type Session struct { // Session info BindJid string // Jabber ID as provided by XMPP server StreamId string - Features streamFeatures + Features StreamFeatures TlsEnabled bool lastPacketId int @@ -85,7 +85,7 @@ func (s *Session) setProxy(conn net.Conn, newConn net.Conn, o Config) { s.decoder.CharsetReader = o.CharsetReader } -func (s *Session) open(domain string) (f streamFeatures) { +func (s *Session) open(domain string) (f StreamFeatures) { // Send stream open tag if _, s.err = fmt.Fprintf(s.socketProxy, xmppStreamOpen, domain, NSClient, NSStream); s.err != nil { return @@ -121,7 +121,7 @@ func (s *Session) startTlsIfSupported(conn net.Conn, domain string) net.Conn { // TODO: add option to accept all TLS certificates: insecureSkipTlsVerify (DefaultTlsConfig.InsecureSkipVerify) DefaultTlsConfig.ServerName = domain - var tlsConn *tls.Conn = tls.Client(conn, &DefaultTlsConfig) + tlsConn := tls.Client(conn, &DefaultTlsConfig) // We convert existing connection to TLS if s.err = tlsConn.Handshake(); s.err != nil { return tlsConn diff --git a/socket_proxy.go b/socket_proxy.go index 49f2f8e..6068124 100644 --- a/socket_proxy.go +++ b/socket_proxy.go @@ -20,21 +20,21 @@ func newSocketProxy(conn io.ReadWriter, logFile *os.File) io.ReadWriter { } } -func (pl *socketProxy) Read(p []byte) (n int, err error) { - n, err = pl.socket.Read(p) +func (sp *socketProxy) Read(p []byte) (n int, err error) { + n, err = sp.socket.Read(p) if n > 0 { - pl.logFile.Write([]byte("RECV:\n")) // Prefix - if n, err := pl.logFile.Write(p[:n]); err != nil { + sp.logFile.Write([]byte("RECV:\n")) // Prefix + if n, err := sp.logFile.Write(p[:n]); err != nil { return n, err } - pl.logFile.Write([]byte("\n\n")) // Separator + sp.logFile.Write([]byte("\n\n")) // Separator } return } -func (pl *socketProxy) Write(p []byte) (n int, err error) { - pl.logFile.Write([]byte("SEND:\n")) // Prefix - for _, w := range []io.Writer{pl.socket, pl.logFile} { +func (sp *socketProxy) Write(p []byte) (n int, err error) { + sp.logFile.Write([]byte("SEND:\n")) // Prefix + for _, w := range []io.Writer{sp.socket, sp.logFile} { n, err = w.Write(p) if err != nil { return @@ -44,6 +44,6 @@ func (pl *socketProxy) Write(p []byte) (n int, err error) { return } } - pl.logFile.Write([]byte("\n\n")) // Separator + sp.logFile.Write([]byte("\n\n")) // Separator return len(p), nil } diff --git a/stream.go b/stream.go index 3b48db0..1e0df97 100644 --- a/stream.go +++ b/stream.go @@ -7,7 +7,7 @@ import ( // ============================================================================ // StreamFeatures Packet -type streamFeatures struct { +type StreamFeatures struct { XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"` StartTLS tlsStartTLS Caps Caps @@ -17,6 +17,20 @@ type streamFeatures struct { Any []xml.Name `xml:",any"` } +func (StreamFeatures) Name() string { + return "stream:features" +} + +type streamFeatureDecoder struct{} + +var streamFeatures streamFeatureDecoder + +func (streamFeatureDecoder) decode(p *xml.Decoder, se xml.StartElement) (StreamFeatures, error) { + var packet StreamFeatures + err := p.DecodeElement(&packet, &se) + return packet, err +} + // ============================================================================ // StreamError Packet