HybridCipher: Expose simpler API. Hide complex API.

We convert the functions Encrypt() and Decrypt() which expose
three-part ciphertexts into private internal functions and instead
expose new versions of Encrypt() and Decrypt() that hide that
complexity and deal with a simple opaque compound ciphertext.

We have in the past had conversations around using protocol buffer
messages to encapsulate the three parts of the hybrid ciphertext.
I actually started to do that work and came to the conclusion that
it was not a good idea. Our three part ciphertext is quite simple:
Two of the parts have fixed length so it is quite easy to simply
concatenate the parts. I say let the encoding be simply
<33-byte-public-key><16-byte-salt><symmetric ciphertext>
and be done with it.

There is no reason that a caller to the HybridCipher API needs to know
the internal structure of the ciphertext any more than there is
for a caller to know the internal structure of an AES ciphertext.

We will need a specification document that describes this byte layout
but we will need a specification document anyway to describe the whole
scheme.

Change-Id: I3e9205b41b1928c92ff6b6ee4c649ffeebe68481
diff --git a/util/crypto_util/cipher.cc b/util/crypto_util/cipher.cc
index 81d1d0f..e0acec9 100644
--- a/util/crypto_util/cipher.cc
+++ b/util/crypto_util/cipher.cc
@@ -39,29 +39,27 @@
 // here http://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.186-4.pdf)
 #define EC_CURVE_CONSTANT NID_X9_62_prime256v1
 
-  const EVP_AEAD* GetAEAD() {
-    // Note(rudominer) The constants KEY_SIZE and NONCE_SIZE are set based
-    // on the algorithm chosen. If this algorithm changes you must also
-    // change those constants accordingly.
-    //
-    // NOTE(pseudorandom) By using a 256-bit curve in EC_CURVE_CONSTANT for
-    // public-key cryptography when SymmetricCipher is used in HybridCipher,
-    // the effective security level is AES-128 and not AES-256.
-    return EVP_aead_aes_256_gcm();
-  }
-  static const size_t GROUP_ELEMENT_SIZE  = 256 / 8;  // (g^xy) object length
+const EVP_AEAD* GetAEAD() {
+  // Note(rudominer) The constants KEY_SIZE and NONCE_SIZE are set based
+  // on the algorithm chosen. If this algorithm changes you must also
+  // change those constants accordingly.
+  //
+  // NOTE(pseudorandom) By using a 256-bit curve in EC_CURVE_CONSTANT for
+  // public-key cryptography when SymmetricCipher is used in HybridCipher,
+  // the effective security level is AES-128 and not AES-256.
+  return EVP_aead_aes_256_gcm();
+}
+static const size_t GROUP_ELEMENT_SIZE = 256 / 8;  // (g^xy) object length
 
-  // For hybrid mode, we can fix the nonce to all zeroes without losing
-  // security. See: https://goto.google.com/aes-gcm-zero-nonce-security
-  const byte kAllZeroNonce[SymmetricCipher::NONCE_SIZE] = {0x00};
+// For hybrid mode, we can fix the nonce to all zeroes without losing
+// security. See: https://goto.google.com/aes-gcm-zero-nonce-security
+const byte kAllZeroNonce[SymmetricCipher::NONCE_SIZE] = {0x00};
 
 }  //  namespace
 
 class CipherContext {
  public:
-  ~CipherContext() {
-    EVP_AEAD_CTX_cleanup(&impl_);
-  }
+  ~CipherContext() { EVP_AEAD_CTX_cleanup(&impl_); }
 
   bool SetKey(const byte key[SymmetricCipher::KEY_SIZE]) {
     EVP_AEAD_CTX_cleanup(&impl_);
@@ -69,9 +67,7 @@
                              EVP_AEAD_DEFAULT_TAG_LENGTH, NULL);
   }
 
-  EVP_AEAD_CTX* get() {
-    return &impl_;
-  }
+  EVP_AEAD_CTX* get() { return &impl_; }
 
  private:
   EVP_AEAD_CTX impl_;
@@ -79,17 +75,14 @@
 
 class HybridCipherContext {
  public:
-  HybridCipherContext()
-      : key_(nullptr, ::EVP_PKEY_free) {}
+  HybridCipherContext() : key_(nullptr, ::EVP_PKEY_free) {}
 
   bool ResetKey() {
     key_.reset(EVP_PKEY_new());
     return (nullptr != key_);
   }
 
-  EVP_PKEY* GetKey() {
-    return key_.get();
-  }
+  EVP_PKEY* GetKey() { return key_.get(); }
 
  private:
   std::unique_ptr<EVP_PKEY, decltype(&::EVP_PKEY_free)> key_;
@@ -105,32 +98,33 @@
   return context_->SetKey(key);
 }
 
-bool SymmetricCipher::Encrypt(const byte nonce[NONCE_SIZE], const byte *ptext,
-    int ptext_len, std::vector<byte>* ctext) {
-
+bool SymmetricCipher::Encrypt(const byte nonce[NONCE_SIZE], const byte* ptext,
+                              int ptext_len, std::vector<byte>* ctext) {
   int max_out_len = EVP_AEAD_max_overhead(GetAEAD()) + ptext_len;
   ctext->resize(max_out_len);
   size_t out_len;
-  int rc = EVP_AEAD_CTX_seal(context_->get(), ctext->data(), &out_len,
-      max_out_len, nonce, NONCE_SIZE, ptext, ptext_len, NULL, 0);
+  int rc =
+      EVP_AEAD_CTX_seal(context_->get(), ctext->data(), &out_len, max_out_len,
+                        nonce, NONCE_SIZE, ptext, ptext_len, NULL, 0);
   ctext->resize(out_len);
   return rc;
 }
 
-bool SymmetricCipher::Decrypt(const byte nonce[NONCE_SIZE], const byte *ctext,
-  int ctext_len, std::vector<byte>* ptext) {
+bool SymmetricCipher::Decrypt(const byte nonce[NONCE_SIZE], const byte* ctext,
+                              int ctext_len, std::vector<byte>* ptext) {
   ptext->resize(ctext_len);
   size_t out_len;
-  int rc = EVP_AEAD_CTX_open(context_->get(), ptext->data(), &out_len,
-      ptext->size(), nonce, NONCE_SIZE, ctext, ctext_len, NULL, 0);
+  int rc =
+      EVP_AEAD_CTX_open(context_->get(), ptext->data(), &out_len, ptext->size(),
+                        nonce, NONCE_SIZE, ctext, ctext_len, NULL, 0);
   ptext->resize(out_len);
   return rc;
 }
 
 // HybridCipher methods.
 
-HybridCipher::HybridCipher() : context_(new HybridCipherContext()),
-                               symm_cipher_(new SymmetricCipher) {}
+HybridCipher::HybridCipher()
+    : context_(new HybridCipherContext()), symm_cipher_(new SymmetricCipher) {}
 
 HybridCipher::~HybridCipher() {}
 
@@ -149,7 +143,7 @@
 
   // Read bytes from public_key into ecpoint
   if (!EC_POINT_oct2point(EC_KEY_get0_group(eckey.get()), ecpoint.get(),
-                         public_key, PUBLIC_KEY_SIZE, nullptr)) {
+                          public_key, PUBLIC_KEY_SIZE, nullptr)) {
     return false;
   }
 
@@ -196,10 +190,27 @@
   return true;
 }
 
-bool HybridCipher::Encrypt(const byte *ptext, int ptext_len,
-                           byte public_key_part_out[PUBLIC_KEY_SIZE],
-                           byte salt_out[SALT_SIZE],
-                           std::vector<byte>* ctext) {
+bool HybridCipher::Encrypt(const byte* ptext, int ptext_len,
+                           std::vector<byte>* hybrid_ctext) {
+  byte public_key_part[PUBLIC_KEY_SIZE];
+  byte salt[SALT_SIZE];
+  std::vector<byte> symmetric_ctext;
+  if (!EncryptInternal(ptext, ptext_len, public_key_part, salt,
+                       &symmetric_ctext)) {
+    return false;
+  }
+  hybrid_ctext->resize(symmetric_ctext.size() + PUBLIC_KEY_SIZE + SALT_SIZE);
+  std::memcpy(hybrid_ctext->data(), public_key_part, PUBLIC_KEY_SIZE);
+  std::memcpy(hybrid_ctext->data() + PUBLIC_KEY_SIZE, salt, SALT_SIZE);
+  std::memcpy(hybrid_ctext->data() + PUBLIC_KEY_SIZE + SALT_SIZE,
+              symmetric_ctext.data(), symmetric_ctext.size());
+  return true;
+}
+
+bool HybridCipher::EncryptInternal(const byte* ptext, int ptext_len,
+                                   byte public_key_part_out[PUBLIC_KEY_SIZE],
+                                   byte salt_out[SALT_SIZE],
+                                   std::vector<byte>* symmetric_ctext_out) {
   std::unique_ptr<EC_KEY, decltype(&::EC_KEY_free)> eckey(
       EC_KEY_new_by_curve_name(EC_CURVE_CONSTANT), EC_KEY_free);
   if (!eckey) {
@@ -208,7 +219,7 @@
 
   // Generate fresh EC key. The public key will be published in
   // public_key_part and the EC key will be used to generate a symmetric key
-  // that encrypts ptext bytes into ctext
+  // that encrypts ptext bytes into symmetric_ctext_out
   if (!EC_KEY_generate_key(eckey.get())) {
     return false;
   }
@@ -216,14 +227,13 @@
   // Write EC public key into public_key_part
   if (EC_POINT_point2oct(EC_KEY_get0_group(eckey.get()),
                          EC_KEY_get0_public_key(eckey.get()),
-                         POINT_CONVERSION_COMPRESSED,
-                         public_key_part_out,
+                         POINT_CONVERSION_COMPRESSED, public_key_part_out,
                          PUBLIC_KEY_SIZE, nullptr) != PUBLIC_KEY_SIZE) {
     return false;
   }
 
   byte shared_key[GROUP_ELEMENT_SIZE];  // To store g^(xy) after ECDH
-  const EC_POINT *ec_pub_point =
+  const EC_POINT* ec_pub_point =
       EC_KEY_get0_public_key(EVP_PKEY_get0_EC_KEY(context_->GetKey()));
   size_t shared_key_len = ECDH_compute_key(shared_key, sizeof(shared_key),
                                            ec_pub_point, eckey.get(), nullptr);
@@ -242,9 +252,8 @@
   std::memcpy(hkdf_input.data() + PUBLIC_KEY_SIZE, shared_key,
               GROUP_ELEMENT_SIZE);
   if (!HKDF(hkdf_derived_key, SymmetricCipher::KEY_SIZE, EVP_sha512(),
-            hkdf_input.data(), hkdf_input.size(),
-            salt_out, SALT_SIZE,
-            nullptr, 0)) {
+            hkdf_input.data(), hkdf_input.size(), salt_out, SALT_SIZE, nullptr,
+            0)) {
     return false;
   }
 
@@ -254,7 +263,8 @@
   }
   // For hybrid mode, we can fix the nonce to all zeroes without losing
   // security. See: https://goto.google.com/aes-gcm-zero-nonce-security
-  if (!symm_cipher_->Encrypt(kAllZeroNonce, ptext, ptext_len, ctext)) {
+  if (!symm_cipher_->Encrypt(kAllZeroNonce, ptext, ptext_len,
+                             symmetric_ctext_out)) {
     return false;
   }
 
@@ -262,10 +272,21 @@
   return true;
 }
 
-bool HybridCipher::Decrypt(const byte public_key_part[PUBLIC_KEY_SIZE],
-                           const byte salt[SALT_SIZE],
-                           const byte *ctext, int ctext_len,
+bool HybridCipher::Decrypt(const byte* hybrid_ctext, int ctext_len,
                            std::vector<byte>* ptext) {
+  if (!hybrid_ctext || ctext_len < PUBLIC_KEY_SIZE + SALT_SIZE + 1) {
+    return false;
+  }
+  return DecryptInternal(hybrid_ctext, hybrid_ctext + PUBLIC_KEY_SIZE,
+                         hybrid_ctext + PUBLIC_KEY_SIZE + SALT_SIZE,
+                         ctext_len - (PUBLIC_KEY_SIZE + SALT_SIZE), ptext);
+}
+
+bool HybridCipher::DecryptInternal(const byte public_key_part[PUBLIC_KEY_SIZE],
+                                   const byte salt[SALT_SIZE],
+                                   const byte* symmetric_ctext,
+                                   int symmetric_ctext_len,
+                                   std::vector<byte>* ptext) {
   // Read public_key_part into new EVP_PKEY object
   std::unique_ptr<EC_KEY, decltype(&::EC_KEY_free)> eckey(
       EC_KEY_new_by_curve_name(EC_CURVE_CONSTANT), EC_KEY_free);
@@ -279,9 +300,8 @@
   }
 
   // Read bytes from public_key_part into ecpoint
-  if (!EC_POINT_oct2point(EC_KEY_get0_group(eckey.get()),
-                          ecpoint.get(), public_key_part,
-                          PUBLIC_KEY_SIZE, nullptr)) {
+  if (!EC_POINT_oct2point(EC_KEY_get0_group(eckey.get()), ecpoint.get(),
+                          public_key_part, PUBLIC_KEY_SIZE, nullptr)) {
     return false;
   }
 
@@ -291,10 +311,9 @@
   }
 
   byte shared_key[GROUP_ELEMENT_SIZE];  // To store g^(xy) after ECDH
-  size_t shared_key_len = ECDH_compute_key(shared_key, sizeof(shared_key),
-                                ecpoint.get(),
-                                EVP_PKEY_get0_EC_KEY(context_->GetKey()),
-                                nullptr);
+  size_t shared_key_len =
+      ECDH_compute_key(shared_key, sizeof(shared_key), ecpoint.get(),
+                       EVP_PKEY_get0_EC_KEY(context_->GetKey()), nullptr);
   if (shared_key_len != sizeof(shared_key)) {
     return false;
   }
@@ -306,20 +325,19 @@
   std::memcpy(hkdf_input.data() + PUBLIC_KEY_SIZE, shared_key,
               GROUP_ELEMENT_SIZE);
   if (!HKDF(hkdf_derived_key, SymmetricCipher::KEY_SIZE, EVP_sha512(),
-            hkdf_input.data(), hkdf_input.size(),
-            salt, SALT_SIZE,
-            nullptr, 0)) {
+            hkdf_input.data(), hkdf_input.size(), salt, SALT_SIZE, nullptr,
+            0)) {
     return false;
   }
 
-  // Now decrypt using symm_cipher_ interface
+  // Decrypt using symm_cipher_ interface
   if (!symm_cipher_->set_key(hkdf_derived_key)) {
     return false;
   }
 
-  // For hybrid mode, we can fix the nonce to all zeroes without losing
-  // security. See: https://goto.google.com/aes-gcm-zero-nonce-security
-  if (!symm_cipher_->Decrypt(kAllZeroNonce, ctext, ctext_len, ptext)) {
+  // Our encryption always uses the all-zero nonce.
+  if (!symm_cipher_->Decrypt(kAllZeroNonce, symmetric_ctext,
+                             symmetric_ctext_len, ptext)) {
     return false;
   }
 
diff --git a/util/crypto_util/cipher.h b/util/crypto_util/cipher.h
index 6179f83..cf50e66 100644
--- a/util/crypto_util/cipher.h
+++ b/util/crypto_util/cipher.h
@@ -105,15 +105,16 @@
 //    compression function (also, see Note 2)
 //    4. (Symmetric) encrypts message using SymmetricCipher::encrypt with key
 //    and all-zero nonce into ciphertext
-//    5. Publishes (public_key_part, salt, ciphertext) as the hybrid
+//    5. Publishes (public_key_part, salt, symmetric_ciphertext) as the hybrid
 //    ciphertext, where public_key_part is the X9.62 serialization of g^y.
 //
-// Dec(private key, hybrid ciphertext = (public_key_part, salt, ciphertext)):
+// Dec(private key, hybrid_ciphertext)
+//    where hybrid_ciphertext = (public_key_part, salt, symmetric_ciphertext)):
 //
 //    1. Computes symmetric key = HKDF(g^y, g^xy, salt) with SHA512
 //    compression function from private key (x) and public_key_part (g^y)
-//    2. (Symmetric) decrypts ciphertext using SymmetricCipher::decrypt with
-//    key and all-zero nonce.
+//    2. (Symmetric) decrypts symmetric_ciphertext using
+//    SymmetricCipher::decrypt with key and all-zero nonce.
 //
 // An instance of HybridCipher may be used repeatedly for multiple
 // encryptions or decryptions. The method set_public_key() must be used before
@@ -132,6 +133,7 @@
 // compromising security? Probably not.
 class HybridCipher {
  public:
+  // NOTE: All thre sizes below specify a number of bytes (not bits.)
   static const size_t PUBLIC_KEY_SIZE    = 33;  // One byte extra
                                                 // for X9.62 serialization
   static const size_t PRIVATE_KEY_SIZE   = 256 / 8;
@@ -141,7 +143,7 @@
   ~HybridCipher();
 
   // Sets the public key for encryption. This must be invoked
-  // at least once before encrypt is called. Using decryption after
+  // at least once before Encrypt is called. Using decryption after
   // set_public_key is undefined behavior.
   // Returns true for success or false for failure. Use the functions
   // in errors.h to obtain error information upon failure.
@@ -151,7 +153,7 @@
   bool set_public_key(const byte key[PUBLIC_KEY_SIZE]);
 
   // Sets the private key for decryption. This must be invoked
-  // at least once before decrypt is called. Using encryption after
+  // at least once before Decrypt is called. Using encryption after
   // set_private_key is undefined behavior.
   // Returns true for success or false for failure. Use the functions
   // in errors.h to obtain error information upon failure.
@@ -166,8 +168,37 @@
   //
   // |ptext_len| The number of bytes of plain text
   //
-  // The output is represented in four parts which will be written to
-  // |public_key_part_out|, |salt_out|, |nonce_out| and |ctext| respectively
+  // |hybrid_ctext| A pointer to a vector that will be modified to
+  // contain the hybrid ciphertext.
+  //
+  // Returns true for success or false for failure. Use the functions
+  // in errors.h to obtain error information upon failure.
+  bool Encrypt(const byte *ptext, int ptext_len,
+               std::vector<byte>* hybrid_ctext);
+
+  // Performs ECDH-based hybrid decryption.
+  //
+  // |hybrid_ctext| The hybrid ciphertext to be decrypted.
+  //
+  // |hybrid_ctext_len| The number of bytes of hybrid ciphertext.
+  //
+  // |ptext| A pointer to a vector that will be modified to contain the
+  // recovered plaintext.
+  //
+  // Returns true for success or false for failure. Use the functions
+  // in errors.h to obtain error information upon failure.
+  bool Decrypt(const byte *hybrid_ctext, int hybrid_ctext_len,
+               std::vector<byte>* ptext);
+
+ private:
+  // Performs ECDH-based hybrid encryption
+  //
+  // |ptext| The plain text to be encrypted
+  //
+  // |ptext_len| The number of bytes of plain text
+  //
+  // The output is represented in three parts which will be written to
+  // |public_key_part_out|, |salt_out| and |symmetric_ctext_out| respectively
   //
   // |public_key_part_out| must point to a buffer of size |PUBLIC_KEY_SIZE|.
   // The X9.62 serialization of g^y will be written there
@@ -175,15 +206,15 @@
   // |salt| is a pointer to a vector that will be modified to store a random
   // salt of size |SALT_SIZE|
   //
-  // |ctext| A pointer to a vector that will be modified to contain
-  // the ciphertext under the derived symmetric key
+  // |symmetric_ctext_out| A pointer to a vector that will be modified to
+  // contain the ciphertext under the derived symmetric key
   //
   // Returns true for success or false for failure. Use the functions
   // in errors.h to obtain error information upon failure.
-  bool Encrypt(const byte *ptext, int ptext_len,
-               byte public_key_part_out[PUBLIC_KEY_SIZE],
-               byte salt_out[SALT_SIZE],
-               std::vector<byte>* ctext);
+  bool EncryptInternal(const byte* ptext, int ptext_len,
+                       byte public_key_part_out[PUBLIC_KEY_SIZE],
+                       byte salt_out[SALT_SIZE],
+                       std::vector<byte>* symmetric_ctext_out);
 
   // Performs ECDH-based hybrid decryption.
   //
@@ -195,21 +226,19 @@
   // |salt| The salt to be used in the HKDF step in decryption. Must have size
   // |SALT_SIZE|
   //
-  // |ctext| The ciphertext to be decrypted.
+  // |symmetric_ctext| The ciphertext to be decrypted.
   //
-  // |ctext_len| The number of bytes of ciphertext.
+  // |symmetric_ctext_len| The number of bytes of symmetric ciphertext.
   //
   // |ptext| A pointer to a vector that will be modified to contain the
   // recovered plaintext.
   //
   // Returns true for success or false for failure. Use the functions
   // in errors.h to obtain error information upon failure.
-  bool Decrypt(const byte public_key_part[PUBLIC_KEY_SIZE],
-               const byte salt[SALT_SIZE],
-               const byte *ctext, int ctext_len,
-               std::vector<byte>* ptext);
+  bool DecryptInternal(const byte public_key_part[PUBLIC_KEY_SIZE],
+                       const byte salt[SALT_SIZE], const byte* symmetric_ctext,
+                       int symmetric_ctext_len, std::vector<byte>* ptext);
 
- private:
   std::unique_ptr<HybridCipherContext> context_;
   std::unique_ptr<SymmetricCipher> symm_cipher_;
 };
diff --git a/util/crypto_util/cipher_test.cc b/util/crypto_util/cipher_test.cc
index 3b9712f..86c157a 100644
--- a/util/crypto_util/cipher_test.cc
+++ b/util/crypto_util/cipher_test.cc
@@ -101,21 +101,17 @@
     const byte private_key[HybridCipher::PRIVATE_KEY_SIZE]) {
 
   // Encrypt
-  byte public_key_part[HybridCipher::PUBLIC_KEY_SIZE];
-  byte salt[HybridCipher::SALT_SIZE];
   std::vector<byte> cipher_text;
   ASSERT_TRUE(hybrid_cipher->set_public_key(public_key))
       << GetLastErrorMessage();
-  EXPECT_TRUE(hybrid_cipher->Encrypt(plain_text, ptext_len, public_key_part,
-                                     salt, &cipher_text))
+  EXPECT_TRUE(hybrid_cipher->Encrypt(plain_text, ptext_len, &cipher_text))
       << GetLastErrorMessage();
 
   // Decrypt
   std::vector<byte> recovered_text;
   ASSERT_TRUE(hybrid_cipher->set_private_key(private_key))
       << GetLastErrorMessage();
-  ASSERT_TRUE(hybrid_cipher->Decrypt(public_key_part, salt,
-                                     cipher_text.data(), cipher_text.size(),
+  ASSERT_TRUE(hybrid_cipher->Decrypt(cipher_text.data(), cipher_text.size(),
                                      &recovered_text))
       << GetLastErrorMessage();
 
@@ -125,18 +121,18 @@
             std::string((const char*)plain_text));
 
   // Decrypt with flipped salt
-  salt[0] ^= 0x1;  // flip a bit in the salt
-  EXPECT_FALSE(hybrid_cipher->Decrypt(public_key_part, salt,
-                                      cipher_text.data(), cipher_text.size(),
+  cipher_text.data()[HybridCipher::PUBLIC_KEY_SIZE] ^=
+      0x1;  // flip a bit in the first byte of the salt
+  EXPECT_FALSE(hybrid_cipher->Decrypt(cipher_text.data(), cipher_text.size(),
                                       &recovered_text))
       << GetLastErrorMessage();
 
   // Decrypt with modified public_key_part
-  salt[0] ^= 0x1;  // flip salt bit back
-  public_key_part[2] ^= 0x1;  // flip any bit except in first byte (due to
-                              // X9.62 serialization)
-  EXPECT_FALSE(hybrid_cipher->Decrypt(public_key_part, salt,
-                                      cipher_text.data(), cipher_text.size(),
+  cipher_text.data()[HybridCipher::PUBLIC_KEY_SIZE] ^=
+      0x1;                       // flip salt bit back
+  cipher_text.data()[2] ^= 0x1;  // flip any bit except in first byte (due to
+                                 // X9.62 serialization)
+  EXPECT_FALSE(hybrid_cipher->Decrypt(cipher_text.data(), cipher_text.size(),
                                       &recovered_text))
       << GetLastErrorMessage();
 }