Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SEDONA-221] Outer join throws NPE for null geometries. #749

Merged
merged 1 commit into from
Jan 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.sedona_sql.strategy.join

import org.apache.sedona.core.spatialOperator.{SpatialPredicate, SpatialPredicateEvaluators}
import org.apache.sedona.core.spatialOperator.SpatialPredicateEvaluators.SpatialPredicateEvaluator
import org.apache.sedona.sql.utils.GeometrySerializer

import scala.collection.JavaConverters._
import org.apache.spark.broadcast.Broadcast
Expand All @@ -29,12 +30,14 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, Expression, GenericInternalRow, JoinedRow, Predicate, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.execution.{RowIterator, SparkPlan}
import org.apache.spark.sql.sedona_sql.execution.SedonaBinaryExecNode
import org.locationtech.jts.geom.Geometry
import org.locationtech.jts.geom.prep.{PreparedGeometry, PreparedGeometryFactory}
import org.locationtech.jts.index.SpatialIndex

import java.util.Collections
import scala.collection.mutable

case class BroadcastIndexJoinExec(
Expand Down Expand Up @@ -68,6 +71,10 @@ case class BroadcastIndexJoinExec(
}
}

override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))


private val (streamed, broadcast) = indexBuildSide match {
case LeftSide => (right, left.asInstanceOf[SpatialIndexExec])
case RightSide => (left, right.asInstanceOf[SpatialIndexExec])
Expand Down Expand Up @@ -115,34 +122,34 @@ case class BroadcastIndexJoinExec(
SpatialPredicateEvaluators.create(SpatialPredicate.inverse(spatialPredicate))
}

private def innerJoin(streamIter: Iterator[Geometry], index: Broadcast[SpatialIndex]): Iterator[InternalRow] = {
private def innerJoin(streamIter: Iterator[(Geometry, UnsafeRow)], index: Broadcast[SpatialIndex]): Iterator[InternalRow] = {
val factory = new PreparedGeometryFactory()
val preparedGeometries = new mutable.HashMap[Geometry, PreparedGeometry]
val joinedRow = new JoinedRow
streamIter.flatMap { srow =>
joinedRow.withLeft(srow.getUserData.asInstanceOf[UnsafeRow])
index.value.query(srow.getEnvelopeInternal)
streamIter.flatMap { case (geom, row) =>
joinedRow.withLeft(row)
index.value.query(geom.getEnvelopeInternal)
.iterator.asScala.asInstanceOf[Iterator[Geometry]]
.filter(candidate => evaluator.eval(preparedGeometries.getOrElseUpdate(candidate, { factory.create(candidate) }), srow))
.filter(candidate => evaluator.eval(preparedGeometries.getOrElseUpdate(candidate, { factory.create(candidate) }), geom))
.map(candidate => joinedRow.withRight(candidate.getUserData.asInstanceOf[UnsafeRow]))
.filter(boundCondition)
}
}

private def semiJoin(
streamIter: Iterator[Geometry], index: Broadcast[SpatialIndex]
streamIter: Iterator[(Geometry, UnsafeRow)], index: Broadcast[SpatialIndex]
): Iterator[InternalRow] = {
val factory = new PreparedGeometryFactory()
val preparedGeometries = new mutable.HashMap[Geometry, PreparedGeometry]
val joinedRow = new JoinedRow
streamIter.flatMap { srow =>
val left = srow.getUserData.asInstanceOf[UnsafeRow]
streamIter.flatMap { case (geom, row) =>
val left = row
joinedRow.withLeft(left)
val anyMatches = index.value.query(srow.getEnvelopeInternal)
val anyMatches = index.value.query(geom.getEnvelopeInternal)
.iterator.asScala.asInstanceOf[Iterator[Geometry]]
.filter(candidate => evaluator.eval(preparedGeometries.getOrElseUpdate(candidate, {
factory.create(candidate)
}), srow))
}), geom))
.map(candidate => joinedRow.withRight(candidate.getUserData.asInstanceOf[UnsafeRow]))
.exists(boundCondition)

Expand All @@ -155,19 +162,19 @@ case class BroadcastIndexJoinExec(
}

private def antiJoin(
streamIter: Iterator[Geometry], index: Broadcast[SpatialIndex]
streamIter: Iterator[(Geometry, UnsafeRow)], index: Broadcast[SpatialIndex]
): Iterator[InternalRow] = {
val factory = new PreparedGeometryFactory()
val preparedGeometries = new mutable.HashMap[Geometry, PreparedGeometry]
val joinedRow = new JoinedRow
streamIter.flatMap { srow =>
val left = srow.getUserData.asInstanceOf[UnsafeRow]
joinedRow.withLeft(left)
val anyMatches = index.value.query(srow.getEnvelopeInternal)
streamIter.flatMap { case (geom, row) =>
val left = row
joinedRow.withLeft(row)
val anyMatches = (if (geom == null) Collections.EMPTY_LIST else index.value.query(geom.getEnvelopeInternal))
.iterator.asScala.asInstanceOf[Iterator[Geometry]]
.filter(candidate => evaluator.eval(preparedGeometries.getOrElseUpdate(candidate, {
factory.create(candidate)
}), srow))
}), geom))
.map(candidate => joinedRow.withRight(candidate.getUserData.asInstanceOf[UnsafeRow]))
.exists(boundCondition)

Expand All @@ -180,20 +187,20 @@ case class BroadcastIndexJoinExec(
}

private def outerJoin(
streamIter: Iterator[Geometry], index: Broadcast[SpatialIndex]
streamIter: Iterator[(Geometry, UnsafeRow)], index: Broadcast[SpatialIndex]
): Iterator[InternalRow] = {
val factory = new PreparedGeometryFactory()
val preparedGeometries = new mutable.HashMap[Geometry, PreparedGeometry]
val joinedRow = new JoinedRow
val nullRow = new GenericInternalRow(broadcast.output.length)

streamIter.flatMap { srow =>
joinedRow.withLeft(srow.getUserData.asInstanceOf[UnsafeRow])
val candidates = index.value.query(srow.getEnvelopeInternal)
streamIter.flatMap { case (geom, row) =>
joinedRow.withLeft(row)
val candidates = (if (geom == null) Collections.EMPTY_LIST else index.value.query(geom.getEnvelopeInternal))
.iterator.asScala.asInstanceOf[Iterator[Geometry]]
.filter(candidate => evaluator.eval(preparedGeometries.getOrElseUpdate(candidate, {
factory.create(candidate)
}), srow))
}), geom))

new RowIterator {
private var found = false
Expand All @@ -218,20 +225,15 @@ case class BroadcastIndexJoinExec(
}

override protected def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
val boundStreamShape = BindReferences.bindReference(streamShape, streamed.output)
val streamResultsRaw = streamed.execute().asInstanceOf[RDD[UnsafeRow]]

val broadcastIndex = broadcast.executeBroadcast[SpatialIndex]()

// If there's a distance and the objects are being broadcast, we need to build the expanded envelope on the window stream side
val streamShapes = distance match {
case Some(distanceExpression) if indexBuildSide != windowJoinSide =>
toExpandedEnvelopeRDD(streamResultsRaw, boundStreamShape, BindReferences.bindReference(distanceExpression, streamed.output))
case _ =>
toSpatialRDD(streamResultsRaw, boundStreamShape)
}
val streamShapes = createStreamShapes(streamResultsRaw, boundStreamShape)

streamShapes.getRawSpatialRDD.rdd.mapPartitions { streamedIter =>
streamShapes.mapPartitions { streamedIter =>
val joinedIter = joinType match {
case _: InnerLike =>
innerJoin(streamedIter, broadcastIndex)
Expand All @@ -248,11 +250,40 @@ case class BroadcastIndexJoinExec(

val resultProj = createResultProjection()
joinedIter.map { r =>
numOutputRows += 1
resultProj(r)
}
}
}

private def createStreamShapes(streamResultsRaw: RDD[UnsafeRow], boundStreamShape: Expression) = {
// If there's a distance and the objects are being broadcast, we need to build the expanded envelope on the window stream side
distance match {
case Some(distanceExpression) if indexBuildSide != windowJoinSide =>
streamResultsRaw.map(row => {
val geom = boundStreamShape.eval(row).asInstanceOf[Array[Byte]]
if (geom == null) {
(null, row)
} else {
val geometry = GeometrySerializer.deserialize(geom)
val radius = BindReferences.bindReference(distanceExpression, streamed.output).eval(row).asInstanceOf[Double]
val envelope = geometry.getEnvelopeInternal
envelope.expandBy(radius)
(geometry.getFactory.toGeometry(envelope), row)
}
})
case _ =>
streamResultsRaw.map(row => {
val geom = boundStreamShape.eval(row).asInstanceOf[Array[Byte]]
if (geom == null) {
(null, row)
} else {
(GeometrySerializer.deserialize(geom), row)
}
})
}
}

protected def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan = {
copy(left = newLeft, right = newRight)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1408,7 +1408,7 @@ class BroadcastIndexJoinSuite extends TestBaseScala {
}
}

describe("Sedona SQL automatic broadcast") {
describe("Sedona-SQL Automatic broadcast") {
it("Datasets smaller than threshold should be broadcasted") {
val polygonDf = buildPolygonDf.repartition(3).alias("polygon")
val pointDf = buildPointDf.repartition(5).alias("point")
Expand All @@ -1430,4 +1430,32 @@ class BroadcastIndexJoinSuite extends TestBaseScala {
assert(df.queryExecution.sparkPlan.collect { case p: BroadcastIndexJoinExec => p }.size === 0)
}
}

describe("Sedona-SQL Broadcast join with null geometries") {
it("Left outer join with nulls on left side") {
import sparkSession.implicits._
val left = Seq(("1", "POINT(1 1)"), ("2", "POINT(1 1)"), ("3", "POINT(1 1)"), ("4", null))
.toDF("seq", "left_geom")
.withColumn("left_geom", expr("ST_GeomFromText(left_geom)"))
val right = Seq("POLYGON((2 0, 2 2, 0 2, 0 0, 2 0))")
.toDF("right_geom")
.withColumn("right_geom", expr("ST_GeomFromText(right_geom)"))
val result = left.join(broadcast(right), expr("ST_Intersects(left_geom, right_geom)"), "left")
assert(result.queryExecution.sparkPlan.collect { case p: BroadcastIndexJoinExec => p }.size === 1)
assert(result.count() == 4)
}

it("Left anti join with nulls on left side") {
import sparkSession.implicits._
val left = Seq(("1", "POINT(1 1)"), ("2", "POINT(1 1)"), ("3", "POINT(1 1)"), ("4", null))
.toDF("seq", "left_geom")
.withColumn("left_geom", expr("ST_GeomFromText(left_geom)"))
val right = Seq("POLYGON((2 0, 2 2, 0 2, 0 0, 2 0))")
.toDF("right_geom")
.withColumn("right_geom", expr("ST_GeomFromText(right_geom)"))
val result = left.join(broadcast(right), expr("ST_Intersects(left_geom, right_geom)"), "left_anti")
assert(result.queryExecution.sparkPlan.collect { case p: BroadcastIndexJoinExec => p }.size === 1)
assert(result.count() == 1)
}
}
}