[mle] add common `ProcessMessageSecurity()` (#7625)
This commit adds `Mle::ProcessMessageSecurity()`. Based on the `aMode`
input, it can be used to encrypt and append tag to a message (which
will be sent) or to decrypt and validate the tag in a received
message. The common method handles the generation of nonce and
preparation of AES-CCM header.
diff --git a/src/core/thread/mle.cpp b/src/core/thread/mle.cpp
index 396df6d..b64f529 100644
--- a/src/core/thread/mle.cpp
+++ b/src/core/thread/mle.cpp
@@ -46,7 +46,6 @@
#include "common/random.hpp"
#include "common/serial_number.hpp"
#include "common/settings.hpp"
-#include "crypto/aes_ccm.hpp"
#include "meshcop/meshcop.hpp"
#include "meshcop/meshcop_tlvs.hpp"
#include "net/netif.hpp"
@@ -2685,6 +2684,98 @@
}
#endif
+Error Mle::ProcessMessageSecurity(Crypto::AesCcm::Mode aMode,
+ Message & aMessage,
+ const Ip6::MessageInfo &aMessageInfo,
+ uint16_t aCmdOffset,
+ const SecurityHeader & aHeader)
+{
+ // This method performs MLE message security. Based on `aMode` it
+ // can be used to encrypt and append tag to `aMessage` or to
+ // decrypt and validate the tag in a received `aMessage` (which is
+ // then removed from `aMessage`).
+ //
+ // `aCmdOffset` in both cases specifies the offset in `aMessage`
+ // to the start of MLE payload (i.e., the command field).
+ //
+ // When decrypting, possible errors are:
+ // `kErrorNone` decrypted and verified tag, tag is also removed.
+ // `kErrorParse` message does not contain the tag
+ // `kErrorSecurity` message tag is invalid.
+ //
+ // When encrypting, possible errors are:
+ // `kErrorNone` message encrypted and tag appended to message.
+ // `kErrorNoBufs` could not grow the message to append the tag.
+
+ Error error = kErrorNone;
+ Crypto::AesCcm aesCcm;
+ uint8_t nonce[Crypto::AesCcm::kNonceSize];
+ uint8_t tag[kMleSecurityTagSize];
+ Mac::ExtAddress extAddress;
+ uint32_t keySequence;
+ uint16_t payloadLength = aMessage.GetLength() - aCmdOffset;
+ const Ip6::Address *senderAddress = &aMessageInfo.GetSockAddr();
+ const Ip6::Address *receiverAddress = &aMessageInfo.GetPeerAddr();
+
+ switch (aMode)
+ {
+ case Crypto::AesCcm::kEncrypt:
+ // Use the initialized values for `senderAddress`,
+ // `receiverAddress` and `payloadLength`
+ break;
+
+ case Crypto::AesCcm::kDecrypt:
+ senderAddress = &aMessageInfo.GetPeerAddr();
+ receiverAddress = &aMessageInfo.GetSockAddr();
+ // Ensure message contains command field (uint8_t) and
+ // tag. Then exclude the tag from payload to decrypt.
+ VerifyOrExit(aCmdOffset + sizeof(uint8_t) + kMleSecurityTagSize <= aMessage.GetLength(), error = kErrorParse);
+ payloadLength -= kMleSecurityTagSize;
+ break;
+ }
+
+ senderAddress->GetIid().ConvertToExtAddress(extAddress);
+ Crypto::AesCcm::GenerateNonce(extAddress, aHeader.GetFrameCounter(), Mac::Frame::kSecEncMic32, nonce);
+
+ keySequence = aHeader.GetKeyId();
+
+ aesCcm.SetKey(keySequence == Get<KeyManager>().GetCurrentKeySequence()
+ ? Get<KeyManager>().GetCurrentMleKey()
+ : Get<KeyManager>().GetTemporaryMleKey(keySequence));
+
+ aesCcm.Init(sizeof(Ip6::Address) + sizeof(Ip6::Address) + sizeof(SecurityHeader), payloadLength,
+ kMleSecurityTagSize, nonce, sizeof(nonce));
+
+ aesCcm.Header(*senderAddress);
+ aesCcm.Header(*receiverAddress);
+ aesCcm.Header(aHeader);
+
+#ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION
+ if (aMode == Crypto::AesCcm::kDecrypt)
+ {
+ // Skip decrypting the message under fuzz build mode
+ IgnoreError(aMessage.SetLength(aMessage.GetLength() - kMleSecurityTagSize));
+ ExitNow();
+ }
+#endif
+
+ aesCcm.Payload(aMessage, aCmdOffset, payloadLength, aMode);
+ aesCcm.Finalize(tag);
+
+ if (aMode == Crypto::AesCcm::kEncrypt)
+ {
+ SuccessOrExit(error = aMessage.Append(tag));
+ }
+ else
+ {
+ VerifyOrExit(aMessage.Compare(aMessage.GetLength() - kMleSecurityTagSize, tag), error = kErrorSecurity);
+ IgnoreError(aMessage.SetLength(aMessage.GetLength() - kMleSecurityTagSize));
+ }
+
+exit:
+ return error;
+}
+
Error Mle::SendMessage(Message &aMessage, const Ip6::Address &aDestination)
{
Error error = kErrorNone;
@@ -2692,49 +2783,31 @@
uint8_t securitySuite;
Ip6::MessageInfo messageInfo;
+ messageInfo.SetPeerAddr(aDestination);
+ messageInfo.SetSockAddr(mLinkLocal64.GetAddress());
+ messageInfo.SetPeerPort(kUdpPort);
+ messageInfo.SetHopLimit(kMleHopLimit);
+
IgnoreError(aMessage.Read(offset, securitySuite));
offset += sizeof(securitySuite);
if (securitySuite == k154Security)
{
SecurityHeader header;
- uint8_t nonce[Crypto::AesCcm::kNonceSize];
- uint8_t tag[kMleSecurityTagSize];
- Crypto::AesCcm aesCcm;
+
+ // Update the fields in the security header
IgnoreError(aMessage.Read(offset, header));
-
header.SetFrameCounter(Get<KeyManager>().GetMleFrameCounter());
header.SetKeyId(Get<KeyManager>().GetCurrentKeySequence());
-
aMessage.Write(offset, header);
offset += sizeof(SecurityHeader);
- Crypto::AesCcm::GenerateNonce(Get<Mac::Mac>().GetExtAddress(), Get<KeyManager>().GetMleFrameCounter(),
- Mac::Frame::kSecEncMic32, nonce);
-
- aesCcm.SetKey(Get<KeyManager>().GetCurrentMleKey());
-
- aesCcm.Init(sizeof(Ip6::Address) + sizeof(Ip6::Address) + sizeof(SecurityHeader), aMessage.GetLength() - offset,
- kMleSecurityTagSize, nonce, sizeof(nonce));
-
- aesCcm.Header(mLinkLocal64.GetAddress());
- aesCcm.Header(aDestination);
- aesCcm.Header(header);
-
- aesCcm.Payload(aMessage, offset, aMessage.GetLength() - offset, Crypto::AesCcm::kEncrypt);
-
- aesCcm.Finalize(tag);
- SuccessOrExit(error = aMessage.Append(tag));
+ SuccessOrExit(error = ProcessMessageSecurity(Crypto::AesCcm::kEncrypt, aMessage, messageInfo, offset, header));
Get<KeyManager>().IncrementMleFrameCounter();
}
- messageInfo.SetPeerAddr(aDestination);
- messageInfo.SetSockAddr(mLinkLocal64.GetAddress());
- messageInfo.SetPeerPort(kUdpPort);
- messageInfo.SetHopLimit(kMleHopLimit);
-
SuccessOrExit(error = mSocket.SendTo(aMessage, messageInfo));
exit:
@@ -2771,14 +2844,9 @@
SecurityHeader header;
uint32_t keySequence;
uint32_t frameCounter;
- uint8_t messageTag[kMleSecurityTagSize];
- uint8_t nonce[Crypto::AesCcm::kNonceSize];
Mac::ExtAddress extAddr;
- Crypto::AesCcm aesCcm;
- uint8_t tag[kMleSecurityTagSize];
uint8_t command;
Neighbor * neighbor;
- bool skipLoggingError = false;
LogDebg("Receive MLE message");
@@ -2822,42 +2890,8 @@
keySequence = header.GetKeyId();
frameCounter = header.GetFrameCounter();
- VerifyOrExit(aMessage.GetOffset() + sizeof(command) + sizeof(messageTag) <= aMessage.GetLength(),
- error = kErrorParse);
-
- IgnoreError(aMessage.Read(aMessage.GetLength() - sizeof(messageTag), messageTag));
- SuccessOrExit(error = aMessage.SetLength(aMessage.GetLength() - sizeof(messageTag)));
-
- aMessageInfo.GetPeerAddr().GetIid().ConvertToExtAddress(extAddr);
-
- Crypto::AesCcm::GenerateNonce(extAddr, frameCounter, Mac::Frame::kSecEncMic32, nonce);
-
- aesCcm.SetKey((keySequence == Get<KeyManager>().GetCurrentKeySequence())
- ? Get<KeyManager>().GetCurrentMleKey()
- : Get<KeyManager>().GetTemporaryMleKey(keySequence));
-
- aesCcm.Init(sizeof(Ip6::Address) + sizeof(Ip6::Address) + sizeof(SecurityHeader),
- aMessage.GetLength() - aMessage.GetOffset(), sizeof(messageTag), nonce, sizeof(nonce));
-
- aesCcm.Header(aMessageInfo.GetPeerAddr());
- aesCcm.Header(aMessageInfo.GetSockAddr());
- aesCcm.Header(header);
-
-#ifndef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION
- aesCcm.Payload(aMessage, aMessage.GetOffset(), aMessage.GetLength() - aMessage.GetOffset(),
- Crypto::AesCcm::kDecrypt);
- aesCcm.Finalize(tag);
-
- if (memcmp(messageTag, tag, sizeof(tag)) != 0)
- {
- // We skip logging security check failures for broadcast MLE
- // messages since it can be common to receive such messages
- // from adjacent Thread networks.
- skipLoggingError =
- (aMessageInfo.GetSockAddr().IsMulticast() && aMessageInfo.GetThreadLinkInfo()->IsDstPanIdBroadcast());
- ExitNow(error = kErrorSecurity);
- }
-#endif
+ SuccessOrExit(
+ error = ProcessMessageSecurity(Crypto::AesCcm::kDecrypt, aMessage, aMessageInfo, aMessage.GetOffset(), header));
if (keySequence > Get<KeyManager>().GetCurrentKeySequence())
{
@@ -2867,6 +2901,7 @@
IgnoreError(aMessage.Read(aMessage.GetOffset(), command));
aMessage.MoveOffset(sizeof(command));
+ aMessageInfo.GetPeerAddr().GetIid().ConvertToExtAddress(extAddr);
neighbor = (command == kCommandChildIdResponse) ? mNeighborTable.FindParent(extAddr)
: mNeighborTable.FindNeighbor(extAddr);
@@ -3047,7 +3082,10 @@
#endif
exit:
- if (!skipLoggingError)
+ // We skip logging failures for broadcast MLE messages since it
+ // can be common to receive such messages from adjacent Thread
+ // networks.
+ if (!aMessageInfo.GetSockAddr().IsMulticast() || !aMessageInfo.GetThreadLinkInfo()->IsDstPanIdBroadcast())
{
LogProcessError(kTypeGenericUdp, error);
}
diff --git a/src/core/thread/mle.hpp b/src/core/thread/mle.hpp
index be2a1a2..205b818 100644
--- a/src/core/thread/mle.hpp
+++ b/src/core/thread/mle.hpp
@@ -42,6 +42,7 @@
#include "common/non_copyable.hpp"
#include "common/notifier.hpp"
#include "common/timer.hpp"
+#include "crypto/aes_ccm.hpp"
#include "mac/mac.hpp"
#include "meshcop/joiner_router.hpp"
#include "meshcop/meshcop.hpp"
@@ -1828,6 +1829,12 @@
uint8_t aCslUncertainty);
bool IsNetworkDataNewer(const LeaderData &aLeaderData);
+ Error ProcessMessageSecurity(Crypto::AesCcm::Mode aMode,
+ Message & aMessage,
+ const Ip6::MessageInfo &aMessageInfo,
+ uint16_t aCmdOffset,
+ const SecurityHeader & aHeader);
+
#if OPENTHREAD_CONFIG_TMF_NETDATA_SERVICE_ENABLE
ServiceAloc *FindInServiceAlocs(uint16_t aAloc16);
void UpdateServiceAlocs(void);