Skip to content

Commit

Permalink
Merge pull request #1483 from wenhoujx/master
Browse files Browse the repository at this point in the history
Add int16 int32 support to get_as_arrow.
  • Loading branch information
ray6080 committed Apr 24, 2023
2 parents 0b046a6 + 7408e23 commit 09106ce
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 8 deletions.
6 changes: 6 additions & 0 deletions src/common/arrow/arrow_converter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ void ArrowConverter::setArrowFormat(
case DataTypeID::INT64: {
child.format = "l";
} break;
case DataTypeID::INT32: {
child.format = "i";
} break;
case DataTypeID::INT16: {
child.format = "s";
} break;
case DataTypeID::DOUBLE: {
child.format = "g";
} break;
Expand Down
24 changes: 24 additions & 0 deletions src/common/arrow/arrow_row_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ std::unique_ptr<ArrowVector> ArrowRowBatch::createVector(
case INT64: {
templateInitializeVector<INT64>(result.get(), typeInfo, capacity);
} break;
case INT32: {
templateInitializeVector<INT32>(result.get(), typeInfo, capacity);
} break;
case INT16: {
templateInitializeVector<INT16>(result.get(), typeInfo, capacity);
} break;
case DOUBLE: {
templateInitializeVector<DOUBLE>(result.get(), typeInfo, capacity);
} break;
Expand Down Expand Up @@ -252,6 +258,12 @@ void ArrowRowBatch::copyNonNullValue(
case INT64: {
templateCopyNonNullValue<INT64>(vector, typeInfo, value, pos);
} break;
case INT32: {
templateCopyNonNullValue<INT32>(vector, typeInfo, value, pos);
} break;
case INT16: {
templateCopyNonNullValue<INT16>(vector, typeInfo, value, pos);
} break;
case DOUBLE: {
templateCopyNonNullValue<DOUBLE>(vector, typeInfo, value, pos);
} break;
Expand Down Expand Up @@ -317,6 +329,12 @@ void ArrowRowBatch::copyNullValue(ArrowVector* vector, Value* value, std::int64_
case INT64: {
templateCopyNullValue<INT64>(vector, pos);
} break;
case INT32: {
templateCopyNullValue<INT32>(vector, pos);
} break;
case INT16: {
templateCopyNullValue<INT16>(vector, pos);
} break;
case DOUBLE: {
templateCopyNullValue<DOUBLE>(vector, pos);
} break;
Expand Down Expand Up @@ -449,6 +467,12 @@ ArrowArray* ArrowRowBatch::convertVectorToArray(
case INT64: {
return templateCreateArray<INT64>(vector, typeInfo);
}
case INT32: {
return templateCreateArray<INT32>(vector, typeInfo);
}
case INT16: {
return templateCreateArray<INT16>(vector, typeInfo);
}
case DOUBLE: {
return templateCreateArray<DOUBLE>(vector, typeInfo);
}
Expand Down
26 changes: 18 additions & 8 deletions tools/python_api/test/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,35 +11,45 @@ def test_to_arrow(establish_connection):
conn, db = establish_connection

def _test_primitive_data_types(_conn):
query = "MATCH (a:person) RETURN a.age, a.isStudent, a.eyeSight, a.birthdate, a.registerTime," \
query = "MATCH (a:person) RETURN a.age, to_int32(a.age), to_int16(a.age), a.isStudent, a.eyeSight, a.birthdate, a.registerTime," \
" a.lastJobDuration, a.fName ORDER BY a.ID"
arrow_tbl = _conn.execute(query).get_as_arrow(8)
assert arrow_tbl.num_columns == 7
assert arrow_tbl.num_columns == 9

age_col = arrow_tbl.column(0)
assert age_col.type == pa.int64()
assert age_col.length() == 8
assert age_col.to_pylist() == [35, 30, 45, 20, 20, 25, 40, 83]

is_student_col = arrow_tbl.column(1)
age_32_col = arrow_tbl.column(1)
assert age_32_col.type == pa.int32()
assert age_32_col.length() == 8
assert age_32_col.to_pylist() == [35, 30, 45, 20, 20, 25, 40, 83]

age_16_col = arrow_tbl.column(2)
assert age_16_col.type == pa.int16()
assert age_16_col.length() == 8
assert age_16_col.to_pylist() == [35, 30, 45, 20, 20, 25, 40, 83]

is_student_col = arrow_tbl.column(3)
assert is_student_col.type == pa.bool_()
assert is_student_col.length() == 8
assert is_student_col.to_pylist() == [True, True, False, False, False, True, False, False]

eye_sight_col = arrow_tbl.column(2)
eye_sight_col = arrow_tbl.column(4)
assert eye_sight_col.type == pa.float64()
assert eye_sight_col.length() == 8
assert eye_sight_col.to_pylist() == [5.0, 5.1, 5.0, 4.8, 4.7, 4.5, 4.9, 4.9]

birthdate_col = arrow_tbl.column(3)
birthdate_col = arrow_tbl.column(5)
assert birthdate_col.type == pa.date32()
assert birthdate_col.length() == 8
assert birthdate_col.to_pylist() == [datetime.date(1900, 1, 1), datetime.date(1900, 1, 1),
datetime.date(1940, 6, 22), datetime.date(1950, 7, 23),
datetime.date(1980, 10, 26), datetime.date(1980, 10, 26),
datetime.date(1980, 10, 26), datetime.date(1990, 11, 27)]

register_time_col = arrow_tbl.column(4)
register_time_col = arrow_tbl.column(6)
assert register_time_col.type == pa.timestamp('us')
assert register_time_col.length() == 8
assert register_time_col.to_pylist() == [
Expand All @@ -48,7 +58,7 @@ def _test_primitive_data_types(_conn):
datetime.datetime(1976, 12, 23, 11, 21, 42), datetime.datetime(1972, 7, 31, 13, 22, 30, 678559),
datetime.datetime(1976, 12, 23, 4, 41, 42), datetime.datetime(2023, 2, 21, 13, 25, 30)]

last_job_duration_col = arrow_tbl.column(5)
last_job_duration_col = arrow_tbl.column(7)
assert last_job_duration_col.type == pa.duration('ms')
assert last_job_duration_col.length() == 8
assert last_job_duration_col.to_pylist() == [datetime.timedelta(days=99, seconds=36334, microseconds=628000),
Expand All @@ -60,7 +70,7 @@ def _test_primitive_data_types(_conn):
datetime.timedelta(microseconds=125000),
datetime.timedelta(days=541, seconds=57600, microseconds=24000)]

f_name_col = arrow_tbl.column(6)
f_name_col = arrow_tbl.column(8)
assert f_name_col.type == pa.string()
assert f_name_col.length() == 8
assert f_name_col.to_pylist() == ["Alice", "Bob", "Carol", "Dan", "Elizabeth", "Farooq", "Greg",
Expand Down

0 comments on commit 09106ce

Please sign in to comment.