Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the ability to register and unregister reinitialization hooks #1072

Merged
merged 14 commits into from
Jul 20, 2022
2 changes: 2 additions & 0 deletions python/rmm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
RMMNumbaManager,
_numba_memory_manager,
is_initialized,
register_reinitialize_hook,
reinitialize,
rmm_cupy_allocator,
unregister_reinitialize_hook,
)

__version__ = get_versions()["version"]
25 changes: 25 additions & 0 deletions python/rmm/rmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ def __init__(self, errcode, msg):
super(RMMError, self).__init__(msg)


_reinitialize_hooks = {}


def reinitialize(
pool_allocator=False,
managed_memory=False,
Expand Down Expand Up @@ -82,6 +85,10 @@ def reinitialize(
Use `rmm.get_log_filenames()` to get the log file names
corresponding to each device.
"""
[
hook(*args, **kwargs)
for (hook, (args, kwargs)) in reversed(_reinitialize_hooks.items())
]
shwina marked this conversation as resolved.
Show resolved Hide resolved
rmm.mr._initialize(
pool_allocator=pool_allocator,
managed_memory=managed_memory,
Expand Down Expand Up @@ -231,3 +238,21 @@ def rmm_cupy_allocator(nbytes):
ptr = cupy.cuda.memory.MemoryPointer(mem, 0)

return ptr


def register_reinitialize_hook(func, *args, **kwargs):
"""
Register a hook to be called by `rmm.reinitialize()`.

Hooks are called in the *reverse* order they are registered.
shwina marked this conversation as resolved.
Show resolved Hide resolved
"""
_reinitialize_hooks[func] = (args, kwargs)
shwina marked this conversation as resolved.
Show resolved Hide resolved
return func


def unregister_reinitialize_hook(func):
"""
Remove func from the list of hooks to be called by `rmm.reinitialize()`.
"""
_reinitialize_hooks.pop(func, None)
return func
35 changes: 35 additions & 0 deletions python/rmm/tests/test_rmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,3 +765,38 @@ def deallocate_func(ptr, size):

captured = capsys.readouterr()
assert captured.out == "Allocating 256 bytes\nDeallocating 256 bytes\n"


def test_reinitialize_hooks(capsys):
def one():
print("one")

def two():
print("two")

def three(s):
print(s)

rmm.register_reinitialize_hook(one)
rmm.register_reinitialize_hook(two)
rmm.register_reinitialize_hook(three, "four")

rmm.reinitialize()
captured = capsys.readouterr()
assert captured.out == "four\ntwo\none\n"

rmm.unregister_reinitialize_hook(one)
rmm.unregister_reinitialize_hook(two)
rmm.unregister_reinitialize_hook(three)

rmm.register_reinitialize_hook(one)
rmm.register_reinitialize_hook(two)
rmm.unregister_reinitialize_hook(one)

rmm.reinitialize()
captured = capsys.readouterr()
assert captured.out == "two\n"

rmm.unregister_reinitialize_hook(one)
rmm.unregister_reinitialize_hook(two)
rmm.unregister_reinitialize_hook(three)