diff --git a/src/main/java/de/measite/minidns/Client.java b/src/main/java/de/measite/minidns/Client.java index 3f1523468..462d52ff2 100644 --- a/src/main/java/de/measite/minidns/Client.java +++ b/src/main/java/de/measite/minidns/Client.java @@ -13,8 +13,6 @@ import java.security.SecureRandom; import java.util.ArrayList; import java.util.Arrays; import java.util.HashSet; -import java.util.LinkedHashMap; -import java.util.Map.Entry; import java.util.Random; import java.util.logging.Level; import java.util.logging.Logger; @@ -48,23 +46,26 @@ public class Client { /** * The internal DNS cache. */ - protected LinkedHashMap cache; + protected DNSCache cache; /** - * Maximum acceptable ttl. + * Create a new DNS client with the given DNS cache. + * @param cache The backend DNS cache. */ - protected long maxTTL = 60 * 60 * 1000; - - /** - * Create a new DNS client. - */ - public Client() { + public Client(DNSCache cache) { try { random = SecureRandom.getInstance("SHA1PRNG"); } catch (NoSuchAlgorithmException e1) { random = new SecureRandom(); } - setCacheSize(10); + this.cache = cache; + } + + /** + * Create a new DNS client. + */ + public Client() { + this(null); } /** @@ -132,19 +133,8 @@ public class Client { */ public DNSMessage query(Question q, String host, int port) throws IOException { DNSMessage dnsMessage = (cache == null) ? null : cache.get(q); - if (dnsMessage != null && dnsMessage.getReceiveTimestamp() > 0l) { - // check the ttl - long ttl = maxTTL; - for (Record r : dnsMessage.getAnswers()) { - ttl = Math.min(ttl, r.ttl); - } - for (Record r : dnsMessage.getAdditionalResourceRecords()) { - ttl = Math.min(ttl, r.ttl); - } - if (dnsMessage.getReceiveTimestamp() + ttl < - System.currentTimeMillis()) { - return dnsMessage; - } + if (dnsMessage != null) { + return dnsMessage; } DNSMessage message = new DNSMessage(); message.setQuestions(new Question[]{q}); @@ -326,52 +316,4 @@ public class Client { return null; } - /** - * Configure the cache size (default 10). - * @param maximumSize The new cache size or 0 to disable. - */ - @SuppressWarnings("serial") - public void setCacheSize(final int maximumSize) { - if (maximumSize == 0) { - this.cache = null; - } else { - LinkedHashMap old = cache; - cache = new LinkedHashMap() { - @Override - protected boolean removeEldestEntry( - Entry eldest) { - return size() > maximumSize; - } - }; - if (old != null) { - cache.putAll(old); - } - } - } - - /** - * Flush the DNS cache. - */ - public void flushCache() { - if (cache != null) { - cache.clear(); - } - } - - /** - * Get the current maximum record ttl. - * @return The maximum record ttl. - */ - public long getMaxTTL() { - return maxTTL; - } - - /** - * Set the maximum record ttl. - * @param maxTTL The new maximum ttl. - */ - public void setMaxTTL(long maxTTL) { - this.maxTTL = maxTTL; - } - } diff --git a/src/main/java/de/measite/minidns/DNSCache.java b/src/main/java/de/measite/minidns/DNSCache.java new file mode 100644 index 000000000..78e6c041f --- /dev/null +++ b/src/main/java/de/measite/minidns/DNSCache.java @@ -0,0 +1,12 @@ +package de.measite.minidns; + +/** + * Cache for DNS Entries. Implementations must be thread safe. + */ +public interface DNSCache { + + void put(Question q, DNSMessage message); + + DNSMessage get(Question q); + +} diff --git a/src/main/java/de/measite/minidns/LRUCache.java b/src/main/java/de/measite/minidns/LRUCache.java new file mode 100644 index 000000000..132bf7946 --- /dev/null +++ b/src/main/java/de/measite/minidns/LRUCache.java @@ -0,0 +1,139 @@ +package de.measite.minidns; + +import java.util.LinkedHashMap; +import java.util.Map.Entry; + +/** + * LRU based DNSCache backed by a LinkedHashMap. + */ +public class LRUCache implements DNSCache { + + /** + * Internal miss count. + */ + protected long missCount = 0l; + + /** + * Internal expire count (subset of misses that was caused by expire). + */ + protected long expireCount = 0l; + + /** + * Internal hit count. + */ + protected long hitCount = 0l; + + /** + * The internal capacity of the backend cache. + */ + protected int capacity; + + /** + * The upper bound of the ttl. All longer TTLs will be capped by this ttl. + */ + protected long maxTTL; + + /** + * The backend cache. + */ + protected LinkedHashMap backend; + + /** + * Create a new LRUCache with given capacity and upper bound ttl. + * @param capacity The internal capacity. + * @param maxTTL The upper bound for any ttl. + */ + @SuppressWarnings("serial") + public LRUCache(final int capacity, final long maxTTL) { + this.capacity = capacity; + this.maxTTL = maxTTL; + backend = new LinkedHashMap( + Math.min(capacity, 11), 0.75f, true) + { + @Override + protected boolean removeEldestEntry( + Entry eldest) { + return size() > capacity; + } + }; + } + + /** + * Create a new LRUCache with given capacity. + * @param capacity + */ + public LRUCache(final int capacity) { + this(capacity, Long.MAX_VALUE); + } + + @Override + public synchronized void put(Question q, DNSMessage message) { + if (message.getReceiveTimestamp() <= 0l) { + return; + } + backend.put(q, message); + } + + @Override + public synchronized DNSMessage get(Question q) { + DNSMessage message = backend.get(q); + if (message == null) { + missCount++; + return null; + } + + long ttl = maxTTL; + for (Record r : message.getAnswers()) { + ttl = Math.min(ttl, r.ttl); + } + for (Record r : message.getAdditionalResourceRecords()) { + ttl = Math.min(ttl, r.ttl); + } + if (message.getReceiveTimestamp() + ttl > System.currentTimeMillis()) { + missCount++; + expireCount++; + backend.remove(q); + return null; + } else { + hitCount++; + return message; + } + } + + /** + * Clear all entries in this cache. + */ + public synchronized void clear() { + backend.clear(); + missCount = 0l; + hitCount = 0l; + expireCount = 0l; + } + + /** + * Get the miss count of this cache which is the number of fruitless + * get calls since this cache was last resetted. + * @return The number of cache misses. + */ + public long getMissCount() { + return missCount; + } + + /** + * The number of expires (cache hits that have had a ttl to low to be + * retrieved). + * @return The expire count. + */ + public long getExpireCount() { + return expireCount; + } + + /** + * The cache hit count (all sucessful calls to get). + * @return The hit count. + */ + public long getHitCount() { + return hitCount; + } + +}