Skip to content

Commit

Permalink
change order of multiplication
Browse files Browse the repository at this point in the history
  • Loading branch information
umangyadav committed Jul 18, 2023
1 parent dea2a09 commit 0da1e71
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 31 deletions.
16 changes: 8 additions & 8 deletions src/onnx/parse_batchnorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
auto x_rank = x_lens.size();
if(x_rank == 1 or x_rank == 2)
{
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}});
auto numer = info.add_broadcastable_binary_op("sub", args[0], args[3]);
auto var_eps = info.add_broadcastable_binary_op("add", args[4], eps);
auto rsqrt = info.add_instruction(make_op("rsqrt"), var_eps);
auto div0 = info.add_broadcastable_binary_op("mul", numer, rsqrt);
auto r0 = info.add_broadcastable_binary_op("mul", div0, args[1]);
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}});
auto x_sub_mean = info.add_broadcastable_binary_op("sub", args[0], args[3]);
auto var_eps = info.add_broadcastable_binary_op("add", args[4], eps);
auto rsqrt = info.add_instruction(make_op("rsqrt"), var_eps);
auto mul0 = info.add_broadcastable_binary_op("mul", args[1], rsqrt);
auto r0 = info.add_broadcastable_binary_op("mul", x_sub_mean, mul0);
return info.add_broadcastable_binary_op("add", r0, args[2]);
}
else if(x_rank > 2)
Expand All @@ -82,8 +82,8 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
auto x_sub_mean = info.add_broadcastable_binary_op("sub", args[0], mean_unsqueeze);
auto var_eps = info.add_broadcastable_binary_op("add", var_unsqueeze, eps);
auto rsqrt = info.add_instruction(make_op("rsqrt"), var_eps);
auto div0 = info.add_broadcastable_binary_op("mul", x_sub_mean, rsqrt);
auto r0 = info.add_broadcastable_binary_op("mul", div0, scale_unsqueeze);
auto mul0 = info.add_broadcastable_binary_op("mul", scale_unsqueeze, rsqrt);
auto r0 = info.add_broadcastable_binary_op("mul", x_sub_mean, mul0);
return info.add_broadcastable_binary_op("add", r0, bias_unsqueeze);
}
else
Expand Down
4 changes: 2 additions & 2 deletions src/tf/parse_batchnorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
auto x_sub_mean = info.add_broadcastable_binary_op("sub", args[0], mean_unsqueeze);
auto var_eps = info.add_broadcastable_binary_op("add", var_unsqueeze, eps);
auto rsqrt = info.add_instruction(make_op("rsqrt"), var_eps);
auto div0 = info.add_broadcastable_binary_op("mul", x_sub_mean, rsqrt);
auto r0 = info.add_broadcastable_binary_op("mul", div0, scale_unsqueeze);
auto mul0 = info.add_broadcastable_binary_op("mul", scale_unsqueeze, rsqrt);
auto r0 = info.add_broadcastable_binary_op("mul", x_sub_mean, mul0);
return info.add_broadcastable_binary_op("add", r0, bias_unsqueeze);
}
};
Expand Down
29 changes: 15 additions & 14 deletions test/onnx/onnx_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -445,8 +445,8 @@ TEST_CASE(batch_norm_flat_test)
auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x, mean});
auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {var, eps});
auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), {var_eps});
auto div0 = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, rsqrt});
auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {div0, scale});
auto mul0 = add_common_op(*mm, migraphx::make_op("mul"), {scale, rsqrt});
auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, mul0});
add_common_op(*mm, migraphx::make_op("add"), {r0, bias});

auto prog = optimize_onnx("batch_norm_flat_test.onnx");
Expand All @@ -468,9 +468,9 @@ TEST_CASE(batch_norm_rank_2_test)

auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x, mean});
auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {var, eps});
auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), var_eps);
auto div0 = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, rsqrt});
auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {div0, scale});
auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), {var_eps});
auto mul0 = add_common_op(*mm, migraphx::make_op("mul"), {scale, rsqrt});
auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, mul0});
add_common_op(*mm, migraphx::make_op("add"), {r0, bias});

auto prog = optimize_onnx("batch_norm_rank_2_test.onnx");
Expand Down Expand Up @@ -498,8 +498,8 @@ TEST_CASE(batch_norm_1d_test)
auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x, usq_mean});
auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {usq_var, eps});
auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), var_eps);
auto div0 = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, rsqrt});
auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {div0, usq_scale});
auto mul0 = add_common_op(*mm, migraphx::make_op("mul"), {usq_scale, rsqrt});
auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, mul0});
add_common_op(*mm, migraphx::make_op("add"), {r0, usq_bias});

auto prog = optimize_onnx("batch_norm_1d_test.onnx");
Expand All @@ -524,11 +524,11 @@ TEST_CASE(batch_norm_2d_test)
auto usq_mean = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), mean);
auto usq_var = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), var);

auto numer = add_common_op(*mm, migraphx::make_op("sub"), {x, usq_mean});
auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {usq_var, eps});
auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), var_eps);
auto div0 = add_common_op(*mm, migraphx::make_op("mul"), {numer, rsqrt});
auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {div0, usq_scale});
auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x, usq_mean});
auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {usq_var, eps});
auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), var_eps);
auto mul0 = add_common_op(*mm, migraphx::make_op("mul"), {usq_scale, rsqrt});
auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, mul0});
add_common_op(*mm, migraphx::make_op("add"), {r0, usq_bias});

auto prog = optimize_onnx("batch_norm_2d_test.onnx");
Expand Down Expand Up @@ -559,9 +559,10 @@ TEST_CASE(batch_norm_3d_test)
auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x, usq_mean});
auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {usq_var, eps});
auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), var_eps);
auto div0 = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, rsqrt});
auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {div0, usq_scale});
auto mul0 = add_common_op(*mm, migraphx::make_op("mul"), {usq_scale, rsqrt});
auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, mul0});
add_common_op(*mm, migraphx::make_op("add"), {r0, usq_bias});

auto prog = optimize_onnx("batch_norm_3d_test.onnx");

EXPECT(p == prog);
Expand Down
13 changes: 6 additions & 7 deletions test/tf/tf_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,8 @@ TEST_CASE(batchnorm_test)
auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x, usq_mean});
auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {usq_var, eps});
auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), var_eps);
auto div0 = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, rsqrt});
auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {div0, usq_scale});

auto mul0 = add_common_op(*mm, migraphx::make_op("mul"), {usq_scale, rsqrt});
auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, mul0});
add_common_op(*mm, migraphx::make_op("add"), {r0, usq_bias});

auto prog = optimize_tf("batchnorm_test.pb", true);
Expand Down Expand Up @@ -237,8 +236,8 @@ TEST_CASE(batchnorm_half_test)
auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x, usq_mean});
auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {usq_var, eps});
auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), var_eps);
auto div0 = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, rsqrt});
auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {div0, usq_scale});
auto mul0 = add_common_op(*mm, migraphx::make_op("mul"), {usq_scale, rsqrt});
auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, mul0});
add_common_op(*mm, migraphx::make_op("add"), {r0, usq_bias});

auto prog = optimize_tf("batchnorm_half_test.pb", true);
Expand Down Expand Up @@ -267,8 +266,8 @@ TEST_CASE(batchnormv3_test)
auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x, usq_mean});
auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {usq_var, eps});
auto rsqrt = mm->add_instruction(migraphx::make_op("rsqrt"), var_eps);
auto div0 = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, rsqrt});
auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {div0, usq_scale});
auto mul0 = add_common_op(*mm, migraphx::make_op("mul"), {usq_scale, rsqrt});
auto r0 = add_common_op(*mm, migraphx::make_op("mul"), {x_sub_mean, mul0});
add_common_op(*mm, migraphx::make_op("add"), {r0, usq_bias});

auto prog = optimize_tf("batchnormv3_test.pb", true);
Expand Down

0 comments on commit 0da1e71

Please sign in to comment.