Skip to content

Commit

Permalink
fix: EXPOSED-485 ClassCastException when eager loading referrersOn wi…
Browse files Browse the repository at this point in the history
…th uuid().references() (#2198)

* fix: EXPOSED-485 ClassCastException when eager loading referrersOn with uuid().references()

Using Column.references() invoked on a UUIDColumnType that targets an EntityIDColumnType<UUID>
causes a ClassCastException when referrersOn or backReferencedOn are eager loaded,
for the same type mismatch as detailed in the previous related issues.

This extends the previous fix to include Referrers and BackReference objects in
preloadRelations(), by forcing the query clause to use the underlying/wrapped type
of EntityIDColumnType if there is a type mismatch between refColumn.referree and refIds.

* test: Add corresponding test in `UIntIdTableEntityTest` and `ULongIdTableEntityTest`

---------

Co-authored-by: Jocelyne <jocelyne.abihaidar@jetbrains.com>
  • Loading branch information
bog-walk and joc-a committed Aug 19, 2024
1 parent 03afa02 commit c6a035f
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -942,7 +942,10 @@ abstract class EntityClass<ID : Comparable<ID>, out T : Entity<ID>>(
val entities = getEntities(forUpdate, findQuery).distinct()

entities.groupByReference(refColumn = refColumn).forEach { (id, values) ->
val parentEntityId: EntityID<*> = parentTable.selectAll().where { refColumn.referee as Column<SID> eq id }
val castReferee = refColumn.referee
.takeUnless { it?.columnType is EntityIDColumnType<*> && id !is EntityID<*> }
?: (refColumn.referee?.columnType as EntityIDColumnType<*>).idColumn
val parentEntityId: EntityID<*> = parentTable.selectAll().where { castReferee as Column<SID> eq id }
.single()[parentTable.id]

cache.getOrPutReferrers(parentEntityId, refColumn) { SizedCollection(values) }.also {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,13 @@ private fun <ID : Comparable<ID>> List<Entity<ID>>.preloadRelations(
refColumns.map { (child, parent) -> child to (parent.lookup() as EntityID<*>).value }
}

fun Entity<*>.getRefereeId(refereeColumn: Column<*>, delegateRefColumn: Column<*>): Any {
val refereeValue = refereeColumn.lookup()
return refereeValue.takeUnless {
delegateRefColumn.columnType !is EntityIDColumnType<*> && it is EntityID<*>
} ?: (refereeValue as EntityID<*>).value
}

val directRelations = filterRelationsForEntity(entity, relations)
directRelations.forEach { prop ->
when (val refObject = getReferenceObjectFromDelegatedProperty(entity, prop)) {
Expand Down Expand Up @@ -302,7 +309,8 @@ private fun <ID : Comparable<ID>> List<Entity<ID>>.preloadRelations(
(refObject as Referrers<ID, Entity<ID>, *, Entity<*>, Any>).allReferences.let { refColumns ->
val delegateRefColumn = refObject.reference
if (hasSingleReferenceWithReferee(refColumns)) {
val refIds = this.map { it.run { delegateRefColumn.referee<Any>()!!.lookup() } }
val castReferee = delegateRefColumn.referee<Any>()!!
val refIds = this.map { entity -> entity.getRefereeId(castReferee, delegateRefColumn) }
refObject.factory.warmUpReferences(refIds, delegateRefColumn)
} else {
val refIds = this.map { it.getCompositeReferrerId(refColumns) }
Expand All @@ -326,7 +334,8 @@ private fun <ID : Comparable<ID>> List<Entity<ID>>.preloadRelations(
(refObject.delegate as Referrers<ID, Entity<ID>, *, Entity<*>, Any>).allReferences.let { refColumns ->
val delegateRefColumn = refObject.delegate.reference
if (hasSingleReferenceWithReferee(refColumns)) {
val refIds = this.map { it.run { delegateRefColumn.referee<Any>()!!.lookup() } }
val castReferee = delegateRefColumn.referee<Any>()!!
val refIds = this.map { entity -> entity.getRefereeId(castReferee, delegateRefColumn) }
refObject.delegate.factory.warmUpReferences(refIds, delegateRefColumn)
} else {
val refIds = this.map { it.getCompositeReferrerId(refColumns) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import org.jetbrains.exposed.dao.id.LongIdTable
import org.jetbrains.exposed.dao.with
import org.jetbrains.exposed.sql.Column
import org.jetbrains.exposed.sql.exists
import org.jetbrains.exposed.sql.insert
import org.jetbrains.exposed.sql.insertAndGetId
import org.jetbrains.exposed.sql.tests.DatabaseTestsBase
import org.jetbrains.exposed.sql.tests.shared.assertEquals
Expand Down Expand Up @@ -123,21 +122,26 @@ class LongIdTableEntityTest : DatabaseTestsBase() {
val cId = LongIdTables.Cities.insertAndGetId {
it[name] = "City A"
}
LongIdTables.Towns.insert {
val tId = LongIdTables.Towns.insertAndGetId {
it[cityId] = cId.value
}

// lazy loaded reference
// lazy loaded referencedOn
val town1 = LongIdTables.Town.all().single()
assertEquals(cId, town1.city.id)

// eager loaded reference
// eager loaded referencedOn
val town1WithCity = LongIdTables.Town.all().with(LongIdTables.Town::city).single()
assertEquals(cId, town1WithCity.city.id)

// lazy loaded referrersOn
val city1 = LongIdTables.City.all().single()
val towns = city1.towns
assertEquals(cId, towns.first().city.id)

// eager loaded referrersOn
val city1WithTowns = LongIdTables.City.all().with(LongIdTables.City::towns).single()
assertEquals(tId, city1WithTowns.towns.first().id)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ import org.jetbrains.exposed.dao.UIntEntity
import org.jetbrains.exposed.dao.UIntEntityClass
import org.jetbrains.exposed.dao.id.EntityID
import org.jetbrains.exposed.dao.id.UIntIdTable
import org.jetbrains.exposed.dao.with
import org.jetbrains.exposed.sql.Column
import org.jetbrains.exposed.sql.exists
import org.jetbrains.exposed.sql.insertAndGetId
import org.jetbrains.exposed.sql.tests.DatabaseTestsBase
import org.jetbrains.exposed.sql.tests.shared.assertEquals
import org.junit.Test
Expand Down Expand Up @@ -78,6 +81,35 @@ class UIntIdTableEntityTest : DatabaseTestsBase() {
assertEquals(false, allPeople.contains(Pair("Tanu Arora", "Pune")))
}
}

@Test
fun testForeignKeyBetweenUIntAndEntityIDColumns() {
withTables(UIntIdTables.Cities, UIntIdTables.Towns) {
val cId = UIntIdTables.Cities.insertAndGetId {
it[name] = "City A"
}
val tId = UIntIdTables.Towns.insertAndGetId {
it[cityId] = cId.value
}

// lazy loaded referencedOn
val town1 = UIntIdTables.Town.all().single()
assertEquals(cId, town1.city.id)

// eager loaded referencedOn
val town1WithCity = UIntIdTables.Town.all().with(UIntIdTables.Town::city).single()
assertEquals(cId, town1WithCity.city.id)

// lazy loaded referrersOn
val city1 = UIntIdTables.City.all().single()
val towns = city1.towns
assertEquals(cId, towns.first().city.id)

// eager loaded referrersOn
val city1WithTowns = UIntIdTables.City.all().with(UIntIdTables.City::towns).single()
assertEquals(tId, city1WithTowns.towns.first().id)
}
}
}

object UIntIdTables {
Expand All @@ -89,6 +121,7 @@ object UIntIdTables {
companion object : UIntEntityClass<City>(Cities)

var name by Cities.name
val towns by Town referrersOn Towns.cityId
}

object People : UIntIdTable() {
Expand All @@ -102,4 +135,14 @@ object UIntIdTables {
var name by People.name
var city by City referencedOn People.cityId
}

object Towns : UIntIdTable("towns") {
val cityId: Column<UInt> = uinteger("city_id").references(Cities.id)
}

class Town(id: EntityID<UInt>) : UIntEntity(id) {
companion object : UIntEntityClass<Town>(Towns)

var city by City referencedOn Towns.cityId
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ import org.jetbrains.exposed.dao.ULongEntity
import org.jetbrains.exposed.dao.ULongEntityClass
import org.jetbrains.exposed.dao.id.EntityID
import org.jetbrains.exposed.dao.id.ULongIdTable
import org.jetbrains.exposed.dao.with
import org.jetbrains.exposed.sql.Column
import org.jetbrains.exposed.sql.exists
import org.jetbrains.exposed.sql.insertAndGetId
import org.jetbrains.exposed.sql.tests.DatabaseTestsBase
import org.jetbrains.exposed.sql.tests.shared.assertEquals
import org.junit.Test
Expand Down Expand Up @@ -78,6 +81,35 @@ class ULongIdTableEntityTest : DatabaseTestsBase() {
assertEquals(false, allPeople.contains(Pair("Tanu Arora", "Pune")))
}
}

@Test
fun testForeignKeyBetweenULongAndEntityIDColumns() {
withTables(ULongIdTables.Cities, ULongIdTables.Towns) {
val cId = ULongIdTables.Cities.insertAndGetId {
it[name] = "City A"
}
val tId = ULongIdTables.Towns.insertAndGetId {
it[cityId] = cId.value
}

// lazy loaded referencedOn
val town1 = ULongIdTables.Town.all().single()
assertEquals(cId, town1.city.id)

// eager loaded referencedOn
val town1WithCity = ULongIdTables.Town.all().with(ULongIdTables.Town::city).single()
assertEquals(cId, town1WithCity.city.id)

// lazy loaded referrersOn
val city1 = ULongIdTables.City.all().single()
val towns = city1.towns
assertEquals(cId, towns.first().city.id)

// eager loaded referrersOn
val city1WithTowns = ULongIdTables.City.all().with(ULongIdTables.City::towns).single()
assertEquals(tId, city1WithTowns.towns.first().id)
}
}
}

object ULongIdTables {
Expand All @@ -89,6 +121,7 @@ object ULongIdTables {
companion object : ULongEntityClass<City>(Cities)

var name by Cities.name
val towns by Town referrersOn Towns.cityId
}

object People : ULongIdTable() {
Expand All @@ -102,4 +135,14 @@ object ULongIdTables {
var name by People.name
var city by City referencedOn People.cityId
}

object Towns : ULongIdTable("towns") {
val cityId: Column<ULong> = ulong("city_id").references(Cities.id)
}

class Town(id: EntityID<ULong>) : ULongEntity(id) {
companion object : ULongEntityClass<Town>(Towns)

var city by City referencedOn Towns.cityId
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import org.jetbrains.exposed.dao.id.UUIDTable
import org.jetbrains.exposed.dao.with
import org.jetbrains.exposed.sql.Column
import org.jetbrains.exposed.sql.exists
import org.jetbrains.exposed.sql.insert
import org.jetbrains.exposed.sql.insertAndGetId
import org.jetbrains.exposed.sql.tests.DatabaseTestsBase
import org.jetbrains.exposed.sql.tests.shared.assertEquals
Expand Down Expand Up @@ -170,21 +169,26 @@ class UUIDTableEntityTest : DatabaseTestsBase() {
val cId = UUIDTables.Cities.insertAndGetId {
it[name] = "City A"
}
UUIDTables.Towns.insert {
val tId = UUIDTables.Towns.insertAndGetId {
it[cityId] = cId.value
}

// lazy loaded reference
// lazy loaded referencedOn
val town1 = UUIDTables.Town.all().single()
assertEquals(cId, town1.city.id)

// eager loaded reference
// eager loaded referencedOn
val town1WithCity = UUIDTables.Town.all().with(UUIDTables.Town::city).single()
assertEquals(cId, town1WithCity.city.id)

// lazy loaded referrersOn
val city1 = UUIDTables.City.all().single()
val towns = city1.towns
assertEquals(cId, towns.first().city.id)

// eager loaded referrersOn
val city1WithTowns = UUIDTables.City.all().with(UUIDTables.City::towns).single()
assertEquals(tId, city1WithTowns.towns.first().id)
}
}
}

0 comments on commit c6a035f

Please sign in to comment.