blob: f5b18f787a0516fb86a7bb16e9c4062c420067d7 [file] [log] [blame]
// Copyright 2020 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
use {
crate::{parse_ip_addr, HyperConnectorFuture, SocketOptions, TcpOptions, TcpStream},
async_net as net,
futures::io,
http::uri::{Scheme, Uri},
hyper::service::Service,
rustls::ClientConfig,
std::net::ToSocketAddrs,
std::task::{Context, Poll},
tracing::warn,
};
pub(crate) fn configure_cert_store(tls: &mut ClientConfig) {
tls.root_store = rustls_native_certs::load_native_certs().unwrap_or_else(|(certs, err)| {
if certs.is_some() {
warn!("One or more TLS certificates in root store failed to load: {}", err);
}
certs.unwrap_or_else(|| {
panic!("Unable to load any TLS CA certificates from platform root store: {}", err)
})
})
}
/// A Async-std-compatible implementation of hyper's `Connect` trait which allows
/// creating a TcpStream to a particular destination.
#[derive(Clone, Debug)]
pub struct HyperConnector {
tcp_options: TcpOptions,
socket_options: SocketOptions,
}
impl From<(TcpOptions, SocketOptions)> for HyperConnector {
fn from((tcp_options, socket_options): (TcpOptions, SocketOptions)) -> Self {
Self { tcp_options, socket_options }
}
}
impl HyperConnector {
pub fn new() -> Self {
Self::from_tcp_options(TcpOptions::default())
}
pub fn from_tcp_options(tcp_options: TcpOptions) -> Self {
Self { tcp_options, socket_options: SocketOptions::default() }
}
}
impl Service<Uri> for HyperConnector {
type Response = TcpStream;
type Error = std::io::Error;
type Future = HyperConnectorFuture;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, dst: Uri) -> Self::Future {
let self_ = self.clone();
HyperConnectorFuture { fut: Box::pin(async move { self_.call_async(dst).await }) }
}
}
impl HyperConnector {
async fn call_async(&self, dst: Uri) -> Result<TcpStream, io::Error> {
let port = match dst.port() {
Some(p) => p.as_u16(),
None => {
if dst.scheme() == Some(&Scheme::HTTPS) {
443
} else {
80
}
}
};
let host = match dst.host() {
Some(host) => host,
_ => return Err(io::Error::new(io::ErrorKind::Other, "missing host in Uri")),
};
let addr = parse_ip_addr(host, port, |_| async {
Err(io::Error::new(io::ErrorKind::Other, "does not yet support non-integer zone ids"))
})
.await?;
if self.socket_options.bind_device.is_some() {
unimplemented!("TODO(https://fxbug.dev/42083862) fuchsia-hyper does not support bind_device on non-fuchsia devices");
}
let stream = if let Some(addr) = addr {
net::TcpStream::connect(addr).await?
} else {
resolve_host_port(host, port).await?
};
let () = self.tcp_options.apply(&stream)?;
Ok(TcpStream { stream })
}
}
/// Resolve a hostname into an address.
async fn resolve_host_port(host: &str, port: u16) -> Result<net::TcpStream, io::Error> {
// TODO(https://fxbug.dev/42075095): Implement happy eyeballs algorithm to make this
// more efficient.
let mut last_err = None;
for addr in (host, port).to_socket_addrs()? {
match net::TcpStream::connect(addr).await {
Ok(stream) => {
return Ok(stream);
}
Err(err) => {
last_err = Some(err);
}
}
}
if let Some(err) = last_err {
Err(err)
} else {
Err(io::Error::new(io::ErrorKind::Other, "destination resolved to no address"))
}
}
////////////////////////////////////////////////////////////////////////////////
///// tests
#[cfg(test)]
mod test {
use {
crate::*,
anyhow::{Error, Result},
async_net::TcpListener,
futures::{
future::BoxFuture, stream::FuturesUnordered, StreamExt, TryFutureExt, TryStreamExt,
},
hyper::{
body::HttpBody,
server::{accept::from_stream, Server},
service::{make_service_fn, service_fn},
Response, StatusCode,
},
std::{convert::Infallible, io::Write},
};
trait AsyncReadWrite: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send {}
impl<T> AsyncReadWrite for T where T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send {}
async fn fetch_url<W: Write>(url: hyper::Uri, mut buffer: W) -> Result<StatusCode> {
let client = new_https_client();
let mut res = client.get(url).await?;
let status = res.status();
if status == StatusCode::OK {
while let Some(next) = res.data().await {
let chunk = next?;
buffer.write_all(&chunk)?;
}
buffer.flush()?;
}
Ok(status)
}
#[fuchsia_async::run_singlethreaded(test)]
async fn test_download_succeeds() -> Result<()> {
let (listener, addr) = {
let addr = SocketAddr::new(Ipv6Addr::LOCALHOST.into(), 0);
let listener = TcpListener::bind(&addr).await.unwrap();
let local_addr = listener.local_addr().unwrap();
(listener, local_addr)
};
let listener =
listener.incoming().map_err(Error::from).map_ok(|conn| TcpStream { stream: conn });
let connections = listener
.map_ok(|conn| Pin::new(Box::new(conn)) as Pin<Box<dyn AsyncReadWrite>>)
.boxed();
let make_svc = make_service_fn(move |_socket| async move {
Ok::<_, Infallible>(service_fn(move |_req| async move {
Ok::<_, Infallible>(Response::new(Body::from("Hello")))
}))
});
let (stop, rx_stop) = futures::channel::oneshot::channel();
let server = async {
Server::builder(from_stream(connections))
.executor(Executor)
.serve(make_svc)
.with_graceful_shutdown(
rx_stop.map(|res| res.unwrap_or_else(|futures::channel::oneshot::Canceled| ())),
)
.unwrap_or_else(|e| panic!("error serving repo over http: {}", e))
.await;
Ok(())
};
let client = async {
let output: Vec<u8> = Vec::new();
let status = fetch_url(
format!("http://localhost:{}", addr.port()).parse::<hyper::Uri>().unwrap(),
output,
)
.await
.unwrap();
match status {
StatusCode::OK | StatusCode::FOUND => {}
_ => assert!(false, "Unexpected status code: {}", status),
}
stop.send(()).expect("server to still be running");
Ok(())
};
let mut tasks: FuturesUnordered<BoxFuture<'_, Result<(), Error>>> = FuturesUnordered::new();
tasks.push(Box::pin(server));
tasks.push(Box::pin(client));
while let Some(Ok(())) = tasks.next().await {}
Ok(())
}
#[fuchsia_async::run_singlethreaded(test)]
async fn test_download_handles_bad_domain() -> Result<()> {
let output: Vec<u8> = Vec::new();
let res = fetch_url("https://domain.invalid".parse::<hyper::Uri>()?, output).await;
assert!(res.is_err());
Ok(())
}
}