blob: 39e1bb98e66b348eba9841fd8e278682af881766 [file] [log] [blame]
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
////////////////////////////////////////////////////////////////////////////////
package hpke
import (
"testing"
pb "github.com/google/tink/go/proto/hpke_go_proto"
)
func TestNewKEM(t *testing.T) {
kemID, err := kemIDFromProto(pb.HpkeKem_DHKEM_X25519_HKDF_SHA256)
if err != nil {
t.Fatal(err)
}
if kemID != x25519HKDFSHA256 {
t.Errorf("kemID: got %d, want %d", kemID, x25519HKDFSHA256)
}
kem, err := newKEM(kemID)
if err != nil {
t.Fatal(err)
}
if kem.id() != x25519HKDFSHA256 {
t.Errorf("id: got %d, want %d", kem.id(), x25519HKDFSHA256)
}
}
func TestNewKEMUnsupportedID(t *testing.T) {
if _, err := newKEM(0x0010 /*= DHKEM(P-256, HKDF-SHA256)*/); err == nil {
t.Fatal("newKEM(unsupported ID): got success, want err")
}
}
func TestKEMIDFromProtoUnsupportedID(t *testing.T) {
if _, err := kemIDFromProto(pb.HpkeKem_KEM_UNKNOWN); err == nil {
t.Fatal("kemIDFromProto(unsupported ID): got success, want err")
}
}
func TestNewKDF(t *testing.T) {
kdfID, err := kdfIDFromProto(pb.HpkeKdf_HKDF_SHA256)
if err != nil {
t.Fatal(err)
}
if kdfID != hkdfSHA256 {
t.Errorf("kdfID: got %d, want %d", kdfID, hkdfSHA256)
}
kdf, err := newKDF(kdfID)
if err != nil {
t.Fatal(err)
}
if kdf.id() != hkdfSHA256 {
t.Errorf("id: got %d, want %d", kdf.id(), hkdfSHA256)
}
}
func TestNewKDFUnsupportedID(t *testing.T) {
if _, err := newKDF(0x0002 /*= HKDF-SHA384*/); err == nil {
t.Fatal("newKDF(unsupported ID): got success, want err")
}
}
func TestKDFIDFromProtoUnsupportedID(t *testing.T) {
if _, err := kdfIDFromProto(pb.HpkeKdf_KDF_UNKNOWN); err == nil {
t.Fatal("kdfIDFromProto(unsupported ID): got success, want err")
}
}
var aeads = []struct {
name string
proto pb.HpkeAead
id uint16
}{
{"AES-128-GCM", pb.HpkeAead_AES_128_GCM, aes128GCM},
{"AES-256-GCM", pb.HpkeAead_AES_256_GCM, aes256GCM},
{"ChaCha20Poly1305", pb.HpkeAead_CHACHA20_POLY1305, chaCha20Poly1305},
}
func TestNewAEAD(t *testing.T) {
for _, a := range aeads {
t.Run(a.name, func(t *testing.T) {
aeadID, err := aeadIDFromProto(a.proto)
if err != nil {
t.Fatal(err)
}
if aeadID != a.id {
t.Errorf("aeadID: got %d, want %d", aeadID, a.id)
}
aead, err := newAEAD(aeadID)
if err != nil {
t.Fatal(err)
}
if aead.id() != a.id {
t.Errorf("id: got %d, want %d", aead.id(), a.id)
}
})
}
}
func TestNewAEADUnsupportedID(t *testing.T) {
if _, err := newAEAD(0xFFFF /*= Export-only*/); err == nil {
t.Fatal("newAEAD(unsupported ID): got success, want err")
}
}
func TestAEADIDFromProtoUnsupportedID(t *testing.T) {
if _, err := aeadIDFromProto(pb.HpkeAead_AEAD_UNKNOWN); err == nil {
t.Fatal("aeadIDFromProto(unsupported ID): got success, want err")
}
}
func TestNewPrimitivesFromProto(t *testing.T) {
for _, a := range aeads {
t.Run("", func(t *testing.T) {
params := &pb.HpkeParams{
Kem: pb.HpkeKem_DHKEM_X25519_HKDF_SHA256,
Kdf: pb.HpkeKdf_HKDF_SHA256,
Aead: a.proto,
}
kem, kdf, aead, err := newPrimitivesFromProto(params)
if err != nil {
t.Fatalf("newPrimitivesFromProto: %v", err)
}
if kem.id() != x25519HKDFSHA256 {
t.Errorf("kem.id: got %d, want %d", kem.id(), x25519HKDFSHA256)
}
if kdf.id() != hkdfSHA256 {
t.Errorf("kdf.id: got %d, want %d", kdf.id(), hkdfSHA256)
}
if aead.id() != a.id {
t.Errorf("aead.id: got %d, want %d", aead.id(), a.id)
}
})
}
}
func TestNewPrimitivesFromProtoUnsupportedID(t *testing.T) {
tests := []struct {
name string
params *pb.HpkeParams
}{
{
"KEM",
&pb.HpkeParams{
Kem: pb.HpkeKem_KEM_UNKNOWN,
Kdf: pb.HpkeKdf_HKDF_SHA256,
Aead: pb.HpkeAead_AES_256_GCM,
},
},
{"KDF",
&pb.HpkeParams{
Kem: pb.HpkeKem_DHKEM_X25519_HKDF_SHA256,
Kdf: pb.HpkeKdf_KDF_UNKNOWN,
Aead: pb.HpkeAead_AES_256_GCM,
},
},
{"AEAD",
&pb.HpkeParams{
Kem: pb.HpkeKem_DHKEM_X25519_HKDF_SHA256,
Kdf: pb.HpkeKdf_HKDF_SHA256,
Aead: pb.HpkeAead_AEAD_UNKNOWN,
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if _, _, _, err := newPrimitivesFromProto(test.params); err == nil {
t.Error("newPrimitivesFromProto: got success, want err")
}
})
}
}