Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【pir】modify ir Backward for prune #59100

Merged
merged 23 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
bce9b3b
tmp
xiaoguoguo626807 Aug 30, 2023
c2341a5
fix conflict
xiaoguoguo626807 Aug 30, 2023
4d30fdd
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Aug 31, 2023
cae7604
modify ci bug
xiaoguoguo626807 Sep 19, 2023
c94252d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 19, 2023
305ed20
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 20, 2023
3aa6686
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 22, 2023
6c553e6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 22, 2023
3b3b5ea
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 25, 2023
7e8e095
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 25, 2023
9c09a56
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Oct 7, 2023
cae57c1
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Oct 7, 2023
d52fe87
[PIR]Migrate maximum into pir
0x45f Oct 8, 2023
9e5a0b1
Polish code
0x45f Oct 9, 2023
2218be2
add ir_grad of static_gradient
xiaoguoguo626807 Oct 9, 2023
b190b2f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Oct 9, 2023
2ce9d92
Merge commit 'refs/pull/57929/head' of https://github.com/PaddlePaddl…
xiaoguoguo626807 Oct 9, 2023
02040b1
add test
xiaoguoguo626807 Oct 9, 2023
ae9b38a
Merge branch 'develop', commit 'refs/pull/57956/head' of https://gith…
xiaoguoguo626807 Oct 9, 2023
464106f
tmp
xiaoguoguo626807 Nov 16, 2023
e8421b1
modify backward
xiaoguoguo626807 Nov 17, 2023
ff2bcf2
modify
xiaoguoguo626807 Nov 17, 2023
30521e5
modify segment
xiaoguoguo626807 Nov 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 57 additions & 31 deletions python/paddle/autograd/ir_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,10 @@ def update_no_grad_set_by_stopgradient(block, no_grad_set):
no_grad_set.add(value)


def update_bwdop_structure(backward_ops, op_to_opgrad_list, grad_op):
backward_ops.append(grad_op)
op_to_opgrad_list.append(grad_op)
def update_bwdop_structure(backward_ops, op_to_opgrad_list, grad_op_list):
for grad_op in grad_op_list:
backward_ops.append(grad_op)
op_to_opgrad_list.append(grad_op)


def prepare_grad_outputs(grad_outputs, outputs, state):
Expand All @@ -87,18 +88,19 @@ def prepare_grad_outputs(grad_outputs, outputs, state):
for i, grad in enumerate(grad_outputs):
output = outputs[i]
# fwd : op1 -> op2 -> op3 -> output
# bwd : op1G <- op2G <- op3G <- outputG <- fillop/feedop
# bwd : op1G <- op2G <- op3G <- outputG <- full_likeop/feedop
if grad is None:
output_grad = paddle.full_like(
output,
1.0,
dtype=output.dtype,
)
fillop = output_grad.get_defining_op()
full_likeop = output_grad.get_defining_op()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

full_like_opfull_op data_op是不是比full_likeop fullopdataop更自然一些?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后续修改

fullop = full_likeop.operand_source(1).get_defining_op()
update_bwdop_structure(
backward_ops,
state.op_to_opgrad[output.get_defining_op()],
fillop,
[full_likeop, fullop],
)
state.value_to_valuegrad[output] = [[output_grad]]
else:
Expand All @@ -116,7 +118,7 @@ def prepare_grad_outputs(grad_outputs, outputs, state):
update_bwdop_structure(
backward_ops,
state.op_to_opgrad[output.get_defining_op()],
feedop,
[feedop],
)
state.value_to_valuegrad[output] = [[grad]]

Expand All @@ -138,12 +140,13 @@ def prepare_grad_outputs(grad_outputs, outputs, state):
0.0,
opresult.dtype,
)
fillop = grad_value.get_defining_op()
full_likeop = grad_value.get_defining_op()
fullop = full_likeop.operand_source(1).get_defining_op()

update_bwdop_structure(
backward_ops,
state.op_to_opgrad[opresult.get_defining_op()],
fillop,
[full_likeop, fullop],
)
state.value_to_valuegrad[opresult] = [[grad_value]]

Expand Down Expand Up @@ -383,11 +386,9 @@ def make_output_with_output_grad(op):
combineop = bwd_block.ops[len(bwd_block.ops) - 2]
sumop = bwd_block.ops[len(bwd_block.ops) - 1]
update_bwdop_structure(
backward_ops, state.op_to_opgrad[op], combineop
)
update_bwdop_structure(
backward_ops, state.op_to_opgrad[op], sumop
backward_ops, state.op_to_opgrad[op], [combineop, sumop]
)

state.value_to_valuegrad[value] = [[sumop.result(0)]]
state.value_to_sumvaluegrad[value] = state.value_to_valuegrad[
value
Expand Down Expand Up @@ -426,10 +427,13 @@ def make_output_with_output_grad(op):
0.0,
dtype=value.dtype,
)
fillop = grad_value.get_defining_op()
full_likeop = grad_value.get_defining_op()
fullop = full_likeop.operand_source(1).get_defining_op()

update_bwdop_structure(
backward_ops, state.op_to_opgrad[op], fillop
backward_ops,
state.op_to_opgrad[op],
[full_likeop, fullop],
)
zero_flag[i] = True

Expand Down Expand Up @@ -548,10 +552,12 @@ def update_input_grad_map(op, input_grads):
after_ops_num = len(bwd_block.ops)

# update grad_op structure
for i in range(before_ops_num, after_ops_num):
update_bwdop_structure(
backward_ops, state.op_to_opgrad[op], bwd_block.ops[i]
)
bwd_ops = [
bwd_block.ops[i] for i in range(before_ops_num, after_ops_num)
]
update_bwdop_structure(
backward_ops, state.op_to_opgrad[op], bwd_ops
)

# update input_grad map
update_input_grad_map(op, input_grads)
Expand All @@ -570,10 +576,9 @@ def update_input_grad_map(op, input_grads):
combineop = bwd_block.ops[len(bwd_block.ops) - 2]
sumop = bwd_block.ops[len(bwd_block.ops) - 1]
update_bwdop_structure(
backward_ops, state.op_to_opgrad[op], combineop
)
update_bwdop_structure(
backward_ops, state.op_to_opgrad[op], sumop
backward_ops,
state.op_to_opgrad[op],
[combineop, sumop],
)
state.value_to_valuegrad[value] = [[sumop.result(0)]]
state.value_to_sumvaluegrad[
Expand All @@ -585,20 +590,35 @@ def update_input_grad_map(op, input_grads):
state.op_to_opgrad[op] = []


def create_backward_prune_set(inputs, outputs, no_grad_set, state):
outputs_set = set()
def prepare_backward_prune_set(inputs, outputs):
outputs_fwd_set = set()
for input_ in inputs:
if not input_.use_empty():
for item in input_.first_use().owner().operands_source():
if state.value_to_valuegrad[item] != []:
outputs_set.add(state.value_to_valuegrad[item][0][0])
outputs_fwd_set.add(item)
else:
logging.warning("input privided by inputs has no use")

inputs_set = set()
inputs_fwd_set = set()
for output in outputs:
if state.value_to_valuegrad[output] != []:
inputs_set.add(state.value_to_valuegrad[output][0][0])
inputs_fwd_set.add(output)

return outputs_fwd_set, inputs_fwd_set


def create_backward_prune_set(
outputs_fwd_set, inputs_fwd_set, no_grad_set, state
):
outputs_set = set()
for item in outputs_fwd_set:
if state.value_to_valuegrad[item] != []:
outputs_set.add(state.value_to_valuegrad[item][0][0])

inputs_set = set()
for item in inputs_fwd_set:
if state.value_to_valuegrad[item] != []:
inputs_set.add(state.value_to_valuegrad[item][0][0])

inputs_set_tmp = set()
for out_grad in inputs_set:
if not out_grad.use_empty():
Expand Down Expand Up @@ -660,13 +680,19 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set):
block, effective_forward_ops, no_grad_set, inputs, complete_outputs
)

outputs_fwd_set, inputs_fwd_set = prepare_backward_prune_set(
inputs, complete_outputs
)

append_backward_ops(
block, block, effective_forward_ops, no_grad_set, backward_ops, state
)

# now value_to_valuegrad should be value <-> value (add sum op for the same values's gradvalue)
outputs_set, inputs_set, no_gradvar_set = create_backward_prune_set(
inputs, complete_outputs, no_grad_set, state
outputs_fwd_set, inputs_fwd_set, no_grad_set, state
)

_, remove_ops = prune_ops(
backward_ops, inputs_set, outputs_set, no_gradvar_set
)
Expand Down
2 changes: 1 addition & 1 deletion test/ir/pir/test_ir_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def test_no_grad_set(self):
out = paddle.mean(tanh_out)
input_grad = grad(out, input, no_grad_vars=[input])
self.assertEqual(
pir_program.global_block().ops[-1].name(), "pd_op.full"
pir_program.global_block().ops[-1].name(), "pd_op.mean"
)

def test_split(self):
Expand Down
8 changes: 5 additions & 3 deletions test/legacy_test/test_segment_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def test_check_output(self):
self.check_output(check_pir=True)

def test_check_grad(self):
self.check_grad(["X"], "Out")
self.check_grad(["X"], "Out", check_pir=True)

def convert_bf16(self):
if self.dtype == np.uint16:
Expand Down Expand Up @@ -277,7 +277,7 @@ def test_check_output(self):
self.check_output_with_place(self.place, check_pir=True)

def test_check_grad(self):
self.check_grad_with_place(self.place, ["X"], "Out")
self.check_grad_with_place(self.place, ["X"], "Out", check_pir=True)


@unittest.skipIf(
Expand All @@ -300,6 +300,7 @@ def test_check_grad(self):
["X"],
"Out",
user_defined_grads=[self.gradient],
check_pir=True,
)


Expand All @@ -323,6 +324,7 @@ def test_check_grad(self):
["X"],
"Out",
user_defined_grads=[self.gradient],
check_pir=True,
)


Expand All @@ -341,7 +343,7 @@ def test_check_output(self):
self.check_output_with_place(self.place, check_pir=True)

def test_check_grad(self):
self.check_grad_with_place(self.place, ["X"], "Out")
self.check_grad_with_place(self.place, ["X"], "Out", check_pir=True)


class API_SegmentOpsTest(unittest.TestCase):
Expand Down