From bf90e216a38fb4d399c1ddf360bd8076a4c62039 Mon Sep 17 00:00:00 2001 From: Corentin Garcia Date: Sat, 16 Dec 2023 12:39:12 +0100 Subject: [PATCH] feat: switch mistral model --- README.md | 2 +- main.py | 2 +- roboquote/constants.py | 2 +- roboquote/entities/text_model.py | 2 +- roboquote/quote_text_generation.py | 7 ++++--- 5 files changed, 8 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 572c7c3..3609d57 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ Generate random "inspirational" quotes images by using an AI text generation mod The following models can be used: - [bigscience/bloom](https://huggingface.co/bigscience/bloom) -- [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) +- [mistralai/Mixtral-8x7B-Instruct-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) ## Examples diff --git a/main.py b/main.py index 2e483a1..3b2c58f 100644 --- a/main.py +++ b/main.py @@ -29,7 +29,7 @@ def generate( + "Works best with simple queries like 'mountain', 'sea' etc.", ), text_model: TextModel = typer.Option( - default=TextModel.MISTRAL_7B_INSTRUCT.value, + default=TextModel.MISTRAL_8X7B_INSTRUCT.value, help="The text generation model to use.", ), ) -> None: diff --git a/roboquote/constants.py b/roboquote/constants.py index 5877bb0..ace3f02 100644 --- a/roboquote/constants.py +++ b/roboquote/constants.py @@ -5,4 +5,4 @@ HUGGING_FACE_BASE_API_URL = "https://api-inference.huggingface.co/models/" -DEFAULT_WEB_TEXT_MODEL = TextModel.MISTRAL_7B_INSTRUCT +DEFAULT_WEB_TEXT_MODEL = TextModel.MISTRAL_8X7B_INSTRUCT diff --git a/roboquote/entities/text_model.py b/roboquote/entities/text_model.py index 69bb373..0368722 100644 --- a/roboquote/entities/text_model.py +++ b/roboquote/entities/text_model.py @@ -3,4 +3,4 @@ class TextModel(Enum): BLOOM = "bigscience/bloom" - MISTRAL_7B_INSTRUCT = "mistralai/Mistral-7B-Instruct-v0.1" + MISTRAL_8X7B_INSTRUCT = "mistralai/Mixtral-8x7B-Instruct-v0.1" diff --git a/roboquote/quote_text_generation.py b/roboquote/quote_text_generation.py index 471e5bb..8b6c554 100644 --- a/roboquote/quote_text_generation.py +++ b/roboquote/quote_text_generation.py @@ -24,9 +24,9 @@ def _get_base_prompt_by_model( f"On a {background_search_query} themed inspirational picture, " + "there was a fitting short quote: ", ] - elif text_model == TextModel.MISTRAL_7B_INSTRUCT: + elif text_model == TextModel.MISTRAL_8X7B_INSTRUCT: prompts = [ - "[INST] Give me a inspirational quote " + "[INST] Give me a inspirational quote " + f"that fits a {background_search_query} themed picture, " + "similar to old Tumblr pictures. Give me the quote text " + "without any surrounding text. Do not return lists. " @@ -81,7 +81,7 @@ def _cleanup_text(generated_text: str) -> str: async def get_random_quote(background_search_query: str, text_model: TextModel) -> str: """For a given background category, get a random quote.""" prompt = _get_random_prompt(background_search_query, text_model) - logger.debug(f'Prompt for model: "{prompt}"') + logger.debug(f'Prompt for {text_model.value}: "{prompt}"') headers = { "Authorization": f"Bearer {config.HUGGING_FACE_API_TOKEN}", @@ -107,6 +107,7 @@ async def get_random_quote(background_search_query: str, text_model: TextModel) try: response_content = json.loads(response.content.decode("utf-8")) + logger.debug(f'Hugging Face response {response.status_code}: {response_content}') except json.JSONDecodeError as e: raise CannotGenerateQuoteError() from e