Skip to content

Commit

Permalink
feat: switch mistral model
Browse files Browse the repository at this point in the history
  • Loading branch information
corenting committed Dec 16, 2023
1 parent 3ce1e24 commit bf90e21
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 7 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Generate random "inspirational" quotes images by using an AI text generation mod

The following models can be used:
- [bigscience/bloom](https://huggingface.co/bigscience/bloom)
- [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
- [mistralai/Mixtral-8x7B-Instruct-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)

## Examples

Expand Down
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def generate(
+ "Works best with simple queries like 'mountain', 'sea' etc.",
),
text_model: TextModel = typer.Option(
default=TextModel.MISTRAL_7B_INSTRUCT.value,
default=TextModel.MISTRAL_8X7B_INSTRUCT.value,
help="The text generation model to use.",
),
) -> None:
Expand Down
2 changes: 1 addition & 1 deletion roboquote/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@

HUGGING_FACE_BASE_API_URL = "https://api-inference.huggingface.co/models/"

DEFAULT_WEB_TEXT_MODEL = TextModel.MISTRAL_7B_INSTRUCT
DEFAULT_WEB_TEXT_MODEL = TextModel.MISTRAL_8X7B_INSTRUCT
2 changes: 1 addition & 1 deletion roboquote/entities/text_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

class TextModel(Enum):
BLOOM = "bigscience/bloom"
MISTRAL_7B_INSTRUCT = "mistralai/Mistral-7B-Instruct-v0.1"
MISTRAL_8X7B_INSTRUCT = "mistralai/Mixtral-8x7B-Instruct-v0.1"
7 changes: 4 additions & 3 deletions roboquote/quote_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ def _get_base_prompt_by_model(
f"On a {background_search_query} themed inspirational picture, "
+ "there was a fitting short quote: ",
]
elif text_model == TextModel.MISTRAL_7B_INSTRUCT:
elif text_model == TextModel.MISTRAL_8X7B_INSTRUCT:
prompts = [
"<s>[INST] Give me a inspirational quote "
"[INST] Give me a inspirational quote "
+ f"that fits a {background_search_query} themed picture, "
+ "similar to old Tumblr pictures. Give me the quote text "
+ "without any surrounding text. Do not return lists. "
Expand Down Expand Up @@ -81,7 +81,7 @@ def _cleanup_text(generated_text: str) -> str:
async def get_random_quote(background_search_query: str, text_model: TextModel) -> str:
"""For a given background category, get a random quote."""
prompt = _get_random_prompt(background_search_query, text_model)
logger.debug(f'Prompt for model: "{prompt}"')
logger.debug(f'Prompt for {text_model.value}: "{prompt}"')

headers = {
"Authorization": f"Bearer {config.HUGGING_FACE_API_TOKEN}",
Expand All @@ -107,6 +107,7 @@ async def get_random_quote(background_search_query: str, text_model: TextModel)

try:
response_content = json.loads(response.content.decode("utf-8"))
logger.debug(f'Hugging Face response {response.status_code}: {response_content}')
except json.JSONDecodeError as e:
raise CannotGenerateQuoteError() from e

Expand Down

0 comments on commit bf90e21

Please sign in to comment.