Skip to content

Commit

Permalink
Fix scoring with batches of size 1 by avoid squeezing the batch dim (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
fhieber committed Apr 21, 2022
1 parent c822e20 commit 23ffd29
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 3 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [3.1.12]

### Fixed

- Fix scoring with batches of size 1 (whic may occur when `|data| % batch_size == 1`.

## [3.1.11]

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '3.1.11'
__version__ = '3.1.12'
4 changes: 2 additions & 2 deletions sockeye/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def forward(self,

# Select the label log probability
# logprobs and scores: (batch_size, target_seq_len)
token_scores = logprobs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze()
token_scores = logprobs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
if self.score_type == C.SCORING_TYPE_NEGLOGPROB:
token_scores = -token_scores

Expand All @@ -80,7 +80,7 @@ def forward(self,
factor_scores = [] # type: List[pt.Tensor]
for factor_logit, factor_label in factor_logits_and_labels:
factor_logprobs = factor_logit.log_softmax(dim=-1)
factor_token_scores = factor_logprobs.gather(dim=-1, index=factor_label.unsqueeze(-1)).squeeze()
factor_token_scores = factor_logprobs.gather(dim=-1, index=factor_label.unsqueeze(-1)).squeeze(-1)
if self.score_type == C.SCORING_TYPE_NEGLOGPROB:
factor_token_scores = -factor_token_scores
fs = factor_token_scores.masked_fill_(factor_label == C.PAD_ID, .0).sum(dim=-1, keepdims=True) # type: ignore
Expand Down

0 comments on commit 23ffd29

Please sign in to comment.