diff --git a/crates/matrix-sdk-crypto/src/identities/manager.rs b/crates/matrix-sdk-crypto/src/identities/manager.rs index 28ab62eba8..7819d7acfa 100644 --- a/crates/matrix-sdk-crypto/src/identities/manager.rs +++ b/crates/matrix-sdk-crypto/src/identities/manager.rs @@ -36,7 +36,7 @@ use crate::{ requests::KeysQueryRequest, store::{ caches::SequenceNumber, Changes, DeviceChanges, IdentityChanges, KeyQueryManager, - Result as StoreResult, Store, StoreCache, UserKeyQueryResult, + Result as StoreResult, Store, StoreCache, StoreCacheGuard, UserKeyQueryResult, }, types::{CrossSigningKey, DeviceKeys, MasterPubkey, SelfSigningPubkey, UserSigningPubkey}, CryptoStoreError, LocalTrust, OwnUserIdentity, SignatureError, UserIdentities, @@ -1147,6 +1147,28 @@ impl IdentityManager { Ok(()) } + + /// Mark all tracked users as dirty. + /// + /// All users *whose device lists we are tracking* are flagged as needing a + /// key query. Users whose devices we are not tracking are ignored. + pub(crate) async fn mark_all_tracked_users_as_dirty( + &self, + store_cache: StoreCacheGuard, + ) -> StoreResult<()> { + let store_wrapper = store_cache.store_wrapper(); + let tracked_users = store_wrapper.load_tracked_users().await?; + + self.key_query_manager + .synced(&store_cache) + .await? + .mark_tracked_users_as_changed( + tracked_users.iter().map(|tracked_user| tracked_user.user_id.as_ref()), + ) + .await?; + + Ok(()) + } } /// Log information about what changed after processing a /keys/query response. diff --git a/crates/matrix-sdk-crypto/src/machine/mod.rs b/crates/matrix-sdk-crypto/src/machine/mod.rs index 5a2214566f..d0e985d056 100644 --- a/crates/matrix-sdk-crypto/src/machine/mod.rs +++ b/crates/matrix-sdk-crypto/src/machine/mod.rs @@ -185,30 +185,43 @@ impl OlmMachine { }) .await?; + let (verification_machine, store, identity_manager) = + Self::new_helper_prelude(store, static_account, self.store().private_identity()); + Ok(Self::new_helper( device_id, store, - static_account, + verification_machine, + identity_manager, self.store().private_identity(), None, )) } + fn new_helper_prelude( + store_wrapper: Arc, + account: StaticAccountData, + user_identity: Arc>, + ) -> (VerificationMachine, Store, IdentityManager) { + let verification_machine = + VerificationMachine::new(account.clone(), user_identity.clone(), store_wrapper.clone()); + let store = Store::new(account, user_identity, store_wrapper, verification_machine.clone()); + + let identity_manager = IdentityManager::new(store.clone()); + + (verification_machine, store, identity_manager) + } + fn new_helper( device_id: &DeviceId, - store: Arc, - account: StaticAccountData, + store: Store, + verification_machine: VerificationMachine, + identity_manager: IdentityManager, user_identity: Arc>, maybe_backup_key: Option, ) -> Self { - let verification_machine = - VerificationMachine::new(account.clone(), user_identity.clone(), store.clone()); - let store = Store::new(account, user_identity.clone(), store, verification_machine.clone()); - let group_session_manager = GroupSessionManager::new(store.clone()); - let identity_manager = IdentityManager::new(store.clone()); - let users_for_key_claim = Arc::new(StdRwLock::new(BTreeMap::new())); let key_request_machine = GossipMachine::new( store.clone(), @@ -360,11 +373,21 @@ impl OlmMachine { let identity = Arc::new(Mutex::new(identity)); let store = Arc::new(CryptoStoreWrapper::new(user_id, device_id, store)); + let (verification_machine, store, identity_manager) = + Self::new_helper_prelude(store, static_account, identity.clone()); + // FIXME: We might want in the future a more generic high-level data migration // mechanism (at the store wrapper layer). - Self::migration_post_verified_latch_support(&store).await?; + Self::migration_post_verified_latch_support(&store, &identity_manager).await?; - Ok(OlmMachine::new_helper(device_id, store, static_account, identity, maybe_backup_key)) + Ok(Self::new_helper( + device_id, + store, + verification_machine, + identity_manager, + identity, + maybe_backup_key, + )) } // The sdk now support verified identity change detection. @@ -375,19 +398,15 @@ impl OlmMachine { // // pub(crate) visibility for testing. pub(crate) async fn migration_post_verified_latch_support( - store: &CryptoStoreWrapper, + store: &Store, + identity_manager: &IdentityManager, ) -> Result<(), CryptoStoreError> { let maybe_migrate_for_identity_verified_latch = store.get_custom_value(Self::HAS_MIGRATED_VERIFICATION_LATCH).await?.is_none(); + if maybe_migrate_for_identity_verified_latch { - // We want to mark all tracked users as dirty to ensure the verified latch is - // set up correctly. - let tracked_user = store.load_tracked_users().await?; - let mut store_updates = Vec::with_capacity(tracked_user.len()); - tracked_user.iter().for_each(|tu| { - store_updates.push((tu.user_id.as_ref(), true)); - }); - store.save_tracked_users(&store_updates).await?; + identity_manager.mark_all_tracked_users_as_dirty(store.cache().await?).await?; + store.set_custom_value(Self::HAS_MIGRATED_VERIFICATION_LATCH, vec![0]).await? } Ok(()) @@ -1992,6 +2011,17 @@ impl OlmMachine { self.inner.identity_manager.update_tracked_users(users).await } + /// Mark all tracked users as dirty. + /// + /// All users *whose device lists we are tracking* are flagged as needing a + /// key query. Users whose devices we are not tracking are ignored. + pub async fn mark_all_tracked_users_as_dirty(&self) -> StoreResult<()> { + self.inner + .identity_manager + .mark_all_tracked_users_as_dirty(self.inner.store.cache().await?) + .await + } + async fn wait_if_user_pending( &self, user_id: &UserId, @@ -2404,10 +2434,10 @@ impl OlmMachine { Ok(()) } - #[cfg(any(feature = "testing", test))] /// Returns whether this `OlmMachine` is the same another one. /// /// Useful for testing purposes only. + #[cfg(any(feature = "testing", test))] pub fn same_as(&self, other: &OlmMachine) -> bool { Arc::ptr_eq(&self.inner, &other.inner) } @@ -2419,6 +2449,18 @@ impl OlmMachine { let account = cache.account().await?; Ok(account.uploaded_key_count()) } + + /// Returns the identity manager. + #[cfg(test)] + pub(crate) fn identity_manager(&self) -> &IdentityManager { + &self.inner.identity_manager + } + + /// Returns a store key, only useful for testing purposes. + #[cfg(test)] + pub(crate) fn key_for_has_migrated_verification_latch() -> &'static str { + Self::HAS_MIGRATED_VERIFICATION_LATCH + } } fn sender_data_to_verification_state( diff --git a/crates/matrix-sdk-crypto/src/machine/tests/mod.rs b/crates/matrix-sdk-crypto/src/machine/tests/mod.rs index 3d355fd603..8d7bea24da 100644 --- a/crates/matrix-sdk-crypto/src/machine/tests/mod.rs +++ b/crates/matrix-sdk-crypto/src/machine/tests/mod.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{collections::BTreeMap, iter, sync::Arc, time::Duration}; +use std::{collections::BTreeMap, iter, ops::Not, sync::Arc, time::Duration}; use assert_matches2::assert_matches; use futures_util::{pin_mut, FutureExt, StreamExt}; @@ -1528,6 +1528,43 @@ async fn test_unsigned_decryption() { assert_matches!(thread_encryption_result, UnsignedDecryptionResult::Decrypted(_)); } +#[async_test] +async fn test_mark_all_tracked_users_as_dirty() { + let store = MemoryStore::new(); + let account = vodozemac::olm::Account::new(); + + // Put some tracked users + let damir = user_id!("@damir:localhost"); + let ben = user_id!("@ben:localhost"); + let ivan = user_id!("@ivan:localhost"); + + // Mark them as not dirty. + store.save_tracked_users(&[(damir, false), (ben, false), (ivan, false)]).await.unwrap(); + + // Let's imagine the migration has been done: this is useful so that tracked + // users are not marked as dirty when creating the `OlmMachine`. + store + .set_custom_value(OlmMachine::key_for_has_migrated_verification_latch(), vec![0]) + .await + .unwrap(); + + let alice = + OlmMachine::with_store(user_id(), alice_device_id(), store, Some(account)).await.unwrap(); + + // All users are marked as not dirty. + alice.store().load_tracked_users().await.unwrap().iter().for_each(|tracked_user| { + assert!(tracked_user.dirty.not()); + }); + + // Now, mark all tracked users as dirty. + alice.mark_all_tracked_users_as_dirty().await.unwrap(); + + // All users are now marked as dirty. + alice.store().load_tracked_users().await.unwrap().iter().for_each(|tracked_user| { + assert!(tracked_user.dirty); + }); +} + #[async_test] async fn test_verified_latch_migration() { let store = MemoryStore::new(); @@ -1544,20 +1581,22 @@ async fn test_verified_latch_migration() { let alice = OlmMachine::with_store(user_id(), alice_device_id(), store, Some(account)).await.unwrap(); + let alice_store = alice.store(); + // A migration should have occurred and all users should be marked as dirty - alice.store().load_tracked_users().await.unwrap().iter().for_each(|tu| { + alice_store.load_tracked_users().await.unwrap().iter().for_each(|tu| { assert!(tu.dirty); }); // Ensure it does so only once - alice.store().save_tracked_users(&to_track_not_dirty).await.unwrap(); + alice_store.save_tracked_users(&to_track_not_dirty).await.unwrap(); - OlmMachine::migration_post_verified_latch_support(alice.store().crypto_store().as_ref()) + OlmMachine::migration_post_verified_latch_support(alice_store, alice.identity_manager()) .await .unwrap(); // Migration already done, so user should not be marked as dirty - alice.store().load_tracked_users().await.unwrap().iter().for_each(|tu| { + alice_store.load_tracked_users().await.unwrap().iter().for_each(|tu| { assert!(!tu.dirty); }); } diff --git a/crates/matrix-sdk-crypto/src/store/mod.rs b/crates/matrix-sdk-crypto/src/store/mod.rs index f294803b82..c06de54761 100644 --- a/crates/matrix-sdk-crypto/src/store/mod.rs +++ b/crates/matrix-sdk-crypto/src/store/mod.rs @@ -351,6 +351,10 @@ pub(crate) struct StoreCache { } impl StoreCache { + pub(crate) fn store_wrapper(&self) -> &CryptoStoreWrapper { + self.store.as_ref() + } + /// Returns a reference to the `Account`. /// /// Either load the account from the cache, or the store if missing from diff --git a/crates/matrix-sdk/src/sliding_sync/error.rs b/crates/matrix-sdk/src/sliding_sync/error.rs index 10706d860f..c72d5333e0 100644 --- a/crates/matrix-sdk/src/sliding_sync/error.rs +++ b/crates/matrix-sdk/src/sliding_sync/error.rs @@ -53,4 +53,14 @@ pub enum Error { /// The original `JoinError`. error: JoinError, }, + + /// No Olm machine. + #[cfg(feature = "e2e-encryption")] + #[error("The Olm machine is missing")] + NoOlmMachine, + + /// An error occurred during a E2EE operation. + #[cfg(feature = "e2e-encryption")] + #[error(transparent)] + CryptoStoreError(#[from] matrix_sdk_base::crypto::CryptoStoreError), } diff --git a/crates/matrix-sdk/src/sliding_sync/mod.rs b/crates/matrix-sdk/src/sliding_sync/mod.rs index de002be738..c767709582 100644 --- a/crates/matrix-sdk/src/sliding_sync/mod.rs +++ b/crates/matrix-sdk/src/sliding_sync/mod.rs @@ -491,6 +491,30 @@ impl SlidingSync { Span::current().record("pos", &pos); + // There is a non-negligible difference MSC3575 and MSC4186 in how + // the `e2ee` extension works. When the client sends a request with + // no `pos`: + // + // * MSC3575 returns all device lists updates since the last request from the + // device that asked for device lists (this works similarly to to-device + // message handling), + // * MSC4186 returns no device lists updates, as it only returns changes since + // the provided `pos` (which is `null` in this case); this is in line with + // sync v2. + // + // Therefore, with MSC4186, the device list cache must be marked as to be + // re-downloaded if the `since` token is `None`, otherwise it's easy to miss + // device lists updates that happened between the previous request and the new + // “initial” request. + #[cfg(feature = "e2e-encryption")] + if pos.is_none() && self.inner.version.is_native() && self.is_e2ee_enabled() { + info!("Marking all tracked users as dirty"); + + let olm_machine = self.inner.client.olm_machine().await; + let olm_machine = olm_machine.as_ref().ok_or(Error::NoOlmMachine)?; + olm_machine.mark_all_tracked_users_as_dirty().await?; + } + // Configure the timeout. // // The `timeout` query is necessary when all lists require it. Please see @@ -841,15 +865,9 @@ enum SlidingSyncInternalMessage { #[cfg(any(test, feature = "testing"))] impl SlidingSync { - /// Get a copy of the `pos` value. - pub fn pos(&self) -> Option { - let position_lock = self.inner.position.blocking_lock(); - position_lock.pos.clone() - } - /// Set a new value for `pos`. - pub fn set_pos(&self, new_pos: String) { - let mut position_lock = self.inner.position.blocking_lock(); + pub async fn set_pos(&self, new_pos: String) { + let mut position_lock = self.inner.position.lock().await; position_lock.pos = Some(new_pos); } @@ -1660,6 +1678,153 @@ mod tests { Ok(()) } + // With MSC4186, with the `e2ee` extension enabled, if a request has no `pos`, + // all the tracked users by the `OlmMachine` must be marked as dirty, i.e. + // `/key/query` requests must be sent. See the code to see the details. + // + // This test is asserting that. + #[async_test] + #[cfg(feature = "e2e-encryption")] + async fn test_no_pos_with_e2ee_marks_all_tracked_users_as_dirty() -> anyhow::Result<()> { + use matrix_sdk_base::crypto::{IncomingResponse, OutgoingRequests}; + use matrix_sdk_test::ruma_response_from_json; + use ruma::user_id; + + let server = MockServer::start().await; + let client = logged_in_client(Some(server.uri())).await; + + let alice = user_id!("@alice:localhost"); + let bob = user_id!("@bob:localhost"); + let me = user_id!("@example:localhost"); + + // Track and mark users are not dirty, so that we can check they are “dirty” + // after that. Dirty here means that a `/key/query` must be sent. + { + let olm_machine = client.olm_machine().await; + let olm_machine = olm_machine.as_ref().unwrap(); + + olm_machine.update_tracked_users([alice, bob]).await?; + + // Assert requests. + let outgoing_requests = olm_machine.outgoing_requests().await?; + + assert_eq!(outgoing_requests.len(), 2); + assert_matches!(outgoing_requests[0].request(), OutgoingRequests::KeysUpload(_)); + assert_matches!(outgoing_requests[1].request(), OutgoingRequests::KeysQuery(_)); + + // Fake responses. + olm_machine + .mark_request_as_sent( + outgoing_requests[0].request_id(), + IncomingResponse::KeysUpload(&ruma_response_from_json(&json!({ + "one_time_key_counts": {} + }))), + ) + .await?; + + olm_machine + .mark_request_as_sent( + outgoing_requests[1].request_id(), + IncomingResponse::KeysQuery(&ruma_response_from_json(&json!({ + "device_keys": { + alice: {}, + bob: {}, + } + }))), + ) + .await?; + + // Once more. + let outgoing_requests = olm_machine.outgoing_requests().await?; + + assert_eq!(outgoing_requests.len(), 1); + assert_matches!(outgoing_requests[0].request(), OutgoingRequests::KeysQuery(_)); + + olm_machine + .mark_request_as_sent( + outgoing_requests[0].request_id(), + IncomingResponse::KeysQuery(&ruma_response_from_json(&json!({ + "device_keys": { + me: {}, + } + }))), + ) + .await?; + + // No more. + let outgoing_requests = olm_machine.outgoing_requests().await?; + + assert!(outgoing_requests.is_empty()); + } + + let sync = client + .sliding_sync("test-slidingsync")? + .add_list(SlidingSyncList::builder("new_list")) + .with_e2ee_extension(assign!(http::request::E2EE::default(), { enabled: Some(true)})) + .build() + .await?; + + // First request: no `pos`. + let txn_id = TransactionId::new(); + let (_request, _, _) = sync + .generate_sync_request(&mut LazyTransactionId::from_owned(txn_id.to_owned())) + .await?; + + // Now, tracked users must be dirty. + { + let olm_machine = client.olm_machine().await; + let olm_machine = olm_machine.as_ref().unwrap(); + + // Assert requests. + let outgoing_requests = olm_machine.outgoing_requests().await?; + + assert_eq!(outgoing_requests.len(), 1); + assert_matches!( + outgoing_requests[0].request(), + OutgoingRequests::KeysQuery(request) => { + assert!(request.device_keys.contains_key(alice)); + assert!(request.device_keys.contains_key(bob)); + assert!(request.device_keys.contains_key(me)); + } + ); + + // Fake responses. + olm_machine + .mark_request_as_sent( + outgoing_requests[0].request_id(), + IncomingResponse::KeysQuery(&ruma_response_from_json(&json!({ + "device_keys": { + alice: {}, + bob: {}, + me: {}, + } + }))), + ) + .await?; + } + + // Second request: with a `pos` this time. + sync.set_pos("chocolat".to_owned()).await; + + let txn_id = TransactionId::new(); + let (_request, _, _) = sync + .generate_sync_request(&mut LazyTransactionId::from_owned(txn_id.to_owned())) + .await?; + + // Tracked users are not marked as dirty. + { + let olm_machine = client.olm_machine().await; + let olm_machine = olm_machine.as_ref().unwrap(); + + // Assert requests. + let outgoing_requests = olm_machine.outgoing_requests().await?; + + assert!(outgoing_requests.is_empty()); + } + + Ok(()) + } + #[async_test] async fn test_unknown_pos_resets_pos_and_sticky_parameters() -> Result<()> { let server = MockServer::start().await;