Skip to content

Commit

Permalink
【Hackathon 5th No.13】【关联 PR】Added int8 support for less_than
Browse files Browse the repository at this point in the history
  • Loading branch information
jjyaoao committed Oct 18, 2023
1 parent 9690055 commit 9431401
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 4 deletions.
3 changes: 2 additions & 1 deletion paddle/phi/kernels/kps/compare_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ void EqualAllKernel(const Context& ctx,
} // namespace phi

#ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(less_than, KPS, ALL_LAYOUT, phi::LessThanKernel, int) {
PD_REGISTER_KERNEL(
less_than, KPS, ALL_LAYOUT, phi::LessThanKernel, int, int8_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
PD_REGISTER_KERNEL(less_equal, KPS, ALL_LAYOUT, phi::LessEqualKernel, int) {
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/xpu/compare_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ PD_REGISTER_KERNEL(less_than,
ALL_LAYOUT,
phi::LessThanKernel,
int,
int8_t,
int64_t,
float,
phi::dtype::float16) {
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/tensor/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,7 @@ def less_than(x, y, name=None):
"float16",
"float32",
"float64",
"int8",
"int32",
"int64",
"uint16",
Expand All @@ -921,6 +922,7 @@ def less_than(x, y, name=None):
"float16",
"float32",
"float64",
"int8",
"int32",
"int64",
"uint16",
Expand Down
37 changes: 34 additions & 3 deletions test/legacy_test/test_compare_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,13 @@ def test_output(self):
def test_errors(self):
paddle.enable_static()
with program_guard(Program(), Program()):
x = paddle.static.data(name='x', shape=[-1, 2], dtype='int32')
y = paddle.static.data(name='y', shape=[-1, 2], dtype='int32')
a = paddle.static.data(name='a', shape=[-1, 2], dtype='int16')
x = paddle.static.data(name='x', shape=[-1, 2], dtype=typename)
y = paddle.static.data(name='y', shape=[-1, 2], dtype=typename)
error_dtype = 'int16' if typename != 'int16' else 'int32'
a = paddle.static.data(
name='a', shape=[-1, 2], dtype=error_dtype
)

op = eval("paddle.%s" % self.op_type)
self.assertRaises(TypeError, op, x=x, y=a)
self.assertRaises(TypeError, op, x=a, y=y)
Expand All @@ -67,6 +71,8 @@ def test_errors(self):
create_test_class('equal', _type_name, lambda _a, _b: _a == _b, True)
create_test_class('not_equal', _type_name, lambda _a, _b: _a != _b)

create_test_class('less_than', 'int8', lambda _a, _b: _a < _b)


def create_paddle_case(op_type, callback):
class PaddleCls(unittest.TestCase):
Expand Down Expand Up @@ -477,6 +483,31 @@ def test_check_output(self):
create_bf16_case('not_equal', lambda _a, _b: _a != _b)


# add int8 tests
def create_int8_case(op_type, callback, check_pir=False):
class TestCompareOpInt8Op(op_test.OpTest):
def setUp(self):
self.op_type = op_type
self.dtype = np.int8
self.python_api = eval("paddle." + op_type)

x = np.random.randint(-128, 127, size=[5, 5]).astype(np.int8)
y = np.random.randint(-128, 127, size=[5, 5]).astype(np.int8)
real_result = callback(x, y)
self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': real_result}

def test_check_output(self):
self.check_output(check_cinn=True, check_pir=check_pir)

cls_name = f"Int8TestCase_{op_type}"
TestCompareOpInt8Op.__name__ = cls_name
globals()[cls_name] = TestCompareOpInt8Op


create_int8_case('less_than', lambda _a, _b: _a < _b)


class TestCompareOpError(unittest.TestCase):
def test_errors(self):
paddle.enable_static()
Expand Down

0 comments on commit 9431401

Please sign in to comment.