diff --git a/auth.go b/auth.go index 8569297..e69f82e 100644 --- a/auth.go +++ b/auth.go @@ -50,11 +50,6 @@ func authPlain(socket io.ReadWriter, decoder *xml.Decoder, user string, password return err } -type saslMechanisms struct { - XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl mechanisms"` - Mechanism []string `xml:"mechanism"` -} - // ============================================================================ // SASLSuccess diff --git a/check_cert.go b/check_cert.go index 6190d96..074676e 100644 --- a/check_cert.go +++ b/check_cert.go @@ -76,8 +76,7 @@ func (c *ServerCheck) Check() error { return errors.New("expected packet received while expecting features, got " + p.Name()) } - startTLSFeature := f.StartTLS.XMLName.Space + " " + f.StartTLS.XMLName.Local - if startTLSFeature == nsTLS+" starttls" { + if _, ok := f.DoesStartTLS(); ok { fmt.Fprintf(tcpconn, "") var k tlsProceed diff --git a/client_test.go b/client_test.go index b7fb1ac..7d68717 100644 --- a/client_test.go +++ b/client_test.go @@ -60,6 +60,30 @@ func TestClient_NoInsecure(t *testing.T) { mock.Stop() } +// Check that the client is properly tracking features, as session negotiation progresses. +func TestClient_FeaturesTracking(t *testing.T) { + // Setup Mock server + mock := ServerMock{} + mock.Start(t, testXMPPAddress, handlerAbortTLS) + + // Test / Check result + config := Config{Address: testXMPPAddress, Jid: "test@localhost", Password: "test"} + + var client *Client + var err error + if client, err = NewClient(config); err != nil { + t.Errorf("cannot create XMPP client: %s", err) + } + + if err = client.Connect(); err == nil { + // When insecure is not allowed: + t.Errorf("should fail as insecure connection is not allowed and server does not support TLS") + } + + mock.Stop() + +} + //============================================================================= // Basic XMPP Server Mock Handlers. diff --git a/session.go b/session.go index 0c0b278..3bce3cd 100644 --- a/session.go +++ b/session.go @@ -109,7 +109,7 @@ func (s *Session) startTlsIfSupported(conn net.Conn, domain string) net.Conn { return conn } - if s.Features.StartTLS.XMLName.Space+" "+s.Features.StartTLS.XMLName.Local == nsTLS+" starttls" { + if _, ok := s.Features.DoesStartTLS(); ok { fmt.Fprintf(s.socketProxy, "") var k tlsProceed diff --git a/starttls.go b/starttls.go index 28149b7..8c36222 100644 --- a/starttls.go +++ b/starttls.go @@ -7,12 +7,7 @@ import ( var DefaultTlsConfig tls.Config -// XMPP Packet Parsing -type tlsStartTLS struct { - XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls starttls"` - Required bool -} - +// Used during stream initiation / session establishment type tlsProceed struct { XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls proceed"` } diff --git a/stream.go b/stream.go index a66ac85..b887c55 100644 --- a/stream.go +++ b/stream.go @@ -6,11 +6,14 @@ import ( // ============================================================================ // StreamFeatures Packet +// Reference: https://xmpp.org/registrar/stream-features.html type StreamFeatures struct { - XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"` + XMLName xml.Name `xml:"http://etherx.jabber.org/streams features"` + // Server capabilities hash + Caps Caps + // Stream features StartTLS tlsStartTLS - Caps Caps Mechanisms saslMechanisms Bind BindBind Session sessionSession @@ -31,6 +34,76 @@ func (streamFeatureDecoder) decode(p *xml.Decoder, se xml.StartElement) (StreamF return packet, err } +// Capabilities +// Reference: https://xmpp.org/extensions/xep-0115.html#stream +// "A server MAY include its entity capabilities in a stream feature element so that connecting clients +// and peer servers do not need to send service discovery requests each time they connect." +// This is not a stream feature but a way to let client cache server disco info. +type Caps struct { + XMLName xml.Name `xml:"http://jabber.org/protocol/caps c"` + Hash string `xml:"hash,attr"` + Node string `xml:"node,attr"` + Ver string `xml:"ver,attr"` + Ext string `xml:"ext,attr,omitempty"` +} + +// ============================================================================ +// Supported Stream Features + +// StartTLS feature +// Reference: RFC 6120 - https://tools.ietf.org/html/rfc6120#section-5.4 +type tlsStartTLS struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-tls starttls"` + Required bool +} + +// UnmarshalXML implements custom parsing startTLS required flag +func (stls *tlsStartTLS) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { + stls.XMLName = start.Name + + // Check subelements to extract required field as boolean + for { + t, err := d.Token() + if err != nil { + return err + } + + switch tt := t.(type) { + + case xml.StartElement: + elt := new(Node) + + err = d.DecodeElement(elt, &tt) + if err != nil { + return err + } + + if elt.XMLName.Local == "required" { + stls.Required = true + } + + case xml.EndElement: + if tt == start.End() { + return nil + } + } + } +} + +func (sf *StreamFeatures) DoesStartTLS() (feature tlsStartTLS, isSupported bool) { + if sf.StartTLS.XMLName.Space+" "+sf.StartTLS.XMLName.Local == nsTLS+" starttls" { + return sf.StartTLS, true + } + return feature, false +} + +// Mechanisms +// Reference: RFC 6120 - https://tools.ietf.org/html/rfc6120#section-6.4.1 +type saslMechanisms struct { + XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl mechanisms"` + Mechanism []string `xml:"mechanism"` +} + // ============================================================================ // StreamError Packet @@ -53,14 +126,3 @@ func (streamErrorDecoder) decode(p *xml.Decoder, se xml.StartElement) (StreamErr err := p.DecodeElement(&packet, &se) return packet, err } - -// ============================================================================ -// Caps subElement - -type Caps struct { - XMLName xml.Name `xml:"http://jabber.org/protocol/caps c"` - Hash string `xml:"hash,attr"` - Node string `xml:"node,attr"` - Ver string `xml:"ver,attr"` - Ext string `xml:"ext,attr,omitempty"` -} diff --git a/stream_test.go b/stream_test.go new file mode 100644 index 0000000..f10d1de --- /dev/null +++ b/stream_test.go @@ -0,0 +1,47 @@ +package xmpp_test + +import ( + "encoding/xml" + "testing" + + "gosrc.io/xmpp" +) + +func TestNoStartTLS(t *testing.T) { + streamFeatures := ` +` + + var parsedSF xmpp.StreamFeatures + if err := xml.Unmarshal([]byte(streamFeatures), &parsedSF); err != nil { + t.Errorf("Unmarshal(%s) returned error: %v", streamFeatures, err) + } + + startTLS, ok := parsedSF.DoesStartTLS() + if ok { + t.Error("StartTLS feature should not be enabled") + } + if startTLS.Required { + t.Error("StartTLS cannot be required as default") + } +} + +func TestStartTLS(t *testing.T) { + streamFeatures := ` + + + +` + + var parsedSF xmpp.StreamFeatures + if err := xml.Unmarshal([]byte(streamFeatures), &parsedSF); err != nil { + t.Errorf("Unmarshal(%s) returned error: %v", streamFeatures, err) + } + + startTLS, ok := parsedSF.DoesStartTLS() + if !ok { + t.Error("StartTLS feature should be enabled") + } + if !startTLS.Required { + t.Error("StartTLS feature should be required") + } +}