implement Private DNS (DoT)

due to limitations in the MiniDNS library this does not work when
'Validate hostname with DNSSEC' is enabled in the expert settings
This commit is contained in:
Daniel Gultsch 2023-10-06 13:28:55 +02:00
parent d3d582759f
commit 8f014d5525
No known key found for this signature in database
GPG key ID: F43D18AD2A0982C2
6 changed files with 604 additions and 1 deletions

View file

@ -0,0 +1,123 @@
package de.gultsch.minidns;
import android.content.Context;
import android.net.ConnectivityManager;
import android.net.LinkProperties;
import android.net.Network;
import android.os.Build;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import de.measite.minidns.AbstractDNSClient;
import de.measite.minidns.DNSMessage;
import java.io.IOException;
import java.net.InetAddress;
import java.util.List;
public class AndroidDNSClient extends AbstractDNSClient {
private final Context context;
private final NetworkDataSource networkDataSource = new NetworkDataSource();
private boolean askForDnssec = false;
public AndroidDNSClient(final Context context) {
super();
this.setDataSource(networkDataSource);
this.context = context;
}
private static String getPrivateDnsServerName(final LinkProperties linkProperties) {
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.P) {
return linkProperties.getPrivateDnsServerName();
} else {
return null;
}
}
private static boolean isPrivateDnsActive(final LinkProperties linkProperties) {
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.P) {
return linkProperties.isPrivateDnsActive();
} else {
return false;
}
}
@Override
protected DNSMessage.Builder newQuestion(final DNSMessage.Builder message) {
message.setRecursionDesired(true);
message.getEdnsBuilder()
.setUdpPayloadSize(networkDataSource.getUdpPayloadSize())
.setDnssecOk(askForDnssec);
return message;
}
@Override
protected DNSMessage query(final DNSMessage.Builder queryBuilder) throws IOException {
final DNSMessage question = newQuestion(queryBuilder).build();
for (final DNSServer dnsServer : getDNSServers()) {
final DNSMessage response = this.networkDataSource.query(question, dnsServer);
if (response == null) {
continue;
}
switch (response.responseCode) {
case NO_ERROR:
case NX_DOMAIN:
break;
default:
continue;
}
return response;
}
return null;
}
public boolean isAskForDnssec() {
return askForDnssec;
}
public void setAskForDnssec(boolean askForDnssec) {
this.askForDnssec = askForDnssec;
}
private List<DNSServer> getDNSServers() {
final ImmutableList.Builder<DNSServer> dnsServerBuilder = new ImmutableList.Builder<>();
final ConnectivityManager connectivityManager =
(ConnectivityManager) context.getSystemService(Context.CONNECTIVITY_SERVICE);
final Network[] networks = getActiveNetworks(connectivityManager);
for (final Network network : networks) {
final LinkProperties linkProperties = connectivityManager.getLinkProperties(network);
if (linkProperties == null) {
continue;
}
final String privateDnsServerName = getPrivateDnsServerName(linkProperties);
if (Strings.isNullOrEmpty(privateDnsServerName)) {
final boolean isPrivateDns = isPrivateDnsActive(linkProperties);
for (final InetAddress dnsServer : linkProperties.getDnsServers()) {
if (isPrivateDns) {
dnsServerBuilder.add(new DNSServer(dnsServer, Transport.TLS));
} else {
dnsServerBuilder.add(new DNSServer(dnsServer));
}
}
} else {
dnsServerBuilder.add(new DNSServer(privateDnsServerName, Transport.TLS));
}
}
return dnsServerBuilder.build();
}
private Network[] getActiveNetworks(final ConnectivityManager connectivityManager) {
if (connectivityManager == null) {
return new Network[0];
}
if (android.os.Build.VERSION.SDK_INT >= android.os.Build.VERSION_CODES.M) {
final Network activeNetwork = connectivityManager.getActiveNetwork();
if (activeNetwork != null) {
return new Network[] {activeNetwork};
}
}
return connectivityManager.getAllNetworks();
}
}

View file

@ -0,0 +1,104 @@
package de.gultsch.minidns;
import com.google.common.base.MoreObjects;
import com.google.common.base.Objects;
import com.google.common.base.Preconditions;
import com.google.common.collect.Iterables;
import java.net.InetAddress;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import javax.annotation.Nonnull;
public final class DNSServer {
public final InetAddress inetAddress;
public final String hostname;
public final int port;
public final List<Transport> transports;
public DNSServer(InetAddress inetAddress, Integer port, Transport transport) {
this.inetAddress = inetAddress;
this.port = port == null ? 0 : port;
this.transports = Collections.singletonList(transport);
this.hostname = null;
}
public DNSServer(final String hostname, final Integer port, final Transport transport) {
Preconditions.checkArgument(
Arrays.asList(Transport.HTTPS, Transport.TLS).contains(transport),
"hostname validation only works with TLS based transports");
this.hostname = hostname;
this.port = port == null ? 0 : port;
this.transports = Collections.singletonList(transport);
this.inetAddress = null;
}
public DNSServer(final String hostname, final Transport transport) {
this(hostname, Transport.DEFAULT_PORTS.get(transport), transport);
}
public DNSServer(InetAddress inetAddress, Transport transport) {
this(inetAddress, Transport.DEFAULT_PORTS.get(transport), transport);
}
public DNSServer(final InetAddress inetAddress) {
this(inetAddress, 53, Arrays.asList(Transport.UDP, Transport.TCP));
}
public DNSServer(final InetAddress inetAddress, int port, List<Transport> transports) {
this(inetAddress, null, port, transports);
}
private DNSServer(
final InetAddress inetAddress,
final String hostname,
final int port,
final List<Transport> transports) {
this.inetAddress = inetAddress;
this.hostname = hostname;
this.port = port;
this.transports = transports;
}
public Transport uniqueTransport() {
return Iterables.getOnlyElement(this.transports);
}
public DNSServer asUniqueTransport(final Transport transport) {
Preconditions.checkArgument(
this.transports.contains(transport),
"This DNS server does not have transport ",
transport);
return new DNSServer(inetAddress, hostname, port, Collections.singletonList(transport));
}
@Override
@Nonnull
public String toString() {
return MoreObjects.toStringHelper(this)
.add("inetAddress", inetAddress)
.add("hostname", hostname)
.add("port", port)
.add("transports", transports)
.toString();
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
DNSServer dnsServer = (DNSServer) o;
return port == dnsServer.port
&& Objects.equal(inetAddress, dnsServer.inetAddress)
&& Objects.equal(hostname, dnsServer.hostname)
&& Objects.equal(transports, dnsServer.transports);
}
@Override
public int hashCode() {
return Objects.hashCode(inetAddress, hostname, port, transports);
}
}

View file

@ -0,0 +1,190 @@
package de.gultsch.minidns;
import android.util.Log;
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import de.measite.minidns.DNSMessage;
import eu.siacs.conversations.Config;
import org.conscrypt.OkHostnameVerifier;
import java.io.Closeable;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.EOFException;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.SocketAddress;
import java.security.cert.Certificate;
import java.security.cert.X509Certificate;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Semaphore;
import javax.net.ssl.SSLPeerUnverifiedException;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
final class DNSSocket implements Closeable {
private static final int CONNECT_TIMEOUT = 5_000;
private final Semaphore semaphore = new Semaphore(1);
private final Map<Integer, SettableFuture<DNSMessage>> inFlightQueries = new HashMap<>();
private final Socket socket;
private final DataInputStream dataInputStream;
private final DataOutputStream dataOutputStream;
private DNSSocket(
final Socket socket,
final DataInputStream dataInputStream,
final DataOutputStream dataOutputStream) {
this.socket = socket;
this.dataInputStream = dataInputStream;
this.dataOutputStream = dataOutputStream;
new Thread(this::readDNSMessages).start();
}
private void readDNSMessages() {
try {
while (socket.isConnected()) {
final DNSMessage response = readDNSMessage();
final SettableFuture<DNSMessage> future;
synchronized (inFlightQueries) {
future = inFlightQueries.remove(response.id);
}
if (future != null) {
future.set(response);
} else {
Log.e(Config.LOGTAG, "no in flight query found for response id " + response.id);
}
}
evictInFlightQueries(new EOFException());
} catch (final IOException e) {
evictInFlightQueries(e);
}
}
private void evictInFlightQueries(final Exception e) {
synchronized (inFlightQueries) {
final Iterator<Map.Entry<Integer, SettableFuture<DNSMessage>>> iterator =
inFlightQueries.entrySet().iterator();
while (iterator.hasNext()) {
final Map.Entry<Integer, SettableFuture<DNSMessage>> entry = iterator.next();
entry.getValue().setException(e);
iterator.remove();
}
}
}
private static DNSSocket of(final Socket socket) throws IOException {
final DataInputStream dataInputStream = new DataInputStream(socket.getInputStream());
final DataOutputStream dataOutputStream = new DataOutputStream(socket.getOutputStream());
return new DNSSocket(socket, dataInputStream, dataOutputStream);
}
public static DNSSocket connect(final DNSServer dnsServer) throws IOException {
switch (dnsServer.uniqueTransport()) {
case TCP:
return connectTcpSocket(dnsServer);
case TLS:
return connectTlsSocket(dnsServer);
default:
throw new IllegalStateException("This is not a socket based transport");
}
}
private static DNSSocket connectTcpSocket(final DNSServer dnsServer) throws IOException {
Preconditions.checkArgument(dnsServer.uniqueTransport() == Transport.TCP);
final SocketAddress socketAddress =
new InetSocketAddress(dnsServer.inetAddress, dnsServer.port);
final Socket socket = new Socket();
socket.connect(socketAddress, CONNECT_TIMEOUT);
return DNSSocket.of(socket);
}
private static DNSSocket connectTlsSocket(final DNSServer dnsServer) throws IOException {
Preconditions.checkArgument(dnsServer.uniqueTransport() == Transport.TLS);
final SSLSocketFactory factory = (SSLSocketFactory) SSLSocketFactory.getDefault();
final SSLSocket sslSocket;
if (Strings.isNullOrEmpty(dnsServer.hostname)) {
final SocketAddress socketAddress =
new InetSocketAddress(dnsServer.inetAddress, dnsServer.port);
sslSocket = (SSLSocket) factory.createSocket(dnsServer.inetAddress, dnsServer.port);
sslSocket.connect(socketAddress, 5_000);
} else {
sslSocket = (SSLSocket) factory.createSocket(dnsServer.hostname, dnsServer.port);
final SSLSession session = sslSocket.getSession();
final Certificate[] peerCertificates = session.getPeerCertificates();
if (peerCertificates.length == 0 || !(peerCertificates[0] instanceof X509Certificate)) {
throw new IOException("Peer did not provide X509 certificates");
}
final X509Certificate certificate = (X509Certificate) peerCertificates[0];
if (!OkHostnameVerifier.strictInstance().verify(dnsServer.hostname, certificate)) {
throw new SSLPeerUnverifiedException("Peer did not provide valid certificates");
}
}
return DNSSocket.of(sslSocket);
}
public DNSMessage query(final DNSMessage query) throws IOException, InterruptedException {
try {
return queryAsync(query).get();
} catch (final ExecutionException e) {
final Throwable cause = e.getCause();
if (cause instanceof IOException) {
throw (IOException) cause;
} else {
throw new IOException(e);
}
}
}
public ListenableFuture<DNSMessage> queryAsync(final DNSMessage query)
throws InterruptedException, IOException {
final SettableFuture<DNSMessage> responseFuture = SettableFuture.create();
synchronized (this.inFlightQueries) {
this.inFlightQueries.put(query.id, responseFuture);
}
this.semaphore.acquire();
try {
query.writeTo(this.dataOutputStream);
this.dataOutputStream.flush();
} finally {
this.semaphore.release();
}
return responseFuture;
}
private DNSMessage readDNSMessage() throws IOException {
final int length = this.dataInputStream.readUnsignedShort();
byte[] data = new byte[length];
int read = 0;
while (read < length) {
read += this.dataInputStream.read(data, read, length - read);
}
return new DNSMessage(data);
}
@Override
public void close() throws IOException {
this.socket.close();
}
public void closeQuietly() {
try {
this.socket.close();
} catch (final IOException ignored) {
}
}
}

View file

@ -0,0 +1,160 @@
package de.gultsch.minidns;
import android.util.Log;
import androidx.annotation.NonNull;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.cache.RemovalListener;
import com.google.common.collect.ImmutableList;
import de.measite.minidns.DNSMessage;
import de.measite.minidns.MiniDNSException;
import de.measite.minidns.source.DNSDataSource;
import de.measite.minidns.util.MultipleIoException;
import eu.siacs.conversations.Config;
import java.io.IOException;
import java.net.DatagramPacket;
import java.net.DatagramSocket;
import java.net.InetAddress;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
public class NetworkDataSource extends DNSDataSource {
private static final LoadingCache<DNSServer, DNSSocket> socketCache =
CacheBuilder.newBuilder()
.removalListener(
(RemovalListener<DNSServer, DNSSocket>)
notification -> {
final DNSServer dnsServer = notification.getKey();
final DNSSocket dnsSocket = notification.getValue();
if (dnsSocket == null) {
return;
}
Log.d(Config.LOGTAG, "closing connection to " + dnsServer);
dnsSocket.closeQuietly();
})
.expireAfterAccess(5, TimeUnit.MINUTES)
.build(
new CacheLoader<DNSServer, DNSSocket>() {
@Override
@NonNull
public DNSSocket load(@NonNull final DNSServer dnsServer)
throws Exception {
Log.d(Config.LOGTAG, "establishing connection to " + dnsServer);
return DNSSocket.connect(dnsServer);
}
});
private static List<Transport> transportsForPort(final int port) {
final ImmutableList.Builder<Transport> transportBuilder = new ImmutableList.Builder<>();
for (final Map.Entry<Transport, Integer> entry : Transport.DEFAULT_PORTS.entrySet()) {
if (entry.getValue().equals(port)) {
transportBuilder.add(entry.getKey());
}
}
return transportBuilder.build();
}
@Override
public DNSMessage query(final DNSMessage message, final InetAddress address, final int port)
throws IOException {
final List<Transport> transports = transportsForPort(port);
Log.w(
Config.LOGTAG,
"using legacy DataSource interface. guessing transports "
+ transports
+ " from port");
if (transports.isEmpty()) {
throw new IOException(String.format("No transports found for port %d", port));
}
return query(message, new DNSServer(address, port, transports));
}
public DNSMessage query(final DNSMessage message, final DNSServer dnsServer)
throws IOException {
Log.d(Config.LOGTAG, "using " + dnsServer);
final List<IOException> ioExceptions = new ArrayList<>();
for (final Transport transport : dnsServer.transports) {
try {
final DNSMessage response =
queryWithUniqueTransport(message, dnsServer.asUniqueTransport(transport));
if (response != null && !response.truncated) {
return response;
}
} catch (final IOException e) {
ioExceptions.add(e);
} catch (final InterruptedException e) {
return null;
}
}
MultipleIoException.throwIfRequired(ioExceptions);
return null;
}
private DNSMessage queryWithUniqueTransport(final DNSMessage message, final DNSServer dnsServer)
throws IOException, InterruptedException {
final Transport transport = dnsServer.uniqueTransport();
switch (transport) {
case UDP:
return queryUdp(message, dnsServer.inetAddress, dnsServer.port);
case TCP:
case TLS:
return queryDnsSocket(message, dnsServer);
default:
throw new IOException(
String.format("Transport %s has not been implemented", transport));
}
}
protected DNSMessage queryUdp(
final DNSMessage message, final InetAddress address, final int port)
throws IOException {
final DatagramPacket request = message.asDatagram(address, port);
final byte[] buffer = new byte[udpPayloadSize];
try (final DatagramSocket socket = new DatagramSocket()) {
socket.setSoTimeout(timeout);
socket.send(request);
final DatagramPacket response = new DatagramPacket(buffer, buffer.length);
socket.receive(response);
DNSMessage dnsMessage = new DNSMessage(response.getData());
if (dnsMessage.id != message.id) {
throw new MiniDNSException.IdMismatch(message, dnsMessage);
}
return dnsMessage;
}
}
protected DNSMessage queryDnsSocket(final DNSMessage message, final DNSServer dnsServer)
throws IOException, InterruptedException {
final DNSSocket cachedDnsSocket = socketCache.getIfPresent(dnsServer);
if (cachedDnsSocket != null) {
try {
return cachedDnsSocket.query(message);
} catch (final IOException e) {
Log.d(
Config.LOGTAG,
"IOException occurred at cached socket. invalidating and falling through to new socket creation");
socketCache.invalidate(dnsServer);
}
}
try {
return socketCache.get(dnsServer).query(message);
} catch (final ExecutionException e) {
final Throwable cause = e.getCause();
if (cause instanceof IOException) {
throw (IOException) cause;
} else {
throw new IOException(cause);
}
}
}
}

View file

@ -0,0 +1,23 @@
package de.gultsch.minidns;
import com.google.common.collect.ImmutableMap;
import java.util.Map;
public enum Transport {
UDP,
TCP,
TLS,
HTTPS;
public static final Map<Transport, Integer> DEFAULT_PORTS;
static {
final ImmutableMap.Builder<Transport, Integer> builder = new ImmutableMap.Builder<>();
builder.put(Transport.UDP, 53);
builder.put(Transport.TCP, 53);
builder.put(Transport.TLS, 853);
builder.put(Transport.HTTPS, 443);
DEFAULT_PORTS = builder.build();
}
}

View file

@ -15,6 +15,7 @@ import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import de.gultsch.minidns.AndroidDNSClient;
import de.measite.minidns.AbstractDNSClient;
import de.measite.minidns.DNSCache;
import de.measite.minidns.DNSClient;
@ -274,7 +275,9 @@ public class Resolver {
private static <D extends Data> ResolverResult<D> resolveWithFallback(DNSName dnsName, Class<D> type, boolean validateHostname) throws IOException {
final Question question = new Question(dnsName, Record.TYPE.getType(type));
if (!validateHostname) {
return ResolverApi.INSTANCE.resolve(question);
final AndroidDNSClient androidDNSClient = new AndroidDNSClient(SERVICE);
final ResolverApi resolverApi = new ResolverApi(androidDNSClient);
return resolverApi.resolve(question);
}
try {
return DnssecResolverApi.INSTANCE.resolveDnssecReliable(question);