blob: 6b0d45c906c348e5e5a517c3b18b1e9a5f599429 [file] [log] [blame]
// Copyright 2020 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef SRC_DEVELOPER_FORENSICS_UTILS_FIDL_HANGING_GET_PTR_H_
#define SRC_DEVELOPER_FORENSICS_UTILS_FIDL_HANGING_GET_PTR_H_
#include <lib/async/cpp/task.h>
#include <lib/async/dispatcher.h>
#include <lib/fidl/cpp/interface_ptr.h>
#include <lib/fit/function.h>
#include <lib/fpromise/promise.h>
#include <lib/sys/cpp/service_directory.h>
#include <lib/syslog/cpp/macros.h>
#include <functional>
#include "src/developer/forensics/utils/errors.h"
#include "src/developer/forensics/utils/fit/bridge_map.h"
#include "src/developer/forensics/utils/fit/timeout.h"
#include "src/lib/backoff/exponential_backoff.h"
namespace forensics {
namespace fidl {
// Wrapper around InterfacePtr<Interface> that automatically manages making new calls to a
// hanging-get protocol.
//
// For example, if we wished to fetch a device's update channel from
// fuchsia::update::channel::Provider assuming a hanging get pattern, then we would use
// HangingGetPtr as follows:
//
// class HangingGetChannelPtr {
// public:
// HangingGetChannelPtr(async_dispatcher_t* dispatcher, std::shared_ptr<sys::ServiceDirectory>
// services)
// : connection_(dispatcher, services, [this] { GetChannel(); }) {}
//
// ::fpromise::promise<std::string> GetChannel(zx::duration timeout) {
// return connection_.GetValue(fit::Timeout(timeout));
// }
//
// private:
// void GetChannel() {
// connection_->GetCurrent([this](std::string channel) {
// std::optional<std::string> value;
// if (!channel.empty()) {
// connection_.SetValue(channel);
// } else {
// connection_.SetError(Error::kMissingValue);
// }
// });
// }
//
// HangingGetPtr<fuchsia::update::channel::Provider, std::string> connection_;
// };
//
// This class is not thread safe.
template <typename Interface, typename V>
class HangingGetPtr {
public:
HangingGetPtr(async_dispatcher_t* dispatcher, std::shared_ptr<sys::ServiceDirectory> services,
std::function<void(void)> make_call)
: dispatcher_(dispatcher),
services_(services),
pending_calls_(dispatcher_),
make_call_(make_call),
make_call_task_([this] {
Connect();
make_call_();
}),
make_call_backoff_(/*initial_delay=*/zx::msec(100), /*retry_factor=*/2u,
/*max_delay=*/zx::hour(1)) {
// Post |make_call_| on the async loop with an immediate deadline in an attempt to
// pre-cache |value_|. Because the class that owns the HangingGetPtr and supplies |make_call_|
// may capture itself in the lambda, we need to ensure that the owning class is fully
// initialized before |make_call_| is executed. Thus, we post |make_call_| on the async loop and
// are guaranteed that any data initialized along side the HangingGetPtr is initialized before
// |make_call_| is executed.
//
// This isn't safe if |dispatcher_| is on a different thread than |this|.
make_call_task_.Post(dispatcher_);
}
void SetValue(V value) { UpdateValue(std::move(value)); }
void SetError(Error error) { UpdateValue(error); }
::fpromise::promise<V, Error> GetValue(fit::Timeout timeout) {
if (value_) {
return ::fpromise::make_result_promise(ValueToResult());
}
const uint64_t id = pending_calls_.NewBridgeForTask(Interface::Name_);
// A call to GetValue() is only ever completed with an error due to circumstances that affect
// only that call, i.e. the call times out or there is an issue posting the timeout task, so we
// don't set an error for all pending GetValue() calls, i.e. we don't set |value_| with the
// Error and instead only propagate the Error to the caller.
return pending_calls_.WaitForDone(id, std::move(timeout))
.then([id, this](const ::fpromise::result<void, Error>& result) {
pending_calls_.Delete(id);
return result;
})
.and_then([this]() { return ValueToResult(); })
.or_else([](const Error& error) { return ::fpromise::error(error); });
}
Interface* operator->() { return connection_.get(); }
private:
::fpromise::result<V, Error> ValueToResult() const {
if (!value_) {
FX_LOGS(FATAL) << "Attempting to return a result when none has been cached";
}
if (value_->HasValue()) {
return ::fpromise::ok(value_->Value());
} else {
return ::fpromise::error(value_->Error());
}
}
void Connect() {
if (!connection_.is_bound()) {
connection_ = services_->Connect<Interface>();
}
connection_.set_error_handler([this](zx_status_t status) {
FX_PLOGS(WARNING, status) << "Lost connection with " << Interface::Name_;
if (const auto post_status =
make_call_task_.PostDelayed(dispatcher_, make_call_backoff_.GetNext());
post_status != ZX_OK) {
FX_PLOGS(ERROR, post_status) << "Failed to post task to make call on async loop";
SetError(Error::kAsyncTaskPostFailure);
}
});
}
void UpdateValue(ErrorOr<V> value) {
value_ = std::make_unique<ErrorOr<V>>(std::move(value));
pending_calls_.CompleteAllOk();
make_call_backoff_.Reset();
make_call_();
}
async_dispatcher_t* dispatcher_;
const std::shared_ptr<sys::ServiceDirectory> services_;
::fidl::InterfacePtr<Interface> connection_;
fit::BridgeMap<> pending_calls_;
// The std::unique_ptr<> indicates whether a value is cached, the ErrorOr<> indicates
// whether the cached value is a payload or an error.
std::unique_ptr<ErrorOr<V>> value_;
std::function<void(void)> make_call_;
async::TaskClosure make_call_task_;
backoff::ExponentialBackoff make_call_backoff_;
};
} // namespace fidl
} // namespace forensics
#endif // SRC_DEVELOPER_FORENSICS_UTILS_FIDL_HANGING_GET_PTR_H_