diff --git a/python/paddle/distributed/auto_parallel/api.py b/python/paddle/distributed/auto_parallel/api.py index 28f15011190f2..c63f8ce3a58c9 100644 --- a/python/paddle/distributed/auto_parallel/api.py +++ b/python/paddle/distributed/auto_parallel/api.py @@ -2018,22 +2018,22 @@ def __init__( process_id, self._meshes ) ) + if input_keys is not None: + assert len(input_keys) == 2, "input_keys lengths must be 2" self._all_inputs_in_one_mesh = len(self._meshes) == 1 self._input_keys = input_keys self._shard_dims = self._process_shard_dims(shard_dims) - mesh_index = self._get_mesh_idx(process_id) - if mesh_index == -1: + mesh, shard_dim = self._get_mesh_and_shard_dim(process_id) + if mesh is None: + mesh = to_list(self._meshes[0])[0] + shard_dim = to_list(self._shard_dims[0])[0] dp_rank = 0 - dp_world_size = self._meshes[0].get_dim_size(self._shard_dims[0]) + dp_world_size = mesh.get_dim_size(shard_dim) else: - dp_rank = self._meshes[mesh_index].get_rank_by_dim_and_process_id( - self._shard_dims[mesh_index], process_id - ) - dp_world_size = self._meshes[mesh_index].get_dim_size( - self._shard_dims[mesh_index] - ) + dp_rank = mesh.get_rank_by_dim_and_process_id(shard_dim, process_id) + dp_world_size = mesh.get_dim_size(shard_dim) if is_dataset_splitted is True or shard_dims is None: self._dataloader = dataloader @@ -2074,7 +2074,13 @@ def __init__( def _process_shard_dims(self, shard_dims): if isinstance(shard_dims, (int, str)) or shard_dims is None: - return [shard_dims] * len(self._meshes) + res = [] + for i in range(len(self._meshes)): + if isinstance(self._meshes[i], (list, tuple)): + res.append([shard_dims] * len(self._meshes[i])) + else: + res.append(shard_dims) + return res else: if len(shard_dims) != len(self._meshes): raise ValueError( @@ -2084,16 +2090,30 @@ def _process_shard_dims(self, shard_dims): ) return shard_dims - def _get_mesh_idx(self, process_id): + def _get_mesh_and_shard_dim(self, process_id): for i in range(len(self._meshes)): - if process_id in self._meshes[i]._process_ids: - return i - return -1 + if isinstance(self._meshes[i], (list, tuple)): + for j in range(len(self._meshes[i])): + if process_id in self._meshes[i][j]._process_ids: + return self._meshes[i][j], self._shard_dims[i][j] + else: + if process_id in self._meshes[i]._process_ids: + return self._meshes[i], self._shard_dims[i] + return None, None def _process_id_in_multi_meshes(self, process_id): count = 0 - for i in range(len(self._meshes)): - if process_id in self._meshes[i]._process_ids: + flatten_meshes = [] + for mesh in self._meshes: + if isinstance(mesh, (list, tuple)): + flatten_meshes.extend(mesh) + else: + flatten_meshes.append(mesh) + + # NOTE(zhengzhonghui): User may set the same mesh for different inputs, so we need to unique the meshes + unique_meshes = list(set(flatten_meshes)) + for mesh in unique_meshes: + if process_id in mesh._process_ids: count += 1 return count > 1 @@ -2123,16 +2143,69 @@ def _get_mesh_and_placement(self, index): placements.append(dist.Replicate()) return mesh, placements + def _get_meshes_and_placements_for_list_input(self, index, length): + if self._all_inputs_in_one_mesh: + meshes = [self._meshes[0]] * length + shard_dims = [self._shard_dims[0]] * length + else: + meshes = self._meshes[index] + if isinstance(meshes, (list, tuple)): + assert len(meshes) == length + else: + meshes = [meshes] * length + shard_dims = self._shard_dims[index] + if isinstance(shard_dims, (list, tuple)): + assert len(shard_dims) == length + else: + shard_dims = [shard_dims] * length + + placements = [] + for i in range(length): + if shard_dims[i] is not None: + placement = [dist.Shard(0)] + else: + placement = [dist.Replicate()] + for _ in range(1, len(meshes[i]._shape)): + placement.append(dist.Replicate()) + placements.append(placement) + return meshes, placements + + def _dtensors_from_list_input(self, list_tensors, meshes, placements): + dist_data = [] + for j in range(len(list_tensors)): + dist_data.append( + dtensor_from_local(list_tensors[j], meshes[j], placements[j]) + ) + return dist_data + def _get_batch(self, batch_data): if isinstance(batch_data, (list, tuple)): if self._all_inputs_in_one_mesh is False: assert len(batch_data) == len(self._meshes) dist_batch_data = [] for i in range(len(batch_data)): - mesh, placements = self._get_mesh_and_placement(i) - dist_batch_data.append( - dtensor_from_local(batch_data[i], mesh, placements) - ) + input_data = batch_data[i] + if isinstance(input_data, (list, tuple)): + ( + meshes, + placements, + ) = self._get_meshes_and_placements_for_list_input( + i, len(input_data) + ) + dist_batch_data.append( + self._dtensors_from_list_input( + input_data, meshes, placements + ) + ) + elif isinstance(input_data, paddle.Tensor): + mesh, placements = self._get_mesh_and_placement(i) + dist_batch_data.append( + dtensor_from_local(input_data, mesh, placements) + ) + else: + raise ValueError( + f"Unsupported input_data type {type(input_data)}" + ) return dist_batch_data elif isinstance(batch_data, dict): if self._all_inputs_in_one_mesh is False: @@ -2140,10 +2213,26 @@ def _get_batch(self, batch_data): dist_batch_data = {} for i in range(len(self._input_keys)): key = self._input_keys[i] - mesh, placements = self._get_mesh_and_placement(i) - dist_batch_data[key] = dtensor_from_local( - batch_data[key], mesh, placements - ) + input_data = batch_data[key] + if isinstance(input_data, (list, tuple)): + ( + meshes, + placements, + ) = self._get_meshes_and_placements_for_list_input( + i, len(input_data) + ) + dist_batch_data[key] = self._dtensors_from_list_input( + input_data, meshes, placements + ) + elif isinstance(input_data, paddle.Tensor): + mesh, placements = self._get_mesh_and_placement(i) + dist_batch_data[key] = dtensor_from_local( + batch_data[key], mesh, placements + ) + else: + raise ValueError( + f"Unsupported input_data type {type(input_data)}" + ) return dist_batch_data else: raise ValueError(f"Unsupported batch_data type {type(batch_data)}") @@ -2173,7 +2262,9 @@ def shard_dataloader( only if is_dataset_splitted is False and shard_dims is not None, it will do split. Args: - dataloader (paddle.io.DataLoader): The dataloader to be sharded. + dataloader (paddle.io.DataLoader): The dataloader to be sharded. the output of dataloader + must be a list or dict of paddle.Tensor with 2 elements, i.e. [input_data, label] or + {"input_data": input_data, "label": label}, input_data and label can be a list to support multiple inputs. meshes (ProcessMesh|list[ProcessMesh]|tuple[ProcessMesh]): The mesh list of the dataloader. Identify which mesh the input is on. if len(meshes) == 1 or type(meshes) == ProcessMesh, all the inputs are on the same mesh. @@ -2191,6 +2282,7 @@ def shard_dataloader( Examples: .. code-block:: python + :name: example-1 >>> import paddle >>> import paddle.distributed as dist @@ -2286,6 +2378,59 @@ def shard_dataloader( >>> # RUN_STATIC=1 python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" {test_case}.py >>> # RUN_STATIC=0 python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" {test_case}.py + .. code-block:: python + :name: example-2 + + >>> import paddle + >>> import paddle.distributed as dist + >>> from paddle.io import BatchSampler, DataLoader, Dataset + >>> import numpy as np + >>> mesh0 = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=['dp', 'mp']) + >>> mesh1 = dist.ProcessMesh([[4, 5], [6, 7]], dim_names=['dp', 'mp']) + >>> class RandomDataset(Dataset): + ... def __init__(self, seq_len, hidden, num_samples=8): + ... super().__init__() + ... self.seq_len = seq_len + ... self.hidden = hidden + ... self.num_samples = num_samples + ... self.inputs1 = [ + ... np.random.uniform(size=[self.seq_len, self.hidden]).astype( + ... "float32" + ... ) + ... for _ in range(num_samples) + ... ] + ... self.inputs2 = [ + ... np.random.uniform(size=[self.seq_len, self.hidden]).astype( + ... "float32" + ... ) + ... for _ in range(num_samples) + ... ] + ... self.labels = [ + ... np.array(index, dtype="float32") for index in range(num_samples) + ... ] + ... def __getitem__(self, index): + ... return { + ... "inputs": [self.inputs1[index], self.inputs2[index]], + ... "label": self.labels[index], + ... } + ... def __len__(self): + ... return self.num_samples + + >>> dataset = RandomDataset(4, 8) + >>> sampler = BatchSampler( + ... dataset, + ... batch_size=2, + ... ) + >>> dataloader = DataLoader( + ... dataset, + ... batch_sampler=sampler, + ... ) + >>> dist_dataloader = dist.shard_dataloader( + ... dataloader=dataloader, + ... meshes=[mesh0, mesh1], # or [[mesh0, mesh0], mesh1] + ... shard_dims="dp", + ... input_keys=["inputs", "label"], + ... ) """ return ShardDataloader( diff --git a/test/auto_parallel/hybrid_strategy/CMakeLists.txt b/test/auto_parallel/hybrid_strategy/CMakeLists.txt index 08a9f42c02a1f..063b1b5873e74 100644 --- a/test/auto_parallel/hybrid_strategy/CMakeLists.txt +++ b/test/auto_parallel/hybrid_strategy/CMakeLists.txt @@ -73,3 +73,11 @@ if((WITH_GPU) AND (LINUX)) set_tests_properties(test_semi_auto_parallel_global_input PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=HYBRID") endif() +if((WITH_GPU) AND (LINUX)) + py_test_modules( + test_semi_auto_parallel_multi_inputs MODULES + test_semi_auto_parallel_multi_inputs ENVS + "http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python") + set_tests_properties(test_semi_auto_parallel_multi_inputs + PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=HYBRID") +endif() diff --git a/test/auto_parallel/hybrid_strategy/semi_auto_parallel_multi_inputs.py b/test/auto_parallel/hybrid_strategy/semi_auto_parallel_multi_inputs.py new file mode 100644 index 0000000000000..a7166ca901d09 --- /dev/null +++ b/test/auto_parallel/hybrid_strategy/semi_auto_parallel_multi_inputs.py @@ -0,0 +1,212 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np + +import paddle +import paddle.distributed as dist +from paddle.io import BatchSampler, DataLoader, Dataset + +SEQ_LEN = 4 +HIDDLE_SIZE = 8 +global_mesh = dist.ProcessMesh( + [[[0, 1], [2, 3]], [[4, 5], [6, 7]]], dim_names=['pp', 'dp', 'mp'] +) +mesh0 = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=['dp', 'mp']) +mesh1 = dist.ProcessMesh([[4, 5], [6, 7]], dim_names=['dp', 'mp']) + + +class MlpModel(paddle.nn.Layer): + def __init__(self, variable_initial_values, run_single_process=False): + super().__init__() + self.w0 = self.create_parameter( + shape=[HIDDLE_SIZE, HIDDLE_SIZE], + default_initializer=paddle.nn.initializer.Assign( + variable_initial_values[0] + ), + ) + self.w1 = self.create_parameter( + shape=[HIDDLE_SIZE, HIDDLE_SIZE], + default_initializer=paddle.nn.initializer.Assign( + variable_initial_values[1] + ), + ) + if run_single_process is False: + self.w0 = dist.shard_tensor( + self.w0, + mesh0, + [dist.Replicate(), dist.Shard(1)], + ) + self.w1 = dist.shard_tensor( + self.w1, + mesh1, + [dist.Replicate(), dist.Shard(0)], + ) + self.run_single_process = run_single_process + + def forward(self, input1, input2): + x = input1 + input2 + # x: [bs, seq_len, hidden] + # forward on mesh0 + y = paddle.matmul(x, self.w0) + # forward on mesh1 + if self.run_single_process is False: + y = dist.reshard(y, mesh1, [dist.Shard(0), dist.Shard(2)]) + z = paddle.matmul(y, self.w1) + return z + + +class RandomDataset(Dataset): + def __init__(self, seq_len, hidden, num_samples=8): + super().__init__() + self.seq_len = seq_len + self.hidden = hidden + self.num_samples = num_samples + self.inputs1 = [ + np.random.uniform(size=[self.seq_len, self.hidden]).astype( + "float32" + ) + for _ in range(num_samples) + ] + self.inputs2 = [ + np.random.uniform(size=[self.seq_len, self.hidden]).astype( + "float32" + ) + for _ in range(num_samples) + ] + self.labels = [ + np.array(index, dtype="float32") for index in range(num_samples) + ] + + def __getitem__(self, index): + return { + "inputs": [self.inputs1[index], self.inputs2[index]], + "label": self.labels[index], + } + + def __len__(self): + return self.num_samples + + +def create_dataloader(): + dataset = RandomDataset(SEQ_LEN, HIDDLE_SIZE) + sampler = BatchSampler( + dataset, + batch_size=2, + ) + dataloader = DataLoader( + dataset, + batch_sampler=sampler, + ) + return dataloader + + +def get_variable_initial_value(var_num=2): + res = [] + for i in range(var_num): + res.append( + paddle.uniform( + shape=[HIDDLE_SIZE, HIDDLE_SIZE], + dtype=paddle.float32, + min=-0.0001, + max=0.0001, + ) + ) + return res + + +def loss_fn(logits, label): + # logits: [bs, seq_len, hidden], label: [bs] + loss = paddle.nn.MSELoss(reduction="sum") + logits = paddle.sum(logits, axis=[1, 2]) + return loss(logits, label) + + +class TestSemiAutoParallelMultiInputs: + def __init__(self): + self._backend = os.getenv("backend") + self._seed = eval(os.getenv("seed")) + self._run_static = eval(os.getenv("run_static")) + paddle.seed(self._seed) + np.random.seed(self._seed) + paddle.set_device(self._backend) + self.dataloader = create_dataloader() + self.variable_initial_values = get_variable_initial_value() + self.single_process_loss = self.get_single_process_loss() + + def get_single_process_loss(self): + model = MlpModel( + variable_initial_values=self.variable_initial_values, + run_single_process=True, + ) + opt = paddle.optimizer.AdamW( + learning_rate=0.001, parameters=model.parameters() + ) + for step, data in enumerate(self.dataloader()): + input1, input2 = data["inputs"] + logits = model(input1, input2) + label = data["label"] + loss = loss_fn(logits, label) + loss.backward() + opt.step() + opt.clear_grad() + return loss.numpy() + + def test_basic(self): + model = MlpModel(variable_initial_values=self.variable_initial_values) + opt = paddle.optimizer.AdamW( + learning_rate=0.001, parameters=model.parameters() + ) + dist_dataloader = dist.shard_dataloader( + dataloader=self.dataloader, + meshes=[mesh0, mesh1], # or [[mesh0, mesh0], mesh1] + shard_dims="dp", + input_keys=["inputs", "label"], + ) + cur_rank = paddle.distributed.get_rank() + if self._run_static: + dist_model = dist.to_static(model, dist_dataloader, loss_fn, opt) + + for step, data in enumerate(dist_dataloader()): + input1, input2 = data["inputs"] + label = data["label"] + loss = dist_model(input1, input2, label) + + if cur_rank in [5, 7]: + loss = paddle.to_tensor(loss) + group = paddle.distributed.new_group([5, 7]) + dist.all_reduce(loss, group=group) + else: + dist_opt = dist.shard_optimizer(opt) + for step, data in enumerate(dist_dataloader()): + input1, input2 = data["inputs"] + logits = model(input1, input2) + label = data["label"] + loss = loss_fn(logits, label) + loss.backward() + dist_opt.step() + dist_opt.clear_grad() + if cur_rank in [5, 7]: + np.testing.assert_allclose( + loss.numpy(), self.single_process_loss, rtol=1e-06, verbose=True + ) + + def run_test_case(self): + self.test_basic() + + +if __name__ == '__main__': + TestSemiAutoParallelMultiInputs().run_test_case() diff --git a/test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_multi_inputs.py b/test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_multi_inputs.py new file mode 100644 index 0000000000000..e172ba1da70f5 --- /dev/null +++ b/test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_multi_inputs.py @@ -0,0 +1,57 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import collective.test_communication_api_base as test_base + + +class TestSemiAutoParallelMultiInputs(test_base.CommunicationTestDistBase): + def setUp(self): + super().setUp( + num_of_devices=8, + timeout=120, + nnode=1, + ) + self._default_envs = { + "dtype": "float32", + "seed": "1024", + } + self._changeable_envs = {"backend": ["gpu"]} + + def test_dynamic(self): + self._default_envs.update({"run_static": "0"}) + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_multi_inputs.py", + user_defined_envs=envs, + ) + + def test_static(self): + self._default_envs.update({"run_static": "1"}) + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_multi_inputs.py", + user_defined_envs=envs, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/auto_parallel/hybrid_strategy/testslist.csv b/test/auto_parallel/hybrid_strategy/testslist.csv index 5791b71d0d5ff..2fac60515b51a 100644 --- a/test/auto_parallel/hybrid_strategy/testslist.csv +++ b/test/auto_parallel/hybrid_strategy/testslist.csv @@ -8,3 +8,4 @@ test_semi_auto_parallel_llama_model_amp,LINUX,GPU,180,HYBRID,test_runner.py,,,ht test_semi_auto_parallel_hybrid_sharding_strategy,LINUX,GPU,120,HYBRID,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_global_mesh_reshard,LINUX,GPU,120,HYBRID,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_semi_auto_parallel_global_input,LINUX,GPU,120,HYBRID,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_semi_auto_parallel_multi_inputs,LINUX,GPU,120,HYBRID,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..,