Index: net/dns/dns_test_util.cc |
diff --git a/net/dns/dns_test_util.cc b/net/dns/dns_test_util.cc |
index 6f94b7f9e29873e003b81abd6437cf5399ba9518..01219af867c0f3a9ab4ad0d7da310ed507ec7804 100644 |
--- a/net/dns/dns_test_util.cc |
+++ b/net/dns/dns_test_util.cc |
@@ -8,6 +8,7 @@ |
#include "base/big_endian.h" |
#include "base/bind.h" |
+#include "base/callback.h" |
#include "base/location.h" |
#include "base/memory/weak_ptr.h" |
#include "base/single_thread_task_runner.h" |
@@ -43,8 +44,7 @@ class MockTransaction : public DnsTransaction, |
const std::string& hostname, |
uint16_t qtype, |
const DnsTransactionFactory::CallbackType& callback) |
- : result_(MockDnsClientRule::FAIL), |
- hostname_(hostname), |
+ : hostname_(hostname), |
qtype_(qtype), |
callback_(callback), |
started_(false), |
@@ -55,7 +55,7 @@ class MockTransaction : public DnsTransaction, |
if ((rules[i].qtype == qtype) && |
(hostname.size() >= prefix.size()) && |
(hostname.compare(0, prefix.size(), prefix) == 0)) { |
- result_ = rules[i].result; |
+ response_callback_ = rules[i].response_callback; |
delayed_ = rules[i].delay; |
break; |
} |
@@ -86,70 +86,37 @@ class MockTransaction : public DnsTransaction, |
private: |
void Finish() { |
- switch (result_) { |
- case MockDnsClientRule::EMPTY: |
- case MockDnsClientRule::OK: { |
- std::string qname; |
- DNSDomainFromDot(hostname_, &qname); |
- DnsQuery query(0, qname, qtype_); |
- |
- DnsResponse response; |
- char* buffer = response.io_buffer()->data(); |
- int nbytes = query.io_buffer()->size(); |
- memcpy(buffer, query.io_buffer()->data(), nbytes); |
- dns_protocol::Header* header = |
- reinterpret_cast<dns_protocol::Header*>(buffer); |
- header->flags |= dns_protocol::kFlagResponse; |
- |
- if (MockDnsClientRule::OK == result_) { |
- const uint16_t kPointerToQueryName = |
- static_cast<uint16_t>(0xc000 | sizeof(*header)); |
- |
- const uint32_t kTTL = 86400; // One day. |
- |
- // Size of RDATA which is a IPv4 or IPv6 address. |
- size_t rdata_size = qtype_ == dns_protocol::kTypeA |
- ? IPAddress::kIPv4AddressSize |
- : IPAddress::kIPv6AddressSize; |
- |
- // 12 is the sum of sizes of the compressed name reference, TYPE, |
- // CLASS, TTL and RDLENGTH. |
- size_t answer_size = 12 + rdata_size; |
- |
- // Write answer with loopback IP address. |
- header->ancount = base::HostToNet16(1); |
- base::BigEndianWriter writer(buffer + nbytes, answer_size); |
- writer.WriteU16(kPointerToQueryName); |
- writer.WriteU16(qtype_); |
- writer.WriteU16(dns_protocol::kClassIN); |
- writer.WriteU32(kTTL); |
- writer.WriteU16(static_cast<uint16_t>(rdata_size)); |
- if (qtype_ == dns_protocol::kTypeA) { |
- char kIPv4Loopback[] = { 0x7f, 0, 0, 1 }; |
- writer.WriteBytes(kIPv4Loopback, sizeof(kIPv4Loopback)); |
- } else { |
- char kIPv6Loopback[] = { 0, 0, 0, 0, 0, 0, 0, 0, |
- 0, 0, 0, 0, 0, 0, 0, 1 }; |
- writer.WriteBytes(kIPv6Loopback, sizeof(kIPv6Loopback)); |
- } |
- nbytes += answer_size; |
- } |
- EXPECT_TRUE(response.InitParse(nbytes, query)); |
- callback_.Run(this, OK, &response); |
- } break; |
- case MockDnsClientRule::FAIL: |
- callback_.Run(this, ERR_NAME_NOT_RESOLVED, NULL); |
- break; |
- case MockDnsClientRule::TIMEOUT: |
- callback_.Run(this, ERR_DNS_TIMED_OUT, NULL); |
- break; |
- default: |
- NOTREACHED(); |
- break; |
+ if (response_callback_.is_null()) { |
+ callback_.Run(this, ERR_NAME_NOT_RESOLVED, nullptr); |
+ return; |
+ } |
+ |
+ std::string qname; |
+ DNSDomainFromDot(hostname_, &qname); |
+ DnsQuery query(0, qname, qtype_); |
+ |
+ DnsResponse response; |
+ IOBufferWithSize* buffer = response.io_buffer(); |
+ int query_size = query.io_buffer()->size(); |
+ CHECK_GE(buffer->size(), query_size); |
+ memcpy(buffer->data(), query.io_buffer()->data(), query_size); |
+ dns_protocol::Header* header = |
+ reinterpret_cast<dns_protocol::Header*>(buffer->data()); |
+ header->flags |= dns_protocol::kFlagResponse; |
+ |
+ base::BigEndianWriter answer_writer(buffer->data() + query_size, |
+ buffer->size() - query_size); |
+ int net_error = response_callback_.Run(header, &answer_writer); |
+ if (net_error == OK) { |
+ int nbytes = answer_writer.ptr() - buffer->data(); |
+ EXPECT_TRUE(response.InitParse(nbytes, query)); |
+ callback_.Run(this, OK, &response); |
+ } else { |
+ callback_.Run(this, net_error, nullptr); |
} |
} |
- MockDnsClientRule::Result result_; |
+ MockDnsClientRule::ResponseCallback response_callback_; |
const std::string hostname_; |
const uint16_t qtype_; |
DnsTransactionFactory::CallbackType callback_; |
@@ -157,6 +124,48 @@ class MockTransaction : public DnsTransaction, |
bool delayed_; |
}; |
+// Simply returns the |net_error| argument. |
+// Useful as a simple callback that does nothing but reports an error. |
+int ReturnNetError(int net_error, |
+ dns_protocol::Header* response_header, |
+ base::BigEndianWriter* answer_writer) { |
+ CHECK_LE(net_error, 0); |
+ return net_error; |
+} |
+ |
+// Writes a |qtype| record for the loopback address using |answer_writer|. |
+// |qtype| must be |dns_protocol::kTypeA| or |dns_protocol::kTypeAAAA|. |
+// Returns net::OK if successful. |
+int WriteLoopbackRecordResponse(uint16_t qtype, |
+ dns_protocol::Header* response_header, |
+ base::BigEndianWriter* answer_writer) { |
+ const uint16_t kPointerToQueryName = |
+ static_cast<uint16_t>(0xc000 | sizeof(*response_header)); |
+ |
+ const uint32_t kTTL = 86400; // One day. |
+ |
+ // Write answer with loopback IP address. |
+ response_header->ancount = |
+ base::HostToNet16(base::NetToHost16(response_header->ancount) + 1); |
+ CHECK(answer_writer->WriteU16(kPointerToQueryName)); |
+ CHECK(answer_writer->WriteU16(qtype)); |
+ CHECK(answer_writer->WriteU16(dns_protocol::kClassIN)); |
+ CHECK(answer_writer->WriteU32(kTTL)); |
+ if (qtype == dns_protocol::kTypeA) { |
+ char kIPv4Loopback[] = {0x7f, 0, 0, 1}; |
+ CHECK(answer_writer->WriteU16(IPAddress::kIPv4AddressSize)); |
+ CHECK(answer_writer->WriteBytes(kIPv4Loopback, sizeof(kIPv4Loopback))); |
+ } else if (qtype == dns_protocol::kTypeAAAA) { |
+ char kIPv6Loopback[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}; |
+ CHECK(answer_writer->WriteU16(IPAddress::kIPv6AddressSize)); |
+ CHECK(answer_writer->WriteBytes(kIPv6Loopback, sizeof(kIPv6Loopback))); |
+ } else { |
+ NOTREACHED(); |
+ } |
+ |
+ return OK; |
+} |
+ |
} // namespace |
// A DnsTransactionFactory which creates MockTransaction. |
@@ -196,6 +205,63 @@ class MockTransactionFactory : public DnsTransactionFactory { |
DelayedTransactionList delayed_transactions_; |
}; |
+MockDnsClientRule::MockDnsClientRule(const std::string& prefix, |
+ uint16_t qtype, |
+ Result result, |
+ bool delay) |
+ : prefix(prefix), qtype(qtype), delay(delay) { |
+ switch (result) { |
+ case FAIL: |
+ response_callback = |
+ base::Bind(&ReturnNetError, net::ERR_NAME_NOT_RESOLVED); |
+ break; |
+ case TIMEOUT: |
+ response_callback = base::Bind(&ReturnNetError, net::ERR_DNS_TIMED_OUT); |
+ break; |
+ case EMPTY: |
+ response_callback = base::Bind(&ReturnNetError, net::OK); |
+ break; |
+ case OK: |
+ response_callback = base::Bind(&WriteLoopbackRecordResponse, qtype); |
+ break; |
+ } |
+ CHECK(!response_callback.is_null()); |
+} |
+ |
+MockDnsClientRule::MockDnsClientRule(const std::string& prefix, |
+ uint16_t qtype, |
+ ResponseCallback response_callback, |
+ bool delay) |
+ : response_callback(response_callback), |
+ prefix(prefix), |
+ qtype(qtype), |
+ delay(delay) {} |
+ |
+MockDnsClientRule::MockDnsClientRule(const MockDnsClientRule& o) |
+ : response_callback(o.response_callback), |
+ prefix(o.prefix), |
+ qtype(o.qtype), |
+ delay(o.delay) {} |
+ |
+MockDnsClientRule::MockDnsClientRule(MockDnsClientRule&& o) { |
+ swap(*this, o); |
+} |
+ |
+MockDnsClientRule::~MockDnsClientRule() {} |
+ |
+MockDnsClientRule& MockDnsClientRule::operator=(net::MockDnsClientRule o) { |
+ swap(*this, o); |
+ return *this; |
+} |
+ |
+void swap(MockDnsClientRule& x, MockDnsClientRule& y) { |
+ using std::swap; |
+ swap(x.response_callback, y.response_callback); |
+ swap(x.prefix, y.prefix); |
+ swap(x.qtype, y.qtype); |
+ swap(x.delay, y.delay); |
+} |
+ |
MockDnsClient::MockDnsClient(const DnsConfig& config, |
const MockDnsClientRuleList& rules) |
: config_(config), |