-
Notifications
You must be signed in to change notification settings - Fork 417
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Metrics refactor pt1 #1411
Metrics refactor pt1 #1411
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM, a few questions and suggested code changes for style nits.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, please see my nit about the TODO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One more quick change, and I think this is ready to merge!
This PR implements the second half of the metrics refactor (#1411) that updates the training loop and all the Composer models. The training loop now uses the previously added state.train_metrics and state.eval_metrics from #1411 to perform all training and evaluation on the correct set of metrics. The main changes with respect to models is that the models now have split up validate() into eval_forward() and update_metrics() methods, which run an evaluation forward pass and update the metrics with the outputs of the evaluation forward pass respectively. This is mainly to get rid of the double forward pass (#467).
This PR implements part 1 of the Metrics Refactor, which entails removing metrics from the
Evaluator
class and storing raw deep-copied metrics in theState
class. TheEvaluator
class now stores evaluation metric names instead of metric instances, which are matched against the model-definedmodel.val_metrics
to indicate which metrics will be computed at eval time. Additionally, instead of storing all computed metrics as part ofstate.current_metrics
, both the raw non-computed training and validation metrics are now stored separately as part ofstate.train_metrics
andstate.eval_metrics
.All tests updated and passing, regression tests will be conducted once PR 2 is ready.
Note: PR for part 2 will also touch a lot of the same code, so expect some refactors and clean up (mostly in the
Trainer
class) after this PR is merged in.Closes CO-679