Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sg 946 simpify torch schedulers integration #1230

Merged
Merged
2 changes: 1 addition & 1 deletion documentation/source/Checkpoints.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ The checkpoint keys:
- `optimizer_state_dict`: The state_dict of the optimizer (state_dict).
- `scaler_state_dict`: Optional - only present when training with [mixed_precision=True](average_mixed_precision.md). The state_dict of Trainer.scaler.
- `ema_net`: Optional - only present when training with [ema=True](EMA.md). The EMA model's state_dict. Note that `average_model.pth` lacks this entry even if ema=True since the average model's snapshots are of the EMA network already (i.e., the "net" entry is already an average of the EMA snapshots).

- `torch_scheduler_state_dict`: Optional, will only be present when using a torch native lr scheduler (see [LRScheduling](LRScheduling.md))
## Remote Checkpoint Saving with SG Loggers

SG supports remote checkpoint saving using 3rd party tools (for example, [Weights & Biases](https://www.google.com/aclk?sa=l&ai=DChcSEwi1iaLxhYj9AhXejWgJHZYqCGIYABAAGgJ3Zg&sig=AOD64_30zInAUka20YKKdULr8PHnLnLWgg&q&adurl&ved=2ahUKEwiKxZvxhYj9AhUzTKQEHSJwCkcQ0Qx6BAgGEAE)).
Expand Down
100 changes: 80 additions & 20 deletions documentation/source/LRScheduling.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
When training deep neural networks, it is often useful to reduce learning rate as the training progresses. This can be done by using pre-defined learning rate schedules or adaptive learning rate methods.
Learning rate scheduling type is controlled by the training parameter `lr_mode`. From `Trainer.train(...)` docs:

`lr_mode` : str
`lr_mode` : Union[str, Mapping]

When str:

Learning rate scheduling policy, one of ['step','poly','cosine','function'].

'step' refers to constant updates at epoch numbers passed through `lr_updates`. Each update decays the learning rate by `lr_decay_factor`.
Expand Down Expand Up @@ -202,36 +205,93 @@ Note that internally, Trainer unpacks [training_params to the scheduler callback
### Using PyTorchs Native LR Schedulers (torch.optim.lr_scheduler)

PyTorch offers a [wide variety of learning rate schedulers](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate).
They can all be easily used by wrapping them up with `LRSchedulerCallback` and passing them as phase_callbacks.
They can all be easily used by passing a Mapping through the lr_mode parameter, following aa simple API.
From `Trainer.train(...)` docs:

When Mapping, refers to a torch.optim.lr_scheduler._LRScheduler, following the below API:

lr_mode = {LR_SCHEDULER_CLASS_NAME: {**LR_SCHEDULER_KWARGS, "phase": XXX, "metric_name": XXX)

Where "phase" (of Phase type) controls when to call torch.optim.lr_scheduler._LRScheduler.step().
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
For instance, in order to:
- Update LR on each batch: Use phase: Phase.TRAIN_BATCH_END
- Update LR after each epoch: Use phase: Phase.TRAIN_EPOCH_END

The "metric_name" refers to the metric to watch (See docs for "metric_to_watch" in train(...)
https://docs.deci.ai/super-gradients/docstring/training/sg_trainer.html) when using
ReduceLROnPlateau. In any other case this kwarg is ignored.

**LR_SCHEDULER_KWARGS are simply passed to the torch scheduler's __init__.



For example:
lr_mode = {"StepLR": {"gamma": 0.1, "step_size": 1, "phase": Phase.TRAIN_EPOCH_END}}
is equivalent to following training code:

from torch.optim.lr_scheduler import StepLR
...
optimizer = ....
scheduler = StepLR(optimizer=optimizer, gamma=0.1, step_size=1)

for epoch in num_epochs:
train_epoch(...)
scheduler.step()
....
For example:

```python
...

trainer = Trainer("torch_Scheduler_example")
train_dataloader = ...
valid_dataloader = ...
model = ...

lr = 2.5e-4
optimizer = SGD(model.parameters(), lr=lr, weight_decay=0.0001)
step_lr_scheduler = MultiStepLR(optimizer, milestones=[0, 150, 200], gamma=0.1)

# Define phase callbacks
phase_callbacks = [
LRSchedulerCallback(scheduler=step_lr_scheduler, phase=Phase.TRAIN_EPOCH_END),
]
train_params = {
"max_epochs": 2,
"lr_mode": {"StepLR": {"gamma": 0.1, "step_size": 1, "phase": Phase.TRAIN_EPOCH_END}},
"lr_warmup_epochs": 0,
"initial_lr": 0.1,
"loss": torch.nn.CrossEntropyLoss(),
"optimizer": "SGD",
"criterion_params": {},
"optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
"train_metrics_list": [Accuracy()],
"valid_metrics_list": [Accuracy()],
"metric_to_watch": "Accuracy",
"greater_metric_to_watch_is_better": True,
}
trainer.train(model=model, training_params=train_params, train_loader=dataloader, valid_loader=dataloader)
```

# Bring everything together with Trainer and start training
trainer = Trainer("torch_schedulers_experiment")
And as stated above, for ReduceLROnPlateau we need to pass a "metric_name", which follows the same
rules as the training parameter "metric_to_watch"(see [metrics guide](Metrics.md) when not familiar).
For example:

```python
trainer = Trainer("torch_ROP_Scheduler_example")
train_dataloader = ...
valid_dataloader = ...
model = ...
train_params = {
...
"phase_callbacks": phase_callbacks,
"initial_lr": lr,
"optimizer": optimizer,
...
"max_epochs": 2,
"lr_decay_factor": 0.1,
"lr_mode": {
"ReduceLROnPlateau": {"patience": 0, "phase": Phase.TRAIN_EPOCH_END, "metric_name": "DummyMetric"}},
"lr_warmup_epochs": 0,
"initial_lr": 0.1,
"loss": torch.nn.CrossEntropyLoss(),
"optimizer": "SGD",
"criterion_params": {},
"optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
"train_metrics_list": [Accuracy()],
"valid_metrics_list": [Accuracy()],
"metric_to_watch": "DummyMetric",
"greater_metric_to_watch_is_better": True,
}

trainer.train(model=net, training_params=train_params, train_loader=train_loader, valid_loader=valid_loader)
trainer.train(model=model, training_params=train_params, train_loader=dataloader, valid_loader=dataloader)

```

The scheduler's `state_dict` is saved under `torch_scheduler_state_dict` entry inside the checkpoint during training,
allowing us to resume from the same state of the scheduling.
Empty file.
14 changes: 14 additions & 0 deletions src/super_gradients/common/registry/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,20 @@ def decorator(cls: Callable) -> Callable:
Optimizers.ADAMW: optim.AdamW,
Optimizers.RMS_PROP: optim.RMSprop,
}

TORCH_LR_SCHEDULERS = {
"StepLR": torch.optim.lr_scheduler.StepLR,
"LambdaLR": torch.optim.lr_scheduler.LambdaLR,
"MultiStepLR": torch.optim.lr_scheduler.MultiStepLR,
"ConstantLR": torch.optim.lr_scheduler.ConstantLR,
"CosineAnnealingLR": torch.optim.lr_scheduler.CosineAnnealingLR,
"CosineAnnealingWarmRestarts": torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
"CyclicLR": torch.optim.lr_scheduler.CyclicLR,
"ExponentialLR": torch.optim.lr_scheduler.ExponentialLR,
"ReduceLROnPlateau": torch.optim.lr_scheduler.ReduceLROnPlateau,
"LinearLR": torch.optim.lr_scheduler.LinearLR,
}

register_optimizer = create_register_decorator(registry=OPTIMIZERS)

PROCESSINGS = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@ resume_from_remote_sg_logger: False # bool (default=False), When true, ckpt_name
# IMPORTANT: For WandB loggers, one must also pass the run id through the wandb_id arg in sg_logger_params.

ckpt_name: ckpt_latest.pth # The checkpoint (.pth file) filename in CKPT_ROOT_DIR/EXPERIMENT_NAME/ to use when resume=True and resume_path=None
lr_mode: # Learning rate scheduling policy, one of ['step','poly','cosine','function']

lr_mode: # Union[str, Mapping]
# when str: Learning rate scheduling policy, one of ['step','poly','cosine','function']
# when Mapping: refers to a torch.optim.lr_scheduler._LRScheduler, following the below API: lr_mode = {LR_SCHEDULER_CLASS_NAME: {**LR_SCHEDULER_KWARGS, "phase": XXX, "metric_name": XXX)

lr_schedule_function: # Learning rate scheduling function to be used when `lr_mode` is 'function'.
lr_warmup_epochs: 0 # number of epochs for learning rate warm up - see https://arxiv.org/pdf/1706.02677.pdf (Section 2.2).
lr_warmup_steps: 0 # number of warmup steps (Used when warmup_mode=linear_batch_step)
Expand Down
71 changes: 57 additions & 14 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from super_gradients.modules.repvgg_block import fuse_repvgg_blocks_residual_branches

from super_gradients.training.utils.sg_trainer_utils import get_callable_param_names
from super_gradients.training.utils.callbacks.callbacks import create_lr_scheduler_callback, LRSchedulerCallback
from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.sg_loggers.abstract_sg_logger import AbstractSGLogger
from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger
Expand Down Expand Up @@ -84,7 +85,7 @@
MetricsUpdateCallback,
LRCallbackBase,
)
from super_gradients.common.registry.registry import LR_SCHEDULERS_CLS_DICT, LR_WARMUP_CLS_DICT
from super_gradients.common.registry.registry import LR_WARMUP_CLS_DICT
from super_gradients.common.environment.device_utils import device_config
from super_gradients.training.utils import HpmStruct
from super_gradients.common.environment.cfg_utils import load_experiment_cfg, add_params_to_cfg, load_recipe
Expand Down Expand Up @@ -212,6 +213,7 @@ def __init__(self, experiment_name: str, device: str = None, multi_gpu: Union[Mu
self.max_valid_batches = None

self._epoch_start_logging_values = {}
self._torch_lr_scheduler = None

@property
def device(self) -> str:
Expand Down Expand Up @@ -600,6 +602,9 @@ def _save_checkpoint(
if processing_params is not None:
state["processing_params"] = processing_params

if self._torch_lr_scheduler is not None:
state["torch_scheduler_state_dict"] = self._torch_lr_scheduler.state_dict()

# SAVES CURRENT MODEL AS ckpt_latest
self.sg_logger.add_checkpoint(tag="ckpt_latest.pth", state_dict=state, global_step=epoch)

Expand Down Expand Up @@ -721,7 +726,9 @@ def train(
Decay factor to apply to the learning rate at each update when `lr_mode='step'`.


- `lr_mode` : str
- `lr_mode` : Union[str, Mapping],

When str:

Learning rate scheduling policy, one of ['step','poly','cosine','function'].

Expand All @@ -734,6 +741,37 @@ def train(

'function' refers to a user-defined learning rate scheduling function, that is passed through `lr_schedule_function`.



When Mapping, refers to a torch.optim.lr_scheduler._LRScheduler, following the below API:

lr_mode = {LR_SCHEDULER_CLASS_NAME: {**LR_SCHEDULER_KWARGS, "phase": XXX, "metric_name": XXX)

Where "phase" (of Phase type) controls when to call torch.optim.lr_scheduler._LRScheduler.step().

The "metric_name" refers to the metric to watch (See docs for "metric_to_watch" in train(...)
https://docs.deci.ai/super-gradients/docstring/training/sg_trainer.html) when using
ReduceLROnPlateau. In any other case this kwarg is ignored.

**LR_SCHEDULER_KWARGS are simply passed to the torch scheduler's __init__.


For example:
lr_mode = {"StepLR": {"gamma": 0.1, "step_size": 1, "phase": Phase.TRAIN_EPOCH_END}}
is equivalent to following training code:

from torch.optim.lr_scheduler import StepLR
...
optimizer = ....
scheduler = StepLR(optimizer=optimizer, gamma=0.1, step_size=1)

for epoch in num_epochs:
train_epoch(...)
scheduler.step()
....



- `lr_schedule_function` : Union[callable,None]

Learning rate scheduling function to be used when `lr_mode` is 'function'.
Expand Down Expand Up @@ -1133,18 +1171,6 @@ def forward(self, inputs, targets):
self.phase_callbacks = self.training_params.phase_callbacks or []
self.phase_callbacks = ListFactory(CallbacksFactory()).get(self.phase_callbacks)

if self.lr_mode is not None:
sg_lr_callback_cls = LR_SCHEDULERS_CLS_DICT[self.lr_mode]
self.phase_callbacks.append(
sg_lr_callback_cls(
train_loader_len=len(self.train_loader),
net=self.net,
training_params=self.training_params,
update_param_groups=self.update_param_groups,
**self.training_params.to_dict(),
)
)

warmup_mode = self.training_params.warmup_mode
warmup_callback_cls = None
if isinstance(warmup_mode, str):
Expand Down Expand Up @@ -1197,6 +1223,23 @@ def forward(self, inputs, targets):
else:
raise UnsupportedOptimizerFormat()

if self.lr_mode is not None:
lr_scheduler_callback = create_lr_scheduler_callback(
lr_mode=self.lr_mode,
train_loader=self.train_loader,
net=self.net,
training_params=self.training_params,
update_param_groups=self.update_param_groups,
optimizer=self.optimizer,
)
self.phase_callbacks.append(lr_scheduler_callback)

# NEED ACCESS TO THE UNDERLYING TORCH SCHEDULER FOR LOADING/SAVING IT'S STATE_DICT
if isinstance(lr_scheduler_callback, LRSchedulerCallback):
self._torch_lr_scheduler = lr_scheduler_callback.scheduler
BloodAxe marked this conversation as resolved.
Show resolved Hide resolved
if self.load_checkpoint:
self._torch_lr_scheduler.load_state_dict(self.checkpoint["torch_scheduler_state_dict"])

# VERIFY GRADIENT CLIPPING VALUE
if self.training_params.clip_grad_norm is not None and self.training_params.clip_grad_norm <= 0:
raise TypeError("Params", "Invalid clip_grad_norm")
Expand Down
Loading