diff --git a/python/rmm/__init__.py b/python/rmm/__init__.py index bdb7f7c56..424bcaa35 100644 --- a/python/rmm/__init__.py +++ b/python/rmm/__init__.py @@ -22,8 +22,10 @@ RMMNumbaManager, _numba_memory_manager, is_initialized, + register_reinitialize_hook, reinitialize, rmm_cupy_allocator, + unregister_reinitialize_hook, ) __version__ = get_versions()["version"] diff --git a/python/rmm/rmm.py b/python/rmm/rmm.py index 3e99f51e2..fd46b4793 100644 --- a/python/rmm/rmm.py +++ b/python/rmm/rmm.py @@ -29,6 +29,9 @@ def __init__(self, errcode, msg): super(RMMError, self).__init__(msg) +_reinitialize_hooks = [] + + def reinitialize( pool_allocator=False, managed_memory=False, @@ -82,6 +85,9 @@ def reinitialize( Use `rmm.get_log_filenames()` to get the log file names corresponding to each device. """ + for func, args, kwargs in reversed(_reinitialize_hooks): + func(*args, **kwargs) + rmm.mr._initialize( pool_allocator=pool_allocator, managed_memory=managed_memory, @@ -231,3 +237,46 @@ def rmm_cupy_allocator(nbytes): ptr = cupy.cuda.memory.MemoryPointer(mem, 0) return ptr + + +def register_reinitialize_hook(func, *args, **kwargs): + """ + Add a function to the list of functions ("hooks") that will be + called before :py:func:`~rmm.reinitialize()`. + + A user or library may register hooks to perform any necessary + cleanup before RMM is reinitialized. For example, a library with + an internal cache of objects that use device memory allocated by + RMM can register a hook to release those references before RMM is + reinitialized, thus ensuring that the relevant device memory + resource can be deallocated. + + Hooks are called in the *reverse* order they are registered. This + is useful, for example, when a library registers multiple hooks + and needs them to run in a specific order for cleanup to be safe. + Hooks cannot rely on being registered in a particular order + relative to hooks registered by other packages, since that is + determined by package import ordering. + + Parameters + ---------- + func : callable + Function to be called before :py:func:`~rmm.reinitialize()` + args, kwargs + Positional and keyword arguments to be passed to `func` + """ + global _reinitialize_hooks + _reinitialize_hooks.append((func, args, kwargs)) + return func + + +def unregister_reinitialize_hook(func): + """ + Remove `func` from list of hooks that will be called before + :py:func:`~rmm.reinitialize()`. + + If `func` was registered more than once, every instance of it will + be removed from the list of hooks. + """ + global _reinitialize_hooks + _reinitialize_hooks = [x for x in _reinitialize_hooks if x[0] != func] diff --git a/python/rmm/tests/test_rmm.py b/python/rmm/tests/test_rmm.py index 94f375aec..5d5c3c18f 100644 --- a/python/rmm/tests/test_rmm.py +++ b/python/rmm/tests/test_rmm.py @@ -765,3 +765,72 @@ def deallocate_func(ptr, size): captured = capsys.readouterr() assert captured.out == "Allocating 256 bytes\nDeallocating 256 bytes\n" + + +@pytest.fixture +def make_reinit_hook(): + funcs = [] + + def _make_reinit_hook(func, *args, **kwargs): + funcs.append(func) + rmm.register_reinitialize_hook(func, *args, **kwargs) + return func + + yield _make_reinit_hook + for func in funcs: + rmm.unregister_reinitialize_hook(func) + + +def test_reinit_hooks_register(make_reinit_hook): + L = [] + make_reinit_hook(lambda: L.append(1)) + make_reinit_hook(lambda: L.append(2)) + make_reinit_hook(lambda x: L.append(x), 3) + + rmm.reinitialize() + assert L == [3, 2, 1] + + +def test_reinit_hooks_unregister(make_reinit_hook): + L = [] + one = make_reinit_hook(lambda: L.append(1)) + make_reinit_hook(lambda: L.append(2)) + + rmm.unregister_reinitialize_hook(one) + rmm.reinitialize() + assert L == [2] + + +def test_reinit_hooks_register_twice(make_reinit_hook): + L = [] + + def func_with_arg(x): + L.append(x) + + def func_without_arg(): + L.append(2) + + make_reinit_hook(func_with_arg, 1) + make_reinit_hook(func_without_arg) + make_reinit_hook(func_with_arg, 3) + make_reinit_hook(func_without_arg) + + rmm.reinitialize() + assert L == [2, 3, 2, 1] + + +def test_reinit_hooks_unregister_twice_registered(make_reinit_hook): + # unregistering a twice-registered function + # should unregister both instances: + L = [] + + def func_with_arg(x): + L.append(x) + + make_reinit_hook(func_with_arg, 1) + make_reinit_hook(lambda: L.append(2)) + make_reinit_hook(func_with_arg, 3) + + rmm.unregister_reinitialize_hook(func_with_arg) + rmm.reinitialize() + assert L == [2]