Skip to content

Commit

Permalink
Fix some issues and some refactoring
Browse files Browse the repository at this point in the history
* Initialize parameters ini `__init__` (fixes #1)
* Refactor tests
* Deprecate `summed` in favor of `reduce` (fixes #2)
* Rename setup.cfg to .flake8
  • Loading branch information
kmkurn committed Dec 29, 2017
1 parent 0127d5f commit 559221d
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 82 deletions.
File renamed without changes.
5 changes: 0 additions & 5 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,6 @@ In the examples below, we will assume that these lines have been executed ::
>>> emissions = torch.autograd.Variable(torch.randn(seq_length, batch_size, num_tags), requires_grad=True)
>>> tags = torch.autograd.Variable(torch.LongTensor([[0, 1], [2, 4], [3, 1]])) # (seq_length, batch_size)
>>> model = CRF(num_tags)
>>> # Initialize model parameters
... for p in model.parameters():
... _ = torch.nn.init.uniform(p, -1, 1)
...
>>>

Forward computation
-------------------
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


setup(name='pytorch-crf',
version='0.4.0',
version='0.4.1',
description='Conditional random field in PyTorch',
long_description=readme,
url='https://github.com/kmkurn/pytorch-crf',
Expand Down
30 changes: 27 additions & 3 deletions src/torchcrf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Optional, Union
import warnings

from torch.autograd import Variable
import torch
Expand Down Expand Up @@ -47,11 +48,24 @@ def __init__(self, num_tags: int) -> None:
self.end_transitions = nn.Parameter(torch.Tensor(num_tags))
self.transitions = nn.Parameter(torch.Tensor(num_tags, num_tags))

self.reset_parameters()

def reset_parameters(self) -> None:
"""Initialize the transition parameters.
The parameters will be initialized randomly from a uniform distribution
between -0.1 and 0.1.
"""
nn.init.uniform(self.start_transitions, -0.1, 0.1)
nn.init.uniform(self.end_transitions, -0.1, 0.1)
nn.init.uniform(self.transitions, -0.1, 0.1)

def forward(self,
emissions: Variable,
tags: Variable,
mask: Optional[Variable] = None,
summed: bool = True) -> Variable:
reduce: bool = True,
**kwargs) -> Variable:
"""Compute the log likelihood of the given sequence of tags and emission score.
Arguments
Expand All @@ -62,7 +76,7 @@ def forward(self,
Sequence of tags as ``LongTensor`` of size ``(seq_length, batch_size)``.
mask : :class:`~torch.autograd.Variable`, optional
Mask tensor as ``ByteTensor`` of size ``(seq_length, batch_size)``.
summed : bool
reduce : bool
Whether to sum the log likelihood over the batch.
Returns
Expand Down Expand Up @@ -94,13 +108,23 @@ def forward(self,
if not all(mask[0].data):
raise ValueError('mask of the first timestep must all be on')

if 'summed' in kwargs:
msg = "keyword argument 'summed' is deprecated and will be removed in "\
"future versions, please use 'reduce' instead"
warnings.warn(msg, DeprecationWarning, stacklevel=3)
reduce = kwargs.pop('summed')

if kwargs:
raise TypeError(
f"'{kwargs.popitem()[0]}' is an invalid keyword argument for this function")

if mask is None:
mask = Variable(self._new(*tags.size()).fill_(1)).byte()

numerator = self._compute_joint_llh(emissions, tags, mask)
denominator = self._compute_log_partition_function(emissions, mask)
llh = numerator - denominator
return torch.sum(llh) if summed else llh
return llh if not reduce else torch.sum(llh)

def decode(self,
emissions: Union[Variable, torch.FloatTensor],
Expand Down
Loading

0 comments on commit 559221d

Please sign in to comment.