Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-97982: Reuse PyUnicode_Count in unicode_count #98025

Merged
merged 10 commits into from
Oct 12, 2022
10 changes: 10 additions & 0 deletions Lib/test/test_unicode.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,10 @@ def test_count(self):
self.checkequal(0, 'a' * 10, 'count', 'a\u0102')
self.checkequal(0, 'a' * 10, 'count', 'a\U00100304')
self.checkequal(0, '\u0102' * 10, 'count', '\u0102\U00100304')
# test subclass
class MyStr(str):
pass
self.checkequal(3, MyStr('aaa'), 'count', 'a')

def test_find(self):
string_tests.CommonTest.test_find(self)
Expand Down Expand Up @@ -3002,6 +3006,12 @@ def test_count(self):
self.assertEqual(unicode_count(uni, ch, 0, len(uni)), 1)
self.assertEqual(unicode_count(st, ch, 0, len(st)), 0)

# subclasses should still work
class MyStr(str):
pass

self.assertEqual(unicode_count(MyStr('aab'), 'a', 0, 3), 2)

# Test PyUnicode_FindChar()
@support.cpython_only
@unittest.skipIf(_testcapi is None, 'need _testcapi module')
Expand Down
86 changes: 26 additions & 60 deletions Objects/unicodeobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -8964,21 +8964,20 @@ _PyUnicode_InsertThousandsGrouping(
return count;
}


Py_ssize_t
PyUnicode_Count(PyObject *str,
PyObject *substr,
Py_ssize_t start,
Py_ssize_t end)
static Py_ssize_t
unicode_count_impl(PyObject *str,
vstinner marked this conversation as resolved.
Show resolved Hide resolved
PyObject *substr,
Py_ssize_t start,
Py_ssize_t end)
{
assert(PyUnicode_Check(str));
assert(PyUnicode_Check(substr));

Py_ssize_t result;
int kind1, kind2;
const void *buf1 = NULL, *buf2 = NULL;
Py_ssize_t len1, len2;

if (ensure_unicode(str) < 0 || ensure_unicode(substr) < 0)
return -1;

kind1 = PyUnicode_KIND(str);
kind2 = PyUnicode_KIND(substr);
if (kind1 < kind2)
Expand All @@ -8998,6 +8997,7 @@ PyUnicode_Count(PyObject *str,
goto onError;
}

// We don't reuse `anylib_count` here because of the explicit casts.
switch (kind1) {
case PyUnicode_1BYTE_KIND:
result = ucs1lib_count(
Expand Down Expand Up @@ -9033,6 +9033,18 @@ PyUnicode_Count(PyObject *str,
return -1;
}

Py_ssize_t
PyUnicode_Count(PyObject *str,
PyObject *substr,
Py_ssize_t start,
Py_ssize_t end)
{
if (ensure_unicode(str) < 0 || ensure_unicode(substr) < 0)
return -1;

return unicode_count_impl(str, substr, start, end);
}

Py_ssize_t
PyUnicode_Find(PyObject *str,
PyObject *substr,
Expand Down Expand Up @@ -10848,62 +10860,16 @@ unicode_count(PyObject *self, PyObject *args)
PyObject *substring = NULL; /* initialize to fix a compiler warning */
Py_ssize_t start = 0;
Py_ssize_t end = PY_SSIZE_T_MAX;
PyObject *result;
int kind1, kind2;
const void *buf1, *buf2;
Py_ssize_t len1, len2, iresult;
Py_ssize_t result;

if (!parse_args_finds_unicode("count", args, &substring, &start, &end))
return NULL;

kind1 = PyUnicode_KIND(self);
kind2 = PyUnicode_KIND(substring);
if (kind1 < kind2)
return PyLong_FromLong(0);

len1 = PyUnicode_GET_LENGTH(self);
len2 = PyUnicode_GET_LENGTH(substring);
ADJUST_INDICES(start, end, len1);
if (end - start < len2)
return PyLong_FromLong(0);

buf1 = PyUnicode_DATA(self);
buf2 = PyUnicode_DATA(substring);
if (kind2 != kind1) {
buf2 = unicode_askind(kind2, buf2, len2, kind1);
if (!buf2)
return NULL;
}
switch (kind1) {
case PyUnicode_1BYTE_KIND:
iresult = ucs1lib_count(
((const Py_UCS1*)buf1) + start, end - start,
buf2, len2, PY_SSIZE_T_MAX
);
break;
case PyUnicode_2BYTE_KIND:
iresult = ucs2lib_count(
((const Py_UCS2*)buf1) + start, end - start,
buf2, len2, PY_SSIZE_T_MAX
);
break;
case PyUnicode_4BYTE_KIND:
iresult = ucs4lib_count(
((const Py_UCS4*)buf1) + start, end - start,
buf2, len2, PY_SSIZE_T_MAX
);
break;
default:
Py_UNREACHABLE();
}

result = PyLong_FromSsize_t(iresult);

assert((kind2 == kind1) == (buf2 == PyUnicode_DATA(substring)));
if (kind2 != kind1)
PyMem_Free((void *)buf2);
result = unicode_count_impl(self, substring, start, end);
if (result == -1)
return NULL;

return result;
return PyLong_FromSsize_t(result);
}

/*[clinic input]
Expand Down