use futures in DNS resolver

This commit is contained in:
Daniel Gultsch 2023-03-02 18:44:27 +01:00
parent 100c735636
commit 4addeaa356
No known key found for this signature in database
GPG key ID: F43D18AD2A0982C2
3 changed files with 231 additions and 250 deletions

View file

@ -2,9 +2,12 @@ package im.conversations.android.dns;
import android.app.Application;
import android.content.Context;
import androidx.annotation.NonNull;
import com.google.common.base.MoreObjects;
import com.google.common.base.Objects;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Ordering;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.MoreExecutors;
import de.measite.minidns.AbstractDNSClient;
import de.measite.minidns.DNSCache;
import de.measite.minidns.DNSClient;
@ -24,14 +27,17 @@ import de.measite.minidns.record.CNAME;
import de.measite.minidns.record.Data;
import de.measite.minidns.record.InternetAddressRR;
import de.measite.minidns.record.SRV;
import im.conversations.android.database.model.Connection;
import java.io.IOException;
import java.lang.reflect.Field;
import java.net.Inet4Address;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.stream.Collectors;
import org.jxmpp.jid.DomainJid;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -45,6 +51,8 @@ public class Resolver {
private static final String DIRECT_TLS_SERVICE = "_xmpps-client";
private static final String STARTTLS_SERVICE = "_xmpp-client";
private static final Executor EXECUTOR = Executors.newFixedThreadPool(4);
private static Context SERVICE;
public static void init(final Application application) {
@ -75,13 +83,15 @@ public class Resolver {
}
}
public static List<Result> fromHardCoded(final String hostname, final int port) {
final Result result = new Result();
result.hostname = DNSName.from(hostname);
result.port = port;
result.directTls = useDirectTls(port);
result.authenticated = true;
return Collections.singletonList(result);
public static List<ServiceRecord> fromHardCoded(final Connection connection) {
return Collections.singletonList(
new ServiceRecord(
null,
DNSName.from(connection.hostname),
connection.port,
connection.directTls,
0,
true));
}
public static void checkDomain(final DomainJid jid) {
@ -106,183 +116,132 @@ public class Resolver {
}
}
public static boolean useDirectTls(final int port) {
return port == 443 || port == 5223;
}
public static List<Result> resolve(String domain) {
final List<Result> ipResults = fromIpAddress(domain);
public static List<ServiceRecord> resolve(final String domain) {
final List<ServiceRecord> ipResults = fromIpAddress(domain);
if (ipResults.size() > 0) {
return ipResults;
}
final List<Result> results = new ArrayList<>();
final List<Result> fallbackResults = new ArrayList<>();
final Thread[] threads = new Thread[3];
threads[0] =
new Thread(
() -> {
final ListenableFuture<List<ServiceRecord>> directTlsSrvRecords =
Futures.submitAsync(() -> resolveSrv(domain, true), EXECUTOR);
final ListenableFuture<List<ServiceRecord>> startTlsSrvRecords =
Futures.submitAsync(() -> resolveSrv(domain, false), EXECUTOR);
final ListenableFuture<List<ServiceRecord>> srvRecords =
Futures.transform(
Futures.allAsList(directTlsSrvRecords, startTlsSrvRecords),
input -> {
final var list =
input.stream()
.flatMap(List::stream)
.collect(Collectors.toList());
if (list.isEmpty()) {
throw new IllegalStateException("No SRV records found");
}
return list;
},
MoreExecutors.directExecutor());
final ListenableFuture<List<ServiceRecord>> fallback =
Futures.submit(() -> resolveNoSrvRecords(DNSName.from(domain), true), EXECUTOR);
final var resultFuture =
Futures.catchingAsync(
srvRecords,
Exception.class,
input -> fallback,
MoreExecutors.directExecutor());
try {
final List<Result> list = resolveSrv(domain, true);
synchronized (results) {
results.addAll(list);
}
} catch (final Throwable throwable) {
LOGGER.debug("error resolving SRV record (direct TLS)", throwable);
}
});
threads[1] =
new Thread(
() -> {
try {
final List<Result> list = resolveSrv(domain, false);
synchronized (results) {
results.addAll(list);
}
} catch (Throwable throwable) {
LOGGER.debug(
"error resolving SRV record (direct STARTTLS)", throwable);
}
});
threads[2] =
new Thread(
() -> {
List<Result> list = resolveNoSrvRecords(DNSName.from(domain), true);
synchronized (fallbackResults) {
fallbackResults.addAll(list);
}
});
for (final Thread thread : threads) {
thread.start();
}
try {
threads[0].join();
threads[1].join();
if (results.size() > 0) {
threads[2].interrupt();
synchronized (results) {
Collections.sort(results);
LOGGER.info("{}", results);
return new ArrayList<>(results);
}
} else {
threads[2].join();
synchronized (fallbackResults) {
Collections.sort(fallbackResults);
LOGGER.info("fallback {}", fallbackResults);
return new ArrayList<>(fallbackResults);
}
}
} catch (InterruptedException e) {
for (Thread thread : threads) {
thread.interrupt();
}
return Ordering.natural().sortedCopy(resultFuture.get());
} catch (final Exception e) {
return Collections.emptyList();
}
}
private static List<Result> fromIpAddress(String domain) {
private static List<ServiceRecord> fromIpAddress(final String domain) {
if (!IP.matches(domain)) {
return Collections.emptyList();
}
final InetAddress ip;
try {
Result result = new Result();
result.ip = InetAddress.getByName(domain);
result.port = DEFAULT_PORT_XMPP;
return Collections.singletonList(result);
} catch (UnknownHostException e) {
ip = InetAddress.getByName(domain);
} catch (final UnknownHostException e) {
return Collections.emptyList();
}
return Collections.singletonList(new ServiceRecord(ip, null, DEFAULT_PORT_XMPP, false, 0, false));
}
private static List<Result> resolveSrv(String domain, final boolean directTls)
throws IOException {
private static ListenableFuture<List<ServiceRecord>> resolveSrv(
final String domain, final boolean directTls) throws IOException {
DNSName dnsName =
DNSName.from(
(directTls ? DIRECT_TLS_SERVICE : STARTTLS_SERVICE) + "._tcp." + domain);
ResolverResult<SRV> result = resolveWithFallback(dnsName, SRV.class);
final List<Result> results = new ArrayList<>();
final List<Thread> threads = new ArrayList<>();
for (SRV record : result.getAnswersOrEmptySet()) {
final ResolverResult<SRV> result = resolveWithFallback(dnsName, SRV.class);
final List<ListenableFuture<List<ServiceRecord>>> results = new ArrayList<>();
for (final SRV record : result.getAnswersOrEmptySet()) {
if (record.name.length() == 0 && record.priority == 0) {
continue;
}
threads.add(
new Thread(
results.add(
Futures.submit(
() -> {
final List<Result> ipv4s =
final List<ServiceRecord> ipv4s =
resolveIp(
record,
A.class,
result.isAuthenticData(),
directTls);
if (ipv4s.size() == 0) {
Result resolverResult = Result.fromRecord(record, directTls);
resolverResult.authenticated = result.isAuthenticData();
ipv4s.add(resolverResult);
if (ipv4s.isEmpty()) {
return Collections.singletonList(
ServiceRecord.fromRecord(
record, directTls, result.isAuthenticData()));
} else {
return ipv4s;
}
synchronized (results) {
results.addAll(ipv4s);
}
}));
threads.add(
new Thread(
() -> {
final List<Result> ipv6s =
},
EXECUTOR));
results.add(
Futures.submit(
() ->
resolveIp(
record,
AAAA.class,
result.isAuthenticData(),
directTls);
synchronized (results) {
results.addAll(ipv6s);
directTls),
EXECUTOR));
}
}));
}
for (Thread thread : threads) {
thread.start();
}
for (Thread thread : threads) {
try {
thread.join();
} catch (InterruptedException e) {
return Collections.emptyList();
}
}
return results;
return Futures.transform(
Futures.allAsList(results),
input -> input.stream().flatMap(List::stream).collect(Collectors.toList()),
MoreExecutors.directExecutor());
}
private static <D extends InternetAddressRR> List<Result> resolveIp(
private static <D extends InternetAddressRR> List<ServiceRecord> resolveIp(
SRV srv, Class<D> type, boolean authenticated, boolean directTls) {
List<Result> list = new ArrayList<>();
final ImmutableList.Builder<ServiceRecord> builder = new ImmutableList.Builder<>();
try {
ResolverResult<D> results = resolveWithFallback(srv.name, type, authenticated);
for (D record : results.getAnswersOrEmptySet()) {
Result resolverResult = Result.fromRecord(srv, directTls);
resolverResult.authenticated =
results.isAuthenticData()
&& authenticated; // TODO technically it doesnt matter if the IP
// was authenticated
resolverResult.ip = record.getInetAddress();
list.add(resolverResult);
builder.add(
ServiceRecord.fromRecord(
srv,
directTls,
results.isAuthenticData() && authenticated,
record.getInetAddress()));
}
} catch (final Throwable t) {
LOGGER.info("error resolving {}", type.getSimpleName(), t);
}
return list;
return builder.build();
}
private static List<Result> resolveNoSrvRecords(DNSName dnsName, boolean withCnames) {
List<Result> results = new ArrayList<>();
private static List<ServiceRecord> resolveNoSrvRecords(DNSName dnsName, boolean includeCName) {
List<ServiceRecord> results = new ArrayList<>();
try {
for (A a : resolveWithFallback(dnsName, A.class, false).getAnswersOrEmptySet()) {
results.add(Result.createDefault(dnsName, a.getInetAddress()));
results.add(ServiceRecord.createDefault(dnsName, a.getInetAddress()));
}
for (AAAA aaaa :
resolveWithFallback(dnsName, AAAA.class, false).getAnswersOrEmptySet()) {
results.add(Result.createDefault(dnsName, aaaa.getInetAddress()));
results.add(ServiceRecord.createDefault(dnsName, aaaa.getInetAddress()));
}
if (results.size() == 0 && withCnames) {
if (results.size() == 0 && includeCName) {
for (CNAME cname :
resolveWithFallback(dnsName, CNAME.class, false).getAnswersOrEmptySet()) {
results.addAll(resolveNoSrvRecords(cname.name, false));
@ -291,7 +250,7 @@ public class Resolver {
} catch (Throwable throwable) {
LOGGER.info("Error resolving fallback records", throwable);
}
results.add(Result.createDefault(dnsName));
results.add(ServiceRecord.createDefault(dnsName));
return results;
}
@ -327,108 +286,4 @@ public class Resolver {
return false;
}
public static class Result implements Comparable<Result> {
private InetAddress ip;
private DNSName hostname;
private int port = DEFAULT_PORT_XMPP;
private boolean directTls = false;
private boolean authenticated = false;
private int priority;
static Result fromRecord(SRV srv, boolean directTls) {
Result result = new Result();
result.port = srv.port;
result.hostname = srv.name;
result.directTls = directTls;
result.priority = srv.priority;
return result;
}
static Result createDefault(DNSName hostname, InetAddress ip) {
Result result = new Result();
result.port = DEFAULT_PORT_XMPP;
result.hostname = hostname;
result.ip = ip;
return result;
}
static Result createDefault(DNSName hostname) {
return createDefault(hostname, null);
}
public InetAddress getIp() {
return ip;
}
public int getPort() {
return port;
}
public DNSName getHostname() {
return hostname;
}
public boolean isDirectTls() {
return directTls;
}
public boolean isAuthenticated() {
return authenticated;
}
@Override
public int compareTo(@NonNull Result result) {
// TODO use comparison chain. get rid of IPv4 preference
if (result.priority == priority) {
if (directTls == result.directTls) {
if (ip == null && result.ip == null) {
return 0;
} else if (ip != null && result.ip != null) {
if (ip instanceof Inet4Address && result.ip instanceof Inet4Address) {
return 0;
} else {
return ip instanceof Inet4Address ? -1 : 1;
}
} else {
return ip != null ? -1 : 1;
}
} else {
return directTls ? -1 : 1;
}
} else {
return priority - result.priority;
}
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Result result = (Result) o;
return port == result.port
&& directTls == result.directTls
&& authenticated == result.authenticated
&& priority == result.priority
&& Objects.equal(ip, result.ip)
&& Objects.equal(hostname, result.hostname);
}
@Override
public int hashCode() {
return Objects.hashCode(ip, hostname, port, directTls, authenticated, priority);
}
@NonNull
@Override
public String toString() {
return MoreObjects.toStringHelper(this)
.add("ip", ip)
.add("hostname", hostname)
.add("port", port)
.add("directTls", directTls)
.add("authenticated", authenticated)
.add("priority", priority)
.toString();
}
}
}

View file

@ -0,0 +1,124 @@
package im.conversations.android.dns;
import androidx.annotation.NonNull;
import com.google.common.base.MoreObjects;
import com.google.common.base.Objects;
import java.net.InetAddress;
import de.measite.minidns.DNSName;
import de.measite.minidns.record.SRV;
public class ServiceRecord implements Comparable<ServiceRecord> {
private final InetAddress ip;
private final DNSName hostname;
private final int port;
private final boolean directTls;
private final int priority;
private final boolean authenticated;
public ServiceRecord(
InetAddress ip,
DNSName hostname,
int port,
boolean directTls,
int priority,
boolean authenticated) {
this.ip = ip;
this.hostname = hostname;
this.port = port;
this.directTls = directTls;
this.authenticated = authenticated;
this.priority = priority;
}
public static ServiceRecord fromRecord(
final SRV srv,
final boolean directTls,
final boolean authenticated,
final InetAddress ip) {
return new ServiceRecord(ip, srv.name, srv.port, directTls, srv.priority, authenticated);
}
public static ServiceRecord fromRecord(
final SRV srv, final boolean directTls, final boolean authenticated) {
return fromRecord(srv, directTls, authenticated, null);
}
static ServiceRecord createDefault(final DNSName hostname, final InetAddress ip) {
return new ServiceRecord(ip, hostname, Resolver.DEFAULT_PORT_XMPP, false, 0, false);
}
static ServiceRecord createDefault(final DNSName hostname) {
return createDefault(hostname, null);
}
public InetAddress getIp() {
return ip;
}
public int getPort() {
return port;
}
public DNSName getHostname() {
return hostname;
}
public boolean isDirectTls() {
return directTls;
}
public boolean isAuthenticated() {
return authenticated;
}
@Override
public int compareTo(@NonNull ServiceRecord result) {
if (result.priority == priority) {
if (directTls == result.directTls) {
if (ip == null && result.ip == null) {
return 0;
} else {
return ip != null ? -1 : 1;
}
} else {
return directTls ? -1 : 1;
}
} else {
return priority - result.priority;
}
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ServiceRecord result = (ServiceRecord) o;
return port == result.port
&& directTls == result.directTls
&& authenticated == result.authenticated
&& priority == result.priority
&& Objects.equal(ip, result.ip)
&& Objects.equal(hostname, result.hostname);
}
@Override
public int hashCode() {
return Objects.hashCode(ip, hostname, port, directTls, authenticated, priority);
}
@NonNull
@Override
public String toString() {
return MoreObjects.toStringHelper(this)
.add("ip", ip)
.add("hostname", hostname)
.add("port", port)
.add("directTls", directTls)
.add("authenticated", authenticated)
.add("priority", priority)
.toString();
}
}

View file

@ -26,6 +26,7 @@ import im.conversations.android.database.model.Account;
import im.conversations.android.database.model.Connection;
import im.conversations.android.database.model.Credential;
import im.conversations.android.dns.Resolver;
import im.conversations.android.dns.ServiceRecord;
import im.conversations.android.socks.SocksSocketFactory;
import im.conversations.android.tls.SSLSockets;
import im.conversations.android.tls.XmppDomainVerifier;
@ -322,12 +323,13 @@ public class XmppConnection implements Runnable {
}
} else {
final String domain = account.address.getDomain().toString();
final List<Resolver.Result> results;
final List<ServiceRecord> results;
if (connection != null) {
results = Resolver.fromHardCoded(connection.hostname, connection.port);
results = Resolver.fromHardCoded(connection);
} else {
results = Resolver.resolve(domain);
}
LOGGER.info("{}", results);
if (Thread.currentThread().isInterrupted()) {
LOGGER.debug(account.address + ": Thread was interrupted");
return;
@ -336,7 +338,7 @@ public class XmppConnection implements Runnable {
LOGGER.warn("Resolver results were empty");
return;
}
final Resolver.Result storedBackupResult;
final ServiceRecord storedBackupResult;
if (connection != null) {
storedBackupResult = null;
} else {
@ -351,9 +353,9 @@ public class XmppConnection implements Runnable {
+ storedBackupResult);
}
}
for (Iterator<Resolver.Result> iterator = results.iterator();
for (Iterator<ServiceRecord> iterator = results.iterator();
iterator.hasNext(); ) {
final Resolver.Result result = iterator.next();
final ServiceRecord result = iterator.next();
if (Thread.currentThread().isInterrupted()) {
LOGGER.debug(account.address + ": Thread was interrupted");
return;