[mle] add common `ProcessMessageSecurity()` (#7625)

This commit adds `Mle::ProcessMessageSecurity()`. Based on the `aMode`
input, it can be used to encrypt and append tag to a message (which
will be sent) or to decrypt and validate the tag in a received
message. The common method handles the generation of nonce and
preparation of AES-CCM header.
diff --git a/src/core/thread/mle.cpp b/src/core/thread/mle.cpp
index 396df6d..b64f529 100644
--- a/src/core/thread/mle.cpp
+++ b/src/core/thread/mle.cpp
@@ -46,7 +46,6 @@
 #include "common/random.hpp"
 #include "common/serial_number.hpp"
 #include "common/settings.hpp"
-#include "crypto/aes_ccm.hpp"
 #include "meshcop/meshcop.hpp"
 #include "meshcop/meshcop_tlvs.hpp"
 #include "net/netif.hpp"
@@ -2685,6 +2684,98 @@
 }
 #endif
 
+Error Mle::ProcessMessageSecurity(Crypto::AesCcm::Mode    aMode,
+                                  Message &               aMessage,
+                                  const Ip6::MessageInfo &aMessageInfo,
+                                  uint16_t                aCmdOffset,
+                                  const SecurityHeader &  aHeader)
+{
+    // This method performs MLE message security. Based on `aMode` it
+    // can be used to encrypt and append tag to `aMessage` or to
+    // decrypt and validate the tag in a received `aMessage` (which is
+    // then removed from `aMessage`).
+    //
+    // `aCmdOffset` in both cases specifies the offset in `aMessage`
+    // to the start of MLE payload (i.e., the command field).
+    //
+    // When decrypting, possible errors are:
+    // `kErrorNone` decrypted and verified tag, tag is also removed.
+    // `kErrorParse` message does not contain the tag
+    // `kErrorSecurity` message tag is invalid.
+    //
+    // When encrypting, possible errors are:
+    // `kErrorNone` message encrypted and tag appended to message.
+    // `kErrorNoBufs` could not grow the message to append the tag.
+
+    Error               error = kErrorNone;
+    Crypto::AesCcm      aesCcm;
+    uint8_t             nonce[Crypto::AesCcm::kNonceSize];
+    uint8_t             tag[kMleSecurityTagSize];
+    Mac::ExtAddress     extAddress;
+    uint32_t            keySequence;
+    uint16_t            payloadLength   = aMessage.GetLength() - aCmdOffset;
+    const Ip6::Address *senderAddress   = &aMessageInfo.GetSockAddr();
+    const Ip6::Address *receiverAddress = &aMessageInfo.GetPeerAddr();
+
+    switch (aMode)
+    {
+    case Crypto::AesCcm::kEncrypt:
+        // Use the initialized values for `senderAddress`,
+        // `receiverAddress` and `payloadLength`
+        break;
+
+    case Crypto::AesCcm::kDecrypt:
+        senderAddress   = &aMessageInfo.GetPeerAddr();
+        receiverAddress = &aMessageInfo.GetSockAddr();
+        // Ensure message contains command field (uint8_t) and
+        // tag. Then exclude the tag from payload to decrypt.
+        VerifyOrExit(aCmdOffset + sizeof(uint8_t) + kMleSecurityTagSize <= aMessage.GetLength(), error = kErrorParse);
+        payloadLength -= kMleSecurityTagSize;
+        break;
+    }
+
+    senderAddress->GetIid().ConvertToExtAddress(extAddress);
+    Crypto::AesCcm::GenerateNonce(extAddress, aHeader.GetFrameCounter(), Mac::Frame::kSecEncMic32, nonce);
+
+    keySequence = aHeader.GetKeyId();
+
+    aesCcm.SetKey(keySequence == Get<KeyManager>().GetCurrentKeySequence()
+                      ? Get<KeyManager>().GetCurrentMleKey()
+                      : Get<KeyManager>().GetTemporaryMleKey(keySequence));
+
+    aesCcm.Init(sizeof(Ip6::Address) + sizeof(Ip6::Address) + sizeof(SecurityHeader), payloadLength,
+                kMleSecurityTagSize, nonce, sizeof(nonce));
+
+    aesCcm.Header(*senderAddress);
+    aesCcm.Header(*receiverAddress);
+    aesCcm.Header(aHeader);
+
+#ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION
+    if (aMode == Crypto::AesCcm::kDecrypt)
+    {
+        // Skip decrypting the message under fuzz build mode
+        IgnoreError(aMessage.SetLength(aMessage.GetLength() - kMleSecurityTagSize));
+        ExitNow();
+    }
+#endif
+
+    aesCcm.Payload(aMessage, aCmdOffset, payloadLength, aMode);
+    aesCcm.Finalize(tag);
+
+    if (aMode == Crypto::AesCcm::kEncrypt)
+    {
+        SuccessOrExit(error = aMessage.Append(tag));
+    }
+    else
+    {
+        VerifyOrExit(aMessage.Compare(aMessage.GetLength() - kMleSecurityTagSize, tag), error = kErrorSecurity);
+        IgnoreError(aMessage.SetLength(aMessage.GetLength() - kMleSecurityTagSize));
+    }
+
+exit:
+    return error;
+}
+
 Error Mle::SendMessage(Message &aMessage, const Ip6::Address &aDestination)
 {
     Error            error  = kErrorNone;
@@ -2692,49 +2783,31 @@
     uint8_t          securitySuite;
     Ip6::MessageInfo messageInfo;
 
+    messageInfo.SetPeerAddr(aDestination);
+    messageInfo.SetSockAddr(mLinkLocal64.GetAddress());
+    messageInfo.SetPeerPort(kUdpPort);
+    messageInfo.SetHopLimit(kMleHopLimit);
+
     IgnoreError(aMessage.Read(offset, securitySuite));
     offset += sizeof(securitySuite);
 
     if (securitySuite == k154Security)
     {
         SecurityHeader header;
-        uint8_t        nonce[Crypto::AesCcm::kNonceSize];
-        uint8_t        tag[kMleSecurityTagSize];
-        Crypto::AesCcm aesCcm;
+
+        // Update the fields in the security header
 
         IgnoreError(aMessage.Read(offset, header));
-
         header.SetFrameCounter(Get<KeyManager>().GetMleFrameCounter());
         header.SetKeyId(Get<KeyManager>().GetCurrentKeySequence());
-
         aMessage.Write(offset, header);
         offset += sizeof(SecurityHeader);
 
-        Crypto::AesCcm::GenerateNonce(Get<Mac::Mac>().GetExtAddress(), Get<KeyManager>().GetMleFrameCounter(),
-                                      Mac::Frame::kSecEncMic32, nonce);
-
-        aesCcm.SetKey(Get<KeyManager>().GetCurrentMleKey());
-
-        aesCcm.Init(sizeof(Ip6::Address) + sizeof(Ip6::Address) + sizeof(SecurityHeader), aMessage.GetLength() - offset,
-                    kMleSecurityTagSize, nonce, sizeof(nonce));
-
-        aesCcm.Header(mLinkLocal64.GetAddress());
-        aesCcm.Header(aDestination);
-        aesCcm.Header(header);
-
-        aesCcm.Payload(aMessage, offset, aMessage.GetLength() - offset, Crypto::AesCcm::kEncrypt);
-
-        aesCcm.Finalize(tag);
-        SuccessOrExit(error = aMessage.Append(tag));
+        SuccessOrExit(error = ProcessMessageSecurity(Crypto::AesCcm::kEncrypt, aMessage, messageInfo, offset, header));
 
         Get<KeyManager>().IncrementMleFrameCounter();
     }
 
-    messageInfo.SetPeerAddr(aDestination);
-    messageInfo.SetSockAddr(mLinkLocal64.GetAddress());
-    messageInfo.SetPeerPort(kUdpPort);
-    messageInfo.SetHopLimit(kMleHopLimit);
-
     SuccessOrExit(error = mSocket.SendTo(aMessage, messageInfo));
 
 exit:
@@ -2771,14 +2844,9 @@
     SecurityHeader  header;
     uint32_t        keySequence;
     uint32_t        frameCounter;
-    uint8_t         messageTag[kMleSecurityTagSize];
-    uint8_t         nonce[Crypto::AesCcm::kNonceSize];
     Mac::ExtAddress extAddr;
-    Crypto::AesCcm  aesCcm;
-    uint8_t         tag[kMleSecurityTagSize];
     uint8_t         command;
     Neighbor *      neighbor;
-    bool            skipLoggingError = false;
 
     LogDebg("Receive MLE message");
 
@@ -2822,42 +2890,8 @@
     keySequence  = header.GetKeyId();
     frameCounter = header.GetFrameCounter();
 
-    VerifyOrExit(aMessage.GetOffset() + sizeof(command) + sizeof(messageTag) <= aMessage.GetLength(),
-                 error = kErrorParse);
-
-    IgnoreError(aMessage.Read(aMessage.GetLength() - sizeof(messageTag), messageTag));
-    SuccessOrExit(error = aMessage.SetLength(aMessage.GetLength() - sizeof(messageTag)));
-
-    aMessageInfo.GetPeerAddr().GetIid().ConvertToExtAddress(extAddr);
-
-    Crypto::AesCcm::GenerateNonce(extAddr, frameCounter, Mac::Frame::kSecEncMic32, nonce);
-
-    aesCcm.SetKey((keySequence == Get<KeyManager>().GetCurrentKeySequence())
-                      ? Get<KeyManager>().GetCurrentMleKey()
-                      : Get<KeyManager>().GetTemporaryMleKey(keySequence));
-
-    aesCcm.Init(sizeof(Ip6::Address) + sizeof(Ip6::Address) + sizeof(SecurityHeader),
-                aMessage.GetLength() - aMessage.GetOffset(), sizeof(messageTag), nonce, sizeof(nonce));
-
-    aesCcm.Header(aMessageInfo.GetPeerAddr());
-    aesCcm.Header(aMessageInfo.GetSockAddr());
-    aesCcm.Header(header);
-
-#ifndef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION
-    aesCcm.Payload(aMessage, aMessage.GetOffset(), aMessage.GetLength() - aMessage.GetOffset(),
-                   Crypto::AesCcm::kDecrypt);
-    aesCcm.Finalize(tag);
-
-    if (memcmp(messageTag, tag, sizeof(tag)) != 0)
-    {
-        // We skip logging security check failures for broadcast MLE
-        // messages since it can be common to receive such messages
-        // from adjacent Thread networks.
-        skipLoggingError =
-            (aMessageInfo.GetSockAddr().IsMulticast() && aMessageInfo.GetThreadLinkInfo()->IsDstPanIdBroadcast());
-        ExitNow(error = kErrorSecurity);
-    }
-#endif
+    SuccessOrExit(
+        error = ProcessMessageSecurity(Crypto::AesCcm::kDecrypt, aMessage, aMessageInfo, aMessage.GetOffset(), header));
 
     if (keySequence > Get<KeyManager>().GetCurrentKeySequence())
     {
@@ -2867,6 +2901,7 @@
     IgnoreError(aMessage.Read(aMessage.GetOffset(), command));
     aMessage.MoveOffset(sizeof(command));
 
+    aMessageInfo.GetPeerAddr().GetIid().ConvertToExtAddress(extAddr);
     neighbor = (command == kCommandChildIdResponse) ? mNeighborTable.FindParent(extAddr)
                                                     : mNeighborTable.FindNeighbor(extAddr);
 
@@ -3047,7 +3082,10 @@
 #endif
 
 exit:
-    if (!skipLoggingError)
+    // We skip logging failures for broadcast MLE messages since it
+    // can be common to receive such messages from adjacent Thread
+    // networks.
+    if (!aMessageInfo.GetSockAddr().IsMulticast() || !aMessageInfo.GetThreadLinkInfo()->IsDstPanIdBroadcast())
     {
         LogProcessError(kTypeGenericUdp, error);
     }
diff --git a/src/core/thread/mle.hpp b/src/core/thread/mle.hpp
index be2a1a2..205b818 100644
--- a/src/core/thread/mle.hpp
+++ b/src/core/thread/mle.hpp
@@ -42,6 +42,7 @@
 #include "common/non_copyable.hpp"
 #include "common/notifier.hpp"
 #include "common/timer.hpp"
+#include "crypto/aes_ccm.hpp"
 #include "mac/mac.hpp"
 #include "meshcop/joiner_router.hpp"
 #include "meshcop/meshcop.hpp"
@@ -1828,6 +1829,12 @@
                         uint8_t                aCslUncertainty);
     bool IsNetworkDataNewer(const LeaderData &aLeaderData);
 
+    Error ProcessMessageSecurity(Crypto::AesCcm::Mode    aMode,
+                                 Message &               aMessage,
+                                 const Ip6::MessageInfo &aMessageInfo,
+                                 uint16_t                aCmdOffset,
+                                 const SecurityHeader &  aHeader);
+
 #if OPENTHREAD_CONFIG_TMF_NETDATA_SERVICE_ENABLE
     ServiceAloc *FindInServiceAlocs(uint16_t aAloc16);
     void         UpdateServiceAlocs(void);