diff --git a/opensearch_dashboards.json b/opensearch_dashboards.json index f286b75a..09e7df90 100644 --- a/opensearch_dashboards.json +++ b/opensearch_dashboards.json @@ -16,6 +16,6 @@ "dataSourceManagement" ], "configPath": [ - "assistant" + "assistant" ] } \ No newline at end of file diff --git a/public/components/__tests__/chat_window_header_title.test.tsx b/public/components/__tests__/chat_window_header_title.test.tsx index 025dd62c..5c87cd55 100644 --- a/public/components/__tests__/chat_window_header_title.test.tsx +++ b/public/components/__tests__/chat_window_header_title.test.tsx @@ -17,6 +17,7 @@ import * as coreContextExports from '../../contexts/core_context'; import { IMessage } from '../../../common/types/chat_saved_object_attributes'; import { ChatWindowHeaderTitle } from '../chat_window_header_title'; +import { DataSourceServiceMock } from '../../services/data_source_service.mock'; const setup = ({ messages = [], @@ -37,6 +38,7 @@ const setup = ({ }), reload: jest.fn(), }, + dataSource: new DataSourceServiceMock(), }, }; useCoreMock.services.http.put.mockImplementation(() => Promise.resolve()); diff --git a/public/components/__tests__/edit_conversation_name_modal.test.tsx b/public/components/__tests__/edit_conversation_name_modal.test.tsx index db594a02..c650fd60 100644 --- a/public/components/__tests__/edit_conversation_name_modal.test.tsx +++ b/public/components/__tests__/edit_conversation_name_modal.test.tsx @@ -15,10 +15,11 @@ import { EditConversationNameModalProps, } from '../edit_conversation_name_modal'; import { HttpHandler } from '../../../../../src/core/public'; +import { DataSourceServiceMock } from '../../services/data_source_service.mock'; const setup = ({ onClose, defaultTitle, conversationId }: EditConversationNameModalProps) => { const useCoreMock = { - services: coreMock.createStart(), + services: { ...coreMock.createStart(), dataSource: new DataSourceServiceMock() }, }; jest.spyOn(coreContextExports, 'useCore').mockReturnValue(useCoreMock); @@ -156,7 +157,10 @@ describe('', () => { expect(useCoreMock.services.http.put).not.toHaveBeenCalled(); fireEvent.click(renderResult.getByTestId('confirmModalConfirmButton')); - expect(useCoreMock.services.http.put).toHaveBeenCalled(); + + await waitFor(() => { + expect(useCoreMock.services.http.put).toHaveBeenCalled(); + }); fireEvent.click(renderResult.getByTestId('confirmModalCancelButton')); diff --git a/public/components/feedback_modal.tsx b/public/components/feedback_modal.tsx deleted file mode 100644 index e9889da3..00000000 --- a/public/components/feedback_modal.tsx +++ /dev/null @@ -1,309 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -import { - EuiButton, - EuiButtonEmpty, - EuiForm, - EuiFormRow, - EuiModal, - EuiModalBody, - EuiModalFooter, - EuiModalHeader, - EuiModalHeaderTitle, - EuiRadioGroup, - EuiTextArea, -} from '@elastic/eui'; -import React, { useState } from 'react'; -import { HttpStart } from '../../../../src/core/public'; -import { ASSISTANT_API } from '../../common/constants/llm'; -import { getCoreStart } from '../plugin'; - -export interface LabelData { - formHeader: string; - inputPlaceholder: string; - outputPlaceholder: string; -} - -export interface FeedbackFormData { - input: string; - output: string; - correct: boolean | undefined; - expectedOutput: string; - comment: string; -} - -interface FeedbackMetaData { - type: 'event_analytics' | 'chat' | 'ppl_submit'; - conversationId?: string; - interactionId?: string; - error?: boolean; - selectedIndex?: string; -} - -interface FeedbackModelProps { - input?: string; - output?: string; - metadata: FeedbackMetaData; - onClose: () => void; -} - -export const FeedbackModal: React.FC = (props) => { - const [formData, setFormData] = useState({ - input: props.input ?? '', - output: props.output ?? '', - correct: undefined, - expectedOutput: '', - comment: '', - }); - return ( - - - - ); -}; - -interface FeedbackModalContentProps { - formData: FeedbackFormData; - setFormData: React.Dispatch>; - metadata: FeedbackMetaData; - displayLabels?: Partial> & Partial; - onClose: () => void; -} - -export const FeedbackModalContent: React.FC = (props) => { - const core = getCoreStart(); - const labels: NonNullable> = Object.assign( - { - formHeader: 'Olly Skills Feedback', - inputPlaceholder: 'Your input question', - input: 'Input question', - outputPlaceholder: 'The LLM response', - output: 'Output', - correct: 'Does the output match your expectations?', - expectedOutput: 'Expected output', - comment: 'Comment', - }, - props.displayLabels - ); - const { loading, submitFeedback } = useSubmitFeedback(props.formData, props.metadata, core.http); - const [formErrors, setFormErrors] = useState< - Partial<{ [x in keyof FeedbackFormData]: string[] }> - >({ - input: [], - output: [], - expectedOutput: [], - }); - - const hasError = (key?: keyof FeedbackFormData) => { - if (!key) return Object.values(formErrors).some((e) => !!e.length); - return !!formErrors[key]?.length; - }; - - const onSubmit = async (event: React.FormEvent) => { - event.preventDefault(); - const errors = { - input: validator - .input(props.formData.input) - .concat(await validator.validateQuery(props.formData.input, props.metadata.type)), - output: validator.output(props.formData.output), - correct: validator.correct(props.formData.correct), - expectedOutput: validator.expectedOutput( - props.formData.expectedOutput, - props.formData.correct === false - ), - }; - if (Object.values(errors).some((e) => !!e.length)) { - setFormErrors(errors); - return; - } - - try { - await submitFeedback(); - props.setFormData({ - input: '', - output: '', - correct: undefined, - expectedOutput: '', - comment: '', - }); - core.notifications.toasts.addSuccess('Thanks for your feedback!'); - props.onClose(); - } catch (e) { - core.notifications.toasts.addError(e, { title: 'Failed to submit feedback' }); - } - }; - - return ( - <> - - {labels.formHeader} - - - - - - props.setFormData({ ...props.formData, input: e.target.value })} - onBlur={(e) => { - setFormErrors({ ...formErrors, input: validator.input(e.target.value) }); - }} - isInvalid={hasError('input')} - /> - - - props.setFormData({ ...props.formData, output: e.target.value })} - onBlur={(e) => { - setFormErrors({ ...formErrors, output: validator.output(e.target.value) }); - }} - isInvalid={hasError('output')} - /> - - {props.metadata.type !== 'ppl_submit' && ( - - { - props.setFormData({ ...props.formData, correct: id === 'yes' }); - setFormErrors({ ...formErrors, expectedOutput: [] }); - }} - onBlur={() => setFormErrors({ ...formErrors, correct: [] })} - /> - - )} - {props.formData.correct === false && ( - - - props.setFormData({ ...props.formData, expectedOutput: e.target.value }) - } - onBlur={(e) => { - setFormErrors({ - ...formErrors, - expectedOutput: validator.expectedOutput( - e.target.value, - props.formData.correct === false - ), - }); - }} - isInvalid={hasError('expectedOutput')} - /> - - )} - - props.setFormData({ ...props.formData, comment: e.target.value })} - /> - - - - - - Cancel - - Send - - - - ); -}; - -const useSubmitFeedback = (data: FeedbackFormData, metadata: FeedbackMetaData, http: HttpStart) => { - const [loading, setLoading] = useState(false); - return { - loading, - submitFeedback: async () => { - setLoading(true); - const auth = await http - .get<{ data: { user_name: string; user_requested_tenant: string; roles: string[] } }>( - '/api/v1/configuration/account' - ) - .then((res) => ({ user: res.data.user_name, tenant: res.data.user_requested_tenant })); - - return http - .post(ASSISTANT_API.FEEDBACK, { - body: JSON.stringify({ metadata: { ...metadata, ...auth }, ...data }), - }) - .finally(() => setLoading(false)); - }, - }; -}; - -const validatePPLQuery = async (logsQuery: string, feedBackType: FeedbackMetaData['type']) => { - return []; - // TODO remove - // let responseMessage: [] | string[] = []; - // const errorMessage = [' Invalid PPL Query, please re-check the ppl syntax']; - - // if (feedBackType === 'ppl_submit') { - // const pplService = getPPLService(); - // await pplService - // .fetch({ query: logsQuery, format: 'jdbc' }) - // .then((res) => { - // if (res === undefined) responseMessage = errorMessage; - // }) - // .catch((error: Error) => { - // responseMessage = errorMessage; - // }); - // } - // return responseMessage; -}; - -const validator = { - input: (text: string) => (text.trim().length === 0 ? ['Input is required'] : []), - output: (text: string) => (text.trim().length === 0 ? ['Output is required'] : []), - correct: (correct: boolean | undefined) => - correct === undefined ? ['Correctness is required'] : [], - expectedOutput: (text: string, required: boolean) => - required && text.trim().length === 0 ? ['expectedOutput is required'] : [], - validateQuery: async (logsQuery: string, feedBackType: FeedbackMetaData['type']) => - await validatePPLQuery(logsQuery, feedBackType), -}; diff --git a/public/contexts/__mocks__/core_context.tsx b/public/contexts/__mocks__/core_context.tsx index 5c1dab99..2d92d77f 100644 --- a/public/contexts/__mocks__/core_context.tsx +++ b/public/contexts/__mocks__/core_context.tsx @@ -5,6 +5,7 @@ import { BehaviorSubject } from 'rxjs'; import { coreMock } from '../../../../../src/core/public/mocks'; +import { DataSourceServiceMock } from '../../services/data_source_service.mock'; export const useCore = jest.fn(() => { const useCoreMock = { @@ -24,7 +25,7 @@ export const useCore = jest.fn(() => { load: jest.fn(), }, conversationLoad: {}, - dataSource: {}, + dataSource: new DataSourceServiceMock(), }, }; useCoreMock.services.http.delete.mockReturnValue(Promise.resolve()); diff --git a/public/contexts/core_context.tsx b/public/contexts/core_context.tsx index b4143154..6db5dd2e 100644 --- a/public/contexts/core_context.tsx +++ b/public/contexts/core_context.tsx @@ -15,6 +15,7 @@ export interface AssistantServices extends Required; @@ -76,6 +77,10 @@ describe('useDeleteConversation', () => { deleteConversationPromise = result.current.deleteConversation('foo'); }); + await waitFor(() => { + expect(useCoreMocked.mock.results[0].value.services.http.delete).toHaveBeenCalled(); + }); + let deleteConversationError; await act(async () => { result.current.abort(); @@ -160,6 +165,10 @@ describe('usePatchConversation', () => { patchConversationPromise = result.current.patchConversation('foo', 'new-title'); }); + await waitFor(() => { + expect(useCoreMocked.mock.results[0].value.services.http.put).toHaveBeenCalled(); + }); + let patchConversationError; await act(async () => { result.current.abort(); diff --git a/public/hooks/use_chat_actions.test.tsx b/public/hooks/use_chat_actions.test.tsx index aaa6c77a..6418ae31 100644 --- a/public/hooks/use_chat_actions.test.tsx +++ b/public/hooks/use_chat_actions.test.tsx @@ -13,6 +13,7 @@ import { ConversationLoadService } from '../services/conversation_load_service'; import * as chatStateHookExports from './use_chat_state'; import { ASSISTANT_API } from '../../common/constants/llm'; import { IMessage } from 'common/types/chat_saved_object_attributes'; +import { DataSourceServiceMock } from '../services/data_source_service.mock'; jest.mock('../services/conversations_service', () => { return { @@ -60,6 +61,7 @@ describe('useChatActions hook', () => { const setSelectedTabIdMock = jest.fn(); const pplVisualizationRenderMock = jest.fn(); const setInteractionIdMock = jest.fn(); + const dataSourceServiceMock = new DataSourceServiceMock(); const chatContextMock: chatContextHookExports.IChatContext = { selectedTabId: 'chat', @@ -88,8 +90,9 @@ describe('useChatActions hook', () => { jest.spyOn(coreHookExports, 'useCore').mockReturnValue({ services: { http: httpMock, - conversations: new ConversationsService(httpMock), - conversationLoad: new ConversationLoadService(httpMock), + conversations: new ConversationsService(httpMock, dataSourceServiceMock), + conversationLoad: new ConversationLoadService(httpMock, dataSourceServiceMock), + dataSource: dataSourceServiceMock, }, }); @@ -125,6 +128,7 @@ describe('useChatActions hook', () => { messages: [SEND_MESSAGE_RESPONSE.messages[0]], input: INPUT_MESSAGE, }), + query: await dataSourceServiceMock.getDataSourceQuery(), }); // it should send dispatch `receive` action to remove the message without messageId @@ -199,6 +203,7 @@ describe('useChatActions hook', () => { messages: [], input: { type: 'input', content: 'message that send as input', contentType: 'text' }, }), + query: await dataSourceServiceMock.getDataSourceQuery(), }); }); @@ -261,6 +266,7 @@ describe('useChatActions hook', () => { expect(chatStateDispatchMock).toHaveBeenCalledWith({ type: 'abort' }); expect(httpMock.post).toHaveBeenCalledWith(ASSISTANT_API.ABORT_AGENT_EXECUTION, { body: JSON.stringify({ conversationId: 'conversation_id_to_abort' }), + query: await dataSourceServiceMock.getDataSourceQuery(), }); }); @@ -288,6 +294,7 @@ describe('useChatActions hook', () => { conversationId: 'conversation_id_mock', interactionId: 'interaction_id_mock', }), + query: await dataSourceServiceMock.getDataSourceQuery(), }); expect(chatStateDispatchMock).toHaveBeenCalledWith( expect.objectContaining({ type: 'receive', payload: { messages: [], interactions: [] } }) @@ -323,6 +330,7 @@ describe('useChatActions hook', () => { conversationId: 'conversation_id_mock', interactionId: 'interaction_id_mock', }), + query: await dataSourceServiceMock.getDataSourceQuery(), }); expect(chatStateDispatchMock).not.toHaveBeenCalledWith( expect.objectContaining({ type: 'receive' }) diff --git a/public/hooks/use_chat_actions.tsx b/public/hooks/use_chat_actions.tsx index a8f6358b..962420f2 100644 --- a/public/hooks/use_chat_actions.tsx +++ b/public/hooks/use_chat_actions.tsx @@ -35,6 +35,7 @@ export const useChatActions = (): AssistantActions => { ...(!chatContext.conversationId && { messages: chatState.messages }), // include all previous messages for new chats input, }), + query: await core.services.dataSource.getDataSourceQuery(), }); if (abortController.signal.aborted) return; // Refresh history list after new conversation created if new conversation saved and history list page visible @@ -162,6 +163,7 @@ export const useChatActions = (): AssistantActions => { // abort agent execution await core.services.http.post(`${ASSISTANT_API.ABORT_AGENT_EXECUTION}`, { body: JSON.stringify({ conversationId }), + query: await core.services.dataSource.getDataSourceQuery(), }); } }; @@ -178,6 +180,7 @@ export const useChatActions = (): AssistantActions => { conversationId: chatContext.conversationId, interactionId, }), + query: await core.services.dataSource.getDataSourceQuery(), }); if (abortController.signal.aborted) { diff --git a/public/hooks/use_conversations.ts b/public/hooks/use_conversations.ts index 89e1a058..2828a4af 100644 --- a/public/hooks/use_conversations.ts +++ b/public/hooks/use_conversations.ts @@ -14,12 +14,13 @@ export const useDeleteConversation = () => { const abortControllerRef = useRef(); const deleteConversation = useCallback( - (conversationId: string) => { + async (conversationId: string) => { abortControllerRef.current = new AbortController(); dispatch({ type: 'request' }); return core.services.http .delete(`${ASSISTANT_API.CONVERSATION}/${conversationId}`, { signal: abortControllerRef.current.signal, + query: await core.services.dataSource.getDataSourceQuery(), }) .then((payload) => { dispatch({ type: 'success', payload }); @@ -52,7 +53,7 @@ export const usePatchConversation = () => { const abortControllerRef = useRef(); const patchConversation = useCallback( - (conversationId: string, title: string) => { + async (conversationId: string, title: string) => { abortControllerRef.current = new AbortController(); dispatch({ type: 'request' }); return core.services.http @@ -60,6 +61,7 @@ export const usePatchConversation = () => { body: JSON.stringify({ title, }), + query: await core.services.dataSource.getDataSourceQuery(), signal: abortControllerRef.current.signal, }) .then((payload) => dispatch({ type: 'success', payload })) diff --git a/public/hooks/use_feed_back.test.tsx b/public/hooks/use_feed_back.test.tsx index 8653c480..74de2031 100644 --- a/public/hooks/use_feed_back.test.tsx +++ b/public/hooks/use_feed_back.test.tsx @@ -12,11 +12,12 @@ import { httpServiceMock } from '../../../../src/core/public/mocks'; import * as chatContextHookExports from '../contexts/chat_context'; import { Interaction, IOutput, IMessage } from '../../common/types/chat_saved_object_attributes'; import { ASSISTANT_API } from '../../common/constants/llm'; +import { DataSourceServiceMock } from '../services/data_source_service.mock'; describe('useFeedback hook', () => { const httpMock = httpServiceMock.createStartContract(); const chatStateDispatchMock = jest.fn(); - + const dataSourceMock = new DataSourceServiceMock(); const chatContextMock = { rootAgentId: 'root_agent_id_mock', selectedTabId: 'chat', @@ -29,6 +30,7 @@ describe('useFeedback hook', () => { jest.spyOn(coreHookExports, 'useCore').mockReturnValue({ services: { http: httpMock, + dataSource: dataSourceMock, }, }); jest.spyOn(chatContextHookExports, 'useChatContext').mockReturnValue(chatContextMock); @@ -82,6 +84,7 @@ describe('useFeedback hook', () => { body: JSON.stringify({ satisfaction: true, }), + query: await dataSourceMock.getDataSourceQuery(), } ); expect(result.current.feedbackResult).toBe(true); @@ -116,6 +119,7 @@ describe('useFeedback hook', () => { body: JSON.stringify({ satisfaction: true, }), + query: await dataSourceMock.getDataSourceQuery(), } ); expect(result.current.feedbackResult).toBe(undefined); diff --git a/public/hooks/use_feed_back.tsx b/public/hooks/use_feed_back.tsx index 50e56380..2222b301 100644 --- a/public/hooks/use_feed_back.tsx +++ b/public/hooks/use_feed_back.tsx @@ -38,6 +38,7 @@ export const useFeedback = (interaction?: Interaction | null) => { try { await core.services.http.put(`${ASSISTANT_API.FEEDBACK}/${message.interactionId}`, { body: JSON.stringify(body), + query: await core.services.dataSource.getDataSourceQuery(), }); setFeedbackResult(correct); } catch (error) { diff --git a/public/hooks/use_fetch_agentframework_traces.test.ts b/public/hooks/use_fetch_agentframework_traces.test.ts index 4ce4c2a2..df81766c 100644 --- a/public/hooks/use_fetch_agentframework_traces.test.ts +++ b/public/hooks/use_fetch_agentframework_traces.test.ts @@ -9,11 +9,17 @@ import { createOpenSearchDashboardsReactContext } from '../../../../src/plugins/ import { coreMock } from '../../../../src/core/public/mocks'; import { HttpHandler } from '../../../../src/core/public'; import { AbortError } from '../../../../src/plugins/data/common'; +import { DataSourceServiceMock } from '../services/data_source_service.mock'; +import { waitFor } from '@testing-library/dom'; describe('useFetchAgentFrameworkTraces hook', () => { const interactionId = 'foo'; const services = coreMock.createStart(); - const { Provider } = createOpenSearchDashboardsReactContext(services); + const mockServices = { + ...services, + dataSource: new DataSourceServiceMock(), + }; + const { Provider } = createOpenSearchDashboardsReactContext(mockServices); const wrapper = { wrapper: Provider }; it('return undefined when interaction id is not specfied', () => { @@ -106,14 +112,16 @@ describe('useFetchAgentFrameworkTraces hook', () => { unmount(); }); - expect(services.http.get).toHaveBeenCalledWith( - `/api/assistant/trace/${interactionId}`, - expect.objectContaining({ - signal: expect.any(Object), - }) - ); + await waitFor(() => { + expect(services.http.get).toHaveBeenCalledWith( + `/api/assistant/trace/${interactionId}`, + expect.objectContaining({ + signal: expect.any(Object), + }) + ); - // expect the mock to be called - expect(abortFn).toBeCalledTimes(1); + // expect the mock to be called + expect(abortFn).toBeCalledTimes(1); + }); }); }); diff --git a/public/hooks/use_fetch_agentframework_traces.ts b/public/hooks/use_fetch_agentframework_traces.ts index 1edece5f..43925855 100644 --- a/public/hooks/use_fetch_agentframework_traces.ts +++ b/public/hooks/use_fetch_agentframework_traces.ts @@ -22,20 +22,23 @@ export const useFetchAgentFrameworkTraces = (interactionId: string) => { return; } - core.services.http - .get(`${ASSISTANT_API.TRACE}/${interactionId}`, { - signal: abortController.signal, - }) - .then((payload) => - dispatch({ - type: 'success', - payload, + core.services.dataSource.getDataSourceQuery().then((query) => { + core.services.http + .get(`${ASSISTANT_API.TRACE}/${interactionId}`, { + signal: abortController.signal, + query, }) - ) - .catch((error) => { - if (error.name === 'AbortError') return; - dispatch({ type: 'failure', error }); - }); + .then((payload) => + dispatch({ + type: 'success', + payload, + }) + ) + .catch((error) => { + if (error.name === 'AbortError') return; + dispatch({ type: 'failure', error }); + }); + }); return () => abortController.abort(); }, [core.services.http, interactionId]); diff --git a/public/plugin.tsx b/public/plugin.tsx index 76912891..a42449e5 100644 --- a/public/plugin.tsx +++ b/public/plugin.tsx @@ -111,8 +111,8 @@ export class AssistantPlugin ...coreStart, setupDeps, startDeps, - conversationLoad: new ConversationLoadService(coreStart.http), - conversations: new ConversationsService(coreStart.http), + conversationLoad: new ConversationLoadService(coreStart.http, this.dataSourceService), + conversations: new ConversationsService(coreStart.http, this.dataSourceService), dataSource: this.dataSourceService, }); const account = await getAccount(); diff --git a/public/services/__tests__/conversation_load_service.test.ts b/public/services/__tests__/conversation_load_service.test.ts index 81c2cf37..acf00ceb 100644 --- a/public/services/__tests__/conversation_load_service.test.ts +++ b/public/services/__tests__/conversation_load_service.test.ts @@ -3,13 +3,15 @@ * SPDX-License-Identifier: Apache-2.0 */ +import { waitFor } from '@testing-library/dom'; import { HttpHandler } from '../../../../../src/core/public'; import { httpServiceMock } from '../../../../../src/core/public/mocks'; import { ConversationLoadService } from '../conversation_load_service'; +import { DataSourceServiceMock } from '../data_source_service.mock'; const setup = () => { const http = httpServiceMock.createSetupContract(); - const conversationLoad = new ConversationLoadService(http); + const conversationLoad = new ConversationLoadService(http, new DataSourceServiceMock()); return { conversationLoad, @@ -18,17 +20,19 @@ const setup = () => { }; describe('ConversationLoadService', () => { - it('should emit loading status and call get with specific conversation id', () => { + it('should emit loading status and call get with specific conversation id', async () => { const { conversationLoad, http } = setup(); conversationLoad.load('foo'); - expect(http.get).toHaveBeenCalledWith( - '/api/assistant/conversation/foo', - expect.objectContaining({ - signal: expect.anything(), - }) - ); expect(conversationLoad.status$.getValue()).toBe('loading'); + await waitFor(() => { + expect(http.get).toHaveBeenCalledWith( + '/api/assistant/conversation/foo', + expect.objectContaining({ + signal: expect.anything(), + }) + ); + }); }); it('should resolved with response data and "idle" status', async () => { @@ -53,6 +57,9 @@ describe('ConversationLoadService', () => { }); }) as HttpHandler); const loadResult = conversationLoad.load('foo'); + await waitFor(() => { + expect(http.get).toHaveBeenCalled(); + }); conversationLoad.abortController?.abort(); await loadResult; diff --git a/public/services/__tests__/conversations_service.test.ts b/public/services/__tests__/conversations_service.test.ts index 2ed8c75f..198a9ed5 100644 --- a/public/services/__tests__/conversations_service.test.ts +++ b/public/services/__tests__/conversations_service.test.ts @@ -3,35 +3,41 @@ * SPDX-License-Identifier: Apache-2.0 */ +import { waitFor } from '@testing-library/dom'; import { HttpHandler } from '../../../../../src/core/public'; import { httpServiceMock } from '../../../../../src/core/public/mocks'; import { ConversationsService } from '../conversations_service'; +import { DataSourceServiceMock } from '../data_source_service.mock'; const setup = () => { const http = httpServiceMock.createSetupContract(); - const conversations = new ConversationsService(http); + const dataSourceServiceMock = new DataSourceServiceMock(); + const conversations = new ConversationsService(http, dataSourceServiceMock); return { conversations, http, + dataSource: dataSourceServiceMock, }; }; describe('ConversationsService', () => { - it('should emit loading status and call get with conversations API path', () => { + it('should emit loading status and call get with conversations API path', async () => { const { conversations, http } = setup(); conversations.load(); - expect(http.get).toHaveBeenCalledWith( - '/api/assistant/conversations', - expect.objectContaining({ - signal: expect.anything(), - }) - ); expect(conversations.status$.getValue()).toBe('loading'); + await waitFor(() => { + expect(http.get).toHaveBeenCalledWith( + '/api/assistant/conversations', + expect.objectContaining({ + signal: expect.anything(), + }) + ); + }); }); - it('should update options property and call get with passed query', () => { + it('should update options property and call get with passed query', async () => { const { conversations, http } = setup(); expect(conversations.options).toBeFalsy(); @@ -43,16 +49,19 @@ describe('ConversationsService', () => { page: 1, perPage: 10, }); - expect(http.get).toHaveBeenCalledWith( - '/api/assistant/conversations', - expect.objectContaining({ - query: { - page: 1, - perPage: 10, - }, - signal: expect.anything(), - }) - ); + await waitFor(() => { + expect(http.get).toHaveBeenCalledWith( + '/api/assistant/conversations', + expect.objectContaining({ + query: { + page: 1, + perPage: 10, + dataSourceId: '', + }, + signal: expect.anything(), + }) + ); + }); }); it('should emit latest conversations and "idle" status', async () => { @@ -78,17 +87,20 @@ describe('ConversationsService', () => { http.get.mockClear(); conversations.reload(); - expect(http.get).toHaveBeenCalledTimes(1); - expect(http.get).toHaveBeenCalledWith( - '/api/assistant/conversations', - expect.objectContaining({ - query: { - page: 1, - perPage: 10, - }, - signal: expect.anything(), - }) - ); + await waitFor(() => { + expect(http.get).toHaveBeenCalledTimes(1); + expect(http.get).toHaveBeenCalledWith( + '/api/assistant/conversations', + expect.objectContaining({ + query: { + page: 1, + perPage: 10, + dataSourceId: '', + }, + signal: expect.anything(), + }) + ); + }); }); it('should emit error after loading aborted', async () => { @@ -104,6 +116,9 @@ describe('ConversationsService', () => { }); }) as HttpHandler); const loadResult = conversations.load(); + await waitFor(() => { + expect(http.get).toHaveBeenCalled(); + }); conversations.abortController?.abort(); await loadResult; diff --git a/public/services/conversation_load_service.ts b/public/services/conversation_load_service.ts index 44b2e970..866ad6d6 100644 --- a/public/services/conversation_load_service.ts +++ b/public/services/conversation_load_service.ts @@ -7,6 +7,7 @@ import { BehaviorSubject } from 'rxjs'; import { HttpStart } from '../../../../src/core/public'; import { IConversation } from '../../common/types/chat_saved_object_attributes'; import { ASSISTANT_API } from '../../common/constants/llm'; +import { DataSourceService } from './data_source_service'; export class ConversationLoadService { status$: BehaviorSubject< @@ -14,7 +15,7 @@ export class ConversationLoadService { > = new BehaviorSubject<'idle' | 'loading' | { status: 'error'; error: Error }>('idle'); abortController?: AbortController; - constructor(private _http: HttpStart) {} + constructor(private _http: HttpStart, private _dataSource: DataSourceService) {} load = async (conversationId: string) => { this.abortController?.abort(); @@ -25,6 +26,7 @@ export class ConversationLoadService { `${ASSISTANT_API.CONVERSATION}/${conversationId}`, { signal: this.abortController.signal, + query: await this._dataSource.getDataSourceQuery(), } ); this.status$.next('idle'); diff --git a/public/services/conversations_service.ts b/public/services/conversations_service.ts index 4f95794a..ab188070 100644 --- a/public/services/conversations_service.ts +++ b/public/services/conversations_service.ts @@ -7,6 +7,7 @@ import { BehaviorSubject } from 'rxjs'; import { HttpFetchQuery, HttpStart, SavedObjectsFindOptions } from '../../../../src/core/public'; import { IConversationFindResponse } from '../../common/types/chat_saved_object_attributes'; import { ASSISTANT_API } from '../../common/constants/llm'; +import { DataSourceService } from './data_source_service'; export class ConversationsService { conversations$: BehaviorSubject = new BehaviorSubject( @@ -21,7 +22,7 @@ export class ConversationsService { >; abortController?: AbortController; - constructor(private _http: HttpStart) {} + constructor(private _http: HttpStart, private _dataSource: DataSourceService) {} public get options() { return this._options; @@ -37,7 +38,10 @@ export class ConversationsService { this.status$.next('loading'); this.conversations$.next( await this._http.get(ASSISTANT_API.CONVERSATIONS, { - query: this._options as HttpFetchQuery, + query: { + ...this._options, + ...(await this._dataSource.getDataSourceQuery()), + } as HttpFetchQuery, signal: this.abortController.signal, }) ); diff --git a/public/services/data_source_service.mock.ts b/public/services/data_source_service.mock.ts new file mode 100644 index 00000000..40fc2018 --- /dev/null +++ b/public/services/data_source_service.mock.ts @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +export class DataSourceServiceMock { + constructor() {} + + getDataSourceQuery() { + return new Promise((resolve) => { + resolve({ dataSourceId: '' }); + }); + // return { dataSourceId: '' }; + } +} diff --git a/public/tabs/history/__tests__/chat_history_page.test.tsx b/public/tabs/history/__tests__/chat_history_page.test.tsx index eb948ceb..944fedf6 100644 --- a/public/tabs/history/__tests__/chat_history_page.test.tsx +++ b/public/tabs/history/__tests__/chat_history_page.test.tsx @@ -14,6 +14,7 @@ import * as useChatStateExports from '../../../hooks/use_chat_state'; import * as chatContextExports from '../../../contexts/chat_context'; import * as coreContextExports from '../../../contexts/core_context'; import { ConversationsService } from '../../../services/conversations_service'; +import { DataSourceServiceMock } from '../../../services/data_source_service.mock'; import { ChatHistoryPage } from '../chat_history_page'; @@ -38,12 +39,14 @@ const setup = ({ http?: HttpStart; chatContext?: { flyoutFullScreen?: boolean }; } = {}) => { + const dataSourceMock = new DataSourceServiceMock(); const useCoreMock = { services: { ...coreMock.createStart(), http, - conversations: new ConversationsService(http), + conversations: new ConversationsService(http, dataSourceMock), conversationLoad: {}, + dataSource: dataSourceMock, }, }; const useChatStateMock = { @@ -91,10 +94,11 @@ describe('', () => { expect(useChatStateMock.chatStateDispatch).not.toHaveBeenCalled(); fireEvent.click(renderResult.getByTestId('confirmModalConfirmButton')); - - expect(useChatContextMock.setConversationId).toHaveBeenLastCalledWith(undefined); - expect(useChatContextMock.setTitle).toHaveBeenLastCalledWith(undefined); - expect(useChatStateMock.chatStateDispatch).toHaveBeenLastCalledWith({ type: 'reset' }); + await waitFor(() => { + expect(useChatContextMock.setConversationId).toHaveBeenLastCalledWith(undefined); + expect(useChatContextMock.setTitle).toHaveBeenLastCalledWith(undefined); + expect(useChatStateMock.chatStateDispatch).toHaveBeenLastCalledWith({ type: 'reset' }); + }); }); it('should render empty screen', async () => { diff --git a/public/tabs/history/__tests__/chat_history_search_list.test.tsx b/public/tabs/history/__tests__/chat_history_search_list.test.tsx index d84baf86..1792136e 100644 --- a/public/tabs/history/__tests__/chat_history_search_list.test.tsx +++ b/public/tabs/history/__tests__/chat_history_search_list.test.tsx @@ -12,6 +12,7 @@ import * as chatContextExports from '../../../contexts/chat_context'; import * as coreContextExports from '../../../contexts/core_context'; import { ChatHistorySearchList, ChatHistorySearchListProps } from '../chat_history_search_list'; +import { DataSourceServiceMock } from '../../../services/data_source_service.mock'; const setup = ({ loading = false, @@ -26,8 +27,9 @@ const setup = ({ conversationId: '1', setTitle: jest.fn(), }; + const dataSourceServiceMock = new DataSourceServiceMock(); const useCoreMock = { - services: coreMock.createStart(), + services: { ...coreMock.createStart(), dataSource: dataSourceServiceMock }, }; useCoreMock.services.http.put.mockImplementation(() => Promise.resolve()); useCoreMock.services.http.delete.mockImplementation(() => Promise.resolve()); diff --git a/public/tabs/history/__tests__/delete_conversation_confirm_modal.test.tsx b/public/tabs/history/__tests__/delete_conversation_confirm_modal.test.tsx index 764bd8a4..64530e70 100644 --- a/public/tabs/history/__tests__/delete_conversation_confirm_modal.test.tsx +++ b/public/tabs/history/__tests__/delete_conversation_confirm_modal.test.tsx @@ -15,10 +15,13 @@ import { DeleteConversationConfirmModalProps, } from '../delete_conversation_confirm_modal'; import { HttpHandler } from '../../../../../../src/core/public'; +import { DataSourceServiceMock } from '../../../services/data_source_service.mock'; const setup = ({ onClose, conversationId }: DeleteConversationConfirmModalProps) => { + const dataSourceServiceMock = new DataSourceServiceMock(); + const useCoreMock = { - services: coreMock.createStart(), + services: { ...coreMock.createStart(), dataSource: dataSourceServiceMock }, }; jest.spyOn(coreContextExports, 'useCore').mockReturnValue(useCoreMock); @@ -124,7 +127,9 @@ describe('', () => { expect(useCoreMock.services.http.delete).not.toHaveBeenCalled(); fireEvent.click(renderResult.getByTestId('confirmModalConfirmButton')); - expect(useCoreMock.services.http.delete).toHaveBeenCalled(); + await waitFor(() => { + expect(useCoreMock.services.http.delete).toHaveBeenCalled(); + }); fireEvent.click(renderResult.getByTestId('confirmModalCancelButton')); diff --git a/server/routes/chat_routes.test.ts b/server/routes/chat_routes.test.ts index 41f9c195..cbe03ad9 100644 --- a/server/routes/chat_routes.test.ts +++ b/server/routes/chat_routes.test.ts @@ -14,9 +14,24 @@ import { mockOllyChatService } from '../services/chat/olly_chat_service.mock'; import { loggerMock } from '../../../../src/core/server/logging/logger.mock'; import { registerChatRoutes } from './chat_routes'; import { ASSISTANT_API } from '../../common/constants/llm'; +import { getOpenSearchClientTransport } from '../utils/get_opensearch_client_transport'; -const mockedLogger = loggerMock.create(); +jest.mock('../utils/get_opensearch_client_transport'); + +beforeEach(() => { + (getOpenSearchClientTransport as jest.Mock).mockImplementation(({ dataSourceId }) => { + if (dataSourceId) { + return 'dataSource-client'; + } else { + return 'client'; + } + }); +}); +afterEach(() => { + (getOpenSearchClientTransport as jest.Mock).mockClear(); +}); +const mockedLogger = loggerMock.create(); const router = new Router( '', mockedLogger, @@ -30,38 +45,90 @@ registerChatRoutes(router, { messageParsers: [], }); -const triggerDeleteConversation = (conversationId: string) => +const triggerDeleteConversation = (conversationId: string, dataSourceId?: string) => triggerHandler(router, { method: 'delete', path: `${ASSISTANT_API.CONVERSATION}/{conversationId}`, - req: httpServerMock.createRawRequest({ params: { conversationId } }), + req: httpServerMock.createRawRequest({ + params: { conversationId }, + ...(dataSourceId + ? { + query: { + dataSourceId, + }, + } + : {}), + }), }); const triggerUpdateConversation = ( params: { conversationId: string }, - payload: { title: string } + payload: { title: string }, + dataSourceId?: string ) => triggerHandler(router, { method: 'put', path: `${ASSISTANT_API.CONVERSATION}/{conversationId}`, - req: httpServerMock.createRawRequest({ params, payload }), + req: httpServerMock.createRawRequest({ + params, + payload, + ...(dataSourceId + ? { + query: { + dataSourceId, + }, + } + : {}), + }), }); -const triggerGetTrace = (interactionId: string) => +const triggerGetTrace = (interactionId: string, dataSourceId?: string) => triggerHandler(router, { method: 'get', path: `${ASSISTANT_API.TRACE}/{interactionId}`, - req: httpServerMock.createRawRequest({ params: { interactionId } }), + req: httpServerMock.createRawRequest({ + params: { interactionId }, + ...(dataSourceId + ? { + query: { + dataSourceId, + }, + } + : {}), + }), }); -const triggerAbortAgentExecution = (conversationId: string) => +const triggerAbortAgentExecution = (conversationId: string, dataSourceId?: string) => triggerHandler(router, { method: 'post', path: ASSISTANT_API.ABORT_AGENT_EXECUTION, - req: httpServerMock.createRawRequest({ payload: { conversationId } }), + req: httpServerMock.createRawRequest({ + payload: { conversationId }, + ...(dataSourceId + ? { + query: { + dataSourceId, + }, + } + : {}), + }), }); -const triggerFeedback = (params: { interactionId: string }, payload: { satisfaction: boolean }) => +const triggerFeedback = ( + params: { interactionId: string }, + payload: { satisfaction: boolean }, + dataSourceId?: string +) => triggerHandler(router, { method: 'put', path: `${ASSISTANT_API.FEEDBACK}/{interactionId}`, - req: httpServerMock.createRawRequest({ params, payload }), + req: httpServerMock.createRawRequest({ + params, + payload, + ...(dataSourceId + ? { + query: { + dataSourceId, + }, + } + : {}), + }), }); describe('chat routes', () => { @@ -80,6 +147,22 @@ describe('chat routes', () => { expect(mockAgentFrameworkStorageService.deleteConversation).not.toHaveBeenCalled(); const result = (await triggerDeleteConversation('foo')) as ResponseObject; + expect(getOpenSearchClientTransport.mock.results[0].value).toBe('client'); + expect(mockAgentFrameworkStorageService.deleteConversation).toHaveBeenCalledWith('foo'); + expect(result.source).toMatchInlineSnapshot(` + Object { + "success": true, + } + `); + }); + + it('should call delete conversation with passed data source id and get data source transport', async () => { + mockAgentFrameworkStorageService.deleteConversation.mockResolvedValueOnce({ + success: true, + }); + expect(mockAgentFrameworkStorageService.deleteConversation).not.toHaveBeenCalled(); + const result = (await triggerDeleteConversation('foo', 'data_source_id')) as ResponseObject; + expect(getOpenSearchClientTransport.mock.results[0].value).toBe('dataSource-client'); expect(mockAgentFrameworkStorageService.deleteConversation).toHaveBeenCalledWith('foo'); expect(result.source).toMatchInlineSnapshot(` Object { @@ -109,6 +192,7 @@ describe('chat routes', () => { { conversationId: 'foo' }, { title: 'new-title' } )) as ResponseObject; + expect(getOpenSearchClientTransport.mock.results[0].value).toBe('client'); expect(mockAgentFrameworkStorageService.updateConversation).toHaveBeenCalledWith( 'foo', 'new-title' @@ -120,6 +204,29 @@ describe('chat routes', () => { `); }); + it('should call update conversation with passed data source id and title then get data source transport', async () => { + mockAgentFrameworkStorageService.updateConversation.mockResolvedValueOnce({ + success: true, + }); + + expect(mockAgentFrameworkStorageService.updateConversation).not.toHaveBeenCalled(); + const result = (await triggerUpdateConversation( + { conversationId: 'foo' }, + { title: 'new-title' }, + 'data_source_id' + )) as ResponseObject; + expect(mockAgentFrameworkStorageService.updateConversation).toHaveBeenCalledWith( + 'foo', + 'new-title' + ); + expect(getOpenSearchClientTransport.mock.results[0].value).toBe('dataSource-client'); + expect(result.source).toMatchInlineSnapshot(` + Object { + "success": true, + } + `); + }); + it('should log error and return 500 error when failed to update conversation', async () => { mockAgentFrameworkStorageService.updateConversation.mockRejectedValueOnce(new Error()); @@ -149,6 +256,27 @@ describe('chat routes', () => { expect(mockAgentFrameworkStorageService.getTraces).not.toHaveBeenCalled(); const result = (await triggerGetTrace('interaction-1')) as ResponseObject; + expect(getOpenSearchClientTransport.mock.results[0].value).toBe('client'); + expect(mockAgentFrameworkStorageService.getTraces).toHaveBeenCalledWith('interaction-1'); + expect(result.source).toEqual(getTraceResultMock); + }); + + it('should call get traces with passed data source id and get data source transport', async () => { + const getTraceResultMock = [ + { + interactionId: 'interaction-1', + createTime: '', + input: 'foo', + output: 'bar', + origin: '', + traceNumber: 0, + }, + ]; + mockAgentFrameworkStorageService.getTraces.mockResolvedValueOnce(getTraceResultMock); + + expect(mockAgentFrameworkStorageService.getTraces).not.toHaveBeenCalled(); + const result = (await triggerGetTrace('interaction-1', 'data_source_id')) as ResponseObject; + expect(getOpenSearchClientTransport.mock.results[0].value).toBe('dataSource-client'); expect(mockAgentFrameworkStorageService.getTraces).toHaveBeenCalledWith('interaction-1'); expect(result.source).toEqual(getTraceResultMock); }); @@ -169,6 +297,16 @@ describe('chat routes', () => { await triggerAbortAgentExecution('foo'); expect(mockOllyChatService.abortAgentExecution).toHaveBeenCalledWith('foo'); + expect(getOpenSearchClientTransport.mock.results[0].value).toBe('client'); + expect(mockedLogger.info).toHaveBeenCalledWith('Abort agent execution: foo'); + }); + + it('should call get abort agent with passed data source id and get data source transport ', async () => { + expect(mockOllyChatService.abortAgentExecution).not.toHaveBeenCalled(); + + await triggerAbortAgentExecution('foo', 'data_source_id'); + expect(mockOllyChatService.abortAgentExecution).toHaveBeenCalledWith('foo'); + expect(getOpenSearchClientTransport.mock.results[0].value).toBe('dataSource-client'); expect(mockedLogger.info).toHaveBeenCalledWith('Abort agent execution: foo'); }); @@ -200,6 +338,31 @@ describe('chat routes', () => { { interactionId: 'foo' }, { satisfaction: true } )) as ResponseObject; + expect(getOpenSearchClientTransport.mock.results[0].value).toBe('client'); + expect(mockAgentFrameworkStorageService.updateInteraction).toHaveBeenCalledWith('foo', { + feedback: { + satisfaction: true, + }, + }); + expect(result.source).toMatchInlineSnapshot(` + Object { + "success": true, + } + `); + }); + + it('should call update interaction with passed data source id and get data source transport', async () => { + mockAgentFrameworkStorageService.updateConversation.mockResolvedValueOnce({ + success: true, + }); + + expect(mockAgentFrameworkStorageService.updateConversation).not.toHaveBeenCalled(); + const result = (await triggerFeedback( + { interactionId: 'foo' }, + { satisfaction: true }, + 'data_source_id' + )) as ResponseObject; + expect(getOpenSearchClientTransport.mock.results[0].value).toBe('dataSource-client'); expect(mockAgentFrameworkStorageService.updateInteraction).toHaveBeenCalledWith('foo', { feedback: { satisfaction: true, diff --git a/server/routes/chat_routes.ts b/server/routes/chat_routes.ts index 0a2b7b51..e0809be5 100644 --- a/server/routes/chat_routes.ts +++ b/server/routes/chat_routes.ts @@ -17,6 +17,7 @@ import { OllyChatService } from '../services/chat/olly_chat_service'; import { AgentFrameworkStorageService } from '../services/storage/agent_framework_storage_service'; import { RoutesOptions } from '../types'; import { ChatService } from '../services/chat/chat_service'; +import { getOpenSearchClientTransport } from '../utils/get_opensearch_client_transport'; const llmRequestRoute = { path: ASSISTANT_API.SEND_MESSAGE, @@ -33,6 +34,9 @@ const llmRequestRoute = { contentType: schema.literal('text'), }), }), + query: schema.object({ + dataSourceId: schema.maybe(schema.string()), + }), }, }; export type LLMRequestSchema = TypeOf; @@ -43,6 +47,9 @@ const getConversationRoute = { params: schema.object({ conversationId: schema.string(), }), + query: schema.object({ + dataSourceId: schema.maybe(schema.string()), + }), }, }; export type GetConversationSchema = TypeOf; @@ -53,6 +60,9 @@ const abortAgentExecutionRoute = { body: schema.object({ conversationId: schema.string(), }), + query: schema.object({ + dataSourceId: schema.maybe(schema.string()), + }), }, }; export type AbortAgentExecutionSchema = TypeOf; @@ -64,6 +74,9 @@ const regenerateRoute = { conversationId: schema.string(), interactionId: schema.string(), }), + query: schema.object({ + dataSourceId: schema.maybe(schema.string()), + }), }, }; export type RegenerateSchema = TypeOf; @@ -79,6 +92,7 @@ const getConversationsRoute = { fields: schema.maybe(schema.arrayOf(schema.string())), search: schema.maybe(schema.string()), searchFields: schema.maybe(schema.oneOf([schema.string(), schema.arrayOf(schema.string())])), + dataSourceId: schema.maybe(schema.string()), }), }, }; @@ -90,6 +104,9 @@ const deleteConversationRoute = { params: schema.object({ conversationId: schema.string(), }), + query: schema.object({ + dataSourceId: schema.maybe(schema.string()), + }), }, }; @@ -102,6 +119,9 @@ const updateConversationRoute = { body: schema.object({ title: schema.string(), }), + query: schema.object({ + dataSourceId: schema.maybe(schema.string()), + }), }, }; @@ -111,6 +131,9 @@ const getTracesRoute = { params: schema.object({ interactionId: schema.string(), }), + query: schema.object({ + dataSourceId: schema.maybe(schema.string()), + }), }, }; @@ -123,16 +146,20 @@ const feedbackRoute = { body: schema.object({ satisfaction: schema.boolean(), }), + query: schema.object({ + dataSourceId: schema.maybe(schema.string()), + }), }, }; export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions) { - const createStorageService = (context: RequestHandlerContext) => + const createStorageService = async (context: RequestHandlerContext, dataSourceId?: string) => new AgentFrameworkStorageService( - context.core.opensearch.client.asCurrentUser, + await getOpenSearchClientTransport({ context, dataSourceId }), routeOptions.messageParsers ); - const createChatService = (context: RequestHandlerContext) => new OllyChatService(context); + const createChatService = async (context: RequestHandlerContext, dataSourceId?: string) => + new OllyChatService(await getOpenSearchClientTransport({ context, dataSourceId })); router.post( llmRequestRoute, @@ -142,8 +169,8 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions) response ): Promise> => { const { messages = [], input, conversationId: conversationIdInRequestBody } = request.body; - const storageService = createStorageService(context); - const chatService = createChatService(context); + const storageService = await createStorageService(context, request.query.dataSourceId); + const chatService = await createChatService(context, request.query.dataSourceId); let outputs: Awaited> | undefined; @@ -217,7 +244,7 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions) request, response ): Promise> => { - const storageService = createStorageService(context); + const storageService = await createStorageService(context, request.query.dataSourceId); try { const getResponse = await storageService.getConversation(request.params.conversationId); @@ -236,7 +263,7 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions) request, response ): Promise> => { - const storageService = createStorageService(context); + const storageService = await createStorageService(context, request.query.dataSourceId); try { const getResponse = await storageService.getConversations(request.query); @@ -255,7 +282,7 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions) request, response ): Promise> => { - const storageService = createStorageService(context); + const storageService = await createStorageService(context, request.query.dataSourceId); try { const getResponse = await storageService.deleteConversation(request.params.conversationId); @@ -274,7 +301,7 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions) request, response ): Promise> => { - const storageService = createStorageService(context); + const storageService = await createStorageService(context, request.query.dataSourceId); try { const getResponse = await storageService.updateConversation( @@ -296,7 +323,7 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions) request, response ): Promise> => { - const storageService = createStorageService(context); + const storageService = await createStorageService(context, request.query.dataSourceId); try { const getResponse = await storageService.getTraces(request.params.interactionId); @@ -315,7 +342,7 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions) request, response ): Promise> => { - const chatService = createChatService(context, ''); + const chatService = await createChatService(context, request.query.dataSourceId); try { chatService.abortAgentExecution(request.body.conversationId); context.assistant_plugin.logger.info( @@ -337,8 +364,8 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions) response ): Promise> => { const { conversationId, interactionId } = request.body; - const storageService = createStorageService(context); - const chatService = createChatService(context); + const storageService = await createStorageService(context, request.query.dataSourceId); + const chatService = await createChatService(context, request.query.dataSourceId); let outputs: Awaited> | undefined; @@ -386,7 +413,7 @@ export function registerChatRoutes(router: IRouter, routeOptions: RoutesOptions) request, response ): Promise> => { - const storageService = createStorageService(context); + const storageService = await createStorageService(context, request.query.dataSourceId); const { interactionId } = request.params; try { diff --git a/server/services/chat/olly_chat_service.test.ts b/server/services/chat/olly_chat_service.test.ts index 936f7fc5..49362dcb 100644 --- a/server/services/chat/olly_chat_service.test.ts +++ b/server/services/chat/olly_chat_service.test.ts @@ -6,30 +6,23 @@ import { OllyChatService } from './olly_chat_service'; import { CoreRouteHandlerContext } from '../../../../../src/core/server/core_route_handler_context'; import { coreMock, httpServerMock } from '../../../../../src/core/server/mocks'; -import { loggerMock } from '../../../../../src/core/server/logging/logger.mock'; -import { ResponseError } from '@opensearch-project/opensearch/lib/errors'; -import { ApiResponse } from '@opensearch-project/opensearch'; describe('OllyChatService', () => { const coreContext = new CoreRouteHandlerContext( coreMock.createInternalStart(), httpServerMock.createOpenSearchDashboardsRequest() ); - const mockedTransport = coreContext.opensearch.client.asCurrentUser.transport - .request as jest.Mock; - const contextMock = { - core: coreContext, - assistant_plugin: { - logger: loggerMock.create(), - }, - }; - const ollyChatService: OllyChatService = new OllyChatService(contextMock); + const mockedTransport = coreContext.opensearch.client.asCurrentUser.transport; + const mockedTransportRequest = mockedTransport.request as jest.Mock; + + const ollyChatService: OllyChatService = new OllyChatService(mockedTransport); + beforeEach(async () => { - mockedTransport.mockClear(); + mockedTransportRequest.mockClear(); }); it('requestLLM should invoke client call with correct params', async () => { - mockedTransport.mockImplementation((args) => { + mockedTransportRequest.mockImplementation((args) => { if (args.path === '/_plugins/_ml/config/os_chat') { return { body: { @@ -65,7 +58,7 @@ describe('OllyChatService', () => { }, conversationId: 'conversationId', }); - expect(mockedTransport.mock.calls).toMatchInlineSnapshot(` + expect(mockedTransportRequest.mock.calls).toMatchInlineSnapshot(` Array [ Array [ Object { @@ -102,7 +95,7 @@ describe('OllyChatService', () => { }); it('requestLLM should throw error when transport.request throws error', async () => { - mockedTransport + mockedTransportRequest .mockImplementationOnce(() => { return { body: { @@ -124,13 +117,14 @@ describe('OllyChatService', () => { contentType: 'text', content: 'content', }, + conversationId: '', }) ).rejects.toMatchInlineSnapshot(`[Error: error]`); }); it('regenerate should invoke client call with correct params', async () => { - mockedTransport.mockImplementation((args) => { + mockedTransportRequest.mockImplementation((args) => { if (args.path === '/_plugins/_ml/config/os_chat') { return { body: { @@ -161,7 +155,7 @@ describe('OllyChatService', () => { conversationId: 'conversationId', interactionId: 'interactionId', }); - expect(mockedTransport.mock.calls).toMatchInlineSnapshot(` + expect(mockedTransportRequest.mock.calls).toMatchInlineSnapshot(` Array [ Array [ Object { @@ -198,7 +192,7 @@ describe('OllyChatService', () => { }); it('regenerate should throw error when transport.request throws error', async () => { - mockedTransport + mockedTransportRequest .mockImplementationOnce(() => { return { body: { @@ -221,7 +215,7 @@ describe('OllyChatService', () => { }); it('fetching root agent id throws error', async () => { - mockedTransport.mockImplementationOnce(() => { + mockedTransportRequest.mockImplementationOnce(() => { return { body: { hits: { diff --git a/server/services/chat/olly_chat_service.ts b/server/services/chat/olly_chat_service.ts index cb239079..bdebec5a 100644 --- a/server/services/chat/olly_chat_service.ts +++ b/server/services/chat/olly_chat_service.ts @@ -4,7 +4,7 @@ */ import { ApiResponse } from '@opensearch-project/opensearch'; -import { RequestHandlerContext } from '../../../../../src/core/server'; +import { OpenSearchClient } from '../../../../../src/core/server'; import { IMessage, IInput } from '../../../common/types/chat_saved_object_attributes'; import { ChatService } from './chat_service'; import { ML_COMMONS_BASE_API, ROOT_AGENT_CONFIG_ID } from '../../utils/constants'; @@ -22,13 +22,12 @@ const INTERACTION_ID_FIELD = 'parent_interaction_id'; export class OllyChatService implements ChatService { static abortControllers: Map = new Map(); - constructor(private readonly context: RequestHandlerContext) {} + constructor(private readonly opensearchClientTransport: OpenSearchClient['transport']) {} private async getRootAgent(): Promise { try { - const opensearchClient = this.context.core.opensearch.client.asCurrentUser; const path = `${ML_COMMONS_BASE_API}/config/${ROOT_AGENT_CONFIG_ID}`; - const response = await opensearchClient.transport.request({ + const response = await this.opensearchClientTransport.request({ method: 'GET', path, }); @@ -53,9 +52,8 @@ export class OllyChatService implements ChatService { } private async callExecuteAgentAPI(payload: AgentRunPayload, rootAgentId: string) { - const opensearchClient = this.context.core.opensearch.client.asCurrentUser; try { - const agentFrameworkResponse = (await opensearchClient.transport.request( + const agentFrameworkResponse = (await this.opensearchClientTransport.request( { method: 'POST', path: `${ML_COMMONS_BASE_API}/agents/${rootAgentId}/_execute`, diff --git a/server/services/storage/agent_framework_storage_service.test.ts b/server/services/storage/agent_framework_storage_service.test.ts index e1b5781a..aa1a77b0 100644 --- a/server/services/storage/agent_framework_storage_service.test.ts +++ b/server/services/storage/agent_framework_storage_service.test.ts @@ -12,17 +12,15 @@ describe('AgentFrameworkStorageService', () => { coreMock.createInternalStart(), httpServerMock.createOpenSearchDashboardsRequest() ); - const mockedTransport = coreContext.opensearch.client.asCurrentUser.transport - .request as jest.Mock; - const agentFrameworkService = new AgentFrameworkStorageService( - coreContext.opensearch.client.asCurrentUser, - [] - ); + const mockedTransport = coreContext.opensearch.client.asCurrentUser.transport; + const mockedTransportRequest = mockedTransport.request as jest.Mock; + + const agentFrameworkService = new AgentFrameworkStorageService(mockedTransport, []); beforeEach(() => { - mockedTransport.mockReset(); + mockedTransportRequest.mockReset(); }); it('getConversation', async () => { - mockedTransport.mockImplementation(async (params) => { + mockedTransportRequest.mockImplementation(async (params) => { if (params.path.includes('/messages?max_results=1000')) { return { body: { @@ -62,7 +60,7 @@ describe('AgentFrameworkStorageService', () => { "updatedTimeMs": 0, } `); - expect(mockedTransport.mock.calls).toMatchInlineSnapshot(` + expect(mockedTransportRequest.mock.calls).toMatchInlineSnapshot(` Array [ Array [ Object { @@ -81,14 +79,14 @@ describe('AgentFrameworkStorageService', () => { }); it('should encode id when calls getConversation with non-standard params in request payload', async () => { - mockedTransport.mockResolvedValue({ + mockedTransportRequest.mockResolvedValue({ body: { messages: [], }, }); await agentFrameworkService.getConversation('../non-standard/id'); - expect(mockedTransport.mock.calls[0]).toMatchInlineSnapshot(` + expect(mockedTransportRequest.mock.calls[0]).toMatchInlineSnapshot(` Array [ Object { "method": "GET", @@ -99,7 +97,7 @@ describe('AgentFrameworkStorageService', () => { }); it('getConversations', async () => { - mockedTransport.mockImplementation(async (params) => { + mockedTransportRequest.mockImplementation(async (params) => { return { body: { hits: { @@ -165,7 +163,7 @@ describe('AgentFrameworkStorageService', () => { "total": 10, } `); - expect(mockedTransport.mock.calls).toMatchInlineSnapshot(` + expect(mockedTransportRequest.mock.calls).toMatchInlineSnapshot(` Array [ Array [ Object { @@ -214,7 +212,7 @@ describe('AgentFrameworkStorageService', () => { }); it('deleteConversation', async () => { - mockedTransport.mockImplementationOnce(async (params) => ({ + mockedTransportRequest.mockImplementationOnce(async (params) => ({ statusCode: 200, })); expect(agentFrameworkService.deleteConversation('foo')).resolves.toMatchInlineSnapshot(` @@ -222,7 +220,7 @@ describe('AgentFrameworkStorageService', () => { "success": true, } `); - mockedTransport.mockImplementationOnce(async (params) => ({ + mockedTransportRequest.mockImplementationOnce(async (params) => ({ statusCode: 404, body: { message: 'can not find conversation', @@ -235,18 +233,18 @@ describe('AgentFrameworkStorageService', () => { "success": false, } `); - mockedTransport.mockImplementationOnce(async (params) => { + mockedTransportRequest.mockImplementationOnce(async (params) => { return Promise.reject({ meta: { body: 'error' } }); }); expect(agentFrameworkService.deleteConversation('foo')).rejects.toBeDefined(); }); it('should encode id when calls deleteConversation with non-standard params in request payload', async () => { - mockedTransport.mockResolvedValueOnce({ + mockedTransportRequest.mockResolvedValueOnce({ statusCode: 200, }); await agentFrameworkService.deleteConversation('../non-standard/id'); - expect(mockedTransport.mock.calls[0]).toMatchInlineSnapshot(` + expect(mockedTransportRequest.mock.calls[0]).toMatchInlineSnapshot(` Array [ Object { "method": "DELETE", @@ -257,7 +255,7 @@ describe('AgentFrameworkStorageService', () => { }); it('updateConversation', async () => { - mockedTransport.mockImplementationOnce(async (params) => ({ + mockedTransportRequest.mockImplementationOnce(async (params) => ({ statusCode: 200, })); expect(agentFrameworkService.updateConversation('foo', 'title')).resolves @@ -266,7 +264,7 @@ describe('AgentFrameworkStorageService', () => { "success": true, } `); - mockedTransport.mockImplementationOnce(async (params) => ({ + mockedTransportRequest.mockImplementationOnce(async (params) => ({ statusCode: 404, body: { message: 'can not find conversation', @@ -280,18 +278,18 @@ describe('AgentFrameworkStorageService', () => { "success": false, } `); - mockedTransport.mockImplementationOnce(async (params) => { + mockedTransportRequest.mockImplementationOnce(async (params) => { return Promise.reject({ meta: { body: 'error' } }); }); expect(agentFrameworkService.updateConversation('foo', 'title')).rejects.toBeDefined(); }); it('should encode id when calls updateConversation with non-standard params in request payload', async () => { - mockedTransport.mockResolvedValueOnce({ + mockedTransportRequest.mockResolvedValueOnce({ statusCode: 200, }); await agentFrameworkService.updateConversation('../non-standard/id', 'title'); - expect(mockedTransport.mock.calls[0]).toMatchInlineSnapshot(` + expect(mockedTransportRequest.mock.calls[0]).toMatchInlineSnapshot(` Array [ Object { "body": Object { @@ -305,7 +303,7 @@ describe('AgentFrameworkStorageService', () => { }); it('getTraces', async () => { - mockedTransport.mockImplementationOnce(async (params) => ({ + mockedTransportRequest.mockImplementationOnce(async (params) => ({ body: { traces: [ { @@ -331,7 +329,7 @@ describe('AgentFrameworkStorageService', () => { }, ] `); - mockedTransport.mockImplementationOnce(async (params) => { + mockedTransportRequest.mockImplementationOnce(async (params) => { return Promise.reject({ meta: { body: 'error' } }); }); expect(agentFrameworkService.getTraces('foo')).rejects.toMatchInlineSnapshot(` @@ -344,7 +342,7 @@ describe('AgentFrameworkStorageService', () => { }); it('should encode id when calls getTraces with non-standard params in request payload', async () => { - mockedTransport.mockResolvedValueOnce({ + mockedTransportRequest.mockResolvedValueOnce({ body: { traces: [ { @@ -359,7 +357,7 @@ describe('AgentFrameworkStorageService', () => { }, }); await agentFrameworkService.getTraces('../non-standard/id'); - expect(mockedTransport.mock.calls[0]).toMatchInlineSnapshot(` + expect(mockedTransportRequest.mock.calls[0]).toMatchInlineSnapshot(` Array [ Object { "method": "GET", @@ -370,7 +368,7 @@ describe('AgentFrameworkStorageService', () => { }); it('updateInteraction', async () => { - mockedTransport.mockImplementationOnce(async (params) => ({ + mockedTransportRequest.mockImplementationOnce(async (params) => ({ statusCode: 200, })); expect( @@ -384,7 +382,7 @@ describe('AgentFrameworkStorageService', () => { "success": true, } `); - mockedTransport.mockImplementationOnce(async (params) => ({ + mockedTransportRequest.mockImplementationOnce(async (params) => ({ statusCode: 404, body: { message: 'can not find conversation', @@ -403,7 +401,7 @@ describe('AgentFrameworkStorageService', () => { "success": false, } `); - mockedTransport.mockImplementationOnce(async (params) => { + mockedTransportRequest.mockImplementationOnce(async (params) => { return Promise.reject({ meta: { body: 'error' } }); }); expect( @@ -422,7 +420,7 @@ describe('AgentFrameworkStorageService', () => { }); it('should encode id when calls updateInteraction with non-standard params in request payload', async () => { - mockedTransport.mockResolvedValueOnce({ + mockedTransportRequest.mockResolvedValueOnce({ statusCode: 200, }); await agentFrameworkService.updateInteraction('../non-standard/id', { @@ -430,7 +428,7 @@ describe('AgentFrameworkStorageService', () => { bar: 'foo', }, }); - expect(mockedTransport.mock.calls[0]).toMatchInlineSnapshot(` + expect(mockedTransportRequest.mock.calls[0]).toMatchInlineSnapshot(` Array [ Object { "body": Object { @@ -448,7 +446,7 @@ describe('AgentFrameworkStorageService', () => { }); it('getInteraction', async () => { - mockedTransport.mockImplementation(async (params) => ({ + mockedTransportRequest.mockImplementation(async (params) => ({ body: { input: 'input', response: 'response', @@ -460,7 +458,7 @@ describe('AgentFrameworkStorageService', () => { expect(agentFrameworkService.getInteraction('_id', '')).rejects.toMatchInlineSnapshot( `[Error: interactionId is required]` ); - expect(mockedTransport).toBeCalledTimes(0); + expect(mockedTransportRequest).toBeCalledTimes(0); expect(agentFrameworkService.getInteraction('_id', 'interaction_id')).resolves .toMatchInlineSnapshot(` Object { @@ -470,18 +468,18 @@ describe('AgentFrameworkStorageService', () => { "response": "response", } `); - expect(mockedTransport).toBeCalledTimes(1); + expect(mockedTransportRequest).toBeCalledTimes(1); }); it('should encode id when calls getInteraction with non-standard params in request payload', async () => { - mockedTransport.mockResolvedValueOnce({ + mockedTransportRequest.mockResolvedValueOnce({ body: { input: 'input', response: 'response', }, }); await agentFrameworkService.getInteraction('_id', '../non-standard/id'); - expect(mockedTransport.mock.calls[0]).toMatchInlineSnapshot(` + expect(mockedTransportRequest.mock.calls[0]).toMatchInlineSnapshot(` Array [ Object { "method": "GET", diff --git a/server/services/storage/agent_framework_storage_service.ts b/server/services/storage/agent_framework_storage_service.ts index f0f1a735..05f57b19 100644 --- a/server/services/storage/agent_framework_storage_service.ts +++ b/server/services/storage/agent_framework_storage_service.ts @@ -28,12 +28,12 @@ export interface ConversationOptResponse { export class AgentFrameworkStorageService implements StorageService { constructor( - private readonly client: OpenSearchClient, + private readonly clientTransport: OpenSearchClient['transport'], private readonly messageParsers: MessageParser[] = [] ) {} async getConversation(conversationId: string): Promise { const [interactionsResp, conversation] = await Promise.all([ - this.client.transport.request({ + this.clientTransport.request({ method: 'GET', path: `${ML_COMMONS_BASE_API}/memory/${encodeURIComponent( conversationId @@ -43,7 +43,7 @@ export class AgentFrameworkStorageService implements StorageService { messages: InteractionFromAgentFramework[]; }> >, - this.client.transport.request({ + this.clientTransport.request({ method: 'GET', path: `${ML_COMMONS_BASE_API}/memory/${encodeURIComponent(conversationId)}`, }) as TransportRequestPromise< @@ -103,7 +103,7 @@ export class AgentFrameworkStorageService implements StorageService { ...(sortField && query.sortOrder && { sort: [{ [sortField]: query.sortOrder }] }), }; - const conversations = await this.client.transport.request({ + const conversations = await this.clientTransport.request({ method: 'GET', path: `${ML_COMMONS_BASE_API}/memory/_search`, body: requestParams, @@ -146,7 +146,7 @@ export class AgentFrameworkStorageService implements StorageService { } async deleteConversation(conversationId: string): Promise { - const response = await this.client.transport.request({ + const response = await this.clientTransport.request({ method: 'DELETE', path: `${ML_COMMONS_BASE_API}/memory/${encodeURIComponent(conversationId)}`, }); @@ -167,7 +167,7 @@ export class AgentFrameworkStorageService implements StorageService { conversationId: string, title: string ): Promise { - const response = await this.client.transport.request({ + const response = await this.clientTransport.request({ method: 'PUT', path: `${ML_COMMONS_BASE_API}/memory/${encodeURIComponent(conversationId)}`, body: { @@ -188,7 +188,7 @@ export class AgentFrameworkStorageService implements StorageService { } async getTraces(interactionId: string): Promise { - const response = (await this.client.transport.request({ + const response = (await this.clientTransport.request({ method: 'GET', path: `${ML_COMMONS_BASE_API}/memory/message/${encodeURIComponent(interactionId)}/traces`, })) as ApiResponse<{ @@ -216,7 +216,7 @@ export class AgentFrameworkStorageService implements StorageService { interactionId: string, additionalInfo: Record> ): Promise { - const response = await this.client.transport.request({ + const response = await this.clientTransport.request({ method: 'PUT', path: `${ML_COMMONS_BASE_API}/memory/message/${encodeURIComponent(interactionId)}`, body: { @@ -260,7 +260,7 @@ export class AgentFrameworkStorageService implements StorageService { if (!interactionId) { throw new Error('interactionId is required'); } - const interactionsResp = (await this.client.transport.request({ + const interactionsResp = (await this.clientTransport.request({ method: 'GET', path: `${ML_COMMONS_BASE_API}/memory/message/${encodeURIComponent(interactionId)}`, })) as ApiResponse; diff --git a/server/utils/get_opensearch_client_transport.test.ts b/server/utils/get_opensearch_client_transport.test.ts new file mode 100644 index 00000000..4b37577a --- /dev/null +++ b/server/utils/get_opensearch_client_transport.test.ts @@ -0,0 +1,44 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { getOpenSearchClientTransport } from './get_opensearch_client_transport'; +import { coreMock } from '../../../../src/core/server/mocks'; +import { loggerMock } from '../../../../src/core/server/logging/logger.mock'; + +const mockedLogger = loggerMock.create(); + +describe('getOpenSearchClientTransport', () => { + it('should return current user opensearch transport', async () => { + const core = coreMock.createRequestHandlerContext(); + + expect( + await getOpenSearchClientTransport({ + context: { core, assistant_plugin: { logger: mockedLogger } }, + }) + ).toBe(core.opensearch.client.asCurrentUser.transport); + }); + it('should return data source id related opensearch transport', async () => { + const transportMock = {}; + const core = coreMock.createRequestHandlerContext(); + const context = { + core, + assistant_plugin: { logger: mockedLogger }, + dataSource: { + opensearch: { + getClient: async (_dataSourceId: string) => ({ + transport: transportMock, + }), + }, + }, + }; + + expect( + await getOpenSearchClientTransport({ + context, + dataSourceId: 'foo', + }) + ).toBe(transportMock); + }); +}); diff --git a/server/utils/get_opensearch_client_transport.ts b/server/utils/get_opensearch_client_transport.ts new file mode 100644 index 00000000..9f54a4ff --- /dev/null +++ b/server/utils/get_opensearch_client_transport.ts @@ -0,0 +1,25 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import { OpenSearchClient, RequestHandlerContext } from '../../../../src/core/server'; + +export const getOpenSearchClientTransport = async ({ + context, + dataSourceId, +}: { + context: RequestHandlerContext & { + dataSource?: { + opensearch: { + getClient: (dataSourceId: string) => Promise; + }; + }; + }; + dataSourceId?: string; +}) => { + if (dataSourceId && context.dataSource) { + return (await context.dataSource.opensearch.getClient(dataSourceId)).transport; + } + return context.core.opensearch.client.asCurrentUser.transport; +};