diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index 97392168a..ffd15cc2c 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -267,7 +267,12 @@ def get_llm_chain(self): if not lm_provider or not lm_provider_params: return None - if curr_lm_id != next_lm_id: + if ( + len(self._chat_history) <= 2 + ): # Check if chat history has been cleared, then reinitialize the llm chain + self.log.info("Clear conversation memory, re-initializing the llm chain.") + self.create_llm_chain(lm_provider, lm_provider_params) + elif curr_lm_id != next_lm_id: self.log.info( f"Switching chat language model from {curr_lm_id} to {next_lm_id}." ) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py index 51526801b..7d776a739 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/clear.py @@ -19,14 +19,15 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) async def process_message(self, _): - tmp_chat_history = self._chat_history[0] - self._chat_history.clear() for handler in self._root_chat_handlers.values(): if not handler: continue handler.broadcast_message(ClearMessage()) + tmp_chat_history = self._chat_history[0] + self._chat_history.clear() + self._chat_history = [tmp_chat_history] self.reply(tmp_chat_history.body)