[usb-device] Verify wTotalLength sanity

A malicious USB device could change the value it returns in wTotalLength
when the USB host reads configuration descriptors, because those
descriptors must be requested twice -- once to read only the header of
the descriptor to learn the length, and then once again to retrieve the
entire descriptor into an appropriately-sized buffer.

We already verified that the number of bytes we read back in the second
request matched the length specified in the first, but we did not verify
that the contents of the second request matched those of the first.
Failing to do so would leave a descriptor that claims to be longer than
the buffer allocated to it is, which if not handled very carefully could
lead to out-of-bounds reads.  Indeed, the rest of the code treats
wTotalLength as authoritative.

If a device attempts such trickery, we should reject the device in the
same way we'd reject it if it gave us a short read.

Additionally, we should reject wTotalLength values that are shorter than
the config descriptor header -- they can't possibly be valid.

Test: added new test for this case to usb-device-test that fails before
the change but passes after.

We're grateful to Quarkslab for reporting this vulnerability.

Fixed: 50619
Change-Id: I70caa5c1da47c305fc4bcd32c0d35484ca18a323
Reviewed-on: https://fuchsia-review.googlesource.com/c/fuchsia/+/391569
Reviewed-by: Brian Bosak <bbosak@google.com>
Testability-Review: Drew Fisher <zarvox@google.com>
Commit-Queue: Drew Fisher <zarvox@google.com>
diff --git a/src/devices/usb/drivers/usb-bus/tests/usb-device.cc b/src/devices/usb/drivers/usb-bus/tests/usb-device.cc
index 84d6546..24ca985 100644
--- a/src/devices/usb/drivers/usb-bus/tests/usb-device.cc
+++ b/src/devices/usb/drivers/usb-bus/tests/usb-device.cc
@@ -689,4 +689,162 @@
   ASSERT_EQ(get_configuration(), 2);
 }
 
+class EvilFakeHci : public ddk::UsbHciProtocol<EvilFakeHci> {
+  // A fake HCI that pretends to be a device that does dodgy things with
+  // configuration descriptors: namely, changing the size they claim to be
+  // depending on how many requests for config descriptors have been made
+  // previously.
+ public:
+  EvilFakeHci(uint16_t initial_config_length, uint16_t subsequent_config_length) {
+    initial_config_length_ = initial_config_length;
+    subsequent_config_length_ = subsequent_config_length;
+    proto_.ops = &usb_hci_protocol_ops_;
+    proto_.ctx = this;
+  }
+  uint64_t UsbHciGetCurrentFrame() { return kCurrentFrame; }
+
+  zx_status_t UsbHciConfigureHub(uint32_t device_id, usb_speed_t speed,
+                                 const usb_hub_descriptor_t* desc, bool multi_tt) {
+    return ZX_ERR_NOT_SUPPORTED;
+  }
+
+  zx_status_t UsbHciHubDeviceAdded(uint32_t device_id, uint32_t port, usb_speed_t speed) {
+    return ZX_ERR_NOT_SUPPORTED;
+  }
+
+  zx_status_t UsbHciHubDeviceRemoved(uint32_t device_id, uint32_t port) {
+    return ZX_ERR_NOT_SUPPORTED;
+  }
+
+  zx_status_t UsbHciHubDeviceReset(uint32_t device_id, uint32_t port) {
+    return ZX_ERR_NOT_SUPPORTED;
+  }
+
+  zx_status_t UsbHciResetEndpoint(uint32_t device_id, uint8_t ep_address) { return ZX_OK; }
+
+  zx_status_t UsbHciResetDevice(uint32_t hub_address, uint32_t device_id) { return ZX_OK; }
+
+  size_t UsbHciGetMaxTransferSize(uint32_t device_id, uint8_t ep_address) {
+    return ((device_id == kDeviceId) && (ep_address == kTransferSizeEndpoint)) ? kMaxTransferSize
+                                                                               : 0;
+  }
+
+  zx_status_t UsbHciCancelAll(uint32_t device_id, uint8_t ep_address) {
+    auto requests = pending_requests();
+    requests.CompleteAll(ZX_ERR_CANCELED, 0);
+    return ZX_OK;
+  }
+
+  void UsbHciSetBusInterface(const usb_bus_interface_protocol_t* bus_intf) {}
+
+  size_t UsbHciGetMaxDeviceCount() { return 0; }
+
+  size_t UsbHciGetRequestSize() {
+    return usb::BorrowedRequest<void>::RequestSize(sizeof(usb_request_t));
+  }
+
+  void UsbHciRequestQueue(usb_request_t* usb_request_, const usb_request_complete_t* complete_cb_) {
+    usb::BorrowedRequest<void> request(usb_request_, *complete_cb_, sizeof(usb_request_t));
+    EXPECT_EQ(request.request()->header.ep_address, 0);
+    EXPECT_EQ(request.request()->setup.bmRequestType,
+              USB_DIR_IN | USB_TYPE_STANDARD | USB_RECIP_DEVICE);
+    EXPECT_EQ(request.request()->setup.bRequest, USB_REQ_GET_DESCRIPTOR);
+
+    if (request.request()->header.ep_address == 0) {
+      if ((request.request()->setup.bmRequestType ==
+           (USB_DIR_IN | USB_TYPE_STANDARD | USB_RECIP_DEVICE)) &&
+          (request.request()->setup.bRequest == USB_REQ_GET_DESCRIPTOR)) {
+        uint8_t type = static_cast<uint8_t>(request.request()->setup.wValue >> 8);
+        uint8_t index = static_cast<uint8_t>(request.request()->setup.wValue);
+        switch (type) {
+          case USB_DT_DEVICE: {
+            usb_device_descriptor_t* descriptor;
+            request.Mmap(reinterpret_cast<void**>(&descriptor));
+            descriptor->bNumConfigurations = 2;
+            descriptor->idVendor = kVendorId;
+            descriptor->idProduct = kProductId;
+            descriptor->bDeviceClass = kDeviceClass;
+            descriptor->bDeviceSubClass = kDeviceSubclass;
+            descriptor->bDeviceProtocol = kDeviceProtocol;
+            request.Complete(ZX_OK, sizeof(*descriptor));
+          }
+            return;
+          case USB_DT_CONFIG: {
+            usb_configuration_descriptor_t* descriptor;
+            request.Mmap(reinterpret_cast<void**>(&descriptor));
+            // Use the config descriptor lengths described in the constructor
+            // arguments.
+            descriptor->wTotalLength =
+                (config_descriptor_request_count_ % 2 == 0 ? initial_config_length_
+                                                           : subsequent_config_length_);
+            config_descriptor_request_count_++;
+            descriptor->bConfigurationValue = static_cast<uint8_t>(index + 1);
+            request.Complete(ZX_OK, sizeof(*descriptor));
+          }
+            return;
+        }
+      }
+
+      // The host should not send us any requests (like attempting to set a configuration)
+      // after we do questionable things with wTotalLength.
+      request.Complete(ZX_ERR_INVALID_ARGS, 0);
+      return;
+    }
+    pending_requests_.push(std::move(request));
+  }
+
+  zx_status_t UsbHciEnableEndpoint(uint32_t device_id, const usb_endpoint_descriptor_t* ep_desc,
+                                   const usb_ss_ep_comp_descriptor_t* ss_com_desc, bool enable) {
+    return ZX_ERR_BAD_STATE;
+  }
+
+  const usb_hci_protocol_t* proto() { return &proto_; }
+
+  usb::BorrowedRequestQueue<void> pending_requests() { return std::move(pending_requests_); }
+
+ private:
+  int config_descriptor_request_count_ = 0;
+  uint16_t initial_config_length_;
+  uint16_t subsequent_config_length_;
+  usb_hci_protocol_t proto_;
+  fit::function<zx_status_t(uint32_t device_id, const usb_endpoint_descriptor_t* ep_desc,
+                            const usb_ss_ep_comp_descriptor_t* ss_com_desc, bool enable)>
+      enable_endpoint_hook_;
+  usb::BorrowedRequestQueue<void> pending_requests_;
+};
+
+TEST(DeviceTest, GetConfigurationDescriptorTooShortRejected) {
+  // We expect this device to fail to initialize because wTotalLength is too
+  // short -- 1 byte is shorter than the minimal config descriptor length, so
+  // such a response is invalid.
+  EvilFakeHci hci(1, 1);
+  fbl::RefPtr<FakeTimer> timer = fbl::MakeRefCounted<FakeTimer>();
+  timer->set_timeout_handler([=](sync_completion_t* completion, zx_duration_t duration) {
+    return sync_completion_wait(completion, duration);
+  });
+
+  auto device =
+      fbl::MakeRefCounted<UsbDevice>(fake_ddk::kFakeParent, ddk::UsbHciProtocolClient(hci.proto()),
+                                     kDeviceId, kHubId, kDeviceSpeed, timer);
+  auto result = device->Init();
+  ASSERT_EQ(result, ZX_ERR_IO);
+}
+
+TEST(DeviceTest, GetConfigurationDescriptorDifferentSizesAreRejected) {
+  // We expect this device to fail to initialize because when we request its
+  // configuration descriptors, the wTotalSize value we get back changes between
+  // the first (size-fetching) request and second (full descriptor-fetching) request.
+  EvilFakeHci hci(sizeof(usb_configuration_descriptor_t), 65535);
+  fbl::RefPtr<FakeTimer> timer = fbl::MakeRefCounted<FakeTimer>();
+  timer->set_timeout_handler([=](sync_completion_t* completion, zx_duration_t duration) {
+    return sync_completion_wait(completion, duration);
+  });
+
+  auto device =
+      fbl::MakeRefCounted<UsbDevice>(fake_ddk::kFakeParent, ddk::UsbHciProtocolClient(hci.proto()),
+                                     kDeviceId, kHubId, kDeviceSpeed, timer);
+  auto result = device->Init();
+  ASSERT_EQ(result, ZX_ERR_IO);
+}
+
 }  // namespace usb_bus
diff --git a/src/devices/usb/drivers/usb-bus/usb-device.cc b/src/devices/usb/drivers/usb-bus/usb-device.cc
index d7a4377..8830d7a 100644
--- a/src/devices/usb/drivers/usb-bus/usb-device.cc
+++ b/src/devices/usb/drivers/usb-bus/usb-device.cc
@@ -832,6 +832,13 @@
       return status;
     }
     uint16_t config_desc_size = letoh16(config_desc_header.wTotalLength);
+    if (config_desc_size < sizeof(config_desc_header)) {
+      zxlogf(ERROR,
+             "%s: GetDescriptor(USB_DT_CONFIG) gave length shorter than self: "
+             "expected at least %lu, got %u\n",
+             __func__, sizeof(config_desc_header), config_desc_size);
+      return ZX_ERR_IO;
+    }
     auto* config_desc = new (&ac) uint8_t[config_desc_size];
     if (!ac.check()) {
       return ZX_ERR_NO_MEMORY;
@@ -840,13 +847,34 @@
 
     // read full configuration descriptor
     status = GetDescriptor(USB_DT_CONFIG, config, 0, config_desc, config_desc_size, &actual);
-    if (status == ZX_OK && actual != config_desc_size) {
-      status = ZX_ERR_IO;
-    }
     if (status != ZX_OK) {
       zxlogf(ERROR, "%s: GetDescriptor(USB_DT_CONFIG) failed", __func__);
       return status;
     }
+
+    // Guard against the device being evil in a couple ways here.
+
+    // If the actual number of bytes we read for this descriptor doesn't
+    // match the number of bytes the descriptor said we should
+    // expect when we first asked for the descriptor header, return an error.
+    if (actual != config_desc_size) {
+      zxlogf(ERROR, "%s GetDescriptor(USB_DT_CONFIG) config %u expected %u bytes, got %lu\n",
+             __func__, config, config_desc_size, actual);
+      return ZX_ERR_IO;
+    }
+
+    // Similarly, if the second time we read the descriptor, the field
+    // inside the descriptor we just read says that it's a different size from
+    // what we expected or what we read the first time, return an error.
+    uint16_t config_desc_size_on_second_read =
+        letoh16(reinterpret_cast<usb_configuration_descriptor_t*>(config_desc)->wTotalLength);
+    if (actual != config_desc_size_on_second_read) {
+      zxlogf(ERROR,
+             "%s GetDescriptor(USB_DT_CONFIG) config %u length changed between reads: "
+             "was %u bytes, then became %u\n",
+             __func__, config, config_desc_size, config_desc_size_on_second_read);
+      return ZX_ERR_IO;
+    }
   }
   // we will create devices for interfaces on the first configuration by default
   uint8_t configuration = 1;