Merge "adb: Add a way to distinguish between connection failures and successes"
diff --git a/adb/adb.cpp b/adb/adb.cpp
index 38c11b9..3bf281c 100644
--- a/adb/adb.cpp
+++ b/adb/adb.cpp
@@ -132,6 +132,7 @@
 {
     D("adb: online");
     t->online = 1;
+    t->SetConnectionEstablished(true);
 }
 
 void handle_offline(atransport *t)
diff --git a/adb/test_adb.py b/adb/test_adb.py
index 363002f..32bf029 100644
--- a/adb/test_adb.py
+++ b/adb/test_adb.py
@@ -49,8 +49,16 @@
     # A pipe that is used to signal the thread that it should terminate.
     readpipe, writepipe = os.pipe()
 
+    def _adb_packet(command, arg0, arg1, data):
+        bin_command = struct.unpack('I', command)[0]
+        buf = struct.pack('IIIIII', bin_command, arg0, arg1, len(data), 0,
+                          bin_command ^ 0xffffffff)
+        buf += data
+        return buf
+
     def _handle():
         rlist = [readpipe, serversock]
+        cnxn_sent = {}
         while True:
             ready, _, _ = select.select(rlist, [], [])
             for r in ready:
@@ -68,7 +76,15 @@
                     # Client socket
                     data = r.recv(1024)
                     if not data:
+                        if r in cnxn_sent:
+                            del cnxn_sent[r]
                         rlist.remove(r)
+                        continue
+                    if r in cnxn_sent:
+                        continue
+                    cnxn_sent[r] = True
+                    r.sendall(_adb_packet('CNXN', 0x01000001, 1024 * 1024,
+                                          'device::ro.product.name=fakeadb'))
 
     port = serversock.getsockname()[1]
     server_thread = threading.Thread(target=_handle)
diff --git a/adb/transport.cpp b/adb/transport.cpp
index f5f6d26..0ab428e 100644
--- a/adb/transport.cpp
+++ b/adb/transport.cpp
@@ -64,6 +64,21 @@
 const char* const kFeatureLibusb = "libusb";
 const char* const kFeaturePushSync = "push_sync";
 
+namespace {
+
+// A class that helps the Clang Thread Safety Analysis deal with
+// std::unique_lock. Given that std::unique_lock is movable, and the analysis
+// can not currently perform alias analysis, it is not annotated. In order to
+// assert that the mutex is held, a ScopedAssumeLocked can be created just after
+// the std::unique_lock.
+class SCOPED_CAPABILITY ScopedAssumeLocked {
+  public:
+    ScopedAssumeLocked(std::mutex& mutex) ACQUIRE(mutex) {}
+    ~ScopedAssumeLocked() RELEASE() {}
+};
+
+}  // namespace
+
 TransportId NextTransportId() {
     static std::atomic<TransportId> next(1);
     return next++;
@@ -77,8 +92,6 @@
     Stop();
 }
 
-static void AssumeLocked(std::mutex& mutex) ASSERT_CAPABILITY(mutex) {}
-
 void BlockingConnectionAdapter::Start() {
     std::lock_guard<std::mutex> lock(mutex_);
     if (started_) {
@@ -103,12 +116,11 @@
         LOG(INFO) << this->transport_name_ << ": write thread spawning";
         while (true) {
             std::unique_lock<std::mutex> lock(mutex_);
+            ScopedAssumeLocked assume_locked(mutex_);
             cv_.wait(lock, [this]() REQUIRES(mutex_) {
                 return this->stopped_ || !this->write_queue_.empty();
             });
 
-            AssumeLocked(mutex_);
-
             if (this->stopped_) {
                 return;
             }
@@ -721,6 +733,30 @@
     return result;
 }
 
+bool ConnectionWaitable::WaitForConnection(std::chrono::milliseconds timeout) {
+    std::unique_lock<std::mutex> lock(mutex_);
+    ScopedAssumeLocked assume_locked(mutex_);
+    return cv_.wait_for(lock, timeout, [&]() REQUIRES(mutex_) {
+        return connection_established_ready_;
+    }) && connection_established_;
+}
+
+void ConnectionWaitable::SetConnectionEstablished(bool success) {
+    {
+        std::lock_guard<std::mutex> lock(mutex_);
+        if (connection_established_ready_) return;
+        connection_established_ready_ = true;
+        connection_established_ = success;
+        D("connection established with %d", success);
+    }
+    cv_.notify_one();
+}
+
+atransport::~atransport() {
+    // If the connection callback had not been run before, run it now.
+    SetConnectionEstablished(false);
+}
+
 int atransport::Write(apacket* p) {
     return this->connection->Write(std::unique_ptr<apacket>(p)) ? 0 : -1;
 }
@@ -873,6 +909,10 @@
            qual_match(target.c_str(), "device:", device, false);
 }
 
+void atransport::SetConnectionEstablished(bool success) {
+    connection_waitable_->SetConnectionEstablished(success);
+}
+
 #if ADB_HOST
 
 // We use newline as our delimiter, make sure to never output it.
@@ -992,8 +1032,10 @@
 
     lock.unlock();
 
+    auto waitable = t->connection_waitable();
     register_transport(t);
-    return 0;
+
+    return waitable->WaitForConnection(std::chrono::seconds(10)) ? 0 : -1;
 }
 
 #if ADB_HOST
diff --git a/adb/transport.h b/adb/transport.h
index d18c362..4e0220f 100644
--- a/adb/transport.h
+++ b/adb/transport.h
@@ -20,6 +20,7 @@
 #include <sys/types.h>
 
 #include <atomic>
+#include <chrono>
 #include <condition_variable>
 #include <deque>
 #include <functional>
@@ -30,6 +31,7 @@
 #include <thread>
 #include <unordered_set>
 
+#include <android-base/macros.h>
 #include <android-base/thread_annotations.h>
 #include <openssl/rsa.h>
 
@@ -160,6 +162,35 @@
     usb_handle* handle_;
 };
 
+// Waits for a transport's connection to be not pending. This is a separate
+// object so that the transport can be destroyed and another thread can be
+// notified of it in a race-free way.
+class ConnectionWaitable {
+  public:
+    ConnectionWaitable() = default;
+    ~ConnectionWaitable() = default;
+
+    // Waits until the first CNXN packet has been received by the owning
+    // atransport, or the specified timeout has elapsed. Can be called from any
+    // thread.
+    //
+    // Returns true if the CNXN packet was received in a timely fashion, false
+    // otherwise.
+    bool WaitForConnection(std::chrono::milliseconds timeout);
+
+    // Can be called from any thread when the connection stops being pending.
+    // Only the first invocation will be acknowledged, the rest will be no-ops.
+    void SetConnectionEstablished(bool success);
+
+  private:
+    bool connection_established_ GUARDED_BY(mutex_) = false;
+    bool connection_established_ready_ GUARDED_BY(mutex_) = false;
+    std::mutex mutex_;
+    std::condition_variable cv_;
+
+    DISALLOW_COPY_AND_ASSIGN(ConnectionWaitable);
+};
+
 class atransport {
   public:
     // TODO(danalbert): We expose waaaaaaay too much stuff because this was
@@ -168,13 +199,15 @@
     // it's better to do this piece by piece.
 
     atransport(ConnectionState state = kCsOffline)
-        : id(NextTransportId()), connection_state_(state) {
+        : id(NextTransportId()),
+          connection_state_(state),
+          connection_waitable_(std::make_shared<ConnectionWaitable>()) {
         // Initialize protocol to min version for compatibility with older versions.
         // Version will be updated post-connect.
         protocol_version = A_VERSION_MIN;
         max_payload = MAX_PAYLOAD;
     }
-    virtual ~atransport() {}
+    virtual ~atransport();
 
     int Write(apacket* p);
     void Kick();
@@ -241,7 +274,14 @@
     // This is to make it easier to use the same network target for both fastboot and adb.
     bool MatchesTarget(const std::string& target) const;
 
-private:
+    // Notifies that the atransport is no longer waiting for the connection
+    // being established.
+    void SetConnectionEstablished(bool success);
+
+    // Gets a shared reference to the ConnectionWaitable.
+    std::shared_ptr<ConnectionWaitable> connection_waitable() { return connection_waitable_; }
+
+  private:
     bool kicked_ = false;
 
     // A set of features transmitted in the banner with the initial connection.
@@ -258,6 +298,10 @@
     std::deque<std::shared_ptr<RSA>> keys_;
 #endif
 
+    // A sharable object that can be used to wait for the atransport's
+    // connection to be established.
+    std::shared_ptr<ConnectionWaitable> connection_waitable_;
+
     DISALLOW_COPY_AND_ASSIGN(atransport);
 };