adb: Add a way to reconnect TCP transports

This change adds a reconnect handler that tracks all TCP transports that
were connected at some point, but became disconnected. It does so by
attempting to reconnect every 10s for up to a minute.

Bug: 74411879
Test: system/core/adb/test_adb.py
Test: adb connect chromebook:22  # This runs with sslh
Test: CtsBootStatsTestCases
Test: emulator -show-kernel ; adb -s emulator-5554 shell

Change-Id: I7b9f6d181b71ccf5c26ff96c45d36aaf6409b992
diff --git a/adb/client/main.cpp b/adb/client/main.cpp
index 31cb853..44ed3a2 100644
--- a/adb/client/main.cpp
+++ b/adb/client/main.cpp
@@ -117,6 +117,7 @@
     atexit(adb_server_cleanup);
 
     init_transport_registration();
+    init_reconnect_handler();
     init_mdns_transport_discovery();
 
     usb_init();
diff --git a/adb/test_adb.py b/adb/test_adb.py
index 32bf029..ce4d4ec 100644
--- a/adb/test_adb.py
+++ b/adb/test_adb.py
@@ -75,9 +75,11 @@
                 else:
                     # Client socket
                     data = r.recv(1024)
-                    if not data:
+                    if not data or data.startswith('OPEN'):
                         if r in cnxn_sent:
                             del cnxn_sent[r]
+                        r.shutdown(socket.SHUT_RDWR)
+                        r.close()
                         rlist.remove(r)
                         continue
                     if r in cnxn_sent:
@@ -97,6 +99,25 @@
         server_thread.join()
 
 
+@contextlib.contextmanager
+def adb_connect(unittest, serial):
+    """Context manager for an ADB connection.
+
+    This automatically disconnects when done with the connection.
+    """
+
+    output = subprocess.check_output(['adb', 'connect', serial])
+    unittest.assertEqual(output.strip(), 'connected to {}'.format(serial))
+
+    try:
+        yield
+    finally:
+        # Perform best-effort disconnection. Discard the output.
+        p = subprocess.Popen(['adb', 'disconnect', serial],
+                             stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+        p.communicate()
+
+
 class NonApiTest(unittest.TestCase):
     """Tests for ADB that aren't a part of the AndroidDevice API."""
 
@@ -278,29 +299,60 @@
         for protocol in (socket.AF_INET, socket.AF_INET6):
             try:
                 with fake_adb_server(protocol=protocol) as port:
-                    output = subprocess.check_output(
-                        ['adb', 'connect', 'localhost:{}'.format(port)])
-
-                    self.assertEqual(
-                        output.strip(), 'connected to localhost:{}'.format(port))
+                    serial = 'localhost:{}'.format(port)
+                    with adb_connect(self, serial):
+                        pass
             except socket.error:
                 print("IPv6 not available, skipping")
                 continue
 
     def test_already_connected(self):
+        """Ensure that an already-connected device stays connected."""
+
         with fake_adb_server() as port:
-            output = subprocess.check_output(
-                ['adb', 'connect', 'localhost:{}'.format(port)])
+            serial = 'localhost:{}'.format(port)
+            with adb_connect(self, serial):
+                # b/31250450: this always returns 0 but probably shouldn't.
+                output = subprocess.check_output(['adb', 'connect', serial])
+                self.assertEqual(
+                    output.strip(), 'already connected to {}'.format(serial))
 
-            self.assertEqual(
-                output.strip(), 'connected to localhost:{}'.format(port))
+    def test_reconnect(self):
+        """Ensure that a disconnected device reconnects."""
 
-            # b/31250450: this always returns 0 but probably shouldn't.
-            output = subprocess.check_output(
-                ['adb', 'connect', 'localhost:{}'.format(port)])
+        with fake_adb_server() as port:
+            serial = 'localhost:{}'.format(port)
+            with adb_connect(self, serial):
+                output = subprocess.check_output(['adb', '-s', serial,
+                                                  'get-state'])
+                self.assertEqual(output.strip(), 'device')
 
-            self.assertEqual(
-                output.strip(), 'already connected to localhost:{}'.format(port))
+                # This will fail.
+                p = subprocess.Popen(['adb', '-s', serial, 'shell', 'true'],
+                                     stdout=subprocess.PIPE,
+                                     stderr=subprocess.STDOUT)
+                output, _ = p.communicate()
+                self.assertEqual(output.strip(), 'error: closed')
+
+                subprocess.check_call(['adb', '-s', serial, 'wait-for-device'])
+
+                output = subprocess.check_output(['adb', '-s', serial,
+                                                  'get-state'])
+                self.assertEqual(output.strip(), 'device')
+
+                # Once we explicitly kick a device, it won't attempt to
+                # reconnect.
+                output = subprocess.check_output(['adb', 'disconnect', serial])
+                self.assertEqual(
+                    output.strip(), 'disconnected {}'.format(serial))
+                try:
+                    subprocess.check_output(['adb', '-s', serial, 'get-state'],
+                                            stderr=subprocess.STDOUT)
+                    self.fail('Device should not be available')
+                except subprocess.CalledProcessError as e:
+                    self.assertEqual(
+                        e.output.strip(),
+                        'error: device \'{}\' not found'.format(serial))
 
 def main():
     random.seed(0)
diff --git a/adb/transport.cpp b/adb/transport.cpp
index be7f8fe..beec13a 100644
--- a/adb/transport.cpp
+++ b/adb/transport.cpp
@@ -33,6 +33,7 @@
 #include <deque>
 #include <list>
 #include <mutex>
+#include <queue>
 #include <thread>
 
 #include <android-base/logging.h>
@@ -50,7 +51,9 @@
 #include "adb_utils.h"
 #include "fdevent.h"
 
-static void transport_unref(atransport *t);
+static void register_transport(atransport* transport);
+static void remove_transport(atransport* transport);
+static void transport_unref(atransport* transport);
 
 // TODO: unordered_map<TransportId, atransport*>
 static auto& transport_list = *new std::list<atransport*>();
@@ -77,6 +80,130 @@
     ~ScopedAssumeLocked() RELEASE() {}
 };
 
+// Tracks and handles atransport*s that are attempting reconnection.
+class ReconnectHandler {
+  public:
+    ReconnectHandler() = default;
+    ~ReconnectHandler() = default;
+
+    // Starts the ReconnectHandler thread.
+    void Start();
+
+    // Requests the ReconnectHandler thread to stop.
+    void Stop();
+
+    // Adds the atransport* to the queue of reconnect attempts.
+    void TrackTransport(atransport* transport);
+
+  private:
+    // The main thread loop.
+    void Run();
+
+    // Tracks a reconnection attempt.
+    struct ReconnectAttempt {
+        atransport* transport;
+        std::chrono::system_clock::time_point deadline;
+        size_t attempts_left;
+    };
+
+    // Only retry for up to one minute.
+    static constexpr const std::chrono::seconds kDefaultTimeout = std::chrono::seconds(10);
+    static constexpr const size_t kMaxAttempts = 6;
+
+    // Protects all members.
+    std::mutex reconnect_mutex_;
+    bool running_ GUARDED_BY(reconnect_mutex_) = true;
+    std::thread handler_thread_;
+    std::condition_variable reconnect_cv_;
+    std::queue<ReconnectAttempt> reconnect_queue_ GUARDED_BY(reconnect_mutex_);
+
+    DISALLOW_COPY_AND_ASSIGN(ReconnectHandler);
+};
+
+void ReconnectHandler::Start() {
+    check_main_thread();
+    handler_thread_ = std::thread(&ReconnectHandler::Run, this);
+}
+
+void ReconnectHandler::Stop() {
+    check_main_thread();
+    {
+        std::lock_guard<std::mutex> lock(reconnect_mutex_);
+        running_ = false;
+    }
+    reconnect_cv_.notify_one();
+    handler_thread_.join();
+
+    // Drain the queue to free all resources.
+    std::lock_guard<std::mutex> lock(reconnect_mutex_);
+    while (!reconnect_queue_.empty()) {
+        ReconnectAttempt attempt = reconnect_queue_.front();
+        reconnect_queue_.pop();
+        remove_transport(attempt.transport);
+    }
+}
+
+void ReconnectHandler::TrackTransport(atransport* transport) {
+    check_main_thread();
+    {
+        std::lock_guard<std::mutex> lock(reconnect_mutex_);
+        if (!running_) return;
+        reconnect_queue_.emplace(ReconnectAttempt{
+            transport, std::chrono::system_clock::now() + ReconnectHandler::kDefaultTimeout,
+            ReconnectHandler::kMaxAttempts});
+    }
+    reconnect_cv_.notify_one();
+}
+
+void ReconnectHandler::Run() {
+    while (true) {
+        ReconnectAttempt attempt;
+        {
+            std::unique_lock<std::mutex> lock(reconnect_mutex_);
+            ScopedAssumeLocked assume_lock(reconnect_mutex_);
+
+            auto deadline = std::chrono::time_point<std::chrono::system_clock>::max();
+            if (!reconnect_queue_.empty()) deadline = reconnect_queue_.front().deadline;
+            reconnect_cv_.wait_until(lock, deadline, [&]() REQUIRES(reconnect_mutex_) {
+                return !running_ ||
+                       (!reconnect_queue_.empty() && reconnect_queue_.front().deadline < deadline);
+            });
+
+            if (!running_) return;
+            attempt = reconnect_queue_.front();
+            reconnect_queue_.pop();
+            if (attempt.transport->kicked()) {
+                D("transport %s was kicked. giving up on it.", attempt.transport->serial);
+                remove_transport(attempt.transport);
+                continue;
+            }
+        }
+        D("attempting to reconnect %s", attempt.transport->serial);
+
+        if (!attempt.transport->Reconnect()) {
+            D("attempting to reconnect %s failed.", attempt.transport->serial);
+            if (attempt.attempts_left == 0) {
+                D("transport %s exceeded the number of retry attempts. giving up on it.",
+                  attempt.transport->serial);
+                remove_transport(attempt.transport);
+                continue;
+            }
+
+            std::lock_guard<std::mutex> lock(reconnect_mutex_);
+            reconnect_queue_.emplace(ReconnectAttempt{
+                attempt.transport,
+                std::chrono::system_clock::now() + ReconnectHandler::kDefaultTimeout,
+                attempt.attempts_left - 1});
+            continue;
+        }
+
+        D("reconnection to %s succeeded.", attempt.transport->serial);
+        register_transport(attempt.transport);
+    }
+}
+
+static auto& reconnect_handler = *new ReconnectHandler();
+
 }  // namespace
 
 TransportId NextTransportId() {
@@ -477,8 +604,6 @@
     return 0;
 }
 
-static void remove_transport(atransport*);
-
 static void transport_registration_func(int _fd, unsigned ev, void*) {
     tmsg m;
     atransport* t;
@@ -515,8 +640,9 @@
 
     /* don't create transport threads for inaccessible devices */
     if (t->GetConnectionState() != kCsNoPerm) {
-        /* initial references are the two threads */
-        t->ref_count = 1;
+        // The connection gets a reference to the atransport. It will release it
+        // upon a read/write error.
+        t->ref_count++;
         t->connection()->SetTransportName(t->serial_name());
         t->connection()->SetReadCallback([t](Connection*, std::unique_ptr<apacket> p) {
             if (!check_header(p.get(), t)) {
@@ -547,13 +673,20 @@
 
     {
         std::lock_guard<std::recursive_mutex> lock(transport_lock);
-        pending_list.remove(t);
-        transport_list.push_front(t);
+        auto it = std::find(pending_list.begin(), pending_list.end(), t);
+        if (it != pending_list.end()) {
+            pending_list.remove(t);
+            transport_list.push_front(t);
+        }
     }
 
     update_transports();
 }
 
+void init_reconnect_handler(void) {
+    reconnect_handler.Start();
+}
+
 void init_transport_registration(void) {
     int s[2];
 
@@ -572,6 +705,7 @@
 }
 
 void kick_all_transports() {
+    reconnect_handler.Stop();
     // To avoid only writing part of a packet to a transport after exit, kick all transports.
     std::lock_guard<std::recursive_mutex> lock(transport_lock);
     for (auto t : transport_list) {
@@ -601,15 +735,21 @@
 }
 
 static void transport_unref(atransport* t) {
+    check_main_thread();
     CHECK(t != nullptr);
 
     std::lock_guard<std::recursive_mutex> lock(transport_lock);
     CHECK_GT(t->ref_count, 0u);
     t->ref_count--;
     if (t->ref_count == 0) {
-        D("transport: %s unref (kicking and closing)", t->serial);
         t->connection()->Stop();
-        remove_transport(t);
+        if (t->IsTcpDevice() && !t->kicked()) {
+            D("transport: %s unref (attempting reconnection) %d", t->serial, t->kicked());
+            reconnect_handler.TrackTransport(t);
+        } else {
+            D("transport: %s unref (kicking and closing)", t->serial);
+            remove_transport(t);
+        }
     } else {
         D("transport: %s unref (count=%zu)", t->serial, t->ref_count);
     }
@@ -781,9 +921,8 @@
 }
 
 void atransport::Kick() {
-    if (!kicked_) {
-        D("kicking transport %s", this->serial);
-        kicked_ = true;
+    if (!kicked_.exchange(true)) {
+        D("kicking transport %p %s", this, this->serial);
         this->connection()->Stop();
     }
 }
@@ -941,6 +1080,10 @@
     connection_waitable_->SetConnectionEstablished(success);
 }
 
+bool atransport::Reconnect() {
+    return reconnect_(this);
+}
+
 #if ADB_HOST
 
 // We use newline as our delimiter, make sure to never output it.
@@ -1021,8 +1164,9 @@
 }
 #endif  // ADB_HOST
 
-int register_socket_transport(int s, const char* serial, int port, int local) {
-    atransport* t = new atransport();
+int register_socket_transport(int s, const char* serial, int port, int local,
+                              atransport::ReconnectCallback reconnect) {
+    atransport* t = new atransport(std::move(reconnect), kCsOffline);
 
     if (!serial) {
         char buf[32];
@@ -1103,7 +1247,7 @@
 
 void register_usb_transport(usb_handle* usb, const char* serial, const char* devpath,
                             unsigned writeable) {
-    atransport* t = new atransport((writeable ? kCsConnecting : kCsNoPerm));
+    atransport* t = new atransport(writeable ? kCsOffline : kCsNoPerm);
 
     D("transport: %p init'ing for usb_handle %p (sn='%s')", t, usb, serial ? serial : "");
     init_usb_transport(t, usb);
diff --git a/adb/transport.h b/adb/transport.h
index e1cbc09..ae9cc02 100644
--- a/adb/transport.h
+++ b/adb/transport.h
@@ -198,20 +198,27 @@
     // class in one go is a very large change. Given how bad our testing is,
     // it's better to do this piece by piece.
 
-    atransport(ConnectionState state = kCsConnecting)
+    using ReconnectCallback = std::function<bool(atransport*)>;
+
+    atransport(ReconnectCallback reconnect, ConnectionState state)
         : id(NextTransportId()),
+          kicked_(false),
           connection_state_(state),
           connection_waitable_(std::make_shared<ConnectionWaitable>()),
-          connection_(nullptr) {
+          connection_(nullptr),
+          reconnect_(std::move(reconnect)) {
         // Initialize protocol to min version for compatibility with older versions.
         // Version will be updated post-connect.
         protocol_version = A_VERSION_MIN;
         max_payload = MAX_PAYLOAD;
     }
+    atransport(ConnectionState state = kCsOffline)
+        : atransport([](atransport*) { return false; }, state) {}
     virtual ~atransport();
 
     int Write(apacket* p);
     void Kick();
+    bool kicked() const { return kicked_; }
 
     // ConnectionState can be read by all threads, but can only be written in the main thread.
     ConnectionState GetConnectionState() const;
@@ -286,8 +293,12 @@
     // Gets a shared reference to the ConnectionWaitable.
     std::shared_ptr<ConnectionWaitable> connection_waitable() { return connection_waitable_; }
 
+    // Attempts to reconnect with the underlying Connection. Returns true if the
+    // reconnection attempt succeeded.
+    bool Reconnect();
+
   private:
-    bool kicked_ = false;
+    std::atomic<bool> kicked_;
 
     // A set of features transmitted in the banner with the initial connection.
     // This is stored in the banner as 'features=feature0,feature1,etc'.
@@ -310,6 +321,9 @@
     // The underlying connection object.
     std::shared_ptr<Connection> connection_ GUARDED_BY(mutex_);
 
+    // A callback that will be invoked when the atransport needs to reconnect.
+    ReconnectCallback reconnect_;
+
     std::mutex mutex_;
 
     DISALLOW_COPY_AND_ASSIGN(atransport);
@@ -333,6 +347,7 @@
 // Stops iteration and returns false if fn returns false, otherwise returns true.
 bool iterate_transports(std::function<bool(const atransport*)> fn);
 
+void init_reconnect_handler(void);
 void init_transport_registration(void);
 void init_mdns_transport_discovery(void);
 std::string list_transports(bool long_listing);
@@ -347,7 +362,8 @@
 void connect_device(const std::string& address, std::string* response);
 
 /* cause new transports to be init'd and added to the list */
-int register_socket_transport(int s, const char* serial, int port, int local);
+int register_socket_transport(int s, const char* serial, int port, int local,
+                              atransport::ReconnectCallback reconnect);
 
 // This should only be used for transports with connection_state == kCsNoPerm.
 void unregister_usb_transport(usb_handle* usb);
diff --git a/adb/transport_local.cpp b/adb/transport_local.cpp
index e81f27c..181d666 100644
--- a/adb/transport_local.cpp
+++ b/adb/transport_local.cpp
@@ -68,28 +68,24 @@
     return local_connect_arbitrary_ports(port - 1, port, &dummy) == 0;
 }
 
-void connect_device(const std::string& address, std::string* response) {
-    if (address.empty()) {
-        *response = "empty address";
-        return;
-    }
-
+std::tuple<unique_fd, int, std::string> tcp_connect(const std::string& address,
+                                                    std::string* response) {
     std::string serial;
     std::string host;
     int port = DEFAULT_ADB_LOCAL_TRANSPORT_PORT;
     if (!android::base::ParseNetAddress(address, &host, &port, &serial, response)) {
-        return;
+        return std::make_tuple(unique_fd(), port, serial);
     }
 
     std::string error;
-    int fd = network_connect(host.c_str(), port, SOCK_STREAM, 10, &error);
+    unique_fd fd(network_connect(host.c_str(), port, SOCK_STREAM, 10, &error));
     if (fd == -1) {
         *response = android::base::StringPrintf("unable to connect to %s: %s",
                                                 serial.c_str(), error.c_str());
-        return;
+        return std::make_tuple(std::move(fd), port, serial);
     }
 
-    D("client: connected %s remote on fd %d", serial.c_str(), fd);
+    D("client: connected %s remote on fd %d", serial.c_str(), fd.get());
     close_on_exec(fd);
     disable_tcp_nagle(fd);
 
@@ -98,7 +94,38 @@
         D("warning: failed to configure TCP keepalives (%s)", strerror(errno));
     }
 
-    int ret = register_socket_transport(fd, serial.c_str(), port, 0);
+    return std::make_tuple(std::move(fd), port, serial);
+}
+
+void connect_device(const std::string& address, std::string* response) {
+    if (address.empty()) {
+        *response = "empty address";
+        return;
+    }
+
+    unique_fd fd;
+    int port;
+    std::string serial;
+    std::tie(fd, port, serial) = tcp_connect(address, response);
+    auto reconnect = [address](atransport* t) {
+        std::string response;
+        unique_fd fd;
+        int port;
+        std::string serial;
+        std::tie(fd, port, serial) = tcp_connect(address, &response);
+        if (fd == -1) {
+            D("reconnect failed: %s", response.c_str());
+            return false;
+        }
+
+        // This invokes the part of register_socket_transport() that needs to be
+        // invoked if the atransport* has already been setup. This eventually
+        // calls atransport->SetConnection() with a newly created Connection*
+        // that will in turn send the CNXN packet.
+        return init_socket_transport(t, fd.release(), port, 0) >= 0;
+    };
+
+    int ret = register_socket_transport(fd.release(), serial.c_str(), port, 0, std::move(reconnect));
     if (ret < 0) {
         adb_close(fd);
         if (ret == -EALREADY) {
@@ -135,7 +162,8 @@
         close_on_exec(fd);
         disable_tcp_nagle(fd);
         std::string serial = getEmulatorSerialString(console_port);
-        if (register_socket_transport(fd, serial.c_str(), adb_port, 1) == 0) {
+        if (register_socket_transport(fd, serial.c_str(), adb_port, 1,
+                                      [](atransport*) { return false; }) == 0) {
             return 0;
         }
         adb_close(fd);
@@ -239,7 +267,8 @@
             close_on_exec(fd);
             disable_tcp_nagle(fd);
             std::string serial = android::base::StringPrintf("host-%d", fd);
-            if (register_socket_transport(fd, serial.c_str(), port, 1) != 0) {
+            if (register_socket_transport(fd, serial.c_str(), port, 1,
+                                          [](atransport*) { return false; }) != 0) {
                 adb_close(fd);
             }
         }
@@ -338,7 +367,8 @@
                 /* Host is connected. Register the transport, and start the
                  * exchange. */
                 std::string serial = android::base::StringPrintf("host-%d", fd);
-                if (register_socket_transport(fd, serial.c_str(), port, 1) != 0 ||
+                if (register_socket_transport(fd, serial.c_str(), port, 1,
+                                              [](atransport*) { return false; }) != 0 ||
                     !WriteFdExactly(fd, _start_req, strlen(_start_req))) {
                     adb_close(fd);
                 }