Skip to content

Commit

Permalink
More testing.
Browse files Browse the repository at this point in the history
  • Loading branch information
jamadden committed Jun 29, 2023
1 parent 441d1d4 commit 34a41ab
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 9 deletions.
4 changes: 2 additions & 2 deletions src/relstorage/adapters/mysql/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,12 @@ def visit_upsert_before_select(self, select):

ver_det = select.context.version_detector
if ver_det.requires_values_upsert_alias(None):
self.emit('SELECT * FROM (')
self.emit_w_padding_space('SELECT * FROM (')

def visit_upsert_after_select(self, select):
ver_det = select.context.version_detector
if ver_det.requires_values_upsert_alias(None):
self.emit(') AS excluded')
self.emit_w_padding_space(') AS excluded')


class MySQLDialect(DefaultDialect):
Expand Down
41 changes: 41 additions & 0 deletions src/relstorage/adapters/mysql/tests/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
class TestMySQLDialect(test_sql.TestUpsert):
keep_history = False
dialect = MySQLDialect()
REQUIRES_UPSERT = False

insert_or_replace = (
'INSERT INTO object_state(zoid, state, tid, state_size) '
Expand All @@ -48,3 +49,43 @@ class TestMySQLDialect(test_sql.TestUpsert):
'ON DUPLICATE KEY UPDATE state = VALUES(state), '
'tid = VALUES(tid), state_size = VALUES(state_size)'
)

def get_bind_context(self):
requires_upsert = self.REQUIRES_UPSERT
class Context:
dialect = MySQLDialect()
keep_history = False
@property
def version_detector(self):
return self
def requires_values_upsert_alias(self, _cursor):
return requires_upsert
return Context()

class TestMySQLDialect8019(TestMySQLDialect):
REQUIRES_UPSERT = True

insert_or_replace = (
'INSERT INTO object_state(zoid, state, tid, state_size) '
'VALUES (%s, %s, %s, %s) AS excluded '
'ON DUPLICATE KEY UPDATE '
'state = excluded.state, tid = excluded.tid, '
'state_size = excluded.state_size'
)

insert_or_replace_subquery = (
'INSERT INTO object_state(zoid, tid, state, state_size) '
'SELECT * FROM ( SELECT zoid, %s, state, COALESCE(LENGTH(state), 0) FROM temp_store '
'ORDER BY zoid ) AS excluded '
'ON DUPLICATE KEY UPDATE '
'state = excluded.state, tid = excluded.tid, '
'state_size = excluded.state_size'
)

upsert_unconstrained_subquery = (
'INSERT INTO object_state(zoid, tid, state, state_size) '
'SELECT * FROM ( SELECT zoid, %s, state, COALESCE(LENGTH(state), 0) FROM temp_store ) '
'AS excluded '
'ON DUPLICATE KEY UPDATE state = excluded.state, '
'tid = excluded.tid, state_size = excluded.state_size'
)
7 changes: 6 additions & 1 deletion src/relstorage/adapters/sql/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,8 @@ def visit_upsert_before_select(self, select): # pylint:disable=unused-argument
class _DefaultContext(object):
keep_history = True

class NoDialectFoundError(TypeError):
"Raised when we cannot find a dialect."

class DialectAware(object):
context = _DefaultContext()
Expand Down Expand Up @@ -537,7 +539,10 @@ def _find_dialect(self, context):
else:
return dialect.bind(context)
__traceback_info__ = getattr(context, '__dict__', ()) # vars() doesn't work on e.g., None
raise TypeError("Unable to bind to %s; no dialect found" % (context,))
raise NoDialectFoundError("Unable to bind %s to %s; no dialect found" % (
type(self),
context,
))

def bind(self, context, dialect=None):
assert self.context is DialectAware.context, "already bound"
Expand Down
20 changes: 17 additions & 3 deletions src/relstorage/adapters/sql/insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from __future__ import division
from __future__ import print_function

import logging

from zope.interface import implementer

from .query import Query
Expand All @@ -23,10 +25,13 @@
from .expressions import Expression
from .expressions import ParamMixin
from .dialect import DialectAware
from .dialect import NoDialectFoundError

# pylint objects to __compile_visit.*__
# pylint:disable=bad-dunder-name

logger = logging.getLogger(__name__)

class _ValuesPlaceholderList(ColumnList):
pass

Expand Down Expand Up @@ -86,7 +91,11 @@ def __compile_visit__(self, compiler):
self._visit_select(compiler)
else:
values = _InsertValuesClause(self.column_list)
values = values.bind(self.context)
try:
values = values.bind(self.context)
except NoDialectFoundError:
# Should only happen in testing.
logger.debug('Unable to find dialect', exc_info=True)
compiler.visit(values)
compiler.emit(self.epilogue)

Expand Down Expand Up @@ -178,11 +187,16 @@ def do_update(self, *columns):

@property
def update_clause(self):
return Update(EmptyExpression(), _BindableList(
update = Update(EmptyExpression(), _BindableList(
_UpsertAssignmentExpression(col, _ExcludedColumn(col.name))
for col
in self.update_columns # pylint:disable=not-an-iterable
)).bind(self.context)
))
try:
return update.bind(self.context)
except NoDialectFoundError:
# Should only happen in testing
return update

def _visit_command(self, compiler):
compiler.emit_keyword_upsert()
Expand Down
9 changes: 6 additions & 3 deletions src/relstorage/adapters/sql/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def test_insert_or_replace(self):
it.c.state,
it.c.tid,
it.c.state_size
).bind(self)
).bind(self.get_bind_context())

self.assertEqual(
str(stmt),
Expand Down Expand Up @@ -509,7 +509,7 @@ def test_insert_or_replace_subquery(self):
it.c.state,
it.c.tid,
it.c.state_size
).bind(self)
).bind(self.get_bind_context())

self.maxDiff = None
self.assertEqual(
Expand Down Expand Up @@ -549,10 +549,13 @@ def test_upsert_unconstrained_subquery(self):
it.c.state,
it.c.tid,
it.c.state_size
).bind(self)
).bind(self.get_bind_context())

self.maxDiff = None
self.assertEqual(
str(stmt),
self.upsert_unconstrained_subquery
)

def get_bind_context(self):
return self

0 comments on commit 34a41ab

Please sign in to comment.