blob: c9a61253403c7f23e032027b349e41005744d3c2 [file] [log] [blame]
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
////////////////////////////////////////////////////////////////////////////////
package hpke
import (
"bytes"
"encoding/hex"
"fmt"
"math"
"testing"
"github.com/google/tink/go/subtle"
)
// TODO(b/201070904): Write tests using baseModeX25519HKDFSHA256Vectors.
func TestHKDFKDFLabeledExtract(t *testing.T) {
kdf, err := newKDF(hkdfSHA256)
if err != nil {
t.Fatalf("newKDF(hkdfSHA256): got err %q, want success", err)
}
id, v := internetDraftVector(t)
suiteID := hpkeSuiteID(id.kemID, id.kdfID, id.aeadID)
// Base mode uses a default empty value for the pre-shared key (PSK), see
// https://www.rfc-editor.org/rfc/rfc9180.html#section-5.1.1-4.
pskIDHash := kdf.labeledExtract(emptySalt, emptyIKM /*= default PSK ID*/, "psk_id_hash", suiteID)
infoHash := kdf.labeledExtract(emptySalt, v.info, "info_hash", suiteID)
keyScheduleCtx := keyScheduleContext(id.mode, pskIDHash, infoHash)
if !bytes.Equal(keyScheduleCtx, v.keyScheduleCtx) {
t.Errorf("labeledExtract: got %x, want %x", keyScheduleCtx, v.keyScheduleCtx)
}
secret := kdf.labeledExtract(v.sharedSecret, emptyIKM /*= default PSK*/, "secret", suiteID)
if !bytes.Equal(secret, v.secret) {
t.Errorf("labeledExtract: got %x, want %x", secret, v.secret)
}
}
func TestHKDFKDFLabeledExpand(t *testing.T) {
kdf, err := newKDF(hkdfSHA256)
if err != nil {
t.Fatalf("newKDF(hkdfSHA256): got err %q, want success", err)
}
id, v := internetDraftVector(t)
suiteID := hpkeSuiteID(id.kemID, id.kdfID, id.aeadID)
tests := []struct {
infoLabel string
length int
want []byte
wantErr bool
}{
{"key", 16, v.key, false},
{"base_nonce", 12, v.baseNonce, false},
{"large_length", int(math.Pow(2, 16)), []byte{}, true},
}
for _, test := range tests {
t.Run(test.infoLabel, func(t *testing.T) {
got, err := kdf.labeledExpand(v.secret, v.keyScheduleCtx, test.infoLabel, suiteID, test.length)
if test.wantErr {
if err == nil {
t.Error("labeledExpand: got success, want err")
}
return
}
if err != nil {
t.Errorf("labeledExpand: got err %q, want success", err)
}
if !bytes.Equal(got, test.want) {
t.Errorf("labeledExpand: got %x, want %x", got, test.want)
}
})
}
}
func TestHKDFKDFLabeledExpandRFCVectors(t *testing.T) {
kdf, err := newKDF(hkdfSHA256)
if err != nil {
t.Fatalf("newKDF(hkdfSHA256): got err %q, want success", err)
}
suiteID := hpkeSuiteID(x25519HKDFSHA256, hkdfSHA256, aes128GCM)
// Vectors are defined at
// https://datatracker.ietf.org/doc/html/rfc5869#appendix-A.
var tests = []struct {
name string
info string
prk string
length int
want string // Generated manually.
}{
{
name: "basic",
info: "f0f1f2f3f4f5f6f7f8f9",
prk: "077709362c2e32df0ddc3f0dc47bba6390b6c73bb50f9c3122ec844ad7c2b3e5",
length: 42,
want: "2f1a8eb86971cd1850d04a1b98f9a63d52d56c5a4d5fcb68103e57c7a85a1df2c9be1346ae041007712d",
},
{
name: "longer inputs",
info: "b0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff",
prk: "06a6b88c5853361a06104c9ceb35b45cef760014904671014a193f40c15fc244",
length: 82,
want: "3961afd1985cb4d811e261b3568c44b88ae7e5d5909d33a5419e954eb245fe03fd3635769d88cec8adb709e900fa399e1a68bdb9d5c879e385845eeb99034fd232e30d1acc58f7fa37791fe0c433221b1fec",
},
{
name: "zero-length info",
info: "",
prk: "19ef24a32c717b167f33a91d6f648bdf96596776afdb6377ac434c1c293ccb04",
length: 42,
want: "bdb2761a4f8504177b10ecc354f41153a3964435b9072d1f349c2993afbaa77a05ed426c384e195dba76",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
info, err := hex.DecodeString(test.info)
if err != nil {
t.Fatal("hex.DecodeString(info) failed")
}
prk, err := hex.DecodeString(test.prk)
if err != nil {
t.Fatal("hex.DecodeString(prk) failed")
}
want, err := hex.DecodeString(test.want)
if err != nil {
t.Fatal("hex.DecodeString(want) failed")
}
got, err := kdf.labeledExpand(prk, info, "info_label", suiteID, test.length)
if err != nil {
t.Errorf("labeledExpand: got err %q, want success", err)
}
if !bytes.Equal(got, want) {
t.Errorf("labeledExpand: got %x, want %x", got, want)
}
})
}
}
func TestHKDFKDFExtractAndExpand(t *testing.T) {
kdf, err := newKDF(hkdfSHA256)
if err != nil {
t.Fatalf("newKDF(hkdfSHA256): got err %q, want success", err)
}
_, v := internetDraftVector(t)
dhSharedSecret, err := subtle.ComputeSharedSecretX25519(v.senderPrivKey, v.recipientPubKey)
if err != nil {
t.Fatalf("ComputeSharedSecretX25519: got err %q, want success", err)
}
kemCtx := []byte{}
kemCtx = append(kemCtx, v.senderPubKey...)
kemCtx = append(kemCtx, v.recipientPubKey...)
var tests = []struct {
length int
want []byte
wantErr bool
}{
{32, v.sharedSecret, false},
{int(math.Pow(2, 16)), nil, true},
}
for _, test := range tests {
t.Run(fmt.Sprintf("%d", test.length), func(t *testing.T) {
sharedSecret, err := kdf.extractAndExpand(
emptySalt,
dhSharedSecret,
"eae_prk",
kemCtx,
"shared_secret",
kemSuiteID(x25519HKDFSHA256),
test.length)
if test.wantErr {
if err == nil {
t.Error("extractAndExpand: got success, want err")
}
return
}
if err != nil {
t.Errorf("extractAndExpand: got err %q, want success", err)
}
if !bytes.Equal(sharedSecret, v.sharedSecret) {
t.Errorf("extractAndExpand: got %x, want %x", sharedSecret, v.sharedSecret)
}
})
}
}