Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(core): add cache-control header to cognito identity client #10753

Merged
merged 11 commits into from
Dec 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 43 additions & 27 deletions packages/amazon-cognito-identity-js/src/Client.js
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,20 @@ export default class Client {
requestWithRetry(operation, params, callback) {
const MAX_DELAY_IN_MILLIS = 5 * 1000;

haverchuck marked this conversation as resolved.
Show resolved Hide resolved
jitteredExponentialRetry((p) => new Promise((res, rej) => {
this.request(operation, p, (error, result) => {
if (error) {
rej(error);
} else {
res(result);
}
});
}), [params], MAX_DELAY_IN_MILLIS)
jitteredExponentialRetry(
p =>
new Promise((res, rej) => {
this.request(operation, p, (error, result) => {
if (error) {
rej(error);
} else {
res(result);
}
});
}),
[params],
MAX_DELAY_IN_MILLIS
)
.then(result => callback(null, result))
.catch(error => callback(error));
}
Expand Down Expand Up @@ -112,9 +117,9 @@ export default class Client {
// Taken from aws-sdk-js/lib/protocol/json.js
// eslint-disable-next-line no-underscore-dangle
const code = (data.__type || data.code).split('#').pop();
const error = new Error(data.message || data.Message || null)
error.name = code
error.code = code
const error = new Error(data.message || data.Message || null);
error.name = code;
error.code = code;
return callback(error);
})
.catch(err => {
Expand All @@ -126,17 +131,19 @@ export default class Client {
) {
try {
const code = response.headers.get('x-amzn-errortype').split(':')[0];
const error = new Error(response.status ? response.status.toString() : null)
error.code = code
error.name = code
error.statusCode = response.status
const error = new Error(
response.status ? response.status.toString() : null
);
error.code = code;
error.name = code;
error.statusCode = response.status;
return callback(error);
} catch (ex) {
return callback(err);
}
// otherwise check if error is Network error
} else if (err instanceof Error && err.message === 'Network error') {
err.code = 'NetworkError'
err.code = 'NetworkError';
}
return callback(err);
});
Expand All @@ -146,7 +153,7 @@ export default class Client {
const logger = {
debug: () => {
// Intentionally blank. This package doesn't have logging
}
},
};

/**
Expand All @@ -159,7 +166,7 @@ class NonRetryableError extends Error {
}
}

const isNonRetryableError = (obj) => {
const isNonRetryableError = obj => {
const key = 'nonRetryable';
return obj && obj[key];
};
Expand All @@ -169,9 +176,13 @@ function retry(functionToRetry, args, delayFn, attempt = 1) {
throw Error('functionToRetry must be a function');
}

logger.debug(`${functionToRetry.name} attempt #${attempt} with args: ${JSON.stringify(args)}`);
logger.debug(
`${functionToRetry.name} attempt #${attempt} with args: ${JSON.stringify(
args
)}`
);

return functionToRetry(...args).catch((err) => {
return functionToRetry(...args).catch(err => {
logger.debug(`error on ${functionToRetry.name}`, err);

if (isNonRetryableError(err)) {
Expand All @@ -184,12 +195,13 @@ function retry(functionToRetry, args, delayFn, attempt = 1) {
logger.debug(`${functionToRetry.name} retrying in ${retryIn} ms`);

if (retryIn !== false) {
return new Promise(res => setTimeout(res, retryIn))
.then(() => retry(functionToRetry, args, delayFn, attempt + 1))
return new Promise(res => setTimeout(res, retryIn)).then(() =>
retry(functionToRetry, args, delayFn, attempt + 1)
);
} else {
throw err;
}
})
});
}

function jitteredBackoff(maxDelayMs) {
Expand All @@ -203,6 +215,10 @@ function jitteredBackoff(maxDelayMs) {
}

const MAX_DELAY_MS = 5 * 60 * 1000;
function jitteredExponentialRetry(functionToRetry, args, maxDelayMs = MAX_DELAY_MS) {
return retry(functionToRetry, args, jitteredBackoff(maxDelayMs))
};
function jitteredExponentialRetry(
functionToRetry,
args,
maxDelayMs = MAX_DELAY_MS
) {
return retry(functionToRetry, args, jitteredBackoff(maxDelayMs));
}
3 changes: 3 additions & 0 deletions packages/core/__tests__/Credentials-test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ describe('Credentials test', () => {
};
}
},
middlewareStack: {
add: (next, _) => {},
},
};
});

Expand Down
82 changes: 82 additions & 0 deletions packages/core/__tests__/Util-test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@ import Reachability from '../src/Util/Reachability';
import { ConsoleLogger as Logger } from '../src/Logger';
import { urlSafeDecode, urlSafeEncode } from '../src/Util/StringUtils';
import { DateUtils } from '../src/Util/DateUtils';
import {
createCognitoIdentityClient,
middlewareArgs,
} from '../src/Util/CognitoIdentityClient';
import { BuildMiddleware, HttpRequest } from '@aws-sdk/types';
import {
GetCredentialsForIdentityCommand,
GetIdCommand,
} from '@aws-sdk/client-cognito-identity';

Logger.LOG_LEVEL = 'DEBUG';

Expand Down Expand Up @@ -51,6 +60,79 @@ describe('Util', () => {
});
});

describe('cognito identity client test', () => {
test('client should be instantiated', async () => {
const cognitoClient = createCognitoIdentityClient({
region: 'us-west-1',
});
expect(cognitoClient).toBeTruthy();
expect.assertions(1);
});

test('middlewareArgs helper should merge headers into request object', async () => {
const args = middlewareArgs({
request: {
headers: {
'test-header': '1234',
},
},
input: {},
});
expect(args.request.headers['test-header']).toEqual('1234');
expect(args.request.headers['cache-control']).toEqual('no-store');
expect.assertions(2);
});

test('headers should be added by middleware on GetIdCommand', async () => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add expect.assertions and also use done in case the tests ends before doing the assertions

Copy link
Member Author

@haverchuck haverchuck Dec 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the done is necessary if you are using async callback and awaiting, but I added the expect.assertions. Good callout, thanks.

const requestCacheHeaderValidator: BuildMiddleware<any, any> =
next => async args => {
// middleware intercept the request and return it early
const request = args.request as HttpRequest;
const { headers } = request;
expect(headers['cache-control']).toEqual('no-store');
return { output: {} as any, response: {} as any };
};

const client = createCognitoIdentityClient({ region: 'us-west-1' });
client.middlewareStack.addRelativeTo(requestCacheHeaderValidator, {
relation: 'after',
toMiddleware: 'cacheControlMiddleWare',
});

await client.send(
new GetIdCommand({
IdentityPoolId: 'us-west-1:12345678-1234-1234-1234-123456789000',
})
);
expect.assertions(1);
});

test('headers should be added by middleware on GetCredentialsForIdentityCommand', async () => {
haverchuck marked this conversation as resolved.
Show resolved Hide resolved
const requestCacheHeaderValidator: BuildMiddleware<any, any> =
next => async args => {
// middleware intercept the request and return it early
const request = args.request as HttpRequest;
const { headers } = request;
expect(headers['cache-control']).toEqual('no-store');
return { output: {} as any, response: {} as any };
};

const client = createCognitoIdentityClient({ region: 'us-west-1' });
client.middlewareStack.addRelativeTo(requestCacheHeaderValidator, {
relation: 'after',
toMiddleware: 'cacheControlMiddleWare',
});
await client.send(
new GetCredentialsForIdentityCommand({
IdentityId: '1234',
Logins: {},
})
);

expect.assertions(1);
});
});

test('jitteredExponential retry happy case', async () => {
const resolveAt = 3;
let attempts = 0;
Expand Down
18 changes: 8 additions & 10 deletions packages/core/src/Credentials.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import { makeQuerablePromise } from './JS';
import { FacebookOAuth, GoogleOAuth } from './OAuthHelper';
import { jitteredExponentialRetry } from './Util';
import { ICredentials } from './types';
import { getAmplifyUserAgent } from './Platform';
import { Amplify } from './Amplify';
import {
fromCognitoIdentity,
Expand All @@ -13,13 +12,13 @@ import {
FromCognitoIdentityPoolParameters,
} from '@aws-sdk/credential-provider-cognito-identity';
import {
CognitoIdentityClient,
GetIdCommand,
GetCredentialsForIdentityCommand,
} from '@aws-sdk/client-cognito-identity';
import { CredentialProvider } from '@aws-sdk/types';
import { parseAWSExports } from './parseAWSExports';
import { Hub } from './Hub';
import { createCognitoIdentityClient } from './Util/CognitoIdentityClient';

const logger = new Logger('Credentials');

Expand Down Expand Up @@ -265,7 +264,8 @@ export class CredentialsClass {
parseAWSExports(this._config || {}).Auth
);
}
const { identityPoolId, region, mandatorySignIn, identityPoolRegion } = this._config;
const { identityPoolId, region, mandatorySignIn, identityPoolRegion } =
this._config;

if (mandatorySignIn) {
return Promise.reject(
Expand All @@ -291,9 +291,8 @@ export class CredentialsClass {

const identityId = (this._identityId = await this._getGuestIdentityId());

const cognitoClient = new CognitoIdentityClient({
const cognitoClient = createCognitoIdentityClient({
region: identityPoolRegion || region,
customUserAgent: getAmplifyUserAgent(),
});

let credentials = undefined;
Expand Down Expand Up @@ -408,9 +407,8 @@ export class CredentialsClass {
);
}

const cognitoClient = new CognitoIdentityClient({
const cognitoClient = createCognitoIdentityClient({
region: identityPoolRegion || region,
customUserAgent: getAmplifyUserAgent(),
});

let credentials = undefined;
Expand All @@ -435,7 +433,8 @@ export class CredentialsClass {
private _setCredentialsFromSession(session): Promise<ICredentials> {
logger.debug('set credentials from session');
const idToken = session.getIdToken().getJwtToken();
const { region, userPoolId, identityPoolId, identityPoolRegion } = this._config;
const { region, userPoolId, identityPoolId, identityPoolRegion } =
this._config;
if (!identityPoolId) {
logger.debug('No Cognito Federated Identity pool provided');
return Promise.reject('No Cognito Federated Identity pool provided');
Expand All @@ -450,9 +449,8 @@ export class CredentialsClass {
const logins = {};
logins[key] = idToken;

const cognitoClient = new CognitoIdentityClient({
const cognitoClient = createCognitoIdentityClient({
region: identityPoolRegion || region,
customUserAgent: getAmplifyUserAgent(),
});

/*
Expand Down
48 changes: 48 additions & 0 deletions packages/core/src/Util/CognitoIdentityClient.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

import {
CognitoIdentityClient,
CognitoIdentityClientConfig,
} from '@aws-sdk/client-cognito-identity';
import { Provider } from '@aws-sdk/types';
import { getAmplifyUserAgent } from '../Platform';

/**
* Returns a CognitoIdentityClient with middleware
* @param {CognitoIdentityClientConfig} config
* @return {CognitoIdentityClient}
*/
export function createCognitoIdentityClient(
config: CognitoIdentityClientConfig
): CognitoIdentityClient {
const client = new CognitoIdentityClient({
region: config.region,
customUserAgent: getAmplifyUserAgent(),
});

client.middlewareStack.add(
(next, _) => (args: any) => {
return next(middlewareArgs(args));
},
{
step: 'build',
name: 'cacheControlMiddleWare',
}
);

return client;
}

export function middlewareArgs(args: { request: any; input: any }) {
return {
...args,
request: {
...args.request,
headers: {
...args.request.headers,
'cache-control': 'no-store',
},
},
};
}