From 7ee3e079461fe1b48b59c14b32ad5fb1f66312a0 Mon Sep 17 00:00:00 2001 From: Daniel Gultsch Date: Fri, 13 Jan 2023 10:59:23 +0100 Subject: [PATCH] Introduce XmppConnection v3 The various layers of the app are too intertwined to refactor them in place. The C3 refactor is going to create a parallel architecture for all classes that have too strong of a connection to other parts of the app. This commit introduces XmppConnection v3 that keeps a lot of the logic of the privous XmppConnection but cuts ties to XmppConnectionService and the very stateful `entites.Account`. The latter is replaced by a lightweight immutable account model. The reconnection logic has been kept but was moved from XmppConnectionService to a singleton ConnectionPool. --- build.gradle | 2 + .../1.json | 38 +- src/main/AndroidManifest.xml | 1 + .../services/MemorizingTrustManager.java | 280 +- .../siacs/conversations/utils/SSLSockets.java | 41 +- .../java/eu/siacs/conversations/xml/Tag.java | 14 +- .../conversations/android/Conversations.java | 13 + .../java/im/conversations/android/Uuids.java | 38 + .../database/ConversationsDatabase.java | 6 + .../android/database/CredentialStore.java | 206 ++ .../android/database/dao/AccountDao.java | 64 + .../android/database/dao/PresenceDao.java | 11 + .../database/entity/AccountEntity.java | 9 + .../database/entity/PresenceEntity.java | 3 + .../android/database/model/Account.java | 57 + .../android/database/model/Credential.java | 50 + .../android/xmpp/ConnectionPool.java | 356 +++ .../android/xmpp/ConnectionState.java | 124 + .../android/xmpp/XmppConnection.java | 2805 +++++++++++++++++ .../android/xmpp/processor/BaseProcessor.java | 25 + .../android/xmpp/processor/BindProcessor.java | 54 + .../android/xmpp/processor/IqProcessor.java | 14 + .../xmpp/processor/JingleProcessor.java | 14 + .../MessageAcknowledgeProcessor.java | 16 + .../xmpp/processor/MessageProcessor.java | 14 + .../xmpp/processor/PresenceProcessor.java | 14 + .../android/xmpp/sasl/Anonymous.java | 29 + .../android/xmpp/sasl/ChannelBinding.java | 117 + .../xmpp/sasl/ChannelBindingMechanism.java | 100 + .../android/xmpp/sasl/DigestMd5.java | 112 + .../android/xmpp/sasl/External.java | 30 + .../android/xmpp/sasl/HashedToken.java | 189 ++ .../android/xmpp/sasl/HashedTokenSha256.java | 26 + .../android/xmpp/sasl/HashedTokenSha512.java | 26 + .../android/xmpp/sasl/Plain.java | 36 + .../android/xmpp/sasl/SaslMechanism.java | 236 ++ .../android/xmpp/sasl/ScramMechanism.java | 318 ++ .../android/xmpp/sasl/ScramPlusMechanism.java | 24 + .../android/xmpp/sasl/ScramSha1.java | 37 + .../android/xmpp/sasl/ScramSha1Plus.java | 38 + .../android/xmpp/sasl/ScramSha256.java | 37 + .../android/xmpp/sasl/ScramSha256Plus.java | 40 + .../android/xmpp/sasl/ScramSha512.java | 37 + .../android/xmpp/sasl/ScramSha512Plus.java | 40 + .../android/xmpp/sasl/Tokenizer.java | 77 + 45 files changed, 5671 insertions(+), 147 deletions(-) create mode 100644 src/main/java/im/conversations/android/Conversations.java create mode 100644 src/main/java/im/conversations/android/Uuids.java create mode 100644 src/main/java/im/conversations/android/database/CredentialStore.java create mode 100644 src/main/java/im/conversations/android/database/dao/AccountDao.java create mode 100644 src/main/java/im/conversations/android/database/dao/PresenceDao.java create mode 100644 src/main/java/im/conversations/android/database/model/Account.java create mode 100644 src/main/java/im/conversations/android/database/model/Credential.java create mode 100644 src/main/java/im/conversations/android/xmpp/ConnectionPool.java create mode 100644 src/main/java/im/conversations/android/xmpp/ConnectionState.java create mode 100644 src/main/java/im/conversations/android/xmpp/XmppConnection.java create mode 100644 src/main/java/im/conversations/android/xmpp/processor/BaseProcessor.java create mode 100644 src/main/java/im/conversations/android/xmpp/processor/BindProcessor.java create mode 100644 src/main/java/im/conversations/android/xmpp/processor/IqProcessor.java create mode 100644 src/main/java/im/conversations/android/xmpp/processor/JingleProcessor.java create mode 100644 src/main/java/im/conversations/android/xmpp/processor/MessageAcknowledgeProcessor.java create mode 100644 src/main/java/im/conversations/android/xmpp/processor/MessageProcessor.java create mode 100644 src/main/java/im/conversations/android/xmpp/processor/PresenceProcessor.java create mode 100644 src/main/java/im/conversations/android/xmpp/sasl/Anonymous.java create mode 100644 src/main/java/im/conversations/android/xmpp/sasl/ChannelBinding.java create mode 100644 src/main/java/im/conversations/android/xmpp/sasl/ChannelBindingMechanism.java create mode 100644 src/main/java/im/conversations/android/xmpp/sasl/DigestMd5.java create mode 100644 src/main/java/im/conversations/android/xmpp/sasl/External.java create mode 100644 src/main/java/im/conversations/android/xmpp/sasl/HashedToken.java create mode 100644 src/main/java/im/conversations/android/xmpp/sasl/HashedTokenSha256.java create mode 100644 src/main/java/im/conversations/android/xmpp/sasl/HashedTokenSha512.java create mode 100644 src/main/java/im/conversations/android/xmpp/sasl/Plain.java create mode 100644 src/main/java/im/conversations/android/xmpp/sasl/SaslMechanism.java create mode 100644 src/main/java/im/conversations/android/xmpp/sasl/ScramMechanism.java create mode 100644 src/main/java/im/conversations/android/xmpp/sasl/ScramPlusMechanism.java create mode 100644 src/main/java/im/conversations/android/xmpp/sasl/ScramSha1.java create mode 100644 src/main/java/im/conversations/android/xmpp/sasl/ScramSha1Plus.java create mode 100644 src/main/java/im/conversations/android/xmpp/sasl/ScramSha256.java create mode 100644 src/main/java/im/conversations/android/xmpp/sasl/ScramSha256Plus.java create mode 100644 src/main/java/im/conversations/android/xmpp/sasl/ScramSha512.java create mode 100644 src/main/java/im/conversations/android/xmpp/sasl/ScramSha512Plus.java create mode 100644 src/main/java/im/conversations/android/xmpp/sasl/Tokenizer.java diff --git a/build.gradle b/build.gradle index 5784f3e89..0c3c998bc 100644 --- a/build.gradle +++ b/build.gradle @@ -53,6 +53,8 @@ dependencies { annotationProcessor "androidx.room:room-compiler:$room_version" implementation "androidx.room:room-guava:$room_version" + implementation "androidx.security:security-crypto:1.0.0" + // legacy dependencies. Ideally everything below should be carefully reviewed and eventually moved up diff --git a/schemas/im.conversations.android.database.ConversationsDatabase/1.json b/schemas/im.conversations.android.database.ConversationsDatabase/1.json index 2af6825e8..91b5916c2 100644 --- a/schemas/im.conversations.android.database.ConversationsDatabase/1.json +++ b/schemas/im.conversations.android.database.ConversationsDatabase/1.json @@ -2,11 +2,11 @@ "formatVersion": 1, "database": { "version": 1, - "identityHash": "c78cb993428558b863fd91c46b608926", + "identityHash": "4a70ff0733436f5a2a08e7abb8e6cc95", "entities": [ { "tableName": "account", - "createSql": "CREATE TABLE IF NOT EXISTS `${TABLE_NAME}` (`id` INTEGER PRIMARY KEY AUTOINCREMENT, `address` TEXT NOT NULL, `resource` TEXT, `randomSeed` BLOB, `enabled` INTEGER NOT NULL, `rosterVersion` TEXT, `hostname` TEXT, `port` INTEGER, `directTls` INTEGER, `proxytype` TEXT, `proxyhostname` TEXT, `proxyport` INTEGER)", + "createSql": "CREATE TABLE IF NOT EXISTS `${TABLE_NAME}` (`id` INTEGER PRIMARY KEY AUTOINCREMENT, `address` TEXT NOT NULL, `resource` TEXT, `randomSeed` BLOB, `enabled` INTEGER NOT NULL, `quickStartAvailable` INTEGER NOT NULL, `pendingRegistration` INTEGER NOT NULL, `loggedInSuccessfully` INTEGER NOT NULL, `showErrorNotification` INTEGER NOT NULL, `rosterVersion` TEXT, `hostname` TEXT, `port` INTEGER, `directTls` INTEGER, `proxytype` TEXT, `proxyhostname` TEXT, `proxyport` INTEGER)", "fields": [ { "fieldPath": "id", @@ -38,6 +38,30 @@ "affinity": "INTEGER", "notNull": true }, + { + "fieldPath": "quickStartAvailable", + "columnName": "quickStartAvailable", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "pendingRegistration", + "columnName": "pendingRegistration", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "loggedInSuccessfully", + "columnName": "loggedInSuccessfully", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "showErrorNotification", + "columnName": "showErrorNotification", + "affinity": "INTEGER", + "notNull": true + }, { "fieldPath": "rosterVersion", "columnName": "rosterVersion", @@ -830,7 +854,7 @@ }, { "tableName": "presence", - "createSql": "CREATE TABLE IF NOT EXISTS `${TABLE_NAME}` (`id` INTEGER PRIMARY KEY AUTOINCREMENT, `accountId` INTEGER NOT NULL, `address` TEXT NOT NULL, `resource` TEXT, `type` TEXT, `show` TEXT, `status` TEXT, `vCardPhoto` TEXT, `occupantId` TEXT, `mucUserAffiliation` TEXT, `mucUserRole` TEXT, `mucUserJid` TEXT, FOREIGN KEY(`accountId`) REFERENCES `account`(`id`) ON UPDATE NO ACTION ON DELETE CASCADE )", + "createSql": "CREATE TABLE IF NOT EXISTS `${TABLE_NAME}` (`id` INTEGER PRIMARY KEY AUTOINCREMENT, `accountId` INTEGER NOT NULL, `address` TEXT NOT NULL, `resource` TEXT, `type` TEXT, `show` TEXT, `status` TEXT, `vCardPhoto` TEXT, `occupantId` TEXT, `mucUserAffiliation` TEXT, `mucUserRole` TEXT, `mucUserJid` TEXT, `mucUserSelf` INTEGER NOT NULL, FOREIGN KEY(`accountId`) REFERENCES `account`(`id`) ON UPDATE NO ACTION ON DELETE CASCADE )", "fields": [ { "fieldPath": "id", @@ -903,6 +927,12 @@ "columnName": "mucUserJid", "affinity": "TEXT", "notNull": false + }, + { + "fieldPath": "mucUserSelf", + "columnName": "mucUserSelf", + "affinity": "INTEGER", + "notNull": true } ], "primaryKey": { @@ -1159,7 +1189,7 @@ "views": [], "setupQueries": [ "CREATE TABLE IF NOT EXISTS room_master_table (id INTEGER PRIMARY KEY,identity_hash TEXT)", - "INSERT OR REPLACE INTO room_master_table (id,identity_hash) VALUES(42, 'c78cb993428558b863fd91c46b608926')" + "INSERT OR REPLACE INTO room_master_table (id,identity_hash) VALUES(42, '4a70ff0733436f5a2a08e7abb8e6cc95')" ] } } \ No newline at end of file diff --git a/src/main/AndroidManifest.xml b/src/main/AndroidManifest.xml index c18addf27..837ebd3be 100644 --- a/src/main/AndroidManifest.xml +++ b/src/main/AndroidManifest.xml @@ -74,6 +74,7 @@ - * The certificate validity is checked using the system default X509 - * TrustManager, creating a query Dialog if the check fails. - *

- * WARNING: This only works if a dedicated thread is used for - * opening sockets! + * A X509 trust manager implementation which asks the user about invalid certificates and memorizes + * their decision. + * + *

The certificate validity is checked using the system default X509 TrustManager, creating a + * query Dialog if the check fails. + * + *

WARNING: This only works if a dedicated thread is used for opening sockets! */ public class MemorizingTrustManager { - private static final SimpleDateFormat DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd", Locale.US); + private static final SimpleDateFormat DATE_FORMAT = + new SimpleDateFormat("yyyy-MM-dd", Locale.US); - final static String DECISION_INTENT = "de.duenndns.ssl.DECISION"; - public final static String DECISION_INTENT_ID = DECISION_INTENT + ".decisionId"; - public final static String DECISION_INTENT_CERT = DECISION_INTENT + ".cert"; - public final static String DECISION_TITLE_ID = DECISION_INTENT + ".titleId"; - final static String NO_TRUST_ANCHOR = "Trust anchor for certification path not found."; - private static final Pattern PATTERN_IPV4 = Pattern.compile("\\A(25[0-5]|2[0-4]\\d|[0-1]?\\d?\\d)(\\.(25[0-5]|2[0-4]\\d|[0-1]?\\d?\\d)){3}\\z"); - private static final Pattern PATTERN_IPV6_HEX4DECCOMPRESSED = Pattern.compile("\\A((?:[0-9A-Fa-f]{1,4}(?::[0-9A-Fa-f]{1,4})*)?) ::((?:[0-9A-Fa-f]{1,4}:)*)(25[0-5]|2[0-4]\\d|[0-1]?\\d?\\d)(\\.(25[0-5]|2[0-4]\\d|[0-1]?\\d?\\d)){3}\\z"); - private static final Pattern PATTERN_IPV6_6HEX4DEC = Pattern.compile("\\A((?:[0-9A-Fa-f]{1,4}:){6,6})(25[0-5]|2[0-4]\\d|[0-1]?\\d?\\d)(\\.(25[0-5]|2[0-4]\\d|[0-1]?\\d?\\d)){3}\\z"); - private static final Pattern PATTERN_IPV6_HEXCOMPRESSED = Pattern.compile("\\A((?:[0-9A-Fa-f]{1,4}(?::[0-9A-Fa-f]{1,4})*)?)::((?:[0-9A-Fa-f]{1,4}(?::[0-9A-Fa-f]{1,4})*)?)\\z"); - private static final Pattern PATTERN_IPV6 = Pattern.compile("\\A(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}\\z"); - private final static Logger LOGGER = Logger.getLogger(MemorizingTrustManager.class.getName()); + static final String DECISION_INTENT = "de.duenndns.ssl.DECISION"; + public static final String DECISION_INTENT_ID = DECISION_INTENT + ".decisionId"; + public static final String DECISION_INTENT_CERT = DECISION_INTENT + ".cert"; + public static final String DECISION_TITLE_ID = DECISION_INTENT + ".titleId"; + static final String NO_TRUST_ANCHOR = "Trust anchor for certification path not found."; + private static final Pattern PATTERN_IPV4 = + Pattern.compile( + "\\A(25[0-5]|2[0-4]\\d|[0-1]?\\d?\\d)(\\.(25[0-5]|2[0-4]\\d|[0-1]?\\d?\\d)){3}\\z"); + private static final Pattern PATTERN_IPV6_HEX4DECCOMPRESSED = + Pattern.compile( + "\\A((?:[0-9A-Fa-f]{1,4}(?::[0-9A-Fa-f]{1,4})*)?)" + + " ::((?:[0-9A-Fa-f]{1,4}:)*)(25[0-5]|2[0-4]\\d|[0-1]?\\d?\\d)(\\.(25[0-5]|2[0-4]\\d|[0-1]?\\d?\\d)){3}\\z"); + private static final Pattern PATTERN_IPV6_6HEX4DEC = + Pattern.compile( + "\\A((?:[0-9A-Fa-f]{1,4}:){6,6})(25[0-5]|2[0-4]\\d|[0-1]?\\d?\\d)(\\.(25[0-5]|2[0-4]\\d|[0-1]?\\d?\\d)){3}\\z"); + private static final Pattern PATTERN_IPV6_HEXCOMPRESSED = + Pattern.compile( + "\\A((?:[0-9A-Fa-f]{1,4}(?::[0-9A-Fa-f]{1,4})*)?)::((?:[0-9A-Fa-f]{1,4}(?::[0-9A-Fa-f]{1,4})*)?)\\z"); + private static final Pattern PATTERN_IPV6 = + Pattern.compile("\\A(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}\\z"); + private static final Logger LOGGER = Logger.getLogger(MemorizingTrustManager.class.getName()); static String KEYSTORE_DIR = "KeyStore"; static String KEYSTORE_FILE = "KeyStore.bks"; private static int decisionId = 0; @@ -125,19 +129,32 @@ public class MemorizingTrustManager { private X509TrustManager appTrustManager; private String poshCacheDir; + public static MemorizingTrustManager create(final Context context) { + final SharedPreferences preferences = + PreferenceManager.getDefaultSharedPreferences(context.getApplicationContext()); + final boolean dontTrustSystemCAs = + preferences.getBoolean( + "dont_trust_system_cas", + context.getResources().getBoolean(R.bool.dont_trust_system_cas)); + if (dontTrustSystemCAs) { + return new MemorizingTrustManager(context.getApplicationContext(), null); + } else { + return new MemorizingTrustManager(context.getApplicationContext()); + } + } /** - * Creates an instance of the MemorizingTrustManager class that falls back to a custom TrustManager. - *

- * You need to supply the application context. This has to be one of: - * - Application - * - Activity - * - Service - *

- * The context is used for file management, to display the dialog / - * notification and for obtaining translated strings. + * Creates an instance of the MemorizingTrustManager class that falls back to a custom + * TrustManager. * - * @param m Context for the application. - * @param defaultTrustManager Delegate trust management to this TM. If null, the user must accept every certificate. + *

You need to supply the application context. This has to be one of: - Application - + * Activity - Service + * + *

The context is used for file management, to display the dialog / notification and for + * obtaining translated strings. + * + * @param m Context for the application. + * @param defaultTrustManager Delegate trust management to this TM. If null, the user must + * accept every certificate. */ public MemorizingTrustManager(Context m, X509TrustManager defaultTrustManager) { init(m); @@ -147,14 +164,12 @@ public class MemorizingTrustManager { /** * Creates an instance of the MemorizingTrustManager class using the system X509TrustManager. - *

- * You need to supply the application context. This has to be one of: - * - Application - * - Activity - * - Service - *

- * The context is used for file management, to display the dialog / - * notification and for obtaining translated strings. + * + *

You need to supply the application context. This has to be one of: - Application - + * Activity - Service + * + *

The context is used for file management, to display the dialog / notification and for + * obtaining translated strings. * * @param m Context for the application. */ @@ -165,15 +180,16 @@ public class MemorizingTrustManager { } private static boolean isIp(final String server) { - return server != null && ( - PATTERN_IPV4.matcher(server).matches() + return server != null + && (PATTERN_IPV4.matcher(server).matches() || PATTERN_IPV6.matcher(server).matches() || PATTERN_IPV6_6HEX4DEC.matcher(server).matches() || PATTERN_IPV6_HEX4DECCOMPRESSED.matcher(server).matches() || PATTERN_IPV6_HEXCOMPRESSED.matcher(server).matches()); } - private static String getBase64Hash(X509Certificate certificate, String digest) throws CertificateEncodingException { + private static String getBase64Hash(X509Certificate certificate, String digest) + throws CertificateEncodingException { MessageDigest md; try { md = MessageDigest.getInstance(digest); @@ -188,8 +204,7 @@ public class MemorizingTrustManager { StringBuffer si = new StringBuffer(); for (int i = 0; i < data.length; i++) { si.append(String.format("%02x", data[i])); - if (i < data.length - 1) - si.append(":"); + if (i < data.length - 1) si.append(":"); } return si.toString(); } @@ -223,7 +238,8 @@ public class MemorizingTrustManager { void init(final Context m) { master = m; masterHandler = new Handler(m.getMainLooper()); - notificationManager = (NotificationManager) master.getSystemService(Context.NOTIFICATION_SERVICE); + notificationManager = + (NotificationManager) master.getSystemService(Context.NOTIFICATION_SERVICE); Application app; if (m instanceof Application) { @@ -233,7 +249,8 @@ public class MemorizingTrustManager { } else if (m instanceof AppCompatActivity) { app = ((AppCompatActivity) m).getApplication(); } else - throw new ClassCastException("MemorizingTrustManager context must be either Activity or Service!"); + throw new ClassCastException( + "MemorizingTrustManager context must be either Activity or Service!"); File dir = app.getDir(KEYSTORE_DIR, Context.MODE_PRIVATE); keyStoreFile = new File(dir + File.separator + KEYSTORE_FILE); @@ -260,12 +277,9 @@ public class MemorizingTrustManager { /** * Removes the given certificate from MTMs key store. * - *

- * WARNING: this does not immediately invalidate the certificate. It is - * well possible that (a) data is transmitted over still existing connections or - * (b) new connections are created using TLS renegotiation, without a new cert - * check. - *

+ *

WARNING: this does not immediately invalidate the certificate. It is well possible + * that (a) data is transmitted over still existing connections or (b) new connections are + * created using TLS renegotiation, without a new cert check. * * @param alias the certificate's alias as returned by {@link #getCertificates()}. * @throws KeyStoreException if the certificate could not be deleted. @@ -361,45 +375,60 @@ public class MemorizingTrustManager { } } - - private void checkCertTrusted(X509Certificate[] chain, String authType, String domain, boolean isServer, boolean interactive) + private void checkCertTrusted( + X509Certificate[] chain, + String authType, + String domain, + boolean isServer, + boolean interactive) throws CertificateException { - LOGGER.log(Level.FINE, "checkCertTrusted(" + chain + ", " + authType + ", " + isServer + ")"); + LOGGER.log( + Level.FINE, "checkCertTrusted(" + chain + ", " + authType + ", " + isServer + ")"); try { LOGGER.log(Level.FINE, "checkCertTrusted: trying appTrustManager"); - if (isServer) - appTrustManager.checkServerTrusted(chain, authType); - else - appTrustManager.checkClientTrusted(chain, authType); + if (isServer) appTrustManager.checkServerTrusted(chain, authType); + else appTrustManager.checkClientTrusted(chain, authType); } catch (final CertificateException ae) { LOGGER.log(Level.FINER, "checkCertTrusted: appTrustManager failed", ae); if (isCertKnown(chain[0])) { - LOGGER.log(Level.INFO, "checkCertTrusted: accepting cert already stored in keystore"); + LOGGER.log( + Level.INFO, "checkCertTrusted: accepting cert already stored in keystore"); return; } try { - if (defaultTrustManager == null) - throw ae; + if (defaultTrustManager == null) throw ae; LOGGER.log(Level.FINE, "checkCertTrusted: trying defaultTrustManager"); - if (isServer) - defaultTrustManager.checkServerTrusted(chain, authType); - else - defaultTrustManager.checkClientTrusted(chain, authType); + if (isServer) defaultTrustManager.checkServerTrusted(chain, authType); + else defaultTrustManager.checkClientTrusted(chain, authType); } catch (final CertificateException e) { - final SharedPreferences preferences = PreferenceManager.getDefaultSharedPreferences(master); - final boolean trustSystemCAs = !preferences.getBoolean("dont_trust_system_cas", false); - if (domain != null && isServer && trustSystemCAs && !isIp(domain) && !domain.endsWith(".onion")) { + final SharedPreferences preferences = + PreferenceManager.getDefaultSharedPreferences(master); + final boolean trustSystemCAs = + !preferences.getBoolean("dont_trust_system_cas", false); + if (domain != null + && isServer + && trustSystemCAs + && !isIp(domain) + && !domain.endsWith(".onion")) { final String hash = getBase64Hash(chain[0], "SHA-256"); final List fingerprints = getPoshFingerprints(domain); if (hash != null && fingerprints.size() > 0) { if (fingerprints.contains(hash)) { - Log.d(Config.LOGTAG, "trusted cert fingerprint of " + domain + " via posh"); + Log.d( + Config.LOGTAG, + "trusted cert fingerprint of " + domain + " via posh"); return; } else { - Log.d(Config.LOGTAG, "fingerprint " + hash + " not found in " + fingerprints); + Log.d( + Config.LOGTAG, + "fingerprint " + hash + " not found in " + fingerprints); } if (getPoshCacheFile(domain).delete()) { - Log.d(Config.LOGTAG, "deleted posh file for " + domain + " after not being able to verify"); + Log.d( + Config.LOGTAG, + "deleted posh file for " + + domain + + " after not being able to verify"); } } } @@ -422,17 +451,25 @@ public class MemorizingTrustManager { } private List getPoshFingerprintsFromServer(String domain) { - return getPoshFingerprintsFromServer(domain, "https://" + domain + "/.well-known/posh/xmpp-client.json", -1, true); + return getPoshFingerprintsFromServer( + domain, "https://" + domain + "/.well-known/posh/xmpp-client.json", -1, true); } - private List getPoshFingerprintsFromServer(String domain, String url, int maxTtl, boolean followUrl) { + private List getPoshFingerprintsFromServer( + String domain, String url, int maxTtl, boolean followUrl) { Log.d(Config.LOGTAG, "downloading json for " + domain + " from " + url); final SharedPreferences preferences = PreferenceManager.getDefaultSharedPreferences(master); - final boolean useTor = QuickConversationsService.isConversations() && preferences.getBoolean("use_tor", master.getResources().getBoolean(R.bool.use_tor)); + final boolean useTor = + QuickConversationsService.isConversations() + && preferences.getBoolean( + "use_tor", master.getResources().getBoolean(R.bool.use_tor)); try { final List results = new ArrayList<>(); final InputStream inputStream = HttpConnectionManager.open(url, useTor); - final String body = CharStreams.toString(new InputStreamReader(ByteStreams.limit(inputStream,10_000), Charsets.UTF_8)); + final String body = + CharStreams.toString( + new InputStreamReader( + ByteStreams.limit(inputStream, 10_000), Charsets.UTF_8)); final JSONObject jsonObject = new JSONObject(body); int expires = jsonObject.getInt("expires"); if (expires <= 0) { @@ -459,7 +496,7 @@ public class MemorizingTrustManager { writeFingerprintsToCache(domain, results, 1000L * expires + System.currentTimeMillis()); return results; } catch (final Exception e) { - Log.d(Config.LOGTAG, "error fetching posh",e); + Log.d(Config.LOGTAG, "error fetching posh", e); return new ArrayList<>(); } } @@ -489,7 +526,8 @@ public class MemorizingTrustManager { final File file = getPoshCacheFile(domain); try { final InputStream inputStream = new FileInputStream(file); - final String json = CharStreams.toString(new InputStreamReader(inputStream, Charsets.UTF_8)); + final String json = + CharStreams.toString(new InputStreamReader(inputStream, Charsets.UTF_8)); final JSONObject jsonObject = new JSONObject(json); long expires = jsonObject.getLong("expires"); long expiresIn = expires - System.currentTimeMillis(); @@ -514,7 +552,9 @@ public class MemorizingTrustManager { } private X509Certificate[] getAcceptedIssuers() { - return defaultTrustManager == null ? new X509Certificate[0] : defaultTrustManager.getAcceptedIssuers(); + return defaultTrustManager == null + ? new X509Certificate[0] + : defaultTrustManager.getAcceptedIssuers(); } private int createDecisionId(MTMDecision d) { @@ -527,7 +567,8 @@ public class MemorizingTrustManager { return myId; } - private void certDetails(final StringBuffer si, final X509Certificate c, final boolean showValidFor) { + private void certDetails( + final StringBuffer si, final X509Certificate c, final boolean showValidFor) { si.append("\n"); if (showValidFor) { @@ -564,8 +605,7 @@ public class MemorizingTrustManager { // not found", so we use string comparison. if (NO_TRUST_ANCHOR.equals(e.getMessage())) { si.append(master.getString(R.string.mtm_trust_anchor)); - } else - si.append(e.getLocalizedMessage()); + } else si.append(e.getLocalizedMessage()); si.append("\n"); } si.append("\n"); @@ -573,7 +613,7 @@ public class MemorizingTrustManager { si.append("\n\n"); si.append(master.getString(R.string.mtm_cert_details)); si.append('\n'); - for(int i = 0; i < chain.length; ++i) { + for (int i = 0; i < chain.length; ++i) { certDetails(si, chain[i], i == 0); } return si.toString(); @@ -593,24 +633,25 @@ public class MemorizingTrustManager { MTMDecision choice = new MTMDecision(); final int myId = createDecisionId(choice); - masterHandler.post(new Runnable() { - public void run() { - Intent ni = new Intent(master, MemorizingActivity.class); - ni.setFlags(Intent.FLAG_ACTIVITY_NEW_TASK); - ni.setData(Uri.parse(MemorizingTrustManager.class.getName() + "/" + myId)); - ni.putExtra(DECISION_INTENT_ID, myId); - ni.putExtra(DECISION_INTENT_CERT, message); - ni.putExtra(DECISION_TITLE_ID, titleId); + masterHandler.post( + new Runnable() { + public void run() { + Intent ni = new Intent(master, MemorizingActivity.class); + ni.setFlags(Intent.FLAG_ACTIVITY_NEW_TASK); + ni.setData(Uri.parse(MemorizingTrustManager.class.getName() + "/" + myId)); + ni.putExtra(DECISION_INTENT_ID, myId); + ni.putExtra(DECISION_INTENT_CERT, message); + ni.putExtra(DECISION_TITLE_ID, titleId); - // we try to directly start the activity and fall back to - // making a notification - try { - getUI().startActivity(ni); - } catch (Exception e) { - LOGGER.log(Level.FINE, "startActivity(MemorizingActivity)", e); - } - } - }); + // we try to directly start the activity and fall back to + // making a notification + try { + getUI().startActivity(ni); + } catch (Exception e) { + LOGGER.log(Level.FINE, "startActivity(MemorizingActivity)", e); + } + } + }); LOGGER.log(Level.FINE, "openDecisions: " + openDecisions + ", waiting on " + myId); try { @@ -661,7 +702,8 @@ public class MemorizingTrustManager { } @Override - public void checkClientTrusted(X509Certificate[] chain, String authType) throws CertificateException { + public void checkClientTrusted(X509Certificate[] chain, String authType) + throws CertificateException { MemorizingTrustManager.this.checkCertTrusted(chain, authType, domain, false, false); } @@ -675,7 +717,6 @@ public class MemorizingTrustManager { public X509Certificate[] getAcceptedIssuers() { return MemorizingTrustManager.this.getAcceptedIssuers(); } - } private class InteractiveMemorizingTrustManager implements X509TrustManager { @@ -686,7 +727,8 @@ public class MemorizingTrustManager { } @Override - public void checkClientTrusted(X509Certificate[] chain, String authType) throws CertificateException { + public void checkClientTrusted(X509Certificate[] chain, String authType) + throws CertificateException { MemorizingTrustManager.this.checkCertTrusted(chain, authType, domain, false, true); } diff --git a/src/main/java/eu/siacs/conversations/utils/SSLSockets.java b/src/main/java/eu/siacs/conversations/utils/SSLSockets.java index ae853bea8..e5f0c52e8 100644 --- a/src/main/java/eu/siacs/conversations/utils/SSLSockets.java +++ b/src/main/java/eu/siacs/conversations/utils/SSLSockets.java @@ -2,13 +2,11 @@ package eu.siacs.conversations.utils; import android.os.Build; import android.util.Log; - import androidx.annotation.RequiresApi; - import com.google.common.base.Strings; - -import org.conscrypt.Conscrypt; - +import eu.siacs.conversations.Config; +import eu.siacs.conversations.entities.Account; +import eu.siacs.conversations.xmpp.Jid; import java.lang.reflect.Method; import java.net.Socket; import java.nio.charset.StandardCharsets; @@ -17,29 +15,26 @@ import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.LinkedList; - import javax.net.ssl.SNIHostName; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLSession; import javax.net.ssl.SSLSocket; - -import eu.siacs.conversations.Config; -import eu.siacs.conversations.entities.Account; +import org.conscrypt.Conscrypt; public class SSLSockets { public static void setSecurity(final SSLSocket sslSocket) { final String[] supportProtocols; - final Collection supportedProtocols = new LinkedList<>( - Arrays.asList(sslSocket.getSupportedProtocols())); + final Collection supportedProtocols = + new LinkedList<>(Arrays.asList(sslSocket.getSupportedProtocols())); supportedProtocols.remove("SSLv3"); supportProtocols = supportedProtocols.toArray(new String[0]); sslSocket.setEnabledProtocols(supportProtocols); - final String[] cipherSuites = CryptoHelper.getOrderedCipherSuites( - sslSocket.getSupportedCipherSuites()); + final String[] cipherSuites = + CryptoHelper.getOrderedCipherSuites(sslSocket.getSupportedCipherSuites()); if (cipherSuites.length > 0) { sslSocket.setEnabledCipherSuites(cipherSuites); } @@ -70,7 +65,8 @@ public class SSLSockets { socket.setSSLParameters(parameters); } - private static void setApplicationProtocolReflection(final SSLSocket socket, final String protocol) { + private static void setApplicationProtocolReflection( + final SSLSocket socket, final String protocol) { try { final Method method = socket.getClass().getMethod("setAlpnProtocols", byte[].class); // the concatenation of 8-bit, length prefixed protocol names, just one in our case... @@ -78,16 +74,17 @@ public class SSLSockets { final byte[] protocolUTF8Bytes = protocol.getBytes(StandardCharsets.UTF_8); final byte[] lengthPrefixedProtocols = new byte[protocolUTF8Bytes.length + 1]; lengthPrefixedProtocols[0] = (byte) protocol.length(); // cannot be over 255 anyhow - System.arraycopy(protocolUTF8Bytes, 0, lengthPrefixedProtocols, 1, protocolUTF8Bytes.length); - method.invoke(socket, new Object[]{lengthPrefixedProtocols}); + System.arraycopy( + protocolUTF8Bytes, 0, lengthPrefixedProtocols, 1, protocolUTF8Bytes.length); + method.invoke(socket, new Object[] {lengthPrefixedProtocols}); } catch (Throwable e) { - Log.e(Config.LOGTAG,"unable to set ALPN on socket",e); + Log.e(Config.LOGTAG, "unable to set ALPN on socket", e); } } public static void setApplicationProtocol(final SSLSocket socket, final String protocol) { if (Conscrypt.isConscrypt(socket)) { - Conscrypt.setApplicationProtocols(socket, new String[]{protocol}); + Conscrypt.setApplicationProtocols(socket, new String[] {protocol}); } else { setApplicationProtocolReflection(socket, protocol); } @@ -101,11 +98,15 @@ public class SSLSockets { } } - public static void log(Account account, SSLSocket socket) { + public static void log(final Account account, SSLSocket socket) { + log(account.getJid(), socket); + } + + public static void log(final Jid address, SSLSocket socket) { SSLSession session = socket.getSession(); Log.d( Config.LOGTAG, - account.getJid().asBareJid() + address + ": protocol=" + session.getProtocol() + " cipher=" diff --git a/src/main/java/eu/siacs/conversations/xml/Tag.java b/src/main/java/eu/siacs/conversations/xml/Tag.java index db2b11172..ce90800c8 100644 --- a/src/main/java/eu/siacs/conversations/xml/Tag.java +++ b/src/main/java/eu/siacs/conversations/xml/Tag.java @@ -1,12 +1,11 @@ package eu.siacs.conversations.xml; -import org.jetbrains.annotations.NotNull; - +import eu.siacs.conversations.utils.XmlHelper; +import eu.siacs.conversations.xmpp.Jid; import java.util.Hashtable; import java.util.Map.Entry; import java.util.Set; - -import eu.siacs.conversations.utils.XmlHelper; +import org.jetbrains.annotations.NotNull; public class Tag { public static final int NO = -1; @@ -52,6 +51,13 @@ public class Tag { return this; } + public Tag setAttribute(final String attrName, final Jid attrValue) { + if (attrValue != null) { + this.attributes.put(attrName, attrValue.toEscapedString()); + } + return this; + } + public void setAttributes(final Hashtable attributes) { this.attributes = attributes; } diff --git a/src/main/java/im/conversations/android/Conversations.java b/src/main/java/im/conversations/android/Conversations.java new file mode 100644 index 000000000..488134ba6 --- /dev/null +++ b/src/main/java/im/conversations/android/Conversations.java @@ -0,0 +1,13 @@ +package im.conversations.android; + +import android.app.Application; +import im.conversations.android.xmpp.ConnectionPool; + +public class Conversations extends Application { + + @Override + public void onCreate() { + super.onCreate(); + ConnectionPool.getInstance(this).reconfigure(); + } +} diff --git a/src/main/java/im/conversations/android/Uuids.java b/src/main/java/im/conversations/android/Uuids.java new file mode 100644 index 000000000..4cf3ff29c --- /dev/null +++ b/src/main/java/im/conversations/android/Uuids.java @@ -0,0 +1,38 @@ +package im.conversations.android; + +import com.google.common.base.Preconditions; +import java.util.UUID; + +public class Uuids { + + private static final long VERSION_MASK = 4 << 12; + + public static UUID getUuid(final byte[] bytes) { + Preconditions.checkArgument(bytes != null && bytes.length == 32); + + long msb = 0; + long lsb = 0; + + msb |= (bytes[0x0] & 0xffL) << 56; + msb |= (bytes[0x1] & 0xffL) << 48; + msb |= (bytes[0x2] & 0xffL) << 40; + msb |= (bytes[0x3] & 0xffL) << 32; + msb |= (bytes[0x4] & 0xffL) << 24; + msb |= (bytes[0x5] & 0xffL) << 16; + msb |= (bytes[0x6] & 0xffL) << 8; + msb |= (bytes[0x7] & 0xffL); + + lsb |= (bytes[0x8] & 0xffL) << 56; + lsb |= (bytes[0x9] & 0xffL) << 48; + lsb |= (bytes[0xa] & 0xffL) << 40; + lsb |= (bytes[0xb] & 0xffL) << 32; + lsb |= (bytes[0xc] & 0xffL) << 24; + lsb |= (bytes[0xd] & 0xffL) << 16; + lsb |= (bytes[0xe] & 0xffL) << 8; + lsb |= (bytes[0xf] & 0xffL); + + msb = (msb & 0xffffffffffff0fffL) | VERSION_MASK; // set version + lsb = (lsb & 0x3fffffffffffffffL) | 0x8000000000000000L; // set variant + return new UUID(msb, lsb); + } +} diff --git a/src/main/java/im/conversations/android/database/ConversationsDatabase.java b/src/main/java/im/conversations/android/database/ConversationsDatabase.java index 9a7b2600f..9547e332b 100644 --- a/src/main/java/im/conversations/android/database/ConversationsDatabase.java +++ b/src/main/java/im/conversations/android/database/ConversationsDatabase.java @@ -5,6 +5,8 @@ import androidx.room.Database; import androidx.room.Room; import androidx.room.RoomDatabase; import androidx.room.TypeConverters; +import im.conversations.android.database.dao.AccountDao; +import im.conversations.android.database.dao.PresenceDao; import im.conversations.android.database.entity.AccountEntity; import im.conversations.android.database.entity.BlockedItemEntity; import im.conversations.android.database.entity.ChatEntity; @@ -62,4 +64,8 @@ public abstract class ConversationsDatabase extends RoomDatabase { return INSTANCE; } } + + public abstract AccountDao accountDao(); + + public abstract PresenceDao presenceDao(); } diff --git a/src/main/java/im/conversations/android/database/CredentialStore.java b/src/main/java/im/conversations/android/database/CredentialStore.java new file mode 100644 index 000000000..c0678c4af --- /dev/null +++ b/src/main/java/im/conversations/android/database/CredentialStore.java @@ -0,0 +1,206 @@ +package im.conversations.android.database; + +import android.content.Context; +import android.security.keystore.KeyGenParameterSpec; +import androidx.annotation.NonNull; +import androidx.security.crypto.EncryptedFile; +import androidx.security.crypto.MasterKeys; +import com.google.common.collect.ImmutableMap; +import com.google.common.reflect.TypeToken; +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import im.conversations.android.database.model.Account; +import im.conversations.android.database.model.Credential; +import im.conversations.android.xmpp.sasl.ChannelBindingMechanism; +import im.conversations.android.xmpp.sasl.HashedToken; +import im.conversations.android.xmpp.sasl.SaslMechanism; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.OutputStreamWriter; +import java.lang.reflect.Type; +import java.security.GeneralSecurityException; +import java.util.HashMap; +import java.util.Map; + +// TODO cache credentials?! +public class CredentialStore { + + private static final String FILENAME = "credential.store"; + + private static final Gson GSON = new GsonBuilder().create(); + + private static volatile CredentialStore INSTANCE; + + private final Context context; + + private CredentialStore(final Context context) { + this.context = context.getApplicationContext(); + } + + public static CredentialStore getInstance(final Context context) { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (CredentialStore.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new CredentialStore(context); + return INSTANCE; + } + } + + public synchronized Credential get(final Account account) { + return getOrEmpty(account); + } + + public void setPassword(final Account account, final String password) + throws GeneralSecurityException, IOException { + setPassword(account, password, false); + } + + public synchronized void setPassword( + final Account account, final String password, final boolean autogeneratedPassword) + throws GeneralSecurityException, IOException { + final Credential credential = getOrEmpty(account); + final Credential modifiedCredential = + new Credential( + password, + autogeneratedPassword, + credential.pinnedMechanism, + credential.pinnedChannelBinding, + credential.fastMechanism, + credential.fastToken, + credential.preAuthRegistrationToken, + credential.privateKeyAlias); + // TODO ignore if unchanged + this.set(account, modifiedCredential); + } + + public void setFastToken( + final Account account, final HashedToken.Mechanism mechanism, final String token) + throws GeneralSecurityException, IOException { + final Credential credential = getOrEmpty(account); + final Credential modifiedCredential = + new Credential( + credential.password, + credential.autogeneratedPassword, + credential.pinnedMechanism, + credential.pinnedChannelBinding, + mechanism.name(), + token, + credential.preAuthRegistrationToken, + credential.privateKeyAlias); + // TODO ignore if unchanged + this.set(account, modifiedCredential); + } + + public void resetFastToken(final Account account) throws GeneralSecurityException, IOException { + final Credential credential = getOrEmpty(account); + final Credential modifiedCredential = + new Credential( + credential.password, + credential.autogeneratedPassword, + credential.pinnedMechanism, + credential.pinnedChannelBinding, + null, + null, + credential.preAuthRegistrationToken, + credential.privateKeyAlias); + // TODO ignore if unchanged + this.set(account, modifiedCredential); + } + + public void setPinnedMechanism(final Account account, final SaslMechanism mechanism) + throws GeneralSecurityException, IOException { + final String pinnedMechanism = mechanism.getMechanism(); + final String pinnedChannelBinding; + if (mechanism instanceof ChannelBindingMechanism) { + pinnedChannelBinding = + ((ChannelBindingMechanism) mechanism).getChannelBinding().toString(); + } else { + pinnedChannelBinding = null; + } + final Credential credential = getOrEmpty(account); + final Credential modifiedCredential = + new Credential( + credential.password, + credential.autogeneratedPassword, + pinnedMechanism, + pinnedChannelBinding, + credential.fastMechanism, + credential.fastToken, + credential.preAuthRegistrationToken, + credential.privateKeyAlias); + // TODO ignore if unchanged + this.set(account, modifiedCredential); + } + + public void resetPinnedMechanism(final Account account) + throws GeneralSecurityException, IOException { + final Credential credential = getOrEmpty(account); + final Credential modifiedCredential = + new Credential( + credential.password, + credential.autogeneratedPassword, + null, + null, + credential.fastMechanism, + credential.fastToken, + credential.preAuthRegistrationToken, + credential.privateKeyAlias); + // TODO ignore if unchanged + this.set(account, modifiedCredential); + } + + private Credential getOrEmpty(final Account account) { + final Map store = loadOrEmpty(); + final Credential credential = store.get(account.address.toEscapedString()); + return credential == null ? Credential.empty() : credential; + } + + private void set(@NonNull final Account account, @NonNull final Credential credential) + throws GeneralSecurityException, IOException { + final HashMap credentialStore = new HashMap<>(loadOrEmpty()); + credentialStore.put(account.address.toEscapedString(), credential); + store(credentialStore); + } + + private Map loadOrEmpty() { + final Map store; + try { + store = load(); + } catch (final GeneralSecurityException | IOException e) { + return ImmutableMap.of(); + } + return store == null ? ImmutableMap.of() : store; + } + + private Map load() throws GeneralSecurityException, IOException { + final EncryptedFile encryptedFile = getEncryptedFile(); + final FileInputStream inputStream = encryptedFile.openFileInput(); + final Type type = new TypeToken>() {}.getType(); + return GSON.fromJson(new InputStreamReader(inputStream), type); + } + + private void store(final Map store) + throws GeneralSecurityException, IOException { + final EncryptedFile encryptedFile = getEncryptedFile(); + final FileOutputStream outputStream = encryptedFile.openFileOutput(); + GSON.toJson(store, new OutputStreamWriter(outputStream)); + } + + private EncryptedFile getEncryptedFile() throws GeneralSecurityException, IOException { + final KeyGenParameterSpec keyGenParameterSpec = MasterKeys.AES256_GCM_SPEC; + final String mainKeyAlias = MasterKeys.getOrCreate(keyGenParameterSpec); + return new EncryptedFile.Builder( + new File(context.getFilesDir(), FILENAME), + context, + mainKeyAlias, + EncryptedFile.FileEncryptionScheme.AES256_GCM_HKDF_4KB) + .build(); + } +} diff --git a/src/main/java/im/conversations/android/database/dao/AccountDao.java b/src/main/java/im/conversations/android/database/dao/AccountDao.java new file mode 100644 index 000000000..094a47f26 --- /dev/null +++ b/src/main/java/im/conversations/android/database/dao/AccountDao.java @@ -0,0 +1,64 @@ +package im.conversations.android.database.dao; + +import androidx.room.Dao; +import androidx.room.Insert; +import androidx.room.OnConflictStrategy; +import androidx.room.Query; +import com.google.common.util.concurrent.ListenableFuture; +import im.conversations.android.database.entity.AccountEntity; +import im.conversations.android.database.model.Account; +import im.conversations.android.database.model.Connection; +import java.util.List; + +@Dao +public interface AccountDao { + + @Insert(onConflict = OnConflictStrategy.ABORT) + void insert(final AccountEntity account); + + @Query("SELECT id,address,randomSeed FROM account WHERE enabled = 1") + ListenableFuture> getEnabledAccounts(); + + @Query("SELECT hostname,port,directTls FROM account WHERE id=:id AND hostname != null") + Connection getConnectionSettings(long id); + + @Query("SELECT resource FROM account WHERE id=:id") + String getResource(long id); + + @Query("SELECT rosterVersion FROM account WHERE id=:id") + String getRosterVersion(long id); + + @Query("SELECT quickStartAvailable FROM account where id=:id") + boolean quickStartAvailable(long id); + + @Query("SELECT pendingRegistration FROM account where id=:id") + boolean pendingRegistration(long id); + + @Query("SELECT loggedInSuccessfully == 0 FROM account where id=:id") + boolean isInitialLogin(long id); + + @Query( + "UPDATE account set quickStartAvailable=:available WHERE id=:id AND" + + " quickStartAvailable != :available") + void setQuickStartAvailable(long id, boolean available); + + @Query( + "UPDATE account set pendingRegistration=:pendingRegistration WHERE id=:id AND" + + " pendingRegistration != :pendingRegistration") + void setPendingRegistration(long id, boolean pendingRegistration); + + @Query( + "UPDATE account set loggedInSuccessfully=:loggedInSuccessfully WHERE id=:id AND" + + " loggedInSuccessfully != :loggedInSuccessfully") + int setLoggedInSuccessfully(long id, boolean loggedInSuccessfully); + + @Query( + "UPDATE account set showErrorNotification=:showErrorNotification WHERE id=:id AND" + + " showErrorNotification != :showErrorNotification") + int setShowErrorNotification(long id, boolean showErrorNotification); + + @Query("UPDATE account set resource=:resource WHERE id=:id") + void setResource(long id, String resource); + + // TODO on disable set resource to null +} diff --git a/src/main/java/im/conversations/android/database/dao/PresenceDao.java b/src/main/java/im/conversations/android/database/dao/PresenceDao.java new file mode 100644 index 000000000..c5cf93945 --- /dev/null +++ b/src/main/java/im/conversations/android/database/dao/PresenceDao.java @@ -0,0 +1,11 @@ +package im.conversations.android.database.dao; + +import androidx.room.Dao; +import androidx.room.Query; + +@Dao +public interface PresenceDao { + + @Query("DELETE FROM presence WHERE accountId=:account") + void deletePresences(long account); +} diff --git a/src/main/java/im/conversations/android/database/entity/AccountEntity.java b/src/main/java/im/conversations/android/database/entity/AccountEntity.java index aef2c1b5c..0942623c4 100644 --- a/src/main/java/im/conversations/android/database/entity/AccountEntity.java +++ b/src/main/java/im/conversations/android/database/entity/AccountEntity.java @@ -26,6 +26,15 @@ public class AccountEntity { public boolean enabled; + public boolean quickStartAvailable = false; + public boolean pendingRegistration = false; + + // TODO this is only used during setup; depending on how the setup procedure will look in the + // future we might get rid of this property + public boolean loggedInSuccessfully = false; + + public boolean showErrorNotification = true; + public String rosterVersion; @Embedded public Connection connection; diff --git a/src/main/java/im/conversations/android/database/entity/PresenceEntity.java b/src/main/java/im/conversations/android/database/entity/PresenceEntity.java index 102cf39a6..dffb9ca63 100644 --- a/src/main/java/im/conversations/android/database/entity/PresenceEntity.java +++ b/src/main/java/im/conversations/android/database/entity/PresenceEntity.java @@ -50,4 +50,7 @@ public class PresenceEntity { @Nullable public MucOptions.Role mucUserRole; @Nullable public Jid mucUserJid; + + // set to true if presence has status code 110 (this means we are online) + public boolean mucUserSelf; } diff --git a/src/main/java/im/conversations/android/database/model/Account.java b/src/main/java/im/conversations/android/database/model/Account.java new file mode 100644 index 000000000..95e048abf --- /dev/null +++ b/src/main/java/im/conversations/android/database/model/Account.java @@ -0,0 +1,57 @@ +package im.conversations.android.database.model; + +import androidx.annotation.NonNull; +import com.google.common.base.Objects; +import com.google.common.base.Preconditions; +import com.google.common.hash.Hashing; +import com.google.common.io.ByteSource; +import eu.siacs.conversations.xmpp.Jid; +import im.conversations.android.Uuids; +import java.io.IOException; +import java.util.UUID; + +public class Account { + + public final long id; + @NonNull public final Jid address; + @NonNull public final byte[] randomSeed; + + public Account(final long id, @NonNull final Jid address, @NonNull byte[] randomSeed) { + Preconditions.checkNotNull(address, "Account can not be instantiated without an address"); + Preconditions.checkArgument(address.isBareJid(), "Account address must be bare"); + Preconditions.checkArgument( + randomSeed.length == 32, "RandomSeed must have exactly 32 bytes"); + this.id = id; + this.address = address; + this.randomSeed = randomSeed; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Account account = (Account) o; + return id == account.id + && Objects.equal(address, account.address) + && Objects.equal(randomSeed, account.randomSeed); + } + + @Override + public int hashCode() { + return Objects.hashCode(id, address, randomSeed); + } + + public boolean isOnion() { + final String domain = address.getDomain().toEscapedString(); + return domain.endsWith(".onion"); + } + + public UUID getPublicDeviceId() { + try { + return Uuids.getUuid( + ByteSource.wrap(randomSeed).slice(0, 16).hash(Hashing.sha256()).asBytes()); + } catch (final IOException e) { + return UUID.randomUUID(); + } + } +} diff --git a/src/main/java/im/conversations/android/database/model/Credential.java b/src/main/java/im/conversations/android/database/model/Credential.java new file mode 100644 index 000000000..aeefcb747 --- /dev/null +++ b/src/main/java/im/conversations/android/database/model/Credential.java @@ -0,0 +1,50 @@ +package im.conversations.android.database.model; + +public class Credential { + + public final String password; + public final boolean autogeneratedPassword; + public final String pinnedMechanism; + public final String pinnedChannelBinding; + + public final String fastMechanism; + public final String fastToken; + + public final String preAuthRegistrationToken; + + public final String privateKeyAlias; + + private Credential() { + this.password = null; + this.autogeneratedPassword = false; + this.pinnedMechanism = null; + this.pinnedChannelBinding = null; + this.fastMechanism = null; + this.fastToken = null; + this.preAuthRegistrationToken = null; + this.privateKeyAlias = null; + } + + public Credential( + String password, + boolean autogeneratedPassword, + String pinnedMechanism, + String pinnedChannelBinding, + String fastMechanism, + String fastToken, + String preAuthRegistrationToken, + String privateKeyAlias) { + this.password = password; + this.autogeneratedPassword = autogeneratedPassword; + this.pinnedMechanism = pinnedMechanism; + this.pinnedChannelBinding = pinnedChannelBinding; + this.fastMechanism = fastMechanism; + this.fastToken = fastToken; + this.preAuthRegistrationToken = preAuthRegistrationToken; + this.privateKeyAlias = privateKeyAlias; + } + + public static Credential empty() { + return new Credential(); + } +} diff --git a/src/main/java/im/conversations/android/xmpp/ConnectionPool.java b/src/main/java/im/conversations/android/xmpp/ConnectionPool.java new file mode 100644 index 000000000..d4af0f8f5 --- /dev/null +++ b/src/main/java/im/conversations/android/xmpp/ConnectionPool.java @@ -0,0 +1,356 @@ +package im.conversations.android.xmpp; + +import static eu.siacs.conversations.utils.Random.SECURE_RANDOM; + +import android.content.Context; +import android.os.SystemClock; +import android.util.Log; +import com.google.common.base.Optional; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import com.google.common.primitives.Ints; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import eu.siacs.conversations.Config; +import eu.siacs.conversations.utils.CryptoHelper; +import eu.siacs.conversations.utils.PhoneHelper; +import eu.siacs.conversations.xmpp.Jid; +import im.conversations.android.database.ConversationsDatabase; +import im.conversations.android.database.model.Account; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; + +public class ConnectionPool { + + private static volatile ConnectionPool INSTANCE; + + private final Context context; + + private final Executor reconfigurationExecutor = Executors.newSingleThreadExecutor(); + private final ScheduledExecutorService reconnectExecutor = + Executors.newSingleThreadScheduledExecutor(); + + private final List connections = new ArrayList<>(); + private final HashSet lowPingTimeoutMode = new HashSet<>(); + + private ConnectionPool(final Context context) { + this.context = context.getApplicationContext(); + } + + public ListenableFuture reconfigure() { + final ListenableFuture> accountFuture = + ConversationsDatabase.getInstance(context).accountDao().getEnabledAccounts(); + return Futures.transform( + accountFuture, + accounts -> this.reconfigure(ImmutableSet.copyOf(accounts)), + reconfigurationExecutor); + } + + public synchronized XmppConnection get(final Jid address) { + return Iterables.find(this.connections, c -> address.equals(c.getAccount().address)); + } + + public synchronized XmppConnection get(final long id) { + return Iterables.find(this.connections, c -> id == c.getAccount().id); + } + + public synchronized boolean isEnabled(final long id) { + return Iterables.any(this.connections, c -> id == c.getAccount().id); + } + + public synchronized List getConnections() { + return ImmutableList.copyOf(this.connections); + } + + private synchronized Void reconfigure(final Set accounts) { + final Set current = getAccounts(); + final Set removed = Sets.difference(current, accounts); + final Set added = Sets.difference(accounts, current); + for (final Account account : added) { + final XmppConnection connection = this.instantiate(context, account); + connection.setOnStatusChangedListener(this::onStatusChanged); + } + for (final Account account : removed) { + final Optional connectionOptional = + Iterables.tryFind(connections, c -> c.getAccount().equals(account)); + if (connectionOptional.isPresent()) { + final XmppConnection connection = connectionOptional.get(); + disconnect(connection, false); + } + } + return null; + } + + private void onStatusChanged(final XmppConnection connection) { + final Account account = connection.getAccount(); + if (connection.getStatus() == ConnectionState.ONLINE || connection.getStatus().isError()) { + // TODO notify QuickConversationsService of account state change + // mQuickConversationsService.signalAccountStateChange(); + } + + if (connection.getStatus() == ConnectionState.ONLINE) { + synchronized (lowPingTimeoutMode) { + if (lowPingTimeoutMode.remove(account.address)) { + Log.d(Config.LOGTAG, account.address + ": leaving low ping timeout mode"); + } + } + ConversationsDatabase.getInstance(context) + .accountDao() + .setShowErrorNotification(account.id, true); + if (connection.getFeatures().csi()) { + // TODO send correct CSI state (connection.sendActive or connection.sendInactive) + } + scheduleWakeUpCall(Config.PING_MAX_INTERVAL); + } else if (connection.getStatus() == ConnectionState.OFFLINE) { + + // TODO previously we would call resetSendingToWaiting. The new architecture likely + // won’t need this but we should double check + + // resetSendingToWaiting(account); + if (isInLowPingTimeoutMode(account)) { + Log.d( + Config.LOGTAG, + account.address + + ": went into offline state during low ping mode." + + " reconnecting now"); + reconnectAccount(connection); + } else { + final int timeToReconnect = SECURE_RANDOM.nextInt(10) + 2; + scheduleWakeUpCall(timeToReconnect); + } + } else if (connection.getStatus() == ConnectionState.REGISTRATION_SUCCESSFUL) { + // databaseBackend.updateAccount(account); + reconnectAccount(connection); + } else if (connection.getStatus() != ConnectionState.CONNECTING) { + // resetSendingToWaiting(account); + if (connection.getStatus().isAttemptReconnect()) { + final int next = connection.getTimeToNextAttempt(); + final boolean lowPingTimeoutMode = isInLowPingTimeoutMode(account); + if (next <= 0) { + Log.d( + Config.LOGTAG, + account.address + + ": error connecting account. reconnecting now." + + " lowPingTimeout=" + + lowPingTimeoutMode); + reconnectAccount(connection); + } else { + final int attempt = connection.getAttempt() + 1; + Log.d( + Config.LOGTAG, + account.address + + ": error connecting account. try again in " + + next + + "s for the " + + attempt + + " time. lowPingTimeout=" + + lowPingTimeoutMode); + scheduleWakeUpCall(next); + } + } + } + // TODO toggle error notification + // getNotificationService().updateErrorNotification(); + } + + public void scheduleWakeUpCall(final int seconds) { + reconnectExecutor.schedule( + () -> { + manageConnectionStates(); + }, + Math.max(0, seconds) + 1, + TimeUnit.SECONDS); + } + + /** This is called externally if we want to force pings for example on connection switches */ + public void ping() { + manageConnectionStates(null, true); + } + + /** + * This is called externally from the push receiver + * + * @param pushedAccountHash + */ + public void receivePush(final String pushedAccountHash) { + manageConnectionStates(pushedAccountHash, false); + } + + private void manageConnectionStates() { + manageConnectionStates(null, false); + } + + private void manageConnectionStates( + final String pushedAccountHash, final boolean immediatePing) { + // WakeLockHelper.acquire(wakeLock); + int pingNow = 0; + final HashSet pingCandidates = new HashSet<>(); + final String androidId = PhoneHelper.getAndroidId(context); + for (final XmppConnection xmppConnection : this.connections) { + final Account account = xmppConnection.getAccount(); + final boolean pushWasMeantForThisAccount = + CryptoHelper.getFingerprint(account.address, androidId) + .equals(pushedAccountHash); + if (processAccountState(xmppConnection, pushWasMeantForThisAccount, pingCandidates)) { + pingNow++; + } + } + if (pingNow > 0 || immediatePing) { + for (final XmppConnection xmppConnection : pingCandidates) { + final Account account = xmppConnection.getAccount(); + final boolean lowTimeout = isInLowPingTimeoutMode(account); + xmppConnection.sendPing(); + Log.d( + Config.LOGTAG, + account.address + " send ping (lowTimeout=" + lowTimeout + ")"); + scheduleWakeUpCall(lowTimeout ? Config.LOW_PING_TIMEOUT : Config.PING_TIMEOUT); + } + } + // WakeLockHelper.release(wakeLock); + } + + private boolean processAccountState( + final XmppConnection connection, + final boolean isAccountPushed, + final HashSet pingCandidates) { + boolean pingNow = false; + if (connection.getStatus().isAttemptReconnect()) { + final Account account = connection.getAccount(); + if (connection.getStatus() == ConnectionState.ONLINE) { + synchronized (lowPingTimeoutMode) { + final long lastReceived = connection.getLastPacketReceived(); + final long lastSent = connection.getLastPingSent(); + final long msToNextPing = + (Math.max(lastReceived, lastSent) + Config.PING_MAX_INTERVAL) + - SystemClock.elapsedRealtime(); + final int pingTimeout = + lowPingTimeoutMode.contains(account.address) + ? Config.LOW_PING_TIMEOUT * 1000 + : Config.PING_TIMEOUT * 1000; + final long pingTimeoutIn = + (lastSent + pingTimeout) - SystemClock.elapsedRealtime(); + if (lastSent > lastReceived) { + if (pingTimeoutIn < 0) { + Log.d(Config.LOGTAG, account.address + ": ping timeout"); + this.reconnectAccount(connection); + } else { + final int secs = (int) (pingTimeoutIn / 1000); + this.scheduleWakeUpCall(secs); + } + } else { + pingCandidates.add(connection); + if (isAccountPushed) { + pingNow = true; + if (lowPingTimeoutMode.add(account.address)) { + Log.d( + Config.LOGTAG, + account.address + ": entering low ping timeout mode"); + } + } else if (msToNextPing <= 0) { + pingNow = true; + } else { + this.scheduleWakeUpCall(Ints.saturatedCast(msToNextPing / 1000)); + if (lowPingTimeoutMode.remove(account.address)) { + Log.d( + Config.LOGTAG, + account.address + ": leaving low ping timeout mode"); + } + } + } + } + } else if (connection.getStatus() == ConnectionState.OFFLINE) { + reconnectAccount(connection); + } else if (connection.getStatus() == ConnectionState.CONNECTING) { + long secondsSinceLastConnect = + (SystemClock.elapsedRealtime() - connection.getLastConnect()) / 1000; + long secondsSinceLastDisco = + (SystemClock.elapsedRealtime() - connection.getLastDiscoStarted()) / 1000; + long discoTimeout = Config.CONNECT_DISCO_TIMEOUT - secondsSinceLastDisco; + long timeout = Config.CONNECT_TIMEOUT - secondsSinceLastConnect; + if (timeout < 0) { + Log.d( + Config.LOGTAG, + account.address + + ": time out during connect reconnecting" + + " (secondsSinceLast=" + + secondsSinceLastConnect + + ")"); + connection.resetAttemptCount(false); + reconnectAccount(connection); + } else if (discoTimeout < 0) { + connection.sendDiscoTimeout(); + scheduleWakeUpCall(Ints.saturatedCast(discoTimeout)); + } else { + scheduleWakeUpCall(Ints.saturatedCast(Math.min(timeout, discoTimeout))); + } + } else { + if (connection.getTimeToNextAttempt() <= 0) { + reconnectAccount(connection); + } + } + } + return pingNow; + } + + private void reconnectAccount(final XmppConnection connection) { + final Account account = connection.getAccount(); + if (isEnabled(account.id)) { + final Thread thread = new Thread(connection); + connection.prepareNewConnection(); + connection.interrupt(); + thread.start(); + scheduleWakeUpCall(Config.CONNECT_DISCO_TIMEOUT); + } else { + disconnect(connection, true); + connection.resetEverything(); + } + } + + private void disconnect(final XmppConnection connection, boolean force) { + if (force) { + connection.disconnect(true); + } else { + // TODO bring back code that gracefully leaves MUCs + // TODO send offline presence + connection.disconnect(false); + } + } + + private boolean isInLowPingTimeoutMode(Account account) { + synchronized (lowPingTimeoutMode) { + return lowPingTimeoutMode.contains(account.address); + } + } + + private XmppConnection instantiate(final Context context, final Account account) { + final XmppConnection xmppConnection = new XmppConnection(context, account); + this.connections.add(xmppConnection); + return xmppConnection; + } + + private Set getAccounts() { + return ImmutableSet.copyOf(Lists.transform(this.connections, XmppConnection::getAccount)); + } + + public static ConnectionPool getInstance(final Context context) { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (ConnectionPool.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new ConnectionPool(context); + return INSTANCE; + } + } +} diff --git a/src/main/java/im/conversations/android/xmpp/ConnectionState.java b/src/main/java/im/conversations/android/xmpp/ConnectionState.java new file mode 100644 index 000000000..6fd71b13d --- /dev/null +++ b/src/main/java/im/conversations/android/xmpp/ConnectionState.java @@ -0,0 +1,124 @@ +package im.conversations.android.xmpp; + +import androidx.annotation.StringRes; +import eu.siacs.conversations.R; + +public enum ConnectionState { + OFFLINE(false), + CONNECTING(false), + ONLINE(false), + UNAUTHORIZED, + TEMPORARY_AUTH_FAILURE, + SERVER_NOT_FOUND, + REGISTRATION_SUCCESSFUL(false), + REGISTRATION_FAILED(true, false), + REGISTRATION_WEB(true, false), + REGISTRATION_CONFLICT(true, false), + REGISTRATION_NOT_SUPPORTED(true, false), + REGISTRATION_PLEASE_WAIT(true, false), + REGISTRATION_INVALID_TOKEN(true, false), + REGISTRATION_PASSWORD_TOO_WEAK(true, false), + TLS_ERROR, + TLS_ERROR_DOMAIN, + INCOMPATIBLE_SERVER, + INCOMPATIBLE_CLIENT, + TOR_NOT_AVAILABLE, + DOWNGRADE_ATTACK, + SESSION_FAILURE, + BIND_FAILURE, + HOST_UNKNOWN, + STREAM_ERROR, + STREAM_OPENING_ERROR, + POLICY_VIOLATION, + PAYMENT_REQUIRED, + MISSING_INTERNET_PERMISSION(false); + + private final boolean isError; + private final boolean attemptReconnect; + + ConnectionState(final boolean isError) { + this(isError, true); + } + + ConnectionState(final boolean isError, final boolean reconnect) { + this.isError = isError; + this.attemptReconnect = reconnect; + } + + ConnectionState() { + this(true, true); + } + + public boolean isError() { + return this.isError; + } + + public boolean isAttemptReconnect() { + return this.attemptReconnect; + } + + // TODO refactor into DataBinder (we can print the enum directly in the UI) + @StringRes + public int getReadableId() { + switch (this) { + case ONLINE: + return R.string.account_status_online; + case CONNECTING: + return R.string.account_status_connecting; + case OFFLINE: + return R.string.account_status_offline; + case UNAUTHORIZED: + return R.string.account_status_unauthorized; + case SERVER_NOT_FOUND: + return R.string.account_status_not_found; + case REGISTRATION_FAILED: + return R.string.account_status_regis_fail; + case REGISTRATION_WEB: + return R.string.account_status_regis_web; + case REGISTRATION_CONFLICT: + return R.string.account_status_regis_conflict; + case REGISTRATION_SUCCESSFUL: + return R.string.account_status_regis_success; + case REGISTRATION_NOT_SUPPORTED: + return R.string.account_status_regis_not_sup; + case REGISTRATION_INVALID_TOKEN: + return R.string.account_status_regis_invalid_token; + case TLS_ERROR: + return R.string.account_status_tls_error; + case TLS_ERROR_DOMAIN: + return R.string.account_status_tls_error_domain; + case INCOMPATIBLE_SERVER: + return R.string.account_status_incompatible_server; + case INCOMPATIBLE_CLIENT: + return R.string.account_status_incompatible_client; + case TOR_NOT_AVAILABLE: + return R.string.account_status_tor_unavailable; + case BIND_FAILURE: + return R.string.account_status_bind_failure; + case SESSION_FAILURE: + return R.string.session_failure; + case DOWNGRADE_ATTACK: + return R.string.sasl_downgrade; + case HOST_UNKNOWN: + return R.string.account_status_host_unknown; + case POLICY_VIOLATION: + return R.string.account_status_policy_violation; + case REGISTRATION_PLEASE_WAIT: + return R.string.registration_please_wait; + case REGISTRATION_PASSWORD_TOO_WEAK: + return R.string.registration_password_too_weak; + case STREAM_ERROR: + return R.string.account_status_stream_error; + case STREAM_OPENING_ERROR: + return R.string.account_status_stream_opening_error; + case PAYMENT_REQUIRED: + return R.string.payment_required; + case MISSING_INTERNET_PERMISSION: + return R.string.missing_internet_permission; + case TEMPORARY_AUTH_FAILURE: + return R.string.account_status_temporary_auth_failure; + default: + return R.string.account_status_unknown; + } + } +} diff --git a/src/main/java/im/conversations/android/xmpp/XmppConnection.java b/src/main/java/im/conversations/android/xmpp/XmppConnection.java new file mode 100644 index 000000000..ab0e1ee7a --- /dev/null +++ b/src/main/java/im/conversations/android/xmpp/XmppConnection.java @@ -0,0 +1,2805 @@ +package im.conversations.android.xmpp; + +import static eu.siacs.conversations.utils.Random.SECURE_RANDOM; + +import android.content.Context; +import android.graphics.Bitmap; +import android.graphics.BitmapFactory; +import android.os.Build; +import android.os.SystemClock; +import android.security.KeyChain; +import android.util.Base64; +import android.util.Log; +import android.util.Pair; +import android.util.SparseArray; +import androidx.annotation.NonNull; +import androidx.annotation.Nullable; +import com.google.common.base.Optional; +import com.google.common.base.Strings; +import com.google.common.io.ByteSource; +import eu.siacs.conversations.Config; +import eu.siacs.conversations.R; +import eu.siacs.conversations.crypto.XmppDomainVerifier; +import eu.siacs.conversations.crypto.axolotl.AxolotlService; +import eu.siacs.conversations.entities.ServiceDiscoveryResult; +import eu.siacs.conversations.http.HttpConnectionManager; +import eu.siacs.conversations.persistance.FileBackend; +import eu.siacs.conversations.services.MemorizingTrustManager; +import eu.siacs.conversations.services.MessageArchiveService; +import eu.siacs.conversations.services.NotificationService; +import eu.siacs.conversations.utils.CryptoHelper; +import eu.siacs.conversations.utils.Patterns; +import eu.siacs.conversations.utils.PhoneHelper; +import eu.siacs.conversations.utils.Resolver; +import eu.siacs.conversations.utils.SSLSockets; +import eu.siacs.conversations.utils.SocksSocketFactory; +import eu.siacs.conversations.utils.XmlHelper; +import eu.siacs.conversations.xml.Element; +import eu.siacs.conversations.xml.LocalizedContent; +import eu.siacs.conversations.xml.Namespace; +import eu.siacs.conversations.xml.Tag; +import eu.siacs.conversations.xml.TagWriter; +import eu.siacs.conversations.xml.XmlReader; +import eu.siacs.conversations.xmpp.InvalidJid; +import eu.siacs.conversations.xmpp.Jid; +import eu.siacs.conversations.xmpp.bind.Bind2; +import eu.siacs.conversations.xmpp.forms.Data; +import eu.siacs.conversations.xmpp.jingle.stanzas.JinglePacket; +import eu.siacs.conversations.xmpp.stanzas.AbstractAcknowledgeableStanza; +import eu.siacs.conversations.xmpp.stanzas.AbstractStanza; +import eu.siacs.conversations.xmpp.stanzas.IqPacket; +import eu.siacs.conversations.xmpp.stanzas.MessagePacket; +import eu.siacs.conversations.xmpp.stanzas.PresencePacket; +import eu.siacs.conversations.xmpp.stanzas.csi.ActivePacket; +import eu.siacs.conversations.xmpp.stanzas.csi.InactivePacket; +import eu.siacs.conversations.xmpp.stanzas.streammgmt.AckPacket; +import eu.siacs.conversations.xmpp.stanzas.streammgmt.EnablePacket; +import eu.siacs.conversations.xmpp.stanzas.streammgmt.RequestPacket; +import eu.siacs.conversations.xmpp.stanzas.streammgmt.ResumePacket; +import im.conversations.android.database.ConversationsDatabase; +import im.conversations.android.database.CredentialStore; +import im.conversations.android.database.model.Account; +import im.conversations.android.database.model.Connection; +import im.conversations.android.database.model.Credential; +import im.conversations.android.xmpp.processor.BindProcessor; +import im.conversations.android.xmpp.processor.IqProcessor; +import im.conversations.android.xmpp.processor.JingleProcessor; +import im.conversations.android.xmpp.processor.MessageAcknowledgeProcessor; +import im.conversations.android.xmpp.processor.MessageProcessor; +import im.conversations.android.xmpp.processor.PresenceProcessor; +import im.conversations.android.xmpp.sasl.ChannelBinding; +import im.conversations.android.xmpp.sasl.HashedToken; +import im.conversations.android.xmpp.sasl.SaslMechanism; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.ConnectException; +import java.net.IDN; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.Socket; +import java.net.UnknownHostException; +import java.security.KeyManagementException; +import java.security.NoSuchAlgorithmException; +import java.security.Principal; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Hashtable; +import java.util.Iterator; +import java.util.List; +import java.util.Map.Entry; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.regex.Matcher; +import javax.net.ssl.KeyManager; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLSocket; +import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.X509KeyManager; +import javax.net.ssl.X509TrustManager; +import okhttp3.HttpUrl; +import org.xmlpull.v1.XmlPullParserException; + +public class XmppConnection implements Runnable { + + private static final int PACKET_IQ = 0; + private static final int PACKET_MESSAGE = 1; + private static final int PACKET_PRESENCE = 2; + protected final Account account; + private final Features features = new Features(this); + private final HashMap disco = new HashMap<>(); + private final HashMap commands = new HashMap<>(); + private final SparseArray mStanzaQueue = new SparseArray<>(); + private final Hashtable>> packetCallbacks = + new Hashtable<>(); + private final Context context; + private Socket socket; + private XmlReader tagReader; + private TagWriter tagWriter = new TagWriter(); + private boolean shouldAuthenticate = true; + private boolean inSmacksSession = false; + private boolean quickStartInProgress = false; + private boolean isBound = false; + private Element streamFeatures; + private String streamId = null; + private Jid connectionAddress; + private ConnectionState connectionState = ConnectionState.OFFLINE; + private ConnectionState recentErrorConnectionState = ConnectionState.OFFLINE; + private int stanzasReceived = 0; + private int stanzasSent = 0; + private int stanzasSentBeforeAuthentication; + private long lastPacketReceived = 0; + private long lastPingSent = 0; + private long lastConnect = 0; + private long lastSessionStarted = 0; + private long lastDiscoStarted = 0; + private boolean isMamPreferenceAlways = false; + private final AtomicInteger mPendingServiceDiscoveries = new AtomicInteger(0); + private final AtomicBoolean mWaitForDisco = new AtomicBoolean(true); + private final AtomicBoolean mWaitingForSmCatchup = new AtomicBoolean(false); + private final AtomicInteger mSmCatchupMessageCounter = new AtomicInteger(0); + private int attempt = 0; + private final Consumer presencePacketConsumer; + private final Consumer jinglePacketConsumer; + private final Consumer iqPacketConsumer; + private final Consumer messagePacketConsumer; + private final BiFunction messageAcknowledgeProcessor; + private final Consumer bindConsumer; + private Consumer statusListener = null; + private SaslMechanism saslMechanism; + private HashedToken.Mechanism hashTokenRequest; + private HttpUrl redirectionUrl = null; + private String verifiedHostname = null; + private volatile Thread mThread; + private CountDownLatch mStreamCountDownLatch; + + public XmppConnection(final Context context, final Account account) { + this.context = context; + this.account = account; + this.connectionAddress = account.address; + + // these consumers are pure listeners; they don’t have public method except for accept|apply + // those consumers don’t need to be invoked from anywhere except this connection + // this is different to 'Managers' (like MAM, OMEMO, Avatar) that need to listen to external + // events like 'go fetch history for x' + this.messagePacketConsumer = new MessageProcessor(context, this); + this.presencePacketConsumer = new PresenceProcessor(context, this); + this.iqPacketConsumer = new IqProcessor(context, this); + this.jinglePacketConsumer = new JingleProcessor(context, this); + this.messageAcknowledgeProcessor = new MessageAcknowledgeProcessor(context, this); + this.bindConsumer = new BindProcessor(context, this); + } + + public Account getAccount() { + return account; + } + + private String fixResource(final String resource) { + if (Strings.isNullOrEmpty(resource)) { + return null; + } + int fixedPartLength = + context.getString(R.string.app_name).length() + 1; // include the trailing dot + int randomPartLength = 4; // 3 bytes + if (resource.length() > fixedPartLength + randomPartLength) { + if (validBase64( + resource.substring(fixedPartLength, fixedPartLength + randomPartLength))) { + return resource.substring(0, fixedPartLength + randomPartLength); + } + } + return resource; + } + + private static boolean validBase64(String input) { + try { + return Base64.decode(input, Base64.URL_SAFE).length == 3; + } catch (Throwable throwable) { + return false; + } + } + + private void changeStatus(final ConnectionState nextStatus) { + synchronized (this) { + if (Thread.currentThread().isInterrupted()) { + Log.d( + Config.LOGTAG, + account.address + + ": not changing status to " + + nextStatus + + " because thread was interrupted"); + return; + } + final ConnectionState current = this.connectionState; + if (current == nextStatus) { + return; + } + if ((nextStatus == ConnectionState.OFFLINE) + && (current != ConnectionState.CONNECTING) + && (current != ConnectionState.ONLINE)) { + return; + } + if (nextStatus == ConnectionState.ONLINE) { + this.attempt = 0; + } + this.connectionState = nextStatus; + if (nextStatus.isError() || nextStatus == ConnectionState.ONLINE) { + this.recentErrorConnectionState = nextStatus; + } + } + if (statusListener != null) { + statusListener.accept(this); + } + } + + public Jid getJidForCommand(final String node) { + synchronized (this.commands) { + return this.commands.get(node); + } + } + + public void prepareNewConnection() { + this.lastConnect = SystemClock.elapsedRealtime(); + this.lastPingSent = SystemClock.elapsedRealtime(); + this.lastDiscoStarted = Long.MAX_VALUE; + this.mWaitingForSmCatchup.set(false); + this.changeStatus(ConnectionState.CONNECTING); + } + + public boolean isWaitingForSmCatchup() { + return mWaitingForSmCatchup.get(); + } + + public void incrementSmCatchupMessageCounter() { + this.mSmCatchupMessageCounter.incrementAndGet(); + } + + protected void connect() { + final Connection connection = + ConversationsDatabase.getInstance(context) + .accountDao() + .getConnectionSettings(account.id); + Log.d(Config.LOGTAG, account.address + ": connecting"); + features.encryptionEnabled = false; + this.inSmacksSession = false; + this.quickStartInProgress = false; + this.isBound = false; + this.attempt++; + this.verifiedHostname = + null; // will be set if user entered hostname is being used or hostname was verified + // with dnssec + try { + Socket localSocket; + shouldAuthenticate = + !ConversationsDatabase.getInstance(context) + .accountDao() + .pendingRegistration(account.id); + this.changeStatus(ConnectionState.CONNECTING); + // TODO introduce proxy check + final boolean useTor = /*fcontext.useTorToConnect() ||*/ account.isOnion(); + if (useTor) { + final String destination; + final int port; + final boolean directTls; + if (connection == null || account.isOnion()) { + destination = account.address.getDomain().toEscapedString(); + port = 5222; + directTls = false; + } else { + destination = connection.hostname; + this.verifiedHostname = destination; + port = connection.port; + directTls = connection.directTls; + } + + Log.d( + Config.LOGTAG, + account.address + + ": connect to " + + destination + + " via Tor. directTls=" + + directTls); + localSocket = SocksSocketFactory.createSocketOverTor(destination, port); + + if (directTls) { + localSocket = upgradeSocketToTls(localSocket); + features.encryptionEnabled = true; + } + + try { + startXmpp(localSocket); + } catch (final InterruptedException e) { + Log.d( + Config.LOGTAG, + account.address + ": thread was interrupted before beginning stream"); + return; + } catch (final Exception e) { + throw new IOException("Could not start stream", e); + } + } else { + final String domain = account.address.getDomain().toEscapedString(); + final List results; + if (connection != null) { + results = Resolver.fromHardCoded(connection.hostname, connection.port); + } else { + results = Resolver.resolve(domain); + } + if (Thread.currentThread().isInterrupted()) { + Log.d(Config.LOGTAG, account.address + ": Thread was interrupted"); + return; + } + if (results.size() == 0) { + Log.e(Config.LOGTAG, account.address + ": Resolver results were empty"); + return; + } + final Resolver.Result storedBackupResult; + if (connection != null) { + storedBackupResult = null; + } else { + // TODO fix resolver result caching + storedBackupResult = + null; // context.databaseBackend.findResolverResult(domain); + if (storedBackupResult != null && !results.contains(storedBackupResult)) { + results.add(storedBackupResult); + Log.d( + Config.LOGTAG, + account.address + + ": loaded backup resolver result from db: " + + storedBackupResult); + } + } + for (Iterator iterator = results.iterator(); + iterator.hasNext(); ) { + final Resolver.Result result = iterator.next(); + if (Thread.currentThread().isInterrupted()) { + Log.d(Config.LOGTAG, account.address + ": Thread was interrupted"); + return; + } + try { + // if tls is true, encryption is implied and must not be started + features.encryptionEnabled = result.isDirectTls(); + verifiedHostname = + result.isAuthenticated() ? result.getHostname().toString() : null; + Log.d(Config.LOGTAG, "verified hostname " + verifiedHostname); + final InetSocketAddress addr; + if (result.getIp() != null) { + addr = new InetSocketAddress(result.getIp(), result.getPort()); + Log.d( + Config.LOGTAG, + account.address + + ": using values from resolver " + + (result.getHostname() == null + ? "" + : result.getHostname().toString() + "/") + + result.getIp().getHostAddress() + + ":" + + result.getPort() + + " tls: " + + features.encryptionEnabled); + } else { + addr = + new InetSocketAddress( + IDN.toASCII(result.getHostname().toString()), + result.getPort()); + Log.d( + Config.LOGTAG, + account.address + + ": using values from resolver " + + result.getHostname().toString() + + ":" + + result.getPort() + + " tls: " + + features.encryptionEnabled); + } + + localSocket = new Socket(); + localSocket.connect(addr, Config.SOCKET_TIMEOUT * 1000); + + if (features.encryptionEnabled) { + localSocket = upgradeSocketToTls(localSocket); + } + + localSocket.setSoTimeout(Config.SOCKET_TIMEOUT * 1000); + if (startXmpp(localSocket)) { + localSocket.setSoTimeout( + 0); // reset to 0; once the connection is established we don’t + // want this + if (connection == null && !result.equals(storedBackupResult)) { + // TODO store resolver result + // context.databaseBackend.saveResolverResult(domain, result); + } + break; // successfully connected to server that speaks xmpp + } else { + FileBackend.close(localSocket); + throw new StateChangingException(ConnectionState.STREAM_OPENING_ERROR); + } + } catch (final StateChangingException e) { + if (!iterator.hasNext()) { + throw e; + } + } catch (InterruptedException e) { + Log.d( + Config.LOGTAG, + account.address + + ": thread was interrupted before beginning stream"); + return; + } catch (final Throwable e) { + Log.d( + Config.LOGTAG, + account.address + + ": " + + e.getMessage() + + "(" + + e.getClass().getName() + + ")"); + if (!iterator.hasNext()) { + throw new UnknownHostException(); + } + } + } + } + processStream(); + } catch (final SecurityException e) { + this.changeStatus(ConnectionState.MISSING_INTERNET_PERMISSION); + } catch (final StateChangingException e) { + this.changeStatus(e.state); + } catch (final UnknownHostException + | ConnectException + | SocksSocketFactory.HostNotFoundException e) { + this.changeStatus(ConnectionState.SERVER_NOT_FOUND); + } catch (final SocksSocketFactory.SocksProxyNotFoundException e) { + this.changeStatus(ConnectionState.TOR_NOT_AVAILABLE); + } catch (final IOException | XmlPullParserException e) { + Log.d(Config.LOGTAG, account.address + ": " + e.getMessage()); + this.changeStatus(ConnectionState.OFFLINE); + this.attempt = Math.max(0, this.attempt - 1); + } finally { + if (!Thread.currentThread().isInterrupted()) { + forceCloseSocket(); + } else { + Log.d( + Config.LOGTAG, + account.address + + ": not force closing socket because thread was interrupted"); + } + } + } + + /** + * Starts xmpp protocol, call after connecting to socket + * + * @return true if server returns with valid xmpp, false otherwise + */ + private boolean startXmpp(final Socket socket) throws Exception { + if (Thread.currentThread().isInterrupted()) { + throw new InterruptedException(); + } + this.socket = socket; + tagReader = new XmlReader(); + if (tagWriter != null) { + tagWriter.forceClose(); + } + tagWriter = new TagWriter(); + tagWriter.setOutputStream(socket.getOutputStream()); + tagReader.setInputStream(socket.getInputStream()); + tagWriter.beginDocument(); + final boolean quickStart; + if (socket instanceof SSLSocket) { + final SSLSocket sslSocket = (SSLSocket) socket; + SSLSockets.log(account.address, sslSocket); + quickStart = establishStream(SSLSockets.version(sslSocket)); + } else { + quickStart = establishStream(SSLSockets.Version.NONE); + } + final Tag tag = tagReader.readTag(); + if (Thread.currentThread().isInterrupted()) { + throw new InterruptedException(); + } + final boolean success = tag != null && tag.isStart("stream", Namespace.STREAMS); + if (success && quickStart) { + this.quickStartInProgress = true; + } + return success; + } + + private SSLSocketFactory getSSLSocketFactory() + throws NoSuchAlgorithmException, KeyManagementException { + final SSLContext sc = SSLSockets.getSSLContext(); + final MemorizingTrustManager trustManager = MemorizingTrustManager.create(context); + final KeyManager[] keyManager; + final Credential credential = CredentialStore.getInstance(context).get(account); + if (Strings.isNullOrEmpty(credential.privateKeyAlias)) { + keyManager = null; + } else { + keyManager = new KeyManager[] {new MyKeyManager(context, credential)}; + } + final String domain = account.address.getDomain().toEscapedString(); + // TODO we used to use two different trust managers; interactive and non interactive (to + // trigger SSL cert prompts) + // we need a better solution for this using live data or similar + sc.init( + keyManager, + new X509TrustManager[] {trustManager.getInteractive(domain)}, + SECURE_RANDOM); + return sc.getSocketFactory(); + } + + @Override + public void run() { + synchronized (this) { + this.mThread = Thread.currentThread(); + if (this.mThread.isInterrupted()) { + Log.d( + Config.LOGTAG, + account.address + ": aborting connect because thread was interrupted"); + return; + } + forceCloseSocket(); + } + connect(); + } + + private void processStream() throws XmlPullParserException, IOException { + final CountDownLatch streamCountDownLatch = new CountDownLatch(1); + this.mStreamCountDownLatch = streamCountDownLatch; + Tag nextTag = tagReader.readTag(); + while (nextTag != null && !nextTag.isEnd("stream")) { + if (nextTag.isStart("error")) { + processStreamError(nextTag); + } else if (nextTag.isStart("features", Namespace.STREAMS)) { + processStreamFeatures(nextTag); + } else if (nextTag.isStart("proceed", Namespace.TLS)) { + switchOverToTls(); + } else if (nextTag.isStart("success")) { + final Element success = tagReader.readElement(nextTag); + if (processSuccess(success)) { + break; + } + + } else if (nextTag.isStart("failure", Namespace.TLS)) { + throw new StateChangingException(ConnectionState.TLS_ERROR); + } else if (nextTag.isStart("failure")) { + final Element failure = tagReader.readElement(nextTag); + processFailure(failure); + } else if (nextTag.isStart("continue", Namespace.SASL_2)) { + // two step sasl2 - we don’t support this yet + throw new StateChangingException(ConnectionState.INCOMPATIBLE_CLIENT); + } else if (nextTag.isStart("challenge")) { + if (isSecure() && this.saslMechanism != null) { + final Element challenge = tagReader.readElement(nextTag); + processChallenge(challenge); + } else { + Log.d( + Config.LOGTAG, + account.address + ": received 'challenge on an unsecure connection"); + throw new StateChangingException(ConnectionState.INCOMPATIBLE_CLIENT); + } + } else if (nextTag.isStart("enabled", Namespace.STREAM_MANAGEMENT)) { + final Element enabled = tagReader.readElement(nextTag); + processEnabled(enabled); + } else if (nextTag.isStart("resumed")) { + final Element resumed = tagReader.readElement(nextTag); + processResumed(resumed); + } else if (nextTag.isStart("r")) { + tagReader.readElement(nextTag); + if (Config.EXTENDED_SM_LOGGING) { + Log.d( + Config.LOGTAG, + account.address + ": acknowledging stanza #" + this.stanzasReceived); + } + final AckPacket ack = new AckPacket(this.stanzasReceived); + tagWriter.writeStanzaAsync(ack); + } else if (nextTag.isStart("a")) { + synchronized (NotificationService.CATCHUP_LOCK) { + if (mWaitingForSmCatchup.compareAndSet(true, false)) { + final int messageCount = mSmCatchupMessageCounter.get(); + final int pendingIQs = packetCallbacks.size(); + Log.d( + Config.LOGTAG, + account.address + + ": SM catchup complete (messages=" + + messageCount + + ", pending IQs=" + + pendingIQs + + ")"); + if (messageCount > 0) { + // TODO finish notification backlog (ok to pling now) + // context.getNotificationService().finishBacklog(true, account); + } + } + } + final Element ack = tagReader.readElement(nextTag); + lastPacketReceived = SystemClock.elapsedRealtime(); + final boolean acknowledgedMessages; + synchronized (this.mStanzaQueue) { + final Optional serverSequence = ack.getOptionalIntAttribute("h"); + if (serverSequence.isPresent()) { + acknowledgedMessages = acknowledgeStanzaUpTo(serverSequence.get()); + } else { + acknowledgedMessages = false; + Log.d( + Config.LOGTAG, + account.address + ": server send ack without sequence number"); + } + } + } else if (nextTag.isStart("failed")) { + final Element failed = tagReader.readElement(nextTag); + processFailed(failed, true); + } else if (nextTag.isStart("iq")) { + processIq(nextTag); + } else if (nextTag.isStart("message")) { + processMessage(nextTag); + } else if (nextTag.isStart("presence")) { + processPresence(nextTag); + } + nextTag = tagReader.readTag(); + } + if (nextTag != null && nextTag.isEnd("stream")) { + streamCountDownLatch.countDown(); + } + } + + private void processChallenge(final Element challenge) throws IOException { + final SaslMechanism.Version version; + try { + version = SaslMechanism.Version.of(challenge); + } catch (final IllegalArgumentException e) { + throw new StateChangingException(ConnectionState.INCOMPATIBLE_SERVER); + } + final Element response; + if (version == SaslMechanism.Version.SASL) { + response = new Element("response", Namespace.SASL); + } else if (version == SaslMechanism.Version.SASL_2) { + response = new Element("response", Namespace.SASL_2); + } else { + throw new AssertionError("Missing implementation for " + version); + } + try { + response.setContent( + saslMechanism.getResponse(challenge.getContent(), sslSocketOrNull(socket))); + } catch (final SaslMechanism.AuthenticationException e) { + // TODO: Send auth abort tag. + Log.e(Config.LOGTAG, e.toString()); + throw new StateChangingException(ConnectionState.UNAUTHORIZED); + } + tagWriter.writeElement(response); + } + + private boolean processSuccess(final Element success) + throws IOException, XmlPullParserException { + final SaslMechanism.Version version; + try { + version = SaslMechanism.Version.of(success); + } catch (final IllegalArgumentException e) { + throw new StateChangingException(ConnectionState.INCOMPATIBLE_SERVER); + } + final SaslMechanism currentSaslMechanism = this.saslMechanism; + if (currentSaslMechanism == null) { + throw new StateChangingException(ConnectionState.INCOMPATIBLE_SERVER); + } + final String challenge; + if (version == SaslMechanism.Version.SASL) { + challenge = success.getContent(); + } else if (version == SaslMechanism.Version.SASL_2) { + challenge = success.findChildContent("additional-data"); + } else { + throw new AssertionError("Missing implementation for " + version); + } + try { + currentSaslMechanism.getResponse(challenge, sslSocketOrNull(socket)); + } catch (final SaslMechanism.AuthenticationException e) { + Log.e(Config.LOGTAG, String.valueOf(e)); + throw new StateChangingException(ConnectionState.UNAUTHORIZED); + } + Log.d(Config.LOGTAG, account.address + ": logged in (using " + version + ")"); + if (SaslMechanism.pin(currentSaslMechanism)) { + try { + CredentialStore.getInstance(context) + .setPinnedMechanism(account, currentSaslMechanism); + } catch (final Exception e) { + Log.d(Config.LOGTAG, "unable to pin mechanism in credential store", e); + } + } + if (version == SaslMechanism.Version.SASL_2) { + final String authorizationIdentifier = + success.findChildContent("authorization-identifier"); + final Jid authorizationJid; + try { + authorizationJid = + Strings.isNullOrEmpty(authorizationIdentifier) + ? null + : Jid.ofEscaped(authorizationIdentifier); + } catch (final IllegalArgumentException e) { + Log.d( + Config.LOGTAG, + account.address + + ": SASL 2.0 authorization identifier was not a valid jid"); + throw new StateChangingException(ConnectionState.BIND_FAILURE); + } + if (authorizationJid == null) { + throw new StateChangingException(ConnectionState.BIND_FAILURE); + } + Log.d( + Config.LOGTAG, + account.address + + ": SASL 2.0 authorization identifier was " + + authorizationJid); + if (!account.address.getDomain().equals(authorizationJid.getDomain())) { + Log.d( + Config.LOGTAG, + account.address + + ": server tried to re-assign domain to " + + authorizationJid.getDomain()); + throw new StateChangingError(ConnectionState.BIND_FAILURE); + } + setConnectionAddress(authorizationJid); + final Element bound = success.findChild("bound", Namespace.BIND2); + final Element resumed = success.findChild("resumed", Namespace.STREAM_MANAGEMENT); + final Element failed = success.findChild("failed", Namespace.STREAM_MANAGEMENT); + final Element tokenWrapper = success.findChild("token", Namespace.FAST); + final String token = tokenWrapper == null ? null : tokenWrapper.getAttribute("token"); + if (bound != null && resumed != null) { + Log.d( + Config.LOGTAG, + account.address + ": server sent bound and resumed in SASL2 success"); + throw new StateChangingException(ConnectionState.INCOMPATIBLE_SERVER); + } + final boolean processNopStreamFeatures; + if (resumed != null && streamId != null) { + processResumed(resumed); + } else if (failed != null) { + processFailed(failed, false); // wait for new stream features + } + if (bound != null) { + clearIqCallbacks(); + this.isBound = true; + final Element streamManagementEnabled = + bound.findChild("enabled", Namespace.STREAM_MANAGEMENT); + final Element carbonsEnabled = bound.findChild("enabled", Namespace.CARBONS); + final boolean waitForDisco; + if (streamManagementEnabled != null) { + resetOutboundStanzaQueue(); + processEnabled(streamManagementEnabled); + waitForDisco = true; + } else { + // if we did not enable stream management in bind do it now + waitForDisco = enableStreamManagement(); + } + if (carbonsEnabled != null) { + Log.d(Config.LOGTAG, account.address + ": successfully enabled carbons"); + features.carbonsEnabled = true; + } + sendPostBindInitialization(waitForDisco, carbonsEnabled != null); + processNopStreamFeatures = true; + } else { + processNopStreamFeatures = false; + } + final HashedToken.Mechanism tokenMechanism; + if (SaslMechanism.hashedToken(currentSaslMechanism)) { + tokenMechanism = ((HashedToken) currentSaslMechanism).getTokenMechanism(); + } else if (this.hashTokenRequest != null) { + tokenMechanism = this.hashTokenRequest; + } else { + tokenMechanism = null; + } + if (tokenMechanism != null && !Strings.isNullOrEmpty(token)) { + try { + CredentialStore.getInstance(context) + .setFastToken(account, tokenMechanism, token); + Log.d( + Config.LOGTAG, + account.address + ": storing hashed token " + tokenMechanism); + } catch (final Exception e) { + Log.d(Config.LOGTAG, "could not store fast token", e); + } + } else if (this.hashTokenRequest != null) { + Log.w( + Config.LOGTAG, + account.address + + ": no response to our hashed token request " + + this.hashTokenRequest); + } + // a successful resume will not send stream features + if (processNopStreamFeatures) { + processNopStreamFeatures(); + } + } + this.quickStartInProgress = false; + if (version == SaslMechanism.Version.SASL) { + tagReader.reset(); + sendStartStream(false, true); + final Tag tag = tagReader.readTag(); + if (tag != null && tag.isStart("stream", Namespace.STREAMS)) { + processStream(); + return true; + } else { + throw new StateChangingException(ConnectionState.STREAM_OPENING_ERROR); + } + } else { + return false; + } + } + + private void resetOutboundStanzaQueue() { + synchronized (this.mStanzaQueue) { + final List intermediateStanzas = new ArrayList<>(); + if (Config.EXTENDED_SM_LOGGING) { + Log.d( + Config.LOGTAG, + account.address + + ": stanzas sent before auth: " + + this.stanzasSentBeforeAuthentication); + } + for (int i = this.stanzasSentBeforeAuthentication + 1; i <= this.stanzasSent; ++i) { + final AbstractAcknowledgeableStanza stanza = this.mStanzaQueue.get(i); + if (stanza != null) { + intermediateStanzas.add(stanza); + } + } + this.mStanzaQueue.clear(); + for (int i = 0; i < intermediateStanzas.size(); ++i) { + this.mStanzaQueue.put(i, intermediateStanzas.get(i)); + } + this.stanzasSent = intermediateStanzas.size(); + if (Config.EXTENDED_SM_LOGGING) { + Log.d( + Config.LOGTAG, + account.address + + ": resetting outbound stanza queue to " + + this.stanzasSent); + } + } + } + + private void processNopStreamFeatures() throws IOException { + final Tag tag = tagReader.readTag(); + if (tag != null && tag.isStart("features", Namespace.STREAMS)) { + this.streamFeatures = tagReader.readElement(tag); + Log.d( + Config.LOGTAG, + account.address + + ": processed NOP stream features after success: " + + XmlHelper.printElementNames(this.streamFeatures)); + } else { + Log.d(Config.LOGTAG, account.address + ": received " + tag); + Log.d( + Config.LOGTAG, + account.address + ": server did not send stream features after SASL2 success"); + throw new StateChangingException(ConnectionState.INCOMPATIBLE_SERVER); + } + } + + private void processFailure(final Element failure) throws IOException { + final SaslMechanism.Version version; + try { + version = SaslMechanism.Version.of(failure); + } catch (final IllegalArgumentException e) { + throw new StateChangingException(ConnectionState.INCOMPATIBLE_SERVER); + } + Log.d(Config.LOGTAG, failure.toString()); + Log.d(Config.LOGTAG, account.address + ": login failure " + version); + if (SaslMechanism.hashedToken(this.saslMechanism)) { + Log.d(Config.LOGTAG, account.address + ": resetting token"); + try { + CredentialStore.getInstance(context).resetFastToken(account); + } catch (final Exception e) { + Log.d(Config.LOGTAG, "could not reset fast token in credential store", e); + } + } + if (failure.hasChild("temporary-auth-failure")) { + throw new StateChangingException(ConnectionState.TEMPORARY_AUTH_FAILURE); + } else if (failure.hasChild("account-disabled")) { + final String text = failure.findChildContent("text"); + if (Strings.isNullOrEmpty(text)) { + throw new StateChangingException(ConnectionState.UNAUTHORIZED); + } + final Matcher matcher = Patterns.AUTOLINK_WEB_URL.matcher(text); + if (matcher.find()) { + final HttpUrl url; + try { + url = HttpUrl.get(text.substring(matcher.start(), matcher.end())); + } catch (final IllegalArgumentException e) { + throw new StateChangingException(ConnectionState.UNAUTHORIZED); + } + if (url.isHttps()) { + this.redirectionUrl = url; + throw new StateChangingException(ConnectionState.PAYMENT_REQUIRED); + } + } + } + if (SaslMechanism.hashedToken(this.saslMechanism)) { + Log.d( + Config.LOGTAG, + account.address + + ": fast authentication failed. falling back to regular" + + " authentication"); + authenticate(); + } else { + throw new StateChangingException(ConnectionState.UNAUTHORIZED); + } + } + + private static SSLSocket sslSocketOrNull(final Socket socket) { + if (socket instanceof SSLSocket) { + return (SSLSocket) socket; + } else { + return null; + } + } + + private void processEnabled(final Element enabled) { + final String streamId; + if (enabled.getAttributeAsBoolean("resume")) { + streamId = enabled.getAttribute("id"); + Log.d(Config.LOGTAG, account.address + ": stream management enabled (resumable)"); + } else { + Log.d(Config.LOGTAG, account.address + ": stream management enabled"); + streamId = null; + } + this.streamId = streamId; + this.stanzasReceived = 0; + this.inSmacksSession = true; + final RequestPacket r = new RequestPacket(); + tagWriter.writeStanzaAsync(r); + } + + private void processResumed(final Element resumed) throws StateChangingException { + this.inSmacksSession = true; + this.isBound = true; + this.tagWriter.writeStanzaAsync(new RequestPacket()); + lastPacketReceived = SystemClock.elapsedRealtime(); + final Optional h = resumed.getOptionalIntAttribute("h"); + final int serverCount; + if (h.isPresent()) { + serverCount = h.get(); + } else { + resetStreamId(); + throw new StateChangingException(ConnectionState.INCOMPATIBLE_SERVER); + } + final ArrayList failedStanzas = new ArrayList<>(); + final boolean acknowledgedMessages; + synchronized (this.mStanzaQueue) { + if (serverCount < stanzasSent) { + Log.d(Config.LOGTAG, account.address + ": session resumed with lost packages"); + stanzasSent = serverCount; + } else { + Log.d(Config.LOGTAG, account.address + ": session resumed"); + } + acknowledgedMessages = acknowledgeStanzaUpTo(serverCount); + for (int i = 0; i < this.mStanzaQueue.size(); ++i) { + failedStanzas.add(mStanzaQueue.valueAt(i)); + } + mStanzaQueue.clear(); + } + Log.d(Config.LOGTAG, account.address + ": resending " + failedStanzas.size() + " stanzas"); + for (final AbstractAcknowledgeableStanza packet : failedStanzas) { + if (packet instanceof MessagePacket) { + MessagePacket message = (MessagePacket) packet; + // TODO set ack = false in message table + // context.markMessage(account, message.getTo().asBareJid(), message.getId(), + // Message.STATUS_UNSEND); + } + sendPacket(packet); + } + changeStatusToOnline(); + } + + private void changeStatusToOnline() { + Log.d( + Config.LOGTAG, + account.address + ": online with resource " + connectionAddress.getResource()); + changeStatus(ConnectionState.ONLINE); + } + + private void processFailed(final Element failed, final boolean sendBindRequest) { + final Optional serverCount = failed.getOptionalIntAttribute("h"); + if (serverCount.isPresent()) { + Log.d( + Config.LOGTAG, + account.address + + ": resumption failed but server acknowledged stanza #" + + serverCount.get()); + final boolean acknowledgedMessages; + synchronized (this.mStanzaQueue) { + acknowledgedMessages = acknowledgeStanzaUpTo(serverCount.get()); + } + } else { + Log.d(Config.LOGTAG, account.address + ": resumption failed"); + } + resetStreamId(); + if (sendBindRequest) { + sendBindRequest(); + } + } + + private boolean acknowledgeStanzaUpTo(final int serverCount) { + if (serverCount > stanzasSent) { + Log.e( + Config.LOGTAG, + "server acknowledged more stanzas than we sent. serverCount=" + + serverCount + + ", ourCount=" + + stanzasSent); + } + boolean acknowledgedMessages = false; + for (int i = 0; i < mStanzaQueue.size(); ++i) { + if (serverCount >= mStanzaQueue.keyAt(i)) { + if (Config.EXTENDED_SM_LOGGING) { + Log.d( + Config.LOGTAG, + account.address + + ": server acknowledged stanza #" + + mStanzaQueue.keyAt(i)); + } + final AbstractAcknowledgeableStanza stanza = mStanzaQueue.valueAt(i); + if (stanza instanceof MessagePacket && messageAcknowledgeProcessor != null) { + final MessagePacket packet = (MessagePacket) stanza; + final String id = packet.getId(); + final Jid to = packet.getTo(); + if (id != null && to != null) { + acknowledgedMessages |= messageAcknowledgeProcessor.apply(to, id); + } + } + mStanzaQueue.removeAt(i); + i--; + } + } + return acknowledgedMessages; + } + + private @NonNull Element processPacket(final Tag currentTag, final int packetType) + throws IOException { + final Element element; + switch (packetType) { + case PACKET_IQ: + element = new IqPacket(); + break; + case PACKET_MESSAGE: + element = new MessagePacket(); + break; + case PACKET_PRESENCE: + element = new PresencePacket(); + break; + default: + throw new AssertionError("Should never encounter invalid type"); + } + element.setAttributes(currentTag.getAttributes()); + Tag nextTag = tagReader.readTag(); + if (nextTag == null) { + throw new IOException("interrupted mid tag"); + } + while (!nextTag.isEnd(element.getName())) { + if (!nextTag.isNo()) { + element.addChild(tagReader.readElement(nextTag)); + } + nextTag = tagReader.readTag(); + if (nextTag == null) { + throw new IOException("interrupted mid tag"); + } + } + if (stanzasReceived == Integer.MAX_VALUE) { + resetStreamId(); + throw new IOException("time to restart the session. cant handle >2 billion pcks"); + } + if (inSmacksSession) { + ++stanzasReceived; + } else if (features.sm()) { + Log.d( + Config.LOGTAG, + account.address + + ": not counting stanza(" + + element.getClass().getSimpleName() + + "). Not in smacks session."); + } + lastPacketReceived = SystemClock.elapsedRealtime(); + if (element instanceof IqPacket + && (((IqPacket) element).getType() == IqPacket.TYPE.SET) + && element.hasChild("jingle", Namespace.JINGLE)) { + return JinglePacket.upgrade((IqPacket) element); + } else { + return element; + } + } + + private void processIq(final Tag currentTag) throws IOException { + final IqPacket packet = (IqPacket) processPacket(currentTag, PACKET_IQ); + if (!packet.valid()) { + Log.e( + Config.LOGTAG, + "encountered invalid iq from='" + + packet.getFrom() + + "' to='" + + packet.getTo() + + "'"); + return; + } + if (packet instanceof JinglePacket) { + this.jinglePacketConsumer.accept((JinglePacket) packet); + } else { + final Consumer callback; + synchronized (this.packetCallbacks) { + final Pair> packetCallbackDuple = + packetCallbacks.get(packet.getId()); + if (packetCallbackDuple != null) { + // Packets to the server should have responses from the server + if (toServer(packetCallbackDuple.first)) { + if (fromServer(packet)) { + callback = packetCallbackDuple.second; + packetCallbacks.remove(packet.getId()); + } else { + callback = null; + Log.e(Config.LOGTAG, account.address + ": ignoring spoofed iq packet"); + } + } else { + if (packet.getFrom() != null + && packet.getFrom().equals(packetCallbackDuple.first.getTo())) { + callback = packetCallbackDuple.second; + packetCallbacks.remove(packet.getId()); + } else { + callback = null; + Log.e(Config.LOGTAG, account.address + ": ignoring spoofed iq packet"); + } + } + } else if (packet.getType() == IqPacket.TYPE.GET + || packet.getType() == IqPacket.TYPE.SET) { + callback = this.iqPacketConsumer; + } else { + callback = null; + } + } + if (callback != null) { + try { + callback.accept(packet); + } catch (StateChangingError error) { + throw new StateChangingException(error.state); + } + } + } + } + + private void processMessage(final Tag currentTag) throws IOException { + final MessagePacket packet = (MessagePacket) processPacket(currentTag, PACKET_MESSAGE); + if (!packet.valid()) { + Log.e( + Config.LOGTAG, + "encountered invalid message from='" + + packet.getFrom() + + "' to='" + + packet.getTo() + + "'"); + return; + } + this.messagePacketConsumer.accept(packet); + } + + private void processPresence(final Tag currentTag) throws IOException { + PresencePacket packet = (PresencePacket) processPacket(currentTag, PACKET_PRESENCE); + if (!packet.valid()) { + Log.e( + Config.LOGTAG, + "encountered invalid presence from='" + + packet.getFrom() + + "' to='" + + packet.getTo() + + "'"); + return; + } + this.presencePacketConsumer.accept(packet); + } + + private void sendStartTLS() throws IOException { + final Tag startTLS = Tag.empty("starttls"); + startTLS.setAttribute("xmlns", Namespace.TLS); + tagWriter.writeTag(startTLS); + } + + private void switchOverToTls() throws XmlPullParserException, IOException { + tagReader.readTag(); + final Socket socket = this.socket; + final SSLSocket sslSocket = upgradeSocketToTls(socket); + tagReader.setInputStream(sslSocket.getInputStream()); + tagWriter.setOutputStream(sslSocket.getOutputStream()); + Log.d(Config.LOGTAG, account.address + ": TLS connection established"); + final boolean quickStart; + try { + quickStart = establishStream(SSLSockets.version(sslSocket)); + } catch (final InterruptedException e) { + return; + } + if (quickStart) { + this.quickStartInProgress = true; + } + features.encryptionEnabled = true; + final Tag tag = tagReader.readTag(); + if (tag != null && tag.isStart("stream", Namespace.STREAMS)) { + SSLSockets.log(account.address, sslSocket); + processStream(); + } else { + throw new StateChangingException(ConnectionState.STREAM_OPENING_ERROR); + } + sslSocket.close(); + } + + private SSLSocket upgradeSocketToTls(final Socket socket) throws IOException { + final SSLSocketFactory sslSocketFactory; + try { + sslSocketFactory = getSSLSocketFactory(); + } catch (final NoSuchAlgorithmException | KeyManagementException e) { + throw new StateChangingException(ConnectionState.TLS_ERROR); + } + final InetAddress address = socket.getInetAddress(); + final SSLSocket sslSocket = + (SSLSocket) + sslSocketFactory.createSocket( + socket, address.getHostAddress(), socket.getPort(), true); + SSLSockets.setSecurity(sslSocket); + SSLSockets.setHostname( + sslSocket, IDN.toASCII(account.address.getDomain().toEscapedString())); + SSLSockets.setApplicationProtocol(sslSocket, "xmpp-client"); + final XmppDomainVerifier xmppDomainVerifier = new XmppDomainVerifier(); + try { + if (!xmppDomainVerifier.verify( + account.address.getDomain().toEscapedString(), + this.verifiedHostname, + sslSocket.getSession())) { + Log.d( + Config.LOGTAG, + account.address + ": TLS certificate domain verification failed"); + FileBackend.close(sslSocket); + throw new StateChangingException(ConnectionState.TLS_ERROR_DOMAIN); + } + } catch (final SSLPeerUnverifiedException e) { + FileBackend.close(sslSocket); + throw new StateChangingException(ConnectionState.TLS_ERROR); + } + return sslSocket; + } + + private void processStreamFeatures(final Tag currentTag) throws IOException { + final boolean pendingRegistration = + ConversationsDatabase.getInstance(context) + .accountDao() + .pendingRegistration(account.id); + this.streamFeatures = tagReader.readElement(currentTag); + final boolean isSecure = isSecure(); + final boolean needsBinding = !isBound && !pendingRegistration; + if (this.quickStartInProgress) { + if (this.streamFeatures.hasChild("authentication", Namespace.SASL_2)) { + Log.d( + Config.LOGTAG, + account.address + + ": quick start in progress. ignoring features: " + + XmlHelper.printElementNames(this.streamFeatures)); + if (SaslMechanism.hashedToken(this.saslMechanism)) { + return; + } + if (isFastTokenAvailable( + this.streamFeatures.findChild("authentication", Namespace.SASL_2))) { + Log.d( + Config.LOGTAG, + account.address + ": fast token available; resetting quick start"); + ConversationsDatabase.getInstance(context) + .accountDao() + .setQuickStartAvailable(account.id, false); + } + return; + } + Log.d( + Config.LOGTAG, + account.address + ": server lost support for SASL 2. quick start not possible"); + ConversationsDatabase.getInstance(context) + .accountDao() + .setQuickStartAvailable(account.id, false); + throw new StateChangingException(ConnectionState.INCOMPATIBLE_SERVER); + } + if (this.streamFeatures.hasChild("starttls", Namespace.TLS) + && !features.encryptionEnabled) { + sendStartTLS(); + } else if (this.streamFeatures.hasChild("register", Namespace.REGISTER_STREAM_FEATURE) + && pendingRegistration) { + if (isSecure) { + register(); + } else { + Log.d( + Config.LOGTAG, + account.address + + ": unable to find STARTTLS for registration process " + + XmlHelper.printElementNames(this.streamFeatures)); + throw new StateChangingException(ConnectionState.INCOMPATIBLE_SERVER); + } + } else if (!this.streamFeatures.hasChild("register", Namespace.REGISTER_STREAM_FEATURE) + && pendingRegistration) { + throw new StateChangingException(ConnectionState.REGISTRATION_NOT_SUPPORTED); + } else if (this.streamFeatures.hasChild("authentication", Namespace.SASL_2) + && shouldAuthenticate + && isSecure) { + authenticate(SaslMechanism.Version.SASL_2); + } else if (this.streamFeatures.hasChild("mechanisms", Namespace.SASL) + && shouldAuthenticate + && isSecure) { + authenticate(SaslMechanism.Version.SASL); + } else if (this.streamFeatures.hasChild("sm", Namespace.STREAM_MANAGEMENT) + && streamId != null + && !inSmacksSession) { + if (Config.EXTENDED_SM_LOGGING) { + Log.d( + Config.LOGTAG, + account.address + ": resuming after stanza #" + stanzasReceived); + } + final ResumePacket resume = new ResumePacket(this.streamId, stanzasReceived); + this.mSmCatchupMessageCounter.set(0); + this.mWaitingForSmCatchup.set(true); + this.tagWriter.writeStanzaAsync(resume); + } else if (needsBinding) { + if (this.streamFeatures.hasChild("bind", Namespace.BIND) && isSecure) { + sendBindRequest(); + } else { + Log.d( + Config.LOGTAG, + account.address + + ": unable to find bind feature " + + XmlHelper.printElementNames(this.streamFeatures)); + throw new StateChangingException(ConnectionState.INCOMPATIBLE_SERVER); + } + } else { + Log.d( + Config.LOGTAG, + account.address + + ": received NOP stream features: " + + XmlHelper.printElementNames(this.streamFeatures)); + } + } + + private void authenticate() throws IOException { + final boolean isSecure = isSecure(); + if (isSecure && this.streamFeatures.hasChild("authentication", Namespace.SASL_2)) { + authenticate(SaslMechanism.Version.SASL_2); + } else if (isSecure && this.streamFeatures.hasChild("mechanisms", Namespace.SASL)) { + authenticate(SaslMechanism.Version.SASL); + } else { + throw new StateChangingException(ConnectionState.INCOMPATIBLE_SERVER); + } + } + + private boolean isSecure() { + return features.encryptionEnabled || Config.ALLOW_NON_TLS_CONNECTIONS || account.isOnion(); + } + + private void authenticate(final SaslMechanism.Version version) throws IOException { + final Element authElement; + if (version == SaslMechanism.Version.SASL) { + authElement = this.streamFeatures.findChild("mechanisms", Namespace.SASL); + } else { + authElement = this.streamFeatures.findChild("authentication", Namespace.SASL_2); + } + final Collection mechanisms = SaslMechanism.mechanisms(authElement); + final Element cbElement = + this.streamFeatures.findChild("sasl-channel-binding", Namespace.CHANNEL_BINDING); + final Collection channelBindings = ChannelBinding.of(cbElement); + final SaslMechanism.Factory saslFactory = + new SaslMechanism.Factory( + account, CredentialStore.getInstance(context).get(account)); + final SaslMechanism saslMechanism = + saslFactory.of( + mechanisms, channelBindings, version, SSLSockets.version(this.socket)); + this.saslMechanism = validate(saslMechanism, mechanisms); + final boolean quickStartAvailable; + final String firstMessage = + this.saslMechanism.getClientFirstMessage(sslSocketOrNull(this.socket)); + final boolean usingFast = SaslMechanism.hashedToken(this.saslMechanism); + final Element authenticate; + if (version == SaslMechanism.Version.SASL) { + authenticate = new Element("auth", Namespace.SASL); + if (!Strings.isNullOrEmpty(firstMessage)) { + authenticate.setContent(firstMessage); + } + quickStartAvailable = false; + } else if (version == SaslMechanism.Version.SASL_2) { + final Element inline = authElement.findChild("inline", Namespace.SASL_2); + final boolean sm = inline != null && inline.hasChild("sm", Namespace.STREAM_MANAGEMENT); + final HashedToken.Mechanism hashTokenRequest; + if (usingFast) { + hashTokenRequest = null; + } else { + final Element fast = + inline == null ? null : inline.findChild("fast", Namespace.FAST); + final Collection fastMechanisms = SaslMechanism.mechanisms(fast); + hashTokenRequest = + HashedToken.Mechanism.best(fastMechanisms, SSLSockets.version(this.socket)); + } + final Collection bindFeatures = Bind2.features(inline); + quickStartAvailable = + sm + && bindFeatures != null + && bindFeatures.containsAll(Bind2.QUICKSTART_FEATURES); + this.hashTokenRequest = hashTokenRequest; + authenticate = + generateAuthenticationRequest( + firstMessage, usingFast, hashTokenRequest, bindFeatures, sm); + } else { + throw new AssertionError("Missing implementation for " + version); + } + + ConversationsDatabase.getInstance(context) + .accountDao() + .setQuickStartAvailable(account.id, quickStartAvailable); + + Log.d( + Config.LOGTAG, + account.address + + ": Authenticating with " + + version + + "/" + + this.saslMechanism.getMechanism()); + authenticate.setAttribute("mechanism", this.saslMechanism.getMechanism()); + synchronized (this.mStanzaQueue) { + this.stanzasSentBeforeAuthentication = this.stanzasSent; + tagWriter.writeElement(authenticate); + } + } + + private static boolean isFastTokenAvailable(final Element authentication) { + final Element inline = authentication == null ? null : authentication.findChild("inline"); + return inline != null && inline.hasChild("fast", Namespace.FAST); + } + + @NonNull + private SaslMechanism validate( + final @Nullable SaslMechanism saslMechanism, Collection mechanisms) + throws StateChangingException { + if (saslMechanism == null) { + Log.d( + Config.LOGTAG, + account.address + ": unable to find supported SASL mechanism in " + mechanisms); + throw new StateChangingException(ConnectionState.INCOMPATIBLE_SERVER); + } + if (SaslMechanism.hashedToken(saslMechanism)) { + return saslMechanism; + } + final SaslMechanism.Factory saslFactory = + new SaslMechanism.Factory( + account, CredentialStore.getInstance(context).get(account)); + final int pinnedMechanism = saslFactory.getPinnedMechanismPriority(); + if (pinnedMechanism > saslMechanism.getPriority()) { + Log.e( + Config.LOGTAG, + "Auth failed. Authentication mechanism " + + saslMechanism.getMechanism() + + " has lower priority (" + + saslMechanism.getPriority() + + ") than pinned priority (" + + pinnedMechanism + + "). Possible downgrade attack?"); + throw new StateChangingException(ConnectionState.DOWNGRADE_ATTACK); + } + return saslMechanism; + } + + private Element generateAuthenticationRequest( + final String firstMessage, final boolean usingFast) { + return generateAuthenticationRequest( + firstMessage, usingFast, null, Bind2.QUICKSTART_FEATURES, true); + } + + private Element generateAuthenticationRequest( + final String firstMessage, + final boolean usingFast, + final HashedToken.Mechanism hashedTokenRequest, + final Collection bind, + final boolean inlineStreamManagement) { + final Element authenticate = new Element("authenticate", Namespace.SASL_2); + if (!Strings.isNullOrEmpty(firstMessage)) { + authenticate.addChild("initial-response").setContent(firstMessage); + } + final Element userAgent = authenticate.addChild("user-agent"); + userAgent.setAttribute("id", account.getPublicDeviceId().toString()); + userAgent.addChild("software").setContent(context.getString(R.string.app_name)); + if (!PhoneHelper.isEmulator()) { + userAgent + .addChild("device") + .setContent(String.format("%s %s", Build.MANUFACTURER, Build.MODEL)); + } + if (bind != null) { + authenticate.addChild(generateBindRequest(bind)); + } + if (inlineStreamManagement && streamId != null) { + final ResumePacket resume = new ResumePacket(this.streamId, stanzasReceived); + this.mSmCatchupMessageCounter.set(0); + this.mWaitingForSmCatchup.set(true); + authenticate.addChild(resume); + } + if (hashedTokenRequest != null) { + authenticate + .addChild("request-token", Namespace.FAST) + .setAttribute("mechanism", hashedTokenRequest.name()); + } + if (usingFast) { + authenticate.addChild("fast", Namespace.FAST); + } + return authenticate; + } + + private Element generateBindRequest(final Collection bindFeatures) { + Log.d(Config.LOGTAG, "inline bind features: " + bindFeatures); + final Element bind = new Element("bind", Namespace.BIND2); + bind.addChild("tag").setContent(context.getString(R.string.app_name)); + if (bindFeatures.contains(Namespace.CARBONS)) { + bind.addChild("enable", Namespace.CARBONS); + } + if (bindFeatures.contains(Namespace.STREAM_MANAGEMENT)) { + bind.addChild(new EnablePacket()); + } + return bind; + } + + private void register() { + final Credential credential = CredentialStore.getInstance(context).get(account); + final String preAuth = credential.preAuthRegistrationToken; + if (Strings.isNullOrEmpty(preAuth) || !features.invite()) { + sendRegistryRequest(); + return; + } + final IqPacket preAuthRequest = new IqPacket(IqPacket.TYPE.SET); + preAuthRequest.addChild("preauth", Namespace.PARS).setAttribute("token", preAuth); + sendUnmodifiedIqPacket( + preAuthRequest, + (response) -> { + if (response.getType() == IqPacket.TYPE.RESULT) { + sendRegistryRequest(); + } else { + final String error = response.getErrorCondition(); + Log.d(Config.LOGTAG, account.address + ": failed to pre auth. " + error); + throw new StateChangingError(ConnectionState.REGISTRATION_INVALID_TOKEN); + } + }, + true); + } + + private void sendRegistryRequest() { + final IqPacket register = new IqPacket(IqPacket.TYPE.GET); + register.query(Namespace.REGISTER); + register.setTo(account.address.getDomain()); + sendUnmodifiedIqPacket( + register, + (packet) -> { + if (packet.getType() == IqPacket.TYPE.TIMEOUT) { + return; + } + if (packet.getType() == IqPacket.TYPE.ERROR) { + throw new StateChangingError(ConnectionState.REGISTRATION_FAILED); + } + final Element query = packet.query(Namespace.REGISTER); + if (query.hasChild("username") && (query.hasChild("password"))) { + final Credential credential = + CredentialStore.getInstance(context).get(account); + final IqPacket register1 = new IqPacket(IqPacket.TYPE.SET); + final Element username = + new Element("username") + .setContent(account.address.getEscapedLocal()); + final Element password = + new Element("password").setContent(credential.password); + register1.query(Namespace.REGISTER).addChild(username); + register1.query().addChild(password); + register1.setFrom(account.address); + sendUnmodifiedIqPacket(register1, this::handleRegistrationResponse, true); + } else if (query.hasChild("x", Namespace.DATA)) { + final Data data = Data.parse(query.findChild("x", Namespace.DATA)); + final Element blob = query.findChild("data", "urn:xmpp:bob"); + final String id = packet.getId(); + InputStream is; + if (blob != null) { + try { + final String base64Blob = blob.getContent(); + final byte[] strBlob = Base64.decode(base64Blob, Base64.DEFAULT); + is = new ByteArrayInputStream(strBlob); + } catch (Exception e) { + is = null; + } + } else { + // TODO this too needs fixing + final boolean useTor = /*context.useTorToConnect() ||*/ + account.isOnion(); + try { + final String url = data.getValue("url"); + final String fallbackUrl = data.getValue("captcha-fallback-url"); + if (url != null) { + is = HttpConnectionManager.open(url, useTor); + } else if (fallbackUrl != null) { + is = HttpConnectionManager.open(fallbackUrl, useTor); + } else { + is = null; + } + } catch (final IOException e) { + Log.d( + Config.LOGTAG, + account.address + ": unable to fetch captcha", + e); + is = null; + } + } + + if (is != null) { + Bitmap captcha = BitmapFactory.decodeStream(is); + throw new StateChangingError(ConnectionState.REGISTRATION_FAILED); + } + throw new StateChangingError(ConnectionState.REGISTRATION_FAILED); + } else if (query.hasChild("instructions") + || query.hasChild("x", Namespace.OOB)) { + final String instructions = query.findChildContent("instructions"); + final Element oob = query.findChild("x", Namespace.OOB); + final String url = oob == null ? null : oob.findChildContent("url"); + if (url != null) { + setAccountCreationFailed(url); + } else if (instructions != null) { + final Matcher matcher = Patterns.AUTOLINK_WEB_URL.matcher(instructions); + if (matcher.find()) { + setAccountCreationFailed( + instructions.substring(matcher.start(), matcher.end())); + } + } + throw new StateChangingError(ConnectionState.REGISTRATION_FAILED); + } + }, + true); + } + + private void handleRegistrationResponse(final IqPacket packet) { + if (packet.getType() == IqPacket.TYPE.RESULT) { + ConversationsDatabase.getInstance(context) + .accountDao() + .setPendingRegistration(account.id, false); + Log.d( + Config.LOGTAG, + account.address + ": successfully registered new account on server"); + throw new StateChangingError(ConnectionState.REGISTRATION_SUCCESSFUL); + } else { + final List PASSWORD_TOO_WEAK_MSGS = + Arrays.asList("The password is too weak", "Please use a longer password."); + Element error = packet.findChild("error"); + ConnectionState state = ConnectionState.REGISTRATION_FAILED; + if (error != null) { + if (error.hasChild("conflict")) { + state = ConnectionState.REGISTRATION_CONFLICT; + } else if (error.hasChild("resource-constraint") + && "wait".equals(error.getAttribute("type"))) { + state = ConnectionState.REGISTRATION_PLEASE_WAIT; + } else if (error.hasChild("not-acceptable") + && PASSWORD_TOO_WEAK_MSGS.contains(error.findChildContent("text"))) { + state = ConnectionState.REGISTRATION_PASSWORD_TOO_WEAK; + } + } + throw new StateChangingError(state); + } + } + + private void setAccountCreationFailed(final String url) { + final HttpUrl httpUrl = url == null ? null : HttpUrl.parse(url); + if (httpUrl != null && httpUrl.isHttps()) { + this.redirectionUrl = httpUrl; + throw new StateChangingError(ConnectionState.REGISTRATION_WEB); + } + throw new StateChangingError(ConnectionState.REGISTRATION_FAILED); + } + + public HttpUrl getRedirectionUrl() { + return this.redirectionUrl; + } + + public void resetEverything() { + resetAttemptCount(true); + resetStreamId(); + clearIqCallbacks(); + this.stanzasSent = 0; + mStanzaQueue.clear(); + this.redirectionUrl = null; + synchronized (this.disco) { + disco.clear(); + } + synchronized (this.commands) { + this.commands.clear(); + } + this.saslMechanism = null; + } + + private void sendBindRequest() { + clearIqCallbacks(); + final String recentResource = + fixResource( + ConversationsDatabase.getInstance(context) + .accountDao() + .getResource(account.id)); + final String resource; + if (recentResource != null) { + resource = recentResource; + } else { + resource = this.createNewResource(account.randomSeed); + } + final IqPacket iq = new IqPacket(IqPacket.TYPE.SET); + iq.addChild("bind", Namespace.BIND).addChild("resource").setContent(resource); + this.sendUnmodifiedIqPacket( + iq, + (packet) -> { + if (packet.getType() == IqPacket.TYPE.TIMEOUT) { + return; + } + final Element bind = packet.findChild("bind"); + if (bind != null && packet.getType() == IqPacket.TYPE.RESULT) { + isBound = true; + final Element jid = bind.findChild("jid"); + if (jid != null && jid.getContent() != null) { + try { + final Jid assignedJid = Jid.ofEscaped(jid.getContent()); + if (!account.address.getDomain().equals(assignedJid.getDomain())) { + Log.d( + Config.LOGTAG, + account.address + + ": server tried to re-assign domain to " + + assignedJid.getDomain()); + throw new StateChangingError(ConnectionState.BIND_FAILURE); + } + setConnectionAddress(assignedJid); + if (streamFeatures.hasChild("session") + && !streamFeatures + .findChild("session") + .hasChild("optional")) { + sendStartSession(); + } else { + final boolean waitForDisco = enableStreamManagement(); + sendPostBindInitialization(waitForDisco, false); + } + return; + } catch (final IllegalArgumentException e) { + Log.d( + Config.LOGTAG, + account.address + + ": server reported invalid jid (" + + jid.getContent() + + ") on bind"); + } + } else { + Log.d( + Config.LOGTAG, + account.address + + ": disconnecting because of bind failure. (no jid)"); + } + } else { + Log.d( + Config.LOGTAG, + account.address + + ": disconnecting because of bind failure (" + + packet); + } + final Element error = packet.findChild("error"); + if (packet.getType() == IqPacket.TYPE.ERROR + && error != null + && error.hasChild("conflict")) { + final String alternativeResource = + createNewResource(SECURE_RANDOM.generateSeed(3)); + ConversationsDatabase.getInstance(context) + .accountDao() + .setResource(account.id, alternativeResource); + Log.d( + Config.LOGTAG, + account.address + + ": switching resource due to conflict (" + + alternativeResource + + ")"); + } + throw new StateChangingError(ConnectionState.BIND_FAILURE); + }, + true); + } + + private void setConnectionAddress(final Jid jid) { + this.connectionAddress = jid; + } + + private void clearIqCallbacks() { + final IqPacket failurePacket = new IqPacket(IqPacket.TYPE.TIMEOUT); + final ArrayList> callbacks = new ArrayList<>(); + synchronized (this.packetCallbacks) { + if (this.packetCallbacks.size() == 0) { + return; + } + Log.d( + Config.LOGTAG, + account.address + + ": clearing " + + this.packetCallbacks.size() + + " iq callbacks"); + final Iterator>> iterator = + this.packetCallbacks.values().iterator(); + while (iterator.hasNext()) { + Pair> entry = iterator.next(); + callbacks.add(entry.second); + iterator.remove(); + } + } + for (final Consumer callback : callbacks) { + try { + callback.accept(failurePacket); + } catch (StateChangingError error) { + Log.d( + Config.LOGTAG, + account.address + + ": caught StateChangingError(" + + error.state.toString() + + ") while clearing callbacks"); + // ignore + } + } + Log.d( + Config.LOGTAG, + account.address + + ": done clearing iq callbacks. " + + this.packetCallbacks.size() + + " left"); + } + + public void sendDiscoTimeout() { + if (mWaitForDisco.compareAndSet(true, false)) { + Log.d(Config.LOGTAG, account.address + ": finalizing bind after disco timeout"); + finalizeBind(); + } + } + + private void sendStartSession() { + Log.d(Config.LOGTAG, account.address + ": sending legacy session to outdated server"); + final IqPacket startSession = new IqPacket(IqPacket.TYPE.SET); + startSession.addChild("session", "urn:ietf:params:xml:ns:xmpp-session"); + this.sendUnmodifiedIqPacket( + startSession, + (packet) -> { + if (packet.getType() == IqPacket.TYPE.RESULT) { + final boolean waitForDisco = enableStreamManagement(); + sendPostBindInitialization(waitForDisco, false); + } else if (packet.getType() != IqPacket.TYPE.TIMEOUT) { + throw new StateChangingError(ConnectionState.SESSION_FAILURE); + } + }, + true); + } + + private boolean enableStreamManagement() { + final boolean streamManagement = + this.streamFeatures.hasChild("sm", Namespace.STREAM_MANAGEMENT); + if (streamManagement) { + synchronized (this.mStanzaQueue) { + final EnablePacket enable = new EnablePacket(); + tagWriter.writeStanzaAsync(enable); + stanzasSent = 0; + mStanzaQueue.clear(); + } + return true; + } else { + return false; + } + } + + private void sendPostBindInitialization( + final boolean waitForDisco, final boolean carbonsEnabled) { + features.carbonsEnabled = carbonsEnabled; + features.blockListRequested = false; + synchronized (this.disco) { + this.disco.clear(); + } + Log.d(Config.LOGTAG, account.address + ": starting service discovery"); + mPendingServiceDiscoveries.set(0); + mWaitForDisco.set(waitForDisco); + lastDiscoStarted = SystemClock.elapsedRealtime(); + // TODO bring back disco timeout + // context.scheduleWakeUpCall(Config.CONNECT_DISCO_TIMEOUT, account); + final Element caps = streamFeatures.findChild("c"); + final String hash = caps == null ? null : caps.getAttribute("hash"); + final String ver = caps == null ? null : caps.getAttribute("ver"); + ServiceDiscoveryResult discoveryResult = null; + if (hash != null && ver != null) { + // Bring back disco result caching + discoveryResult = + null; // context.getCachedServiceDiscoveryResult(new Pair<>(hash, ver)); + } + // TODO from an older git commit "should make initial connect faster because code is not + // waiting for omemo code to run" - do we need to keep this? + final boolean requestDiscoItemsFirst = + !ConversationsDatabase.getInstance(context).accountDao().isInitialLogin(account.id); + + if (requestDiscoItemsFirst) { + sendServiceDiscoveryItems(account.address.getDomain()); + } + if (discoveryResult == null) { + sendServiceDiscoveryInfo(account.address.getDomain()); + } else { + Log.d(Config.LOGTAG, account.address + ": server caps came from cache"); + disco.put(account.address.getDomain(), discoveryResult); + } + discoverMamPreferences(); + sendServiceDiscoveryInfo(account.address); + if (!requestDiscoItemsFirst) { + sendServiceDiscoveryItems(account.address.getDomain()); + } + + if (!mWaitForDisco.get()) { + finalizeBind(); + } + this.lastSessionStarted = SystemClock.elapsedRealtime(); + } + + // TODO rename to getConnectionState + public ConnectionState getStatus() { + return this.connectionState; + } + + private void sendServiceDiscoveryInfo(final Jid jid) { + mPendingServiceDiscoveries.incrementAndGet(); + final IqPacket iq = new IqPacket(IqPacket.TYPE.GET); + iq.setTo(jid); + iq.query("http://jabber.org/protocol/disco#info"); + this.sendIqPacket( + iq, + (packet) -> { + if (packet.getType() == IqPacket.TYPE.RESULT) { + boolean advancedStreamFeaturesLoaded; + synchronized (XmppConnection.this.disco) { + ServiceDiscoveryResult result = new ServiceDiscoveryResult(packet); + if (jid.equals(account.address.getDomain())) { + // context.databaseBackend.insertDiscoveryResult(result); + } + disco.put(jid, result); + advancedStreamFeaturesLoaded = + disco.containsKey(account.address.getDomain()) + && disco.containsKey(account.address); + } + if (advancedStreamFeaturesLoaded + && (jid.equals(account.address.getDomain()) + || jid.equals(account.address))) { + enableAdvancedStreamFeatures(); + } + } else if (packet.getType() == IqPacket.TYPE.ERROR) { + Log.d( + Config.LOGTAG, + account.address + + ": could not query disco info for " + + jid.toString()); + final boolean serverOrAccount = + jid.equals(account.address.getDomain()) + || jid.equals(account.address); + final boolean advancedStreamFeaturesLoaded; + if (serverOrAccount) { + synchronized (XmppConnection.this.disco) { + disco.put(jid, ServiceDiscoveryResult.empty()); + advancedStreamFeaturesLoaded = + disco.containsKey(account.address.getDomain()) + && disco.containsKey(account.address); + } + } else { + advancedStreamFeaturesLoaded = false; + } + if (advancedStreamFeaturesLoaded) { + enableAdvancedStreamFeatures(); + } + } + if (packet.getType() != IqPacket.TYPE.TIMEOUT) { + if (mPendingServiceDiscoveries.decrementAndGet() == 0 + && mWaitForDisco.compareAndSet(true, false)) { + finalizeBind(); + } + } + }); + } + + private void discoverMamPreferences() { + IqPacket request = new IqPacket(IqPacket.TYPE.GET); + request.addChild("prefs", MessageArchiveService.Version.MAM_2.namespace); + sendIqPacket( + request, + (response) -> { + if (response.getType() == IqPacket.TYPE.RESULT) { + Element prefs = + response.findChild( + "prefs", MessageArchiveService.Version.MAM_2.namespace); + isMamPreferenceAlways = + "always" + .equals( + prefs == null + ? null + : prefs.getAttribute("default")); + } + }); + } + + private void discoverCommands() { + final IqPacket request = new IqPacket(IqPacket.TYPE.GET); + request.setTo(account.address.getDomain()); + request.addChild("query", Namespace.DISCO_ITEMS).setAttribute("node", Namespace.COMMANDS); + sendIqPacket( + request, + (response) -> { + if (response.getType() == IqPacket.TYPE.RESULT) { + final Element query = response.findChild("query", Namespace.DISCO_ITEMS); + if (query == null) { + return; + } + final HashMap commands = new HashMap<>(); + for (final Element child : query.getChildren()) { + if ("item".equals(child.getName())) { + final String node = child.getAttribute("node"); + final Jid jid = child.getAttributeAsJid("jid"); + if (node != null && jid != null) { + commands.put(node, jid); + } + } + } + synchronized (this.commands) { + this.commands.clear(); + this.commands.putAll(commands); + } + } + }); + } + + public boolean isMamPreferenceAlways() { + return isMamPreferenceAlways; + } + + private void finalizeBind() { + this.bindConsumer.accept(this.connectionAddress); + this.changeStatusToOnline(); + } + + private void enableAdvancedStreamFeatures() { + if (getFeatures().blocking() && !features.blockListRequested) { + Log.d(Config.LOGTAG, account.address + ": Requesting block list"); + // TODO actually request block list + /*this.sendIqPacket( + getIqGenerator().generateGetBlockList(), context.getIqParser());*/ + } + if (getFeatures().carbons() && !features.carbonsEnabled) { + sendEnableCarbons(); + } + if (getFeatures().commands()) { + discoverCommands(); + } + } + + private void sendServiceDiscoveryItems(final Jid server) { + mPendingServiceDiscoveries.incrementAndGet(); + final IqPacket iq = new IqPacket(IqPacket.TYPE.GET); + iq.setTo(server.getDomain()); + iq.query("http://jabber.org/protocol/disco#items"); + this.sendIqPacket( + iq, + (packet) -> { + if (packet.getType() == IqPacket.TYPE.RESULT) { + final HashSet items = new HashSet<>(); + final List elements = packet.query().getChildren(); + for (final Element element : elements) { + if (element.getName().equals("item")) { + final Jid jid = + InvalidJid.getNullForInvalid( + element.getAttributeAsJid("jid")); + if (jid != null && !jid.equals(account.address.getDomain())) { + items.add(jid); + } + } + } + for (Jid jid : items) { + sendServiceDiscoveryInfo(jid); + } + } else { + Log.d( + Config.LOGTAG, + account.address + ": could not query disco items of " + server); + } + if (packet.getType() != IqPacket.TYPE.TIMEOUT) { + if (mPendingServiceDiscoveries.decrementAndGet() == 0 + && mWaitForDisco.compareAndSet(true, false)) { + finalizeBind(); + } + } + }); + } + + private void sendEnableCarbons() { + final IqPacket iq = new IqPacket(IqPacket.TYPE.SET); + iq.addChild("enable", Namespace.CARBONS); + this.sendIqPacket( + iq, + (packet) -> { + if (packet.getType() == IqPacket.TYPE.RESULT) { + Log.d(Config.LOGTAG, account.address + ": successfully enabled carbons"); + features.carbonsEnabled = true; + } else { + Log.d( + Config.LOGTAG, + account.address + ": could not enable carbons " + packet); + } + }); + } + + private void processStreamError(final Tag currentTag) throws IOException { + final Element streamError = tagReader.readElement(currentTag); + if (streamError == null) { + return; + } + if (streamError.hasChild("conflict")) { + final String alternativeResource = createNewResource(SECURE_RANDOM.generateSeed(3)); + ConversationsDatabase.getInstance(context) + .accountDao() + .setResource(account.id, alternativeResource); + Log.d( + Config.LOGTAG, + account.address + + ": switching resource due to conflict (" + + alternativeResource + + ")"); + throw new IOException(); + } else if (streamError.hasChild("host-unknown")) { + throw new StateChangingException(ConnectionState.HOST_UNKNOWN); + } else if (streamError.hasChild("policy-violation")) { + this.lastConnect = SystemClock.elapsedRealtime(); + final String text = streamError.findChildContent("text"); + Log.d(Config.LOGTAG, account.address + ": policy violation. " + text); + failPendingMessages(text); + throw new StateChangingException(ConnectionState.POLICY_VIOLATION); + } else { + Log.d(Config.LOGTAG, account.address + ": stream error " + streamError); + throw new StateChangingException(ConnectionState.STREAM_ERROR); + } + } + + private void failPendingMessages(final String error) { + synchronized (this.mStanzaQueue) { + for (int i = 0; i < mStanzaQueue.size(); ++i) { + final AbstractAcknowledgeableStanza stanza = mStanzaQueue.valueAt(i); + if (stanza instanceof MessagePacket) { + final MessagePacket packet = (MessagePacket) stanza; + final String id = packet.getId(); + final Jid to = packet.getTo(); + // TODO set ack=true but add error? + // TODO the intent was clearly to stop resending + // context.markMessage(account, to.asBareJid(), id, Message.STATUS_SEND_FAILED, + // error); + } + } + } + } + + private boolean establishStream(final SSLSockets.Version sslVersion) + throws IOException, InterruptedException { + final SaslMechanism.Factory saslFactory = + new SaslMechanism.Factory( + account, CredentialStore.getInstance(context).get(account)); + final SaslMechanism quickStartMechanism = + SaslMechanism.ensureAvailable(saslFactory.getQuickStartMechanism(), sslVersion); + final boolean secureConnection = sslVersion != SSLSockets.Version.NONE; + if (secureConnection + && Config.QUICKSTART_ENABLED + && quickStartMechanism != null + && ConversationsDatabase.getInstance(context) + .accountDao() + .quickStartAvailable(account.id)) { + // context.restoredFromDatabaseLatch.await(); + this.saslMechanism = quickStartMechanism; + final boolean usingFast = quickStartMechanism instanceof HashedToken; + final Element authenticate = + generateAuthenticationRequest( + quickStartMechanism.getClientFirstMessage(sslSocketOrNull(this.socket)), + usingFast); + authenticate.setAttribute("mechanism", quickStartMechanism.getMechanism()); + sendStartStream(true, false); + synchronized (this.mStanzaQueue) { + this.stanzasSentBeforeAuthentication = this.stanzasSent; + tagWriter.writeElement(authenticate); + } + Log.d( + Config.LOGTAG, + account.address + ": quick start with " + quickStartMechanism.getMechanism()); + return true; + } else { + sendStartStream(secureConnection, true); + return false; + } + } + + private void sendStartStream(final boolean from, final boolean flush) throws IOException { + final Tag stream = Tag.start("stream:stream"); + stream.setAttribute("to", account.address.getDomain()); + if (from) { + stream.setAttribute("from", account.address); + } + stream.setAttribute("version", "1.0"); + stream.setAttribute("xml:lang", LocalizedContent.STREAM_LANGUAGE); + stream.setAttribute("xmlns", "jabber:client"); + stream.setAttribute("xmlns:stream", Namespace.STREAMS); + tagWriter.writeTag(stream, flush); + } + + private String createNewResource(final byte[] random) { + return String.format( + "%s.%s", + context.getString(R.string.app_name), + Base64.encodeToString( + slice(random), Base64.NO_PADDING | Base64.NO_WRAP | Base64.URL_SAFE)); + } + + private static byte[] slice(final byte[] input) { + if (input == null || input.length < 3) { + return new byte[3]; + } + try { + return ByteSource.wrap(input).slice(0, 3).read(); + } catch (final IOException e) { + return new byte[3]; + } + } + + private String nextRandomId() { + return nextRandomId(false); + } + + private String nextRandomId(final boolean s) { + return CryptoHelper.random(s ? 3 : 9); + } + + public String sendIqPacket(final IqPacket packet, final Consumer callback) { + packet.setFrom(account.address); + return this.sendUnmodifiedIqPacket(packet, callback, false); + } + + public synchronized String sendUnmodifiedIqPacket( + final IqPacket packet, final Consumer callback, boolean force) { + if (packet.getId() == null) { + packet.setAttribute("id", nextRandomId()); + } + if (callback != null) { + synchronized (this.packetCallbacks) { + packetCallbacks.put(packet.getId(), new Pair<>(packet, callback)); + } + } + this.sendPacket(packet, force); + return packet.getId(); + } + + public void sendMessagePacket(final MessagePacket packet) { + this.sendPacket(packet); + } + + public void sendPresencePacket(final PresencePacket packet) { + this.sendPacket(packet); + } + + private synchronized void sendPacket(final AbstractStanza packet) { + sendPacket(packet, false); + } + + private synchronized void sendPacket(final AbstractStanza packet, final boolean force) { + if (stanzasSent == Integer.MAX_VALUE) { + resetStreamId(); + disconnect(true); + return; + } + synchronized (this.mStanzaQueue) { + if (force || isBound) { + tagWriter.writeStanzaAsync(packet); + } else { + Log.d( + Config.LOGTAG, + account.address + + " do not write stanza to unbound stream " + + packet.toString()); + } + if (packet instanceof AbstractAcknowledgeableStanza) { + AbstractAcknowledgeableStanza stanza = (AbstractAcknowledgeableStanza) packet; + + if (this.mStanzaQueue.size() != 0) { + int currentHighestKey = this.mStanzaQueue.keyAt(this.mStanzaQueue.size() - 1); + if (currentHighestKey != stanzasSent) { + throw new AssertionError("Stanza count messed up"); + } + } + + ++stanzasSent; + if (Config.EXTENDED_SM_LOGGING) { + Log.d( + Config.LOGTAG, + account.address + + ": counting outbound " + + packet.getName() + + " as #" + + stanzasSent); + } + this.mStanzaQueue.append(stanzasSent, stanza); + if (stanza instanceof MessagePacket && stanza.getId() != null && inSmacksSession) { + if (Config.EXTENDED_SM_LOGGING) { + Log.d( + Config.LOGTAG, + account.address + + ": requesting ack for message stanza #" + + stanzasSent); + } + tagWriter.writeStanzaAsync(new RequestPacket()); + } + } + } + } + + public void sendPing() { + if (!r()) { + final IqPacket iq = new IqPacket(IqPacket.TYPE.GET); + iq.setFrom(account.address); + iq.addChild("ping", Namespace.PING); + this.sendIqPacket(iq, null); + } + this.lastPingSent = SystemClock.elapsedRealtime(); + } + + public void setOnStatusChangedListener(final Consumer listener) { + this.statusListener = listener; + } + + private void forceCloseSocket() { + FileBackend.close(this.socket); + FileBackend.close(this.tagReader); + } + + public void interrupt() { + if (this.mThread != null) { + this.mThread.interrupt(); + } + } + + public void disconnect(final boolean force) { + interrupt(); + Log.d(Config.LOGTAG, account.address + ": disconnecting force=" + force); + if (force) { + forceCloseSocket(); + } else { + final TagWriter currentTagWriter = this.tagWriter; + if (currentTagWriter.isActive()) { + currentTagWriter.finish(); + final Socket currentSocket = this.socket; + final CountDownLatch streamCountDownLatch = this.mStreamCountDownLatch; + try { + currentTagWriter.await(1, TimeUnit.SECONDS); + Log.d(Config.LOGTAG, account.address + ": closing stream"); + currentTagWriter.writeTag(Tag.end("stream:stream")); + if (streamCountDownLatch != null) { + if (streamCountDownLatch.await(1, TimeUnit.SECONDS)) { + Log.d(Config.LOGTAG, account.address + ": remote ended stream"); + } else { + Log.d( + Config.LOGTAG, + account.address + + ": remote has not closed socket. force closing"); + } + } + } catch (InterruptedException e) { + Log.d( + Config.LOGTAG, + account.address + ": interrupted while gracefully closing stream"); + } catch (final IOException e) { + Log.d( + Config.LOGTAG, + account.address + + ": io exception during disconnect (" + + e.getMessage() + + ")"); + } finally { + FileBackend.close(currentSocket); + } + } else { + forceCloseSocket(); + } + } + } + + private void resetStreamId() { + this.streamId = null; + } + + private List> findDiscoItemsByFeature(final String feature) { + synchronized (this.disco) { + final List> items = new ArrayList<>(); + for (final Entry cursor : this.disco.entrySet()) { + if (cursor.getValue().getFeatures().contains(feature)) { + items.add(cursor); + } + } + return items; + } + } + + public Jid findDiscoItemByFeature(final String feature) { + final List> items = findDiscoItemsByFeature(feature); + if (items.size() >= 1) { + return items.get(0).getKey(); + } + return null; + } + + public boolean r() { + if (getFeatures().sm()) { + this.tagWriter.writeStanzaAsync(new RequestPacket()); + return true; + } else { + return false; + } + } + + public List getMucServersWithholdAccount() { + final List servers = getMucServers(); + servers.remove(account.address.getDomain().toEscapedString()); + return servers; + } + + public List getMucServers() { + List servers = new ArrayList<>(); + synchronized (this.disco) { + for (final Entry cursor : disco.entrySet()) { + final ServiceDiscoveryResult value = cursor.getValue(); + if (value.getFeatures().contains("http://jabber.org/protocol/muc") + && value.hasIdentity("conference", "text") + && !value.getFeatures().contains("jabber:iq:gateway") + && !value.hasIdentity("conference", "irc")) { + servers.add(cursor.getKey().toString()); + } + } + } + return servers; + } + + public String getMucServer() { + List servers = getMucServers(); + return servers.size() > 0 ? servers.get(0) : null; + } + + public int getTimeToNextAttempt() { + final int additionalTime = + recentErrorConnectionState == ConnectionState.POLICY_VIOLATION ? 3 : 0; + final int interval = Math.min((int) (25 * Math.pow(1.3, (additionalTime + attempt))), 300); + final int secondsSinceLast = + (int) ((SystemClock.elapsedRealtime() - this.lastConnect) / 1000); + return interval - secondsSinceLast; + } + + public int getAttempt() { + return this.attempt; + } + + public Features getFeatures() { + return this.features; + } + + public long getLastSessionEstablished() { + final long diff = SystemClock.elapsedRealtime() - this.lastSessionStarted; + return System.currentTimeMillis() - diff; + } + + public long getLastConnect() { + return this.lastConnect; + } + + public long getLastPingSent() { + return this.lastPingSent; + } + + public long getLastDiscoStarted() { + return this.lastDiscoStarted; + } + + public long getLastPacketReceived() { + return this.lastPacketReceived; + } + + public void sendActive() { + this.sendPacket(new ActivePacket()); + } + + public void sendInactive() { + this.sendPacket(new InactivePacket()); + } + + public void resetAttemptCount(boolean resetConnectTime) { + this.attempt = 0; + if (resetConnectTime) { + this.lastConnect = 0; + } + } + + public boolean fromServer(final AbstractStanza stanza) { + final Jid from = stanza.getFrom(); + return from == null + || from.equals(connectionAddress.getDomain()) + || from.equals(connectionAddress.asBareJid()) + || from.equals(connectionAddress); + } + + public boolean toServer(final AbstractStanza stanza) { + final Jid to = stanza.getTo(); + return to == null + || to.equals(connectionAddress.getDomain()) + || to.equals(connectionAddress.asBareJid()) + || to.equals(connectionAddress); + } + + public boolean fromAccount(final AbstractStanza stanza) { + final Jid from = stanza.getFrom(); + return from != null && from.asBareJid().equals(connectionAddress.asBareJid()); + } + + private static class MyKeyManager implements X509KeyManager { + + private final Context context; + private final Credential credential; + + private MyKeyManager(Context context, Credential credential) { + this.context = context; + this.credential = credential; + } + + @Override + public String chooseClientAlias(String[] strings, Principal[] principals, Socket socket) { + return credential.privateKeyAlias; + } + + @Override + public String chooseServerAlias(String s, Principal[] principals, Socket socket) { + return null; + } + + @Override + public X509Certificate[] getCertificateChain(String alias) { + Log.d(Config.LOGTAG, "getting certificate chain"); + try { + return KeyChain.getCertificateChain(context, alias); + } catch (final Exception e) { + Log.d(Config.LOGTAG, "could not get certificate chain", e); + return new X509Certificate[0]; + } + } + + @Override + public String[] getClientAliases(String s, Principal[] principals) { + final String alias = credential.privateKeyAlias; + return alias != null ? new String[] {alias} : new String[0]; + } + + @Override + public String[] getServerAliases(String s, Principal[] principals) { + return new String[0]; + } + + @Override + public PrivateKey getPrivateKey(String alias) { + try { + return KeyChain.getPrivateKey(context, alias); + } catch (Exception e) { + return null; + } + } + } + + private static class StateChangingError extends Error { + private final ConnectionState state; + + public StateChangingError(ConnectionState state) { + this.state = state; + } + } + + private static class StateChangingException extends IOException { + private final ConnectionState state; + + public StateChangingException(ConnectionState state) { + this.state = state; + } + } + + public class Features { + XmppConnection connection; + private boolean carbonsEnabled = false; + private boolean encryptionEnabled = false; + private boolean blockListRequested = false; + + public Features(final XmppConnection connection) { + this.connection = connection; + } + + private boolean hasDiscoFeature(final Jid server, final String feature) { + synchronized (XmppConnection.this.disco) { + final ServiceDiscoveryResult sdr = connection.disco.get(server); + return sdr != null && sdr.getFeatures().contains(feature); + } + } + + public boolean carbons() { + return hasDiscoFeature(account.address.getDomain(), Namespace.CARBONS); + } + + public boolean commands() { + return hasDiscoFeature(account.address.getDomain(), Namespace.COMMANDS); + } + + public boolean easyOnboardingInvites() { + synchronized (commands) { + return commands.containsKey(Namespace.EASY_ONBOARDING_INVITE); + } + } + + public boolean bookmarksConversion() { + return hasDiscoFeature(account.address, Namespace.BOOKMARKS_CONVERSION) + && pepPublishOptions(); + } + + public boolean avatarConversion() { + return hasDiscoFeature(account.address, Namespace.AVATAR_CONVERSION) + && pepPublishOptions(); + } + + public boolean blocking() { + return hasDiscoFeature(account.address.getDomain(), Namespace.BLOCKING); + } + + public boolean spamReporting() { + return hasDiscoFeature(account.address.getDomain(), "urn:xmpp:reporting:reason:spam:0"); + } + + public boolean flexibleOfflineMessageRetrieval() { + return hasDiscoFeature( + account.address.getDomain(), Namespace.FLEXIBLE_OFFLINE_MESSAGE_RETRIEVAL); + } + + public boolean register() { + return hasDiscoFeature(account.address.getDomain(), Namespace.REGISTER); + } + + public boolean invite() { + return connection.streamFeatures != null + && connection.streamFeatures.hasChild("register", Namespace.INVITE); + } + + public boolean sm() { + return streamId != null + || (connection.streamFeatures != null + && connection.streamFeatures.hasChild( + "sm", Namespace.STREAM_MANAGEMENT)); + } + + public boolean csi() { + return connection.streamFeatures != null + && connection.streamFeatures.hasChild("csi", Namespace.CSI); + } + + public boolean pep() { + synchronized (XmppConnection.this.disco) { + ServiceDiscoveryResult info = disco.get(account.address); + return info != null && info.hasIdentity("pubsub", "pep"); + } + } + + public boolean pepPersistent() { + synchronized (XmppConnection.this.disco) { + ServiceDiscoveryResult info = disco.get(account.address); + return info != null + && info.getFeatures() + .contains("http://jabber.org/protocol/pubsub#persistent-items"); + } + } + + public boolean pepPublishOptions() { + return hasDiscoFeature(account.address, Namespace.PUBSUB_PUBLISH_OPTIONS); + } + + public boolean pepOmemoWhitelisted() { + return hasDiscoFeature(account.address, AxolotlService.PEP_OMEMO_WHITELISTED); + } + + public boolean mam() { + return MessageArchiveService.Version.has(getAccountFeatures()); + } + + public List getAccountFeatures() { + ServiceDiscoveryResult result = connection.disco.get(account.address); + return result == null ? Collections.emptyList() : result.getFeatures(); + } + + public boolean push() { + return hasDiscoFeature(account.address, Namespace.PUSH) + || hasDiscoFeature(account.address.getDomain(), Namespace.PUSH); + } + + public boolean rosterVersioning() { + return connection.streamFeatures != null && connection.streamFeatures.hasChild("ver"); + } + + public void setBlockListRequested(boolean value) { + this.blockListRequested = value; + } + + public boolean httpUpload(long filesize) { + if (Config.DISABLE_HTTP_UPLOAD) { + return false; + } else { + for (String namespace : + new String[] {Namespace.HTTP_UPLOAD, Namespace.HTTP_UPLOAD_LEGACY}) { + List> items = + findDiscoItemsByFeature(namespace); + if (items.size() > 0) { + try { + long maxsize = + Long.parseLong( + items.get(0) + .getValue() + .getExtendedDiscoInformation( + namespace, "max-file-size")); + if (filesize <= maxsize) { + return true; + } else { + Log.d( + Config.LOGTAG, + account.address + + ": http upload is not available for files with" + + " size " + + filesize + + " (max is " + + maxsize + + ")"); + return false; + } + } catch (Exception e) { + return true; + } + } + } + return false; + } + } + + public boolean useLegacyHttpUpload() { + return findDiscoItemByFeature(Namespace.HTTP_UPLOAD) == null + && findDiscoItemByFeature(Namespace.HTTP_UPLOAD_LEGACY) != null; + } + + public long getMaxHttpUploadSize() { + for (String namespace : + new String[] {Namespace.HTTP_UPLOAD, Namespace.HTTP_UPLOAD_LEGACY}) { + List> items = findDiscoItemsByFeature(namespace); + if (items.size() > 0) { + try { + return Long.parseLong( + items.get(0) + .getValue() + .getExtendedDiscoInformation(namespace, "max-file-size")); + } catch (Exception e) { + // ignored + } + } + } + return -1; + } + + public boolean stanzaIds() { + return hasDiscoFeature(account.address, Namespace.STANZA_IDS); + } + + public boolean bookmarks2() { + return Config + .USE_BOOKMARKS2 /* || hasDiscoFeature(account.address, Namespace.BOOKMARKS2_COMPAT)*/; + } + + public boolean externalServiceDiscovery() { + return hasDiscoFeature( + account.address.getDomain(), Namespace.EXTERNAL_SERVICE_DISCOVERY); + } + } +} diff --git a/src/main/java/im/conversations/android/xmpp/processor/BaseProcessor.java b/src/main/java/im/conversations/android/xmpp/processor/BaseProcessor.java new file mode 100644 index 000000000..7d3b2b3ce --- /dev/null +++ b/src/main/java/im/conversations/android/xmpp/processor/BaseProcessor.java @@ -0,0 +1,25 @@ +package im.conversations.android.xmpp.processor; + +import android.content.Context; +import im.conversations.android.database.ConversationsDatabase; +import im.conversations.android.database.model.Account; +import im.conversations.android.xmpp.XmppConnection; + +abstract class BaseProcessor { + + protected final Context context; + protected final XmppConnection connection; + + BaseProcessor(final Context context, final XmppConnection connection) { + this.context = context; + this.connection = connection; + } + + protected Account getAccount() { + return connection.getAccount(); + } + + protected ConversationsDatabase getDatabase() { + return ConversationsDatabase.getInstance(context); + } +} diff --git a/src/main/java/im/conversations/android/xmpp/processor/BindProcessor.java b/src/main/java/im/conversations/android/xmpp/processor/BindProcessor.java new file mode 100644 index 000000000..0f01f3c40 --- /dev/null +++ b/src/main/java/im/conversations/android/xmpp/processor/BindProcessor.java @@ -0,0 +1,54 @@ +package im.conversations.android.xmpp.processor; + +import android.content.Context; +import android.util.Log; +import com.google.common.base.Strings; +import eu.siacs.conversations.Config; +import eu.siacs.conversations.xml.Namespace; +import eu.siacs.conversations.xmpp.Jid; +import eu.siacs.conversations.xmpp.stanzas.IqPacket; +import im.conversations.android.xmpp.XmppConnection; +import java.util.function.Consumer; + +public class BindProcessor extends BaseProcessor implements Consumer { + + public BindProcessor(final Context context, final XmppConnection connection) { + super(context, connection); + } + + @Override + public void accept(final Jid jid) { + final var account = getAccount(); + final var database = getDatabase(); + + final boolean firstLogin = + database.accountDao().setLoggedInSuccessfully(account.id, true) > 0; + + if (firstLogin) { + // TODO publish display name if this is the first attempt + // iirc this is used when the display name is set from a certificate or something + } + + database.presenceDao().deletePresences(account.id); + + fetchRoster(); + + // TODO fetch bookmarks + + // TODO send initial presence + } + + private void fetchRoster() { + final var account = getAccount(); + final var database = getDatabase(); + final String rosterVersion = database.accountDao().getRosterVersion(account.id); + final IqPacket iqPacket = new IqPacket(IqPacket.TYPE.GET); + if (Strings.isNullOrEmpty(rosterVersion)) { + Log.d(Config.LOGTAG, account.address + ": fetching roster"); + } else { + Log.d(Config.LOGTAG, account.address + ": fetching roster version " + rosterVersion); + } + iqPacket.query(Namespace.ROSTER).setAttribute("ver", rosterVersion); + connection.sendIqPacket(iqPacket, result -> {}); + } +} diff --git a/src/main/java/im/conversations/android/xmpp/processor/IqProcessor.java b/src/main/java/im/conversations/android/xmpp/processor/IqProcessor.java new file mode 100644 index 000000000..1a1002062 --- /dev/null +++ b/src/main/java/im/conversations/android/xmpp/processor/IqProcessor.java @@ -0,0 +1,14 @@ +package im.conversations.android.xmpp.processor; + +import android.content.Context; +import eu.siacs.conversations.xmpp.stanzas.IqPacket; +import im.conversations.android.xmpp.XmppConnection; +import java.util.function.Consumer; + +public class IqProcessor implements Consumer { + + public IqProcessor(final Context context, final XmppConnection connection) {} + + @Override + public void accept(final IqPacket packet) {} +} diff --git a/src/main/java/im/conversations/android/xmpp/processor/JingleProcessor.java b/src/main/java/im/conversations/android/xmpp/processor/JingleProcessor.java new file mode 100644 index 000000000..493b398bb --- /dev/null +++ b/src/main/java/im/conversations/android/xmpp/processor/JingleProcessor.java @@ -0,0 +1,14 @@ +package im.conversations.android.xmpp.processor; + +import android.content.Context; +import eu.siacs.conversations.xmpp.jingle.stanzas.JinglePacket; +import im.conversations.android.xmpp.XmppConnection; +import java.util.function.Consumer; + +public class JingleProcessor implements Consumer { + + public JingleProcessor(final Context context, final XmppConnection connection) {} + + @Override + public void accept(JinglePacket jinglePacket) {} +} diff --git a/src/main/java/im/conversations/android/xmpp/processor/MessageAcknowledgeProcessor.java b/src/main/java/im/conversations/android/xmpp/processor/MessageAcknowledgeProcessor.java new file mode 100644 index 000000000..e35986fde --- /dev/null +++ b/src/main/java/im/conversations/android/xmpp/processor/MessageAcknowledgeProcessor.java @@ -0,0 +1,16 @@ +package im.conversations.android.xmpp.processor; + +import android.content.Context; +import eu.siacs.conversations.xmpp.Jid; +import im.conversations.android.xmpp.XmppConnection; +import java.util.function.BiFunction; + +public class MessageAcknowledgeProcessor implements BiFunction { + + public MessageAcknowledgeProcessor(final Context context, final XmppConnection connection) {} + + @Override + public Boolean apply(final Jid to, final String id) { + return null; + } +} diff --git a/src/main/java/im/conversations/android/xmpp/processor/MessageProcessor.java b/src/main/java/im/conversations/android/xmpp/processor/MessageProcessor.java new file mode 100644 index 000000000..594b7d7ca --- /dev/null +++ b/src/main/java/im/conversations/android/xmpp/processor/MessageProcessor.java @@ -0,0 +1,14 @@ +package im.conversations.android.xmpp.processor; + +import android.content.Context; +import eu.siacs.conversations.xmpp.stanzas.MessagePacket; +import im.conversations.android.xmpp.XmppConnection; +import java.util.function.Consumer; + +public class MessageProcessor implements Consumer { + + public MessageProcessor(final Context context, final XmppConnection connection) {} + + @Override + public void accept(final MessagePacket messagePacket) {} +} diff --git a/src/main/java/im/conversations/android/xmpp/processor/PresenceProcessor.java b/src/main/java/im/conversations/android/xmpp/processor/PresenceProcessor.java new file mode 100644 index 000000000..bbe4f95c2 --- /dev/null +++ b/src/main/java/im/conversations/android/xmpp/processor/PresenceProcessor.java @@ -0,0 +1,14 @@ +package im.conversations.android.xmpp.processor; + +import android.content.Context; +import eu.siacs.conversations.xmpp.stanzas.PresencePacket; +import im.conversations.android.xmpp.XmppConnection; +import java.util.function.Consumer; + +public class PresenceProcessor implements Consumer { + + public PresenceProcessor(final Context context, final XmppConnection connection) {} + + @Override + public void accept(PresencePacket presencePacket) {} +} diff --git a/src/main/java/im/conversations/android/xmpp/sasl/Anonymous.java b/src/main/java/im/conversations/android/xmpp/sasl/Anonymous.java new file mode 100644 index 000000000..e440e3856 --- /dev/null +++ b/src/main/java/im/conversations/android/xmpp/sasl/Anonymous.java @@ -0,0 +1,29 @@ +package im.conversations.android.xmpp.sasl; + +import im.conversations.android.database.model.Account; +import im.conversations.android.database.model.Credential; +import javax.net.ssl.SSLSocket; + +public class Anonymous extends SaslMechanism { + + public static final String MECHANISM = "ANONYMOUS"; + + public Anonymous(final Account account) { + super(account, Credential.empty()); + } + + @Override + public int getPriority() { + return 0; + } + + @Override + public String getMechanism() { + return MECHANISM; + } + + @Override + public String getClientFirstMessage(final SSLSocket sslSocket) { + return ""; + } +} diff --git a/src/main/java/im/conversations/android/xmpp/sasl/ChannelBinding.java b/src/main/java/im/conversations/android/xmpp/sasl/ChannelBinding.java new file mode 100644 index 000000000..3bc38d0b9 --- /dev/null +++ b/src/main/java/im/conversations/android/xmpp/sasl/ChannelBinding.java @@ -0,0 +1,117 @@ +package im.conversations.android.xmpp.sasl; + +import android.util.Log; +import com.google.common.base.CaseFormat; +import com.google.common.base.Preconditions; +import com.google.common.base.Predicates; +import com.google.common.base.Strings; +import com.google.common.collect.BiMap; +import com.google.common.collect.Collections2; +import com.google.common.collect.ImmutableBiMap; +import eu.siacs.conversations.Config; +import eu.siacs.conversations.utils.SSLSockets; +import eu.siacs.conversations.xml.Element; +import eu.siacs.conversations.xml.Namespace; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; + +public enum ChannelBinding { + NONE, + TLS_EXPORTER, + TLS_SERVER_END_POINT, + TLS_UNIQUE; + + public static final BiMap SHORT_NAMES; + + static { + final ImmutableBiMap.Builder builder = ImmutableBiMap.builder(); + for (final ChannelBinding cb : values()) { + builder.put(cb, shortName(cb)); + } + SHORT_NAMES = builder.build(); + } + + public static Collection of(final Element channelBinding) { + Preconditions.checkArgument( + channelBinding == null + || ("sasl-channel-binding".equals(channelBinding.getName()) + && Namespace.CHANNEL_BINDING.equals(channelBinding.getNamespace())), + "pass null or a valid channel binding stream feature"); + return Collections2.filter( + Collections2.transform( + Collections2.filter( + channelBinding == null + ? Collections.emptyList() + : channelBinding.getChildren(), + c -> c != null && "channel-binding".equals(c.getName())), + c -> c == null ? null : ChannelBinding.of(c.getAttribute("type"))), + Predicates.notNull()); + } + + private static ChannelBinding of(final String type) { + if (type == null) { + return null; + } + try { + return valueOf( + CaseFormat.LOWER_HYPHEN.converterTo(CaseFormat.UPPER_UNDERSCORE).convert(type)); + } catch (final IllegalArgumentException e) { + Log.d(Config.LOGTAG, type + " is not a known channel binding"); + return null; + } + } + + public static ChannelBinding get(final String name) { + if (Strings.isNullOrEmpty(name)) { + return NONE; + } + try { + return valueOf(name); + } catch (final IllegalArgumentException e) { + return NONE; + } + } + + public static ChannelBinding best( + final Collection bindings, final SSLSockets.Version sslVersion) { + if (sslVersion == SSLSockets.Version.NONE) { + return NONE; + } + if (bindings.contains(TLS_EXPORTER) && sslVersion == SSLSockets.Version.TLS_1_3) { + return TLS_EXPORTER; + } else if (bindings.contains(TLS_UNIQUE) + && Arrays.asList( + SSLSockets.Version.TLS_1_0, + SSLSockets.Version.TLS_1_1, + SSLSockets.Version.TLS_1_2) + .contains(sslVersion)) { + return TLS_UNIQUE; + } else if (bindings.contains(TLS_SERVER_END_POINT)) { + return TLS_SERVER_END_POINT; + } else { + return NONE; + } + } + + public static boolean isAvailable( + final ChannelBinding channelBinding, final SSLSockets.Version sslVersion) { + return ChannelBinding.best(Collections.singleton(channelBinding), sslVersion) + == channelBinding; + } + + private static String shortName(final ChannelBinding channelBinding) { + switch (channelBinding) { + case TLS_UNIQUE: + return "UNIQ"; + case TLS_EXPORTER: + return "EXPR"; + case TLS_SERVER_END_POINT: + return "ENDP"; + case NONE: + return "NONE"; + default: + throw new AssertionError("Missing short name for " + channelBinding); + } + } +} diff --git a/src/main/java/im/conversations/android/xmpp/sasl/ChannelBindingMechanism.java b/src/main/java/im/conversations/android/xmpp/sasl/ChannelBindingMechanism.java new file mode 100644 index 000000000..33b5839a5 --- /dev/null +++ b/src/main/java/im/conversations/android/xmpp/sasl/ChannelBindingMechanism.java @@ -0,0 +1,100 @@ +package im.conversations.android.xmpp.sasl; + +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.security.cert.Certificate; +import java.security.cert.CertificateEncodingException; +import java.security.cert.X509Certificate; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLSession; +import javax.net.ssl.SSLSocket; +import org.bouncycastle.jcajce.provider.digest.SHA256; +import org.conscrypt.Conscrypt; + +public interface ChannelBindingMechanism { + + String EXPORTER_LABEL = "EXPORTER-Channel-Binding"; + + ChannelBinding getChannelBinding(); + + static byte[] getChannelBindingData( + final SSLSocket sslSocket, final ChannelBinding channelBinding) + throws SaslMechanism.AuthenticationException { + if (sslSocket == null) { + throw new SaslMechanism.AuthenticationException( + "Channel binding attempt on non secure socket"); + } + if (channelBinding == ChannelBinding.TLS_EXPORTER) { + final byte[] keyingMaterial; + try { + keyingMaterial = + Conscrypt.exportKeyingMaterial(sslSocket, EXPORTER_LABEL, new byte[0], 32); + } catch (final SSLException e) { + throw new SaslMechanism.AuthenticationException("Could not export keying material"); + } + if (keyingMaterial == null) { + throw new SaslMechanism.AuthenticationException( + "Could not export keying material. Socket not ready"); + } + return keyingMaterial; + } else if (channelBinding == ChannelBinding.TLS_UNIQUE) { + final byte[] unique = Conscrypt.getTlsUnique(sslSocket); + if (unique == null) { + throw new SaslMechanism.AuthenticationException( + "Could not retrieve tls unique. Socket not ready"); + } + return unique; + } else if (channelBinding == ChannelBinding.TLS_SERVER_END_POINT) { + return getServerEndPointChannelBinding(sslSocket.getSession()); + } else { + throw new SaslMechanism.AuthenticationException( + String.format("%s is not a valid channel binding", channelBinding)); + } + } + + static byte[] getServerEndPointChannelBinding(final SSLSession session) + throws SaslMechanism.AuthenticationException { + final Certificate[] certificates; + try { + certificates = session.getPeerCertificates(); + } catch (final SSLPeerUnverifiedException e) { + throw new SaslMechanism.AuthenticationException("Could not verify peer certificates"); + } + if (certificates == null || certificates.length == 0) { + throw new SaslMechanism.AuthenticationException("Could not retrieve peer certificate"); + } + final X509Certificate certificate; + if (certificates[0] instanceof X509Certificate) { + certificate = (X509Certificate) certificates[0]; + } else { + throw new SaslMechanism.AuthenticationException("Certificate was not X509"); + } + final String algorithm = certificate.getSigAlgName(); + final int withIndex = algorithm.indexOf("with"); + if (withIndex <= 0) { + throw new SaslMechanism.AuthenticationException("Unable to parse SigAlgName"); + } + final String hashAlgorithm = algorithm.substring(0, withIndex); + final MessageDigest messageDigest; + // https://www.rfc-editor.org/rfc/rfc5929#section-4.1 + if ("MD5".equalsIgnoreCase(hashAlgorithm) || "SHA1".equalsIgnoreCase(hashAlgorithm)) { + messageDigest = new SHA256.Digest(); + } else { + try { + messageDigest = MessageDigest.getInstance(hashAlgorithm); + } catch (final NoSuchAlgorithmException e) { + throw new SaslMechanism.AuthenticationException( + "Could not instantiate message digest for " + hashAlgorithm); + } + } + final byte[] encodedCertificate; + try { + encodedCertificate = certificate.getEncoded(); + } catch (final CertificateEncodingException e) { + throw new SaslMechanism.AuthenticationException("Could not encode certificate"); + } + messageDigest.update(encodedCertificate); + return messageDigest.digest(); + } +} diff --git a/src/main/java/im/conversations/android/xmpp/sasl/DigestMd5.java b/src/main/java/im/conversations/android/xmpp/sasl/DigestMd5.java new file mode 100644 index 000000000..d1e2b914b --- /dev/null +++ b/src/main/java/im/conversations/android/xmpp/sasl/DigestMd5.java @@ -0,0 +1,112 @@ +package im.conversations.android.xmpp.sasl; + +import android.util.Base64; +import eu.siacs.conversations.utils.CryptoHelper; +import im.conversations.android.database.model.Account; +import im.conversations.android.database.model.Credential; +import java.nio.charset.Charset; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import javax.net.ssl.SSLSocket; + +public class DigestMd5 extends SaslMechanism { + + public static final String MECHANISM = "DIGEST-MD5"; + private State state = State.INITIAL; + + public DigestMd5(final Account account, final Credential credential) { + super(account, credential); + } + + @Override + public int getPriority() { + return 10; + } + + @Override + public String getMechanism() { + return MECHANISM; + } + + @Override + public String getResponse(final String challenge, final SSLSocket sslSocket) + throws AuthenticationException { + switch (state) { + case INITIAL: + state = State.RESPONSE_SENT; + final String encodedResponse; + try { + final Tokenizer tokenizer = + new Tokenizer(Base64.decode(challenge, Base64.DEFAULT)); + String nonce = ""; + for (final String token : tokenizer) { + final String[] parts = token.split("=", 2); + if (parts[0].equals("nonce")) { + nonce = parts[1].replace("\"", ""); + } else if (parts[0].equals("rspauth")) { + return ""; + } + } + final String digestUri = "xmpp/" + account.address.getDomain(); + final String nonceCount = "00000001"; + final String x = + account.address.getEscapedLocal() + + ":" + + account.address.getDomain() + + ":" + + credential.password; + final MessageDigest md = MessageDigest.getInstance("MD5"); + final byte[] y = md.digest(x.getBytes(Charset.defaultCharset())); + final String cNonce = CryptoHelper.random(100); + final byte[] a1 = + CryptoHelper.concatenateByteArrays( + y, + (":" + nonce + ":" + cNonce) + .getBytes(Charset.defaultCharset())); + final String a2 = "AUTHENTICATE:" + digestUri; + final String ha1 = CryptoHelper.bytesToHex(md.digest(a1)); + final String ha2 = + CryptoHelper.bytesToHex( + md.digest(a2.getBytes(Charset.defaultCharset()))); + final String kd = + ha1 + ":" + nonce + ":" + nonceCount + ":" + cNonce + ":auth:" + ha2; + final String response = + CryptoHelper.bytesToHex( + md.digest(kd.getBytes(Charset.defaultCharset()))); + final String saslString = + "username=\"" + + account.address.getEscapedLocal() + + "\",realm=\"" + + account.address.getDomain() + + "\",nonce=\"" + + nonce + + "\",cnonce=\"" + + cNonce + + "\",nc=" + + nonceCount + + ",qop=auth,digest-uri=\"" + + digestUri + + "\",response=" + + response + + ",charset=utf-8"; + encodedResponse = + Base64.encodeToString( + saslString.getBytes(Charset.defaultCharset()), Base64.NO_WRAP); + } catch (final NoSuchAlgorithmException e) { + throw new AuthenticationException(e); + } + + return encodedResponse; + case RESPONSE_SENT: + state = State.VALID_SERVER_RESPONSE; + break; + case VALID_SERVER_RESPONSE: + if (challenge == null) { + return null; // everything is fine + } + default: + throw new InvalidStateException(state); + } + return null; + } +} diff --git a/src/main/java/im/conversations/android/xmpp/sasl/External.java b/src/main/java/im/conversations/android/xmpp/sasl/External.java new file mode 100644 index 000000000..8ef0f060c --- /dev/null +++ b/src/main/java/im/conversations/android/xmpp/sasl/External.java @@ -0,0 +1,30 @@ +package im.conversations.android.xmpp.sasl; + +import android.util.Base64; +import im.conversations.android.database.model.Account; +import im.conversations.android.database.model.Credential; +import javax.net.ssl.SSLSocket; + +public class External extends SaslMechanism { + + public static final String MECHANISM = "EXTERNAL"; + + public External(final Account account) { + super(account, Credential.empty()); + } + + @Override + public int getPriority() { + return 25; + } + + @Override + public String getMechanism() { + return MECHANISM; + } + + @Override + public String getClientFirstMessage(final SSLSocket sslSocket) { + return Base64.encodeToString(account.address.toEscapedString().getBytes(), Base64.NO_WRAP); + } +} diff --git a/src/main/java/im/conversations/android/xmpp/sasl/HashedToken.java b/src/main/java/im/conversations/android/xmpp/sasl/HashedToken.java new file mode 100644 index 000000000..46046683d --- /dev/null +++ b/src/main/java/im/conversations/android/xmpp/sasl/HashedToken.java @@ -0,0 +1,189 @@ +package im.conversations.android.xmpp.sasl; + +import android.util.Base64; +import android.util.Log; +import com.google.common.base.MoreObjects; +import com.google.common.base.Strings; +import com.google.common.collect.ImmutableMultimap; +import com.google.common.collect.Multimap; +import com.google.common.hash.HashFunction; +import com.google.common.primitives.Bytes; +import eu.siacs.conversations.Config; +import eu.siacs.conversations.utils.SSLSockets; +import im.conversations.android.database.model.Account; +import im.conversations.android.database.model.Credential; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import javax.net.ssl.SSLSocket; +import org.jetbrains.annotations.NotNull; + +public abstract class HashedToken extends SaslMechanism implements ChannelBindingMechanism { + + private static final String PREFIX = "HT"; + + private static final List HASH_FUNCTIONS = Arrays.asList("SHA-512", "SHA-256"); + private static final byte[] INITIATOR = "Initiator".getBytes(StandardCharsets.UTF_8); + private static final byte[] RESPONDER = "Responder".getBytes(StandardCharsets.UTF_8); + + protected final ChannelBinding channelBinding; + + protected HashedToken( + final Account account, + final Credential credential, + final ChannelBinding channelBinding) { + super(account, credential); + this.channelBinding = channelBinding; + } + + @Override + public int getPriority() { + throw new UnsupportedOperationException(); + } + + @Override + public String getClientFirstMessage(final SSLSocket sslSocket) { + final String token = Strings.nullToEmpty(this.credential.fastToken); + final HashFunction hashing = getHashFunction(token.getBytes(StandardCharsets.UTF_8)); + final byte[] cbData = getChannelBindingData(sslSocket); + final byte[] initiatorHashedToken = + hashing.hashBytes(Bytes.concat(INITIATOR, cbData)).asBytes(); + final byte[] firstMessage = + Bytes.concat( + account.address.getEscapedLocal().getBytes(StandardCharsets.UTF_8), + new byte[] {0x00}, + initiatorHashedToken); + return Base64.encodeToString(firstMessage, Base64.NO_WRAP); + } + + private byte[] getChannelBindingData(final SSLSocket sslSocket) { + if (this.channelBinding == ChannelBinding.NONE) { + return new byte[0]; + } + try { + return ChannelBindingMechanism.getChannelBindingData(sslSocket, this.channelBinding); + } catch (final AuthenticationException e) { + Log.e( + Config.LOGTAG, + account.address + + ": unable to retrieve channel binding data for " + + getMechanism(), + e); + return new byte[0]; + } + } + + @Override + public String getResponse(final String challenge, final SSLSocket socket) + throws AuthenticationException { + final byte[] responderMessage; + try { + responderMessage = Base64.decode(challenge, Base64.NO_WRAP); + } catch (final Exception e) { + throw new AuthenticationException("Unable to decode responder message", e); + } + final String token = Strings.nullToEmpty(this.credential.fastToken); + final HashFunction hashing = getHashFunction(token.getBytes(StandardCharsets.UTF_8)); + final byte[] cbData = getChannelBindingData(socket); + final byte[] expectedResponderMessage = + hashing.hashBytes(Bytes.concat(RESPONDER, cbData)).asBytes(); + if (Arrays.equals(responderMessage, expectedResponderMessage)) { + return null; + } + throw new AuthenticationException("Responder message did not match"); + } + + protected abstract HashFunction getHashFunction(final byte[] key); + + public abstract Mechanism getTokenMechanism(); + + @Override + public String getMechanism() { + return getTokenMechanism().name(); + } + + public static final class Mechanism { + public final String hashFunction; + public final ChannelBinding channelBinding; + + public Mechanism(String hashFunction, ChannelBinding channelBinding) { + this.hashFunction = hashFunction; + this.channelBinding = channelBinding; + } + + public static Mechanism of(final String mechanism) { + final int first = mechanism.indexOf('-'); + final int last = mechanism.lastIndexOf('-'); + if (last <= first || mechanism.length() <= last) { + throw new IllegalArgumentException("Not a valid HashedToken name"); + } + if (mechanism.substring(0, first).equals(PREFIX)) { + final String hashFunction = mechanism.substring(first + 1, last); + final String cbShortName = mechanism.substring(last + 1); + final ChannelBinding channelBinding = + ChannelBinding.SHORT_NAMES.inverse().get(cbShortName); + if (channelBinding == null) { + throw new IllegalArgumentException("Unknown channel binding " + cbShortName); + } + return new Mechanism(hashFunction, channelBinding); + } else { + throw new IllegalArgumentException("HashedToken name does not start with HT"); + } + } + + public static Mechanism ofOrNull(final String mechanism) { + try { + return mechanism == null ? null : of(mechanism); + } catch (final IllegalArgumentException e) { + return null; + } + } + + public static Multimap of(final Collection mechanisms) { + final ImmutableMultimap.Builder builder = + ImmutableMultimap.builder(); + for (final String name : mechanisms) { + try { + final Mechanism mechanism = Mechanism.of(name); + builder.put(mechanism.hashFunction, mechanism.channelBinding); + } catch (final IllegalArgumentException ignored) { + } + } + return builder.build(); + } + + public static Mechanism best( + final Collection mechanisms, final SSLSockets.Version sslVersion) { + final Multimap multimap = of(mechanisms); + for (final String hashFunction : HASH_FUNCTIONS) { + final Collection channelBindings = multimap.get(hashFunction); + if (channelBindings.isEmpty()) { + continue; + } + final ChannelBinding cb = ChannelBinding.best(channelBindings, sslVersion); + return new Mechanism(hashFunction, cb); + } + return null; + } + + @NotNull + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("hashFunction", hashFunction) + .add("channelBinding", channelBinding) + .toString(); + } + + public String name() { + return String.format( + "%s-%s-%s", + PREFIX, hashFunction, ChannelBinding.SHORT_NAMES.get(channelBinding)); + } + } + + public ChannelBinding getChannelBinding() { + return this.channelBinding; + } +} diff --git a/src/main/java/im/conversations/android/xmpp/sasl/HashedTokenSha256.java b/src/main/java/im/conversations/android/xmpp/sasl/HashedTokenSha256.java new file mode 100644 index 000000000..8ee0b5083 --- /dev/null +++ b/src/main/java/im/conversations/android/xmpp/sasl/HashedTokenSha256.java @@ -0,0 +1,26 @@ +package im.conversations.android.xmpp.sasl; + +import com.google.common.hash.HashFunction; +import com.google.common.hash.Hashing; +import im.conversations.android.database.model.Account; +import im.conversations.android.database.model.Credential; + +public class HashedTokenSha256 extends HashedToken { + + public HashedTokenSha256( + final Account account, + final Credential credential, + final ChannelBinding channelBinding) { + super(account, credential, channelBinding); + } + + @Override + protected HashFunction getHashFunction(final byte[] key) { + return Hashing.hmacSha256(key); + } + + @Override + public Mechanism getTokenMechanism() { + return new Mechanism("SHA-256", channelBinding); + } +} diff --git a/src/main/java/im/conversations/android/xmpp/sasl/HashedTokenSha512.java b/src/main/java/im/conversations/android/xmpp/sasl/HashedTokenSha512.java new file mode 100644 index 000000000..eee51aa2d --- /dev/null +++ b/src/main/java/im/conversations/android/xmpp/sasl/HashedTokenSha512.java @@ -0,0 +1,26 @@ +package im.conversations.android.xmpp.sasl; + +import com.google.common.hash.HashFunction; +import com.google.common.hash.Hashing; +import im.conversations.android.database.model.Account; +import im.conversations.android.database.model.Credential; + +public class HashedTokenSha512 extends HashedToken { + + public HashedTokenSha512( + final Account account, + final Credential credential, + final ChannelBinding channelBinding) { + super(account, credential, channelBinding); + } + + @Override + protected HashFunction getHashFunction(final byte[] key) { + return Hashing.hmacSha512(key); + } + + @Override + public Mechanism getTokenMechanism() { + return new Mechanism("SHA-512", this.channelBinding); + } +} diff --git a/src/main/java/im/conversations/android/xmpp/sasl/Plain.java b/src/main/java/im/conversations/android/xmpp/sasl/Plain.java new file mode 100644 index 000000000..d41c07629 --- /dev/null +++ b/src/main/java/im/conversations/android/xmpp/sasl/Plain.java @@ -0,0 +1,36 @@ +package im.conversations.android.xmpp.sasl; + +import android.util.Base64; +import im.conversations.android.database.model.Account; +import im.conversations.android.database.model.Credential; +import java.nio.charset.Charset; +import javax.net.ssl.SSLSocket; + +public class Plain extends SaslMechanism { + + public static final String MECHANISM = "PLAIN"; + + public Plain(final Account account, final Credential credential) { + super(account, credential); + } + + public static String getMessage(String username, String password) { + final String message = '\u0000' + username + '\u0000' + password; + return Base64.encodeToString(message.getBytes(Charset.defaultCharset()), Base64.NO_WRAP); + } + + @Override + public int getPriority() { + return 10; + } + + @Override + public String getMechanism() { + return MECHANISM; + } + + @Override + public String getClientFirstMessage(final SSLSocket sslSocket) { + return getMessage(account.address.getEscapedLocal(), credential.password); + } +} diff --git a/src/main/java/im/conversations/android/xmpp/sasl/SaslMechanism.java b/src/main/java/im/conversations/android/xmpp/sasl/SaslMechanism.java new file mode 100644 index 000000000..be030fcb3 --- /dev/null +++ b/src/main/java/im/conversations/android/xmpp/sasl/SaslMechanism.java @@ -0,0 +1,236 @@ +package im.conversations.android.xmpp.sasl; + +import android.util.Log; +import com.google.common.base.Preconditions; +import com.google.common.base.Strings; +import com.google.common.collect.Collections2; +import eu.siacs.conversations.Config; +import eu.siacs.conversations.utils.SSLSockets; +import eu.siacs.conversations.xml.Element; +import eu.siacs.conversations.xml.Namespace; +import im.conversations.android.database.model.Account; +import im.conversations.android.database.model.Credential; +import java.util.Collection; +import java.util.Collections; +import javax.net.ssl.SSLSocket; + +public abstract class SaslMechanism { + + protected final Account account; + protected final Credential credential; + + protected SaslMechanism(final Account account, final Credential credential) { + this.account = account; + this.credential = credential; + } + + public static String namespace(final Version version) { + if (version == Version.SASL) { + return Namespace.SASL; + } else { + return Namespace.SASL_2; + } + } + + /** + * The priority is used to pin the authentication mechanism. If authentication fails, it MAY be + * retried with another mechanism of the same priority, but MUST NOT be tried with a mechanism + * of lower priority (to prevent downgrade attacks). + * + * @return An arbitrary int representing the priority + */ + public abstract int getPriority(); + + public abstract String getMechanism(); + + public String getClientFirstMessage(final SSLSocket sslSocket) { + return ""; + } + + public String getResponse(final String challenge, final SSLSocket sslSocket) + throws AuthenticationException { + return ""; + } + + public static Collection mechanisms(final Element authElement) { + if (authElement == null) { + return Collections.emptyList(); + } + return Collections2.transform( + Collections2.filter( + authElement.getChildren(), + c -> c != null && "mechanism".equals(c.getName())), + c -> c == null ? null : c.getContent()); + } + + protected enum State { + INITIAL, + AUTH_TEXT_SENT, + RESPONSE_SENT, + VALID_SERVER_RESPONSE, + } + + public enum Version { + SASL, + SASL_2; + + public static Version of(final Element element) { + switch (Strings.nullToEmpty(element.getNamespace())) { + case Namespace.SASL: + return SASL; + case Namespace.SASL_2: + return SASL_2; + default: + throw new IllegalArgumentException("Unrecognized SASL namespace"); + } + } + } + + public static class AuthenticationException extends Exception { + public AuthenticationException(final String message) { + super(message); + } + + public AuthenticationException(final Exception inner) { + super(inner); + } + + public AuthenticationException(final String message, final Exception exception) { + super(message, exception); + } + } + + public static class InvalidStateException extends AuthenticationException { + public InvalidStateException(final String message) { + super(message); + } + + public InvalidStateException(final State state) { + this("Invalid state: " + state.toString()); + } + } + + public static final class Factory { + + private final Account account; + private final Credential credential; + + public Factory(final Account account, final Credential credential) { + this.account = account; + this.credential = credential; + } + + private SaslMechanism of( + final Collection mechanisms, final ChannelBinding channelBinding) { + Preconditions.checkNotNull(channelBinding, "Use ChannelBinding.NONE instead of null"); + if (mechanisms.contains(External.MECHANISM) && credential.privateKeyAlias != null) { + return new External(account); + } else if (mechanisms.contains(ScramSha512Plus.MECHANISM) + && channelBinding != ChannelBinding.NONE) { + return new ScramSha512Plus(account, credential, channelBinding); + } else if (mechanisms.contains(ScramSha256Plus.MECHANISM) + && channelBinding != ChannelBinding.NONE) { + return new ScramSha256Plus(account, credential, channelBinding); + } else if (mechanisms.contains(ScramSha1Plus.MECHANISM) + && channelBinding != ChannelBinding.NONE) { + return new ScramSha1Plus(account, credential, channelBinding); + } else if (mechanisms.contains(ScramSha512.MECHANISM)) { + return new ScramSha512(account, credential); + } else if (mechanisms.contains(ScramSha256.MECHANISM)) { + return new ScramSha256(account, credential); + } else if (mechanisms.contains(ScramSha1.MECHANISM)) { + return new ScramSha1(account, credential); + } else if (mechanisms.contains(Plain.MECHANISM)) { + return new Plain(account, credential); + } else if (mechanisms.contains(DigestMd5.MECHANISM)) { + return new DigestMd5(account, credential); + } else if (mechanisms.contains(Anonymous.MECHANISM)) { + return new Anonymous(account); + } else { + return null; + } + } + + public SaslMechanism of( + final Collection mechanisms, + final Collection bindings, + final Version version, + final SSLSockets.Version sslVersion) { + final HashedToken fastMechanism = getFastMechanism(); + if (version == Version.SASL_2 && fastMechanism != null) { + return fastMechanism; + } + final ChannelBinding channelBinding = ChannelBinding.best(bindings, sslVersion); + return of(mechanisms, channelBinding); + } + + public SaslMechanism of(final String mechanism, final ChannelBinding channelBinding) { + return of(Collections.singleton(mechanism), channelBinding); + } + + public HashedToken getFastMechanism() { + final HashedToken.Mechanism fastMechanism = + HashedToken.Mechanism.ofOrNull(credential.fastMechanism); + final String token = credential.fastToken; + if (fastMechanism == null || Strings.isNullOrEmpty(token)) { + return null; + } + if (fastMechanism.hashFunction.equals("SHA-256")) { + return new HashedTokenSha256(account, credential, fastMechanism.channelBinding); + } else if (fastMechanism.hashFunction.equals("SHA-512")) { + return new HashedTokenSha512(account, credential, fastMechanism.channelBinding); + } else { + return null; + } + } + + private SaslMechanism getPinnedMechanism() { + final String mechanism = Strings.nullToEmpty(credential.pinnedMechanism); + final ChannelBinding channelBinding = + ChannelBinding.get(credential.pinnedChannelBinding); + return this.of(mechanism, channelBinding); + } + + public SaslMechanism getQuickStartMechanism() { + final HashedToken hashedTokenMechanism = getFastMechanism(); + if (hashedTokenMechanism != null) { + return hashedTokenMechanism; + } + return getPinnedMechanism(); + } + + public int getPinnedMechanismPriority() { + final SaslMechanism saslMechanism = getPinnedMechanism(); + if (saslMechanism == null) { + return Integer.MIN_VALUE; + } else { + return saslMechanism.getPriority(); + } + } + } + + public static SaslMechanism ensureAvailable( + final SaslMechanism mechanism, final SSLSockets.Version sslVersion) { + if (mechanism instanceof ChannelBindingMechanism) { + final ChannelBinding cb = ((ChannelBindingMechanism) mechanism).getChannelBinding(); + if (ChannelBinding.isAvailable(cb, sslVersion)) { + return mechanism; + } else { + Log.d( + Config.LOGTAG, + "pinned channel binding method " + cb + " no longer available"); + return null; + } + } else { + return mechanism; + } + } + + public static boolean hashedToken(final SaslMechanism saslMechanism) { + return saslMechanism instanceof HashedToken; + } + + public static boolean pin(final SaslMechanism saslMechanism) { + return !hashedToken(saslMechanism); + } +} diff --git a/src/main/java/im/conversations/android/xmpp/sasl/ScramMechanism.java b/src/main/java/im/conversations/android/xmpp/sasl/ScramMechanism.java new file mode 100644 index 000000000..99f500d75 --- /dev/null +++ b/src/main/java/im/conversations/android/xmpp/sasl/ScramMechanism.java @@ -0,0 +1,318 @@ +package im.conversations.android.xmpp.sasl; + +import android.util.Base64; +import com.google.common.base.CaseFormat; +import com.google.common.base.Objects; +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; +import com.google.common.hash.HashFunction; +import eu.siacs.conversations.utils.CryptoHelper; +import im.conversations.android.database.model.Account; +import im.conversations.android.database.model.Credential; +import java.nio.charset.Charset; +import java.security.InvalidKeyException; +import java.util.concurrent.ExecutionException; +import javax.crypto.SecretKey; +import javax.net.ssl.SSLSocket; + +abstract class ScramMechanism extends SaslMechanism { + + public static final SecretKey EMPTY_KEY = + new SecretKey() { + @Override + public String getAlgorithm() { + return "HMAC"; + } + + @Override + public String getFormat() { + return "RAW"; + } + + @Override + public byte[] getEncoded() { + return new byte[0]; + } + }; + + private static final byte[] CLIENT_KEY_BYTES = "Client Key".getBytes(); + private static final byte[] SERVER_KEY_BYTES = "Server Key".getBytes(); + private static final Cache CACHE = + CacheBuilder.newBuilder().maximumSize(10).build(); + protected final ChannelBinding channelBinding; + private final String gs2Header; + private final String clientNonce; + protected State state = State.INITIAL; + private String clientFirstMessageBare; + private byte[] serverSignature = null; + + ScramMechanism( + final Account account, + final Credential credential, + final ChannelBinding channelBinding) { + super(account, credential); + this.channelBinding = channelBinding; + if (channelBinding == ChannelBinding.NONE) { + // TODO this needs to be changed to "y,," for the scram internal down grade protection + // but we might risk compatibility issues if the server supports a binding that we don’t + // support + this.gs2Header = "n,,"; + } else { + this.gs2Header = + String.format( + "p=%s,,", + CaseFormat.UPPER_UNDERSCORE + .converterTo(CaseFormat.LOWER_HYPHEN) + .convert(channelBinding.toString())); + } + // This nonce should be different for each authentication attempt. + this.clientNonce = CryptoHelper.random(100); + clientFirstMessageBare = ""; + } + + protected abstract HashFunction getHMac(final byte[] key); + + protected abstract HashFunction getDigest(); + + private KeyPair getKeyPair(final String password, final String salt, final int iterations) + throws ExecutionException { + return CACHE.get( + new CacheKey(getMechanism(), password, salt, iterations), + () -> { + final byte[] saltedPassword, serverKey, clientKey; + saltedPassword = + hi( + password.getBytes(), + Base64.decode(salt, Base64.DEFAULT), + iterations); + serverKey = hmac(saltedPassword, SERVER_KEY_BYTES); + clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES); + return new KeyPair(clientKey, serverKey); + }); + } + + private byte[] hmac(final byte[] key, final byte[] input) throws InvalidKeyException { + return getHMac(key).hashBytes(input).asBytes(); + } + + private byte[] digest(final byte[] bytes) { + return getDigest().hashBytes(bytes).asBytes(); + } + + /* + * Hi() is, essentially, PBKDF2 [RFC2898] with HMAC() as the + * pseudorandom function (PRF) and with dkLen == output length of + * HMAC() == output length of H(). + */ + private byte[] hi(final byte[] key, final byte[] salt, final int iterations) + throws InvalidKeyException { + byte[] u = hmac(key, CryptoHelper.concatenateByteArrays(salt, CryptoHelper.ONE)); + byte[] out = u.clone(); + for (int i = 1; i < iterations; i++) { + u = hmac(key, u); + for (int j = 0; j < u.length; j++) { + out[j] ^= u[j]; + } + } + return out; + } + + @Override + public String getClientFirstMessage(final SSLSocket sslSocket) { + if (clientFirstMessageBare.isEmpty() && state == State.INITIAL) { + clientFirstMessageBare = + "n=" + + CryptoHelper.saslEscape( + CryptoHelper.saslPrep(account.address.getEscapedLocal())) + + ",r=" + + this.clientNonce; + state = State.AUTH_TEXT_SENT; + } + return Base64.encodeToString( + (gs2Header + clientFirstMessageBare).getBytes(Charset.defaultCharset()), + Base64.NO_WRAP); + } + + @Override + public String getResponse(final String challenge, final SSLSocket socket) + throws AuthenticationException { + switch (state) { + case AUTH_TEXT_SENT: + if (challenge == null) { + throw new AuthenticationException("challenge can not be null"); + } + byte[] serverFirstMessage; + try { + serverFirstMessage = Base64.decode(challenge, Base64.DEFAULT); + } catch (IllegalArgumentException e) { + throw new AuthenticationException("Unable to decode server challenge", e); + } + final Tokenizer tokenizer = new Tokenizer(serverFirstMessage); + String nonce = ""; + int iterationCount = -1; + String salt = ""; + for (final String token : tokenizer) { + if (token.charAt(1) == '=') { + switch (token.charAt(0)) { + case 'i': + try { + iterationCount = Integer.parseInt(token.substring(2)); + } catch (final NumberFormatException e) { + throw new AuthenticationException(e); + } + break; + case 's': + salt = token.substring(2); + break; + case 'r': + nonce = token.substring(2); + break; + case 'm': + /* + * RFC 5802: + * m: This attribute is reserved for future extensibility. In this + * version of SCRAM, its presence in a client or a server message + * MUST cause authentication failure when the attribute is parsed by + * the other end. + */ + throw new AuthenticationException( + "Server sent reserved token: `m'"); + } + } + } + + if (iterationCount < 0) { + throw new AuthenticationException("Server did not send iteration count"); + } + if (nonce.isEmpty() || !nonce.startsWith(clientNonce)) { + throw new AuthenticationException( + "Server nonce does not contain client nonce: " + nonce); + } + if (salt.isEmpty()) { + throw new AuthenticationException("Server sent empty salt"); + } + + final byte[] channelBindingData = getChannelBindingData(socket); + + final int gs2Len = this.gs2Header.getBytes().length; + final byte[] cMessage = new byte[gs2Len + channelBindingData.length]; + System.arraycopy(this.gs2Header.getBytes(), 0, cMessage, 0, gs2Len); + System.arraycopy( + channelBindingData, 0, cMessage, gs2Len, channelBindingData.length); + + final String clientFinalMessageWithoutProof = + "c=" + Base64.encodeToString(cMessage, Base64.NO_WRAP) + ",r=" + nonce; + + final byte[] authMessage = + (clientFirstMessageBare + + ',' + + new String(serverFirstMessage) + + ',' + + clientFinalMessageWithoutProof) + .getBytes(); + + final KeyPair keys; + try { + keys = + getKeyPair( + CryptoHelper.saslPrep(credential.password), + salt, + iterationCount); + } catch (ExecutionException e) { + throw new AuthenticationException("Invalid keys generated"); + } + final byte[] clientSignature; + try { + serverSignature = hmac(keys.serverKey, authMessage); + final byte[] storedKey = digest(keys.clientKey); + + clientSignature = hmac(storedKey, authMessage); + + } catch (final InvalidKeyException e) { + throw new AuthenticationException(e); + } + + final byte[] clientProof = new byte[keys.clientKey.length]; + + if (clientSignature.length < keys.clientKey.length) { + throw new AuthenticationException( + "client signature was shorter than clientKey"); + } + + for (int i = 0; i < clientProof.length; i++) { + clientProof[i] = (byte) (keys.clientKey[i] ^ clientSignature[i]); + } + + final String clientFinalMessage = + clientFinalMessageWithoutProof + + ",p=" + + Base64.encodeToString(clientProof, Base64.NO_WRAP); + state = State.RESPONSE_SENT; + return Base64.encodeToString(clientFinalMessage.getBytes(), Base64.NO_WRAP); + case RESPONSE_SENT: + try { + final String clientCalculatedServerFinalMessage = + "v=" + Base64.encodeToString(serverSignature, Base64.NO_WRAP); + if (!clientCalculatedServerFinalMessage.equals( + new String(Base64.decode(challenge, Base64.DEFAULT)))) { + throw new Exception(); + } + state = State.VALID_SERVER_RESPONSE; + return ""; + } catch (Exception e) { + throw new AuthenticationException( + "Server final message does not match calculated final message"); + } + default: + throw new InvalidStateException(state); + } + } + + protected byte[] getChannelBindingData(final SSLSocket sslSocket) + throws AuthenticationException { + if (this.channelBinding == ChannelBinding.NONE) { + return new byte[0]; + } + throw new AssertionError("getChannelBindingData needs to be overwritten"); + } + + private static class CacheKey { + final String algorithm; + final String password; + final String salt; + final int iterations; + + private CacheKey(String algorithm, String password, String salt, int iterations) { + this.algorithm = algorithm; + this.password = password; + this.salt = salt; + this.iterations = iterations; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + CacheKey cacheKey = (CacheKey) o; + return iterations == cacheKey.iterations + && Objects.equal(algorithm, cacheKey.algorithm) + && Objects.equal(password, cacheKey.password) + && Objects.equal(salt, cacheKey.salt); + } + + @Override + public int hashCode() { + return Objects.hashCode(algorithm, password, salt, iterations); + } + } + + private static class KeyPair { + final byte[] clientKey; + final byte[] serverKey; + + KeyPair(final byte[] clientKey, final byte[] serverKey) { + this.clientKey = clientKey; + this.serverKey = serverKey; + } + } +} diff --git a/src/main/java/im/conversations/android/xmpp/sasl/ScramPlusMechanism.java b/src/main/java/im/conversations/android/xmpp/sasl/ScramPlusMechanism.java new file mode 100644 index 000000000..60c4e211d --- /dev/null +++ b/src/main/java/im/conversations/android/xmpp/sasl/ScramPlusMechanism.java @@ -0,0 +1,24 @@ +package im.conversations.android.xmpp.sasl; + +import im.conversations.android.database.model.Account; +import im.conversations.android.database.model.Credential; +import javax.net.ssl.SSLSocket; + +public abstract class ScramPlusMechanism extends ScramMechanism implements ChannelBindingMechanism { + + ScramPlusMechanism( + Account account, final Credential credential, ChannelBinding channelBinding) { + super(account, credential, channelBinding); + } + + @Override + protected byte[] getChannelBindingData(final SSLSocket sslSocket) + throws AuthenticationException { + return ChannelBindingMechanism.getChannelBindingData(sslSocket, this.channelBinding); + } + + @Override + public ChannelBinding getChannelBinding() { + return this.channelBinding; + } +} diff --git a/src/main/java/im/conversations/android/xmpp/sasl/ScramSha1.java b/src/main/java/im/conversations/android/xmpp/sasl/ScramSha1.java new file mode 100644 index 000000000..e83b2d201 --- /dev/null +++ b/src/main/java/im/conversations/android/xmpp/sasl/ScramSha1.java @@ -0,0 +1,37 @@ +package im.conversations.android.xmpp.sasl; + +import com.google.common.hash.HashFunction; +import com.google.common.hash.Hashing; +import im.conversations.android.database.model.Account; +import im.conversations.android.database.model.Credential; + +public class ScramSha1 extends ScramMechanism { + + public static final String MECHANISM = "SCRAM-SHA-1"; + + public ScramSha1(final Account account, final Credential credential) { + super(account, credential, ChannelBinding.NONE); + } + + @Override + protected HashFunction getHMac(final byte[] key) { + return (key == null || key.length == 0) + ? Hashing.hmacSha1(EMPTY_KEY) + : Hashing.hmacSha1(key); + } + + @Override + protected HashFunction getDigest() { + return Hashing.sha1(); + } + + @Override + public int getPriority() { + return 20; + } + + @Override + public String getMechanism() { + return MECHANISM; + } +} diff --git a/src/main/java/im/conversations/android/xmpp/sasl/ScramSha1Plus.java b/src/main/java/im/conversations/android/xmpp/sasl/ScramSha1Plus.java new file mode 100644 index 000000000..2f9bbde2f --- /dev/null +++ b/src/main/java/im/conversations/android/xmpp/sasl/ScramSha1Plus.java @@ -0,0 +1,38 @@ +package im.conversations.android.xmpp.sasl; + +import com.google.common.hash.HashFunction; +import com.google.common.hash.Hashing; +import im.conversations.android.database.model.Account; +import im.conversations.android.database.model.Credential; + +public class ScramSha1Plus extends ScramPlusMechanism { + + public static final String MECHANISM = "SCRAM-SHA-1-PLUS"; + + public ScramSha1Plus( + final Account account, Credential credential, final ChannelBinding channelBinding) { + super(account, credential, channelBinding); + } + + @Override + protected HashFunction getHMac(final byte[] key) { + return (key == null || key.length == 0) + ? Hashing.hmacSha1(EMPTY_KEY) + : Hashing.hmacSha1(key); + } + + @Override + protected HashFunction getDigest() { + return Hashing.sha1(); + } + + @Override + public int getPriority() { + return 35; // higher than SCRAM-SHA512 (30) + } + + @Override + public String getMechanism() { + return MECHANISM; + } +} diff --git a/src/main/java/im/conversations/android/xmpp/sasl/ScramSha256.java b/src/main/java/im/conversations/android/xmpp/sasl/ScramSha256.java new file mode 100644 index 000000000..f3f3f3419 --- /dev/null +++ b/src/main/java/im/conversations/android/xmpp/sasl/ScramSha256.java @@ -0,0 +1,37 @@ +package im.conversations.android.xmpp.sasl; + +import com.google.common.hash.HashFunction; +import com.google.common.hash.Hashing; +import im.conversations.android.database.model.Account; +import im.conversations.android.database.model.Credential; + +public class ScramSha256 extends ScramMechanism { + + public static final String MECHANISM = "SCRAM-SHA-256"; + + public ScramSha256(final Account account, final Credential credential) { + super(account, credential, ChannelBinding.NONE); + } + + @Override + protected HashFunction getHMac(final byte[] key) { + return (key == null || key.length == 0) + ? Hashing.hmacSha256(EMPTY_KEY) + : Hashing.hmacSha256(key); + } + + @Override + protected HashFunction getDigest() { + return Hashing.sha256(); + } + + @Override + public int getPriority() { + return 25; + } + + @Override + public String getMechanism() { + return MECHANISM; + } +} diff --git a/src/main/java/im/conversations/android/xmpp/sasl/ScramSha256Plus.java b/src/main/java/im/conversations/android/xmpp/sasl/ScramSha256Plus.java new file mode 100644 index 000000000..e0e57331a --- /dev/null +++ b/src/main/java/im/conversations/android/xmpp/sasl/ScramSha256Plus.java @@ -0,0 +1,40 @@ +package im.conversations.android.xmpp.sasl; + +import com.google.common.hash.HashFunction; +import com.google.common.hash.Hashing; +import im.conversations.android.database.model.Account; +import im.conversations.android.database.model.Credential; + +public class ScramSha256Plus extends ScramPlusMechanism { + + public static final String MECHANISM = "SCRAM-SHA-256-PLUS"; + + public ScramSha256Plus( + final Account account, + final Credential credential, + final ChannelBinding channelBinding) { + super(account, credential, channelBinding); + } + + @Override + protected HashFunction getHMac(final byte[] key) { + return (key == null || key.length == 0) + ? Hashing.hmacSha256(EMPTY_KEY) + : Hashing.hmacSha256(key); + } + + @Override + protected HashFunction getDigest() { + return Hashing.sha256(); + } + + @Override + public int getPriority() { + return 40; + } + + @Override + public String getMechanism() { + return MECHANISM; + } +} diff --git a/src/main/java/im/conversations/android/xmpp/sasl/ScramSha512.java b/src/main/java/im/conversations/android/xmpp/sasl/ScramSha512.java new file mode 100644 index 000000000..4057f79a1 --- /dev/null +++ b/src/main/java/im/conversations/android/xmpp/sasl/ScramSha512.java @@ -0,0 +1,37 @@ +package im.conversations.android.xmpp.sasl; + +import com.google.common.hash.HashFunction; +import com.google.common.hash.Hashing; +import im.conversations.android.database.model.Account; +import im.conversations.android.database.model.Credential; + +public class ScramSha512 extends ScramMechanism { + + public static final String MECHANISM = "SCRAM-SHA-512"; + + public ScramSha512(final Account account, final Credential credential) { + super(account, credential, ChannelBinding.NONE); + } + + @Override + protected HashFunction getHMac(final byte[] key) { + return (key == null || key.length == 0) + ? Hashing.hmacSha512(EMPTY_KEY) + : Hashing.hmacSha512(key); + } + + @Override + protected HashFunction getDigest() { + return Hashing.sha512(); + } + + @Override + public int getPriority() { + return 30; + } + + @Override + public String getMechanism() { + return MECHANISM; + } +} diff --git a/src/main/java/im/conversations/android/xmpp/sasl/ScramSha512Plus.java b/src/main/java/im/conversations/android/xmpp/sasl/ScramSha512Plus.java new file mode 100644 index 000000000..5dfa34d07 --- /dev/null +++ b/src/main/java/im/conversations/android/xmpp/sasl/ScramSha512Plus.java @@ -0,0 +1,40 @@ +package im.conversations.android.xmpp.sasl; + +import com.google.common.hash.HashFunction; +import com.google.common.hash.Hashing; +import im.conversations.android.database.model.Account; +import im.conversations.android.database.model.Credential; + +public class ScramSha512Plus extends ScramPlusMechanism { + + public static final String MECHANISM = "SCRAM-SHA-512-PLUS"; + + public ScramSha512Plus( + final Account account, + final Credential credential, + final ChannelBinding channelBinding) { + super(account, credential, channelBinding); + } + + @Override + protected HashFunction getHMac(final byte[] key) { + return (key == null || key.length == 0) + ? Hashing.hmacSha512(EMPTY_KEY) + : Hashing.hmacSha512(key); + } + + @Override + protected HashFunction getDigest() { + return Hashing.sha512(); + } + + @Override + public int getPriority() { + return 45; + } + + @Override + public String getMechanism() { + return MECHANISM; + } +} diff --git a/src/main/java/im/conversations/android/xmpp/sasl/Tokenizer.java b/src/main/java/im/conversations/android/xmpp/sasl/Tokenizer.java new file mode 100644 index 000000000..3e649b83b --- /dev/null +++ b/src/main/java/im/conversations/android/xmpp/sasl/Tokenizer.java @@ -0,0 +1,77 @@ +package im.conversations.android.xmpp.sasl; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.NoSuchElementException; + +/** A tokenizer for GS2 header strings */ +public final class Tokenizer implements Iterator, Iterable { + private final List parts; + private int index; + + public Tokenizer(final byte[] challenge) { + final String challengeString = new String(challenge); + parts = new ArrayList<>(Arrays.asList(challengeString.split(","))); + // Trim parts. + for (int i = 0; i < parts.size(); i++) { + parts.set(i, parts.get(i).trim()); + } + index = 0; + } + + /** + * Returns true if there is at least one more element, false otherwise. + * + * @see #next + */ + @Override + public boolean hasNext() { + return parts.size() != index + 1; + } + + /** + * Returns the next object and advances the iterator. + * + * @return the next object. + * @throws java.util.NoSuchElementException if there are no more elements. + * @see #hasNext + */ + @Override + public String next() { + if (hasNext()) { + return parts.get(index++); + } else { + throw new NoSuchElementException("No such element. Size is: " + parts.size()); + } + } + + /** + * Removes the last object returned by {@code next} from the collection. This method can only be + * called once between each call to {@code next}. + * + * @throws UnsupportedOperationException if removing is not supported by the collection being + * iterated. + * @throws IllegalStateException if {@code next} has not been called, or {@code remove} has + * already been called after the last call to {@code next}. + */ + @Override + public void remove() { + if (index <= 0) { + throw new IllegalStateException( + "You can't delete an element before first next() method call"); + } + parts.remove(--index); + } + + /** + * Returns an {@link java.util.Iterator} for the elements in this object. + * + * @return An {@code Iterator} instance. + */ + @Override + public Iterator iterator() { + return parts.iterator(); + } +}