diff --git a/src/main/java/de/gultsch/minidns/AndroidDNSClient.java b/src/main/java/de/gultsch/minidns/AndroidDNSClient.java index 194ad23bd..d47430339 100644 --- a/src/main/java/de/gultsch/minidns/AndroidDNSClient.java +++ b/src/main/java/de/gultsch/minidns/AndroidDNSClient.java @@ -5,18 +5,30 @@ import android.net.ConnectivityManager; import android.net.LinkProperties; import android.net.Network; import android.os.Build; +import android.util.Log; +import androidx.collection.LruCache; + +import com.google.common.base.Objects; import com.google.common.base.Strings; +import com.google.common.collect.Collections2; import com.google.common.collect.ImmutableList; import de.measite.minidns.AbstractDNSClient; import de.measite.minidns.DNSMessage; +import eu.siacs.conversations.Config; + import java.io.IOException; import java.net.InetAddress; +import java.time.Duration; +import java.util.Collections; import java.util.List; public class AndroidDNSClient extends AbstractDNSClient { + + private static final LruCache QUERY_CACHE = + new LruCache<>(1024); private final Context context; private final NetworkDataSource networkDataSource = new NetworkDataSource(); private boolean askForDnssec = false; @@ -56,6 +68,8 @@ public class AndroidDNSClient extends AbstractDNSClient { protected DNSMessage query(final DNSMessage.Builder queryBuilder) throws IOException { final DNSMessage question = newQuestion(queryBuilder).build(); for (final DNSServer dnsServer : getDNSServers()) { + final QuestionServerTuple cacheKey = new QuestionServerTuple(dnsServer, question); + final DNSMessage cachedResponse = queryCache(cacheKey); final DNSMessage response = this.networkDataSource.query(question, dnsServer); if (response == null) { continue; @@ -67,7 +81,7 @@ public class AndroidDNSClient extends AbstractDNSClient { default: continue; } - + cacheQuery(cacheKey, response); return response; } return null; @@ -120,4 +134,68 @@ public class AndroidDNSClient extends AbstractDNSClient { } return connectivityManager.getAllNetworks(); } + + private DNSMessage queryCache(final QuestionServerTuple key) { + final DNSMessage cachedResponse; + synchronized (QUERY_CACHE) { + cachedResponse = QUERY_CACHE.get(key); + if (cachedResponse == null) { + return null; + } + final long expiresIn = expiresIn(cachedResponse); + if (expiresIn < 0) { + QUERY_CACHE.remove(key); + return null; + } + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) { + Log.d( + Config.LOGTAG, + "DNS query came from cache. expires in " + Duration.ofMillis(expiresIn)); + } + } + return cachedResponse; + } + + private void cacheQuery(final QuestionServerTuple key, final DNSMessage response) { + if (response.receiveTimestamp <= 0) { + return; + } + synchronized (QUERY_CACHE) { + QUERY_CACHE.put(key, response); + } + } + + private static long expiresAt(final DNSMessage dnsMessage) { + return dnsMessage.receiveTimestamp + + (Collections.min(Collections2.transform(dnsMessage.answerSection, d -> d.ttl)) + * 1000L); + } + + private static long expiresIn(final DNSMessage dnsMessage) { + return expiresAt(dnsMessage) - System.currentTimeMillis(); + } + + private static class QuestionServerTuple { + private final DNSServer dnsServer; + private final DNSMessage question; + + private QuestionServerTuple(final DNSServer dnsServer, final DNSMessage question) { + this.dnsServer = dnsServer; + this.question = question.asNormalizedVersion(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + QuestionServerTuple that = (QuestionServerTuple) o; + return Objects.equal(dnsServer, that.dnsServer) + && Objects.equal(question, that.question); + } + + @Override + public int hashCode() { + return Objects.hashCode(dnsServer, question); + } + } }