You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

151 lines
3.7 KiB

  1. package xmpp
  2. import (
  3. "crypto/tls"
  4. "encoding/xml"
  5. "errors"
  6. "fmt"
  7. "net"
  8. "strings"
  9. "time"
  10. "gosrc.io/xmpp/stanza"
  11. )
  12. // TODO: Should I move this as an extension of the client?
  13. // I should probably make the code more modular, but keep concern separated to keep it simple.
  14. type ServerCheck struct {
  15. address string
  16. domain string
  17. }
  18. func NewChecker(address, domain string) (*ServerCheck, error) {
  19. client := ServerCheck{}
  20. var err error
  21. var host string
  22. if client.address, host, err = extractParams(address); err != nil {
  23. return &client, err
  24. }
  25. if domain != "" {
  26. client.domain = domain
  27. } else {
  28. client.domain = host
  29. }
  30. return &client, nil
  31. }
  32. // Check triggers actual TCP connection, based on previously defined parameters.
  33. func (c *ServerCheck) Check() error {
  34. var tcpconn net.Conn
  35. var err error
  36. timeout := 15 * time.Second
  37. tcpconn, err = net.DialTimeout("tcp", c.address, timeout)
  38. if err != nil {
  39. return err
  40. }
  41. decoder := xml.NewDecoder(tcpconn)
  42. // Send stream open tag
  43. if _, err = fmt.Fprintf(tcpconn, clientStreamOpen, c.domain); err != nil {
  44. return err
  45. }
  46. // Set xml decoder and extract streamID from reply (not used for now)
  47. _, err = stanza.InitStream(decoder)
  48. if err != nil {
  49. return err
  50. }
  51. // extract stream features
  52. var f stanza.StreamFeatures
  53. packet, err := stanza.NextPacket(decoder)
  54. if err != nil {
  55. err = fmt.Errorf("stream open decode features: %s", err)
  56. return err
  57. }
  58. switch p := packet.(type) {
  59. case stanza.StreamFeatures:
  60. f = p
  61. case stanza.StreamError:
  62. return errors.New("open stream error: " + p.Error.Local)
  63. default:
  64. return errors.New("expected packet received while expecting features, got " + p.Name())
  65. }
  66. if _, ok := f.DoesStartTLS(); ok {
  67. _, err = fmt.Fprintf(tcpconn, "<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>")
  68. if err != nil {
  69. return err
  70. }
  71. var k stanza.TLSProceed
  72. if err = decoder.DecodeElement(&k, nil); err != nil {
  73. return fmt.Errorf("expecting starttls proceed: %s", err)
  74. }
  75. var tlsConfig tls.Config
  76. tlsConfig.ServerName = c.domain
  77. tlsConn := tls.Client(tcpconn, &tlsConfig)
  78. // We convert existing connection to TLS
  79. if err = tlsConn.Handshake(); err != nil {
  80. return err
  81. }
  82. // We check that cert matches hostname
  83. if err = tlsConn.VerifyHostname(c.domain); err != nil {
  84. return err
  85. }
  86. if err = checkExpiration(tlsConn); err != nil {
  87. return err
  88. }
  89. return nil
  90. }
  91. return errors.New("TLS not supported on server")
  92. }
  93. // Check expiration date for the whole certificate chain and returns an error
  94. // if the expiration date is in less than 48 hours.
  95. func checkExpiration(tlsConn *tls.Conn) error {
  96. checkedCerts := make(map[string]struct{})
  97. for _, chain := range tlsConn.ConnectionState().VerifiedChains {
  98. for _, cert := range chain {
  99. if _, checked := checkedCerts[string(cert.Signature)]; checked {
  100. continue
  101. }
  102. checkedCerts[string(cert.Signature)] = struct{}{}
  103. // Check the expiration.
  104. timeNow := time.Now()
  105. expiresInHours := int64(cert.NotAfter.Sub(timeNow).Hours())
  106. // fmt.Printf("Cert '%s' expires in %d days\n", cert.Subject.CommonName, expiresInHours/24)
  107. if expiresInHours <= 48 {
  108. return fmt.Errorf("certificate '%s' will expire on %s", cert.Subject.CommonName, cert.NotAfter)
  109. }
  110. }
  111. }
  112. return nil
  113. }
  114. func extractParams(addr string) (string, string, error) {
  115. var err error
  116. hostport := strings.Split(addr, ":")
  117. if len(hostport) > 2 {
  118. err = errors.New("too many colons in xmpp server address")
  119. return addr, hostport[0], err
  120. }
  121. // Address is composed of two parts, we are good
  122. if len(hostport) == 2 && hostport[1] != "" {
  123. return addr, hostport[0], err
  124. }
  125. // Port was not passed, we append XMPP default port:
  126. return strings.Join([]string{hostport[0], "5222"}, ":"), hostport[0], err
  127. }