Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prune the tree: move general functions to base class NNCV #18

Merged
merged 17 commits into from
Nov 16, 2022

Conversation

luigibonati
Copy link
Owner

Description

Moved fit function to NeuralNetworkCV class. In addition, added default functions for prepare_dataset, evaluate_dataset and train_epoch

List of changes

  • Moved fit in NNCV
  • Added prepare_dataset function (both for base class and children where needed)
  • Move training parameters (e.g. loss_type for tica) outside fit
  • add custom_train_epoch to change the workflow
  • add tests for custom_train
  • add custom_eval function
  • change log info (e.g. train loss) to dictionary
  • changed .to(device) into set_device

@luigibonati
Copy link
Owner Author

  • I did some refactoring here, moving a lot of stuff from cv classes to base NNCV (@EnricoTrizio see for instance the new TDA file).
  • The only breaking change is for DeepTICA, related to the parameters of the loss function (loss_type and n_eig) that are removed from the fit function and moved to a specific set_loss_function method
  • @EnricoTrizio I also changed how the training log is printed, so in principle if you want to print something else (e.g. contribution to tda loss) you could simply add new keys and values to the self.logs dictionary --> an example is in DeepTICA where I add a new key in the constructor and set it in the evaluate_dataset function

@lgtm-com
Copy link

lgtm-com bot commented Nov 16, 2022

This pull request fixes 1 alert when merging 82202c5 into adf73af - view on LGTM.com

fixed alerts:

  • 1 for Unused import

Heads-up: LGTM.com's PR analysis will be disabled on the 5th of December, and LGTM.com will be shut down ⏻ completely on the 16th of December 2022. Please enable GitHub code scanning, which uses the same CodeQL engine ⚙️ that powers LGTM.com. For more information, please check out our post on the GitHub blog.

@luigibonati luigibonati merged commit 053dedf into main Nov 16, 2022
@luigibonati luigibonati deleted the refactoring branch November 16, 2022 09:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant