Skip to content

Commit

Permalink
fix: update python tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vibhatha committed Jul 18, 2023
1 parent 6eef6ab commit 0855add
Showing 1 changed file with 41 additions and 14 deletions.
55 changes: 41 additions & 14 deletions python/pyarrow/tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from collections import OrderedDict
from collections.abc import Iterable
from datetime import datetime, timedelta
import pickle
import sys
import weakref
Expand Down Expand Up @@ -2449,24 +2450,50 @@ def test_invalid_non_join_column():
assert exp_error_msg in str(excinfo.value)


def test_aggregate_hash_functions():
@pytest.mark.parametrize("function_name", ["hash_list", "hash_one"])
def test_invalid_scalar_aggregate_functions(function_name):
table = pa.table({'key': ['a', 'a', 'b', 'b', 'a'], 'value': [11, 112, 0, 1, 2]})

aggregates = [(["value"], "hash_list", None, "value_list")]
keys = []

with pytest.raises(pa.lib.ArrowInvalid) as excinfo:
res = table.group_by(keys).aggregate([(['value'], 'hash_list')])
assert "The provided function (hash_list) is a hash aggregate function." in str(
excinfo.value)

table.group_by(keys).aggregate([(['value'], function_name)])
assert "The provided function ({}) is a hash aggregate function.".format(
function_name) in str(excinfo.value)


@pytest.fixture(params=[
(pa.int32(), [11, 11, 10, 12, 12], [10, 11, 12]),
(pa.float32(), [11.1, 11.1, 10.1, 12.1, 12.1], [10.1, 11.1, 12.1]),
(pa.bool_(), [True, True, False, True, False], [False, True]),
(pa.timestamp('s'), [datetime(2022, 1, 1), datetime(2022, 1, 1),
datetime(2022, 2, 2), datetime(2022, 2, 3),
datetime(2022, 1, 3)], [datetime(2022, 1, 1),
datetime(2022, 1, 3),
datetime(2022, 2, 2),
datetime(2022, 2, 3)]),
(pa.time32('s'), [timedelta(hours=1), timedelta(hours=1),
timedelta(hours=2), timedelta(hours=2),
timedelta(hours=3)], [timedelta(hours=1),
timedelta(hours=2),
timedelta(hours=3)]),
(pa.duration('s'), [timedelta(days=1), timedelta(days=1),
timedelta(days=2), timedelta(days=2),
timedelta(days=3)], [timedelta(days=1),
timedelta(days=2),
timedelta(days=3)]),
(pa.utf8(), ['ab', 'ab', 'cd', 'cd', 'ef'], ['ab', 'cd', 'ef'])
])
def typed_table(request):
type_, values, distinct_values = request.param
return pa.table({'key': ['a', 'a', 'b', 'b', 'a'],
'value': values}), distinct_values

def test_scalar_aggregate_distinct_functions():
table = pa.table({'key': ['a', 'a', 'b', 'b', 'a'], 'value': [11, 11, 10, 12, 12]})

aggregates = [(["value"], "distinct", None, "value_distinct")]
def test_scalar_aggregate_distinct_functions(typed_table):
table, expected_distinct_values = typed_table
keys = []
res_table = table.group_by(keys).aggregate([(['value'], 'distinct')])
sort_indices = pc.sort_indices(
res_table, sort_keys=[("value_distinct", "ascending")])
sorted_table = pc.take(res_table, sort_indices)

func = pc.get_function("distinct")
print(func)
res = table.group_by(keys).aggregate([(['value'], 'distinct')])
assert sorted_table.column('value_distinct').to_pylist() == expected_distinct_values

0 comments on commit 0855add

Please sign in to comment.