Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SEDONA-171] Add ST_SetPoint to Apache Sedona #694

Merged
merged 5 commits into from
Oct 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions common/src/main/java/org/apache/sedona/common/Functions.java
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,21 @@ public static Geometry removePoint(Geometry linestring, int position) {
return null;
}

public static Geometry setPoint(Geometry linestring, int position, Geometry point) {
if (linestring instanceof LineString) {
List<Coordinate> coordinates = new ArrayList<>(Arrays.asList(linestring.getCoordinates()));
if (-coordinates.size() <= position && position < coordinates.size()) {
if (position < 0) {
coordinates.set(coordinates.size() + position, point.getCoordinate());
} else {
coordinates.set(position, point.getCoordinate());
}
return GEOMETRY_FACTORY.createLineString(coordinates.toArray(new Coordinate[0]));
}
}
return null;
}

public static Geometry lineFromMultiPoint(Geometry geometry) {
if(!(geometry instanceof MultiPoint)) {
return null;
Expand Down
24 changes: 24 additions & 0 deletions docs/api/flink/Function.md
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,30 @@ SELECT ST_RemovePoint(ST_GeomFromText("LINESTRING(0 0, 1 1, 1 0)"), 1)

Output: `LINESTRING(0 0, 1 0)`

## ST_SetPoint

Introduction: Replace Nth point of linestring with given point. Index is 0-based. Negative index are counted backwards, e.g., -1 is last point.

Format: `ST_SetPoint (linestring: geometry, index: integer, point: geometry)`

Since: `v1.3.0`

Example:

```SQL
SELECT ST_SetPoint(ST_GeomFromText('LINESTRING (0 0, 0 1, 1 1)'), 2, ST_GeomFromText('POINT (1 0)')) AS geom
```

Result:

```
+--------------------------------+
| geom |
+--------------------------------+
| LINESTRING (0 0, 0 1, 1 0) |
+--------------------------------+
```

## ST_SetSRID

Introduction: Sets the spatial refence system identifier (SRID) of the geometry.
Expand Down
24 changes: 24 additions & 0 deletions docs/api/sql/Function.md
Original file line number Diff line number Diff line change
Expand Up @@ -1063,6 +1063,30 @@ Result:
+---------------------------------------------------------------+
```

## ST_SetPoint

Introduction: Replace Nth point of linestring with given point. Index is 0-based. Negative index are counted backwards, e.g., -1 is last point.

Format: `ST_SetPoint (linestring: geometry, index: integer, point: geometry)`

Since: `v1.3.0`

Example:

```SQL
SELECT ST_SetPoint(ST_GeomFromText('LINESTRING (0 0, 0 1, 1 1)'), 2, ST_GeomFromText('POINT (1 0)')) AS geom
```

Result:

```
+--------------------------+
|geom |
+--------------------------+
|LINESTRING (0 0, 0 1, 1 0)|
+--------------------------+
```

## ST_SetSRID

Introduction: Sets the spatial refence system identifier (SRID) of the geometry.
Expand Down
1 change: 1 addition & 0 deletions flink/src/main/java/org/apache/sedona/flink/Catalog.java
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ public static UserDefinedFunction[] getFuncs() {
new Functions.ST_Normalize(),
new Functions.ST_AddPoint(),
new Functions.ST_RemovePoint(),
new Functions.ST_SetPoint(),
new Functions.ST_LineFromMultiPoint(),
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,16 @@ public Geometry eval(@DataTypeHint(value = "RAW", bridgedTo = org.locationtech.j
}
}

public static class ST_SetPoint extends ScalarFunction {
@DataTypeHint(value = "RAW", bridgedTo = org.locationtech.jts.geom.Geometry.class)
public Geometry eval(@DataTypeHint(value = "RAW", bridgedTo = org.locationtech.jts.geom.Geometry.class) Object o1, int position,
@DataTypeHint(value = "RAW", bridgedTo = org.locationtech.jts.geom.Geometry.class) Object o2) {
Geometry linestring = (Geometry) o1;
Geometry point = (Geometry) o2;
return org.apache.sedona.common.Functions.setPoint(linestring, position, point);
}
}

public static class ST_LineFromMultiPoint extends ScalarFunction {
@DataTypeHint(value = "RAW", bridgedTo = org.locationtech.jts.geom.Geometry.class)
public Geometry eval(@DataTypeHint(value = "RAW", bridgedTo = org.locationtech.jts.geom.Geometry.class) Object o) {
Expand Down
14 changes: 11 additions & 3 deletions flink/src/test/java/org/apache/sedona/flink/FunctionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -435,28 +435,36 @@ public void testNormalize() {
public void testAddPoint() {
Table pointTable = tableEnv.sqlQuery("SELECT ST_AddPoint(ST_GeomFromWKT('LINESTRING (0 0, 1 1)'), ST_GeomFromWKT('POINT (2 2)'))");
assertEquals("LINESTRING (0 0, 1 1, 2 2)", first(pointTable).getField(0).toString());

}

@Test
public void testAddPointWithIndex() {
Table pointTable = tableEnv.sqlQuery("SELECT ST_AddPoint(ST_GeomFromWKT('LINESTRING (0 0, 1 1)'), ST_GeomFromWKT('POINT (2 2)'), 1)");
assertEquals("LINESTRING (0 0, 2 2, 1 1)", first(pointTable).getField(0).toString());

}

@Test
public void testRemovePoint() {
Table pointTable = tableEnv.sqlQuery("SELECT ST_RemovePoint(ST_GeomFromWKT('LINESTRING (0 0, 1 1, 2 2)'))");
assertEquals("LINESTRING (0 0, 1 1)", first(pointTable).getField(0).toString());

}

@Test
public void testRemovePointWithIndex() {
Table pointTable = tableEnv.sqlQuery("SELECT ST_RemovePoint(ST_GeomFromWKT('LINESTRING (0 0, 1 1, 2 2)'), 1)");
assertEquals("LINESTRING (0 0, 2 2)", first(pointTable).getField(0).toString());
}

@Test
public void testSetPoint() {
Table pointTable = tableEnv.sqlQuery("SELECT ST_SetPoint(ST_GeomFromWKT('LINESTRING (0 0, 1 1, 2 2)'), 0, ST_GeomFromWKT('POINT (3 3)'))");
assertEquals("LINESTRING (3 3, 1 1, 2 2)", first(pointTable).getField(0).toString());
}

@Test
public void testSetPointWithNegativeIndex() {
Table pointTable = tableEnv.sqlQuery("SELECT ST_SetPoint(ST_GeomFromWKT('LINESTRING (0 0, 1 1, 2 2)'), -1, ST_GeomFromWKT('POINT (3 3)'))");
assertEquals("LINESTRING (0 0, 1 1, 3 3)", first(pointTable).getField(0).toString());
}

@Test
Expand Down
17 changes: 17 additions & 0 deletions python/sedona/sql/st_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
"ST_PrecisionReduce",
"ST_RemovePoint",
"ST_Reverse",
"ST_SetPoint",
"ST_SetSRID",
"ST_SRID",
"ST_StartPoint",
Expand Down Expand Up @@ -827,6 +828,22 @@ def ST_Reverse(geometry: ColumnOrName) -> Column:
return _call_st_function("ST_Reverse", geometry)


@validate_argument_types
def ST_SetPoint(line_string: ColumnOrName, index: Union[ColumnOrName, int], point: ColumnOrName) -> Column:
"""Replace a point in a linestring.

:param line_string: Linestring geometry column which contains the point to be replaced.
:type line_string: ColumnOrName
:param index: Index for the point to be replaced, 0-based, negative values start from the end so -1 is the last point.
:type index: Union[ColumnOrName, int]
:param point: Point geometry column to be newly set.
:type point: ColumnOrName
:return: Linestring geometry column with the replaced point, or null if the index is out of bounds.
:rtype: Column
"""
return _call_st_function("ST_SetPoint", (line_string, index, point))


@validate_argument_types
def ST_SetSRID(geometry: ColumnOrName, srid: Union[ColumnOrName, int]) -> Column:
"""Set the SRID for geometry.
Expand Down
4 changes: 4 additions & 0 deletions python/tests/sql/test_dataframe_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
(stf.ST_PrecisionReduce, ("geom", 1), "precision_reduce_point", "", "POINT (0.1 0.2)"),
(stf.ST_RemovePoint, ("line", 1), "linestring_geom", "", "LINESTRING (0 0, 2 0, 3 0, 4 0, 5 0)"),
(stf.ST_Reverse, ("line",), "linestring_geom", "", "LINESTRING (5 0, 4 0, 3 0, 2 0, 1 0, 0 0)"),
(stf.ST_SetPoint, ("line", 1, lambda: f.expr("ST_Point(1.0, 1.0)")), "linestring_geom", "", "LINESTRING (0 0, 1 1, 2 0, 3 0, 4 0, 5 0)"),
(stf.ST_SetSRID, ("point", 3021), "point_geom", "ST_SRID(geom)", 3021),
(stf.ST_SimplifyPreserveTopology, ("geom", 0.2), "0.9_poly", "", "POLYGON ((0 0, 1 0, 1 1, 0 0))"),
(stf.ST_SRID, ("point",), "point_geom", "", 0),
Expand Down Expand Up @@ -226,6 +227,9 @@
(stf.ST_RemovePoint, ("", None)),
(stf.ST_RemovePoint, ("", 1.0)),
(stf.ST_Reverse, (None,)),
(stf.ST_SetPoint, (None, 1, "")),
(stf.ST_SetPoint, ("", None, "")),
(stf.ST_SetPoint, ("", 1, None)),
(stf.ST_SetSRID, (None, 3021)),
(stf.ST_SetSRID, ("", None)),
(stf.ST_SetSRID, ("", 3021.0)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ object Catalog {
ST_NumInteriorRings,
ST_AddPoint,
ST_RemovePoint,
ST_SetPoint,
ST_IsRing,
ST_FlipCoordinates,
ST_LineSubstring,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -892,6 +892,14 @@ case class ST_RemovePoint(inputExpressions: Seq[Expression])
}
}

case class ST_SetPoint(inputExpressions: Seq[Expression])
extends InferredTernaryExpression(Functions.setPoint) {

protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
}
}

case class ST_IsRing(inputExpressions: Seq[Expression])
extends UnaryGeometryExpression with CodegenFallback {
assert(inputExpressions.length == 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,40 @@ abstract class InferredBinaryExpression[A1: InferrableType, A2: InferrableType,
}
}
}

abstract class InferredTernaryExpression[A1: InferrableType, A2: InferrableType, A3: InferrableType, R: InferrableType]
(f: (A1, A2, A3) => R)
(implicit val a1Tag: TypeTag[A1], implicit val a2Tag: TypeTag[A2], implicit val a3Tag: TypeTag[A3], implicit val rTag: TypeTag[R])
extends Expression with ImplicitCastInputTypes with CodegenFallback with Serializable {
import InferredTypes._

def inputExpressions: Seq[Expression]
assert(inputExpressions.length == 3)

override def children: Seq[Expression] = inputExpressions

override def toString: String = s" **${getClass.getName}** "

override def inputTypes: Seq[AbstractDataType] = Seq(inferSparkType[A1], inferSparkType[A2], inferSparkType[A3])

override def nullable: Boolean = true

override def dataType = inferSparkType[R]

lazy val extractFirst = buildExtractor[A1](inputExpressions(0))
lazy val extractSecond = buildExtractor[A2](inputExpressions(1))
lazy val extractThird = buildExtractor[A3](inputExpressions(2))

lazy val serialize = buildSerializer[R]

override def eval(input: InternalRow): Any = {
val first = extractFirst(input)
val second = extractSecond(input)
val third = extractThird(input)
if (first != null && second != null && third != null) {
serialize(f(first, second, third))
} else {
null
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ object st_functions extends DataFrameAPI {
def ST_Reverse(geometry: Column): Column = wrapExpression[ST_Reverse](geometry)
def ST_Reverse(geometry: String): Column = wrapExpression[ST_Reverse](geometry)

def ST_SetPoint(lineString: Column, index: Column, point: Column): Column = wrapExpression[ST_SetPoint](lineString, index, point)
def ST_SetPoint(lineString: String, index: Int, point: String): Column = wrapExpression[ST_SetPoint](lineString, index, point)

def ST_SetSRID(geometry: Column, srid: Column): Column = wrapExpression[ST_SetSRID](geometry, srid)
def ST_SetSRID(geometry: String, srid: Int): Column = wrapExpression[ST_SetSRID](geometry, srid)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,14 @@ class dataFrameAPITestScala extends TestBaseScala {
assert(actualResult == expectedResult)
}

it("Passed ST_SetPoint") {
val baseDf = sparkSession.sql("SELECT ST_GeomFromWKT('LINESTRING (0 0, 1 0)') AS line, ST_Point(1.0, 1.0) AS point")
val df = baseDf.select(ST_SetPoint("line", 1, "point"))
val actualResult = df.take(1)(0).get(0).asInstanceOf[Geometry].toText()
val expectedResult = "LINESTRING (0 0, 1 1)"
assert(actualResult == expectedResult)
}

it("Passed ST_IsRing") {
val baseDf = sparkSession.sql("SELECT ST_GeomFromWKT('LINESTRING (0 0, 1 0, 1 1, 0 0)') AS geom")
val df = baseDf.select(ST_IsRing("geom"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,24 @@ class functionTestScala extends TestBaseScala with Matchers with GeometrySample
calculateStRemovePointOption("MULTILINESTRING ((10 10, 20 20, 10 40, 10 10), (40 40, 30 30, 40 20, 30 10, 40 40))", 3) shouldBe None
}

it("Should correctly set using ST_SetPoint") {
calculateStSetPointOption("Linestring(0 0, 1 1, 1 0, 0 0)", 0, "Point(0 1)") shouldBe Some("LINESTRING (0 1, 1 1, 1 0, 0 0)")
calculateStSetPointOption("Linestring(0 0, 1 1, 1 0, 0 0)", 1, "Point(0 1)") shouldBe Some("LINESTRING (0 0, 0 1, 1 0, 0 0)")
calculateStSetPointOption("Linestring(0 0, 1 1, 1 0, 0 0)", 2, "Point(0 1)") shouldBe Some("LINESTRING (0 0, 1 1, 0 1, 0 0)")
calculateStSetPointOption("Linestring(0 0, 1 1, 1 0, 0 0)", 3, "Point(0 1)") shouldBe Some("LINESTRING (0 0, 1 1, 1 0, 0 1)")
calculateStSetPointOption("Linestring(0 0, 1 1, 1 0, 0 0)", 4, "Point(0 1)") shouldBe None
calculateStSetPointOption("Linestring(0 0, 1 1, 1 0, 0 0)", -1, "Point(0 1)") shouldBe Some("LINESTRING (0 0, 1 1, 1 0, 0 1)")
calculateStSetPointOption("Linestring(0 0, 1 1, 1 0, 0 0)", -2, "Point(0 1)") shouldBe Some("LINESTRING (0 0, 1 1, 0 1, 0 0)")
calculateStSetPointOption("Linestring(0 0, 1 1, 1 0, 0 0)", -3, "Point(0 1)") shouldBe Some("LINESTRING (0 0, 0 1, 1 0, 0 0)")
calculateStSetPointOption("Linestring(0 0, 1 1, 1 0, 0 0)", -4, "Point(0 1)") shouldBe Some("LINESTRING (0 1, 1 1, 1 0, 0 0)")
calculateStSetPointOption("Linestring(0 0, 1 1, 1 0, 0 0)", -5, "Point(0 1)") shouldBe None
calculateStSetPointOption("POINT(0 1)", 0, "Point(0 1)") shouldBe None
calculateStSetPointOption("POLYGON ((0 0, 0 5, 5 5, 5 0, 0 0), (1 1, 2 1, 2 2, 1 2, 1 1))", 0, "Point(0 1)") shouldBe None
calculateStSetPointOption("GEOMETRYCOLLECTION (POINT (40 10), LINESTRING (10 10, 20 20, 10 40))", 0, "Point(0 1)") shouldBe None
calculateStSetPointOption("MULTIPOLYGON (((30 20, 45 40, 10 40, 30 20)), ((15 5, 40 10, 10 20, 5 10, 15 5)))", 0, "Point(0 1)") shouldBe None
calculateStSetPointOption("MULTILINESTRING ((10 10, 20 20, 10 40, 10 10), (40 40, 30 30, 40 20, 30 10, 40 40))", 0, "Point(0 1)") shouldBe None
}

it("Should pass ST_IsRing") {
calculateStIsRing("LINESTRING(0 0, 0 1, 1 0, 1 1, 0 0)") shouldBe Some(false)
calculateStIsRing("LINESTRING(2 0, 2 2, 3 3)") shouldBe Some(false)
Expand Down Expand Up @@ -1152,6 +1170,15 @@ class functionTestScala extends TestBaseScala with Matchers with GeometrySample
.filter("geom is not null")
.selectExpr("ST_AsText(geom)").as[String].collect()

private def calculateStSetPointOption(wktA: String, index: Int, wktB: String): Option[String] =
calculateStSetPoint(wktA, index, wktB).headOption

private def calculateStSetPoint(wktA: String, index: Int, wktB: String): Array[String] =
Seq(Tuple3(wktReader.read(wktA), index, wktReader.read(wktB))).toDF("geomA", "index", "geomB")
.selectExpr(s"ST_SetPoint(geomA, index, geomB) as geom")
.filter("geom is not null")
.selectExpr("ST_AsText(geom)").as[String].collect()

it("Passed ST_NumGeometries") {
Given("Some different types of geometries in a DF")
// Test data
Expand Down