blob: 5a0d3f61684b6834ce392a0230658846c8cd2c88 [file] [log] [blame]
/*
* Copyright (C) 2019 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#define ANDROID_BASE_UNIQUE_FD_DISABLE_IMPLICIT_CONVERSION
#include "include/adbd_auth.h"
#include <inttypes.h>
#include <sys/epoll.h>
#include <sys/eventfd.h>
#include <sys/uio.h>
#include <chrono>
#include <deque>
#include <string>
#include <string_view>
#include <tuple>
#include <unordered_map>
#include <utility>
#include <variant>
#include <vector>
#include <android-base/file.h>
#include <android-base/logging.h>
#include <android-base/macros.h>
#include <android-base/strings.h>
#include <android-base/thread_annotations.h>
#include <android-base/unique_fd.h>
#include <cutils/sockets.h>
using android::base::unique_fd;
static constexpr uint32_t kAuthVersion = 1;
struct AdbdAuthPacketAuthenticated {
std::string public_key;
};
struct AdbdAuthPacketDisconnected {
std::string public_key;
};
struct AdbdAuthPacketRequestAuthorization {
std::string public_key;
};
struct AdbdPacketTlsDeviceConnected {
uint8_t transport_type;
std::string public_key;
};
struct AdbdPacketTlsDeviceDisconnected {
uint8_t transport_type;
std::string public_key;
};
using AdbdAuthPacket = std::variant<AdbdAuthPacketAuthenticated,
AdbdAuthPacketDisconnected,
AdbdAuthPacketRequestAuthorization,
AdbdPacketTlsDeviceConnected,
AdbdPacketTlsDeviceDisconnected>;
struct AdbdAuthContext {
static constexpr uint64_t kEpollConstSocket = 0;
static constexpr uint64_t kEpollConstEventFd = 1;
static constexpr uint64_t kEpollConstFramework = 2;
public:
explicit AdbdAuthContext(AdbdAuthCallbacksV1* callbacks) : next_id_(0), callbacks_(*callbacks) {
InitFrameworkHandlers();
epoll_fd_.reset(epoll_create1(EPOLL_CLOEXEC));
if (epoll_fd_ == -1) {
PLOG(FATAL) << "failed to create epoll fd";
}
event_fd_.reset(eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK));
if (event_fd_ == -1) {
PLOG(FATAL) << "failed to create eventfd";
}
sock_fd_.reset(android_get_control_socket("adbd"));
if (sock_fd_ == -1) {
PLOG(ERROR) << "failed to get adbd authentication socket";
} else {
if (fcntl(sock_fd_.get(), F_SETFD, FD_CLOEXEC) != 0) {
PLOG(FATAL) << "failed to make adbd authentication socket cloexec";
}
if (fcntl(sock_fd_.get(), F_SETFL, O_NONBLOCK) != 0) {
PLOG(FATAL) << "failed to make adbd authentication socket nonblocking";
}
if (listen(sock_fd_.get(), 4) != 0) {
PLOG(FATAL) << "failed to listen on adbd authentication socket";
}
}
}
AdbdAuthContext(const AdbdAuthContext& copy) = delete;
AdbdAuthContext(AdbdAuthContext&& move) = delete;
AdbdAuthContext& operator=(const AdbdAuthContext& copy) = delete;
AdbdAuthContext& operator=(AdbdAuthContext&& move) = delete;
uint64_t NextId() { return next_id_++; }
void DispatchPendingPrompt() REQUIRES(mutex_) {
if (dispatched_prompt_) {
LOG(INFO) << "adbd_auth: prompt currently pending, skipping";
return;
}
if (pending_prompts_.empty()) {
LOG(INFO) << "adbd_auth: no prompts to send";
return;
}
LOG(INFO) << "adbd_auth: prompting user for adb authentication";
auto [id, public_key, arg] = std::move(pending_prompts_.front());
pending_prompts_.pop_front();
this->output_queue_.emplace_back(
AdbdAuthPacketRequestAuthorization{.public_key = public_key});
Interrupt();
dispatched_prompt_ = std::make_tuple(id, public_key, arg);
}
void UpdateFrameworkWritable() REQUIRES(mutex_) {
// This might result in redundant calls to EPOLL_CTL_MOD if, for example, we get notified
// at the same time as a framework connection, but that's unlikely and this doesn't need to
// be fast anyway.
if (framework_fd_ != -1) {
struct epoll_event event;
event.events = EPOLLIN;
if (!output_queue_.empty()) {
LOG(INFO) << "marking framework writable";
event.events |= EPOLLOUT;
}
event.data.u64 = kEpollConstFramework;
CHECK_EQ(0, epoll_ctl(epoll_fd_.get(), EPOLL_CTL_MOD, framework_fd_.get(), &event));
}
}
void ReplaceFrameworkFd(unique_fd new_fd) REQUIRES(mutex_) {
LOG(INFO) << "received new framework fd " << new_fd.get()
<< " (current = " << framework_fd_.get() << ")";
// If we already had a framework fd, clean up after ourselves.
if (framework_fd_ != -1) {
output_queue_.clear();
dispatched_prompt_.reset();
CHECK_EQ(0, epoll_ctl(epoll_fd_.get(), EPOLL_CTL_DEL, framework_fd_.get(), nullptr));
framework_fd_.reset();
}
if (new_fd != -1) {
struct epoll_event event;
event.events = EPOLLIN;
if (!output_queue_.empty()) {
LOG(INFO) << "marking framework writable";
event.events |= EPOLLOUT;
}
event.data.u64 = kEpollConstFramework;
CHECK_EQ(0, epoll_ctl(epoll_fd_.get(), EPOLL_CTL_ADD, new_fd.get(), &event));
framework_fd_ = std::move(new_fd);
}
}
void HandlePacket(std::string_view packet) EXCLUDES(mutex_) {
LOG(INFO) << "received packet: " << packet;
if (packet.size() < 2) {
LOG(ERROR) << "received packet of invalid length";
std::lock_guard<std::mutex> lock(mutex_);
ReplaceFrameworkFd(unique_fd());
}
bool handled_packet = false;
for (size_t i = 0; i < framework_handlers_.size(); ++i) {
if (android::base::ConsumePrefix(&packet, framework_handlers_[i].code)) {
framework_handlers_[i].cb(packet);
handled_packet = true;
break;
}
}
if (!handled_packet) {
LOG(ERROR) << "unhandled packet: " << packet;
std::lock_guard<std::mutex> lock(mutex_);
ReplaceFrameworkFd(unique_fd());
}
}
void AllowUsbDevice(std::string_view buf) EXCLUDES(mutex_) {
std::lock_guard<std::mutex> lock(mutex_);
CHECK(buf.empty());
CHECK(dispatched_prompt_.has_value());
auto& [id, key, arg] = *dispatched_prompt_;
keys_.emplace(id, std::move(key));
callbacks_.key_authorized(arg, id);
dispatched_prompt_ = std::nullopt;
// We need to dispatch pending prompts here upon success as well,
// since we might have multiple queued prompts.
DispatchPendingPrompt();
}
void DenyUsbDevice(std::string_view buf) EXCLUDES(mutex_) {
std::lock_guard<std::mutex> lock(mutex_);
CHECK(buf.empty());
// TODO: Do we want a callback if the key is denied?
dispatched_prompt_ = std::nullopt;
DispatchPendingPrompt();
}
void KeyRemoved(std::string_view buf) EXCLUDES(mutex_) {
CHECK(!buf.empty());
callbacks_.key_removed(buf.data(), buf.size());
}
bool SendPacket() REQUIRES(mutex_) {
if (output_queue_.empty()) {
return false;
}
CHECK_NE(-1, framework_fd_.get());
auto& packet = output_queue_.front();
struct iovec iovs[3];
int iovcnt = 2;
if (auto* p = std::get_if<AdbdAuthPacketAuthenticated>(&packet)) {
iovs[0].iov_base = const_cast<char*>("CK");
iovs[0].iov_len = 2;
iovs[1].iov_base = p->public_key.data();
iovs[1].iov_len = p->public_key.size();
} else if (auto* p = std::get_if<AdbdAuthPacketDisconnected>(&packet)) {
iovs[0].iov_base = const_cast<char*>("DC");
iovs[0].iov_len = 2;
iovs[1].iov_base = p->public_key.data();
iovs[1].iov_len = p->public_key.size();
} else if (auto* p = std::get_if<AdbdAuthPacketRequestAuthorization>(&packet)) {
iovs[0].iov_base = const_cast<char*>("PK");
iovs[0].iov_len = 2;
iovs[1].iov_base = p->public_key.data();
iovs[1].iov_len = p->public_key.size();
} else if (auto* p = std::get_if<AdbdPacketTlsDeviceConnected>(&packet)) {
iovcnt = 3;
iovs[0].iov_base = const_cast<char*>("WE");
iovs[0].iov_len = 2;
iovs[1].iov_base = &p->transport_type;
iovs[1].iov_len = 1;
iovs[2].iov_base = p->public_key.data();
iovs[2].iov_len = p->public_key.size();
} else if (auto* p = std::get_if<AdbdPacketTlsDeviceDisconnected>(&packet)) {
iovcnt = 3;
iovs[0].iov_base = const_cast<char*>("WF");
iovs[0].iov_len = 2;
iovs[1].iov_base = &p->transport_type;
iovs[1].iov_len = 1;
iovs[2].iov_base = p->public_key.data();
iovs[2].iov_len = p->public_key.size();
} else {
LOG(FATAL) << "unhandled packet type?";
}
output_queue_.pop_front();
ssize_t rc = writev(framework_fd_.get(), iovs, iovcnt);
if (rc == -1 && errno != EAGAIN && errno != EWOULDBLOCK) {
PLOG(ERROR) << "failed to write to framework fd";
ReplaceFrameworkFd(unique_fd());
return false;
}
return true;
}
void Run() {
if (sock_fd_ == -1) {
LOG(ERROR) << "adbd authentication socket unavailable, disabling user prompts";
} else {
struct epoll_event event;
event.events = EPOLLIN;
event.data.u64 = kEpollConstSocket;
CHECK_EQ(0, epoll_ctl(epoll_fd_.get(), EPOLL_CTL_ADD, sock_fd_.get(), &event));
}
{
struct epoll_event event;
event.events = EPOLLIN;
event.data.u64 = kEpollConstEventFd;
CHECK_EQ(0, epoll_ctl(epoll_fd_.get(), EPOLL_CTL_ADD, event_fd_.get(), &event));
}
while (true) {
struct epoll_event events[3];
int rc = TEMP_FAILURE_RETRY(epoll_wait(epoll_fd_.get(), events, 3, -1));
if (rc == -1) {
PLOG(FATAL) << "epoll_wait failed";
} else if (rc == 0) {
LOG(FATAL) << "epoll_wait returned 0";
}
bool restart = false;
for (int i = 0; i < rc; ++i) {
if (restart) {
break;
}
struct epoll_event& event = events[i];
switch (event.data.u64) {
case kEpollConstSocket: {
unique_fd new_framework_fd(accept4(sock_fd_.get(), nullptr, nullptr,
SOCK_CLOEXEC | SOCK_NONBLOCK));
if (new_framework_fd == -1) {
PLOG(FATAL) << "failed to accept framework fd";
}
LOG(INFO) << "adbd_auth: received a new framework connection";
std::lock_guard<std::mutex> lock(mutex_);
ReplaceFrameworkFd(std::move(new_framework_fd));
// Stop iterating over events: one of the later ones might be the old
// framework fd.
restart = false;
break;
}
case kEpollConstEventFd: {
// We were woken up to write something.
uint64_t dummy;
int rc = TEMP_FAILURE_RETRY(read(event_fd_.get(), &dummy, sizeof(dummy)));
if (rc != 8) {
PLOG(FATAL) << "failed to read from eventfd (rc = " << rc << ")";
}
std::lock_guard<std::mutex> lock(mutex_);
UpdateFrameworkWritable();
break;
}
case kEpollConstFramework: {
char buf[4096];
if (event.events & EPOLLIN) {
int rc = TEMP_FAILURE_RETRY(read(framework_fd_.get(), buf, sizeof(buf)));
if (rc == -1) {
LOG(FATAL) << "failed to read from framework fd";
} else if (rc == 0) {
LOG(INFO) << "hit EOF on framework fd";
std::lock_guard<std::mutex> lock(mutex_);
ReplaceFrameworkFd(unique_fd());
} else {
HandlePacket(std::string_view(buf, rc));
}
}
if (event.events & EPOLLOUT) {
std::lock_guard<std::mutex> lock(mutex_);
while (SendPacket()) {
continue;
}
UpdateFrameworkWritable();
}
break;
}
}
}
}
}
static constexpr const char* key_paths[] = {"/adb_keys", "/data/misc/adb/adb_keys"};
void IteratePublicKeys(bool (*callback)(void*, const char*, size_t), void* opaque) {
for (const auto& path : key_paths) {
if (access(path, R_OK) == 0) {
LOG(INFO) << "Loading keys from " << path;
std::string content;
if (!android::base::ReadFileToString(path, &content)) {
PLOG(ERROR) << "Couldn't read " << path;
continue;
}
for (const auto& line : android::base::Split(content, "\n")) {
if (!callback(opaque, line.data(), line.size())) {
return;
}
}
}
}
}
uint64_t PromptUser(std::string_view public_key, void* arg) EXCLUDES(mutex_) {
uint64_t id = NextId();
std::lock_guard<std::mutex> lock(mutex_);
pending_prompts_.emplace_back(id, public_key, arg);
DispatchPendingPrompt();
return id;
}
uint64_t NotifyAuthenticated(std::string_view public_key) EXCLUDES(mutex_) {
uint64_t id = NextId();
std::lock_guard<std::mutex> lock(mutex_);
keys_.emplace(id, public_key);
output_queue_.emplace_back(
AdbdAuthPacketAuthenticated{.public_key = std::string(public_key)});
return id;
}
void NotifyDisconnected(uint64_t id) EXCLUDES(mutex_) {
std::lock_guard<std::mutex> lock(mutex_);
auto it = keys_.find(id);
if (it == keys_.end()) {
LOG(DEBUG) << "couldn't find public key to notify disconnection, skipping";
return;
}
output_queue_.emplace_back(AdbdAuthPacketDisconnected{.public_key = std::move(it->second)});
keys_.erase(it);
}
uint64_t NotifyTlsDeviceConnected(AdbTransportType type,
std::string_view public_key) EXCLUDES(mutex_) {
uint64_t id = NextId();
std::lock_guard<std::mutex> lock(mutex_);
keys_.emplace(id, public_key);
output_queue_.emplace_back(AdbdPacketTlsDeviceConnected{
.transport_type = static_cast<uint8_t>(type),
.public_key = std::string(public_key)});
Interrupt();
return id;
}
void NotifyTlsDeviceDisconnected(AdbTransportType type, uint64_t id) EXCLUDES(mutex_) {
std::lock_guard<std::mutex> lock(mutex_);
auto it = keys_.find(id);
if (it == keys_.end()) {
LOG(DEBUG) << "couldn't find public key to notify disconnection of tls device, skipping";
return;
}
output_queue_.emplace_back(AdbdPacketTlsDeviceDisconnected{
.transport_type = static_cast<uint8_t>(type),
.public_key = std::move(it->second)});
keys_.erase(it);
Interrupt();
}
// Interrupt the worker thread to do some work.
void Interrupt() {
uint64_t value = 1;
ssize_t rc = write(event_fd_.get(), &value, sizeof(value));
if (rc == -1) {
PLOG(FATAL) << "write to eventfd failed";
} else if (rc != sizeof(value)) {
LOG(FATAL) << "write to eventfd returned short (" << rc << ")";
}
}
void InitFrameworkHandlers() {
// Framework wants to disconnect from a secured wifi device
framework_handlers_.emplace_back(
FrameworkPktHandler{
.code = "DD",
.cb = std::bind(&AdbdAuthContext::KeyRemoved, this, std::placeholders::_1)});
// Framework allows USB debugging for the device
framework_handlers_.emplace_back(
FrameworkPktHandler{
.code = "OK",
.cb = std::bind(&AdbdAuthContext::AllowUsbDevice, this, std::placeholders::_1)});
// Framework denies USB debugging for the device
framework_handlers_.emplace_back(
FrameworkPktHandler{
.code = "NO",
.cb = std::bind(&AdbdAuthContext::DenyUsbDevice, this, std::placeholders::_1)});
}
unique_fd epoll_fd_;
unique_fd event_fd_;
unique_fd sock_fd_;
unique_fd framework_fd_;
std::atomic<uint64_t> next_id_;
AdbdAuthCallbacksV1 callbacks_;
std::mutex mutex_;
std::unordered_map<uint64_t, std::string> keys_ GUARDED_BY(mutex_);
// We keep two separate queues: one to handle backpressure from the socket (output_queue_)
// and one to make sure we only dispatch one authrequest at a time (pending_prompts_).
std::deque<AdbdAuthPacket> output_queue_ GUARDED_BY(mutex_);
std::optional<std::tuple<uint64_t, std::string, void*>> dispatched_prompt_ GUARDED_BY(mutex_);
std::deque<std::tuple<uint64_t, std::string, void*>> pending_prompts_ GUARDED_BY(mutex_);
// This is a list of commands that the framework could send to us.
using FrameworkHandlerCb = std::function<void(std::string_view)>;
struct FrameworkPktHandler {
const char* code;
FrameworkHandlerCb cb;
};
std::vector<FrameworkPktHandler> framework_handlers_;
};
AdbdAuthContext* adbd_auth_new(AdbdAuthCallbacks* callbacks) {
if (callbacks->version == 1) {
return new AdbdAuthContext(reinterpret_cast<AdbdAuthCallbacksV1*>(callbacks));
} else {
LOG(ERROR) << "received unknown AdbdAuthCallbacks version " << callbacks->version;
return nullptr;
}
}
void adbd_auth_delete(AdbdAuthContext* ctx) {
delete ctx;
}
void adbd_auth_run(AdbdAuthContext* ctx) {
return ctx->Run();
}
void adbd_auth_get_public_keys(AdbdAuthContext* ctx,
bool (*callback)(void* opaque, const char* public_key, size_t len),
void* opaque) {
ctx->IteratePublicKeys(callback, opaque);
}
uint64_t adbd_auth_notify_auth(AdbdAuthContext* ctx, const char* public_key, size_t len) {
return ctx->NotifyAuthenticated(std::string_view(public_key, len));
}
void adbd_auth_notify_disconnect(AdbdAuthContext* ctx, uint64_t id) {
return ctx->NotifyDisconnected(id);
}
void adbd_auth_prompt_user(AdbdAuthContext* ctx, const char* public_key, size_t len,
void* opaque) {
ctx->PromptUser(std::string_view(public_key, len), opaque);
}
uint64_t adbd_auth_tls_device_connected(AdbdAuthContext* ctx,
AdbTransportType type,
const char* public_key,
size_t len) {
return ctx->NotifyTlsDeviceConnected(type, std::string_view(public_key, len));
}
void adbd_auth_tls_device_disconnected(AdbdAuthContext* ctx,
AdbTransportType type,
uint64_t id) {
ctx->NotifyTlsDeviceDisconnected(type, id);
}
uint32_t adbd_auth_get_max_version() {
return kAuthVersion;
}
bool adbd_auth_supports_feature(AdbdAuthFeature f) {
UNUSED(f);
return false;
}