Skip to content

Commit

Permalink
feat: expose linker from library class
Browse files Browse the repository at this point in the history
  • Loading branch information
nsantacruz committed Nov 9, 2023
1 parent 10300f3 commit facd588
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 35 deletions.
3 changes: 1 addition & 2 deletions sefaria/model/linker/named_entity_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,8 +376,7 @@ def match(self, named_entity: RawNamedEntity) -> List[Topic]:

class NamedEntityResolver:

def __init__(self, named_entity_recognizer: NamedEntityRecognizer, topic_matcher: TopicMatcher):
self._named_entity_recognizer = named_entity_recognizer
def __init__(self, topic_matcher: TopicMatcher):
self._topic_matcher = topic_matcher

def bulk_resolve(self, raw_named_entities: List[RawNamedEntity], with_failures=False) -> List[ResolvedNamedEntity]:
Expand Down
7 changes: 1 addition & 6 deletions sefaria/model/linker/ref_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,11 +244,9 @@ def get_ref_by_title(self, title: str) -> Optional[text.Ref]:

class RefResolver:

def __init__(self, lang: str, named_entity_recognizer: NamedEntityRecognizer,
ref_part_title_trie: MatchTemplateTrie, term_matcher: TermMatcher) -> None:
def __init__(self, lang: str, ref_part_title_trie: MatchTemplateTrie, term_matcher: TermMatcher) -> None:

self._lang = lang
self._named_entity_recognizer = named_entity_recognizer
self._ref_part_title_trie = ref_part_title_trie
self._term_matcher = term_matcher
self._ibid_history = IbidHistory()
Expand Down Expand Up @@ -292,9 +290,6 @@ def _update_ibid_history(self, temp_resolved: List[PossiblyAmbigResolvedRef]):
else:
self._ibid_history.last_refs = temp_resolved[-1].ref

def get_ner(self) -> NamedEntityRecognizer:
return self._named_entity_recognizer

def get_ref_part_title_trie(self) -> MatchTemplateTrie:
return self._ref_part_title_trie

Expand Down
50 changes: 23 additions & 27 deletions sefaria/model/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -4865,8 +4865,7 @@ def __init__(self):
self._simple_term_mapping = {}
self._full_term_mapping = {}
self._simple_term_mapping_json = None
self._ref_resolver_by_lang = {}
self._named_entity_resolver_by_lang = {}
self._linker_by_lang = {}

# Topics
self._topic_mapping = {}
Expand Down Expand Up @@ -5601,24 +5600,32 @@ def _build_topic_mapping(self):
self._topic_mapping = {t.slug: {"en": t.get_primary_title("en"), "he": t.get_primary_title("he")} for t in TopicSet()}
return self._topic_mapping

def get_named_entity_resolver(self, lang: str, rebuild=False):
resolver = self._named_entity_resolver_by_lang.get(lang)
if not resolver or rebuild:
resolver = self.build_named_entity_resolver(lang)
return resolver
def get_linker(self, lang: str, rebuild=False):
linker = self._linker_by_lang.get(lang)
if not linker or rebuild:
linker = self.build_linker(lang)
return linker

def build_named_entity_resolver(self, lang: str):
def build_linker(self, lang: str):
from sefaria.model.linker.linker import Linker

logger.info("Loading Spacy Model")

named_entity_resolver = self._build_named_entity_resolver(lang)
ref_resolver = self._build_ref_resolver(lang)
named_entity_recognizer = self._build_named_entity_recognizer(lang)
self._linker_by_lang[lang] = Linker(ref_resolver, named_entity_resolver, named_entity_recognizer)
return self._linker_by_lang[lang]

@staticmethod
def _build_named_entity_resolver(self, lang: str):
from .linker.named_entity_resolver import TopicMatcher, NamedEntityResolver

named_entity_types_to_topics = {
"PERSON": {"ontology_roots": ['people'], "single_slugs": ['god', 'the-tetragrammaton']},
"GROUP": {'ontology_roots': ["group-of-people"]},
}
self._named_entity_resolver_by_lang[lang] = NamedEntityResolver(
self._build_named_entity_recognizer(lang),
TopicMatcher(lang, named_entity_types_to_topics)
)
return self._named_entity_resolver_by_lang[lang]
return NamedEntityResolver(TopicMatcher(lang, named_entity_types_to_topics))

@staticmethod
def _build_named_entity_recognizer(lang: str):
Expand All @@ -5631,30 +5638,19 @@ def _build_named_entity_recognizer(lang: str):
load_spacy_model(RAW_REF_PART_MODEL_BY_LANG_FILEPATH[lang])
)

def get_ref_resolver(self, lang: str, rebuild=False):
resolver = self._ref_resolver_by_lang.get(lang)
if not resolver or rebuild:
resolver = self.build_ref_resolver(lang)
return resolver

def build_ref_resolver(self, lang: str):
def _build_ref_resolver(self, lang: str):
from .linker.match_template import MatchTemplateTrie
from .linker.ref_resolver import RefResolver, TermMatcher
from sefaria.model.schema import NonUniqueTermSet

logger.info("Loading Spacy Model")

root_nodes = list(filter(lambda n: getattr(n, 'match_templates', None) is not None, self.get_index_forest()))
alone_nodes = reduce(lambda a, b: a + b.index.get_referenceable_alone_nodes(), root_nodes, [])
non_unique_terms = NonUniqueTermSet()
ner = self._build_named_entity_recognizer(lang)

self._ref_resolver_by_lang[lang] = RefResolver(
lang, ner,
MatchTemplateTrie(lang, nodes=(root_nodes + alone_nodes), scope='alone'),
return RefResolver(
lang, MatchTemplateTrie(lang, nodes=(root_nodes + alone_nodes), scope='alone'),
TermMatcher(lang, non_unique_terms),
)
return self._ref_resolver_by_lang[lang]

def get_index_forest(self):
"""
Expand Down

0 comments on commit facd588

Please sign in to comment.