Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the ability to register and unregister reinitialization hooks #1072

Merged
merged 14 commits into from
Jul 20, 2022
2 changes: 2 additions & 0 deletions python/rmm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
RMMNumbaManager,
_numba_memory_manager,
is_initialized,
register_reinitialize_hook,
reinitialize,
rmm_cupy_allocator,
unregister_reinitialize_hook,
)

__version__ = get_versions()["version"]
49 changes: 49 additions & 0 deletions python/rmm/rmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ def __init__(self, errcode, msg):
super(RMMError, self).__init__(msg)


_reinitialize_hooks = []


def reinitialize(
pool_allocator=False,
managed_memory=False,
Expand Down Expand Up @@ -82,6 +85,9 @@ def reinitialize(
Use `rmm.get_log_filenames()` to get the log file names
corresponding to each device.
"""
for func, args, kwargs in reversed(_reinitialize_hooks):
func(*args, **kwargs)

rmm.mr._initialize(
pool_allocator=pool_allocator,
managed_memory=managed_memory,
Expand Down Expand Up @@ -231,3 +237,46 @@ def rmm_cupy_allocator(nbytes):
ptr = cupy.cuda.memory.MemoryPointer(mem, 0)

return ptr


def register_reinitialize_hook(func, *args, **kwargs):
"""
Add a function to the list of functions ("hooks") that will be
called before :py:func:`~rmm.reinitialize()`.

A user or library may register hooks to perform any necessary
cleanup before RMM is reinitialized. For example, a library with
an internal cache of objects that use device memory allocated by
RMM can register a hook to release those references before RMM is
reinitialized, thus ensuring that the relevant device memory
resource can be deallocated.

Hooks are called in the *reverse* order they are registered. This
is useful, for example, when a library registers multiple hooks
and needs them to run in a specific order for cleanup to be safe.
Hooks cannot rely on being registered in a particular order
relative to hooks registered by other packages, since that is
determined by package import ordering.

Parameters
----------
func : callable
Function to be called before :py:func:`~rmm.reinitialize()`
args, kwargs
Positional and keyword arguments to be passed to `func`
"""
global _reinitialize_hooks
_reinitialize_hooks.append((func, args, kwargs))
return func


def unregister_reinitialize_hook(func):
"""
Remove `func` from list of hooks that will be called before
:py:func:`~rmm.reinitialize()`.

If `func` was registered more than once, every instance of it will
be removed from the list of hooks.
"""
global _reinitialize_hooks
_reinitialize_hooks = [x for x in _reinitialize_hooks if x[0] != func]
69 changes: 69 additions & 0 deletions python/rmm/tests/test_rmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,3 +765,72 @@ def deallocate_func(ptr, size):

captured = capsys.readouterr()
assert captured.out == "Allocating 256 bytes\nDeallocating 256 bytes\n"


@pytest.fixture
def make_reinit_hook():
funcs = []

def _make_reinit_hook(func, *args, **kwargs):
funcs.append(func)
rmm.register_reinitialize_hook(func, *args, **kwargs)
return func

yield _make_reinit_hook
for func in funcs:
rmm.unregister_reinitialize_hook(func)


def test_reinit_hooks_register(make_reinit_hook):
L = []
make_reinit_hook(lambda: L.append(1))
make_reinit_hook(lambda: L.append(2))
make_reinit_hook(lambda x: L.append(x), 3)

rmm.reinitialize()
assert L == [3, 2, 1]


def test_reinit_hooks_unregister(make_reinit_hook):
L = []
one = make_reinit_hook(lambda: L.append(1))
make_reinit_hook(lambda: L.append(2))

rmm.unregister_reinitialize_hook(one)
rmm.reinitialize()
assert L == [2]


def test_reinit_hooks_register_twice(make_reinit_hook):
L = []

def func_with_arg(x):
L.append(x)

def func_without_arg():
L.append(2)

make_reinit_hook(func_with_arg, 1)
make_reinit_hook(func_without_arg)
make_reinit_hook(func_with_arg, 3)
make_reinit_hook(func_without_arg)

rmm.reinitialize()
assert L == [2, 3, 2, 1]


def test_reinit_hooks_unregister_twice_registered(make_reinit_hook):
# unregistering a twice-registered function
# should unregister both instances:
L = []

def func_with_arg(x):
L.append(x)

make_reinit_hook(func_with_arg, 1)
make_reinit_hook(lambda: L.append(2))
make_reinit_hook(func_with_arg, 3)

rmm.unregister_reinitialize_hook(func_with_arg)
rmm.reinitialize()
assert L == [2]