Skip to content

Commit

Permalink
feat(linker): basic NamedEntityResolver
Browse files Browse the repository at this point in the history
  • Loading branch information
nsantacruz committed Oct 24, 2023
1 parent c212a58 commit 0005b82
Showing 1 changed file with 46 additions and 0 deletions.
46 changes: 46 additions & 0 deletions sefaria/model/linker/named_entity_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections import defaultdict
from sefaria.model.linker.ref_part import RawRef, RawRefPart, SpanOrToken, span_inds, RefPartType, RawNamedEntity, NamedEntityType
from sefaria.helper.normalization import NormalizerComposer
from sefaria.model.topic import Topic, TopicSet

try:
import spacy
Expand Down Expand Up @@ -176,3 +177,48 @@ def _get_dh_continuation(ispan: int, ipart: int, named_entities: List[RawNamedEn
dh_cont = part_span.doc[part_span_end:next_part_span_start]

return dh_cont


class ResolvedNamedEntity:

def __init__(self, raw_named_entity: RawNamedEntity, topic: Topic):
self.raw_named_entity = raw_named_entity
self.topic = topic


class TopicMatcher:

def __init__(self, lang: str, topics=None):
topics = topics or TopicSet()
self._slug_topic_map = {t.slug: t for t in topics}
self._title_slug_map = {}
for topic in topics:
for title in topic.get_titles(lang=lang, with_disambiguation=False):
self._title_slug_map[title] = topic.slug

def match(self, text) -> Optional[Topic]:
slug = self._title_slug_map.get(text)
if slug:
return self._slug_topic_map[slug]


class NamedEntityResolver:

def __init__(self, named_entity_recognizer: NamedEntityRecognizer, topic_matcher: TopicMatcher):
self._named_entity_recognizer = named_entity_recognizer
self._topic_matcher = topic_matcher

def bulk_resolve_named_entities(self, inputs: List[str]) -> List[List[ResolvedNamedEntity]]:
all_named_entities = self._named_entity_recognizer.bulk_get_raw_named_entities(inputs)
resolved = []
for named_entities in all_named_entities:
temp_resolved = []
for named_entity in named_entities:
matched_topic = self._topic_matcher.match(named_entity.text)
if matched_topic:
temp_resolved += [ResolvedNamedEntity(named_entity, matched_topic)]
resolved += [temp_resolved]
return resolved



0 comments on commit 0005b82

Please sign in to comment.