Skip to content

Commit

Permalink
BUG: ArrowExtensionArray._from_* accepts pyarrow arrays (#48264)
Browse files Browse the repository at this point in the history
  • Loading branch information
mroeschke authored Sep 6, 2022
1 parent edf0fce commit 50c119d
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 20 deletions.
20 changes: 13 additions & 7 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,13 @@ def _from_sequence(cls, scalars, *, dtype: Dtype | None = None, copy=False):
Construct a new ExtensionArray from a sequence of scalars.
"""
pa_dtype = to_pyarrow_type(dtype)
if isinstance(scalars, cls):
data = scalars._data
is_cls = isinstance(scalars, cls)
if is_cls or isinstance(scalars, (pa.Array, pa.ChunkedArray)):
if is_cls:
scalars = scalars._data
if pa_dtype:
data = data.cast(pa_dtype)
return cls(data)
scalars = scalars.cast(pa_dtype)
return cls(scalars)
else:
return cls(
pa.chunked_array(pa.array(scalars, type=pa_dtype, from_pandas=True))
Expand All @@ -242,7 +244,10 @@ def _from_sequence_of_strings(
Construct a new ExtensionArray from a sequence of strings.
"""
pa_type = to_pyarrow_type(dtype)
if pa.types.is_timestamp(pa_type):
if pa_type is None:
# Let pyarrow try to infer or raise
scalars = strings
elif pa.types.is_timestamp(pa_type):
from pandas.core.tools.datetimes import to_datetime

scalars = to_datetime(strings, errors="raise")
Expand Down Expand Up @@ -272,8 +277,9 @@ def _from_sequence_of_strings(

scalars = to_numeric(strings, errors="raise")
else:
# Let pyarrow try to infer or raise
scalars = strings
raise NotImplementedError(
f"Converting strings to {pa_type} is not implemented."
)
return cls._from_sequence(scalars, dtype=pa_type, copy=copy)

def __getitem__(self, item: PositionalIndexer):
Expand Down
25 changes: 14 additions & 11 deletions pandas/core/tools/times.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,20 @@ def _convert_listlike(arg, format):
format_found = False
for element in arg:
time_object = None
for time_format in formats:
try:
time_object = datetime.strptime(element, time_format).time()
if not format_found:
# Put the found format in front
fmt = formats.pop(formats.index(time_format))
formats.insert(0, fmt)
format_found = True
break
except (ValueError, TypeError):
continue
try:
time_object = time.fromisoformat(element)
except (ValueError, TypeError):
for time_format in formats:
try:
time_object = datetime.strptime(element, time_format).time()
if not format_found:
# Put the found format in front
fmt = formats.pop(formats.index(time_format))
formats.insert(0, fmt)
format_found = True
break
except (ValueError, TypeError):
continue

if time_object is not None:
times.append(time_object)
Expand Down
94 changes: 94 additions & 0 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import pytest

from pandas.compat import (
is_ci_environment,
is_platform_windows,
pa_version_under2p0,
pa_version_under3p0,
pa_version_under4p0,
Expand All @@ -37,6 +39,8 @@

pa = pytest.importorskip("pyarrow", minversion="1.0.1")

from pandas.core.arrays.arrow.array import ArrowExtensionArray

from pandas.core.arrays.arrow.dtype import ArrowDtype # isort:skip


Expand Down Expand Up @@ -224,6 +228,96 @@ def test_from_dtype(self, data, request):
)
super().test_from_dtype(data)

def test_from_sequence_pa_array(self, data, request):
# https://github.com/pandas-dev/pandas/pull/47034#discussion_r955500784
# data._data = pa.ChunkedArray
if pa_version_under3p0:
request.node.add_marker(
pytest.mark.xfail(
reason="ChunkedArray has no attribute combine_chunks",
)
)
result = type(data)._from_sequence(data._data)
tm.assert_extension_array_equal(result, data)
assert isinstance(result._data, pa.ChunkedArray)

result = type(data)._from_sequence(data._data.combine_chunks())
tm.assert_extension_array_equal(result, data)
assert isinstance(result._data, pa.ChunkedArray)

def test_from_sequence_pa_array_notimplemented(self, request):
if pa_version_under6p0:
request.node.add_marker(
pytest.mark.xfail(
raises=AttributeError,
reason="month_day_nano_interval not implemented by pyarrow.",
)
)
with pytest.raises(NotImplementedError, match="Converting strings to"):
ArrowExtensionArray._from_sequence_of_strings(
["12-1"], dtype=pa.month_day_nano_interval()
)

def test_from_sequence_of_strings_pa_array(self, data, request):
pa_dtype = data.dtype.pyarrow_dtype
if pa_version_under3p0:
request.node.add_marker(
pytest.mark.xfail(
reason="ChunkedArray has no attribute combine_chunks",
)
)
elif pa.types.is_time64(pa_dtype) and pa_dtype.equals("time64[ns]"):
request.node.add_marker(
pytest.mark.xfail(
reason="Nanosecond time parsing not supported.",
)
)
elif pa.types.is_duration(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
reason=f"pyarrow doesn't support parsing {pa_dtype}",
)
)
elif pa.types.is_boolean(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
reason="Iterating over ChunkedArray[bool] returns PyArrow scalars.",
)
)
elif pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is not None:
if pa_version_under7p0:
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
reason=f"pyarrow doesn't support string cast from {pa_dtype}",
)
)
elif is_platform_windows() and is_ci_environment():
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowInvalid,
reason=(
"TODO: Set ARROW_TIMEZONE_DATABASE environment variable "
"on CI to path to the tzdata for pyarrow."
),
)
)
elif pa_version_under6p0 and pa.types.is_temporal(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
reason=f"pyarrow doesn't support string cast from {pa_dtype}",
)
)
pa_array = data._data.cast(pa.string())
result = type(data)._from_sequence_of_strings(pa_array, dtype=data.dtype)
tm.assert_extension_array_equal(result, data)

pa_array = pa_array.combine_chunks()
result = type(data)._from_sequence_of_strings(pa_array, dtype=data.dtype)
tm.assert_extension_array_equal(result, data)


@pytest.mark.xfail(
raises=NotImplementedError, reason="pyarrow.ChunkedArray backing is 1D."
Expand Down
7 changes: 5 additions & 2 deletions pandas/tests/tools/test_to_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import numpy as np
import pytest

from pandas.compat import PY311

from pandas import Series
import pandas._testing as tm
from pandas.core.tools.datetimes import to_time as to_time_alias
Expand Down Expand Up @@ -40,8 +42,9 @@ def test_parsers_time(self, time_string):
def test_odd_format(self):
new_string = "14.15"
msg = r"Cannot convert arg \['14\.15'\] to a time"
with pytest.raises(ValueError, match=msg):
to_time(new_string)
if not PY311:
with pytest.raises(ValueError, match=msg):
to_time(new_string)
assert to_time(new_string, format="%H.%M") == time(14, 15)

def test_arraylike(self):
Expand Down

0 comments on commit 50c119d

Please sign in to comment.