Skip to content

Commit

Permalink
Handle MatrixRTC encryption keys arriving out of order (#4345)
Browse files Browse the repository at this point in the history
* Handle MatrixRTC encryption keys arriving out of order

* Apply suggestions from code review

Co-authored-by: Andrew Ferrazzutti <andrewf@element.io>

* Suggestion from code review

---------

Co-authored-by: Andrew Ferrazzutti <andrewf@element.io>
  • Loading branch information
hughns and AndrewFerr committed Aug 15, 2024
1 parent c65ef03 commit 87eddaf
Show file tree
Hide file tree
Showing 2 changed files with 181 additions and 14 deletions.
129 changes: 128 additions & 1 deletion spec/unit/matrixrtc/MatrixRTCSession.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/

import { EventTimeline, EventType, MatrixClient, MatrixError, MatrixEvent, Room } from "../../../src";
import { encodeBase64, EventTimeline, EventType, MatrixClient, MatrixError, MatrixEvent, Room } from "../../../src";
import { KnownMembership } from "../../../src/@types/membership";
import {
CallMembershipData,
Expand Down Expand Up @@ -1145,6 +1145,7 @@ describe("MatrixRTCSession", () => {
],
}),
getSender: jest.fn().mockReturnValue("@bob:example.org"),
getTs: jest.fn().mockReturnValue(Date.now()),
} as unknown as MatrixEvent);

const bobKeys = sess.getKeysForParticipant("@bob:example.org", "bobsphone")!;
Expand All @@ -1168,6 +1169,7 @@ describe("MatrixRTCSession", () => {
],
}),
getSender: jest.fn().mockReturnValue("@bob:example.org"),
getTs: jest.fn().mockReturnValue(Date.now()),
} as unknown as MatrixEvent);

const bobKeys = sess.getKeysForParticipant("@bob:example.org", "bobsphone")!;
Expand All @@ -1179,6 +1181,130 @@ describe("MatrixRTCSession", () => {
expect(bobKeys[4]).toEqual(Buffer.from("this is the key", "utf-8"));
});

it("collects keys by merging", () => {
const mockRoom = makeMockRoom([membershipTemplate]);
sess = MatrixRTCSession.roomSessionForRoom(client, mockRoom);
sess.onCallEncryption({
getType: jest.fn().mockReturnValue("io.element.call.encryption_keys"),
getContent: jest.fn().mockReturnValue({
device_id: "bobsphone",
call_id: "",
keys: [
{
index: 0,
key: "dGhpcyBpcyB0aGUga2V5",
},
],
}),
getSender: jest.fn().mockReturnValue("@bob:example.org"),
getTs: jest.fn().mockReturnValue(Date.now()),
} as unknown as MatrixEvent);

let bobKeys = sess.getKeysForParticipant("@bob:example.org", "bobsphone")!;
expect(bobKeys).toHaveLength(1);
expect(bobKeys[0]).toEqual(Buffer.from("this is the key", "utf-8"));

sess.onCallEncryption({
getType: jest.fn().mockReturnValue("io.element.call.encryption_keys"),
getContent: jest.fn().mockReturnValue({
device_id: "bobsphone",
call_id: "",
keys: [
{
index: 4,
key: "dGhpcyBpcyB0aGUga2V5",
},
],
}),
getSender: jest.fn().mockReturnValue("@bob:example.org"),
getTs: jest.fn().mockReturnValue(Date.now()),
} as unknown as MatrixEvent);

bobKeys = sess.getKeysForParticipant("@bob:example.org", "bobsphone")!;
expect(bobKeys).toHaveLength(5);
expect(bobKeys[4]).toEqual(Buffer.from("this is the key", "utf-8"));
});

it("ignores older keys at same index", () => {
const mockRoom = makeMockRoom([membershipTemplate]);
sess = MatrixRTCSession.roomSessionForRoom(client, mockRoom);
sess.onCallEncryption({
getType: jest.fn().mockReturnValue("io.element.call.encryption_keys"),
getContent: jest.fn().mockReturnValue({
device_id: "bobsphone",
call_id: "",
keys: [
{
index: 0,
key: encodeBase64(Buffer.from("newer key", "utf-8")),
},
],
}),
getSender: jest.fn().mockReturnValue("@bob:example.org"),
getTs: jest.fn().mockReturnValue(2000),
} as unknown as MatrixEvent);

sess.onCallEncryption({
getType: jest.fn().mockReturnValue("io.element.call.encryption_keys"),
getContent: jest.fn().mockReturnValue({
device_id: "bobsphone",
call_id: "",
keys: [
{
index: 0,
key: encodeBase64(Buffer.from("older key", "utf-8")),
},
],
}),
getSender: jest.fn().mockReturnValue("@bob:example.org"),
getTs: jest.fn().mockReturnValue(1000), // earlier timestamp than the newer key
} as unknown as MatrixEvent);

const bobKeys = sess.getKeysForParticipant("@bob:example.org", "bobsphone")!;
expect(bobKeys).toHaveLength(1);
expect(bobKeys[0]).toEqual(Buffer.from("newer key", "utf-8"));
});

it("key timestamps are treated as monotonic", () => {
const mockRoom = makeMockRoom([membershipTemplate]);
sess = MatrixRTCSession.roomSessionForRoom(client, mockRoom);
sess.onCallEncryption({
getType: jest.fn().mockReturnValue("io.element.call.encryption_keys"),
getContent: jest.fn().mockReturnValue({
device_id: "bobsphone",
call_id: "",
keys: [
{
index: 0,
key: encodeBase64(Buffer.from("first key", "utf-8")),
},
],
}),
getSender: jest.fn().mockReturnValue("@bob:example.org"),
getTs: jest.fn().mockReturnValue(1000),
} as unknown as MatrixEvent);

sess.onCallEncryption({
getType: jest.fn().mockReturnValue("io.element.call.encryption_keys"),
getContent: jest.fn().mockReturnValue({
device_id: "bobsphone",
call_id: "",
keys: [
{
index: 0,
key: encodeBase64(Buffer.from("second key", "utf-8")),
},
],
}),
getSender: jest.fn().mockReturnValue("@bob:example.org"),
getTs: jest.fn().mockReturnValue(1000), // same timestamp as the first key
} as unknown as MatrixEvent);

const bobKeys = sess.getKeysForParticipant("@bob:example.org", "bobsphone")!;
expect(bobKeys).toHaveLength(1);
expect(bobKeys[0]).toEqual(Buffer.from("second key", "utf-8"));
});

it("ignores keys event for the local participant", () => {
const mockRoom = makeMockRoom([membershipTemplate]);
sess = MatrixRTCSession.roomSessionForRoom(client, mockRoom);
Expand All @@ -1195,6 +1321,7 @@ describe("MatrixRTCSession", () => {
],
}),
getSender: jest.fn().mockReturnValue(client.getUserId()),
getTs: jest.fn().mockReturnValue(Date.now()),
} as unknown as MatrixEvent);

const myKeys = sess.getKeysForParticipant(client.getUserId()!, client.getDeviceId()!)!;
Expand Down
66 changes: 53 additions & 13 deletions src/matrixrtc/MatrixRTCSession.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ const USE_KEY_DELAY = 5000;
const getParticipantId = (userId: string, deviceId: string): string => `${userId}:${deviceId}`;
const getParticipantIdFromMembership = (m: CallMembership): string => getParticipantId(m.sender!, m.deviceId);

function keysEqual(a: Uint8Array, b: Uint8Array): boolean {
function keysEqual(a: Uint8Array | undefined, b: Uint8Array | undefined): boolean {
if (a === b) return true;
return a && b && a.length === b.length && a.every((x, i) => x === b[i]);
return !!a && !!b && a.length === b.length && a.every((x, i) => x === b[i]);
}

export enum MatrixRTCSessionEvent {
Expand Down Expand Up @@ -134,8 +134,8 @@ export class MatrixRTCSession extends TypedEventEmitter<MatrixRTCSessionEvent, M

private manageMediaKeys = false;
private useLegacyMemberEvents = true;
// userId:deviceId => array of keys
private encryptionKeys = new Map<string, Array<Uint8Array>>();
// userId:deviceId => array of (key, timestamp)
private encryptionKeys = new Map<string, Array<{ key: Uint8Array; timestamp: number }>>();
private lastEncryptionKeyUpdateRequest?: number;

// We use this to store the last membership fingerprints we saw, so we can proactively re-send encryption keys
Expand Down Expand Up @@ -378,16 +378,26 @@ export class MatrixRTCSession extends TypedEventEmitter<MatrixRTCSessionEvent, M
}
}

/**
* Get the known encryption keys for a given participant device.
*
* @param userId the user ID of the participant
* @param deviceId the device ID of the participant
* @returns The encryption keys for the given participant, or undefined if they are not known.
*/
public getKeysForParticipant(userId: string, deviceId: string): Array<Uint8Array> | undefined {
return this.encryptionKeys.get(getParticipantId(userId, deviceId));
return this.encryptionKeys.get(getParticipantId(userId, deviceId))?.map((entry) => entry.key);
}

/**
* A map of keys used to encrypt and decrypt (we are using a symmetric
* cipher) given participant's media. This also includes our own key
*/
public getEncryptionKeys(): IterableIterator<[string, Array<Uint8Array>]> {
return this.encryptionKeys.entries();
// the returned array doesn't contain the timestamps
return Array.from(this.encryptionKeys.entries())
.map(([participantId, keys]): [string, Uint8Array[]] => [participantId, keys.map((k) => k.key)])
.values();
}

private getNewEncryptionKeyIndex(): number {
Expand All @@ -402,12 +412,14 @@ export class MatrixRTCSession extends TypedEventEmitter<MatrixRTCSessionEvent, M

/**
* Sets an encryption key at a specified index for a participant.
* The encryption keys for the local participanmt are also stored here under the
* The encryption keys for the local participant are also stored here under the
* user and device ID of the local participant.
* If the key is older than the existing key at the index, it will be ignored.
* @param userId - The user ID of the participant
* @param deviceId - Device ID of the participant
* @param encryptionKeyIndex - The index of the key to set
* @param encryptionKeyString - The string representation of the key to set in base64
* @param timestamp - The timestamp of the key. We assume that these are monotonic for each participant device.
* @param delayBeforeUse - If true, delay before emitting a key changed event. Useful when setting
* encryption keys for the local participant to allow time for the key to
* be distributed.
Expand All @@ -417,17 +429,38 @@ export class MatrixRTCSession extends TypedEventEmitter<MatrixRTCSessionEvent, M
deviceId: string,
encryptionKeyIndex: number,
encryptionKeyString: string,
timestamp: number,
delayBeforeUse = false,
): void {
const keyBin = decodeBase64(encryptionKeyString);

const participantId = getParticipantId(userId, deviceId);
const encryptionKeys = this.encryptionKeys.get(participantId) ?? [];
if (!this.encryptionKeys.has(participantId)) {
this.encryptionKeys.set(participantId, []);
}
const participantKeys = this.encryptionKeys.get(participantId)!;

if (keysEqual(encryptionKeys[encryptionKeyIndex], keyBin)) return;
const existingKeyAtIndex = participantKeys[encryptionKeyIndex];

if (existingKeyAtIndex) {
if (existingKeyAtIndex.timestamp > timestamp) {
logger.info(
`Ignoring new key at index ${encryptionKeyIndex} for ${participantId} as it is older than existing known key`,
);
return;
}

if (keysEqual(existingKeyAtIndex.key, keyBin)) {
existingKeyAtIndex.timestamp = timestamp;
return;
}
}

participantKeys[encryptionKeyIndex] = {
key: keyBin,
timestamp,
};

encryptionKeys[encryptionKeyIndex] = keyBin;
this.encryptionKeys.set(participantId, encryptionKeys);
if (delayBeforeUse) {
const useKeyTimeout = setTimeout(() => {
this.setNewKeyTimeouts.delete(useKeyTimeout);
Expand Down Expand Up @@ -455,7 +488,7 @@ export class MatrixRTCSession extends TypedEventEmitter<MatrixRTCSessionEvent, M
const encryptionKey = secureRandomBase64Url(16);
const encryptionKeyIndex = this.getNewEncryptionKeyIndex();
logger.info("Generated new key at index " + encryptionKeyIndex);
this.setEncryptionKey(userId, deviceId, encryptionKeyIndex, encryptionKey, delayBeforeUse);
this.setEncryptionKey(userId, deviceId, encryptionKeyIndex, encryptionKey, Date.now(), delayBeforeUse);
}

/**
Expand Down Expand Up @@ -574,6 +607,13 @@ export class MatrixRTCSession extends TypedEventEmitter<MatrixRTCSessionEvent, M
}
}

/**
* Process `m.call.encryption_keys` events to track the encryption keys for call participants.
* This should be called each time the relevant event is received from a room timeline.
* If the event is malformed then it will be logged and ignored.
*
* @param event the event to process
*/
public onCallEncryption = (event: MatrixEvent): void => {
const userId = event.getSender();
const content = event.getContent<EncryptionKeysEventContent>();
Expand Down Expand Up @@ -635,7 +675,7 @@ export class MatrixRTCSession extends TypedEventEmitter<MatrixRTCSessionEvent, M
`Embedded-E2EE-LOG onCallEncryption userId=${userId}:${deviceId} encryptionKeyIndex=${encryptionKeyIndex}`,
this.encryptionKeys,
);
this.setEncryptionKey(userId, deviceId, encryptionKeyIndex, encryptionKey);
this.setEncryptionKey(userId, deviceId, encryptionKeyIndex, encryptionKey, event.getTs());
}
}
};
Expand Down

0 comments on commit 87eddaf

Please sign in to comment.