Skip to content

Commit

Permalink
parameterize nms gpu test
Browse files Browse the repository at this point in the history
  • Loading branch information
qqaatw committed Jul 17, 2023
1 parent 70f3906 commit b1cf619
Showing 1 changed file with 13 additions and 24 deletions.
37 changes: 13 additions & 24 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,47 +751,36 @@ def test_qnms(self, iou, scale, zero_point):

torch.testing.assert_close(qkeep, keep, msg=err_msg.format(iou))

@needs_cuda
@pytest.mark.parametrize(
"device",
(
pytest.param("cuda", marks=pytest.mark.needs_cuda),
pytest.param("mps", marks=pytest.mark.needs_mps),
),
)
@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
def test_nms_cuda(self, iou, dtype=torch.float64):
def test_nms_gpu(self, iou, device, dtype=torch.float64):
dtype = torch.float32 if device == "mps" else dtype
tol = 1e-3 if dtype is torch.half else 1e-5
err_msg = "NMS incompatible between CPU and CUDA for IoU={}"

boxes, scores = self._create_tensors_with_iou(1000, iou)
r_cpu = ops.nms(boxes, scores, iou)
r_cuda = ops.nms(boxes.cuda(), scores.cuda(), iou)

is_eq = torch.allclose(r_cpu, r_cuda.cpu())
if not is_eq:
# if the indices are not the same, ensure that it's because the scores
# are duplicate
is_eq = torch.allclose(scores[r_cpu], scores[r_cuda.cpu()], rtol=tol, atol=tol)
assert is_eq, err_msg.format(iou)

@needs_mps
@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
def test_nms_mps(self, iou, dtype=torch.float32):
tol = 1e-3 if dtype is torch.half else 1e-5
err_msg = "NMS incompatible between CPU and MPS for IoU={}"

boxes, scores = self._create_tensors_with_iou(1000, iou)
r_cpu = ops.nms(boxes, scores, iou)
r_mps = ops.nms(boxes.to("mps"), scores.to("mps"), iou)
r_gpu = ops.nms(boxes.to(device), scores.to(device), iou)

print(r_cpu.size(), r_mps.size())
is_eq = torch.allclose(r_cpu, r_mps.cpu())
is_eq = torch.allclose(r_cpu, r_gpu.cpu())
if not is_eq:
# if the indices are not the same, ensure that it's because the scores
# are duplicate
is_eq = torch.allclose(scores[r_cpu], scores[r_mps.cpu()], rtol=tol, atol=tol)
is_eq = torch.allclose(scores[r_cpu], scores[r_gpu.cpu()], rtol=tol, atol=tol)
assert is_eq, err_msg.format(iou)

@needs_cuda
@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
@pytest.mark.parametrize("dtype", (torch.float, torch.half))
def test_autocast(self, iou, dtype):
with torch.cuda.amp.autocast():
self.test_nms_cuda(iou=iou, dtype=dtype)
self.test_nms_gpu(iou=iou, dtype=dtype, device="cuda")

@pytest.mark.parametrize(
"device",
Expand Down

0 comments on commit b1cf619

Please sign in to comment.