blob: 3873b8d4add2f615aa543c5fd04e86f26f037062 [file] [log] [blame]
// Copyright 2015-2016 Benjamin Fry <benjaminfry@me.com>
//
// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.
//! `RetryDnsHandle` allows for DnsQueries to be reattempted on failure
use std::pin::Pin;
use std::task::{Context, Poll};
use futures_util::stream::{Stream, StreamExt};
use crate::error::{ProtoError, ProtoErrorKind};
use crate::xfer::{DnsRequest, DnsResponse};
use crate::DnsHandle;
/// Can be used to reattempt queries if they fail
///
/// Note: this does not reattempt queries that fail with a negative response.
/// For example, if a query gets a `NODATA` response from a name server, the
/// query will not be retried. It only reattempts queries that effectively
/// failed to get a response, such as queries that resulted in IO or timeout
/// errors.
///
/// Whether an error is retryable by the [`RetryDnsHandle`] is determined by the
/// [`RetryableError`] trait.
///
/// *note* Current value of this is not clear, it may be removed
#[derive(Clone)]
#[must_use = "queries can only be sent through a ClientHandle"]
pub struct RetryDnsHandle<H>
where
H: DnsHandle + Unpin + Send,
H::Error: RetryableError,
{
handle: H,
attempts: usize,
}
impl<H> RetryDnsHandle<H>
where
H: DnsHandle + Unpin + Send,
H::Error: RetryableError,
{
/// Creates a new Client handler for reattempting requests on failures.
///
/// # Arguments
///
/// * `handle` - handle to the dns connection
/// * `attempts` - number of attempts before failing
pub fn new(handle: H, attempts: usize) -> Self {
Self { handle, attempts }
}
}
impl<H> DnsHandle for RetryDnsHandle<H>
where
H: DnsHandle + Send + Unpin + 'static,
H::Error: RetryableError,
{
type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, Self::Error>> + Send + Unpin>>;
type Error = <H as DnsHandle>::Error;
fn send<R: Into<DnsRequest>>(&mut self, request: R) -> Self::Response {
let request = request.into();
// need to clone here so that the retry can resend if necessary...
// obviously it would be nice to be lazy about this...
let stream = self.handle.send(request.clone());
Box::pin(RetrySendStream {
request,
handle: self.handle.clone(),
stream,
remaining_attempts: self.attempts,
})
}
}
/// A stream for retrying (on failure, for the remaining number of times specified)
struct RetrySendStream<H>
where
H: DnsHandle,
{
request: DnsRequest,
handle: H,
stream: <H as DnsHandle>::Response,
remaining_attempts: usize,
}
impl<H: DnsHandle + Unpin> Stream for RetrySendStream<H>
where
<H as DnsHandle>::Error: RetryableError,
{
type Item = Result<DnsResponse, <H as DnsHandle>::Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
// loop over the stream, on errors, spawn a new stream
// on ready and not ready return.
loop {
match self.stream.poll_next_unpin(cx) {
Poll::Ready(Some(Err(e))) => {
if self.remaining_attempts == 0 || !e.should_retry() {
return Poll::Ready(Some(Err(e)));
}
if e.attempted() {
self.remaining_attempts -= 1;
}
// TODO: if the "sent" Message is part of the error result,
// then we can just reuse it... and no clone necessary
let request = self.request.clone();
self.stream = self.handle.send(request);
}
poll => return poll,
}
}
}
}
/// What errors should be retried
pub trait RetryableError {
/// Whether the query should be retried after this error
fn should_retry(&self) -> bool;
/// Whether this error should count as an attempt
fn attempted(&self) -> bool;
}
impl RetryableError for ProtoError {
fn should_retry(&self) -> bool {
true
}
fn attempted(&self) -> bool {
!matches!(self.kind(), ProtoErrorKind::Busy)
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::error::*;
use crate::op::*;
use crate::xfer::FirstAnswer;
use futures_executor::block_on;
use futures_util::future::*;
use futures_util::stream::*;
use std::sync::{
atomic::{AtomicU16, Ordering},
Arc,
};
use DnsHandle;
#[derive(Clone)]
struct TestClient {
last_succeed: bool,
retries: u16,
attempts: Arc<AtomicU16>,
}
impl DnsHandle for TestClient {
type Response = Box<dyn Stream<Item = Result<DnsResponse, ProtoError>> + Send + Unpin>;
type Error = ProtoError;
fn send<R: Into<DnsRequest>>(&mut self, _: R) -> Self::Response {
let i = self.attempts.load(Ordering::SeqCst);
if (i > self.retries || self.retries - i == 0) && self.last_succeed {
let mut message = Message::new();
message.set_id(i);
return Box::new(once(ok(message.into())));
}
self.attempts.fetch_add(1, Ordering::SeqCst);
Box::new(once(err(ProtoError::from("last retry set to fail"))))
}
}
#[test]
fn test_retry() {
let mut handle = RetryDnsHandle::new(
TestClient {
last_succeed: true,
retries: 1,
attempts: Arc::new(AtomicU16::new(0)),
},
2,
);
let test1 = Message::new();
let result = block_on(handle.send(test1).first_answer()).expect("should have succeeded");
assert_eq!(result.id(), 1); // this is checking the number of iterations the TestClient ran
}
#[test]
fn test_error() {
let mut client = RetryDnsHandle::new(
TestClient {
last_succeed: false,
retries: 1,
attempts: Arc::new(AtomicU16::new(0)),
},
2,
);
let test1 = Message::new();
assert!(block_on(client.send(test1).first_answer()).is_err());
}
}