Skip to content

Commit

Permalink
Final code clean up.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Sep 15, 2021
1 parent eb932b9 commit b548b7a
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions references/classification/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,14 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
target = target.clone()

if target.ndim == 1:
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=torch.float32)
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)

if torch.rand(1).item() >= self.p:
return batch, target

# It's faster to roll the batch by one instead of shuffling it to create image pairs
batch_rolled = batch.roll(1, 0)
target_rolled = target.roll(1)
target_rolled = target.roll(1, 0)

# Implemented as on mixup paper, page 3.
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
Expand Down Expand Up @@ -132,14 +132,14 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
target = target.clone()

if target.ndim == 1:
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=torch.float32)
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)

if torch.rand(1).item() >= self.p:
return batch, target

# It's faster to roll the batch by one instead of shuffling it to create image pairs
batch_rolled = batch.roll(1, 0)
target_rolled = target.roll(1)
target_rolled = target.roll(1, 0)

# Implemented as on cutmix paper, page 12 (with minor corrections on typos).
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
Expand Down

0 comments on commit b548b7a

Please sign in to comment.