Skip to content

Commit

Permalink
feat(llm): save topic prompts after generation
Browse files Browse the repository at this point in the history
  • Loading branch information
nsantacruz committed Feb 4, 2024
1 parent 8e45a7c commit b554a62
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 4 deletions.
27 changes: 27 additions & 0 deletions sefaria/helper/llm/llm_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""
Classes for instantiating objects received from the LLM repo
"""
from typing import List
from dataclasses import dataclass


@dataclass
class TopicPrompt:
title: str
prompt: str
lang: str
ref: str
slug: str


@dataclass
class TopicPromptGenerationOutput:
lang: str
prompts: List[TopicPrompt]

@staticmethod
def create(raw_output):
return TopicPromptGenerationOutput(
**{**raw_output, "prompts": [TopicPrompt(**raw_prompt) for raw_prompt in raw_output['prompts']]}
)

8 changes: 5 additions & 3 deletions sefaria/helper/llm/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from celery import shared_task
from sefaria.model.text import Ref
from sefaria.model.topic import Topic
from sefaria.helper.llm.topic_prompt import make_topic_prompt_input
from sefaria.helper.llm.topic_prompt import make_topic_prompt_input, save_topic_prompt_output
from sefaria.helper.llm.llm_interface import TopicPromptGenerationOutput


@shared_task
Expand All @@ -14,5 +15,6 @@ def generate_topic_prompts(lang: str, sefaria_topic: Topic, orefs: List[Ref], co


@shared_task
def save_topic_prompts(topic_prompts: List[dict]):
pass
def save_topic_prompts(raw_output: TopicPromptGenerationOutput):
output = TopicPromptGenerationOutput.create(raw_output)
save_topic_prompt_output(output)
22 changes: 21 additions & 1 deletion sefaria/helper/llm/topic_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import re
from sefaria.model.text import Ref, library, TextChunk
from sefaria.model.passage import Passage
from sefaria.model.topic import Topic
from sefaria.model.topic import Topic, RefTopicLink
from sefaria.client.wrapper import get_links
from sefaria.datatype.jagged_array import JaggedTextArray
from sefaria.helper.llm.llm_interface import TopicPromptGenerationOutput
from sefaria.utils.util import deep_update


def _lang_dict_by_func(func: Callable[[str], Any]):
Expand Down Expand Up @@ -131,3 +133,21 @@ def make_topic_prompt_input(lang: str, sefaria_topic: Topic, orefs: List[Ref], c
"topic": _make_topic_prompt_topic(sefaria_topic),
"sources": [_make_topic_prompt_source(oref, context) for oref, context in zip(orefs, contexts)]
}


def save_topic_prompt_output(output: TopicPromptGenerationOutput) -> None:
for prompt in output.prompts:
link = RefTopicLink().load({
"ref": prompt.ref,
"toTopic": prompt.slug,
"dataSource": "learning-team",
"linkType": "about",
})
curr_descriptions = getattr(link, "descriptions", {})
description_edits = {output.lang: {
"title": prompt.title, "ai_title": prompt.title,
"prompt": prompt.prompt, "ai_prompt": prompt.prompt,
"published": False, "review_state": "not reviewed"
}}
setattr(link, "descriptions", deep_update(curr_descriptions, description_edits))
link.save()

0 comments on commit b554a62

Please sign in to comment.