blob: 127fbcfa02437e449fbae216f345987f99669880 [file] [log] [blame]
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
////////////////////////////////////////////////////////////////////////////////
package hpke
import (
"errors"
"fmt"
"math/big"
pb "github.com/google/tink/go/proto/hpke_go_proto"
)
type context struct {
aead aead
maxSequenceNumber *big.Int
sequenceNumber *big.Int
key []byte
baseNonce []byte
encapsulatedKey []byte
}
// newSenderContext creates the HPKE sender context as per KeySchedule()
// https://www.rfc-editor.org/rfc/rfc9180.html#section-5.1-10.
func newSenderContext(recipientPubKey *pb.HpkePublicKey, kem kem, kdf kdf, aead aead, info []byte) (*context, error) {
if recipientPubKey.GetPublicKey() == nil {
return nil, errors.New("HpkePublicKey has an empty PublicKey")
}
sharedSecret, encapsulatedKey, err := kem.encapsulate(recipientPubKey.GetPublicKey())
if err != nil {
return nil, fmt.Errorf("encapsulate: %v", err)
}
return createContext(encapsulatedKey, sharedSecret, kem, kdf, aead, info)
}
// newRecipientContext creates the HPKE recipient context as per KeySchedule()
// https://www.rfc-editor.org/rfc/rfc9180.html#section-5.1-10.
func newRecipientContext(encapsulatedKey []byte, recipientPrivKey *pb.HpkePrivateKey, kem kem, kdf kdf, aead aead, info []byte) (*context, error) {
if recipientPrivKey.GetPrivateKey() == nil {
return nil, errors.New("HpkePrivateKey has an empty PrivateKey")
}
sharedSecret, err := kem.decapsulate(encapsulatedKey, recipientPrivKey.GetPrivateKey())
if err != nil {
return nil, fmt.Errorf("decapsulate: %v", err)
}
return createContext(encapsulatedKey, sharedSecret, kem, kdf, aead, info)
}
func createContext(encapsulatedKey []byte, sharedSecret []byte, kem kem, kdf kdf, aead aead, info []byte) (*context, error) {
suiteID := hpkeSuiteID(kem.id(), kdf.id(), aead.id())
// In base mode, both the pre-shared key (default_psk) and pre-shared key ID
// (default_psk_id) are empty strings, see
// https://www.rfc-editor.org/rfc/rfc9180.html#section-5.1.1-4.
pskIDHash := kdf.labeledExtract(emptySalt, emptyIKM /*= default PSK ID*/, "psk_id_hash", suiteID)
infoHash := kdf.labeledExtract(emptySalt, info, "info_hash", suiteID)
keyScheduleCtx := keyScheduleContext(baseMode, pskIDHash, infoHash)
secret := kdf.labeledExtract(sharedSecret, emptyIKM /*= default PSK*/, "secret", suiteID)
key, err := kdf.labeledExpand(secret, keyScheduleCtx, "key", suiteID, aead.keyLength())
if err != nil {
return nil, fmt.Errorf("labeledExpand of key: %v", err)
}
baseNonce, err := kdf.labeledExpand(secret, keyScheduleCtx, "base_nonce", suiteID, aead.nonceLength())
if err != nil {
return nil, fmt.Errorf("labeledExpand of base nonce: %v", err)
}
return &context{
aead: aead,
maxSequenceNumber: maxSequenceNumber(aead.nonceLength()),
sequenceNumber: big.NewInt(0),
key: key,
baseNonce: baseNonce,
encapsulatedKey: encapsulatedKey,
}, nil
}
// maxSequenceNumber returns the maximum sequence number indicating that the
// message limit is reached, calculated as per
// https://www.rfc-editor.org/rfc/rfc9180.html#section-5.2-11.
func maxSequenceNumber(nonceLength int) *big.Int {
res := new(big.Int)
one := big.NewInt(1)
res.Lsh(one, uint(8*nonceLength)).Sub(res, one)
return res
}
func (c *context) incrementSequenceNumber() error {
c.sequenceNumber.Add(c.sequenceNumber, big.NewInt(1))
if c.sequenceNumber.Cmp(c.maxSequenceNumber) > 0 {
return errors.New("message limit reached")
}
return nil
}
// computeNonce computes the nonce as per
// https://www.rfc-editor.org/rfc/rfc9180.html#section-5.2-12.
func (c *context) computeNonce() ([]byte, error) {
nonce := make([]byte, len(c.baseNonce))
// Write the big-endian c.sequenceNumber value at the end of nonce.
sequenceNumber := c.sequenceNumber.Bytes()
index := len(nonce) - len(sequenceNumber)
if index < 0 {
return nil, fmt.Errorf("sequence number length (%d) is larger than nonce length (%d)", len(sequenceNumber), len(nonce))
}
copy(nonce[index:], sequenceNumber)
// nonce XOR c.baseNonce.
for i, b := range c.baseNonce {
nonce[i] ^= b
}
return nonce, nil
}
// seal allows the sender's context to encrypt plaintext with associatedData,
// defined as ContextS.Seal in
// https://www.rfc-editor.org/rfc/rfc9180.html#section-5.2-7.
func (c *context) seal(plaintext, associatedData []byte) ([]byte, error) {
nonce, err := c.computeNonce()
if err != nil {
return nil, fmt.Errorf("computeNonce: %v", err)
}
ciphertext, err := c.aead.seal(c.key, nonce, plaintext, associatedData)
if err != nil {
return nil, fmt.Errorf("seal: %v", err)
}
if err := c.incrementSequenceNumber(); err != nil {
return nil, err
}
return ciphertext, nil
}
// open allows the receiver's context to decrypt ciphertext with
// associatedData, defined as ContextR.Open in
// https://www.rfc-editor.org/rfc/rfc9180.html#section-5.2-9.
func (c *context) open(ciphertext, associatedData []byte) ([]byte, error) {
nonce, err := c.computeNonce()
if err != nil {
return nil, fmt.Errorf("computeNonce: %v", err)
}
plaintext, err := c.aead.open(c.key, nonce, ciphertext, associatedData)
if err != nil {
return nil, fmt.Errorf("open: %v", err)
}
if err := c.incrementSequenceNumber(); err != nil {
return nil, err
}
return plaintext, nil
}