telegabber/xmpp/component.go

298 lines
6.6 KiB
Go

package xmpp
import (
"github.com/pkg/errors"
"regexp"
"strconv"
"sync"
"time"
"dev.narayana.im/narayana/telegabber/badger"
"dev.narayana.im/narayana/telegabber/config"
"dev.narayana.im/narayana/telegabber/persistence"
"dev.narayana.im/narayana/telegabber/telegram"
"dev.narayana.im/narayana/telegabber/xmpp/gateway"
log "github.com/sirupsen/logrus"
"gosrc.io/xmpp"
"gosrc.io/xmpp/stanza"
)
var tgConf config.TelegramConfig
var sessions map[string]*telegram.Client
var db *persistence.SessionsYamlDB
var sessionLock sync.Mutex
const (
B uint64 = 1
KB = B << 10
MB = KB << 10
GB = MB << 10
TB = GB << 10
PB = TB << 10
EB = PB << 10
maxUint64 uint64 = (1 << 64) - 1
)
var sizeRegex = regexp.MustCompile("\\A([0-9]+) ?([KMGTPE]?B?)\\z")
// NewComponent starts a new component and wraps it in
// a stream manager that you should start yourself
func NewComponent(conf config.XMPPConfig, tc config.TelegramConfig, idsPath string) (*xmpp.StreamManager, *xmpp.Component, error) {
var err error
gateway.Jid, err = stanza.NewJid(conf.Jid)
if err != nil {
return nil, nil, err
}
if gateway.Jid.Resource == "" {
if tc.Tdlib.Client.DeviceModel != "" {
gateway.Jid.Resource = tc.Tdlib.Client.DeviceModel
} else {
gateway.Jid.Resource = "telegabber"
}
}
gateway.IdsDB = badger.IdsDBOpen(idsPath)
tgConf = tc
if tc.Content.Quota != "" {
gateway.StorageQuota, err = parseSize(tc.Content.Quota)
if err != nil {
log.Warnf("Error parsing the storage quota: %v; the cleaner is disabled", err)
}
}
options := xmpp.ComponentOptions{
TransportConfiguration: xmpp.TransportConfiguration{
Address: conf.Host + ":" + conf.Port,
Domain: conf.Jid,
},
Domain: conf.Jid,
Secret: conf.Password,
Name: "telegabber",
}
router := xmpp.NewRouter()
router.HandleFunc("iq", HandleIq)
router.HandleFunc("presence", HandlePresence)
router.HandleFunc("message", HandleMessage)
component, err := xmpp.NewComponent(options, router, func(err error) {
log.Error(err)
})
if err != nil {
return nil, nil, err
}
// probe all known sessions
err = loadSessions(conf.Db, component)
if err != nil {
return nil, nil, err
}
sm := xmpp.NewStreamManager(component, func(s xmpp.Sender) {
go heartbeat(component)
})
return sm, component, nil
}
func heartbeat(component *xmpp.Component) {
var err error
probeType := gateway.SPType("probe")
sessionLock.Lock()
for jid := range sessions {
err = gateway.SendPresence(component, jid, probeType)
if err != nil {
log.Error(err)
}
}
sessionLock.Unlock()
quotaLowThreshold := gateway.StorageQuota / 10 * 9
log.Info("Starting heartbeat queue")
// status updater thread
for {
gateway.StorageLock.Lock()
if quotaLowThreshold > 0 && tgConf.Content.Path != "" {
gateway.MeasureStorageSize(tgConf.Content.Path)
if gateway.CachedStorageSize > quotaLowThreshold {
gateway.CleanOldFiles(tgConf.Content.Path, quotaLowThreshold)
}
}
gateway.StorageLock.Unlock()
time.Sleep(60e9)
now := time.Now().Unix()
sessionLock.Lock()
for _, session := range sessions {
session.DelayedStatusesLock.Lock()
for chatID, delayedStatus := range session.DelayedStatuses {
if delayedStatus.TimestampExpired <= now {
go session.ProcessStatusUpdate(
chatID,
session.LastSeenStatus(delayedStatus.TimestampOnline),
"away",
)
delete(session.DelayedStatuses, chatID)
}
}
session.DelayedStatusesLock.Unlock()
}
sessionLock.Unlock()
for key, presence := range gateway.Queue {
err = gateway.ResumableSend(component, presence)
if err != nil {
gateway.LogBadPresence(presence)
} else {
gateway.QueueLock.Lock()
delete(gateway.Queue, key)
gateway.QueueLock.Unlock()
}
}
if gateway.DirtySessions {
gateway.DirtySessions = false
// no problem if a dirty flag gets set again here,
// it would be resolved on the next iteration
SaveSessions()
}
gateway.IdsDB.Gc()
}
}
func loadSessions(dbPath string, component *xmpp.Component) error {
var err error
sessions = make(map[string]*telegram.Client)
db, err = persistence.LoadSessions(dbPath)
if err != nil {
return err
}
db.Transaction(func() bool {
for jid, session := range db.Data.Sessions {
// copy the session struct, otherwise all of them would reference
// the same temporary range variable
currentSession := session
getTelegramInstance(jid, &currentSession, component)
}
return false
}, persistence.SessionMarshaller)
return nil
}
func getTelegramInstance(jid string, savedSession *persistence.Session, component *xmpp.Component) (*telegram.Client, bool) {
var err error
session, ok := sessions[jid]
if !ok {
session, err = telegram.NewClient(tgConf, jid, component, savedSession)
if err != nil {
log.Error(errors.Wrap(err, "TDlib initialization failure"))
return session, false
}
if savedSession.KeepOnline {
if err = session.Connect(""); err != nil {
log.Error(err)
return session, false
}
}
sessionLock.Lock()
sessions[jid] = session
sessionLock.Unlock()
}
return session, true
}
// SaveSessions dumps current sessions to the file
func SaveSessions() {
sessionLock.Lock()
defer sessionLock.Unlock()
db.Transaction(func() bool {
for jid, session := range sessions {
db.Data.Sessions[jid] = *session.Session
}
return true
}, persistence.SessionMarshaller)
}
// Close gracefully terminates the component and saves active sessions
func Close(component *xmpp.Component) {
log.Error("Disconnecting...")
sessionLock.Lock()
// close all sessions
for _, session := range sessions {
session.Disconnect("", true)
}
sessionLock.Unlock()
// save sessions
SaveSessions()
// flush the ids database
gateway.IdsDB.Close()
// close stream
component.Disconnect()
}
// based on https://github.com/c2h5oh/datasize/blob/master/datasize.go
func parseSize(sSize string) (uint64, error) {
sizeParts := sizeRegex.FindStringSubmatch(sSize)
if len(sizeParts) > 2 {
numPart, err := strconv.ParseInt(sizeParts[1], 10, 64)
if err != nil {
return 0, err
}
var divisor uint64
val := uint64(numPart)
if len(sizeParts[2]) > 0 {
switch sizeParts[2][0] {
case 'B':
divisor = 1
case 'K':
divisor = KB
case 'M':
divisor = MB
case 'G':
divisor = GB
case 'T':
divisor = TB
case 'P':
divisor = PB
case 'E':
divisor = EB
}
}
if divisor == 0 {
return 0, &strconv.NumError{"Wrong suffix", sSize, strconv.ErrSyntax}
}
if val > maxUint64/divisor {
return 0, &strconv.NumError{"Overflow", sSize, strconv.ErrRange}
}
return val * divisor, nil
}
return 0, &strconv.NumError{"Not enough parts", sSize, strconv.ErrSyntax}
}