Skip to content

Commit

Permalink
[AutoParallel] shard_dataloader support list inputs (#62229)
Browse files Browse the repository at this point in the history
* [AutoParallel] shard_dataloader support list inputs

* add an example

* fix doc example error

* add doc

* fix

* fix

* fix doc
  • Loading branch information
deepllz authored Mar 1, 2024
1 parent 7ea78b6 commit e5404f0
Show file tree
Hide file tree
Showing 5 changed files with 448 additions and 25 deletions.
195 changes: 170 additions & 25 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2018,22 +2018,22 @@ def __init__(
process_id, self._meshes
)
)
if input_keys is not None:
assert len(input_keys) == 2, "input_keys lengths must be 2"

self._all_inputs_in_one_mesh = len(self._meshes) == 1
self._input_keys = input_keys
self._shard_dims = self._process_shard_dims(shard_dims)

mesh_index = self._get_mesh_idx(process_id)
if mesh_index == -1:
mesh, shard_dim = self._get_mesh_and_shard_dim(process_id)
if mesh is None:
mesh = to_list(self._meshes[0])[0]
shard_dim = to_list(self._shard_dims[0])[0]
dp_rank = 0
dp_world_size = self._meshes[0].get_dim_size(self._shard_dims[0])
dp_world_size = mesh.get_dim_size(shard_dim)
else:
dp_rank = self._meshes[mesh_index].get_rank_by_dim_and_process_id(
self._shard_dims[mesh_index], process_id
)
dp_world_size = self._meshes[mesh_index].get_dim_size(
self._shard_dims[mesh_index]
)
dp_rank = mesh.get_rank_by_dim_and_process_id(shard_dim, process_id)
dp_world_size = mesh.get_dim_size(shard_dim)

if is_dataset_splitted is True or shard_dims is None:
self._dataloader = dataloader
Expand Down Expand Up @@ -2074,7 +2074,13 @@ def __init__(

def _process_shard_dims(self, shard_dims):
if isinstance(shard_dims, (int, str)) or shard_dims is None:
return [shard_dims] * len(self._meshes)
res = []
for i in range(len(self._meshes)):
if isinstance(self._meshes[i], (list, tuple)):
res.append([shard_dims] * len(self._meshes[i]))
else:
res.append(shard_dims)
return res
else:
if len(shard_dims) != len(self._meshes):
raise ValueError(
Expand All @@ -2084,16 +2090,30 @@ def _process_shard_dims(self, shard_dims):
)
return shard_dims

def _get_mesh_idx(self, process_id):
def _get_mesh_and_shard_dim(self, process_id):
for i in range(len(self._meshes)):
if process_id in self._meshes[i]._process_ids:
return i
return -1
if isinstance(self._meshes[i], (list, tuple)):
for j in range(len(self._meshes[i])):
if process_id in self._meshes[i][j]._process_ids:
return self._meshes[i][j], self._shard_dims[i][j]
else:
if process_id in self._meshes[i]._process_ids:
return self._meshes[i], self._shard_dims[i]
return None, None

def _process_id_in_multi_meshes(self, process_id):
count = 0
for i in range(len(self._meshes)):
if process_id in self._meshes[i]._process_ids:
flatten_meshes = []
for mesh in self._meshes:
if isinstance(mesh, (list, tuple)):
flatten_meshes.extend(mesh)
else:
flatten_meshes.append(mesh)

# NOTE(zhengzhonghui): User may set the same mesh for different inputs, so we need to unique the meshes
unique_meshes = list(set(flatten_meshes))
for mesh in unique_meshes:
if process_id in mesh._process_ids:
count += 1
return count > 1

Expand Down Expand Up @@ -2123,27 +2143,96 @@ def _get_mesh_and_placement(self, index):
placements.append(dist.Replicate())
return mesh, placements

def _get_meshes_and_placements_for_list_input(self, index, length):
if self._all_inputs_in_one_mesh:
meshes = [self._meshes[0]] * length
shard_dims = [self._shard_dims[0]] * length
else:
meshes = self._meshes[index]
if isinstance(meshes, (list, tuple)):
assert len(meshes) == length
else:
meshes = [meshes] * length
shard_dims = self._shard_dims[index]
if isinstance(shard_dims, (list, tuple)):
assert len(shard_dims) == length
else:
shard_dims = [shard_dims] * length

placements = []
for i in range(length):
if shard_dims[i] is not None:
placement = [dist.Shard(0)]
else:
placement = [dist.Replicate()]
for _ in range(1, len(meshes[i]._shape)):
placement.append(dist.Replicate())
placements.append(placement)
return meshes, placements

def _dtensors_from_list_input(self, list_tensors, meshes, placements):
dist_data = []
for j in range(len(list_tensors)):
dist_data.append(
dtensor_from_local(list_tensors[j], meshes[j], placements[j])
)
return dist_data

def _get_batch(self, batch_data):
if isinstance(batch_data, (list, tuple)):
if self._all_inputs_in_one_mesh is False:
assert len(batch_data) == len(self._meshes)
dist_batch_data = []
for i in range(len(batch_data)):
mesh, placements = self._get_mesh_and_placement(i)
dist_batch_data.append(
dtensor_from_local(batch_data[i], mesh, placements)
)
input_data = batch_data[i]
if isinstance(input_data, (list, tuple)):
(
meshes,
placements,
) = self._get_meshes_and_placements_for_list_input(
i, len(input_data)
)
dist_batch_data.append(
self._dtensors_from_list_input(
input_data, meshes, placements
)
)
elif isinstance(input_data, paddle.Tensor):
mesh, placements = self._get_mesh_and_placement(i)
dist_batch_data.append(
dtensor_from_local(input_data, mesh, placements)
)
else:
raise ValueError(
f"Unsupported input_data type {type(input_data)}"
)
return dist_batch_data
elif isinstance(batch_data, dict):
if self._all_inputs_in_one_mesh is False:
assert len(self._input_keys) == len(self._meshes)
dist_batch_data = {}
for i in range(len(self._input_keys)):
key = self._input_keys[i]
mesh, placements = self._get_mesh_and_placement(i)
dist_batch_data[key] = dtensor_from_local(
batch_data[key], mesh, placements
)
input_data = batch_data[key]
if isinstance(input_data, (list, tuple)):
(
meshes,
placements,
) = self._get_meshes_and_placements_for_list_input(
i, len(input_data)
)
dist_batch_data[key] = self._dtensors_from_list_input(
input_data, meshes, placements
)
elif isinstance(input_data, paddle.Tensor):
mesh, placements = self._get_mesh_and_placement(i)
dist_batch_data[key] = dtensor_from_local(
batch_data[key], mesh, placements
)
else:
raise ValueError(
f"Unsupported input_data type {type(input_data)}"
)
return dist_batch_data
else:
raise ValueError(f"Unsupported batch_data type {type(batch_data)}")
Expand Down Expand Up @@ -2173,7 +2262,9 @@ def shard_dataloader(
only if is_dataset_splitted is False and shard_dims is not None, it will do split.
Args:
dataloader (paddle.io.DataLoader): The dataloader to be sharded.
dataloader (paddle.io.DataLoader): The dataloader to be sharded. the output of dataloader
must be a list or dict of paddle.Tensor with 2 elements, i.e. [input_data, label] or
{"input_data": input_data, "label": label}, input_data and label can be a list to support multiple inputs.
meshes (ProcessMesh|list[ProcessMesh]|tuple[ProcessMesh]): The mesh list of the dataloader.
Identify which mesh the input is on. if len(meshes) == 1 or type(meshes) == ProcessMesh,
all the inputs are on the same mesh.
Expand All @@ -2191,6 +2282,7 @@ def shard_dataloader(
Examples:
.. code-block:: python
:name: example-1
>>> import paddle
>>> import paddle.distributed as dist
Expand Down Expand Up @@ -2286,6 +2378,59 @@ def shard_dataloader(
>>> # RUN_STATIC=1 python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" {test_case}.py
>>> # RUN_STATIC=0 python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" {test_case}.py
.. code-block:: python
:name: example-2
>>> import paddle
>>> import paddle.distributed as dist
>>> from paddle.io import BatchSampler, DataLoader, Dataset
>>> import numpy as np
>>> mesh0 = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=['dp', 'mp'])
>>> mesh1 = dist.ProcessMesh([[4, 5], [6, 7]], dim_names=['dp', 'mp'])
>>> class RandomDataset(Dataset):
... def __init__(self, seq_len, hidden, num_samples=8):
... super().__init__()
... self.seq_len = seq_len
... self.hidden = hidden
... self.num_samples = num_samples
... self.inputs1 = [
... np.random.uniform(size=[self.seq_len, self.hidden]).astype(
... "float32"
... )
... for _ in range(num_samples)
... ]
... self.inputs2 = [
... np.random.uniform(size=[self.seq_len, self.hidden]).astype(
... "float32"
... )
... for _ in range(num_samples)
... ]
... self.labels = [
... np.array(index, dtype="float32") for index in range(num_samples)
... ]
... def __getitem__(self, index):
... return {
... "inputs": [self.inputs1[index], self.inputs2[index]],
... "label": self.labels[index],
... }
... def __len__(self):
... return self.num_samples
>>> dataset = RandomDataset(4, 8)
>>> sampler = BatchSampler(
... dataset,
... batch_size=2,
... )
>>> dataloader = DataLoader(
... dataset,
... batch_sampler=sampler,
... )
>>> dist_dataloader = dist.shard_dataloader(
... dataloader=dataloader,
... meshes=[mesh0, mesh1], # or [[mesh0, mesh0], mesh1]
... shard_dims="dp",
... input_keys=["inputs", "label"],
... )
"""

return ShardDataloader(
Expand Down
8 changes: 8 additions & 0 deletions test/auto_parallel/hybrid_strategy/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,11 @@ if((WITH_GPU) AND (LINUX))
set_tests_properties(test_semi_auto_parallel_global_input
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=HYBRID")
endif()
if((WITH_GPU) AND (LINUX))
py_test_modules(
test_semi_auto_parallel_multi_inputs MODULES
test_semi_auto_parallel_multi_inputs ENVS
"http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_semi_auto_parallel_multi_inputs
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=HYBRID")
endif()
Loading

0 comments on commit e5404f0

Please sign in to comment.