From 6828ba57b0d15d9f8cba418d2acd0174e789e971 Mon Sep 17 00:00:00 2001 From: Chantal Loncle <82039410+bog-walk@users.noreply.github.com> Date: Fri, 23 Feb 2024 17:55:13 -0500 Subject: [PATCH] fix: EXPOSED-301 Update with join throws if additionalConstraint provided Attempting to use any join function in its full overload form, with an additionalConstraint argument, results in an exception if an update() is called instead of a select(). This occurs because the additionalConstraint conditions are being correctly appended to generated SQL, but their arguments are not being first registered with the prepared update statement. The arguments in the update statement now correctly register any additionalConstraint and take into account their order if an update with where is called. --- .../exposed/sql/statements/UpdateStatement.kt | 38 ++++++++++++++----- .../sql/tests/shared/dml/UpdateTests.kt | 23 +++++++++-- 2 files changed, 49 insertions(+), 12 deletions(-) diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/UpdateStatement.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/UpdateStatement.kt index c828f79733..8c91ebe2f8 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/UpdateStatement.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/UpdateStatement.kt @@ -6,6 +6,8 @@ import org.jetbrains.exposed.sql.statements.api.PreparedStatementApi import org.jetbrains.exposed.sql.vendors.H2Dialect.H2CompatibilityMode import org.jetbrains.exposed.sql.vendors.H2FunctionProvider import org.jetbrains.exposed.sql.vendors.OracleDialect +import org.jetbrains.exposed.sql.vendors.PostgreSQLDialect +import org.jetbrains.exposed.sql.vendors.SQLServerDialect import org.jetbrains.exposed.sql.vendors.currentDialect import org.jetbrains.exposed.sql.vendors.h2Mode @@ -45,20 +47,38 @@ open class UpdateStatement(val targetsSet: ColumnSet, val limit: Int?, val where } override fun arguments(): Iterable>> = QueryBuilder(true).run { + val dialect = currentDialect when { - targetsSet is Join && currentDialect is OracleDialect -> { - where?.toQueryBuilder(this) - values.forEach { - registerArgument(it.key, it.value) - } + targetsSet is Join && dialect is OracleDialect -> { + registerAdditionalArgs(targetsSet) + registerWhereArg() + registerUpdateArgs() + } + targetsSet is Join && (dialect is SQLServerDialect || dialect is PostgreSQLDialect) -> { + registerUpdateArgs() + registerAdditionalArgs(targetsSet) + registerWhereArg() + } + targetsSet is Join -> { + registerAdditionalArgs(targetsSet) + registerUpdateArgs() + registerWhereArg() } else -> { - values.forEach { - registerArgument(it.key, it.value) - } - where?.toQueryBuilder(this) + registerUpdateArgs() + registerWhereArg() } } if (args.isNotEmpty()) listOf(args) else emptyList() } + + private fun QueryBuilder.registerWhereArg() { where?.toQueryBuilder(this) } + + private fun QueryBuilder.registerUpdateArgs() { values.forEach { registerArgument(it.key, it.value) } } + + private fun QueryBuilder.registerAdditionalArgs(join: Join) { + join.joinParts.forEach { + it.additionalConstraint?.invoke(SqlExpressionBuilder)?.toQueryBuilder(this) + } + } } diff --git a/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/dml/UpdateTests.kt b/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/dml/UpdateTests.kt index 36616a8766..328d8ea781 100644 --- a/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/dml/UpdateTests.kt +++ b/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/dml/UpdateTests.kt @@ -68,7 +68,7 @@ class UpdateTests : DatabaseTestsBase() { } @Test - fun testUpdateWithJoin01() { + fun testUpdateWithSingleJoin() { withCitiesAndUsers(exclude = listOf(TestDB.SQLITE)) { _, users, userData -> val join = users.innerJoin(userData) join.update { @@ -80,11 +80,22 @@ class UpdateTests : DatabaseTestsBase() { assertEquals(it[users.name], it[userData.comment]) assertEquals(123, it[userData.value]) } + + val joinWithConstraint = users.innerJoin(userData, { users.id }, { userData.user_id }) { users.id eq "smth" } + joinWithConstraint.update { + it[userData.comment] = users.name + it[userData.value] = 0 + } + + joinWithConstraint.selectAll().forEach { + assertEquals(it[users.name], it[userData.comment]) + assertEquals(0, it[userData.value]) + } } } @Test - fun testUpdateWithJoin02() { + fun testUpdateWithMultipleJoins() { withCitiesAndUsers(exclude = TestDB.allH2TestDB + TestDB.SQLITE) { cities, users, userData -> val join = cities.innerJoin(users).innerJoin(userData) join.update { @@ -109,7 +120,7 @@ class UpdateTests : DatabaseTestsBase() { val tableAId = reference("table_a_id", tableA) } - val supportWhere = TestDB.entries - TestDB.allH2TestDB - TestDB.SQLITE + TestDB.H2_ORACLE + val supportWhere = TestDB.entries - TestDB.allH2TestDB.toSet() - TestDB.SQLITE + TestDB.H2_ORACLE withTables(tableA, tableB) { testingDb -> val aId = tableA.insertAndGetId { it[foo] = "foo" } @@ -119,6 +130,7 @@ class UpdateTests : DatabaseTestsBase() { } val join = tableA.innerJoin(tableB) + val joinWithConstraint = tableA.innerJoin(tableB, { tableA.id }, { tableB.tableAId }) { tableB.bar eq "foo" } if (testingDb in supportWhere) { join.update({ tableA.foo eq "foo" }) { @@ -127,6 +139,11 @@ class UpdateTests : DatabaseTestsBase() { join.selectAll().single().also { assertEquals("baz", it[tableB.bar]) } + + joinWithConstraint.update({ tableA.foo eq "foo" }) { + it[tableB.bar] = "baz" + } + assertEquals(0, joinWithConstraint.selectAll().count()) } else { expectException { join.update({ tableA.foo eq "foo" }) {