Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dose composer support best checkpoint saver which can monitoring the checkpoint for best metrics or losses? #2303

Open
mayujie opened this issue Jun 14, 2023 · 20 comments
Labels
enhancement New (engineering) enhancements, such as features or API changes.

Comments

@mayujie
Copy link

mayujie commented Jun 14, 2023

🚀 Feature Request

From the Documentation in Checkpointing part, I didn't find there is a best checkpoint saver function supported.
best checkpoint saver refers to the function that monitors the user-specified metrics or losses. It will save the best checkpoint for such particular best metrics or losses during the training process.
So I'm wondering if such a function already supports it or not.

In Keras, you can use the EarlyStopping function to achieve such needs.
https://stackoverflow.com/questions/48285129/saving-best-model-in-keras

Motivation

It would be convenient to have the best checkpoint saver function, in such case that I want to have the best checkpoint for a particular metric or loss. For example, for the Denoising task, I want to have the best checkpoint for PSNR metric and another best checkpoint for SSIM metric.

[Optional] Implementation

CheckpointSaver class in composer.

Additional context

@mayujie mayujie added the enhancement New (engineering) enhancements, such as features or API changes. label Jun 14, 2023
@mvpatel2000
Copy link
Contributor

This is a great suggestion!

We currently do have an EarlyStopping callback (see docs). Unfortunately, at this time we don't have a way to save best checkpoint, but we will add it to our roadmap.

@ez2rok
Copy link

ez2rok commented Jul 26, 2023

I had just mentioned this same issue on the mosaic slack here last week :).

@yasinalm
Copy link

is there any updates on this? this is a must-have feature.

@mvpatel2000
Copy link
Contributor

@yasinalm unfortunately, we haven't added anything to this effect yet. I'd recommend extending the Early Stopper Callback to call save_checkpoint on the checkpoint saver callback or Trainer as a workaround

@priba
Copy link
Contributor

priba commented Apr 24, 2024

Hi @mvpatel2000, we have a custom best checkpoint saver but the issue is how to load the best checkpoint before evaluating, this means after the fit but before evaluation.

Using the tooling provided by Composer, it does not seem to work in DDP. We are currently using checkpoint.load_checkpoint and we have tested to load it in a single rank or broadcast the checkpoint to several ranks without luck. Wrapping the model with the DistributedDataParallel, changes the keymapping.

We have not check with deepspeed nor FSDP.

@mvpatel2000
Copy link
Contributor

Hi @mvpatel2000, we have a custom best checkpoint saver but the issue is how to load the best checkpoint before evaluating, this means after the fit but before evaluation.

Using the tooling provided by Composer, it does not seem to work in DDP. We are currently using checkpoint.load_checkpoint and we have tested to load it in a single rank or broadcast the checkpoint to several ranks without luck. Wrapping the model with the DistributedDataParallel, changes the keymapping.

We have not check with deepspeed nor FSDP.

@priba would you mind sharing a code example please? We're happy to take a look (and ideally we can add a unit test to ensure this works)

@priba
Copy link
Contributor

priba commented Apr 25, 2024

Hi @mvpatel2000, thanks for your answer, here are some snippets of the code we are using.

class BestCheckpointSaver(Callback):

    ...

    def fit_end(self, state: State, logger: Logger) -> None:
        if not self.save_last:
            return

        self.save_checkpoint("last" + DEFAULT_CHECKPOINT_EXTENSION, state, logger, None)

    def save_checkpoint(self, filename: str, state: State, logger: Logger, metric: Any) -> Path | None:
        destination = format_name_with_dist_and_time(
            os.fspath(Path(self.folder) / filename), state.run_name, state.timestamp
        )
        saved_path = checkpoint.save_checkpoint(state, destination, weights_only=self._weights_only)

        if saved_path is None:  # not all ranks save
            return None

        self.saved_checkpoints.append((Path(saved_path), state.timestamp.copy(), metric))
        log.debug(f"Uploading checkpoint to {destination} ...")
        logger.upload_file(remote_file_name=destination, file_path=saved_path)
        return Path(saved_path)

Then just before the eval call, we manually use a new entry point into the trainer:

class MyTrainer(Trainer):
    ...

    def load_best_checkpoint(self) -> None:
        """Loads the best checkpoint as tracked by BestCheckpointSaver.

        Example:
            ```
            trainer = MyTrainer(
                model=model,
                max_duration="100ep",
                callbacks=[BestCheckpointSaver("cross-entropy", mode="min")],
                ...
            )
            trainer.fit(...)
            trainer.load_best_checkpoint()
            trainer.eval(...)
            ```
        """

        if is_model_ddp(self.state.model) or is_model_deepspeed(self.state.model) or is_model_fsdp(self.state.model):
            log.warn(
                "`load_best_checkpoint` is not implemented for DDP, DeepSpeed and FSDP. "
                "Last weights will be used. Run test independently to load the best model."
            )
            return

        checkpoint_savers = [callback for callback in self.state.callbacks if isinstance(callback, BestCheckpointSaver)]

        if len(checkpoint_savers) == 0:
            raise ValueError("A `BestCheckpointSaver` callback must be provided to call this method.")

        if len(checkpoint_savers) > 1:
            log.warning(f"Several `BestCheckpointSaver` were provided. Evaluating with {checkpoint_savers[0]!r}.")

        ckpt = checkpoint_savers[0].best_checkpoint
        if ckpt is None:
            raise ValueError("No best checkpoint was found. Have you ran `.fit()` with an `eval_dataloader`?")

        checkpoint.load_checkpoint(os.fspath(ckpt[0]), self.state, self.logger)

We have tried different combinations such as running in rank 0, brodcasting the checkpoint path and even brodcasting the model weights but we didn't manage to make it work in DDP.

@mvpatel2000
Copy link
Contributor

@priba looking into this... it would also be helpful if you could elaborate on what "it does not seem to work in DDP" means. Could you provide an error trace or some kind of description? This design seems more or less correct to me...

@priba
Copy link
Contributor

priba commented Apr 26, 2024

Hi @mvpatel2000 , the issues we have are:

  • As it is, ranks besides 0 will not have a checkpoint path and the Value error in ckpt is None will be raised.
  • In case we just load the checkpoint in rank 0, it just runs into a deadlock
  • In our case, we brodcasted the string ckpt to all ranks so we can load it everywhere, however, the keys of the state dict are not matching as DDP adds a module.*. I've seen that this should be removed inside the load_checkpoint but I am not sure why it is not happening.

@mvpatel2000
Copy link
Contributor

mvpatel2000 commented Apr 26, 2024

@priba

  • As it is, ranks besides 0 will not have a checkpoint path and the Value error in ckpt is None will be raised.
  • In case we just load the checkpoint in rank 0, it just runs into a deadlock
  • In our case, we brodcasted the string ckpt to all ranks so we can load it everywhere,

Only rank 0 saves the checkpoint to avoid duplicate work. But, all ranks need to load the checkpoint (with DDP, if you use FSDP we support a path for broadcasting). If you share a filesystem, broadcasting name should be sufficient as you described

  • however, the keys of the state dict are not matching as DDP adds a module.*. I've seen that this should be removed inside the load_checkpoint but I am not sure why it is not happening.

This is an ugly pytorch detail -- they add an extra level of module. when using DDP for reasons I am not familiar with. Composer should appropriately autostrip these on state_dict calls to state, but maybe this codepath fails. Could you share the exact error?

Torch 2.3 exposes a new API from torch get_model_state_dict which also handles this internally for PyTorch. So, when we release next week with 2.3 support and the new APIs, it might just go away for you

@priba
Copy link
Contributor

priba commented Apr 28, 2024

Thanks for the info @mvpatel2000

If we broadcast the string it does not throw any error message. It just skips loading the model as the keys are not there and throws a warning with the set of ignored keys. I don't think this is very informative for your case.

I will check once support for 2.3 is released and report back to you. Thank you so much

@mvpatel2000
Copy link
Contributor

If we broadcast the string it does not throw any error message. It just skips loading the model as the keys are not there and throws a warning with the set of ignored keys. I don't think this is very informative for your case.

Ah, I think this is a bug. If you load from Trainer, it will default raise an error if the keys are not there (load_strict_model_weights defaults to True). However, if you call load_checkpoint directly, it defaults to False whereas it should be consistent. I've opened a PR to fix this! #3219

I will check once support for 2.3 is released and report back to you. Thank you so much

Great! Please let me know if you encounter any errors after the release :)

@priba
Copy link
Contributor

priba commented May 3, 2024

Hi @mvpatel2000, I am afraid the problem persists. Here I have some minimum working code:

from composer import Trainer
from composer.algorithms import ChannelsLast, CutMix, LabelSmoothing, BlurPool
from composer.core import DataSpec
from composer.models import ComposerClassifier
from composer.utils import dist, checkpoint
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, transforms


# Define Model
num_classes: int = 10
resnet = torchvision.models.resnet18()
resnet.fc = nn.Linear(512, num_classes)
model = ComposerClassifier(module=resnet, num_classes=num_classes)


# Normalization constants
mean = (0.507, 0.487, 0.441)
std = (0.267, 0.256, 0.276)
batch_size = 1024
cifar10_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])

# Download Data
data_directory = "./data"
train_dataset = datasets.CIFAR10(data_directory, train=True, download=True, transform=cifar10_transforms)
eval_dataset = datasets.CIFAR10(data_directory, train=False, download=True, transform=cifar10_transforms)

# Build DataSpec
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, sampler=dist.get_sampler(train_dataset, drop_last=True, shuffle=True)
)
train_spec = DataSpec(train_dataloader, device_transforms=None, get_num_samples_in_batch=lambda batch: len(batch[0]))

eval_dataloader = torch.utils.data.DataLoader(
    eval_dataset, batch_size=batch_size, sampler=dist.get_sampler(eval_dataset, drop_last=False, shuffle=False)
)
eval_spec = DataSpec(eval_dataloader, device_transforms=None, get_num_samples_in_batch=lambda batch: len(batch[0]))

trainer = Trainer(
    model=model,
    train_dataloader=train_spec,
    eval_dataloader=eval_spec,
    max_duration="2ep",
    algorithms=[BlurPool(), LabelSmoothing(smoothing=0.1), CutMix(alpha=1.0), ChannelsLast()],
    save_folder="./checkpoints",
    save_filename="ep{epoch}.pt",
    save_latest_filename="latest",
)
trainer.fit()
checkpoint.load_checkpoint("./checkpoints/ep1.pt", trainer.state, trainer.logger)
trainer.eval()

This will run into key problems as DDP wraps the model around a module.: Found these unexpected keys in the checkpoint: module.bn1.bias,....

If I force that the checkpoint loading to only happen in rank 0 it hangs using or not dist.barrier().

if dist.get_global_rank() == 0:
    checkpoint.load_checkpoint("./checkpoints/ep1.pt", trainer.state, trainer.logger)
else:
    dist.barrier()

Do you think is there any workaround?

@mvpatel2000
Copy link
Contributor

@priba Hm... your snippet seems to work for me:

train          Epoch   0:  100%|_________________________| 25/25 [00:04<00:00,  6.76ba/s, loss/train/total=2.1885]                                                                                                                                  /mnt/workdisk/mihirp/composer/composer/core/data_spec.py:37: UserWarning: Cannot split tensor of length 904 into batches of size 1024. As it is smaller, no splitting will be done. This may happen on the last batch of a dataset if it is a smaller size than the microbatch size.
  warnings.warn(

eval           Epoch   0:  100%|_________________________| 5/5 [00:00<00:00,  6.62ba/s, metrics/eval/CrossEntropy=2.0713, metrics/eval/MulticlassAccuracy=0.2756]
train          Epoch   1:  100%|_________________________| 25/25 [00:03<00:00,  6.36ba/s, loss/train/total=1.9583]
eval           Epoch   1:  100%|_________________________| 5/5 [00:00<00:00,  6.92ba/s, metrics/eval/CrossEntropy=1.7521, metrics/eval/MulticlassAccuracy=0.3656]
eval           Epoch   0:  100%|_________________________| 5/5 [00:00<00:00,  7.07ba/s, metrics/eval/CrossEntropy=2.0718, metrics/eval/MulticlassAccuracy=0.2750]

Can you double check you are on Composer v0.22 and torch 2.3?

@priba
Copy link
Contributor

priba commented May 4, 2024

@mvpatel2000 my bad :(. I just updated composer and not torch. After the update the snippet worked smoothly. Thanks for the support.

Next week I will do tests in a real use case.

@mvpatel2000
Copy link
Contributor

Great to hear! Please let me know if you encounter any further issues.

@priba
Copy link
Contributor

priba commented May 6, 2024

@mvpatel2000 out of curiosity, may I ask you which GPUs did you use to run this test? Its much faster than what I got. Are these H100?

@mvpatel2000
Copy link
Contributor

@mvpatel2000 out of curiosity, may I ask you which GPUs did you use to run this test? Its much faster than what I got. Are these H100?

I was debugging this on 2xA100

@antoinebrl
Copy link
Contributor

Hi @mvpatel2000 ! Thanks for providing feedback and insights into the the latest feature in PyTorch.

We are also using 2xA100 with the same code snippet than shared above but our throughput is quite a bit lower, around ~4.1 ba/sec when you reach >6.6ba/sec. We are using A100 with 40GB of vRAM with CUDA 12.1 and the latest version of PyTorch. Are you using A100 with that amount of vRAM or GPUS with 80GB? Would you mind sharing the output of composer_collect_env so I can compare?

@mvpatel2000
Copy link
Contributor

We are also using 2xA100 with the same code snippet than shared above but our throughput is quite a bit lower, around ~4.1 ba/sec when you reach >6.6ba/sec.

---------------------------------
System Environment Report
Created: 2024-05-14 17:23:20 UTC
---------------------------------

PyTorch information
-------------------

CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-40GB
GPU 1: NVIDIA A100-SXM4-40GB

Nvidia driver version: 535.129.03
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Byte Order:                         Little Endian
Address sizes:                      48 bits physical, 48 bits virtual
CPU(s):                             64
On-line CPU(s) list:                0-63
Thread(s) per core:                 1
Core(s) per socket:                 32
Socket(s):                          2
NUMA node(s):                       8
Vendor ID:                          AuthenticAMD
CPU family:                         25
Model:                              1
Model name:                         AMD EPYC 7513 32-Core Processor
Stepping:                           1
Frequency boost:                    enabled
CPU MHz:                            2600.000
CPU max MHz:                        3681.6399
CPU min MHz:                        1500.0000
BogoMIPS:                           5199.55
Virtualization:                     AMD-V
L1d cache:                          2 MiB
L1i cache:                          2 MiB
L2 cache:                           32 MiB
L3 cache:                           256 MiB
NUMA node0 CPU(s):                  0-7
NUMA node1 CPU(s):                  8-15
NUMA node2 CPU(s):                  16-23
NUMA node3 CPU(s):                  24-31
NUMA node4 CPU(s):                  32-39
NUMA node5 CPU(s):                  40-47
NUMA node6 CPU(s):                  48-55
NUMA node7 CPU(s):                  56-63
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] pytorch-ranger==0.1.1
[pip3] torch==2.3.0+cu121
[pip3] torch-optimizer==0.3.0
[pip3] torchmetrics==1.3.2
[pip3] torchvision==0.18.0+cu121
[pip3] triton==2.3.0
[conda] Could not collect


Composer information
--------------------
Composer version: 0.22.0
Composer commit hash: None
Host processor model name: AMD EPYC 7513 32-Core Processor
Host processor core count: 64
Number of nodes: 0
Accelerator model name: NVIDIA A100-SXM4-40GB
Accelerators per node: 2
CUDA Device Count: 2

Could be CPU bottlenecked since its CIFAR 🤷

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New (engineering) enhancements, such as features or API changes.
Projects
None yet
Development

No branches or pull requests

6 participants