adb: Add a way to reconnect TCP transports
This change adds a reconnect handler that tracks all TCP transports that
were connected at some point, but became disconnected. It does so by
attempting to reconnect every 10s for up to a minute.
Bug: 74411879
Test: system/core/adb/test_adb.py
Test: adb connect chromebook:22 # This runs with sslh
Test: CtsBootStatsTestCases
Test: emulator -show-kernel ; adb -s emulator-5554 shell
Change-Id: I7b9f6d181b71ccf5c26ff96c45d36aaf6409b992
diff --git a/adb/client/main.cpp b/adb/client/main.cpp
index 31cb853..44ed3a2 100644
--- a/adb/client/main.cpp
+++ b/adb/client/main.cpp
@@ -117,6 +117,7 @@
atexit(adb_server_cleanup);
init_transport_registration();
+ init_reconnect_handler();
init_mdns_transport_discovery();
usb_init();
diff --git a/adb/test_adb.py b/adb/test_adb.py
index 32bf029..ce4d4ec 100644
--- a/adb/test_adb.py
+++ b/adb/test_adb.py
@@ -75,9 +75,11 @@
else:
# Client socket
data = r.recv(1024)
- if not data:
+ if not data or data.startswith('OPEN'):
if r in cnxn_sent:
del cnxn_sent[r]
+ r.shutdown(socket.SHUT_RDWR)
+ r.close()
rlist.remove(r)
continue
if r in cnxn_sent:
@@ -97,6 +99,25 @@
server_thread.join()
+@contextlib.contextmanager
+def adb_connect(unittest, serial):
+ """Context manager for an ADB connection.
+
+ This automatically disconnects when done with the connection.
+ """
+
+ output = subprocess.check_output(['adb', 'connect', serial])
+ unittest.assertEqual(output.strip(), 'connected to {}'.format(serial))
+
+ try:
+ yield
+ finally:
+ # Perform best-effort disconnection. Discard the output.
+ p = subprocess.Popen(['adb', 'disconnect', serial],
+ stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ p.communicate()
+
+
class NonApiTest(unittest.TestCase):
"""Tests for ADB that aren't a part of the AndroidDevice API."""
@@ -278,29 +299,60 @@
for protocol in (socket.AF_INET, socket.AF_INET6):
try:
with fake_adb_server(protocol=protocol) as port:
- output = subprocess.check_output(
- ['adb', 'connect', 'localhost:{}'.format(port)])
-
- self.assertEqual(
- output.strip(), 'connected to localhost:{}'.format(port))
+ serial = 'localhost:{}'.format(port)
+ with adb_connect(self, serial):
+ pass
except socket.error:
print("IPv6 not available, skipping")
continue
def test_already_connected(self):
+ """Ensure that an already-connected device stays connected."""
+
with fake_adb_server() as port:
- output = subprocess.check_output(
- ['adb', 'connect', 'localhost:{}'.format(port)])
+ serial = 'localhost:{}'.format(port)
+ with adb_connect(self, serial):
+ # b/31250450: this always returns 0 but probably shouldn't.
+ output = subprocess.check_output(['adb', 'connect', serial])
+ self.assertEqual(
+ output.strip(), 'already connected to {}'.format(serial))
- self.assertEqual(
- output.strip(), 'connected to localhost:{}'.format(port))
+ def test_reconnect(self):
+ """Ensure that a disconnected device reconnects."""
- # b/31250450: this always returns 0 but probably shouldn't.
- output = subprocess.check_output(
- ['adb', 'connect', 'localhost:{}'.format(port)])
+ with fake_adb_server() as port:
+ serial = 'localhost:{}'.format(port)
+ with adb_connect(self, serial):
+ output = subprocess.check_output(['adb', '-s', serial,
+ 'get-state'])
+ self.assertEqual(output.strip(), 'device')
- self.assertEqual(
- output.strip(), 'already connected to localhost:{}'.format(port))
+ # This will fail.
+ p = subprocess.Popen(['adb', '-s', serial, 'shell', 'true'],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT)
+ output, _ = p.communicate()
+ self.assertEqual(output.strip(), 'error: closed')
+
+ subprocess.check_call(['adb', '-s', serial, 'wait-for-device'])
+
+ output = subprocess.check_output(['adb', '-s', serial,
+ 'get-state'])
+ self.assertEqual(output.strip(), 'device')
+
+ # Once we explicitly kick a device, it won't attempt to
+ # reconnect.
+ output = subprocess.check_output(['adb', 'disconnect', serial])
+ self.assertEqual(
+ output.strip(), 'disconnected {}'.format(serial))
+ try:
+ subprocess.check_output(['adb', '-s', serial, 'get-state'],
+ stderr=subprocess.STDOUT)
+ self.fail('Device should not be available')
+ except subprocess.CalledProcessError as e:
+ self.assertEqual(
+ e.output.strip(),
+ 'error: device \'{}\' not found'.format(serial))
def main():
random.seed(0)
diff --git a/adb/transport.cpp b/adb/transport.cpp
index be7f8fe..beec13a 100644
--- a/adb/transport.cpp
+++ b/adb/transport.cpp
@@ -33,6 +33,7 @@
#include <deque>
#include <list>
#include <mutex>
+#include <queue>
#include <thread>
#include <android-base/logging.h>
@@ -50,7 +51,9 @@
#include "adb_utils.h"
#include "fdevent.h"
-static void transport_unref(atransport *t);
+static void register_transport(atransport* transport);
+static void remove_transport(atransport* transport);
+static void transport_unref(atransport* transport);
// TODO: unordered_map<TransportId, atransport*>
static auto& transport_list = *new std::list<atransport*>();
@@ -77,6 +80,130 @@
~ScopedAssumeLocked() RELEASE() {}
};
+// Tracks and handles atransport*s that are attempting reconnection.
+class ReconnectHandler {
+ public:
+ ReconnectHandler() = default;
+ ~ReconnectHandler() = default;
+
+ // Starts the ReconnectHandler thread.
+ void Start();
+
+ // Requests the ReconnectHandler thread to stop.
+ void Stop();
+
+ // Adds the atransport* to the queue of reconnect attempts.
+ void TrackTransport(atransport* transport);
+
+ private:
+ // The main thread loop.
+ void Run();
+
+ // Tracks a reconnection attempt.
+ struct ReconnectAttempt {
+ atransport* transport;
+ std::chrono::system_clock::time_point deadline;
+ size_t attempts_left;
+ };
+
+ // Only retry for up to one minute.
+ static constexpr const std::chrono::seconds kDefaultTimeout = std::chrono::seconds(10);
+ static constexpr const size_t kMaxAttempts = 6;
+
+ // Protects all members.
+ std::mutex reconnect_mutex_;
+ bool running_ GUARDED_BY(reconnect_mutex_) = true;
+ std::thread handler_thread_;
+ std::condition_variable reconnect_cv_;
+ std::queue<ReconnectAttempt> reconnect_queue_ GUARDED_BY(reconnect_mutex_);
+
+ DISALLOW_COPY_AND_ASSIGN(ReconnectHandler);
+};
+
+void ReconnectHandler::Start() {
+ check_main_thread();
+ handler_thread_ = std::thread(&ReconnectHandler::Run, this);
+}
+
+void ReconnectHandler::Stop() {
+ check_main_thread();
+ {
+ std::lock_guard<std::mutex> lock(reconnect_mutex_);
+ running_ = false;
+ }
+ reconnect_cv_.notify_one();
+ handler_thread_.join();
+
+ // Drain the queue to free all resources.
+ std::lock_guard<std::mutex> lock(reconnect_mutex_);
+ while (!reconnect_queue_.empty()) {
+ ReconnectAttempt attempt = reconnect_queue_.front();
+ reconnect_queue_.pop();
+ remove_transport(attempt.transport);
+ }
+}
+
+void ReconnectHandler::TrackTransport(atransport* transport) {
+ check_main_thread();
+ {
+ std::lock_guard<std::mutex> lock(reconnect_mutex_);
+ if (!running_) return;
+ reconnect_queue_.emplace(ReconnectAttempt{
+ transport, std::chrono::system_clock::now() + ReconnectHandler::kDefaultTimeout,
+ ReconnectHandler::kMaxAttempts});
+ }
+ reconnect_cv_.notify_one();
+}
+
+void ReconnectHandler::Run() {
+ while (true) {
+ ReconnectAttempt attempt;
+ {
+ std::unique_lock<std::mutex> lock(reconnect_mutex_);
+ ScopedAssumeLocked assume_lock(reconnect_mutex_);
+
+ auto deadline = std::chrono::time_point<std::chrono::system_clock>::max();
+ if (!reconnect_queue_.empty()) deadline = reconnect_queue_.front().deadline;
+ reconnect_cv_.wait_until(lock, deadline, [&]() REQUIRES(reconnect_mutex_) {
+ return !running_ ||
+ (!reconnect_queue_.empty() && reconnect_queue_.front().deadline < deadline);
+ });
+
+ if (!running_) return;
+ attempt = reconnect_queue_.front();
+ reconnect_queue_.pop();
+ if (attempt.transport->kicked()) {
+ D("transport %s was kicked. giving up on it.", attempt.transport->serial);
+ remove_transport(attempt.transport);
+ continue;
+ }
+ }
+ D("attempting to reconnect %s", attempt.transport->serial);
+
+ if (!attempt.transport->Reconnect()) {
+ D("attempting to reconnect %s failed.", attempt.transport->serial);
+ if (attempt.attempts_left == 0) {
+ D("transport %s exceeded the number of retry attempts. giving up on it.",
+ attempt.transport->serial);
+ remove_transport(attempt.transport);
+ continue;
+ }
+
+ std::lock_guard<std::mutex> lock(reconnect_mutex_);
+ reconnect_queue_.emplace(ReconnectAttempt{
+ attempt.transport,
+ std::chrono::system_clock::now() + ReconnectHandler::kDefaultTimeout,
+ attempt.attempts_left - 1});
+ continue;
+ }
+
+ D("reconnection to %s succeeded.", attempt.transport->serial);
+ register_transport(attempt.transport);
+ }
+}
+
+static auto& reconnect_handler = *new ReconnectHandler();
+
} // namespace
TransportId NextTransportId() {
@@ -477,8 +604,6 @@
return 0;
}
-static void remove_transport(atransport*);
-
static void transport_registration_func(int _fd, unsigned ev, void*) {
tmsg m;
atransport* t;
@@ -515,8 +640,9 @@
/* don't create transport threads for inaccessible devices */
if (t->GetConnectionState() != kCsNoPerm) {
- /* initial references are the two threads */
- t->ref_count = 1;
+ // The connection gets a reference to the atransport. It will release it
+ // upon a read/write error.
+ t->ref_count++;
t->connection()->SetTransportName(t->serial_name());
t->connection()->SetReadCallback([t](Connection*, std::unique_ptr<apacket> p) {
if (!check_header(p.get(), t)) {
@@ -547,13 +673,20 @@
{
std::lock_guard<std::recursive_mutex> lock(transport_lock);
- pending_list.remove(t);
- transport_list.push_front(t);
+ auto it = std::find(pending_list.begin(), pending_list.end(), t);
+ if (it != pending_list.end()) {
+ pending_list.remove(t);
+ transport_list.push_front(t);
+ }
}
update_transports();
}
+void init_reconnect_handler(void) {
+ reconnect_handler.Start();
+}
+
void init_transport_registration(void) {
int s[2];
@@ -572,6 +705,7 @@
}
void kick_all_transports() {
+ reconnect_handler.Stop();
// To avoid only writing part of a packet to a transport after exit, kick all transports.
std::lock_guard<std::recursive_mutex> lock(transport_lock);
for (auto t : transport_list) {
@@ -601,15 +735,21 @@
}
static void transport_unref(atransport* t) {
+ check_main_thread();
CHECK(t != nullptr);
std::lock_guard<std::recursive_mutex> lock(transport_lock);
CHECK_GT(t->ref_count, 0u);
t->ref_count--;
if (t->ref_count == 0) {
- D("transport: %s unref (kicking and closing)", t->serial);
t->connection()->Stop();
- remove_transport(t);
+ if (t->IsTcpDevice() && !t->kicked()) {
+ D("transport: %s unref (attempting reconnection) %d", t->serial, t->kicked());
+ reconnect_handler.TrackTransport(t);
+ } else {
+ D("transport: %s unref (kicking and closing)", t->serial);
+ remove_transport(t);
+ }
} else {
D("transport: %s unref (count=%zu)", t->serial, t->ref_count);
}
@@ -781,9 +921,8 @@
}
void atransport::Kick() {
- if (!kicked_) {
- D("kicking transport %s", this->serial);
- kicked_ = true;
+ if (!kicked_.exchange(true)) {
+ D("kicking transport %p %s", this, this->serial);
this->connection()->Stop();
}
}
@@ -941,6 +1080,10 @@
connection_waitable_->SetConnectionEstablished(success);
}
+bool atransport::Reconnect() {
+ return reconnect_(this);
+}
+
#if ADB_HOST
// We use newline as our delimiter, make sure to never output it.
@@ -1021,8 +1164,9 @@
}
#endif // ADB_HOST
-int register_socket_transport(int s, const char* serial, int port, int local) {
- atransport* t = new atransport();
+int register_socket_transport(int s, const char* serial, int port, int local,
+ atransport::ReconnectCallback reconnect) {
+ atransport* t = new atransport(std::move(reconnect), kCsOffline);
if (!serial) {
char buf[32];
@@ -1103,7 +1247,7 @@
void register_usb_transport(usb_handle* usb, const char* serial, const char* devpath,
unsigned writeable) {
- atransport* t = new atransport((writeable ? kCsConnecting : kCsNoPerm));
+ atransport* t = new atransport(writeable ? kCsOffline : kCsNoPerm);
D("transport: %p init'ing for usb_handle %p (sn='%s')", t, usb, serial ? serial : "");
init_usb_transport(t, usb);
diff --git a/adb/transport.h b/adb/transport.h
index e1cbc09..ae9cc02 100644
--- a/adb/transport.h
+++ b/adb/transport.h
@@ -198,20 +198,27 @@
// class in one go is a very large change. Given how bad our testing is,
// it's better to do this piece by piece.
- atransport(ConnectionState state = kCsConnecting)
+ using ReconnectCallback = std::function<bool(atransport*)>;
+
+ atransport(ReconnectCallback reconnect, ConnectionState state)
: id(NextTransportId()),
+ kicked_(false),
connection_state_(state),
connection_waitable_(std::make_shared<ConnectionWaitable>()),
- connection_(nullptr) {
+ connection_(nullptr),
+ reconnect_(std::move(reconnect)) {
// Initialize protocol to min version for compatibility with older versions.
// Version will be updated post-connect.
protocol_version = A_VERSION_MIN;
max_payload = MAX_PAYLOAD;
}
+ atransport(ConnectionState state = kCsOffline)
+ : atransport([](atransport*) { return false; }, state) {}
virtual ~atransport();
int Write(apacket* p);
void Kick();
+ bool kicked() const { return kicked_; }
// ConnectionState can be read by all threads, but can only be written in the main thread.
ConnectionState GetConnectionState() const;
@@ -286,8 +293,12 @@
// Gets a shared reference to the ConnectionWaitable.
std::shared_ptr<ConnectionWaitable> connection_waitable() { return connection_waitable_; }
+ // Attempts to reconnect with the underlying Connection. Returns true if the
+ // reconnection attempt succeeded.
+ bool Reconnect();
+
private:
- bool kicked_ = false;
+ std::atomic<bool> kicked_;
// A set of features transmitted in the banner with the initial connection.
// This is stored in the banner as 'features=feature0,feature1,etc'.
@@ -310,6 +321,9 @@
// The underlying connection object.
std::shared_ptr<Connection> connection_ GUARDED_BY(mutex_);
+ // A callback that will be invoked when the atransport needs to reconnect.
+ ReconnectCallback reconnect_;
+
std::mutex mutex_;
DISALLOW_COPY_AND_ASSIGN(atransport);
@@ -333,6 +347,7 @@
// Stops iteration and returns false if fn returns false, otherwise returns true.
bool iterate_transports(std::function<bool(const atransport*)> fn);
+void init_reconnect_handler(void);
void init_transport_registration(void);
void init_mdns_transport_discovery(void);
std::string list_transports(bool long_listing);
@@ -347,7 +362,8 @@
void connect_device(const std::string& address, std::string* response);
/* cause new transports to be init'd and added to the list */
-int register_socket_transport(int s, const char* serial, int port, int local);
+int register_socket_transport(int s, const char* serial, int port, int local,
+ atransport::ReconnectCallback reconnect);
// This should only be used for transports with connection_state == kCsNoPerm.
void unregister_usb_transport(usb_handle* usb);
diff --git a/adb/transport_local.cpp b/adb/transport_local.cpp
index e81f27c..181d666 100644
--- a/adb/transport_local.cpp
+++ b/adb/transport_local.cpp
@@ -68,28 +68,24 @@
return local_connect_arbitrary_ports(port - 1, port, &dummy) == 0;
}
-void connect_device(const std::string& address, std::string* response) {
- if (address.empty()) {
- *response = "empty address";
- return;
- }
-
+std::tuple<unique_fd, int, std::string> tcp_connect(const std::string& address,
+ std::string* response) {
std::string serial;
std::string host;
int port = DEFAULT_ADB_LOCAL_TRANSPORT_PORT;
if (!android::base::ParseNetAddress(address, &host, &port, &serial, response)) {
- return;
+ return std::make_tuple(unique_fd(), port, serial);
}
std::string error;
- int fd = network_connect(host.c_str(), port, SOCK_STREAM, 10, &error);
+ unique_fd fd(network_connect(host.c_str(), port, SOCK_STREAM, 10, &error));
if (fd == -1) {
*response = android::base::StringPrintf("unable to connect to %s: %s",
serial.c_str(), error.c_str());
- return;
+ return std::make_tuple(std::move(fd), port, serial);
}
- D("client: connected %s remote on fd %d", serial.c_str(), fd);
+ D("client: connected %s remote on fd %d", serial.c_str(), fd.get());
close_on_exec(fd);
disable_tcp_nagle(fd);
@@ -98,7 +94,38 @@
D("warning: failed to configure TCP keepalives (%s)", strerror(errno));
}
- int ret = register_socket_transport(fd, serial.c_str(), port, 0);
+ return std::make_tuple(std::move(fd), port, serial);
+}
+
+void connect_device(const std::string& address, std::string* response) {
+ if (address.empty()) {
+ *response = "empty address";
+ return;
+ }
+
+ unique_fd fd;
+ int port;
+ std::string serial;
+ std::tie(fd, port, serial) = tcp_connect(address, response);
+ auto reconnect = [address](atransport* t) {
+ std::string response;
+ unique_fd fd;
+ int port;
+ std::string serial;
+ std::tie(fd, port, serial) = tcp_connect(address, &response);
+ if (fd == -1) {
+ D("reconnect failed: %s", response.c_str());
+ return false;
+ }
+
+ // This invokes the part of register_socket_transport() that needs to be
+ // invoked if the atransport* has already been setup. This eventually
+ // calls atransport->SetConnection() with a newly created Connection*
+ // that will in turn send the CNXN packet.
+ return init_socket_transport(t, fd.release(), port, 0) >= 0;
+ };
+
+ int ret = register_socket_transport(fd.release(), serial.c_str(), port, 0, std::move(reconnect));
if (ret < 0) {
adb_close(fd);
if (ret == -EALREADY) {
@@ -135,7 +162,8 @@
close_on_exec(fd);
disable_tcp_nagle(fd);
std::string serial = getEmulatorSerialString(console_port);
- if (register_socket_transport(fd, serial.c_str(), adb_port, 1) == 0) {
+ if (register_socket_transport(fd, serial.c_str(), adb_port, 1,
+ [](atransport*) { return false; }) == 0) {
return 0;
}
adb_close(fd);
@@ -239,7 +267,8 @@
close_on_exec(fd);
disable_tcp_nagle(fd);
std::string serial = android::base::StringPrintf("host-%d", fd);
- if (register_socket_transport(fd, serial.c_str(), port, 1) != 0) {
+ if (register_socket_transport(fd, serial.c_str(), port, 1,
+ [](atransport*) { return false; }) != 0) {
adb_close(fd);
}
}
@@ -338,7 +367,8 @@
/* Host is connected. Register the transport, and start the
* exchange. */
std::string serial = android::base::StringPrintf("host-%d", fd);
- if (register_socket_transport(fd, serial.c_str(), port, 1) != 0 ||
+ if (register_socket_transport(fd, serial.c_str(), port, 1,
+ [](atransport*) { return false; }) != 0 ||
!WriteFdExactly(fd, _start_req, strlen(_start_req))) {
adb_close(fd);
}