Skip to content

Commit

Permalink
update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
themattinthehatt committed May 7, 2024
1 parent d28060e commit 9f58476
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 7 deletions.
15 changes: 8 additions & 7 deletions docs/source/faqs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -112,24 +112,25 @@ Model training

.. _faq_epoch:

.. dropdown:: What is and how many epochs should I use for training?
.. dropdown:: How many epochs should I use for training?

**What is an epoch?**
An epoch refers to one complete pass through the entire training dataset. During an epoch,
the model is trained on every sample in the dataset exactly once. Find more info
`here <https://lightning-pose.readthedocs.io/en/latest/source/user_guide/config_file.html#model-training-parameters>`
`here <https://lightning-pose.readthedocs.io/en/latest/source/user_guide/config_file.html#model-training-parameters>`_
(this link takes you to another set of docs specifically for Lightning Pose).

**With what value should I start?**
To train a full model, we recommend starting with the default - 300. To get a baseline
understanding of how the model performs, we recommend 50 epochs as the minimum number to get
a valid model to check.
**What are the trade-offs for increasing or decreasing that number?**
Increasing epochs may enhance convergence and accuracy but raise the risk of overfitting.
a valid model to check.

**What are the trade-offs for increasing or decreasing the number of epochs?**
Increasing the epochs may enhance convergence and accuracy but raises the risk of overfitting.
Conversely, fewer epochs might speed up training but risk underfitting. Balancing epochs is
crucial to minimize both underfitting and overfitting.




Post-processing
---------------

Expand Down
32 changes: 32 additions & 0 deletions docs/source/tabs/train_status.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,35 @@ The earlier one is the supervised model, and the later one is the semi-supervise
If you don't see all your models in that tab,
hit the refresh button on the top right corner of the screen,
and the other models should appear.

Available metrics
-----------------

The following are the important metrics for all model types
(supervised, context, semi-supervised, etc.):

* ``train_supervised_loss``: this is the same as ``train_heatmap_mse_loss_weighted``, which is the
mean square error (MSE) between the true and predicted heatmaps on labeled training data
* ``train_supervised_rmse``: the root mean square error (RMSE) between the true and predicted
(x, y) coordinates on labeled training data; scale is in pixels
* ``val_supervised_loss``: this is the same as ``val_heatmap_mse_loss_weighted``, which is the
MSE between the true and predicted heatmaps on labeled validation data
* ``val_supervised_rmse``: the RMSE between the true and predicted (x, y) coordinates on labeled
validation data; scale is in pixels

The following are important metrics for the semi-supervised models:

* ``train_pca_multiview_loss_weighted``: the ``train_pca_multiview_loss`` (in pixels), which
measures multiview consistency, multplied by the loss weight set in the configuration file.
This metric is only computed on batches of unlabeled training data.
* ``train_pca_singleview_loss_weighted``: the ``train_pca_singleview_loss`` (in pixels), which
measures pose plausibility, multplied by the loss weight set in the configuration file.
This metric is only computed on batches of unlabeled training data.
* ``train_temporal_loss_weighted``: the ``train_temporal_loss`` (in pixels), which
measures temporal smoothness, multplied by the loss weight set in the configuration file.
This metric is only computed on batches of unlabeled training data.
* ``total_unsupervised_importance``: a weight on all *weighted* unsupervised losses that linearly
increases from 0 to 1 over 100 epochs
* ``total_loss``: weighted supervised loss (``train_heatmap_mse_loss_weighted``) plus
``total_unsupervised_importance`` times the sum of all applicable weighted unsupervised losses

0 comments on commit 9f58476

Please sign in to comment.