Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deal with multi input tensors #99

Merged
merged 17 commits into from
Oct 18, 2022
11 changes: 10 additions & 1 deletion oneflow_onnx/oneflow2onnx/flow2onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from typing import Text, Optional, Dict, Callable, List

import numpy as np
import onnx
from onnx import helper, onnx_pb

import oneflow
Expand Down Expand Up @@ -266,6 +267,15 @@ def Export(
e
)
)
try:
from onnxsim import simplify

model = onnx.load(onnx_filename)
model_simp, check = simplify(model)
onnx.save(model_simp, onnx_filename)
except Exception:
logger.info("If you want to simplify the onnx model, please install onnxsim and run again!")

return


Expand Down Expand Up @@ -302,7 +312,6 @@ def ProcessFlowGraph(
TopologicalSort(g, continue_on_error)

g.UpdateProto()

logger.debug("Summay Stats:\n" "\toneflow ops: {}\n" "\toneflow attr: {}\n" "\tonnx mapped: {}\n" "\tonnx unmapped: {}".format(op_cnt, attr_cnt, mapped_op, unmapped_op))

return g
2 changes: 1 addition & 1 deletion oneflow_onnx/oneflow2onnx/handlers/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def Version_1(cls, ctx, node, **kwargs):
if is_floating_value:
values = np.full(shape=shape, fill_value=floating_value, dtype=np.float32)
else:
values = np.full(shape=shape, fill_value=integer_value, dtype=np.float32)
values = np.full(shape=shape, fill_value=integer_value, dtype=np.int64)
output_name = node.output_tensor_names[0]
ctx.RemoveNode(node.name)
if is_floating_value:
Expand Down
9 changes: 7 additions & 2 deletions oneflow_onnx/oneflow2onnx/handlers/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,16 @@ def Version_6(cls, ctx, node, **kwargs):
node.input_tensor_names.append(scalar_node.output_tensor_names[0])


@flow_op("add_n", onnx_op="Sum")
@flow_op("add_n", onnx_op="Add")
class AddN:
@classmethod
def Version_1(cls, ctx, node, **kwargs):
pass
input_length = len(node.input_tensor_names)
if input_length <= 2:
pass
else:
ctx.RemoveNode(node.name)
ctx.MakeNode("Sum", node.input_tensor_names, outputs=[node.output_tensor_names[0]], op_name_scope=node.name, name="mul")


@flow_op("bias_add", onnx_op="Add", flow_ibns=["a", "b"])
Expand Down
102 changes: 81 additions & 21 deletions oneflow_onnx/oneflow2onnx/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,35 @@
from typing import Optional, Union, Tuple, List
from collections import OrderedDict
from oneflow_onnx.oneflow2onnx.flow2onnx import Export
os.environ['NVIDIA_TF32_OVERRIDE'] = '0'

def run_onnx(onnx_model_path: str, providers: List[str], ipt_dict: Optional[OrderedDict] = None, ort_optimize: bool = True,) -> Union[Tuple[OrderedDict, np.ndarray], np.ndarray]:
os.environ["NVIDIA_TF32_OVERRIDE"] = "0"


def run_onnx(
onnx_model_path: str, providers: List[str], ipt_dict: Optional[OrderedDict] = None, ort_optimize: bool = True, input_tensor_range: List = None,
) -> Union[Tuple[OrderedDict, np.ndarray], np.ndarray]:
ort_sess_opt = ort.SessionOptions()
ort_sess_opt.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED if ort_optimize else ort.GraphOptimizationLevel.ORT_DISABLE_ALL
sess = ort.InferenceSession(onnx_model_path, sess_options=ort_sess_opt, providers=providers)
assert len(sess.get_outputs()) == 1
assert len(sess.get_inputs()) <= 1

only_return_result = ipt_dict is not None

if ipt_dict is None:
ipt_dict = OrderedDict()
for ipt in sess.get_inputs():
ipt_data = np.random.uniform(low=-10, high=10, size=ipt.shape).astype(np.float32)
low, high = -10, 10
if input_tensor_range is not None:
low = input_tensor_range[0]
high = input_tensor_range[1]
ipt_data = np.random.uniform(low=low, high=high, size=ipt.shape)
if ipt.type == "tensor(int64)":
ipt_data = ipt_data.astype(np.int64)
elif ipt.type == "tensor(float)":
ipt_data = ipt_data.astype(np.float32)
elif ipt.type == "tensor(bool)":
ipt_data = ipt_data.astype(np.bool)
else:
raise NotImplementedError(f"{ipt.type} is not supported now, please give a feedback in https://github.com/Oneflow-Inc/oneflow_convert/issues/new .")
ipt_dict[ipt.name] = ipt_data

onnx_res = sess.run([], ipt_dict)[0]
Expand Down Expand Up @@ -64,9 +78,11 @@ def export_onnx_model(
onnx_model_dir = onnx_model_path
if os.path.isdir(onnx_model_path):
onnx_model_path = os.path.join(onnx_model_dir, "model.onnx")
print("Converting model to onnx....")
Export(
graph, flow_weight_dir, onnx_model_path, opset=opset, external_data=external_data, dynamic_batch_size=dynamic_batch_size,
)
print(f"Succeed converting model, save model to {onnx_model_path}")

def cleanup():
if os.path.exists(flow_weight_dir) and flow_weight_clean_flag:
Expand All @@ -88,31 +104,75 @@ def compare_result(


def convert_to_onnx_and_check(
graph, print_outlier=True, external_data=False, ort_optimize=True, opset=None, flow_weight_dir=None, onnx_model_path="/tmp", dynamic_batch_size=False, device="cpu",
graph, print_outlier=True, external_data=False, ort_optimize=True, opset=None, flow_weight_dir=None, onnx_model_path="/tmp", dynamic_batch_size=False, device="cpu", input_tensor_range=None,
):
onnx_model_path, cleanup = export_onnx_model(graph, external_data, opset, flow_weight_dir, onnx_model_path, dynamic_batch_size,)

if input_tensor_range is not None:
assert isinstance(input_tensor_range, List), f"input_tensor_range {input_tensor_range} must be a List, e.g. [-5, 5]"
assert len(input_tensor_range) == 2 and input_tensor_range[0] < input_tensor_range[1], f"input_tensor_range {input_tensor_range} must be a increasing List with two elements, e.g. [0, 10]"

if dynamic_batch_size != True:
if ort.__version__ > "1.9.0":
ipt_dict, onnx_res = run_onnx(onnx_model_path, ["TensorrtExecutionProvider", "CUDAExecutionProvider", "CPUExecutionProvider",], ort_optimize=ort_optimize,)
ipt_dict, onnx_res = run_onnx(
onnx_model_path, ["TensorrtExecutionProvider", "CUDAExecutionProvider", "CPUExecutionProvider",], ort_optimize=ort_optimize, input_tensor_range=input_tensor_range,
)
else:
ipt_dict, onnx_res = run_onnx(onnx_model_path, ["CPUExecutionProvider"], ort_optimize=ort_optimize)
ipt_dict, onnx_res = run_onnx(onnx_model_path, ["CPUExecutionProvider"], ort_optimize=ort_optimize, input_tensor_range=input_tensor_range,)

oneflow_res = None

if device == "gpu":
if len(ipt_dict) == 0:
oneflow_res = graph()
else:
oneflow_res = graph(flow.tensor(*ipt_dict.values(), dtype=flow.float32).to("cuda"))
device_kwargs = dict(device="cuda")
elif device == "cpu":
device_kwargs = dict(device="cpu")
elif device == "gpu_global":
device_kwargs = dict(sbp=flow.sbp.broadcast, placement=flow.placement("cuda", ranks=[0]))
elif device == "cpu_global":
device_kwargs = dict(sbp=flow.sbp.broadcast, placement=flow.placement("cpu", ranks=[0]))
else:
if len(ipt_dict) == 0:
oneflow_res = graph()
else:
oneflow_res = graph(flow.tensor(*ipt_dict.values(), dtype=flow.float32))
if not isinstance(oneflow_res, np.ndarray):
raise NotImplementedError

if len(ipt_dict) == 0:
oneflow_res = graph()
else:
graph_input_tensor = []
for _, value in ipt_dict.items():
value_tensor = None
if value.dtype == "int64":
value_tensor = flow.tensor(value, dtype=flow.int64, **device_kwargs)
elif value.dtype == "float" or value.dtype == "float32":
value_tensor = flow.tensor(value, dtype=flow.float32, **device_kwargs)
elif value.dtype == "bool":
value_tensor = flow.tensor(value, dtype=flow.bool, **device_kwargs)
else:
raise NotImplementedError(f"{value.dtype} is not supported now, please give a feedback in https://github.com/Oneflow-Inc/oneflow_convert/issues/new .")
graph_input_tensor.append(value_tensor)

try:
oneflow_res = graph(graph_input_tensor)
except:
print(
f"\033[0;36mInput Tensor or Weight by nn.Graph complied is not in Eager Local mode, maybe in Eager Global mode? In Eager Local Mode we can not compare result diffrience, so the inference result of the onnx model maybe not correct. We strongly recommend that you export onnx in Eager Local mode!\033[0;36m"
)

if oneflow_res is not None:
if not isinstance(oneflow_res, np.ndarray):
if flow.is_tensor(oneflow_res):
pass
elif isinstance(oneflow_res, dict):
for key, value in oneflow_res.items():
oneflow_res = value
break
elif isinstance(oneflow_res, (list, tuple)):
oneflow_res = oneflow_res[0]
else:
raise NotImplementedError
if flow.is_tensor(oneflow_res):
if "global" in device:
oneflow_res = oneflow_res.to_local()
oneflow_res = oneflow_res.numpy()
else:
oneflow_res = oneflow_res[0].numpy()
compare_result(oneflow_res, onnx_res, print_outlier=print_outlier)

print("Comparing result between oneflow and onnx....")
compare_result(oneflow_res, onnx_res, print_outlier=print_outlier)
print("Compare succeed!")
# cleanup()
12 changes: 10 additions & 2 deletions oneflow_onnx/onnx_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,7 +892,7 @@ def _get_unvisited_child(g, node, not_visited):
label = [-1 for _ in range(n)]
stack = []
in_stack = dict()
not_visited = dict.fromkeys(range(n))
not_visited = dict.fromkeys(range(n - 1, -1, -1))
label_counter = n - 1

while not_visited:
Expand Down Expand Up @@ -921,6 +921,7 @@ def MakeGraph(self, doc, onnx_filename, external_data=False, graph_name="oneflow
"""
self.DeleteUnusedNodes(self.outputs)
self.TopologicalSort(self.get_nodes())

self.UpdateProto()

ops = []
Expand Down Expand Up @@ -1202,6 +1203,7 @@ def ExtractSubGraphNodes(self, outputs_name, input_checker=None, ignore_unused_p
Return:
a list of nodes
"""
res = []
res_set = set()
if not outputs_name:
return list(res_set)
Expand All @@ -1216,8 +1218,13 @@ def ExtractSubGraphNodes(self, outputs_name, input_checker=None, ignore_unused_p
if node.is_graph_input():
if node not in res_set:
res_set.add(node)
res = []
for i in range(len(self._nodes)):
if self._nodes[i].op_type == "input":
res.append(self._nodes[i])
res_set.discard(self._nodes[i])

return list(res_set)
return res + list(res_set)

def DeleteUnusedNodes(self, outputs_name):
"""Delete nodes not in subgraph ending with output_names."""
Expand All @@ -1233,6 +1240,7 @@ def DeleteUnusedNodes(self, outputs_name):
if attr_body_graphs:
for _, body_graph in attr_body_graphs.items():
body_graph.DeleteUnusedNodes(body_graph.outputs)

self.ResetNodes(related_nodes)

def SafeToRemoveNodes(self, to_delete):
Expand Down