diff --git a/opensearch_dashboards.json b/opensearch_dashboards.json
index f286b75a..09e7df90 100644
--- a/opensearch_dashboards.json
+++ b/opensearch_dashboards.json
@@ -16,6 +16,6 @@
"dataSourceManagement"
],
"configPath": [
- "assistant"
+ "assistant"
]
}
\ No newline at end of file
diff --git a/public/components/__tests__/chat_window_header_title.test.tsx b/public/components/__tests__/chat_window_header_title.test.tsx
index 025dd62c..5c87cd55 100644
--- a/public/components/__tests__/chat_window_header_title.test.tsx
+++ b/public/components/__tests__/chat_window_header_title.test.tsx
@@ -17,6 +17,7 @@ import * as coreContextExports from '../../contexts/core_context';
import { IMessage } from '../../../common/types/chat_saved_object_attributes';
import { ChatWindowHeaderTitle } from '../chat_window_header_title';
+import { DataSourceServiceMock } from '../../services/data_source_service.mock';
const setup = ({
messages = [],
@@ -37,6 +38,7 @@ const setup = ({
}),
reload: jest.fn(),
},
+ dataSource: new DataSourceServiceMock(),
},
};
useCoreMock.services.http.put.mockImplementation(() => Promise.resolve());
diff --git a/public/components/__tests__/edit_conversation_name_modal.test.tsx b/public/components/__tests__/edit_conversation_name_modal.test.tsx
index db594a02..c650fd60 100644
--- a/public/components/__tests__/edit_conversation_name_modal.test.tsx
+++ b/public/components/__tests__/edit_conversation_name_modal.test.tsx
@@ -15,10 +15,11 @@ import {
EditConversationNameModalProps,
} from '../edit_conversation_name_modal';
import { HttpHandler } from '../../../../../src/core/public';
+import { DataSourceServiceMock } from '../../services/data_source_service.mock';
const setup = ({ onClose, defaultTitle, conversationId }: EditConversationNameModalProps) => {
const useCoreMock = {
- services: coreMock.createStart(),
+ services: { ...coreMock.createStart(), dataSource: new DataSourceServiceMock() },
};
jest.spyOn(coreContextExports, 'useCore').mockReturnValue(useCoreMock);
@@ -156,7 +157,10 @@ describe('', () => {
expect(useCoreMock.services.http.put).not.toHaveBeenCalled();
fireEvent.click(renderResult.getByTestId('confirmModalConfirmButton'));
- expect(useCoreMock.services.http.put).toHaveBeenCalled();
+
+ await waitFor(() => {
+ expect(useCoreMock.services.http.put).toHaveBeenCalled();
+ });
fireEvent.click(renderResult.getByTestId('confirmModalCancelButton'));
diff --git a/public/components/feedback_modal.tsx b/public/components/feedback_modal.tsx
deleted file mode 100644
index e9889da3..00000000
--- a/public/components/feedback_modal.tsx
+++ /dev/null
@@ -1,309 +0,0 @@
-/*
- * Copyright OpenSearch Contributors
- * SPDX-License-Identifier: Apache-2.0
- */
-
-import {
- EuiButton,
- EuiButtonEmpty,
- EuiForm,
- EuiFormRow,
- EuiModal,
- EuiModalBody,
- EuiModalFooter,
- EuiModalHeader,
- EuiModalHeaderTitle,
- EuiRadioGroup,
- EuiTextArea,
-} from '@elastic/eui';
-import React, { useState } from 'react';
-import { HttpStart } from '../../../../src/core/public';
-import { ASSISTANT_API } from '../../common/constants/llm';
-import { getCoreStart } from '../plugin';
-
-export interface LabelData {
- formHeader: string;
- inputPlaceholder: string;
- outputPlaceholder: string;
-}
-
-export interface FeedbackFormData {
- input: string;
- output: string;
- correct: boolean | undefined;
- expectedOutput: string;
- comment: string;
-}
-
-interface FeedbackMetaData {
- type: 'event_analytics' | 'chat' | 'ppl_submit';
- conversationId?: string;
- interactionId?: string;
- error?: boolean;
- selectedIndex?: string;
-}
-
-interface FeedbackModelProps {
- input?: string;
- output?: string;
- metadata: FeedbackMetaData;
- onClose: () => void;
-}
-
-export const FeedbackModal: React.FC = (props) => {
- const [formData, setFormData] = useState({
- input: props.input ?? '',
- output: props.output ?? '',
- correct: undefined,
- expectedOutput: '',
- comment: '',
- });
- return (
-
-
-
- );
-};
-
-interface FeedbackModalContentProps {
- formData: FeedbackFormData;
- setFormData: React.Dispatch>;
- metadata: FeedbackMetaData;
- displayLabels?: Partial> & Partial;
- onClose: () => void;
-}
-
-export const FeedbackModalContent: React.FC = (props) => {
- const core = getCoreStart();
- const labels: NonNullable> = Object.assign(
- {
- formHeader: 'Olly Skills Feedback',
- inputPlaceholder: 'Your input question',
- input: 'Input question',
- outputPlaceholder: 'The LLM response',
- output: 'Output',
- correct: 'Does the output match your expectations?',
- expectedOutput: 'Expected output',
- comment: 'Comment',
- },
- props.displayLabels
- );
- const { loading, submitFeedback } = useSubmitFeedback(props.formData, props.metadata, core.http);
- const [formErrors, setFormErrors] = useState<
- Partial<{ [x in keyof FeedbackFormData]: string[] }>
- >({
- input: [],
- output: [],
- expectedOutput: [],
- });
-
- const hasError = (key?: keyof FeedbackFormData) => {
- if (!key) return Object.values(formErrors).some((e) => !!e.length);
- return !!formErrors[key]?.length;
- };
-
- const onSubmit = async (event: React.FormEvent) => {
- event.preventDefault();
- const errors = {
- input: validator
- .input(props.formData.input)
- .concat(await validator.validateQuery(props.formData.input, props.metadata.type)),
- output: validator.output(props.formData.output),
- correct: validator.correct(props.formData.correct),
- expectedOutput: validator.expectedOutput(
- props.formData.expectedOutput,
- props.formData.correct === false
- ),
- };
- if (Object.values(errors).some((e) => !!e.length)) {
- setFormErrors(errors);
- return;
- }
-
- try {
- await submitFeedback();
- props.setFormData({
- input: '',
- output: '',
- correct: undefined,
- expectedOutput: '',
- comment: '',
- });
- core.notifications.toasts.addSuccess('Thanks for your feedback!');
- props.onClose();
- } catch (e) {
- core.notifications.toasts.addError(e, { title: 'Failed to submit feedback' });
- }
- };
-
- return (
- <>
-
- {labels.formHeader}
-
-
-
-
-
- props.setFormData({ ...props.formData, input: e.target.value })}
- onBlur={(e) => {
- setFormErrors({ ...formErrors, input: validator.input(e.target.value) });
- }}
- isInvalid={hasError('input')}
- />
-
-
- props.setFormData({ ...props.formData, output: e.target.value })}
- onBlur={(e) => {
- setFormErrors({ ...formErrors, output: validator.output(e.target.value) });
- }}
- isInvalid={hasError('output')}
- />
-
- {props.metadata.type !== 'ppl_submit' && (
-
- {
- props.setFormData({ ...props.formData, correct: id === 'yes' });
- setFormErrors({ ...formErrors, expectedOutput: [] });
- }}
- onBlur={() => setFormErrors({ ...formErrors, correct: [] })}
- />
-
- )}
- {props.formData.correct === false && (
-
-
- props.setFormData({ ...props.formData, expectedOutput: e.target.value })
- }
- onBlur={(e) => {
- setFormErrors({
- ...formErrors,
- expectedOutput: validator.expectedOutput(
- e.target.value,
- props.formData.correct === false
- ),
- });
- }}
- isInvalid={hasError('expectedOutput')}
- />
-
- )}
-
- props.setFormData({ ...props.formData, comment: e.target.value })}
- />
-
-
-
-
-
- Cancel
-
- Send
-
-
- >
- );
-};
-
-const useSubmitFeedback = (data: FeedbackFormData, metadata: FeedbackMetaData, http: HttpStart) => {
- const [loading, setLoading] = useState(false);
- return {
- loading,
- submitFeedback: async () => {
- setLoading(true);
- const auth = await http
- .get<{ data: { user_name: string; user_requested_tenant: string; roles: string[] } }>(
- '/api/v1/configuration/account'
- )
- .then((res) => ({ user: res.data.user_name, tenant: res.data.user_requested_tenant }));
-
- return http
- .post(ASSISTANT_API.FEEDBACK, {
- body: JSON.stringify({ metadata: { ...metadata, ...auth }, ...data }),
- })
- .finally(() => setLoading(false));
- },
- };
-};
-
-const validatePPLQuery = async (logsQuery: string, feedBackType: FeedbackMetaData['type']) => {
- return [];
- // TODO remove
- // let responseMessage: [] | string[] = [];
- // const errorMessage = [' Invalid PPL Query, please re-check the ppl syntax'];
-
- // if (feedBackType === 'ppl_submit') {
- // const pplService = getPPLService();
- // await pplService
- // .fetch({ query: logsQuery, format: 'jdbc' })
- // .then((res) => {
- // if (res === undefined) responseMessage = errorMessage;
- // })
- // .catch((error: Error) => {
- // responseMessage = errorMessage;
- // });
- // }
- // return responseMessage;
-};
-
-const validator = {
- input: (text: string) => (text.trim().length === 0 ? ['Input is required'] : []),
- output: (text: string) => (text.trim().length === 0 ? ['Output is required'] : []),
- correct: (correct: boolean | undefined) =>
- correct === undefined ? ['Correctness is required'] : [],
- expectedOutput: (text: string, required: boolean) =>
- required && text.trim().length === 0 ? ['expectedOutput is required'] : [],
- validateQuery: async (logsQuery: string, feedBackType: FeedbackMetaData['type']) =>
- await validatePPLQuery(logsQuery, feedBackType),
-};
diff --git a/public/contexts/__mocks__/core_context.tsx b/public/contexts/__mocks__/core_context.tsx
index 5c1dab99..2d92d77f 100644
--- a/public/contexts/__mocks__/core_context.tsx
+++ b/public/contexts/__mocks__/core_context.tsx
@@ -5,6 +5,7 @@
import { BehaviorSubject } from 'rxjs';
import { coreMock } from '../../../../../src/core/public/mocks';
+import { DataSourceServiceMock } from '../../services/data_source_service.mock';
export const useCore = jest.fn(() => {
const useCoreMock = {
@@ -24,7 +25,7 @@ export const useCore = jest.fn(() => {
load: jest.fn(),
},
conversationLoad: {},
- dataSource: {},
+ dataSource: new DataSourceServiceMock(),
},
};
useCoreMock.services.http.delete.mockReturnValue(Promise.resolve());
diff --git a/public/contexts/core_context.tsx b/public/contexts/core_context.tsx
index b4143154..6db5dd2e 100644
--- a/public/contexts/core_context.tsx
+++ b/public/contexts/core_context.tsx
@@ -15,6 +15,7 @@ export interface AssistantServices extends Required;
@@ -76,6 +77,10 @@ describe('useDeleteConversation', () => {
deleteConversationPromise = result.current.deleteConversation('foo');
});
+ await waitFor(() => {
+ expect(useCoreMocked.mock.results[0].value.services.http.delete).toHaveBeenCalled();
+ });
+
let deleteConversationError;
await act(async () => {
result.current.abort();
@@ -160,6 +165,10 @@ describe('usePatchConversation', () => {
patchConversationPromise = result.current.patchConversation('foo', 'new-title');
});
+ await waitFor(() => {
+ expect(useCoreMocked.mock.results[0].value.services.http.put).toHaveBeenCalled();
+ });
+
let patchConversationError;
await act(async () => {
result.current.abort();
diff --git a/public/hooks/use_chat_actions.test.tsx b/public/hooks/use_chat_actions.test.tsx
index aaa6c77a..6418ae31 100644
--- a/public/hooks/use_chat_actions.test.tsx
+++ b/public/hooks/use_chat_actions.test.tsx
@@ -13,6 +13,7 @@ import { ConversationLoadService } from '../services/conversation_load_service';
import * as chatStateHookExports from './use_chat_state';
import { ASSISTANT_API } from '../../common/constants/llm';
import { IMessage } from 'common/types/chat_saved_object_attributes';
+import { DataSourceServiceMock } from '../services/data_source_service.mock';
jest.mock('../services/conversations_service', () => {
return {
@@ -60,6 +61,7 @@ describe('useChatActions hook', () => {
const setSelectedTabIdMock = jest.fn();
const pplVisualizationRenderMock = jest.fn();
const setInteractionIdMock = jest.fn();
+ const dataSourceServiceMock = new DataSourceServiceMock();
const chatContextMock: chatContextHookExports.IChatContext = {
selectedTabId: 'chat',
@@ -88,8 +90,9 @@ describe('useChatActions hook', () => {
jest.spyOn(coreHookExports, 'useCore').mockReturnValue({
services: {
http: httpMock,
- conversations: new ConversationsService(httpMock),
- conversationLoad: new ConversationLoadService(httpMock),
+ conversations: new ConversationsService(httpMock, dataSourceServiceMock),
+ conversationLoad: new ConversationLoadService(httpMock, dataSourceServiceMock),
+ dataSource: dataSourceServiceMock,
},
});
@@ -125,6 +128,7 @@ describe('useChatActions hook', () => {
messages: [SEND_MESSAGE_RESPONSE.messages[0]],
input: INPUT_MESSAGE,
}),
+ query: await dataSourceServiceMock.getDataSourceQuery(),
});
// it should send dispatch `receive` action to remove the message without messageId
@@ -199,6 +203,7 @@ describe('useChatActions hook', () => {
messages: [],
input: { type: 'input', content: 'message that send as input', contentType: 'text' },
}),
+ query: await dataSourceServiceMock.getDataSourceQuery(),
});
});
@@ -261,6 +266,7 @@ describe('useChatActions hook', () => {
expect(chatStateDispatchMock).toHaveBeenCalledWith({ type: 'abort' });
expect(httpMock.post).toHaveBeenCalledWith(ASSISTANT_API.ABORT_AGENT_EXECUTION, {
body: JSON.stringify({ conversationId: 'conversation_id_to_abort' }),
+ query: await dataSourceServiceMock.getDataSourceQuery(),
});
});
@@ -288,6 +294,7 @@ describe('useChatActions hook', () => {
conversationId: 'conversation_id_mock',
interactionId: 'interaction_id_mock',
}),
+ query: await dataSourceServiceMock.getDataSourceQuery(),
});
expect(chatStateDispatchMock).toHaveBeenCalledWith(
expect.objectContaining({ type: 'receive', payload: { messages: [], interactions: [] } })
@@ -323,6 +330,7 @@ describe('useChatActions hook', () => {
conversationId: 'conversation_id_mock',
interactionId: 'interaction_id_mock',
}),
+ query: await dataSourceServiceMock.getDataSourceQuery(),
});
expect(chatStateDispatchMock).not.toHaveBeenCalledWith(
expect.objectContaining({ type: 'receive' })
diff --git a/public/hooks/use_chat_actions.tsx b/public/hooks/use_chat_actions.tsx
index a8f6358b..962420f2 100644
--- a/public/hooks/use_chat_actions.tsx
+++ b/public/hooks/use_chat_actions.tsx
@@ -35,6 +35,7 @@ export const useChatActions = (): AssistantActions => {
...(!chatContext.conversationId && { messages: chatState.messages }), // include all previous messages for new chats
input,
}),
+ query: await core.services.dataSource.getDataSourceQuery(),
});
if (abortController.signal.aborted) return;
// Refresh history list after new conversation created if new conversation saved and history list page visible
@@ -162,6 +163,7 @@ export const useChatActions = (): AssistantActions => {
// abort agent execution
await core.services.http.post(`${ASSISTANT_API.ABORT_AGENT_EXECUTION}`, {
body: JSON.stringify({ conversationId }),
+ query: await core.services.dataSource.getDataSourceQuery(),
});
}
};
@@ -178,6 +180,7 @@ export const useChatActions = (): AssistantActions => {
conversationId: chatContext.conversationId,
interactionId,
}),
+ query: await core.services.dataSource.getDataSourceQuery(),
});
if (abortController.signal.aborted) {
diff --git a/public/hooks/use_conversations.ts b/public/hooks/use_conversations.ts
index 89e1a058..2828a4af 100644
--- a/public/hooks/use_conversations.ts
+++ b/public/hooks/use_conversations.ts
@@ -14,12 +14,13 @@ export const useDeleteConversation = () => {
const abortControllerRef = useRef();
const deleteConversation = useCallback(
- (conversationId: string) => {
+ async (conversationId: string) => {
abortControllerRef.current = new AbortController();
dispatch({ type: 'request' });
return core.services.http
.delete(`${ASSISTANT_API.CONVERSATION}/${conversationId}`, {
signal: abortControllerRef.current.signal,
+ query: await core.services.dataSource.getDataSourceQuery(),
})
.then((payload) => {
dispatch({ type: 'success', payload });
@@ -52,7 +53,7 @@ export const usePatchConversation = () => {
const abortControllerRef = useRef();
const patchConversation = useCallback(
- (conversationId: string, title: string) => {
+ async (conversationId: string, title: string) => {
abortControllerRef.current = new AbortController();
dispatch({ type: 'request' });
return core.services.http
@@ -60,6 +61,7 @@ export const usePatchConversation = () => {
body: JSON.stringify({
title,
}),
+ query: await core.services.dataSource.getDataSourceQuery(),
signal: abortControllerRef.current.signal,
})
.then((payload) => dispatch({ type: 'success', payload }))
diff --git a/public/hooks/use_feed_back.test.tsx b/public/hooks/use_feed_back.test.tsx
index 8653c480..74de2031 100644
--- a/public/hooks/use_feed_back.test.tsx
+++ b/public/hooks/use_feed_back.test.tsx
@@ -12,11 +12,12 @@ import { httpServiceMock } from '../../../../src/core/public/mocks';
import * as chatContextHookExports from '../contexts/chat_context';
import { Interaction, IOutput, IMessage } from '../../common/types/chat_saved_object_attributes';
import { ASSISTANT_API } from '../../common/constants/llm';
+import { DataSourceServiceMock } from '../services/data_source_service.mock';
describe('useFeedback hook', () => {
const httpMock = httpServiceMock.createStartContract();
const chatStateDispatchMock = jest.fn();
-
+ const dataSourceMock = new DataSourceServiceMock();
const chatContextMock = {
rootAgentId: 'root_agent_id_mock',
selectedTabId: 'chat',
@@ -29,6 +30,7 @@ describe('useFeedback hook', () => {
jest.spyOn(coreHookExports, 'useCore').mockReturnValue({
services: {
http: httpMock,
+ dataSource: dataSourceMock,
},
});
jest.spyOn(chatContextHookExports, 'useChatContext').mockReturnValue(chatContextMock);
@@ -82,6 +84,7 @@ describe('useFeedback hook', () => {
body: JSON.stringify({
satisfaction: true,
}),
+ query: await dataSourceMock.getDataSourceQuery(),
}
);
expect(result.current.feedbackResult).toBe(true);
@@ -116,6 +119,7 @@ describe('useFeedback hook', () => {
body: JSON.stringify({
satisfaction: true,
}),
+ query: await dataSourceMock.getDataSourceQuery(),
}
);
expect(result.current.feedbackResult).toBe(undefined);
diff --git a/public/hooks/use_feed_back.tsx b/public/hooks/use_feed_back.tsx
index 50e56380..2222b301 100644
--- a/public/hooks/use_feed_back.tsx
+++ b/public/hooks/use_feed_back.tsx
@@ -38,6 +38,7 @@ export const useFeedback = (interaction?: Interaction | null) => {
try {
await core.services.http.put(`${ASSISTANT_API.FEEDBACK}/${message.interactionId}`, {
body: JSON.stringify(body),
+ query: await core.services.dataSource.getDataSourceQuery(),
});
setFeedbackResult(correct);
} catch (error) {
diff --git a/public/hooks/use_fetch_agentframework_traces.test.ts b/public/hooks/use_fetch_agentframework_traces.test.ts
index 4ce4c2a2..df81766c 100644
--- a/public/hooks/use_fetch_agentframework_traces.test.ts
+++ b/public/hooks/use_fetch_agentframework_traces.test.ts
@@ -9,11 +9,17 @@ import { createOpenSearchDashboardsReactContext } from '../../../../src/plugins/
import { coreMock } from '../../../../src/core/public/mocks';
import { HttpHandler } from '../../../../src/core/public';
import { AbortError } from '../../../../src/plugins/data/common';
+import { DataSourceServiceMock } from '../services/data_source_service.mock';
+import { waitFor } from '@testing-library/dom';
describe('useFetchAgentFrameworkTraces hook', () => {
const interactionId = 'foo';
const services = coreMock.createStart();
- const { Provider } = createOpenSearchDashboardsReactContext(services);
+ const mockServices = {
+ ...services,
+ dataSource: new DataSourceServiceMock(),
+ };
+ const { Provider } = createOpenSearchDashboardsReactContext(mockServices);
const wrapper = { wrapper: Provider };
it('return undefined when interaction id is not specfied', () => {
@@ -106,14 +112,16 @@ describe('useFetchAgentFrameworkTraces hook', () => {
unmount();
});
- expect(services.http.get).toHaveBeenCalledWith(
- `/api/assistant/trace/${interactionId}`,
- expect.objectContaining({
- signal: expect.any(Object),
- })
- );
+ await waitFor(() => {
+ expect(services.http.get).toHaveBeenCalledWith(
+ `/api/assistant/trace/${interactionId}`,
+ expect.objectContaining({
+ signal: expect.any(Object),
+ })
+ );
- // expect the mock to be called
- expect(abortFn).toBeCalledTimes(1);
+ // expect the mock to be called
+ expect(abortFn).toBeCalledTimes(1);
+ });
});
});
diff --git a/public/hooks/use_fetch_agentframework_traces.ts b/public/hooks/use_fetch_agentframework_traces.ts
index 1edece5f..43925855 100644
--- a/public/hooks/use_fetch_agentframework_traces.ts
+++ b/public/hooks/use_fetch_agentframework_traces.ts
@@ -22,20 +22,23 @@ export const useFetchAgentFrameworkTraces = (interactionId: string) => {
return;
}
- core.services.http
- .get(`${ASSISTANT_API.TRACE}/${interactionId}`, {
- signal: abortController.signal,
- })
- .then((payload) =>
- dispatch({
- type: 'success',
- payload,
+ core.services.dataSource.getDataSourceQuery().then((query) => {
+ core.services.http
+ .get(`${ASSISTANT_API.TRACE}/${interactionId}`, {
+ signal: abortController.signal,
+ query,
})
- )
- .catch((error) => {
- if (error.name === 'AbortError') return;
- dispatch({ type: 'failure', error });
- });
+ .then((payload) =>
+ dispatch({
+ type: 'success',
+ payload,
+ })
+ )
+ .catch((error) => {
+ if (error.name === 'AbortError') return;
+ dispatch({ type: 'failure', error });
+ });
+ });
return () => abortController.abort();
}, [core.services.http, interactionId]);
diff --git a/public/plugin.tsx b/public/plugin.tsx
index 76912891..a42449e5 100644
--- a/public/plugin.tsx
+++ b/public/plugin.tsx
@@ -111,8 +111,8 @@ export class AssistantPlugin
...coreStart,
setupDeps,
startDeps,
- conversationLoad: new ConversationLoadService(coreStart.http),
- conversations: new ConversationsService(coreStart.http),
+ conversationLoad: new ConversationLoadService(coreStart.http, this.dataSourceService),
+ conversations: new ConversationsService(coreStart.http, this.dataSourceService),
dataSource: this.dataSourceService,
});
const account = await getAccount();
diff --git a/public/services/__tests__/conversation_load_service.test.ts b/public/services/__tests__/conversation_load_service.test.ts
index 81c2cf37..acf00ceb 100644
--- a/public/services/__tests__/conversation_load_service.test.ts
+++ b/public/services/__tests__/conversation_load_service.test.ts
@@ -3,13 +3,15 @@
* SPDX-License-Identifier: Apache-2.0
*/
+import { waitFor } from '@testing-library/dom';
import { HttpHandler } from '../../../../../src/core/public';
import { httpServiceMock } from '../../../../../src/core/public/mocks';
import { ConversationLoadService } from '../conversation_load_service';
+import { DataSourceServiceMock } from '../data_source_service.mock';
const setup = () => {
const http = httpServiceMock.createSetupContract();
- const conversationLoad = new ConversationLoadService(http);
+ const conversationLoad = new ConversationLoadService(http, new DataSourceServiceMock());
return {
conversationLoad,
@@ -18,17 +20,19 @@ const setup = () => {
};
describe('ConversationLoadService', () => {
- it('should emit loading status and call get with specific conversation id', () => {
+ it('should emit loading status and call get with specific conversation id', async () => {
const { conversationLoad, http } = setup();
conversationLoad.load('foo');
- expect(http.get).toHaveBeenCalledWith(
- '/api/assistant/conversation/foo',
- expect.objectContaining({
- signal: expect.anything(),
- })
- );
expect(conversationLoad.status$.getValue()).toBe('loading');
+ await waitFor(() => {
+ expect(http.get).toHaveBeenCalledWith(
+ '/api/assistant/conversation/foo',
+ expect.objectContaining({
+ signal: expect.anything(),
+ })
+ );
+ });
});
it('should resolved with response data and "idle" status', async () => {
@@ -53,6 +57,9 @@ describe('ConversationLoadService', () => {
});
}) as HttpHandler);
const loadResult = conversationLoad.load('foo');
+ await waitFor(() => {
+ expect(http.get).toHaveBeenCalled();
+ });
conversationLoad.abortController?.abort();
await loadResult;
diff --git a/public/services/__tests__/conversations_service.test.ts b/public/services/__tests__/conversations_service.test.ts
index 2ed8c75f..198a9ed5 100644
--- a/public/services/__tests__/conversations_service.test.ts
+++ b/public/services/__tests__/conversations_service.test.ts
@@ -3,35 +3,41 @@
* SPDX-License-Identifier: Apache-2.0
*/
+import { waitFor } from '@testing-library/dom';
import { HttpHandler } from '../../../../../src/core/public';
import { httpServiceMock } from '../../../../../src/core/public/mocks';
import { ConversationsService } from '../conversations_service';
+import { DataSourceServiceMock } from '../data_source_service.mock';
const setup = () => {
const http = httpServiceMock.createSetupContract();
- const conversations = new ConversationsService(http);
+ const dataSourceServiceMock = new DataSourceServiceMock();
+ const conversations = new ConversationsService(http, dataSourceServiceMock);
return {
conversations,
http,
+ dataSource: dataSourceServiceMock,
};
};
describe('ConversationsService', () => {
- it('should emit loading status and call get with conversations API path', () => {
+ it('should emit loading status and call get with conversations API path', async () => {
const { conversations, http } = setup();
conversations.load();
- expect(http.get).toHaveBeenCalledWith(
- '/api/assistant/conversations',
- expect.objectContaining({
- signal: expect.anything(),
- })
- );
expect(conversations.status$.getValue()).toBe('loading');
+ await waitFor(() => {
+ expect(http.get).toHaveBeenCalledWith(
+ '/api/assistant/conversations',
+ expect.objectContaining({
+ signal: expect.anything(),
+ })
+ );
+ });
});
- it('should update options property and call get with passed query', () => {
+ it('should update options property and call get with passed query', async () => {
const { conversations, http } = setup();
expect(conversations.options).toBeFalsy();
@@ -43,16 +49,19 @@ describe('ConversationsService', () => {
page: 1,
perPage: 10,
});
- expect(http.get).toHaveBeenCalledWith(
- '/api/assistant/conversations',
- expect.objectContaining({
- query: {
- page: 1,
- perPage: 10,
- },
- signal: expect.anything(),
- })
- );
+ await waitFor(() => {
+ expect(http.get).toHaveBeenCalledWith(
+ '/api/assistant/conversations',
+ expect.objectContaining({
+ query: {
+ page: 1,
+ perPage: 10,
+ dataSourceId: '',
+ },
+ signal: expect.anything(),
+ })
+ );
+ });
});
it('should emit latest conversations and "idle" status', async () => {
@@ -78,17 +87,20 @@ describe('ConversationsService', () => {
http.get.mockClear();
conversations.reload();
- expect(http.get).toHaveBeenCalledTimes(1);
- expect(http.get).toHaveBeenCalledWith(
- '/api/assistant/conversations',
- expect.objectContaining({
- query: {
- page: 1,
- perPage: 10,
- },
- signal: expect.anything(),
- })
- );
+ await waitFor(() => {
+ expect(http.get).toHaveBeenCalledTimes(1);
+ expect(http.get).toHaveBeenCalledWith(
+ '/api/assistant/conversations',
+ expect.objectContaining({
+ query: {
+ page: 1,
+ perPage: 10,
+ dataSourceId: '',
+ },
+ signal: expect.anything(),
+ })
+ );
+ });
});
it('should emit error after loading aborted', async () => {
@@ -104,6 +116,9 @@ describe('ConversationsService', () => {
});
}) as HttpHandler);
const loadResult = conversations.load();
+ await waitFor(() => {
+ expect(http.get).toHaveBeenCalled();
+ });
conversations.abortController?.abort();
await loadResult;
diff --git a/public/services/conversation_load_service.ts b/public/services/conversation_load_service.ts
index 44b2e970..866ad6d6 100644
--- a/public/services/conversation_load_service.ts
+++ b/public/services/conversation_load_service.ts
@@ -7,6 +7,7 @@ import { BehaviorSubject } from 'rxjs';
import { HttpStart } from '../../../../src/core/public';
import { IConversation } from '../../common/types/chat_saved_object_attributes';
import { ASSISTANT_API } from '../../common/constants/llm';
+import { DataSourceService } from './data_source_service';
export class ConversationLoadService {
status$: BehaviorSubject<
@@ -14,7 +15,7 @@ export class ConversationLoadService {
> = new BehaviorSubject<'idle' | 'loading' | { status: 'error'; error: Error }>('idle');
abortController?: AbortController;
- constructor(private _http: HttpStart) {}
+ constructor(private _http: HttpStart, private _dataSource: DataSourceService) {}
load = async (conversationId: string) => {
this.abortController?.abort();
@@ -25,6 +26,7 @@ export class ConversationLoadService {
`${ASSISTANT_API.CONVERSATION}/${conversationId}`,
{
signal: this.abortController.signal,
+ query: await this._dataSource.getDataSourceQuery(),
}
);
this.status$.next('idle');
diff --git a/public/services/conversations_service.ts b/public/services/conversations_service.ts
index 4f95794a..ab188070 100644
--- a/public/services/conversations_service.ts
+++ b/public/services/conversations_service.ts
@@ -7,6 +7,7 @@ import { BehaviorSubject } from 'rxjs';
import { HttpFetchQuery, HttpStart, SavedObjectsFindOptions } from '../../../../src/core/public';
import { IConversationFindResponse } from '../../common/types/chat_saved_object_attributes';
import { ASSISTANT_API } from '../../common/constants/llm';
+import { DataSourceService } from './data_source_service';
export class ConversationsService {
conversations$: BehaviorSubject = new BehaviorSubject(
@@ -21,7 +22,7 @@ export class ConversationsService {
>;
abortController?: AbortController;
- constructor(private _http: HttpStart) {}
+ constructor(private _http: HttpStart, private _dataSource: DataSourceService) {}
public get options() {
return this._options;
@@ -37,7 +38,10 @@ export class ConversationsService {
this.status$.next('loading');
this.conversations$.next(
await this._http.get(ASSISTANT_API.CONVERSATIONS, {
- query: this._options as HttpFetchQuery,
+ query: {
+ ...this._options,
+ ...(await this._dataSource.getDataSourceQuery()),
+ } as HttpFetchQuery,
signal: this.abortController.signal,
})
);
diff --git a/public/services/data_source_service.mock.ts b/public/services/data_source_service.mock.ts
new file mode 100644
index 00000000..40fc2018
--- /dev/null
+++ b/public/services/data_source_service.mock.ts
@@ -0,0 +1,15 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+export class DataSourceServiceMock {
+ constructor() {}
+
+ getDataSourceQuery() {
+ return new Promise((resolve) => {
+ resolve({ dataSourceId: '' });
+ });
+ // return { dataSourceId: '' };
+ }
+}
diff --git a/public/tabs/history/__tests__/chat_history_page.test.tsx b/public/tabs/history/__tests__/chat_history_page.test.tsx
index eb948ceb..944fedf6 100644
--- a/public/tabs/history/__tests__/chat_history_page.test.tsx
+++ b/public/tabs/history/__tests__/chat_history_page.test.tsx
@@ -14,6 +14,7 @@ import * as useChatStateExports from '../../../hooks/use_chat_state';
import * as chatContextExports from '../../../contexts/chat_context';
import * as coreContextExports from '../../../contexts/core_context';
import { ConversationsService } from '../../../services/conversations_service';
+import { DataSourceServiceMock } from '../../../services/data_source_service.mock';
import { ChatHistoryPage } from '../chat_history_page';
@@ -38,12 +39,14 @@ const setup = ({
http?: HttpStart;
chatContext?: { flyoutFullScreen?: boolean };
} = {}) => {
+ const dataSourceMock = new DataSourceServiceMock();
const useCoreMock = {
services: {
...coreMock.createStart(),
http,
- conversations: new ConversationsService(http),
+ conversations: new ConversationsService(http, dataSourceMock),
conversationLoad: {},
+ dataSource: dataSourceMock,
},
};
const useChatStateMock = {
@@ -91,10 +94,11 @@ describe('', () => {
expect(useChatStateMock.chatStateDispatch).not.toHaveBeenCalled();
fireEvent.click(renderResult.getByTestId('confirmModalConfirmButton'));
-
- expect(useChatContextMock.setConversationId).toHaveBeenLastCalledWith(undefined);
- expect(useChatContextMock.setTitle).toHaveBeenLastCalledWith(undefined);
- expect(useChatStateMock.chatStateDispatch).toHaveBeenLastCalledWith({ type: 'reset' });
+ await waitFor(() => {
+ expect(useChatContextMock.setConversationId).toHaveBeenLastCalledWith(undefined);
+ expect(useChatContextMock.setTitle).toHaveBeenLastCalledWith(undefined);
+ expect(useChatStateMock.chatStateDispatch).toHaveBeenLastCalledWith({ type: 'reset' });
+ });
});
it('should render empty screen', async () => {
diff --git a/public/tabs/history/__tests__/chat_history_search_list.test.tsx b/public/tabs/history/__tests__/chat_history_search_list.test.tsx
index d84baf86..1792136e 100644
--- a/public/tabs/history/__tests__/chat_history_search_list.test.tsx
+++ b/public/tabs/history/__tests__/chat_history_search_list.test.tsx
@@ -12,6 +12,7 @@ import * as chatContextExports from '../../../contexts/chat_context';
import * as coreContextExports from '../../../contexts/core_context';
import { ChatHistorySearchList, ChatHistorySearchListProps } from '../chat_history_search_list';
+import { DataSourceServiceMock } from '../../../services/data_source_service.mock';
const setup = ({
loading = false,
@@ -26,8 +27,9 @@ const setup = ({
conversationId: '1',
setTitle: jest.fn(),
};
+ const dataSourceServiceMock = new DataSourceServiceMock();
const useCoreMock = {
- services: coreMock.createStart(),
+ services: { ...coreMock.createStart(), dataSource: dataSourceServiceMock },
};
useCoreMock.services.http.put.mockImplementation(() => Promise.resolve());
useCoreMock.services.http.delete.mockImplementation(() => Promise.resolve());
diff --git a/public/tabs/history/__tests__/delete_conversation_confirm_modal.test.tsx b/public/tabs/history/__tests__/delete_conversation_confirm_modal.test.tsx
index 764bd8a4..64530e70 100644
--- a/public/tabs/history/__tests__/delete_conversation_confirm_modal.test.tsx
+++ b/public/tabs/history/__tests__/delete_conversation_confirm_modal.test.tsx
@@ -15,10 +15,13 @@ import {
DeleteConversationConfirmModalProps,
} from '../delete_conversation_confirm_modal';
import { HttpHandler } from '../../../../../../src/core/public';
+import { DataSourceServiceMock } from '../../../services/data_source_service.mock';
const setup = ({ onClose, conversationId }: DeleteConversationConfirmModalProps) => {
+ const dataSourceServiceMock = new DataSourceServiceMock();
+
const useCoreMock = {
- services: coreMock.createStart(),
+ services: { ...coreMock.createStart(), dataSource: dataSourceServiceMock },
};
jest.spyOn(coreContextExports, 'useCore').mockReturnValue(useCoreMock);
@@ -124,7 +127,9 @@ describe('', () => {
expect(useCoreMock.services.http.delete).not.toHaveBeenCalled();
fireEvent.click(renderResult.getByTestId('confirmModalConfirmButton'));
- expect(useCoreMock.services.http.delete).toHaveBeenCalled();
+ await waitFor(() => {
+ expect(useCoreMock.services.http.delete).toHaveBeenCalled();
+ });
fireEvent.click(renderResult.getByTestId('confirmModalCancelButton'));
diff --git a/server/routes/chat_routes.test.ts b/server/routes/chat_routes.test.ts
index 41f9c195..cbe03ad9 100644
--- a/server/routes/chat_routes.test.ts
+++ b/server/routes/chat_routes.test.ts
@@ -14,9 +14,24 @@ import { mockOllyChatService } from '../services/chat/olly_chat_service.mock';
import { loggerMock } from '../../../../src/core/server/logging/logger.mock';
import { registerChatRoutes } from './chat_routes';
import { ASSISTANT_API } from '../../common/constants/llm';
+import { getOpenSearchClientTransport } from '../utils/get_opensearch_client_transport';
-const mockedLogger = loggerMock.create();
+jest.mock('../utils/get_opensearch_client_transport');
+
+beforeEach(() => {
+ (getOpenSearchClientTransport as jest.Mock).mockImplementation(({ dataSourceId }) => {
+ if (dataSourceId) {
+ return 'dataSource-client';
+ } else {
+ return 'client';
+ }
+ });
+});
+afterEach(() => {
+ (getOpenSearchClientTransport as jest.Mock).mockClear();
+});
+const mockedLogger = loggerMock.create();
const router = new Router(
'',
mockedLogger,
@@ -30,38 +45,90 @@ registerChatRoutes(router, {
messageParsers: [],
});
-const triggerDeleteConversation = (conversationId: string) =>
+const triggerDeleteConversation = (conversationId: string, dataSourceId?: string) =>
triggerHandler(router, {
method: 'delete',
path: `${ASSISTANT_API.CONVERSATION}/{conversationId}`,
- req: httpServerMock.createRawRequest({ params: { conversationId } }),
+ req: httpServerMock.createRawRequest({
+ params: { conversationId },
+ ...(dataSourceId
+ ? {
+ query: {
+ dataSourceId,
+ },
+ }
+ : {}),
+ }),
});
const triggerUpdateConversation = (
params: { conversationId: string },
- payload: { title: string }
+ payload: { title: string },
+ dataSourceId?: string
) =>
triggerHandler(router, {
method: 'put',
path: `${ASSISTANT_API.CONVERSATION}/{conversationId}`,
- req: httpServerMock.createRawRequest({ params, payload }),
+ req: httpServerMock.createRawRequest({
+ params,
+ payload,
+ ...(dataSourceId
+ ? {
+ query: {
+ dataSourceId,
+ },
+ }
+ : {}),
+ }),
});
-const triggerGetTrace = (interactionId: string) =>
+const triggerGetTrace = (interactionId: string, dataSourceId?: string) =>
triggerHandler(router, {
method: 'get',
path: `${ASSISTANT_API.TRACE}/{interactionId}`,
- req: httpServerMock.createRawRequest({ params: { interactionId } }),
+ req: httpServerMock.createRawRequest({
+ params: { interactionId },
+ ...(dataSourceId
+ ? {
+ query: {
+ dataSourceId,
+ },
+ }
+ : {}),
+ }),
});
-const triggerAbortAgentExecution = (conversationId: string) =>
+const triggerAbortAgentExecution = (conversationId: string, dataSourceId?: string) =>
triggerHandler(router, {
method: 'post',
path: ASSISTANT_API.ABORT_AGENT_EXECUTION,
- req: httpServerMock.createRawRequest({ payload: { conversationId } }),
+ req: httpServerMock.createRawRequest({
+ payload: { conversationId },
+ ...(dataSourceId
+ ? {
+ query: {
+ dataSourceId,
+ },
+ }
+ : {}),
+ }),
});
-const triggerFeedback = (params: { interactionId: string }, payload: { satisfaction: boolean }) =>
+const triggerFeedback = (
+ params: { interactionId: string },
+ payload: { satisfaction: boolean },
+ dataSourceId?: string
+) =>
triggerHandler(router, {
method: 'put',
path: `${ASSISTANT_API.FEEDBACK}/{interactionId}`,
- req: httpServerMock.createRawRequest({ params, payload }),
+ req: httpServerMock.createRawRequest({
+ params,
+ payload,
+ ...(dataSourceId
+ ? {
+ query: {
+ dataSourceId,
+ },
+ }
+ : {}),
+ }),
});
describe('chat routes', () => {
@@ -80,6 +147,22 @@ describe('chat routes', () => {
expect(mockAgentFrameworkStorageService.deleteConversation).not.toHaveBeenCalled();
const result = (await triggerDeleteConversation('foo')) as ResponseObject;
+ expect(getOpenSearchClientTransport.mock.results[0].value).toBe('client');
+ expect(mockAgentFrameworkStorageService.deleteConversation).toHaveBeenCalledWith('foo');
+ expect(result.source).toMatchInlineSnapshot(`
+ Object {
+ "success": true,
+ }
+ `);
+ });
+
+ it('should call delete conversation with passed data source id and get data source transport', async () => {
+ mockAgentFrameworkStorageService.deleteConversation.mockResolvedValueOnce({
+ success: true,
+ });
+ expect(mockAgentFrameworkStorageService.deleteConversation).not.toHaveBeenCalled();
+ const result = (await triggerDeleteConversation('foo', 'data_source_id')) as ResponseObject;
+ expect(getOpenSearchClientTransport.mock.results[0].value).toBe('dataSource-client');
expect(mockAgentFrameworkStorageService.deleteConversation).toHaveBeenCalledWith('foo');
expect(result.source).toMatchInlineSnapshot(`
Object {
@@ -109,6 +192,7 @@ describe('chat routes', () => {
{ conversationId: 'foo' },
{ title: 'new-title' }
)) as ResponseObject;
+ expect(getOpenSearchClientTransport.mock.results[0].value).toBe('client');
expect(mockAgentFrameworkStorageService.updateConversation).toHaveBeenCalledWith(
'foo',
'new-title'
@@ -120,6 +204,29 @@ describe('chat routes', () => {
`);
});
+ it('should call update conversation with passed data source id and title then get data source transport', async () => {
+ mockAgentFrameworkStorageService.updateConversation.mockResolvedValueOnce({
+ success: true,
+ });
+
+ expect(mockAgentFrameworkStorageService.updateConversation).not.toHaveBeenCalled();
+ const result = (await triggerUpdateConversation(
+ { conversationId: 'foo' },
+ { title: 'new-title' },
+ 'data_source_id'
+ )) as ResponseObject;
+ expect(mockAgentFrameworkStorageService.updateConversation).toHaveBeenCalledWith(
+ 'foo',
+ 'new-title'
+ );
+ expect(getOpenSearchClientTransport.mock.results[0].value).toBe('dataSource-client');
+ expect(result.source).toMatchInlineSnapshot(`
+ Object {
+ "success": true,
+ }
+ `);
+ });
+
it('should log error and return 500 error when failed to update conversation', async () => {
mockAgentFrameworkStorageService.updateConversation.mockRejectedValueOnce(new Error());
@@ -149,6 +256,27 @@ describe('chat routes', () => {
expect(mockAgentFrameworkStorageService.getTraces).not.toHaveBeenCalled();
const result = (await triggerGetTrace('interaction-1')) as ResponseObject;
+ expect(getOpenSearchClientTransport.mock.results[0].value).toBe('client');
+ expect(mockAgentFrameworkStorageService.getTraces).toHaveBeenCalledWith('interaction-1');
+ expect(result.source).toEqual(getTraceResultMock);
+ });
+
+ it('should call get traces with passed data source id and get data source transport', async () => {
+ const getTraceResultMock = [
+ {
+ interactionId: 'interaction-1',
+ createTime: '',
+ input: 'foo',
+ output: 'bar',
+ origin: '',
+ traceNumber: 0,
+ },
+ ];
+ mockAgentFrameworkStorageService.getTraces.mockResolvedValueOnce(getTraceResultMock);
+
+ expect(mockAgentFrameworkStorageService.getTraces).not.toHaveBeenCalled();
+ const result = (await triggerGetTrace('interaction-1', 'data_source_id')) as ResponseObject;
+ expect(getOpenSearchClientTransport.mock.results[0].value).toBe('dataSource-client');
expect(mockAgentFrameworkStorageService.getTraces).toHaveBeenCalledWith('interaction-1');
expect(result.source).toEqual(getTraceResultMock);
});
@@ -169,6 +297,16 @@ describe('chat routes', () => {
await triggerAbortAgentExecution('foo');
expect(mockOllyChatService.abortAgentExecution).toHaveBeenCalledWith('foo');
+ expect(getOpenSearchClientTransport.mock.results[0].value).toBe('client');
+ expect(mockedLogger.info).toHaveBeenCalledWith('Abort agent execution: foo');
+ });
+
+ it('should call get abort agent with passed data source id and get data source transport ', async () => {
+ expect(mockOllyChatService.abortAgentExecution).not.toHaveBeenCalled();
+
+ await triggerAbortAgentExecution('foo', 'data_source_id');
+ expect(mockOllyChatService.abortAgentExecution).toHaveBeenCalledWith('foo');
+ expect(getOpenSearchClientTransport.mock.results[0].value).toBe('dataSource-client');
expect(mockedLogger.info).toHaveBeenCalledWith('Abort agent execution: foo');
});
@@ -200,6 +338,31 @@ describe('chat routes', () => {
{ interactionId: 'foo' },
{ satisfaction: true }
)) as ResponseObject;
+ expect(getOpenSearchClientTransport.mock.results[0].value).toBe('client');
+ expect(mockAgentFrameworkStorageService.updateInteraction).toHaveBeenCalledWith('foo', {
+ feedback: {
+ satisfaction: true,
+ },
+ });
+ expect(result.source).toMatchInlineSnapshot(`
+ Object {
+ "success": true,
+ }
+ `);
+ });
+
+ it('should call update interaction with passed data source id and get data source transport', async () => {
+ mockAgentFrameworkStorageService.updateConversation.mockResolvedValueOnce({
+ success: true,
+ });
+
+ expect(mockAgentFrameworkStorageService.updateConversation).not.toHaveBeenCalled();
+ const result = (await triggerFeedback(
+ { interactionId: 'foo' },
+ { satisfaction: true },
+ 'data_source_id'
+ )) as ResponseObject;
+ expect(getOpenSearchClientTransport.mock.results[0].value).toBe('dataSource-client');
expect(mockAgentFrameworkStorageService.updateInteraction).toHaveBeenCalledWith('foo', {
feedback: {
satisfaction: true,
diff --git a/server/routes/chat_routes.ts b/server/routes/chat_routes.ts
index 0a2b7b51..e0809be5 100644
--- a/server/routes/chat_routes.ts
+++ b/server/routes/chat_routes.ts
@@ -17,6 +17,7 @@ import { OllyChatService } from '../services/chat/olly_chat_service';
import { AgentFrameworkStorageService } from '../services/storage/agent_framework_storage_service';
import { RoutesOptions } from '../types';
import { ChatService } from '../services/chat/chat_service';
+import { getOpenSearchClientTransport } from '../utils/get_opensearch_client_transport';
const llmRequestRoute = {
path: ASSISTANT_API.SEND_MESSAGE,
@@ -33,6 +34,9 @@ const llmRequestRoute = {
contentType: schema.literal('text'),
}),
}),
+ query: schema.object({
+ dataSourceId: schema.maybe(schema.string()),
+ }),
},
};
export type LLMRequestSchema = TypeOf;
@@ -43,6 +47,9 @@ const getConversationRoute = {
params: schema.object({
conversationId: schema.string(),
}),
+ query: schema.object({
+ dataSourceId: schema.maybe(schema.string()),
+ }),
},
};
export type GetConversationSchema = TypeOf;
@@ -53,6 +60,9 @@ const abortAgentExecutionRoute = {
body: schema.object({
conversationId: schema.string(),
}),
+ query: schema.object({
+ dataSourceId: schema.maybe(schema.string()),
+ }),
},
};
export type AbortAgentExecutionSchema = TypeOf;
@@ -64,6 +74,9 @@ const regenerateRoute = {
conversationId: schema.string(),
interactionId: schema.string(),
}),
+ query: schema.object({
+ dataSourceId: schema.maybe(schema.string()),
+ }),
},
};
export type RegenerateSchema = TypeOf;
@@ -79,6 +92,7 @@ const getConversationsRoute = {
fields: schema.maybe(schema.arrayOf(schema.string())),
search: schema.maybe(schema.string()),
searchFields: schema.maybe(schema.oneOf([schema.string(), schema.arrayOf(schema.string())])),
+ dataSourceId: schema.maybe(schema.string()),
}),
},
};
@@ -90,6 +104,9 @@ const deleteConversationRoute = {
params: schema.object({
conversationId: schema.string(),
}),
+ query: schema.object({
+ dataSourceId: schema.maybe(schema.string()),
+ }),
},
};
@@ -102,6 +119,9 @@ const updateConversationRoute = {
body: schema.object({
title: schema.string(),
}),
+ query: schema.object({
+ dataSourceId: schema.maybe(schema.string()),
+ }),
},
};
@@ -111,6 +131,9 @@ const getTracesRoute = {
params: schema.object({
interactionId: schema.string(),
}),
+ query: schema.object({
+ dataSourceId: schema.maybe(schema.string()),
+ }),
},
};
@@ -123,16 +146,20 @@ const feedbackRoute = {
body: schema.object({
satisfaction: schema.boolean(),
}),
+ query: schema.object({
+ dataSourceId: schema.maybe(schema.string()),
+ }),
},
};
export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions) {
- const createStorageService = (context: RequestHandlerContext) =>
+ const createStorageService = async (context: RequestHandlerContext, dataSourceId?: string) =>
new AgentFrameworkStorageService(
- context.core.opensearch.client.asCurrentUser,
+ await getOpenSearchClientTransport({ context, dataSourceId }),
routeOptions.messageParsers
);
- const createChatService = (context: RequestHandlerContext) => new OllyChatService(context);
+ const createChatService = async (context: RequestHandlerContext, dataSourceId?: string) =>
+ new OllyChatService(await getOpenSearchClientTransport({ context, dataSourceId }));
router.post(
llmRequestRoute,
@@ -142,8 +169,8 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions)
response
): Promise> => {
const { messages = [], input, conversationId: conversationIdInRequestBody } = request.body;
- const storageService = createStorageService(context);
- const chatService = createChatService(context);
+ const storageService = await createStorageService(context, request.query.dataSourceId);
+ const chatService = await createChatService(context, request.query.dataSourceId);
let outputs: Awaited> | undefined;
@@ -217,7 +244,7 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions)
request,
response
): Promise> => {
- const storageService = createStorageService(context);
+ const storageService = await createStorageService(context, request.query.dataSourceId);
try {
const getResponse = await storageService.getConversation(request.params.conversationId);
@@ -236,7 +263,7 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions)
request,
response
): Promise> => {
- const storageService = createStorageService(context);
+ const storageService = await createStorageService(context, request.query.dataSourceId);
try {
const getResponse = await storageService.getConversations(request.query);
@@ -255,7 +282,7 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions)
request,
response
): Promise> => {
- const storageService = createStorageService(context);
+ const storageService = await createStorageService(context, request.query.dataSourceId);
try {
const getResponse = await storageService.deleteConversation(request.params.conversationId);
@@ -274,7 +301,7 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions)
request,
response
): Promise> => {
- const storageService = createStorageService(context);
+ const storageService = await createStorageService(context, request.query.dataSourceId);
try {
const getResponse = await storageService.updateConversation(
@@ -296,7 +323,7 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions)
request,
response
): Promise> => {
- const storageService = createStorageService(context);
+ const storageService = await createStorageService(context, request.query.dataSourceId);
try {
const getResponse = await storageService.getTraces(request.params.interactionId);
@@ -315,7 +342,7 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions)
request,
response
): Promise> => {
- const chatService = createChatService(context, '');
+ const chatService = await createChatService(context, request.query.dataSourceId);
try {
chatService.abortAgentExecution(request.body.conversationId);
context.assistant_plugin.logger.info(
@@ -337,8 +364,8 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions)
response
): Promise> => {
const { conversationId, interactionId } = request.body;
- const storageService = createStorageService(context);
- const chatService = createChatService(context);
+ const storageService = await createStorageService(context, request.query.dataSourceId);
+ const chatService = await createChatService(context, request.query.dataSourceId);
let outputs: Awaited> | undefined;
@@ -386,7 +413,7 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions)
request,
response
): Promise> => {
- const storageService = createStorageService(context);
+ const storageService = await createStorageService(context, request.query.dataSourceId);
const { interactionId } = request.params;
try {
diff --git a/server/services/chat/olly_chat_service.test.ts b/server/services/chat/olly_chat_service.test.ts
index 936f7fc5..49362dcb 100644
--- a/server/services/chat/olly_chat_service.test.ts
+++ b/server/services/chat/olly_chat_service.test.ts
@@ -6,30 +6,23 @@
import { OllyChatService } from './olly_chat_service';
import { CoreRouteHandlerContext } from '../../../../../src/core/server/core_route_handler_context';
import { coreMock, httpServerMock } from '../../../../../src/core/server/mocks';
-import { loggerMock } from '../../../../../src/core/server/logging/logger.mock';
-import { ResponseError } from '@opensearch-project/opensearch/lib/errors';
-import { ApiResponse } from '@opensearch-project/opensearch';
describe('OllyChatService', () => {
const coreContext = new CoreRouteHandlerContext(
coreMock.createInternalStart(),
httpServerMock.createOpenSearchDashboardsRequest()
);
- const mockedTransport = coreContext.opensearch.client.asCurrentUser.transport
- .request as jest.Mock;
- const contextMock = {
- core: coreContext,
- assistant_plugin: {
- logger: loggerMock.create(),
- },
- };
- const ollyChatService: OllyChatService = new OllyChatService(contextMock);
+ const mockedTransport = coreContext.opensearch.client.asCurrentUser.transport;
+ const mockedTransportRequest = mockedTransport.request as jest.Mock;
+
+ const ollyChatService: OllyChatService = new OllyChatService(mockedTransport);
+
beforeEach(async () => {
- mockedTransport.mockClear();
+ mockedTransportRequest.mockClear();
});
it('requestLLM should invoke client call with correct params', async () => {
- mockedTransport.mockImplementation((args) => {
+ mockedTransportRequest.mockImplementation((args) => {
if (args.path === '/_plugins/_ml/config/os_chat') {
return {
body: {
@@ -65,7 +58,7 @@ describe('OllyChatService', () => {
},
conversationId: 'conversationId',
});
- expect(mockedTransport.mock.calls).toMatchInlineSnapshot(`
+ expect(mockedTransportRequest.mock.calls).toMatchInlineSnapshot(`
Array [
Array [
Object {
@@ -102,7 +95,7 @@ describe('OllyChatService', () => {
});
it('requestLLM should throw error when transport.request throws error', async () => {
- mockedTransport
+ mockedTransportRequest
.mockImplementationOnce(() => {
return {
body: {
@@ -124,13 +117,14 @@ describe('OllyChatService', () => {
contentType: 'text',
content: 'content',
},
+
conversationId: '',
})
).rejects.toMatchInlineSnapshot(`[Error: error]`);
});
it('regenerate should invoke client call with correct params', async () => {
- mockedTransport.mockImplementation((args) => {
+ mockedTransportRequest.mockImplementation((args) => {
if (args.path === '/_plugins/_ml/config/os_chat') {
return {
body: {
@@ -161,7 +155,7 @@ describe('OllyChatService', () => {
conversationId: 'conversationId',
interactionId: 'interactionId',
});
- expect(mockedTransport.mock.calls).toMatchInlineSnapshot(`
+ expect(mockedTransportRequest.mock.calls).toMatchInlineSnapshot(`
Array [
Array [
Object {
@@ -198,7 +192,7 @@ describe('OllyChatService', () => {
});
it('regenerate should throw error when transport.request throws error', async () => {
- mockedTransport
+ mockedTransportRequest
.mockImplementationOnce(() => {
return {
body: {
@@ -221,7 +215,7 @@ describe('OllyChatService', () => {
});
it('fetching root agent id throws error', async () => {
- mockedTransport.mockImplementationOnce(() => {
+ mockedTransportRequest.mockImplementationOnce(() => {
return {
body: {
hits: {
diff --git a/server/services/chat/olly_chat_service.ts b/server/services/chat/olly_chat_service.ts
index cb239079..bdebec5a 100644
--- a/server/services/chat/olly_chat_service.ts
+++ b/server/services/chat/olly_chat_service.ts
@@ -4,7 +4,7 @@
*/
import { ApiResponse } from '@opensearch-project/opensearch';
-import { RequestHandlerContext } from '../../../../../src/core/server';
+import { OpenSearchClient } from '../../../../../src/core/server';
import { IMessage, IInput } from '../../../common/types/chat_saved_object_attributes';
import { ChatService } from './chat_service';
import { ML_COMMONS_BASE_API, ROOT_AGENT_CONFIG_ID } from '../../utils/constants';
@@ -22,13 +22,12 @@ const INTERACTION_ID_FIELD = 'parent_interaction_id';
export class OllyChatService implements ChatService {
static abortControllers: Map = new Map();
- constructor(private readonly context: RequestHandlerContext) {}
+ constructor(private readonly opensearchClientTransport: OpenSearchClient['transport']) {}
private async getRootAgent(): Promise {
try {
- const opensearchClient = this.context.core.opensearch.client.asCurrentUser;
const path = `${ML_COMMONS_BASE_API}/config/${ROOT_AGENT_CONFIG_ID}`;
- const response = await opensearchClient.transport.request({
+ const response = await this.opensearchClientTransport.request({
method: 'GET',
path,
});
@@ -53,9 +52,8 @@ export class OllyChatService implements ChatService {
}
private async callExecuteAgentAPI(payload: AgentRunPayload, rootAgentId: string) {
- const opensearchClient = this.context.core.opensearch.client.asCurrentUser;
try {
- const agentFrameworkResponse = (await opensearchClient.transport.request(
+ const agentFrameworkResponse = (await this.opensearchClientTransport.request(
{
method: 'POST',
path: `${ML_COMMONS_BASE_API}/agents/${rootAgentId}/_execute`,
diff --git a/server/services/storage/agent_framework_storage_service.test.ts b/server/services/storage/agent_framework_storage_service.test.ts
index e1b5781a..aa1a77b0 100644
--- a/server/services/storage/agent_framework_storage_service.test.ts
+++ b/server/services/storage/agent_framework_storage_service.test.ts
@@ -12,17 +12,15 @@ describe('AgentFrameworkStorageService', () => {
coreMock.createInternalStart(),
httpServerMock.createOpenSearchDashboardsRequest()
);
- const mockedTransport = coreContext.opensearch.client.asCurrentUser.transport
- .request as jest.Mock;
- const agentFrameworkService = new AgentFrameworkStorageService(
- coreContext.opensearch.client.asCurrentUser,
- []
- );
+ const mockedTransport = coreContext.opensearch.client.asCurrentUser.transport;
+ const mockedTransportRequest = mockedTransport.request as jest.Mock;
+
+ const agentFrameworkService = new AgentFrameworkStorageService(mockedTransport, []);
beforeEach(() => {
- mockedTransport.mockReset();
+ mockedTransportRequest.mockReset();
});
it('getConversation', async () => {
- mockedTransport.mockImplementation(async (params) => {
+ mockedTransportRequest.mockImplementation(async (params) => {
if (params.path.includes('/messages?max_results=1000')) {
return {
body: {
@@ -62,7 +60,7 @@ describe('AgentFrameworkStorageService', () => {
"updatedTimeMs": 0,
}
`);
- expect(mockedTransport.mock.calls).toMatchInlineSnapshot(`
+ expect(mockedTransportRequest.mock.calls).toMatchInlineSnapshot(`
Array [
Array [
Object {
@@ -81,14 +79,14 @@ describe('AgentFrameworkStorageService', () => {
});
it('should encode id when calls getConversation with non-standard params in request payload', async () => {
- mockedTransport.mockResolvedValue({
+ mockedTransportRequest.mockResolvedValue({
body: {
messages: [],
},
});
await agentFrameworkService.getConversation('../non-standard/id');
- expect(mockedTransport.mock.calls[0]).toMatchInlineSnapshot(`
+ expect(mockedTransportRequest.mock.calls[0]).toMatchInlineSnapshot(`
Array [
Object {
"method": "GET",
@@ -99,7 +97,7 @@ describe('AgentFrameworkStorageService', () => {
});
it('getConversations', async () => {
- mockedTransport.mockImplementation(async (params) => {
+ mockedTransportRequest.mockImplementation(async (params) => {
return {
body: {
hits: {
@@ -165,7 +163,7 @@ describe('AgentFrameworkStorageService', () => {
"total": 10,
}
`);
- expect(mockedTransport.mock.calls).toMatchInlineSnapshot(`
+ expect(mockedTransportRequest.mock.calls).toMatchInlineSnapshot(`
Array [
Array [
Object {
@@ -214,7 +212,7 @@ describe('AgentFrameworkStorageService', () => {
});
it('deleteConversation', async () => {
- mockedTransport.mockImplementationOnce(async (params) => ({
+ mockedTransportRequest.mockImplementationOnce(async (params) => ({
statusCode: 200,
}));
expect(agentFrameworkService.deleteConversation('foo')).resolves.toMatchInlineSnapshot(`
@@ -222,7 +220,7 @@ describe('AgentFrameworkStorageService', () => {
"success": true,
}
`);
- mockedTransport.mockImplementationOnce(async (params) => ({
+ mockedTransportRequest.mockImplementationOnce(async (params) => ({
statusCode: 404,
body: {
message: 'can not find conversation',
@@ -235,18 +233,18 @@ describe('AgentFrameworkStorageService', () => {
"success": false,
}
`);
- mockedTransport.mockImplementationOnce(async (params) => {
+ mockedTransportRequest.mockImplementationOnce(async (params) => {
return Promise.reject({ meta: { body: 'error' } });
});
expect(agentFrameworkService.deleteConversation('foo')).rejects.toBeDefined();
});
it('should encode id when calls deleteConversation with non-standard params in request payload', async () => {
- mockedTransport.mockResolvedValueOnce({
+ mockedTransportRequest.mockResolvedValueOnce({
statusCode: 200,
});
await agentFrameworkService.deleteConversation('../non-standard/id');
- expect(mockedTransport.mock.calls[0]).toMatchInlineSnapshot(`
+ expect(mockedTransportRequest.mock.calls[0]).toMatchInlineSnapshot(`
Array [
Object {
"method": "DELETE",
@@ -257,7 +255,7 @@ describe('AgentFrameworkStorageService', () => {
});
it('updateConversation', async () => {
- mockedTransport.mockImplementationOnce(async (params) => ({
+ mockedTransportRequest.mockImplementationOnce(async (params) => ({
statusCode: 200,
}));
expect(agentFrameworkService.updateConversation('foo', 'title')).resolves
@@ -266,7 +264,7 @@ describe('AgentFrameworkStorageService', () => {
"success": true,
}
`);
- mockedTransport.mockImplementationOnce(async (params) => ({
+ mockedTransportRequest.mockImplementationOnce(async (params) => ({
statusCode: 404,
body: {
message: 'can not find conversation',
@@ -280,18 +278,18 @@ describe('AgentFrameworkStorageService', () => {
"success": false,
}
`);
- mockedTransport.mockImplementationOnce(async (params) => {
+ mockedTransportRequest.mockImplementationOnce(async (params) => {
return Promise.reject({ meta: { body: 'error' } });
});
expect(agentFrameworkService.updateConversation('foo', 'title')).rejects.toBeDefined();
});
it('should encode id when calls updateConversation with non-standard params in request payload', async () => {
- mockedTransport.mockResolvedValueOnce({
+ mockedTransportRequest.mockResolvedValueOnce({
statusCode: 200,
});
await agentFrameworkService.updateConversation('../non-standard/id', 'title');
- expect(mockedTransport.mock.calls[0]).toMatchInlineSnapshot(`
+ expect(mockedTransportRequest.mock.calls[0]).toMatchInlineSnapshot(`
Array [
Object {
"body": Object {
@@ -305,7 +303,7 @@ describe('AgentFrameworkStorageService', () => {
});
it('getTraces', async () => {
- mockedTransport.mockImplementationOnce(async (params) => ({
+ mockedTransportRequest.mockImplementationOnce(async (params) => ({
body: {
traces: [
{
@@ -331,7 +329,7 @@ describe('AgentFrameworkStorageService', () => {
},
]
`);
- mockedTransport.mockImplementationOnce(async (params) => {
+ mockedTransportRequest.mockImplementationOnce(async (params) => {
return Promise.reject({ meta: { body: 'error' } });
});
expect(agentFrameworkService.getTraces('foo')).rejects.toMatchInlineSnapshot(`
@@ -344,7 +342,7 @@ describe('AgentFrameworkStorageService', () => {
});
it('should encode id when calls getTraces with non-standard params in request payload', async () => {
- mockedTransport.mockResolvedValueOnce({
+ mockedTransportRequest.mockResolvedValueOnce({
body: {
traces: [
{
@@ -359,7 +357,7 @@ describe('AgentFrameworkStorageService', () => {
},
});
await agentFrameworkService.getTraces('../non-standard/id');
- expect(mockedTransport.mock.calls[0]).toMatchInlineSnapshot(`
+ expect(mockedTransportRequest.mock.calls[0]).toMatchInlineSnapshot(`
Array [
Object {
"method": "GET",
@@ -370,7 +368,7 @@ describe('AgentFrameworkStorageService', () => {
});
it('updateInteraction', async () => {
- mockedTransport.mockImplementationOnce(async (params) => ({
+ mockedTransportRequest.mockImplementationOnce(async (params) => ({
statusCode: 200,
}));
expect(
@@ -384,7 +382,7 @@ describe('AgentFrameworkStorageService', () => {
"success": true,
}
`);
- mockedTransport.mockImplementationOnce(async (params) => ({
+ mockedTransportRequest.mockImplementationOnce(async (params) => ({
statusCode: 404,
body: {
message: 'can not find conversation',
@@ -403,7 +401,7 @@ describe('AgentFrameworkStorageService', () => {
"success": false,
}
`);
- mockedTransport.mockImplementationOnce(async (params) => {
+ mockedTransportRequest.mockImplementationOnce(async (params) => {
return Promise.reject({ meta: { body: 'error' } });
});
expect(
@@ -422,7 +420,7 @@ describe('AgentFrameworkStorageService', () => {
});
it('should encode id when calls updateInteraction with non-standard params in request payload', async () => {
- mockedTransport.mockResolvedValueOnce({
+ mockedTransportRequest.mockResolvedValueOnce({
statusCode: 200,
});
await agentFrameworkService.updateInteraction('../non-standard/id', {
@@ -430,7 +428,7 @@ describe('AgentFrameworkStorageService', () => {
bar: 'foo',
},
});
- expect(mockedTransport.mock.calls[0]).toMatchInlineSnapshot(`
+ expect(mockedTransportRequest.mock.calls[0]).toMatchInlineSnapshot(`
Array [
Object {
"body": Object {
@@ -448,7 +446,7 @@ describe('AgentFrameworkStorageService', () => {
});
it('getInteraction', async () => {
- mockedTransport.mockImplementation(async (params) => ({
+ mockedTransportRequest.mockImplementation(async (params) => ({
body: {
input: 'input',
response: 'response',
@@ -460,7 +458,7 @@ describe('AgentFrameworkStorageService', () => {
expect(agentFrameworkService.getInteraction('_id', '')).rejects.toMatchInlineSnapshot(
`[Error: interactionId is required]`
);
- expect(mockedTransport).toBeCalledTimes(0);
+ expect(mockedTransportRequest).toBeCalledTimes(0);
expect(agentFrameworkService.getInteraction('_id', 'interaction_id')).resolves
.toMatchInlineSnapshot(`
Object {
@@ -470,18 +468,18 @@ describe('AgentFrameworkStorageService', () => {
"response": "response",
}
`);
- expect(mockedTransport).toBeCalledTimes(1);
+ expect(mockedTransportRequest).toBeCalledTimes(1);
});
it('should encode id when calls getInteraction with non-standard params in request payload', async () => {
- mockedTransport.mockResolvedValueOnce({
+ mockedTransportRequest.mockResolvedValueOnce({
body: {
input: 'input',
response: 'response',
},
});
await agentFrameworkService.getInteraction('_id', '../non-standard/id');
- expect(mockedTransport.mock.calls[0]).toMatchInlineSnapshot(`
+ expect(mockedTransportRequest.mock.calls[0]).toMatchInlineSnapshot(`
Array [
Object {
"method": "GET",
diff --git a/server/services/storage/agent_framework_storage_service.ts b/server/services/storage/agent_framework_storage_service.ts
index f0f1a735..05f57b19 100644
--- a/server/services/storage/agent_framework_storage_service.ts
+++ b/server/services/storage/agent_framework_storage_service.ts
@@ -28,12 +28,12 @@ export interface ConversationOptResponse {
export class AgentFrameworkStorageService implements StorageService {
constructor(
- private readonly client: OpenSearchClient,
+ private readonly clientTransport: OpenSearchClient['transport'],
private readonly messageParsers: MessageParser[] = []
) {}
async getConversation(conversationId: string): Promise {
const [interactionsResp, conversation] = await Promise.all([
- this.client.transport.request({
+ this.clientTransport.request({
method: 'GET',
path: `${ML_COMMONS_BASE_API}/memory/${encodeURIComponent(
conversationId
@@ -43,7 +43,7 @@ export class AgentFrameworkStorageService implements StorageService {
messages: InteractionFromAgentFramework[];
}>
>,
- this.client.transport.request({
+ this.clientTransport.request({
method: 'GET',
path: `${ML_COMMONS_BASE_API}/memory/${encodeURIComponent(conversationId)}`,
}) as TransportRequestPromise<
@@ -103,7 +103,7 @@ export class AgentFrameworkStorageService implements StorageService {
...(sortField && query.sortOrder && { sort: [{ [sortField]: query.sortOrder }] }),
};
- const conversations = await this.client.transport.request({
+ const conversations = await this.clientTransport.request({
method: 'GET',
path: `${ML_COMMONS_BASE_API}/memory/_search`,
body: requestParams,
@@ -146,7 +146,7 @@ export class AgentFrameworkStorageService implements StorageService {
}
async deleteConversation(conversationId: string): Promise {
- const response = await this.client.transport.request({
+ const response = await this.clientTransport.request({
method: 'DELETE',
path: `${ML_COMMONS_BASE_API}/memory/${encodeURIComponent(conversationId)}`,
});
@@ -167,7 +167,7 @@ export class AgentFrameworkStorageService implements StorageService {
conversationId: string,
title: string
): Promise {
- const response = await this.client.transport.request({
+ const response = await this.clientTransport.request({
method: 'PUT',
path: `${ML_COMMONS_BASE_API}/memory/${encodeURIComponent(conversationId)}`,
body: {
@@ -188,7 +188,7 @@ export class AgentFrameworkStorageService implements StorageService {
}
async getTraces(interactionId: string): Promise {
- const response = (await this.client.transport.request({
+ const response = (await this.clientTransport.request({
method: 'GET',
path: `${ML_COMMONS_BASE_API}/memory/message/${encodeURIComponent(interactionId)}/traces`,
})) as ApiResponse<{
@@ -216,7 +216,7 @@ export class AgentFrameworkStorageService implements StorageService {
interactionId: string,
additionalInfo: Record>
): Promise {
- const response = await this.client.transport.request({
+ const response = await this.clientTransport.request({
method: 'PUT',
path: `${ML_COMMONS_BASE_API}/memory/message/${encodeURIComponent(interactionId)}`,
body: {
@@ -260,7 +260,7 @@ export class AgentFrameworkStorageService implements StorageService {
if (!interactionId) {
throw new Error('interactionId is required');
}
- const interactionsResp = (await this.client.transport.request({
+ const interactionsResp = (await this.clientTransport.request({
method: 'GET',
path: `${ML_COMMONS_BASE_API}/memory/message/${encodeURIComponent(interactionId)}`,
})) as ApiResponse;
diff --git a/server/utils/get_opensearch_client_transport.test.ts b/server/utils/get_opensearch_client_transport.test.ts
new file mode 100644
index 00000000..4b37577a
--- /dev/null
+++ b/server/utils/get_opensearch_client_transport.test.ts
@@ -0,0 +1,44 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import { getOpenSearchClientTransport } from './get_opensearch_client_transport';
+import { coreMock } from '../../../../src/core/server/mocks';
+import { loggerMock } from '../../../../src/core/server/logging/logger.mock';
+
+const mockedLogger = loggerMock.create();
+
+describe('getOpenSearchClientTransport', () => {
+ it('should return current user opensearch transport', async () => {
+ const core = coreMock.createRequestHandlerContext();
+
+ expect(
+ await getOpenSearchClientTransport({
+ context: { core, assistant_plugin: { logger: mockedLogger } },
+ })
+ ).toBe(core.opensearch.client.asCurrentUser.transport);
+ });
+ it('should return data source id related opensearch transport', async () => {
+ const transportMock = {};
+ const core = coreMock.createRequestHandlerContext();
+ const context = {
+ core,
+ assistant_plugin: { logger: mockedLogger },
+ dataSource: {
+ opensearch: {
+ getClient: async (_dataSourceId: string) => ({
+ transport: transportMock,
+ }),
+ },
+ },
+ };
+
+ expect(
+ await getOpenSearchClientTransport({
+ context,
+ dataSourceId: 'foo',
+ })
+ ).toBe(transportMock);
+ });
+});
diff --git a/server/utils/get_opensearch_client_transport.ts b/server/utils/get_opensearch_client_transport.ts
new file mode 100644
index 00000000..9f54a4ff
--- /dev/null
+++ b/server/utils/get_opensearch_client_transport.ts
@@ -0,0 +1,25 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import { OpenSearchClient, RequestHandlerContext } from '../../../../src/core/server';
+
+export const getOpenSearchClientTransport = async ({
+ context,
+ dataSourceId,
+}: {
+ context: RequestHandlerContext & {
+ dataSource?: {
+ opensearch: {
+ getClient: (dataSourceId: string) => Promise;
+ };
+ };
+ };
+ dataSourceId?: string;
+}) => {
+ if (dataSourceId && context.dataSource) {
+ return (await context.dataSource.opensearch.getClient(dataSourceId)).transport;
+ }
+ return context.core.opensearch.client.asCurrentUser.transport;
+};