blob: baa23d00327f90bd6119de0f5da9b62e682735b8 [file] [log] [blame]
// Copyright 2022 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
use {
anyhow::Context as _,
argh::FromArgs,
async_net::{TcpListener, TcpStream},
fuchsia_async as fasync,
fuchsia_sync::Mutex,
futures::prelude::*,
hyper::{
server::{accept::from_stream, Server},
service::{make_service_fn, service_fn},
},
mock_omaha_server::{
handle_request, OmahaServerBuilder, PrivateKeyAndId, PrivateKeys, ResponseAndMetadata,
},
std::{
collections::HashMap,
convert::Infallible,
io,
net::{Ipv6Addr, SocketAddr},
pin::Pin,
sync::Arc,
task::{Context, Poll},
},
};
#[derive(FromArgs)]
/// Arguments for mock-omaha-server.
struct Args {
/// A hashmap from appid to response metadata struct.
/// Example JSON argument:
/// {
/// "appid_01": {
/// "response": "NoUpdate",
/// "merkle": "deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef",
/// "check_assertion": "UpdatesEnabled",
/// "version": "0.1.2.3",
/// },
/// ...
/// }
#[argh(
option,
description = "responses and metadata keyed by appid",
from_str_fn(parse_responses_by_appid),
default = "HashMap::new()"
)]
responses_by_appid: HashMap<String, ResponseAndMetadata>,
#[argh(
option,
description = "private key ID",
default = "DEFAULT_PRIVATE_KEY_ID.try_into().expect(\"key parse\")"
)]
key_id: u64,
#[argh(
option,
description = "path to private key",
default = "\"testing_keys/text_private_key.pem\".to_string()"
)]
key_path: String,
#[argh(option, description = "which port to serve on", default = "0")]
port: u16,
#[argh(
option,
description = "which IP address to listen on. One of '::', '::1', or anything Ipv6Addr::from_str() can interpret.",
default = "Ipv6Addr::UNSPECIFIED"
)]
listen_on: Ipv6Addr,
#[argh(switch, description = "if 'true', will only accept requests with CUP enabled.")]
require_cup: bool,
}
fn parse_responses_by_appid(value: &str) -> Result<HashMap<String, ResponseAndMetadata>, String> {
serde_json::from_str(value).map_err(|e| format!("Parsing failed: {e:?}"))
}
/// Adapt [async_net::TcpStream] to work with hyper.
#[derive(Debug)]
pub enum ConnectionStream {
Tcp(TcpStream),
Socket(fasync::Socket),
}
impl tokio::io::AsyncRead for ConnectionStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<Result<(), std::io::Error>> {
match &mut *self {
ConnectionStream::Tcp(t) => Pin::new(t).poll_read(cx, buf.initialize_unfilled()),
ConnectionStream::Socket(t) => {
futures::AsyncRead::poll_read(Pin::new(t), cx, buf.initialize_unfilled())
}
}
.map_ok(|sz| {
buf.advance(sz);
})
}
}
impl tokio::io::AsyncWrite for ConnectionStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match &mut *self {
ConnectionStream::Tcp(t) => Pin::new(t).poll_write(cx, buf),
ConnectionStream::Socket(t) => futures::AsyncWrite::poll_write(Pin::new(t), cx, buf),
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match &mut *self {
ConnectionStream::Tcp(t) => Pin::new(t).poll_flush(cx),
ConnectionStream::Socket(t) => futures::AsyncWrite::poll_flush(Pin::new(t), cx),
}
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match &mut *self {
ConnectionStream::Tcp(t) => Pin::new(t).poll_close(cx),
ConnectionStream::Socket(t) => Pin::new(t).poll_close(cx),
}
}
}
pub const DEFAULT_PRIVATE_KEY_ID: i32 = 42;
#[fasync::run(10)]
async fn main() -> Result<(), anyhow::Error> {
let args: Args = argh::from_env();
let server = OmahaServerBuilder::default()
.responses_by_appid(args.responses_by_appid)
.private_keys(PrivateKeys {
latest: PrivateKeyAndId {
id: args.key_id,
key: std::fs::read_to_string(args.key_path)
.expect("read from key_path failed")
.parse()
.expect("failed to parse key"),
},
historical: vec![],
})
.require_cup(args.require_cup)
.build()
.expect("omaha server build");
let arc_server = Arc::new(Mutex::new(server));
let addr = SocketAddr::new(args.listen_on.into(), args.port);
let listener = TcpListener::bind(&addr).await.context("binding to addr")?;
println!("listening on {}", listener.local_addr()?);
let connections = listener.incoming().map_ok(ConnectionStream::Tcp);
let make_svc = make_service_fn(move |_socket| {
let arc_server = Arc::clone(&arc_server);
async move {
Ok::<_, Infallible>(service_fn(move |req| {
println!("received req: {req:?}");
let arc_server = Arc::clone(&arc_server);
async move { handle_request(req, &arc_server).await }
}))
}
});
Server::builder(from_stream(connections))
.executor(fuchsia_hyper::Executor)
.serve(make_svc)
.await
.context("error serving omaha server")
}