Skip to content

Commit

Permalink
support more base Instructions and support resnet (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
zrr1999 authored May 5, 2023
1 parent 48c71f7 commit 83b9fdc
Show file tree
Hide file tree
Showing 11 changed files with 265 additions and 54 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ jobs:

- name: Install dependencies
run: |
pip install --upgrade pip
pip install -r requirements_dev.txt
- name: Build
Expand Down
17 changes: 8 additions & 9 deletions examples/graph_editing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,17 @@ def net(x, y):
graph = traced_layer.graph

print("Before editing:")
graph.print_tabular()
print(traced_layer.get_source())

for node in graph.nodes:
if node.op == 'call_function':
with graph.inserting_after(node):
new_node = graph.create_node(
node.op, paddle.add, args=(node.args[0], node.args[0]), kwargs={}
)
node.replace_all_uses_with(new_node)
graph.erase_node(node)
break

with graph.inserting_after(node):
new_node = graph.create_node(
node.op, paddle.add, args=(node.args[0], node.args[0]), kwargs={}
)
node.replace_all_uses_with(new_node)
graph.erase_node(node)

print("After editing:")
graph.print_tabular()
print(traced_layer.get_source())
27 changes: 27 additions & 0 deletions examples/resnet_dynamo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from __future__ import annotations

import numpy as np
import paddle
import paddle.nn
import paddle.tensor

from paddle.vision.models import resnet18

import paddlefx


def my_compiler(gl: paddlefx.GraphLayer, example_inputs: list[paddle.Tensor] = None):
print("my_compiler() called with FX graph:")
print(gl.get_source())
gl.graph.print_tabular(print_mode="rich")
return gl.forward


net = resnet18()
optimized_net = paddlefx.optimize(my_compiler)(net)

x = paddle.rand([1, 3, 224, 224])
out = net(x)
res = optimized_net(x)

np.testing.assert_equal(res.numpy(), out.numpy())
3 changes: 1 addition & 2 deletions examples/resnet_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,5 @@
assert paddle.allclose(orig_output, traced_output)

print(f"python IR for {type(net).__name__}")
print(traced_layer.get_source())
traced_layer.graph.print_tabular(print_mode="tabulate")
traced_layer.graph.print_tabular(print_mode="rich")
traced_layer.graph.print_tabular(print_mode="raw")
22 changes: 22 additions & 0 deletions examples/simple_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,25 @@ def inplace(a, b):
optimized_res = optimized_foo(in_a, in_b)

np.testing.assert_equal(original_res.numpy(), optimized_res.numpy())


class ExampleNet(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.fc = [paddle.nn.Linear(1, 1), paddle.nn.Linear(1, 1)]

def forward(self, a, b):
c = self.fc[0](a)
d = self.fc[1](b)
e = paddle.add(c, d)
return e


net = ExampleNet()

optimized_func = paddlefx.optimize(my_compiler)(net)

original_res = net(in_a, in_b)
optimized_res = optimized_func(in_a, in_b)
# TODO(zrr1999): `optimized_res` is the result of running the converted bytecode in the future.
np.testing.assert_equal(original_res.numpy(), optimized_res.numpy())
21 changes: 18 additions & 3 deletions src/paddlefx/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

import dataclasses
import dis
import inspect
import types

from typing import Callable

import paddle
import paddle.nn

from ._eval_frame import set_eval_frame
from .translator import InstructionTranslator, convert_instruction
Expand All @@ -24,6 +26,7 @@ def __init__(self, callback):

def __enter__(self):
self.old_callback = set_eval_frame(self.callback)
return self

def __exit__(self, exc_type, exc_value, traceback):
set_eval_frame(self.old_callback)
Expand All @@ -45,7 +48,22 @@ def _compile(
frame: types.FrameType,
compiler_fn: Callable,
):
# TODO(zrr1999): This part can be removed when running the converted bytecode in the future.
paddle_modules = [
"paddle.nn",
"paddle.fluid",
"paddle.tensor",
# TODO(zrr1999): add more modules
]
module = inspect.getmodule(frame)
if module is None:
raise RuntimeError('Cannot find module for frame')
package_name = module.__name__

code = frame.f_code
for paddle_module in paddle_modules:
if package_name.startswith(paddle_module):
return GuardedCode(code)
instructions = list(map(convert_instruction, dis.get_instructions(code)))

tracer = InstructionTranslator(instructions, frame, compiler_fn)
Expand All @@ -64,9 +82,6 @@ def has_tensor_in_frame(frame: types.FrameType) -> bool:
if frame.f_code.co_name == 'in_dygraph_mode':
return False

# print(frame)
# print(dis.disassemble(frame.f_code))

for v in frame.f_locals.values():
# TODO: supprt containers
if isinstance(v, paddle.Tensor):
Expand Down
34 changes: 28 additions & 6 deletions src/paddlefx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,18 @@ def snake_case(s):


def _qualified_name(func):
if hasattr(func, 'node'):
name = func.node.name
elif hasattr(func, '__name__'):
name = func.__name__
elif hasattr(func, 'name'):
name = func.name
else:
raise NotImplementedError(f'cannot get name of {func}')

# things like getattr just appear in builtins
if getattr(builtins, func.__name__, None) is func:
return func.__name__
name = func.__name__
if getattr(builtins, name, None) is func:
return name
module = _find_module_of_method(func)
return f'{module}.{name}'

Expand All @@ -42,7 +50,10 @@ def _is_illegal_name(name: str, obj: Any) -> bool:


def _find_module_of_method(orig_method):
name = orig_method.__name__
if hasattr(orig_method, '__name__'):
name = orig_method.__name__
else:
name = orig_method.__class__.__name__
module = orig_method.__module__
if module is not None:
return module
Expand Down Expand Up @@ -138,7 +149,7 @@ def create_node(self, op, target=None, args=None, kwargs=None, name=None):
'placeholder',
'output',
)
args = () if args is None else args
args = () if args is None else tuple(args)
kwargs = {} if kwargs is None else kwargs
name = name if name is not None else self._name(target or op)
if name[0].isdigit():
Expand All @@ -161,6 +172,10 @@ def output(self, result):
def _name(self, op):
if hasattr(op, '__name__'):
op = op.__name__
if hasattr(op, 'name'):
op = op.name
if hasattr(op, 'node'):
op = op.node.name

if _is_magic(op):
op = op[2:-2]
Expand All @@ -185,6 +200,11 @@ def get_param(self, target):
def placeholder(self, name):
return self.create_node('placeholder', target=name, name=name.replace('*', ''))

def call_module(self, target, args, kwargs):
return self.create_node(
'call_module', target, args, kwargs, name=target.replace('.', '_')
)

def erase_node(self, to_erase: Node) -> None:
if len(to_erase.users) > 0:
raise RuntimeError(
Expand Down Expand Up @@ -281,7 +301,9 @@ def print_tabular(self, print_mode="tabulate"):
"""Prints the intermediate representation of the graph in tabular
format.
Note that this API requires the ``tabulate`` module to be installed.
Note that this API allows users to choose between using the ``raw``,
``tabulate`` or ``rich`` mode. If the user specifies a mode that is not
installed, the API will automatically fall back on the ``raw`` mode.
"""
assert print_mode in ["raw", "tabulate", "rich"]
if print_mode == "raw":
Expand Down
9 changes: 8 additions & 1 deletion src/paddlefx/graph_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,11 @@ def __init__(self, root, graph: Graph):

def _generate_forward(self):
body, free_variables = self.graph.python_code(root_module='self')
if "self" not in free_variables:
free_variables.insert(0, "self")
body = '\n'.join(' ' + line for line in body.split('\n')) + '\n'
self.src = f"""\
def forward(self, {', '.join(free_variables)}):
def forward({', '.join(free_variables)}):
self = self.root
{body}
"""
Expand All @@ -82,6 +84,11 @@ def forward(self, {', '.join(free_variables)}):
for k, v in gbls.items():
setattr(cls, k, v)

def get_source(self, update: bool = True):
if update:
self._generate_forward()
return self.src


# copy an attribute value with qualified name 'target' from 'from_module' to 'to_module'
# This installs empty Modules where none exist yet if they are subpaths of target
Expand Down
7 changes: 5 additions & 2 deletions src/paddlefx/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, node: Node, tracer: Tracer):
self.tracer = tracer

def __repr__(self):
return f'Proxy({self.node.name})'
return f'{self.node.name}'

def __getattr__(self, k):
# note: not added to the graph yet, if this is a method call
Expand All @@ -45,7 +45,7 @@ def __iter__(self):
if current_instruction.opname == "UNPACK_SEQUENCE":
return (self[i] for i in range(current_instruction.argval))
elif current_instruction.opname == "GET_ITER":
raise NotImplementedError()
return (self[i] for i in range(current_instruction.argval))
raise ValueError("Cannot find UNPACK_SEQUENCE instruction")


Expand All @@ -66,6 +66,9 @@ def node(self):
).node
return self._node

def __str__(self):
return f'{self.root}.{self.node.name}'

def __call__(self, *args, **kwargs):
return _create_proxy(
self.tracer, 'call_method', self.attr, (self.root,) + args, kwargs
Expand Down
2 changes: 1 addition & 1 deletion src/paddlefx/symbolic_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def _proxy_placeholder(self, name):
n = self.graph.create_node('placeholder', name, (), {})
return Proxy(n, self)

def create_node(self, op, target, args, kwargs, name=None):
def create_node(self, op, target, args=None, kwargs=None, name=None):
return self.graph.create_node(op, target, args, kwargs, name)

def create_arg(self, a):
Expand Down
Loading

0 comments on commit 83b9fdc

Please sign in to comment.