use futures in DNS resolver

This commit is contained in:
Daniel Gultsch 2023-03-02 18:44:27 +01:00
parent 100c735636
commit 4addeaa356
No known key found for this signature in database
GPG key ID: F43D18AD2A0982C2
3 changed files with 231 additions and 250 deletions

View file

@ -2,9 +2,12 @@ package im.conversations.android.dns;
import android.app.Application; import android.app.Application;
import android.content.Context; import android.content.Context;
import androidx.annotation.NonNull;
import com.google.common.base.MoreObjects; import com.google.common.collect.ImmutableList;
import com.google.common.base.Objects; import com.google.common.collect.Ordering;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.MoreExecutors;
import de.measite.minidns.AbstractDNSClient; import de.measite.minidns.AbstractDNSClient;
import de.measite.minidns.DNSCache; import de.measite.minidns.DNSCache;
import de.measite.minidns.DNSClient; import de.measite.minidns.DNSClient;
@ -24,14 +27,17 @@ import de.measite.minidns.record.CNAME;
import de.measite.minidns.record.Data; import de.measite.minidns.record.Data;
import de.measite.minidns.record.InternetAddressRR; import de.measite.minidns.record.InternetAddressRR;
import de.measite.minidns.record.SRV; import de.measite.minidns.record.SRV;
import im.conversations.android.database.model.Connection;
import java.io.IOException; import java.io.IOException;
import java.lang.reflect.Field; import java.lang.reflect.Field;
import java.net.Inet4Address;
import java.net.InetAddress; import java.net.InetAddress;
import java.net.UnknownHostException; import java.net.UnknownHostException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.stream.Collectors;
import org.jxmpp.jid.DomainJid; import org.jxmpp.jid.DomainJid;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -45,6 +51,8 @@ public class Resolver {
private static final String DIRECT_TLS_SERVICE = "_xmpps-client"; private static final String DIRECT_TLS_SERVICE = "_xmpps-client";
private static final String STARTTLS_SERVICE = "_xmpp-client"; private static final String STARTTLS_SERVICE = "_xmpp-client";
private static final Executor EXECUTOR = Executors.newFixedThreadPool(4);
private static Context SERVICE; private static Context SERVICE;
public static void init(final Application application) { public static void init(final Application application) {
@ -75,13 +83,15 @@ public class Resolver {
} }
} }
public static List<Result> fromHardCoded(final String hostname, final int port) { public static List<ServiceRecord> fromHardCoded(final Connection connection) {
final Result result = new Result(); return Collections.singletonList(
result.hostname = DNSName.from(hostname); new ServiceRecord(
result.port = port; null,
result.directTls = useDirectTls(port); DNSName.from(connection.hostname),
result.authenticated = true; connection.port,
return Collections.singletonList(result); connection.directTls,
0,
true));
} }
public static void checkDomain(final DomainJid jid) { public static void checkDomain(final DomainJid jid) {
@ -106,183 +116,132 @@ public class Resolver {
} }
} }
public static boolean useDirectTls(final int port) { public static List<ServiceRecord> resolve(final String domain) {
return port == 443 || port == 5223; final List<ServiceRecord> ipResults = fromIpAddress(domain);
}
public static List<Result> resolve(String domain) {
final List<Result> ipResults = fromIpAddress(domain);
if (ipResults.size() > 0) { if (ipResults.size() > 0) {
return ipResults; return ipResults;
} }
final List<Result> results = new ArrayList<>(); final ListenableFuture<List<ServiceRecord>> directTlsSrvRecords =
final List<Result> fallbackResults = new ArrayList<>(); Futures.submitAsync(() -> resolveSrv(domain, true), EXECUTOR);
final Thread[] threads = new Thread[3]; final ListenableFuture<List<ServiceRecord>> startTlsSrvRecords =
threads[0] = Futures.submitAsync(() -> resolveSrv(domain, false), EXECUTOR);
new Thread( final ListenableFuture<List<ServiceRecord>> srvRecords =
() -> { Futures.transform(
Futures.allAsList(directTlsSrvRecords, startTlsSrvRecords),
input -> {
final var list =
input.stream()
.flatMap(List::stream)
.collect(Collectors.toList());
if (list.isEmpty()) {
throw new IllegalStateException("No SRV records found");
}
return list;
},
MoreExecutors.directExecutor());
final ListenableFuture<List<ServiceRecord>> fallback =
Futures.submit(() -> resolveNoSrvRecords(DNSName.from(domain), true), EXECUTOR);
final var resultFuture =
Futures.catchingAsync(
srvRecords,
Exception.class,
input -> fallback,
MoreExecutors.directExecutor());
try { try {
final List<Result> list = resolveSrv(domain, true); return Ordering.natural().sortedCopy(resultFuture.get());
synchronized (results) { } catch (final Exception e) {
results.addAll(list);
}
} catch (final Throwable throwable) {
LOGGER.debug("error resolving SRV record (direct TLS)", throwable);
}
});
threads[1] =
new Thread(
() -> {
try {
final List<Result> list = resolveSrv(domain, false);
synchronized (results) {
results.addAll(list);
}
} catch (Throwable throwable) {
LOGGER.debug(
"error resolving SRV record (direct STARTTLS)", throwable);
}
});
threads[2] =
new Thread(
() -> {
List<Result> list = resolveNoSrvRecords(DNSName.from(domain), true);
synchronized (fallbackResults) {
fallbackResults.addAll(list);
}
});
for (final Thread thread : threads) {
thread.start();
}
try {
threads[0].join();
threads[1].join();
if (results.size() > 0) {
threads[2].interrupt();
synchronized (results) {
Collections.sort(results);
LOGGER.info("{}", results);
return new ArrayList<>(results);
}
} else {
threads[2].join();
synchronized (fallbackResults) {
Collections.sort(fallbackResults);
LOGGER.info("fallback {}", fallbackResults);
return new ArrayList<>(fallbackResults);
}
}
} catch (InterruptedException e) {
for (Thread thread : threads) {
thread.interrupt();
}
return Collections.emptyList(); return Collections.emptyList();
} }
} }
private static List<Result> fromIpAddress(String domain) { private static List<ServiceRecord> fromIpAddress(final String domain) {
if (!IP.matches(domain)) { if (!IP.matches(domain)) {
return Collections.emptyList(); return Collections.emptyList();
} }
final InetAddress ip;
try { try {
Result result = new Result(); ip = InetAddress.getByName(domain);
result.ip = InetAddress.getByName(domain); } catch (final UnknownHostException e) {
result.port = DEFAULT_PORT_XMPP;
return Collections.singletonList(result);
} catch (UnknownHostException e) {
return Collections.emptyList(); return Collections.emptyList();
} }
return Collections.singletonList(new ServiceRecord(ip, null, DEFAULT_PORT_XMPP, false, 0, false));
} }
private static List<Result> resolveSrv(String domain, final boolean directTls) private static ListenableFuture<List<ServiceRecord>> resolveSrv(
throws IOException { final String domain, final boolean directTls) throws IOException {
DNSName dnsName = DNSName dnsName =
DNSName.from( DNSName.from(
(directTls ? DIRECT_TLS_SERVICE : STARTTLS_SERVICE) + "._tcp." + domain); (directTls ? DIRECT_TLS_SERVICE : STARTTLS_SERVICE) + "._tcp." + domain);
ResolverResult<SRV> result = resolveWithFallback(dnsName, SRV.class); final ResolverResult<SRV> result = resolveWithFallback(dnsName, SRV.class);
final List<Result> results = new ArrayList<>(); final List<ListenableFuture<List<ServiceRecord>>> results = new ArrayList<>();
final List<Thread> threads = new ArrayList<>(); for (final SRV record : result.getAnswersOrEmptySet()) {
for (SRV record : result.getAnswersOrEmptySet()) {
if (record.name.length() == 0 && record.priority == 0) { if (record.name.length() == 0 && record.priority == 0) {
continue; continue;
} }
threads.add( results.add(
new Thread( Futures.submit(
() -> { () -> {
final List<Result> ipv4s = final List<ServiceRecord> ipv4s =
resolveIp( resolveIp(
record, record,
A.class, A.class,
result.isAuthenticData(), result.isAuthenticData(),
directTls); directTls);
if (ipv4s.size() == 0) { if (ipv4s.isEmpty()) {
Result resolverResult = Result.fromRecord(record, directTls); return Collections.singletonList(
resolverResult.authenticated = result.isAuthenticData(); ServiceRecord.fromRecord(
ipv4s.add(resolverResult); record, directTls, result.isAuthenticData()));
} else {
return ipv4s;
} }
synchronized (results) { },
results.addAll(ipv4s); EXECUTOR));
} results.add(
})); Futures.submit(
threads.add( () ->
new Thread(
() -> {
final List<Result> ipv6s =
resolveIp( resolveIp(
record, record,
AAAA.class, AAAA.class,
result.isAuthenticData(), result.isAuthenticData(),
directTls); directTls),
synchronized (results) { EXECUTOR));
results.addAll(ipv6s);
} }
})); return Futures.transform(
} Futures.allAsList(results),
for (Thread thread : threads) { input -> input.stream().flatMap(List::stream).collect(Collectors.toList()),
thread.start(); MoreExecutors.directExecutor());
}
for (Thread thread : threads) {
try {
thread.join();
} catch (InterruptedException e) {
return Collections.emptyList();
}
}
return results;
} }
private static <D extends InternetAddressRR> List<Result> resolveIp( private static <D extends InternetAddressRR> List<ServiceRecord> resolveIp(
SRV srv, Class<D> type, boolean authenticated, boolean directTls) { SRV srv, Class<D> type, boolean authenticated, boolean directTls) {
List<Result> list = new ArrayList<>(); final ImmutableList.Builder<ServiceRecord> builder = new ImmutableList.Builder<>();
try { try {
ResolverResult<D> results = resolveWithFallback(srv.name, type, authenticated); ResolverResult<D> results = resolveWithFallback(srv.name, type, authenticated);
for (D record : results.getAnswersOrEmptySet()) { for (D record : results.getAnswersOrEmptySet()) {
Result resolverResult = Result.fromRecord(srv, directTls); builder.add(
resolverResult.authenticated = ServiceRecord.fromRecord(
results.isAuthenticData() srv,
&& authenticated; // TODO technically it doesnt matter if the IP directTls,
// was authenticated results.isAuthenticData() && authenticated,
resolverResult.ip = record.getInetAddress(); record.getInetAddress()));
list.add(resolverResult);
} }
} catch (final Throwable t) { } catch (final Throwable t) {
LOGGER.info("error resolving {}", type.getSimpleName(), t); LOGGER.info("error resolving {}", type.getSimpleName(), t);
} }
return list; return builder.build();
} }
private static List<Result> resolveNoSrvRecords(DNSName dnsName, boolean withCnames) { private static List<ServiceRecord> resolveNoSrvRecords(DNSName dnsName, boolean includeCName) {
List<Result> results = new ArrayList<>(); List<ServiceRecord> results = new ArrayList<>();
try { try {
for (A a : resolveWithFallback(dnsName, A.class, false).getAnswersOrEmptySet()) { for (A a : resolveWithFallback(dnsName, A.class, false).getAnswersOrEmptySet()) {
results.add(Result.createDefault(dnsName, a.getInetAddress())); results.add(ServiceRecord.createDefault(dnsName, a.getInetAddress()));
} }
for (AAAA aaaa : for (AAAA aaaa :
resolveWithFallback(dnsName, AAAA.class, false).getAnswersOrEmptySet()) { resolveWithFallback(dnsName, AAAA.class, false).getAnswersOrEmptySet()) {
results.add(Result.createDefault(dnsName, aaaa.getInetAddress())); results.add(ServiceRecord.createDefault(dnsName, aaaa.getInetAddress()));
} }
if (results.size() == 0 && withCnames) { if (results.size() == 0 && includeCName) {
for (CNAME cname : for (CNAME cname :
resolveWithFallback(dnsName, CNAME.class, false).getAnswersOrEmptySet()) { resolveWithFallback(dnsName, CNAME.class, false).getAnswersOrEmptySet()) {
results.addAll(resolveNoSrvRecords(cname.name, false)); results.addAll(resolveNoSrvRecords(cname.name, false));
@ -291,7 +250,7 @@ public class Resolver {
} catch (Throwable throwable) { } catch (Throwable throwable) {
LOGGER.info("Error resolving fallback records", throwable); LOGGER.info("Error resolving fallback records", throwable);
} }
results.add(Result.createDefault(dnsName)); results.add(ServiceRecord.createDefault(dnsName));
return results; return results;
} }
@ -327,108 +286,4 @@ public class Resolver {
return false; return false;
} }
public static class Result implements Comparable<Result> {
private InetAddress ip;
private DNSName hostname;
private int port = DEFAULT_PORT_XMPP;
private boolean directTls = false;
private boolean authenticated = false;
private int priority;
static Result fromRecord(SRV srv, boolean directTls) {
Result result = new Result();
result.port = srv.port;
result.hostname = srv.name;
result.directTls = directTls;
result.priority = srv.priority;
return result;
}
static Result createDefault(DNSName hostname, InetAddress ip) {
Result result = new Result();
result.port = DEFAULT_PORT_XMPP;
result.hostname = hostname;
result.ip = ip;
return result;
}
static Result createDefault(DNSName hostname) {
return createDefault(hostname, null);
}
public InetAddress getIp() {
return ip;
}
public int getPort() {
return port;
}
public DNSName getHostname() {
return hostname;
}
public boolean isDirectTls() {
return directTls;
}
public boolean isAuthenticated() {
return authenticated;
}
@Override
public int compareTo(@NonNull Result result) {
// TODO use comparison chain. get rid of IPv4 preference
if (result.priority == priority) {
if (directTls == result.directTls) {
if (ip == null && result.ip == null) {
return 0;
} else if (ip != null && result.ip != null) {
if (ip instanceof Inet4Address && result.ip instanceof Inet4Address) {
return 0;
} else {
return ip instanceof Inet4Address ? -1 : 1;
}
} else {
return ip != null ? -1 : 1;
}
} else {
return directTls ? -1 : 1;
}
} else {
return priority - result.priority;
}
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Result result = (Result) o;
return port == result.port
&& directTls == result.directTls
&& authenticated == result.authenticated
&& priority == result.priority
&& Objects.equal(ip, result.ip)
&& Objects.equal(hostname, result.hostname);
}
@Override
public int hashCode() {
return Objects.hashCode(ip, hostname, port, directTls, authenticated, priority);
}
@NonNull
@Override
public String toString() {
return MoreObjects.toStringHelper(this)
.add("ip", ip)
.add("hostname", hostname)
.add("port", port)
.add("directTls", directTls)
.add("authenticated", authenticated)
.add("priority", priority)
.toString();
}
}
} }

View file

@ -0,0 +1,124 @@
package im.conversations.android.dns;
import androidx.annotation.NonNull;
import com.google.common.base.MoreObjects;
import com.google.common.base.Objects;
import java.net.InetAddress;
import de.measite.minidns.DNSName;
import de.measite.minidns.record.SRV;
public class ServiceRecord implements Comparable<ServiceRecord> {
private final InetAddress ip;
private final DNSName hostname;
private final int port;
private final boolean directTls;
private final int priority;
private final boolean authenticated;
public ServiceRecord(
InetAddress ip,
DNSName hostname,
int port,
boolean directTls,
int priority,
boolean authenticated) {
this.ip = ip;
this.hostname = hostname;
this.port = port;
this.directTls = directTls;
this.authenticated = authenticated;
this.priority = priority;
}
public static ServiceRecord fromRecord(
final SRV srv,
final boolean directTls,
final boolean authenticated,
final InetAddress ip) {
return new ServiceRecord(ip, srv.name, srv.port, directTls, srv.priority, authenticated);
}
public static ServiceRecord fromRecord(
final SRV srv, final boolean directTls, final boolean authenticated) {
return fromRecord(srv, directTls, authenticated, null);
}
static ServiceRecord createDefault(final DNSName hostname, final InetAddress ip) {
return new ServiceRecord(ip, hostname, Resolver.DEFAULT_PORT_XMPP, false, 0, false);
}
static ServiceRecord createDefault(final DNSName hostname) {
return createDefault(hostname, null);
}
public InetAddress getIp() {
return ip;
}
public int getPort() {
return port;
}
public DNSName getHostname() {
return hostname;
}
public boolean isDirectTls() {
return directTls;
}
public boolean isAuthenticated() {
return authenticated;
}
@Override
public int compareTo(@NonNull ServiceRecord result) {
if (result.priority == priority) {
if (directTls == result.directTls) {
if (ip == null && result.ip == null) {
return 0;
} else {
return ip != null ? -1 : 1;
}
} else {
return directTls ? -1 : 1;
}
} else {
return priority - result.priority;
}
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ServiceRecord result = (ServiceRecord) o;
return port == result.port
&& directTls == result.directTls
&& authenticated == result.authenticated
&& priority == result.priority
&& Objects.equal(ip, result.ip)
&& Objects.equal(hostname, result.hostname);
}
@Override
public int hashCode() {
return Objects.hashCode(ip, hostname, port, directTls, authenticated, priority);
}
@NonNull
@Override
public String toString() {
return MoreObjects.toStringHelper(this)
.add("ip", ip)
.add("hostname", hostname)
.add("port", port)
.add("directTls", directTls)
.add("authenticated", authenticated)
.add("priority", priority)
.toString();
}
}

View file

@ -26,6 +26,7 @@ import im.conversations.android.database.model.Account;
import im.conversations.android.database.model.Connection; import im.conversations.android.database.model.Connection;
import im.conversations.android.database.model.Credential; import im.conversations.android.database.model.Credential;
import im.conversations.android.dns.Resolver; import im.conversations.android.dns.Resolver;
import im.conversations.android.dns.ServiceRecord;
import im.conversations.android.socks.SocksSocketFactory; import im.conversations.android.socks.SocksSocketFactory;
import im.conversations.android.tls.SSLSockets; import im.conversations.android.tls.SSLSockets;
import im.conversations.android.tls.XmppDomainVerifier; import im.conversations.android.tls.XmppDomainVerifier;
@ -322,12 +323,13 @@ public class XmppConnection implements Runnable {
} }
} else { } else {
final String domain = account.address.getDomain().toString(); final String domain = account.address.getDomain().toString();
final List<Resolver.Result> results; final List<ServiceRecord> results;
if (connection != null) { if (connection != null) {
results = Resolver.fromHardCoded(connection.hostname, connection.port); results = Resolver.fromHardCoded(connection);
} else { } else {
results = Resolver.resolve(domain); results = Resolver.resolve(domain);
} }
LOGGER.info("{}", results);
if (Thread.currentThread().isInterrupted()) { if (Thread.currentThread().isInterrupted()) {
LOGGER.debug(account.address + ": Thread was interrupted"); LOGGER.debug(account.address + ": Thread was interrupted");
return; return;
@ -336,7 +338,7 @@ public class XmppConnection implements Runnable {
LOGGER.warn("Resolver results were empty"); LOGGER.warn("Resolver results were empty");
return; return;
} }
final Resolver.Result storedBackupResult; final ServiceRecord storedBackupResult;
if (connection != null) { if (connection != null) {
storedBackupResult = null; storedBackupResult = null;
} else { } else {
@ -351,9 +353,9 @@ public class XmppConnection implements Runnable {
+ storedBackupResult); + storedBackupResult);
} }
} }
for (Iterator<Resolver.Result> iterator = results.iterator(); for (Iterator<ServiceRecord> iterator = results.iterator();
iterator.hasNext(); ) { iterator.hasNext(); ) {
final Resolver.Result result = iterator.next(); final ServiceRecord result = iterator.next();
if (Thread.currentThread().isInterrupted()) { if (Thread.currentThread().isInterrupted()) {
LOGGER.debug(account.address + ": Thread was interrupted"); LOGGER.debug(account.address + ": Thread was interrupted");
return; return;