Skip to content

Commit

Permalink
feat(linker): fully support traversing ArrayMapNode children
Browse files Browse the repository at this point in the history
  • Loading branch information
nsantacruz committed Nov 18, 2023
1 parent 872d36f commit 588c92b
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 47 deletions.
136 changes: 117 additions & 19 deletions sefaria/model/linker/referenceable_book_node.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,28 @@
import dataclasses
from typing import List, Union, Optional
from typing import List, Union, Optional, Tuple
from sefaria.model import abstract as abst
from sefaria.model import text
from sefaria.model import schema
from sefaria.system.exceptions import InputError


def subref(ref: text.Ref, section: int):
if ref.index_node.addressTypes[len(ref.sections)-1] == "Talmud":
d = ref._core_dict()
d['sections'][-1] += (section-1)
d['toSections'] = d['sections'][:]
return text.Ref(_obj=d)
else:
return ref.subref(section)


def truncate_serialized_node_to_depth(serial_node: dict, depth: int) -> dict:
truncated_serial_node = serial_node.copy()
for list_attr in ('addressTypes', 'sectionNames', 'lengths', 'referenceableSections'):
if list_attr not in serial_node:
continue
truncated_serial_node[list_attr] = serial_node[list_attr][depth:]
return truncated_serial_node


class ReferenceableBookNode:
Expand All @@ -20,6 +40,10 @@ def get_children(self, *args, **kwargs) -> List['ReferenceableBookNode']:
def is_default(self) -> bool:
return False

@property
def referenceable(self) -> bool:
return True


class NamedReferenceableBookNode(ReferenceableBookNode):

Expand All @@ -29,6 +53,10 @@ def __init__(self, titled_tree_node_or_index: Union[schema.TitledTreeNode, text.
if isinstance(titled_tree_node_or_index, text.Index):
self._titled_tree_node = titled_tree_node_or_index.nodes

@property
def referenceable(self):
return getattr(self._titled_tree_node, 'referenceable', True)

def is_default(self):
return self._titled_tree_node.is_default()

Expand All @@ -40,20 +68,44 @@ def ref(self) -> text.Ref:

def _get_all_children(self) -> List[ReferenceableBookNode]:
thingy = self._titled_tree_node_or_index
#the schema node for this referenceable node has a dibur hamatchil child
# the schema node for this referenceable node has a dibur hamatchil child
if isinstance(thingy, schema.NumberedTitledTreeNode) and thingy.is_segment_level_dibur_hamatchil():
return [DiburHamatchilNodeSet({"container_refs": self.ref().normal()})]
#the schema node for this referenceable is a JAN. JANs act as both named and numbered nodes
# the schema node for this referenceable is a JAN. JANs act as both named and numbered nodes
if isinstance(thingy, schema.JaggedArrayNode) and len(thingy.children) == 0:
return [NumberedReferenceableBookNode(thingy)]
if isinstance(thingy, text.Index):
children = thingy.referenceable_children()
elif isinstance(thingy, schema.ArrayMapNode):
# TODO following two if's are very similar...
if getattr(thingy, 'refs', None):
address_types = thingy.addressTypes
section_names = thingy.sectionNames
children = []
for ichild, tref in enumerate(thingy.refs):
oref = text.Ref(tref)
children += [MonoReferenceableBookNode(address_types, section_names, ichild+1, oref)]
return children
elif getattr(thingy, 'wholeRef', None):
whole_ref = text.Ref(thingy.wholeRef)
schema_node = whole_ref.index_node.serialize()
truncated_node = truncate_serialized_node_to_depth(schema_node, -2)
refs = whole_ref.split_spanning_ref()
children = []
for oref in refs:
children += [MonoReferenceableBookNode(numeric_equivalent=oref.section_ref().sections[0], ref=oref, **truncated_node)]
return children
else:
children = self._titled_tree_node.children
else:
# Any other type of TitledTreeNode
children = self._titled_tree_node.children
children = [self._transform_schema_node_to_referenceable(x) for x in children]
return children

def _get_children_from_array_map_node(self, node: schema.ArrayMapNode) -> List[ReferenceableBookNode]:
pass

@staticmethod
def _transform_schema_node_to_referenceable(schema_node: schema.TitledTreeNode) -> ReferenceableBookNode:
if isinstance(schema_node, schema.JaggedArrayNode) and (schema_node.is_default() or schema_node.parent is None):
Expand Down Expand Up @@ -84,27 +136,44 @@ class NumberedReferenceableBookNode(ReferenceableBookNode):
def __init__(self, ja_node: schema.NumberedTitledTreeNode):
self._ja_node = ja_node

@property
def referenceable(self):
return getattr(self._ja_node, 'referenceable', True)

def is_default(self):
return self._ja_node.is_default() and self._ja_node.parent is not None

def ref(self):
return self._ja_node.ref()

def possible_subrefs(self, lang: str, initial_ref: text.Ref, section_str: str, fromSections=None) -> Tuple[List[text.Ref], List[bool]]:
try:
possible_sections, possible_to_sections, addr_classes = self._address_class.get_all_possible_sections_from_string(lang, section_str, fromSections, strip_prefixes=True)
except (IndexError, TypeError, KeyError):
return [], []
possible_subrefs = []
can_match_out_of_order_list = []
for sec, toSec, addr_class in zip(possible_sections, possible_to_sections, addr_classes):
try:
refined_ref = subref(initial_ref, sec)
if toSec != sec:
to_ref = subref(initial_ref, toSec)
refined_ref = refined_ref.to(to_ref)
possible_subrefs += [refined_ref]
can_match_out_of_order_list += [addr_class.can_match_out_of_order(lang, section_str)]
except (InputError, IndexError, AssertionError, AttributeError):
continue
return possible_subrefs, can_match_out_of_order_list

# TODO move these two properties to be private
@property
def address_class(self) -> schema.AddressType:
def _address_class(self) -> schema.AddressType:
return self._ja_node.address_class(0)

@property
def section_name(self) -> str:
def _section_name(self) -> str:
return self._ja_node.sectionNames[0]

def get_all_possible_sections_from_string(self, *args, **kwargs):
"""
wraps AddressType function with same name
@return:
"""
return self.address_class.get_all_possible_sections_from_string(*args, **kwargs)

def _get_next_referenceable_depth(self):
if self.is_default():
return 0
Expand All @@ -126,30 +195,59 @@ def get_children(self, context_ref=None, **kwargs) -> [ReferenceableBookNode]:
if serial['depth'] <= 1 and self._ja_node.is_segment_level_dibur_hamatchil():
return [DiburHamatchilNodeSet({"container_refs": context_ref.normal()})]
if (self._ja_node.depth - next_referenceable_depth) == 0:
if isinstance(self.address_class, schema.AddressTalmud):
if isinstance(self._address_class, schema.AddressTalmud):
serial['addressTypes'] = ["Amud"]
serial['sectionNames'] = ["Amud"]
serial['lengths'] = [1]
serial['referenceableSections'] = [True]
else:
return []
else:
for list_attr in ('addressTypes', 'sectionNames', 'lengths', 'referenceableSections'):
# truncate every list attribute by `next_referenceable_depth`
if list_attr not in serial: continue
serial[list_attr] = serial[list_attr][next_referenceable_depth:]
serial = truncate_serialized_node_to_depth(serial, next_referenceable_depth)
new_ja = schema.JaggedArrayNode(serial=serial, index=getattr(self, 'index', None), **kwargs)
return [NumberedReferenceableBookNode(new_ja)]

def matches_section_context(self, section_context: 'SectionContext') -> bool:
"""
Does the address in `self` match the address in `section_context`?
"""
if self.address_class.__class__ != section_context.addr_type.__class__: return False
if self.section_name != section_context.section_name: return False
if self._address_class.__class__ != section_context.addr_type.__class__: return False
if self._section_name != section_context.section_name: return False
return True


class MonoReferenceableBookNode(NumberedReferenceableBookNode):
"""
Node that can only be referenced by one ref
"""

def __init__(self, addressTypes: List[str], sectionNames: List[str], numeric_equivalent: int, ref: text.Ref, **ja_node_attrs):
ja_node = schema.JaggedArrayNode(serial={
"addressTypes": addressTypes,
"sectionNames": sectionNames,
**ja_node_attrs,
"depth": len(addressTypes),
})
super().__init__(ja_node)
self._numeric_equivalent = numeric_equivalent
self._ref = ref

def ref(self):
return self._ref

def possible_subrefs(self, lang: str, initial_ref: text.Ref, section_str: str, fromSections=None) -> Tuple[List[text.Ref], List[bool]]:
try:
possible_sections, possible_to_sections, addr_classes = self._address_class.\
get_all_possible_sections_from_string(lang, section_str, fromSections, strip_prefixes=True)
except (IndexError, TypeError, KeyError):
return [], []
# if any section matches numeric_equivalent, this node's ref is the subref.
for sec, to_sec in zip(possible_sections, possible_to_sections):
if sec == self._numeric_equivalent and sec == to_sec:
return [self._ref], [True]
return [], []


@dataclasses.dataclass
class DiburHamatchilMatch:
score: float
Expand Down
31 changes: 4 additions & 27 deletions sefaria/model/linker/resolved_ref_refiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,6 @@
from sefaria.model.text import Ref


def subref(ref: Ref, section: int):
if ref.index_node.addressTypes[len(ref.sections)-1] == "Talmud":
d = ref._core_dict()
d['sections'][-1] += (section-1)
d['toSections'] = d['sections'][:]
return Ref(_obj=d)
else:
return ref.subref(section)


class ResolvedRefRefiner(ABC):

def __init__(self, part_to_match: RawRefPart, node: ReferenceableBookNode, resolved_ref: 'ResolvedRef'):
Expand Down Expand Up @@ -88,31 +78,18 @@ def __refine_context_full(self) -> List['ResolvedRef']:
def __refine_context_free(self, lang: str, fromSections=None) -> List['ResolvedRef']:
if self.node is None:
return []
try:
possible_sections, possible_to_sections, addr_classes = self.node.get_all_possible_sections_from_string(lang, self.part_to_match.text, fromSections, strip_prefixes=True)
except (IndexError, TypeError, KeyError):
return []
possible_subrefs, can_match_out_of_order_list = self.node.possible_subrefs(lang, self.resolved_ref.ref, self.part_to_match.text, fromSections)
refined_refs = []
addr_classes_used = []
for sec, toSec, addr_class in zip(possible_sections, possible_to_sections, addr_classes):
if self._has_prev_unused_numbered_ref_part() and not addr_class.can_match_out_of_order(lang, self.part_to_match.text):
for refined_ref, can_match_out_of_order in zip(possible_subrefs, can_match_out_of_order_list):
if self._has_prev_unused_numbered_ref_part() and not can_match_out_of_order:
"""
If raw_ref has NUMBERED parts [a, b]
and part b matches before part a
and part b gets matched as AddressInteger
discard match because AddressInteger parts need to match in order
"""
continue
try:
refined_ref = subref(self.resolved_ref.ref, sec)
if toSec != sec:
to_ref = subref(self.resolved_ref.ref, toSec)
refined_ref = refined_ref.to(to_ref)
refined_refs += [refined_ref]
addr_classes_used += [addr_class]
except (InputError, IndexError, AssertionError, AttributeError):
continue

refined_refs += [refined_ref]
return [self._clone_resolved_ref(resolved_parts=self._get_resolved_parts(), node=self.node, ref=refined_ref) for refined_ref in refined_refs]


Expand Down
3 changes: 2 additions & 1 deletion sefaria/model/linker/resolved_ref_refiner_factory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from sefaria.model.linker.ref_part import RawRefPart, RefPartType
from sefaria.model.linker.referenceable_book_node import ReferenceableBookNode, NamedReferenceableBookNode, NumberedReferenceableBookNode
from sefaria.model.linker.referenceable_book_node import ReferenceableBookNode, NamedReferenceableBookNode, NumberedReferenceableBookNode, MonoReferenceableBookNode
from sefaria.model.linker.resolved_ref_refiner import ResolvedRefRefinerForDefaultNode, ResolvedRefRefinerForNumberedPart, ResolvedRefRefinerForDiburHamatchilPart, ResolvedRefRefinerForRangedPart, ResolvedRefRefinerForNamedNode, ResolvedRefRefiner, ResolvedRefRefinerCatchAll


Expand Down Expand Up @@ -48,6 +48,7 @@ def initialize_resolved_ref_refiner_factory() -> ResolvedRefRefinerFactory:
refiners_to_register = [
(key(is_default=True), ResolvedRefRefinerForDefaultNode),
(key(RefPartType.NUMBERED, node_class=NumberedReferenceableBookNode), ResolvedRefRefinerForNumberedPart),
(key(RefPartType.NUMBERED, node_class=MonoReferenceableBookNode), ResolvedRefRefinerForNumberedPart),
(key(RefPartType.RANGE, node_class=NumberedReferenceableBookNode), ResolvedRefRefinerForRangedPart),
(key(RefPartType.NAMED, node_class=NamedReferenceableBookNode), ResolvedRefRefinerForNamedNode),
(key(RefPartType.NUMBERED, node_class=NamedReferenceableBookNode), ResolvedRefRefinerForNamedNode),
Expand Down
4 changes: 4 additions & 0 deletions sefaria/model/linker/tests/linker_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def test_resolved_raw_ref_clone():


@pytest.mark.parametrize(('resolver_data', 'expected_trefs'), [
# Using addressTypes of alt structs
[crrd(["@JT", "@Berakhot", "#2a"], lang="en"), ("Jerusalem Talmud Berakhot 1:1:7-11",)],
[crrd(["@JT", "@Berakhot", "@Chapter 1", "#2a"], lang="en"), ("Jerusalem Talmud Berakhot 1:1:7-11",)],
# Numbered JAs
[crrd(["@בבלי", "@ברכות", "#דף ב"]), ("Berakhot 2",)], # amud-less talmud
[crrd(["@ברכות", "#דף ב"]), ("Berakhot 2",)], # amud-less talmud
Expand Down

0 comments on commit 588c92b

Please sign in to comment.