From 3637046d18a9d62e0b1ee1d420b4edfa8bbb8fb9 Mon Sep 17 00:00:00 2001 From: tony Date: Tue, 27 Aug 2024 08:18:34 -0400 Subject: [PATCH] feat(python): account for both multi-process and distributed torch workers (#2761) --- python/python/lance/torch/data.py | 10 ++-- python/python/lance/torch/dist.py | 80 +++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 7 deletions(-) create mode 100644 python/python/lance/torch/dist.py diff --git a/python/python/lance/torch/data.py b/python/python/lance/torch/data.py index 16230bde43..0062060b08 100644 --- a/python/python/lance/torch/data.py +++ b/python/python/lance/torch/data.py @@ -26,6 +26,7 @@ ShardedFragmentSampler, maybe_sample, ) +from .dist import get_global_rank, get_global_world_size __all__ = ["LanceDataset"] @@ -254,13 +255,8 @@ def __iter__(self): rank = self.rank world_size = self.world_size else: - worker_info = torch.utils.data.get_worker_info() - if worker_info is not None: - rank = worker_info.id - world_size = worker_info.num_workers - else: - rank = None - world_size = None + rank = get_global_rank() + world_size = get_global_world_size() if self.shard_granularity is None: if rank is not None and world_size is not None: sampler = ShardedFragmentSampler(rank=rank, world_size=world_size) diff --git a/python/python/lance/torch/dist.py b/python/python/lance/torch/dist.py new file mode 100644 index 0000000000..f77376e2f9 --- /dev/null +++ b/python/python/lance/torch/dist.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The Lance Authors + +"""Pytorch Distributed Utilities""" + +import torch.distributed as dist +import torch.utils.data + + +def get_dist_world_size() -> int: + """ + Get the number of processes in the distributed training setup. + + Returns: + int: The number of distributed processes if distributed training is initialized, + otherwise 1. + """ + if dist.is_initialized(): + return dist.get_world_size() + return 1 + + +def get_dist_rank() -> int: + """ + Get the rank of the current process in the distributed training setup. + + Returns: + int: The rank of the current process if distributed training is initialized, + otherwise 0. + """ + if dist.is_initialized(): + return int(dist.get_rank()) + return 0 + + +def get_mp_world_size() -> int: + """ + Get the number of worker processes for the current DataLoader. + + Returns: + int: The number of worker processes if running in a DataLoader worker, + otherwise 1. + """ + if (worker_info := torch.utils.data.get_worker_info()) is not None: + return worker_info.num_workers + return 1 + + +def get_mp_rank() -> int: + """ + Get the rank of the current DataLoader worker process. + + Returns: + int: The rank of the current DataLoader worker if running in a worker process, + otherwise 0. + """ + if (worker_info := torch.utils.data.get_worker_info()) is not None: + return worker_info.id + return 0 + + +def get_global_world_size() -> int: + """ + Get the global world size across distributed and multiprocessing contexts. + + Returns: + int: The global world size, defaulting to 1 if not set in the environment. + """ + return get_dist_world_size() * get_mp_world_size() + + +def get_global_rank() -> int: + """ + Get the global rank of the current process across distributed and + multiprocessing contexts. + + Returns: + int: The global rank of the current process. + """ + return get_dist_rank() * get_mp_world_size() + get_mp_rank()