
// Copyright 2017 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include <assert.h>
#include <fuchsia/hardware/usb/peripheral/block/c/fidl.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <zircon/device/usb-peripheral.h>
#include <zircon/hw/usb/ums.h>
#include <zircon/process.h>
#include <zircon/syscalls.h>

#include <ddk/binding.h>
#include <ddk/debug.h>
#include <ddk/device.h>
#include <ddk/driver.h>
#include <ddk/protocol/usb/function.h>
#include <usb/usb-request.h>

#define BLOCK_SIZE 512L
#define STORAGE_SIZE (4L * 1024L * 1024L * 1024L)
#define BLOCK_COUNT (STORAGE_SIZE / BLOCK_SIZE)
#define DATA_REQ_SIZE 16384
#define BULK_MAX_PACKET 512

typedef enum {
  DATA_STATE_NONE,
  DATA_STATE_READ,
  DATA_STATE_WRITE,
  DATA_STATE_FAILED
} ums_data_state_t;

static struct {
  usb_interface_descriptor_t intf;
  usb_endpoint_descriptor_t out_ep;
  usb_endpoint_descriptor_t in_ep;
} descriptors = {
    .intf =
        {
            .bLength = sizeof(usb_interface_descriptor_t),
            .bDescriptorType = USB_DT_INTERFACE,
            //      .bInterfaceNumber set later
            .bAlternateSetting = 0,
            .bNumEndpoints = 2,
            .bInterfaceClass = USB_CLASS_MSC,
            .bInterfaceSubClass = USB_SUBCLASS_MSC_SCSI,
            .bInterfaceProtocol = USB_PROTOCOL_MSC_BULK_ONLY,
            .iInterface = 0,
        },
    .out_ep =
        {
            .bLength = sizeof(usb_endpoint_descriptor_t),
            .bDescriptorType = USB_DT_ENDPOINT,
            //      .bEndpointAddress set later
            .bmAttributes = USB_ENDPOINT_BULK,
            .wMaxPacketSize = htole16(BULK_MAX_PACKET),
            .bInterval = 0,
        },
    .in_ep =
        {
            .bLength = sizeof(usb_endpoint_descriptor_t),
            .bDescriptorType = USB_DT_ENDPOINT,
            //      .bEndpointAddress set later
            .bmAttributes = USB_ENDPOINT_BULK,
            .wMaxPacketSize = htole16(BULK_MAX_PACKET),
            .bInterval = 0,
        },
};

typedef struct {
  zx_device_t* zxdev;
  usb_function_protocol_t function;
  usb_request_t* cbw_req;
  bool cbw_req_complete;
  usb_request_t* data_req;
  bool data_req_complete;
  usb_request_t* csw_req;
  bool csw_req_complete;

  // vmo for backing storage
  zx_handle_t storage_handle;
  void* storage;

  // command we are currently handling
  ums_cbw_t current_cbw;
  // data transferred for the current command
  uint32_t data_length;

  // state for data transfers
  ums_data_state_t data_state;
  // state for reads and writes
  zx_off_t data_offset;
  size_t data_remaining;

  uint8_t bulk_out_addr;
  uint8_t bulk_in_addr;
  size_t parent_req_size;
  bool writeback_cache;
  bool writeback_cache_report;
  thrd_t thread;
  bool active;
  cnd_t event;
  mtx_t mtx;
  atomic_int pending_request_count;
} usb_ums_t;

static void ums_cbw_complete(void* ctx, usb_request_t* req);
static void ums_data_complete(void* ctx, usb_request_t* req);
static void ums_csw_complete(void* ctx, usb_request_t* req);

static void usb_request_queue(void* ctx, usb_function_protocol_t* function, usb_request_t* req,
                              const usb_request_complete_t* completion) {
  usb_ums_t* ums = (usb_ums_t*)ctx;
  atomic_fetch_add(&ums->pending_request_count, 1);
  usb_function_request_queue(function, req, completion);
}

static void ums_completion_callback(void* ctx, usb_request_t* req) {
  usb_ums_t* ums = (usb_ums_t*)ctx;
  mtx_lock(&ums->mtx);
  if (req == ums->cbw_req) {
    ums->cbw_req_complete = true;
  } else {
    if (req == ums->data_req) {
      ums->data_req_complete = true;
    } else {
      ums->csw_req_complete = true;
    }
  }
  cnd_signal(&ums->event);
  mtx_unlock(&ums->mtx);
}

static void ums_function_queue_data(usb_ums_t* ums, usb_request_t* req) {
  ums->data_length += req->header.length;
  req->header.ep_address =
      ums->current_cbw.bmCBWFlags & USB_DIR_IN ? ums->bulk_in_addr : ums->bulk_out_addr;
  usb_request_complete_t complete = {
      .callback = ums_completion_callback,
      .ctx = ums,
  };
  usb_request_queue(ums, &ums->function, req, &complete);
}

static void ums_queue_csw(usb_ums_t* ums, uint8_t status) {
  // first queue next cbw so it is ready to go
  usb_request_complete_t cbw_complete = {
      .callback = ums_completion_callback,
      .ctx = ums,
  };
  usb_request_queue(ums, &ums->function, ums->cbw_req, &cbw_complete);

  usb_request_t* req = ums->csw_req;
  ums_csw_t* csw;
  usb_request_mmap(req, (void**)&csw);

  csw->dCSWSignature = htole32(CSW_SIGNATURE);
  csw->dCSWTag = ums->current_cbw.dCBWTag;
  csw->dCSWDataResidue =
      htole32(le32toh(ums->current_cbw.dCBWDataTransferLength) - ums->data_length);
  csw->bmCSWStatus = status;

  req->header.length = sizeof(ums_csw_t);
  usb_request_complete_t csw_complete = {
      .callback = ums_completion_callback,
      .ctx = ums,
  };
  usb_request_queue(ums, &ums->function, ums->csw_req, &csw_complete);
}

static void ums_continue_transfer(usb_ums_t* ums) {
  usb_request_t* req = ums->data_req;

  size_t length = ums->data_remaining;
  if (length > DATA_REQ_SIZE) {
    length = DATA_REQ_SIZE;
  }
  req->header.length = length;

  if (ums->data_state == DATA_STATE_READ) {
    size_t result = usb_request_copy_to(req, ums->storage + ums->data_offset, length, 0);
    ZX_ASSERT(result == length);
    ums_function_queue_data(ums, req);
  } else if (ums->data_state == DATA_STATE_WRITE) {
    ums_function_queue_data(ums, req);
  } else {
    zxlogf(ERROR, "ums_continue_transfer: bad data state %d", ums->data_state);
  }
}

static void ums_start_transfer(usb_ums_t* ums, ums_data_state_t state, uint64_t lba,
                               uint32_t blocks) {
  zx_off_t offset = lba * BLOCK_SIZE;
  size_t length = blocks * BLOCK_SIZE;

  if (offset + length > STORAGE_SIZE) {
    zxlogf(ERROR, "ums_start_transfer: transfer out of range state: %d, lba: %zu blocks: %u", state,
           lba, blocks);
    // TODO(voydanoff) report error to host
    return;
  }

  ums->data_state = state;
  ums->data_offset = offset;
  ums->data_remaining = length;

  ums_continue_transfer(ums);
}

static void ums_handle_inquiry(usb_ums_t* ums, ums_cbw_t* cbw) {
  zxlogf(DEBUG, "ums_handle_inquiry");

  usb_request_t* req = ums->data_req;
  uint8_t* buffer;
  usb_request_mmap(req, (void**)&buffer);
  memset(buffer, 0, UMS_INQUIRY_TRANSFER_LENGTH);
  req->header.length = UMS_INQUIRY_TRANSFER_LENGTH;

  // fill in inquiry result
  buffer[0] = 0;     // Peripheral Device Type: Direct access block device
  buffer[1] = 0x80;  // Removable
  buffer[2] = 6;     // Version SPC-4
  buffer[3] = 0x12;  // Response Data Format
  memcpy(buffer + 8, "Google  ", 8);
  memcpy(buffer + 16, "Zircon UMS      ", 16);
  memcpy(buffer + 32, "1.00", 4);

  ums_function_queue_data(ums, req);
}

static void ums_handle_test_unit_ready(usb_ums_t* ums, ums_cbw_t* cbw) {
  zxlogf(DEBUG, "ums_handle_test_unit_ready");

  // no data phase here. Just return status OK
  ums_queue_csw(ums, CSW_SUCCESS);
}

static void ums_handle_request_sense(usb_ums_t* ums, ums_cbw_t* cbw) {
  zxlogf(DEBUG, "ums_handle_request_sense");

  usb_request_t* req = ums->data_req;
  uint8_t* buffer;
  usb_request_mmap(req, (void**)&buffer);
  memset(buffer, 0, UMS_REQUEST_SENSE_TRANSFER_LENGTH);
  req->header.length = UMS_REQUEST_SENSE_TRANSFER_LENGTH;

  // TODO(voydanoff) This is a hack. Figure out correct values to return here.
  buffer[0] = 0x70;   // Response Code
  buffer[2] = 5;      // Illegal Request
  buffer[7] = 10;     // Additional Sense Length
  buffer[12] = 0x20;  // Additional Sense Code

  ums_function_queue_data(ums, req);
}

static void ums_handle_read_capacity10(usb_ums_t* ums, ums_cbw_t* cbw) {
  zxlogf(DEBUG, "ums_handle_read_capacity10");

  usb_request_t* req = ums->data_req;
  scsi_read_capacity_10_t* data;
  usb_request_mmap(req, (void**)&data);

  uint64_t lba = BLOCK_COUNT - 1;
  if (lba > UINT32_MAX) {
    data->lba = htobe32(UINT32_MAX);
  } else {
    data->lba = htobe32(lba);
  }
  data->block_length = htobe32(BLOCK_SIZE);

  req->header.length = sizeof(*data);
  ums_function_queue_data(ums, req);
}

static void ums_handle_read_capacity16(usb_ums_t* ums, ums_cbw_t* cbw) {
  zxlogf(DEBUG, "ums_handle_read_capacity16");

  usb_request_t* req = ums->data_req;
  scsi_read_capacity_16_t* data;
  usb_request_mmap(req, (void**)&data);
  memset(data, 0, sizeof(*data));

  data->lba = htobe64(BLOCK_COUNT - 1);
  data->block_length = htobe32(BLOCK_SIZE);

  req->header.length = sizeof(*data);
  ums_function_queue_data(ums, req);
}

static void ums_handle_mode_sense6(usb_ums_t* ums, ums_cbw_t* cbw) {
  zxlogf(DEBUG, "ums_handle_mode_sense6");
  scsi_mode_sense_6_command_t command;
  memcpy(&command, cbw->CBWCB, sizeof(command));
  usb_request_t* req = ums->data_req;
  scsi_mode_sense_6_data_t* data;
  usb_request_mmap(req, (void**)&data);
  memset(data, 0, sizeof(*data));
  req->header.length = sizeof(*data);
  if (command.page == 0x3F && ums->writeback_cache_report) {
    // Special request (cache page)
    // 20 byte response.
    ((unsigned char*)data)[6] = 1 << 2;  // Write Cache enable bit
    req->header.length = 20;
  }
  ums_function_queue_data(ums, req);
}

static void ums_handle_read10(usb_ums_t* ums, ums_cbw_t* cbw) {
  zxlogf(DEBUG, "ums_handle_read10");

  scsi_command10_t* command = (scsi_command10_t*)cbw->CBWCB;
  uint64_t lba = be32toh(command->lba);
  uint32_t blocks = ((uint32_t)command->length_hi << 8) | (uint32_t)command->length_lo;
  ums_start_transfer(ums, DATA_STATE_READ, lba, blocks);
}

static void ums_handle_read12(usb_ums_t* ums, ums_cbw_t* cbw) {
  zxlogf(DEBUG, "ums_handle_read12");

  scsi_command12_t* command = (scsi_command12_t*)cbw->CBWCB;
  uint64_t lba = be32toh(command->lba);
  uint32_t blocks = be32toh(command->length);
  ums_start_transfer(ums, DATA_STATE_READ, lba, blocks);
}

static void ums_handle_read16(usb_ums_t* ums, ums_cbw_t* cbw) {
  zxlogf(DEBUG, "ums_handle_read16");

  scsi_command16_t* command = (scsi_command16_t*)cbw->CBWCB;
  uint64_t lba = be64toh(command->lba);
  uint32_t blocks = be32toh(command->length);
  ums_start_transfer(ums, DATA_STATE_READ, lba, blocks);
}

static void ums_handle_write10(usb_ums_t* ums, ums_cbw_t* cbw) {
  zxlogf(DEBUG, "ums_handle_write10");

  scsi_command10_t* command = (scsi_command10_t*)cbw->CBWCB;
  uint64_t lba = be32toh(command->lba);
  uint32_t blocks = ((uint32_t)command->length_hi << 8) | (uint32_t)command->length_lo;
  ums_start_transfer(ums, DATA_STATE_WRITE, lba, blocks);
}

static void ums_handle_write12(usb_ums_t* ums, ums_cbw_t* cbw) {
  zxlogf(DEBUG, "ums_handle_write12");

  scsi_command12_t* command = (scsi_command12_t*)cbw->CBWCB;
  uint64_t lba = be32toh(command->lba);
  uint32_t blocks = be32toh(command->length);
  ums_start_transfer(ums, DATA_STATE_WRITE, lba, blocks);
}

static void ums_handle_write16(usb_ums_t* ums, ums_cbw_t* cbw) {
  zxlogf(DEBUG, "ums_handle_write16");

  scsi_command16_t* command = (scsi_command16_t*)cbw->CBWCB;
  uint64_t lba = be64toh(command->lba);
  uint32_t blocks = be32toh(command->length);
  ums_start_transfer(ums, DATA_STATE_WRITE, lba, blocks);
}

static void ums_handle_cbw(usb_ums_t* ums, ums_cbw_t* cbw) {
  if (le32toh(cbw->dCBWSignature) != CBW_SIGNATURE) {
    zxlogf(ERROR, "ums_handle_cbw: bad dCBWSignature 0x%x", le32toh(cbw->dCBWSignature));
    return;
  }

  // reset data length for computing residue
  ums->data_length = 0;

  // all SCSI commands have opcode in the same place, so using scsi_command6_t works here.
  scsi_command6_t* command = (scsi_command6_t*)cbw->CBWCB;
  switch (command->opcode) {
    case UMS_INQUIRY:
      ums_handle_inquiry(ums, cbw);
      break;
    case UMS_TEST_UNIT_READY:
      ums_handle_test_unit_ready(ums, cbw);
      break;
    case UMS_REQUEST_SENSE:
      ums_handle_request_sense(ums, cbw);
      break;
    case UMS_READ_CAPACITY10:
      ums_handle_read_capacity10(ums, cbw);
      break;
    case UMS_READ_CAPACITY16:
      ums_handle_read_capacity16(ums, cbw);
      break;
    case UMS_MODE_SENSE6:
      ums_handle_mode_sense6(ums, cbw);
      break;
    case UMS_READ10:
      ums_handle_read10(ums, cbw);
      break;
    case UMS_READ12:
      ums_handle_read12(ums, cbw);
      break;
    case UMS_READ16:
      ums_handle_read16(ums, cbw);
      break;
    case UMS_WRITE10:
      ums_handle_write10(ums, cbw);
      break;
    case UMS_WRITE12:
      ums_handle_write12(ums, cbw);
      break;
    case UMS_WRITE16:
      ums_handle_write16(ums, cbw);
      break;
    case UMS_SYNCHRONIZE_CACHE:
      // TODO: This is presently untestable.
      // Implement this once we have a means of testing this.
      break;
    default:
      zxlogf(DEBUG, "ums_handle_cbw: unsupported opcode %d", command->opcode);
      if (cbw->dCBWDataTransferLength) {
        // queue zero length packet to satisfy data phase
        usb_request_t* req = ums->data_req;
        req->header.length = 0;
        ums->data_state = DATA_STATE_FAILED;
        ums_function_queue_data(ums, req);
      }
      ums_queue_csw(ums, CSW_FAILED);
      break;
  }
}

static void ums_cbw_complete(void* ctx, usb_request_t* req) {
  usb_ums_t* ums = ctx;

  zxlogf(DEBUG, "ums_cbw_complete %d %ld", req->response.status, req->response.actual);

  if (req->response.status == ZX_OK && req->response.actual == sizeof(ums_cbw_t)) {
    ums_cbw_t* cbw = &ums->current_cbw;
    memset(cbw, 0, sizeof(*cbw));
    __UNUSED size_t result = usb_request_copy_from(req, cbw, sizeof(*cbw), 0);
    ums_handle_cbw(ums, cbw);
  }
}

static void ums_data_complete(void* ctx, usb_request_t* req) {
  usb_ums_t* ums = ctx;

  zxlogf(DEBUG, "ums_data_complete %d %ld", req->response.status, req->response.actual);

  if (ums->data_state == DATA_STATE_WRITE) {
    size_t result =
        usb_request_copy_from(req, ums->storage + ums->data_offset, req->response.actual, 0);
    ZX_ASSERT(result == req->response.actual);
  } else if (ums->data_state == DATA_STATE_FAILED) {
    ums->data_state = DATA_STATE_NONE;
    ums_queue_csw(ums, CSW_FAILED);
    return;
  } else {
    ums->data_state = DATA_STATE_NONE;
    ums_queue_csw(ums, CSW_SUCCESS);
    return;
  }

  ums->data_offset += req->response.actual;
  if (ums->data_remaining > req->response.actual) {
    ums->data_remaining -= req->response.actual;
  } else {
    ums->data_remaining = 0;
  }

  if (ums->data_remaining > 0) {
    ums_continue_transfer(ums);
  } else {
    ums->data_state = DATA_STATE_NONE;
    ums_queue_csw(ums, CSW_SUCCESS);
  }
}

static void ums_csw_complete(void* ctx, usb_request_t* req) {
  zxlogf(DEBUG, "ums_csw_complete %d %ld", req->response.status, req->response.actual);
}

static size_t ums_get_descriptors_size(void* ctx) { return sizeof(descriptors); }

static void ums_get_descriptors(void* ctx, void* buffer, size_t buffer_size, size_t* out_actual) {
  size_t length = sizeof(descriptors);
  if (length > buffer_size) {
    length = buffer_size;
  }
  memcpy(buffer, &descriptors, length);
  *out_actual = length;
}

static zx_status_t ums_control(void* ctx, const usb_setup_t* setup, const void* write_buffer,
                               size_t write_size, void* out_read_buffer, size_t read_size,
                               size_t* out_read_actual) {
  if (setup->bmRequestType == (USB_DIR_IN | USB_TYPE_CLASS | USB_RECIP_INTERFACE) &&
      setup->bRequest == USB_REQ_GET_MAX_LUN && setup->wValue == 0 && setup->wIndex == 0 &&
      setup->wLength >= sizeof(uint8_t)) {
    *((uint8_t*)out_read_buffer) = 0;
    *out_read_actual = sizeof(uint8_t);
    return ZX_OK;
  }

  return ZX_ERR_NOT_SUPPORTED;
}

static zx_status_t ums_set_configured(void* ctx, bool configured, usb_speed_t speed) {
  zxlogf(DEBUG, "ums_set_configured %d %d", configured, speed);
  usb_ums_t* ums = ctx;
  zx_status_t status;

  // TODO(voydanoff) fullspeed and superspeed support
  if (configured) {
    if ((status = usb_function_config_ep(&ums->function, &descriptors.out_ep, NULL)) != ZX_OK ||
        (status = usb_function_config_ep(&ums->function, &descriptors.in_ep, NULL)) != ZX_OK) {
      zxlogf(ERROR, "ums_set_configured: usb_function_config_ep failed");
    }
  } else {
    if ((status = usb_function_disable_ep(&ums->function, ums->bulk_out_addr)) != ZX_OK ||
        (status = usb_function_disable_ep(&ums->function, ums->bulk_in_addr)) != ZX_OK) {
      zxlogf(ERROR, "ums_set_configured: usb_function_disable_ep failed");
    }
  }

  if (configured && status == ZX_OK) {
    // queue first read on OUT endpoint
    usb_request_complete_t cbw_complete = {
        .callback = ums_completion_callback,
        .ctx = ums,
    };
    usb_request_queue(ums, &ums->function, ums->cbw_req, &cbw_complete);
  }
  return status;
}

static zx_status_t ums_set_interface(void* ctx, uint8_t interface, uint8_t alt_setting) {
  return ZX_ERR_NOT_SUPPORTED;
}

usb_function_interface_protocol_ops_t ums_device_ops = {
    .get_descriptors_size = ums_get_descriptors_size,
    .get_descriptors = ums_get_descriptors,
    .control = ums_control,
    .set_configured = ums_set_configured,
    .set_interface = ums_set_interface,
};

static void usb_ums_unbind(void* ctx) {
  zxlogf(DEBUG, "usb_ums_unbind");
  usb_ums_t* ums = ctx;

  usb_function_cancel_all(&ums->function, ums->bulk_out_addr);
  usb_function_cancel_all(&ums->function, ums->bulk_in_addr);
  usb_function_cancel_all(&ums->function, descriptors.intf.bInterfaceNumber);

  mtx_lock(&ums->mtx);
  ums->active = false;
  cnd_signal(&ums->event);
  mtx_unlock(&ums->mtx);
  int retval;
  thrd_join(ums->thread, &retval);
  device_unbind_reply(ums->zxdev);
}

static void usb_ums_release(void* ctx) {
  zxlogf(DEBUG, "usb_ums_release");
  usb_ums_t* ums = ctx;

  if (ums->storage) {
    zx_vmar_unmap(zx_vmar_root_self(), (uintptr_t)ums->storage, STORAGE_SIZE);
  }
  if (ums->cbw_req) {
    usb_request_release(ums->cbw_req);
  }
  if (ums->data_req) {
    usb_request_release(ums->data_req);
  }
  if (ums->cbw_req) {
    usb_request_release(ums->csw_req);
  }
  cnd_destroy(&ums->event);
  mtx_destroy(&ums->mtx);
  free(ums);
}

static zx_status_t usb_ums_enable_writeback_cache(void* ctx, fidl_txn_t* txn) {
  usb_ums_t* ums = (usb_ums_t*)ctx;
  ums->writeback_cache = true;
  return fuchsia_hardware_usb_peripheral_block_DeviceEnableWritebackCache_reply(txn, ZX_OK);
}

static zx_status_t usb_ums_disable_writeback_cache(void* ctx, fidl_txn_t* txn) {
  usb_ums_t* ums = (usb_ums_t*)ctx;
  ums->writeback_cache = false;
  return fuchsia_hardware_usb_peripheral_block_DeviceDisableWritebackCache_reply(txn, ZX_OK);
}

static zx_status_t usb_ums_set_writeback_cache_reported(void* ctx, bool report, fidl_txn_t* txn) {
  usb_ums_t* ums = (usb_ums_t*)ctx;
  ums->writeback_cache_report = report;
  return fuchsia_hardware_usb_peripheral_block_DeviceSetWritebackCacheReported_reply(txn, ZX_OK);
}

fuchsia_hardware_usb_peripheral_block_Device_ops_t usb_cache_proto = {
    .EnableWritebackCache = usb_ums_enable_writeback_cache,
    .DisableWritebackCache = usb_ums_disable_writeback_cache,
    .SetWritebackCacheReported = usb_ums_set_writeback_cache_reported,
};

static zx_status_t usb_ums_message(void* ctx, fidl_incoming_msg_t* msg, fidl_txn_t* txn) {
  return fuchsia_hardware_usb_peripheral_block_Device_dispatch(ctx, txn, msg, &usb_cache_proto);
}

static zx_protocol_device_t usb_ums_proto = {
    .version = DEVICE_OPS_VERSION,
    .unbind = usb_ums_unbind,
    .release = usb_ums_release,
    .message = usb_ums_message,
};
static zx_handle_t vmo = 0;

static int usb_ums_thread(void* ctx) {
  usb_ums_t* ums = (usb_ums_t*)ctx;
  while (1) {
    mtx_lock(&ums->mtx);
    if (!(ums->cbw_req_complete || ums->csw_req_complete || ums->data_req_complete ||
          (!ums->active))) {
      cnd_wait(&ums->event, &ums->mtx);
    }
    mtx_unlock(&ums->mtx);
    if (!ums->active && !atomic_load(&ums->pending_request_count)) {
      return 0;
    }
    if (ums->cbw_req_complete) {
      atomic_fetch_add(&ums->pending_request_count, -1);
      ums->cbw_req_complete = false;
      ums_cbw_complete(ums, ums->cbw_req);
    }
    if (ums->csw_req_complete) {
      atomic_fetch_add(&ums->pending_request_count, -1);
      ums->csw_req_complete = false;
      ums_csw_complete(ums, ums->csw_req);
    }
    if (ums->data_req_complete) {
      atomic_fetch_add(&ums->pending_request_count, -1);
      ums->data_req_complete = false;
      ums_data_complete(ums, ums->data_req);
    }
  }
  return 0;
}

zx_status_t usb_ums_bind(void* ctx, zx_device_t* parent) {
  zxlogf(INFO, "usb_ums_bind");

  usb_ums_t* ums = calloc(1, sizeof(usb_ums_t));
  if (!ums) {
    return ZX_ERR_NO_MEMORY;
  }
  ums->data_state = DATA_STATE_NONE;
  ums->active = true;
  mtx_init(&ums->mtx, 0);
  atomic_init(&ums->pending_request_count, 0);
  cnd_init(&ums->event);
  zx_status_t status = ZX_OK;
  ums->writeback_cache = false;
  ums->writeback_cache_report = false;
  status = device_get_protocol(parent, ZX_PROTOCOL_USB_FUNCTION, &ums->function);
  if (status != ZX_OK) {
    goto fail;
  }

  ums->parent_req_size = usb_function_get_request_size(&ums->function);
  ZX_DEBUG_ASSERT(ums->parent_req_size != 0);

  status = usb_function_alloc_interface(&ums->function, &descriptors.intf.bInterfaceNumber);
  if (status != ZX_OK) {
    zxlogf(ERROR, "usb_ums_bind: usb_function_alloc_interface failed");
    goto fail;
  }
  status = usb_function_alloc_ep(&ums->function, USB_DIR_OUT, &ums->bulk_out_addr);
  if (status != ZX_OK) {
    zxlogf(ERROR, "usb_ums_bind: usb_function_alloc_ep failed");
    goto fail;
  }
  status = usb_function_alloc_ep(&ums->function, USB_DIR_IN, &ums->bulk_in_addr);
  if (status != ZX_OK) {
    zxlogf(ERROR, "usb_ums_bind: usb_function_alloc_ep failed");
    goto fail;
  }
  descriptors.out_ep.bEndpointAddress = ums->bulk_out_addr;
  descriptors.in_ep.bEndpointAddress = ums->bulk_in_addr;

  status =
      usb_request_alloc(&ums->cbw_req, BULK_MAX_PACKET, ums->bulk_out_addr, ums->parent_req_size);
  if (status != ZX_OK) {
    goto fail;
  }
  // Endpoint for data_req depends on current_cbw.bmCBWFlags,
  // and will be set in ums_function_queue_data.
  status = usb_request_alloc(&ums->data_req, DATA_REQ_SIZE, 0, ums->parent_req_size);
  if (status != ZX_OK) {
    goto fail;
  }
  status =
      usb_request_alloc(&ums->csw_req, BULK_MAX_PACKET, ums->bulk_in_addr, ums->parent_req_size);
  if (status != ZX_OK) {
    goto fail;
  }
  // create and map a VMO
  if (!vmo) {
    status = zx_vmo_create(STORAGE_SIZE, 0, &vmo);
    if (status != ZX_OK) {
      goto fail;
    }
  }
  ums->storage_handle = vmo;
  status = zx_vmar_map(zx_vmar_root_self(), ZX_VM_PERM_READ | ZX_VM_PERM_WRITE, 0,
                       ums->storage_handle, 0, STORAGE_SIZE, (zx_vaddr_t*)&ums->storage);
  if (status != ZX_OK) {
    goto fail;
  }

  ums->csw_req->header.length = sizeof(ums_csw_t);

  device_add_args_t args = {
      .version = DEVICE_ADD_ARGS_VERSION,
      .name = "usb-ums-function",
      .ctx = ums,
      .ops = &usb_ums_proto,
  };
  args.proto_id = ZX_PROTOCOL_CACHE_TEST;

  status = device_add(parent, &args, &ums->zxdev);
  if (status != ZX_OK) {
    zxlogf(ERROR, "usb_device_bind add_device failed %d", status);
    goto fail;
  }

  usb_function_set_interface(&ums->function, ums, &ums_device_ops);
  thrd_create_with_name(&ums->thread, usb_ums_thread, ums, "ums_worker");
  return ZX_OK;

fail:
  usb_ums_release(ums);
  return status;
}

static zx_driver_ops_t usb_ums_ops = {
    .version = DRIVER_OPS_VERSION,
    .bind = usb_ums_bind,
};

// clang-format off
ZIRCON_DRIVER_BEGIN(usb_ums, usb_ums_ops, "zircon", "0.1", 4)
    BI_ABORT_IF(NE, BIND_PROTOCOL, ZX_PROTOCOL_USB_FUNCTION),
    BI_ABORT_IF(NE, BIND_USB_CLASS, USB_CLASS_MSC),
    BI_ABORT_IF(NE, BIND_USB_SUBCLASS, USB_SUBCLASS_MSC_SCSI),
    BI_MATCH_IF(EQ, BIND_USB_PROTOCOL, USB_PROTOCOL_MSC_BULK_ONLY),
ZIRCON_DRIVER_END(usb_ums)
