[mle] perform in-place AES-CCM encryption/decryption (#7614)

This commit adds a new method in `AesCcm` to encrypt/decrypt the
payload content in place within a given `Message`. This is then used
in `Mle` for processing MLE messages. This commit also updates unit
test `test_aes.cpp` to validate the new method.
diff --git a/src/core/common/message.hpp b/src/core/common/message.hpp
index 4a78413..2d359ce 100644
--- a/src/core/common/message.hpp
+++ b/src/core/common/message.hpp
@@ -70,6 +70,7 @@
 
 namespace Crypto {
 
+class AesCcm;
 class Sha256;
 class HmacSha256;
 
@@ -262,6 +263,7 @@
     friend class Checksum;
     friend class Crypto::HmacSha256;
     friend class Crypto::Sha256;
+    friend class Crypto::AesCcm;
     friend class MessagePool;
     friend class MessageQueue;
     friend class PriorityQueue;
diff --git a/src/core/crypto/aes_ccm.cpp b/src/core/crypto/aes_ccm.cpp
index a3ed5e5..49c42f2 100644
--- a/src/core/crypto/aes_ccm.cpp
+++ b/src/core/crypto/aes_ccm.cpp
@@ -248,6 +248,21 @@
     }
 }
 
+#if !OPENTHREAD_RADIO
+void AesCcm::Payload(Message &aMessage, uint16_t aOffset, uint16_t aLength, Mode aMode)
+{
+    Message::MutableChunk chunk;
+
+    aMessage.GetFirstChunk(aOffset, aLength, chunk);
+
+    while (chunk.GetLength() > 0)
+    {
+        Payload(chunk.GetBytes(), chunk.GetBytes(), chunk.GetLength(), aMode);
+        aMessage.GetNextChunk(aLength, chunk);
+    }
+}
+#endif
+
 void AesCcm::Finalize(void *aTag)
 {
     uint8_t *tagBytes = reinterpret_cast<uint8_t *>(aTag);
diff --git a/src/core/crypto/aes_ccm.hpp b/src/core/crypto/aes_ccm.hpp
index fcca7fc..0234243 100644
--- a/src/core/crypto/aes_ccm.hpp
+++ b/src/core/crypto/aes_ccm.hpp
@@ -40,6 +40,8 @@
 
 #include <openthread/platform/crypto.h>
 #include "common/error.hpp"
+#include "common/message.hpp"
+#include "common/type_traits.hpp"
 #include "crypto/aes_ecb.hpp"
 #include "crypto/storage.hpp"
 #include "mac/mac_types.hpp"
@@ -126,6 +128,21 @@
     void Header(const void *aHeader, uint32_t aHeaderLength);
 
     /**
+     * This method processes the header.
+     *
+     * @tparam    ObjectType   The object type.
+     *
+     * @param[in] aObject      A reference to the object to add to header.
+     *
+     */
+    template <typename ObjectType> void Header(const ObjectType &aObject)
+    {
+        static_assert(!TypeTraits::IsPointer<ObjectType>::kValue, "ObjectType must not be a pointer");
+
+        Header(&aObject, sizeof(ObjectType));
+    }
+
+    /**
      * This method processes the payload.
      *
      * @param[in,out]  aPlainText   A pointer to the plaintext.
@@ -136,6 +153,21 @@
      */
     void Payload(void *aPlainText, void *aCipherText, uint32_t aLength, Mode aMode);
 
+#if !OPENTHREAD_RADIO
+    /**
+     * This method processes the payload within a given message.
+     *
+     * This method encrypts/decrypts the payload content in place within the @p aMessage.
+     *
+     * @param[in,out]  aMessage     The message to read from and update.
+     * @param[in]      aOffset      The offset in @p aMessage to start of payload.
+     * @param[in]      aLength      Payload length in bytes.
+     * @param[in]      aMode        Mode to indicate whether to encrypt (`kEncrypt`) or decrypt (`kDecrypt`).
+     *
+     */
+    void Payload(Message &aMessage, uint16_t aOffset, uint16_t aLength, Mode aMode);
+#endif
+
     /**
      * This method returns the tag length in bytes.
      *
diff --git a/src/core/thread/mle.cpp b/src/core/thread/mle.cpp
index 543c234..023d47d 100644
--- a/src/core/thread/mle.cpp
+++ b/src/core/thread/mle.cpp
@@ -2678,18 +2678,17 @@
 {
     Error            error = kErrorNone;
     Header           header;
-    uint32_t         keySequence;
-    uint8_t          nonce[Crypto::AesCcm::kNonceSize];
-    uint8_t          tag[kMleSecurityTagSize];
-    Crypto::AesCcm   aesCcm;
-    uint8_t          buf[64];
-    uint16_t         length;
     Ip6::MessageInfo messageInfo;
 
     IgnoreError(aMessage.Read(0, header));
 
     if (header.GetSecuritySuite() == Header::k154Security)
     {
+        uint32_t       keySequence;
+        uint8_t        nonce[Crypto::AesCcm::kNonceSize];
+        uint8_t        tag[kMleSecurityTagSize];
+        Crypto::AesCcm aesCcm;
+
         header.SetFrameCounter(Get<KeyManager>().GetMleFrameCounter());
 
         keySequence = Get<KeyManager>().GetCurrentKeySequence();
@@ -2704,22 +2703,17 @@
         aesCcm.Init(16 + 16 + header.GetHeaderLength(), aMessage.GetLength() - (header.GetLength() - 1), sizeof(tag),
                     nonce, sizeof(nonce));
 
-        aesCcm.Header(&mLinkLocal64.GetAddress(), sizeof(mLinkLocal64.GetAddress()));
-        aesCcm.Header(&aDestination, sizeof(aDestination));
+        aesCcm.Header(mLinkLocal64.GetAddress());
+        aesCcm.Header(aDestination);
         aesCcm.Header(header.GetBytes() + 1, header.GetHeaderLength());
 
         aMessage.SetOffset(header.GetLength() - 1);
 
-        while (aMessage.GetOffset() < aMessage.GetLength())
-        {
-            length = aMessage.ReadBytes(aMessage.GetOffset(), buf, sizeof(buf));
-            aesCcm.Payload(buf, buf, length, Crypto::AesCcm::kEncrypt);
-            aMessage.WriteBytes(aMessage.GetOffset(), buf, length);
-            aMessage.MoveOffset(length);
-        }
+        aesCcm.Payload(aMessage, aMessage.GetOffset(), aMessage.GetLength() - aMessage.GetOffset(),
+                       Crypto::AesCcm::kEncrypt);
 
         aesCcm.Finalize(tag);
-        SuccessOrExit(error = aMessage.AppendBytes(tag, sizeof(tag)));
+        SuccessOrExit(error = aMessage.Append(tag));
 
         Get<KeyManager>().IncrementMleFrameCounter();
     }
@@ -2768,8 +2762,6 @@
     uint8_t            nonce[Crypto::AesCcm::kNonceSize];
     Mac::ExtAddress    extAddr;
     Crypto::AesCcm     aesCcm;
-    uint16_t           mleOffset;
-    uint8_t            buf[64];
     uint16_t           length;
     uint8_t            tag[kMleSecurityTagSize];
     uint8_t            command;
@@ -2837,24 +2829,15 @@
     aesCcm.Init(sizeof(aMessageInfo.GetPeerAddr()) + sizeof(aMessageInfo.GetSockAddr()) + header.GetHeaderLength(),
                 aMessage.GetLength() - aMessage.GetOffset(), sizeof(messageTag), nonce, sizeof(nonce));
 
-    aesCcm.Header(&aMessageInfo.GetPeerAddr(), sizeof(aMessageInfo.GetPeerAddr()));
-    aesCcm.Header(&aMessageInfo.GetSockAddr(), sizeof(aMessageInfo.GetSockAddr()));
+    aesCcm.Header(aMessageInfo.GetPeerAddr());
+    aesCcm.Header(aMessageInfo.GetSockAddr());
     aesCcm.Header(header.GetBytes() + 1, header.GetHeaderLength());
 
-    mleOffset = aMessage.GetOffset();
-
-    while (aMessage.GetOffset() < aMessage.GetLength())
-    {
-        length = aMessage.ReadBytes(aMessage.GetOffset(), buf, sizeof(buf));
-        aesCcm.Payload(buf, buf, length, Crypto::AesCcm::kDecrypt);
 #ifndef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION
-        aMessage.WriteBytes(aMessage.GetOffset(), buf, length);
-#endif
-        aMessage.MoveOffset(length);
-    }
-
+    aesCcm.Payload(aMessage, aMessage.GetOffset(), aMessage.GetLength() - aMessage.GetOffset(),
+                   Crypto::AesCcm::kDecrypt);
     aesCcm.Finalize(tag);
-#ifndef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION
+
     if (memcmp(messageTag, tag, sizeof(tag)) != 0)
     {
         // We skip logging security check failures for broadcast MLE
@@ -2871,8 +2854,6 @@
         Get<KeyManager>().SetCurrentKeySequence(keySequence);
     }
 
-    aMessage.SetOffset(mleOffset);
-
     IgnoreError(aMessage.Read(aMessage.GetOffset(), command));
     aMessage.MoveOffset(sizeof(command));
 
diff --git a/tests/unit/test_aes.cpp b/tests/unit/test_aes.cpp
index 6a9f516..ce75b28 100644
--- a/tests/unit/test_aes.cpp
+++ b/tests/unit/test_aes.cpp
@@ -32,7 +32,7 @@
 #include "crypto/aes_ccm.hpp"
 
 #include "test_platform.h"
-#include "test_util.h"
+#include "test_util.hpp"
 
 /**
  * Verifies test vectors from IEEE 802.15.4-2006 Annex C Section C.2.1
@@ -88,7 +88,7 @@
 /**
  * Verifies test vectors from IEEE 802.15.4-2006 Annex C Section C.2.3
  */
-void TestMacCommandFrame()
+void TestMacCommandFrame(void)
 {
     uint8_t key[] = {
         0xc0, 0xc1, 0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc7, 0xc8, 0xc9, 0xca, 0xcb, 0xcc, 0xcd, 0xce, 0xcf,
@@ -100,8 +100,9 @@
         0x00, 0x00, 0x01, 0xCE, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
     };
 
-    uint32_t headerLength = 29, payloadLength = 1;
-    uint8_t  tagLength = 8;
+    static constexpr uint32_t kHeaderLength  = 29;
+    static constexpr uint32_t kPayloadLength = 1;
+    static constexpr uint8_t  kTagLength     = 8;
 
     uint8_t encrypted[] = {
         0x2B, 0xDC, 0x84, 0x21, 0x43, 0x02, 0x00, 0x00, 0x00, 0x00, 0x48, 0xDE, 0xAC,
@@ -119,28 +120,144 @@
         0xAC, 0xDE, 0x48, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x05, 0x06,
     };
 
+    uint8_t tag[kTagLength];
+
+    ot::Instance *     instance = testInitInstance();
+    ot::Message *      message;
     ot::Crypto::AesCcm aesCcm;
+
+    VerifyOrQuit(instance != nullptr);
+
     aesCcm.SetKey(key, sizeof(key));
-    aesCcm.Init(headerLength, payloadLength, tagLength, nonce, sizeof(nonce));
-    aesCcm.Header(test, headerLength);
-    aesCcm.Payload(test + headerLength, test + headerLength, payloadLength, ot::Crypto::AesCcm::kEncrypt);
-    VerifyOrQuit(aesCcm.GetTagLength() == tagLength);
-    aesCcm.Finalize(test + headerLength + payloadLength);
+    aesCcm.Init(kHeaderLength, kPayloadLength, kTagLength, nonce, sizeof(nonce));
+    aesCcm.Header(test, kHeaderLength);
+    aesCcm.Payload(test + kHeaderLength, test + kHeaderLength, kPayloadLength, ot::Crypto::AesCcm::kEncrypt);
+    VerifyOrQuit(aesCcm.GetTagLength() == kTagLength);
+    aesCcm.Finalize(test + kHeaderLength + kPayloadLength);
     VerifyOrQuit(memcmp(test, encrypted, sizeof(encrypted)) == 0);
 
-    aesCcm.Init(headerLength, payloadLength, tagLength, nonce, sizeof(nonce));
-    aesCcm.Header(test, headerLength);
-    aesCcm.Payload(test + headerLength, test + headerLength, payloadLength, ot::Crypto::AesCcm::kDecrypt);
-    VerifyOrQuit(aesCcm.GetTagLength() == tagLength);
-    aesCcm.Finalize(test + headerLength + payloadLength);
+    aesCcm.Init(kHeaderLength, kPayloadLength, kTagLength, nonce, sizeof(nonce));
+    aesCcm.Header(test, kHeaderLength);
+    aesCcm.Payload(test + kHeaderLength, test + kHeaderLength, kPayloadLength, ot::Crypto::AesCcm::kDecrypt);
+    VerifyOrQuit(aesCcm.GetTagLength() == kTagLength);
+    aesCcm.Finalize(test + kHeaderLength + kPayloadLength);
 
     VerifyOrQuit(memcmp(test, decrypted, sizeof(decrypted)) == 0);
+
+    // Verify encryption/decryption in place within a message.
+
+    message = instance->Get<ot::MessagePool>().Allocate(ot::Message::kTypeIp6);
+    VerifyOrQuit(message != nullptr);
+
+    SuccessOrQuit(message->AppendBytes(test, kHeaderLength + kPayloadLength));
+
+    aesCcm.Init(kHeaderLength, kPayloadLength, kTagLength, nonce, sizeof(nonce));
+    aesCcm.Header(test, kHeaderLength);
+
+    aesCcm.Payload(*message, kHeaderLength, kPayloadLength, ot::Crypto::AesCcm::kEncrypt);
+    VerifyOrQuit(aesCcm.GetTagLength() == kTagLength);
+    aesCcm.Finalize(tag);
+    SuccessOrQuit(message->Append(tag));
+    VerifyOrQuit(message->GetLength() == sizeof(encrypted));
+    VerifyOrQuit(message->Compare(0, encrypted));
+
+    aesCcm.Init(kHeaderLength, kPayloadLength, kTagLength, nonce, sizeof(nonce));
+    aesCcm.Header(test, kHeaderLength);
+    aesCcm.Payload(*message, kHeaderLength, kPayloadLength, ot::Crypto::AesCcm::kDecrypt);
+
+    VerifyOrQuit(message->GetLength() == sizeof(encrypted));
+    VerifyOrQuit(message->Compare(0, decrypted));
+
+    message->Free();
+    testFreeInstance(instance);
+}
+
+/**
+ * Verifies in-place encryption/decryption.
+ *
+ */
+void TestInPlaceAesCcmProcessing(void)
+{
+    static constexpr uint16_t kTagLength    = 4;
+    static constexpr uint32_t kHeaderLength = 19;
+
+    static const uint8_t kKey[] = {
+        0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7, 0xa8, 0xa9, 0xaa, 0xab, 0xac, 0xad, 0xae, 0xaf,
+    };
+
+    static const uint8_t kNonce[] = {
+        0xac, 0xde, 0x48, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x05, 0x06,
+    };
+
+    static uint16_t kMessageLengths[] = {30, 400, 800};
+
+    uint8_t tag[kTagLength];
+    uint8_t header[kHeaderLength];
+
+    ot::Crypto::AesCcm aesCcm;
+    ot::Instance *     instance = testInitInstance();
+    ot::Message *      message;
+    ot::Message *      messageClone;
+
+    VerifyOrQuit(instance != nullptr);
+
+    message = instance->Get<ot::MessagePool>().Allocate(ot::Message::kTypeIp6);
+    VerifyOrQuit(message != nullptr);
+
+    aesCcm.SetKey(kKey, sizeof(kKey));
+
+    for (uint16_t msgLength : kMessageLengths)
+    {
+        printf("msgLength %d\n", msgLength);
+
+        SuccessOrQuit(message->SetLength(0));
+
+        for (uint16_t i = msgLength; i != 0; i--)
+        {
+            SuccessOrQuit(message->Append<uint8_t>(i & 0xff));
+        }
+
+        messageClone = message->Clone();
+        VerifyOrQuit(messageClone != nullptr);
+        VerifyOrQuit(messageClone->GetLength() == msgLength);
+
+        SuccessOrQuit(message->Read(0, header));
+
+        // Encrypt in place
+        aesCcm.Init(kHeaderLength, msgLength - kHeaderLength, kTagLength, kNonce, sizeof(kNonce));
+        aesCcm.Header(header);
+        aesCcm.Payload(*message, kHeaderLength, msgLength - kHeaderLength, ot::Crypto::AesCcm::kEncrypt);
+
+        // Append the tag
+        aesCcm.Finalize(tag);
+        SuccessOrQuit(message->Append(tag));
+
+        VerifyOrQuit(message->GetLength() == msgLength + kTagLength);
+
+        // Decrpt in place
+        aesCcm.Init(kHeaderLength, msgLength - kHeaderLength, kTagLength, kNonce, sizeof(kNonce));
+        aesCcm.Header(header);
+        aesCcm.Payload(*message, kHeaderLength, msgLength - kHeaderLength, ot::Crypto::AesCcm::kDecrypt);
+
+        // Check the tag against what is the message
+        aesCcm.Finalize(tag);
+        VerifyOrQuit(message->Compare(msgLength, tag));
+
+        // Check that decrypted message is the same as original (cloned) message
+        VerifyOrQuit(message->CompareBytes(0, *messageClone, 0, msgLength));
+
+        messageClone->Free();
+    }
+
+    message->Free();
+    testFreeInstance(instance);
 }
 
 int main(void)
 {
     TestMacBeaconFrame();
     TestMacCommandFrame();
+    TestInPlaceAesCcmProcessing();
     printf("All tests passed\n");
     return 0;
 }