Skip to content

Commit

Permalink
Release v0.5.0
Browse files Browse the repository at this point in the history
* Remove support for keyword argument `summed` in `forward`
* Remove unneeded asterisks when creating tensors
* Convert mask to long before summing to avoid overflow
* Set status to beta
  • Loading branch information
kmkurn committed Jan 4, 2018
1 parent 90be424 commit b200a38
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 40 deletions.
2 changes: 0 additions & 2 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ Description

This package provides an implementation of `conditional random field <https://en.wikipedia.org/wiki/Conditional_random_field>`_ (CRF) in PyTorch. This implementation borrows mostly from `AllenNLP CRF module <https://github.com/allenai/allennlp/blob/master/allennlp/modules/conditional_random_field.py>`_ with some modifications.

NOTE: This software is still in alpha version; every minor version change introduces backward incompatibility.

Requirements
============

Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@


setup(name='pytorch-crf',
version='0.4.1',
version='0.5.0',
description='Conditional random field in PyTorch',
long_description=readme,
url='https://github.com/kmkurn/pytorch-crf',
author='Kemal Kurniawan',
author_email='kemal@kkurniawan.com',
license='MIT',
classifiers=[
'Development Status :: 3 - Alpha',
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Intended Audience :: Science/Research',
'License :: OSI Approved :: MIT License',
Expand Down
19 changes: 4 additions & 15 deletions src/torchcrf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import List, Optional, Union
import warnings

from torch.autograd import Variable
import torch
Expand Down Expand Up @@ -65,7 +64,7 @@ def forward(self,
tags: Variable,
mask: Optional[Variable] = None,
reduce: bool = True,
**kwargs) -> Variable:
) -> Variable:
"""Compute the log likelihood of the given sequence of tags and emission score.
Arguments
Expand Down Expand Up @@ -108,18 +107,8 @@ def forward(self,
if not all(mask[0].data):
raise ValueError('mask of the first timestep must all be on')

if 'summed' in kwargs:
msg = "keyword argument 'summed' is deprecated and will be removed in "\
"future versions, please use 'reduce' instead"
warnings.warn(msg, DeprecationWarning, stacklevel=3)
reduce = kwargs.pop('summed')

if kwargs:
raise TypeError(
f"'{kwargs.popitem()[0]}' is an invalid keyword argument for this function")

if mask is None:
mask = Variable(self._new(*tags.size()).fill_(1)).byte()
mask = Variable(self._new(tags.size()).fill_(1)).byte()

numerator = self._compute_joint_llh(emissions, tags, mask)
denominator = self._compute_log_partition_function(emissions, mask)
Expand Down Expand Up @@ -159,7 +148,7 @@ def decode(self,
if isinstance(emissions, Variable):
emissions = emissions.data
if mask is None:
mask = self._new(*emissions.size()[:2]).fill_(1).byte()
mask = self._new(emissions.size()[:2]).fill_(1).byte()
elif isinstance(mask, Variable):
mask = mask.data

Expand All @@ -169,7 +158,7 @@ def decode(self,

best_tags = []
for emission, mask_ in zip(emissions, mask):
seq_length = mask_.sum()
seq_length = mask_.long().sum()
best_tags.append(self._viterbi_decode(emission[:seq_length]))
return best_tags

Expand Down
21 changes: 0 additions & 21 deletions tests/test_crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,27 +214,6 @@ def test_first_timestep_mask_is_not_all_on(self):
crf(emissions, tags, mask=mask)
assert 'mask of the first timestep must all be on' in str(excinfo.value)

def test_warning_when_kwarg_summed_is_used(self, recwarn):
crf = make_crf()
emissions = make_emissions(num_tags=crf.num_tags)
tags = make_tags(num_tags=crf.num_tags)

crf(emissions, tags, summed=False)

w = recwarn.pop(DeprecationWarning)
msg = "keyword argument 'summed' is deprecated and will be removed in "\
"future versions, please use 'reduce' instead"
assert msg in str(w.message)

def test_unknown_kwargs(self):
crf = make_crf()
emissions = make_emissions(num_tags=crf.num_tags)
tags = make_tags(num_tags=crf.num_tags)

with pytest.raises(TypeError) as excinfo:
crf(emissions, tags, foo='foo')
assert "'foo' is an invalid keyword argument for this function" in str(excinfo.value)


class TestDecode(object):
def test_batched_decode_is_correct(self):
Expand Down

0 comments on commit b200a38

Please sign in to comment.