[joiner] adding otJoinerPskd and JoinerPskd types (#5137)
diff --git a/include/openthread/commissioner.h b/include/openthread/commissioner.h
index d81e043..8ff811e 100644
--- a/include/openthread/commissioner.h
+++ b/include/openthread/commissioner.h
@@ -113,7 +113,16 @@
bool mIsJoinerUdpPortSet : 1; ///< TRUE if Joiner UDP Port is set, FALSE otherwise.
} otCommissioningDataset;
-#define OT_PSKD_MAX_SIZE 32 ///< Size of a Joiner PSKd (bytes)
+#define OT_JOINER_MAX_PSKD_LENGTH 32 ///< Maximum string length of a Joiner PSKd (does not include null char).
+
+/**
+ * This structure represent a Joiner PSKd.
+ *
+ */
+typedef struct otJoinerPskd
+{
+ char m8[OT_JOINER_MAX_PSKD_LENGTH + 1]; ///< Char string array (must be null terminated - +1 is for null char).
+} otJoinerPskd;
/**
* This enumeration defines a Joiner Info Typer.
@@ -135,11 +144,11 @@
otJoinerInfoType mType; ///< Joiner type.
union
{
- otExtAddress mEui64; ///< Joiner EUI64 (when `mType` is `OT_JOINER_INFO_TYPE_EUI64`)
- otJoinerDiscerner mDiscerner; ///< Joiner Discerner (when `mType` is `OT_JOINER_INFO_TYPE_DISCERNER`)
- } mSharedId; ///< Shared fields
- char mPsk[OT_PSKD_MAX_SIZE + 1]; ///< Joiner PSKd
- uint32_t mExpirationTime; ///< Joiner expiration time in msec
+ otExtAddress mEui64; ///< Joiner EUI64 (when `mType` is `OT_JOINER_INFO_TYPE_EUI64`)
+ otJoinerDiscerner mDiscerner; ///< Joiner Discerner (when `mType` is `OT_JOINER_INFO_TYPE_DISCERNER`)
+ } mSharedId; ///< Shared fields
+ otJoinerPskd mPskd; ///< Joiner PSKd
+ uint32_t mExpirationTime; ///< Joiner expiration time in msec
} otJoinerInfo;
/**
diff --git a/include/openthread/instance.h b/include/openthread/instance.h
index 2c8897e..a0722a5 100644
--- a/include/openthread/instance.h
+++ b/include/openthread/instance.h
@@ -53,7 +53,7 @@
* @note This number versions both OpenThread platform and user APIs.
*
*/
-#define OPENTHREAD_API_VERSION (8)
+#define OPENTHREAD_API_VERSION (9)
/**
* @addtogroup api-instance
diff --git a/src/core/coap/coap_secure.cpp b/src/core/coap/coap_secure.cpp
index 9a4f79c..e345c89 100644
--- a/src/core/coap/coap_secure.cpp
+++ b/src/core/coap/coap_secure.cpp
@@ -115,6 +115,21 @@
return mDtls.SetPsk(aPsk, aPskLength);
}
+void CoapSecure::SetPsk(const MeshCoP::JoinerPskd &aPskd)
+{
+ otError error;
+
+ OT_UNUSED_VARIABLE(error);
+
+ static_assert(static_cast<uint16_t>(MeshCoP::JoinerPskd::kMaxLength) <=
+ static_cast<uint16_t>(MeshCoP::Dtls::kPskMaxLength),
+ "The maximum length of DTLS PSK is smaller than joiner PSKd");
+
+ error = mDtls.SetPsk(reinterpret_cast<const uint8_t *>(aPskd.GetAsCString()), aPskd.GetLength());
+
+ OT_ASSERT(error == OT_ERROR_NONE);
+}
+
#if OPENTHREAD_CONFIG_COAP_SECURE_API_ENABLE
#ifdef MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED
diff --git a/src/core/coap/coap_secure.hpp b/src/core/coap/coap_secure.hpp
index a0281f5..bb3fc56 100644
--- a/src/core/coap/coap_secure.hpp
+++ b/src/core/coap/coap_secure.hpp
@@ -33,6 +33,7 @@
#include "coap/coap.hpp"
#include "meshcop/dtls.hpp"
+#include "meshcop/meshcop.hpp"
#include <openthread/coap_secure.h>
@@ -160,6 +161,14 @@
*/
otError SetPsk(const uint8_t *aPsk, uint8_t aPskLength);
+ /**
+ * This method sets the PSK.
+ *
+ * @param[in] aPskd A Joiner PSKd.
+ *
+ */
+ void SetPsk(const MeshCoP::JoinerPskd &aPskd);
+
#if OPENTHREAD_CONFIG_COAP_SECURE_API_ENABLE
#ifdef MBEDTLS_KEY_EXCHANGE_PSK_ENABLED
diff --git a/src/core/meshcop/commissioner.cpp b/src/core/meshcop/commissioner.cpp
index 9df9481..1e2e933 100644
--- a/src/core/meshcop/commissioner.cpp
+++ b/src/core/meshcop/commissioner.cpp
@@ -452,7 +452,6 @@
Joiner *joiner;
VerifyOrExit(mState == OT_COMMISSIONER_STATE_ACTIVE, error = OT_ERROR_INVALID_STATE);
- VerifyOrExit(IsPskdValid(aPskd), error = OT_ERROR_INVALID_ARGS);
if (aDiscerner != nullptr)
{
@@ -471,6 +470,8 @@
VerifyOrExit(joiner != nullptr, error = OT_ERROR_NO_BUFS);
+ SuccessOrExit(error = joiner->mPskd.SetFrom(aPskd));
+
if (aDiscerner != nullptr)
{
joiner->mType = Joiner::kTypeDiscerner;
@@ -486,9 +487,6 @@
joiner->mType = Joiner::kTypeAny;
}
- strncpy(joiner->mPsk, aPskd, sizeof(joiner->mPsk) - 1);
- joiner->mPsk[sizeof(joiner->mPsk) - 1] = '\0';
-
joiner->mExpirationTime = TimerMilli::GetNow() + Time::SecToMsec(aTimeout);
UpdateJoinerExpirationTimer();
@@ -525,7 +523,7 @@
ExitNow();
}
- strncpy(aInfo.mPsk, mPsk, sizeof(aInfo.mPsk) - 1);
+ aInfo.mPskd = mPskd;
aInfo.mExpirationTime = mExpirationTime - TimerMilli::GetNow();
exit:
@@ -1063,8 +1061,7 @@
joiner = FindBestMatchingJoinerEntry(receivedId);
VerifyOrExit(joiner != nullptr, OT_NOOP);
- SuccessOrExit(error = Get<Coap::CoapSecure>().SetPsk(reinterpret_cast<const uint8_t *>(joiner->mPsk),
- static_cast<uint8_t>(strlen(joiner->mPsk))));
+ Get<Coap::CoapSecure>().SetPsk(joiner->mPskd);
mActiveJoiner = joiner;
LogJoinerEntry("Starting new session with", *joiner);
@@ -1302,17 +1299,17 @@
break;
case Joiner::kTypeAny:
- otLogInfoMeshCoP("%s Joiner (any, %s)", aAction, aJoiner.mPsk);
+ otLogInfoMeshCoP("%s Joiner (any, %s)", aAction, aJoiner.mPskd.GetAsCString());
break;
case Joiner::kTypeEui64:
otLogInfoMeshCoP("%s Joiner (eui64:%s, %s)", aAction, aJoiner.mSharedId.mEui64.ToString().AsCString(),
- aJoiner.mPsk);
+ aJoiner.mPskd.GetAsCString());
break;
case Joiner::kTypeDiscerner:
otLogInfoMeshCoP("%s Joiner (disc:%s, %s)", aAction, aJoiner.mSharedId.mDiscerner.ToString().AsCString(),
- aJoiner.mPsk);
+ aJoiner.mPskd.GetAsCString());
break;
}
}
diff --git a/src/core/meshcop/commissioner.hpp b/src/core/meshcop/commissioner.hpp
index 44e77bf..cfaaa41 100644
--- a/src/core/meshcop/commissioner.hpp
+++ b/src/core/meshcop/commissioner.hpp
@@ -345,8 +345,8 @@
JoinerDiscerner mDiscerner;
} mSharedId;
- char mPsk[Dtls::kPskMaxLength + 1];
- Type mType;
+ JoinerPskd mPskd;
+ Type mType;
void CopyToJoinerInfo(otJoinerInfo &aInfo) const;
};
diff --git a/src/core/meshcop/joiner.cpp b/src/core/meshcop/joiner.cpp
index 92640e8..3676aff 100644
--- a/src/core/meshcop/joiner.cpp
+++ b/src/core/meshcop/joiner.cpp
@@ -136,6 +136,7 @@
void * aContext)
{
otError error;
+ JoinerPskd joinerPskd;
Mac::ExtAddress randomAddress;
SteeringData::HashBitIndexes filterIndexes;
@@ -143,7 +144,7 @@
VerifyOrExit(mState == OT_JOINER_STATE_IDLE, error = OT_ERROR_BUSY);
- VerifyOrExit(IsPskdValid(aPskd), error = OT_ERROR_INVALID_ARGS);
+ SuccessOrExit(error = joinerPskd.SetFrom(aPskd));
// Use random-generated extended address.
randomAddress.GenerateRandom();
@@ -151,8 +152,7 @@
Get<Mle::MleRouter>().UpdateLinkLocalAddress();
SuccessOrExit(error = Get<Coap::CoapSecure>().Start(kJoinerUdpPort));
- SuccessOrExit(error = Get<Coap::CoapSecure>().SetPsk(reinterpret_cast<const uint8_t *>(aPskd),
- static_cast<uint8_t>(strlen(aPskd))));
+ Get<Coap::CoapSecure>().SetPsk(joinerPskd);
for (JoinerRouter *router = &mJoinerRouters[0]; router < OT_ARRAY_END(mJoinerRouters); router++)
{
diff --git a/src/core/meshcop/meshcop.cpp b/src/core/meshcop/meshcop.cpp
index 15dbe92..1a30d87 100644
--- a/src/core/meshcop/meshcop.cpp
+++ b/src/core/meshcop/meshcop.cpp
@@ -44,11 +44,61 @@
namespace ot {
namespace MeshCoP {
-enum
+otError JoinerPskd::SetFrom(const char *aPskdString)
{
- kPskdMinLength = 6, ///< Minimum PSKd length.
- kPskdMaxLength = 32, ///< Maximum PSKd Length.
-};
+ otError error = OT_ERROR_NONE;
+
+ VerifyOrExit(IsPskdValid(aPskdString), error = OT_ERROR_INVALID_ARGS);
+
+ Clear();
+ memcpy(m8, aPskdString, StringLength(aPskdString, sizeof(m8)));
+
+exit:
+ return error;
+}
+
+bool JoinerPskd::operator==(const JoinerPskd &aOther) const
+{
+ bool isEqual = true;
+
+ for (uint8_t i = 0; i < sizeof(m8); i++)
+ {
+ if (m8[i] != aOther.m8[i])
+ {
+ isEqual = false;
+ ExitNow();
+ }
+
+ if (m8[i] == '\0')
+ {
+ break;
+ }
+ }
+
+exit:
+ return isEqual;
+}
+
+bool JoinerPskd::IsPskdValid(const char *aPskString)
+{
+ bool valid = false;
+ uint16_t pskdLength = StringLength(aPskString, kMaxLength + 1);
+
+ VerifyOrExit(pskdLength >= kMinLength && pskdLength <= kMaxLength, OT_NOOP);
+
+ for (uint16_t i = 0; i < pskdLength; i++)
+ {
+ char c = aPskString[i];
+
+ VerifyOrExit(isdigit(c) || isupper(c), OT_NOOP);
+ VerifyOrExit(c != 'I' && c != 'O' && c != 'Q' && c != 'Z', OT_NOOP);
+ }
+
+ valid = true;
+
+exit:
+ return valid;
+}
void JoinerDiscerner::GenerateJoinerId(Mac::ExtAddress &aJoinerId) const
{
@@ -300,31 +350,5 @@
}
#endif // OPENTHREAD_FTD
-#if OPENTHREAD_CONFIG_JOINER_ENABLE || OPENTHREAD_CONFIG_COMMISSIONER_ENABLE
-bool IsPskdValid(const char *aPskd)
-{
- bool valid = false;
- size_t pskdLength = StringLength(aPskd, kPskdMaxLength + 1);
-
- static_assert(static_cast<uint8_t>(kPskdMaxLength) <= static_cast<uint8_t>(Dtls::kPskMaxLength),
- "The maximum length of DTLS PSK is smaller than joiner PSKd");
-
- VerifyOrExit(pskdLength >= kPskdMinLength && pskdLength <= kPskdMaxLength, OT_NOOP);
-
- for (size_t i = 0; i < pskdLength; i++)
- {
- char c = aPskd[i];
-
- VerifyOrExit(isdigit(c) || isupper(c), OT_NOOP);
- VerifyOrExit(c != 'I' && c != 'O' && c != 'Q' && c != 'Z', OT_NOOP);
- }
-
- valid = true;
-
-exit:
- return valid;
-}
-#endif // OPENTHREAD_CONFIG_JOINER_ENABLE || OPENTHREAD_CONFIG_COMMISSIONER_ENABLE
-
} // namespace MeshCoP
} // namespace ot
diff --git a/src/core/meshcop/meshcop.hpp b/src/core/meshcop/meshcop.hpp
index f386003..e8cd865 100644
--- a/src/core/meshcop/meshcop.hpp
+++ b/src/core/meshcop/meshcop.hpp
@@ -39,11 +39,14 @@
#include <limits.h>
+#include <openthread/commissioner.h>
#include <openthread/instance.h>
#include <openthread/joiner.h>
#include "coap/coap.hpp"
+#include "common/clearable.hpp"
#include "common/message.hpp"
+#include "common/string.hpp"
#include "mac/mac_types.hpp"
#include "meshcop/meshcop_tlvs.hpp"
@@ -59,6 +62,97 @@
};
/**
+ * This type represents a Joiner PSKd.
+ *
+ */
+class JoinerPskd : public otJoinerPskd, public Clearable<JoinerPskd>
+{
+public:
+ enum
+ {
+ kMinLength = 6, ///< Minimum PSKd string length (excluding null char).
+ kMaxLength = OT_JOINER_MAX_PSKD_LENGTH, ///< Maximum PSKd string length (excluding null char)
+ };
+
+ /**
+ * This method indicates whether the PSKd if well-formed and valid.
+ *
+ * Per Thread specification, a Joining Device Credential is encoded as uppercase alphanumeric characters
+ * (base32-thread: 0-9, A-Z excluding I, O, Q, and Z for readability) with a minimum length of 6 such characters
+ * and a maximum length of 32 such characters.
+ *
+ * @returns TRUE if the PSKd is valid, FALSE otherwise.
+ *
+ */
+ bool IsValid(void) const { return IsPskdValid(m8); }
+
+ /**
+ * This method sets the joiner PSKd from a given C string.
+ *
+ * @param[in] aPskdString A pointer to the PSKd C string array.
+ *
+ * @retval OT_ERROR_NONE The PSKd was updated successfully.
+ * @retval OT_ERROR_INVALID_ARGS The given PSKd C string is not valid.
+ *
+ */
+ otError SetFrom(const char *aPskdString);
+
+ /**
+ * This method gets the PSKd as a null terminated C string.
+ *
+ * This method must be used after the PSKd is validated, otherwise its behavior is undefined.
+ *
+ * @returns The PSKd as a C string.
+ *
+ */
+ const char *GetAsCString(void) const { return m8; }
+
+ /**
+ * This method gets the PSKd string length.
+ *
+ * This method must be used after the PSKd is validated, otherwise its behavior is undefined.
+ *
+ * @returns The PSKd string length.
+ *
+ */
+ uint8_t GetLength(void) const { return static_cast<uint8_t>(StringLength(m8, kMaxLength + 1)); }
+
+ /**
+ * This method overloads operator `==` to evaluate whether or not two PSKds are equal.
+ *
+ * @param[in] aOther The other PSKd to compare with.
+ *
+ * @retval TRUE If the two are equal.
+ * @retval FALSE If the two are not equal.
+ *
+ */
+ bool operator==(const JoinerPskd &aOther) const;
+
+ /**
+ * This method overloads operator `!=` to evaluate whether or not two PSKds are equal.
+ *
+ * @param[in] aOther The other PSKd to compare with.
+ *
+ * @retval TRUE If the two are not equal.
+ * @retval FALSE If the two are equal.
+ *
+ */
+ bool operator!=(const JoinerPskd &aOther) const { return !(*this == aOther); }
+
+ /**
+ * This static method indicates whether a given PSKd string if well-formed and valid.
+ *
+ * @param[in] aPskdString A pointer to a PSKd string array.
+ *
+ * @sa IsValid()
+ *
+ * @returns TRUE if @p aPskdString is valid, FALSE otherwise.
+ *
+ */
+ static bool IsPskdValid(const char *aPskdString);
+};
+
+/**
* This type represents a Joiner Discerner.
*
*/
@@ -399,23 +493,6 @@
*/
otError GetBorderAgentRloc(ThreadNetif &aNetIf, uint16_t &aRloc);
-#if OPENTHREAD_CONFIG_JOINER_ENABLE || OPENTHREAD_CONFIG_COMMISSIONER_ENABLE
-/**
- * This method validates the PSKd.
- *
- * Per Thread specification, a Joining Device Credential is encoded as
- * uppercase alphanumeric characters (base32-thread: 0-9, A-Z excluding
- * I, O, Q, and Z for readability) with a minimum length of 6 such
- * characters and a maximum length of 32 such characters.
- *
- * param[in] aPskd The PSKd to validate.
- *
- * @retval A boolean indicates whether the given @p aPskd is valid.
- *
- */
-bool IsPskdValid(const char *aPskd);
-#endif // OPENTHREAD_CONFIG_JOINER_ENABLE || OPENTHREAD_CONFIG_COMMISSIONER_ENABLE
-
} // namespace MeshCoP
} // namespace ot