Skip to content

Commit

Permalink
[Zero-Dim] Support paddle.sum/mean/loss api output 0D,test=allcase (P…
Browse files Browse the repository at this point in the history
  • Loading branch information
zhwesky2010 committed May 8, 2023
1 parent 95d95fb commit 3be3b57
Show file tree
Hide file tree
Showing 57 changed files with 380 additions and 258 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ void sum_grad(const Tensor& x,
if (!keepdim) {
auto axis_ = std::vector<int64_t>();
if (reduce_all) {
for (int64_t i = 1; i < x_dim_size; i++) {
for (int64_t i = 0; i < x_dim_size; i++) {
axis_.push_back(i);
}
} else {
Expand Down
7 changes: 2 additions & 5 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4005,9 +4005,6 @@ DDim OriginReduceInferDim(const MetaTensor& x,
out_dim_vector.push_back(x.dims().at(i));
}
}
if (x_rank > 0 && out_dim_vector.size() == 0) {
out_dim_vector.push_back(1);
}

DDim out_dim = phi::make_ddim(out_dim_vector);
return out_dim;
Expand All @@ -4024,14 +4021,14 @@ DDim OriginReduceInferDimForIntArrayAxis(const MetaTensor& x,
if (keep_dim) {
vec_dim = std::vector<int64_t>(x.dims().size(), 1);
} else {
vec_dim = {1};
vec_dim = {};
}
} else {
if (keep_dim) {
vec_dim = std::vector<int64_t>(x.dims().size(), -1);
} else {
auto x_rank = static_cast<size_t>(x.dims().size());
if (vec_axis.size() >= x_rank) {
if (vec_axis.size() > x_rank) {
vec_dim = {-1};
} else {
vec_dim = std::vector<int64_t>(x.dims().size() - vec_axis.size(), -1);
Expand Down
20 changes: 13 additions & 7 deletions python/paddle/distributed/auto_parallel/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1688,7 +1688,7 @@ def complete_update_annotation(self, serial_main_program):
world_ranks
)
out_dist_attr.dims_mapping = [
-1 for _ in range(len(out_var.shape))
-1 for _ in out_var.shape
]
self._dist_context.set_tensor_dist_attr_for_program(
out_var, out_dist_attr
Expand Down Expand Up @@ -1732,7 +1732,9 @@ def complete_update_annotation(self, serial_main_program):
len(out_var.shape) == 1
and out_var.shape[0] == 1
)
out_dist_attr.dims_mapping = [-1]
out_dist_attr.dims_mapping = [
-1 for _ in out_var.shape
]
self._dist_context.set_tensor_dist_attr_for_program(
out_var, out_dist_attr
)
Expand Down Expand Up @@ -1802,16 +1804,20 @@ def complete_update_annotation(self, serial_main_program):
param.name, ref_dims_mapping
)
learning_var = vars[op.input("LearningRate")[0]]
op_dist_attr.set_input_dims_mapping(learning_var.name, [-1])
op_dist_attr.set_input_dims_mapping(
learning_var.name, [-1 for _ in learning_var.shape]
)
op_dist_attr.set_output_dims_mapping(
learning_var.name, [-1]
learning_var.name, [-1 for _ in learning_var.shape]
)

if not learning_rate_completed:
learning_rate_completed = True
var_dist_attr = TensorDistAttr()
var_dist_attr.process_mesh = ProcessMesh(world_ranks)
var_dist_attr.dims_mapping = [-1]
var_dist_attr.dims_mapping = [
-1 for _ in learning_var.shape
]
self._dist_context.set_tensor_dist_attr_for_program(
learning_var, var_dist_attr
)
Expand Down Expand Up @@ -1841,10 +1847,10 @@ def complete_update_annotation(self, serial_main_program):
):
input_var_attr.dims_mapping = [-1]
op_dist_attr.set_input_dims_mapping(
input_var.name, [-1]
input_var.name, [-1 for _ in input_var.shape]
)
op_dist_attr.set_output_dims_mapping(
input_var.name, [-1]
input_var.name, [-1 for _ in input_var.shape]
)
else:
input_var_attr.dims_mapping = ref_dims_mapping
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/distributed/auto_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def _prepare_logger(
loss_indices = fetch_indices[group_idx]
assert len(loss_indices) <= 1
for idx in loss_indices:
logs["loss"] = outs[idx][0]
logs["loss"] = outs[idx]
group_idx += 1
# logging metrics
dist_context = self._dist_contexts[mode]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def get_data_parallel_group(dist_ctx, op, act_grad_names, rank):

for var_name in act_grad_names:
var_dim_mapping = op_dist_attr.get_input_dims_mapping(var_name)
# consider that the variable's shape is None
# consider that the variable's shape is [], which is 0D
# TODO utilize the batch_dim attr instead of "0" in future
batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,9 @@ def calc_bwd_cost(self, dist_op, ctx, cluster):
):
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
batch_size_axis = (
var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
)
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
need_gradient_allreduce = True
break
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def calc_bwd_cost(self, dist_op, ctx, cluster):
):
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
batch_size_axis = (
var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
)
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
need_gradient_allreduce = True
break
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def calc_bwd_cost(self, dist_op, ctx, cluster):
backward_op.input("Ids")[0]
)
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
Expand Down
18 changes: 9 additions & 9 deletions python/paddle/distributed/auto_parallel/operators/dist_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ def calc_bwd_cost(self, dist_op, ctx, cluster):
backward_op.input("X")[0]
)
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
if (
batch_size_axis > -1
and mesh_shape[batch_size_axis] > 1
Expand Down Expand Up @@ -1028,7 +1028,7 @@ def calc_bwd_cost(self, dist_op, ctx, cluster):
backward_op.input("X")[0]
)
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
if (
batch_size_axis > -1
and mesh_shape[batch_size_axis] > 1
Expand Down Expand Up @@ -1365,7 +1365,7 @@ def calc_bwd_cost(self, dist_op, ctx, cluster):
backward_op.input("X")[0]
)
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
if (
batch_size_axis > -1
and mesh_shape[batch_size_axis] > 1
Expand Down Expand Up @@ -1552,7 +1552,7 @@ def calc_bwd_cost(self, dist_op, ctx, cluster):
backward_op.input("X")[0]
)
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
if (
batch_size_axis > -1
and mesh_shape[batch_size_axis] > 1
Expand Down Expand Up @@ -1929,7 +1929,7 @@ def calc_bwd_cost(self, dist_op, ctx, cluster):
backward_op.input("X")[0]
)
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
if (
batch_size_axis > -1
and mesh_shape[batch_size_axis] > 1
Expand Down Expand Up @@ -2264,7 +2264,7 @@ def calc_bwd_cost(self, dist_op, ctx, cluster):
backward_op.input("X")[0]
)
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
if (
batch_size_axis > -1
and mesh_shape[batch_size_axis] > 1
Expand Down Expand Up @@ -2449,7 +2449,7 @@ def calc_bwd_cost(self, dist_op, ctx, cluster):
backward_op.input("X")[0]
)
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
if (
batch_size_axis > -1
and mesh_shape[batch_size_axis] > 1
Expand Down Expand Up @@ -2832,7 +2832,7 @@ def calc_bwd_cost(self, dist_op, ctx, cluster):
backward_op.input("X")[0]
)
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
if (
batch_size_axis > -1
and mesh_shape[batch_size_axis] > 1
Expand Down Expand Up @@ -3178,7 +3178,7 @@ def calc_bwd_cost(self, dist_op, ctx, cluster):
backward_op.input("X")[0]
)
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
if (
batch_size_axis > -1
and mesh_shape[batch_size_axis] > 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ def calc_bwd_cost(self, dist_op, ctx, cluster):
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)

mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
batch_size_axis = (
var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
)
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
Expand Down Expand Up @@ -377,7 +379,9 @@ def calc_bwd_cost(self, dist_op, ctx, cluster):
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)

mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
batch_size_axis = (
var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
)
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
Expand Down Expand Up @@ -637,7 +641,9 @@ def calc_bwd_cost(self, dist_op, ctx, cluster):
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)

mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
batch_size_axis = (
var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
)
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ def calc_bwd_cost(self, dist_op, ctx, cluster):
):
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
batch_size_axis = (
var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
)
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
need_gradient_allreduce = True
break
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ def calc_bwd_cost(self, dist_op, ctx, cluster):
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)

mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
batch_size_axis = (
var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
)
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,9 @@ def calc_bwd_cost(self, dist_op, ctx, cluster):
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)

mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0]
batch_size_axis = (
var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
)
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1727,7 +1727,9 @@ def _complete_sub_update_program(self, sub_program_dist_context):
len(out_var.shape) == 1
and out_var.shape[0] == 1
)
out_dist_attr.dims_mapping = [-1]
out_dist_attr.dims_mapping = [
-1 for _ in out_var.shape
]
sub_program_dist_context.set_tensor_dist_attr_for_program(
out_var, out_dist_attr
)
Expand Down Expand Up @@ -1798,17 +1800,19 @@ def _complete_sub_update_program(self, sub_program_dist_context):
)
learning_var = vars[op.input("LearningRate")[0]]
op_dist_attr.set_input_dims_mapping(
learning_var.name, [-1]
learning_var.name, [-1 for i in learning_var.shape]
)
op_dist_attr.set_output_dims_mapping(
learning_var.name, [-1]
learning_var.name, [-1 for i in learning_var.shape]
)

if not learning_rate_completed:
learning_rate_completed = True
var_dist_attr = TensorDistAttr()
var_dist_attr.process_mesh = world_ranks
var_dist_attr.dims_mapping = [-1]
var_dist_attr.dims_mapping = [
-1 for i in learning_var.shape
]
sub_program_dist_context.set_tensor_dist_attr_for_program(
learning_var, var_dist_attr
)
Expand Down
13 changes: 9 additions & 4 deletions python/paddle/distributed/auto_parallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1466,7 +1466,8 @@ def update_op_dims_mapping_by_default_dist_impl(dist_op):
), "{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part.".format(
op_desc.type(), idx, mapping
)
batch_dim_mappings.append(dims_mapping[0])
if len(dims_mapping) >= 1:
batch_dim_mappings.append(dims_mapping[0])
for arg_name in op_desc.output_arg_names():
serial_tensor = dist_op.get_serial_output(arg_name)
if serial_tensor.is_parameter:
Expand All @@ -1480,7 +1481,8 @@ def update_op_dims_mapping_by_default_dist_impl(dist_op):
), "{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part.".format(
op_desc.type(), idx, mapping
)
batch_dim_mappings.append(dims_mapping[0])
if len(dims_mapping) >= 1:
batch_dim_mappings.append(dims_mapping[0])
else:
assert (
dims_mapping[0] == -1
Expand All @@ -1505,7 +1507,7 @@ def update_op_dims_mapping_by_default_dist_impl(dist_op):
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if compatible_dim_mapping != dims_mapping[0]:
if len(dims_mapping) >= 1 and compatible_dim_mapping != dims_mapping[0]:
dims_mapping[0] = compatible_dim_mapping
changed = True
for arg_name in op_desc.output_arg_names():
Expand All @@ -1514,7 +1516,10 @@ def update_op_dims_mapping_by_default_dist_impl(dist_op):
continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names:
if compatible_dim_mapping != dims_mapping[0]:
if (
len(dims_mapping) >= 1
and compatible_dim_mapping != dims_mapping[0]
):
dims_mapping[0] = compatible_dim_mapping
changed = True
else:
Expand Down
Loading

0 comments on commit 3be3b57

Please sign in to comment.