diff --git a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/UpdateStatement.kt b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/UpdateStatement.kt index c828f79733..8c91ebe2f8 100644 --- a/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/UpdateStatement.kt +++ b/exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/statements/UpdateStatement.kt @@ -6,6 +6,8 @@ import org.jetbrains.exposed.sql.statements.api.PreparedStatementApi import org.jetbrains.exposed.sql.vendors.H2Dialect.H2CompatibilityMode import org.jetbrains.exposed.sql.vendors.H2FunctionProvider import org.jetbrains.exposed.sql.vendors.OracleDialect +import org.jetbrains.exposed.sql.vendors.PostgreSQLDialect +import org.jetbrains.exposed.sql.vendors.SQLServerDialect import org.jetbrains.exposed.sql.vendors.currentDialect import org.jetbrains.exposed.sql.vendors.h2Mode @@ -45,20 +47,38 @@ open class UpdateStatement(val targetsSet: ColumnSet, val limit: Int?, val where } override fun arguments(): Iterable>> = QueryBuilder(true).run { + val dialect = currentDialect when { - targetsSet is Join && currentDialect is OracleDialect -> { - where?.toQueryBuilder(this) - values.forEach { - registerArgument(it.key, it.value) - } + targetsSet is Join && dialect is OracleDialect -> { + registerAdditionalArgs(targetsSet) + registerWhereArg() + registerUpdateArgs() + } + targetsSet is Join && (dialect is SQLServerDialect || dialect is PostgreSQLDialect) -> { + registerUpdateArgs() + registerAdditionalArgs(targetsSet) + registerWhereArg() + } + targetsSet is Join -> { + registerAdditionalArgs(targetsSet) + registerUpdateArgs() + registerWhereArg() } else -> { - values.forEach { - registerArgument(it.key, it.value) - } - where?.toQueryBuilder(this) + registerUpdateArgs() + registerWhereArg() } } if (args.isNotEmpty()) listOf(args) else emptyList() } + + private fun QueryBuilder.registerWhereArg() { where?.toQueryBuilder(this) } + + private fun QueryBuilder.registerUpdateArgs() { values.forEach { registerArgument(it.key, it.value) } } + + private fun QueryBuilder.registerAdditionalArgs(join: Join) { + join.joinParts.forEach { + it.additionalConstraint?.invoke(SqlExpressionBuilder)?.toQueryBuilder(this) + } + } } diff --git a/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/dml/UpdateTests.kt b/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/dml/UpdateTests.kt index 36616a8766..328d8ea781 100644 --- a/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/dml/UpdateTests.kt +++ b/exposed-tests/src/test/kotlin/org/jetbrains/exposed/sql/tests/shared/dml/UpdateTests.kt @@ -68,7 +68,7 @@ class UpdateTests : DatabaseTestsBase() { } @Test - fun testUpdateWithJoin01() { + fun testUpdateWithSingleJoin() { withCitiesAndUsers(exclude = listOf(TestDB.SQLITE)) { _, users, userData -> val join = users.innerJoin(userData) join.update { @@ -80,11 +80,22 @@ class UpdateTests : DatabaseTestsBase() { assertEquals(it[users.name], it[userData.comment]) assertEquals(123, it[userData.value]) } + + val joinWithConstraint = users.innerJoin(userData, { users.id }, { userData.user_id }) { users.id eq "smth" } + joinWithConstraint.update { + it[userData.comment] = users.name + it[userData.value] = 0 + } + + joinWithConstraint.selectAll().forEach { + assertEquals(it[users.name], it[userData.comment]) + assertEquals(0, it[userData.value]) + } } } @Test - fun testUpdateWithJoin02() { + fun testUpdateWithMultipleJoins() { withCitiesAndUsers(exclude = TestDB.allH2TestDB + TestDB.SQLITE) { cities, users, userData -> val join = cities.innerJoin(users).innerJoin(userData) join.update { @@ -109,7 +120,7 @@ class UpdateTests : DatabaseTestsBase() { val tableAId = reference("table_a_id", tableA) } - val supportWhere = TestDB.entries - TestDB.allH2TestDB - TestDB.SQLITE + TestDB.H2_ORACLE + val supportWhere = TestDB.entries - TestDB.allH2TestDB.toSet() - TestDB.SQLITE + TestDB.H2_ORACLE withTables(tableA, tableB) { testingDb -> val aId = tableA.insertAndGetId { it[foo] = "foo" } @@ -119,6 +130,7 @@ class UpdateTests : DatabaseTestsBase() { } val join = tableA.innerJoin(tableB) + val joinWithConstraint = tableA.innerJoin(tableB, { tableA.id }, { tableB.tableAId }) { tableB.bar eq "foo" } if (testingDb in supportWhere) { join.update({ tableA.foo eq "foo" }) { @@ -127,6 +139,11 @@ class UpdateTests : DatabaseTestsBase() { join.selectAll().single().also { assertEquals("baz", it[tableB.bar]) } + + joinWithConstraint.update({ tableA.foo eq "foo" }) { + it[tableB.bar] = "baz" + } + assertEquals(0, joinWithConstraint.selectAll().count()) } else { expectException { join.update({ tableA.foo eq "foo" }) {