diff --git a/examples/resnet_trace.py b/examples/resnet_trace.py index 5c51855..f22bc64 100644 --- a/examples/resnet_trace.py +++ b/examples/resnet_trace.py @@ -17,4 +17,6 @@ assert paddle.allclose(orig_output, traced_output) print(f"python IR for {type(net).__name__}") -traced_layer.graph.print_tabular() +traced_layer.graph.print_tabular(print_mode="tabulate") +traced_layer.graph.print_tabular(print_mode="rich") +traced_layer.graph.print_tabular(print_mode="raw") diff --git a/src/paddlefx/graph.py b/src/paddlefx/graph.py index 93c3564..7542d75 100644 --- a/src/paddlefx/graph.py +++ b/src/paddlefx/graph.py @@ -277,21 +277,65 @@ def python_code(self, root_module): src = ''.join(body) return src, free_vars - def print_tabular(self): + def print_tabular(self, print_mode="tabulate"): """Prints the intermediate representation of the graph in tabular format. Note that this API requires the ``tabulate`` module to be installed. """ - try: - from tabulate import tabulate - except ImportError: - print( - "`print_tabular` relies on the library `tabulate`, " - "which could not be found on this machine. Run `pip " - "install tabulate` to install the library." - ) - node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] for n in self.nodes] - print( - tabulate(node_specs, headers=['opcode', 'name', 'target', 'args', 'kwargs']) - ) + assert print_mode in ["raw", "tabulate", "rich"] + if print_mode == "raw": + node_specs = [ + " ".join( + map(str, [v for v in [n.op, n.name, n.target, n.args, n.kwargs]]) + ) + for n in self.nodes + ] + print(" ".join(['opcode', 'name', 'target', 'args', 'kwargs'])) + print("\n".join(node_specs)) + elif print_mode == "tabulate": + try: + from tabulate import tabulate + + node_specs = [ + [n.op, n.name, n.target, n.args, n.kwargs] for n in self.nodes + ] + print( + tabulate( + node_specs, + headers=['opcode', 'name', 'target', 'args', 'kwargs'], + ) + ) + except ImportError: + import warnings + + warnings.warn( + "`print_tabular` relies on the library `tabulate`, " + "which could not be found on this machine. Run `pip " + "install tabulate` to install the library." + ) + self.print_tabular("raw") + elif print_mode == "rich": + try: + import rich + import rich.table + + table = rich.table.Table( + 'opcode', 'name', 'target', 'args', 'kwargs', expand=True + ) + for n in self.nodes: + table.add_row( + *map( + str, [v for v in [n.op, n.name, n.target, n.args, n.kwargs]] + ) + ) + rich.print(table) + except ImportError: + import warnings + + warnings.warn( + "`print_tabular` relies on the library `rich`, " + "which could not be found on this machine. Run `pip " + "install rich` to install the library." + ) + self.print_tabular("raw") diff --git a/src/paddlefx/graph_viewer.py b/src/paddlefx/graph_viewer.py index bc8dc7f..419b9bc 100644 --- a/src/paddlefx/graph_viewer.py +++ b/src/paddlefx/graph_viewer.py @@ -3,6 +3,7 @@ from typing import Any import paddle +import paddle.nn import pydot import paddlefx