diff --git a/requirements-test.txt b/requirements-test.txt index 2ace388..ca6a507 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -4,3 +4,4 @@ torch pytest>=8,<9 pytest-cov>=2,<3 +packaging diff --git a/tests/test_crf.py b/tests/test_crf.py index dc70340..48b5a3b 100644 --- a/tests/test_crf.py +++ b/tests/test_crf.py @@ -2,10 +2,10 @@ import math import random -from pytest import approx import pytest import torch import torch.nn as nn +from packaging.version import Version from torchcrf import CRF @@ -116,7 +116,7 @@ def test_works_with_mask(self): denominator = math.log(sum(math.exp(s) for s in all_scores)) manual_llh += numerator - denominator - assert llh.item() == approx(manual_llh) + assert_close(llh, manual_llh) llh.backward() # ensure gradients can be computed def test_works_without_mask(self): @@ -130,7 +130,7 @@ def test_works_without_mask(self): # No mask means the mask is all ones llh_mask = crf(emissions, tags, mask=torch.ones_like(tags).byte()) - assert llh_no_mask.item() == approx(llh_mask.item()) + assert_close(llh_no_mask, llh_mask) def test_batched_loss(self): crf = make_crf() @@ -154,7 +154,7 @@ def test_batched_loss(self): # shape: () total_llh += crf(emissions_, tags_) - assert llh.item() == approx(total_llh.item()) + assert_close(llh, total_llh) def test_reduction_none(self): crf = make_crf() @@ -186,8 +186,7 @@ def test_reduction_none(self): denominator = math.log(sum(math.exp(s) for s in all_scores)) manual_llh.append(numerator - denominator) - for llh_, manual_llh_ in zip(llh, manual_llh): - assert llh_.item() == approx(manual_llh_) + assert_close(llh, torch.tensor(manual_llh)) def test_reduction_mean(self): crf = make_crf() @@ -219,7 +218,7 @@ def test_reduction_mean(self): denominator = math.log(sum(math.exp(s) for s in all_scores)) manual_llh += numerator - denominator - assert llh.item() == approx(manual_llh / batch_size) + assert_close(llh, manual_llh / batch_size) def test_reduction_token_mean(self): crf = make_crf() @@ -258,7 +257,7 @@ def test_reduction_token_mean(self): manual_llh += numerator - denominator n_tokens += seq_len - assert llh.item() == approx(manual_llh / n_tokens) + assert_close(llh, manual_llh / n_tokens) def test_batch_first(self): crf = make_crf() @@ -281,7 +280,7 @@ def test_batch_first(self): tags = tags.transpose(0, 1) llh_bf = crf_bf(emissions, tags) - assert llh.item() == approx(llh_bf.item()) + assert_close(llh, llh_bf) def test_emissions_has_bad_number_of_dimension(self): emissions = torch.randn(1, 2) @@ -467,3 +466,13 @@ def test_first_timestep_mask_is_not_all_on(self): with pytest.raises(ValueError) as excinfo: crf.decode(emissions, mask=mask) assert 'mask of the first timestep must all be on' in str(excinfo.value) + + +if Version(torch.__version__) >= Version("1.9.0"): + + def assert_close(actual, expected): + torch.testing.assert_close(actual, expected, atol=1e-12, rtol=1e-6) +else: + + def assert_close(actual, expected): + torch.testing.assert_allclose(actual, expected, atol=1e-12, rtol=1e-6)