fetch required device ids on demand

This commit is contained in:
Daniel Gultsch 2017-06-29 14:17:45 +02:00
parent f7258d16e1
commit 7a2856ac86
4 changed files with 104 additions and 36 deletions

View file

@ -12,7 +12,7 @@ public final class Config {
private static final int OTR = 4; private static final int OTR = 4;
private static final int OMEMO = 8; private static final int OMEMO = 8;
private static final int ENCRYPTION_MASK = UNENCRYPTED | OPENPGP | OTR | OMEMO; private static final int ENCRYPTION_MASK = OMEMO;
public static boolean supportUnencrypted() { public static boolean supportUnencrypted() {
return (ENCRYPTION_MASK & UNENCRYPTED) != 0; return (ENCRYPTION_MASK & UNENCRYPTED) != 0;

View file

@ -31,6 +31,7 @@ import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Random; import java.util.Random;
@ -74,6 +75,7 @@ public class AxolotlService implements OnAdvancedStreamFeaturesLoaded {
private final Map<Jid, Set<Integer>> deviceIds; private final Map<Jid, Set<Integer>> deviceIds;
private final Map<String, XmppAxolotlMessage> messageCache; private final Map<String, XmppAxolotlMessage> messageCache;
private final FetchStatusMap fetchStatusMap; private final FetchStatusMap fetchStatusMap;
private final HashMap<Jid,List<OnDeviceIdsFetched>> fetchDeviceIdsMap = new HashMap<>();
private final SerialSingleThreadExecutor executor; private final SerialSingleThreadExecutor executor;
private int numPublishTriesOnEmptyPep = 0; private int numPublishTriesOnEmptyPep = 0;
private boolean pepBroken = false; private boolean pepBroken = false;
@ -210,10 +212,9 @@ public class AxolotlService implements OnAdvancedStreamFeaturesLoaded {
private void fillMap(SQLiteAxolotlStore store) { private void fillMap(SQLiteAxolotlStore store) {
List<Integer> deviceIds = store.getSubDeviceSessions(account.getJid().toBareJid().toPreppedString()); List<Integer> deviceIds = store.getSubDeviceSessions(account.getJid().toBareJid().toPreppedString());
putDevicesForJid(account.getJid().toBareJid().toPreppedString(), deviceIds, store); putDevicesForJid(account.getJid().toBareJid().toPreppedString(), deviceIds, store);
for (Contact contact : account.getRoster().getContacts()) { for (String address : store.getKnownAddresses()) {
Jid bareJid = contact.getJid().toBareJid();
String address = bareJid.toPreppedString();
deviceIds = store.getSubDeviceSessions(address); deviceIds = store.getSubDeviceSessions(address);
Log.d(Config.LOGTAG,account.getJid().toBareJid()+" adding device ids for "+address+" "+deviceIds);
putDevicesForJid(address, deviceIds, store); putDevicesForJid(address, deviceIds, store);
} }
@ -358,6 +359,7 @@ public class AxolotlService implements OnAdvancedStreamFeaturesLoaded {
axolotlStore.regenerate(); axolotlStore.regenerate();
sessions.clear(); sessions.clear();
fetchStatusMap.clear(); fetchStatusMap.clear();
fetchDeviceIdsMap.clear();
publishBundlesIfNeeded(true, wipeOther); publishBundlesIfNeeded(true, wipeOther);
} }
@ -753,7 +755,8 @@ public class AxolotlService implements OnAdvancedStreamFeaturesLoaded {
public List<Jid> getCryptoTargets(Conversation conversation) { public List<Jid> getCryptoTargets(Conversation conversation) {
final List<Jid> jids; final List<Jid> jids;
if (conversation.getMode() == Conversation.MODE_SINGLE) { if (conversation.getMode() == Conversation.MODE_SINGLE) {
jids = Arrays.asList(conversation.getJid().toBareJid()); jids = new ArrayList<>();
jids.add(conversation.getJid().toBareJid());
} else { } else {
jids = conversation.getMucOptions().getMembers(); jids = conversation.getMucOptions().getMembers();
} }
@ -866,36 +869,80 @@ public class AxolotlService implements OnAdvancedStreamFeaturesLoaded {
} }
public interface OnDeviceIdsFetched { public interface OnDeviceIdsFetched {
void fetched(Set<Integer> deviceIds); void fetched(Jid jid, Set<Integer> deviceIds);
}
public interface OnMultipleDeviceIdFetched {
void fetched();
} }
public void fetchDeviceIds(final Jid jid) { public void fetchDeviceIds(final Jid jid) {
fetchDeviceIds(jid,null); fetchDeviceIds(jid,null);
} }
public void fetchDeviceIds(final Jid jid, final OnDeviceIdsFetched callback) { public void fetchDeviceIds(final Jid jid, OnDeviceIdsFetched callback) {
Log.d(Config.LOGTAG,"fetching device ids for "+jid); synchronized (this.fetchDeviceIdsMap) {
List<OnDeviceIdsFetched> callbacks = this.fetchDeviceIdsMap.get(jid);
if (callbacks != null) {
if (callback != null) {
callbacks.add(callback);
}
Log.d(Config.LOGTAG,account.getJid().toBareJid()+": fetching device ids for "+jid+" already running. adding callback");
} else {
callbacks = new ArrayList<>();
if (callback != null) {
callbacks.add(callback);
}
this.fetchDeviceIdsMap.put(jid,callbacks);
Log.d(Config.LOGTAG,account.getJid().toBareJid()+": fetching device ids for " + jid);
IqPacket packet = mXmppConnectionService.getIqGenerator().retrieveDeviceIds(jid); IqPacket packet = mXmppConnectionService.getIqGenerator().retrieveDeviceIds(jid);
mXmppConnectionService.sendIqPacket(account, packet, new OnIqPacketReceived() { mXmppConnectionService.sendIqPacket(account, packet, new OnIqPacketReceived() {
@Override @Override
public void onIqPacketReceived(Account account, IqPacket packet) { public void onIqPacketReceived(Account account, IqPacket packet) {
synchronized (fetchDeviceIdsMap) {
List<OnDeviceIdsFetched> callbacks = fetchDeviceIdsMap.remove(jid);
if (packet.getType() == IqPacket.TYPE.RESULT) { if (packet.getType() == IqPacket.TYPE.RESULT) {
Element item = mXmppConnectionService.getIqParser().getItem(packet); Element item = mXmppConnectionService.getIqParser().getItem(packet);
Set<Integer> deviceIds = mXmppConnectionService.getIqParser().deviceIds(item); Set<Integer> deviceIds = mXmppConnectionService.getIqParser().deviceIds(item);
registerDevices(jid,deviceIds); registerDevices(jid, deviceIds);
if (callback != null) { if (callbacks != null) {
callback.fetched(deviceIds); for(OnDeviceIdsFetched callback : callbacks) {
callback.fetched(jid, deviceIds);
}
} }
} else { } else {
Log.d(Config.LOGTAG,packet.toString()); Log.d(Config.LOGTAG, packet.toString());
if (callback != null) { if (callbacks != null) {
callback.fetched(null); for(OnDeviceIdsFetched callback : callbacks) {
callback.fetched(jid, null);
}
}
}
}
}
});
}
} }
} }
private void fetchDeviceIds(List<Jid> jids, final OnMultipleDeviceIdFetched callback) {
final ArrayList<Jid> unfinishedJids = new ArrayList<>(jids);
synchronized (unfinishedJids) {
for (Jid jid : unfinishedJids) {
fetchDeviceIds(jid, new OnDeviceIdsFetched() {
@Override
public void fetched(Jid jid, Set<Integer> deviceIds) {
synchronized (unfinishedJids) {
unfinishedJids.remove(jid);
if (unfinishedJids.size() == 0 && callback != null) {
callback.fetched();
}
}
} }
}); });
} }
}
}
private void buildSessionFromPEP(final SignalProtocolAddress address) { private void buildSessionFromPEP(final SignalProtocolAddress address) {
Log.i(Config.LOGTAG, AxolotlService.getLogprefix(account) + "Building new session for " + address.toString()); Log.i(Config.LOGTAG, AxolotlService.getLogprefix(account) + "Building new session for " + address.toString());
@ -1028,19 +1075,19 @@ public class AxolotlService implements OnAdvancedStreamFeaturesLoaded {
} }
public boolean createSessionsIfNeeded(final Conversation conversation) { public boolean createSessionsIfNeeded(final Conversation conversation) {
final Jid jid = conversation.getJid().toBareJid(); final List<Jid> jidsWithEmptyDeviceList = getCryptoTargets(conversation);
if (conversation.getMode() == Conversation.MODE_SINGLE && hasEmptyDeviceList(jid)) { for(Iterator<Jid> iterator = jidsWithEmptyDeviceList.iterator(); iterator.hasNext();) {
final SignalProtocolAddress placeholder = new SignalProtocolAddress(jid.toPreppedString(), Integer.MIN_VALUE); final Jid jid = iterator.next();
FetchStatus status = fetchStatusMap.get(placeholder); if (!hasEmptyDeviceList(jid)) {
if (status == null || status == FetchStatus.TIMEOUT) { iterator.remove();
fetchStatusMap.put(placeholder, FetchStatus.PENDING);
} }
fetchDeviceIds(conversation.getJid().toBareJid(), new OnDeviceIdsFetched() { }
Log.d(Config.LOGTAG,account.getJid().toBareJid()+": createSessionsIfNeeded() - jids with empty device list: "+jidsWithEmptyDeviceList);
if (jidsWithEmptyDeviceList.size() > 0) {
fetchDeviceIds(jidsWithEmptyDeviceList, new OnMultipleDeviceIdFetched() {
@Override @Override
public void fetched(Set<Integer> deviceIds) { public void fetched() {
createSessionsIfNeededActual(conversation); createSessionsIfNeededActual(conversation);
fetchStatusMap.put(placeholder,deviceIds != null && !deviceIds.isEmpty() ? FetchStatus.SUCCESS : FetchStatus.ERROR);
finishBuildingSessionsFromPEP(placeholder);
} }
}); });
return true; return true;

View file

@ -283,6 +283,10 @@ public class SQLiteAxolotlStore implements SignalProtocolStore {
new SignalProtocolAddress(name, 0)); new SignalProtocolAddress(name, 0));
} }
public List<String> getKnownAddresses() {
return mXmppConnectionService.databaseBackend.getKnownSignalAddresses(account);
}
/** /**
* Commit to storage the {@link SessionRecord} for a given recipientId + deviceId tuple. * Commit to storage the {@link SessionRecord} for a given recipientId + deviceId tuple.
* *

View file

@ -946,6 +946,23 @@ public class DatabaseBackend extends SQLiteOpenHelper {
return devices; return devices;
} }
public List<String> getKnownSignalAddresses(Account account) {
List<String> addresses = new ArrayList<>();
String[] colums = {SQLiteAxolotlStore.NAME};
String[] selectionArgs = {account.getUuid()};
Cursor cursor = getReadableDatabase().query(SQLiteAxolotlStore.SESSION_TABLENAME,
colums,
SQLiteAxolotlStore.ACCOUNT + " = ?",
selectionArgs,
null,null,null
);
while (cursor.moveToNext()) {
addresses.add(cursor.getString(cursor.getColumnIndex(SQLiteAxolotlStore.NAME)));
}
cursor.close();
return addresses;
}
public boolean containsSession(Account account, SignalProtocolAddress contact) { public boolean containsSession(Account account, SignalProtocolAddress contact) {
Cursor cursor = getCursorForSession(account, contact); Cursor cursor = getCursorForSession(account, contact);
int count = cursor.getCount(); int count = cursor.getCount();