Skip to content

Commit

Permalink
Use torch.testing.assert_close instead of approx
Browse files Browse the repository at this point in the history
The tests outputs `RuntimeError: Can't call numpy() on Tensor that
requires grad. Use tensor.detach().numpy() instead.` for torch above
v1.4. The error seems to came from trying to convert tensor to scalar.

Replace the pytest approx() with `torch.testing.assert_close()`
to let torch handle the conversion.
Use `assert_allclose()` for torch versions before 1.9.0.
  • Loading branch information
leejuyuu authored and kmkurn committed Jun 6, 2024
1 parent 481a801 commit cc449eb
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
1 change: 1 addition & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
torch
pytest>=8,<9
pytest-cov>=2,<3
packaging
27 changes: 18 additions & 9 deletions tests/test_crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import math
import random

from pytest import approx
import pytest
import torch
import torch.nn as nn
from packaging.version import Version

from torchcrf import CRF

Expand Down Expand Up @@ -116,7 +116,7 @@ def test_works_with_mask(self):
denominator = math.log(sum(math.exp(s) for s in all_scores))
manual_llh += numerator - denominator

assert llh.item() == approx(manual_llh)
assert_close(llh, manual_llh)
llh.backward() # ensure gradients can be computed

def test_works_without_mask(self):
Expand All @@ -130,7 +130,7 @@ def test_works_without_mask(self):
# No mask means the mask is all ones
llh_mask = crf(emissions, tags, mask=torch.ones_like(tags).byte())

assert llh_no_mask.item() == approx(llh_mask.item())
assert_close(llh_no_mask, llh_mask)

def test_batched_loss(self):
crf = make_crf()
Expand All @@ -154,7 +154,7 @@ def test_batched_loss(self):
# shape: ()
total_llh += crf(emissions_, tags_)

assert llh.item() == approx(total_llh.item())
assert_close(llh, total_llh)

def test_reduction_none(self):
crf = make_crf()
Expand Down Expand Up @@ -186,8 +186,7 @@ def test_reduction_none(self):
denominator = math.log(sum(math.exp(s) for s in all_scores))
manual_llh.append(numerator - denominator)

for llh_, manual_llh_ in zip(llh, manual_llh):
assert llh_.item() == approx(manual_llh_)
assert_close(llh, torch.tensor(manual_llh))

def test_reduction_mean(self):
crf = make_crf()
Expand Down Expand Up @@ -219,7 +218,7 @@ def test_reduction_mean(self):
denominator = math.log(sum(math.exp(s) for s in all_scores))
manual_llh += numerator - denominator

assert llh.item() == approx(manual_llh / batch_size)
assert_close(llh, manual_llh / batch_size)

def test_reduction_token_mean(self):
crf = make_crf()
Expand Down Expand Up @@ -258,7 +257,7 @@ def test_reduction_token_mean(self):
manual_llh += numerator - denominator
n_tokens += seq_len

assert llh.item() == approx(manual_llh / n_tokens)
assert_close(llh, manual_llh / n_tokens)

def test_batch_first(self):
crf = make_crf()
Expand All @@ -281,7 +280,7 @@ def test_batch_first(self):
tags = tags.transpose(0, 1)
llh_bf = crf_bf(emissions, tags)

assert llh.item() == approx(llh_bf.item())
assert_close(llh, llh_bf)

def test_emissions_has_bad_number_of_dimension(self):
emissions = torch.randn(1, 2)
Expand Down Expand Up @@ -467,3 +466,13 @@ def test_first_timestep_mask_is_not_all_on(self):
with pytest.raises(ValueError) as excinfo:
crf.decode(emissions, mask=mask)
assert 'mask of the first timestep must all be on' in str(excinfo.value)


if Version(torch.__version__) >= Version("1.9.0"):

def assert_close(actual, expected):
torch.testing.assert_close(actual, expected, atol=1e-12, rtol=1e-6)
else:

def assert_close(actual, expected):
torch.testing.assert_allclose(actual, expected, atol=1e-12, rtol=1e-6)

0 comments on commit cc449eb

Please sign in to comment.