Skip to content

Commit

Permalink
Make compatible with torchscript
Browse files Browse the repository at this point in the history
  • Loading branch information
erksch authored and kmkurn committed Jun 9, 2024
1 parent cc449eb commit 623e340
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
fail-fast: false
matrix:
python-version: ["3.8"]
pytorch-version: ["1.4", "2.0"]
pytorch-version: ["1.4", "1.10", "2.0"]

steps:
- uses: actions/checkout@v2
Expand Down
11 changes: 11 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,17 @@ To obtain the most probable sequence of tags, use the `CRF.decode` method.
This method also accepts a mask tensor, see `CRF.decode` for details.

TorchScript
-----------

The ``CRF`` module is compatible with TorchScript on PyTorch ``>=1.10.0``.
To get a ``torch.jit.ScriptModule``, wrap a ``CRF`` instance inside ``torch.jit.script``.

.. code-block:: python
script_model = torch.jit.script(CRF(num_tags))
API documentation
=================

Expand Down
181 changes: 181 additions & 0 deletions tests/test_crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,3 +476,184 @@ def assert_close(actual, expected):

def assert_close(actual, expected):
torch.testing.assert_allclose(actual, expected, atol=1e-12, rtol=1e-6)


# TODO: reduce duplicated code below, maybe with pytest.mark.parametrize
@pytest.mark.skipif(
Version(torch.__version__) < Version("1.10.0"),
reason="torch version does not support torch script")
class TestTorchScriptForward:
def test_torch_scriptable(self):
crf = make_crf()
scripted_module = torch.jit.script(crf)
assert hasattr(scripted_module, 'decode')

def test_default_forward(self):
# Test default case
crf = make_crf()
crf_script = torch.jit.script(crf)
seq_length, batch_size = 3, 2
# shape: (seq_length, batch_size, num_tags)
emissions = make_emissions(crf, seq_length, batch_size)
# shape: (seq_length, batch_size)
tags = make_tags(crf, seq_length, batch_size)
# mask should have size of (seq_length, batch_size)
mask = torch.tensor([[1, 1, 1], [1, 1, 0]], dtype=torch.uint8).transpose(0, 1)
llh = crf(emissions, tags, mask=mask)
llh_scripted = crf_script(emissions, tags, mask=mask)
assert_close(llh_scripted, llh)

def test_without_mask(self):
crf = make_crf()
crf_script = torch.jit.script(crf)
seq_length, batch_size = 3, 2
# shape: (seq_length, batch_size, num_tags)
emissions = make_emissions(crf, seq_length, batch_size)
# shape: (seq_length, batch_size)
tags = make_tags(crf, seq_length, batch_size)
# Test scripted forward works without mask
llh_no_mask = crf(emissions, tags)
llh_no_mask_script = crf_script(emissions, tags)
assert_close(llh_no_mask_script, llh_no_mask)

def test_all_ones_mask(self):
crf = make_crf()
crf_script = torch.jit.script(crf)
seq_length, batch_size = 3, 2
# shape: (seq_length, batch_size, num_tags)
emissions = make_emissions(crf, seq_length, batch_size)
# shape: (seq_length, batch_size)
tags = make_tags(crf, seq_length, batch_size)

# No mask means the mask is all ones
llh_mask = crf(emissions, tags, mask=torch.ones_like(tags).byte())
llh_mask_script = crf_script(emissions, tags, mask=torch.ones_like(tags).byte())
assert_close(llh_mask_script, llh_mask)

def test_batched_forward(self):
crf = make_crf()
crf_script = torch.jit.script(crf)

# Test scripted forward in batched setting
batch_size = 10
# shape: (seq_length, batch_size, num_tags)
emissions_batch = make_emissions(crf, batch_size=batch_size)
# shape: (seq_length, batch_size)
tags_batch = make_tags(crf, batch_size=batch_size)
llh = crf(emissions_batch, tags_batch)
llh_script = crf_script(emissions_batch, tags_batch)
assert_close(llh_script, llh)

def test_reduction_none(self):
crf = make_crf()
crf_script = torch.jit.script(crf)

# shape: (seq_length, batch_size, num_tags)
emissions = make_emissions(crf)
# shape: (seq_length, batch_size)
tags = make_tags(crf)
llh = crf(emissions, tags, reduction='none')
llh_script = crf_script(emissions, tags, reduction='none')
assert_close(llh_script, llh)

def test_reduction_mean(self):
crf = make_crf()
crf_script = torch.jit.script(crf)

# shape: (seq_length, batch_size, num_tags)
emissions = make_emissions(crf)
# shape: (seq_length, batch_size)
tags = make_tags(crf)

llh = crf(emissions, tags, reduction='mean')
llh_script = crf_script(emissions, tags, reduction='mean')
assert_close(llh_script, llh)

def test_reduction_token_mean(self):
crf = make_crf()
crf_script = torch.jit.script(crf)

# shape: (seq_length, batch_size, num_tags)
emissions = make_emissions(crf)
# shape: (seq_length, batch_size)
tags = make_tags(crf)

mask = torch.tensor([[1, 1, 1], [1, 1, 0]], dtype=torch.uint8).transpose(0, 1)
llh = crf(emissions, tags, mask=mask, reduction='token_mean')
llh_script = crf_script(emissions, tags, mask=mask, reduction='token_mean')
assert_close(llh_script, llh)

def test_batch_first(self):
crf = make_crf()

# shape: (seq_length, batch_size, num_tags)
emissions = make_emissions(crf)
# shape: (seq_length, batch_size)
tags = make_tags(crf)

# Test scripted forward when running batch first mode
crf_bf = make_crf(batch_first=True)
# Copy parameter values from non-batch-first CRF; requires_grad must be False
# to avoid runtime error of in-place operation on a leaf variable
crf_bf.start_transitions.requires_grad_(False).copy_(crf.start_transitions)
crf_bf.end_transitions.requires_grad_(False).copy_(crf.end_transitions)
crf_bf.transitions.requires_grad_(False).copy_(crf.transitions)
crf_bf_script = torch.jit.script(crf_bf)
emissions = emissions.transpose(0, 1)
# shape: (batch_size, seq_length)
tags = tags.transpose(0, 1)
llh_bf = crf_bf(emissions, tags)
llh_bf_script = crf_bf_script(emissions, tags)
assert_close(llh_bf_script, llh_bf)


@pytest.mark.skipif(
Version(torch.__version__) < Version("1.10.0"),
reason="torch version does not support torch script")
class TestTorchScriptDecode:
def test_with_mask(self):
# Test decoding with a mask
crf = make_crf()
crf_script = torch.jit.script(crf)

seq_length, batch_size = 3, 2
# shape: (seq_length, batch_size, num_tags)
emissions = make_emissions(crf, seq_length, batch_size)
# mask should be (seq_length, batch_size)
mask = torch.tensor([[1, 1, 1], [1, 1, 0]], dtype=torch.uint8).transpose(0, 1)
best_tags = crf.decode(emissions, mask=mask)
best_tags_scripted = crf_script.decode(emissions, mask=mask)
assert best_tags == best_tags_scripted

def test_without_mask(self):
crf = make_crf()
crf_script = torch.jit.script(crf)

seq_length, batch_size = 3, 2
# shape: (seq_length, batch_size, num_tags)

emissions = make_emissions(crf, seq_length, batch_size)
best_tags_no_mask = crf.decode(emissions)
best_tags_no_mask_scripted = crf_script.decode(emissions)
assert best_tags_no_mask == best_tags_no_mask_scripted

def test_batch_first(self):
crf = make_crf()

seq_length, batch_size = 3, 2
# shape: (seq_length, batch_size, num_tags)

emissions = make_emissions(crf, seq_length, batch_size)

crf_bf = make_crf(batch_first=True)
# Copy parameter values from non-batch-first CRF; requires_grad must be False
# to avoid runtime error of in-place operation on a leaf variable
crf_bf.start_transitions.requires_grad_(False).copy_(crf.start_transitions)
crf_bf.end_transitions.requires_grad_(False).copy_(crf.end_transitions)
crf_bf.transitions.requires_grad_(False).copy_(crf.transitions)
crf_bf_script = torch.jit.script(crf_bf)
# shape: (batch_size, seq_length, num_tags)
emissions = emissions.transpose(0, 1)
best_tags_bf = crf_bf.decode(emissions)
best_tags_bf_script = crf_bf_script.decode(emissions)
assert best_tags_bf == best_tags_bf_script
18 changes: 12 additions & 6 deletions torchcrf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def forward(
assert reduction == 'token_mean'
return llh.sum() / mask.type_as(emissions).sum()

@torch.jit.export
def decode(self, emissions: torch.Tensor,
mask: Optional[torch.ByteTensor] = None) -> List[List[int]]:
"""Find the most likely tag sequence using Viterbi algorithm.
Expand Down Expand Up @@ -154,13 +155,15 @@ def _validate(
if emissions.shape[:2] != tags.shape:
raise ValueError(
'the first two dimensions of emissions and tags must match, '
f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}')
f'got {(emissions.shape[0], emissions.shape[1])} and {(tags.shape[0], tags.shape[1])}'
)

if mask is not None:
if emissions.shape[:2] != mask.shape:
raise ValueError(
'the first two dimensions of emissions and mask must match, '
f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}')
f'got {(emissions.shape[0], emissions.shape[1])} and {(mask.shape[0], mask.shape[1])}'
)
no_empty_seq = not self.batch_first and mask[0].all()
no_empty_seq_bf = self.batch_first and mask[:, 0].all()
if not no_empty_seq and not no_empty_seq_bf:
Expand Down Expand Up @@ -270,7 +273,7 @@ def _viterbi_decode(self, emissions: torch.FloatTensor,
# Start transition and first emission
# shape: (batch_size, num_tags)
score = self.start_transitions + emissions[0]
history = []
history: List[torch.Tensor] = []

# score is a tensor of size (batch_size, num_tags) where for every batch,
# value at column j stores the score of the best tag sequence so far that ends
Expand Down Expand Up @@ -313,17 +316,20 @@ def _viterbi_decode(self, emissions: torch.FloatTensor,

# shape: (batch_size,)
seq_ends = mask.long().sum(dim=0) - 1
best_tags_list = []
best_tags_list: List[List[int]] = []

for idx in range(batch_size):
# Find the tag which maximizes the score at the last timestep; this is our best tag
# for the last timestep
_, best_last_tag = score[idx].max(dim=0)
best_tags = [best_last_tag.item()]
best_tags: List[int] = []
best_tags.append(best_last_tag.item())

# We trace back where the best last tag comes from, append that to our best tag
# sequence, and trace it back again, and so on
for hist in reversed(history[:seq_ends[idx]]):
# NOTE: reversed() cannot be used here because it is not supported by TorchScript,
# see https://github.com/pytorch/pytorch/issues/31772.
for hist in history[:seq_ends[idx]][::-1]:
best_last_tag = hist[idx][best_tags[-1]]
best_tags.append(best_last_tag.item())

Expand Down

0 comments on commit 623e340

Please sign in to comment.