Skip to content

Commit

Permalink
Send optional MRP parameters in PASE messages (#12638)
Browse files Browse the repository at this point in the history
* Send optional MRP parameters in PASE messages

* add documentation and unit tests

* fix test build errors

* Restyled by clang-format

* Use base class references to enable CASE MRP related work

Co-authored-by: Restyled.io <commits@restyled.io>
  • Loading branch information
2 people authored and pull[bot] committed Jan 9, 2024
1 parent ecc8a0f commit 1069240
Show file tree
Hide file tree
Showing 12 changed files with 403 additions and 57 deletions.
8 changes: 5 additions & 3 deletions src/app/server/CommissioningWindowManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,9 @@ CHIP_ERROR CommissioningWindowManager::OpenCommissioningWindow()
if (mUseECM)
{
ReturnErrorOnFailure(SetTemporaryDiscriminator(mECMDiscriminator));
ReturnErrorOnFailure(mPairingSession.WaitForPairing(mECMPASEVerifier, mECMIterations, ByteSpan(mECMSalt, mECMSaltLength),
mECMPasscodeID, keyID, this));
ReturnErrorOnFailure(
mPairingSession.WaitForPairing(mECMPASEVerifier, mECMIterations, ByteSpan(mECMSalt, mECMSaltLength), mECMPasscodeID,
keyID, Optional<ReliableMessageProtocolConfig>::Value(gDefaultMRPConfig), this));

// reset all advertising, indicating we are in commissioningMode
app::DnssdServer::Instance().StartServer(Dnssd::CommissioningMode::kEnabledEnhanced);
Expand All @@ -189,7 +190,8 @@ CHIP_ERROR CommissioningWindowManager::OpenCommissioningWindow()

ReturnErrorOnFailure(mPairingSession.WaitForPairing(
pinCode, kSpake2p_Iteration_Count,
ByteSpan(reinterpret_cast<const uint8_t *>(kSpake2pKeyExchangeSalt), strlen(kSpake2pKeyExchangeSalt)), keyID, this));
ByteSpan(reinterpret_cast<const uint8_t *>(kSpake2pKeyExchangeSalt), strlen(kSpake2pKeyExchangeSalt)), keyID,
Optional<ReliableMessageProtocolConfig>::Value(gDefaultMRPConfig), this));

// reset all advertising, indicating we are in commissioningMode
app::DnssdServer::Instance().StartServer(Dnssd::CommissioningMode::kEnabledBasic);
Expand Down
3 changes: 2 additions & 1 deletion src/controller/CHIPDeviceController.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,8 @@ CHIP_ERROR DeviceCommissioner::EstablishPASEConnection(NodeId remoteDeviceId, Re
// TODO - Remove use of SetActive/IsActive from CommissioneeDeviceProxy
device->SetActive(true);

err = device->GetPairing().Pair(params.GetPeerAddress(), params.GetSetupPINCode(), keyID, exchangeCtxt, this);
err = device->GetPairing().Pair(params.GetPeerAddress(), params.GetSetupPINCode(), keyID,
Optional<ReliableMessageProtocolConfig>::Value(mMRPConfig), exchangeCtxt, this);
SuccessOrExit(err);

// Immediately persist the updated mNextKeyID value
Expand Down
2 changes: 2 additions & 0 deletions src/controller/CHIPDeviceController.h
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,8 @@ class DLL_EXPORT DeviceCommissioner : public DeviceController,

Callback::Callback<OnNOCChainGeneration> mDeviceNOCChainCallback;
SetUpCodePairer mSetUpCodePairer;

ReliableMessageProtocolConfig mMRPConfig = gDefaultMRPConfig;
};

} // namespace Controller
Expand Down
4 changes: 0 additions & 4 deletions src/protocols/secure_channel/CASESession.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,6 @@ class DLL_EXPORT CASESession : public Messaging::ExchangeDelegate, public Pairin
*/
virtual CHIP_ERROR DeriveSecureSession(CryptoContext & session, CryptoContext::SessionRole role) override;

const char * GetI2RSessionInfo() const override { return "Sigma I2R Key"; }

const char * GetR2ISessionInfo() const override { return "Sigma R2I Key"; }

/**
* @brief Serialize the CASESession to the given cachableSession data structure for secure pairing
**/
Expand Down
59 changes: 39 additions & 20 deletions src/protocols/secure_channel/PASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,8 @@ CHIP_ERROR PASESession::SetupSpake2p(uint32_t pbkdf2IterCount, const ByteSpan &
}

CHIP_ERROR PASESession::WaitForPairing(uint32_t mySetUpPINCode, uint32_t pbkdf2IterCount, const ByteSpan & salt,
uint16_t mySessionId, SessionEstablishmentDelegate * delegate)
uint16_t mySessionId, Optional<ReliableMessageProtocolConfig> mrpConfig,
SessionEstablishmentDelegate * delegate)
{
// Return early on error here, as we have not initalized any state yet
ReturnErrorCodeIf(salt.empty(), CHIP_ERROR_INVALID_ARGUMENT);
Expand Down Expand Up @@ -281,6 +282,7 @@ CHIP_ERROR PASESession::WaitForPairing(uint32_t mySetUpPINCode, uint32_t pbkdf2I
mNextExpectedMsg = MsgType::PBKDFParamRequest;
mPairingComplete = false;
mPasscodeID = 0;
mLocalMRPConfig = mrpConfig;

ChipLogDetail(SecureChannel, "Waiting for PBKDF param request");

Expand All @@ -293,9 +295,10 @@ CHIP_ERROR PASESession::WaitForPairing(uint32_t mySetUpPINCode, uint32_t pbkdf2I
}

CHIP_ERROR PASESession::WaitForPairing(const PASEVerifier & verifier, uint32_t pbkdf2IterCount, const ByteSpan & salt,
uint16_t passcodeID, uint16_t mySessionId, SessionEstablishmentDelegate * delegate)
uint16_t passcodeID, uint16_t mySessionId, Optional<ReliableMessageProtocolConfig> mrpConfig,
SessionEstablishmentDelegate * delegate)
{
ReturnErrorOnFailure(WaitForPairing(0, pbkdf2IterCount, salt, mySessionId, delegate));
ReturnErrorOnFailure(WaitForPairing(0, pbkdf2IterCount, salt, mySessionId, mrpConfig, delegate));

memmove(&mPASEVerifier, &verifier, sizeof(verifier));
mComputeVerifier = false;
Expand All @@ -305,7 +308,8 @@ CHIP_ERROR PASESession::WaitForPairing(const PASEVerifier & verifier, uint32_t p
}

CHIP_ERROR PASESession::Pair(const Transport::PeerAddress peerAddress, uint32_t peerSetUpPINCode, uint16_t mySessionId,
Messaging::ExchangeContext * exchangeCtxt, SessionEstablishmentDelegate * delegate)
Optional<ReliableMessageProtocolConfig> mrpConfig, Messaging::ExchangeContext * exchangeCtxt,
SessionEstablishmentDelegate * delegate)
{
ReturnErrorCodeIf(exchangeCtxt == nullptr, CHIP_ERROR_INVALID_ARGUMENT);
CHIP_ERROR err = Init(mySessionId, peerSetUpPINCode, delegate);
Expand All @@ -316,6 +320,8 @@ CHIP_ERROR PASESession::Pair(const Transport::PeerAddress peerAddress, uint32_t

SetPeerAddress(peerAddress);

mLocalMRPConfig = mrpConfig;

err = SendPBKDFParamRequest();
SuccessOrExit(err);

Expand Down Expand Up @@ -355,13 +361,14 @@ CHIP_ERROR PASESession::SendPBKDFParamRequest()
{
ReturnErrorOnFailure(DRBG_get_bytes(mPBKDFLocalRandomData, sizeof(mPBKDFLocalRandomData)));

const size_t max_msg_len = TLV::EstimateStructOverhead(kPBKDFParamRandomNumberSize, // initiatorRandom,
const size_t mrpParamsSize = mLocalMRPConfig.HasValue() ? TLV::EstimateStructOverhead(sizeof(uint16_t), sizeof(uint16_t)) : 0;
const size_t max_msg_len = TLV::EstimateStructOverhead(kPBKDFParamRandomNumberSize, // initiatorRandom,
sizeof(uint16_t), // initiatorSessionId
sizeof(uint16_t), // passcodeId,
sizeof(uint8_t) // hasPBKDFParameters
/* TLV::EstimateStructOverhead(sizeof(uint16_t),
sizeof(uint16)_t), // initiatorMRPParams */
sizeof(uint8_t), // hasPBKDFParameters
mrpParamsSize // MRP Parameters
);

System::PacketBufferHandle req = System::PacketBufferHandle::New(max_msg_len);
VerifyOrReturnError(!req.IsNull(), CHIP_ERROR_NO_MEMORY);

Expand All @@ -374,9 +381,11 @@ CHIP_ERROR PASESession::SendPBKDFParamRequest()
ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(2), GetLocalSessionId()));
ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(3), mPasscodeID));
ReturnErrorOnFailure(tlvWriter.PutBoolean(TLV::ContextTag(4), mHavePBKDFParameters));
// TODO - Add optional MRP parameter support to PASE
// When we add MRP params here, adjust the TLV::EstimateStructOverhead call
// above accordingly.
if (mLocalMRPConfig.HasValue())
{
ChipLogDetail(SecureChannel, "Including MRP parameters in PBKDF param request");
ReturnErrorOnFailure(EncodeMRPParameters(TLV::ContextTag(5), mLocalMRPConfig.Value(), tlvWriter));
}
ReturnErrorOnFailure(tlvWriter.EndContainer(outerContainerType));
ReturnErrorOnFailure(tlvWriter.Finalize(&req));

Expand Down Expand Up @@ -433,7 +442,7 @@ CHIP_ERROR PASESession::HandlePBKDFParamRequest(System::PacketBufferHandle && ms
VerifyOrExit(TLV::TagNumFromTag(tlvReader.GetTag()) == ++decodeTagIdSeq, err = CHIP_ERROR_INVALID_TLV_TAG);
SuccessOrExit(err = tlvReader.Get(hasPBKDFParameters));

// TODO - Check if optional MRP parameters were sent. If so, cache them.
SuccessOrExit(err = DecodeMRPParametersIfPresent(tlvReader));

err = SendPBKDFParamResponse(ByteSpan(initiatorRandom), hasPBKDFParameters);
SuccessOrExit(err);
Expand All @@ -453,14 +462,15 @@ CHIP_ERROR PASESession::SendPBKDFParamResponse(ByteSpan initiatorRandom, bool in
{
ReturnErrorOnFailure(DRBG_get_bytes(mPBKDFLocalRandomData, sizeof(mPBKDFLocalRandomData)));

const size_t mrpParamsSize = mLocalMRPConfig.HasValue() ? TLV::EstimateStructOverhead(sizeof(uint16_t), sizeof(uint16_t)) : 0;
const size_t max_msg_len =
TLV::EstimateStructOverhead(kPBKDFParamRandomNumberSize, // initiatorRandom
kPBKDFParamRandomNumberSize, // responderRandom
sizeof(uint16_t), // responderSessionId
TLV::EstimateStructOverhead(sizeof(uint32_t), mSaltLength) // pbkdf_parameters
/* TLV::EstimateStructOverhead(sizeof(uint16_t),
sizeof(uint16)_t), // responderMRPParams */
TLV::EstimateStructOverhead(kPBKDFParamRandomNumberSize, // initiatorRandom
kPBKDFParamRandomNumberSize, // responderRandom
sizeof(uint16_t), // responderSessionId
TLV::EstimateStructOverhead(sizeof(uint32_t), mSaltLength), // pbkdf_parameters
mrpParamsSize // MRP Parameters
);

System::PacketBufferHandle resp = System::PacketBufferHandle::New(max_msg_len);
VerifyOrReturnError(!resp.IsNull(), CHIP_ERROR_NO_MEMORY);

Expand All @@ -483,8 +493,11 @@ CHIP_ERROR PASESession::SendPBKDFParamResponse(ByteSpan initiatorRandom, bool in
ReturnErrorOnFailure(tlvWriter.EndContainer(pbkdfParamContainer));
}

// When we add MRP params here, adjust the TLV::EstimateStructOverhead call
// above accordingly.
if (mLocalMRPConfig.HasValue())
{
ChipLogDetail(SecureChannel, "Including MRP parameters in PBKDF param response");
ReturnErrorOnFailure(EncodeMRPParameters(TLV::ContextTag(5), mLocalMRPConfig.Value(), tlvWriter));
}

ReturnErrorOnFailure(tlvWriter.EndContainer(outerContainerType));
ReturnErrorOnFailure(tlvWriter.Finalize(&resp));
Expand Down Expand Up @@ -549,6 +562,8 @@ CHIP_ERROR PASESession::HandlePBKDFParamResponse(System::PacketBufferHandle && m

if (mHavePBKDFParameters)
{
SuccessOrExit(err = DecodeMRPParametersIfPresent(tlvReader));

// TODO - Add a unit test that exercises mHavePBKDFParameters path
err = SetupSpake2p(mIterationCount, ByteSpan(mSalt, mSaltLength));
SuccessOrExit(err);
Expand All @@ -568,6 +583,10 @@ CHIP_ERROR PASESession::HandlePBKDFParamResponse(System::PacketBufferHandle && m
saltLength = tlvReader.GetLength();
SuccessOrExit(err = tlvReader.GetDataPtr(salt));

SuccessOrExit(err = tlvReader.ExitContainer(containerType));

SuccessOrExit(err = DecodeMRPParametersIfPresent(tlvReader));

err = SetupSpake2p(iterCount, ByteSpan(salt, saltLength));
SuccessOrExit(err);
}
Expand Down
18 changes: 7 additions & 11 deletions src/protocols/secure_channel/PASESession.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class DLL_EXPORT PASESession : public Messaging::ExchangeDelegate, public Pairin
* @return CHIP_ERROR The result of initialization
*/
CHIP_ERROR WaitForPairing(uint32_t mySetUpPINCode, uint32_t pbkdf2IterCount, const ByteSpan & salt, uint16_t mySessionId,
SessionEstablishmentDelegate * delegate);
Optional<ReliableMessageProtocolConfig> mrpConfig, SessionEstablishmentDelegate * delegate);

/**
* @brief
Expand All @@ -123,7 +123,8 @@ class DLL_EXPORT PASESession : public Messaging::ExchangeDelegate, public Pairin
* @return CHIP_ERROR The result of initialization
*/
CHIP_ERROR WaitForPairing(const PASEVerifier & verifier, uint32_t pbkdf2IterCount, const ByteSpan & salt, uint16_t passcodeID,
uint16_t mySessionId, SessionEstablishmentDelegate * delegate);
uint16_t mySessionId, Optional<ReliableMessageProtocolConfig> mrpConfig,
SessionEstablishmentDelegate * delegate);

/**
* @brief
Expand All @@ -141,7 +142,8 @@ class DLL_EXPORT PASESession : public Messaging::ExchangeDelegate, public Pairin
* @return CHIP_ERROR The result of initialization
*/
CHIP_ERROR Pair(const Transport::PeerAddress peerAddress, uint32_t peerSetUpPINCode, uint16_t mySessionId,
Messaging::ExchangeContext * exchangeCtxt, SessionEstablishmentDelegate * delegate);
Optional<ReliableMessageProtocolConfig> mrpConfig, Messaging::ExchangeContext * exchangeCtxt,
SessionEstablishmentDelegate * delegate);

/**
* @brief
Expand Down Expand Up @@ -170,10 +172,6 @@ class DLL_EXPORT PASESession : public Messaging::ExchangeDelegate, public Pairin
*/
CHIP_ERROR DeriveSecureSession(CryptoContext & session, CryptoContext::SessionRole role) override;

const char * GetI2RSessionInfo() const override { return kSpake2pI2RSessionInfo; }

const char * GetR2ISessionInfo() const override { return kSpake2pR2ISessionInfo; }

/** @brief Serialize the Pairing Session to a string.
*
* @return Returns a CHIP_ERROR on error, CHIP_NO_ERROR otherwise
Expand Down Expand Up @@ -304,6 +302,8 @@ class DLL_EXPORT PASESession : public Messaging::ExchangeDelegate, public Pairin

SessionEstablishmentExchangeDispatch mMessageDispatch;

Optional<ReliableMessageProtocolConfig> mLocalMRPConfig;

struct Spake2pErrorMsg
{
Spake2pErrorType error;
Expand Down Expand Up @@ -369,10 +369,6 @@ class SecurePairingUsingTestSecret : public PairingSession
return CHIP_NO_ERROR;
}

const char * GetI2RSessionInfo() const override { return "i2r"; }

const char * GetR2ISessionInfo() const override { return "r2i"; }

private:
const char * kTestSecret = CHIP_CONFIG_TEST_SHARED_SECRET_VALUE;
};
Expand Down
Loading

0 comments on commit 1069240

Please sign in to comment.