Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support yolov5 more ops #95

Merged
merged 8 commits into from
Sep 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/oneflow2onnx/nodes/CPU/test_upsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def test_upsample_nearest_2d():
upsample_nearest2d_graph._compile(flow.randn(1, 1, 2, 2))

convert_to_onnx_and_check(upsample_nearest2d_graph, onnx_model_path="/tmp", opset=10)
convert_to_onnx_and_check(upsample_nearest2d_graph, onnx_model_path="/tmp", opset=12)


test_upsample_nearest_2d()
19 changes: 18 additions & 1 deletion oneflow_onnx/oneflow2onnx/handlers/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,11 @@ def Version_5(cls, ctx, node, **kwargs):
onnx_pb.TensorProto.INT16,
onnx_pb.TensorProto.INT64,
]
shape_node = ctx.MakeConst(oneflow._oneflow_internal.UniqueStr("shape"), np.array(node.attrs.get("shape"), None),)

shape_node = ctx.MakeConst(oneflow._oneflow_internal.UniqueStr("shape"), np.array(node.attrs.get("shape"), None))
if node.attrs.get("shape") == []:
BBuf marked this conversation as resolved.
Show resolved Hide resolved
shape_node = ctx.MakeConst(oneflow._oneflow_internal.UniqueStr("shape"), np.array([]).astype(np.int64))

node.input_tensor_names = node.input_tensor_names + [shape_node.name]
if ctx.opset >= 8 or not need_casting:
# onnx reshape can handle the type - done
Expand Down Expand Up @@ -177,6 +181,19 @@ def Version_13(cls, ctx, node, **kwargs):
node.input_tensor_names.append(axis_node.output_tensor_names[0])


@flow_op("expand", "Expand")
class ExpandOp:
@classmethod
def Version_8(cls, ctx, node, **kwargs):
shape = node.attrs.get("expand_shape")
shape_node = ctx.MakeConst(oneflow._oneflow_internal.UniqueStr("shape"), np.array(shape).astype(np.int64))
node.input_tensor_names.append(shape_node.output_tensor_names[0])

@classmethod
def Version_13(cls, ctx, node, **kwargs):
cls.Version_8(ctx, node, **kwargs)


@flow_op("transpose", onnx_op="Transpose")
class Transpose:
@classmethod
Expand Down
32 changes: 32 additions & 0 deletions oneflow_onnx/oneflow2onnx/handlers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,38 @@ def Version_10(cls, ctx, node, **kwargs):
else:
raise NotImplementedError("Opset 10 don't support specify output_size attribute!")

@classmethod
def Version_11(cls, ctx, node, **kwargs):
node.attrs["coordinate_transformation_mode"] = "half_pixel"
node.attrs["mode"] = "nearest"
node.attrs["nearest_mode"] = "round_prefer_floor"
# onnx support nchw
# node.input_tensor_names.append("")
roi = []
input_shape = ctx.get_shape(node.input_tensor_names[0])
for i in range(len(input_shape)):
roi.append(0)
roi.append(input_shape[i])
roi_node = ctx.MakeConst(oneflow._oneflow_internal.UniqueStr("roi"), np.array(roi).astype(np.float32),)
node.input_tensor_names.append(roi_node.output_tensor_names[0])
if len(node.attrs["output_size"]) == 0:
scales = [1.0, 1.0]
scales.append(node.attrs["height_scale"])
scales.append(node.attrs["width_scale"])
scales_node = ctx.MakeConst(oneflow._oneflow_internal.UniqueStr("scales"), np.array(scales).astype(np.float32),)
node.input_tensor_names.append(scales_node.output_tensor_names[0])
node.input_tensor_names.append("")
else:
node_sizes = node.attrs["output_size"]
node.input_tensor_names.append("")
sizes = []
sizes.append(input_shape[0])
sizes.append(input_shape[1])
sizes.append(node_sizes[0])
sizes.append(node_sizes[1])
sizes_node = ctx.MakeConst(oneflow._oneflow_internal.UniqueStr("sizes"), np.array(sizes).astype(np.int64),)
node.input_tensor_names.append(sizes_node.output_tensor_names[0])

@classmethod
def Version_13(cls, ctx, node, **kwargs):
node.attrs["coordinate_transformation_mode"] = "half_pixel"
Expand Down
5 changes: 4 additions & 1 deletion oneflow_onnx/oneflow2onnx/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,10 @@ def convert_to_onnx_and_check(
else:
oneflow_res = graph(flow.tensor(*ipt_dict.values(), dtype=flow.float32))
if not isinstance(oneflow_res, np.ndarray):
oneflow_res = oneflow_res.numpy()
if flow.is_tensor(oneflow_res):
oneflow_res = oneflow_res.numpy()
else:
oneflow_res = oneflow_res[0].numpy()
compare_result(oneflow_res, onnx_res, print_outlier=print_outlier)

# cleanup()