blob: 5e356fa8020d46ca73d66fa046dfd21d15ee2258 [file] [log] [blame]
// Copyright 2015-2019 Benjamin Fry <benjaminfry@me.com>
//
// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use futures::{future, Future, TryFutureExt};
use smallvec::SmallVec;
#[cfg(test)]
#[cfg(feature = "tokio-runtime")]
use tokio::runtime::Handle;
use proto::error::ProtoError;
use proto::op::ResponseCode;
use proto::xfer::{DnsHandle, DnsRequest, DnsResponse};
use crate::config::{ResolverConfig, ResolverOpts};
#[cfg(feature = "mdns")]
use crate::name_server;
use crate::name_server::{ConnectionProvider, NameServer};
#[cfg(test)]
#[cfg(feature = "tokio-runtime")]
use crate::name_server::{TokioConnection, TokioConnectionProvider};
/// A pool of NameServers
///
/// This is not expected to be used directly, see [`AsyncResolver`].
#[derive(Clone)]
pub struct NameServerPool<
C: DnsHandle + Send + Sync + 'static,
P: ConnectionProvider<Conn = C> + Send + 'static,
> {
// TODO: switch to FuturesMutex (Mutex will have some undesireable locking)
datagram_conns: Arc<Vec<NameServer<C, P>>>, /* All NameServers must be the same type */
stream_conns: Arc<Vec<NameServer<C, P>>>, /* All NameServers must be the same type */
#[cfg(feature = "mdns")]
mdns_conns: NameServer<C, P>, /* All NameServers must be the same type */
options: ResolverOpts,
conn_provider: P,
}
#[cfg(test)]
#[cfg(feature = "tokio-runtime")]
impl NameServerPool<TokioConnection, TokioConnectionProvider> {
pub(crate) fn from_config(
config: &ResolverConfig,
options: &ResolverOpts,
runtime: Handle,
) -> Self {
Self::from_config_with_provider(config, options, TokioConnectionProvider::new(runtime))
}
}
impl<C: DnsHandle + Sync + 'static, P: ConnectionProvider<Conn = C> + 'static>
NameServerPool<C, P>
{
pub(crate) fn from_config_with_provider(
config: &ResolverConfig,
options: &ResolverOpts,
conn_provider: P,
) -> NameServerPool<C, P> {
let datagram_conns: Vec<NameServer<C, P>> = config
.name_servers()
.iter()
.filter(|ns_config| ns_config.protocol.is_datagram())
.map(|ns_config| {
#[cfg(feature = "dns-over-rustls")]
let ns_config = {
let mut ns_config = ns_config.clone();
ns_config.tls_config = config.client_config().clone();
ns_config
};
#[cfg(not(feature = "dns-over-rustls"))]
let ns_config = { ns_config.clone() };
NameServer::<C, P>::new_with_provider(ns_config, *options, conn_provider.clone())
})
.collect();
let stream_conns: Vec<NameServer<C, P>> = config
.name_servers()
.iter()
.filter(|ns_config| ns_config.protocol.is_stream())
.map(|ns_config| {
#[cfg(feature = "dns-over-rustls")]
let ns_config = {
let mut ns_config = ns_config.clone();
ns_config.tls_config = config.client_config().clone();
ns_config
};
#[cfg(not(feature = "dns-over-rustls"))]
let ns_config = { ns_config.clone() };
NameServer::<C, P>::new_with_provider(ns_config, *options, conn_provider.clone())
})
.collect();
NameServerPool {
datagram_conns: Arc::new(datagram_conns),
stream_conns: Arc::new(stream_conns),
#[cfg(feature = "mdns")]
mdns_conns: name_server::mdns_nameserver(*options, conn_provider.clone()),
options: *options,
conn_provider,
}
}
#[doc(hidden)]
#[cfg(not(feature = "mdns"))]
pub fn from_nameservers(
options: &ResolverOpts,
datagram_conns: Vec<NameServer<C, P>>,
stream_conns: Vec<NameServer<C, P>>,
conn_provider: P,
) -> Self {
NameServerPool {
datagram_conns: Arc::new(datagram_conns.into_iter().collect()),
stream_conns: Arc::new(stream_conns.into_iter().collect()),
options: *options,
conn_provider,
}
}
#[doc(hidden)]
#[cfg(feature = "mdns")]
pub fn from_nameservers(
options: &ResolverOpts,
datagram_conns: Vec<NameServer<C, P>>,
stream_conns: Vec<NameServer<C, P>>,
mdns_conns: NameServer<C, P>,
conn_provider: P,
) -> Self {
NameServerPool {
datagram_conns: Arc::new(datagram_conns.into_iter().collect()),
stream_conns: Arc::new(stream_conns.into_iter().collect()),
mdns_conns,
options: *options,
conn_provider,
}
}
async fn try_send(
opts: ResolverOpts,
conns: Arc<Vec<NameServer<C, P>>>,
request: DnsRequest,
) -> Result<DnsResponse, ProtoError> {
let mut conns: Vec<NameServer<C, P>> = conns.to_vec();
// select the highest priority connection
// reorder the connections based on current view...
// this reorders the inner set
conns.sort_unstable();
let request_loop = request.clone();
parallel_conn_loop(conns, request_loop, opts).await
}
}
impl<C, P> DnsHandle for NameServerPool<C, P>
where
C: DnsHandle + Sync + 'static,
P: ConnectionProvider<Conn = C> + 'static,
{
type Response = Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send>>;
fn send<R: Into<DnsRequest>>(&mut self, request: R) -> Self::Response {
let opts = self.options;
let request = request.into();
let datagram_conns = Arc::clone(&self.datagram_conns);
let stream_conns1 = Arc::clone(&self.stream_conns);
let stream_conns2 = Arc::clone(&self.stream_conns);
// TODO: remove this clone, return the Message in the error?
let tcp_message1 = request.clone();
let tcp_message2 = request.clone();
// if it's a .local. query, then we *only* query mDNS, these should never be sent on to upstream resolvers
#[cfg(feature = "mdns")]
let mdns = mdns::maybe_local(&mut self.mdns_conns, request);
// TODO: limited to only when mDNS is enabled, but this should probably always be enforced?
#[cfg(not(feature = "mdns"))]
let mdns = Local::NotMdns(request);
// local queries are queried through mDNS
if mdns.is_local() {
return mdns.take_future();
}
// TODO: should we allow mDNS to be used for standard lookups as well?
// it wasn't a local query, continue with standard lookup path
let request = mdns.take_request();
debug!("sending request: {:?}", request.queries());
// First try the UDP connections
Box::pin(
Self::try_send(opts, datagram_conns, request)
.and_then(move |response| {
// handling promotion from datagram to stream base on truncation in message
if ResponseCode::NoError == response.response_code() && response.truncated() {
// TCP connections should not truncate
future::Either::Left(Self::try_send(opts, stream_conns1, tcp_message1))
} else {
debug!("mDNS responsed for query: {:?}", response.response_code());
// Return the result from the UDP connection
future::Either::Right(future::ok(response))
}
})
// if UDP fails, try TCP
.or_else(move |_| Self::try_send(opts, stream_conns2, tcp_message2)),
)
}
}
// TODO: we should be able to have a self-referential future here with Pin and not require cloned conns
/// An async function that will loop over all the conns with a max parallel request count of ops.num_concurrent_req
async fn parallel_conn_loop<C, P>(
mut conns: Vec<NameServer<C, P>>,
request: DnsRequest,
opts: ResolverOpts,
) -> Result<DnsResponse, ProtoError>
where
C: DnsHandle + 'static,
P: ConnectionProvider<Conn = C> + 'static,
{
let mut err = ProtoError::from("No connections available");
loop {
let request_cont = request.clone();
// construct the parallel requests, 2 is the default
let mut par_conns = SmallVec::<[NameServer<C, P>; 2]>::new();
let count = conns.len().min(opts.num_concurrent_reqs.max(1));
for conn in conns.drain(..count) {
par_conns.push(conn);
}
// construct the requests to send
let requests = if par_conns.is_empty() {
return Err(err);
} else {
par_conns
.into_iter()
.map(move |mut conn| conn.send(request_cont.clone()))
};
match future::select_ok(requests).await {
Ok((sent, _)) => return Ok(sent),
// consider a debug msg here
Err(e) => {
err = e;
continue;
}
};
}
}
#[cfg(feature = "mdns")]
mod mdns {
use super::*;
use proto::rr::domain::usage;
use proto::DnsHandle;
/// Returns true
pub fn maybe_local<C, P>(name_server: &mut NameServer<C, P>, request: DnsRequest) -> Local
where
C: DnsHandle + 'static,
P: ConnectionProvider<Conn = C> + 'static,
P: ConnectionProvider,
{
if request
.queries()
.iter()
.any(|query| usage::LOCAL.name().zone_of(query.name()))
{
Local::ResolveFuture(name_server.send(request))
} else {
Local::NotMdns(request)
}
}
}
pub enum Local {
#[allow(dead_code)]
ResolveFuture(Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send>>),
NotMdns(DnsRequest),
}
impl Local {
fn is_local(&self) -> bool {
if let Local::ResolveFuture(..) = *self {
true
} else {
false
}
}
/// Takes the future
///
/// # Panics
///
/// Panics if this is in fact a Local::NotMdns
fn take_future(self) -> Pin<Box<dyn Future<Output = Result<DnsResponse, ProtoError>> + Send>> {
match self {
Local::ResolveFuture(future) => future,
_ => panic!("non Local queries have no future, see take_message()"),
}
}
/// Takes the message
///
/// # Panics
///
/// Panics if this is in fact a Local::ResolveFuture
fn take_request(self) -> DnsRequest {
match self {
Local::NotMdns(request) => request,
_ => panic!("Local queries must be polled, see take_future()"),
}
}
}
impl Future for Local {
type Output = Result<DnsResponse, ProtoError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
match *self {
Local::ResolveFuture(ref mut ns) => ns.as_mut().poll(cx),
// TODO: making this a panic for now
Local::NotMdns(..) => panic!("Local queries that are not mDNS should not be polled"), //Local::NotMdns(message) => return Err(ResolveErrorKind::Message("not mDNS")),
}
}
}
#[cfg(test)]
#[cfg(feature = "tokio-runtime")]
mod tests {
extern crate env_logger;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use tokio::runtime::Runtime;
use proto::op::Query;
use proto::rr::{Name, RecordType};
use proto::xfer::{DnsHandle, DnsRequestOptions};
use super::*;
use crate::config::NameServerConfig;
use crate::config::Protocol;
#[ignore]
// because of there is a real connection that needs a reasonable timeout
#[test]
fn test_failed_then_success_pool() {
let config1 = NameServerConfig {
socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 252)), 253),
protocol: Protocol::Udp,
tls_dns_name: None,
#[cfg(feature = "dns-over-rustls")]
tls_config: None,
};
let config2 = NameServerConfig {
socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53),
protocol: Protocol::Udp,
tls_dns_name: None,
#[cfg(feature = "dns-over-rustls")]
tls_config: None,
};
let mut resolver_config = ResolverConfig::new();
resolver_config.add_name_server(config1);
resolver_config.add_name_server(config2);
let mut io_loop = Runtime::new().unwrap();
let mut pool = NameServerPool::<_, TokioConnectionProvider>::from_config(
&resolver_config,
&ResolverOpts::default(),
io_loop.handle().clone(),
);
let name = Name::parse("www.example.com.", None).unwrap();
// TODO: it's not clear why there are two failures before the success
for i in 0..2 {
assert!(
io_loop
.block_on(pool.lookup(
Query::query(name.clone(), RecordType::A),
DnsRequestOptions::default()
))
.is_err(),
"iter: {}",
i
);
}
for i in 0..10 {
assert!(
io_loop
.block_on(pool.lookup(
Query::query(name.clone(), RecordType::A),
DnsRequestOptions::default()
))
.is_ok(),
"iter: {}",
i
);
}
}
}