Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix HSDP with ShardDegree < 8 #3313

Merged
merged 8 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner):
if self.device_mesh is not None and self.device_mesh.ndim == 2:
# Broadcast file to all replicas
replicate_process_group = self.device_mesh.get_group(0)
shard_process_group = self.device_mesh.get_group(1)
shard_size = self.device_mesh.size(1)
rank_in_first_replica = dist.get_global_rank() % shard_size
sender = dist.get_global_rank() == rank_in_first_replica
Expand All @@ -304,7 +305,9 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner):

# Send each file to the appropriate rank
for file_name in file_list:
if dist.get_local_rank() == 0: # Only 1 rank per node needs to transfer file
if dist.get_local_rank() == 0 or (
dist.get_global_rank(shard_process_group) == 0 # pyright: ignore[reportGeneralTypeIssues]
): # Only 1 rank per node needs to transfer file
full_path = os.path.join(self.destination_path, file_name)
log.debug(f'Transferring {full_path=}')
file_object = [None]
Expand All @@ -318,7 +321,7 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner):
)
received_file_object = file_object[0]
assert received_file_object is not None
if receiver and not os.path.exists(full_path):
if receiver and not os.path.exists(full_path) and dist.get_local_rank() == 0:
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved
with open(full_path, 'wb') as f:
f.write(received_file_object['content'])

Expand Down
18 changes: 14 additions & 4 deletions composer/utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,13 +225,23 @@ def get_world_size() -> int:
)


def get_global_rank() -> int:
"""Returns the global rank of the current process, which is on ``[0; WORLD_SIZE - 1]``.
def get_global_rank(group: Optional[dist.ProcessGroup] = None) -> int:
"""Returns the global rank of the current process in the input PG, which is on ``[0; group.WORLD_SIZE - 1]``.

Args:
group (ProcessGroup, optional): The process group. If ``None``, it will return env_var ``RANK``

Returns:
int: The global rank.
int: The global rank in input process group.
"""
return _get_distributed_config_var(env_var='RANK', human_name='global rank', default=0, fetch_fn_name='get_rank')
if group is None:
return _get_distributed_config_var(
env_var='RANK',
human_name='global rank',
default=0,
fetch_fn_name='get_rank',
)
return dist.get_rank(group)


def get_local_world_size() -> int:
Expand Down
Loading