diff --git a/packages/api/__tests__/API-test.ts b/packages/api/__tests__/API-test.ts index 044d1608ed3..4fa93305e1e 100644 --- a/packages/api/__tests__/API-test.ts +++ b/packages/api/__tests__/API-test.ts @@ -954,6 +954,100 @@ describe('API test', () => { expect(spyon).toBeCalledWith(url, init); }); + + test('happy case query with additionalHeaders', async () => { + const spyonAuth = jest + .spyOn(Credentials, 'get') + .mockImplementationOnce(() => { + return new Promise((res, rej) => { + res('cred'); + }); + }); + + const spyon = jest + .spyOn(RestClient.prototype, 'post') + .mockImplementationOnce((url, init) => { + return new Promise((res, rej) => { + res({}); + }); + }); + + const api = new API(config); + const url = 'https://appsync.amazonaws.com', + region = 'us-east-2', + apiKey = 'secret_api_key', + variables = { id: '809392da-ec91-4ef0-b219-5238a8f942b2' }; + api.configure({ + aws_appsync_graphqlEndpoint: url, + aws_appsync_region: region, + aws_appsync_authenticationType: 'API_KEY', + aws_appsync_apiKey: apiKey, + graphql_headers: async () => + Promise.resolve({ + someHeaderSetAtConfigThatWillBeOverridden: 'initialValue', + someOtherHeaderSetAtConfig: 'expectedValue', + }), + }); + const GetEvent = `query GetEvent($id: ID! $nextToken: String) { + getEvent(id: $id) { + id + name + where + when + description + comments(nextToken: $nextToken) { + items { + commentId + content + createdAt + } + } + } + }`; + + const doc = parse(GetEvent); + const query = print(doc); + + const headers = { + Authorization: null, + 'X-Api-Key': apiKey, + 'x-amz-user-agent': Constants.userAgent, + }; + + const body = { + query, + variables, + }; + + const init = { + headers, + body, + signerServiceInfo: { + service: 'appsync', + region, + }, + }; + + const additionalHeaders = { + someAddtionalHeader: 'foo', + someHeaderSetAtConfigThatWillBeOverridden: 'expectedValue', + }; + + await api.graphql( + graphqlOperation(GetEvent, variables), + additionalHeaders + ); + + expect(spyon).toBeCalledWith(url, { + ...init, + headers: { + someAddtionalHeader: 'foo', + someHeaderSetAtConfigThatWillBeOverridden: 'expectedValue', + ...init.headers, + someOtherHeaderSetAtConfig: 'expectedValue', + }, + }); + }); }); describe('configure test', () => { diff --git a/packages/api/src/API.ts b/packages/api/src/API.ts index 6f9da9843d7..bab9641de4c 100644 --- a/packages/api/src/API.ts +++ b/packages/api/src/API.ts @@ -347,9 +347,13 @@ export default class APIClass { * Executes a GraphQL operation * * @param {GraphQLOptions} GraphQL Options + * @param {object} additionalHeaders headers to merge in after any `graphql_headers` set in the config * @returns {Promise | Observable} */ - graphql({ query: paramQuery, variables = {}, authMode }: GraphQLOptions) { + graphql( + { query: paramQuery, variables = {}, authMode }: GraphQLOptions, + addtionalHeaders?: { [key: string]: string } + ) { const query = typeof paramQuery === 'string' ? parse(paramQuery) @@ -365,7 +369,7 @@ export default class APIClass { switch (operationType) { case 'query': case 'mutation': - return this._graphql({ query, variables, authMode }); + return this._graphql({ query, variables, authMode }, addtionalHeaders); case 'subscription': return this._graphqlSubscribe({ query, @@ -399,8 +403,8 @@ export default class APIClass { (customEndpointRegion ? await this._headerBasedAuth(authMode) : { Authorization: null })), - ...additionalHeaders, ...(await graphql_headers({ query, variables })), + ...additionalHeaders, ...(!customGraphqlEndpoint && { [USER_AGENT_HEADER]: Constants.userAgent, }),