// Copyright 2020 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "zstd-seekable.h"

#include <zircon/types.h>

#include <memory>

#include <blobfs/format.h>
#include <fbl/algorithm.h>
#include <fbl/auto_call.h>
#include <fbl/macros.h>
#include <fs/trace.h>
#include <zstd/zstd.h>
#include <zstd/zstd_seekable.h>

#include "compressor.h"
#include "zircon/errors.h"

namespace blobfs {

constexpr int kSeekableCompressionLevel = 5;

// TODO(49551): Consider disabling checksums if cryptographic verification suffices.
constexpr int kSeekableChecksumFlag = 1;

ZSTDSeekableCompressor::ZSTDSeekableCompressor(ZSTD_seekable_CStream* stream,
                                               void* compressed_buffer,
                                               size_t compressed_buffer_length)
    : stream_(stream),
      output_({
          .dst = compressed_buffer,
          .size = compressed_buffer_length,
          // Initialize output buffer leaving space for archive size header.
          .pos = kZSTDSeekableHeaderSize,
      }) {}

ZSTDSeekableCompressor::~ZSTDSeekableCompressor() { ZSTD_seekable_freeCStream(stream_); }

zx_status_t ZSTDSeekableCompressor::Create(size_t input_size, void* compression_buffer,
                                           size_t compression_buffer_length,
                                           std::unique_ptr<ZSTDSeekableCompressor>* out) {
  if (BufferMax(input_size) > compression_buffer_length)
    return ZX_ERR_BUFFER_TOO_SMALL;

  ZSTD_seekable_CStream* stream = ZSTD_seekable_createCStream();
  if (stream == nullptr)
    return ZX_ERR_NO_MEMORY;

  auto compressor = std::unique_ptr<ZSTDSeekableCompressor>(
      new ZSTDSeekableCompressor(std::move(stream), compression_buffer, compression_buffer_length));

  size_t r = ZSTD_seekable_initCStream(compressor->stream_, kSeekableCompressionLevel,
                                       kSeekableChecksumFlag, kZSTDSeekableMaxFrameSize);
  if (ZSTD_isError(r)) {
    FS_TRACE_ERROR("[blobfs][zstd-seekable] Failed to initialize seekable cstream: %s\n",
                   ZSTD_getErrorName(r));
    return ZX_ERR_INTERNAL;
  }

  *out = std::move(compressor);
  return ZX_OK;
}

// TODO(markdittmer): This doesn't take into account a couple issues related to the seekable format:
// 1. It doesn't include the seekable format footer.
// 2. Frequent flushes caused by the seekable format's max frame size can cause compressed contents
//    to exceed this bound.
size_t ZSTDSeekableCompressor::BufferMax(size_t blob_size) {
  // Add archive size header to estimate.
  return kZSTDSeekableHeaderSize + ZSTD_compressBound(blob_size);
}

zx_status_t ZSTDSeekableCompressor::WriteHeader(void* buf, size_t buf_size,
                                                ZSTDSeekableHeader header) {
  if (buf_size < kZSTDSeekableHeaderSize) {
    return ZX_ERR_BUFFER_TOO_SMALL;
  }

  uint64_t* size_header = static_cast<uint64_t*>(buf);
  size_header[0] = header.archive_size;

  return ZX_OK;
}

zx_status_t ZSTDSeekableCompressor::Update(const void* input_data, size_t input_length) {
  ZSTD_inBuffer input;
  input.src = input_data;
  input.size = input_length;
  input.pos = 0;

  // Invoke ZSTD_seekable_compressStream repeatedly to consume entire input buffer.
  //
  // From the ZSTD seekable format documentation:
  //   Use ZSTD_seekable_compressStream() repetitively to consume input stream.
  //   The function will automatically update both `pos` fields.
  //   Note that it may not consume the entire input, in which case `pos < size`,
  //   and it's up to the caller to present again remaining data.
  size_t zstd_return = 0;
  while (input.pos != input_length) {
    zstd_return = ZSTD_seekable_compressStream(stream_, &output_, &input);
    if (ZSTD_isError(zstd_return)) {
      FS_TRACE_ERROR("[blobfs][zstd-seekable] Failed to compress in seekable format: %s\n",
                     ZSTD_getErrorName(zstd_return));
      return ZX_ERR_IO_DATA_INTEGRITY;
    }
  }

  return ZX_OK;
}

zx_status_t ZSTDSeekableCompressor::End() {
  size_t zstd_return = ZSTD_seekable_endStream(stream_, &output_);
  if (ZSTD_isError(zstd_return)) {
    FS_TRACE_ERROR("[blobfs][zstd-seekable] Failed to end seekable stream: %s\n",
                   ZSTD_getErrorName(zstd_return));
    return ZX_ERR_IO_DATA_INTEGRITY;
  }

  // Store archive size header as first bytes of blob.
  uint64_t archive_size = output_.pos - kZSTDSeekableHeaderSize;
  WriteHeader(output_.dst, output_.size, ZSTDSeekableHeader{archive_size});

  return ZX_OK;
}

size_t ZSTDSeekableCompressor::Size() const { return output_.pos; }

zx_status_t ZSTDSeekableDecompressor::DecompressArchive(void* uncompressed_buf,
                                                        size_t* uncompressed_size,
                                                        const void* compressed_buf,
                                                        size_t compressed_size, size_t offset) {
  ZSTD_seekable* stream = ZSTD_seekable_create();
  auto cleanup = fbl::MakeAutoCall([&stream] { ZSTD_seekable_free(stream); });
  size_t zstd_return = ZSTD_seekable_initBuff(stream, compressed_buf, compressed_size);
  if (ZSTD_isError(zstd_return)) {
    FS_TRACE_ERROR("[blobfs][zstd-seekable] Failed to initialize seekable dstream: %s\n",
                   ZSTD_getErrorName(zstd_return));
    return ZX_ERR_INTERNAL;
  }

  size_t decompressed = 0;
  zstd_return = 0;
  do {
    zstd_return = ZSTD_seekable_decompress(stream, uncompressed_buf, *uncompressed_size,
                                           offset + decompressed);
    decompressed += zstd_return;
    if (ZSTD_isError(zstd_return)) {
      FS_TRACE_ERROR("[blobfs][zstd-seekable] Failed to decompress: %s\n",
                     ZSTD_getErrorName(zstd_return));
      return ZX_ERR_IO_DATA_INTEGRITY;
    }
    // From the ZSTD_seekable_decompress Documentation:
    //   The return value is the number of bytes decompressed, or an error code checkable with
    //   ZSTD_isError().
    // Assume that a return value of 0 indicates, not only that 0 bytes were decompressed, but also
    // that there are no more bytes to decompress.
  } while (zstd_return > 0 && decompressed < *uncompressed_size);

  *uncompressed_size = decompressed;
  return ZX_OK;
}

zx_status_t ZSTDSeekableDecompressor::Decompress(void* uncompressed_buf, size_t* uncompressed_size,
                                                 const void* compressed_buf,
                                                 const size_t max_compressed_size) {
  return DecompressRange(uncompressed_buf, uncompressed_size, compressed_buf, max_compressed_size,
                         0);
}

// SeekableDecompressor implementation.
zx_status_t ZSTDSeekableDecompressor::DecompressRange(void* uncompressed_buf,
                                                      size_t* uncompressed_size,
                                                      const void* compressed_buf,
                                                      size_t max_compressed_size,
                                                      size_t offset) {
  TRACE_DURATION("blobfs", "ZSTDSeekableDecompressor::DecompressRange", "uncompressed_size",
                 *uncompressed_size, "max_compressed_size", max_compressed_size);

  ZSTDSeekableHeader header;
  zx_status_t status = ReadHeader(compressed_buf, max_compressed_size, &header);
  if (status != ZX_OK) {
    return status;
  }

  const uint8_t* compressed_byte_buf = static_cast<const uint8_t*>(compressed_buf);
  return DecompressArchive(uncompressed_buf, uncompressed_size,
                           compressed_byte_buf + kZSTDSeekableHeaderSize, header.archive_size,
                           offset);
}

zx_status_t ZSTDSeekableDecompressor::ReadHeader(const void* buf, size_t buf_size,
                                                 ZSTDSeekableHeader* header) {
  if (buf_size < kZSTDSeekableHeaderSize) {
    return ZX_ERR_BUFFER_TOO_SMALL;
  }
  const uint64_t* size_header = static_cast<const uint64_t*>(buf);
  const uint64_t archive_size = size_header[0];
  header->archive_size = archive_size;
  if (buf_size < archive_size + kZSTDSeekableHeaderSize) {
    return ZX_ERR_BUFFER_TOO_SMALL;
  }

  return ZX_OK;
}

}  // namespace blobfs
