-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Report CycleGAN validation metrics correctly to wandb #2131
base: master
Are you sure you want to change the base?
Conversation
.pre-commit-config.yaml
Outdated
@@ -1,5 +1,10 @@ | |||
exclude: "external/gcsfs/" | |||
repos: | |||
- repo: https://github.com/psf/black |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like black to run before flake8 so that it can auto-fix flake8 issues before flake8 runs.
generator | ||
discriminator_optimizer: configuration for the optimizer used to train the | ||
discriminator | ||
optimizer: configuration for the optimizer used to train the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Merging these was necessary so a wandb sweep can operate on the learning rate, there's no way to pair two hyperparameters for wandb sweeps.
@@ -316,76 +314,8 @@ def _init_targets(self, shape: Tuple[int, ...]): | |||
torch.Tensor(shape).fill_(0.0).to(DEVICE), requires_grad=False | |||
) | |||
|
|||
def evaluate_on_dataset( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function was never actually used (it was called, but only on validation data, and I never provided validation datasets before now).
@@ -395,6 +325,8 @@ def train_on_batch( | |||
[sample, time, tile, channel, y, x] | |||
real_b: a batch of data from domain B, should have shape | |||
[sample, time, tile, channel, y, x] | |||
training: if True, the model will be trained, otherwise we will |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This allows getting training metrics on validation data.
@@ -51,12 +50,11 @@ class CycleGANNetworkConfig: | |||
cycle_weight: weight of the cycle loss | |||
generator_weight: weight of the generator's gan loss | |||
discriminator_weight: weight of the discriminator gan loss | |||
reload_path: path to a directory containing a saved CycleGAN model to use |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was just a missing docstring entry.
Currently the CycleGAN training routine does not report the same training losses on validation data. This PR refactors the code to use training losses for validation data, produce one wandb report per epoch, and adds regularization loss as an output metric.
Refactored public API:
Significant internal changes:
Coverage reports (updated automatically):