Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: remove calls to Pytorch Dataset len #8647

Merged
merged 7 commits into from
Mar 27, 2024

Conversation

wes-turner
Copy link
Contributor

@wes-turner wes-turner commented Jan 5, 2024

Description

Pytorch Datasets (torch.utils.data.Dataset) aren't guaranteed to have a __len__ implemented (Datasets can be either "map-style" or "iterable-style". When map-style, they must implement a __len__, and when iterable-style they may). The __len__ on a Pytorch DataLoader may pass the call through to its Dataset.

A det.pytorch.PyTorchTrial is typically constructed from a det.pytorch.DataLoader. det.pytorch.DataLoader cannot, itself, front an iterable-style Pytorch Dataset. It is, however, possible to construct a det.pytorch.PyTorchTrial with an unwrapped torch.utils.data.Dataset if context.experimental.disable_dataset_reproducibility_checks() is called in the PyTorchTrial's __init__.

Before this patch, during a PyTorchTrialContext.run we called len on the trial's validation dataloader. Per the above, it had been possible to construct a trial with a validation dataloader that did not have __len__ implemented, and in this case run would raise a runtime TypeError exception.

Turns out, though, those existing calls to __len__ that weren't actually necessary. This patch revises them with no functional change in behavior

  • instead of len(validation_loader) to check for emptiness before iterating through it, instead check the number of times the validation_loader is iterated through, raising the same error if it was empty.
  • removes a call to len where the result was entirely ignored.

This PR also makes a couple "continuous improvement" changes, including moving around a couple pieces of code and renames variables so that its logic is a little more obvious.

Test Plan

I've tested this by hand by modifying build_validation_training_loader in the example https://github.com/determined-ai/determined/blob/main/examples/tutorials/mnist_pytorch/train.py to return a Dataset for a torch.utils.data.Dataset subclass that has no implemented __len__ and then running https://github.com/determined-ai/determined/blob/main/harness/tests/experiment/pytorch/test_examples.py tests on it. Without the patch, validation fails because of the call to __len__. With the patch, validation succeeds.

We don't have any unit tests of the function I modified (_PyTorchTrialController._run or its caller _PyTorchTrialController.run), and this doesn't quite seem like enough of a patch to create unit tests for the class for it. It also doesn't seem quite appropriate to create another end-to-end test just to ensure __len__ isn't called on a validation loader. Maybe more automated tests aren't needed?

For the release party, if you'd like to test this yourself, create a PyTorchTrial object of a class that's implemented build_validation_data_loader to return a plain, unwrapped torch.utils.data.DataLoader that itself has no __len__. Then run a training loop for this trial object.

Commentary (optional)

Checklist

  • Changes have been manually QA'd
  • User-facing API changes need the "User-facing API Change" label.
  • Release notes should be added as a separate file under docs/release-notes/.
    See Release Note for details.
  • Licenses should be included for new code which was copied and/or modified from any external code.

Ticket

[MLG-1022]

@cla-bot cla-bot bot added the cla-signed label Jan 5, 2024
Copy link

netlify bot commented Jan 5, 2024

Deploy Preview for determined-ui canceled.

Name Link
🔨 Latest commit 7cb26c5
🔍 Latest deploy log https://app.netlify.com/sites/determined-ui/deploys/660477492cdedd00086a070f

@wes-turner wes-turner force-pushed the wes/pytorchtrial-no-dataset-len branch from 71b78f4 to 44e0e57 Compare January 11, 2024 22:22
Copy link

codecov bot commented Jan 11, 2024

Codecov Report

Attention: Patch coverage is 78.57143% with 3 lines in your changes are missing coverage. Please review.

Project coverage is 47.14%. Comparing base (18dd29e) to head (7cb26c5).

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #8647      +/-   ##
==========================================
- Coverage   47.16%   47.14%   -0.02%     
==========================================
  Files        1150     1150              
  Lines      141674   141671       -3     
  Branches     2415     2417       +2     
==========================================
- Hits        66814    66786      -28     
- Misses      74670    74695      +25     
  Partials      190      190              
Flag Coverage Δ
backend 42.86% <ø> (-0.07%) ⬇️
harness 64.32% <78.57%> (-0.01%) ⬇️
web 38.96% <ø> (ø)

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Coverage Δ
harness/determined/pytorch/_pytorch_trial.py 80.21% <78.57%> (+0.05%) ⬆️

... and 3 files with indirect coverage changes

@wes-turner wes-turner force-pushed the wes/pytorchtrial-no-dataset-len branch 4 times, most recently from 1452720 to a3c2ae0 Compare January 12, 2024 22:32
@wes-turner wes-turner marked this pull request as ready for review January 16, 2024 22:56
@wes-turner wes-turner requested a review from a team as a code owner January 16, 2024 22:56
for callback in self.callbacks.values():
callback.on_validation_epoch_start()

idx = -1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are taking a check which was fully contained in two consecutive lines and spreading it over a wider area here. I don't want you to rewrite it because what you have is simple and easy, but can you just put a comment above idx = -1 to explain why -1 is significant?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a nice idea. Done.

@@ -1059,12 +1054,17 @@ def _validate(self, searcher_op: Optional[core.SearcherOperation] = None) -> Dic
# common than evaluate_batch() and we can't know how the user processed their
# validation data.
if self._evaluate_batch_defined():
# Reshape and sum.
# TODO: remove the type directive once we upgrade to mypy >= 1.7.0
inputs_total, batches_total = [sum(n) for n in zip(*input_counts)] # type: ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_counts doesn't seem to be defined in the evaluate_full_dataset codepath

Copy link
Contributor Author

@wes-turner wes-turner Jan 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code is kind of unfortunately structured and pretty hard to figure out:

1. if evaluate_batch:
2.   do stuff in a managed way
3.   if chief:
4.     do managed stuff relevant to the chief
5. else:  # evaluate_full
6.   do whatever stuff evaluate_full says to do
7. 
8. if chief:
9.   if evaluate_batch:
10.     report batch-specific detail
11.  report general stuff

Before patch:

  • num_inputs defined / calculated line 2
  • input_counts defined line 2 (from gathered num_inputs)
  • num_inputs re-defined line 4 when evaluate_batch (from input_counts)
  • num_inputs independently defined line 6 when evaluate_full to be given a meaning similar to that on line 4 (from len(validation_loader))
  • num_inputs (per its second definition) used line 10 and nowhere else

Pre-patch, num_inputs was being defined when both evaluate_batch and evaluate_full, but only used during evaluate_batch.

After patch:

  • num_inputs defined / calculated line 2
  • input_counts defined line 2 (from gathered num_inputs)
  • inputs_total defined line 10 (from input_counts)
  • inputs_total used line 10

Effects of this patch:

  • num_inputs not defined during evaluate_full (which is fine, because it hadn't been used there)
  • gathered calculation from input_counts moved to where it's used (line 10)
  • gathered calculation from input_counts given a new name so num_inputs isn't overloaded

Comment on lines 1575 to 1579
For full "determined" functionality, this must return an instance of
:py:class:`determined.pytorch.DataLoader`. It can also return an unwrapped
:py:class:`torch.utils.data.DataLoader` if you need more control over the underlying
DataLoader and are willing to sacrifice some Determined features (ex: automatic data
sharding).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for catching that our typing and docstring had fallen out of correctness.

However, this is a user-facing docstring, and given that I think this addition is much too vague. What is full "determined" functionality anyway?

I wouldn't recommend answering that question here; I would link to our existing docs on the subject.

Something like:

Users with a MapDataset will normally return a :class:`determined.pytorch.DataLoader`, but users with an IterableDataset or with other advanced needs may return a bare ``torch.utils.data.DataLoader`` if they follow the steps described in :ref:`pytorch-reproducible-dataset`.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree about the vagueness. It made me feel a little gross, too.

Ideally, the docstring should say why a user would choose which class to return. It's somewhat easy to say "if you can't return a det.pytorch.DataLoader, return the other one. It's harder to come up with an explanation for "why managed" that's not vague.

I like your idea of referring to docs for that. That's something docs.determined.ai should do. The doc you suggested does have a nice note explaining how, I can't find anything on the site for why.

I think a voice chat is probably the best way to work something out from here. When you get to this and you've got the time to, could you please give me a call?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ended up with a little from column A, a little from column B.

"""
Defines the data loader to use during validation.

Must return an instance of :py:class:`determined.pytorch.DataLoader`.
For full "determined" functionality, this must return an instance of
:py:class:`determined.pytorch.DataLoader`. It can also return an unwrapped
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

skip the :py in :py:class:..., as it is not necessary

Copy link
Contributor Author

@wes-turner wes-turner Jan 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. Thanks! (also changed in another existing place)

Copy link
Member

@rb-determined-ai rb-determined-ai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are missing the most important usage of len(data_loader) which is here.

@wes-turner
Copy link
Contributor Author

You are missing the most important usage of len(data_loader) which is here.

When __len__ doesn't exist for a DataLoader, it'll throw a TypeError. I believe, then, that line isn't covered by the MLG-1022. The author of that line kept this in mind.

I saw in #4303 (comment) that when you last looked at this line you preferred a different solution then, too.

I understand that there can be inaccuracy both from this line and from variable-length datasets, too, and that counting examples is more foolproof. But I haven't worked out what the implications of that inaccuracy might be (or the cost of a solution) well enough to create a ticket for it.

Copy link
Contributor

@azhou-determined azhou-determined left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

Datasets aren't guaranteed to have a `__len__` implemented. This removes
two calls to `__len__` that weren't actually necessary.
This is functionally an abstract method, but actually marking it as
abstract or making it raise an exception will functionally change the
interface people have to implement in a breaking change. This is a
problem for tomorrow.
@wes-turner wes-turner force-pushed the wes/pytorchtrial-no-dataset-len branch from cc0bc65 to 7cb26c5 Compare March 27, 2024 19:45
@wes-turner wes-turner requested a review from a team as a code owner March 27, 2024 19:45
@wes-turner wes-turner requested review from MikhailKardash and removed request for MikhailKardash March 27, 2024 19:45
@wes-turner wes-turner merged commit 8f5de35 into main Mar 27, 2024
87 of 92 checks passed
@wes-turner wes-turner deleted the wes/pytorchtrial-no-dataset-len branch March 27, 2024 22:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants