Skip to content

Commit

Permalink
[Improvement] max_batches support to training log and tqdm progress b…
Browse files Browse the repository at this point in the history
…ar. (#1554)

* Added max_batches support to training log and tqdm progress bar.

* Added changing string in accordance which parameter is used (len(loader) of max_batches)

* Replaced stopping condition for the epoch with a smaller one
  • Loading branch information
hakuryuu96 committed Oct 23, 2023
1 parent 68a96bb commit 749a9c7
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 17 deletions.
42 changes: 29 additions & 13 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,8 +458,12 @@ def _train_epoch(self, context: PhaseContext, silent_mode: bool = False) -> tupl
# SET THE MODEL IN training STATE
self.net.train()

expected_iterations = len(self.train_loader) if self.max_train_batches is None else self.max_train_batches

# THE DISABLE FLAG CONTROLS WHETHER THE PROGRESS BAR IS SILENT OR PRINTS THE LOGS
with tqdm(self.train_loader, bar_format="{l_bar}{bar:10}{r_bar}", dynamic_ncols=True, disable=silent_mode) as progress_bar_train_loader:
with tqdm(
self.train_loader, total=expected_iterations, bar_format="{l_bar}{bar:10}{r_bar}", dynamic_ncols=True, disable=silent_mode
) as progress_bar_train_loader:
progress_bar_train_loader.set_description(f"Train epoch {context.epoch}")

# RESET/INIT THE METRIC LOGGERS
Expand All @@ -471,6 +475,9 @@ def _train_epoch(self, context: PhaseContext, silent_mode: bool = False) -> tupl
context.update_context(loss_avg_meter=loss_avg_meter, metrics_compute_fn=self.train_metrics)

for batch_idx, batch_items in enumerate(progress_bar_train_loader):
if expected_iterations <= batch_idx:
break

batch_items = core_utils.tensor_container_to_device(batch_items, device_config.device, non_blocking=True)
inputs, targets, additional_batch_items = sg_trainer_utils.unpack_batch_items(batch_items)

Expand Down Expand Up @@ -510,9 +517,6 @@ def _train_epoch(self, context: PhaseContext, silent_mode: bool = False) -> tupl
progress_bar_train_loader.set_postfix(**pbar_message_dict)
self.phase_callback_handler.on_train_batch_end(context)

if self.max_train_batches is not None and self.max_train_batches - 1 <= batch_idx:
break

self.train_monitored_values = sg_trainer_utils.update_monitored_values_dict(
monitored_values_dict=self.train_monitored_values, new_values_dict=pbar_message_dict
)
Expand Down Expand Up @@ -1331,21 +1335,23 @@ def forward(self, inputs, targets):

self.ckpt_best_name = self.training_params.ckpt_best_name

self.max_train_batches = self.training_params.max_train_batches
self.max_valid_batches = self.training_params.max_valid_batches

if self.training_params.max_train_batches is not None:
if self.training_params.max_train_batches > len(self.train_loader):
logger.warning("max_train_batches is greater than len(self.train_loader) and will have no effect.")
self.max_train_batches = len(self.train_loader)
elif self.training_params.max_train_batches <= 0:
raise ValueError("max_train_batches must be positive.")

if self.training_params.max_valid_batches is not None:
if self.training_params.max_valid_batches > len(self.valid_loader):
logger.warning("max_valid_batches is greater than len(self.valid_loader) and will have no effect.")
self.max_valid_batches = len(self.valid_loader)
elif self.training_params.max_valid_batches <= 0:
raise ValueError("max_valid_batches must be positive.")

self.max_train_batches = self.training_params.max_train_batches
self.max_valid_batches = self.training_params.max_valid_batches

# STATE ATTRIBUTE SET HERE FOR SUBSEQUENT TRAIN() CALLS
self._first_backward = True

Expand Down Expand Up @@ -1394,6 +1400,7 @@ def forward(self, inputs, targets):
batch_accumulate=self.batch_accumulate,
train_dataset_length=len(self.train_loader.dataset),
train_dataloader_len=len(self.train_loader),
max_train_batches=self.max_train_batches,
)

processing_params = self._get_preprocessing_from_valid_loader()
Expand Down Expand Up @@ -2014,7 +2021,12 @@ def _validate_epoch(self, context: PhaseContext, silent_mode: bool = False) -> D
self._reset_metrics()
self.valid_metrics.to(device_config.device)
return self.evaluate(
data_loader=self.valid_loader, metrics=self.valid_metrics, evaluation_type=EvaluationType.VALIDATION, epoch=context.epoch, silent_mode=silent_mode
data_loader=self.valid_loader,
metrics=self.valid_metrics,
evaluation_type=EvaluationType.VALIDATION,
epoch=context.epoch,
silent_mode=silent_mode,
max_batches=self.max_valid_batches,
)

def _test_epoch(self, data_loader: DataLoader, context: PhaseContext, silent_mode: bool = False, dataset_name: str = "") -> Dict[str, float]:
Expand Down Expand Up @@ -2047,6 +2059,7 @@ def evaluate(
silent_mode: bool = False,
metrics_progress_verbose: bool = False,
dataset_name: str = "",
max_batches: Optional[int] = None,
) -> Dict[str, float]:
"""
Evaluates the model on given dataloader and metrics.
Expand Down Expand Up @@ -2081,7 +2094,11 @@ def evaluate(
loss_logging_items_names=self.loss_logging_items_names,
)

with tqdm(data_loader, bar_format="{l_bar}{bar:10}{r_bar}", dynamic_ncols=True, disable=silent_mode) as progress_bar_data_loader:
expected_iterations = len(data_loader) if max_batches is None else max_batches

with tqdm(
data_loader, total=expected_iterations, bar_format="{l_bar}{bar:10}{r_bar}", dynamic_ncols=True, disable=silent_mode
) as progress_bar_data_loader:

if not silent_mode:
# PRINT TITLES
Expand All @@ -2091,9 +2108,11 @@ def evaluate(
if epoch:
pbar_start_msg += f" epoch {epoch}"
progress_bar_data_loader.set_description(pbar_start_msg)

with torch.no_grad():
for batch_idx, batch_items in enumerate(progress_bar_data_loader):
if evaluation_type == EvaluationType.VALIDATION and expected_iterations <= batch_idx:
break

batch_items = core_utils.tensor_container_to_device(batch_items, device_config.device, non_blocking=True)
inputs, targets, additional_batch_items = sg_trainer_utils.unpack_batch_items(batch_items)

Expand Down Expand Up @@ -2128,9 +2147,6 @@ def evaluate(

progress_bar_data_loader.set_postfix(**pbar_message_dict)

if evaluation_type == EvaluationType.VALIDATION and self.max_valid_batches is not None and self.max_valid_batches - 1 <= batch_idx:
break

logging_values = get_logging_values(loss_avg_meter, metrics, self.criterion)
# NEED TO COMPUTE METRICS FOR THE FIRST TIME IF PROGRESS VERBOSITY IS NOT SET
if not metrics_progress_verbose:
Expand Down
23 changes: 19 additions & 4 deletions src/super_gradients/training/utils/sg_trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,19 +447,34 @@ def get_callable_param_names(obj: callable) -> Tuple[str]:


def log_main_training_params(
multi_gpu: MultiGPUMode, num_gpus: int, batch_size: int, batch_accumulate: int, train_dataset_length: int, train_dataloader_len: int
multi_gpu: MultiGPUMode,
num_gpus: int,
batch_size: int,
batch_accumulate: int,
train_dataset_length: int,
train_dataloader_len: int,
max_train_batches: Optional[int] = None,
):
"""Log training parameters"""

iterations_per_epoch = int(train_dataloader_len) if max_train_batches is None else max_train_batches
gradients_updates_per_epoch = int(iterations_per_epoch / batch_accumulate)
what_used_str = "len(train_loader)" if max_train_batches is None else "max_train_batches"

msg = (
"TRAINING PARAMETERS:\n"
f" - Mode: {multi_gpu.name if multi_gpu else 'Single GPU'}\n"
f" - Number of GPUs: {num_gpus if 'cuda' in device_config.device else 0:<10} ({torch.cuda.device_count()} available on the machine)\n"
f" - Dataset size: {train_dataset_length:<10} (len(train_set))\n"
f" - Full dataset size: {train_dataset_length:<10} (len(train_set))\n"
f" - Batch size per GPU: {batch_size:<10} (batch_size)\n"
f" - Batch Accumulate: {batch_accumulate:<10} (batch_accumulate)\n"
f" - Total batch size: {num_gpus * batch_size:<10} (num_gpus * batch_size)\n"
f" - Effective Batch size: {num_gpus * batch_size * batch_accumulate:<10} (num_gpus * batch_size * batch_accumulate)\n"
f" - Iterations per epoch: {int(train_dataloader_len):<10} (len(train_loader))\n"
f" - Gradient updates per epoch: {int(train_dataloader_len / batch_accumulate):<10} (len(train_loader) / batch_accumulate)\n"
f" - Iterations per epoch: {iterations_per_epoch:<10} ({what_used_str})\n"
f" - Gradient updates per epoch: {gradients_updates_per_epoch:<10} ({what_used_str} / batch_accumulate)\n"
)

logger.info(msg)

if max_train_batches:
logger.warning(f"max_train_batch is set to {max_train_batches}. This limits the number of iterations per epoch and gradient updates per epoch.")

0 comments on commit 749a9c7

Please sign in to comment.