Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add torch.positive #4999

Merged
merged 8 commits into from
Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tools/pnnx/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ set(pnnx_pass_level2_SRCS
pass_level2/torch_normal.cpp
pass_level2/torch_ones.cpp
pass_level2/torch_ones_like.cpp
pass_level2/torch_positive.cpp
pass_level2/torch_prod.cpp
pass_level2/torch_quantize_per_tensor.cpp
pass_level2/torch_randn.cpp
Expand Down Expand Up @@ -436,7 +437,6 @@ set(pnnx_pass_ncnn_SRCS
pass_ncnn/F_max_pool1d.cpp
pass_ncnn/F_max_pool2d.cpp
pass_ncnn/F_max_pool3d.cpp
pass_ncnn/F_mish.cpp
nicochen1118 marked this conversation as resolved.
Show resolved Hide resolved
pass_ncnn/F_normalize.cpp
pass_ncnn/F_pad.cpp
pass_ncnn/F_pixel_shuffle.cpp
Expand Down
40 changes: 40 additions & 0 deletions tools/pnnx/src/pass_level2/torch_positive.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

namespace pnnx {

class torch_positive : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
aten::positive op_0 1 1 input out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.positive";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_positive, 20)

} // namespace pnnx
4 changes: 4 additions & 0 deletions tools/pnnx/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,10 @@ pnnx_add_test(pnnx_fuse_pad_conv1d)
pnnx_add_test(pnnx_fuse_pad_conv2d)
pnnx_add_test(pnnx_fuse_pixel_unshuffle)

if(Torch_VERSION VERSION_GREATER_EQUAL "1.8")
pnnx_add_test(torch_positive)
endif()

if(Torch_VERSION VERSION_GREATER_EQUAL "1.9")
pnnx_add_test(F_mish)
pnnx_add_test(nn_Mish)
Expand Down
61 changes: 61 additions & 0 deletions tools/pnnx/tests/test_torch_positive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z):
x = torch.positive(x)
y = torch.positive(y)
z = torch.positive(z)
return x, y, z

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(1, 3, 16)
y = torch.rand(1, 5, 9, 11)
z = torch.rand(14, 8, 5, 9, 10)

a = net(x, y, z)

# export torchscript
mod = torch.jit.trace(net, (x, y, z))
mod.save("test_torch_positive.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_torch_positive.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]")

# pnnx inference
import test_torch_positive_pnnx
b = test_torch_positive_pnnx.test_inference()

for a0, b0 in zip(a, b):
if not torch.equal(a0, b0):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)
Loading