diff --git a/src/Lifecycle.ts b/src/Lifecycle.ts index e11b9b8a01e..23873140667 100644 --- a/src/Lifecycle.ts +++ b/src/Lifecycle.ts @@ -18,7 +18,7 @@ limitations under the License. */ import { ReactNode } from "react"; -import { createClient, MatrixClient, SSOAction } from "matrix-js-sdk/src/matrix"; +import { createClient, MatrixClient, SSOAction, OidcTokenRefresher } from "matrix-js-sdk/src/matrix"; import { InvalidStoreError } from "matrix-js-sdk/src/errors"; import { IEncryptedPayload } from "matrix-js-sdk/src/crypto/aes"; import { QueryDict } from "matrix-js-sdk/src/utils"; @@ -65,7 +65,12 @@ import { OverwriteLoginPayload } from "./dispatcher/payloads/OverwriteLoginPaylo import { SdkContextClass } from "./contexts/SDKContext"; import { messageForLoginError } from "./utils/ErrorUtils"; import { completeOidcLogin } from "./utils/oidc/authorize"; -import { persistOidcAuthenticatedSettings } from "./utils/oidc/persistOidcSettings"; +import { + getStoredOidcClientId, + getStoredOidcIdTokenClaims, + getStoredOidcTokenIssuer, + persistOidcAuthenticatedSettings, +} from "./utils/oidc/persistOidcSettings"; import GenericToast from "./components/views/toasts/GenericToast"; import { ACCESS_TOKEN_IV, @@ -78,6 +83,7 @@ import { REFRESH_TOKEN_STORAGE_KEY, tryDecryptToken, } from "./utils/tokens/tokens"; +import { TokenRefresher } from "./utils/oidc/TokenRefresher"; const HOMESERVER_URL_KEY = "mx_hs_url"; const ID_SERVER_URL_KEY = "mx_is_url"; @@ -746,6 +752,45 @@ export async function hydrateSession(credentials: IMatrixClientCreds): Promise { + if (!credentials.refreshToken) { + return; + } + // stored token issuer indicates we authenticated via OIDC-native flow + const tokenIssuer = getStoredOidcTokenIssuer(); + if (!tokenIssuer) { + return; + } + try { + const clientId = getStoredOidcClientId(); + const idTokenClaims = getStoredOidcIdTokenClaims(); + const redirectUri = window.location.origin; + const deviceId = credentials.deviceId; + if (!deviceId) { + throw new Error("Expected deviceId in user credentials."); + } + const tokenRefresher = new TokenRefresher( + { issuer: tokenIssuer }, + clientId, + redirectUri, + deviceId, + idTokenClaims!, + credentials.userId, + ); + // wait for the OIDC client to initialise + await tokenRefresher.oidcClientReady; + return tokenRefresher; + } catch (error) { + logger.error("Failed to initialise OIDC token refresher", error); + } +} + /** * optionally clears localstorage, persists new credentials * to localstorage, starts the new client. @@ -787,9 +832,11 @@ async function doSetLoggedIn(credentials: IMatrixClientCreds, clearStorageEnable await abortLogin(); } + const tokenRefresher = await createOidcTokenRefresher(credentials); + // check the session lock just before creating the new client checkSessionLock(); - MatrixClientPeg.replaceUsingCreds(credentials); + MatrixClientPeg.replaceUsingCreds(credentials, tokenRefresher?.doRefreshAccessToken.bind(tokenRefresher)); const client = MatrixClientPeg.safeGet(); setSentryUser(credentials.userId); diff --git a/src/MatrixClientPeg.ts b/src/MatrixClientPeg.ts index 5ae2d16dba7..fbc54b38967 100644 --- a/src/MatrixClientPeg.ts +++ b/src/MatrixClientPeg.ts @@ -27,6 +27,7 @@ import { IStartClientOpts, MatrixClient, MemoryStore, + TokenRefreshFunction, } from "matrix-js-sdk/src/matrix"; import * as utils from "matrix-js-sdk/src/utils"; import { verificationMethods } from "matrix-js-sdk/src/crypto"; @@ -122,8 +123,10 @@ export interface IMatrixClientPeg { * homeserver / identity server URLs and active credentials * * @param {IMatrixClientCreds} creds The new credentials to use. + * @param {TokenRefreshFunction} tokenRefreshFunction OPTIONAL function used by MatrixClient to attempt token refresh + * see {@link ICreateClientOpts.tokenRefreshFunction} */ - replaceUsingCreds(creds: IMatrixClientCreds): void; + replaceUsingCreds(creds: IMatrixClientCreds, tokenRefreshFunction?: TokenRefreshFunction): void; } /** @@ -196,8 +199,8 @@ class MatrixClientPegClass implements IMatrixClientPeg { } } - public replaceUsingCreds(creds: IMatrixClientCreds): void { - this.createClient(creds); + public replaceUsingCreds(creds: IMatrixClientCreds, tokenRefreshFunction?: TokenRefreshFunction): void { + this.createClient(creds, tokenRefreshFunction); } private onUnexpectedStoreClose = async (): Promise => { @@ -378,11 +381,13 @@ class MatrixClientPegClass implements IMatrixClientPeg { }); } - private createClient(creds: IMatrixClientCreds): void { + private createClient(creds: IMatrixClientCreds, tokenRefreshFunction?: TokenRefreshFunction): void { const opts: ICreateClientOpts = { baseUrl: creds.homeserverUrl, idBaseUrl: creds.identityServerUrl, accessToken: creds.accessToken, + refreshToken: creds.refreshToken, + tokenRefreshFunction, userId: creds.userId, deviceId: creds.deviceId, pickleKey: creds.pickleKey, diff --git a/src/utils/oidc/TokenRefresher.ts b/src/utils/oidc/TokenRefresher.ts new file mode 100644 index 00000000000..a6a0be29be7 --- /dev/null +++ b/src/utils/oidc/TokenRefresher.ts @@ -0,0 +1,47 @@ +/* +Copyright 2023 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +import { IDelegatedAuthConfig, OidcTokenRefresher, AccessTokens } from "matrix-js-sdk/src/matrix"; +import { IdTokenClaims } from "oidc-client-ts"; + +import PlatformPeg from "../../PlatformPeg"; +import { persistAccessTokenInStorage, persistRefreshTokenInStorage } from "../tokens/tokens"; + +/** + * OidcTokenRefresher that implements token persistence. + * Stores tokens in the same way as login flow in Lifecycle. + */ +export class TokenRefresher extends OidcTokenRefresher { + private readonly deviceId!: string; + + public constructor( + authConfig: IDelegatedAuthConfig, + clientId: string, + redirectUri: string, + deviceId: string, + idTokenClaims: IdTokenClaims, + private readonly userId: string, + ) { + super(authConfig, clientId, deviceId, redirectUri, idTokenClaims); + this.deviceId = deviceId; + } + + public async persistTokens({ accessToken, refreshToken }: AccessTokens): Promise { + const pickleKey = (await PlatformPeg.get()?.getPickleKey(this.userId, this.deviceId)) ?? undefined; + await persistAccessTokenInStorage(accessToken, pickleKey); + await persistRefreshTokenInStorage(refreshToken, pickleKey); + } +} diff --git a/src/utils/oidc/persistOidcSettings.ts b/src/utils/oidc/persistOidcSettings.ts index f5132291be1..da4510bbacb 100644 --- a/src/utils/oidc/persistOidcSettings.ts +++ b/src/utils/oidc/persistOidcSettings.ts @@ -57,3 +57,15 @@ export const getStoredOidcClientId = (): string => { } return clientId; }; + +/** + * Retrieve stored id token claims from session storage + * @returns idtokenclaims or undefined + */ +export const getStoredOidcIdTokenClaims = (): IdTokenClaims | undefined => { + const idTokenClaims = sessionStorage.getItem(idTokenClaimsStorageKey); + if (!idTokenClaims) { + return; + } + return JSON.parse(idTokenClaims) as IdTokenClaims; +}; diff --git a/test/Lifecycle-test.ts b/test/Lifecycle-test.ts index c44ea31d698..802556f868f 100644 --- a/test/Lifecycle-test.ts +++ b/test/Lifecycle-test.ts @@ -19,6 +19,7 @@ import { logger } from "matrix-js-sdk/src/logger"; import * as MatrixJs from "matrix-js-sdk/src/matrix"; import { setCrypto } from "matrix-js-sdk/src/crypto/crypto"; import * as MatrixCryptoAes from "matrix-js-sdk/src/crypto/aes"; +import fetchMock from "fetch-mock-jest"; import StorageEvictedDialog from "../src/components/views/dialogs/StorageEvictedDialog"; import { restoreFromLocalStorage, setLoggedIn } from "../src/Lifecycle"; @@ -27,6 +28,8 @@ import Modal from "../src/Modal"; import * as StorageManager from "../src/utils/StorageManager"; import { getMockClientWithEventEmitter, mockPlatformPeg } from "./test-utils"; import ToastStore from "../src/stores/ToastStore"; +import { makeDelegatedAuthConfig } from "./test-utils/oidc"; +import { persistOidcAuthenticatedSettings } from "../src/utils/oidc/persistOidcSettings"; const webCrypto = new Crypto(); @@ -233,6 +236,7 @@ describe("Lifecycle", () => { userId, guest: true, }), + undefined, ); expect(localStorage.setItem).toHaveBeenCalledWith("mx_is_guest", "true"); }); @@ -264,16 +268,19 @@ describe("Lifecycle", () => { it("should create new matrix client with credentials", async () => { expect(await restoreFromLocalStorage()).toEqual(true); - expect(MatrixClientPeg.replaceUsingCreds).toHaveBeenCalledWith({ - userId, - accessToken, - homeserverUrl, - identityServerUrl, - deviceId, - freshLogin: false, - guest: false, - pickleKey: undefined, - }); + expect(MatrixClientPeg.replaceUsingCreds).toHaveBeenCalledWith( + { + userId, + accessToken, + homeserverUrl, + identityServerUrl, + deviceId, + freshLogin: false, + guest: false, + pickleKey: undefined, + }, + undefined, + ); }); it("should remove fresh login flag from session storage", async () => { @@ -312,18 +319,21 @@ describe("Lifecycle", () => { it("should create new matrix client with credentials", async () => { expect(await restoreFromLocalStorage()).toEqual(true); - expect(MatrixClientPeg.replaceUsingCreds).toHaveBeenCalledWith({ - userId, - accessToken, - // refreshToken included in credentials - refreshToken, - homeserverUrl, - identityServerUrl, - deviceId, - freshLogin: false, - guest: false, - pickleKey: undefined, - }); + expect(MatrixClientPeg.replaceUsingCreds).toHaveBeenCalledWith( + { + userId, + accessToken, + // refreshToken included in credentials + refreshToken, + homeserverUrl, + identityServerUrl, + deviceId, + freshLogin: false, + guest: false, + pickleKey: undefined, + }, + undefined, + ); }); }); }); @@ -373,17 +383,20 @@ describe("Lifecycle", () => { it("should create new matrix client with credentials", async () => { expect(await restoreFromLocalStorage()).toEqual(true); - expect(MatrixClientPeg.replaceUsingCreds).toHaveBeenCalledWith({ - userId, - // decrypted accessToken - accessToken, - homeserverUrl, - identityServerUrl, - deviceId, - freshLogin: true, - guest: false, - pickleKey: expect.any(String), - }); + expect(MatrixClientPeg.replaceUsingCreds).toHaveBeenCalledWith( + { + userId, + // decrypted accessToken + accessToken, + homeserverUrl, + identityServerUrl, + deviceId, + freshLogin: true, + guest: false, + pickleKey: expect.any(String), + }, + undefined, + ); }); describe("with a refresh token", () => { @@ -412,18 +425,21 @@ describe("Lifecycle", () => { it("should create new matrix client with credentials", async () => { expect(await restoreFromLocalStorage()).toEqual(true); - expect(MatrixClientPeg.replaceUsingCreds).toHaveBeenCalledWith({ - userId, - accessToken, - // refreshToken included in credentials - refreshToken, - homeserverUrl, - identityServerUrl, - deviceId, - freshLogin: false, - guest: false, - pickleKey: expect.any(String), - }); + expect(MatrixClientPeg.replaceUsingCreds).toHaveBeenCalledWith( + { + userId, + accessToken, + // refreshToken included in credentials + refreshToken, + homeserverUrl, + identityServerUrl, + deviceId, + freshLogin: false, + guest: false, + pickleKey: expect.any(String), + }, + undefined, + ); }); }); }); @@ -529,16 +545,19 @@ describe("Lifecycle", () => { it("should create new matrix client with credentials", async () => { expect(await setLoggedIn(credentials)).toEqual(mockClient); - expect(MatrixClientPeg.replaceUsingCreds).toHaveBeenCalledWith({ - userId, - accessToken, - homeserverUrl, - identityServerUrl, - deviceId, - freshLogin: true, - guest: false, - pickleKey: null, - }); + expect(MatrixClientPeg.replaceUsingCreds).toHaveBeenCalledWith( + { + userId, + accessToken, + homeserverUrl, + identityServerUrl, + deviceId, + freshLogin: true, + guest: false, + pickleKey: null, + }, + undefined, + ); }); }); @@ -628,16 +647,132 @@ describe("Lifecycle", () => { it("should create new matrix client with credentials", async () => { expect(await setLoggedIn(credentials)).toEqual(mockClient); - expect(MatrixClientPeg.replaceUsingCreds).toHaveBeenCalledWith({ - userId, - accessToken, - homeserverUrl, - identityServerUrl, - deviceId, - freshLogin: true, - guest: false, - pickleKey: expect.any(String), + expect(MatrixClientPeg.replaceUsingCreds).toHaveBeenCalledWith( + { + userId, + accessToken, + homeserverUrl, + identityServerUrl, + deviceId, + freshLogin: true, + guest: false, + pickleKey: expect.any(String), + }, + undefined, + ); + }); + }); + + describe("when authenticated via OIDC native flow", () => { + const clientId = "test-client-id"; + const issuer = "https://auth.com/"; + + const delegatedAuthConfig = makeDelegatedAuthConfig(issuer); + const idTokenClaims = { + aud: "123", + iss: issuer, + sub: "123", + exp: 123, + iat: 456, + }; + + beforeAll(() => { + fetchMock.get( + `${delegatedAuthConfig.issuer}.well-known/openid-configuration`, + delegatedAuthConfig.metadata, + ); + fetchMock.get(`${delegatedAuthConfig.issuer}jwks`, { + status: 200, + headers: { + "Content-Type": "application/json", + }, + keys: [], + }); + }); + + beforeEach(() => { + // mock oidc config for oidc client initialisation + mockClient.getClientWellKnown.mockReturnValue({ + "m.authentication": { + issuer: issuer, + }, }); + initSessionStorageMock(); + // set values in session storage as they would be after a successful oidc authentication + persistOidcAuthenticatedSettings(clientId, issuer, idTokenClaims); + }); + + it("should not try to create a token refresher without a refresh token", async () => { + await setLoggedIn(credentials); + + // didn't try to initialise token refresher + expect(fetchMock).not.toHaveFetched(`${delegatedAuthConfig.issuer}.well-known/openid-configuration`); + }); + + it("should not try to create a token refresher without a deviceId", async () => { + await setLoggedIn({ + ...credentials, + refreshToken, + deviceId: undefined, + }); + + // didn't try to initialise token refresher + expect(fetchMock).not.toHaveFetched(`${delegatedAuthConfig.issuer}.well-known/openid-configuration`); + }); + + it("should not try to create a token refresher without an issuer in session storage", async () => { + persistOidcAuthenticatedSettings( + clientId, + // @ts-ignore set undefined issuer + undefined, + idTokenClaims, + ); + await setLoggedIn({ + ...credentials, + refreshToken, + }); + + // didn't try to initialise token refresher + expect(fetchMock).not.toHaveFetched(`${delegatedAuthConfig.issuer}.well-known/openid-configuration`); + }); + + it("should create a client with a tokenRefreshFunction", async () => { + expect( + await setLoggedIn({ + ...credentials, + refreshToken, + }), + ).toEqual(mockClient); + + expect(MatrixClientPeg.replaceUsingCreds).toHaveBeenCalledWith( + expect.objectContaining({ + accessToken, + refreshToken, + }), + expect.any(Function), + ); + }); + + it("should create a client when creating token refresher fails", async () => { + // set invalid value in session storage for a malformed oidc authentication + persistOidcAuthenticatedSettings(null as any, issuer, idTokenClaims); + + // succeeded + expect( + await setLoggedIn({ + ...credentials, + refreshToken, + }), + ).toEqual(mockClient); + + expect(MatrixClientPeg.replaceUsingCreds).toHaveBeenCalledWith( + expect.objectContaining({ + accessToken, + refreshToken, + }), + // no token refresh function + undefined, + ); }); }); }); diff --git a/test/utils/oidc/TokenRefresher-test.ts b/test/utils/oidc/TokenRefresher-test.ts new file mode 100644 index 00000000000..46b33da52a8 --- /dev/null +++ b/test/utils/oidc/TokenRefresher-test.ts @@ -0,0 +1,96 @@ +/* +Copyright 2023 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +import fetchMock from "fetch-mock-jest"; +import { mocked } from "jest-mock"; + +import { TokenRefresher } from "../../../src/utils/oidc/TokenRefresher"; +import { persistAccessTokenInStorage, persistRefreshTokenInStorage } from "../../../src/utils/tokens/tokens"; +import { mockPlatformPeg } from "../../test-utils"; +import { makeDelegatedAuthConfig } from "../../test-utils/oidc"; + +jest.mock("../../../src/utils/tokens/tokens", () => ({ + persistAccessTokenInStorage: jest.fn(), + persistRefreshTokenInStorage: jest.fn(), +})); + +describe("TokenRefresher", () => { + const clientId = "test-client-id"; + const issuer = "https://auth.com/"; + const redirectUri = "https://test.com"; + const deviceId = "test-device-id"; + const userId = "@alice:server.org"; + const accessToken = "test-access-token"; + const refreshToken = "test-refresh-token"; + + const authConfig = makeDelegatedAuthConfig(issuer); + const idTokenClaims = { + aud: "123", + iss: issuer, + sub: "123", + exp: 123, + iat: 456, + }; + + beforeEach(() => { + fetchMock.get(`${authConfig.issuer}.well-known/openid-configuration`, authConfig.metadata); + fetchMock.get(`${authConfig.issuer}jwks`, { + status: 200, + headers: { + "Content-Type": "application/json", + }, + keys: [], + }); + + mocked(persistAccessTokenInStorage).mockResolvedValue(undefined); + mocked(persistRefreshTokenInStorage).mockResolvedValue(undefined); + }); + + afterEach(() => { + jest.restoreAllMocks(); + }); + + it("should persist tokens with a pickle key", async () => { + const pickleKey = "test-pickle-key"; + const getPickleKey = jest.fn().mockResolvedValue(pickleKey); + mockPlatformPeg({ getPickleKey }); + + const refresher = new TokenRefresher(authConfig, clientId, redirectUri, deviceId, idTokenClaims, userId); + + await refresher.oidcClientReady; + + await refresher.persistTokens({ accessToken, refreshToken }); + + expect(getPickleKey).toHaveBeenCalledWith(userId, deviceId); + expect(persistAccessTokenInStorage).toHaveBeenCalledWith(accessToken, pickleKey); + expect(persistRefreshTokenInStorage).toHaveBeenCalledWith(refreshToken, pickleKey); + }); + + it("should persist tokens without a pickle key", async () => { + const getPickleKey = jest.fn().mockResolvedValue(null); + mockPlatformPeg({ getPickleKey }); + + const refresher = new TokenRefresher(authConfig, clientId, redirectUri, deviceId, idTokenClaims, userId); + + await refresher.oidcClientReady; + + await refresher.persistTokens({ accessToken, refreshToken }); + + expect(getPickleKey).toHaveBeenCalledWith(userId, deviceId); + expect(persistAccessTokenInStorage).toHaveBeenCalledWith(accessToken, undefined); + expect(persistRefreshTokenInStorage).toHaveBeenCalledWith(refreshToken, undefined); + }); +}); diff --git a/test/utils/oidc/persistOidcSettings-test.ts b/test/utils/oidc/persistOidcSettings-test.ts index f71a7f3ed66..03ac61d199a 100644 --- a/test/utils/oidc/persistOidcSettings-test.ts +++ b/test/utils/oidc/persistOidcSettings-test.ts @@ -18,6 +18,7 @@ import { IdTokenClaims } from "oidc-client-ts"; import { getStoredOidcClientId, + getStoredOidcIdTokenClaims, getStoredOidcTokenIssuer, persistOidcAuthenticatedSettings, } from "../../../src/utils/oidc/persistOidcSettings"; @@ -75,4 +76,16 @@ describe("persist OIDC settings", () => { expect(() => getStoredOidcClientId()).toThrow("Oidc client id not found in storage"); }); }); + + describe("getStoredOidcIdTokenClaims()", () => { + it("should return issuer from session storage", () => { + jest.spyOn(sessionStorage.__proto__, "getItem").mockReturnValue(JSON.stringify(idTokenClaims)); + expect(getStoredOidcIdTokenClaims()).toEqual(idTokenClaims); + expect(sessionStorage.getItem).toHaveBeenCalledWith("mx_oidc_id_token_claims"); + }); + + it("should return undefined when no issuer in session storage", () => { + expect(getStoredOidcIdTokenClaims()).toBeUndefined(); + }); + }); });