Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Cherry-pick] fix set_value with scalar grad #60930

Merged
merged 5 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 19 additions & 25 deletions paddle/fluid/operators/set_value_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,32 +151,26 @@ class SetValueGradMaker : public framework::SingleGradOpMaker<T> {

protected:
void Apply(GradOpPtr<T> op) const override {
if (this->HasInput("ValueTensor")) {
op->SetType("set_value_grad");

op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetInput("ValueTensor", this->Input("ValueTensor"));
if (this->HasInput("StartsTensorList")) {
op->SetInput("StartsTensorList", this->Input("StartsTensorList"));
}
if (this->HasInput("EndsTensorList")) {
op->SetInput("EndsTensorList", this->Input("EndsTensorList"));
}
if (this->HasInput("StepsTensorList")) {
op->SetInput("StepsTensorList", this->Input("StepsTensorList"));
}

op->SetAttrMap(this->Attrs());

op->SetOutput(framework::GradVarName("ValueTensor"),
this->InputGrad("ValueTensor"));
op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));

} else {
op->SetType("assign");
op->SetInput("X", this->OutputGrad("Out"));
op->SetOutput("Out", this->InputGrad("Input"));
op->SetType("set_value_grad");
op->SetInput("ValueTensor", this->Input("ValueTensor"));
op->SetOutput(framework::GradVarName("ValueTensor"),
this->InputGrad("ValueTensor"));

op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));

if (this->HasInput("StartsTensorList")) {
op->SetInput("StartsTensorList", this->Input("StartsTensorList"));
}
if (this->HasInput("EndsTensorList")) {
op->SetInput("EndsTensorList", this->Input("EndsTensorList"));
}
if (this->HasInput("StepsTensorList")) {
op->SetInput("StepsTensorList", this->Input("StepsTensorList"));
}

op->SetAttrMap(this->Attrs());

op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
}
};

Expand Down
108 changes: 63 additions & 45 deletions paddle/fluid/pybind/eager_method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1375,7 +1375,7 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self,

// step3: Dealing with advanced indexing
std::vector<paddle::Tensor> transed_index;
std::vector<int> trans_back_dim;
std::vector<int> trans_back_dim, trans_dim;
int pos_of_new_dim = INT_MAX, rank_of_new_dim = 1;

paddle::Tensor transed_tensor = dealWithAdvancedIndex(out,
Expand All @@ -1385,7 +1385,8 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self,
&transed_index,
&trans_back_dim,
&pos_of_new_dim,
&rank_of_new_dim);
&rank_of_new_dim,
&trans_dim);

if (transed_index.size() == 1 &&
transed_index[0].dtype() == phi::DataType::BOOL) {
Expand Down Expand Up @@ -1607,58 +1608,70 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
&use_strided_slice);

// step2: Parse values
PADDLE_ENFORCE(
PyCheckTensor(value_obj),
platform::errors::InvalidArgument("The value must be a Tensor"));

std::vector<phi::Scalar> values;
paddle::Tensor value_tensor =
reinterpret_cast<TensorObject*>(value_obj)->tensor;
dealWithValues(tensor, value_obj, &values, has_advanced_index);

if (!has_advanced_index) {
// use set_value OP if there is no advanced index

// Release gil and do tracing
py::gil_scoped_release release;
// use inplace set_value_ operator
if (value_tensor.initialized() &&
(self->tensor.dtype() != value_tensor.dtype())) {
if (egr::Controller::Instance().GetAMPLevel() !=
paddle::imperative::AmpLevel::O0) {
paddle::small_vector<std::vector<paddle::Tensor>,
egr::kSlotSmallVectorSize>
tmps = {{self->tensor}, {value_tensor}};
auto amp_dtype = egr::GetAmpDestDtype("set_value", tmps);
self->tensor = egr::EagerAmpAutoCast(
self->tensor.name(), self->tensor, amp_dtype, "set_value");
value_tensor = egr::EagerAmpAutoCast(
value_tensor.name(), value_tensor, amp_dtype, "set_value");
}
if (value_tensor.initialized()) {
if (self->tensor.dtype() != value_tensor.dtype()) {
value_tensor = cast_ad_func(value_tensor, self->tensor.dtype());
if (egr::Controller::Instance().GetAMPLevel() !=
paddle::imperative::AmpLevel::O0) {
paddle::small_vector<std::vector<paddle::Tensor>,
egr::kSlotSmallVectorSize>
tmps = {{self->tensor}, {value_tensor}};
auto amp_dtype = egr::GetAmpDestDtype("set_value", tmps);
self->tensor = egr::EagerAmpAutoCast(
self->tensor.name(), self->tensor, amp_dtype, "set_value");
value_tensor = egr::EagerAmpAutoCast(
value_tensor.name(), value_tensor, amp_dtype, "set_value");
}
if (self->tensor.dtype() != value_tensor.dtype()) {
value_tensor = cast_ad_func(value_tensor, self->tensor.dtype());
}
}
}

// step3.1: Only basic indexing, use OP set_value.
const phi::distributed::ProcessMesh* mesh = nullptr;
if (InputsContainDistTensor(&mesh, self->tensor, value_tensor)) {
ConvertAllInputsToDistTensor(mesh, self->tensor, value_tensor);
}
self->tensor = set_value_with_tensor__ad_func(self->tensor,
value_tensor,
slice_starts,
slice_ends,
slice_strides,
slice_axes,
decrease_axis,
none_axes);
if (PyCheckTensor(value_obj)) {
// pass the stop_gradient from value to tensor.
// pass stop gradient should be done after CheckInplace in
// set_value__dygraph_function.
if (!egr::EagerUtils::autograd_meta(&value_tensor)->StopGradient() &&
egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient()) {
egr::EagerUtils::autograd_meta(&self->tensor)->SetStopGradient(false);
// step3.1: Only basic indexing, use OP set_value.
const phi::distributed::ProcessMesh* mesh = nullptr;
if (InputsContainDistTensor(&mesh, self->tensor, value_tensor)) {
ConvertAllInputsToDistTensor(mesh, self->tensor, value_tensor);
}
self->tensor = set_value_with_tensor__ad_func(self->tensor,
value_tensor,
slice_starts,
slice_ends,
slice_strides,
slice_axes,
decrease_axis,
none_axes);
if (PyCheckTensor(value_obj)) {
// pass the stop_gradient from value to tensor.
// pass stop gradient should be done after CheckInplace in
// set_value__dygraph_function.
if (!egr::EagerUtils::autograd_meta(&value_tensor)->StopGradient() &&
egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient()) {
egr::EagerUtils::autograd_meta(&self->tensor)->SetStopGradient(false);
}
}
} else {
const phi::distributed::ProcessMesh* mesh = nullptr;
if (InputsContainDistTensor(&mesh, self->tensor)) {
ConvertAllInputsToDistTensor(mesh, self->tensor);
}
self->tensor = set_value__ad_func(self->tensor,
slice_starts,
slice_ends,
slice_strides,
slice_axes,
decrease_axis,
none_axes,
{1},
values);
}
} else {
// step3.2: Case for there are advanced indexing.
Expand All @@ -1679,9 +1692,9 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
&use_strided_slice);

std::vector<paddle::Tensor> transed_index;
std::vector<int> trans_back_dim;
std::vector<int> trans_back_dim, trans_dim;

int pos_of_new_dim = 0, rank_of_new_dim = 0;
int pos_of_new_dim = INT_MAX, rank_of_new_dim = 1;

paddle::Tensor transed_sub_tensor =
dealWithAdvancedIndex(sub_tensor,
Expand All @@ -1691,7 +1704,8 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
&transed_index,
&trans_back_dim,
&pos_of_new_dim,
&rank_of_new_dim);
&rank_of_new_dim,
&trans_dim);

// Release gil and do tracing
py::gil_scoped_release release;
Expand All @@ -1714,6 +1728,10 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
}
}

if (value_tensor.dims().size() > 1 && pos_of_new_dim != 0) {
value_tensor = transpose_ad_func(value_tensor, trans_dim);
}

// TODO(zoooo0820) 1.Using inplace version index_put
// 2.Remove following code after backward bug fixed.
transed_sub_tensor = assign_ad_func(transed_sub_tensor);
Expand Down
Loading