Skip to content

Commit

Permalink
Prototype of deregister and with db.register_functions, refs #589
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Sep 3, 2023
1 parent 1260bdc commit 62f6738
Showing 1 changed file with 31 additions and 4 deletions.
35 changes: 31 additions & 4 deletions sqlite_utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ def register_function(
deterministic: bool = False,
replace: bool = False,
name: Optional[str] = None,
deregister: bool = False,
):
"""
``fn`` will be made available as a function within SQL, with the same name and number
Expand All @@ -434,35 +435,61 @@ def upper(value):
:param deterministic: set ``True`` for functions that always returns the same output for a given input
:param replace: set ``True`` to replace an existing function with the same name - otherwise throw an error
:param name: name of the SQLite function - if not specified, the Python function name will be used
:param deregister: set ``True`` to deregister the function
"""

def register(fn):
fn_name = name or fn.__name__
fn_to_register = fn
arity = len(inspect.signature(fn).parameters)
if not replace and (fn_name, arity) in self._registered_functions:
if (
not replace
and not deregister
and (fn_name, arity) in self._registered_functions
):
return fn
kwargs = {}
if deregister:
fn_to_register = None
registered = False
if deterministic:
# Try this, but fall back if sqlite3.NotSupportedError
try:
self.conn.create_function(
fn_name, arity, fn, **dict(kwargs, deterministic=True)
fn_name,
arity,
fn_to_register,
**dict(kwargs, deterministic=True),
)
registered = True
except (sqlite3.NotSupportedError, TypeError):
# TypeError is Python 3.7 "function takes at most 3 arguments"
pass
if not registered:
self.conn.create_function(fn_name, arity, fn, **kwargs)
self._registered_functions.add((fn_name, arity))
self.conn.create_function(fn_name, arity, fn_to_register, **kwargs)
if deregister:
self._registered_functions.remove((fn_name, arity))
else:
self._registered_functions.add((fn_name, arity))
return fn

if fn is None:
return register
else:
register(fn)

@contextlib.contextmanager
def register_functions(self, *functions: Callable):
"Register functions for the duration of a with block, then unregister them"
for func in functions:
self.register_function(func)
try:
yield
finally:
# Unregister the functions
for func in functions:
self.register_function(func, deregister=True)

def register_fts4_bm25(self):
"Register the ``rank_bm25(match_info)`` function used for calculating relevance with SQLite FTS4."
self.register_function(rank_bm25, deterministic=True, replace=True)
Expand Down

0 comments on commit 62f6738

Please sign in to comment.