Skip to content

Commit

Permalink
【Hackathon 5th No.13】【关联 PR】Added uint8&int8&int16 support for compar…
Browse files Browse the repository at this point in the history
…e_kernel -part (PaddlePaddle#58209)

* 【Hackathon 5th No.13】【关联 PR】Added int8 support for less_than

* update 1.0

* update test

* update v1.3

* Update test_compare_op.py

* Update test_compare_op.py

* Update test_compare_op.py

* Update test_compare_op.py,test=document_fix

* Update test_compare_op.py

* Update test_compare_op.py

* Update test_compare_op.py

* Update test_compare_op.py

* Update test_compare_op.py

* Update test_compare_op.py

* Update compare_kernel.cc

* Update compare_kernel.cu

* Update compare_kernel.cc

* Update compare_kernel.cu

* Update test_compare_op.py

* Update logic.py

* Update compare_kernel.cc

* Update compare_kernel.cc
  • Loading branch information
jjyaoao authored and SecretXV committed Nov 28, 2023
1 parent d7e7281 commit a27555f
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 21 deletions.
5 changes: 4 additions & 1 deletion paddle/phi/kernels/cpu/compare_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,18 @@ PD_REGISTER_KERNEL(equal_all,
ALL_LAYOUT, \
phi::func##Kernel, \
bool, \
int16_t, \
int, \
uint8_t, \
int8_t, \
int16_t, \
int64_t, \
float, \
double, \
phi::dtype::float16, \
phi::dtype::bfloat16) { \
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
}

PD_REGISTER_COMPARE_KERNEL(less_than, LessThan)
PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual)
PD_REGISTER_COMPARE_KERNEL(greater_than, GreaterThan)
Expand Down
4 changes: 3 additions & 1 deletion paddle/phi/kernels/kps/compare_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,10 @@ PD_REGISTER_KERNEL(equal_all,
ALL_LAYOUT, \
phi::func##Kernel, \
bool, \
int16_t, \
int, \
uint8_t, \
int8_t, \
int16_t, \
int64_t, \
float, \
double, \
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/kernels/legacy/cpu/compare_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ PD_REGISTER_KERNEL(less_than_raw,
ALL_LAYOUT,
phi::LessThanRawKernel,
bool,
uint8_t,
int8_t,
int16_t,
int,
int64_t,
Expand All @@ -131,6 +133,8 @@ PD_REGISTER_KERNEL(less_than_raw,
ALL_LAYOUT, \
phi::func##RawKernel, \
bool, \
uint8_t, \
int8_t, \
int16_t, \
int, \
int64_t, \
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/kernels/legacy/kps/compare_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ PD_REGISTER_KERNEL(less_than_raw,
ALL_LAYOUT,
phi::LessThanRawKernel,
bool,
uint8_t,
int8_t,
int16_t,
int,
int64_t,
Expand All @@ -155,8 +157,10 @@ PD_REGISTER_KERNEL(less_than_raw,
ALL_LAYOUT, \
phi::func##RawKernel, \
bool, \
uint8_t, \
int16_t, \
int, \
int8_t, \
int64_t, \
float, \
double, \
Expand Down
60 changes: 48 additions & 12 deletions python/paddle/tensor/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,8 +512,8 @@ def equal(x, y, name=None):
The output has no gradient.
Args:
x (Tensor): Tensor, data type is bool, float16, float32, float64, int32, int64.
y (Tensor): Tensor, data type is bool, float16, float32, float64, int32, int64.
x (Tensor): Tensor, data type is bool, float16, float32, float64, uint8, int8, int16, int32, int64.
y (Tensor): Tensor, data type is bool, float16, float32, float64, uint8, int8, int16, int32, int64.
name (str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Expand Down Expand Up @@ -551,6 +551,9 @@ def equal(x, y, name=None):
"float16",
"float32",
"float64",
"uint8",
"int8",
"int16",
"int32",
"int64",
"uint16",
Expand All @@ -565,6 +568,9 @@ def equal(x, y, name=None):
"float16",
"float32",
"float64",
"uint8",
"int8",
"int16",
"int32",
"int64",
"uint16",
Expand Down Expand Up @@ -609,8 +615,8 @@ def greater_equal(x, y, name=None):
The output has no gradient.
Args:
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64.
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64.
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, uint8, int8, int16, int32, int64.
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, uint8, int8, int16, int32, int64.
name (str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Expand Down Expand Up @@ -639,6 +645,9 @@ def greater_equal(x, y, name=None):
"float16",
"float32",
"float64",
"uint8",
"int8",
"int16",
"int32",
"int64",
"uint16",
Expand All @@ -653,6 +662,9 @@ def greater_equal(x, y, name=None):
"float16",
"float32",
"float64",
"uint8",
"int8",
"int16",
"int32",
"int64",
"uint16",
Expand Down Expand Up @@ -697,8 +709,8 @@ def greater_than(x, y, name=None):
The output has no gradient.
Args:
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64.
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64.
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, uint8, int8, int16, int32, int64.
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, uint8, int8, int16, int32, int64.
name (str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Expand Down Expand Up @@ -727,6 +739,9 @@ def greater_than(x, y, name=None):
"float16",
"float32",
"float64",
"uint8",
"int8",
"int16",
"int32",
"int64",
"uint16",
Expand All @@ -741,6 +756,9 @@ def greater_than(x, y, name=None):
"float16",
"float32",
"float64",
"uint8",
"int8",
"int16",
"int32",
"int64",
"uint16",
Expand Down Expand Up @@ -785,8 +803,8 @@ def less_equal(x, y, name=None):
The output has no gradient.
Args:
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64.
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64.
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, uint8, int8, int16, int32, int64.
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, uint8, int8, int16, int32, int64.
name (str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Expand Down Expand Up @@ -816,6 +834,9 @@ def less_equal(x, y, name=None):
"float16",
"float32",
"float64",
"uint8",
"int8",
"int16",
"int32",
"int64",
"uint16",
Expand All @@ -830,6 +851,9 @@ def less_equal(x, y, name=None):
"float16",
"float32",
"float64",
"uint8",
"int8",
"int16",
"int32",
"int64",
"uint16",
Expand Down Expand Up @@ -874,8 +898,8 @@ def less_than(x, y, name=None):
The output has no gradient.
Args:
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64.
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64.
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, uint8, int8, int16, int32, int64.
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, uint8, int8, int16, int32, int64.
name (str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Expand Down Expand Up @@ -905,6 +929,9 @@ def less_than(x, y, name=None):
"float16",
"float32",
"float64",
"uint8",
"int8",
"int16",
"int32",
"int64",
"uint16",
Expand All @@ -919,6 +946,9 @@ def less_than(x, y, name=None):
"float16",
"float32",
"float64",
"uint8",
"int8",
"int16",
"int32",
"int64",
"uint16",
Expand Down Expand Up @@ -963,8 +993,8 @@ def not_equal(x, y, name=None):
The output has no gradient.
Args:
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64.
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64.
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float32, float64, uint8, int8, int16, int32, int64.
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float32, float64, uint8, int8, int16, int32, int64.
name (str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Expand Down Expand Up @@ -994,6 +1024,9 @@ def not_equal(x, y, name=None):
"float16",
"float32",
"float64",
"uint8",
"int8",
"int16",
"int32",
"int64",
"uint16",
Expand All @@ -1008,6 +1041,9 @@ def not_equal(x, y, name=None):
"float16",
"float32",
"float64",
"uint8",
"int8",
"int16",
"int32",
"int64",
"uint16",
Expand Down
25 changes: 18 additions & 7 deletions test/legacy_test/test_compare_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,24 +38,35 @@ def setUp(self):
def test_output(self):
self.check_output(check_cinn=True, check_pir=check_pir)

def test_errors(self):
def test_int16_support(self):
paddle.enable_static()
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.static.data(name='x', shape=[-1, 2], dtype='int32')
y = paddle.static.data(name='y', shape=[-1, 2], dtype='int32')
a = paddle.static.data(name='a', shape=[-1, 2], dtype='int16')
b = paddle.static.data(name='b', shape=[-1, 2], dtype='int16')
op = eval("paddle.%s" % self.op_type)
self.assertRaises(TypeError, op, x=x, y=a)
self.assertRaises(TypeError, op, x=a, y=y)

try:
result = op(x=a, y=b)
except TypeError:
self.fail("TypeError should not be raised for int16 inputs")

cls_name = f"{op_type}_{typename}"
Cls.__name__ = cls_name
globals()[cls_name] = Cls


for _type_name in {'float32', 'float64', 'int32', 'int64', 'float16'}:
for _type_name in {
'float32',
'float64',
'uint8',
'int8',
'int16',
'int32',
'int64',
'float16',
}:
if _type_name == 'float64' and core.is_compiled_with_rocm():
_type_name = 'float32'
if _type_name == 'float16' and (not core.is_compiled_with_cuda()):
Expand Down Expand Up @@ -513,7 +524,7 @@ def test_check_output(self):


class TestCompareOpError(unittest.TestCase):
def test_errors(self):
def test_int16_support(self):
paddle.enable_static()
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
Expand Down

0 comments on commit a27555f

Please sign in to comment.