Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torchmetrics Accuracy() fails if get_metrics() is called before test_on_dataset #272

Open
arthur-thuy opened this issue Sep 8, 2023 · 2 comments
Labels
bug Something isn't working

Comments

@arthur-thuy
Copy link
Contributor

Describe the bug
The torchmetrics Accuracy() class returns an error RuntimeError: You have to have determined mode. if wrapper.get_metrics() is called before wrapper.test_on_dataset, or if wrapper.test_on_dataset is not called at all.

In contrast, Baal's Accuracy() class handles this by returning 'test_accuracy': nan.

To Reproduce
In this gist, a LeNet-5 model with MC Dropout is trained on the entire MNIST data for 1 epoch (no acquisitions).

The script uses Baal's Accuracy() class as standard, and adds torchmetrics Accuracy() class with the option --torchmetrics. The script evaluates on the test set as standard, and omits this with the option --no-test.

Running python baal_error_torchmetrics.py --no-test:

/home/abthuy/Documents/PhD research/active-uncertainty/src/baal_error_torchmetrics.py:45: UserWarning: You have chosen to seed training. This will turn on the CUDNN deterministic setting, which can slow down your training considerably!
  set_seed(config.seed)
Use GPU: NVIDIA RTX A5000
labelling 100 observations
[1724909-MainThread] [baal.modelwrapper:train_on_dataset:83] 2023-09-08T08:28:32.092743Z [info     ] Starting training              dataset=100 epoch=1
[1724909-MainThread] [baal.modelwrapper:train_on_dataset:94] 2023-09-08T08:28:33.700075Z [info     ] Training complete              train_loss=2.309847593307495
Elapsed training time: 0:0:1
{'dataset_size': 100,
 'test_accuracy': nan,
 'test_loss': nan,
 'train_accuracy': 0.08999999612569809,
 'train_loss': 2.309847593307495}
Elapsed total time: 0:0:2

Running python baal_error_torchmetrics.py --torchmetrics:

/home/abthuy/Documents/PhD research/active-uncertainty/src/baal_error_torchmetrics.py:45: UserWarning: You have chosen to seed training. This will turn on the CUDNN deterministic setting, which can slow down your training considerably!
  set_seed(config.seed)
Use GPU: NVIDIA RTX A5000
labelling 100 observations
[1724520-MainThread] [baal.modelwrapper:train_on_dataset:83] 2023-09-08T08:26:52.913297Z [info     ] Starting training              dataset=100 epoch=1
[1724520-MainThread] [baal.modelwrapper:train_on_dataset:94] 2023-09-08T08:26:54.735623Z [info     ] Training complete              train_loss=2.309847593307495
Elapsed training time: 0:0:1
[1724520-MainThread] [baal.modelwrapper:test_on_dataset:123] 2023-09-08T08:26:54.736589Z [info     ] Starting evaluating            dataset=10000
[1724520-MainThread] [baal.modelwrapper:test_on_dataset:133] 2023-09-08T08:26:57.792895Z [info     ] Evaluation complete            test_loss=2.2263691425323486
{'dataset_size': 100,
 'test_accuracy': 0.19924747943878174,
 'test_loss': 2.2263691425323486,
 'test_torch_accuracy': 0.19900000095367432,
 'train_accuracy': 0.08999999612569809,
 'train_loss': 2.309847593307495,
 'train_torch_accuracy': 0.09000000357627869}
Elapsed total time: 0:0:6

Running python baal_error_torchmetrics.py --torchmetrics --no-test:

/home/abthuy/Documents/PhD research/active-uncertainty/src/baal_error_torchmetrics.py:45: UserWarning: You have chosen to seed training. This will turn on the CUDNN deterministic setting, which can slow down your training considerably!
  set_seed(config.seed)
Use GPU: NVIDIA RTX A5000
labelling 100 observations
[1724780-MainThread] [baal.modelwrapper:train_on_dataset:83] 2023-09-08T08:27:53.001640Z [info     ] Starting training              dataset=100 epoch=1
[1724780-MainThread] [baal.modelwrapper:train_on_dataset:94] 2023-09-08T08:27:54.803111Z [info     ] Training complete              train_loss=2.309847593307495
Elapsed training time: 0:0:1
/home/abthuy/anaconda3/envs/al_uncertainty/lib/python3.9/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: The ``compute`` method of metric Accuracy was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
  warnings.warn(*args, **kwargs)
Traceback (most recent call last):
  File "/home/abthuy/Documents/PhD research/active-uncertainty/src/baal_error_torchmetrics.py", line 206, in <module>
    main()
  File "/home/abthuy/Documents/PhD research/active-uncertainty/src/baal_error_torchmetrics.py", line 115, in main
    pprint(wrapper.get_metrics())
  File "/home/abthuy/anaconda3/envs/al_uncertainty/lib/python3.9/site-packages/baal/metrics/mixin.py", line 71, in get_metrics
    metrics = {
  File "/home/abthuy/anaconda3/envs/al_uncertainty/lib/python3.9/site-packages/baal/metrics/mixin.py", line 72, in <dictcomp>
    met_name: get_value(met) for met_name, met in self.metrics.items() if filter in met_name
  File "/home/abthuy/anaconda3/envs/al_uncertainty/lib/python3.9/site-packages/baal/metrics/mixin.py", line 66, in get_value
    val = met.compute().detach().cpu().numpy()
  File "/home/abthuy/anaconda3/envs/al_uncertainty/lib/python3.9/site-packages/torchmetrics/metric.py", line 531, in wrapped_func
    value = compute(*args, **kwargs)
  File "/home/abthuy/anaconda3/envs/al_uncertainty/lib/python3.9/site-packages/torchmetrics/classification/accuracy.py", line 266, in compute
    raise RuntimeError("You have to have determined mode.")
RuntimeError: You have to have determined mode.

Expected behavior
The torchmetrics Accuracy() class should also return 'test_torch_accuracy': nan, just like Baal's Accuracy() class.

Version:

  • OS: Ubuntu 20.04
  • Python: 3.9.16
  • Baal version: 1.8.0

Additional context
/

@arthur-thuy arthur-thuy added the bug Something isn't working label Sep 8, 2023
@arthur-thuy
Copy link
Contributor Author

I found a way to circumvent this problem in my code. When I want the training metrics before running test_on_dataset, I can just do get_metrics("train") and torchmetrics won't throw an error.

@Dref360
Copy link
Member

Dref360 commented Sep 12, 2023

Thank you for submitting this issue!

I'll take a look more deeply over the weekend, I don't have a super good idea how to fix this right now unfortunately.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants