Skip to content

Commit

Permalink
fix(linker): keep track of substitution end indices which seem to be …
Browse files Browse the repository at this point in the history
…critical to get normalization mapping to work in all cases
  • Loading branch information
nsantacruz committed Nov 12, 2023
1 parent fb5ad5f commit 5c97e57
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 15 deletions.
25 changes: 15 additions & 10 deletions sefaria/helper/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,29 +100,36 @@ def get_mapping_after_normalization(self, text, removal_list=None, reverse=False
removal_list = removal_list or self.find_text_to_remove(text, **kwargs)
total_removed = 0
removal_map = {}
subst_end_indexes = set()
for (start, end), subst in removal_list:
normalized_text_index = start if reverse else (start + min(len(subst), end-start) - total_removed)
curr_removed = end - start - len(subst)
if curr_removed != 0:
total_removed += curr_removed
removal_map[normalized_text_index] = total_removed
return removal_map
if len(subst) > 0:
subst_end_indexes.add(normalized_text_index + 1)
return removal_map, subst_end_indexes

def norm_to_unnorm_indices(self, text, normalized_indices, removal_list=None, reverse=False, **kwargs):
removal_map, subst_end_indices = self.get_mapping_after_normalization(text, removal_list, reverse, **kwargs)
return self.convert_normalized_indices_to_unnormalized_indices(normalized_indices, removal_map, subst_end_indices, reverse)

@staticmethod
def convert_normalized_indices_to_unnormalized_indices(normalized_indices, removal_map, reverse=False, alignment_mode='contract'):
def convert_normalized_indices_to_unnormalized_indices(normalized_indices, removal_map, subst_end_indices, reverse=False):
"""
normalized_indices - list of tuples where each tuple is (x, y) x being start index, y is end index + 1
removal_map - return value of get_mapping_after_normalization()
subst_end_indices -
reverse - if True, normalized_indices are actually unnormalized indices and removal_map was calculated using reverse=True in get_mapping_after_normalization()
alignment_mode - How to deal with cases where the end of a range touches a removal. Use "expand" if the removal should be included in the range. "contract" if it should be excluded.
"""
removal_keys = sorted(removal_map.keys())
unnormalized_indices = []
sign = -1 if reverse else 1
for start, end in normalized_indices:
unnorm_start_index = bisect_right(removal_keys, start) - 1

bisect_end_index = end if (start == end or alignment_mode == 'expand') else end - 1
bisect_end_index = end if (start == end or end in subst_end_indices) else end - 1
unnorm_end_index = bisect_right(removal_keys, bisect_end_index) - 1

unnorm_start = start if unnorm_start_index < 0 else start + (sign * removal_map[removal_keys[unnorm_start_index]])
Expand Down Expand Up @@ -262,8 +269,8 @@ def find_text_to_remove(self, s, **kwargs):
text_to_remove_inds, text_to_remove_repls = [], []
else:
text_to_remove_inds, text_to_remove_repls = zip(*curr_text_to_remove)
for mapping in reversed(mappings):
text_to_remove_inds = step.convert_normalized_indices_to_unnormalized_indices(text_to_remove_inds, mapping, alignment_mode='expand')
for mapping, subst_end_indices in reversed(mappings):
text_to_remove_inds = step.convert_normalized_indices_to_unnormalized_indices(text_to_remove_inds, mapping, subst_end_indices)
curr_text_to_remove = list(zip(text_to_remove_inds, text_to_remove_repls))

# merge any overlapping ranges
Expand Down Expand Up @@ -433,7 +440,6 @@ def char_indices_from_word_indices(input_string, word_ranges, split_regex=None):
count += len(word)
end = count
word_indices.append((start, end))
removal_map = regex_normalizer.get_mapping_after_normalization(input_string)
normalized_char_indices = []
for i, words in enumerate(word_ranges):
first_word, last_word = [w if w < len(word_indices) else -1 for w in words]
Expand All @@ -443,7 +449,7 @@ def char_indices_from_word_indices(input_string, word_ranges, split_regex=None):
word_indices[last_word][1] if last_word >= 0 else -1
)
)
return regex_normalizer.convert_normalized_indices_to_unnormalized_indices(normalized_char_indices, removal_map)
return regex_normalizer.norm_to_unnorm_indices(input_string, normalized_char_indices)


@lru_cache(maxsize=32)
Expand All @@ -461,9 +467,8 @@ def word_index_from_char_index(full_string, char_index, split_regex=r'\s+'):

def sanitized_words_to_unsanitized_words(input_string, sanitized_string, sanitization_method, sanitized_word_ranges):
normalizer = FunctionNormalizer(sanitization_method)
removal_map = normalizer.get_mapping_after_normalization(input_string)
sanitized_char_ranges = char_indices_from_word_indices(sanitized_string, sanitized_word_ranges)
unsanitzied_char_ranges = normalizer.convert_normalized_indices_to_unnormalized_indices(sanitized_char_ranges, removal_map)
unsanitzied_char_ranges = normalizer.norm_to_unnorm_indices(input_string, sanitized_char_ranges)
# for char_range in unsanitied_char_ranges:
# word_range = tuple(word_index_from_char_index(input_string, i) for i in char_range)
# stuff.append(word_range)
Expand Down
6 changes: 3 additions & 3 deletions sefaria/model/linker/named_entity_recognizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,16 +145,16 @@ def map_normal_output_to_original_input(self, input: str, named_entities: List[R
Ref resolution ran on normalized input. Remap raw refs to original (non-normalized) input
"""
unnorm_doc = self._raw_ref_model.make_doc(input)
mapping = self._normalizer.get_mapping_after_normalization(input)
mapping, subst_end_indices = self._normalizer.get_mapping_after_normalization(input)
# this function name is waaay too long
conv = self._normalizer.convert_normalized_indices_to_unnormalized_indices
norm_inds = [named_entity.char_indices for named_entity in named_entities]
unnorm_inds = conv(norm_inds, mapping)
unnorm_inds = conv(norm_inds, mapping, subst_end_indices)
unnorm_part_inds = []
for (named_entity, (norm_raw_ref_start, _)) in zip(named_entities, norm_inds):
raw_ref_parts = named_entity.raw_ref_parts if isinstance(named_entity, RawRef) else []
unnorm_part_inds += [conv([[norm_raw_ref_start + i for i in part.char_indices]
for part in raw_ref_parts], mapping)]
for part in raw_ref_parts], mapping, subst_end_indices)]
for named_entity, temp_unnorm_inds, temp_unnorm_part_inds in zip(named_entities, unnorm_inds, unnorm_part_inds):
named_entity.map_new_indices(unnorm_doc, temp_unnorm_inds)
if isinstance(named_entity, RawRef):
Expand Down
3 changes: 1 addition & 2 deletions sefaria/model/linker/tests/linker_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,8 +397,7 @@ def test_map_new_indices(crrd_params):
n = linker.get_ner()._normalizer
norm_text = n.normalize(text)
norm_doc = nlp.make_doc(norm_text)
mapping = n.get_mapping_after_normalization(text, reverse=True)
norm_part_indices = n.convert_normalized_indices_to_unnormalized_indices(part_indices, mapping, reverse=True)
norm_part_indices = n.norm_to_unnorm_indices(text, part_indices, reverse=True)
norm_part_spans = [norm_doc.char_span(s, e) for (s, e) in norm_part_indices]
norm_part_token_inds = []
for span in norm_part_spans:
Expand Down

0 comments on commit 5c97e57

Please sign in to comment.