From 83b9fdc052d817b08a312ebbd163b79e1c984505 Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <46243324+zrr1999@users.noreply.github.com> Date: Fri, 5 May 2023 15:11:50 +0800 Subject: [PATCH] support more base Instructions and support resnet (#41) --- .github/workflows/test.yml | 1 + examples/graph_editing.py | 17 ++-- examples/resnet_dynamo.py | 27 +++++ examples/resnet_trace.py | 3 +- examples/simple_dynamo.py | 22 +++++ src/paddlefx/eval_frame.py | 21 +++- src/paddlefx/graph.py | 34 +++++-- src/paddlefx/graph_layer.py | 9 +- src/paddlefx/proxy.py | 7 +- src/paddlefx/symbolic_trace.py | 2 +- src/paddlefx/translator.py | 176 +++++++++++++++++++++++++++------ 11 files changed, 265 insertions(+), 54 deletions(-) create mode 100644 examples/resnet_dynamo.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 397d684..799557f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -24,6 +24,7 @@ jobs: - name: Install dependencies run: | + pip install --upgrade pip pip install -r requirements_dev.txt - name: Build diff --git a/examples/graph_editing.py b/examples/graph_editing.py index eb21de6..20bb0c3 100644 --- a/examples/graph_editing.py +++ b/examples/graph_editing.py @@ -13,18 +13,17 @@ def net(x, y): graph = traced_layer.graph print("Before editing:") -graph.print_tabular() +print(traced_layer.get_source()) for node in graph.nodes: if node.op == 'call_function': + with graph.inserting_after(node): + new_node = graph.create_node( + node.op, paddle.add, args=(node.args[0], node.args[0]), kwargs={} + ) + node.replace_all_uses_with(new_node) + graph.erase_node(node) break -with graph.inserting_after(node): - new_node = graph.create_node( - node.op, paddle.add, args=(node.args[0], node.args[0]), kwargs={} - ) - node.replace_all_uses_with(new_node) -graph.erase_node(node) - print("After editing:") -graph.print_tabular() +print(traced_layer.get_source()) diff --git a/examples/resnet_dynamo.py b/examples/resnet_dynamo.py new file mode 100644 index 0000000..e57d681 --- /dev/null +++ b/examples/resnet_dynamo.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +import numpy as np +import paddle +import paddle.nn +import paddle.tensor + +from paddle.vision.models import resnet18 + +import paddlefx + + +def my_compiler(gl: paddlefx.GraphLayer, example_inputs: list[paddle.Tensor] = None): + print("my_compiler() called with FX graph:") + print(gl.get_source()) + gl.graph.print_tabular(print_mode="rich") + return gl.forward + + +net = resnet18() +optimized_net = paddlefx.optimize(my_compiler)(net) + +x = paddle.rand([1, 3, 224, 224]) +out = net(x) +res = optimized_net(x) + +np.testing.assert_equal(res.numpy(), out.numpy()) diff --git a/examples/resnet_trace.py b/examples/resnet_trace.py index f22bc64..4e108c9 100644 --- a/examples/resnet_trace.py +++ b/examples/resnet_trace.py @@ -17,6 +17,5 @@ assert paddle.allclose(orig_output, traced_output) print(f"python IR for {type(net).__name__}") +print(traced_layer.get_source()) traced_layer.graph.print_tabular(print_mode="tabulate") -traced_layer.graph.print_tabular(print_mode="rich") -traced_layer.graph.print_tabular(print_mode="raw") diff --git a/examples/simple_dynamo.py b/examples/simple_dynamo.py index f6dc50d..e0fbc5c 100644 --- a/examples/simple_dynamo.py +++ b/examples/simple_dynamo.py @@ -80,3 +80,25 @@ def inplace(a, b): optimized_res = optimized_foo(in_a, in_b) np.testing.assert_equal(original_res.numpy(), optimized_res.numpy()) + + +class ExampleNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.fc = [paddle.nn.Linear(1, 1), paddle.nn.Linear(1, 1)] + + def forward(self, a, b): + c = self.fc[0](a) + d = self.fc[1](b) + e = paddle.add(c, d) + return e + + +net = ExampleNet() + +optimized_func = paddlefx.optimize(my_compiler)(net) + +original_res = net(in_a, in_b) +optimized_res = optimized_func(in_a, in_b) +# TODO(zrr1999): `optimized_res` is the result of running the converted bytecode in the future. +np.testing.assert_equal(original_res.numpy(), optimized_res.numpy()) diff --git a/src/paddlefx/eval_frame.py b/src/paddlefx/eval_frame.py index 3b4299b..05c0b17 100644 --- a/src/paddlefx/eval_frame.py +++ b/src/paddlefx/eval_frame.py @@ -2,11 +2,13 @@ import dataclasses import dis +import inspect import types from typing import Callable import paddle +import paddle.nn from ._eval_frame import set_eval_frame from .translator import InstructionTranslator, convert_instruction @@ -24,6 +26,7 @@ def __init__(self, callback): def __enter__(self): self.old_callback = set_eval_frame(self.callback) + return self def __exit__(self, exc_type, exc_value, traceback): set_eval_frame(self.old_callback) @@ -45,7 +48,22 @@ def _compile( frame: types.FrameType, compiler_fn: Callable, ): + # TODO(zrr1999): This part can be removed when running the converted bytecode in the future. + paddle_modules = [ + "paddle.nn", + "paddle.fluid", + "paddle.tensor", + # TODO(zrr1999): add more modules + ] + module = inspect.getmodule(frame) + if module is None: + raise RuntimeError('Cannot find module for frame') + package_name = module.__name__ + code = frame.f_code + for paddle_module in paddle_modules: + if package_name.startswith(paddle_module): + return GuardedCode(code) instructions = list(map(convert_instruction, dis.get_instructions(code))) tracer = InstructionTranslator(instructions, frame, compiler_fn) @@ -64,9 +82,6 @@ def has_tensor_in_frame(frame: types.FrameType) -> bool: if frame.f_code.co_name == 'in_dygraph_mode': return False - # print(frame) - # print(dis.disassemble(frame.f_code)) - for v in frame.f_locals.values(): # TODO: supprt containers if isinstance(v, paddle.Tensor): diff --git a/src/paddlefx/graph.py b/src/paddlefx/graph.py index 7542d75..86a7629 100644 --- a/src/paddlefx/graph.py +++ b/src/paddlefx/graph.py @@ -21,10 +21,18 @@ def snake_case(s): def _qualified_name(func): + if hasattr(func, 'node'): + name = func.node.name + elif hasattr(func, '__name__'): + name = func.__name__ + elif hasattr(func, 'name'): + name = func.name + else: + raise NotImplementedError(f'cannot get name of {func}') + # things like getattr just appear in builtins - if getattr(builtins, func.__name__, None) is func: - return func.__name__ - name = func.__name__ + if getattr(builtins, name, None) is func: + return name module = _find_module_of_method(func) return f'{module}.{name}' @@ -42,7 +50,10 @@ def _is_illegal_name(name: str, obj: Any) -> bool: def _find_module_of_method(orig_method): - name = orig_method.__name__ + if hasattr(orig_method, '__name__'): + name = orig_method.__name__ + else: + name = orig_method.__class__.__name__ module = orig_method.__module__ if module is not None: return module @@ -138,7 +149,7 @@ def create_node(self, op, target=None, args=None, kwargs=None, name=None): 'placeholder', 'output', ) - args = () if args is None else args + args = () if args is None else tuple(args) kwargs = {} if kwargs is None else kwargs name = name if name is not None else self._name(target or op) if name[0].isdigit(): @@ -161,6 +172,10 @@ def output(self, result): def _name(self, op): if hasattr(op, '__name__'): op = op.__name__ + if hasattr(op, 'name'): + op = op.name + if hasattr(op, 'node'): + op = op.node.name if _is_magic(op): op = op[2:-2] @@ -185,6 +200,11 @@ def get_param(self, target): def placeholder(self, name): return self.create_node('placeholder', target=name, name=name.replace('*', '')) + def call_module(self, target, args, kwargs): + return self.create_node( + 'call_module', target, args, kwargs, name=target.replace('.', '_') + ) + def erase_node(self, to_erase: Node) -> None: if len(to_erase.users) > 0: raise RuntimeError( @@ -281,7 +301,9 @@ def print_tabular(self, print_mode="tabulate"): """Prints the intermediate representation of the graph in tabular format. - Note that this API requires the ``tabulate`` module to be installed. + Note that this API allows users to choose between using the ``raw``, + ``tabulate`` or ``rich`` mode. If the user specifies a mode that is not + installed, the API will automatically fall back on the ``raw`` mode. """ assert print_mode in ["raw", "tabulate", "rich"] if print_mode == "raw": diff --git a/src/paddlefx/graph_layer.py b/src/paddlefx/graph_layer.py index 063af44..286b58a 100644 --- a/src/paddlefx/graph_layer.py +++ b/src/paddlefx/graph_layer.py @@ -66,9 +66,11 @@ def __init__(self, root, graph: Graph): def _generate_forward(self): body, free_variables = self.graph.python_code(root_module='self') + if "self" not in free_variables: + free_variables.insert(0, "self") body = '\n'.join(' ' + line for line in body.split('\n')) + '\n' self.src = f"""\ -def forward(self, {', '.join(free_variables)}): +def forward({', '.join(free_variables)}): self = self.root {body} """ @@ -82,6 +84,11 @@ def forward(self, {', '.join(free_variables)}): for k, v in gbls.items(): setattr(cls, k, v) + def get_source(self, update: bool = True): + if update: + self._generate_forward() + return self.src + # copy an attribute value with qualified name 'target' from 'from_module' to 'to_module' # This installs empty Modules where none exist yet if they are subpaths of target diff --git a/src/paddlefx/proxy.py b/src/paddlefx/proxy.py index 61c086a..1bc41d4 100644 --- a/src/paddlefx/proxy.py +++ b/src/paddlefx/proxy.py @@ -27,7 +27,7 @@ def __init__(self, node: Node, tracer: Tracer): self.tracer = tracer def __repr__(self): - return f'Proxy({self.node.name})' + return f'{self.node.name}' def __getattr__(self, k): # note: not added to the graph yet, if this is a method call @@ -45,7 +45,7 @@ def __iter__(self): if current_instruction.opname == "UNPACK_SEQUENCE": return (self[i] for i in range(current_instruction.argval)) elif current_instruction.opname == "GET_ITER": - raise NotImplementedError() + return (self[i] for i in range(current_instruction.argval)) raise ValueError("Cannot find UNPACK_SEQUENCE instruction") @@ -66,6 +66,9 @@ def node(self): ).node return self._node + def __str__(self): + return f'{self.root}.{self.node.name}' + def __call__(self, *args, **kwargs): return _create_proxy( self.tracer, 'call_method', self.attr, (self.root,) + args, kwargs diff --git a/src/paddlefx/symbolic_trace.py b/src/paddlefx/symbolic_trace.py index 8d25012..83e1d77 100644 --- a/src/paddlefx/symbolic_trace.py +++ b/src/paddlefx/symbolic_trace.py @@ -229,7 +229,7 @@ def _proxy_placeholder(self, name): n = self.graph.create_node('placeholder', name, (), {}) return Proxy(n, self) - def create_node(self, op, target, args, kwargs, name=None): + def create_node(self, op, target, args=None, kwargs=None, name=None): return self.graph.create_node(op, target, args, kwargs, name) def create_arg(self, a): diff --git a/src/paddlefx/translator.py b/src/paddlefx/translator.py index 71b70f9..aff049b 100644 --- a/src/paddlefx/translator.py +++ b/src/paddlefx/translator.py @@ -56,9 +56,9 @@ def convert_instruction(i: dis.Instruction): def _binary_constructor(op_name: str): def _binary(self, inst: Instruction): op = getattr(operator, op_name) - args = list(reversed([self.stack.pop() for _ in range(2)])) + args = self.popn(2) res = self.output.create_node('call_function', op, args, {}) - self.stack.append(res) + self.push(res) return _binary @@ -66,8 +66,8 @@ def _binary(self, inst: Instruction): def _unary_constructor(op_name: str): def _unary(self, inst: Instruction): op = getattr(operator, op_name) - res = self.output.create_node('call_function', op, self.stack.pop(), {}) - self.stack.append(res) + res = self.output.create_node('call_function', op, self.pop(), {}) + self.push(res) return _unary @@ -100,6 +100,7 @@ def _not_impl(self, inst): 'ipow': 'INPLACE_POWER', 'isub': 'INPLACE_SUBTRACT', 'itruediv': 'INPLACE_TRUE_DIVIDE', + 'is_': 'IS_OP', } UNARY_MAPPER = {'not_': 'UNARY_NOT', 'inv': 'UNARY_INVERT'} @@ -155,16 +156,61 @@ def compile_subgraph(self): # add output node stack_values = list(self.stack) self.output.create_node('output', 'output', stack_values, {}) - - gl = GraphLayer(paddle.nn.Layer(), self.output.graph) + if self.frame.f_locals.get('self', None): + root = self.frame.f_locals.get('self') + else: + root = paddle.nn.Layer() + gl = GraphLayer(root, self.output.graph) self.call_user_compiler(gl) + def pop(self): + return self.stack.pop() + + def push(self, item): + return self.stack.append(item) + + def popn(self, n: int, reverse=True): + assert n >= 0 + if not n: + return [] + if reverse: + return list(reversed([self.pop() for _ in range(n)])) + else: + return [self.pop() for _ in range(n)] + + def call_function(self, fn, args, kwargs): + is_custom_call = False + for arg in args: + if isinstance(arg, (Proxy, paddle.Tensor)): + is_custom_call = True + break + for arg in kwargs: + if isinstance(arg, (Proxy, paddle.Tensor)): + is_custom_call = True + break + + # TODO: add `self.call_function` to handle more functions + if fn is print: + self.push(None) + elif fn is isinstance: + res = self.output.create_node('call_function', fn, args, kwargs) + self.push(res) + elif fn.__module__.startswith("paddle"): + if hasattr(fn, "forward"): + fn = fn.forward + res = self.output.create_node('call_function', fn, args, kwargs) + self.push(res) + elif is_custom_call: + raise NotImplementedError(f"custom_call is not supported") + else: + raise NotImplementedError(f"call function {fn} is not supported") + def LOAD_GLOBAL(self, inst: Instruction): name = inst.argval if name in self.frame.f_globals: - self.stack.append(self.frame.f_globals[name]) + self.push(self.frame.f_globals[name]) elif name in self.frame.f_builtins: - self.stack.append(self.frame.f_builtins[name]) + self.push(self.frame.f_builtins[name]) else: raise Exception(f"name '{name}' is not found") @@ -176,34 +222,104 @@ def POP_JUMP_IF_TRUE(self, inst: Instruction): def LOAD_CONST(self, inst: Instruction): value = inst.argval - self.stack.append(value) + self.push(value) + + def LOAD_ATTR(self, inst: Instruction): + obj = self.pop() + if isinstance(obj, Proxy) and obj.node.name.startswith("self"): + res = self.output.create_node('get_param', inst.argval) + self.push(res) + elif hasattr(obj, inst.argval): + value = getattr(obj, inst.argval) + self.push(value) + else: + self.push(None) + + def LOAD_METHOD(self, inst: Instruction): + target = self.pop() + if isinstance(target, str) and target.startswith("self"): + fn = f"{target}.{inst.argval}" + elif isinstance(target, Proxy) and target.node.name.startswith("self"): + fn = f"{target.node.name}.{inst.argval}" + else: + fn = getattr(target, inst.argval) + self.push(fn) + + def CALL_METHOD(self, inst: Instruction): + args = self.popn(inst.argval) + fn = self.pop() + if isinstance(fn, str): + if fn.startswith("self"): + res = self.output.create_node('call_module', fn[5:], args, {}) + else: + # TODO(zrr1999) call_method is not implemented. + raise NotImplemented + # res = self.output.create_node('call_method', fn, args, {}) + self.push(res) + else: + if hasattr(fn, "forward"): + fn = fn.forward + if fn is not None: + res = self.output.create_node('call_function', fn, args, {}) + self.push(res) + else: + self.push(None) def CALL_FUNCTION(self, inst: Instruction): - args = [self.stack.pop() for _ in range(inst.argval)] - fn = self.stack.pop() - - is_custom_call = False - for arg in args: - if isinstance(arg, (Proxy, paddle.Tensor)): - is_custom_call = True - break - - # TODO: add `self.call_function` to handle more functions - if fn == print: - self.stack.append(None) - elif is_custom_call: - raise NotImplementedError(f"custom_call is not supported") - else: - raise NotImplementedError(f"call function {fn} is not supported") + args = self.popn(inst.argval) + fn = self.pop() + self.call_function(fn, args, {}) + + def CALL_FUNCTION_KW(self, inst: Instruction): + argnames = self.pop() + args = self.popn(inst.argval) + fn = self.pop() + args, kwargs = args[: -len(argnames)], args[-len(argnames) :] + kwargs = dict(zip(argnames, kwargs)) + self.call_function(fn, args, kwargs) + + def BUILD_TUPLE(self, inst): + items = self.popn(inst.argval) + self.push(tuple(items)) + + def BUILD_LIST(self, inst): + items = self.popn(inst.argval) + self.push(items) + + def BUILD_MAP(self, inst): + items = self.popn(inst.argval * 2) + result = dict() + for k, v in zip(items[::2], items[1::2]): + result[k] = v + assert len(result) == len(items) / 2 + self.push(result) + + def BUILD_CONST_KEY_MAP(self, inst): + # TODO(zrr1999): add assert + keys = self.pop() + values = self.popn(inst.argval) + self.push(dict(zip(keys, values))) + + def BINARY_SUBSCR(self, inst): + idx = self.pop() + root = self.pop() + res = self.output.create_node('call_method', "__getitem__", [root, idx], {}) + self.push(res) + + def STORE_SUBSCR(self, inst): + value = self.pop() + idx = self.pop() + root = self.pop() + self.output.create_node('call_method', "__setitem__", [root, idx, value], {}) def POP_TOP(self, inst: Instruction): - value = self.stack.pop() + value = self.pop() def STORE_FAST(self, inst: Instruction): - self.f_locals[inst.argval] = self.stack.pop() + self.f_locals[inst.argval] = self.pop() def LOAD_FAST(self, inst: Instruction): - self.stack.append(self.f_locals[inst.argval]) + self.push(self.f_locals[inst.argval]) def RETURN_VALUE(self, inst: Instruction): self.compile_subgraph() @@ -220,9 +336,9 @@ def COMPARE_OP(self, inst: Instruction): 'is not': 'is_not', } op = getattr(operator, op_mapper[inst.argval]) - args = list(reversed([self.stack.pop() for _ in range(2)])) + args = self.popn(2) res = self.output.create_node('call_function', op, args, {}) - self.stack.append(res) + self.push(res) class InstructionTranslator(InstructionTranslatorBase):