Skip to content

Commit

Permalink
extract annotations test update
Browse files Browse the repository at this point in the history
  • Loading branch information
samuel500 committed Aug 29, 2022
1 parent 0026583 commit b8b00bd
Showing 1 changed file with 24 additions and 33 deletions.
57 changes: 24 additions & 33 deletions tests/trainer/test_information_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,17 +157,25 @@ def test_7_extract_test_document(self):
result = self.pipeline.extract(document=test_document)

assert type(result) is dict

res_doc = extraction_result_to_document(test_document, result)
tests_annotations = res_doc.annotations(use_correct=False)

assert len(tests_annotations) == 20

self.tests_annotations = res_doc.annotations(use_correct=False)
tests_bools = []
for ann_i, expected in enumerate(entity_results_data):
ann = tests_annotations[ann_i]
ann_tuple = (ann.label.name, ann.start_offset, ann.end_offset)
tests_bools.append(ann_tuple == expected[1])

assert len(self.tests_annotations) == 20
assert tests_bools == [True] * len(tests_bools)

# @pytest.mark.parametrize("ann_i,expected", entity_results_data)
# def test_8_test_annotations(self, ann_i, expected):
# """"""
# anns = self.tests_annotations[ann_i]
# assert (anns[ann_i].label.name, anns[ann_i].start_offset, anns[ann_i].end_offset) == expected
# anns = self.tests_annotations[ann_i]
# assert (anns[ann_i].label.name, anns[ann_i].start_offset, anns[ann_i].end_offset) == expected
# https://github.com/pytest-dev/pytest/issues/541


class TestSeparateLabelsEntityMultiClassModel(unittest.TestCase):
Expand Down Expand Up @@ -228,36 +236,19 @@ def test_7_extract_test_document(self):
result = self.pipeline.extract(document=test_document)

assert type(result) is dict

res_doc = extraction_result_to_document(test_document, result)
tests_annotations = res_doc.annotations(use_correct=False)

anns = res_doc.annotations(use_correct=False)

assert len(anns) == 20

assert (anns[0].label.name, anns[0].start_offset, anns[0].end_offset) == ('Austellungsdatum', 159, 169)
assert (anns[1].label.name, anns[1].start_offset, anns[1].end_offset) == ('Personalausweis', 352, 357)
assert (anns[2].label.name, anns[2].start_offset, anns[2].end_offset) == ('Steuerklasse', 365, 366)
assert (anns[3].label.name, anns[3].start_offset, anns[3].end_offset) == ('Personalausweis', 1194, 1199)
assert (anns[4].label.name, anns[4].start_offset, anns[4].end_offset) == ('Gesamt-Brutto', 1498, 1504)
assert (anns[5].label.name, anns[5].start_offset, anns[5].end_offset) == ('Vorname', 1507, 1518)
assert (anns[6].label.name, anns[6].start_offset, anns[6].end_offset) == ('Nachname', 1519, 1527)
assert (anns[7].label.name, anns[7].start_offset, anns[7].end_offset) == ('Gesamt-Brutto', 1582, 1587)
assert (anns[8].label.name, anns[8].start_offset, anns[8].end_offset) == ('Lohnart', 1758, 1762)
assert (anns[9].label.name, anns[9].start_offset, anns[9].end_offset) == ('Bezeichnung', 1763, 1769)
assert (anns[10].label.name, anns[10].start_offset, anns[10].end_offset) == ('Betrag', 1831, 1839)
assert (anns[11].label.name, anns[11].start_offset, anns[11].end_offset) == ('Gesamt-Brutto', 2111, 2119)
assert (anns[12].label.name, anns[12].start_offset, anns[12].end_offset) == ('Sozialversicherung', 2255, 2262)
assert (anns[13].label.name, anns[13].start_offset, anns[13].end_offset) == ('Sozialversicherung', 2269, 2274)
assert (anns[14].label.name, anns[14].start_offset, anns[14].end_offset) == ('Sozialversicherung', 2281, 2285)
assert (anns[15].label.name, anns[15].start_offset, anns[15].end_offset) == ('Sozialversicherung', 2292, 2296)
assert (anns[16].label.name, anns[16].start_offset, anns[16].end_offset) == (
'Steuerrechtliche Abzüge',
2324,
2330,
)
assert (anns[17].label.name, anns[17].start_offset, anns[17].end_offset) == ('Netto-Verdienst', 3004, 3012)
assert (anns[18].label.name, anns[18].start_offset, anns[18].end_offset) == ('Steuer-Brutto', 3141, 3149)
assert (anns[19].label.name, anns[19].start_offset, anns[19].end_offset) == ('Auszahlungsbetrag', 3777, 3785)
assert len(tests_annotations) == 20

tests_bools = []
for ann_i, expected in enumerate(entity_results_data):
ann = tests_annotations[ann_i]
ann_tuple = (ann.label.name, ann.start_offset, ann.end_offset)
tests_bools.append(ann_tuple == expected[1])

assert tests_bools == [True] * len(tests_bools)


class TestInformationExtraction(unittest.TestCase):
Expand Down

3 comments on commit b8b00bd

@atraining
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@samuel500 please refer to my comment before - I was refering to parametrized to see what is failing see

@pytest.mark.parametrize("test_input, expected, document_id", test_data_percentage)
- please revise it again

@samuel500
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There appears to be an issue with parametrize when used inside a unittest class: pytest-dev/pytest#541 .
One alternative workaround would be to use the parametrized package as they did here: seleniumbase/SeleniumBase#395

@da-konfuzio
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I like the idea of adding parameterized to the SDK

Please sign in to comment.