Skip to content

Commit

Permalink
Fix docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
BloodAxe committed Sep 27, 2023
1 parent bf13606 commit a603519
Showing 1 changed file with 21 additions and 6 deletions.
27 changes: 21 additions & 6 deletions src/super_gradients/training/utils/callbacks/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,12 +1085,27 @@ def __init__(
max_images: int = -1,
):
"""
:param metric: Metric, will be the metric which is monitored.
:param metric:
:param metric_component_name:
:param loss_to_monitor:
:param max:
:param freq: Frequency (in epochs) of performing this callback. 1 means every epoch. 2 means every other epoch. Default is 1.
:param metric_component_name: In case metric returns multiple values (as Mapping),
the value at metric.compute()[metric_component_name] will be the one monitored.
:param loss_to_monitor: str, loss_to_monitor corresponding to the 'criterion' passed through training_params in Trainer.train(...).
Monitoring loss follows the same logic as metric_to_watch in Trainer.train(..), when watching the loss and should be:
if hasattr(criterion, "component_names") and criterion.forward(..) returns a tuple:
<LOSS_CLASS.__name__>"/"<COMPONENT_NAME>.
If a single item is returned rather then a tuple:
<LOSS_CLASS.__name__>.
When there is no such attributes and criterion.forward(..) returns a tuple:
<LOSS_CLASS.__name__>"/"Loss_"<IDX>
:param max: bool, Whether to take the batch corresponding to the max value of the metric/loss or
the minimum (default=False).
:param freq: int, epoch frequency to perform all of the above (default=1).
:param enable_on_train_loader: Controls whether to enable this callback on the train loader. Default is False.
:param enable_on_valid_loader: Controls whether to enable this callback on the valid loader. Default is True.
:param max_images: Maximum images to save. If -1, save all images.
Expand Down Expand Up @@ -1182,7 +1197,7 @@ def _on_batch_end(self, context: PhaseContext) -> None:
if self.metric_component_name is not None:
if not isinstance(score, Mapping) or (isinstance(score, Mapping) and self.metric_component_name not in score.keys()):
raise RuntimeError(
f"metric_component_name: {self.metric_component_name} is not a component " f"of the monitored metric: {self.metric.__class__.__name__}"
f"metric_component_name: {self.metric_component_name} is not a component of the monitored metric: {self.metric.__class__.__name__}"
)
score = score[self.metric_component_name]
elif len(score) > 1:
Expand Down

0 comments on commit a603519

Please sign in to comment.