From 62f673835c4a66f87cf6f949eaff43c8b014619b Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Sun, 3 Sep 2023 12:36:10 -0700 Subject: [PATCH] Prototype of deregister and with db.register_functions, refs #589 --- sqlite_utils/db.py | 35 +++++++++++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/sqlite_utils/db.py b/sqlite_utils/db.py index de8d1c33..24ae842f 100644 --- a/sqlite_utils/db.py +++ b/sqlite_utils/db.py @@ -413,6 +413,7 @@ def register_function( deterministic: bool = False, replace: bool = False, name: Optional[str] = None, + deregister: bool = False, ): """ ``fn`` will be made available as a function within SQL, with the same name and number @@ -434,28 +435,42 @@ def upper(value): :param deterministic: set ``True`` for functions that always returns the same output for a given input :param replace: set ``True`` to replace an existing function with the same name - otherwise throw an error :param name: name of the SQLite function - if not specified, the Python function name will be used + :param deregister: set ``True`` to deregister the function """ def register(fn): fn_name = name or fn.__name__ + fn_to_register = fn arity = len(inspect.signature(fn).parameters) - if not replace and (fn_name, arity) in self._registered_functions: + if ( + not replace + and not deregister + and (fn_name, arity) in self._registered_functions + ): return fn kwargs = {} + if deregister: + fn_to_register = None registered = False if deterministic: # Try this, but fall back if sqlite3.NotSupportedError try: self.conn.create_function( - fn_name, arity, fn, **dict(kwargs, deterministic=True) + fn_name, + arity, + fn_to_register, + **dict(kwargs, deterministic=True), ) registered = True except (sqlite3.NotSupportedError, TypeError): # TypeError is Python 3.7 "function takes at most 3 arguments" pass if not registered: - self.conn.create_function(fn_name, arity, fn, **kwargs) - self._registered_functions.add((fn_name, arity)) + self.conn.create_function(fn_name, arity, fn_to_register, **kwargs) + if deregister: + self._registered_functions.remove((fn_name, arity)) + else: + self._registered_functions.add((fn_name, arity)) return fn if fn is None: @@ -463,6 +478,18 @@ def register(fn): else: register(fn) + @contextlib.contextmanager + def register_functions(self, *functions: Callable): + "Register functions for the duration of a with block, then unregister them" + for func in functions: + self.register_function(func) + try: + yield + finally: + # Unregister the functions + for func in functions: + self.register_function(func, deregister=True) + def register_fts4_bm25(self): "Register the ``rank_bm25(match_info)`` function used for calculating relevance with SQLite FTS4." self.register_function(rank_bm25, deterministic=True, replace=True)