Skip to content

Commit

Permalink
Fixed issue with torch 1.12 issue with arange not supporting fp16 for…
Browse files Browse the repository at this point in the history
… CPU device. (#1574)
  • Loading branch information
BloodAxe committed Oct 26, 2023
1 parent 23b4f7a commit 1f15c76
Showing 1 changed file with 6 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def forward(self, inputs):
if not self.training:
outputs_logits.append(output.clone())
if self.grid[i].shape[2:4] != output.shape[2:4]:
self.grid[i] = self._make_grid(nx, ny, dtype=reg_output.dtype).to(output.device)
self.grid[i] = self._make_grid(nx, ny, dtype=reg_output.dtype, device=output.device)

xy = (output[..., :2] + self.grid[i].to(output.device)) * self.stride[i]
wh = torch.exp(output[..., 2:4]) * self.stride[i]
Expand All @@ -279,12 +279,14 @@ def forward(self, inputs):
return outputs if self.training else (torch.cat(outputs, 1), outputs_logits)

@staticmethod
def _make_grid(nx: int, ny: int, dtype: torch.dtype):
def _make_grid(nx: int, ny: int, dtype: torch.dtype, device: torch.device):
y, x = torch.arange(ny, dtype=torch.float32, device=device), torch.arange(nx, dtype=torch.float32, device=device)

if torch_version_is_greater_or_equal(1, 10):
# https://github.com/pytorch/pytorch/issues/50276
yv, xv = torch.meshgrid([torch.arange(ny, dtype=dtype), torch.arange(nx, dtype=dtype)], indexing="ij")
yv, xv = torch.meshgrid([y, x], indexing="ij")
else:
yv, xv = torch.meshgrid([torch.arange(ny, dtype=dtype), torch.arange(nx, dtype=dtype)])
yv, xv = torch.meshgrid([y, x])
return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).to(dtype)


Expand Down

0 comments on commit 1f15c76

Please sign in to comment.