diff --git a/jid.go b/jid.go index cee292a..f44019b 100644 --- a/jid.go +++ b/jid.go @@ -51,11 +51,21 @@ func NewJid(sjid string) (*Jid, error) { } func (j *Jid) Full() string { - return j.Node + "@" + j.Domain + "/" + j.Resource + if j.Resource == "" { + return j.Bare() + } else if j.Node == "" { + return j.Node + "/" + j.Resource + } else { + return j.Node + "@" + j.Domain + "/" + j.Resource + } } func (j *Jid) Bare() string { - return j.Node + "@" + j.Domain + if j.Node == "" { + return j.Domain + } else { + return j.Node + "@" + j.Domain + } } // ============================================================================ diff --git a/jid_test.go b/jid_test.go index c6fee03..45483dd 100644 --- a/jid_test.go +++ b/jid_test.go @@ -61,26 +61,41 @@ func TestIncorrectJids(t *testing.T) { } func TestFull(t *testing.T) { - jid := "test@domain.com/my resource" - parsedJid, err := NewJid(jid) - if err != nil { - t.Errorf("could not parse jid: %v", err) + fullJids := []string{ + "test@domain.com/my resource", + "test@domain.com", + "domain.com", } - fullJid := parsedJid.Full() - if fullJid != jid { - t.Errorf("incorrect full jid: %s", fullJid) + for _, sjid := range fullJids { + parsedJid, err := NewJid(sjid) + if err != nil { + t.Errorf("could not parse jid: %v", err) + } + fullJid := parsedJid.Full() + if fullJid != sjid { + t.Errorf("incorrect full jid: %s", fullJid) + } } } func TestBare(t *testing.T) { - jid := "test@domain.com" - fullJid := jid + "/my resource" - parsedJid, err := NewJid(fullJid) - if err != nil { - t.Errorf("could not parse jid: %v", err) + tests := []struct { + jidstr string + expected string + }{ + {jidstr: "test@domain.com", expected: "test@domain.com"}, + {jidstr: "test@domain.com/resource", expected: "test@domain.com"}, + {jidstr: "domain.com", expected: "domain.com"}, } - bareJid := parsedJid.Bare() - if bareJid != jid { - t.Errorf("incorrect bare jid: %s", bareJid) + + for _, tt := range tests { + parsedJid, err := NewJid(tt.jidstr) + if err != nil { + t.Errorf("could not parse jid: %v", err) + } + bareJid := parsedJid.Bare() + if bareJid != tt.expected { + t.Errorf("incorrect bare jid: %s", bareJid) + } } }