Skip to content

Commit

Permalink
fix bug in print_tabular and add some new print_tabular mode (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
zrr1999 authored Apr 28, 2023
1 parent c80c77b commit 48c71f7
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 14 deletions.
4 changes: 3 additions & 1 deletion examples/resnet_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,6 @@
assert paddle.allclose(orig_output, traced_output)

print(f"python IR for {type(net).__name__}")
traced_layer.graph.print_tabular()
traced_layer.graph.print_tabular(print_mode="tabulate")
traced_layer.graph.print_tabular(print_mode="rich")
traced_layer.graph.print_tabular(print_mode="raw")
70 changes: 57 additions & 13 deletions src/paddlefx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,21 +277,65 @@ def python_code(self, root_module):
src = ''.join(body)
return src, free_vars

def print_tabular(self):
def print_tabular(self, print_mode="tabulate"):
"""Prints the intermediate representation of the graph in tabular
format.
Note that this API requires the ``tabulate`` module to be installed.
"""
try:
from tabulate import tabulate
except ImportError:
print(
"`print_tabular` relies on the library `tabulate`, "
"which could not be found on this machine. Run `pip "
"install tabulate` to install the library."
)
node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] for n in self.nodes]
print(
tabulate(node_specs, headers=['opcode', 'name', 'target', 'args', 'kwargs'])
)
assert print_mode in ["raw", "tabulate", "rich"]
if print_mode == "raw":
node_specs = [
" ".join(
map(str, [v for v in [n.op, n.name, n.target, n.args, n.kwargs]])
)
for n in self.nodes
]
print(" ".join(['opcode', 'name', 'target', 'args', 'kwargs']))
print("\n".join(node_specs))
elif print_mode == "tabulate":
try:
from tabulate import tabulate

node_specs = [
[n.op, n.name, n.target, n.args, n.kwargs] for n in self.nodes
]
print(
tabulate(
node_specs,
headers=['opcode', 'name', 'target', 'args', 'kwargs'],
)
)
except ImportError:
import warnings

warnings.warn(
"`print_tabular` relies on the library `tabulate`, "
"which could not be found on this machine. Run `pip "
"install tabulate` to install the library."
)
self.print_tabular("raw")
elif print_mode == "rich":
try:
import rich
import rich.table

table = rich.table.Table(
'opcode', 'name', 'target', 'args', 'kwargs', expand=True
)
for n in self.nodes:
table.add_row(
*map(
str, [v for v in [n.op, n.name, n.target, n.args, n.kwargs]]
)
)
rich.print(table)
except ImportError:
import warnings

warnings.warn(
"`print_tabular` relies on the library `rich`, "
"which could not be found on this machine. Run `pip "
"install rich` to install the library."
)
self.print_tabular("raw")
1 change: 1 addition & 0 deletions src/paddlefx/graph_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any

import paddle
import paddle.nn
import pydot

import paddlefx
Expand Down

0 comments on commit 48c71f7

Please sign in to comment.