Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] Allow setting *_pool_size with human-readable string #1670

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions python/rmm/rmm/_lib/helper.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,30 @@ cdef dict BYTE_SIZES = {
pattern = re.compile(r"^([0-9]+(?:\.[0-9]*)?)[\t ]*((?i:(?:[kmgtp]i?)?b))?$")

cdef object parse_bytes(object s):
""" Parse byte string to numbers
"""Parse a string or integer into a number of bytes.

Parameters
----------
s : int | str
Size in bytes
Size in bytes. If an integer is provided, it is returned as-is.
A string is parsed as a floating point number with an (optional,
case-insensitive) byte-specifier, both SI prefixes (kb, mb, ..., pb)
and binary prefixes (kib, mib, ..., pib) are supported.

Returns
-------
Requested size in bytes as an integer.
Raises
wence- marked this conversation as resolved.
Show resolved Hide resolved
------
ValueError
If it is not possible to parse the input as a byte specification.
"""
cdef str suffix
cdef double n
cdef int multiplier

if isinstance(s, int):
return int(s)
return s

match = pattern.match(s)

Expand Down
2 changes: 1 addition & 1 deletion python/rmm/rmm/_lib/memory_resource.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ cdef class CudaAsyncMemoryResource(DeviceMemoryResource):
----------
initial_pool_size : int | str, optional
Initial pool size in bytes. By default, half the available memory
on the device is used.
on the device is used. A string argument is parsed using `parse_bytes`.
release_threshold: int, optional
Release threshold in bytes. If the pool size grows beyond this
value, unused memory held by the pool will be released at the
Expand Down
2 changes: 2 additions & 0 deletions python/rmm/rmm/rmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,12 @@ def reinitialize(
When `pool_allocator` is True, this indicates the initial pool size in
bytes. By default, 1/2 of the total GPU memory is used.
When `pool_allocator` is False, this argument is ignored if provided.
Matt711 marked this conversation as resolved.
Show resolved Hide resolved
A string argument is parsed using `parse_bytes`.
maximum_pool_size : int | str, default None
When `pool_allocator` is True, this indicates the maximum pool size in
bytes. By default, the total available memory on the GPU is used.
When `pool_allocator` is False, this argument is ignored if provided.
Matt711 marked this conversation as resolved.
Show resolved Hide resolved
A string argument is parsed using `parse_bytes`.
devices : int or List[int], default 0
GPU device IDs to register. By default registers only GPU 0.
logging : bool, default False
Expand Down
20 changes: 19 additions & 1 deletion python/rmm/rmm/tests/test_rmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,12 +524,30 @@ def test_reinitialize_initial_pool_size_gt_max():
with pytest.raises(RuntimeError) as e:
rmm.reinitialize(
pool_allocator=True,
initial_pool_size="2KiB",
initial_pool_size=1 << 11,
maximum_pool_size=1 << 10,
)
assert "Initial pool size exceeds the maximum pool size" in str(e.value)
wence- marked this conversation as resolved.
Show resolved Hide resolved


def test_reinitialize_with_valid_str_arg_pool_size():
rmm.reinitialize(
pool_allocator=True,
initial_pool_size="2kib",
maximum_pool_size="8kib",
)


def test_reinitialize_with_invalid_str_arg_pool_size():
with pytest.raises(ValueError) as e:
rmm.reinitialize(
pool_allocator=True,
initial_pool_size="2k", # 2kb valid, not 2k
maximum_pool_size="8k",
)
assert "Could not parse" in str(e.value)


@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("nelem", _nelems)
@pytest.mark.parametrize("alloc", _allocs)
Expand Down
Loading