Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Make most metrics work on GPU #3851

Merged
merged 8 commits into from
Feb 27, 2020
Merged

Conversation

bryant1410
Copy link
Contributor

@bryant1410 bryant1410 commented Feb 26, 2020

Fix #3017.

I did the following: renamed unwrap_to_tensors to detach_tensors, and replaced it to send to cuda() instead of cpu(). Then ran all the tests with 2 GPUs (only 4 tests were skipped). It failed on:

  • spearman correlation
  • fbeta
  • entropy
  • boolean accuracy
  • bleu
  • auc
  • allennlp/tests/predictors/srl_test.py:82 (TestSrlPredictor.test_prediction_with_no_verbs)
  • allennlp/tests/predictors/coref_test.py:86 (TestCorefPredictor.test_replace_corefs)
  • simpleseq2seqtest
  • allennlp/tests/models/encoder_decoders/copynet_seq2seq_test.py:15 (CopyNetTest.test_model_can_train_save_load_predict)
  • coreftest
  • graphparsertest
  • allennlp/tests/interpret/simple_gradient_test.py:30 (TestSimpleGradient.test_simple_gradient_coref)

Then went through all detach_tensors() callers and all metrics, and check they are compatible with GPU usage, and calling .cpu() if necessary.

The following should work fine on GPU:

  • attachment score
  • auc - Should work fine for __call__ but get_metric needs CPU because it uses SciPy stuff.
  • average
  • bleu
  • boolean accuracy
  • categorial accuracy
  • CoNLL Coref Scores - uses SciPy stuff, so I just converted to CPU.
  • covariance
  • entropy
  • evalb - doesn't use tensors, not even does import torch.
  • f1
  • fbeta
  • MAE
  • mention recall - the only input tensor is converted to a list right away, so it's all CPU in the end.
  • pearson
  • perplexity
  • seq accuracy
  • span based
  • spearman - Should work fine for __call__ but get_metric needs CPU because it uses SciPy stuff.
  • srl eval - doesn't use tensors, not even does import torch.
  • unigram recall

Then I ran the tests again, and fix one test that does .numpy() without doing .cpu() first. Then I ran all the tests again (2 GPUs; only 4 skipped tests) and it works fine. Then I removed the .cuda() call from detach_tensors and ran the tests again (still with 2 GPUs available; only 4 skipped tests). They were successful. Feel free to try it yourselves because the CI won't do (multi-)GPU testing (remember to change detach_tensors so it does .cuda()).

(Note there are 4 unconditionally skipped tests.)

I realized all metrics have tests except for Average, Perplexity, and MentionRecall. However, if you look at what I changed on those files you'll realize is really minor and it should still work. Note I didn't modify Perplexity but it subclasses Average. It'd be good to still have tests for those though.

Btw, I saw a bunch of nominator / (denominator + 1e-13) (and other similar values), which makes me think those won't work well on FP16. I believe those ideally should be eps args that could be changed if using FP16 for example.

@@ -108,7 +108,7 @@ def _get_brevity_penalty(self) -> float:
return math.exp(1.0 - self._reference_lengths / self._prediction_lengths)

def _get_valid_tokens_mask(self, tensor: torch.LongTensor) -> torch.ByteTensor:
valid_tokens_mask = torch.ones(tensor.size(), dtype=torch.bool)
valid_tokens_mask = torch.ones_like(tensor, dtype=torch.bool)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using _like uses the same device as the input tensor also.

return 0 if precision + recall == 0 else 2 * precision * recall / (precision + recall)

def get_recall(self):
if self.recall_numerator == 0:
if self.recall_denominator == 0:
Copy link
Contributor Author

@bryant1410 bryant1410 Feb 26, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This and the next one were bugs IMO.

return 0
else:
return self.recall_numerator / float(self.recall_denominator)
return self.recall_numerator / self.recall_denominator
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessary anymore in Python 3 these float castings. I tried checking if they were tensors or not before removing them so as not to remove them if they were.

@@ -18,7 +18,7 @@ def __call__(
batched_top_spans: torch.Tensor,
batched_metadata: List[Dict[str, Any]],
):
for top_spans, metadata in zip(batched_top_spans.data.tolist(), batched_metadata):
for top_spans, metadata in zip(batched_top_spans.tolist(), batched_metadata):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in previous PyTorch versions this was necessary, but not anymore.

# the vectors, since each element in the predictions and gold_labels tensor is assumed
# to be a separate observation.
predictions = predictions.view(-1)
gold_labels = gold_labels.view(-1)

self.total_predictions = self.total_predictions.to(predictions.device)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that at initialization time we don't know the device we should use, but here we move it. If it was already in that device, it's a no-op so it's fine.

Copy link
Contributor

@DeNeutoy DeNeutoy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bryant1410 This is awesome and very impactful.

One question - we have GPU tests which we run occasionally in our CI - is it easy to parametrise all of the metrics tests in a conditional way such that they run on the cpu when available, and cpu and gpu when both are?

@bryant1410
Copy link
Contributor Author

bryant1410 commented Feb 26, 2020

One question - we have GPU tests which we run occasionally in our CI - is it easy to parametrise all of the metrics tests in a conditional way such that they run on the cpu when available, and cpu and gpu when both are?

The approach I can come up with is using pytest.param for every test function that should support both GPU and CPU, like (I guess it works with self):

@pytest.mark.parametrize("device", [
    "cpu",
    pytest.param(
        "cuda",
        marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda")
    ),
])
def test_func(self, device):
    ...

Does it make sense?

@bryant1410
Copy link
Contributor Author

bryant1410 commented Feb 26, 2020

The approach I can come up with is using pytest.param for every test function that should support both GPU and CPU, like (I guess it works with self):

@pytest.mark.parametrize("device", [
    "cpu",
    pytest.param(
        "cuda",
        marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda")
    ),
])
def test_func(self, device):
    ...

Does it make sense?

In jax they do something similar.

@DeNeutoy
Copy link
Contributor

Nice, is there a way to make that decorator itself a decorator? It's a bit verbose and it would be nice to be able to do:

@multi_device
def test_func(self, device):

This might not be possible with pytest, I know it's a bit finicky about how it uses those decorators.

Comment on lines +263 to +264
with open(gold_file_path, "w") as gold_file, open(
prediction_file_path, "w"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed this because the test failed (function called twice), and also the append wasn't necessary.

@bryant1410
Copy link
Contributor Author

I added that utility. I discovered that pytest and unittest don't go along well together, and it's hard to parametrize in that context with the utils they provide.

I had to change the test to actually use the device. That implied to avoid using NumPy if I could as well. In the end, torch provides a testing module with assert_allclose that's convenient. It has good defaults, and when you specify rtol you also have to specify atol (either both or none), don't know why. I had also to change some FloatTensor creations and similar ones for tensor, because of some exceptions (and I saw they are recommended to not use anymore).

I think as a general practice, we should specify the device when we create new tensors with ones, zeros, randn, rand, and tensor (or any constructor; not with like ones_like because the device is copied).

Comment on lines -80 to +92
numpy.testing.assert_almost_equal(precisions, self.desired_precisions, decimal=2)
numpy.testing.assert_almost_equal(recalls, self.desired_recalls, decimal=2)
numpy.testing.assert_almost_equal(fscores, self.desired_fscores, decimal=2)
assert_allclose(precisions, self.desired_precisions)
assert_allclose(recalls, self.desired_recalls)
assert_allclose(fscores, self.desired_fscores)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In many of these changes, the expected precision is actually stronger. It works because I changed 0.33 as the expected value to 1 / 3.

@@ -44,14 +44,14 @@ def __call__(
# Flatten predictions, gold_labels, and mask. We calculate the Spearman correlation between
# the vectors, since each element in the predictions and gold_labels tensor is assumed
# to be a separate observation.
predictions = predictions.view(-1)
gold_labels = gold_labels.view(-1)
predictions = predictions.reshape(-1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some reason, on GPU sometimes this view fails. We can use reshape anyway, that tries to do a view but if it can't will reshape.

Comment on lines +214 to +215
assert metric._ignore_classes == ["V"] # type: ignore
assert metric._label_vocabulary == self.vocab.get_index_to_token_vocabulary( # type: ignore
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Idk why this mypy errors didn't appear before.

Copy link
Contributor

@DeNeutoy DeNeutoy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sweeeeet, looks great, thanks @bryant1410 !

allennlp/common/testing/test_case.py Outdated Show resolved Hide resolved
@@ -40,3 +43,34 @@ def setUp(self):

def tearDown(self):
shutil.rmtree(self.TEST_DIR)


def parametrize(arg_names: Iterable[str], arg_values: Iterable[Iterable[Any]]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cute, this is really nice!

Co-Authored-By: Mark Neumann <markn@allenai.org>
from allennlp.common.testing import AllenNlpTestCase, multi_device


class TestFromParams(AllenNlpTestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Class name here needs updating.

(I came to see what Mark thought looked cute, noticed a copy-paste bug.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Good catch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I put TestTesting after the module name, hope it's fine.

@DeNeutoy DeNeutoy merged commit ddebbdc into allenai:master Feb 27, 2020
@bryant1410 bryant1410 deleted the metrics-gpu branch February 27, 2020 19:50
@bryant1410 bryant1410 mentioned this pull request Feb 28, 2020
@bryant1410
Copy link
Contributor Author

@DeNeutoy related to this and to be on the safe side, in Trainer, before calling get_metrics(), shouldn't we do with torch.no_grad(): (in training; in validation it's already there)?

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Why tensors are moved to CPU when calculating metrics?
3 participants