Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): account for both multi-process and distributed torch workers #2761

Merged
merged 9 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions python/python/lance/torch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
ShardedFragmentSampler,
maybe_sample,
)
from .dist import get_global_rank, get_global_world_size

__all__ = ["LanceDataset"]

Expand Down Expand Up @@ -254,13 +255,8 @@ def __iter__(self):
rank = self.rank
world_size = self.world_size
else:
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
rank = worker_info.id
world_size = worker_info.num_workers
else:
rank = None
world_size = None
rank = get_global_rank()
world_size = get_global_world_size()
if self.shard_granularity is None:
if rank is not None and world_size is not None:
sampler = ShardedFragmentSampler(rank=rank, world_size=world_size)
Expand Down
79 changes: 79 additions & 0 deletions python/python/lance/torch/dist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The Lance Authors

""" Pytorch Distributed Utilities """

import torch.distributed as dist
import torch.utils.data


def get_dist_world_size() -> int:
"""
Get the number of processes in the distributed training setup.

Returns:
int: The number of distributed processes if distributed training is initialized,
otherwise 1.
"""
if dist.is_initialized():
return dist.get_world_size()
return 1


def get_dist_rank() -> int:
"""
Get the rank of the current process in the distributed training setup.

Returns:
int: The rank of the current process if distributed training is initialized,
otherwise 0.
"""
if dist.is_initialized():
return int(dist.get_rank())
return 0


def get_mp_world_size() -> int:
"""
Get the number of worker processes for the current DataLoader.

Returns:
int: The number of worker processes if running in a DataLoader worker,
otherwise 1.
"""
if (worker_info := torch.utils.data.get_worker_info()) is not None:
return worker_info.num_workers
return 1


def get_mp_rank() -> int:
"""
Get the rank of the current DataLoader worker process.

Returns:
int: The rank of the current DataLoader worker if running in a worker process,
otherwise 0.
"""
if (worker_info := torch.utils.data.get_worker_info()) is not None:
return worker_info.id
return 0


def get_global_world_size() -> int:
"""
Get the global world size across distributed and multiprocessing contexts.

Returns:
int: The global world size, defaulting to 1 if not set in the environment.
"""
return get_dist_world_size() * get_mp_world_size()


def get_global_rank() -> int:
"""
Get the global rank of the current process across distributed and multiprocessing contexts.

Returns:
int: The global rank of the current process.
"""
return get_dist_rank() * get_mp_world_size() + get_mp_rank()
Loading