Skip to content

Commit

Permalink
Release 0.4.0
Browse files Browse the repository at this point in the history
* Upgrade to PyTorch 0.3.0

* Use fancy indexing now that it's fixed in 0.3.0

* Make squeeze dimension explicit
  • Loading branch information
kmkurn committed Dec 7, 2017
1 parent 349ee22 commit 0127d5f
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 30 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Requirements
============

- Python 3.6
- PyTorch 0.2
- PyTorch 0.3.0

Installation
============
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ pycodestyle==2.3.1
pycparser==2.18
pyflakes==1.6.0
pytest==3.2.5
torch==0.2.0.post4
PyYAML==3.12
torch==0.3.0.post4
typed-ast==1.1.0
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


setup(name='pytorch-crf',
version='0.3.1',
version='0.4.0',
description='Conditional random field in PyTorch',
long_description=readme,
url='https://github.com/kmkurn/pytorch-crf',
Expand Down
33 changes: 6 additions & 27 deletions src/torchcrf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def forward(self,
raise ValueError('mask of the first timestep must all be on')

if mask is None:
mask = Variable(self._new(*tags.size()).fill_(1))
mask = Variable(self._new(*tags.size()).fill_(1)).byte()

numerator = self._compute_joint_llh(emissions, tags, mask)
denominator = self._compute_log_partition_function(emissions, mask)
Expand Down Expand Up @@ -163,49 +163,28 @@ def _compute_joint_llh(self,
assert all(mask[0].data)

seq_length = emissions.size(0)
batch_size = emissions.size(1)
mask = mask.float()

# Start transition score
llh = self.start_transitions[tags[0]] # (batch_size,)

broadcast_transition = (
self.transitions
# Add dimension for batch_size
.view(1, self.num_tags, self.num_tags)
# Copy the transition matrix for all batch
.expand(batch_size, self.num_tags, self.num_tags)
)

for i in range(seq_length - 1):
cur_tag, next_tag = tags[i], tags[i+1]
# Emission score for current tag
llh += emissions[i].gather(1, cur_tag.view(-1, 1)).squeeze() * mask[i]
llh += emissions[i].gather(1, cur_tag.view(-1, 1)).squeeze(1) * mask[i]
# Transition score to next tag
transition_score = (
broadcast_transition
# Copy the batch current tag for all possible next tags, and select the current
# tag from the transition matrix
.gather(1, cur_tag.view(batch_size, 1, 1).expand(batch_size, 1, self.num_tags))
# Squeeze to (batch_size, num_tags); this stores the transition score to every
# possible next tags for each batch
.squeeze(1)
# Select the next tag
.gather(1, next_tag.view(batch_size, 1))
# Squeeze to (batch_size,)
.squeeze()
)
transition_score = self.transitions[cur_tag, next_tag]
# Only add transition score if the next tag is not masked (mask == 1)
llh += transition_score * mask[i+1]

# Find last tag index
last_tag_indices = mask.sum(0).long().data - 1 # (batch_size,)
last_tags = tags.gather(0, last_tag_indices.view(1, -1)).squeeze()
last_tag_indices = mask.long().sum(0) - 1 # (batch_size,)
last_tags = tags.gather(0, last_tag_indices.view(1, -1)).squeeze(0)

# End transition score
llh += self.end_transitions[last_tags]
# Emission score for the last tag, if mask is valid (mask == 1)
llh += emissions[-1].gather(1, last_tags.view(-1, 1)).squeeze() * mask[-1]
llh += emissions[-1].gather(1, last_tags.view(-1, 1)).squeeze(1) * mask[-1]

return llh

Expand Down

0 comments on commit 0127d5f

Please sign in to comment.