Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct behavior for argument gpus in Trainer #561

Merged
merged 1 commit into from
Nov 30, 2019

Conversation

mpariente
Copy link
Contributor

This PR fixes #558

Instead of considering that gpus=0 and gpus=[] in Trainer means GPU training, it doesn't. This makes the API more consistent IMO.

pytorch_lightning/trainer/trainer.py Show resolved Hide resolved
@williamFalcon
Copy link
Contributor

@mpariente @Borda we need a test for this

@williamFalcon
Copy link
Contributor

williamFalcon commented Nov 29, 2019

Almost, but not quite right. Here's what this PR implies

gpus cuda_avail() on_gpu
None T None
None F None
0 T F
0 F F
1 T T
1 F F
2 T T
2 F F
[] T []
[] F []
[0] T T
[0] F F

The outcome we actually want is (* means need to fix):

gpus cuda_avail() on_gpu
None T F*
None F F*
0 T F
0 F F
1 T T
1 F F
2 T T
2 F F
[] T F*
[] F F*
[0] T T
[0] F F

@mpariente
Copy link
Contributor Author

Which test would you like? I can write one.
What I implemented actually works as in the second table, not as the first one.

@mpariente
Copy link
Contributor Author

mpariente commented Nov 29, 2019

gpus_options = [0, 1, [], [0], None, "", "0, 3"]
for cuda_is_available in [True, False]:
    for gpus in gpus_options:
        # This is what I suggest using
        on_gpu = True if (gpus and cuda_is_available) else False
        # By opposition to this which does not work
        # on_gpu = gpus  and cuda_is_available
        print('gpus={}, torch.cuda.is_available={}, '
              'on_gpu={}'.format(gpus,
                                 cuda_is_available,
                                 on_gpu))

@williamFalcon
Copy link
Contributor

the tables imply that on_gpu should always be a boolean

@mpariente
Copy link
Contributor Author

mpariente commented Nov 29, 2019

With what I commited, on_gpu will always be boolean so that

  • on_gpu will be False if torch.cuda.is_available is False.
  • on_gpu will be False when gpus is either 0, [], None or ""
  • on_gpu will be True otherwise.

This is what we want right?

@williamFalcon
Copy link
Contributor

gpus can also be a string no?

@mpariente
Copy link
Contributor Author

Yes, of the form "0, 2, 3".
If gpus = "", self.on_gpu will be False, which is also the intended behavior I believe. I update the small script above accordingly.
By the way, with gpus = "", this will break as well.

@jeffling
Copy link
Contributor

jeffling commented Nov 29, 2019

@mpariente We should also add "0" as a test case. In this case, it should mean 'use GPU 0' since it's evaluated the same as a list.

The above does concern me a bit, i'll make an issue to discuss the usage of gpus in general.

#563

@mpariente
Copy link
Contributor Author

@williamFalcon Any update on what should be done for this to be merged please?

@williamFalcon williamFalcon merged commit df7b6d9 into Lightning-AI:master Nov 30, 2019
@williamFalcon
Copy link
Contributor

@mpariente looks great! thanks for the PR

@mpariente
Copy link
Contributor Author

Thanks.

Also, should I make a PR to support gpus="" ?
Trainer(gpus="") fails and the fix is quite simple.

With gpus="", we could have this usage :
parser.add_argument('--gpus', type=str, default="") which is quite pratcical, although, this currently works
parser.add_argument('--gpus', type=str, default=None)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

trainer.on_gpu == True for gpus = 0 or gpus=[]
4 participants