浏览代码

2009-07-03 Tatsuhiro Tsujikawa <[email protected]>

	Try all available addresses returned by DNS until it gets
	connected in HTTP(S)/FTP download
	* src/AbstractCommand.cc
	* src/AbstractCommand.h
	* src/AbstractProxyRequestCommand.cc
	* src/AbstractProxyRequestCommand.h
	* src/DNSCache.h
	* src/DownloadEngine.cc
	* src/DownloadEngine.h
	* src/FtpInitiateConnectionCommand.cc
	* src/FtpInitiateConnectionCommand.h
	* src/FtpNegotiationCommand.cc
	* src/FtpNegotiationCommand.h
	* src/HttpInitiateConnectionCommand.cc
	* src/HttpInitiateConnectionCommand.h
	* src/HttpRequestCommand.cc
	* src/HttpRequestCommand.h
	* src/InitiateConnectionCommand.cc
	* src/InitiateConnectionCommand.h
	* test/DNSCacheTest.cc
	* test/Makefile.am
	* test/SimpleDNSCacheTest.cc
Tatsuhiro Tsujikawa 16 年之前
父节点
当前提交
01fdb2aaeb

+ 25 - 0
ChangeLog

@@ -1,3 +1,28 @@
+2009-07-03  Tatsuhiro Tsujikawa  <[email protected]>
+
+	Try all available addresses returned by DNS until it gets
+	connected in HTTP(S)/FTP download
+	* src/AbstractCommand.cc
+	* src/AbstractCommand.h
+	* src/AbstractProxyRequestCommand.cc
+	* src/AbstractProxyRequestCommand.h
+	* src/DNSCache.h
+	* src/DownloadEngine.cc
+	* src/DownloadEngine.h
+	* src/FtpInitiateConnectionCommand.cc
+	* src/FtpInitiateConnectionCommand.h
+	* src/FtpNegotiationCommand.cc
+	* src/FtpNegotiationCommand.h
+	* src/HttpInitiateConnectionCommand.cc
+	* src/HttpInitiateConnectionCommand.h
+	* src/HttpRequestCommand.cc
+	* src/HttpRequestCommand.h
+	* src/InitiateConnectionCommand.cc
+	* src/InitiateConnectionCommand.h
+	* test/DNSCacheTest.cc
+	* test/Makefile.am
+	* test/SimpleDNSCacheTest.cc
+
 2009-07-01  Tatsuhiro Tsujikawa  <[email protected]>
 
 	Updated doc

+ 19 - 2
src/AbstractCommand.cc

@@ -539,12 +539,28 @@ void AbstractCommand::prepareForNextAction(Command* nextCommand)
   e->setNoWait(true);
 }
 
-void AbstractCommand::checkIfConnectionEstablished
-(const SharedHandle<SocketCore>& socket)
+bool AbstractCommand::checkIfConnectionEstablished
+(const SharedHandle<SocketCore>& socket,
+ const std::string& connectedHostname,
+ const std::string& connectedAddr,
+ uint16_t connectedPort)
 {
   if(socket->isReadable(0)) {
     std::string error = socket->getSocketError();
     if(!error.empty()) {
+      e->markBadIPAddress(connectedHostname, connectedAddr, connectedPort);
+      if(!e->findCachedIPAddress(connectedHostname, connectedPort).empty()) {
+	logger->info("CUID#%d - Could not to connect to %s:%u."
+		     " Trying another address",
+		     cuid, connectedAddr.c_str(), connectedPort);
+	Command* command =
+	  InitiateConnectionCommandFactory::createInitiateConnectionCommand
+	  (cuid, req, _fileEntry, _requestGroup, e);
+	e->setNoWait(true);
+	e->commands.push_back(command);
+	return false;
+      }
+      e->removeCachedIPAddress(connectedHostname, connectedPort);
       // Don't set error if proxy server is used and its method is GET.
       if(resolveProxyMethod(req->getProtocol()) != V_GET ||
 	 !isProxyRequest(req->getProtocol(), getOption())) {
@@ -555,6 +571,7 @@ void AbstractCommand::checkIfConnectionEstablished
 	(StringFormat(MSG_ESTABLISHING_CONNECTION_FAILED, error.c_str()).str());
     }
   }
+  return true;
 }
 
 const std::string& AbstractCommand::resolveProxyMethod

+ 10 - 1
src/AbstractCommand.h

@@ -102,7 +102,16 @@ protected:
 
   void prepareForNextAction(Command* nextCommand = 0);
 
-  void checkIfConnectionEstablished(const SharedHandle<SocketCore>& socket);
+  // Check if socket is connected. If socket is not connected and
+  // there are other addresses to try, command is created using
+  // InitiateConnectionCommandFactory and it is pushed to
+  // DownloadEngine and returns false. If no addresses left, DlRetryEx
+  // exception is thrown.
+  bool checkIfConnectionEstablished
+  (const SharedHandle<SocketCore>& socket,
+   const std::string& connectedHostname,
+   const std::string& connectedAddr,
+   uint16_t connectedPort);
 
   /*
    * Returns true if proxy for the procol indicated by Request::getProtocol()

+ 4 - 0
src/AbstractProxyRequestCommand.cc

@@ -72,6 +72,10 @@ AbstractProxyRequestCommand::~AbstractProxyRequestCommand() {}
 bool AbstractProxyRequestCommand::executeInternal() {
   //socket->setBlockingMode();
   if(httpConnection->sendBufferIsEmpty()) {
+    if(!checkIfConnectionEstablished
+       (socket, _connectedHostname, _connectedAddr, _connectedPort)) {
+      return true;
+    }
     HttpRequestHandle httpRequest(new HttpRequest());
     httpRequest->setUserAgent(getOption()->get(PREF_USER_AGENT));
     httpRequest->setRequest(req);

+ 12 - 0
src/AbstractProxyRequestCommand.h

@@ -48,6 +48,10 @@ protected:
 
   SharedHandle<HttpConnection> httpConnection;
 
+  std::string _connectedHostname;
+  std::string _connectedAddr;
+  uint16_t _connectedPort;
+
   virtual bool executeInternal();
 public:
   AbstractProxyRequestCommand(int cuid,
@@ -61,6 +65,14 @@ public:
   virtual ~AbstractProxyRequestCommand();
 
   virtual Command* getNextCommand() = 0;
+
+  void setConnectedAddr
+  (const std::string& hostname, const std::string& addr, uint16_t port)
+  {
+    _connectedHostname = hostname;
+    _connectedAddr = addr;
+    _connectedPort = port;
+  }
 };
 
 } // namespace aria2

+ 75 - 30
src/DNSCache.h

@@ -38,57 +38,102 @@
 #include "common.h"
 
 #include <string>
-#include <map>
+#include <deque>
+#include <algorithm>
 
 #include "A2STR.h"
 
 namespace aria2 {
 
 class DNSCache {
-public:
-  virtual ~DNSCache() {}
-
-  virtual const std::string& find(const std::string& hostname) const = 0;
-
-  virtual void put(const std::string& hostname, const std::string& ipaddr) = 0;
-};
-
-class SimpleDNSCache : public DNSCache {
 private:
-  std::map<std::string, std::string> _table;
-public:
-  SimpleDNSCache() {}
+  struct CacheEntry {
+    std::string _hostname;
+    std::string _addr;
+    uint16_t _port;
+    bool _good;
+    CacheEntry
+    (const std::string& hostname, const std::string& addr, uint16_t port):
+      _hostname(hostname), _addr(addr), _port(port), _good(true) {}
+
+    void markBad() { _good = false; }
+
+    bool operator<(const CacheEntry& e) const
+    {
+      int r = _hostname.compare(e._hostname);
+      if(r != 0) {
+	return r < 0;
+      }
+      if(_port != e._port) {
+	return _port < e._port;
+      }
+      return _addr < e._addr;
+    }
+
+    bool operator==(const CacheEntry& e) const
+    {
+      return _hostname == e._hostname && _addr == e._addr && _port == e._port;
+    }
+  };
 
-  virtual ~SimpleDNSCache() {}
+  std::deque<CacheEntry> _entries;
 
-  virtual const std::string& find(const std::string& hostname) const
+  std::deque<CacheEntry>::iterator findEntry
+  (const std::string& hostname, const std::string& ipaddr, uint16_t port)
   {
-    std::map<std::string, std::string>::const_iterator i =
-      _table.find(hostname);
-    if(i == _table.end()) {
-      return A2STR::NIL;
+    CacheEntry target(hostname, ipaddr, port);
+    std::deque<CacheEntry>::iterator i =
+      std::lower_bound(_entries.begin(), _entries.end(), target);
+    if(i != _entries.end() && (*i) == target) {
+      return i;
     } else {
-      return (*i).second;
+      return _entries.end();
     }
   }
 
-  virtual void put(const std::string& hostname, const std::string& ipaddr)
+public:
+  const std::string& find(const std::string& hostname, uint16_t port) const
   {
-    _table[hostname] = ipaddr;
+    CacheEntry target(hostname, A2STR::NIL, port);
+    std::deque<CacheEntry>::const_iterator i =
+      std::lower_bound(_entries.begin(), _entries.end(), target);
+    for(; i != _entries.end() && (*i)._hostname == hostname && (*i)._port == port; ++i) {
+      if((*i)._good) {
+	return (*i)._addr;
+      }
+    }
+    return A2STR::NIL;
   }
-  
-};
 
-class NullDNSCache : public DNSCache {
-public:
-  virtual ~NullDNSCache() {}
+  void put
+  (const std::string& hostname, const std::string& ipaddr, uint16_t port)
+  {
+    CacheEntry target(hostname, ipaddr, port);
+    std::deque<CacheEntry>::iterator i =
+      std::lower_bound(_entries.begin(), _entries.end(), target);
+    if(i == _entries.end() || !((*i) == target)) {
+      _entries.insert(i, target);
+    }
+  }
 
-  virtual const std::string& find(const std::string& hostname)
+  void markBad
+  (const std::string& hostname, const std::string& ipaddr, uint16_t port)
   {
-    return A2STR::NIL;
+    std::deque<CacheEntry>::iterator i = findEntry(hostname, ipaddr, port);
+    if(i != _entries.end()) {
+      (*i).markBad();
+    }
   }
 
-  virtual void put(const std::string& hostname, const std::string& ipaddr) {}
+  void remove(const std::string& hostname, uint16_t port)
+  {
+    CacheEntry target(hostname, A2STR::NIL, port);
+    std::deque<CacheEntry>::iterator i =
+      std::lower_bound(_entries.begin(), _entries.end(), target);
+    for(; i != _entries.end() && (*i)._hostname == hostname && (*i)._port == port;) {
+      i = _entries.erase(i);
+    }
+  }
 };
 
 } // namespace aria2

+ 17 - 5
src/DownloadEngine.cc

@@ -92,7 +92,7 @@ DownloadEngine::DownloadEngine(const SharedHandle<EventPoll>& eventPoll):
   _refreshInterval(DEFAULT_REFRESH_INTERVAL),
   _cookieStorage(new CookieStorage()),
   _btRegistry(new BtRegistry()),
-  _dnsCache(new SimpleDNSCache())
+  _dnsCache(new DNSCache())
 {}
 
 DownloadEngine::~DownloadEngine() {
@@ -446,15 +446,27 @@ cuid_t DownloadEngine::newCUID()
 }
 
 const std::string& DownloadEngine::findCachedIPAddress
-(const std::string& hostname) const
+(const std::string& hostname, uint16_t port) const
 {
-  return _dnsCache->find(hostname);
+  return _dnsCache->find(hostname, port);
 }
 
 void DownloadEngine::cacheIPAddress
-(const std::string& hostname, const std::string& ipaddr)
+(const std::string& hostname, const std::string& ipaddr, uint16_t port)
 {
-  _dnsCache->put(hostname, ipaddr);
+  _dnsCache->put(hostname, ipaddr, port);
+}
+
+void DownloadEngine::markBadIPAddress
+(const std::string& hostname, const std::string& ipaddr, uint16_t port)
+{
+  _dnsCache->markBad(hostname, ipaddr, port);
+}
+
+void DownloadEngine::removeCachedIPAddress
+(const std::string& hostname, uint16_t port)
+{
+  _dnsCache->remove(hostname, port);
 }
 
 void DownloadEngine::setAuthConfigFactory

+ 9 - 2
src/DownloadEngine.h

@@ -244,9 +244,16 @@ public:
 
   cuid_t newCUID();
 
-  const std::string& findCachedIPAddress(const std::string& hostname) const;
+  const std::string& findCachedIPAddress
+  (const std::string& hostname, uint16_t port) const;
 
-  void cacheIPAddress(const std::string& hostname, const std::string& ipaddr);
+  void cacheIPAddress
+  (const std::string& hostname, const std::string& ipaddr, uint16_t port);
+
+  void markBadIPAddress
+  (const std::string& hostname, const std::string& ipaddr, uint16_t port);
+
+  void removeCachedIPAddress(const std::string& hostname, uint16_t port);
 
   void setAuthConfigFactory(const SharedHandle<AuthConfigFactory>& factory);
 

+ 18 - 13
src/FtpInitiateConnectionCommand.cc

@@ -65,7 +65,8 @@ FtpInitiateConnectionCommand::FtpInitiateConnectionCommand
 FtpInitiateConnectionCommand::~FtpInitiateConnectionCommand() {}
 
 Command* FtpInitiateConnectionCommand::createNextCommand
-(const std::deque<std::string>& resolvedAddresses,
+(const std::string& hostname, const std::string& addr, uint16_t port,
+ const std::deque<std::string>& resolvedAddresses,
  const SharedHandle<Request>& proxyRequest)
 {
   Command* command;
@@ -75,11 +76,9 @@ Command* FtpInitiateConnectionCommand::createNextCommand
       e->popPooledSocket(options, req->getHost(), req->getPort());
     std::string proxyMethod = resolveProxyMethod(req->getProtocol());
     if(pooledSocket.isNull()) {
-      logger->info(MSG_CONNECTING_TO_SERVER, cuid,
-		   proxyRequest->getHost().c_str(), proxyRequest->getPort());
+      logger->info(MSG_CONNECTING_TO_SERVER, cuid, addr.c_str(), port);
       socket.reset(new SocketCore());
-      socket->establishConnection(resolvedAddresses.front(),
-				  proxyRequest->getPort());
+      socket->establishConnection(addr, port);
       
       if(proxyMethod == V_GET) {
 	// Use GET for FTP via HTTP proxy.
@@ -90,12 +89,16 @@ Command* FtpInitiateConnectionCommand::createNextCommand
 	HttpRequestCommand* c =
 	  new HttpRequestCommand(cuid, req, _fileEntry,
 				 _requestGroup, hc, e, socket);
+	c->setConnectedAddr(hostname, addr, port);
 	c->setProxyRequest(proxyRequest);
 	command = c;
       } else if(proxyMethod == V_TUNNEL) {
-	command = new FtpTunnelRequestCommand(cuid, req, _fileEntry,
-					      _requestGroup, e,
-					      proxyRequest, socket);
+	FtpTunnelRequestCommand* c =
+	  new FtpTunnelRequestCommand(cuid, req, _fileEntry,
+				      _requestGroup, e,
+				      proxyRequest, socket);
+	c->setConnectedAddr(hostname, addr, port);
+	command = c;
       } else {
 	// TODO
 	throw DL_ABORT_EX("ERROR");
@@ -128,12 +131,14 @@ Command* FtpInitiateConnectionCommand::createNextCommand
     SharedHandle<SocketCore> pooledSocket =
       e->popPooledSocket(options, resolvedAddresses, req->getPort());
     if(pooledSocket.isNull()) {
-      logger->info(MSG_CONNECTING_TO_SERVER, cuid, req->getHost().c_str(),
-		   req->getPort());
+      logger->info(MSG_CONNECTING_TO_SERVER, cuid, addr.c_str(), port);
       socket.reset(new SocketCore());
-      socket->establishConnection(resolvedAddresses.front(), req->getPort());
-      command = new FtpNegotiationCommand(cuid, req, _fileEntry,
-					  _requestGroup, e, socket);
+      socket->establishConnection(addr, port);
+      FtpNegotiationCommand* c =
+	new FtpNegotiationCommand(cuid, req, _fileEntry,
+				  _requestGroup, e, socket);
+      c->setConnectedAddr(hostname, addr, port);
+      command = c;
     } else {
       command =
 	new FtpNegotiationCommand(cuid, req, _fileEntry,

+ 2 - 1
src/FtpInitiateConnectionCommand.h

@@ -42,7 +42,8 @@ namespace aria2 {
 class FtpInitiateConnectionCommand : public InitiateConnectionCommand {
 protected:
   virtual Command* createNextCommand
-  (const std::deque<std::string>& resolvedAddresses,
+  (const std::string& hostname, const std::string& addr, uint16_t port,
+   const std::deque<std::string>& resolvedAddresses,
    const SharedHandle<Request>& proxyRequest);
 public:
   FtpInitiateConnectionCommand(int cuid, const SharedHandle<Request>& req,

+ 7 - 1
src/FtpNegotiationCommand.cc

@@ -118,6 +118,8 @@ bool FtpNegotiationCommand::executeInternal() {
       sequence = SEQ_PREPARE_SERVER_SOCKET;
     }
     return false;
+  } else if(sequence == SEQ_EXIT) {
+    return true;
   } else {
     e->commands.push_back(this);
     return false;
@@ -125,7 +127,11 @@ bool FtpNegotiationCommand::executeInternal() {
 }
 
 bool FtpNegotiationCommand::recvGreeting() {
-  checkIfConnectionEstablished(socket);
+  if(!checkIfConnectionEstablished
+     (socket, _connectedHostname, _connectedAddr, _connectedPort)) {
+    sequence = SEQ_EXIT;
+    return false;
+  }
   setTimeout(_requestGroup->getTimeout());
   //socket->setBlockingMode();
   disableWriteCheckSocket();

+ 13 - 0
src/FtpNegotiationCommand.h

@@ -76,6 +76,7 @@ public:
     SEQ_HEAD_OK,
     SEQ_DOWNLOAD_ALREADY_COMPLETED,
     SEQ_FILE_PREPARATION, // File allocation after SIZE command
+    SEQ_EXIT
   };
 private:
   bool recvGreeting();
@@ -118,6 +119,10 @@ private:
   SharedHandle<SocketCore> serverSocket;
   Seq sequence;
   SharedHandle<FtpConnection> ftp;
+
+  std::string _connectedHostname;
+  std::string _connectedAddr;
+  uint16_t _connectedPort;
 protected:
   virtual bool executeInternal();
 public:
@@ -130,6 +135,14 @@ public:
 			Seq seq = SEQ_RECV_GREETING,
 			const std::string& baseWorkingDir = "/");
   virtual ~FtpNegotiationCommand();
+
+  void setConnectedAddr
+  (const std::string& hostname, const std::string& addr, uint16_t port)
+  {
+    _connectedHostname = hostname;
+    _connectedAddr = addr;
+    _connectedPort = port;
+  }
 };
 
 } // namespace aria2

+ 20 - 13
src/HttpInitiateConnectionCommand.cc

@@ -62,7 +62,8 @@ HttpInitiateConnectionCommand::HttpInitiateConnectionCommand
 HttpInitiateConnectionCommand::~HttpInitiateConnectionCommand() {}
 
 Command* HttpInitiateConnectionCommand::createNextCommand
-(const std::deque<std::string>& resolvedAddresses,
+(const std::string& hostname, const std::string& addr, uint16_t port,
+ const std::deque<std::string>& resolvedAddresses,
  const SharedHandle<Request>& proxyRequest)
 {
   Command* command;
@@ -71,16 +72,17 @@ Command* HttpInitiateConnectionCommand::createNextCommand
       e->popPooledSocket(req->getHost(), req->getPort());
     std::string proxyMethod = resolveProxyMethod(req->getProtocol());
     if(pooledSocket.isNull()) {
-      logger->info(MSG_CONNECTING_TO_SERVER, cuid,
-		   proxyRequest->getHost().c_str(), proxyRequest->getPort());
+      logger->info(MSG_CONNECTING_TO_SERVER, cuid, addr.c_str(), port);
       socket.reset(new SocketCore());
-      socket->establishConnection(resolvedAddresses.front(),
-				  proxyRequest->getPort());
+      socket->establishConnection(addr, port);
 
       if(proxyMethod == V_TUNNEL) {
-	command = new HttpProxyRequestCommand(cuid, req, _fileEntry,
-					      _requestGroup, e,
-					      proxyRequest, socket);
+	HttpProxyRequestCommand* c =
+	  new HttpProxyRequestCommand(cuid, req, _fileEntry,
+				      _requestGroup, e,
+				      proxyRequest, socket);
+	c->setConnectedAddr(hostname, addr, port);
+	command = c;
       } else if(proxyMethod == V_GET) {
 	SharedHandle<HttpConnection> httpConnection
 	  (new HttpConnection(cuid, socket, getOption().get()));
@@ -89,6 +91,7 @@ Command* HttpInitiateConnectionCommand::createNextCommand
 						       _requestGroup,
 						       httpConnection, e,
 						       socket);
+	c->setConnectedAddr(hostname, addr, port);
 	c->setProxyRequest(proxyRequest);
 	command = c;
       } else {
@@ -112,16 +115,20 @@ Command* HttpInitiateConnectionCommand::createNextCommand
     SharedHandle<SocketCore> pooledSocket =
       e->popPooledSocket(resolvedAddresses, req->getPort());
     if(pooledSocket.isNull()) {
-      logger->info(MSG_CONNECTING_TO_SERVER, cuid, req->getHost().c_str(),
-		   req->getPort());
+      logger->info(MSG_CONNECTING_TO_SERVER, cuid, addr.c_str(), port);
       socket.reset(new SocketCore());
-      socket->establishConnection(resolvedAddresses.front(), req->getPort());
+      socket->establishConnection(addr, port);
     } else {
       socket = pooledSocket;
     }
     SharedHandle<HttpConnection> httpConnection(new HttpConnection(cuid, socket, getOption().get()));
-    command = new HttpRequestCommand(cuid, req, _fileEntry, _requestGroup,
-				     httpConnection, e, socket);
+    HttpRequestCommand* c =
+      new HttpRequestCommand(cuid, req, _fileEntry, _requestGroup,
+			     httpConnection, e, socket);
+    if(pooledSocket.isNull()) {
+      c->setConnectedAddr(hostname, addr, port);
+    }
+    command = c;
   }
   return command;
 }

+ 2 - 1
src/HttpInitiateConnectionCommand.h

@@ -42,7 +42,8 @@ namespace aria2 {
 class HttpInitiateConnectionCommand : public InitiateConnectionCommand {
 protected:
   virtual Command* createNextCommand
-  (const std::deque<std::string>& resolvedAddresses,
+  (const std::string& hostname, const std::string& addr, uint16_t port,
+   const std::deque<std::string>& resolvedAddresses,
    const SharedHandle<Request>& proxyRequest);
 public:
   HttpInitiateConnectionCommand(int cuid, const SharedHandle<Request>& req,

+ 4 - 1
src/HttpRequestCommand.cc

@@ -112,7 +112,10 @@ bool HttpRequestCommand::executeInternal() {
     }
   }
   if(_httpConnection->sendBufferIsEmpty()) {
-    checkIfConnectionEstablished(socket);
+    if(!checkIfConnectionEstablished
+       (socket, _connectedHostname, _connectedAddr, _connectedPort)) {
+      return true;
+    }
 
     if(_segments.empty()) {
       HttpRequestHandle httpRequest

+ 12 - 0
src/HttpRequestCommand.h

@@ -47,6 +47,10 @@ private:
   SharedHandle<Request> _proxyRequest;
 
   SharedHandle<HttpConnection> _httpConnection;
+
+  std::string _connectedHostname;
+  std::string _connectedAddr;
+  uint16_t _connectedPort;
 protected:
   virtual bool executeInternal();
 public:
@@ -60,6 +64,14 @@ public:
   virtual ~HttpRequestCommand();
 
   void setProxyRequest(const SharedHandle<Request>& proxyRequest);
+
+  void setConnectedAddr
+  (const std::string& hostname, const std::string& addr, uint16_t port)
+  {
+    _connectedHostname = hostname;
+    _connectedAddr = addr;
+    _connectedPort = port;
+  }
 };
 
 } // namespace aria2

+ 13 - 5
src/InitiateConnectionCommand.cc

@@ -46,6 +46,7 @@
 #include "RequestGroup.h"
 #include "DownloadContext.h"
 #include "Segment.h"
+#include "a2functional.h"
 
 namespace aria2 {
 
@@ -68,14 +69,17 @@ InitiateConnectionCommand::~InitiateConnectionCommand() {}
 
 bool InitiateConnectionCommand::executeInternal() {
   std::string hostname;
+  uint16_t port;
   SharedHandle<Request> proxyRequest = createProxyRequest();
   if(proxyRequest.isNull()) {
     hostname = req->getHost();
+    port = req->getPort();
   } else {
     hostname = proxyRequest->getHost();
+    port = proxyRequest->getPort();
   }
   std::deque<std::string> addrs;
-  std::string ipaddr = e->findCachedIPAddress(hostname);
+  std::string ipaddr = e->findCachedIPAddress(hostname, port);
   if(ipaddr.empty()) {
 #ifdef ENABLE_ASYNC_DNS
     if(getOption()->getAsBool(PREF_ASYNC_DNS)) {
@@ -97,14 +101,18 @@ bool InitiateConnectionCommand::executeInternal() {
       }
     logger->info(MSG_NAME_RESOLUTION_COMPLETE, cuid,
 		 hostname.c_str(),
-		 addrs.front().c_str());
-    e->cacheIPAddress(hostname, addrs.front());
+		 strjoin(addrs.begin(), addrs.end(), ",").c_str());
+    for(std::deque<std::string>::const_iterator i = addrs.begin();
+	i != addrs.end(); ++i) {
+      e->cacheIPAddress(hostname, *i, port);
+    }
+    ipaddr = e->findCachedIPAddress(hostname, port);
   } else {
     logger->info(MSG_DNS_CACHE_HIT, cuid, hostname.c_str(), ipaddr.c_str());
     addrs.push_back(ipaddr);
   }
-
-  Command* command = createNextCommand(addrs, proxyRequest);
+  Command* command = createNextCommand(hostname, ipaddr, port,
+				       addrs, proxyRequest);
   e->commands.push_back(command);
   return true;
 }

+ 8 - 1
src/InitiateConnectionCommand.h

@@ -49,8 +49,15 @@ protected:
    */
   virtual bool executeInternal();
 
+  // hostname and port are the hostname and port number we are going
+  // to connect. If proxy server is used, these values are hostname
+  // and port of proxy server. addr is one of resolved address and we
+  // use this address this time.  resolvedAddresses are all addresses
+  // resolved.  proxyRequest is set if we are going to use proxy
+  // server.
   virtual Command* createNextCommand
-  (const std::deque<std::string>& resolvedAddresses,
+  (const std::string& hostname, const std::string& addr, uint16_t port,
+   const std::deque<std::string>& resolvedAddresses,
    const SharedHandle<Request>& proxyRequest) = 0;
 public:
   InitiateConnectionCommand(int cuid, const SharedHandle<Request>& req,

+ 64 - 0
test/DNSCacheTest.cc

@@ -0,0 +1,64 @@
+#include "DNSCache.h"
+
+#include <cppunit/extensions/HelperMacros.h>
+
+namespace aria2 {
+
+class DNSCacheTest:public CppUnit::TestFixture {
+
+  CPPUNIT_TEST_SUITE(DNSCacheTest);
+  CPPUNIT_TEST(testFind);
+  CPPUNIT_TEST(testMarkBad);
+  CPPUNIT_TEST(testPutBadAddr);
+  CPPUNIT_TEST(testRemove);
+  CPPUNIT_TEST_SUITE_END();
+
+  DNSCache _cache;
+public:
+  void setUp()
+  {
+    _cache = DNSCache();
+    _cache.put("www", "192.168.0.1", 80);
+    _cache.put("www", "::1", 80);
+    _cache.put("ftp", "192.168.0.1", 21);
+    _cache.put("proxy", "192.168.1.2", 8080);
+  }
+
+  void testFind();
+  void testMarkBad();
+  void testPutBadAddr();
+  void testRemove();
+};
+
+
+CPPUNIT_TEST_SUITE_REGISTRATION(DNSCacheTest);
+
+void DNSCacheTest::testFind()
+{
+  CPPUNIT_ASSERT_EQUAL(std::string("192.168.0.1"), _cache.find("www", 80));
+  CPPUNIT_ASSERT_EQUAL(std::string("192.168.0.1"), _cache.find("ftp", 21));
+  CPPUNIT_ASSERT_EQUAL(std::string("192.168.1.2"), _cache.find("proxy", 8080));
+  CPPUNIT_ASSERT_EQUAL(std::string(""), _cache.find("www", 8080));
+  CPPUNIT_ASSERT_EQUAL(std::string(""), _cache.find("another", 80));
+}
+
+void DNSCacheTest::testMarkBad()
+{
+  _cache.markBad("www", "192.168.0.1", 80);
+  CPPUNIT_ASSERT_EQUAL(std::string("::1"), _cache.find("www", 80));
+}
+
+void DNSCacheTest::testPutBadAddr()
+{
+  _cache.markBad("www", "192.168.0.1", 80);
+  _cache.put("www", "192.168.0.1", 80);
+  CPPUNIT_ASSERT_EQUAL(std::string("::1"), _cache.find("www", 80));
+}
+
+void DNSCacheTest::testRemove()
+{
+  _cache.remove("www", 80);
+  CPPUNIT_ASSERT_EQUAL(std::string(""), _cache.find("www", 80));
+}
+
+} // namespace aria2

+ 1 - 1
test/Makefile.am

@@ -60,7 +60,7 @@ aria2c_SOURCES = AllTest.cc\
 	TimeTest.cc\
 	FtpConnectionTest.cc\
 	OptionParserTest.cc\
-	SimpleDNSCacheTest.cc\
+	DNSCacheTest.cc\
 	DownloadHelperTest.cc\
 	SequentialPickerTest.cc\
 	RarestPieceSelectorTest.cc\

+ 14 - 16
test/Makefile.in

@@ -194,12 +194,11 @@ am__aria2c_SOURCES_DIST = AllTest.cc TestUtil.cc TestUtil.h \
 	ServerStatTest.cc NsCookieParserTest.cc \
 	DirectDiskAdaptorTest.cc CookieTest.cc CookieStorageTest.cc \
 	TimeTest.cc FtpConnectionTest.cc OptionParserTest.cc \
-	SimpleDNSCacheTest.cc DownloadHelperTest.cc \
-	SequentialPickerTest.cc RarestPieceSelectorTest.cc \
-	PieceStatManTest.cc InOrderPieceSelector.h \
-	LongestSequencePieceSelectorTest.cc a2algoTest.cc \
-	bitfieldTest.cc BDETest.cc DownloadContextTest.cc \
-	XmlRpcRequestParserControllerTest.cc \
+	DNSCacheTest.cc DownloadHelperTest.cc SequentialPickerTest.cc \
+	RarestPieceSelectorTest.cc PieceStatManTest.cc \
+	InOrderPieceSelector.h LongestSequencePieceSelectorTest.cc \
+	a2algoTest.cc bitfieldTest.cc BDETest.cc \
+	DownloadContextTest.cc XmlRpcRequestParserControllerTest.cc \
 	XmlRpcRequestProcessorTest.cc XmlRpcMethodTest.cc \
 	FallocFileAllocationIteratorTest.cc GZipDecoderTest.cc \
 	Sqlite3MozCookieParserTest.cc MessageDigestHelperTest.cc \
@@ -369,7 +368,7 @@ am_aria2c_OBJECTS = AllTest.$(OBJEXT) TestUtil.$(OBJEXT) \
 	NsCookieParserTest.$(OBJEXT) DirectDiskAdaptorTest.$(OBJEXT) \
 	CookieTest.$(OBJEXT) CookieStorageTest.$(OBJEXT) \
 	TimeTest.$(OBJEXT) FtpConnectionTest.$(OBJEXT) \
-	OptionParserTest.$(OBJEXT) SimpleDNSCacheTest.$(OBJEXT) \
+	OptionParserTest.$(OBJEXT) DNSCacheTest.$(OBJEXT) \
 	DownloadHelperTest.$(OBJEXT) SequentialPickerTest.$(OBJEXT) \
 	RarestPieceSelectorTest.$(OBJEXT) PieceStatManTest.$(OBJEXT) \
 	LongestSequencePieceSelectorTest.$(OBJEXT) \
@@ -599,14 +598,13 @@ aria2c_SOURCES = AllTest.cc TestUtil.cc TestUtil.h SocketCoreTest.cc \
 	ServerStatTest.cc NsCookieParserTest.cc \
 	DirectDiskAdaptorTest.cc CookieTest.cc CookieStorageTest.cc \
 	TimeTest.cc FtpConnectionTest.cc OptionParserTest.cc \
-	SimpleDNSCacheTest.cc DownloadHelperTest.cc \
-	SequentialPickerTest.cc RarestPieceSelectorTest.cc \
-	PieceStatManTest.cc InOrderPieceSelector.h \
-	LongestSequencePieceSelectorTest.cc a2algoTest.cc \
-	bitfieldTest.cc BDETest.cc DownloadContextTest.cc \
-	$(am__append_1) $(am__append_2) $(am__append_3) \
-	$(am__append_4) $(am__append_5) $(am__append_6) \
-	$(am__append_7)
+	DNSCacheTest.cc DownloadHelperTest.cc SequentialPickerTest.cc \
+	RarestPieceSelectorTest.cc PieceStatManTest.cc \
+	InOrderPieceSelector.h LongestSequencePieceSelectorTest.cc \
+	a2algoTest.cc bitfieldTest.cc BDETest.cc \
+	DownloadContextTest.cc $(am__append_1) $(am__append_2) \
+	$(am__append_3) $(am__append_4) $(am__append_5) \
+	$(am__append_6) $(am__append_7)
 
 #aria2c_CXXFLAGS = ${CPPUNIT_CFLAGS} -I../src -I../lib -Wall -D_FILE_OFFSET_BITS=64
 #aria2c_LDFLAGS = ${CPPUNIT_LIBS}
@@ -759,6 +757,7 @@ distclean-compile:
 @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/DHTTokenTrackerTest.Po@am__quote@
 @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/DHTUnknownMessageTest.Po@am__quote@
 @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/DHTUtilTest.Po@am__quote@
+@AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/DNSCacheTest.Po@am__quote@
 @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/DefaultAuthResolverTest.Po@am__quote@
 @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/DefaultBtAnnounceTest.Po@am__quote@
 @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/DefaultBtMessageDispatcherTest.Po@am__quote@
@@ -830,7 +829,6 @@ distclean-compile:
 @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/ShareRatioSeedCriteriaTest.Po@am__quote@
 @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/SharedHandleTest.Po@am__quote@
 @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/SignatureTest.Po@am__quote@
-@AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/SimpleDNSCacheTest.Po@am__quote@
 @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/SingleFileAllocationIteratorTest.Po@am__quote@
 @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/SingletonHolderTest.Po@am__quote@
 @AMDEP_TRUE@@am__include@ @am__quote@./$(DEPDIR)/SocketCoreTest.Po@am__quote@

+ 0 - 35
test/SimpleDNSCacheTest.cc

@@ -1,35 +0,0 @@
-#include "DNSCache.h"
-
-#include <iostream>
-
-#include <cppunit/extensions/HelperMacros.h>
-
-#include "Exception.h"
-#include "Util.h"
-
-namespace aria2 {
-
-class SimpleDNSCacheTest:public CppUnit::TestFixture {
-
-  CPPUNIT_TEST_SUITE(SimpleDNSCacheTest);
-  CPPUNIT_TEST(testFind);
-  CPPUNIT_TEST_SUITE_END();
-public:
-  void testFind();
-};
-
-
-CPPUNIT_TEST_SUITE_REGISTRATION(SimpleDNSCacheTest);
-
-void SimpleDNSCacheTest::testFind()
-{
-  SimpleDNSCache cache;
-  cache.put("host1", "192.168.0.1");
-  cache.put("host2", "192.168.1.2");
-
-  CPPUNIT_ASSERT_EQUAL(std::string("192.168.0.1"), cache.find("host1"));
-  CPPUNIT_ASSERT_EQUAL(std::string("192.168.1.2"), cache.find("host2"));
-  CPPUNIT_ASSERT_EQUAL(std::string(""), cache.find("host3"));
-}
-
-} // namespace aria2