Skip to content

Commit

Permalink
feat(llm): add topic_prompt.py which acts as an interface to the LLM …
Browse files Browse the repository at this point in the history
…repo.
  • Loading branch information
nsantacruz committed Feb 4, 2024
1 parent fe9c971 commit 0c64a20
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 0 deletions.
Empty file added sefaria/helper/llm/__init__.py
Empty file.
133 changes: 133 additions & 0 deletions sefaria/helper/llm/topic_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from typing import List, Callable, Any, Optional, Dict
import re
from sefaria.model.text import Ref, library, TextChunk
from sefaria.model.passage import Passage
from sefaria.model.topic import Topic
from sefaria.client.wrapper import get_links
from sefaria.datatype.jagged_array import JaggedTextArray


def _lang_dict_by_func(func: Callable[[str], Any]):
return {lang: func(lang) for lang in ('en', 'he')}


def _get_commentary_for_tref(tref: str) -> List[dict]:
"""
Return list of commentary for tref. Currently only considers English commentary.
:param tref:
:return: list where each element represents a single commentary on `tref`. Each element is a dict with keys `en`
and `he` for the English and Hebrew text.
"""
library.rebuild_toc()
commentary = []

for link_dict in get_links(tref, with_text=True):
if link_dict['category'] not in {'Commentary'}:
continue
if not link_dict['sourceHasEn']:
continue
temp_commentary = {
"ref": link_dict['sourceRef'],
"text": _lang_dict_by_func(
lambda lang: JaggedTextArray(link_dict['text' if lang == 'en' else 'he']).flatten_to_string()),
}
temp_commentary['text'] = _lang_dict_by_func(
lambda lang: re.sub(r"<[^>]+>", " ", TextChunk.strip_itags(temp_commentary.text[lang])))
commentary += [temp_commentary]
return commentary


def _get_context_ref(segment_oref: Ref) -> Optional[Ref]:
"""
Decide if `segment_oref` requires a context ref and if so, return it.
A context ref is a ref which contains `segment_oref` and provides more context for it.
E.g. Genesis 1 is a context ref for Genesis 1:13
:param segment_oref:
:return:
"""
if segment_oref.primary_category == "Tanakh":
return segment_oref.section_ref()
elif segment_oref.index.get_primary_corpus() == "Bavli":
passage = Passage.containing_segment(segment_oref)
return passage.ref()
return None


def _get_surrounding_text(oref: Ref) -> Optional[Dict[str, str]]:
"""
Get the surrounding context text for `oref`. See _get_context_ref() for an explanation of what a context ref is.
:param oref:
:return: dict with keys "en" and "he" and values the English and Hebrew text of the surrounding text, respectively.
"""
context_ref = _get_context_ref(oref)
if context_ref:
return _lang_dict_by_func(lambda lang: context_ref.text(lang).as_string())


def _make_topic_prompt_topic(sefaria_topic: Topic) -> dict:
"""
Return a dict that can be instantiated as `sefaria_interface.Topic` in the LLM repo.
This represents the basic metadata of a topic for the LLM repo to process.
:param sefaria_topic:
:return:
"""
return {
"slug": sefaria_topic.slug,
"description": getattr(sefaria_topic, 'description', {}),
"title": _lang_dict_by_func(sefaria_topic.get_primary_title),
}


def _make_topic_prompt_source(oref: Ref, context: str) -> dict:
"""
Return a dict that can be instantiated as `sefaria_interface.TopicPromptSource` in the LLM repo.
This represents the basic metadata of a source for the LLM repo to process.
:param oref:
:param context:
:return:
"""

index = oref.index
text = _lang_dict_by_func(lambda lang: oref.text(lang).as_string())
book_description = _lang_dict_by_func(lambda lang: getattr(index, f"{lang}Desc", "N/A"))
book_title = _lang_dict_by_func(index.get_title)
composition_time_period = index.composition_time_period()
pub_year = composition_time_period.period_string("en") if composition_time_period else "N/A"
try:
author_name = Topic.init(index.authors[0]).get_primary_title("en") if len(index.authors) > 0 else "N/A"
except AttributeError:
author_name = "N/A"

commentary = None
if index.get_primary_category() == "Tanakh":
commentary = _get_commentary_for_tref(oref.normal())
surrounding_text = _get_surrounding_text(oref)
return {
"ref": oref.normal(),
"categories": index.categories,
"book_description": book_description,
"book_title": book_title,
"comp_date": pub_year,
"author_name": author_name,
"context_hint": context,
"text": text,
"commentary": commentary,
"surrounding_text": surrounding_text,
}


def make_topic_prompt_input(lang: str, sefaria_topic: Topic, orefs: List[Ref], contexts: List[str]) -> dict:
"""
Return a dict that can be instantiated as `sefaria_interface.TopicPromptInput` in the LLM repo.
This represents the full input required for the LLM repo to generate topic prompts.
:param lang:
:param sefaria_topic:
:param orefs:
:param contexts:
:return:
"""
return {
"lang": lang,
"topic": _make_topic_prompt_topic(sefaria_topic),
"sources": [_make_topic_prompt_source(oref, context) for oref, context in zip(orefs, contexts)]
}

0 comments on commit 0c64a20

Please sign in to comment.