Skip to content

Commit

Permalink
[SEDONA-369] Add ST_DWithin (#1175)
Browse files Browse the repository at this point in the history
* Add ST_DWithin

* Add documentation for ST_DWithin

* Remove unwanted code

* removed null check test for ST_DWithin

* Fix EOF lint error

* Add explanation for ST_DWithin

* Remove CRS checking logic in ST_DWithin
  • Loading branch information
iGN5117 authored Jan 2, 2024
1 parent a7b6f6e commit 1cc4c82
Show file tree
Hide file tree
Showing 17 changed files with 241 additions and 19 deletions.
3 changes: 3 additions & 0 deletions common/src/main/java/org/apache/sedona/common/Predicates.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,7 @@ public static boolean disjoint(Geometry leftGeometry, Geometry rightGeometry) {
public static boolean orderingEquals(Geometry leftGeometry, Geometry rightGeometry) {
return leftGeometry.equalsExact(rightGeometry);
}
public static boolean dWithin(Geometry leftGeometry, Geometry rightGeometry, double distance) {
return leftGeometry.isWithinDistance(rightGeometry, distance);
}
}
18 changes: 1 addition & 17 deletions common/src/test/java/org/apache/sedona/common/FunctionsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

import static org.junit.Assert.*;

public class FunctionsTest {
public class FunctionsTest extends TestBase {
public static final GeometryFactory GEOMETRY_FACTORY = new GeometryFactory();

protected static final double FP_TOLERANCE = 1e-12;
Expand All @@ -53,22 +53,6 @@ protected int compareCoordinate(CoordinateSequence s1, CoordinateSequence s2, in

private final WKTReader wktReader = new WKTReader();

private Coordinate[] coordArray(double... coordValues) {
Coordinate[] coords = new Coordinate[(int)(coordValues.length / 2)];
for (int i = 0; i < coordValues.length; i += 2) {
coords[(int)(i / 2)] = new Coordinate(coordValues[i], coordValues[i+1]);
}
return coords;
}

private Coordinate[] coordArray3d(double... coordValues) {
Coordinate[] coords = new Coordinate[(int)(coordValues.length / 3)];
for (int i = 0; i < coordValues.length; i += 3) {
coords[(int)(i / 3)] = new Coordinate(coordValues[i], coordValues[i+1], coordValues[i+2]);
}
return coords;
}

@Test
public void asEWKT() throws Exception{
GeometryFactory geometryFactory = new GeometryFactory(new PrecisionModel(), 4236);
Expand Down
66 changes: 66 additions & 0 deletions common/src/test/java/org/apache/sedona/common/PredicatesTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sedona.common;

import org.junit.Test;
import org.locationtech.jts.geom.Coordinate;
import org.locationtech.jts.geom.Geometry;
import org.locationtech.jts.geom.GeometryFactory;

import static org.junit.Assert.*;

public class PredicatesTest extends TestBase {

private static final GeometryFactory GEOMETRY_FACTORY = new GeometryFactory();

@Test
public void testDWithinSuccess() {
Geometry point1 = GEOMETRY_FACTORY.createPoint(new Coordinate(1, 1));
Geometry point2 = GEOMETRY_FACTORY.createPoint(new Coordinate(2, 2));
double distance = 1.42;
boolean actual = Predicates.dWithin(point1, point2, distance);
assertTrue(actual);
}

@Test
public void testDWithinFailure() {
Geometry polygon1 = GEOMETRY_FACTORY.createPolygon(coordArray(0, 0, 0, 1, 1, 1, 1, 0, 0, 0));
Geometry polygon2 = GEOMETRY_FACTORY.createPolygon(coordArray(3, 0, 3, 3, 6, 3, 6, 0, 3, 0));


double distance = 1.2;
boolean actual = Predicates.dWithin(polygon1, polygon2, distance);
assertFalse(actual);
}

@Test
public void testDWithinGeomCollection() {
Geometry polygon1 = GEOMETRY_FACTORY.createPolygon(coordArray(0, 0, 0, 1, 1, 1, 1, 0, 0, 0));
Geometry polygon2 = GEOMETRY_FACTORY.createPolygon(coordArray(3, 0, 3, 3, 6, 3, 6, 0, 3, 0));
Geometry point = GEOMETRY_FACTORY.createPoint(new Coordinate(1.1, 0));
Geometry geometryCollection = GEOMETRY_FACTORY.createGeometryCollection(new Geometry[] {polygon2, point});


double distance = 1.2;
boolean actual = Predicates.dWithin(polygon1, geometryCollection, distance);
assertTrue(actual);
}


}
39 changes: 39 additions & 0 deletions common/src/test/java/org/apache/sedona/common/TestBase.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sedona.common;

import org.locationtech.jts.geom.Coordinate;

public class TestBase {
public Coordinate[] coordArray(double... coordValues) {
Coordinate[] coords = new Coordinate[(int)(coordValues.length / 2)];
for (int i = 0; i < coordValues.length; i += 2) {
coords[(int)(i / 2)] = new Coordinate(coordValues[i], coordValues[i+1]);
}
return coords;
}

public Coordinate[] coordArray3d(double... coordValues) {
Coordinate[] coords = new Coordinate[(int)(coordValues.length / 3)];
for (int i = 0; i < coordValues.length; i += 3) {
coords[(int)(i / 3)] = new Coordinate(coordValues[i], coordValues[i+1], coordValues[i+2]);
}
return coords;
}
}
20 changes: 20 additions & 0 deletions docs/api/flink/Predicate.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,26 @@ Output:
true
```

## ST_DWithin

Introduction: Returns true if 'leftGeometry' and 'rightGeometry' are within a specified 'distance'. This function essentially checks if the shortest distance between the envelope of the two geometries is <= the provided distance.

Format: `ST_DWithin (leftGeometry: Geometry, rightGeometry: Geometry, distance: Double)`

Since: `v1.5.1`

Example:

```sql
SELECT ST_DWithin(ST_GeomFromWKT('POINT (0 0)'), ST_GeomFromWKT('POINT (1 0)'), 2.5)
```

Output:

```
true
```

## ST_Equals

Introduction: Return true if A equals to B
Expand Down
20 changes: 20 additions & 0 deletions docs/api/sql/Predicate.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,26 @@ Output:
true
```

## ST_DWithin

Introduction: Returns true if 'leftGeometry' and 'rightGeometry' are within a specified 'distance'. This function essentially checks if the shortest distance between the envelope of the two geometries is <= the provided distance.

Format: `ST_DWithin (leftGeometry: Geometry, rightGeometry: Geometry, distance: Double)`

Since: `v1.5.1`

Spark SQL Example:

```sql
SELECT ST_DWithin(ST_GeomFromWKT('POINT (0 0)'), ST_GeomFromWKT('POINT (1 0)'), 2.5)
```

Output:

```
true
```

## ST_Equals

Introduction: Return true if A equals to B
Expand Down
1 change: 1 addition & 0 deletions flink/src/main/java/org/apache/sedona/flink/Catalog.java
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ public static UserDefinedFunction[] getPredicates() {
new Predicates.ST_OrderingEquals(),
new Predicates.ST_Overlaps(),
new Predicates.ST_Touches(),
new Predicates.ST_DWithin()
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -233,4 +233,20 @@ public Boolean eval(@DataTypeHint(value = "RAW", bridgedTo = org.locationtech.jt
return org.apache.sedona.common.Predicates.touches(geom1, geom2);
}
}

public static class ST_DWithin
extends ScalarFunction {

public ST_DWithin() {

}

@DataTypeHint("Boolean")
public Boolean eval(@DataTypeHint(value = "RAW", bridgedTo = org.locationtech.jts.geom.Geometry.class) Object o1, @DataTypeHint(value = "RAW", bridgedTo = org.locationtech.jts.geom.Geometry.class) Object o2, @DataTypeHint("Double") Double distance)
{
Geometry geom1 = (Geometry) o1;
Geometry geom2 = (Geometry) o2;
return org.apache.sedona.common.Predicates.dWithin(geom1, geom2, distance);
}
}
}
23 changes: 23 additions & 0 deletions flink/src/test/java/org/apache/sedona/flink/PredicateTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,23 @@
*/
package org.apache.sedona.flink;

import org.apache.calcite.runtime.Geometries;
import org.apache.flink.table.api.Table;
import org.apache.sedona.common.utils.GeomUtils;
import org.apache.sedona.flink.expressions.Predicates;
import org.junit.BeforeClass;
import org.junit.Test;
import org.locationtech.jts.geom.Geometry;
import org.locationtech.jts.geom.GeometryFactory;

import static org.junit.Assert.assertEquals;
import static org.apache.flink.table.api.Expressions.$;
import static org.apache.flink.table.api.Expressions.call;
import static org.junit.Assert.assertThrows;

public class PredicateTest extends TestBase{

private static final GeometryFactory GEOMETRY_FACTORY = new GeometryFactory();
@BeforeClass
public static void onceExecutedBeforeAll() {
initialize();
Expand Down Expand Up @@ -122,4 +129,20 @@ public void testTouches() {
Boolean actual = (Boolean) first(table).getField(0);
assertEquals(true, actual);
}

@Test
public void testDWithin() {
Table table = tableEnv.sqlQuery("SELECT ST_GeomFromWKT('POINT (0 0)') as origin, ST_GeomFromWKT('POINT (1 0)') as p1");
table = table.select(call(Predicates.ST_DWithin.class.getSimpleName(), $("origin"), $("p1"), 1));
Boolean actual = (Boolean) first(table).getField(0);
assertEquals(true, actual);
}

@Test
public void testDWithinFailure() {
Table table = tableEnv.sqlQuery("SELECT ST_GeomFromWKT('POINT (0 0)') as origin, ST_GeomFromWKT('POINT (5 0)') as p1");
table = table.select(call(Predicates.ST_DWithin.class.getSimpleName(), $("origin"), $("p1"), 2));
Boolean actual = (Boolean) first(table).getField(0);
assertEquals(false, actual);
}
}
14 changes: 14 additions & 0 deletions python/sedona/sql/st_predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from functools import partial

from pyspark.sql import Column
from typing import Union

from sedona.sql.dataframe_api import ColumnOrName, call_sedona_function, validate_argument_types

Expand All @@ -32,6 +33,7 @@
"ST_Overlaps",
"ST_Touches",
"ST_Within",
"ST_DWithin"
]


Expand Down Expand Up @@ -190,3 +192,15 @@ def ST_CoveredBy(a: ColumnOrName, b: ColumnOrName) -> Column:
:rtype: Column
"""
return _call_predicate_function("ST_CoveredBy", (a, b))

@validate_argument_types
def ST_DWithin(a: ColumnOrName, b: ColumnOrName, distance: Union[ColumnOrName, float]):
"""
Check if geometry a is within 'distance' units of geometry b
:param a: Geometry column to check
:param b: Geometry column to check
:param distance: distance units to check the within predicate
:return: True if a is within distance units of Geometry b
"""

return _call_predicate_function("ST_DWithin", (a, b, distance))
3 changes: 3 additions & 0 deletions python/tests/sql/test_dataframe_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@
(stp.ST_Within, (lambda: f.expr("ST_Point(0.0, 0.0)"), "geom"), "triangle_geom", "", False),
(stp.ST_Covers, ("geom", lambda: f.expr("ST_Point(0.0, 0.0)")), "triangle_geom", "", True),
(stp.ST_CoveredBy, (lambda: f.expr("ST_Point(0.0, 0.0)"), "geom"), "triangle_geom", "", True),
(stp.ST_DWithin, ("origin", "point", 5.0), "origin_and_point", "", True),

# aggregates
(sta.ST_Envelope_Aggr, ("geom",), "exploded_points", "", "POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))"),
Expand Down Expand Up @@ -423,6 +424,8 @@ def base_df(self, request):
return TestDataFrameAPI.spark.sql("SELECT ST_GeomFromWKT('GEOMETRYCOLLECTION(POINT(1 1), LINESTRING(0 0, 1 1, 2 2))') AS geom")
elif request.param == "point_and_line":
return TestDataFrameAPI.spark.sql("SELECT ST_GeomFromWKT('POINT (0.0 1.0)') AS point, ST_GeomFromWKT('LINESTRING (0 0, 1 0, 2 0, 3 0, 4 0, 5 0)') AS line")
elif request.param == "origin_and_point":
return TestDataFrameAPI.spark.sql("SELECT ST_GeomFromWKT('POINT (0 0)') AS origin, ST_GeomFromWKT('POINT (1 0)') as point")
raise ValueError(f"Invalid base_df name passed: {request.param}")

def _id_test_configuration(val):
Expand Down
6 changes: 6 additions & 0 deletions python/tests/sql/test_predicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,9 @@ def test_st_ordering_equals_ok(self):
assert order_equals.take(1)[0][0]
assert not not_order_equals_diff_geom.take(1)[0][0]
assert not not_order_equals_diff_order.take(1)[0][0]

def test_st_dwithin(self):
test_table = self.spark.sql("select ST_GeomFromWKT('POINT (0 0)') as origin, ST_GeomFromWKT('POINT (2 0)') as point_1")
test_table.createOrReplaceTempView("test_table")
isWithin = self.spark.sql("select ST_DWithin(origin, point_1, 3) from test_table").head()[0]
assert isWithin is True
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ object Catalog {
function[ST_Angle](),
function[ST_Degrees](),
function[ST_HausdorffDistance](-1),
function[ST_DWithin](),
// Expression for rasters
function[RS_NormalizedDifference](),
function[RS_Mean](),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression,
import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType}
import org.locationtech.jts.geom.Geometry
import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._

abstract class ST_Predicate extends Expression
with FoldableExpression
Expand Down Expand Up @@ -251,3 +252,11 @@ case class ST_OrderingEquals(inputExpressions: Seq[Expression])
copy(inputExpressions = newChildren)
}
}

case class ST_DWithin(inputExpressions: Seq[Expression])
extends InferredExpression(Predicates.dWithin _) {

protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,7 @@ object st_predicates extends DataFrameAPI {

def ST_CoveredBy(a: Column, b: Column): Column = wrapExpression[ST_CoveredBy](a, b)
def ST_CoveredBy(a: String, b: String): Column = wrapExpression[ST_CoveredBy](a, b)
def ST_DWithin(a: Column, b: Column, distance: Column): Column = wrapExpression[ST_DWithin](a, b, distance)

def ST_DWithin(a: String, b: String, distance: Double): Column = wrapExpression[ST_DWithin](a, b, distance)
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.sedona_sql.expressions.st_aggregates._
import org.apache.spark.sql.sedona_sql.expressions.st_constructors._
import org.apache.spark.sql.sedona_sql.expressions.st_functions._
import org.apache.spark.sql.sedona_sql.expressions.st_predicates._
import org.junit.Assert.assertEquals
import org.junit.Assert.{assertEquals, assertTrue}
import org.locationtech.jts.geom.{Geometry, Polygon}
import org.locationtech.jts.io.WKTWriter
import org.locationtech.jts.operation.buffer.BufferParameters
Expand Down Expand Up @@ -1243,5 +1243,12 @@ class dataFrameAPITestScala extends TestBaseScala {
val actual = df.take(1)(0).get(0).asInstanceOf[String]
assert(expected == actual)
}

it("Passed ST_DWithin") {
val pointDf = sparkSession.sql("SELECT ST_GeomFromWKT('POINT (0 0)') as origin, ST_GeomFromWKT('POINT (1 0)') as point")
val df = pointDf.select(ST_DWithin("origin", "point", 2.0))
val actual = df.head()(0).asInstanceOf[Boolean]
assertTrue(actual)
}
}
}
Loading

0 comments on commit 1cc4c82

Please sign in to comment.