[mle] perform in-place AES-CCM encryption/decryption (#7614)
This commit adds a new method in `AesCcm` to encrypt/decrypt the
payload content in place within a given `Message`. This is then used
in `Mle` for processing MLE messages. This commit also updates unit
test `test_aes.cpp` to validate the new method.
diff --git a/src/core/common/message.hpp b/src/core/common/message.hpp
index 4a78413..2d359ce 100644
--- a/src/core/common/message.hpp
+++ b/src/core/common/message.hpp
@@ -70,6 +70,7 @@
namespace Crypto {
+class AesCcm;
class Sha256;
class HmacSha256;
@@ -262,6 +263,7 @@
friend class Checksum;
friend class Crypto::HmacSha256;
friend class Crypto::Sha256;
+ friend class Crypto::AesCcm;
friend class MessagePool;
friend class MessageQueue;
friend class PriorityQueue;
diff --git a/src/core/crypto/aes_ccm.cpp b/src/core/crypto/aes_ccm.cpp
index a3ed5e5..49c42f2 100644
--- a/src/core/crypto/aes_ccm.cpp
+++ b/src/core/crypto/aes_ccm.cpp
@@ -248,6 +248,21 @@
}
}
+#if !OPENTHREAD_RADIO
+void AesCcm::Payload(Message &aMessage, uint16_t aOffset, uint16_t aLength, Mode aMode)
+{
+ Message::MutableChunk chunk;
+
+ aMessage.GetFirstChunk(aOffset, aLength, chunk);
+
+ while (chunk.GetLength() > 0)
+ {
+ Payload(chunk.GetBytes(), chunk.GetBytes(), chunk.GetLength(), aMode);
+ aMessage.GetNextChunk(aLength, chunk);
+ }
+}
+#endif
+
void AesCcm::Finalize(void *aTag)
{
uint8_t *tagBytes = reinterpret_cast<uint8_t *>(aTag);
diff --git a/src/core/crypto/aes_ccm.hpp b/src/core/crypto/aes_ccm.hpp
index fcca7fc..0234243 100644
--- a/src/core/crypto/aes_ccm.hpp
+++ b/src/core/crypto/aes_ccm.hpp
@@ -40,6 +40,8 @@
#include <openthread/platform/crypto.h>
#include "common/error.hpp"
+#include "common/message.hpp"
+#include "common/type_traits.hpp"
#include "crypto/aes_ecb.hpp"
#include "crypto/storage.hpp"
#include "mac/mac_types.hpp"
@@ -126,6 +128,21 @@
void Header(const void *aHeader, uint32_t aHeaderLength);
/**
+ * This method processes the header.
+ *
+ * @tparam ObjectType The object type.
+ *
+ * @param[in] aObject A reference to the object to add to header.
+ *
+ */
+ template <typename ObjectType> void Header(const ObjectType &aObject)
+ {
+ static_assert(!TypeTraits::IsPointer<ObjectType>::kValue, "ObjectType must not be a pointer");
+
+ Header(&aObject, sizeof(ObjectType));
+ }
+
+ /**
* This method processes the payload.
*
* @param[in,out] aPlainText A pointer to the plaintext.
@@ -136,6 +153,21 @@
*/
void Payload(void *aPlainText, void *aCipherText, uint32_t aLength, Mode aMode);
+#if !OPENTHREAD_RADIO
+ /**
+ * This method processes the payload within a given message.
+ *
+ * This method encrypts/decrypts the payload content in place within the @p aMessage.
+ *
+ * @param[in,out] aMessage The message to read from and update.
+ * @param[in] aOffset The offset in @p aMessage to start of payload.
+ * @param[in] aLength Payload length in bytes.
+ * @param[in] aMode Mode to indicate whether to encrypt (`kEncrypt`) or decrypt (`kDecrypt`).
+ *
+ */
+ void Payload(Message &aMessage, uint16_t aOffset, uint16_t aLength, Mode aMode);
+#endif
+
/**
* This method returns the tag length in bytes.
*
diff --git a/src/core/thread/mle.cpp b/src/core/thread/mle.cpp
index 543c234..023d47d 100644
--- a/src/core/thread/mle.cpp
+++ b/src/core/thread/mle.cpp
@@ -2678,18 +2678,17 @@
{
Error error = kErrorNone;
Header header;
- uint32_t keySequence;
- uint8_t nonce[Crypto::AesCcm::kNonceSize];
- uint8_t tag[kMleSecurityTagSize];
- Crypto::AesCcm aesCcm;
- uint8_t buf[64];
- uint16_t length;
Ip6::MessageInfo messageInfo;
IgnoreError(aMessage.Read(0, header));
if (header.GetSecuritySuite() == Header::k154Security)
{
+ uint32_t keySequence;
+ uint8_t nonce[Crypto::AesCcm::kNonceSize];
+ uint8_t tag[kMleSecurityTagSize];
+ Crypto::AesCcm aesCcm;
+
header.SetFrameCounter(Get<KeyManager>().GetMleFrameCounter());
keySequence = Get<KeyManager>().GetCurrentKeySequence();
@@ -2704,22 +2703,17 @@
aesCcm.Init(16 + 16 + header.GetHeaderLength(), aMessage.GetLength() - (header.GetLength() - 1), sizeof(tag),
nonce, sizeof(nonce));
- aesCcm.Header(&mLinkLocal64.GetAddress(), sizeof(mLinkLocal64.GetAddress()));
- aesCcm.Header(&aDestination, sizeof(aDestination));
+ aesCcm.Header(mLinkLocal64.GetAddress());
+ aesCcm.Header(aDestination);
aesCcm.Header(header.GetBytes() + 1, header.GetHeaderLength());
aMessage.SetOffset(header.GetLength() - 1);
- while (aMessage.GetOffset() < aMessage.GetLength())
- {
- length = aMessage.ReadBytes(aMessage.GetOffset(), buf, sizeof(buf));
- aesCcm.Payload(buf, buf, length, Crypto::AesCcm::kEncrypt);
- aMessage.WriteBytes(aMessage.GetOffset(), buf, length);
- aMessage.MoveOffset(length);
- }
+ aesCcm.Payload(aMessage, aMessage.GetOffset(), aMessage.GetLength() - aMessage.GetOffset(),
+ Crypto::AesCcm::kEncrypt);
aesCcm.Finalize(tag);
- SuccessOrExit(error = aMessage.AppendBytes(tag, sizeof(tag)));
+ SuccessOrExit(error = aMessage.Append(tag));
Get<KeyManager>().IncrementMleFrameCounter();
}
@@ -2768,8 +2762,6 @@
uint8_t nonce[Crypto::AesCcm::kNonceSize];
Mac::ExtAddress extAddr;
Crypto::AesCcm aesCcm;
- uint16_t mleOffset;
- uint8_t buf[64];
uint16_t length;
uint8_t tag[kMleSecurityTagSize];
uint8_t command;
@@ -2837,24 +2829,15 @@
aesCcm.Init(sizeof(aMessageInfo.GetPeerAddr()) + sizeof(aMessageInfo.GetSockAddr()) + header.GetHeaderLength(),
aMessage.GetLength() - aMessage.GetOffset(), sizeof(messageTag), nonce, sizeof(nonce));
- aesCcm.Header(&aMessageInfo.GetPeerAddr(), sizeof(aMessageInfo.GetPeerAddr()));
- aesCcm.Header(&aMessageInfo.GetSockAddr(), sizeof(aMessageInfo.GetSockAddr()));
+ aesCcm.Header(aMessageInfo.GetPeerAddr());
+ aesCcm.Header(aMessageInfo.GetSockAddr());
aesCcm.Header(header.GetBytes() + 1, header.GetHeaderLength());
- mleOffset = aMessage.GetOffset();
-
- while (aMessage.GetOffset() < aMessage.GetLength())
- {
- length = aMessage.ReadBytes(aMessage.GetOffset(), buf, sizeof(buf));
- aesCcm.Payload(buf, buf, length, Crypto::AesCcm::kDecrypt);
#ifndef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION
- aMessage.WriteBytes(aMessage.GetOffset(), buf, length);
-#endif
- aMessage.MoveOffset(length);
- }
-
+ aesCcm.Payload(aMessage, aMessage.GetOffset(), aMessage.GetLength() - aMessage.GetOffset(),
+ Crypto::AesCcm::kDecrypt);
aesCcm.Finalize(tag);
-#ifndef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION
+
if (memcmp(messageTag, tag, sizeof(tag)) != 0)
{
// We skip logging security check failures for broadcast MLE
@@ -2871,8 +2854,6 @@
Get<KeyManager>().SetCurrentKeySequence(keySequence);
}
- aMessage.SetOffset(mleOffset);
-
IgnoreError(aMessage.Read(aMessage.GetOffset(), command));
aMessage.MoveOffset(sizeof(command));
diff --git a/tests/unit/test_aes.cpp b/tests/unit/test_aes.cpp
index 6a9f516..ce75b28 100644
--- a/tests/unit/test_aes.cpp
+++ b/tests/unit/test_aes.cpp
@@ -32,7 +32,7 @@
#include "crypto/aes_ccm.hpp"
#include "test_platform.h"
-#include "test_util.h"
+#include "test_util.hpp"
/**
* Verifies test vectors from IEEE 802.15.4-2006 Annex C Section C.2.1
@@ -88,7 +88,7 @@
/**
* Verifies test vectors from IEEE 802.15.4-2006 Annex C Section C.2.3
*/
-void TestMacCommandFrame()
+void TestMacCommandFrame(void)
{
uint8_t key[] = {
0xc0, 0xc1, 0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc7, 0xc8, 0xc9, 0xca, 0xcb, 0xcc, 0xcd, 0xce, 0xcf,
@@ -100,8 +100,9 @@
0x00, 0x00, 0x01, 0xCE, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
};
- uint32_t headerLength = 29, payloadLength = 1;
- uint8_t tagLength = 8;
+ static constexpr uint32_t kHeaderLength = 29;
+ static constexpr uint32_t kPayloadLength = 1;
+ static constexpr uint8_t kTagLength = 8;
uint8_t encrypted[] = {
0x2B, 0xDC, 0x84, 0x21, 0x43, 0x02, 0x00, 0x00, 0x00, 0x00, 0x48, 0xDE, 0xAC,
@@ -119,28 +120,144 @@
0xAC, 0xDE, 0x48, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x05, 0x06,
};
+ uint8_t tag[kTagLength];
+
+ ot::Instance * instance = testInitInstance();
+ ot::Message * message;
ot::Crypto::AesCcm aesCcm;
+
+ VerifyOrQuit(instance != nullptr);
+
aesCcm.SetKey(key, sizeof(key));
- aesCcm.Init(headerLength, payloadLength, tagLength, nonce, sizeof(nonce));
- aesCcm.Header(test, headerLength);
- aesCcm.Payload(test + headerLength, test + headerLength, payloadLength, ot::Crypto::AesCcm::kEncrypt);
- VerifyOrQuit(aesCcm.GetTagLength() == tagLength);
- aesCcm.Finalize(test + headerLength + payloadLength);
+ aesCcm.Init(kHeaderLength, kPayloadLength, kTagLength, nonce, sizeof(nonce));
+ aesCcm.Header(test, kHeaderLength);
+ aesCcm.Payload(test + kHeaderLength, test + kHeaderLength, kPayloadLength, ot::Crypto::AesCcm::kEncrypt);
+ VerifyOrQuit(aesCcm.GetTagLength() == kTagLength);
+ aesCcm.Finalize(test + kHeaderLength + kPayloadLength);
VerifyOrQuit(memcmp(test, encrypted, sizeof(encrypted)) == 0);
- aesCcm.Init(headerLength, payloadLength, tagLength, nonce, sizeof(nonce));
- aesCcm.Header(test, headerLength);
- aesCcm.Payload(test + headerLength, test + headerLength, payloadLength, ot::Crypto::AesCcm::kDecrypt);
- VerifyOrQuit(aesCcm.GetTagLength() == tagLength);
- aesCcm.Finalize(test + headerLength + payloadLength);
+ aesCcm.Init(kHeaderLength, kPayloadLength, kTagLength, nonce, sizeof(nonce));
+ aesCcm.Header(test, kHeaderLength);
+ aesCcm.Payload(test + kHeaderLength, test + kHeaderLength, kPayloadLength, ot::Crypto::AesCcm::kDecrypt);
+ VerifyOrQuit(aesCcm.GetTagLength() == kTagLength);
+ aesCcm.Finalize(test + kHeaderLength + kPayloadLength);
VerifyOrQuit(memcmp(test, decrypted, sizeof(decrypted)) == 0);
+
+ // Verify encryption/decryption in place within a message.
+
+ message = instance->Get<ot::MessagePool>().Allocate(ot::Message::kTypeIp6);
+ VerifyOrQuit(message != nullptr);
+
+ SuccessOrQuit(message->AppendBytes(test, kHeaderLength + kPayloadLength));
+
+ aesCcm.Init(kHeaderLength, kPayloadLength, kTagLength, nonce, sizeof(nonce));
+ aesCcm.Header(test, kHeaderLength);
+
+ aesCcm.Payload(*message, kHeaderLength, kPayloadLength, ot::Crypto::AesCcm::kEncrypt);
+ VerifyOrQuit(aesCcm.GetTagLength() == kTagLength);
+ aesCcm.Finalize(tag);
+ SuccessOrQuit(message->Append(tag));
+ VerifyOrQuit(message->GetLength() == sizeof(encrypted));
+ VerifyOrQuit(message->Compare(0, encrypted));
+
+ aesCcm.Init(kHeaderLength, kPayloadLength, kTagLength, nonce, sizeof(nonce));
+ aesCcm.Header(test, kHeaderLength);
+ aesCcm.Payload(*message, kHeaderLength, kPayloadLength, ot::Crypto::AesCcm::kDecrypt);
+
+ VerifyOrQuit(message->GetLength() == sizeof(encrypted));
+ VerifyOrQuit(message->Compare(0, decrypted));
+
+ message->Free();
+ testFreeInstance(instance);
+}
+
+/**
+ * Verifies in-place encryption/decryption.
+ *
+ */
+void TestInPlaceAesCcmProcessing(void)
+{
+ static constexpr uint16_t kTagLength = 4;
+ static constexpr uint32_t kHeaderLength = 19;
+
+ static const uint8_t kKey[] = {
+ 0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7, 0xa8, 0xa9, 0xaa, 0xab, 0xac, 0xad, 0xae, 0xaf,
+ };
+
+ static const uint8_t kNonce[] = {
+ 0xac, 0xde, 0x48, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x05, 0x06,
+ };
+
+ static uint16_t kMessageLengths[] = {30, 400, 800};
+
+ uint8_t tag[kTagLength];
+ uint8_t header[kHeaderLength];
+
+ ot::Crypto::AesCcm aesCcm;
+ ot::Instance * instance = testInitInstance();
+ ot::Message * message;
+ ot::Message * messageClone;
+
+ VerifyOrQuit(instance != nullptr);
+
+ message = instance->Get<ot::MessagePool>().Allocate(ot::Message::kTypeIp6);
+ VerifyOrQuit(message != nullptr);
+
+ aesCcm.SetKey(kKey, sizeof(kKey));
+
+ for (uint16_t msgLength : kMessageLengths)
+ {
+ printf("msgLength %d\n", msgLength);
+
+ SuccessOrQuit(message->SetLength(0));
+
+ for (uint16_t i = msgLength; i != 0; i--)
+ {
+ SuccessOrQuit(message->Append<uint8_t>(i & 0xff));
+ }
+
+ messageClone = message->Clone();
+ VerifyOrQuit(messageClone != nullptr);
+ VerifyOrQuit(messageClone->GetLength() == msgLength);
+
+ SuccessOrQuit(message->Read(0, header));
+
+ // Encrypt in place
+ aesCcm.Init(kHeaderLength, msgLength - kHeaderLength, kTagLength, kNonce, sizeof(kNonce));
+ aesCcm.Header(header);
+ aesCcm.Payload(*message, kHeaderLength, msgLength - kHeaderLength, ot::Crypto::AesCcm::kEncrypt);
+
+ // Append the tag
+ aesCcm.Finalize(tag);
+ SuccessOrQuit(message->Append(tag));
+
+ VerifyOrQuit(message->GetLength() == msgLength + kTagLength);
+
+ // Decrpt in place
+ aesCcm.Init(kHeaderLength, msgLength - kHeaderLength, kTagLength, kNonce, sizeof(kNonce));
+ aesCcm.Header(header);
+ aesCcm.Payload(*message, kHeaderLength, msgLength - kHeaderLength, ot::Crypto::AesCcm::kDecrypt);
+
+ // Check the tag against what is the message
+ aesCcm.Finalize(tag);
+ VerifyOrQuit(message->Compare(msgLength, tag));
+
+ // Check that decrypted message is the same as original (cloned) message
+ VerifyOrQuit(message->CompareBytes(0, *messageClone, 0, msgLength));
+
+ messageClone->Free();
+ }
+
+ message->Free();
+ testFreeInstance(instance);
}
int main(void)
{
TestMacBeaconFrame();
TestMacCommandFrame();
+ TestInPlaceAesCcmProcessing();
printf("All tests passed\n");
return 0;
}