Skip to content

Commit

Permalink
feat(python): account for both multi-process and distributed torch wo…
Browse files Browse the repository at this point in the history
…rkers (lancedb#2761)
  • Loading branch information
tonyf authored and gagan-bhullar-tech committed Sep 13, 2024
1 parent 637128c commit 3637046
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 7 deletions.
10 changes: 3 additions & 7 deletions python/python/lance/torch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
ShardedFragmentSampler,
maybe_sample,
)
from .dist import get_global_rank, get_global_world_size

__all__ = ["LanceDataset"]

Expand Down Expand Up @@ -254,13 +255,8 @@ def __iter__(self):
rank = self.rank
world_size = self.world_size
else:
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
rank = worker_info.id
world_size = worker_info.num_workers
else:
rank = None
world_size = None
rank = get_global_rank()
world_size = get_global_world_size()
if self.shard_granularity is None:
if rank is not None and world_size is not None:
sampler = ShardedFragmentSampler(rank=rank, world_size=world_size)
Expand Down
80 changes: 80 additions & 0 deletions python/python/lance/torch/dist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The Lance Authors

"""Pytorch Distributed Utilities"""

import torch.distributed as dist
import torch.utils.data


def get_dist_world_size() -> int:
"""
Get the number of processes in the distributed training setup.
Returns:
int: The number of distributed processes if distributed training is initialized,
otherwise 1.
"""
if dist.is_initialized():
return dist.get_world_size()
return 1


def get_dist_rank() -> int:
"""
Get the rank of the current process in the distributed training setup.
Returns:
int: The rank of the current process if distributed training is initialized,
otherwise 0.
"""
if dist.is_initialized():
return int(dist.get_rank())
return 0


def get_mp_world_size() -> int:
"""
Get the number of worker processes for the current DataLoader.
Returns:
int: The number of worker processes if running in a DataLoader worker,
otherwise 1.
"""
if (worker_info := torch.utils.data.get_worker_info()) is not None:
return worker_info.num_workers
return 1


def get_mp_rank() -> int:
"""
Get the rank of the current DataLoader worker process.
Returns:
int: The rank of the current DataLoader worker if running in a worker process,
otherwise 0.
"""
if (worker_info := torch.utils.data.get_worker_info()) is not None:
return worker_info.id
return 0


def get_global_world_size() -> int:
"""
Get the global world size across distributed and multiprocessing contexts.
Returns:
int: The global world size, defaulting to 1 if not set in the environment.
"""
return get_dist_world_size() * get_mp_world_size()


def get_global_rank() -> int:
"""
Get the global rank of the current process across distributed and
multiprocessing contexts.
Returns:
int: The global rank of the current process.
"""
return get_dist_rank() * get_mp_world_size() + get_mp_rank()

0 comments on commit 3637046

Please sign in to comment.