Skip to content

Commit

Permalink
fix(linker): fix bugs with outputting named entities.
Browse files Browse the repository at this point in the history
  • Loading branch information
nsantacruz committed Oct 26, 2023
1 parent 873c496 commit 2ec609b
Showing 1 changed file with 26 additions and 9 deletions.
35 changes: 26 additions & 9 deletions sefaria/model/linker/named_entity_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections import defaultdict
from sefaria.model.linker.ref_part import RawRef, RawRefPart, SpanOrToken, span_inds, RefPartType, RawNamedEntity, NamedEntityType
from sefaria.helper.normalization import NormalizerComposer
from sefaria.model.topic import Topic, TopicSet
from sefaria.model.topic import Topic, TopicSet, RefTopicLink

try:
import spacy
Expand Down Expand Up @@ -78,22 +78,25 @@ def bulk_map_normal_output_to_original_input(self, input: List[str], raw_ref_lis
for temp_input, raw_ref_list in zip(input, raw_ref_list_list):
self.map_normal_output_to_original_input(temp_input, raw_ref_list)

def map_normal_output_to_original_input(self, input: str, raw_ref_list: List[RawRef]) -> None:
def map_normal_output_to_original_input(self, input: str, named_entities: List[RawNamedEntity]) -> None:
"""
Ref resolution ran on normalized input. Remap raw refs to original (non-normalized) input
"""
unnorm_doc = self._raw_ref_model.make_doc(input)
mapping = self._normalizer.get_mapping_after_normalization(input)
# this function name is waaay too long
conv = self._normalizer.convert_normalized_indices_to_unnormalized_indices
norm_inds = [raw_ref.char_indices for raw_ref in raw_ref_list]
norm_inds = [named_entity.char_indices for named_entity in named_entities]
unnorm_inds = conv(norm_inds, mapping)
unnorm_part_inds = []
for (raw_ref, (norm_raw_ref_start, _)) in zip(raw_ref_list, norm_inds):
for (named_entity, (norm_raw_ref_start, _)) in zip(named_entities, norm_inds):
raw_ref_parts = named_entity.raw_ref_parts if isinstance(named_entity, RawRef) else []
unnorm_part_inds += [conv([[norm_raw_ref_start + i for i in part.char_indices]
for part in raw_ref.raw_ref_parts], mapping)]
for raw_ref, temp_unnorm_inds, temp_unnorm_part_inds in zip(raw_ref_list, unnorm_inds, unnorm_part_inds):
raw_ref.map_new_indices(unnorm_doc, temp_unnorm_inds, temp_unnorm_part_inds)
for part in raw_ref_parts], mapping)]
for named_entity, temp_unnorm_inds, temp_unnorm_part_inds in zip(named_entities, unnorm_inds, unnorm_part_inds):
named_entity.map_new_indices(unnorm_doc, temp_unnorm_inds)
if isinstance(named_entity, RawRef):
named_entity.map_new_part_indices(temp_unnorm_part_inds)

@property
def raw_ref_model(self):
Expand Down Expand Up @@ -185,6 +188,18 @@ def __init__(self, raw_named_entity: RawNamedEntity, topic: Topic):
self.raw_named_entity = raw_named_entity
self.topic = topic

def to_ref_topic_link(self) -> RefTopicLink:
start_char, end_char = self.raw_named_entity.char_indices
return RefTopicLink({
"ref": "",
"toTopic": self.topic.slug if self.topic else "N/A",
"charLevelData": {
"startChar": start_char,
"endChar": end_char,
"text": self.raw_named_entity.text,
}
})


class TopicMatcher:

Expand All @@ -208,16 +223,18 @@ def __init__(self, named_entity_recognizer: NamedEntityRecognizer, topic_matcher
self._named_entity_recognizer = named_entity_recognizer
self._topic_matcher = topic_matcher

def bulk_resolve_named_entities(self, inputs: List[str]) -> List[List[ResolvedNamedEntity]]:
def bulk_resolve_named_entities(self, inputs: List[str], with_failures=False) -> List[List[ResolvedNamedEntity]]:
all_named_entities = self._named_entity_recognizer.bulk_get_raw_named_entities(inputs)
resolved = []
for named_entities in all_named_entities:
temp_resolved = []
for named_entity in named_entities:
matched_topic = self._topic_matcher.match(named_entity.text)
if matched_topic:
if matched_topic or with_failures:
temp_resolved += [ResolvedNamedEntity(named_entity, matched_topic)]
resolved += [temp_resolved]
named_entity_list_list = [[rr.raw_named_entity for rr in inner_resolved] for inner_resolved in resolved]
self._named_entity_recognizer.bulk_map_normal_output_to_original_input(inputs, named_entity_list_list)
return resolved


Expand Down

0 comments on commit 2ec609b

Please sign in to comment.