| Index: components/certificate_transparency/log_dns_client.cc
|
| diff --git a/components/certificate_transparency/log_dns_client.cc b/components/certificate_transparency/log_dns_client.cc
|
| index ce7e8627a703783416b2fae22ac764c74e236400..76ff9f706bb3387692590cf30d5aa931fe7fb39d 100644
|
| --- a/components/certificate_transparency/log_dns_client.cc
|
| +++ b/components/certificate_transparency/log_dns_client.cc
|
| @@ -2,17 +2,17 @@
|
| // Use of this source code is governed by a BSD-style license that can be
|
| // found in the LICENSE file.
|
|
|
| #include "components/certificate_transparency/log_dns_client.h"
|
|
|
| -#include <sstream>
|
| -
|
| #include "base/bind.h"
|
| +#include "base/format_macros.h"
|
| #include "base/location.h"
|
| #include "base/logging.h"
|
| #include "base/strings/string_number_conversions.h"
|
| #include "base/strings/string_util.h"
|
| +#include "base/strings/stringprintf.h"
|
| #include "base/threading/thread_task_runner_handle.h"
|
| #include "base/time/time.h"
|
| #include "components/base32/base32.h"
|
| #include "crypto/sha2.h"
|
| #include "net/base/net_errors.h"
|
| @@ -96,13 +96,15 @@ bool ParseAuditPath(const net::DnsResponse& response,
|
| }
|
|
|
| } // namespace
|
|
|
| LogDnsClient::LogDnsClient(std::unique_ptr<net::DnsClient> dns_client,
|
| - const net::NetLogWithSource& net_log)
|
| + const net::NetLogWithSource& net_log,
|
| + size_t max_concurrent_queries)
|
| : dns_client_(std::move(dns_client)),
|
| net_log_(net_log),
|
| + max_concurrent_queries_(max_concurrent_queries),
|
| weak_ptr_factory_(this) {
|
| CHECK(dns_client_);
|
| net::NetworkChangeNotifier::AddDNSObserver(this);
|
| UpdateDnsConfig();
|
| }
|
| @@ -126,10 +128,17 @@ void LogDnsClient::QueryLeafIndex(base::StringPiece domain_for_log,
|
| base::ThreadTaskRunnerHandle::Get()->PostTask(
|
| FROM_HERE, base::Bind(callback, net::Error::ERR_INVALID_ARGUMENT, 0));
|
| return;
|
| }
|
|
|
| + if (HasMaxConcurrentQueriesInProgress()) {
|
| + base::ThreadTaskRunnerHandle::Get()->PostTask(
|
| + FROM_HERE,
|
| + base::Bind(callback, net::Error::ERR_TEMPORARILY_THROTTLED, 0));
|
| + return;
|
| + }
|
| +
|
| std::string encoded_leaf_hash =
|
| base32::Base32Encode(leaf_hash, base32::Base32EncodePolicy::OMIT_PADDING);
|
| DCHECK_EQ(encoded_leaf_hash.size(), 52u);
|
|
|
| net::DnsTransactionFactory* factory = dns_client_->GetTransactionFactory();
|
| @@ -138,18 +147,18 @@ void LogDnsClient::QueryLeafIndex(base::StringPiece domain_for_log,
|
| FROM_HERE,
|
| base::Bind(callback, net::Error::ERR_NAME_RESOLUTION_FAILED, 0));
|
| return;
|
| }
|
|
|
| - std::ostringstream qname;
|
| - qname << encoded_leaf_hash << ".hash." << domain_for_log << ".";
|
| + std::string qname = base::StringPrintf(
|
| + "%s.hash.%s.", encoded_leaf_hash.c_str(), domain_for_log.data());
|
|
|
| net::DnsTransactionFactory::CallbackType transaction_callback = base::Bind(
|
| &LogDnsClient::QueryLeafIndexComplete, weak_ptr_factory_.GetWeakPtr());
|
|
|
| std::unique_ptr<net::DnsTransaction> dns_transaction =
|
| - factory->CreateTransaction(qname.str(), net::dns_protocol::kTypeTXT,
|
| + factory->CreateTransaction(qname, net::dns_protocol::kTypeTXT,
|
| transaction_callback, net_log_);
|
|
|
| dns_transaction->Start();
|
| leaf_index_queries_.push_back({std::move(dns_transaction), callback});
|
| }
|
| @@ -160,11 +169,12 @@ void LogDnsClient::QueryLeafIndex(base::StringPiece domain_for_log,
|
| // 7-13 and 14-19) immediately. Currently, it sends only the first and then,
|
| // based on the number of nodes received, sends the next query. The complexity
|
| // of the code would increase though, as it would need to detect gaps in the
|
| // audit proof caused by the server not responding with the anticipated number
|
| // of nodes. Ownership of the proof would need to change, as it would be shared
|
| -// between simultaneous DNS transactions.
|
| +// between simultaneous DNS transactions. Throttling of queries would also need
|
| +// to take into account this increase in parallelism.
|
| void LogDnsClient::QueryAuditProof(base::StringPiece domain_for_log,
|
| uint64_t leaf_index,
|
| uint64_t tree_size,
|
| const AuditProofCallback& callback) {
|
| if (domain_for_log.empty() || leaf_index >= tree_size) {
|
| @@ -172,10 +182,17 @@ void LogDnsClient::QueryAuditProof(base::StringPiece domain_for_log,
|
| FROM_HERE,
|
| base::Bind(callback, net::Error::ERR_INVALID_ARGUMENT, nullptr));
|
| return;
|
| }
|
|
|
| + if (HasMaxConcurrentQueriesInProgress()) {
|
| + base::ThreadTaskRunnerHandle::Get()->PostTask(
|
| + FROM_HERE,
|
| + base::Bind(callback, net::Error::ERR_TEMPORARILY_THROTTLED, nullptr));
|
| + return;
|
| + }
|
| +
|
| std::unique_ptr<net::ct::MerkleAuditProof> proof(
|
| new net::ct::MerkleAuditProof);
|
| proof->leaf_index = leaf_index;
|
| // TODO(robpercival): Once a "tree_size" field is added to MerkleAuditProof,
|
| // pass |tree_size| to QueryAuditProofNodes using that.
|
| @@ -243,21 +260,21 @@ void LogDnsClient::QueryAuditProofNodes(
|
| FROM_HERE,
|
| base::Bind(callback, net::Error::ERR_NAME_RESOLUTION_FAILED, nullptr));
|
| return;
|
| }
|
|
|
| - std::ostringstream qname;
|
| - qname << node_index << "." << proof->leaf_index << "." << tree_size
|
| - << ".tree." << domain_for_log << ".";
|
| + std::string qname = base::StringPrintf(
|
| + "%" PRIu64 ".%" PRIu64 ".%" PRIu64 ".tree.%s.", node_index,
|
| + proof->leaf_index, tree_size, domain_for_log.data());
|
|
|
| net::DnsTransactionFactory::CallbackType transaction_callback =
|
| base::Bind(&LogDnsClient::QueryAuditProofNodesComplete,
|
| weak_ptr_factory_.GetWeakPtr(), base::Passed(std::move(proof)),
|
| domain_for_log, tree_size);
|
|
|
| std::unique_ptr<net::DnsTransaction> dns_transaction =
|
| - factory->CreateTransaction(qname.str(), net::dns_protocol::kTypeTXT,
|
| + factory->CreateTransaction(qname, net::dns_protocol::kTypeTXT,
|
| transaction_callback, net_log_);
|
| dns_transaction->Start();
|
| audit_proof_queries_.push_back({std::move(dns_transaction), callback});
|
| }
|
|
|
| @@ -318,10 +335,18 @@ void LogDnsClient::QueryAuditProofNodesComplete(
|
| base::ThreadTaskRunnerHandle::Get()->PostTask(
|
| FROM_HERE,
|
| base::Bind(query.callback, net::OK, base::Passed(std::move(proof))));
|
| }
|
|
|
| +bool LogDnsClient::HasMaxConcurrentQueriesInProgress() const {
|
| + const size_t queries_in_progress =
|
| + leaf_index_queries_.size() + audit_proof_queries_.size();
|
| +
|
| + return max_concurrent_queries_ != 0 &&
|
| + queries_in_progress >= max_concurrent_queries_;
|
| +}
|
| +
|
| void LogDnsClient::UpdateDnsConfig() {
|
| net::DnsConfig config;
|
| net::NetworkChangeNotifier::GetDnsConfig(&config);
|
| if (config.IsValid())
|
| dns_client_->SetConfig(config);
|
|
|