Skip to content

Commit

Permalink
[SEDONA-624] Bind references in distance expression to relations lazi…
Browse files Browse the repository at this point in the history
…ly to avoid exception in query plan canonicalization (#1518)
  • Loading branch information
Kontinuation committed Jul 9, 2024
1 parent 9c3c67d commit 493ec74
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ case class DistanceJoinExec(
with TraitJoinQueryExec
with Logging {

private val boundRadius = if (distanceBoundToLeft) {
private lazy val boundRadius = if (distanceBoundToLeft) {
BindReferences.bindReference(distance, left.output)
} else {
BindReferences.bindReference(distance, right.output)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,15 @@ class SpatialJoinSuite extends TestBaseScala with TableDrivenPropertyChecks {
}
}

describe("Sedona-SQL Spatial Join Test with SELECT *") {
describe("Sedona-SQL Spatial Join Test with SELECT * and SELECT COUNT(*)") {
val joinConditions = Table(
"join condition",
"ST_Contains(df1.geom, df2.geom)",
"ST_Contains(df2.geom, df1.geom)",
"ST_Distance(df1.geom, df2.geom) < 1.0",
"ST_Distance(df2.geom, df1.geom) < 1.0")
"ST_Distance(df2.geom, df1.geom) < 1.0",
"ST_Distance(df1.geom, df2.geom) < df1.dist",
"ST_Distance(df1.geom, df2.geom) < df2.dist")

forAll(joinConditions) { joinCondition =>
it(s"should SELECT * in join query with $joinCondition produce correct result") {
Expand All @@ -120,6 +122,16 @@ class SpatialJoinSuite extends TestBaseScala with TableDrivenPropertyChecks {
assert(result === expected)
}

it(s"should SELECT COUNT(*) in join query with $joinCondition produce correct result") {
val result = sparkSession
.sql(s"SELECT COUNT(*) FROM df1 JOIN df2 ON $joinCondition")
.collect()
.head
.getLong(0)
val expected = buildExpectedResult(joinCondition).length
assert(result === expected)
}

it(
s"should SELECT * in join query with $joinCondition produce correct result, broadcast the left side") {
val resultAll = sparkSession
Expand All @@ -131,6 +143,17 @@ class SpatialJoinSuite extends TestBaseScala with TableDrivenPropertyChecks {
assert(result === expected)
}

it(
s"should SELECT COUNT(*) in join query with $joinCondition produce correct result, broadcast the left side") {
val result = sparkSession
.sql(s"SELECT /*+ BROADCAST(df1) */ COUNT(*) FROM df1 JOIN df2 ON $joinCondition")
.collect()
.head
.getLong(0)
val expected = buildExpectedResult(joinCondition).length
assert(result === expected)
}

it(
s"should SELECT * in join query with $joinCondition produce correct result, broadcast the right side") {
val resultAll = sparkSession
Expand All @@ -141,6 +164,17 @@ class SpatialJoinSuite extends TestBaseScala with TableDrivenPropertyChecks {
assert(result.nonEmpty)
assert(result === expected)
}

it(
s"should SELECT COUNT(*) in join query with $joinCondition produce correct result, broadcast the right side") {
val result = sparkSession
.sql(s"SELECT /*+ BROADCAST(df2) */ COUNT(*) FROM df1 JOIN df2 ON $joinCondition")
.collect()
.head
.getLong(0)
val expected = buildExpectedResult(joinCondition).length
assert(result === expected)
}
}
}

Expand Down Expand Up @@ -192,7 +226,7 @@ class SpatialJoinSuite extends TestBaseScala with TableDrivenPropertyChecks {
}
}

describe("Spatial join should work with dataframe containing 0 partitions") {
describe("Spatial join should work with dataframe containing various number of partitions") {
val queries = Table(
"join queries",
"SELECT * FROM df1 JOIN dfEmpty WHERE ST_Intersects(df1.geom, dfEmpty.geom)",
Expand All @@ -203,7 +237,7 @@ class SpatialJoinSuite extends TestBaseScala with TableDrivenPropertyChecks {
"SELECT /*+ BROADCAST(dfEmpty) */ * FROM dfEmpty JOIN df1 WHERE ST_Intersects(df1.geom, dfEmpty.geom)")

forAll(queries) { query =>
it(s"Legacy join: $query") {
it(s"empty dataframes: $query") {
withConf(Map(spatialJoinPartitionSideConfKey -> "left")) {
val resultRows = sparkSession.sql(query).collect()
assert(resultRows.isEmpty)
Expand Down

0 comments on commit 493ec74

Please sign in to comment.