adb: Make the Connection object a std::shared_ptr

This change is in preparation to allow the TCP-based transports to be
able to reconnect. This is needed because multiple threads can access
the Connection object. It used to be safe to do because one instance of
atransport would have the same Connection instance throughout its
lifetime, but now it is possible to replace the Connection instance,
which could cause threads that were attempting to Write to an
atransport* to use-after-free the Connection instance.

Bug: 74411879
Test: system/core/adb/test_adb.py
Change-Id: I4f092be11b2095088a9a9de2c0386086814d37ce
diff --git a/adb/transport.cpp b/adb/transport.cpp
index 0ab428e..706aee6 100644
--- a/adb/transport.cpp
+++ b/adb/transport.cpp
@@ -517,8 +517,8 @@
     if (t->GetConnectionState() != kCsNoPerm) {
         /* initial references are the two threads */
         t->ref_count = 1;
-        t->connection->SetTransportName(t->serial_name());
-        t->connection->SetReadCallback([t](Connection*, std::unique_ptr<apacket> p) {
+        t->connection()->SetTransportName(t->serial_name());
+        t->connection()->SetReadCallback([t](Connection*, std::unique_ptr<apacket> p) {
             if (!check_header(p.get(), t)) {
                 D("%s: remote read: bad header", t->serial);
                 return false;
@@ -531,7 +531,7 @@
             fdevent_run_on_main_thread([packet, t]() { handle_packet(packet, t); });
             return true;
         });
-        t->connection->SetErrorCallback([t](Connection*, const std::string& error) {
+        t->connection()->SetErrorCallback([t](Connection*, const std::string& error) {
             D("%s: connection terminated: %s", t->serial, error.c_str());
             fdevent_run_on_main_thread([t]() {
                 handle_offline(t);
@@ -539,7 +539,7 @@
             });
         });
 
-        t->connection->Start();
+        t->connection()->Start();
 #if ADB_HOST
         send_connect(t);
 #endif
@@ -608,7 +608,7 @@
     t->ref_count--;
     if (t->ref_count == 0) {
         D("transport: %s unref (kicking and closing)", t->serial);
-        t->connection->Stop();
+        t->connection()->Stop();
         remove_transport(t);
     } else {
         D("transport: %s unref (count=%zu)", t->serial, t->ref_count);
@@ -758,14 +758,14 @@
 }
 
 int atransport::Write(apacket* p) {
-    return this->connection->Write(std::unique_ptr<apacket>(p)) ? 0 : -1;
+    return this->connection()->Write(std::unique_ptr<apacket>(p)) ? 0 : -1;
 }
 
 void atransport::Kick() {
     if (!kicked_) {
         D("kicking transport %s", this->serial);
         kicked_ = true;
-        this->connection->Stop();
+        this->connection()->Stop();
     }
 }
 
@@ -778,6 +778,11 @@
     connection_state_ = state;
 }
 
+void atransport::SetConnection(std::unique_ptr<Connection> connection) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    connection_ = std::shared_ptr<Connection>(std::move(connection));
+}
+
 std::string atransport::connection_state_name() const {
     ConnectionState state = GetConnectionState();
     switch (state) {
@@ -1094,8 +1099,9 @@
 void unregister_usb_transport(usb_handle* usb) {
     std::lock_guard<std::recursive_mutex> lock(transport_lock);
     transport_list.remove_if([usb](atransport* t) {
-        if (auto connection = dynamic_cast<UsbConnection*>(t->connection.get())) {
-            return connection->handle_ == usb && t->GetConnectionState() == kCsNoPerm;
+        auto connection = t->connection();
+        if (auto usb_connection = dynamic_cast<UsbConnection*>(connection.get())) {
+            return usb_connection->handle_ == usb && t->GetConnectionState() == kCsNoPerm;
         }
         return false;
     });
diff --git a/adb/transport.h b/adb/transport.h
index 4e0220f..ebc186b 100644
--- a/adb/transport.h
+++ b/adb/transport.h
@@ -201,7 +201,8 @@
     atransport(ConnectionState state = kCsOffline)
         : id(NextTransportId()),
           connection_state_(state),
-          connection_waitable_(std::make_shared<ConnectionWaitable>()) {
+          connection_waitable_(std::make_shared<ConnectionWaitable>()),
+          connection_(nullptr) {
         // Initialize protocol to min version for compatibility with older versions.
         // Version will be updated post-connect.
         protocol_version = A_VERSION_MIN;
@@ -216,13 +217,17 @@
     ConnectionState GetConnectionState() const;
     void SetConnectionState(ConnectionState state);
 
+    void SetConnection(std::unique_ptr<Connection> connection);
+    std::shared_ptr<Connection> connection() {
+        std::lock_guard<std::mutex> lock(mutex_);
+        return connection_;
+    }
+
     const TransportId id;
     size_t ref_count = 0;
     bool online = false;
     TransportType type = kTransportAny;
 
-    std::unique_ptr<Connection> connection;
-
     // Used to identify transports for clients.
     char* serial = nullptr;
     char* product = nullptr;
@@ -302,6 +307,11 @@
     // connection to be established.
     std::shared_ptr<ConnectionWaitable> connection_waitable_;
 
+    // The underlying connection object.
+    std::shared_ptr<Connection> connection_ GUARDED_BY(mutex_);
+
+    std::mutex mutex_;
+
     DISALLOW_COPY_AND_ASSIGN(atransport);
 };
 
diff --git a/adb/transport_local.cpp b/adb/transport_local.cpp
index 8032421..e81f27c 100644
--- a/adb/transport_local.cpp
+++ b/adb/transport_local.cpp
@@ -456,7 +456,8 @@
     // Emulator connection.
     if (local) {
         auto emulator_connection = std::make_unique<EmulatorConnection>(std::move(fd), adb_port);
-        t->connection = std::make_unique<BlockingConnectionAdapter>(std::move(emulator_connection));
+        t->SetConnection(
+            std::make_unique<BlockingConnectionAdapter>(std::move(emulator_connection)));
         std::lock_guard<std::mutex> lock(local_transports_lock);
         atransport* existing_transport = find_emulator_transport_by_adb_port_locked(adb_port);
         if (existing_transport != NULL) {
@@ -476,6 +477,6 @@
 
     // Regular tcp connection.
     auto fd_connection = std::make_unique<FdConnection>(std::move(fd));
-    t->connection = std::make_unique<BlockingConnectionAdapter>(std::move(fd_connection));
+    t->SetConnection(std::make_unique<BlockingConnectionAdapter>(std::move(fd_connection)));
     return fail;
 }
diff --git a/adb/transport_usb.cpp b/adb/transport_usb.cpp
index e9a75cd..94b2e37 100644
--- a/adb/transport_usb.cpp
+++ b/adb/transport_usb.cpp
@@ -176,7 +176,7 @@
 void init_usb_transport(atransport* t, usb_handle* h) {
     D("transport: usb");
     auto connection = std::make_unique<UsbConnection>(h);
-    t->connection = std::make_unique<BlockingConnectionAdapter>(std::move(connection));
+    t->SetConnection(std::make_unique<BlockingConnectionAdapter>(std::move(connection)));
     t->type = kTransportUsb;
 }