Skip to content

Commit

Permalink
[SEDONA-231] Redundant Serde Elimination (#792)
Browse files Browse the repository at this point in the history
  • Loading branch information
douglasdennis authored Mar 15, 2023
1 parent a7a2581 commit 7dfa2c3
Show file tree
Hide file tree
Showing 12 changed files with 243 additions and 44 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@
/.vscode/
.Rproj.user
__pycache__
/.bsp
/.scala-build
18 changes: 15 additions & 3 deletions common/src/main/java/org/apache/sedona/common/Functions.java
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,11 @@ public static double azimuth(Geometry left, Geometry right) {
}

public static Geometry boundary(Geometry geometry) {
return geometry.getBoundary();
Geometry boundary = geometry.getBoundary();
if (boundary instanceof LinearRing) {
boundary = GEOMETRY_FACTORY.createLineString(boundary.getCoordinates());
}
return boundary;
}

public static Geometry buffer(Geometry geometry, double radius) {
Expand Down Expand Up @@ -236,7 +240,11 @@ public static Geometry interiorRingN(Geometry geometry, int n) {
if (geometry instanceof Polygon) {
Polygon polygon = (Polygon) geometry;
if (n < polygon.getNumInteriorRing()) {
return polygon.getInteriorRingN(n);
Geometry interiorRing = polygon.getInteriorRingN(n);
if (interiorRing instanceof LinearRing) {
interiorRing = GEOMETRY_FACTORY.createLineString(interiorRing.getCoordinates());
}
return interiorRing;
}
}
return null;
Expand All @@ -250,7 +258,11 @@ public static Geometry pointN(Geometry geometry, int n) {
}

public static Geometry exteriorRing(Geometry geometry) {
return GeomUtils.getExteriorRing(geometry);
Geometry ring = GeomUtils.getExteriorRing(geometry);
if (ring instanceof LinearRing) {
ring = GEOMETRY_FACTORY.createLineString(ring.getCoordinates());
}
return ring;
}

public static String asEWKT(Geometry geometry) {
Expand Down
14 changes: 7 additions & 7 deletions flink/src/test/java/org/apache/sedona/flink/FunctionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import org.junit.BeforeClass;
import org.junit.Test;
import org.locationtech.jts.geom.Geometry;
import org.locationtech.jts.geom.LinearRing;
import org.locationtech.jts.geom.LineString;
import org.locationtech.jts.geom.Point;
import org.locationtech.jts.geom.Polygon;
import org.opengis.referencing.FactoryException;
Expand Down Expand Up @@ -62,7 +62,7 @@ public void testBoundary() {
Table polygonTable = tableEnv.sqlQuery("SELECT ST_GeomFromWKT('POLYGON ((1 1, 0 0, -1 1, 1 1))') AS geom");
Table boundaryTable = polygonTable.select(call(Functions.ST_Boundary.class.getSimpleName(), $("geom")));
Geometry result = (Geometry) first(boundaryTable).getField(0);
assertEquals("LINEARRING (1 1, 0 0, -1 1, 1 1)", result.toString());
assertEquals("LINESTRING (1 1, 0 0, -1 1, 1 1)", result.toString());
}

@Test
Expand Down Expand Up @@ -221,8 +221,8 @@ public void testGeometryN() {
public void testInteriorRingN() {
Table polygonTable = tableEnv.sqlQuery("SELECT ST_GeomFromText('POLYGON((7 9,8 7,11 6,15 8,16 6,17 7,17 10,18 12,17 14,15 15,11 15,10 13,9 12,7 9),(9 9,10 10,11 11,11 10,10 8,9 9),(12 14,15 14,13 11,12 14))') AS polygon");
Table resultTable = polygonTable.select(call(Functions.ST_InteriorRingN.class.getSimpleName(), $("polygon"), 1));
LinearRing linearRing = (LinearRing) first(resultTable).getField(0);
assertEquals("LINEARRING (12 14, 15 14, 13 11, 12 14)", linearRing.toString());
LineString lineString = (LineString) first(resultTable).getField(0);
assertEquals("LINESTRING (12 14, 15 14, 13 11, 12 14)", lineString.toString());
}

@Test
Expand Down Expand Up @@ -272,9 +272,9 @@ public void testNumInteriorRings() {
public void testExteriorRing() {
Table polygonTable = createPolygonTable(1);
Table linearRingTable = polygonTable.select(call(Functions.ST_ExteriorRing.class.getSimpleName(), $(polygonColNames[0])));
LinearRing linearRing = (LinearRing) first(linearRingTable).getField(0);
assertNotNull(linearRing);
Assert.assertEquals("LINEARRING (-0.5 -0.5, -0.5 0.5, 0.5 0.5, 0.5 -0.5, -0.5 -0.5)", linearRing.toString());
LineString lineString = (LineString) first(linearRingTable).getField(0);
assertNotNull(lineString);
Assert.assertEquals("LINESTRING (-0.5 -0.5, -0.5 0.5, 0.5 0.5, 0.5 -0.5, -0.5 -0.5)", lineString.toString());
}

@Test
Expand Down
6 changes: 6 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,12 @@
<artifactId>s2-geometry</artifactId>
<version>${googles2.version}</version>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-inline</artifactId>
<version>4.11.0</version>
<scope>test</scope>
</dependency>
</dependencies>
</dependencyManagement>
<repositories>
Expand Down
4 changes: 4 additions & 0 deletions sql/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@
<groupId>org.scalatest</groupId>
<artifactId>scalatest_${scala.compat.version}</artifactId>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-inline</artifactId>
</dependency>
</dependencies>
<build>
<sourceDirectory>src/main/scala</sourceDirectory>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,3 @@ class GeometryUDT extends UserDefinedType[Geometry] {
}

case object GeometryUDT extends org.apache.spark.sql.sedona_sql.UDT.GeometryUDT with scala.Serializable

Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,9 @@ case class ST_StartPoint(inputExpressions: Seq[Expression])

override protected def nullSafeEval(geometry: Geometry): Any = {
geometry match {
case line: LineString => line.getPointN(0).toGenericArrayData
case line: LineString => {
line.getPointN(0)
}
case _ => null
}
}
Expand All @@ -473,11 +475,23 @@ case class ST_Boundary(inputExpressions: Seq[Expression])


case class ST_MinimumBoundingRadius(inputExpressions: Seq[Expression])
extends UnaryGeometryExpression with FoldableExpression with CodegenFallback {
extends Expression with FoldableExpression with CodegenFallback {

override def nullable: Boolean = true

private val geometryFactory = new GeometryFactory()

override protected def nullSafeEval(geometry: Geometry): Any = {
getMinimumBoundingRadius(geometry)
override def eval(input: InternalRow): Any = {
val expr = inputExpressions(0)
val geometry = expr match {
case s: SerdeAware => s.evalWithoutSerialization(input)
case _ => expr.toGeometry(input)
}

geometry match {
case geometry: Geometry => getMinimumBoundingRadius(geometry)
case _ => null
}
}

private def getMinimumBoundingRadius(geom: Geometry): InternalRow = {
Expand Down Expand Up @@ -545,7 +559,7 @@ case class ST_EndPoint(inputExpressions: Seq[Expression])

override protected def nullSafeEval(geometry: Geometry): Any = {
geometry match {
case string: LineString => string.getEndPoint.toGenericArrayData
case string: LineString => string.getEndPoint
case _ => null
}
}
Expand Down Expand Up @@ -588,16 +602,24 @@ case class ST_Dump(inputExpressions: Seq[Expression])
extends UnaryGeometryExpression with FoldableExpression with CodegenFallback {

override protected def nullSafeEval(geometry: Geometry): Any = {
val geometryCollection = geometry match {
geometry match {
case collection: GeometryCollection => {
val numberOfGeometries = collection.getNumGeometries
(0 until numberOfGeometries).map(
index => collection.getGeometryN(index).toGenericArrayData
index => collection.getGeometryN(index)
).toArray
}
case geom: Geometry => Array(geom.toGenericArrayData)
case geom: Geometry => Array(geom)
}
}

override protected def serializeResult(result: Any): Any = {
result match {
case array: Array[Geometry] => ArrayData.toArrayData(
array.map(_.toGenericArrayData)
)
case _ => null
}
ArrayData.toArrayData(geometryCollection)
}

override def dataType: DataType = ArrayType(GeometryUDT)
Expand All @@ -613,7 +635,17 @@ case class ST_DumpPoints(inputExpressions: Seq[Expression])
extends UnaryGeometryExpression with FoldableExpression with CodegenFallback {

override protected def nullSafeEval(geometry: Geometry): Any = {
ArrayData.toArrayData(geometry.getPoints.map(geom => geom.toGenericArrayData))
geometry.getPoints.map(geom => geom).toArray
}

override protected def serializeResult(result: Any): Any = {
result match {
case array: Array[Geometry] => ArrayData.toArrayData(
array.map(geom => geom.toGenericArrayData)
)
case _ => null
}

}

override def dataType: DataType = ArrayType(GeometryUDT)
Expand Down Expand Up @@ -842,7 +874,7 @@ case class ST_SymDifference(inputExpressions: Seq[Expression])
extends BinaryGeometryExpression with FoldableExpression with CodegenFallback {

override protected def nullSafeEval(leftGeometry: Geometry, rightGeometry: Geometry): Any = {
leftGeometry.symDifference(rightGeometry).toGenericArrayData
leftGeometry.symDifference(rightGeometry)
}

override def dataType: DataType = GeometryUDT
Expand All @@ -863,7 +895,7 @@ case class ST_Union(inputExpressions: Seq[Expression])
extends BinaryGeometryExpression with FoldableExpression with CodegenFallback {

override protected def nullSafeEval(leftGeometry: Geometry, rightGeometry: Geometry): Any = {
leftGeometry.union(rightGeometry).toGenericArrayData
leftGeometry.union(rightGeometry)
}

override def dataType: DataType = GeometryUDT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,42 +38,81 @@ trait FoldableExpression extends Expression {
override def foldable: Boolean = children.forall(_.foldable)
}

abstract class UnaryGeometryExpression extends Expression with ExpectsInputTypes {
abstract class UnaryGeometryExpression extends Expression with SerdeAware with ExpectsInputTypes {
def inputExpressions: Seq[Expression]

override def nullable: Boolean = true

override def inputTypes: Seq[AbstractDataType] = Seq(GeometryUDT)

override def eval(input: InternalRow): Any = {
val geometry = inputExpressions.head.toGeometry(input)
val result = evalWithoutSerialization(input)
serializeResult(result)
}

override def evalWithoutSerialization(input: InternalRow): Any ={
val inputExpression = inputExpressions.head
val geometry = inputExpression match {
case expr: SerdeAware => expr.evalWithoutSerialization(input)
case expr: Any => expr.toGeometry(input)
}

(geometry) match {
case (geometry: Geometry) => nullSafeEval(geometry)
case _ => null
}
}

protected def serializeResult(result: Any): Any = {
result match {
case geometry: Geometry => geometry.toGenericArrayData
case _ => result
}
}

protected def nullSafeEval(geometry: Geometry): Any


}

abstract class BinaryGeometryExpression extends Expression with ExpectsInputTypes {
abstract class BinaryGeometryExpression extends Expression with SerdeAware with ExpectsInputTypes {
def inputExpressions: Seq[Expression]

override def nullable: Boolean = true

override def inputTypes: Seq[AbstractDataType] = Seq(GeometryUDT, GeometryUDT)

override def eval(input: InternalRow): Any = {
val leftGeometry = inputExpressions(0).toGeometry(input)
val rightGeometry = inputExpressions(1).toGeometry(input)
val result = evalWithoutSerialization(input)
serializeResult(result)
}

override def evalWithoutSerialization(input: InternalRow): Any = {
val leftExpression = inputExpressions(0)
val leftGeometry = leftExpression match {
case expr: SerdeAware => expr.evalWithoutSerialization(input)
case _ => leftExpression.toGeometry(input)
}

val rightExpression = inputExpressions(1)
val rightGeometry = rightExpression match {
case expr: SerdeAware => expr.evalWithoutSerialization(input)
case _ => rightExpression.toGeometry(input)
}

(leftGeometry, rightGeometry) match {
case (leftGeometry: Geometry, rightGeometry: Geometry) => nullSafeEval(leftGeometry, rightGeometry)
case _ => null
}
}

protected def serializeResult(result: Any): Any = {
result match {
case geometry: Geometry => geometry.toGenericArrayData
case _ => result
}
}

protected def nullSafeEval(leftGeometry: Geometry, rightGeometry: Geometry): Any
}

Expand Down Expand Up @@ -168,7 +207,7 @@ object InferredTypes {
abstract class InferredUnaryExpression[A1: InferrableType, R: InferrableType]
(f: (A1) => R)
(implicit val a1Tag: TypeTag[A1], implicit val rTag: TypeTag[R])
extends Expression with ImplicitCastInputTypes with CodegenFallback with Serializable {
extends Expression with ImplicitCastInputTypes with SerdeAware with CodegenFallback with Serializable {
import InferredTypes._

def inputExpressions: Seq[Expression]
Expand All @@ -187,10 +226,12 @@ abstract class InferredUnaryExpression[A1: InferrableType, R: InferrableType]

lazy val serialize = buildSerializer[R]

override def eval(input: InternalRow): Any = {
override def eval(input: InternalRow): Any = serialize(evalWithoutSerialization(input).asInstanceOf[R])

override def evalWithoutSerialization(input: InternalRow): Any = {
val value = extract(input)
if (value != null) {
serialize(f(value))
f(value)
} else {
null
}
Expand All @@ -200,7 +241,7 @@ abstract class InferredUnaryExpression[A1: InferrableType, R: InferrableType]
abstract class InferredBinaryExpression[A1: InferrableType, A2: InferrableType, R: InferrableType]
(f: (A1, A2) => R)
(implicit val a1Tag: TypeTag[A1], implicit val a2Tag: TypeTag[A2], implicit val rTag: TypeTag[R])
extends Expression with ImplicitCastInputTypes with CodegenFallback with Serializable {
extends Expression with ImplicitCastInputTypes with SerdeAware with CodegenFallback with Serializable {
import InferredTypes._

def inputExpressions: Seq[Expression]
Expand All @@ -220,11 +261,13 @@ abstract class InferredBinaryExpression[A1: InferrableType, A2: InferrableType,

lazy val serialize = buildSerializer[R]

override def eval(input: InternalRow): Any = {
override def eval(input: InternalRow): Any = serialize(evalWithoutSerialization(input).asInstanceOf[R])

override def evalWithoutSerialization(input: InternalRow): Any = {
val left = extractLeft(input)
val right = extractRight(input)
if (left != null && right != null) {
serialize(f(left, right))
f(left, right)
} else {
null
}
Expand All @@ -234,7 +277,7 @@ abstract class InferredBinaryExpression[A1: InferrableType, A2: InferrableType,
abstract class InferredTernaryExpression[A1: InferrableType, A2: InferrableType, A3: InferrableType, R: InferrableType]
(f: (A1, A2, A3) => R)
(implicit val a1Tag: TypeTag[A1], implicit val a2Tag: TypeTag[A2], implicit val a3Tag: TypeTag[A3], implicit val rTag: TypeTag[R])
extends Expression with ImplicitCastInputTypes with CodegenFallback with Serializable {
extends Expression with ImplicitCastInputTypes with SerdeAware with CodegenFallback with Serializable {
import InferredTypes._

def inputExpressions: Seq[Expression]
Expand All @@ -255,12 +298,14 @@ abstract class InferredTernaryExpression[A1: InferrableType, A2: InferrableType,

lazy val serialize = buildSerializer[R]

override def eval(input: InternalRow): Any = {
override def eval(input: InternalRow): Any = serialize(evalWithoutSerialization(input).asInstanceOf[R])

override def evalWithoutSerialization(input: InternalRow): Any = {
val first = extractFirst(input)
val second = extractSecond(input)
val third = extractThird(input)
if (first != null && second != null && third != null) {
serialize(f(first, second, third))
f(first, second, third)
} else {
null
}
Expand Down
Loading

0 comments on commit 7dfa2c3

Please sign in to comment.