Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Jun 12, 2020
1 parent b3e0459 commit d786048
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions tests/metrics/functional/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,14 +205,18 @@ def test_f1_score(pred, target, exp_score):


@pytest.mark.parametrize(['sample_weight', 'pos_label', "exp_shape"], [
pytest.param(1, 1., 40),
pytest.param(None, 1., 39),
pytest.param(1, 1., 42),
pytest.param(None, 1., 42),
])
def test_binary_clf_curve(sample_weight, pos_label, exp_shape):
# TODO: move back the pred and target to test func arguments
# if you fix the array inside the function, you'd also have fix the shape,
# because when the array changes, you also have to fix the shape
seed_everything(0)
pred = torch.randint(low=51, high=99, size=(100,), dtype=torch.float) / 100
target = torch.tensor([int(bool(idx % 2)) for idx in range(100)])
sample_weight = torch.ones_like(pred) * sample_weight if sample_weight is not None else None
target = torch.tensor([0, 1] * 50, dtype=torch.int)
if sample_weight is not None:
sample_weight = torch.ones_like(pred) * sample_weight

fps, tps, thresh = _binary_clf_curve(pred, target, sample_weight, pos_label)

Expand Down

0 comments on commit d786048

Please sign in to comment.