diff --git a/x-pack/plugins/observability_ai_assistant/kibana.jsonc b/x-pack/plugins/observability_ai_assistant/kibana.jsonc index 1de47f47ac3867..38c5e81ab0f2b4 100644 --- a/x-pack/plugins/observability_ai_assistant/kibana.jsonc +++ b/x-pack/plugins/observability_ai_assistant/kibana.jsonc @@ -8,7 +8,7 @@ "browser": true, "configPath": ["xpack", "observabilityAIAssistant"], "requiredPlugins": ["triggersActionsUi", "actions", "security", "observabilityShared"], - "requiredBundles": ["kibanaReact"], + "requiredBundles": ["kibanaReact", "kibanaUtils"], "optionalPlugins": [], "extraPublicDirs": [] } diff --git a/x-pack/plugins/observability_ai_assistant/public/components/insight/insight.tsx b/x-pack/plugins/observability_ai_assistant/public/components/insight/insight.tsx index f474c20fd20c4c..4f9442efeaa708 100644 --- a/x-pack/plugins/observability_ai_assistant/public/components/insight/insight.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/components/insight/insight.tsx @@ -14,6 +14,8 @@ import { MessagePanel } from '../message_panel/message_panel'; import { MessageText } from '../message_panel/message_text'; import { InsightBase } from './insight_base'; import { InsightMissingCredentials } from './insight_missing_credentials'; +import { StopGeneratingButton } from '../stop_generating_button'; +import { RegenerateResponseButton } from '../regenerate_response_button'; function ChatContent({ messages, connectorId }: { messages: Message[]; connectorId: string }) { const chat = useChat({ messages, connectorId }); @@ -22,7 +24,21 @@ function ChatContent({ messages, connectorId }: { messages: Message[]; connector } error={chat.error} - controls={null} + controls={ + chat.loading ? ( + { + chat.abort(); + }} + /> + ) : ( + { + chat.regenerate(); + }} + /> + ) + } /> ); } diff --git a/x-pack/plugins/observability_ai_assistant/public/components/regenerate_response_button.tsx b/x-pack/plugins/observability_ai_assistant/public/components/regenerate_response_button.tsx index 922f3c34a302b8..b0439c854e2ca8 100644 --- a/x-pack/plugins/observability_ai_assistant/public/components/regenerate_response_button.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/components/regenerate_response_button.tsx @@ -11,7 +11,7 @@ import { i18n } from '@kbn/i18n'; export function RegenerateResponseButton(props: Partial) { return ( - + {i18n.translate('xpack.observabilityAiAssistant.regenerateResponseButtonLabel', { defaultMessage: 'Regenerate', })} diff --git a/x-pack/plugins/observability_ai_assistant/public/components/stop_generating_button.stories.tsx b/x-pack/plugins/observability_ai_assistant/public/components/stop_generating_button.stories.tsx new file mode 100644 index 00000000000000..acf27b4a012748 --- /dev/null +++ b/x-pack/plugins/observability_ai_assistant/public/components/stop_generating_button.stories.tsx @@ -0,0 +1,19 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +import { ComponentMeta, ComponentStoryObj } from '@storybook/react'; +import { StopGeneratingButton as Component } from './stop_generating_button'; + +const meta: ComponentMeta = { + component: Component, + title: 'app/Atoms/StopGeneratingButton', +}; + +export default meta; + +export const StopGeneratingButton: ComponentStoryObj = { + args: {}, +}; diff --git a/x-pack/plugins/observability_ai_assistant/public/components/stop_generating_button.tsx b/x-pack/plugins/observability_ai_assistant/public/components/stop_generating_button.tsx new file mode 100644 index 00000000000000..3008ee884ae645 --- /dev/null +++ b/x-pack/plugins/observability_ai_assistant/public/components/stop_generating_button.tsx @@ -0,0 +1,20 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { EuiButtonEmpty, EuiButtonEmptyProps } from '@elastic/eui'; +import React from 'react'; +import { i18n } from '@kbn/i18n'; + +export function StopGeneratingButton(props: Partial) { + return ( + + {i18n.translate('xpack.observabilityAiAssistant.stopGeneratingButtonLabel', { + defaultMessage: 'Stop generating', + })} + + ); +} diff --git a/x-pack/plugins/observability_ai_assistant/public/hooks/use_chat.test.ts b/x-pack/plugins/observability_ai_assistant/public/hooks/use_chat.test.ts new file mode 100644 index 00000000000000..0e8c51ce55f733 --- /dev/null +++ b/x-pack/plugins/observability_ai_assistant/public/hooks/use_chat.test.ts @@ -0,0 +1,264 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { useKibana } from '@kbn/kibana-react-plugin/public'; +import { act, renderHook } from '@testing-library/react-hooks'; +import { ChatCompletionResponseMessage } from 'openai'; +import { Observable } from 'rxjs'; +import { AbortError } from '@kbn/kibana-utils-plugin/common'; +import { ObservabilityAIAssistantService } from '../types'; +import { useChat } from './use_chat'; +import { useObservabilityAIAssistant } from './use_observability_ai_assistant'; + +jest.mock('@kbn/kibana-react-plugin/public'); +jest.mock('./use_observability_ai_assistant'); + +const WAIT_OPTIONS = { timeout: 5000 }; + +const mockUseKibana = useKibana as jest.MockedFunction; +const mockUseObservabilityAIAssistant = useObservabilityAIAssistant as jest.MockedFunction< + typeof useObservabilityAIAssistant +>; + +function mockDeltas(deltas: Array>) { + return mockResponse( + Promise.resolve( + new Observable((subscriber) => { + async function simulateDelays() { + for (const delta of deltas) { + await new Promise((resolve) => { + setTimeout(() => { + subscriber.next({ + choices: [ + { + role: 'assistant', + delta, + }, + ], + }); + resolve(); + }, 100); + }); + } + subscriber.complete(); + } + + simulateDelays(); + }) + ) + ); +} + +function mockResponse(response: Promise) { + mockUseObservabilityAIAssistant.mockReturnValue({ + chat: jest.fn().mockImplementation(() => { + return response; + }), + } as unknown as ObservabilityAIAssistantService); +} + +describe('useChat', () => { + beforeEach(() => { + mockUseKibana.mockReturnValue({ + services: { notifications: { showErrorDialog: jest.fn() } }, + } as any); + }); + + it('returns the result of the chat API', async () => { + mockDeltas([{ content: 'testContent' }]); + const { result, waitFor } = renderHook( + ({ messages, connectorId }) => useChat({ messages, connectorId }), + { initialProps: { messages: [], connectorId: 'myConnectorId' } } + ); + + expect(result.current.loading).toBeTruthy(); + expect(result.current.error).toBeUndefined(); + expect(result.current.content).toBeUndefined(); + + await waitFor(() => result.current.loading === false, WAIT_OPTIONS); + + expect(result.current.error).toBeUndefined(); + expect(result.current.content).toBe('testContent'); + }); + + it('handles 4xx and 5xx', async () => { + mockResponse(Promise.reject(new Error())); + const { result, waitFor } = renderHook( + ({ messages, connectorId }) => useChat({ messages, connectorId }), + { initialProps: { messages: [], connectorId: 'myConnectorId' } } + ); + + await waitFor(() => result.current.loading === false, WAIT_OPTIONS); + + expect(result.current.error).toBeInstanceOf(Error); + expect(result.current.content).toBeUndefined(); + + expect(mockUseKibana().services.notifications?.showErrorDialog).toHaveBeenCalled(); + }); + + it('handles valid responses but generation errors', async () => { + mockResponse( + Promise.resolve( + new Observable((subscriber) => { + subscriber.next({ choices: [{ role: 'assistant', delta: { content: 'foo' } }] }); + setTimeout(() => { + subscriber.error(new Error()); + }, 100); + }) + ) + ); + + const { result, waitFor } = renderHook( + ({ messages, connectorId }) => useChat({ messages, connectorId }), + { initialProps: { messages: [], connectorId: 'myConnectorId' } } + ); + + await waitFor(() => result.current.loading === false, WAIT_OPTIONS); + + expect(result.current.loading).toBe(false); + expect(result.current.error).toBeInstanceOf(Error); + expect(result.current.content).toBe('foo'); + + expect(mockUseKibana().services.notifications?.showErrorDialog).toHaveBeenCalled(); + }); + + it('handles aborted requests', async () => { + mockResponse( + Promise.resolve( + new Observable((subscriber) => { + subscriber.next({ choices: [{ role: 'assistant', delta: { content: 'foo' } }] }); + }) + ) + ); + + const { result, waitFor, unmount } = renderHook( + ({ messages, connectorId }) => useChat({ messages, connectorId }), + { initialProps: { messages: [], connectorId: 'myConnectorId' } } + ); + + await waitFor(() => result.current.content === 'foo', WAIT_OPTIONS); + + unmount(); + + expect(mockUseKibana().services.notifications?.showErrorDialog).not.toHaveBeenCalled(); + }); + + it('handles regenerations triggered by updates', async () => { + mockResponse( + Promise.resolve( + new Observable((subscriber) => { + subscriber.next({ choices: [{ role: 'assistant', delta: { content: 'foo' } }] }); + }) + ) + ); + + const { result, waitFor, rerender } = renderHook( + ({ messages, connectorId }) => useChat({ messages, connectorId }), + { initialProps: { messages: [], connectorId: 'myConnectorId' } } + ); + + await waitFor(() => result.current.content === 'foo', WAIT_OPTIONS); + + mockDeltas([{ content: 'bar' }]); + + rerender({ messages: [], connectorId: 'bar' }); + + await waitFor(() => result.current.loading === false); + + expect(mockUseKibana().services.notifications?.showErrorDialog).not.toHaveBeenCalled(); + + expect(result.current.content).toBe('bar'); + }); + + it('handles streaming updates', async () => { + mockDeltas([ + { + content: 'my', + }, + { + content: ' ', + }, + { + content: 'update', + }, + ]); + + const { result, waitForNextUpdate } = renderHook( + ({ messages, connectorId }) => useChat({ messages, connectorId }), + { initialProps: { messages: [], connectorId: 'myConnectorId' } } + ); + + await waitForNextUpdate(WAIT_OPTIONS); + + expect(result.current.content).toBe('my'); + + await waitForNextUpdate(WAIT_OPTIONS); + + expect(result.current.content).toBe('my '); + + await waitForNextUpdate(WAIT_OPTIONS); + + expect(result.current.content).toBe('my update'); + }); + + it('handles user aborts', async () => { + mockResponse( + Promise.resolve( + new Observable((subscriber) => { + subscriber.next({ choices: [{ role: 'assistant', delta: { content: 'foo' } }] }); + }) + ) + ); + + const { result, waitForNextUpdate } = renderHook( + ({ messages, connectorId }) => useChat({ messages, connectorId }), + { initialProps: { messages: [], connectorId: 'myConnectorId' } } + ); + + await waitForNextUpdate(WAIT_OPTIONS); + + act(() => { + result.current.abort(); + }); + + expect(mockUseKibana().services.notifications?.showErrorDialog).not.toHaveBeenCalled(); + + expect(result.current.content).toBe('foo'); + expect(result.current.loading).toBe(false); + expect(result.current.error).toBeInstanceOf(AbortError); + }); + + it('handles user regenerations', async () => { + mockResponse( + Promise.resolve( + new Observable((subscriber) => { + subscriber.next({ choices: [{ role: 'assistant', delta: { content: 'foo' } }] }); + }) + ) + ); + + const { result, waitForNextUpdate } = renderHook( + ({ messages, connectorId }) => useChat({ messages, connectorId }), + { initialProps: { messages: [], connectorId: 'myConnectorId' } } + ); + + await waitForNextUpdate(WAIT_OPTIONS); + + act(() => { + mockDeltas([{ content: 'bar' }]); + result.current.regenerate(); + }); + + await waitForNextUpdate(WAIT_OPTIONS); + + expect(mockUseKibana().services.notifications?.showErrorDialog).not.toHaveBeenCalled(); + + expect(result.current.content).toBe('bar'); + expect(result.current.loading).toBe(false); + expect(result.current.error).toBeUndefined(); + }); +}); diff --git a/x-pack/plugins/observability_ai_assistant/public/hooks/use_chat.ts b/x-pack/plugins/observability_ai_assistant/public/hooks/use_chat.ts index eb61f592cd2f32..36a957cc1d660e 100644 --- a/x-pack/plugins/observability_ai_assistant/public/hooks/use_chat.ts +++ b/x-pack/plugins/observability_ai_assistant/public/hooks/use_chat.ts @@ -6,10 +6,11 @@ */ import { clone } from 'lodash'; -import { useEffect, useState } from 'react'; +import { useCallback, useEffect, useRef, useState } from 'react'; import { concatMap, delay, of } from 'rxjs'; import { useKibana } from '@kbn/kibana-react-plugin/public'; import { i18n } from '@kbn/i18n'; +import { AbortError } from '@kbn/kibana-utils-plugin/common'; import type { Message } from '../../common/types'; import { useObservabilityAIAssistant } from './use_observability_ai_assistant'; @@ -29,6 +30,8 @@ export function useChat({ messages, connectorId }: { messages: Message[]; connec }; loading: boolean; error?: Error; + abort: () => void; + regenerate: () => void; } { const assistant = useObservabilityAIAssistant(); @@ -42,8 +45,12 @@ export function useChat({ messages, connectorId }: { messages: Message[]; connec const [loading, setLoading] = useState(false); - useEffect(() => { - const controller = new AbortController(); + const controllerRef = useRef(new AbortController()); + + const regenerate = useCallback(() => { + controllerRef.current.abort(); + + const controller = (controllerRef.current = new AbortController()); setResponse(undefined); setError(undefined); @@ -65,6 +72,9 @@ export function useChat({ messages, connectorId }: { messages: Message[]; connec .pipe(concatMap((value) => of(value).pipe(delay(50)))) .subscribe({ next: (chunk) => { + if (controller.signal.aborted) { + return; + } partialResponse.content += chunk.choices[0].delta.content ?? ''; partialResponse.function_call.name += chunk.choices[0].delta.function_call?.name ?? ''; @@ -80,12 +90,16 @@ export function useChat({ messages, connectorId }: { messages: Message[]; connec }, }); - controller.signal.addEventListener('abort', () => { + controllerRef.current.signal.addEventListener('abort', () => { subscription.unsubscribe(); + reject(new AbortError()); }); }); }) .catch((err) => { + if (controller.signal.aborted) { + return; + } notifications?.showErrorDialog({ title: i18n.translate('xpack.observabilityAiAssistant.failedToLoadChatTitle', { defaultMessage: 'Failed to load chat', @@ -95,6 +109,9 @@ export function useChat({ messages, connectorId }: { messages: Message[]; connec setError(err); }) .finally(() => { + if (controller.signal.aborted) { + return; + } setLoading(false); }); @@ -103,9 +120,19 @@ export function useChat({ messages, connectorId }: { messages: Message[]; connec }; }, [messages, connectorId, assistant, notifications]); + useEffect(() => { + return regenerate(); + }, [regenerate]); + return { ...response, error, loading, + abort: () => { + setLoading(false); + setError(new AbortError()); + controllerRef.current.abort(); + }, + regenerate, }; } diff --git a/x-pack/plugins/observability_ai_assistant/public/service/create_service.ts b/x-pack/plugins/observability_ai_assistant/public/service/create_service.ts index 2d0a2851fc901e..78ab34730484f2 100644 --- a/x-pack/plugins/observability_ai_assistant/public/service/create_service.ts +++ b/x-pack/plugins/observability_ai_assistant/public/service/create_service.ts @@ -52,6 +52,10 @@ export function createService(coreSetup: CoreSetup): ObservabilityAIAssistantSer throw new Error('Could not get reader from response'); } + signal.addEventListener('abort', () => { + reader.cancel(); + }); + return readableStreamReaderIntoObservable(reader).pipe( map((line) => line.substring(6)), filter((line) => !!line && line !== '[DONE]'),