signal-protocol/omemo: fix null-pointer issues

Fixes #44 and #58
This commit is contained in:
Marvin W 2017-04-18 17:55:20 +02:00
parent f95b4f4e09
commit 7e388fb2bc
No known key found for this signature in database
GPG key ID: 072E9235DB996F2A
13 changed files with 205 additions and 144 deletions

View file

@ -34,7 +34,7 @@ public class AccountSettingWidget : Plugins.AccountSettingsWidget, Box {
if (row == null) { if (row == null) {
fingerprint.set_markup("%s\n<span font='8'>%s</span>".printf(_("Own fingerprint"), _("Will be generated on first connect"))); fingerprint.set_markup("%s\n<span font='8'>%s</span>".printf(_("Own fingerprint"), _("Will be generated on first connect")));
} else { } else {
uint8[] arr = Base64.decode(row[plugin.db.identity.identity_key_public_base64]); uint8[] arr = Base64.decode(((!)row)[plugin.db.identity.identity_key_public_base64]);
arr = arr[1:arr.length]; arr = arr[1:arr.length];
string res = ""; string res = "";
foreach (uint8 i in arr) { foreach (uint8 i in arr) {

View file

@ -9,21 +9,22 @@ public class Bundle {
public Bundle(StanzaNode? node) { public Bundle(StanzaNode? node) {
this.node = node; this.node = node;
assert(Plugin.ensure_context());
} }
public int32 signed_pre_key_id { owned get { public int32 signed_pre_key_id { owned get {
if (node == null) return -1; if (node == null) return -1;
string id = node.get_deep_attribute("signedPreKeyPublic", "signedPreKeyId"); string? id = ((!)node).get_deep_attribute("signedPreKeyPublic", "signedPreKeyId");
if (id == null) return -1; if (id == null) return -1;
return int.parse(id); return int.parse((!)id);
}} }}
public ECPublicKey? signed_pre_key { owned get { public ECPublicKey? signed_pre_key { owned get {
if (node == null) return null; if (node == null) return null;
string? key = node.get_deep_string_content("signedPreKeyPublic"); string? key = ((!)node).get_deep_string_content("signedPreKeyPublic");
if (key == null) return null; if (key == null) return null;
try { try {
return Plugin.context.decode_public_key(Base64.decode(key)); return Plugin.get_context().decode_public_key(Base64.decode((!)key));
} catch (Error e) { } catch (Error e) {
return null; return null;
} }
@ -31,17 +32,17 @@ public class Bundle {
public uint8[]? signed_pre_key_signature { owned get { public uint8[]? signed_pre_key_signature { owned get {
if (node == null) return null; if (node == null) return null;
string? sig = node.get_deep_string_content("signedPreKeySignature"); string? sig = ((!)node).get_deep_string_content("signedPreKeySignature");
if (sig == null) return null; if (sig == null) return null;
return Base64.decode(sig); return Base64.decode((!)sig);
}} }}
public ECPublicKey? identity_key { owned get { public ECPublicKey? identity_key { owned get {
if (node == null) return null; if (node == null) return null;
string? key = node.get_deep_string_content("identityKey"); string? key = ((!)node).get_deep_string_content("identityKey");
if (key == null) return null; if (key == null) return null;
try { try {
return Plugin.context.decode_public_key(Base64.decode(key)); return Plugin.get_context().decode_public_key(Base64.decode((!)key));
} catch (Error e) { } catch (Error e) {
return null; return null;
} }
@ -49,9 +50,9 @@ public class Bundle {
public ArrayList<PreKey> pre_keys { owned get { public ArrayList<PreKey> pre_keys { owned get {
ArrayList<PreKey> list = new ArrayList<PreKey>(); ArrayList<PreKey> list = new ArrayList<PreKey>();
if (node == null || node.get_subnode("prekeys") == null) return list; if (node == null || ((!)node).get_subnode("prekeys") == null) return list;
node.get_deep_subnodes("prekeys", "preKeyPublic") ((!)node).get_deep_subnodes("prekeys", "preKeyPublic")
.filter((node) => node.get_attribute("preKeyId") != null) .filter((node) => ((!)node).get_attribute("preKeyId") != null)
.map<PreKey>(PreKey.create) .map<PreKey>(PreKey.create)
.foreach((key) => list.add(key)); .foreach((key) => list.add(key));
return list; return list;
@ -76,7 +77,7 @@ public class Bundle {
string? key = node.get_string_content(); string? key = node.get_string_content();
if (key == null) return null; if (key == null) return null;
try { try {
return Plugin.context.decode_public_key(Base64.decode(key)); return Plugin.get_context().decode_public_key(Base64.decode((!)key));
} catch (Error e) { } catch (Error e) {
return null; return null;
} }

View file

@ -12,8 +12,8 @@ public class Database : Qlite.Database {
public Column<int> id = new Column.Integer("id") { primary_key = true, auto_increment = true }; public Column<int> id = new Column.Integer("id") { primary_key = true, auto_increment = true };
public Column<int> account_id = new Column.Integer("account_id") { unique = true, not_null = true }; public Column<int> account_id = new Column.Integer("account_id") { unique = true, not_null = true };
public Column<int> device_id = new Column.Integer("device_id") { not_null = true }; public Column<int> device_id = new Column.Integer("device_id") { not_null = true };
public Column<string> identity_key_private_base64 = new Column.Text("identity_key_private_base64") { not_null = true }; public Column<string> identity_key_private_base64 = new Column.NonNullText("identity_key_private_base64");
public Column<string> identity_key_public_base64 = new Column.Text("identity_key_public_base64") { not_null = true }; public Column<string> identity_key_public_base64 = new Column.NonNullText("identity_key_public_base64");
internal IdentityTable(Database db) { internal IdentityTable(Database db) {
base(db, "identity"); base(db, "identity");
@ -24,7 +24,7 @@ public class Database : Qlite.Database {
public class SignedPreKeyTable : Table { public class SignedPreKeyTable : Table {
public Column<int> identity_id = new Column.Integer("identity_id") { not_null = true }; public Column<int> identity_id = new Column.Integer("identity_id") { not_null = true };
public Column<int> signed_pre_key_id = new Column.Integer("signed_pre_key_id") { not_null = true }; public Column<int> signed_pre_key_id = new Column.Integer("signed_pre_key_id") { not_null = true };
public Column<string> record_base64 = new Column.Text("record_base64") { not_null = true }; public Column<string> record_base64 = new Column.NonNullText("record_base64");
internal SignedPreKeyTable(Database db) { internal SignedPreKeyTable(Database db) {
base(db, "signed_pre_key"); base(db, "signed_pre_key");
@ -36,7 +36,7 @@ public class Database : Qlite.Database {
public class PreKeyTable : Table { public class PreKeyTable : Table {
public Column<int> identity_id = new Column.Integer("identity_id") { not_null = true }; public Column<int> identity_id = new Column.Integer("identity_id") { not_null = true };
public Column<int> pre_key_id = new Column.Integer("pre_key_id") { not_null = true }; public Column<int> pre_key_id = new Column.Integer("pre_key_id") { not_null = true };
public Column<string> record_base64 = new Column.Text("record_base64") { not_null = true }; public Column<string> record_base64 = new Column.NonNullText("record_base64");
internal PreKeyTable(Database db) { internal PreKeyTable(Database db) {
base(db, "pre_key"); base(db, "pre_key");
@ -47,9 +47,9 @@ public class Database : Qlite.Database {
public class SessionTable : Table { public class SessionTable : Table {
public Column<int> identity_id = new Column.Integer("identity_id") { not_null = true }; public Column<int> identity_id = new Column.Integer("identity_id") { not_null = true };
public Column<string> address_name = new Column.Text("name") { not_null = true }; public Column<string> address_name = new Column.NonNullText("name");
public Column<int> device_id = new Column.Integer("device_id") { not_null = true }; public Column<int> device_id = new Column.Integer("device_id") { not_null = true };
public Column<string> record_base64 = new Column.Text("record_base64") { not_null = true }; public Column<string> record_base64 = new Column.NonNullText("record_base64");
internal SessionTable(Database db) { internal SessionTable(Database db) {
base(db, "session"); base(db, "session");

View file

@ -70,16 +70,22 @@ public class Manager : StreamInteractionModule, Object {
} }
private void on_pre_message_received(Entities.Message message, Xmpp.Message.Stanza message_stanza, Conversation conversation) { private void on_pre_message_received(Entities.Message message, Xmpp.Message.Stanza message_stanza, Conversation conversation) {
if (MessageFlag.get_flag(message_stanza) != null && MessageFlag.get_flag(message_stanza).decrypted) { MessageFlag? flag = MessageFlag.get_flag(message_stanza);
if (flag != null && ((!)flag).decrypted) {
message.encryption = Encryption.OMEMO; message.encryption = Encryption.OMEMO;
} }
} }
private void on_pre_message_send(Entities.Message message, Xmpp.Message.Stanza message_stanza, Conversation conversation) { private void on_pre_message_send(Entities.Message message, Xmpp.Message.Stanza message_stanza, Conversation conversation) {
if (message.encryption == Encryption.OMEMO) { if (message.encryption == Encryption.OMEMO) {
StreamModule module = stream_interactor.get_stream(conversation.account).get_module(StreamModule.IDENTITY); Core.XmppStream? stream = stream_interactor.get_stream(conversation.account);
if (stream == null) {
message.marked = Entities.Message.Marked.UNSENT;
return;
}
StreamModule module = ((!)stream).get_module(StreamModule.IDENTITY);
EncryptState enc_state = module.encrypt(message_stanza, conversation.account.bare_jid.to_string()); EncryptState enc_state = module.encrypt(message_stanza, conversation.account.bare_jid.to_string());
MessageState state = null; MessageState state;
lock (message_states) { lock (message_states) {
if (message_states.has_key(message)) { if (message_states.has_key(message)) {
state = message_states.get(message); state = message_states.get(message);
@ -95,18 +101,18 @@ public class Manager : StreamInteractionModule, Object {
if (!state.will_send_now) { if (!state.will_send_now) {
if (message.marked == Entities.Message.Marked.WONTSEND) { if (message.marked == Entities.Message.Marked.WONTSEND) {
if (Plugin.DEBUG) print(@"OMEMO: message $(message.stanza_id) was not sent: $state\n"); if (Plugin.DEBUG) print(@"OMEMO: message was not sent: $state\n");
} else { } else {
if (Plugin.DEBUG) print(@"OMEMO: message $(message.stanza_id) will be delayed: $state\n"); if (Plugin.DEBUG) print(@"OMEMO: message will be delayed: $state\n");
if (state.waiting_own_sessions > 0) { if (state.waiting_own_sessions > 0) {
module.start_sessions_with(stream_interactor.get_stream(conversation.account), conversation.account.bare_jid.to_string()); module.start_sessions_with((!)stream, conversation.account.bare_jid.to_string());
} }
if (state.waiting_other_sessions > 0) { if (state.waiting_other_sessions > 0 && message.counterpart != null) {
module.start_sessions_with(stream_interactor.get_stream(conversation.account), message.counterpart.bare_jid.to_string()); module.start_sessions_with((!)stream, ((!)message.counterpart).bare_jid.to_string());
} }
if (state.waiting_other_devicelist) { if (state.waiting_other_devicelist && message.counterpart != null) {
module.request_user_devicelist(stream_interactor.get_stream(conversation.account), message.counterpart.bare_jid.to_string()); module.request_user_devicelist((!)stream, ((!)message.counterpart).bare_jid.to_string());
} }
} }
} }
@ -120,8 +126,7 @@ public class Manager : StreamInteractionModule, Object {
stream_interactor.module_manager.get_module(account, StreamModule.IDENTITY).session_start_failed.connect((jid, device_id) => on_session_started(account, jid, true)); stream_interactor.module_manager.get_module(account, StreamModule.IDENTITY).session_start_failed.connect((jid, device_id) => on_session_started(account, jid, true));
} }
private void on_stream_negotiated(Account account) { private void on_stream_negotiated(Account account, Core.XmppStream stream) {
Core.XmppStream stream = stream_interactor.get_stream(account);
stream_interactor.module_manager.get_module(account, StreamModule.IDENTITY).request_user_devicelist(stream, account.bare_jid.to_string()); stream_interactor.module_manager.get_module(account, StreamModule.IDENTITY).request_user_devicelist(stream, account.bare_jid.to_string());
} }
@ -134,7 +139,7 @@ public class Manager : StreamInteractionModule, Object {
MessageState state = message_states[msg]; MessageState state = message_states[msg];
if (account.bare_jid.to_string() == jid) { if (account.bare_jid.to_string() == jid) {
state.waiting_own_sessions--; state.waiting_own_sessions--;
} else if (msg.counterpart.bare_jid.to_string() == jid) { } else if (msg.counterpart != null && ((!)msg.counterpart).bare_jid.to_string() == jid) {
state.waiting_other_sessions--; state.waiting_other_sessions--;
} }
if (state.should_retry_now()) { if (state.should_retry_now()) {
@ -144,8 +149,10 @@ public class Manager : StreamInteractionModule, Object {
} }
} }
foreach (Entities.Message msg in send_now) { foreach (Entities.Message msg in send_now) {
Entities.Conversation conv = stream_interactor.get_module(ConversationManager.IDENTITY).get_conversation(msg.counterpart, account); if (msg.counterpart == null) continue;
stream_interactor.get_module(MessageProcessor.IDENTITY).send_xmpp_message(msg, conv, true); Entities.Conversation? conv = stream_interactor.get_module(ConversationManager.IDENTITY).get_conversation((!)msg.counterpart, account);
if (conv == null) continue;
stream_interactor.get_module(MessageProcessor.IDENTITY).send_xmpp_message(msg, (!)conv, true);
} }
} }
@ -158,7 +165,7 @@ public class Manager : StreamInteractionModule, Object {
MessageState state = message_states[msg]; MessageState state = message_states[msg];
if (account.bare_jid.to_string() == jid) { if (account.bare_jid.to_string() == jid) {
state.waiting_own_devicelist = false; state.waiting_own_devicelist = false;
} else if (msg.counterpart.bare_jid.to_string() == jid) { } else if (msg.counterpart != null && ((!)msg.counterpart).bare_jid.to_string() == jid) {
state.waiting_other_devicelist = false; state.waiting_other_devicelist = false;
} }
if (state.should_retry_now()) { if (state.should_retry_now()) {
@ -168,8 +175,10 @@ public class Manager : StreamInteractionModule, Object {
} }
} }
foreach (Entities.Message msg in send_now) { foreach (Entities.Message msg in send_now) {
Entities.Conversation conv = stream_interactor.get_module(ConversationManager.IDENTITY).get_conversation(msg.counterpart, account); if (msg.counterpart == null) continue;
stream_interactor.get_module(MessageProcessor.IDENTITY).send_xmpp_message(msg, conv, true); Entities.Conversation? conv = stream_interactor.get_module(ConversationManager.IDENTITY).get_conversation(((!)msg.counterpart), account);
if (conv == null) continue;
stream_interactor.get_module(MessageProcessor.IDENTITY).send_xmpp_message(msg, (!)conv, true);
} }
} }
@ -187,7 +196,7 @@ public class Manager : StreamInteractionModule, Object {
try { try {
store.identity_key_store.local_registration_id = Random.int_range(1, int32.MAX); store.identity_key_store.local_registration_id = Random.int_range(1, int32.MAX);
Signal.ECKeyPair key_pair = Plugin.context.generate_key_pair(); Signal.ECKeyPair key_pair = Plugin.get_context().generate_key_pair();
store.identity_key_store.identity_key_private = key_pair.private.serialize(); store.identity_key_store.identity_key_private = key_pair.private.serialize();
store.identity_key_store.identity_key_public = key_pair.public.serialize(); store.identity_key_store.identity_key_public = key_pair.public.serialize();
@ -201,10 +210,10 @@ public class Manager : StreamInteractionModule, Object {
// Ignore error // Ignore error
} }
} else { } else {
store.identity_key_store.local_registration_id = row[db.identity.device_id]; store.identity_key_store.local_registration_id = ((!)row)[db.identity.device_id];
store.identity_key_store.identity_key_private = Base64.decode(row[db.identity.identity_key_private_base64]); store.identity_key_store.identity_key_private = Base64.decode(((!)row)[db.identity.identity_key_private_base64]);
store.identity_key_store.identity_key_public = Base64.decode(row[db.identity.identity_key_public_base64]); store.identity_key_store.identity_key_public = Base64.decode(((!)row)[db.identity.identity_key_public_base64]);
identity_id = row[db.identity.id]; identity_id = ((!)row)[db.identity.id];
} }
if (identity_id >= 0) { if (identity_id >= 0) {
@ -218,9 +227,11 @@ public class Manager : StreamInteractionModule, Object {
public bool can_encrypt(Entities.Conversation conversation) { public bool can_encrypt(Entities.Conversation conversation) {
Core.XmppStream stream = stream_interactor.get_stream(conversation.account); Core.XmppStream? stream = stream_interactor.get_stream(conversation.account);
if (stream == null) return false; if (stream == null) return false;
return stream.get_module(StreamModule.IDENTITY).is_known_address(conversation.counterpart.bare_jid.to_string()); StreamModule? module = ((!)stream).get_module(StreamModule.IDENTITY);
if (module == null) return false;
return ((!)module).is_known_address(conversation.counterpart.bare_jid.to_string());
} }
public static void start(StreamInteractor stream_interactor, Database db) { public static void start(StreamInteractor stream_interactor, Database db) {

View file

@ -5,7 +5,23 @@ namespace Dino.Plugins.Omemo {
public class Plugin : RootInterface, Object { public class Plugin : RootInterface, Object {
public const bool DEBUG = false; public const bool DEBUG = false;
public static Signal.Context context; private static Signal.Context? _context;
public static Signal.Context get_context() {
assert(_context != null);
return (!)_context;
}
public static bool ensure_context() {
lock(_context) {
try {
if (_context == null) {
_context = new Signal.Context(DEBUG);
}
return true;
} catch (Error e) {
return false;
}
}
}
public Dino.Application app; public Dino.Application app;
public Database db; public Database db;
@ -14,7 +30,7 @@ public class Plugin : RootInterface, Object {
public void registered(Dino.Application app) { public void registered(Dino.Application app) {
try { try {
context = new Signal.Context(DEBUG); ensure_context();
this.app = app; this.app = app;
this.db = new Database(Path.build_filename(Application.get_storage_dir(), "omemo.db")); this.db = new Database(Path.build_filename(Application.get_storage_dir(), "omemo.db"));
this.list_entry = new EncryptionListEntry(this); this.list_entry = new EncryptionListEntry(this);
@ -26,7 +42,13 @@ public class Plugin : RootInterface, Object {
}); });
Manager.start(this.app.stream_interaction, db); Manager.start(this.app.stream_interaction, db);
internationalize(GETTEXT_PACKAGE, app.search_path_generator.get_locale_path(GETTEXT_PACKAGE, LOCALE_INSTALL_DIR)); string locales_dir;
if (app.search_path_generator != null) {
locales_dir = ((!)app.search_path_generator).get_locale_path(GETTEXT_PACKAGE, LOCALE_INSTALL_DIR);
} else {
locales_dir = LOCALE_INSTALL_DIR;
}
internationalize(GETTEXT_PACKAGE, locales_dir);
} catch (Error e) { } catch (Error e) {
print(@"Error initializing OMEMO: $(e.message)\n"); print(@"Error initializing OMEMO: $(e.message)\n");
} }

View file

@ -15,11 +15,10 @@ private class BackedSessionStore : SimpleSessionStore {
private void init() { private void init() {
try { try {
Address addr = new Address();
foreach (Row row in db.session.select().with(db.session.identity_id, "=", identity_id)) { foreach (Row row in db.session.select().with(db.session.identity_id, "=", identity_id)) {
addr.name = row[db.session.address_name]; Address addr = new Address(row[db.session.address_name], row[db.session.device_id]);
addr.device_id = row[db.session.device_id];
store_session(addr, Base64.decode(row[db.session.record_base64])); store_session(addr, Base64.decode(row[db.session.record_base64]));
addr.device_id = 0;
} }
} catch (Error e) { } catch (Error e) {
print(@"OMEMO: Error while initializing session store: $(e.message)\n"); print(@"OMEMO: Error while initializing session store: $(e.message)\n");

View file

@ -29,25 +29,26 @@ public class StreamModule : XmppStreamModule {
public EncryptState encrypt(Message.Stanza message, string self_bare_jid) { public EncryptState encrypt(Message.Stanza message, string self_bare_jid) {
EncryptState status = new EncryptState(); EncryptState status = new EncryptState();
if (Plugin.context == null) return status; if (!Plugin.ensure_context()) return status;
if (message.to == null) return status;
try { try {
string name = get_bare_jid(message.to); string name = get_bare_jid((!)message.to);
if (device_lists.get(self_bare_jid) == null) return status; if (!device_lists.has_key(self_bare_jid)) return status;
status.own_list = true; status.own_list = true;
status.own_devices = device_lists.get(self_bare_jid).size; status.own_devices = device_lists.get(self_bare_jid).size;
if (device_lists.get(name) == null) return status; if (!device_lists.has_key(name)) return status;
status.other_list = true; status.other_list = true;
status.other_devices = device_lists.get(name).size; status.other_devices = device_lists.get(name).size;
if (status.own_devices == 0 || status.other_devices == 0) return status; if (status.own_devices == 0 || status.other_devices == 0) return status;
uint8[] key = new uint8[16]; uint8[] key = new uint8[16];
Plugin.context.randomize(key); Plugin.get_context().randomize(key);
uint8[] iv = new uint8[16]; uint8[] iv = new uint8[16];
Plugin.context.randomize(iv); Plugin.get_context().randomize(iv);
uint8[] ciphertext = aes_encrypt(Cipher.AES_GCM_NOPADDING, key, iv, message.body.data); uint8[] ciphertext = aes_encrypt(Cipher.AES_GCM_NOPADDING, key, iv, message.body.data);
StanzaNode header = null; StanzaNode header;
StanzaNode encrypted = new StanzaNode.build("encrypted", NS_URI).add_self_xmlns() StanzaNode encrypted = new StanzaNode.build("encrypted", NS_URI).add_self_xmlns()
.put_node(header = new StanzaNode.build("header", NS_URI) .put_node(header = new StanzaNode.build("header", NS_URI)
.put_attribute("sid", store.local_registration_id.to_string()) .put_attribute("sid", store.local_registration_id.to_string())
@ -56,8 +57,7 @@ public class StreamModule : XmppStreamModule {
.put_node(new StanzaNode.build("payload", NS_URI) .put_node(new StanzaNode.build("payload", NS_URI)
.put_node(new StanzaNode.text(Base64.encode(ciphertext)))); .put_node(new StanzaNode.text(Base64.encode(ciphertext))));
Address address = new Address(); Address address = new Address(name, 0);
address.name = name;
foreach(int32 device_id in device_lists[name]) { foreach(int32 device_id in device_lists[name]) {
if (is_ignored_device(name, device_id)) { if (is_ignored_device(name, device_id)) {
status.other_lost++; status.other_lost++;
@ -114,43 +114,47 @@ public class StreamModule : XmppStreamModule {
public override void attach(XmppStream stream) { public override void attach(XmppStream stream) {
Message.Module.require(stream); Message.Module.require(stream);
Pubsub.Module.require(stream); Pubsub.Module.require(stream);
if (Plugin.context == null) return; if (!Plugin.ensure_context()) return;
this.store = Plugin.context.create_store(); this.store = Plugin.get_context().create_store();
store_created(store); store_created(store);
stream.get_module(Message.Module.IDENTITY).pre_received_message.connect(on_pre_received_message); stream.get_module(Message.Module.IDENTITY).pre_received_message.connect(on_pre_received_message);
stream.get_module(Pubsub.Module.IDENTITY).add_filtered_notification(stream, NODE_DEVICELIST, (stream, jid, id, node, obj) => (obj as StreamModule).on_devicelist(stream, jid, id, node), this); stream.get_module(Pubsub.Module.IDENTITY).add_filtered_notification(stream, NODE_DEVICELIST, (stream, jid, id, node, obj) => ((StreamModule)obj).on_devicelist(stream, jid, id, node), this);
} }
private void on_pre_received_message(XmppStream stream, Message.Stanza message) { private void on_pre_received_message(XmppStream stream, Message.Stanza message) {
StanzaNode? encrypted = message.stanza.get_subnode("encrypted", NS_URI); StanzaNode? _encrypted = message.stanza.get_subnode("encrypted", NS_URI);
if (encrypted == null || MessageFlag.get_flag(message) != null) return; if (_encrypted == null || MessageFlag.get_flag(message) != null || message.from == null) return;
StanzaNode encrypted = (!)_encrypted;
if (!Plugin.ensure_context()) return;
MessageFlag flag = new MessageFlag(); MessageFlag flag = new MessageFlag();
message.add_flag(flag); message.add_flag(flag);
StanzaNode? header = encrypted.get_subnode("header"); StanzaNode? _header = encrypted.get_subnode("header");
if (header == null || header.get_attribute_int("sid") <= 0) return; if (_header == null) return;
StanzaNode header = (!)_header;
if (header.get_attribute_int("sid") <= 0) return;
foreach (StanzaNode key_node in header.get_subnodes("key")) { foreach (StanzaNode key_node in header.get_subnodes("key")) {
if (key_node.get_attribute_int("rid") == store.local_registration_id) { if (key_node.get_attribute_int("rid") == store.local_registration_id) {
try { try {
uint8[] key = null; string? payload = encrypted.get_deep_string_content("payload");
uint8[] ciphertext = Base64.decode(encrypted.get_subnode("payload").get_string_content()); string? iv_node = header.get_deep_string_content("iv");
uint8[] iv = Base64.decode(header.get_subnode("iv").get_string_content()); string? key_node_content = key_node.get_string_content();
Address address = new Address(); if (payload == null || iv_node == null || key_node_content == null) continue;
address.name = get_bare_jid(message.from); uint8[] key;
address.device_id = header.get_attribute_int("sid"); uint8[] ciphertext = Base64.decode((!)payload);
uint8[] iv = Base64.decode((!)iv_node);
Address address = new Address(get_bare_jid((!)message.from), header.get_attribute_int("sid"));
if (key_node.get_attribute_bool("prekey")) { if (key_node.get_attribute_bool("prekey")) {
PreKeySignalMessage msg = Plugin.context.deserialize_pre_key_signal_message(Base64.decode(key_node.get_string_content())); PreKeySignalMessage msg = Plugin.get_context().deserialize_pre_key_signal_message(Base64.decode((!)key_node_content));
SessionCipher cipher = store.create_session_cipher(address); SessionCipher cipher = store.create_session_cipher(address);
key = cipher.decrypt_pre_key_signal_message(msg); key = cipher.decrypt_pre_key_signal_message(msg);
} else { } else {
SignalMessage msg = Plugin.context.deserialize_signal_message(Base64.decode(key_node.get_string_content())); SignalMessage msg = Plugin.get_context().deserialize_signal_message(Base64.decode((!)key_node_content));
SessionCipher cipher = store.create_session_cipher(address); SessionCipher cipher = store.create_session_cipher(address);
key = cipher.decrypt_signal_message(msg); key = cipher.decrypt_signal_message(msg);
} }
address.device_id = 0; // TODO: Hack to have address obj live longer address.device_id = 0; // TODO: Hack to have address obj live longer
if (key != null && ciphertext != null && iv != null) {
if (key.length >= 32) { if (key.length >= 32) {
int authtaglength = key.length - 16; int authtaglength = key.length - 16;
uint8[] new_ciphertext = new uint8[ciphertext.length + authtaglength]; uint8[] new_ciphertext = new uint8[ciphertext.length + authtaglength];
@ -164,7 +168,6 @@ public class StreamModule : XmppStreamModule {
message.body = arr_to_str(aes_decrypt(Cipher.AES_GCM_NOPADDING, key, iv, ciphertext)); message.body = arr_to_str(aes_decrypt(Cipher.AES_GCM_NOPADDING, key, iv, ciphertext));
flag.decrypted = true; flag.decrypted = true;
}
} catch (Error e) { } catch (Error e) {
if (Plugin.DEBUG) print(@"OMEMO: Signal error while decrypting message: $(e.message)\n"); if (Plugin.DEBUG) print(@"OMEMO: Signal error while decrypting message: $(e.message)\n");
} }
@ -182,17 +185,15 @@ public class StreamModule : XmppStreamModule {
public void request_user_devicelist(XmppStream stream, string jid) { public void request_user_devicelist(XmppStream stream, string jid) {
if (active_devicelist_requests.add(jid)) { if (active_devicelist_requests.add(jid)) {
if (Plugin.DEBUG) print(@"OMEMO: requesting device list for $jid\n"); if (Plugin.DEBUG) print(@"OMEMO: requesting device list for $jid\n");
stream.get_module(Pubsub.Module.IDENTITY).request(stream, jid, NODE_DEVICELIST, (stream, jid, id, node, obj) => (obj as StreamModule).on_devicelist(stream, jid, id ?? "", node), this); stream.get_module(Pubsub.Module.IDENTITY).request(stream, jid, NODE_DEVICELIST, (stream, jid, id, node, obj) => ((StreamModule)obj).on_devicelist(stream, jid, id ?? "", node), this);
} }
} }
public void on_devicelist(XmppStream stream, string jid, string id, StanzaNode? node_) { public void on_devicelist(XmppStream stream, string jid, string id, StanzaNode? node_) {
StanzaNode? node = node_; StanzaNode node = node_ ?? new StanzaNode.build("list", NS_URI).add_self_xmlns();
if (jid == get_bare_jid(stream.get_flag(Bind.Flag.IDENTITY).my_jid) && store.local_registration_id != 0) { string? my_jid = stream.get_flag(Bind.Flag.IDENTITY).my_jid;
if (node == null) { if (my_jid == null) return;
node = new StanzaNode.build("list", NS_URI).add_self_xmlns().put_node(new StanzaNode.build("device", NS_URI)); if (jid == get_bare_jid((!)my_jid) && store.local_registration_id != 0) {
}
bool am_on_devicelist = false; bool am_on_devicelist = false;
foreach (StanzaNode device_node in node.get_subnodes("device")) { foreach (StanzaNode device_node in node.get_subnodes("device")) {
int device_id = device_node.get_attribute_int("id"); int device_id = device_node.get_attribute_int("id");
@ -223,8 +224,7 @@ public class StreamModule : XmppStreamModule {
// TODO: manually request a device list // TODO: manually request a device list
return; return;
} }
Address address = new Address(); Address address = new Address(bare_jid, 0);
address.name = bare_jid;
foreach(int32 device_id in device_lists[bare_jid]) { foreach(int32 device_id in device_lists[bare_jid]) {
if (!is_ignored_device(bare_jid, device_id)) { if (!is_ignored_device(bare_jid, device_id)) {
address.device_id = device_id; address.device_id = device_id;
@ -293,9 +293,7 @@ public class StreamModule : XmppStreamModule {
if (signed_pre_key_id < 0 || signed_pre_key == null || identity_key == null || pre_key_id < 0 || pre_key == null) { if (signed_pre_key_id < 0 || signed_pre_key == null || identity_key == null || pre_key_id < 0 || pre_key == null) {
fail = true; fail = true;
} else { } else {
Address address = new Address(); Address address = new Address(jid, device_id);
address.name = jid;
address.device_id = device_id;
try { try {
if (store.contains_session(address)) { if (store.contains_session(address)) {
return; return;
@ -322,13 +320,13 @@ public class StreamModule : XmppStreamModule {
} }
private static void on_self_bundle_result(XmppStream stream, string jid, string? id, StanzaNode? node, Object? storage) { private static void on_self_bundle_result(XmppStream stream, string jid, string? id, StanzaNode? node, Object? storage) {
if (!Plugin.ensure_context()) return;
Store store = (Store)storage; Store store = (Store)storage;
Map<int, ECPublicKey> keys = new HashMap<int, ECPublicKey>(); Map<int, ECPublicKey> keys = new HashMap<int, ECPublicKey>();
ECPublicKey identity_key = null; ECPublicKey? identity_key = null;
IdentityKeyPair identity_key_pair = null;
int32 signed_pre_key_id = -1; int32 signed_pre_key_id = -1;
ECPublicKey signed_pre_key = null; ECPublicKey? signed_pre_key = null;
SignedPreKeyRecord signed_pre_key_record = null; SignedPreKeyRecord? signed_pre_key_record = null;
bool changed = false; bool changed = false;
if (node == null) { if (node == null) {
identity_key = store.identity_key_pair.public; identity_key = store.identity_key_pair.public;
@ -336,7 +334,10 @@ public class StreamModule : XmppStreamModule {
} else { } else {
Bundle bundle = new Bundle(node); Bundle bundle = new Bundle(node);
foreach (Bundle.PreKey prekey in bundle.pre_keys) { foreach (Bundle.PreKey prekey in bundle.pre_keys) {
keys[prekey.key_id] = prekey.key; ECPublicKey? key = prekey.key;
if (key != null) {
keys[prekey.key_id] = (!)key;
}
} }
identity_key = bundle.identity_key; identity_key = bundle.identity_key;
signed_pre_key_id = bundle.signed_pre_key_id;; signed_pre_key_id = bundle.signed_pre_key_id;;
@ -345,16 +346,16 @@ public class StreamModule : XmppStreamModule {
try { try {
// Validate IdentityKey // Validate IdentityKey
if (store.identity_key_pair.public.compare(identity_key) != 0) { if (identity_key == null || store.identity_key_pair.public.compare((!)identity_key) != 0) {
changed = true; changed = true;
} }
identity_key_pair = store.identity_key_pair; IdentityKeyPair identity_key_pair = store.identity_key_pair;
// Validate signedPreKeyRecord + ID // Validate signedPreKeyRecord + ID
if (signed_pre_key_id == -1 || !store.contains_signed_pre_key(signed_pre_key_id) || store.load_signed_pre_key(signed_pre_key_id).key_pair.public.compare(signed_pre_key) != 0) { if (signed_pre_key == null || signed_pre_key_id == -1 || !store.contains_signed_pre_key(signed_pre_key_id) || store.load_signed_pre_key(signed_pre_key_id).key_pair.public.compare((!)signed_pre_key) != 0) {
signed_pre_key_id = Random.int_range(1, int32.MAX); // TODO: No random, use ordered number signed_pre_key_id = Random.int_range(1, int32.MAX); // TODO: No random, use ordered number
signed_pre_key_record = Plugin.context.generate_signed_pre_key(identity_key_pair, signed_pre_key_id); signed_pre_key_record = Plugin.get_context().generate_signed_pre_key(identity_key_pair, signed_pre_key_id);
store.store_signed_pre_key(signed_pre_key_record); store.store_signed_pre_key((!)signed_pre_key_record);
changed = true; changed = true;
} else { } else {
signed_pre_key_record = store.load_signed_pre_key(signed_pre_key_id); signed_pre_key_record = store.load_signed_pre_key(signed_pre_key_id);
@ -373,7 +374,7 @@ public class StreamModule : XmppStreamModule {
int new_keys = NUM_KEYS_TO_PUBLISH - pre_key_records.size; int new_keys = NUM_KEYS_TO_PUBLISH - pre_key_records.size;
if (new_keys > 0) { if (new_keys > 0) {
int32 next_id = Random.int_range(1, int32.MAX); // TODO: No random, use ordered number int32 next_id = Random.int_range(1, int32.MAX); // TODO: No random, use ordered number
Set<PreKeyRecord> new_records = Plugin.context.generate_pre_keys((uint)next_id, (uint)new_keys); Set<PreKeyRecord> new_records = Plugin.get_context().generate_pre_keys((uint)next_id, (uint)new_keys);
pre_key_records.add_all(new_records); pre_key_records.add_all(new_records);
foreach (PreKeyRecord record in new_records) { foreach (PreKeyRecord record in new_records) {
store.store_pre_key(record); store.store_pre_key(record);
@ -382,7 +383,7 @@ public class StreamModule : XmppStreamModule {
} }
if (changed) { if (changed) {
publish_bundles(stream, signed_pre_key_record, identity_key_pair, pre_key_records, (int32) store.local_registration_id); publish_bundles(stream, (!)signed_pre_key_record, identity_key_pair, pre_key_records, (int32) store.local_registration_id);
} }
} catch (Error e) { } catch (Error e) {
if (Plugin.DEBUG) print(@"Unexpected error while publishing bundle: $(e.message)\n"); if (Plugin.DEBUG) print(@"Unexpected error while publishing bundle: $(e.message)\n");

View file

@ -3,14 +3,30 @@
#include <gcrypt.h> #include <gcrypt.h>
signal_protocol_address* signal_protocol_address_new() { signal_type_base* signal_type_ref_vapi(signal_type_base* instance) {
g_return_val_if_fail(instance != NULL, NULL);
signal_type_ref(instance);
return instance;
}
signal_type_base* signal_type_unref_vapi(signal_type_base* instance) {
g_return_val_if_fail(instance != NULL, NULL);
signal_type_unref(instance);
return NULL;
}
signal_protocol_address* signal_protocol_address_new(const gchar* name, int32_t device_id) {
g_return_val_if_fail(name != NULL, NULL);
signal_protocol_address* address = malloc(sizeof(signal_protocol_address)); signal_protocol_address* address = malloc(sizeof(signal_protocol_address));
address->name = 0; address->device_id = NULL;
address->device_id = 0; address->name = NULL;
signal_protocol_address_set_name(address, name);
signal_protocol_address_set_device_id(address, device_id);
return address; return address;
} }
void signal_protocol_address_free(signal_protocol_address* ptr) { void signal_protocol_address_free(signal_protocol_address* ptr) {
g_return_if_fail(ptr != NULL);
if (ptr->name) { if (ptr->name) {
g_free((void*)ptr->name); g_free((void*)ptr->name);
} }
@ -18,6 +34,8 @@ void signal_protocol_address_free(signal_protocol_address* ptr) {
} }
void signal_protocol_address_set_name(signal_protocol_address* self, const gchar* name) { void signal_protocol_address_set_name(signal_protocol_address* self, const gchar* name) {
g_return_if_fail(self != NULL);
g_return_if_fail(name != NULL);
gchar* n = g_malloc(strlen(name)+1); gchar* n = g_malloc(strlen(name)+1);
memcpy(n, name, strlen(name)); memcpy(n, name, strlen(name));
n[strlen(name)] = 0; n[strlen(name)] = 0;
@ -29,13 +47,25 @@ void signal_protocol_address_set_name(signal_protocol_address* self, const gchar
} }
gchar* signal_protocol_address_get_name(signal_protocol_address* self) { gchar* signal_protocol_address_get_name(signal_protocol_address* self) {
if (self->name == 0) return 0; g_return_val_if_fail(self != NULL, NULL);
g_return_val_if_fail(self->name != NULL, 0);
gchar* res = g_malloc(sizeof(char) * (self->name_len + 1)); gchar* res = g_malloc(sizeof(char) * (self->name_len + 1));
memcpy(res, self->name, self->name_len); memcpy(res, self->name, self->name_len);
res[self->name_len] = 0; res[self->name_len] = 0;
return res; return res;
} }
int32_t signal_protocol_address_get_device_id(signal_protocol_address* self) {
g_return_val_if_fail(self != NULL, NULL);
return self->device_id;
}
void signal_protocol_address_set_device_id(signal_protocol_address* self, int32_t device_id) {
g_return_if_fail(self != NULL);
self->device_id = device_id;
}
session_pre_key* session_pre_key_new(uint32_t pre_key_id, ec_key_pair* pair, int* err) { session_pre_key* session_pre_key_new(uint32_t pre_key_id, ec_key_pair* pair, int* err) {
session_pre_key* res; session_pre_key* res;
*err = session_pre_key_create(&res, pre_key_id, pair); *err = session_pre_key_create(&res, pre_key_id, pair);

View file

@ -9,10 +9,14 @@
signal_type_base* signal_type_ref_vapi(signal_type_base* what); signal_type_base* signal_type_ref_vapi(signal_type_base* what);
signal_type_base* signal_type_unref_vapi(signal_type_base* what); signal_type_base* signal_type_unref_vapi(signal_type_base* what);
signal_protocol_address* signal_protocol_address_new();
signal_protocol_address* signal_protocol_address_new(const gchar* name, int32_t device_id);
void signal_protocol_address_free(signal_protocol_address* ptr); void signal_protocol_address_free(signal_protocol_address* ptr);
void signal_protocol_address_set_name(signal_protocol_address* self, const gchar* name); void signal_protocol_address_set_name(signal_protocol_address* self, const gchar* name);
gchar* signal_protocol_address_get_name(signal_protocol_address* self); gchar* signal_protocol_address_get_name(signal_protocol_address* self);
void signal_protocol_address_set_device_id(signal_protocol_address* self, int32_t device_id);
int32_t signal_protocol_address_get_device_id(signal_protocol_address* self);
session_pre_key* session_pre_key_new(uint32_t pre_key_id, ec_key_pair* pair, int* err); session_pre_key* session_pre_key_new(uint32_t pre_key_id, ec_key_pair* pair, int* err);
session_signed_pre_key* session_signed_pre_key_new(uint32_t id, uint64_t timestamp, ec_key_pair* pair, uint8_t* key, int key_len, int* err); session_signed_pre_key* session_signed_pre_key_new(uint32_t id, uint64_t timestamp, ec_key_pair* pair, uint8_t* key, int key_len, int* err);

View file

@ -7,10 +7,8 @@ public class SimpleSessionStore : SessionStore {
private Map<string, ArrayList<SessionStore.Session>> session_map = new HashMap<string, ArrayList<SessionStore.Session>>(); private Map<string, ArrayList<SessionStore.Session>> session_map = new HashMap<string, ArrayList<SessionStore.Session>>();
public override uint8[]? load_session(Address address) throws Error { public override uint8[]? load_session(Address address) throws Error {
string name = address.name; if (session_map.has_key(address.name)) {
if (name == null) return null; foreach (SessionStore.Session session in session_map[address.name]) {
if (session_map.has_key(name)) {
foreach (SessionStore.Session session in session_map[name]) {
if (session.device_id == address.device_id) return session.record; if (session.device_id == address.device_id) return session.record;
} }
} }

View file

@ -142,9 +142,9 @@ public class Store : Object {
return 0; return 0;
} }
static int ss_load_session_func(out Buffer buffer, Address address, void* user_data) { static int ss_load_session_func(out Buffer? buffer, Address address, void* user_data) {
Store store = (Store) user_data; Store store = (Store) user_data;
uint8[] res = null; uint8[]? res = null;
try { try {
res = store.session_store.load_session(address); res = store.session_store.load_session(address);
} catch (Error e) { } catch (Error e) {
@ -155,12 +155,12 @@ public class Store : Object {
buffer = null; buffer = null;
return 0; return 0;
} }
buffer = new Buffer.from(res); buffer = new Buffer.from((!)res);
if (buffer == null) return ErrorCode.NOMEM; if (buffer == null) return ErrorCode.NOMEM;
return 1; return 1;
} }
static int ss_get_sub_device_sessions_func(out IntList sessions, char[] name, void* user_data) { static int ss_get_sub_device_sessions_func(out IntList? sessions, char[] name, void* user_data) {
Store store = (Store) user_data; Store store = (Store) user_data;
try { try {
sessions = store.session_store.get_sub_device_sessions(carr_to_string(name)); sessions = store.session_store.get_sub_device_sessions(carr_to_string(name));
@ -206,9 +206,9 @@ public class Store : Object {
return 0; return 0;
} }
static int pks_load_pre_key(out Buffer record, uint32 pre_key_id, void* user_data) { static int pks_load_pre_key(out Buffer? record, uint32 pre_key_id, void* user_data) {
Store store = (Store) user_data; Store store = (Store) user_data;
uint8[] res = null; uint8[]? res = null;
try { try {
res = store.pre_key_store.load_pre_key(pre_key_id); res = store.pre_key_store.load_pre_key(pre_key_id);
} catch (Error e) { } catch (Error e) {
@ -219,7 +219,7 @@ public class Store : Object {
record = new Buffer(0); record = new Buffer(0);
return 0; return 0;
} }
record = new Buffer.from(res); record = new Buffer.from((!)res);
if (record == null) return ErrorCode.NOMEM; if (record == null) return ErrorCode.NOMEM;
return 1; return 1;
} }
@ -251,9 +251,9 @@ public class Store : Object {
return 0; return 0;
} }
static int spks_load_signed_pre_key(out Buffer record, uint32 pre_key_id, void* user_data) { static int spks_load_signed_pre_key(out Buffer? record, uint32 pre_key_id, void* user_data) {
Store store = (Store) user_data; Store store = (Store) user_data;
uint8[] res = null; uint8[]? res = null;
try { try {
res = store.signed_pre_key_store.load_signed_pre_key(pre_key_id); res = store.signed_pre_key_store.load_signed_pre_key(pre_key_id);
} catch (Error e) { } catch (Error e) {
@ -264,7 +264,7 @@ public class Store : Object {
record = new Buffer(0); record = new Buffer(0);
return 0; return 0;
} }
record = new Buffer.from(res); record = new Buffer.from((!)res);
if (record == null) return ErrorCode.NOMEM; if (record == null) return ErrorCode.NOMEM;
return 1; return 1;
} }

View file

@ -18,12 +18,8 @@ class SessionBuilderTest : Gee.TestCase {
public override void set_up() { public override void set_up() {
try { try {
global_context = new Context(); global_context = new Context();
alice_address = new Address(); alice_address = new Address("+14151111111", 1);
alice_address.name = "+14151111111"; bob_address = new Address("+14152222222", 1);
alice_address.device_id = 1;
bob_address = new Address();
bob_address.name = "+14152222222";
bob_address.device_id = 1;
} catch (Error e) { } catch (Error e) {
fail_if_reached(@"Unexpected error: $(e.message)"); fail_if_reached(@"Unexpected error: $(e.message)");
} }

View file

@ -51,7 +51,7 @@ namespace Signal {
} }
[Compact] [Compact]
[CCode (cname = "signal_type_base", ref_function="signal_type_ref", ref_function_void=true, unref_function="signal_type_unref", cheader_filename="signal_protocol_types.h,signal_helper.h")] [CCode (cname = "signal_type_base", ref_function="signal_type_ref_vapi", unref_function="signal_type_unref_vapi", cheader_filename="signal_protocol_types.h,signal_helper.h")]
public class TypeBase { public class TypeBase {
} }
@ -103,8 +103,8 @@ namespace Signal {
[Compact] [Compact]
[CCode (cname = "session_pre_key_bundle", cprefix = "session_pre_key_bundle_", cheader_filename = "session_pre_key.h")] [CCode (cname = "session_pre_key_bundle", cprefix = "session_pre_key_bundle_", cheader_filename = "session_pre_key.h")]
public class PreKeyBundle : TypeBase { public class PreKeyBundle : TypeBase {
public static int create(out PreKeyBundle bundle, uint32 registration_id, int device_id, uint32 pre_key_id, ECPublicKey pre_key_public, public static int create(out PreKeyBundle bundle, uint32 registration_id, int device_id, uint32 pre_key_id, ECPublicKey? pre_key_public,
uint32 signed_pre_key_id, ECPublicKey signed_pre_key_public, uint8[] signed_pre_key_signature, ECPublicKey identity_key); uint32 signed_pre_key_id, ECPublicKey? signed_pre_key_public, uint8[]? signed_pre_key_signature, ECPublicKey? identity_key);
public uint32 registration_id { get; } public uint32 registration_id { get; }
public int device_id { get; } public int device_id { get; }
public uint32 pre_key_id { get; } public uint32 pre_key_id { get; }
@ -192,9 +192,8 @@ namespace Signal {
[Compact] [Compact]
[CCode (cname = "signal_protocol_address", cprefix = "signal_protocol_address_", cheader_filename = "signal_protocol.h,signal_helper.h")] [CCode (cname = "signal_protocol_address", cprefix = "signal_protocol_address_", cheader_filename = "signal_protocol.h,signal_helper.h")]
public class Address { public class Address {
public Address(); public Address(string name, int32 device_id);
public int32 device_id; public int32 device_id { get; set; }
public string name { owned get; set; } public string name { owned get; set; }
} }