blob: 1e2ebdfa73d3f11bfc5279c0cc8372e7e1f564bf [file] [log] [blame]
// Copyright 2019 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package hcvault_test
import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
vault_api "github.com/hashicorp/vault/api"
"github.com/google/tink/go/integration/hcvault"
)
const (
keyURITmpl = "%s/transit/keys/key-1"
token = "mytoken"
)
func TestVaultNewAEAD_EncryptDecrypt(t *testing.T) {
server, _, tlsConfig := newServer(t)
defer server.Close()
client := newVaultAPIClient(t, server.URL, token, tlsConfig)
aead, err := hcvault.NewAEAD("/transit/keys/key-1", client.Logical())
if err != nil {
t.Fatalf("hcvault.NewAEAD() err = %v, want nil", err)
}
plaintext := []byte("plaintext")
associatedData := []byte("associatedData")
ciphertext, err := aead.Encrypt(plaintext, associatedData)
if err != nil {
t.Fatalf("aead.Encrypt(plaintext, associatedData) err = %v, want nil", err)
}
gotPlaintext, err := aead.Decrypt(ciphertext, associatedData)
if err != nil {
t.Fatalf("aead.Decrypt(ciphertext, associatedData) err = %v, want nil", err)
}
if !bytes.Equal(gotPlaintext, plaintext) {
t.Fatalf("aead.Decrypt(ciphertext, associatedData) = %s, want %s", gotPlaintext, plaintext)
}
otherAssociatedData := []byte("otherAssociatedData")
_, err = aead.Decrypt(ciphertext, otherAssociatedData)
if err == nil {
t.Error("aead.Decrypt(ciphertext, otherAssociatedData) err = nil, want error")
}
}
func TestVaultNewAEAD_DecryptWithFixedCiphertext(t *testing.T) {
server, _, tlsConfig := newServer(t)
defer server.Close()
client := newVaultAPIClient(t, server.URL, token, tlsConfig)
aead, err := hcvault.NewAEAD("/transit/keys/key-1", client.Logical())
if err != nil {
t.Fatalf("hcvault.NewAEAD() err = %v, want nil", err)
}
// associatedData is passed as "context" parameter to vault decrypt.
plaintext := []byte("plaintext")
associatedData := []byte("associatedData")
ciphertext := fakeEncrypt(plaintext, associatedData, nil)
gotPlaintext, err := aead.Decrypt(ciphertext, associatedData)
if err != nil {
t.Fatalf("aead.Decrypt(ciphertext, associatedData) err = %v, want nil", err)
}
if !bytes.Equal(gotPlaintext, plaintext) {
t.Fatalf("aead.Decrypt(ciphertext, associatedData) = %s, want %s", gotPlaintext, associatedData)
}
}
func TestVaultNewAEADWithLegacyContextParamater_isCompatible(t *testing.T) {
server, uriPrefix, tlsConfig := newServer(t)
defer server.Close()
client := newVaultAPIClient(t, server.URL, token, tlsConfig)
// Create AEAD with WithLegacyContextParamater.
aead1, err := hcvault.NewAEAD("/transit/keys/key-1", client.Logical(), hcvault.WithLegacyContextParamater())
if err != nil {
t.Fatalf("hcvault.NewAEAD() err = %v, want nil", err)
}
// Create AEAD with hcvault.NewClient and GetAEAD.
hcvaultClient, err := hcvault.NewClient(uriPrefix, tlsConfig, token)
if err != nil {
t.Fatalf("hcvault.NewClient() err = %v, want nil", err)
}
keyURI := fmt.Sprintf("%s/transit/keys/key-1", uriPrefix)
aead2, err := hcvaultClient.GetAEAD(keyURI)
if err != nil {
t.Fatalf("hcvaultClient.GetAEAD(%q) err = %v, want nil", keyURI, err)
}
plaintext := []byte("plaintext")
associatedData := []byte("associatedData")
ciphertext2, err := aead2.Encrypt(plaintext, associatedData)
if err != nil {
t.Fatalf("aead2.Encrypt(plaintext, associatedData) err = %v, want nil", err)
}
gotPlaintext1, err := aead1.Decrypt(ciphertext2, associatedData)
if err != nil {
t.Fatalf("aead1.Decrypt(ciphertext, associatedData) err = %v, want nil", err)
}
if !bytes.Equal(gotPlaintext1, plaintext) {
t.Fatalf("aead1.Decrypt(ciphertext2, associatedData) = %s, want %s", gotPlaintext1, plaintext)
}
ciphertext1, err := aead1.Encrypt(plaintext, associatedData)
if err != nil {
t.Fatalf("aead2.Encrypt(plaintext, associatedData) err = %v, want nil", err)
}
gotPlaintext2, err := aead2.Decrypt(ciphertext1, associatedData)
if err != nil {
t.Fatalf("aead2.Decrypt(ciphertext1, associatedData) err = %v, want nil", err)
}
if !bytes.Equal(gotPlaintext2, plaintext) {
t.Fatalf("aead2.Decrypt(ciphertext1, associatedData) = %s, want %s", gotPlaintext2, plaintext)
}
}
func TestVaultClientAEAD_EncryptDecrypt(t *testing.T) {
server, uriPrefix, tlsConfig := newServer(t)
defer server.Close()
client, err := hcvault.NewClient(uriPrefix, tlsConfig, token)
if err != nil {
t.Fatalf("hcvault.NewClient() err = %v, want nil", err)
}
keyURI := fmt.Sprintf(keyURITmpl, uriPrefix)
aead, err := client.GetAEAD(keyURI)
if err != nil {
t.Fatalf("client.GetAEAD(%q) err = %v, want nil", keyURI, err)
}
plaintext := []byte("plaintext")
context := []byte("context")
ciphertext, err := aead.Encrypt(plaintext, context)
if err != nil {
t.Fatalf("aead.Encrypt(plaintext, context) err = %v, want nil", err)
}
gotPlaintext, err := aead.Decrypt(ciphertext, context)
if err != nil {
t.Fatalf("aead.Decrypt(ciphertext, context) err = %v, want nil", err)
}
if !bytes.Equal(gotPlaintext, plaintext) {
t.Fatalf("aead.Decrypt(ciphertext, context) = %s, want %s", gotPlaintext, plaintext)
}
invalidContext := []byte("invalidContext")
_, err = aead.Decrypt(ciphertext, invalidContext)
if err == nil {
t.Error("aead.Decrypt(ciphertext, invalidContext) err = nil, want error")
}
}
func TestVaultClientAEAD_DecryptWithFixedCiphertext(t *testing.T) {
server, uriPrefix, tlsConfig := newServer(t)
defer server.Close()
client, err := hcvault.NewClient(uriPrefix, tlsConfig, token)
if err != nil {
t.Fatalf("hcvault.NewClient() err = %v, want nil", err)
}
keyURI := fmt.Sprintf(keyURITmpl, uriPrefix)
aead, err := client.GetAEAD(keyURI)
if err != nil {
t.Fatalf("client.GetAEAD(%q) err = %v, want nil", keyURI, err)
}
// associatedData is passed as "context" parameter to vault decrypt.
plaintext := []byte("plaintext")
context := []byte("context")
ciphertext := fakeEncrypt(plaintext, nil, context)
gotPlaintext, err := aead.Decrypt(ciphertext, context)
if err != nil {
t.Fatalf("aead.Decrypt(ciphertext, context) err = %v, want nil", err)
}
if !bytes.Equal(gotPlaintext, plaintext) {
t.Fatalf("aead.Decrypt(ciphertext, context) = %s, want %s", gotPlaintext, plaintext)
}
}
func TestGetAEADFailWithBadKeyURI(t *testing.T) {
server, uriPrefix, tlsConfig := newServer(t)
defer server.Close()
client, err := hcvault.NewClient(uriPrefix, tlsConfig, token)
if err != nil {
t.Fatalf("hcvault.NewClient() err = %v, want nil", err)
}
for _, test := range []struct {
name string
keyURI string
}{
{
name: "empty",
keyURI: fmt.Sprintf("%s/", uriPrefix),
},
{
name: "without slash",
keyURI: fmt.Sprintf("%s/badKeyUri", uriPrefix),
},
{
name: "with one slash",
keyURI: fmt.Sprintf("%s/bad/KeyUri", uriPrefix),
},
{
name: "with three slash",
keyURI: fmt.Sprintf("%s/one/two/three/four", uriPrefix),
},
} {
t.Run(test.name, func(t *testing.T) {
if _, err := client.GetAEAD(test.keyURI); err == nil {
t.Errorf("client.GetAEAD(%q) err = nil, want error", test.keyURI)
}
})
}
}
// newVaultAPIClient creates a new vault API client
func newVaultAPIClient(t *testing.T, url string, token string, tlsConfig *tls.Config) *vault_api.Client {
t.Helper()
httpClient := vault_api.DefaultConfig().HttpClient
transport := httpClient.Transport.(*http.Transport)
transport.TLSClientConfig = tlsConfig.Clone()
cfg := &vault_api.Config{
Address: url,
HttpClient: httpClient,
}
client, err := vault_api.NewClient(cfg)
if err != nil {
t.Fatal(err)
}
client.SetToken(token)
return client
}
type closeFunc func() error
// newServer returns a fake, TLS-enabled Vault server, an "hcvault://" URI
// prefix for accessing it, and a TLS configuration which trusts the servers
// certificate.
//
// Once finished with the server, it's Close() method should be called.
//
// The URL and TLS configuration can be passed to hcvault.NewClient().
//
// The URL can also be used to construct valid key URIs for the server.
func newServer(t *testing.T) (server *httptest.Server, uriPrefix string, clientTLSConfig *tls.Config) {
server = httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.RequestURI {
// Encrypt
case "/v1/transit/encrypt/key-1":
decoder := json.NewDecoder(r.Body)
var encReq = make(map[string]string)
if err := decoder.Decode(&encReq); err != nil {
http.Error(w, fmt.Sprintf("Cannot decode encryption request: %s", err), 400)
return
}
plaintext, err := base64.StdEncoding.DecodeString(encReq["plaintext"])
if err != nil {
http.Error(w, "plaintext must be base64 encoded", 400)
return
}
context, err := base64.StdEncoding.DecodeString(encReq["context"])
if err != nil {
http.Error(w, "context must be base64 encoded", 400)
return
}
associatedData, err := base64.StdEncoding.DecodeString(encReq["associated_data"])
if err != nil {
http.Error(w, "associated_data must be base64 encoded", 400)
return
}
ciphertext := fakeEncrypt(plaintext, associatedData, context)
resp := map[string]any{
"data": map[string]string{
"ciphertext": string(ciphertext),
},
}
respBytes, err := json.Marshal(resp)
if err != nil {
t.Fatalf("Cannot encode encrypted data: %v", err)
}
if _, err := w.Write(respBytes); err != nil {
t.Fatalf("Cannot send encrypted data response: %v", err)
}
// Decrypt
case "/v1/transit/decrypt/key-1":
decoder := json.NewDecoder(r.Body)
var decReq = make(map[string]string)
if err := decoder.Decode(&decReq); err != nil {
http.Error(w, fmt.Sprintf("Cannot decode decryption request: %s", err), 400)
return
}
ciphertext := []byte(decReq["ciphertext"])
context, err := base64.StdEncoding.DecodeString(decReq["context"])
if err != nil {
http.Error(w, "context must be base64 encoded", 400)
return
}
associatedData, err := base64.StdEncoding.DecodeString(decReq["associated_data"])
if err != nil {
http.Error(w, "associated_data must be base64 encoded", 400)
return
}
plaintext, err := fakeDecrypt(ciphertext, associatedData, context)
if err != nil {
http.Error(w, fmt.Sprintf("Cannot decrypt ciphertext: %s", err), 400)
return
}
resp := map[string]any{
"data": map[string]string{
"plaintext": base64.StdEncoding.EncodeToString(plaintext),
},
}
respBytes, err := json.Marshal(resp)
if err != nil {
t.Fatalf("Cannot encode encrypted data: %v", err)
}
if _, err := w.Write(respBytes); err != nil {
t.Fatalf("Cannot send encrypted data response: %v", err)
}
default:
http.NotFound(w, r)
}
}))
uriPrefix = strings.Replace(server.URL, "https", "hcvault", 1)
certpool := x509.NewCertPool()
certpool.AddCert(server.Certificate())
clientTLSConfig = &tls.Config{RootCAs: certpool}
return server, uriPrefix, clientTLSConfig
}
// The ciphertext returned by HC Vault is of the form:
//
// vault:v1:<ciphertext>
//
// where ciphertext is base64-encoded. See:
// https://developer.hashicorp.com/vault/api-docs/secret/transit#sample-request-13
//
// The ciphertext returned by this fake implementation is of the form:
//
// enc:<context>:<associatedData>:<plaintext>
//
// where context, associatedData and plaintext are base64-encoded.
// It is deterministic and not secure.
func fakeEncrypt(plaintext, associatedData, context []byte) []byte {
s := fmt.Sprintf(
"enc:%s:%s:%s",
base64.StdEncoding.EncodeToString(context),
base64.StdEncoding.EncodeToString(associatedData),
base64.StdEncoding.EncodeToString(plaintext),
)
return []byte(s)
}
func TestFakeEncrypt(t *testing.T) {
want := []byte("enc:Y29udGV4dA==:YXNzb2NpYXRlZERhdGE=:cGxhaW50ZXh0")
got := fakeEncrypt([]byte("plaintext"), []byte("associatedData"), []byte("context"))
if !bytes.Equal(got, want) {
t.Errorf("fakeEncrypt(plaintext, associatedData, context) = %q, want %q", got, want)
}
}
func TestFakeEncryptWithoutAssociatedData(t *testing.T) {
want := []byte("enc:Y29udGV4dA==::cGxhaW50ZXh0")
got := fakeEncrypt([]byte("plaintext"), nil, []byte("context"))
if !bytes.Equal(got, want) {
t.Errorf("fakeEncrypt(plaintext, nil, context) = %q, want %q", got, want)
}
}
func TestFakeEncryptWithoutContext(t *testing.T) {
want := []byte("enc::YXNzb2NpYXRlZERhdGE=:cGxhaW50ZXh0")
got := fakeEncrypt([]byte("plaintext"), []byte("associatedData"), nil)
if !bytes.Equal(got, want) {
t.Errorf("fakeEncrypt(plaintext, associatedData, nil) = %q, want %q", got, want)
}
}
func fakeDecrypt(ciphertext, associatedData, context []byte) ([]byte, error) {
ct := string(ciphertext)
parts := strings.Split(ct, ":")
if len(parts) != 4 || parts[0] != "enc" {
return nil, fmt.Errorf("malformed ciphertext: %s", ciphertext)
}
context2, err := base64.StdEncoding.DecodeString(parts[1])
if err != nil {
return nil, err
}
if !bytes.Equal(context, context2) {
return nil, fmt.Errorf("invalid context: %s != %s", context2, context)
}
associatedData2, err := base64.StdEncoding.DecodeString(parts[2])
if err != nil {
return nil, err
}
if !bytes.Equal(associatedData2, associatedData) {
return nil, fmt.Errorf("invalid associatedData: %s != %s", associatedData2, associatedData)
}
plaintext, err := base64.StdEncoding.DecodeString(parts[3])
if err != nil {
return nil, err
}
return plaintext, nil
}
func TestFakeEncryptDecrypt(t *testing.T) {
ciphertext := fakeEncrypt([]byte("plaintext"), []byte("associatedData"), []byte("context"))
got, err := fakeDecrypt(ciphertext, []byte("associatedData"), []byte("context"))
if err != nil {
t.Errorf("fakeDecrypt() err = %v, want nil", err)
}
if want := []byte("plaintext"); !bytes.Equal(got, want) {
t.Errorf("fakeDecrypt() = %q, want %q", got, want)
}
}