Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support more base Instructions and support resnet #41

Merged
merged 29 commits into from
May 5, 2023

Conversation

zrr1999
Copy link
Contributor

@zrr1999 zrr1999 commented Apr 29, 2023

添加

  • LOAD_ATTR
  • LOAD_METHOD
  • CALL_METHOD
  • CALL_FUNCTION_KW
  • BUILD_xxx
  • BINARY_SUBSCR
  • STORE_SUBSCR

paddlefx.optimize 可以捕获ResNet模型,将其转为fx graph。

在_compile时跳过 paddle模块里的函数,从而可以支持paddle.add 这种,而不进入执行动态图和静态图的分支

@zrr1999 zrr1999 changed the title translate more bytecode ops support more base Instructions Apr 29, 2023
@zrr1999 zrr1999 marked this pull request as ready for review April 30, 2023 07:49
@zrr1999 zrr1999 mentioned this pull request Apr 30, 2023
16 tasks
@zrr1999 zrr1999 marked this pull request as draft May 1, 2023 08:26
@zrr1999 zrr1999 marked this pull request as ready for review May 1, 2023 10:37
@zrr1999 zrr1999 changed the title support more base Instructions support more base Instructions and support resnet May 1, 2023
@zrr1999 zrr1999 marked this pull request as draft May 1, 2023 10:49
@zrr1999 zrr1999 marked this pull request as ready for review May 1, 2023 10:49
@gglin001 gglin001 requested a review from jzhang533 May 4, 2023 02:53
@@ -31,6 +31,7 @@ def func(a, b):
in_a = paddle.rand([3, 4])
in_b = paddle.rand([3, 4])
out = paddle.add(in_a, in_b)
# out = paddle.add(out, out)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

无用的注释需要删掉,其他地方还有多处,以及无用的 print

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,print和一些代码注释我都删掉啦


original_res = func(in_a, in_b)
optimized_res = optimized_func(in_a, in_b)
np.testing.assert_equal(original_res.numpy(), optimized_res.numpy())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

据我所知,目前还是跑原来的字节码吧?所以这里对比貌似没啥意义?@gglin001

可以加一个 NOTE 或者 TODO 在这里~

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前还是跑原来的字节码吧?所以这里对比貌似没啥意义?

是的, 目前的对比只是确保 返回了原始的 code, 不是trace 到后转换的 code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

据我所知,目前还是跑原来的字节码吧?所以这里对比貌似没啥意义?@gglin001

可以加一个 NOTE 或者 TODO 在这里~

这块我加了一个# TODO(zrr1999): optimized_res is the result of running the converted bytecode in the future.

code = frame.f_code
for paddle_module in paddle_modules:
if package_name.startswith(paddle_module):
return GuardedCode(code)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在_compile时跳过 paddle模块里的函数,从而可以支持paddle.add 这种,而不进入执行动态图和静态图的分支

同上,这是因为目前只能跑原来的字节码,如果跑转换后的字节码理应是不会进入这些函数的 Eval Frame 里的,不过这个 PR 用于验证 ResNet 所需要的字节码的支持完备性是可以暂时这样的~在之后跑转换后的字节码时这部分逻辑应该可以删掉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块我也新加了个TODO,This part can be removed when running the converted bytecode in the future.

def pop(self):
return self.stack.pop()

def append(self, item):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据栈的语义,取名 push 会更好?

代码里若干处 self.stack.append 可统一替换为 self.push

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,这部分已经替换完毕

def IS_OP(self, inst: Instruction):
args = list(reversed([self.pop() for _ in range(2)]))
res = self.output.create_node('call_function', operator.is_, args, {})
self.stack.append(res)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是否可以统一到 BINARY_MAPPER 呢,看起来可以复用 _binary_constructor

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

以及各处的 list(reversed([self.pop() for _ in range(n)])) 逻辑应该可以复用 self.popn(n)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

list(reversed([self.pop() for _ in range(n)]))替换成了self.popn(n, reverse=True),IS_OP加入了BINARY_MAPPER

if k == "self":
self.f_locals[k] = self.output._proxy_placeholder(k)
else:
self.f_locals[k] = self.output._proxy_placeholder(k)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

emmm,这俩分支有区别嘛?我好像没看粗来:joy:?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块之前是因为跟_generate_forward那块的self有些冲突,后来实验了很多次,然后忘了改回来,现在是直接在 _generate_forward 里判断有没有self

@zrr1999 zrr1999 requested a review from SigureMo May 4, 2023 10:26
Copy link
Collaborator

@SigureMo SigureMo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

emmm,我本地尝试跑了一下,dynamo trace 得到的结果貌似有点奇怪?

my_compiler() called with FX graph:
opcode         name        target                             args                                 kwargs
-------------  ----------  ---------------------------------  -----------------------------------  --------
placeholder    self        self                               ()                                   {}
placeholder    x           x                                  ()                                   {}
call_function  getattr_1   <built-in function getattr>        (self, 'conv1')                      {}
call_function  getattr_2   <built-in function getattr>        (getattr_1, 'forward')               {}
call_function  getattr_3   <built-in function getattr>        (getattr_2, '__name__')              {}
call_function  getattr_3   Proxy(getattr_2)                   [Proxy(x)]                           {}
call_function  getattr_4   <built-in function getattr>        (self, 'bn1')                        {}
call_function  getattr_5   <built-in function getattr>        (getattr_4, 'forward')               {}
call_function  getattr_6   <built-in function getattr>        (getattr_5, '__name__')              {}
call_function  getattr_6   Proxy(getattr_5)                   [getattr_3]                          {}
call_function  getattr_7   <built-in function getattr>        (self, 'relu')                       {}
call_function  getattr_8   <built-in function getattr>        (getattr_7, 'forward')               {}
call_function  getattr_9   <built-in function getattr>        (getattr_8, '__name__')              {}
call_function  getattr_9   Proxy(getattr_8)                   [getattr_6]                          {}
call_function  getattr_10  <built-in function getattr>        (self, 'maxpool')                    {}
call_function  getattr_11  <built-in function getattr>        (getattr_10, 'forward')              {}
call_function  getattr_12  <built-in function getattr>        (getattr_11, '__name__')             {}
call_function  getattr_12  Proxy(getattr_11)                  [getattr_9]                          {}
call_function  getattr_13  <built-in function getattr>        (self, 'layer1')                     {}
call_function  getattr_14  <built-in function getattr>        (getattr_13, 'forward')              {}
call_function  getattr_15  <built-in function getattr>        (getattr_14, '__name__')             {}
call_function  getattr_15  Proxy(getattr_14)                  [getattr_12]                         {}
call_function  getattr_16  <built-in function getattr>        (self, 'layer2')                     {}
call_function  getattr_17  <built-in function getattr>        (getattr_16, 'forward')              {}
call_function  getattr_18  <built-in function getattr>        (getattr_17, '__name__')             {}
call_function  getattr_18  Proxy(getattr_17)                  [getattr_15]                         {}
call_function  getattr_19  <built-in function getattr>        (self, 'layer3')                     {}
call_function  getattr_20  <built-in function getattr>        (getattr_19, 'forward')              {}
call_function  getattr_21  <built-in function getattr>        (getattr_20, '__name__')             {}
call_function  getattr_21  Proxy(getattr_20)                  [getattr_18]                         {}
call_function  getattr_22  <built-in function getattr>        (self, 'layer4')                     {}
call_function  getattr_23  <built-in function getattr>        (getattr_22, 'forward')              {}
call_function  getattr_24  <built-in function getattr>        (getattr_23, '__name__')             {}
call_function  getattr_24  Proxy(getattr_23)                  [getattr_21]                         {}
call_function  getattr_25  <built-in function getattr>        (self, 'avgpool')                    {}
call_function  getattr_26  <built-in function getattr>        (getattr_25, 'forward')              {}
call_function  getattr_27  <built-in function getattr>        (getattr_26, '__name__')             {}
call_function  getattr_27  Proxy(getattr_26)                  [getattr_24]                         {}
call_function  gt          <built-in function gt>             [0, Proxy(getattr_31)]               {}
call_function  flatten_1   <function flatten at 0x125377280>  [1, getattr_27]                      {}
call_function  getattr_28  <built-in function getattr>        (self, 'fc')                         {}
call_function  getattr_29  <built-in function getattr>        (getattr_28, 'forward')              {}
call_function  getattr_30  <built-in function getattr>        (getattr_29, '__name__')             {}
call_function  getattr_30  Proxy(getattr_29)                  [flatten_1]                          {}
output         output      output                             [Proxy(getattr_32), gt, getattr_30]  {}
call_function  getattr_31  <built-in function getattr>        (self, 'num_classes')                {}
call_function  getattr_32  <built-in function getattr>        (self, 'with_pool')                  {}

好像和直接 FX trace 出来的差的有点多,my_compiler 里加上 gl.forward(paddle.rand([2, 3, 224, 224])) 好像也跑不起来?

而且为什么 args 里既有 tuple 又有 list 呢?

self.push(None)

def CALL_FUNCTION(self, inst: Instruction):
args = [self.pop() for _ in range(inst.argval)]
Copy link
Collaborator

@SigureMo SigureMo May 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是不是反了?貌似没有地方有 reverse=False 的需求?

比如对于如下 function call:

import dis


def foo():
    bar(1, 2, 3)


dis.dis(foo)

字节码如下:

 24           0 LOAD_GLOBAL              0 (bar)
              2 LOAD_CONST               1 (1)
              4 LOAD_CONST               2 (2)
              6 LOAD_CONST               3 (3)
              8 CALL_FUNCTION            3
             10 POP_TOP
             12 LOAD_CONST               0 (None)
             14 RETURN_VALUE

参数 1、2、3 依次入栈,依次出栈将会是反的,因此需要 reverse 才可以,BUILD_TUPLE 等都是相同的,入栈是依次的,出栈后需要 reverse

细节可以参考 dis 文档Python - ceval.c source

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,这块我已经都改成统一的反过来出栈啦

@zrr1999
Copy link
Contributor Author

zrr1999 commented May 4, 2023

emmm,我本地尝试跑了一下,dynamo trace 得到的结果貌似有点奇怪?

my_compiler() called with FX graph:
opcode         name        target                             args                                 kwargs
-------------  ----------  ---------------------------------  -----------------------------------  --------
placeholder    self        self                               ()                                   {}
placeholder    x           x                                  ()                                   {}
call_function  getattr_1   <built-in function getattr>        (self, 'conv1')                      {}
call_function  getattr_2   <built-in function getattr>        (getattr_1, 'forward')               {}
call_function  getattr_3   <built-in function getattr>        (getattr_2, '__name__')              {}
call_function  getattr_3   Proxy(getattr_2)                   [Proxy(x)]                           {}
call_function  getattr_4   <built-in function getattr>        (self, 'bn1')                        {}
call_function  getattr_5   <built-in function getattr>        (getattr_4, 'forward')               {}
call_function  getattr_6   <built-in function getattr>        (getattr_5, '__name__')              {}
call_function  getattr_6   Proxy(getattr_5)                   [getattr_3]                          {}
call_function  getattr_7   <built-in function getattr>        (self, 'relu')                       {}
call_function  getattr_8   <built-in function getattr>        (getattr_7, 'forward')               {}
call_function  getattr_9   <built-in function getattr>        (getattr_8, '__name__')              {}
call_function  getattr_9   Proxy(getattr_8)                   [getattr_6]                          {}
call_function  getattr_10  <built-in function getattr>        (self, 'maxpool')                    {}
call_function  getattr_11  <built-in function getattr>        (getattr_10, 'forward')              {}
call_function  getattr_12  <built-in function getattr>        (getattr_11, '__name__')             {}
call_function  getattr_12  Proxy(getattr_11)                  [getattr_9]                          {}
call_function  getattr_13  <built-in function getattr>        (self, 'layer1')                     {}
call_function  getattr_14  <built-in function getattr>        (getattr_13, 'forward')              {}
call_function  getattr_15  <built-in function getattr>        (getattr_14, '__name__')             {}
call_function  getattr_15  Proxy(getattr_14)                  [getattr_12]                         {}
call_function  getattr_16  <built-in function getattr>        (self, 'layer2')                     {}
call_function  getattr_17  <built-in function getattr>        (getattr_16, 'forward')              {}
call_function  getattr_18  <built-in function getattr>        (getattr_17, '__name__')             {}
call_function  getattr_18  Proxy(getattr_17)                  [getattr_15]                         {}
call_function  getattr_19  <built-in function getattr>        (self, 'layer3')                     {}
call_function  getattr_20  <built-in function getattr>        (getattr_19, 'forward')              {}
call_function  getattr_21  <built-in function getattr>        (getattr_20, '__name__')             {}
call_function  getattr_21  Proxy(getattr_20)                  [getattr_18]                         {}
call_function  getattr_22  <built-in function getattr>        (self, 'layer4')                     {}
call_function  getattr_23  <built-in function getattr>        (getattr_22, 'forward')              {}
call_function  getattr_24  <built-in function getattr>        (getattr_23, '__name__')             {}
call_function  getattr_24  Proxy(getattr_23)                  [getattr_21]                         {}
call_function  getattr_25  <built-in function getattr>        (self, 'avgpool')                    {}
call_function  getattr_26  <built-in function getattr>        (getattr_25, 'forward')              {}
call_function  getattr_27  <built-in function getattr>        (getattr_26, '__name__')             {}
call_function  getattr_27  Proxy(getattr_26)                  [getattr_24]                         {}
call_function  gt          <built-in function gt>             [0, Proxy(getattr_31)]               {}
call_function  flatten_1   <function flatten at 0x125377280>  [1, getattr_27]                      {}
call_function  getattr_28  <built-in function getattr>        (self, 'fc')                         {}
call_function  getattr_29  <built-in function getattr>        (getattr_28, 'forward')              {}
call_function  getattr_30  <built-in function getattr>        (getattr_29, '__name__')             {}
call_function  getattr_30  Proxy(getattr_29)                  [flatten_1]                          {}
output         output      output                             [Proxy(getattr_32), gt, getattr_30]  {}
call_function  getattr_31  <built-in function getattr>        (self, 'num_classes')                {}
call_function  getattr_32  <built-in function getattr>        (self, 'with_pool')                  {}

好像和直接 FX trace 出来的差的有点多,my_compiler 里加上 gl.forward(paddle.rand([2, 3, 224, 224])) 好像也跑不起来?

而且为什么 args 里既有 tuple 又有 list 呢?

python codegen这块貌似还有些问题,我明天看看能不能搞明白,现在output后面还有其他的node就很奇怪。print_table我修复了一下,之前用成了call_function,现在是call_module应该是跟resnet_trace里的基本一致了,args的我在graph加上了转换tuple的代码

@zrr1999
Copy link
Contributor Author

zrr1999 commented May 4, 2023

emmm,我本地尝试跑了一下,dynamo trace 得到的结果貌似有点奇怪?

my_compiler() called with FX graph:
opcode         name        target                             args                                 kwargs
-------------  ----------  ---------------------------------  -----------------------------------  --------
placeholder    self        self                               ()                                   {}
placeholder    x           x                                  ()                                   {}
call_function  getattr_1   <built-in function getattr>        (self, 'conv1')                      {}
call_function  getattr_2   <built-in function getattr>        (getattr_1, 'forward')               {}
call_function  getattr_3   <built-in function getattr>        (getattr_2, '__name__')              {}
call_function  getattr_3   Proxy(getattr_2)                   [Proxy(x)]                           {}
call_function  getattr_4   <built-in function getattr>        (self, 'bn1')                        {}
call_function  getattr_5   <built-in function getattr>        (getattr_4, 'forward')               {}
call_function  getattr_6   <built-in function getattr>        (getattr_5, '__name__')              {}
call_function  getattr_6   Proxy(getattr_5)                   [getattr_3]                          {}
call_function  getattr_7   <built-in function getattr>        (self, 'relu')                       {}
call_function  getattr_8   <built-in function getattr>        (getattr_7, 'forward')               {}
call_function  getattr_9   <built-in function getattr>        (getattr_8, '__name__')              {}
call_function  getattr_9   Proxy(getattr_8)                   [getattr_6]                          {}
call_function  getattr_10  <built-in function getattr>        (self, 'maxpool')                    {}
call_function  getattr_11  <built-in function getattr>        (getattr_10, 'forward')              {}
call_function  getattr_12  <built-in function getattr>        (getattr_11, '__name__')             {}
call_function  getattr_12  Proxy(getattr_11)                  [getattr_9]                          {}
call_function  getattr_13  <built-in function getattr>        (self, 'layer1')                     {}
call_function  getattr_14  <built-in function getattr>        (getattr_13, 'forward')              {}
call_function  getattr_15  <built-in function getattr>        (getattr_14, '__name__')             {}
call_function  getattr_15  Proxy(getattr_14)                  [getattr_12]                         {}
call_function  getattr_16  <built-in function getattr>        (self, 'layer2')                     {}
call_function  getattr_17  <built-in function getattr>        (getattr_16, 'forward')              {}
call_function  getattr_18  <built-in function getattr>        (getattr_17, '__name__')             {}
call_function  getattr_18  Proxy(getattr_17)                  [getattr_15]                         {}
call_function  getattr_19  <built-in function getattr>        (self, 'layer3')                     {}
call_function  getattr_20  <built-in function getattr>        (getattr_19, 'forward')              {}
call_function  getattr_21  <built-in function getattr>        (getattr_20, '__name__')             {}
call_function  getattr_21  Proxy(getattr_20)                  [getattr_18]                         {}
call_function  getattr_22  <built-in function getattr>        (self, 'layer4')                     {}
call_function  getattr_23  <built-in function getattr>        (getattr_22, 'forward')              {}
call_function  getattr_24  <built-in function getattr>        (getattr_23, '__name__')             {}
call_function  getattr_24  Proxy(getattr_23)                  [getattr_21]                         {}
call_function  getattr_25  <built-in function getattr>        (self, 'avgpool')                    {}
call_function  getattr_26  <built-in function getattr>        (getattr_25, 'forward')              {}
call_function  getattr_27  <built-in function getattr>        (getattr_26, '__name__')             {}
call_function  getattr_27  Proxy(getattr_26)                  [getattr_24]                         {}
call_function  gt          <built-in function gt>             [0, Proxy(getattr_31)]               {}
call_function  flatten_1   <function flatten at 0x125377280>  [1, getattr_27]                      {}
call_function  getattr_28  <built-in function getattr>        (self, 'fc')                         {}
call_function  getattr_29  <built-in function getattr>        (getattr_28, 'forward')              {}
call_function  getattr_30  <built-in function getattr>        (getattr_29, '__name__')             {}
call_function  getattr_30  Proxy(getattr_29)                  [flatten_1]                          {}
output         output      output                             [Proxy(getattr_32), gt, getattr_30]  {}
call_function  getattr_31  <built-in function getattr>        (self, 'num_classes')                {}
call_function  getattr_32  <built-in function getattr>        (self, 'with_pool')                  {}

好像和直接 FX trace 出来的差的有点多,my_compiler 里加上 gl.forward(paddle.rand([2, 3, 224, 224])) 好像也跑不起来?

而且为什么 args 里既有 tuple 又有 list 呢?

现在应该可以了

Copy link
Collaborator

@gglin001 gglin001 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, 👍

Copy link
Collaborator

@SigureMo SigureMo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM~

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants