Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SEDONA-217] Automatically broadcast small datasets. #730

Merged
merged 1 commit into from
Dec 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions core/src/main/java/org/apache/sedona/core/utils/SedonaConf.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.sedona.core.enums.JoinSparitionDominantSide;
import org.apache.spark.sql.RuntimeConfig;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.util.Utils;
import org.locationtech.jts.geom.Envelope;

import java.io.Serializable;
Expand Down Expand Up @@ -54,6 +55,8 @@ public class SedonaConf

private GridType joinGridType;

private long autoBroadcastJoinThreshold;

public static SedonaConf fromActiveSession() {
return new SedonaConf(SparkSession.active().conf());
}
Expand All @@ -70,6 +73,11 @@ public SedonaConf(RuntimeConfig runtimeConfig)
this.joinBuildSide = JoinBuildSide.getBuildSide(runtimeConfig.get("sedona.join.indexbuildside", "left"));
this.joinSparitionDominantSide = JoinSparitionDominantSide.getJoinSparitionDominantSide(runtimeConfig.get("sedona.join.spatitionside", "left"));
this.fallbackPartitionNum = Integer.parseInt(runtimeConfig.get("sedona.join.numpartition", "-1"));
this.autoBroadcastJoinThreshold = bytesFromString(
runtimeConfig.get("sedona.join.autoBroadcastJoinThreshold",
runtimeConfig.get("spark.sql.autoBroadcastJoinThreshold")
)
);
}

public boolean getUseIndex()
Expand Down Expand Up @@ -113,6 +121,11 @@ public int getFallbackPartitionNum()
return fallbackPartitionNum;
}

public long getAutoBroadcastJoinThreshold()
{
return autoBroadcastJoinThreshold;
}

public String toString()
{
try {
Expand All @@ -132,4 +145,12 @@ public String toString()
return null;
}
}

static long bytesFromString(String str) {
if (str.startsWith("-")) {
return -1 * Utils.byteStringAsBytes(str.substring(1));
} else {
return Utils.byteStringAsBytes(str);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,11 @@ public void testDatasetBoundary() {
Envelope datasetBoundary = SedonaConf.fromActiveSession().getDatasetBoundary();
assertEquals("Env[1.0 : 2.0, 3.0 : 4.0]", datasetBoundary.toString());
}

@Test
public void testBytesFromString() {
assertEquals(-1, SedonaConf.bytesFromString("-1"));
assertEquals(1024, SedonaConf.bytesFromString("1k"));
assertEquals(2097152, SedonaConf.bytesFromString("2MB"));
}
}
7 changes: 6 additions & 1 deletion docs/api/sql/Parameter.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ sparkSession.conf.set("sedona.global.index","false")
* Spatial index type, only valid when "sedona.global.index" is true
* Default: quadtree
* Possible values: rtree, quadtree
* sedona.join.autoBroadcastJoinThreshold
* Configures the maximum size in bytes for a table that will be broadcast to all worker nodes when performing a join.
By setting this value to -1 automatic broadcasting can be disabled.
* Default: The default value is the same as spark.sql.autoBroadcastJoinThreshold
* Possible values: any integer with a byte suffix i.e. 10MB or 512KB
* sedona.join.gridtype
* Spatial partitioning grid type for join query
* Default: kdbtree
Expand All @@ -43,4 +48,4 @@ sparkSession.conf.set("sedona.global.index","false")
* sedona.join.spatitionside **(Advanced users only!)**
* The dominant side in spatial partitioning stage
* Default: left
* Possible values: left, right
* Possible values: left, right
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,28 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
}

def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case Join(left, right, joinType, condition, JoinHint(leftHint, rightHint)) => { // SPARK3 anchor
// case Join(left, right, Inner, condition) => { // SPARK2 anchor
val broadcastLeft = leftHint.exists(_.strategy.contains(BROADCAST)) // SPARK3 anchor
val broadcastRight = rightHint.exists(_.strategy.contains(BROADCAST)) // SPARK3 anchor
// val broadcastLeft = left.isInstanceOf[ResolvedHint] && left.asInstanceOf[ResolvedHint].hints.broadcast // SPARK2 anchor
// val broadcastRight = right.isInstanceOf[ResolvedHint] && right.asInstanceOf[ResolvedHint].hints.broadcast // SPARK2 anchor
case Join(left, right, joinType, condition, JoinHint(leftHint, rightHint)) => {
var broadcastLeft = leftHint.exists(_.strategy.contains(BROADCAST))
var broadcastRight = rightHint.exists(_.strategy.contains(BROADCAST))

/*
If either side is small we can automatically broadcast just like Spark does.
This only applies to inner joins as there are no optimized fallback plan for other join types.
It's better that users are explicit about broadcasting for other join types than seeing wildly different behavior
depending on data size.
*/
if (!broadcastLeft && !broadcastRight && joinType == Inner) {
val canAutoBroadCastLeft = canAutoBroadcastBySize(left)
val canAutoBroadCastRight = canAutoBroadcastBySize(right)
if (canAutoBroadCastLeft && canAutoBroadCastRight) {
// Both sides can be broadcasted. Choose the smallest side.
broadcastLeft = left.stats.sizeInBytes <= right.stats.sizeInBytes
broadcastRight = !broadcastLeft
} else {
broadcastLeft = canAutoBroadCastLeft
broadcastRight = canAutoBroadCastRight
}
}

val queryDetection: Option[JoinQueryDetection] = condition match {
case Some(predicate: ST_Predicate) =>
Expand Down Expand Up @@ -137,6 +153,9 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
Nil
}

private def canAutoBroadcastBySize(plan: LogicalPlan) =
plan.stats.sizeInBytes != 0 && plan.stats.sizeInBytes <= SedonaConf.fromActiveSession.getAutoBroadcastJoinThreshold

/**
* Returns true if specified expression has at least one reference and all its references
* map to the output of the specified plan.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1408,4 +1408,26 @@ class BroadcastIndexJoinSuite extends TestBaseScala {
}
}

describe("Sedona SQL automatic broadcast") {
it("Datasets smaller than threshold should be broadcasted") {
val polygonDf = buildPolygonDf.repartition(3).alias("polygon")
val pointDf = buildPointDf.repartition(5).alias("point")
val df = polygonDf.join(pointDf, expr("ST_Contains(polygon.polygonshape, point.pointshape)"))
sparkSession.conf.set("sedona.global.index", "true")
sparkSession.conf.set("sedona.join.autoBroadcastJoinThreshold", "10mb")

assert(df.queryExecution.sparkPlan.collect { case p: BroadcastIndexJoinExec => p }.size === 1)
sparkSession.conf.set("sedona.join.autoBroadcastJoinThreshold", "-1")
}

it("Datasets larger than threshold should not be broadcasted") {
val polygonDf = buildPolygonDf.repartition(3).alias("polygon")
val pointDf = buildPointDf.repartition(5).alias("point")
val df = polygonDf.join(pointDf, expr("ST_Contains(polygon.polygonshape, point.pointshape)"))
sparkSession.conf.set("sedona.global.index", "true")
sparkSession.conf.set("sedona.join.autoBroadcastJoinThreshold", "-1")

assert(df.queryExecution.sparkPlan.collect { case p: BroadcastIndexJoinExec => p }.size === 0)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll {
config("spark.kryo.registrator", classOf[SedonaKryoRegistrator].getName).
master("local[*]").appName("sedonasqlScalaTest")
.config("spark.sql.warehouse.dir", warehouseLocation)
// We need to be explicit about broadcasting in tests.
.config("sedona.join.autoBroadcastJoinThreshold", "-1")
.getOrCreate()

val resourceFolder = System.getProperty("user.dir") + "/../core/src/test/resources/"
Expand Down