Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor PyTorch for explicit Lighnting/Wandb hyperparameters #30

Merged
merged 34 commits into from
Oct 10, 2022
Merged

Conversation

mwalmsley
Copy link
Owner

@mwalmsley mwalmsley commented Oct 3, 2022

Major changes

  • Refactor pytorch define_model API to pass all hparams as explicit simple variables (e.g. architecture_name) rather than convenient but implicit functions (e.g. loss_function, model). This is messier in that it requires a lot of args, but allows for tracking/sweeps by wandb and restoring from checkpoints by lightning.
  • Add stochastic_depth_probability as a hyperparameter. In short, this sometimes randomly skips whole blocks. Was previously set and on by default - should be no change to users, but can now be altered if desired.
  • Add efficientnetb2 and b4 as options. No performance improvement on my current tests.
  • Add predict_on_catalog.py for pytorch. Allows easy predictions on a catalog dataframe.

Minor (but potentially breaking) changes

  • Rename pytorch model_architecture to architecture_name. Clearer that it is not the architecture func. itself. (may break api)
  • Deprecate requirements.txt
  • Set explicit lightning version (may break a build)
  • Bump version ahead of pypi/pip

cc @patrikasvanagas, @camallen

@mwalmsley mwalmsley self-assigned this Oct 3, 2022
Copy link
Collaborator

@camallen camallen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - fwiw i've been using this version of the code in the batch processing system

@@ -26,7 +26,7 @@
'torch == 1.10.1',
'torchvision == 0.11.2',
'torchaudio == 0.10.1',
'pytorch-lightning',
'pytorch-lightning==1.6.5', # 1.7 requires protobuf version incompatible with tensorflow/tensorboard. Otherwise works.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not blocking - in theory we could isolate the pytorch and tensorflow installs to their own python virtual envs, thus avoiding this conflict and allowing pytorch lightning to resolve as high as possible

Unless of course the tensorboard dependency i used in pytorch....then please ignore me.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could isolate the pytorch and tensorflow installs to their own python virtual envs

Good idea but I don't know how to do that.

In practice there is very little advantage to being on the absolute latest package and I suspect the tensorboard team will update their deps shortly.

datamodule=datamodule,
ckpt_path='best' # can optionally point to a specific checkpoint here e.g. "/share/nas2/walml/repos/gz-decals-classifiers/results/early_stopping_1xgpu_greyscale/checkpoints/epoch=26-step=16847.ckpt"
)
# trainer.test(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason to remove the test step of after fitting? Does this speed up the system but impact quality?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Testing routinely is bad practice as you may accidentally tune your hparams to overfit the model. I was being a bit lazy when adding here earlier. I have left it commented as example, with a warning note.

@mwalmsley mwalmsley merged commit c312a83 into main Oct 10, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants