-
Notifications
You must be signed in to change notification settings - Fork 412
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dose composer support best checkpoint saver which can monitoring the checkpoint for best metrics or losses? #2303
Comments
This is a great suggestion! We currently do have an EarlyStopping callback (see docs). Unfortunately, at this time we don't have a way to save best checkpoint, but we will add it to our roadmap. |
I had just mentioned this same issue on the mosaic slack here last week :). |
is there any updates on this? this is a must-have feature. |
Hi @mvpatel2000, we have a custom best checkpoint saver but the issue is how to load the best checkpoint before evaluating, this means after the fit but before evaluation. Using the tooling provided by Composer, it does not seem to work in DDP. We are currently using We have not check with deepspeed nor FSDP. |
@priba would you mind sharing a code example please? We're happy to take a look (and ideally we can add a unit test to ensure this works) |
Hi @mvpatel2000, thanks for your answer, here are some snippets of the code we are using. class BestCheckpointSaver(Callback):
...
def fit_end(self, state: State, logger: Logger) -> None:
if not self.save_last:
return
self.save_checkpoint("last" + DEFAULT_CHECKPOINT_EXTENSION, state, logger, None)
def save_checkpoint(self, filename: str, state: State, logger: Logger, metric: Any) -> Path | None:
destination = format_name_with_dist_and_time(
os.fspath(Path(self.folder) / filename), state.run_name, state.timestamp
)
saved_path = checkpoint.save_checkpoint(state, destination, weights_only=self._weights_only)
if saved_path is None: # not all ranks save
return None
self.saved_checkpoints.append((Path(saved_path), state.timestamp.copy(), metric))
log.debug(f"Uploading checkpoint to {destination} ...")
logger.upload_file(remote_file_name=destination, file_path=saved_path)
return Path(saved_path) Then just before the eval call, we manually use a new entry point into the trainer: class MyTrainer(Trainer):
...
def load_best_checkpoint(self) -> None:
"""Loads the best checkpoint as tracked by BestCheckpointSaver.
Example:
```
trainer = MyTrainer(
model=model,
max_duration="100ep",
callbacks=[BestCheckpointSaver("cross-entropy", mode="min")],
...
)
trainer.fit(...)
trainer.load_best_checkpoint()
trainer.eval(...)
```
"""
if is_model_ddp(self.state.model) or is_model_deepspeed(self.state.model) or is_model_fsdp(self.state.model):
log.warn(
"`load_best_checkpoint` is not implemented for DDP, DeepSpeed and FSDP. "
"Last weights will be used. Run test independently to load the best model."
)
return
checkpoint_savers = [callback for callback in self.state.callbacks if isinstance(callback, BestCheckpointSaver)]
if len(checkpoint_savers) == 0:
raise ValueError("A `BestCheckpointSaver` callback must be provided to call this method.")
if len(checkpoint_savers) > 1:
log.warning(f"Several `BestCheckpointSaver` were provided. Evaluating with {checkpoint_savers[0]!r}.")
ckpt = checkpoint_savers[0].best_checkpoint
if ckpt is None:
raise ValueError("No best checkpoint was found. Have you ran `.fit()` with an `eval_dataloader`?")
checkpoint.load_checkpoint(os.fspath(ckpt[0]), self.state, self.logger) We have tried different combinations such as running in rank 0, brodcasting the checkpoint path and even brodcasting the model weights but we didn't manage to make it work in DDP. |
@priba looking into this... it would also be helpful if you could elaborate on what "it does not seem to work in DDP" means. Could you provide an error trace or some kind of description? This design seems more or less correct to me... |
Hi @mvpatel2000 , the issues we have are:
|
Only rank 0 saves the checkpoint to avoid duplicate work. But, all ranks need to load the checkpoint (with DDP, if you use FSDP we support a path for broadcasting). If you share a filesystem, broadcasting name should be sufficient as you described
This is an ugly pytorch detail -- they add an extra level of Torch 2.3 exposes a new API from torch |
Thanks for the info @mvpatel2000 If we broadcast the string it does not throw any error message. It just skips loading the model as the keys are not there and throws a warning with the set of ignored keys. I don't think this is very informative for your case. I will check once support for 2.3 is released and report back to you. Thank you so much |
Ah, I think this is a bug. If you load from
Great! Please let me know if you encounter any errors after the release :) |
Hi @mvpatel2000, I am afraid the problem persists. Here I have some minimum working code: from composer import Trainer
from composer.algorithms import ChannelsLast, CutMix, LabelSmoothing, BlurPool
from composer.core import DataSpec
from composer.models import ComposerClassifier
from composer.utils import dist, checkpoint
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, transforms
# Define Model
num_classes: int = 10
resnet = torchvision.models.resnet18()
resnet.fc = nn.Linear(512, num_classes)
model = ComposerClassifier(module=resnet, num_classes=num_classes)
# Normalization constants
mean = (0.507, 0.487, 0.441)
std = (0.267, 0.256, 0.276)
batch_size = 1024
cifar10_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
# Download Data
data_directory = "./data"
train_dataset = datasets.CIFAR10(data_directory, train=True, download=True, transform=cifar10_transforms)
eval_dataset = datasets.CIFAR10(data_directory, train=False, download=True, transform=cifar10_transforms)
# Build DataSpec
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, sampler=dist.get_sampler(train_dataset, drop_last=True, shuffle=True)
)
train_spec = DataSpec(train_dataloader, device_transforms=None, get_num_samples_in_batch=lambda batch: len(batch[0]))
eval_dataloader = torch.utils.data.DataLoader(
eval_dataset, batch_size=batch_size, sampler=dist.get_sampler(eval_dataset, drop_last=False, shuffle=False)
)
eval_spec = DataSpec(eval_dataloader, device_transforms=None, get_num_samples_in_batch=lambda batch: len(batch[0]))
trainer = Trainer(
model=model,
train_dataloader=train_spec,
eval_dataloader=eval_spec,
max_duration="2ep",
algorithms=[BlurPool(), LabelSmoothing(smoothing=0.1), CutMix(alpha=1.0), ChannelsLast()],
save_folder="./checkpoints",
save_filename="ep{epoch}.pt",
save_latest_filename="latest",
)
trainer.fit()
checkpoint.load_checkpoint("./checkpoints/ep1.pt", trainer.state, trainer.logger)
trainer.eval() This will run into key problems as DDP wraps the model around a If I force that the checkpoint loading to only happen in rank 0 it hangs using or not if dist.get_global_rank() == 0:
checkpoint.load_checkpoint("./checkpoints/ep1.pt", trainer.state, trainer.logger)
else:
dist.barrier() Do you think is there any workaround? |
@priba Hm... your snippet seems to work for me:
Can you double check you are on Composer v0.22 and torch 2.3? |
@mvpatel2000 my bad :(. I just updated composer and not torch. After the update the snippet worked smoothly. Thanks for the support. Next week I will do tests in a real use case. |
Great to hear! Please let me know if you encounter any further issues. |
@mvpatel2000 out of curiosity, may I ask you which GPUs did you use to run this test? Its much faster than what I got. Are these H100? |
I was debugging this on 2xA100 |
Hi @mvpatel2000 ! Thanks for providing feedback and insights into the the latest feature in PyTorch. We are also using 2xA100 with the same code snippet than shared above but our throughput is quite a bit lower, around ~4.1 ba/sec when you reach >6.6ba/sec. We are using A100 with 40GB of vRAM with CUDA 12.1 and the latest version of PyTorch. Are you using A100 with that amount of vRAM or GPUS with 80GB? Would you mind sharing the output of |
Could be CPU bottlenecked since its CIFAR 🤷 |
🚀 Feature Request
From the Documentation in Checkpointing part, I didn't find there is a best checkpoint saver function supported.
best checkpoint saver refers to the function that monitors the user-specified metrics or losses. It will save the best checkpoint for such particular best metrics or losses during the training process.
So I'm wondering if such a function already supports it or not.
In Keras, you can use the EarlyStopping function to achieve such needs.
https://stackoverflow.com/questions/48285129/saving-best-model-in-keras
Motivation
It would be convenient to have the best checkpoint saver function, in such case that I want to have the best checkpoint for a particular metric or loss. For example, for the Denoising task, I want to have the best checkpoint for PSNR metric and another best checkpoint for SSIM metric.
[Optional] Implementation
CheckpointSaver class in composer.
Additional context
The text was updated successfully, but these errors were encountered: