From 84981557d9ef83c09087acd21be8e7fd12e74314 Mon Sep 17 00:00:00 2001 From: Ashwin Srinath Date: Mon, 18 Jul 2022 13:45:02 -0400 Subject: [PATCH 01/13] Add reinitialization hooks --- python/rmm/__init__.py | 2 ++ python/rmm/rmm.py | 25 +++++++++++++++++++++++++ python/rmm/tests/test_rmm.py | 35 +++++++++++++++++++++++++++++++++++ 3 files changed, 62 insertions(+) diff --git a/python/rmm/__init__.py b/python/rmm/__init__.py index bdb7f7c56..424bcaa35 100644 --- a/python/rmm/__init__.py +++ b/python/rmm/__init__.py @@ -22,8 +22,10 @@ RMMNumbaManager, _numba_memory_manager, is_initialized, + register_reinitialize_hook, reinitialize, rmm_cupy_allocator, + unregister_reinitialize_hook, ) __version__ = get_versions()["version"] diff --git a/python/rmm/rmm.py b/python/rmm/rmm.py index 3e99f51e2..2ca98963e 100644 --- a/python/rmm/rmm.py +++ b/python/rmm/rmm.py @@ -29,6 +29,9 @@ def __init__(self, errcode, msg): super(RMMError, self).__init__(msg) +_reinitialize_hooks = {} + + def reinitialize( pool_allocator=False, managed_memory=False, @@ -82,6 +85,10 @@ def reinitialize( Use `rmm.get_log_filenames()` to get the log file names corresponding to each device. """ + [ + hook(*args, **kwargs) + for (hook, (args, kwargs)) in reversed(_reinitialize_hooks.items()) + ] rmm.mr._initialize( pool_allocator=pool_allocator, managed_memory=managed_memory, @@ -231,3 +238,21 @@ def rmm_cupy_allocator(nbytes): ptr = cupy.cuda.memory.MemoryPointer(mem, 0) return ptr + + +def register_reinitialize_hook(func, *args, **kwargs): + """ + Register a hook to be called by `rmm.reinitialize()`. + + Hooks are called in the *reverse* order they are registered. + """ + _reinitialize_hooks[func] = (args, kwargs) + return func + + +def unregister_reinitialize_hook(func): + """ + Remove func from the list of hooks to be called by `rmm.reinitialize()`. + """ + _reinitialize_hooks.pop(func, None) + return func diff --git a/python/rmm/tests/test_rmm.py b/python/rmm/tests/test_rmm.py index 94f375aec..3e284d9f0 100644 --- a/python/rmm/tests/test_rmm.py +++ b/python/rmm/tests/test_rmm.py @@ -765,3 +765,38 @@ def deallocate_func(ptr, size): captured = capsys.readouterr() assert captured.out == "Allocating 256 bytes\nDeallocating 256 bytes\n" + + +def test_reinitialize_hooks(capsys): + def one(): + print("one") + + def two(): + print("two") + + def three(s): + print(s) + + rmm.register_reinitialize_hook(one) + rmm.register_reinitialize_hook(two) + rmm.register_reinitialize_hook(three, "four") + + rmm.reinitialize() + captured = capsys.readouterr() + assert captured.out == "four\ntwo\none\n" + + rmm.unregister_reinitialize_hook(one) + rmm.unregister_reinitialize_hook(two) + rmm.unregister_reinitialize_hook(three) + + rmm.register_reinitialize_hook(one) + rmm.register_reinitialize_hook(two) + rmm.unregister_reinitialize_hook(one) + + rmm.reinitialize() + captured = capsys.readouterr() + assert captured.out == "two\n" + + rmm.unregister_reinitialize_hook(one) + rmm.unregister_reinitialize_hook(two) + rmm.unregister_reinitialize_hook(three) From 402d57851bfb3e8f31c44b8f0c553c925b697907 Mon Sep 17 00:00:00 2001 From: Ashwin Srinath Date: Mon, 18 Jul 2022 13:51:51 -0400 Subject: [PATCH 02/13] Better docs --- python/rmm/rmm.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/python/rmm/rmm.py b/python/rmm/rmm.py index 2ca98963e..0acfd927c 100644 --- a/python/rmm/rmm.py +++ b/python/rmm/rmm.py @@ -242,9 +242,17 @@ def rmm_cupy_allocator(nbytes): def register_reinitialize_hook(func, *args, **kwargs): """ - Register a hook to be called by `rmm.reinitialize()`. + Add a function to the list of functions that will be called before + `rmm.reinitialize()`. Hooks are called in the *reverse* order they are registered. + + Parameters + ---------- + func: callable + Function to be called before `rmm.reinitialize()` + args, kwargs + Positional and keyword arguments to bepassed to `func` """ _reinitialize_hooks[func] = (args, kwargs) return func @@ -252,7 +260,8 @@ def register_reinitialize_hook(func, *args, **kwargs): def unregister_reinitialize_hook(func): """ - Remove func from the list of hooks to be called by `rmm.reinitialize()`. + Remove `func` from the list of functions that will be called before + `rmm.reinitialize()`. """ _reinitialize_hooks.pop(func, None) return func From 337aa95d2136c3480c26bd207c705ac269596af5 Mon Sep 17 00:00:00 2001 From: Ashwin Srinath Date: Mon, 18 Jul 2022 14:58:13 -0400 Subject: [PATCH 03/13] Use list, not dict --- python/rmm/rmm.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/python/rmm/rmm.py b/python/rmm/rmm.py index 0acfd927c..d32b42e4b 100644 --- a/python/rmm/rmm.py +++ b/python/rmm/rmm.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import ctypes +from itertools import filterfalse from cuda.cuda import CUdeviceptr, cuIpcGetMemHandle from numba import config, cuda @@ -29,7 +30,7 @@ def __init__(self, errcode, msg): super(RMMError, self).__init__(msg) -_reinitialize_hooks = {} +_reinitialize_hooks = [] def reinitialize( @@ -254,7 +255,8 @@ def register_reinitialize_hook(func, *args, **kwargs): args, kwargs Positional and keyword arguments to bepassed to `func` """ - _reinitialize_hooks[func] = (args, kwargs) + global _reinitialize_hooks + _reinitialize_hooks.append((func, args, kwargs)) return func @@ -262,6 +264,11 @@ def unregister_reinitialize_hook(func): """ Remove `func` from the list of functions that will be called before `rmm.reinitialize()`. + + If `func` was registered more than once, every instance of it will + be removed from the list of functions. """ - _reinitialize_hooks.pop(func, None) - return func + global _reinitialize_hooks + _reinitialize_hooks = list( + filterfalse(lambda x: x[0] == func, _reinitialize_hooks) + ) From 084b030bd9aff970b877f8d50271bda7439bfd32 Mon Sep 17 00:00:00 2001 From: Ashwin Srinath Date: Mon, 18 Jul 2022 15:41:16 -0400 Subject: [PATCH 04/13] More testing for reinit hooks --- python/rmm/rmm.py | 2 +- python/rmm/tests/test_rmm.py | 71 ++++++++++++++++++++++++++---------- 2 files changed, 52 insertions(+), 21 deletions(-) diff --git a/python/rmm/rmm.py b/python/rmm/rmm.py index d32b42e4b..8afe1764f 100644 --- a/python/rmm/rmm.py +++ b/python/rmm/rmm.py @@ -88,7 +88,7 @@ def reinitialize( """ [ hook(*args, **kwargs) - for (hook, (args, kwargs)) in reversed(_reinitialize_hooks.items()) + for (hook, args, kwargs) in reversed(_reinitialize_hooks) ] rmm.mr._initialize( pool_allocator=pool_allocator, diff --git a/python/rmm/tests/test_rmm.py b/python/rmm/tests/test_rmm.py index 3e284d9f0..0b5193cc2 100644 --- a/python/rmm/tests/test_rmm.py +++ b/python/rmm/tests/test_rmm.py @@ -767,36 +767,67 @@ def deallocate_func(ptr, size): assert captured.out == "Allocating 256 bytes\nDeallocating 256 bytes\n" -def test_reinitialize_hooks(capsys): - def one(): - print("one") +@pytest.fixture +def make_reinit_hook(): + funcs = [] - def two(): - print("two") + def _make_reinit_hook(func, *args, **kwargs): + funcs.append(func) + rmm.register_reinitialize_hook(func, *args, **kwargs) + return func + + yield _make_reinit_hook + for func in funcs: + rmm.unregister_reinitialize_hook(func) - def three(s): - print(s) - rmm.register_reinitialize_hook(one) - rmm.register_reinitialize_hook(two) - rmm.register_reinitialize_hook(three, "four") +def test_reinit_hooks_register(make_reinit_hook, capsys): + make_reinit_hook(lambda: print("one")) + make_reinit_hook(lambda: print("two")) + make_reinit_hook(lambda x: print(x), "three") rmm.reinitialize() captured = capsys.readouterr() - assert captured.out == "four\ntwo\none\n" + assert captured.out == "three\ntwo\none\n" - rmm.unregister_reinitialize_hook(one) - rmm.unregister_reinitialize_hook(two) - rmm.unregister_reinitialize_hook(three) - rmm.register_reinitialize_hook(one) - rmm.register_reinitialize_hook(two) - rmm.unregister_reinitialize_hook(one) +def test_reinit_hooks_unregister(make_reinit_hook, capsys): + one = make_reinit_hook(lambda: print("one")) + make_reinit_hook(lambda: print("two")) + rmm.unregister_reinitialize_hook(one) rmm.reinitialize() captured = capsys.readouterr() assert captured.out == "two\n" - rmm.unregister_reinitialize_hook(one) - rmm.unregister_reinitialize_hook(two) - rmm.unregister_reinitialize_hook(three) + +def test_reinit_hooks_register_twice(make_reinit_hook, capsys): + def func_with_arg(x): + print(x) + + def func_without_arg(): + print("two") + + make_reinit_hook(func_with_arg, "one") + make_reinit_hook(func_without_arg) + make_reinit_hook(func_with_arg, "three") + make_reinit_hook(func_without_arg) + + rmm.reinitialize() + captured = capsys.readouterr() + assert captured.out == "two\nthree\ntwo\none\n" + + +def test_register_reinit_hook_twice(reinit_hooks, capsys): + # unregistering a twice-registered function + # should unregister both instances: + def func_with_arg(x): + print(x) + + make_reinit_hook(func_with_arg, "one") + make_reinit_hook(lambda: print("two")) + make_reinit_hook(func_with_arg, "three") + + rmm.reinitialize() + captured = capsys.readouterr() + assert captured.out == "two\n" From 61d6189680761f7e7938f64c5e9a1e6b15f0d4ae Mon Sep 17 00:00:00 2001 From: Ashwin Srinath Date: Mon, 18 Jul 2022 15:52:26 -0400 Subject: [PATCH 05/13] Fix test --- python/rmm/tests/test_rmm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/rmm/tests/test_rmm.py b/python/rmm/tests/test_rmm.py index 0b5193cc2..a0794481b 100644 --- a/python/rmm/tests/test_rmm.py +++ b/python/rmm/tests/test_rmm.py @@ -818,7 +818,7 @@ def func_without_arg(): assert captured.out == "two\nthree\ntwo\none\n" -def test_register_reinit_hook_twice(reinit_hooks, capsys): +def test_reinit_hooks_unregister_twice_registered(make_reinit_hook, capsys): # unregistering a twice-registered function # should unregister both instances: def func_with_arg(x): @@ -828,6 +828,7 @@ def func_with_arg(x): make_reinit_hook(lambda: print("two")) make_reinit_hook(func_with_arg, "three") + rmm.unregister_reinitialize_hook(func_with_arg) rmm.reinitialize() captured = capsys.readouterr() assert captured.out == "two\n" From 298d2b49c6a9ec98dcb9cef64963ebf978bf1b10 Mon Sep 17 00:00:00 2001 From: Ashwin Srinath Date: Mon, 18 Jul 2022 17:17:30 -0400 Subject: [PATCH 06/13] Add a note about when hooks are used --- python/rmm/rmm.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/rmm/rmm.py b/python/rmm/rmm.py index 8afe1764f..332844a2b 100644 --- a/python/rmm/rmm.py +++ b/python/rmm/rmm.py @@ -246,6 +246,10 @@ def register_reinitialize_hook(func, *args, **kwargs): Add a function to the list of functions that will be called before `rmm.reinitialize()`. + Typically, a library will use this function to register hooks that + are responsible for deleting any remaining internal references to + objects using device memory allocated by RMM. + Hooks are called in the *reverse* order they are registered. Parameters From 887f82a74b7465483c5e4dc5014fcf40ce464643 Mon Sep 17 00:00:00 2001 From: Ashwin Srinath Date: Mon, 18 Jul 2022 17:24:33 -0400 Subject: [PATCH 07/13] Use a list rather than captured stdout --- python/rmm/tests/test_rmm.py | 52 +++++++++++++++++++----------------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/python/rmm/tests/test_rmm.py b/python/rmm/tests/test_rmm.py index a0794481b..5d5c3c18f 100644 --- a/python/rmm/tests/test_rmm.py +++ b/python/rmm/tests/test_rmm.py @@ -781,54 +781,56 @@ def _make_reinit_hook(func, *args, **kwargs): rmm.unregister_reinitialize_hook(func) -def test_reinit_hooks_register(make_reinit_hook, capsys): - make_reinit_hook(lambda: print("one")) - make_reinit_hook(lambda: print("two")) - make_reinit_hook(lambda x: print(x), "three") +def test_reinit_hooks_register(make_reinit_hook): + L = [] + make_reinit_hook(lambda: L.append(1)) + make_reinit_hook(lambda: L.append(2)) + make_reinit_hook(lambda x: L.append(x), 3) rmm.reinitialize() - captured = capsys.readouterr() - assert captured.out == "three\ntwo\none\n" + assert L == [3, 2, 1] -def test_reinit_hooks_unregister(make_reinit_hook, capsys): - one = make_reinit_hook(lambda: print("one")) - make_reinit_hook(lambda: print("two")) +def test_reinit_hooks_unregister(make_reinit_hook): + L = [] + one = make_reinit_hook(lambda: L.append(1)) + make_reinit_hook(lambda: L.append(2)) rmm.unregister_reinitialize_hook(one) rmm.reinitialize() - captured = capsys.readouterr() - assert captured.out == "two\n" + assert L == [2] + +def test_reinit_hooks_register_twice(make_reinit_hook): + L = [] -def test_reinit_hooks_register_twice(make_reinit_hook, capsys): def func_with_arg(x): - print(x) + L.append(x) def func_without_arg(): - print("two") + L.append(2) - make_reinit_hook(func_with_arg, "one") + make_reinit_hook(func_with_arg, 1) make_reinit_hook(func_without_arg) - make_reinit_hook(func_with_arg, "three") + make_reinit_hook(func_with_arg, 3) make_reinit_hook(func_without_arg) rmm.reinitialize() - captured = capsys.readouterr() - assert captured.out == "two\nthree\ntwo\none\n" + assert L == [2, 3, 2, 1] -def test_reinit_hooks_unregister_twice_registered(make_reinit_hook, capsys): +def test_reinit_hooks_unregister_twice_registered(make_reinit_hook): # unregistering a twice-registered function # should unregister both instances: + L = [] + def func_with_arg(x): - print(x) + L.append(x) - make_reinit_hook(func_with_arg, "one") - make_reinit_hook(lambda: print("two")) - make_reinit_hook(func_with_arg, "three") + make_reinit_hook(func_with_arg, 1) + make_reinit_hook(lambda: L.append(2)) + make_reinit_hook(func_with_arg, 3) rmm.unregister_reinitialize_hook(func_with_arg) rmm.reinitialize() - captured = capsys.readouterr() - assert captured.out == "two\n" + assert L == [2] From 40a958faf130436df3acb1d2b94cfbe1b249b0c4 Mon Sep 17 00:00:00 2001 From: Ashwin Srinath Date: Mon, 18 Jul 2022 17:25:37 -0400 Subject: [PATCH 08/13] hook->func --- python/rmm/rmm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/rmm/rmm.py b/python/rmm/rmm.py index 332844a2b..65ef609d3 100644 --- a/python/rmm/rmm.py +++ b/python/rmm/rmm.py @@ -87,8 +87,8 @@ def reinitialize( corresponding to each device. """ [ - hook(*args, **kwargs) - for (hook, args, kwargs) in reversed(_reinitialize_hooks) + func(*args, **kwargs) + for (func, args, kwargs) in reversed(_reinitialize_hooks) ] rmm.mr._initialize( pool_allocator=pool_allocator, From 442998c252de366745400358419f2c82df3ce40c Mon Sep 17 00:00:00 2001 From: Ashwin Srinath Date: Tue, 19 Jul 2022 10:28:59 -0400 Subject: [PATCH 09/13] Address review feedback --- python/rmm/rmm.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/python/rmm/rmm.py b/python/rmm/rmm.py index 65ef609d3..268e0a872 100644 --- a/python/rmm/rmm.py +++ b/python/rmm/rmm.py @@ -86,10 +86,9 @@ def reinitialize( Use `rmm.get_log_filenames()` to get the log file names corresponding to each device. """ - [ + for func, args, kwargs in reversed(_reinitialize_hooks): func(*args, **kwargs) - for (func, args, kwargs) in reversed(_reinitialize_hooks) - ] + rmm.mr._initialize( pool_allocator=pool_allocator, managed_memory=managed_memory, @@ -243,19 +242,22 @@ def rmm_cupy_allocator(nbytes): def register_reinitialize_hook(func, *args, **kwargs): """ - Add a function to the list of functions that will be called before - `rmm.reinitialize()`. + Add a function to the list of functions ("hooks") that will be + called before :py:func:`~rmm.reinitialize()`. - Typically, a library will use this function to register hooks that - are responsible for deleting any remaining internal references to - objects using device memory allocated by RMM. + A library may register hooks to perform any necessary cleanup + before RMM is reinitialized. For example, a library with an + internal cache of objects that use device memory allocated by RMM + can register a hook to release those references before RMM is + reinitialized, thus ensuring that the relevant device memory + resource can be deallocated Hooks are called in the *reverse* order they are registered. Parameters ---------- func: callable - Function to be called before `rmm.reinitialize()` + Function to be called before :py:func:`~rmm.reinitialize()` args, kwargs Positional and keyword arguments to bepassed to `func` """ @@ -266,11 +268,11 @@ def register_reinitialize_hook(func, *args, **kwargs): def unregister_reinitialize_hook(func): """ - Remove `func` from the list of functions that will be called before - `rmm.reinitialize()`. + Remove `func` from list of hooks that will be called before + :py:func:`~rmm.reinitialize()`. If `func` was registered more than once, every instance of it will - be removed from the list of functions. + be removed from the list of hooks. """ global _reinitialize_hooks _reinitialize_hooks = list( From 20a10e9afb89edb1f5e45a930acd34d7e0231172 Mon Sep 17 00:00:00 2001 From: Ashwin Srinath <3190405+shwina@users.noreply.github.com> Date: Tue, 19 Jul 2022 10:32:31 -0400 Subject: [PATCH 10/13] Update python/rmm/rmm.py Co-authored-by: Lawrence Mitchell --- python/rmm/rmm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/rmm/rmm.py b/python/rmm/rmm.py index 268e0a872..0948994ac 100644 --- a/python/rmm/rmm.py +++ b/python/rmm/rmm.py @@ -259,7 +259,7 @@ def register_reinitialize_hook(func, *args, **kwargs): func: callable Function to be called before :py:func:`~rmm.reinitialize()` args, kwargs - Positional and keyword arguments to bepassed to `func` + Positional and keyword arguments to be passed to `func` """ global _reinitialize_hooks _reinitialize_hooks.append((func, args, kwargs)) From 46cf3d4def036841f39eb5222c2dee3510d1212e Mon Sep 17 00:00:00 2001 From: Ashwin Srinath <3190405+shwina@users.noreply.github.com> Date: Tue, 19 Jul 2022 12:02:38 -0400 Subject: [PATCH 11/13] Update python/rmm/rmm.py Co-authored-by: Lawrence Mitchell --- python/rmm/rmm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/rmm/rmm.py b/python/rmm/rmm.py index 0948994ac..2b1f7e93a 100644 --- a/python/rmm/rmm.py +++ b/python/rmm/rmm.py @@ -250,7 +250,7 @@ def register_reinitialize_hook(func, *args, **kwargs): internal cache of objects that use device memory allocated by RMM can register a hook to release those references before RMM is reinitialized, thus ensuring that the relevant device memory - resource can be deallocated + resource can be deallocated. Hooks are called in the *reverse* order they are registered. From 4f68c127204f91d379ecdf63ab623139a52930c4 Mon Sep 17 00:00:00 2001 From: Ashwin Srinath Date: Tue, 19 Jul 2022 13:43:43 -0400 Subject: [PATCH 12/13] Update docstring --- python/rmm/rmm.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/python/rmm/rmm.py b/python/rmm/rmm.py index 268e0a872..a3ba770d9 100644 --- a/python/rmm/rmm.py +++ b/python/rmm/rmm.py @@ -245,14 +245,19 @@ def register_reinitialize_hook(func, *args, **kwargs): Add a function to the list of functions ("hooks") that will be called before :py:func:`~rmm.reinitialize()`. - A library may register hooks to perform any necessary cleanup - before RMM is reinitialized. For example, a library with an - internal cache of objects that use device memory allocated by RMM - can register a hook to release those references before RMM is + A user or library may register hooks to perform any necessary + cleanup before RMM is reinitialized. For example, a library with + an internal cache of objects that use device memory allocated by + RMM can register a hook to release those references before RMM is reinitialized, thus ensuring that the relevant device memory resource can be deallocated - Hooks are called in the *reverse* order they are registered. + Hooks are called in the *reverse* order they are registered. This + is useful, for example, when a library registers multiple hooks + and needs them to run in a specific order for cleanup to be safe. + Hooks cannot rely on being registered in a particular order + relative to hooks registered by other packages, since that is + determined by package import ordering. Parameters ---------- From d74efc0374c7a8f26a409845684076a5503633c1 Mon Sep 17 00:00:00 2001 From: Ashwin Srinath Date: Tue, 19 Jul 2022 15:08:10 -0400 Subject: [PATCH 13/13] Address @bdice's comments --- python/rmm/rmm.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/python/rmm/rmm.py b/python/rmm/rmm.py index 51505298e..fd46b4793 100644 --- a/python/rmm/rmm.py +++ b/python/rmm/rmm.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import ctypes -from itertools import filterfalse from cuda.cuda import CUdeviceptr, cuIpcGetMemHandle from numba import config, cuda @@ -252,7 +251,7 @@ def register_reinitialize_hook(func, *args, **kwargs): reinitialized, thus ensuring that the relevant device memory resource can be deallocated. - Hooks are called in the *reverse* order they are registered. This + Hooks are called in the *reverse* order they are registered. This is useful, for example, when a library registers multiple hooks and needs them to run in a specific order for cleanup to be safe. Hooks cannot rely on being registered in a particular order @@ -261,7 +260,7 @@ def register_reinitialize_hook(func, *args, **kwargs): Parameters ---------- - func: callable + func : callable Function to be called before :py:func:`~rmm.reinitialize()` args, kwargs Positional and keyword arguments to be passed to `func` @@ -280,6 +279,4 @@ def unregister_reinitialize_hook(func): be removed from the list of hooks. """ global _reinitialize_hooks - _reinitialize_hooks = list( - filterfalse(lambda x: x[0] == func, _reinitialize_hooks) - ) + _reinitialize_hooks = [x for x in _reinitialize_hooks if x[0] != func]