diff --git a/crates/matrix-sdk/src/sliding_sync/mod.rs b/crates/matrix-sdk/src/sliding_sync/mod.rs index cef6d4e5e7..065ec02682 100644 --- a/crates/matrix-sdk/src/sliding_sync/mod.rs +++ b/crates/matrix-sdk/src/sliding_sync/mod.rs @@ -506,8 +506,7 @@ impl SlidingSync { // re-downloaded if the `since` token is `None`, otherwise it's easy to miss // device lists updates that happened between the previous request and the new // “initial” request. - #[cfg(feature = "e2e-encryption")] - if pos.is_none() && self.inner.version.is_native() { + if pos.is_none() && self.inner.version.is_native() && self.is_e2ee_enabled() { let olm_machine = self.inner.client.olm_machine().await; let olm_machine = olm_machine.as_ref().ok_or(Error::NoOlmMachine)?; olm_machine.mark_all_tracked_users_as_dirty().await?; @@ -865,15 +864,9 @@ enum SlidingSyncInternalMessage { #[cfg(any(test, feature = "testing"))] impl SlidingSync { - /// Get a copy of the `pos` value. - pub fn pos(&self) -> Option { - let position_lock = self.inner.position.blocking_lock(); - position_lock.pos.clone() - } - /// Set a new value for `pos`. - pub fn set_pos(&self, new_pos: String) { - let mut position_lock = self.inner.position.blocking_lock(); + pub async fn set_pos(&self, new_pos: String) { + let mut position_lock = self.inner.position.lock().await; position_lock.pos = Some(new_pos); } @@ -1684,6 +1677,153 @@ mod tests { Ok(()) } + // With MSC4186, with the `e2ee` extension enabled, if a request has no `pos`, + // all the tracked users by the `OlmMachine` must be marked as dirty, i.e. + // `/key/query` requests must be sent. See the code to see the details. + // + // This test is asserting that. + #[async_test] + #[cfg(feature = "e2e-encryption")] + async fn test_no_pos_with_e2ee_marks_all_tracked_users_as_dirty() -> anyhow::Result<()> { + use matrix_sdk_base::crypto::{IncomingResponse, OutgoingRequests}; + use matrix_sdk_test::ruma_response_from_json; + use ruma::user_id; + + let server = MockServer::start().await; + let client = logged_in_client(Some(server.uri())).await; + + let alice = user_id!("@alice:localhost"); + let bob = user_id!("@bob:localhost"); + let me = user_id!("@example:localhost"); + + // Track and mark users are not dirty, so that we can check they are “dirty” + // after that. Dirty here means that a `/key/query` must be sent. + { + let olm_machine = client.olm_machine().await; + let olm_machine = olm_machine.as_ref().unwrap(); + + olm_machine.update_tracked_users([alice, bob]).await?; + + // Assert requests. + let outgoing_requests = olm_machine.outgoing_requests().await?; + + assert_eq!(outgoing_requests.len(), 2); + assert_matches!(outgoing_requests[0].request(), OutgoingRequests::KeysUpload(_)); + assert_matches!(outgoing_requests[1].request(), OutgoingRequests::KeysQuery(_)); + + // Fake responses. + olm_machine + .mark_request_as_sent( + outgoing_requests[0].request_id(), + IncomingResponse::KeysUpload(&ruma_response_from_json(&json!({ + "one_time_key_counts": {} + }))), + ) + .await?; + + olm_machine + .mark_request_as_sent( + outgoing_requests[1].request_id(), + IncomingResponse::KeysQuery(&ruma_response_from_json(&json!({ + "device_keys": { + alice: {}, + bob: {}, + } + }))), + ) + .await?; + + // Once more. + let outgoing_requests = olm_machine.outgoing_requests().await?; + + assert_eq!(outgoing_requests.len(), 1); + assert_matches!(outgoing_requests[0].request(), OutgoingRequests::KeysQuery(_)); + + olm_machine + .mark_request_as_sent( + outgoing_requests[0].request_id(), + IncomingResponse::KeysQuery(&ruma_response_from_json(&json!({ + "device_keys": { + me: {}, + } + }))), + ) + .await?; + + // No more. + let outgoing_requests = olm_machine.outgoing_requests().await?; + + assert!(outgoing_requests.is_empty()); + } + + let sync = client + .sliding_sync("test-slidingsync")? + .add_list(SlidingSyncList::builder("new_list")) + .with_e2ee_extension(assign!(http::request::E2EE::default(), { enabled: Some(true)})) + .build() + .await?; + + // First request: no `pos`. + let txn_id = TransactionId::new(); + let (_request, _, _) = sync + .generate_sync_request(&mut LazyTransactionId::from_owned(txn_id.to_owned())) + .await?; + + // Now, tracked users must be dirty. + { + let olm_machine = client.olm_machine().await; + let olm_machine = olm_machine.as_ref().unwrap(); + + // Assert requests. + let outgoing_requests = olm_machine.outgoing_requests().await?; + + assert_eq!(outgoing_requests.len(), 1); + assert_matches!( + outgoing_requests[0].request(), + OutgoingRequests::KeysQuery(request) => { + assert!(request.device_keys.contains_key(alice)); + assert!(request.device_keys.contains_key(bob)); + assert!(request.device_keys.contains_key(me)); + } + ); + + // Fake responses. + olm_machine + .mark_request_as_sent( + outgoing_requests[0].request_id(), + IncomingResponse::KeysQuery(&ruma_response_from_json(&json!({ + "device_keys": { + alice: {}, + bob: {}, + me: {}, + } + }))), + ) + .await?; + } + + // Second request: with a `pos` this time. + sync.set_pos("chocolat".to_owned()).await; + + let txn_id = TransactionId::new(); + let (_request, _, _) = sync + .generate_sync_request(&mut LazyTransactionId::from_owned(txn_id.to_owned())) + .await?; + + // Tracked users are not marked as dirty. + { + let olm_machine = client.olm_machine().await; + let olm_machine = olm_machine.as_ref().unwrap(); + + // Assert requests. + let outgoing_requests = olm_machine.outgoing_requests().await?; + + assert!(outgoing_requests.is_empty()); + } + + Ok(()) + } + #[async_test] async fn test_unknown_pos_resets_pos_and_sticky_parameters() -> Result<()> { let server = MockServer::start().await;