This commit is contained in:
fmodf 2024-08-31 12:59:18 +02:00
parent f5e201af06
commit 217dd152d3
2 changed files with 89 additions and 71 deletions

View file

@ -44,89 +44,49 @@ final class ClientMartinOMEMO {
// MARK: - Session // MARK: - Session
extension ClientMartinOMEMO: SignalSessionStoreProtocol { extension ClientMartinOMEMO: SignalSessionStoreProtocol {
func sessionRecord(forAddress address: MartinOMEMO.SignalAddress) -> Data? { func sessionRecord(forAddress address: MartinOMEMO.SignalAddress) -> Data? {
do { if let key = OMEMOSession.keyFor(account: credentials.bareJid, name: address.name, deviceId: address.deviceId) {
let data = try Database.shared.dbQueue.read { db in return Data(base64Encoded: key)
try Row.fetchOne( } else {
db,
sql: "SELECT key FROM omemo_sessions WHERE account = :account AND name = :name AND device_id = :deviceId",
arguments: ["account": credentials.bareJid, "name": address.name, "deviceId": address.deviceId]
)
}
return data?["key"]
} catch {
logIt(.error, "Error fetching chats: \(error.localizedDescription)")
return nil return nil
} }
} }
func allDevices(for name: String, activeAndTrusted: Bool) -> [Int32] { func allDevices(for name: String, activeAndTrusted: Bool) -> [Int32] {
do { activeAndTrusted ?
let sql = activeAndTrusted ? OMEMOSession.trustedDevicesIdsFor(account: credentials.bareJid, name: name) :
""" OMEMOSession.devicesIdsFor(account: credentials.bareJid, name: name)
SELECT s.device_id
FROM omemo_sessions s
LEFT JOIN omemo_identities i
ON s.account = i.account
AND s.name = i.name
AND s.device_id = i.device_id
WHERE s.account = :account
AND s.name = :name
AND ((i.status >= 0 AND i.status % 2 = 0) OR i.status IS NULL)
"""
:
"SELECT device_id FROM omemo_sessions WHERE account = :account AND name = :name"
let data = try Database.shared.dbQueue.read { db in
try Row.fetchAll(
db,
sql: sql,
arguments: ["account": credentials.bareJid, "name": name]
)
}
return data.map { $0["device_id"] }
} catch {
logIt(.error, "Error fetching chats: \(error.localizedDescription)")
return []
}
} }
func storeSessionRecord(_ data: Data, forAddress: MartinOMEMO.SignalAddress) -> Bool { func storeSessionRecord(_ data: Data, forAddress: MartinOMEMO.SignalAddress) -> Bool {
do { do {
try Database.shared.dbQueue.write { db in try Database.shared.dbQueue.write { db in
try db.execute( try OMEMOSession(
sql: "INSERT INTO omemo_sessions (account, name, device_id, key) VALUES (:account, :name, :deviceId, :key)", account: credentials.bareJid,
arguments: ["account": credentials.bareJid, "name": forAddress.name, "deviceId": forAddress.deviceId, "key": data] name: forAddress.name,
deviceId: Int(forAddress.deviceId),
key: data.base64EncodedString()
) )
.insert(db)
} }
return true return true
} catch { } catch {
logIt(.error, "Error fetching chats: \(error.localizedDescription)") logIt(.error, "Error storing session info: \(error.localizedDescription)")
return false return false
} }
} }
func containsSessionRecord(forAddress: MartinOMEMO.SignalAddress) -> Bool { func containsSessionRecord(forAddress: MartinOMEMO.SignalAddress) -> Bool {
do { OMEMOSession.keyFor(account: credentials.bareJid, name: forAddress.name, deviceId: forAddress.deviceId) != nil
let rec = try Database.shared.dbQueue.read { db in
try Row.fetchOne(
db,
sql: "SELECT key FROM omemo_sessions WHERE account = :account AND name = :name AND device_id = :deviceId",
arguments: ["account": credentials.bareJid, "name": forAddress.name, "deviceId": forAddress.deviceId]
)
}
return rec != nil
} catch {
logIt(.error, "Error fetching chats: \(error.localizedDescription)")
return false
}
} }
func deleteSessionRecord(forAddress: MartinOMEMO.SignalAddress) -> Bool { func deleteSessionRecord(forAddress: MartinOMEMO.SignalAddress) -> Bool {
do { do {
try Database.shared.dbQueue.write { db in _ = try Database.shared.dbQueue.write { db in
try db.execute( try OMEMOSession
sql: "DELETE FROM omemo_sessions WHERE account = :account AND name = :name AND device_id = :deviceId", .filter(Column("account") == credentials.bareJid)
arguments: ["account": credentials.bareJid, "name": forAddress.name, "deviceId": forAddress.deviceId] .filter(Column("name") == forAddress.name)
) .filter(Column("deviceId") == forAddress.deviceId)
.deleteAll(db)
} }
return true return true
} catch { } catch {
@ -137,11 +97,11 @@ extension ClientMartinOMEMO: SignalSessionStoreProtocol {
func deleteAllSessions(for name: String) -> Bool { func deleteAllSessions(for name: String) -> Bool {
do { do {
try Database.shared.dbQueue.write { db in _ = try Database.shared.dbQueue.write { db in
try db.execute( try OMEMOSession
sql: "DELETE FROM omemo_sessions WHERE account = :account AND name = :name", .filter(Column("account") == credentials.bareJid)
arguments: ["account": credentials.bareJid, "name": name] .filter(Column("name") == name)
) .deleteAll(db)
} }
return true return true
} catch { } catch {
@ -152,11 +112,10 @@ extension ClientMartinOMEMO: SignalSessionStoreProtocol {
func sessionsWipe() { func sessionsWipe() {
do { do {
try Database.shared.dbQueue.write { db in _ = try Database.shared.dbQueue.write { db in
try db.execute( try OMEMOSession
sql: "DELETE FROM omemo_sessions WHERE account = :account", .filter(Column("account") == credentials.bareJid)
arguments: ["account": credentials.bareJid] .deleteAll(db)
)
} }
} catch { } catch {
logIt(.error, "Error fetching chats: \(error.localizedDescription)") logIt(.error, "Error fetching chats: \(error.localizedDescription)")

View file

@ -3,13 +3,72 @@ import GRDB
import Martin import Martin
// MARK: - Session // MARK: - Session
struct OMEMOSession: Codable & Equatable, DatabaseValueConvertible { struct OMEMOSession: DBStorable {
static let databaseTableName = "omemo_sessions" static let databaseTableName = "omemo_sessions"
let account: String let account: String
let name: String let name: String
let deviceId: Int let deviceId: Int
let key: String let key: String
var id: String {
"\(account)_\(name)_\(deviceId)"
}
}
extension OMEMOSession {
static func keyFor(account: String, name: String, deviceId: Int32) -> String? {
do {
return try Database.shared.dbQueue.read { db in
try OMEMOSession
.filter(Column("account") == account)
.filter(Column("name") == name)
.filter(Column("deviceId") == deviceId)
.fetchOne(db)
}?.key
} catch {
return nil
}
}
static func devicesIdsFor(account: String, name: String) -> [Int32] {
do {
return try Database.shared.dbQueue.read { db in
try OMEMOSession
.filter(Column("account") == account)
.filter(Column("name") == name)
.fetchAll(db)
.map(\.deviceId)
}.map { Int32($0) }
} catch {
return []
}
}
static func trustedDevicesIdsFor(account: String, name: String) -> [Int32] {
do {
let sql =
"""
SELECT s.device_id
FROM omemo_sessions s
LEFT JOIN omemo_identities i
ON s.account = i.account
AND s.name = i.name
AND s.device_id = i.device_id
WHERE s.account = :account
AND s.name = :name
AND ((i.status >= 0 AND i.status % 2 = 0) OR i.status IS NULL)
"""
let arguments: StatementArguments = ["account": account, "name": name]
return try Database.shared.dbQueue.read { db in
try Int32.fetchAll(db, sql: sql, arguments: arguments)
}
} catch {
return []
}
}
} }
// MARK: - Identity // MARK: - Identity