From 408179ee3836150cb42d8f947863122554740d2d Mon Sep 17 00:00:00 2001 From: Alexandre Ghelfi Date: Sun, 4 Feb 2024 13:19:28 +0100 Subject: [PATCH 01/29] adding _get_dist_sampler function --- composer/trainer/trainer.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 46843efa50..98cb0dfcb9 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -465,6 +465,15 @@ def _generate_run_name() -> str: return generated_run_name +def _get_dist_sampler(dataloader: Any) -> DistributedSampler | None: + if not isinstance(dataloader, DataLoader): + return + if dataloader.sampler is not None: + return dataloader.sampler + else: + return dataloader.batch_sampler + + class Trainer: """Train models with Composer algorithms. From 514a8b11ea7f010b61a0c8570cf600e88dbedaf6 Mon Sep 17 00:00:00 2001 From: Alexandre Ghelfi Date: Sun, 4 Feb 2024 13:26:16 +0100 Subject: [PATCH 02/29] apply _get_dist_sampler across all trainer.py --- composer/trainer/trainer.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 98cb0dfcb9..225aca1ba0 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -2272,8 +2272,9 @@ def _spin_dataloaders_to_cur_epoch(self): eval_state = self.state.dataset_resumption.get('eval', {}) for evaluator in self.state.evaluators: dataloader = evaluator.dataloader.dataloader - if isinstance(dataloader, DataLoader) and isinstance(dataloader.sampler, DistributedSampler): - dataloader.sampler.set_epoch(0) + sampler = _get_dist_sampler(dataloader) + if isinstance(sampler, DistributedSampler): + sampler.set_epoch(0) if evaluator.label not in eval_state: for _ in dataloader: break @@ -2283,8 +2284,9 @@ def _spin_dataloaders_to_cur_epoch(self): assert dataloader is not None, 'train dataloader is set on state after FIT_START' if 'train' not in self.state.dataset_resumption: for epoch in range(int(self.state.timestamp.epoch)): - if isinstance(dataloader, DataLoader) and isinstance(dataloader.sampler, DistributedSampler): - dataloader.sampler.set_epoch(epoch) + sampler = _get_dist_sampler(dataloader) + if isinstance(sampler, DistributedSampler): + sampler.set_epoch(epoch) for _ in dataloader: break @@ -2366,9 +2368,10 @@ def _train_loop(self) -> None: self.logger.log_metrics({'time/epoch': self.state.timestamp.epoch.value}) dataloader = self.state.dataloader - if isinstance(dataloader, DataLoader) and isinstance(dataloader.sampler, DistributedSampler): - dataloader.sampler.set_epoch(int(self.state.timestamp.epoch)) - + sampler = _get_dist_sampler(dataloader) + if isinstance(sampler, DistributedSampler): + sampler.set_epoch(int(self.state.timestamp.epoch)) + for batch_idx, self.state.batch in enumerate(self._iter_dataloader(TrainerMode.TRAIN)): # Spin dataloader forward unless dataloader handles internally with dataset_resumption if self.spin_dataloaders and 'train' not in self.state.dataset_resumption and batch_idx < int( @@ -3221,19 +3224,19 @@ def _eval_loop( drop_last = None dataset_len = None last_batch = False - if isinstance(dataloader, DataLoader) and isinstance(dataloader.sampler, DistributedSampler): + sampler = _get_dist_sampler(dataloader) + if isinstance(sampler, DistributedSampler): # The distributed sampler uses `set_epoch` to set the random seed # Because evaluation can run on each batch, we use the batch to seed the sampler # so each evaluation will get a proper shuffle. # The epoch provided to `set_epoch` need not be sequential, so this is fine. - dist_sampler = dataloader.sampler - dist_sampler.set_epoch(int(self.state.timestamp.batch)) + sampler.set_epoch(int(self.state.timestamp.batch)) drop_last = dataloader.drop_last # Only compute the dataset length if drop_last is False, as otherwise we don't need # to remove any duplicate samples. if drop_last == False: try: - dataset_len = len(dist_sampler.dataset) # type: ignore + dataset_len = len(sampler.dataset) # type: ignore except AttributeError: warnings.warn( "DistributedSampler's dataset does not have length defined. When " From ef5540d9201cd5d447e94f635ca06191a16ef888 Mon Sep 17 00:00:00 2001 From: Alexandre Ghelfi Date: Sun, 17 Mar 2024 16:36:36 +0100 Subject: [PATCH 03/29] Switch condition to check if batch_sampler if filled first since sampler is always defined --- composer/trainer/trainer.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 225aca1ba0..6214cee54e 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -48,7 +48,7 @@ from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler from torch.nn.parallel import DistributedDataParallel from torch.optim.lr_scheduler import LRScheduler -from torch.utils.data import DataLoader, DistributedSampler +from torch.utils.data import DataLoader, DistributedSampler, Sampler from torchmetrics import Metric from composer.callbacks import CheckpointSaver, MemorySnapshot, OOMObserver, OptimizerMonitor @@ -465,13 +465,12 @@ def _generate_run_name() -> str: return generated_run_name -def _get_dist_sampler(dataloader: Any) -> DistributedSampler | None: +def _get_sampler(dataloader: Any) -> Sampler | None: if not isinstance(dataloader, DataLoader): return - if dataloader.sampler is not None: - return dataloader.sampler - else: + if dataloader.batch_sampler is not None: return dataloader.batch_sampler + return dataloader.sampler class Trainer: @@ -2272,7 +2271,7 @@ def _spin_dataloaders_to_cur_epoch(self): eval_state = self.state.dataset_resumption.get('eval', {}) for evaluator in self.state.evaluators: dataloader = evaluator.dataloader.dataloader - sampler = _get_dist_sampler(dataloader) + sampler = _get_sampler(dataloader) if isinstance(sampler, DistributedSampler): sampler.set_epoch(0) if evaluator.label not in eval_state: @@ -2284,7 +2283,7 @@ def _spin_dataloaders_to_cur_epoch(self): assert dataloader is not None, 'train dataloader is set on state after FIT_START' if 'train' not in self.state.dataset_resumption: for epoch in range(int(self.state.timestamp.epoch)): - sampler = _get_dist_sampler(dataloader) + sampler = _get_sampler(dataloader) if isinstance(sampler, DistributedSampler): sampler.set_epoch(epoch) for _ in dataloader: @@ -2368,10 +2367,10 @@ def _train_loop(self) -> None: self.logger.log_metrics({'time/epoch': self.state.timestamp.epoch.value}) dataloader = self.state.dataloader - sampler = _get_dist_sampler(dataloader) + sampler = _get_sampler(dataloader) if isinstance(sampler, DistributedSampler): sampler.set_epoch(int(self.state.timestamp.epoch)) - + for batch_idx, self.state.batch in enumerate(self._iter_dataloader(TrainerMode.TRAIN)): # Spin dataloader forward unless dataloader handles internally with dataset_resumption if self.spin_dataloaders and 'train' not in self.state.dataset_resumption and batch_idx < int( @@ -3224,7 +3223,7 @@ def _eval_loop( drop_last = None dataset_len = None last_batch = False - sampler = _get_dist_sampler(dataloader) + sampler = _get_sampler(dataloader) if isinstance(sampler, DistributedSampler): # The distributed sampler uses `set_epoch` to set the random seed # Because evaluation can run on each batch, we use the batch to seed the sampler From 413c4c3087014af0577424662f6a4cc39866dd7f Mon Sep 17 00:00:00 2001 From: Alexandre Ghelfi Date: Sun, 17 Mar 2024 16:39:57 +0100 Subject: [PATCH 04/29] adding docstring --- composer/trainer/trainer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 6214cee54e..a3c9c615e3 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -466,6 +466,10 @@ def _generate_run_name() -> str: def _get_sampler(dataloader: Any) -> Sampler | None: + """Checks if dataloader is `torch.utils.data.DataLoader` and return the batch_sampler if defined, + else the regular sampler. + If `dataloader` is not a `torch.utils.data.DataLoader`, returns None. + """ if not isinstance(dataloader, DataLoader): return if dataloader.batch_sampler is not None: From ed8535ffd473594ed5f478fd2ef0f94c1c6b2041 Mon Sep 17 00:00:00 2001 From: Alexandre Ghelfi Date: Sun, 17 Mar 2024 17:22:47 +0100 Subject: [PATCH 05/29] removing antipattenr from _get_sampler and linting --- composer/trainer/trainer.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index a3c9c615e3..e02b20a5ea 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -465,13 +465,11 @@ def _generate_run_name() -> str: return generated_run_name -def _get_sampler(dataloader: Any) -> Sampler | None: - """Checks if dataloader is `torch.utils.data.DataLoader` and return the batch_sampler if defined, - else the regular sampler. - If `dataloader` is not a `torch.utils.data.DataLoader`, returns None. +def _get_sampler(dataloader: DataLoader) -> Sampler | Iterable: + """Fetch the sampler from a `dataloader`. + + Returns `dalaoder.batch_sampler` is defined, else `dataloader.sampler` (always defined in `Dataloader.__init__`). """ - if not isinstance(dataloader, DataLoader): - return if dataloader.batch_sampler is not None: return dataloader.batch_sampler return dataloader.sampler @@ -2275,7 +2273,7 @@ def _spin_dataloaders_to_cur_epoch(self): eval_state = self.state.dataset_resumption.get('eval', {}) for evaluator in self.state.evaluators: dataloader = evaluator.dataloader.dataloader - sampler = _get_sampler(dataloader) + sampler = _get_sampler(dataloader) if isinstance(dataloader, DataLoader) else None if isinstance(sampler, DistributedSampler): sampler.set_epoch(0) if evaluator.label not in eval_state: @@ -2287,7 +2285,7 @@ def _spin_dataloaders_to_cur_epoch(self): assert dataloader is not None, 'train dataloader is set on state after FIT_START' if 'train' not in self.state.dataset_resumption: for epoch in range(int(self.state.timestamp.epoch)): - sampler = _get_sampler(dataloader) + sampler = _get_sampler(dataloader) if isinstance(dataloader, DataLoader) else None if isinstance(sampler, DistributedSampler): sampler.set_epoch(epoch) for _ in dataloader: @@ -2371,7 +2369,7 @@ def _train_loop(self) -> None: self.logger.log_metrics({'time/epoch': self.state.timestamp.epoch.value}) dataloader = self.state.dataloader - sampler = _get_sampler(dataloader) + sampler = _get_sampler(dataloader) if isinstance(dataloader, DataLoader) else None if isinstance(sampler, DistributedSampler): sampler.set_epoch(int(self.state.timestamp.epoch)) @@ -3227,8 +3225,8 @@ def _eval_loop( drop_last = None dataset_len = None last_batch = False - sampler = _get_sampler(dataloader) - if isinstance(sampler, DistributedSampler): + sampler = _get_sampler(dataloader) if isinstance(dataloader, DataLoader) else None + if isinstance(sampler, DistributedSampler) and isinstance(dataloader, DataLoader): # The distributed sampler uses `set_epoch` to set the random seed # Because evaluation can run on each batch, we use the batch to seed the sampler # so each evaluation will get a proper shuffle. From ca7873f9728638839ee5f5bed3af99847e58f0be Mon Sep 17 00:00:00 2001 From: Jane Zhang Date: Mon, 18 Mar 2024 13:13:35 -0700 Subject: [PATCH 06/29] removed default attrs from exception in the attrs dict (#3126) --- composer/loggers/mosaicml_logger.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/composer/loggers/mosaicml_logger.py b/composer/loggers/mosaicml_logger.py index 296a7db862..a00f6f43c0 100644 --- a/composer/loggers/mosaicml_logger.py +++ b/composer/loggers/mosaicml_logger.py @@ -271,12 +271,16 @@ def dict_to_str(data: Dict[str, Any]): def exception_to_json_serializable_dict(exc: Exception): """Converts exception into a JSON serializable dictionary for run metadata.""" + default_exc_attrs = set(dir(Exception())) exc_data = {'class': exc.__class__.__name__, 'message': str(exc), 'attributes': {}} + for attr in dir(exc): - if not attr.startswith('__') and attr not in ['args', 'with_traceback']: - # ignore the traceback and default args in exception object + # Exclude default attributes and special methods + if attr not in default_exc_attrs and not attr.startswith('__'): try: value = getattr(exc, attr) + if callable(value): + continue if isinstance(value, (str, int, float, bool, list, dict, type(None))): exc_data['attributes'][attr] = value else: From 027156ec1be96baa0c947b11ec2bb37ab055436f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 18 Mar 2024 17:40:10 -0400 Subject: [PATCH 07/29] Bump coverage[toml] from 7.4.3 to 7.4.4 (#3121) Bumps [coverage[toml]](https://github.com/nedbat/coveragepy) from 7.4.3 to 7.4.4. - [Release notes](https://github.com/nedbat/coveragepy/releases) - [Changelog](https://github.com/nedbat/coveragepy/blob/master/CHANGES.rst) - [Commits](https://github.com/nedbat/coveragepy/compare/7.4.3...7.4.4) --- updated-dependencies: - dependency-name: coverage[toml] dependency-type: direct:development update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 3a04fd3a20..2be590eadc 100644 --- a/setup.py +++ b/setup.py @@ -102,7 +102,7 @@ def package_files(prefix: str, directory: str, extension: str): # Should manually update dependency versions occassionally. 'custom_inherit==2.4.1', 'junitparser==3.1.2', - 'coverage[toml]==7.4.3', + 'coverage[toml]==7.4.4', 'fasteners==0.18', # object store tests require fasteners 'pytest==7.4.4', 'ipython==8.11.0', From 9f768a4d49381a749de0b5aa135e355143ffdb12 Mon Sep 17 00:00:00 2001 From: Practicinginhell <147235329+Practicinginhell@users.noreply.github.com> Date: Tue, 19 Mar 2024 04:42:14 +0700 Subject: [PATCH 08/29] Refactor initialization (#3127) Co-authored-by: practicinginhell --- composer/models/initializers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/composer/models/initializers.py b/composer/models/initializers.py index c3764b92a7..2a6b14e341 100644 --- a/composer/models/initializers.py +++ b/composer/models/initializers.py @@ -29,19 +29,19 @@ def get_initializer(self) -> Callable[[torch.nn.Module], None]: """ def kaiming_normal(w: nn.Module): - if isinstance(w, torch.nn.Linear) or isinstance(w, torch.nn.Conv2d): + if isinstance(w, (torch.nn.Linear, torch.nn.Conv2d)): torch.nn.init.kaiming_normal_(w.weight) def kaiming_uniform(w: nn.Module): - if isinstance(w, torch.nn.Linear) or isinstance(w, torch.nn.Conv2d): + if isinstance(w, (torch.nn.Linear, torch.nn.Conv2d)): torch.nn.init.kaiming_uniform_(w.weight) def xavier_uniform(w: nn.Module): - if isinstance(w, torch.nn.Linear) or isinstance(w, torch.nn.Conv2d): + if isinstance(w, (torch.nn.Linear, torch.nn.Conv2d)): torch.nn.init.xavier_uniform_(w.weight) def xavier_normal(w: nn.Module): - if isinstance(w, torch.nn.Linear) or isinstance(w, torch.nn.Conv2d): + if isinstance(w, (torch.nn.Linear, torch.nn.Conv2d)): torch.nn.init.xavier_normal_(w.weight) def bn_ones(w: nn.Module): From 06b8afd1ab3239b3adc38b392239800478361f0f Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Mon, 18 Mar 2024 15:55:05 -0700 Subject: [PATCH 09/29] Bump databricks sdk version (#3128) --- composer/utils/object_store/uc_object_store.py | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/composer/utils/object_store/uc_object_store.py b/composer/utils/object_store/uc_object_store.py index 4b6b5dfe51..dce8b6e5ca 100644 --- a/composer/utils/object_store/uc_object_store.py +++ b/composer/utils/object_store/uc_object_store.py @@ -24,7 +24,7 @@ def _wrap_errors(uri: str, e: Exception): from databricks.sdk.core import DatabricksError - from databricks.sdk.errors.mapping import NotFound + from databricks.sdk.errors.platform import NotFound if isinstance(e, DatabricksError): if isinstance(e, NotFound) or e.error_code == _NOT_FOUND_ERROR_CODE: # type: ignore raise FileNotFoundError(f'Object {uri} not found') from e diff --git a/setup.py b/setup.py index 2be590eadc..b770eac736 100644 --- a/setup.py +++ b/setup.py @@ -226,7 +226,7 @@ def package_files(prefix: str, directory: str, extension: str): extra_deps['pandas'] = ['pandas>=2.0.0,<3.0'] -extra_deps['databricks'] = ['databricks-sdk==0.18.0'] +extra_deps['databricks'] = ['databricks-sdk==0.22.0'] extra_deps['all'] = {dep for deps in extra_deps.values() for dep in deps} From c9b41f78c6c1c838e48dd0ecd89f04749a458d58 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 18 Mar 2024 19:07:45 -0400 Subject: [PATCH 10/29] Update packaging requirement from <23.3,>=21.3.0 to >=21.3.0,<24.1 (#3122) Updates the requirements on [packaging](https://github.com/pypa/packaging) to permit the latest version. - [Release notes](https://github.com/pypa/packaging/releases) - [Changelog](https://github.com/pypa/packaging/blob/main/CHANGELOG.rst) - [Commits](https://github.com/pypa/packaging/compare/21.3...24.0) --- updated-dependencies: - dependency-name: packaging dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Mihir Patel --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index b770eac736..08f7f6b6e1 100644 --- a/setup.py +++ b/setup.py @@ -88,7 +88,7 @@ def package_files(prefix: str, directory: str, extension: str): 'coolname>=1.1.0,<3', 'tabulate==0.9.0', # for auto-generating tables 'py-cpuinfo>=8.0.0,<10', - 'packaging>=21.3.0,<23.3', + 'packaging>=21.3.0,<24.1', 'importlib-metadata>=5.0.0,<7', 'mosaicml-cli>=0.5.25,<0.7', ] From 0c42c510b7b5cce6550e2a24eda339260a0c7dea Mon Sep 17 00:00:00 2001 From: Evan Racah Date: Mon, 18 Mar 2024 16:43:14 -0700 Subject: [PATCH 11/29] remove rng from save weights only ckpt (#3129) --- composer/utils/checkpoint.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index e4ae4d683f..3d0f81504c 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -1038,7 +1038,6 @@ def _save_checkpoint( 'integrations': state._get_integrations_state_dict(), 'metadata': state._get_state_metadata(), }, - 'rng': reproducibility.get_rng_state(), } else: state_dict = { @@ -1055,7 +1054,7 @@ def _save_checkpoint( # Ensure state exists state_dict['state'] = state_dict.get('state', {}) - if state.fsdp_sharded_state_dict_enabled: + if state.fsdp_sharded_state_dict_enabled and not weights_only: # Only rank 0 saves RNG if dist.get_global_rank() > 0: state_dict.pop('rng') @@ -1064,7 +1063,7 @@ def _save_checkpoint( # requires a top level state dict key for the optimizer. # See https://github.com/pytorch/pytorch/blob/v2.0.1/torch/distributed/checkpoint/optimizer.py#L271 # for more info. - if version.parse(torch.__version__) < version.parse('2.2.9') and not weights_only: + if version.parse(torch.__version__) < version.parse('2.2.9'): state_dict['optimizers'] = state_dict['state'].pop('optimizers') log.debug('State dict created.') From ccaeec5da7e4fe4e40be81a365b8eb7bac497916 Mon Sep 17 00:00:00 2001 From: Matt Broadway Date: Wed, 20 Mar 2024 00:23:14 +0000 Subject: [PATCH 12/29] More compression options (#3118) * fix documentation about which file extensions perform compression * refactor tar handling when writing checkpoint files * add CLI compressor options to checkpoint saving * update documentation * restructure CLI compressor handling. Added more tests * fix skipping test when compressor not installed * add documentation and fix some code style issues * use correct file extension for zstd * add missing full stop * fix failing tests * remove type ignore Co-authored-by: Mihir Patel * combine imports into one * fix capitalisation Co-authored-by: Mihir Patel --------- Co-authored-by: Mihir Patel --- composer/callbacks/checkpoint_saver.py | 27 +++++-- composer/utils/__init__.py | 10 +++ composer/utils/checkpoint.py | 74 +++++++++++------ composer/utils/compression.py | 94 ++++++++++++++++++++++ composer/utils/file_helpers.py | 3 +- docker/Dockerfile | 8 ++ tests/trainer/test_checkpoint.py | 107 ++++++++++++++++++++++--- tests/utils/test_compression.py | 57 +++++++++++++ tests/utils/test_file_helpers.py | 8 ++ 9 files changed, 343 insertions(+), 45 deletions(-) create mode 100644 composer/utils/compression.py create mode 100644 tests/utils/test_compression.py diff --git a/composer/callbacks/checkpoint_saver.py b/composer/callbacks/checkpoint_saver.py index f1ed674ded..482b3862a1 100644 --- a/composer/callbacks/checkpoint_saver.py +++ b/composer/callbacks/checkpoint_saver.py @@ -30,6 +30,7 @@ is_model_deepspeed, partial_format, ) +from composer.utils.compression import get_compressor, is_compressed_pt from composer.utils.object_store.mlflow_object_store import MLFLOW_EXPERIMENT_ID_FORMAT_KEY, MLFLOW_RUN_ID_FORMAT_KEY log = logging.getLogger(__name__) @@ -101,14 +102,25 @@ class CheckpointSaver(Callback): # noqa: D101 may attempt to write to the same file(s), leading to corrupted checkpoints. If no tarball file extension is specified, ``'.tar'`` will be used. - * To use compression (regardless of whether DeepSpeed is enabled), set the file extension - to ``'.tar.gz'``, ``'.tgz'``, ``'.tar.bzip'``, or ``'.tar.lzma'`` (depending on the desired - compression algorithm). + * To write to compressed tar files (regardless of whether DeepSpeed is enabled), set the file + extension to ``'.tar.gz'``, ``'.tgz'``, ``'.tar.bz2'``, or ``'.tar.lzma'`` (depending on the + desired compression algorithm). + + * To write to compressed pt files (when DeepSpeed is disabled), set the file extension to + ``'.pt.bz2'``, ``'.pt.gz'``, ``'.pt.lz4'``, ``'.pt.lzma'``, ``'.pt.lzo'``, ``'.pt.xz'``, + ``'.pt.zst'`` + (depending on the desired algorithm). You must have the corresponding CLI tool installed. + ``lz4`` is a good choice for a modest space saving while being very fast to compress. .. warning:: - Using compression will block the training loop while checkpoints are being compressed. As such, we - recommend saving checkpoints without compression. + Using compression will block the training loop while checkpoints are being compressed and the + compressibility of checkpoints can vary significantly depending on your setup. As such, we + recommend saving checkpoints without compression by default. + + If you have the ``lz4`` command available on your system, you may want to try saving as ``.pt.lz4`` + as the overhead is minimal (usually less than a second) and the saved space can sometimes + be significant (1% - 40%). Consider the following scenario where: @@ -283,6 +295,11 @@ def __init__( latest_filename = str(latest_filename) if latest_filename is not None else None latest_remote_file_name = str(latest_remote_file_name) if latest_remote_file_name is not None else None + # want to fail early if a required CLI tool is missing to ensure no training time is wasted + for name in [filename, remote_file_name, latest_filename, latest_remote_file_name]: + if name is not None and is_compressed_pt(name): + get_compressor(name).check_exists() + if not callable(save_interval): save_interval = create_interval_scheduler(save_interval) self.save_interval = save_interval diff --git a/composer/utils/__init__.py b/composer/utils/__init__.py index b0151cdaaf..eb047736a4 100644 --- a/composer/utils/__init__.py +++ b/composer/utils/__init__.py @@ -23,6 +23,12 @@ get_composer_env_dict, print_env, ) +from composer.utils.compression import ( + KNOWN_COMPRESSORS, + CliCompressor, + get_compressor, + is_compressed_pt, +) from composer.utils.device import get_device, is_hpu_installed, is_xla_installed from composer.utils.eval_client import EvalClient, LambdaEvalClient, LocalEvalClient, MosaicMLLambdaEvalClient from composer.utils.file_helpers import ( @@ -131,4 +137,8 @@ 'MosaicMLLambdaEvalClient', 'partial_format', 'VersionedDeprecationWarning', + 'is_compressed_pt', + 'CliCompressor', + 'get_compressor', + 'KNOWN_COMPRESSORS', ] diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index 3d0f81504c..11513f5cb5 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -27,6 +27,7 @@ from torch.distributed.checkpoint.planner import LoadPlan, LoadPlanner from composer.utils import dist, reproducibility +from composer.utils.compression import get_compressor, is_compressed_pt from composer.utils.file_helpers import ( FORMAT_NAME_WITH_DIST_AND_TIME_TABLE, extract_path_from_symlink, @@ -724,6 +725,14 @@ def download_checkpoint( rank_n_checkpoint_filepath if fsdp_sharded_state_dict_enabled else rank_zero_checkpoint_filepath ) + if is_compressed_pt(path): + original_path = path + path = os.path.splitext(path)[0] + compressor = get_compressor(original_path) + with open(path, 'wb') as out_file: + with compressor.decompress(original_path) as in_file: + shutil.copyfileobj(in_file, out_file) + checkpoint_is_sharded = fsdp_sharded_state_dict_enabled or deepspeed_sharded_checkpoint try: if not checkpoint_is_sharded and dist.get_local_rank() == 0: @@ -1079,10 +1088,7 @@ def _save_checkpoint( expect_file = True log.debug('Saving deepspeed checkpoints to %s...', save_filename) if dist.get_global_rank() == 0: - with open(save_filename, 'wb') as f: - torch.save(state_dict, f) - if is_tar(save_filename): - _compress_file(save_filename, basename=_COMPOSER_STATES_FILENAME) + _write_checkpoint_file(state_dict, save_filename) _save_deepspeed_model(state.deepspeed_model, save_filename) # Save sharded checkpoint @@ -1121,14 +1127,9 @@ def _save_checkpoint( # Save monolith checkpoint elif dist.get_global_rank() == 0: expect_file = True - with open(save_filename, 'wb') as f: - log.debug(f'Saving monolithic checkpoint to {save_filename}') - torch.save(state_dict, f) - + log.debug(f'Saving monolithic checkpoint to {save_filename}') + _write_checkpoint_file(state_dict, save_filename) log.debug(f'Global rank 0 done saving checkpoint to disk at {save_filename}.') - - if is_tar(save_filename): - _compress_file(save_filename, basename=_COMPOSER_STATES_FILENAME) else: log.debug(f'Only rank 0 is saving a checkpoint, so rank {dist.get_global_rank()} skips checkpointing.') @@ -1142,18 +1143,29 @@ def _save_checkpoint( return None -def _compress_file(filename: str, basename: str): - """Replace a file with its compressed version. +def _write_checkpoint_file(state_dict: Dict[str, Any], filename: str) -> None: + """Write the given checkpoint state to the given path. Compressing if indicated to do so by the file extension.""" + if is_tar(filename): + log.debug('Writing checkpoint tar file %s', filename) + write_mode = _get_write_mode(filename) - The contents will be called ``basename`` inside - the compressed archive. - """ - write_mode = _get_write_mode(filename) + with tempfile.TemporaryDirectory(prefix='checkpoint') as tmpdir: + with open(os.path.join(tmpdir, _COMPOSER_STATES_FILENAME), 'wb') as f: + torch.save(state_dict, f) - with tempfile.TemporaryDirectory() as tmpdir: - shutil.move(filename, os.path.join(tmpdir, basename)) - with tarfile.open(filename, write_mode) as tarball: - tarball.add(tmpdir, arcname='') + with tarfile.open(filename, write_mode) as tarball: + tarball.add(tmpdir, arcname='') + + elif is_compressed_pt(filename): + log.debug('Writing compressed checkpoint %s', filename) + compressor = get_compressor(filename) + with compressor.compress(filename) as f: + torch.save(state_dict, f) + + else: + log.debug('Writing uncompressed checkpoint %s', filename) + with open(filename, 'wb') as f: + torch.save(state_dict, f) def _save_deepspeed_model(model, filename: str): @@ -1207,14 +1219,24 @@ def save_checkpoint( may attempt to write to the same file(s), leading to corrupted checkpoints. If no tarball file extension is specified, ``.tar`` will be used. - * To use compression (regardless of whether DeepSpeed is enabled), set the file extension - to ``'.tar.gz'``, ``'.tgz'``, ``'.tar.bzip'``, or ``'.tar.lzma'`` (depending on the desired - compression algorithm). + * To write to compressed tar files (regardless of whether DeepSpeed is enabled), set the file + extension to ``'.tar.gz'``, ``'.tgz'``, ``'.tar.bz2'``, or ``'.tar.lzma'`` (depending on the + desired compression algorithm). + + * To write to compressed pt files (when DeepSpeed is disabled), set the file extension to + ``'.pt.bz2'``, ``'.pt.gz'``, ``'.pt.lz4'``, ``'.pt.lzma'``, ``'.pt.lzo'``, ``'.pt.xz'``, ``'.pt.zst'`` + (depending on the desired algorithm). You must have the corresponding CLI tool installed. + ``lz4`` is a good choice for a modest space saving while being very fast to compress. .. warning:: - Using compression will block the training loop while checkpoints are being compressed. As such, we - recommend saving checkpoints without compression. + Using compression will block the training loop while checkpoints are being compressed and the + compressibility of checkpoints can vary significantly depending on your setup. As such, we + recommend saving checkpoints without compression by default. + + If you have the ``lz4`` command available on your system, you may want to try saving as ``.pt.lz4`` + as the overhead is minimal (usually less than a second) and the saved space can sometimes + be significant (1% - 40%). Consider the following scenario, where: diff --git a/composer/utils/compression.py b/composer/utils/compression.py new file mode 100644 index 0000000000..f5206887d3 --- /dev/null +++ b/composer/utils/compression.py @@ -0,0 +1,94 @@ +# Copyright 2024 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +"""Utilities for creating and loading compressed files.""" + +import shutil +import subprocess +from contextlib import contextmanager +from typing import IO, Iterator, List, Optional + +__all__ = ['is_compressed_pt', 'CliCompressor', 'get_compressor', 'KNOWN_COMPRESSORS'] + + +class CompressorNotFound(FileNotFoundError): + pass + + +def is_compressed_pt(filename: str) -> bool: + """Whether the filename is for a directly compressed pt file. + + Whether the extension of the given filename indicates that the file contains a raw compressed stream + of a single pt file without a container (like tar). + """ + parts = filename.split('.') + return len(parts) >= 2 and parts[-2] == 'pt' + + +class CliCompressor: + """Base class for data compression CLI tools.""" + + def __init__(self, extension: str, cmd: Optional[str] = None) -> None: + self.extension = extension + self.cmd = cmd if cmd is not None else extension + + @property + def exists(self) -> bool: + return shutil.which(self.cmd) is not None + + def check_exists(self) -> None: + if not self.exists: + raise CompressorNotFound(f'Could not find command "{self.cmd}" in the PATH.') + + def _compress_cmd(self) -> List[str]: + return [self.cmd] + + @contextmanager + def compress(self, filename: str) -> Iterator[IO[bytes]]: + self.check_exists() + with open(filename, 'wb') as f: + proc = subprocess.Popen( + self._compress_cmd(), + stdin=subprocess.PIPE, + stdout=f, + ) + assert proc.stdin is not None + yield proc.stdin + proc.stdin.close() + proc.wait() + + def _decompress_cmd(self, filename: str) -> List[str]: + return [self.cmd, '-dc', filename] + + @contextmanager + def decompress(self, in_filename: str) -> Iterator[IO[bytes]]: + self.check_exists() + proc = subprocess.Popen( + self._decompress_cmd(in_filename), + stdout=subprocess.PIPE, + ) + assert proc.stdout is not None + yield proc.stdout + proc.wait() + + +def get_compressor(filename: str) -> CliCompressor: + """Obtain the compressor that supports the format of the given file.""" + if not is_compressed_pt(filename): + raise ValueError(f'The given filename does not correspond to a compressed file: "{filename}".') + extension = filename.split('.')[-1] + for c in KNOWN_COMPRESSORS: + if c.extension == extension: + return c + raise CompressorNotFound(f'Could not find compressor for "{filename}".') + + +KNOWN_COMPRESSORS = [ + CliCompressor('bz2', 'bzip2'), + CliCompressor('gz', 'gzip'), + CliCompressor('lz4'), + CliCompressor('lzma'), + CliCompressor('lzo', 'lzop'), + CliCompressor('xz'), + CliCompressor('zst', 'zstd'), +] diff --git a/composer/utils/file_helpers.py b/composer/utils/file_helpers.py index d58db366b0..59acb7b4ef 100644 --- a/composer/utils/file_helpers.py +++ b/composer/utils/file_helpers.py @@ -110,7 +110,8 @@ def is_tar(name: Union[str, pathlib.Path]) -> bool: Returns: bool: Whether ``name`` is a tarball. """ - return any(str(name).endswith(x) for x in ('.tar', '.tgz', '.tar.gz', '.tar.bz2', '.tar.lzma')) + parts = str(name).split('.') + return len(parts) > 1 and ('tar' in parts[-2:] or parts[-1] == 'tgz') def ensure_folder_is_empty(folder_name: Union[str, pathlib.Path]): diff --git a/docker/Dockerfile b/docker/Dockerfile index 24261bb10a..504ad78469 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -124,6 +124,14 @@ RUN apt-get update && \ autotools-dev \ automake \ libtool \ + # Compressors + bzip2 \ + gzip \ + lz4 \ + lzma \ + lzop \ + xz-utils \ + zstd \ # Development tools tmux \ htop && \ diff --git a/tests/trainer/test_checkpoint.py b/tests/trainer/test_checkpoint.py index 1e391dbb74..797dc4b112 100644 --- a/tests/trainer/test_checkpoint.py +++ b/tests/trainer/test_checkpoint.py @@ -3,8 +3,10 @@ import contextlib import copy +import io import os import pathlib +import re import shutil import tarfile import tempfile @@ -29,7 +31,13 @@ from composer.trainer import trainer from composer.trainer.trainer import Trainer from composer.utils import dist, is_tar, reproducibility -from composer.utils.checkpoint import _ensure_valid_checkpoint, glob_filter +from composer.utils.checkpoint import ( + _COMPOSER_STATES_FILENAME, + _ensure_valid_checkpoint, + _write_checkpoint_file, + glob_filter, +) +from composer.utils.compression import CliCompressor, CompressorNotFound, get_compressor, is_compressed_pt from composer.utils.object_store.object_store import ObjectStore from composer.utils.object_store.s3_object_store import S3ObjectStore from tests.common import ( @@ -60,17 +68,24 @@ def load_state_dict(self, state: Dict[str, Any]) -> None: self.random_value = state['random_value'] -def _load_checkpoint(filename: Union[str, pathlib.Path]): +def _load_checkpoint(filename: Union[str, pathlib.Path]) -> Dict[str, Any]: filename = str(filename).format(rank=0) - if not is_tar(filename): + if is_tar(filename): + with tempfile.TemporaryDirectory() as tmp_dir: + with tarfile.open(filename) as tarball: + tarball.extractall(tmp_dir) + states_path = os.path.join(tmp_dir, _COMPOSER_STATES_FILENAME) + return torch.load(states_path, map_location='cpu') + + elif is_compressed_pt(filename): + compressor = get_compressor(filename) + with compressor.decompress(filename) as f: + data = io.BytesIO(f.read()) # loading requires random access + return torch.load(data, map_location='cpu') + + else: return torch.load(filename, map_location='cpu') - with tempfile.TemporaryDirectory() as tmp_dir: - with tarfile.open(filename) as tarball: - tarball.extractall(tmp_dir) - states_path = os.path.join(tmp_dir, 'composer_states.pt') - return torch.load(states_path, map_location='cpu') - def _assert_checkpoints_equivalent(file1, file2, atol=0.0, rtol=0.0): # TODO: consider merging with _assert_checkpoints_equal @@ -214,6 +229,27 @@ def test_checkpoint_saver_folder_filename_path(folder: Union[str, pathlib.Path], assert checkpoint_saver.filename.filename == str(filename) +def test_checkpoint_invalid_compressor(monkeypatch: pytest.MonkeyPatch): + with pytest.raises( + CompressorNotFound, + match=re.escape('Could not find compressor for "foo.pt.unknown_compressor".'), + ): + CheckpointSaver(filename='foo.pt.unknown_compressor') + + import composer.utils.compression + monkeypatch.setattr( + composer.utils.compression, + 'KNOWN_COMPRESSORS', + [CliCompressor('unknown_compressor', 'unknown_compressor_cmd')], + ) + + with pytest.raises( + CompressorNotFound, + match=re.escape('Could not find command "unknown_compressor_cmd" in the PATH'), + ): + CheckpointSaver(filename='foo.pt.unknown_compressor') + + @pytest.mark.parametrize( 'remote_file_name,latest_filename,latest_remote_file_name', [ @@ -305,6 +341,43 @@ def test_other_uris_error_out(self, uri: str): def test_local_paths_work(self, local_path: str): self.get_trainer(save_folder=local_path) + def test_write_checkpoint_pt_file(self, tmp_path: pathlib.Path): + state = {'foo': 123} + checkpoint_path = tmp_path / 'checkpoint.pt' + _write_checkpoint_file(state, str(checkpoint_path)) + assert _load_checkpoint(checkpoint_path) == state + + def test_write_checkpoint_tar_file(self, tmp_path: pathlib.Path): + state = {'foo': 123} + checkpoint_path_1 = tmp_path / 'checkpoint_uncompressed.tar' + _write_checkpoint_file(state, str(checkpoint_path_1)) + assert _load_checkpoint(checkpoint_path_1) == state + + checkpoint_path_2 = tmp_path / 'checkpoint_compressed.tar.gz' + _write_checkpoint_file(state, str(checkpoint_path_2)) + assert _load_checkpoint(checkpoint_path_2) == state + + assert checkpoint_path_1.read_bytes() != checkpoint_path_2.read_bytes() + assert checkpoint_path_1.stat().st_size > checkpoint_path_2.stat().st_size + + checkpoint_path_3 = tmp_path / 'checkpoint.tar.unknownalgorithm' + with pytest.raises(ValueError, match='does not end with a valid tarfile extension'): + _write_checkpoint_file(state, str(checkpoint_path_3)) + assert not checkpoint_path_3.exists() + + @pytest.mark.skipif(shutil.which('lz4') is None, reason='lz4 command not found') + def test_write_directly_compressed_pickle(self, tmp_path: pathlib.Path): + state = {'foo': 123} + checkpoint_path_uncompressed = tmp_path / 'checkpoint_uncompressed.pt' + _write_checkpoint_file(state, str(checkpoint_path_uncompressed)) + + checkpoint_path = tmp_path / 'checkpoint_uncompressed.pt.lz4' + _write_checkpoint_file(state, str(checkpoint_path)) + assert _load_checkpoint(checkpoint_path) == state + assert checkpoint_path.exists() + + assert checkpoint_path_uncompressed.stat().st_size > checkpoint_path.stat().st_size + @pytest.mark.parametrize( 'save_folder,expected_path', [ @@ -561,8 +634,9 @@ def _metrics_equal(self, train_metrics_1, train_metrics_2, eval_metrics_1, eval_ def get_trainer( self, model=None, - max_duration='2ep', - latest_filename='latest-rank{rank}.pt', + max_duration: str = '2ep', + latest_filename: str = 'latest-rank{rank}.pt', + file_extension: str = '.pt', **kwargs, ): if model is None: @@ -592,7 +666,7 @@ def get_trainer( save_interval='1ep', eval_interval='1ep', save_latest_filename=latest_filename, - save_filename='ep{epoch}.pt', + save_filename='ep{epoch}' + file_extension, max_duration=max_duration, optimizers=optimizer, schedulers=ExponentialScheduler(gamma=0.9), @@ -621,6 +695,7 @@ def get_logger(self, tmp_path: pathlib.Path): @world_size(1, 2) @device('cpu', 'gpu') + @pytest.mark.parametrize('file_extension', ['.pt', '.tar.gz', '.pt.lz4']) @pytest.mark.parametrize('use_object_store', [True, False]) @pytest.mark.parametrize('delete_local', [True, False]) @pytest.mark.parametrize('test_slashed', [True, False]) @@ -629,6 +704,7 @@ def test_autoresume( self, device: str, tmp_path: pathlib.Path, + file_extension: str, use_object_store: bool, delete_local: bool, test_slashed: bool, @@ -641,11 +717,16 @@ def test_autoresume( if use_object_store: pytest.importorskip('libcloud') - latest_filename = 'latest-rank{rank}.pt' + latest_filename = 'latest-rank{rank}' + file_extension if test_slashed: latest_filename = 'testdir/' + latest_filename + + if is_compressed_pt(latest_filename) and not get_compressor(latest_filename).exists: + pytest.skip(reason=f'compressor not found for {latest_filename}') + trainer_1 = self.get_trainer( latest_filename=latest_filename, + file_extension=file_extension, save_folder='first', device=device, run_name='big-chungus', diff --git a/tests/utils/test_compression.py b/tests/utils/test_compression.py new file mode 100644 index 0000000000..2af427274e --- /dev/null +++ b/tests/utils/test_compression.py @@ -0,0 +1,57 @@ +# Copyright 2024 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +import re +from pathlib import Path + +import pytest + +from composer.utils.compression import ( + KNOWN_COMPRESSORS, + CliCompressor, + CompressorNotFound, + get_compressor, + is_compressed_pt, +) + + +def test_is_compressed_pt() -> None: + assert is_compressed_pt('x.pt.lz4') + assert is_compressed_pt('x.pt.unknown') + assert is_compressed_pt('pt.lz4') + assert is_compressed_pt('pt.unknown') + assert is_compressed_pt('x.y.pt.lz4') + assert is_compressed_pt('x.y.pt.unknown') + + assert not is_compressed_pt('') + assert not is_compressed_pt('x.lz4') + assert not is_compressed_pt('x.tar.lz4') + + +def test_get_invalid_compressor() -> None: + with pytest.raises(CompressorNotFound, match=re.escape('Could not find compressor for "foo.pt.unknown".')): + get_compressor('foo.pt.unknown') + + +def test_compressor_not_found() -> None: + compressor = CliCompressor('foobar', 'unknown_compressor') + assert not compressor.exists + with pytest.raises(CompressorNotFound): + compressor.check_exists() + + +@pytest.mark.parametrize('compressor', KNOWN_COMPRESSORS) +def test_compressor(tmp_path: Path, compressor: CliCompressor) -> None: + if not compressor.exists: + pytest.skip(reason=f'compressor {compressor.cmd} not found') + + test_file = tmp_path / 'my_file' + test_data = b'foo foo foo' + + with compressor.compress(str(test_file)) as f: + f.write(test_data) + assert test_file.exists() + assert test_file.read_bytes() != test_data + + with compressor.decompress(str(test_file)) as f: + assert f.read() == test_data diff --git a/tests/utils/test_file_helpers.py b/tests/utils/test_file_helpers.py index df215fea59..e5a8d17314 100644 --- a/tests/utils/test_file_helpers.py +++ b/tests/utils/test_file_helpers.py @@ -195,11 +195,19 @@ def test_get_file_local_path_not_found(): def test_is_tar(): assert is_tar('x.tar') + assert is_tar('foo.bar.tar') assert is_tar('x.tgz') assert is_tar('x.tar.gz') assert is_tar('x.tar.bz2') assert is_tar('x.tar.lzma') + assert is_tar('x.tar.foo') + assert is_tar('tar.xyz') + assert not is_tar('') assert not is_tar('x') + assert not is_tar('tar.foo.xyz') + assert not is_tar('x.y') + assert not is_tar('x.y.z') + assert not is_tar('tar') def test_format_name_with_dist(): From 2fdbf454bca2c41f1fc8baa06356fcfa6e96faee Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 19 Mar 2024 21:16:59 -0700 Subject: [PATCH 13/29] Only broadcast distcp files (#3130) * filter * remove metadata skip --- composer/utils/checkpoint.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index 11513f5cb5..0116b79314 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -288,15 +288,15 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner): receiver = dist.get_global_rank() != rank_in_first_replica # Send list of files to all ranks - file_list = [sorted(os.listdir(self.destination_path))] + file_list = [ + file_name for file_name in sorted(os.listdir(self.destination_path)) if file_name.endswith('.distcp') + ] dist.broadcast_object_list(file_list, src=rank_in_first_replica, group=replicate_process_group) file_list = file_list[0] log.debug(f'List of files to broadcast: {file_list}') # Send each file to the appropriate rank for file_name in file_list: - if 'metadata' in file_name: # All ranks already have the metadata file - continue if dist.get_local_rank() == 0: # Only 1 rank per node needs to transfer file full_path = os.path.join(self.destination_path, file_name) log.debug(f'Transferring {full_path=}') From 88677c34f859b5a16a2b46a5e9fdd8d7be007b2c Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Wed, 20 Mar 2024 14:57:59 -0700 Subject: [PATCH 14/29] Bump version to 0.21 (#3132) * bump version to 021 * fix lint * fix docstring * remove old ones --- composer/_version.py | 2 +- composer/core/state.py | 11 ---- composer/core/time.py | 14 ----- composer/metrics/nlp.py | 3 +- composer/models/huggingface.py | 2 +- docker/README.md | 6 +- docker/build_matrix.yaml | 66 ++-------------------- docker/generate_build_matrix.py | 36 +----------- docs/source/notes/distributed_training.rst | 14 ++--- 9 files changed, 18 insertions(+), 136 deletions(-) diff --git a/composer/_version.py b/composer/_version.py index fffe771b0c..ad813aba8b 100644 --- a/composer/_version.py +++ b/composer/_version.py @@ -3,4 +3,4 @@ """The Composer Version.""" -__version__ = '0.20.1' +__version__ = '0.21.0' diff --git a/composer/core/state.py b/composer/core/state.py index a21b142d42..ac9c5e0064 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -44,7 +44,6 @@ is_model_deepspeed, reproducibility, ) -from composer.utils.warnings import VersionedDeprecationWarning if TYPE_CHECKING: import deepspeed @@ -790,16 +789,6 @@ def fsdp_state_dict_type(self): def fsdp_sharded_state_dict_enabled(self): return self.fsdp_config is not None and self.fsdp_enabled and self.fsdp_state_dict_type == 'sharded' - @property - def fsdp_elastic_sharded_enabled(self): - warnings.warn( - VersionedDeprecationWarning( - 'state.fsdp_elastic_sharded_enabled is deprecated.', - remove_version='0.21.0', - ), - ) - return self.fsdp_sharded_state_dict_enabled - @property def fsdp_device_mesh(self): if self.fsdp_enabled: diff --git a/composer/core/time.py b/composer/core/time.py index 35c17d74f9..f05b521614 100644 --- a/composer/core/time.py +++ b/composer/core/time.py @@ -19,12 +19,10 @@ import datetime import re -import warnings from typing import Any, Dict, Generic, Optional, TypeVar, Union, cast from composer.core.serializable import Serializable from composer.utils import StringEnum -from composer.utils.warnings import VersionedDeprecationWarning __all__ = ['TimeUnit', 'Time', 'Timestamp', 'ensure_time'] @@ -540,18 +538,6 @@ def state_dict(self) -> Dict[str, Any]: 'batch_wct': self.batch_wct, } - def get_state(self) -> Dict[str, Union[Time[int], datetime.timedelta]]: - """Returns all values of the timestamp object in a dictionary. - - Returns: - Dict[str, Union[Time[int], datetime.timedelta]]: All values of the timestamp object. - """ - warnings.warn( - VersionedDeprecationWarning('core.time.Timestamp.get_state is deprecated.', remove_version='0.21.0'), - ) - - return self.state_dict() - def load_state_dict(self, state: Dict[str, Any]) -> None: self._epoch = Time(state['epoch'], TimeUnit.EPOCH) self._batch = Time(state['batch'], TimeUnit.BATCH) diff --git a/composer/metrics/nlp.py b/composer/metrics/nlp.py index 4b4a0218b5..5082ec87ee 100644 --- a/composer/metrics/nlp.py +++ b/composer/metrics/nlp.py @@ -247,7 +247,7 @@ def update( ): """Abstract interface for computing an in-context learning metrics. - The `output_logits` argument is deprecated and will be removed in v0.21 while it's functionality will + The `output_logits` argument is deprecated and will be removed in v0.22 while it's functionality will be moved to `outputs`. Args: @@ -255,6 +255,7 @@ def update( to compute the metric. output_logits (torch.Tensor): The model outputs evaluated on the batch `input_ids` labels (torch.Tensor): The correct outputs. + outputs (torch.Tensor): The model outputs evaluated on the batch `input_ids`. Raises: NotImplementedError: Abstract method must be implemented by subclasses diff --git a/composer/models/huggingface.py b/composer/models/huggingface.py index 9bb7d62b82..e0ad6fdf6d 100644 --- a/composer/models/huggingface.py +++ b/composer/models/huggingface.py @@ -513,7 +513,7 @@ def eval_forward(self, batch, outputs: Optional[Any] = None): warnings.warn( VersionedDeprecationWarning( '`generation_length` has been deprecated in favor of passing `max_new_tokens` directly into `generation_kwargs`.', - remove_version='0.21.0', + remove_version='0.22.0', ), ) if 'generation_kwargs' in batch: diff --git a/docker/README.md b/docker/README.md index 73b4b1e13b..5e5d943ee0 100644 --- a/docker/README.md +++ b/docker/README.md @@ -15,8 +15,8 @@ all dependencies for both NLP and Vision models. They are built on top of the | Composer Version | CUDA Support | Docker Tag | |--------------------|----------------|----------------------------------------------------------------| -| 0.20.1 | Yes | `mosaicml/composer:latest`, `mosaicml/composer:0.20.1` | -| 0.20.1 | No | `mosaicml/composer:latest_cpu`, `mosaicml/composer:0.20.1_cpu` | +| 0.21.0 | Yes | `mosaicml/composer:latest`, `mosaicml/composer:0.21.0` | +| 0.21.0 | No | `mosaicml/composer:latest_cpu`, `mosaicml/composer:0.21.0_cpu` | **Note**: For a lightweight installation, we recommended using a [MosaicML PyTorch Image](#pytorch-images) and manually @@ -30,8 +30,6 @@ To install composer, once inside the image, run `pip install mosaicml`. | Linux Distro | Flavor | PyTorch Version | CUDA Version | Python Version | Docker Tags | |----------------|----------|-------------------|---------------------|------------------|------------------------------------------------------------------------------------------| -| Ubuntu 20.04 | Base | 2.3.0 | 12.1.1 (Infiniband) | 3.11 | `mosaicml/pytorch:2.3.0_cu121-nightly20240110-python3.11-ubuntu20.04` | -| Ubuntu 20.04 | Base | 2.3.0 | 12.1.1 (EFA) | 3.11 | `mosaicml/pytorch:2.3.0_cu121-nightly20240110-python3.11-ubuntu20.04-aws` | | Ubuntu 20.04 | Base | 2.2.1 | 12.1.1 (Infiniband) | 3.11 | `mosaicml/pytorch:2.2.1_cu121-python3.11-ubuntu20.04` | | Ubuntu 20.04 | Base | 2.2.1 | 12.1.1 (EFA) | 3.11 | `mosaicml/pytorch:2.2.1_cu121-python3.11-ubuntu20.04-aws` | | Ubuntu 20.04 | Base | 2.2.1 | cpu | 3.11 | `mosaicml/pytorch:2.2.1_cpu-python3.11-ubuntu20.04` | diff --git a/docker/build_matrix.yaml b/docker/build_matrix.yaml index 31e3e1ba27..6b150ab7e0 100644 --- a/docker/build_matrix.yaml +++ b/docker/build_matrix.yaml @@ -190,65 +190,11 @@ - mosaicml/pytorch:2.0.1_cpu-python3.10-ubuntu20.04 TARGET: pytorch_stage TORCHVISION_VERSION: 0.15.2 -- AWS_OFI_NCCL_VERSION: v1.7.4-aws - BASE_IMAGE: nvidia/cuda:12.1.1-cudnn8-devel-ubuntu20.04 - CUDA_VERSION: 12.1.1 - IMAGE_NAME: torch-nightly-2-3-0-20240110-cu121-python3-11-aws - MOFED_VERSION: '' - NVIDIA_REQUIRE_CUDA_OVERRIDE: cuda>=12.1 brand=tesla,driver>=450,driver<451 brand=tesla,driver>=470,driver<471 - brand=unknown,driver>=470,driver<471 brand=nvidia,driver>=470,driver<471 brand=nvidiartx,driver>=470,driver<471 - brand=geforce,driver>=470,driver<471 brand=geforcertx,driver>=470,driver<471 brand=quadro,driver>=470,driver<471 - brand=quadrortx,driver>=470,driver<471 brand=titan,driver>=470,driver<471 brand=titanrtx,driver>=470,driver<471 - brand=tesla,driver>=510,driver<511 brand=unknown,driver>=510,driver<511 brand=nvidia,driver>=510,driver<511 - brand=nvidiartx,driver>=510,driver<511 brand=geforce,driver>=510,driver<511 brand=geforcertx,driver>=510,driver<511 - brand=quadro,driver>=510,driver<511 brand=quadrortx,driver>=510,driver<511 brand=titan,driver>=510,driver<511 - brand=titanrtx,driver>=510,driver<511 brand=tesla,driver>=515,driver<516 brand=unknown,driver>=515,driver<516 - brand=nvidia,driver>=515,driver<516 brand=nvidiartx,driver>=515,driver<516 brand=geforce,driver>=515,driver<516 - brand=geforcertx,driver>=515,driver<516 brand=quadro,driver>=515,driver<516 brand=quadrortx,driver>=515,driver<516 - brand=titan,driver>=515,driver<516 brand=titanrtx,driver>=515,driver<516 brand=tesla,driver>=525,driver<526 - brand=unknown,driver>=525,driver<526 brand=nvidia,driver>=525,driver<526 brand=nvidiartx,driver>=525,driver<526 - brand=geforce,driver>=525,driver<526 brand=geforcertx,driver>=525,driver<526 brand=quadro,driver>=525,driver<526 - brand=quadrortx,driver>=525,driver<526 brand=titan,driver>=525,driver<526 brand=titanrtx,driver>=525,driver<526 - PYTHON_VERSION: '3.11' - PYTORCH_NIGHTLY_URL: https://download.pytorch.org/whl/nightly/cu121 - PYTORCH_NIGHTLY_VERSION: dev20240110+cu121 - PYTORCH_VERSION: 2.3.0 - TAGS: - - mosaicml/pytorch:2.3.0_cu121-nightly20240110-python3.11-ubuntu20.04-aws - TARGET: pytorch_stage - TORCHVISION_VERSION: 0.18.0 -- AWS_OFI_NCCL_VERSION: '' - BASE_IMAGE: nvidia/cuda:12.1.1-cudnn8-devel-ubuntu20.04 - CUDA_VERSION: 12.1.1 - IMAGE_NAME: torch-nightly-2-3-0-20240110-cu121-python3-11 - MOFED_VERSION: 5.5-1.0.3.2 - NVIDIA_REQUIRE_CUDA_OVERRIDE: cuda>=12.1 brand=tesla,driver>=450,driver<451 brand=tesla,driver>=470,driver<471 - brand=unknown,driver>=470,driver<471 brand=nvidia,driver>=470,driver<471 brand=nvidiartx,driver>=470,driver<471 - brand=geforce,driver>=470,driver<471 brand=geforcertx,driver>=470,driver<471 brand=quadro,driver>=470,driver<471 - brand=quadrortx,driver>=470,driver<471 brand=titan,driver>=470,driver<471 brand=titanrtx,driver>=470,driver<471 - brand=tesla,driver>=510,driver<511 brand=unknown,driver>=510,driver<511 brand=nvidia,driver>=510,driver<511 - brand=nvidiartx,driver>=510,driver<511 brand=geforce,driver>=510,driver<511 brand=geforcertx,driver>=510,driver<511 - brand=quadro,driver>=510,driver<511 brand=quadrortx,driver>=510,driver<511 brand=titan,driver>=510,driver<511 - brand=titanrtx,driver>=510,driver<511 brand=tesla,driver>=515,driver<516 brand=unknown,driver>=515,driver<516 - brand=nvidia,driver>=515,driver<516 brand=nvidiartx,driver>=515,driver<516 brand=geforce,driver>=515,driver<516 - brand=geforcertx,driver>=515,driver<516 brand=quadro,driver>=515,driver<516 brand=quadrortx,driver>=515,driver<516 - brand=titan,driver>=515,driver<516 brand=titanrtx,driver>=515,driver<516 brand=tesla,driver>=525,driver<526 - brand=unknown,driver>=525,driver<526 brand=nvidia,driver>=525,driver<526 brand=nvidiartx,driver>=525,driver<526 - brand=geforce,driver>=525,driver<526 brand=geforcertx,driver>=525,driver<526 brand=quadro,driver>=525,driver<526 - brand=quadrortx,driver>=525,driver<526 brand=titan,driver>=525,driver<526 brand=titanrtx,driver>=525,driver<526 - PYTHON_VERSION: '3.11' - PYTORCH_NIGHTLY_URL: https://download.pytorch.org/whl/nightly/cu121 - PYTORCH_NIGHTLY_VERSION: dev20240110+cu121 - PYTORCH_VERSION: 2.3.0 - TAGS: - - mosaicml/pytorch:2.3.0_cu121-nightly20240110-python3.11-ubuntu20.04 - TARGET: pytorch_stage - TORCHVISION_VERSION: 0.18.0 - AWS_OFI_NCCL_VERSION: '' BASE_IMAGE: nvidia/cuda:12.1.1-cudnn8-devel-ubuntu20.04 - COMPOSER_INSTALL_COMMAND: mosaicml[all]==0.20.1 + COMPOSER_INSTALL_COMMAND: mosaicml[all]==0.21.0 CUDA_VERSION: 12.1.1 - IMAGE_NAME: composer-0-20-1 + IMAGE_NAME: composer-0-21-0 MOFED_VERSION: 5.5-1.0.3.2 NVIDIA_REQUIRE_CUDA_OVERRIDE: cuda>=12.1 brand=tesla,driver>=450,driver<451 brand=tesla,driver>=470,driver<471 brand=unknown,driver>=470,driver<471 brand=nvidia,driver>=470,driver<471 brand=nvidiartx,driver>=470,driver<471 @@ -269,15 +215,15 @@ PYTORCH_NIGHTLY_VERSION: '' PYTORCH_VERSION: 2.1.2 TAGS: - - mosaicml/composer:0.20.1 + - mosaicml/composer:0.21.0 - mosaicml/composer:latest TARGET: composer_stage TORCHVISION_VERSION: 0.16.2 - AWS_OFI_NCCL_VERSION: '' BASE_IMAGE: ubuntu:20.04 - COMPOSER_INSTALL_COMMAND: mosaicml[all]==0.20.1 + COMPOSER_INSTALL_COMMAND: mosaicml[all]==0.21.0 CUDA_VERSION: '' - IMAGE_NAME: composer-0-20-1-cpu + IMAGE_NAME: composer-0-21-0-cpu MOFED_VERSION: 5.5-1.0.3.2 NVIDIA_REQUIRE_CUDA_OVERRIDE: '' PYTHON_VERSION: '3.10' @@ -285,7 +231,7 @@ PYTORCH_NIGHTLY_VERSION: '' PYTORCH_VERSION: 2.1.2 TAGS: - - mosaicml/composer:0.20.1_cpu + - mosaicml/composer:0.21.0_cpu - mosaicml/composer:latest_cpu TARGET: composer_stage TORCHVISION_VERSION: 0.16.2 diff --git a/docker/generate_build_matrix.py b/docker/generate_build_matrix.py index a45c08228d..0b2405417b 100644 --- a/docker/generate_build_matrix.py +++ b/docker/generate_build_matrix.py @@ -228,44 +228,10 @@ def _main(): pytorch_entries.append(entry) - nightly_entry_311_aws = { - 'AWS_OFI_NCCL_VERSION': 'v1.7.4-aws', - 'BASE_IMAGE': 'nvidia/cuda:12.1.1-cudnn8-devel-ubuntu20.04', - 'CUDA_VERSION': '12.1.1', - 'IMAGE_NAME': 'torch-nightly-2-3-0-20240110-cu121-python3-11-aws', - 'MOFED_VERSION': '', - 'NVIDIA_REQUIRE_CUDA_OVERRIDE': _get_cuda_override('12.1.1'), - 'PYTHON_VERSION': '3.11', - 'PYTORCH_VERSION': '2.3.0', - 'PYTORCH_NIGHTLY_URL': 'https://download.pytorch.org/whl/nightly/cu121', - 'PYTORCH_NIGHTLY_VERSION': 'dev20240110+cu121', - 'TAGS': ['mosaicml/pytorch:2.3.0_cu121-nightly20240110-python3.11-ubuntu20.04-aws'], - 'TARGET': 'pytorch_stage', - 'TORCHVISION_VERSION': '0.18.0', - } - pytorch_entries.append(nightly_entry_311_aws) - - nightly_entry_311 = { - 'AWS_OFI_NCCL_VERSION': '', - 'BASE_IMAGE': 'nvidia/cuda:12.1.1-cudnn8-devel-ubuntu20.04', - 'CUDA_VERSION': '12.1.1', - 'IMAGE_NAME': 'torch-nightly-2-3-0-20240110-cu121-python3-11', - 'MOFED_VERSION': '5.5-1.0.3.2', - 'NVIDIA_REQUIRE_CUDA_OVERRIDE': _get_cuda_override('12.1.1'), - 'PYTHON_VERSION': '3.11', - 'PYTORCH_VERSION': '2.3.0', - 'PYTORCH_NIGHTLY_URL': 'https://download.pytorch.org/whl/nightly/cu121', - 'PYTORCH_NIGHTLY_VERSION': 'dev20240110+cu121', - 'TAGS': ['mosaicml/pytorch:2.3.0_cu121-nightly20240110-python3.11-ubuntu20.04'], - 'TARGET': 'pytorch_stage', - 'TORCHVISION_VERSION': '0.18.0', - } - pytorch_entries.append(nightly_entry_311) - composer_entries = [] # The `GIT_COMMIT` is a placeholder and Jenkins will substitute it with the actual git commit for the `composer_staging` images - composer_versions = ['0.20.1'] # Only build images for the latest composer version + composer_versions = ['0.21.0'] # Only build images for the latest composer version composer_python_versions = [PRODUCTION_PYTHON_VERSION] # just build composer against the latest for product in itertools.product(composer_python_versions, composer_versions, cuda_options): diff --git a/docs/source/notes/distributed_training.rst b/docs/source/notes/distributed_training.rst index cab087f3b8..c64b51dca2 100644 --- a/docs/source/notes/distributed_training.rst +++ b/docs/source/notes/distributed_training.rst @@ -395,18 +395,14 @@ It does this by gathering the model state to the global rank 0 device, unflatten If `load_monolith_rank0_only=True`, then when loading checkpoints the global rank 0 device will load in the checkpoint file and scatter the model and optimizer state to the other ranks, which will will dramatically reduce the memory usage on system. Otherwise, all ranks will separately load in the checkpoint file. -2. :code:`state_dict_type='local'` -For save: each rank saves out the flattened model state shard they are -responsibile for to a distinct checkpoint file. For load, each rank loads in the checkpoint file -corresponding to their shard. **Note: state_dict_type='local' is deprecated in Composer for torch versions 2.0.0 or higher.** - -3. :code:`state_dict_type='sharded'` -Each rank saves out an unflattened shard. For loading, similar to ``state_dict_type='local'``, each rank -loads in the checkpoint file corresponding to their unflattened shard. **Note: state_dict_type='sharded' is the recommended setting for sharded checkpointing in Composer for torch versions 2.0.0 or higher.** +2. :code:`state_dict_type='sharded'` +Each rank saves out an unflattened shard. For loading, each rank loads in the checkpoint file +corresponding to their unflattened shard. +**Note: state_dict_type='sharded' is the recommended setting for sharded checkpointing in Composer for torch versions 2.0.0 or higher.** See `The FSDP docs `__ for more info. -If you use sharded checkpoints (`state_dict_type='sharded'` or `state_dict_type='local'`), your run will save as many files as you have +If you use sharded checkpoints (`state_dict_type='sharded'`), your run will save as many files as you have ranks at each checkpointing event (plus one metadata file for torch versions 2.0.0 or higher). This can quicky pollute your `save_folder` with a lot of files after a couple checkpointing events. To help keep your checkpoint shard files organized, Composer will save each set of shards in it's own prefix directory, which you can configure by using `'sharded_ckpt_prefix_dir'` (default value `sharded_ckpt_prefix_dir='ep{epoch}-ba{batch}'`). Checkpoint shards will be saved to From 5161f598360dba4d1d5f51a948b885b476b333e5 Mon Sep 17 00:00:00 2001 From: Alexandre Ghelfi Date: Mon, 25 Mar 2024 11:01:38 +0100 Subject: [PATCH 15/29] Adding a test for checking the update of epoch on batch_sampler --- tests/trainer/test_checkpoint.py | 97 ++++++++++++++++++++++++++++---- 1 file changed, 87 insertions(+), 10 deletions(-) diff --git a/tests/trainer/test_checkpoint.py b/tests/trainer/test_checkpoint.py index 797dc4b112..c3e6c6477f 100644 --- a/tests/trainer/test_checkpoint.py +++ b/tests/trainer/test_checkpoint.py @@ -20,7 +20,7 @@ import torch.distributed from packaging import version from pytest import MonkeyPatch -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Dataset, DistributedSampler from composer.algorithms import NoOpModel from composer.callbacks import CheckpointSaver @@ -1246,6 +1246,8 @@ def get_trainer( precision='fp32', max_duration='2ep', train_subset_num_batches=5, + use_batch_sampler: bool = False, + with_eval_dataloader: bool = True, **kwargs, ): model = SimpleModel() @@ -1257,18 +1259,54 @@ def get_trainer( eval_dataset = RandomClassificationDataset(size=12) train_batch_size = 2 - return Trainer( - model=model, - train_dataloader=DataLoader( - dataset=train_dataset, - batch_size=train_batch_size, - sampler=dist.get_sampler(train_dataset), - ), - eval_dataloader=DataLoader( + class _DistributedBatchSampler(DistributedSampler): + def __init__( + self, dataset: Dataset, batch_size: int, num_replicas: Optional[int] = None, + rank: Optional[int] = None, shuffle: bool = True, + seed: int = 0, drop_last: bool = False, + ): + super().__init__(dataset=dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed, drop_last=drop_last) + self._batch_size =batch_size + + def __iter__(self): + indices = list(super().__iter__()) + for ind_ in range(len(self)): + yield indices[ind_*self._batch_size: (ind_+1)*self._batch_size] + + def __len__(self) -> int: + return self.num_samples//self._batch_size + + train_dataloader = DataLoader( + dataset=train_dataset, + batch_size=train_batch_size, + sampler=dist.get_sampler(train_dataset), + ) if not use_batch_sampler else DataLoader( + dataset=train_dataset, + batch_sampler=_DistributedBatchSampler(dataset=train_dataset, drop_last=True, + shuffle=True, + num_replicas=dist.get_world_size(), + rank=dist.get_global_rank(), batch_size=train_batch_size), + ) + + if with_eval_dataloader is True: + eval_dataloader = DataLoader( dataset=eval_dataset, batch_size=2, sampler=dist.get_sampler(eval_dataset), - ), + ) if not use_batch_sampler else DataLoader( + dataset=eval_dataset, + batch_sampler=_DistributedBatchSampler(dataset=eval_dataset, drop_last=False, + shuffle=False, + num_replicas=dist.get_world_size(), + rank=dist.get_global_rank(), batch_size=train_batch_size), + ) + else: + eval_dataloader = None + + return Trainer( + model=model, + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, device_train_microbatch_size=train_batch_size // 2, precision=precision, train_subset_num_batches=train_subset_num_batches, @@ -1412,6 +1450,45 @@ def test_resumption( save_folder / 'second' / final_checkpoint, ) + @world_size(2) + @pytest.mark.parametrize( + 'device', + [ + pytest.param('gpu', marks=pytest.mark.gpu), + pytest.param('cpu'), + ], + ) + @pytest.mark.parametrize('max_duration', [1, 2]) + @pytest.mark.filterwarnings('ignore:An unexpected prefix is detected. This case.*') + @pytest.mark.filterwarnings( + 'ignore:``FullyShardedDataParallel.scatter_full_optim_state_dict``is being deprecated and is replaced by.*', + ) + def test_set_dataloaders_to_cur_epoch( + self, + world_size: int, + device: str, + max_duration: int, + tmp_path: pathlib.Path, + ): + # All ranks use rank 0 folder + tmp_paths = dist.all_gather_object(os.path.abspath(tmp_path)) + save_folder = pathlib.Path(tmp_paths[0]) + + trainer = self.get_trainer( + save_folder=os.path.join(save_folder, 'first'), + device=device, + precision='fp32', + max_duration=f'{max_duration}ep', + train_subset_num_batches=2, + use_batch_sampler=True, + with_eval_dataloader=False, + ) + + trainer.fit() + + # Epochs count starts at O + assert trainer.state.train_dataloader.batch_sampler.epoch == max_duration - 1 + @pytest.mark.parametrize( 'world_size', [ From c731f67fb2d4935bc79b838fe129dc1c0be9d161 Mon Sep 17 00:00:00 2001 From: Alexandre Ghelfi Date: Mon, 25 Mar 2024 11:03:38 +0100 Subject: [PATCH 16/29] proper formating --- tests/trainer/test_checkpoint.py | 53 ++++++++++++++++++++++---------- 1 file changed, 37 insertions(+), 16 deletions(-) diff --git a/tests/trainer/test_checkpoint.py b/tests/trainer/test_checkpoint.py index c3e6c6477f..2e7e9a4c25 100644 --- a/tests/trainer/test_checkpoint.py +++ b/tests/trainer/test_checkpoint.py @@ -1260,21 +1260,34 @@ def get_trainer( train_batch_size = 2 class _DistributedBatchSampler(DistributedSampler): + def __init__( - self, dataset: Dataset, batch_size: int, num_replicas: Optional[int] = None, - rank: Optional[int] = None, shuffle: bool = True, - seed: int = 0, drop_last: bool = False, + self, + dataset: Dataset, + batch_size: int, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, ): - super().__init__(dataset=dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed, drop_last=drop_last) - self._batch_size =batch_size + super().__init__( + dataset=dataset, + num_replicas=num_replicas, + rank=rank, + shuffle=shuffle, + seed=seed, + drop_last=drop_last + ) + self._batch_size = batch_size def __iter__(self): indices = list(super().__iter__()) for ind_ in range(len(self)): - yield indices[ind_*self._batch_size: (ind_+1)*self._batch_size] - + yield indices[ind_ * self._batch_size:(ind_ + 1) * self._batch_size] + def __len__(self) -> int: - return self.num_samples//self._batch_size + return self.num_samples // self._batch_size train_dataloader = DataLoader( dataset=train_dataset, @@ -1282,10 +1295,14 @@ def __len__(self) -> int: sampler=dist.get_sampler(train_dataset), ) if not use_batch_sampler else DataLoader( dataset=train_dataset, - batch_sampler=_DistributedBatchSampler(dataset=train_dataset, drop_last=True, - shuffle=True, - num_replicas=dist.get_world_size(), - rank=dist.get_global_rank(), batch_size=train_batch_size), + batch_sampler=_DistributedBatchSampler( + dataset=train_dataset, + drop_last=True, + shuffle=True, + num_replicas=dist.get_world_size(), + rank=dist.get_global_rank(), + batch_size=train_batch_size + ), ) if with_eval_dataloader is True: @@ -1295,10 +1312,14 @@ def __len__(self) -> int: sampler=dist.get_sampler(eval_dataset), ) if not use_batch_sampler else DataLoader( dataset=eval_dataset, - batch_sampler=_DistributedBatchSampler(dataset=eval_dataset, drop_last=False, - shuffle=False, - num_replicas=dist.get_world_size(), - rank=dist.get_global_rank(), batch_size=train_batch_size), + batch_sampler=_DistributedBatchSampler( + dataset=eval_dataset, + drop_last=False, + shuffle=False, + num_replicas=dist.get_world_size(), + rank=dist.get_global_rank(), + batch_size=train_batch_size + ), ) else: eval_dataloader = None From 570558d762113b6608192a08f3d6d691bc6520a5 Mon Sep 17 00:00:00 2001 From: Alexandre Ghelfi Date: Sun, 4 Feb 2024 13:19:28 +0100 Subject: [PATCH 17/29] adding _get_dist_sampler function --- composer/trainer/trainer.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 064c22a73b..80e183e404 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -465,6 +465,15 @@ def _generate_run_name() -> str: return generated_run_name +def _get_dist_sampler(dataloader: Any) -> DistributedSampler | None: + if not isinstance(dataloader, DataLoader): + return + if dataloader.sampler is not None: + return dataloader.sampler + else: + return dataloader.batch_sampler + + class Trainer: """Train models with Composer algorithms. From ab09d434e24a64e9aecd43311dae4e445776aa3d Mon Sep 17 00:00:00 2001 From: Alexandre Ghelfi Date: Sun, 4 Feb 2024 13:26:16 +0100 Subject: [PATCH 18/29] apply _get_dist_sampler across all trainer.py --- composer/trainer/trainer.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 80e183e404..bcdd93bbde 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -2281,8 +2281,9 @@ def _spin_dataloaders_to_cur_epoch(self): eval_state = self.state.dataset_resumption.get('eval', {}) for evaluator in self.state.evaluators: dataloader = evaluator.dataloader.dataloader - if isinstance(dataloader, DataLoader) and isinstance(dataloader.sampler, DistributedSampler): - dataloader.sampler.set_epoch(0) + sampler = _get_dist_sampler(dataloader) + if isinstance(sampler, DistributedSampler): + sampler.set_epoch(0) if evaluator.label not in eval_state: for _ in dataloader: break @@ -2292,8 +2293,9 @@ def _spin_dataloaders_to_cur_epoch(self): assert dataloader is not None, 'train dataloader is set on state after FIT_START' if 'train' not in self.state.dataset_resumption: for epoch in range(int(self.state.timestamp.epoch)): - if isinstance(dataloader, DataLoader) and isinstance(dataloader.sampler, DistributedSampler): - dataloader.sampler.set_epoch(epoch) + sampler = _get_dist_sampler(dataloader) + if isinstance(sampler, DistributedSampler): + sampler.set_epoch(epoch) for _ in dataloader: break @@ -2375,9 +2377,10 @@ def _train_loop(self) -> None: self.logger.log_metrics({'time/epoch': self.state.timestamp.epoch.value}) dataloader = self.state.dataloader - if isinstance(dataloader, DataLoader) and isinstance(dataloader.sampler, DistributedSampler): - dataloader.sampler.set_epoch(int(self.state.timestamp.epoch)) - + sampler = _get_dist_sampler(dataloader) + if isinstance(sampler, DistributedSampler): + sampler.set_epoch(int(self.state.timestamp.epoch)) + for batch_idx, self.state.batch in enumerate(self._iter_dataloader(TrainerMode.TRAIN)): # Spin dataloader forward unless dataloader handles internally with dataset_resumption if self.spin_dataloaders and 'train' not in self.state.dataset_resumption and batch_idx < int( @@ -3234,19 +3237,19 @@ def _eval_loop( drop_last = None dataset_len = None last_batch = False - if isinstance(dataloader, DataLoader) and isinstance(dataloader.sampler, DistributedSampler): + sampler = _get_dist_sampler(dataloader) + if isinstance(sampler, DistributedSampler): # The distributed sampler uses `set_epoch` to set the random seed # Because evaluation can run on each batch, we use the batch to seed the sampler # so each evaluation will get a proper shuffle. # The epoch provided to `set_epoch` need not be sequential, so this is fine. - dist_sampler = dataloader.sampler - dist_sampler.set_epoch(int(self.state.timestamp.batch)) + sampler.set_epoch(int(self.state.timestamp.batch)) drop_last = dataloader.drop_last # Only compute the dataset length if drop_last is False, as otherwise we don't need # to remove any duplicate samples. if drop_last == False: try: - dataset_len = len(dist_sampler.dataset) # type: ignore + dataset_len = len(sampler.dataset) # type: ignore except AttributeError: warnings.warn( "DistributedSampler's dataset does not have length defined. When " From df7332532bfaf7e4467cfc50b19abd2467064c9b Mon Sep 17 00:00:00 2001 From: Alexandre Ghelfi Date: Sun, 17 Mar 2024 16:36:36 +0100 Subject: [PATCH 19/29] Switch condition to check if batch_sampler if filled first since sampler is always defined --- composer/trainer/trainer.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index bcdd93bbde..e9f480d11a 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -48,7 +48,7 @@ from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler from torch.nn.parallel import DistributedDataParallel from torch.optim.lr_scheduler import LRScheduler -from torch.utils.data import DataLoader, DistributedSampler +from torch.utils.data import DataLoader, DistributedSampler, Sampler from torchmetrics import Metric from composer.callbacks import CheckpointSaver, MemorySnapshot, OOMObserver, OptimizerMonitor @@ -465,13 +465,12 @@ def _generate_run_name() -> str: return generated_run_name -def _get_dist_sampler(dataloader: Any) -> DistributedSampler | None: +def _get_sampler(dataloader: Any) -> Sampler | None: if not isinstance(dataloader, DataLoader): return - if dataloader.sampler is not None: - return dataloader.sampler - else: + if dataloader.batch_sampler is not None: return dataloader.batch_sampler + return dataloader.sampler class Trainer: @@ -2281,7 +2280,7 @@ def _spin_dataloaders_to_cur_epoch(self): eval_state = self.state.dataset_resumption.get('eval', {}) for evaluator in self.state.evaluators: dataloader = evaluator.dataloader.dataloader - sampler = _get_dist_sampler(dataloader) + sampler = _get_sampler(dataloader) if isinstance(sampler, DistributedSampler): sampler.set_epoch(0) if evaluator.label not in eval_state: @@ -2293,7 +2292,7 @@ def _spin_dataloaders_to_cur_epoch(self): assert dataloader is not None, 'train dataloader is set on state after FIT_START' if 'train' not in self.state.dataset_resumption: for epoch in range(int(self.state.timestamp.epoch)): - sampler = _get_dist_sampler(dataloader) + sampler = _get_sampler(dataloader) if isinstance(sampler, DistributedSampler): sampler.set_epoch(epoch) for _ in dataloader: @@ -2377,10 +2376,10 @@ def _train_loop(self) -> None: self.logger.log_metrics({'time/epoch': self.state.timestamp.epoch.value}) dataloader = self.state.dataloader - sampler = _get_dist_sampler(dataloader) + sampler = _get_sampler(dataloader) if isinstance(sampler, DistributedSampler): sampler.set_epoch(int(self.state.timestamp.epoch)) - + for batch_idx, self.state.batch in enumerate(self._iter_dataloader(TrainerMode.TRAIN)): # Spin dataloader forward unless dataloader handles internally with dataset_resumption if self.spin_dataloaders and 'train' not in self.state.dataset_resumption and batch_idx < int( @@ -3237,7 +3236,7 @@ def _eval_loop( drop_last = None dataset_len = None last_batch = False - sampler = _get_dist_sampler(dataloader) + sampler = _get_sampler(dataloader) if isinstance(sampler, DistributedSampler): # The distributed sampler uses `set_epoch` to set the random seed # Because evaluation can run on each batch, we use the batch to seed the sampler From 61da1b8e8ac21232346410ef2360cd99f333e540 Mon Sep 17 00:00:00 2001 From: Alexandre Ghelfi Date: Sun, 17 Mar 2024 16:39:57 +0100 Subject: [PATCH 20/29] adding docstring --- composer/trainer/trainer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index e9f480d11a..e12eb4d9b1 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -466,6 +466,10 @@ def _generate_run_name() -> str: def _get_sampler(dataloader: Any) -> Sampler | None: + """Checks if dataloader is `torch.utils.data.DataLoader` and return the batch_sampler if defined, + else the regular sampler. + If `dataloader` is not a `torch.utils.data.DataLoader`, returns None. + """ if not isinstance(dataloader, DataLoader): return if dataloader.batch_sampler is not None: From 12245cc1d0a6b14d9e729183d48fc4d2545a7f30 Mon Sep 17 00:00:00 2001 From: Alexandre Ghelfi Date: Sun, 17 Mar 2024 17:22:47 +0100 Subject: [PATCH 21/29] removing antipattenr from _get_sampler and linting --- composer/trainer/trainer.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index e12eb4d9b1..d225aa914e 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -465,13 +465,11 @@ def _generate_run_name() -> str: return generated_run_name -def _get_sampler(dataloader: Any) -> Sampler | None: - """Checks if dataloader is `torch.utils.data.DataLoader` and return the batch_sampler if defined, - else the regular sampler. - If `dataloader` is not a `torch.utils.data.DataLoader`, returns None. +def _get_sampler(dataloader: DataLoader) -> Sampler | Iterable: + """Fetch the sampler from a `dataloader`. + + Returns `dalaoder.batch_sampler` is defined, else `dataloader.sampler` (always defined in `Dataloader.__init__`). """ - if not isinstance(dataloader, DataLoader): - return if dataloader.batch_sampler is not None: return dataloader.batch_sampler return dataloader.sampler @@ -2284,7 +2282,7 @@ def _spin_dataloaders_to_cur_epoch(self): eval_state = self.state.dataset_resumption.get('eval', {}) for evaluator in self.state.evaluators: dataloader = evaluator.dataloader.dataloader - sampler = _get_sampler(dataloader) + sampler = _get_sampler(dataloader) if isinstance(dataloader, DataLoader) else None if isinstance(sampler, DistributedSampler): sampler.set_epoch(0) if evaluator.label not in eval_state: @@ -2296,7 +2294,7 @@ def _spin_dataloaders_to_cur_epoch(self): assert dataloader is not None, 'train dataloader is set on state after FIT_START' if 'train' not in self.state.dataset_resumption: for epoch in range(int(self.state.timestamp.epoch)): - sampler = _get_sampler(dataloader) + sampler = _get_sampler(dataloader) if isinstance(dataloader, DataLoader) else None if isinstance(sampler, DistributedSampler): sampler.set_epoch(epoch) for _ in dataloader: @@ -2380,7 +2378,7 @@ def _train_loop(self) -> None: self.logger.log_metrics({'time/epoch': self.state.timestamp.epoch.value}) dataloader = self.state.dataloader - sampler = _get_sampler(dataloader) + sampler = _get_sampler(dataloader) if isinstance(dataloader, DataLoader) else None if isinstance(sampler, DistributedSampler): sampler.set_epoch(int(self.state.timestamp.epoch)) @@ -3240,8 +3238,8 @@ def _eval_loop( drop_last = None dataset_len = None last_batch = False - sampler = _get_sampler(dataloader) - if isinstance(sampler, DistributedSampler): + sampler = _get_sampler(dataloader) if isinstance(dataloader, DataLoader) else None + if isinstance(sampler, DistributedSampler) and isinstance(dataloader, DataLoader): # The distributed sampler uses `set_epoch` to set the random seed # Because evaluation can run on each batch, we use the batch to seed the sampler # so each evaluation will get a proper shuffle. From b2c147af78c5e0635bb911eb371c233af379aadb Mon Sep 17 00:00:00 2001 From: Alexandre Ghelfi Date: Mon, 25 Mar 2024 11:01:38 +0100 Subject: [PATCH 22/29] Adding a test for checking the update of epoch on batch_sampler --- tests/trainer/test_checkpoint.py | 97 ++++++++++++++++++++++++++++---- 1 file changed, 87 insertions(+), 10 deletions(-) diff --git a/tests/trainer/test_checkpoint.py b/tests/trainer/test_checkpoint.py index 797dc4b112..c3e6c6477f 100644 --- a/tests/trainer/test_checkpoint.py +++ b/tests/trainer/test_checkpoint.py @@ -20,7 +20,7 @@ import torch.distributed from packaging import version from pytest import MonkeyPatch -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Dataset, DistributedSampler from composer.algorithms import NoOpModel from composer.callbacks import CheckpointSaver @@ -1246,6 +1246,8 @@ def get_trainer( precision='fp32', max_duration='2ep', train_subset_num_batches=5, + use_batch_sampler: bool = False, + with_eval_dataloader: bool = True, **kwargs, ): model = SimpleModel() @@ -1257,18 +1259,54 @@ def get_trainer( eval_dataset = RandomClassificationDataset(size=12) train_batch_size = 2 - return Trainer( - model=model, - train_dataloader=DataLoader( - dataset=train_dataset, - batch_size=train_batch_size, - sampler=dist.get_sampler(train_dataset), - ), - eval_dataloader=DataLoader( + class _DistributedBatchSampler(DistributedSampler): + def __init__( + self, dataset: Dataset, batch_size: int, num_replicas: Optional[int] = None, + rank: Optional[int] = None, shuffle: bool = True, + seed: int = 0, drop_last: bool = False, + ): + super().__init__(dataset=dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed, drop_last=drop_last) + self._batch_size =batch_size + + def __iter__(self): + indices = list(super().__iter__()) + for ind_ in range(len(self)): + yield indices[ind_*self._batch_size: (ind_+1)*self._batch_size] + + def __len__(self) -> int: + return self.num_samples//self._batch_size + + train_dataloader = DataLoader( + dataset=train_dataset, + batch_size=train_batch_size, + sampler=dist.get_sampler(train_dataset), + ) if not use_batch_sampler else DataLoader( + dataset=train_dataset, + batch_sampler=_DistributedBatchSampler(dataset=train_dataset, drop_last=True, + shuffle=True, + num_replicas=dist.get_world_size(), + rank=dist.get_global_rank(), batch_size=train_batch_size), + ) + + if with_eval_dataloader is True: + eval_dataloader = DataLoader( dataset=eval_dataset, batch_size=2, sampler=dist.get_sampler(eval_dataset), - ), + ) if not use_batch_sampler else DataLoader( + dataset=eval_dataset, + batch_sampler=_DistributedBatchSampler(dataset=eval_dataset, drop_last=False, + shuffle=False, + num_replicas=dist.get_world_size(), + rank=dist.get_global_rank(), batch_size=train_batch_size), + ) + else: + eval_dataloader = None + + return Trainer( + model=model, + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, device_train_microbatch_size=train_batch_size // 2, precision=precision, train_subset_num_batches=train_subset_num_batches, @@ -1412,6 +1450,45 @@ def test_resumption( save_folder / 'second' / final_checkpoint, ) + @world_size(2) + @pytest.mark.parametrize( + 'device', + [ + pytest.param('gpu', marks=pytest.mark.gpu), + pytest.param('cpu'), + ], + ) + @pytest.mark.parametrize('max_duration', [1, 2]) + @pytest.mark.filterwarnings('ignore:An unexpected prefix is detected. This case.*') + @pytest.mark.filterwarnings( + 'ignore:``FullyShardedDataParallel.scatter_full_optim_state_dict``is being deprecated and is replaced by.*', + ) + def test_set_dataloaders_to_cur_epoch( + self, + world_size: int, + device: str, + max_duration: int, + tmp_path: pathlib.Path, + ): + # All ranks use rank 0 folder + tmp_paths = dist.all_gather_object(os.path.abspath(tmp_path)) + save_folder = pathlib.Path(tmp_paths[0]) + + trainer = self.get_trainer( + save_folder=os.path.join(save_folder, 'first'), + device=device, + precision='fp32', + max_duration=f'{max_duration}ep', + train_subset_num_batches=2, + use_batch_sampler=True, + with_eval_dataloader=False, + ) + + trainer.fit() + + # Epochs count starts at O + assert trainer.state.train_dataloader.batch_sampler.epoch == max_duration - 1 + @pytest.mark.parametrize( 'world_size', [ From 68e6d9d14b06ccca55ec19977000f6509ffe4a15 Mon Sep 17 00:00:00 2001 From: Alexandre Ghelfi Date: Mon, 25 Mar 2024 11:03:38 +0100 Subject: [PATCH 23/29] proper formating --- tests/trainer/test_checkpoint.py | 53 ++++++++++++++++++++++---------- 1 file changed, 37 insertions(+), 16 deletions(-) diff --git a/tests/trainer/test_checkpoint.py b/tests/trainer/test_checkpoint.py index c3e6c6477f..2e7e9a4c25 100644 --- a/tests/trainer/test_checkpoint.py +++ b/tests/trainer/test_checkpoint.py @@ -1260,21 +1260,34 @@ def get_trainer( train_batch_size = 2 class _DistributedBatchSampler(DistributedSampler): + def __init__( - self, dataset: Dataset, batch_size: int, num_replicas: Optional[int] = None, - rank: Optional[int] = None, shuffle: bool = True, - seed: int = 0, drop_last: bool = False, + self, + dataset: Dataset, + batch_size: int, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, ): - super().__init__(dataset=dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed, drop_last=drop_last) - self._batch_size =batch_size + super().__init__( + dataset=dataset, + num_replicas=num_replicas, + rank=rank, + shuffle=shuffle, + seed=seed, + drop_last=drop_last + ) + self._batch_size = batch_size def __iter__(self): indices = list(super().__iter__()) for ind_ in range(len(self)): - yield indices[ind_*self._batch_size: (ind_+1)*self._batch_size] - + yield indices[ind_ * self._batch_size:(ind_ + 1) * self._batch_size] + def __len__(self) -> int: - return self.num_samples//self._batch_size + return self.num_samples // self._batch_size train_dataloader = DataLoader( dataset=train_dataset, @@ -1282,10 +1295,14 @@ def __len__(self) -> int: sampler=dist.get_sampler(train_dataset), ) if not use_batch_sampler else DataLoader( dataset=train_dataset, - batch_sampler=_DistributedBatchSampler(dataset=train_dataset, drop_last=True, - shuffle=True, - num_replicas=dist.get_world_size(), - rank=dist.get_global_rank(), batch_size=train_batch_size), + batch_sampler=_DistributedBatchSampler( + dataset=train_dataset, + drop_last=True, + shuffle=True, + num_replicas=dist.get_world_size(), + rank=dist.get_global_rank(), + batch_size=train_batch_size + ), ) if with_eval_dataloader is True: @@ -1295,10 +1312,14 @@ def __len__(self) -> int: sampler=dist.get_sampler(eval_dataset), ) if not use_batch_sampler else DataLoader( dataset=eval_dataset, - batch_sampler=_DistributedBatchSampler(dataset=eval_dataset, drop_last=False, - shuffle=False, - num_replicas=dist.get_world_size(), - rank=dist.get_global_rank(), batch_size=train_batch_size), + batch_sampler=_DistributedBatchSampler( + dataset=eval_dataset, + drop_last=False, + shuffle=False, + num_replicas=dist.get_world_size(), + rank=dist.get_global_rank(), + batch_size=train_batch_size + ), ) else: eval_dataloader = None From b3e6551cf654b4c723acf89034a91d9601bd61f4 Mon Sep 17 00:00:00 2001 From: Alexandre Ghelfi Date: Wed, 27 Mar 2024 14:53:35 +0000 Subject: [PATCH 24/29] Chang _get_sampler to _get_distributed_sampler to bypass torch default batch_sampler --- composer/trainer/trainer.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index d225aa914e..73afddc98e 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -48,7 +48,7 @@ from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler from torch.nn.parallel import DistributedDataParallel from torch.optim.lr_scheduler import LRScheduler -from torch.utils.data import DataLoader, DistributedSampler, Sampler +from torch.utils.data import DataLoader, DistributedSampler from torchmetrics import Metric from composer.callbacks import CheckpointSaver, MemorySnapshot, OOMObserver, OptimizerMonitor @@ -465,14 +465,18 @@ def _generate_run_name() -> str: return generated_run_name -def _get_sampler(dataloader: DataLoader) -> Sampler | Iterable: - """Fetch the sampler from a `dataloader`. +def _get_distributed_sampler(dataloader: DataLoader) -> DistributedSampler | None: + """Fetch a distributed sampler from a `dataloader` if it exists est returns None. - Returns `dalaoder.batch_sampler` is defined, else `dataloader.sampler` (always defined in `Dataloader.__init__`). + Checks first the batch_sampler, then the sampler. + If no DistributedSampler is found, returns None. """ - if dataloader.batch_sampler is not None: + if isinstance(dataloader.batch_sampler, DistributedSampler): return dataloader.batch_sampler - return dataloader.sampler + if isinstance(dataloader.sampler, DistributedSampler): + return dataloader.sampler + + return class Trainer: @@ -2282,7 +2286,7 @@ def _spin_dataloaders_to_cur_epoch(self): eval_state = self.state.dataset_resumption.get('eval', {}) for evaluator in self.state.evaluators: dataloader = evaluator.dataloader.dataloader - sampler = _get_sampler(dataloader) if isinstance(dataloader, DataLoader) else None + sampler = _get_distributed_sampler(dataloader) if isinstance(dataloader, DataLoader) else None if isinstance(sampler, DistributedSampler): sampler.set_epoch(0) if evaluator.label not in eval_state: @@ -2294,7 +2298,7 @@ def _spin_dataloaders_to_cur_epoch(self): assert dataloader is not None, 'train dataloader is set on state after FIT_START' if 'train' not in self.state.dataset_resumption: for epoch in range(int(self.state.timestamp.epoch)): - sampler = _get_sampler(dataloader) if isinstance(dataloader, DataLoader) else None + sampler = _get_distributed_sampler(dataloader) if isinstance(dataloader, DataLoader) else None if isinstance(sampler, DistributedSampler): sampler.set_epoch(epoch) for _ in dataloader: @@ -2378,7 +2382,7 @@ def _train_loop(self) -> None: self.logger.log_metrics({'time/epoch': self.state.timestamp.epoch.value}) dataloader = self.state.dataloader - sampler = _get_sampler(dataloader) if isinstance(dataloader, DataLoader) else None + sampler = _get_distributed_sampler(dataloader) if isinstance(dataloader, DataLoader) else None if isinstance(sampler, DistributedSampler): sampler.set_epoch(int(self.state.timestamp.epoch)) @@ -3238,7 +3242,7 @@ def _eval_loop( drop_last = None dataset_len = None last_batch = False - sampler = _get_sampler(dataloader) if isinstance(dataloader, DataLoader) else None + sampler = _get_distributed_sampler(dataloader) if isinstance(dataloader, DataLoader) else None if isinstance(sampler, DistributedSampler) and isinstance(dataloader, DataLoader): # The distributed sampler uses `set_epoch` to set the random seed # Because evaluation can run on each batch, we use the batch to seed the sampler From 76b888b2e15058d45d949f7716196951bfd9add1 Mon Sep 17 00:00:00 2001 From: Alexandre Ghelfi Date: Thu, 28 Mar 2024 14:07:53 +0000 Subject: [PATCH 25/29] Move batch_sampler instanciation for test --- tests/trainer/test_checkpoint.py | 35 +++++++++++++++++--------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/tests/trainer/test_checkpoint.py b/tests/trainer/test_checkpoint.py index 81e08dce8b..621e36dacc 100644 --- a/tests/trainer/test_checkpoint.py +++ b/tests/trainer/test_checkpoint.py @@ -1289,20 +1289,30 @@ def __iter__(self): def __len__(self) -> int: return self.num_samples // self._batch_size + train_batch_sampler = _DistributedBatchSampler( + dataset=train_dataset, + drop_last=True, + shuffle=True, + num_replicas=dist.get_world_size(), + rank=dist.get_global_rank(), + batch_size=train_batch_size, + ) + eval_batch_sampler = _DistributedBatchSampler( + dataset=eval_dataset, + drop_last=False, + shuffle=False, + num_replicas=dist.get_world_size(), + rank=dist.get_global_rank(), + batch_size=train_batch_size, + ) + train_dataloader = DataLoader( dataset=train_dataset, batch_size=train_batch_size, sampler=dist.get_sampler(train_dataset), ) if not use_batch_sampler else DataLoader( dataset=train_dataset, - batch_sampler=_DistributedBatchSampler( - dataset=train_dataset, - drop_last=True, - shuffle=True, - num_replicas=dist.get_world_size(), - rank=dist.get_global_rank(), - batch_size=train_batch_size, - ), + batch_sampler=train_batch_sampler, ) if with_eval_dataloader is True: @@ -1312,14 +1322,7 @@ def __len__(self) -> int: sampler=dist.get_sampler(eval_dataset), ) if not use_batch_sampler else DataLoader( dataset=eval_dataset, - batch_sampler=_DistributedBatchSampler( - dataset=eval_dataset, - drop_last=False, - shuffle=False, - num_replicas=dist.get_world_size(), - rank=dist.get_global_rank(), - batch_size=train_batch_size, - ), + batch_sampler=eval_batch_sampler, ) else: eval_dataloader = None From dd45d3e302ffbeff7f95f029511e91425022091d Mon Sep 17 00:00:00 2001 From: Alexandre Ghelfi Date: Thu, 28 Mar 2024 15:19:33 +0000 Subject: [PATCH 26/29] fix batch_sampler setting in eval_loop --- composer/trainer/trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 73afddc98e..91a52e22ef 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -3248,6 +3248,7 @@ def _eval_loop( # Because evaluation can run on each batch, we use the batch to seed the sampler # so each evaluation will get a proper shuffle. # The epoch provided to `set_epoch` need not be sequential, so this is fine. + dist_sampler = sampler sampler.set_epoch(int(self.state.timestamp.batch)) drop_last = dataloader.drop_last # Only compute the dataset length if drop_last is False, as otherwise we don't need From 448d53189a0aa668190e9072e49a421559b292a1 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 2 Apr 2024 16:40:48 -0400 Subject: [PATCH 27/29] trainer --- composer/trainer/trainer.py | 27 ++++++++++----------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 91a52e22ef..6ff090988e 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -465,18 +465,13 @@ def _generate_run_name() -> str: return generated_run_name -def _get_distributed_sampler(dataloader: DataLoader) -> DistributedSampler | None: - """Fetch a distributed sampler from a `dataloader` if it exists est returns None. - - Checks first the batch_sampler, then the sampler. - If no DistributedSampler is found, returns None. - """ +def _get_distributed_sampler(dataloader: DataLoader) -> Optional[DistributedSampler]: + """Fetch a distributed sampler from a `dataloader` if it exists.""" if isinstance(dataloader.batch_sampler, DistributedSampler): return dataloader.batch_sampler if isinstance(dataloader.sampler, DistributedSampler): return dataloader.sampler - - return + return None class Trainer: @@ -2281,7 +2276,7 @@ def _spin_dataloaders_to_cur_epoch(self): """ log.debug('Spinning the dataloaders') - # spin the evaluator dataloaders once to initialize its sampler deterministically + # Spin the evaluator dataloaders once to initialize its sampler deterministically # so it does not affect any other RNG reads eval_state = self.state.dataset_resumption.get('eval', {}) for evaluator in self.state.evaluators: @@ -2293,12 +2288,12 @@ def _spin_dataloaders_to_cur_epoch(self): for _ in dataloader: break - # spin the train dataloader's sampler to get to the state of the desired epoch + # Spin the train dataloader's sampler to get to the state of the desired epoch dataloader = self.state.dataloader assert dataloader is not None, 'train dataloader is set on state after FIT_START' if 'train' not in self.state.dataset_resumption: + sampler = _get_distributed_sampler(dataloader) if isinstance(dataloader, DataLoader) else None for epoch in range(int(self.state.timestamp.epoch)): - sampler = _get_distributed_sampler(dataloader) if isinstance(dataloader, DataLoader) else None if isinstance(sampler, DistributedSampler): sampler.set_epoch(epoch) for _ in dataloader: @@ -3238,24 +3233,22 @@ def _eval_loop( metric.reset() dataloader = self.state.dataloader - dist_sampler = None drop_last = None dataset_len = None last_batch = False - sampler = _get_distributed_sampler(dataloader) if isinstance(dataloader, DataLoader) else None - if isinstance(sampler, DistributedSampler) and isinstance(dataloader, DataLoader): + dist_sampler = _get_distributed_sampler(dataloader) if isinstance(dataloader, DataLoader) else None + if isinstance(dist_sampler, DistributedSampler) and isinstance(dataloader, DataLoader): # The distributed sampler uses `set_epoch` to set the random seed # Because evaluation can run on each batch, we use the batch to seed the sampler # so each evaluation will get a proper shuffle. # The epoch provided to `set_epoch` need not be sequential, so this is fine. - dist_sampler = sampler - sampler.set_epoch(int(self.state.timestamp.batch)) + dist_sampler.set_epoch(int(self.state.timestamp.batch)) drop_last = dataloader.drop_last # Only compute the dataset length if drop_last is False, as otherwise we don't need # to remove any duplicate samples. if drop_last == False: try: - dataset_len = len(sampler.dataset) # type: ignore + dataset_len = len(dist_sampler.dataset) # type: ignore except AttributeError: warnings.warn( "DistributedSampler's dataset does not have length defined. When " From 522afe4127cb281a92b28fcb2c5ad8b289453b91 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 2 Apr 2024 16:45:00 -0400 Subject: [PATCH 28/29] fix --- tests/trainer/test_checkpoint.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/tests/trainer/test_checkpoint.py b/tests/trainer/test_checkpoint.py index 621e36dacc..8e23086762 100644 --- a/tests/trainer/test_checkpoint.py +++ b/tests/trainer/test_checkpoint.py @@ -1475,13 +1475,6 @@ def test_resumption( ) @world_size(2) - @pytest.mark.parametrize( - 'device', - [ - pytest.param('gpu', marks=pytest.mark.gpu), - pytest.param('cpu'), - ], - ) @pytest.mark.parametrize('max_duration', [1, 2]) @pytest.mark.filterwarnings('ignore:An unexpected prefix is detected. This case.*') @pytest.mark.filterwarnings( @@ -1490,7 +1483,6 @@ def test_resumption( def test_set_dataloaders_to_cur_epoch( self, world_size: int, - device: str, max_duration: int, tmp_path: pathlib.Path, ): @@ -1500,7 +1492,6 @@ def test_set_dataloaders_to_cur_epoch( trainer = self.get_trainer( save_folder=os.path.join(save_folder, 'first'), - device=device, precision='fp32', max_duration=f'{max_duration}ep', train_subset_num_batches=2, @@ -1512,7 +1503,7 @@ def test_set_dataloaders_to_cur_epoch( assert isinstance(trainer.state.train_dataloader, DataLoader) assert isinstance(trainer.state.train_dataloader.batch_sampler, DistributedSampler) - # Epochs count starts at O + # Epoch count starts at O assert trainer.state.train_dataloader.batch_sampler.epoch == max_duration - 1 @pytest.mark.parametrize( From fdb2efa0a83e8e7acba95e435547c622445d48c3 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 2 Apr 2024 17:41:03 -0400 Subject: [PATCH 29/29] lint --- tests/trainer/test_checkpoint.py | 71 +++++++++++++++++--------------- 1 file changed, 38 insertions(+), 33 deletions(-) diff --git a/tests/trainer/test_checkpoint.py b/tests/trainer/test_checkpoint.py index 8e23086762..c6712c8f43 100644 --- a/tests/trainer/test_checkpoint.py +++ b/tests/trainer/test_checkpoint.py @@ -1289,41 +1289,46 @@ def __iter__(self): def __len__(self) -> int: return self.num_samples // self._batch_size - train_batch_sampler = _DistributedBatchSampler( - dataset=train_dataset, - drop_last=True, - shuffle=True, - num_replicas=dist.get_world_size(), - rank=dist.get_global_rank(), - batch_size=train_batch_size, - ) - eval_batch_sampler = _DistributedBatchSampler( - dataset=eval_dataset, - drop_last=False, - shuffle=False, - num_replicas=dist.get_world_size(), - rank=dist.get_global_rank(), - batch_size=train_batch_size, - ) - - train_dataloader = DataLoader( - dataset=train_dataset, - batch_size=train_batch_size, - sampler=dist.get_sampler(train_dataset), - ) if not use_batch_sampler else DataLoader( - dataset=train_dataset, - batch_sampler=train_batch_sampler, - ) + if use_batch_sampler: + train_batch_sampler = _DistributedBatchSampler( + dataset=train_dataset, + drop_last=True, + shuffle=True, + num_replicas=dist.get_world_size(), + rank=dist.get_global_rank(), + batch_size=train_batch_size, + ) + train_dataloader = DataLoader( + dataset=train_dataset, + batch_sampler=train_batch_sampler, + ) + else: + train_dataloader = DataLoader( + dataset=train_dataset, + batch_size=train_batch_size, + sampler=dist.get_sampler(train_dataset), + ) if with_eval_dataloader is True: - eval_dataloader = DataLoader( - dataset=eval_dataset, - batch_size=2, - sampler=dist.get_sampler(eval_dataset), - ) if not use_batch_sampler else DataLoader( - dataset=eval_dataset, - batch_sampler=eval_batch_sampler, - ) + if use_batch_sampler: + eval_batch_sampler = _DistributedBatchSampler( + dataset=eval_dataset, + drop_last=False, + shuffle=False, + num_replicas=dist.get_world_size(), + rank=dist.get_global_rank(), + batch_size=train_batch_size, + ) + eval_dataloader = DataLoader( + dataset=eval_dataset, + batch_sampler=eval_batch_sampler, + ) + else: + eval_dataloader = DataLoader( + dataset=eval_dataset, + batch_size=train_batch_size, + sampler=dist.get_sampler(eval_dataset), + ) else: eval_dataloader = None