Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SEDONA-262] Don't optimize equi-join by default, add an option to configure when to optimize spatial joins #797

Merged
merged 1 commit into from
Mar 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sedona.core.enums;

public enum SpatialJoinOptimizationMode {
/**
* Don't optimize spatial joins, just leave them as they are (cartesian join or broadcast nested loop join).
*/
NONE,

/**
* Optimize all spatial joins, even though the join is an equi-join. For example, for a range join like this:
* <p>{@code SELECT * FROM A, B WHERE A.x = B.x AND ST_Contains(A.geom, B.geom)}
* <p>The join will still be optimized to a spatial range join.
*/
ALL,

/**
* Optimize spatial joins that are not equi-join, this is the default mode.
* <p>For example, for a range join like this:
* <p>{@code SELECT * FROM A, B WHERE A.x = B.x AND ST_Contains(A.geom, B.geom)}
* <p>It won't be optimized as a spatial join, since it is an equi-join (with equi-condition: {@code A.x = B.x}), and
* could be executed by a sort-merge join or hash join.
*/
NONEQUI;

public static SpatialJoinOptimizationMode getSpatialJoinOptimizationMode(String str) {
for (SpatialJoinOptimizationMode me : SpatialJoinOptimizationMode.values()) {
if (me.name().equalsIgnoreCase(str)) { return me; }
}
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.sedona.core.enums.IndexType;
import org.apache.sedona.core.enums.JoinBuildSide;
import org.apache.sedona.core.enums.JoinSparitionDominantSide;
import org.apache.sedona.core.enums.SpatialJoinOptimizationMode;
import org.apache.spark.sql.RuntimeConfig;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.util.Utils;
Expand Down Expand Up @@ -57,6 +58,8 @@ public class SedonaConf

private long autoBroadcastJoinThreshold;

private SpatialJoinOptimizationMode spatialJoinOptimizationMode;

public static SedonaConf fromActiveSession() {
return new SedonaConf(SparkSession.active().conf());
}
Expand All @@ -78,6 +81,8 @@ public SedonaConf(RuntimeConfig runtimeConfig)
runtimeConfig.get("spark.sql.autoBroadcastJoinThreshold")
)
);
this.spatialJoinOptimizationMode = SpatialJoinOptimizationMode.getSpatialJoinOptimizationMode(
runtimeConfig.get("sedona.join.optimizationmode", "nonequi"));
}

public boolean getUseIndex()
Expand Down Expand Up @@ -153,4 +158,8 @@ static long bytesFromString(String str) {
return Utils.byteStringAsBytes(str);
}
}

public SpatialJoinOptimizationMode getSpatialJoinOptimizationMode() {
return spatialJoinOptimizationMode;
}
}
7 changes: 7 additions & 0 deletions docs/api/sql/Parameter.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,10 @@ sparkSession.conf.set("sedona.global.index","false")
* The dominant side in spatial partitioning stage
* Default: left
* Possible values: left, right
* sedona.join.optimizationmode **(Advanced users only!)**
* When should Sedona optimize spatial join SQL queries
* Default: nonequi
* Possible values:
* all: Always optimize spatial join queries, even for equi-joins.
* none: Disable optimization for spatial joins.
* nonequi: Optimize spatial join queries that are not equi-joins.
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.spark.sql.sedona_sql.optimization

import org.apache.spark.sql.catalyst.expressions.{And, Expression}

/**
* This class contains helper methods for transforming catalyst expressions.
*/
object ExpressionUtils {
/**
* This is a polyfill for running on Spark 3.0 while compiling against Spark 3.3. We'd really like to mixin
* `PredicateHelper` here, but the class hierarchy of `PredicateHelper` has changed between Spark 3.0 and 3.3 so
* it would raise `java.lang.ClassNotFoundException: org.apache.spark.sql.catalyst.expressions.AliasHelper`
* at runtime on Spark 3.0.
*
* @param condition filter condition to split
* @return A list of conjunctive conditions
*/
def splitConjunctivePredicates(condition: Expression): Seq[Expression] = {
condition match {
case And(cond1, cond2) =>
splitConjunctivePredicates(cond1) ++ splitConjunctivePredicates(cond2)
case other => other :: Nil
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ import org.apache.spark.sql.sedona_sql.expressions.ST_OrderingEquals
import org.apache.spark.sql.sedona_sql.expressions.ST_Overlaps
import org.apache.spark.sql.sedona_sql.expressions.ST_Touches
import org.apache.spark.sql.sedona_sql.expressions.ST_Within
import org.apache.spark.sql.sedona_sql.optimization.ExpressionUtils.splitConjunctivePredicates
import org.apache.spark.sql.types.DoubleType
import org.locationtech.jts.geom.Geometry
import org.locationtech.jts.geom.Point
Expand Down Expand Up @@ -171,20 +172,4 @@ class SpatialFilterPushDownForGeoParquet(sparkSession: SparkSession) extends Rul
case _ => None
}
}

/**
* This is a polyfill for running on Spark 3.0 while compiling against Spark 3.3. We'd really like to mixin
* `PredicateHelper` here, but the class hierarchy of `PredicateHelper` has changed between Spark 3.0 and 3.3 so
* it would raise `java.lang.ClassNotFoundException: org.apache.spark.sql.catalyst.expressions.AliasHelper`
* at runtime on Spark 3.0.
* @param condition filter condition to split
* @return A list of conjunctive conditions
*/
private def splitConjunctivePredicates(condition: Expression): Seq[Expression] = {
condition match {
case And(cond1, cond2) =>
splitConjunctivePredicates(cond1) ++ splitConjunctivePredicates(cond2)
case other => other :: Nil
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,16 @@
*/
package org.apache.spark.sql.sedona_sql.strategy.join

import org.apache.sedona.core.enums.IndexType
import org.apache.sedona.core.enums.{IndexType, SpatialJoinOptimizationMode}
import org.apache.sedona.core.spatialOperator.SpatialPredicate
import org.apache.sedona.core.utils.SedonaConf
import org.apache.spark.sql.{SparkSession, Strategy}
import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan, LessThanOrEqual}
import org.apache.spark.sql.catalyst.expressions.{And, EqualNullSafe, EqualTo, Expression, LessThan, LessThanOrEqual}
import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, Inner, InnerLike, JoinType, LeftAnti, LeftOuter, LeftSemi, NaturalJoin, RightOuter, UsingJoin}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.sedona_sql.expressions._
import org.apache.spark.sql.sedona_sql.optimization.ExpressionUtils.splitConjunctivePredicates


case class JoinQueryDetection(
Expand All @@ -44,7 +45,7 @@ case class JoinQueryDetection(
* and ST_Intersects(a, b).
*
* Plans `DistanceJoinExec` for inner joins on spatial relationship ST_Distance(a, b) < r.
*
*
* Plans `BroadcastIndexJoinExec` for inner joins on spatial relationships with a broadcast hint.
*/
class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
Expand Down Expand Up @@ -78,7 +79,7 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
}

def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case Join(left, right, joinType, condition, JoinHint(leftHint, rightHint)) => {
case Join(left, right, joinType, condition, JoinHint(leftHint, rightHint)) if optimizationEnabled(left, right, condition) => {
var broadcastLeft = leftHint.exists(_.strategy.contains(BROADCAST))
var broadcastRight = rightHint.exists(_.strategy.contains(BROADCAST))

Expand Down Expand Up @@ -144,7 +145,7 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
planSpatialJoin(left, right, Seq(leftShape, rightShape), joinType, spatialPredicate, extraCondition)
case Some(JoinQueryDetection(left, right, leftShape, rightShape, spatialPredicate, extraCondition, Some(distance))) =>
planDistanceJoin(left, right, Seq(leftShape, rightShape), joinType, distance, spatialPredicate, extraCondition)
case None =>
case None =>
Nil
}
}
Expand All @@ -153,6 +154,16 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
Nil
}

private def optimizationEnabled(left: LogicalPlan, right: LogicalPlan, condition: Option[Expression]): Boolean = {
val sedonaConf = new SedonaConf(sparkSession.conf)
sedonaConf.getSpatialJoinOptimizationMode match {
case SpatialJoinOptimizationMode.NONE => false
case SpatialJoinOptimizationMode.ALL => true
case SpatialJoinOptimizationMode.NONEQUI => !isEquiJoin(left, right, condition)
case mode => throw new IllegalArgumentException(s"Unknown spatial join optimization mode: $mode")
}
}

private def canAutoBroadcastBySize(plan: LogicalPlan) =
plan.stats.sizeInBytes != 0 && plan.stats.sizeInBytes <= SedonaConf.fromActiveSession.getAutoBroadcastJoinThreshold

Expand All @@ -161,7 +172,7 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
* map to the output of the specified plan.
*/
private def matches(expr: Expression, plan: LogicalPlan): Boolean =
expr.references.nonEmpty && expr.references.forall(plan.outputSet.contains(_))
expr.references.nonEmpty && expr.references.subsetOf(plan.outputSet)

private def matchExpressionsToPlans(exprA: Expression,
exprB: Expression,
Expand Down Expand Up @@ -327,4 +338,26 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
Nil
}
}

/**
* Check if the given condition is an equi-join between the given plans. This method basically replicates
* the logic of [[org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys.unapply]] but it does not
* populate the join keys.
*
* @param left left side of the join
* @param right right side of the join
* @param condition join condition
* @return true if the condition is an equi-join between the given plans
*/
private def isEquiJoin(left: LogicalPlan, right: LogicalPlan, condition: Option[Expression]): Boolean = {
val predicates = condition.map(splitConjunctivePredicates).getOrElse(Nil)
predicates.exists {
case EqualTo(l, r) if l.references.isEmpty || r.references.isEmpty => false
case EqualTo(l, r) if matches(l, left) && matches(r, right) => true
case EqualTo(l, r) if matches(l, right) && matches(r, left) => true
case EqualNullSafe(l, r) if matches(l, left) && matches(r, right) => true
case EqualNullSafe(l, r) if matches(l, right) && matches(r, left) => true
case _ => false
}
}
}
51 changes: 51 additions & 0 deletions sql/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.sql.Column
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.{col, expr}
import org.apache.spark.sql.sedona_sql.expressions.st_constructors.ST_GeomFromText
import org.apache.spark.sql.sedona_sql.strategy.join.{BroadcastIndexJoinExec, DistanceJoinExec, RangeJoinExec}
import org.apache.spark.sql.types.IntegerType
import org.locationtech.jts.geom.Geometry
import org.locationtech.jts.io.WKTReader
Expand Down Expand Up @@ -140,6 +141,47 @@ class SpatialJoinSuite extends TestBaseScala with TableDrivenPropertyChecks {
}
}

describe("Spatial join in Sedona SQL should be configurable using sedona.join.optimizationmode") {
it("Optimize all spatial joins when sedona.join.optimizationmode = all") {
withOptimizationMode("all") {
prepareTempViewsForTestData()
val df = sparkSession.sql("SELECT df1.id, df2.id FROM df1 JOIN df2 ON df1.id = df2.id AND ST_Intersects(df1.geom, df2.geom)")
assert(isUsingOptimizedSpatialJoin(df))
val expectedResult = buildExpectedResult("ST_Intersects(df1.geom, df2.geom)")
.filter { case (id1, id2) => id1 == id2 }
verifyResult(expectedResult, df)
}
}

it("Only optimize non-equi-joins when sedona.join.optimizationmode = nonequi") {
withOptimizationMode("nonequi") {
prepareTempViewsForTestData()
val df = sparkSession.sql("SELECT df1.id, df2.id FROM df1 JOIN df2 ON ST_Intersects(df1.geom, df2.geom)")
assert(isUsingOptimizedSpatialJoin(df))
val df2 = sparkSession.sql("SELECT df1.id, df2.id FROM df1 JOIN df2 ON df1.id = df2.id AND ST_Intersects(df1.geom, df2.geom)")
assert(!isUsingOptimizedSpatialJoin(df2))
}
}

it("Won't optimize spatial joins when sedona.join.optimizationmode = none") {
withOptimizationMode("none") {
prepareTempViewsForTestData()
val df = sparkSession.sql("SELECT df1.id, df2.id FROM df1 JOIN df2 ON ST_Intersects(df1.geom, df2.geom)")
assert(!isUsingOptimizedSpatialJoin(df))
}
}
}

private def withOptimizationMode(mode: String)(body: => Unit) : Unit = {
val oldOptimizationMode = sparkSession.conf.get("sedona.join.optimizationmode", "nonequi")
try {
sparkSession.conf.set("sedona.join.optimizationmode", mode)
body
} finally {
sparkSession.conf.set("sedona.join.optimizationmode", oldOptimizationMode)
}
}

private def prepareTempViewsForTestData(): (DataFrame, DataFrame) = {
val df1 = sparkSession.read.format("csv").option("header", "false").option("delimiter", testDataDelimiter)
.load(spatialJoinLeftInputLocation)
Expand Down Expand Up @@ -207,8 +249,17 @@ class SpatialJoinSuite extends TestBaseScala with TableDrivenPropertyChecks {
}

def verifyResult(expected: Seq[(Int, Int)], result: DataFrame): Unit = {
isUsingOptimizedSpatialJoin(result)
val actual = result.collect().map(row => (row.getInt(0), row.getInt(1))).sorted
assert(actual.nonEmpty)
assert(actual === expected)
}

def isUsingOptimizedSpatialJoin(df: DataFrame): Boolean = {
df.queryExecution.executedPlan.collect {
case _: BroadcastIndexJoinExec |
_: DistanceJoinExec |
_: RangeJoinExec => true
}.nonEmpty
}
}