adb: Add a way to distinguish between connection failures and successes

This change adds a callback that is invoked exactly once, either when
the connection is fully established (i.e. CNXN packets have been sent
and received) or the atransport object is deleted before that (because
the connection failed).

This helps in distinguishing between successful and failing connections
for TCP. Especially when there is some kind of port
forwarding/multiplexing in between (like an SSH tunnel or SSLH proxy).

Bug: 74411879
Test: adb connect chromebook:22 (which runs an sslh tunnel to adbd).
      either succeeds or fails, but not fake-succeeds.

Change-Id: I7e826c6f5d4c30338a03b2d376a857ac5d05672a
diff --git a/adb/adb.cpp b/adb/adb.cpp
index 38c11b9..3bf281c 100644
--- a/adb/adb.cpp
+++ b/adb/adb.cpp
@@ -132,6 +132,7 @@
 {
     D("adb: online");
     t->online = 1;
+    t->SetConnectionEstablished(true);
 }
 
 void handle_offline(atransport *t)
diff --git a/adb/test_adb.py b/adb/test_adb.py
index 363002f..32bf029 100644
--- a/adb/test_adb.py
+++ b/adb/test_adb.py
@@ -49,8 +49,16 @@
     # A pipe that is used to signal the thread that it should terminate.
     readpipe, writepipe = os.pipe()
 
+    def _adb_packet(command, arg0, arg1, data):
+        bin_command = struct.unpack('I', command)[0]
+        buf = struct.pack('IIIIII', bin_command, arg0, arg1, len(data), 0,
+                          bin_command ^ 0xffffffff)
+        buf += data
+        return buf
+
     def _handle():
         rlist = [readpipe, serversock]
+        cnxn_sent = {}
         while True:
             ready, _, _ = select.select(rlist, [], [])
             for r in ready:
@@ -68,7 +76,15 @@
                     # Client socket
                     data = r.recv(1024)
                     if not data:
+                        if r in cnxn_sent:
+                            del cnxn_sent[r]
                         rlist.remove(r)
+                        continue
+                    if r in cnxn_sent:
+                        continue
+                    cnxn_sent[r] = True
+                    r.sendall(_adb_packet('CNXN', 0x01000001, 1024 * 1024,
+                                          'device::ro.product.name=fakeadb'))
 
     port = serversock.getsockname()[1]
     server_thread = threading.Thread(target=_handle)
diff --git a/adb/transport.cpp b/adb/transport.cpp
index f5f6d26..0ab428e 100644
--- a/adb/transport.cpp
+++ b/adb/transport.cpp
@@ -64,6 +64,21 @@
 const char* const kFeatureLibusb = "libusb";
 const char* const kFeaturePushSync = "push_sync";
 
+namespace {
+
+// A class that helps the Clang Thread Safety Analysis deal with
+// std::unique_lock. Given that std::unique_lock is movable, and the analysis
+// can not currently perform alias analysis, it is not annotated. In order to
+// assert that the mutex is held, a ScopedAssumeLocked can be created just after
+// the std::unique_lock.
+class SCOPED_CAPABILITY ScopedAssumeLocked {
+  public:
+    ScopedAssumeLocked(std::mutex& mutex) ACQUIRE(mutex) {}
+    ~ScopedAssumeLocked() RELEASE() {}
+};
+
+}  // namespace
+
 TransportId NextTransportId() {
     static std::atomic<TransportId> next(1);
     return next++;
@@ -77,8 +92,6 @@
     Stop();
 }
 
-static void AssumeLocked(std::mutex& mutex) ASSERT_CAPABILITY(mutex) {}
-
 void BlockingConnectionAdapter::Start() {
     std::lock_guard<std::mutex> lock(mutex_);
     if (started_) {
@@ -103,12 +116,11 @@
         LOG(INFO) << this->transport_name_ << ": write thread spawning";
         while (true) {
             std::unique_lock<std::mutex> lock(mutex_);
+            ScopedAssumeLocked assume_locked(mutex_);
             cv_.wait(lock, [this]() REQUIRES(mutex_) {
                 return this->stopped_ || !this->write_queue_.empty();
             });
 
-            AssumeLocked(mutex_);
-
             if (this->stopped_) {
                 return;
             }
@@ -721,6 +733,30 @@
     return result;
 }
 
+bool ConnectionWaitable::WaitForConnection(std::chrono::milliseconds timeout) {
+    std::unique_lock<std::mutex> lock(mutex_);
+    ScopedAssumeLocked assume_locked(mutex_);
+    return cv_.wait_for(lock, timeout, [&]() REQUIRES(mutex_) {
+        return connection_established_ready_;
+    }) && connection_established_;
+}
+
+void ConnectionWaitable::SetConnectionEstablished(bool success) {
+    {
+        std::lock_guard<std::mutex> lock(mutex_);
+        if (connection_established_ready_) return;
+        connection_established_ready_ = true;
+        connection_established_ = success;
+        D("connection established with %d", success);
+    }
+    cv_.notify_one();
+}
+
+atransport::~atransport() {
+    // If the connection callback had not been run before, run it now.
+    SetConnectionEstablished(false);
+}
+
 int atransport::Write(apacket* p) {
     return this->connection->Write(std::unique_ptr<apacket>(p)) ? 0 : -1;
 }
@@ -873,6 +909,10 @@
            qual_match(target.c_str(), "device:", device, false);
 }
 
+void atransport::SetConnectionEstablished(bool success) {
+    connection_waitable_->SetConnectionEstablished(success);
+}
+
 #if ADB_HOST
 
 // We use newline as our delimiter, make sure to never output it.
@@ -992,8 +1032,10 @@
 
     lock.unlock();
 
+    auto waitable = t->connection_waitable();
     register_transport(t);
-    return 0;
+
+    return waitable->WaitForConnection(std::chrono::seconds(10)) ? 0 : -1;
 }
 
 #if ADB_HOST
diff --git a/adb/transport.h b/adb/transport.h
index d18c362..4e0220f 100644
--- a/adb/transport.h
+++ b/adb/transport.h
@@ -20,6 +20,7 @@
 #include <sys/types.h>
 
 #include <atomic>
+#include <chrono>
 #include <condition_variable>
 #include <deque>
 #include <functional>
@@ -30,6 +31,7 @@
 #include <thread>
 #include <unordered_set>
 
+#include <android-base/macros.h>
 #include <android-base/thread_annotations.h>
 #include <openssl/rsa.h>
 
@@ -160,6 +162,35 @@
     usb_handle* handle_;
 };
 
+// Waits for a transport's connection to be not pending. This is a separate
+// object so that the transport can be destroyed and another thread can be
+// notified of it in a race-free way.
+class ConnectionWaitable {
+  public:
+    ConnectionWaitable() = default;
+    ~ConnectionWaitable() = default;
+
+    // Waits until the first CNXN packet has been received by the owning
+    // atransport, or the specified timeout has elapsed. Can be called from any
+    // thread.
+    //
+    // Returns true if the CNXN packet was received in a timely fashion, false
+    // otherwise.
+    bool WaitForConnection(std::chrono::milliseconds timeout);
+
+    // Can be called from any thread when the connection stops being pending.
+    // Only the first invocation will be acknowledged, the rest will be no-ops.
+    void SetConnectionEstablished(bool success);
+
+  private:
+    bool connection_established_ GUARDED_BY(mutex_) = false;
+    bool connection_established_ready_ GUARDED_BY(mutex_) = false;
+    std::mutex mutex_;
+    std::condition_variable cv_;
+
+    DISALLOW_COPY_AND_ASSIGN(ConnectionWaitable);
+};
+
 class atransport {
   public:
     // TODO(danalbert): We expose waaaaaaay too much stuff because this was
@@ -168,13 +199,15 @@
     // it's better to do this piece by piece.
 
     atransport(ConnectionState state = kCsOffline)
-        : id(NextTransportId()), connection_state_(state) {
+        : id(NextTransportId()),
+          connection_state_(state),
+          connection_waitable_(std::make_shared<ConnectionWaitable>()) {
         // Initialize protocol to min version for compatibility with older versions.
         // Version will be updated post-connect.
         protocol_version = A_VERSION_MIN;
         max_payload = MAX_PAYLOAD;
     }
-    virtual ~atransport() {}
+    virtual ~atransport();
 
     int Write(apacket* p);
     void Kick();
@@ -241,7 +274,14 @@
     // This is to make it easier to use the same network target for both fastboot and adb.
     bool MatchesTarget(const std::string& target) const;
 
-private:
+    // Notifies that the atransport is no longer waiting for the connection
+    // being established.
+    void SetConnectionEstablished(bool success);
+
+    // Gets a shared reference to the ConnectionWaitable.
+    std::shared_ptr<ConnectionWaitable> connection_waitable() { return connection_waitable_; }
+
+  private:
     bool kicked_ = false;
 
     // A set of features transmitted in the banner with the initial connection.
@@ -258,6 +298,10 @@
     std::deque<std::shared_ptr<RSA>> keys_;
 #endif
 
+    // A sharable object that can be used to wait for the atransport's
+    // connection to be established.
+    std::shared_ptr<ConnectionWaitable> connection_waitable_;
+
     DISALLOW_COPY_AND_ASSIGN(atransport);
 };