Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SEDONA-280] Add ST_GeometricMedian #831

Merged
merged 12 commits into from
May 18, 2023
160 changes: 155 additions & 5 deletions common/src/main/java/org/apache/sedona/common/Functions.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.locationtech.jts.algorithm.MinimumBoundingCircle;
import org.locationtech.jts.algorithm.hull.ConcaveHull;
import org.locationtech.jts.geom.*;
import org.locationtech.jts.geom.impl.CoordinateArraySequence;
import org.locationtech.jts.geom.util.GeometryFixer;
import org.locationtech.jts.io.gml2.GMLWriter;
import org.locationtech.jts.io.kml.KMLWriter;
Expand All @@ -42,18 +43,18 @@
import org.opengis.referencing.operation.TransformException;
import org.wololo.jts2geojson.GeoJSONWriter;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static com.google.common.geometry.S2.DBL_EPSILON;


public class Functions {
private static final GeometryFactory GEOMETRY_FACTORY = new GeometryFactory();
private static Geometry EMPTY_POLYGON = GEOMETRY_FACTORY.createPolygon(null, null);
private static GeometryCollection EMPTY_GEOMETRY_COLLECTION = GEOMETRY_FACTORY.createGeometryCollection(null);
private static final double DEFAULT_TOLERANCE = 1e-6;
private static final int DEFAULT_MAX_ITER = 1000;

public static double area(Geometry geometry) {
return geometry.getArea();
Expand Down Expand Up @@ -730,4 +731,153 @@ public static Geometry collectionExtract(Geometry geometry) {
return GEOMETRY_FACTORY.createGeometryCollection();
}


// ported from https://github.com/postgis/postgis/blob/f6ed58d1fdc865d55d348212d02c11a10aeb2b30/liblwgeom/lwgeom_median.c
// geometry ST_GeometricMedian ( geometry g , float8 tolerance , int max_iter , boolean fail_if_not_converged );

private static double distance3d(Coordinate p1, Coordinate p2) {
double dx = p2.x - p1.x;
double dy = p2.y - p1.y;
double dz = p2.z - p1.z;
return Math.sqrt(dx * dx + dy * dy + dz * dz);
}

private static void distances(Coordinate curr, Coordinate[] points, double[] distances) {
for(int i = 0; i < points.length; i++) {
distances[i] = distance3d(curr, points[i]);
}
}

private static double iteratePoints(Coordinate curr, Coordinate[] points, double[] distances) {
Coordinate next = new Coordinate(0, 0, 0);
double delta = 0;
double denom = 0;
boolean hit = false;
distances(curr, points, distances);

for (int i = 0; i < points.length; i++) {
/* we need to use lower epsilon than in FP_IS_ZERO in the loop for calculation to converge */
double distance = distances[i];
if (distance > DBL_EPSILON) {
Coordinate coordinate = points[i];
next.x += coordinate.x / distance;
next.y += coordinate.y / distance;
next.z += coordinate.z / distance;
denom += 1.0 / distance;
} else {
hit = true;
}
}
/* negative weight shouldn't get here */
//assert(denom >= 0);

/* denom is zero in case of multipoint of single point when we've converged perfectly */
if (denom > DBL_EPSILON) {
next.x /= denom;
next.y /= denom;
next.z /= denom;

/* If any of the intermediate points in the calculation is found in the
* set of input points, the standard Weiszfeld method gets stuck with a
* divide-by-zero.
*
* To get ourselves out of the hole, we follow an alternate procedure to
* get the next iteration, as described in:
*
* Vardi, Y. and Zhang, C. (2011) "A modified Weiszfeld algorithm for the
* Fermat-Weber location problem." Math. Program., Ser. A 90: 559-566.
* DOI 10.1007/s101070100222
*
* Available online at the time of this writing at
* http://www.stat.rutgers.edu/home/cunhui/papers/43.pdf
*/
if (hit) {
double dx = 0;
double dy = 0;
double dz = 0;
for (int i = 0; i < points.length; i++) {
double distance = distances[i];
if (distance > DBL_EPSILON) {
Coordinate coordinate = points[i];
dx += (coordinate.x - curr.x) / distance;
dy += (coordinate.y - curr.y) / distance;
dz += (coordinate.z - curr.z) / distance;
}
}
double dSqr = Math.sqrt(dx*dx + dy*dy + dz*dz);
/* Avoid division by zero if the intermediate point is the median */
if (dSqr > DBL_EPSILON) {
double rInv = Math.max(0, 1.0 / dSqr);
next.x = (1.0 - rInv)*next.x + rInv*curr.x;
next.y = (1.0 - rInv)*next.y + rInv*curr.y;
next.z = (1.0 - rInv)*next.z + rInv*curr.z;
}
}
delta = distance3d(curr, next);
curr.x = next.x;
curr.y = next.y;
curr.z = next.z;
}
return delta;
}

private static Coordinate initGuess(Coordinate[] points) {
Coordinate guess = new Coordinate(0, 0, 0);
for (Coordinate point : points) {
guess.x += point.x / points.length;
guess.y += point.y / points.length;
guess.z += point.z / points.length;
}
return guess;
}

private static Coordinate[] extractCoordinates(Geometry geometry) {
Coordinate[] points = geometry.getCoordinates();
if(points.length == 0)
return points;
boolean is3d = !Double.isNaN(points[0].z);
Coordinate[] coordinates = new Coordinate[points.length];
for(int i = 0; i < points.length; i++) {
coordinates[i] = points[i].copy();
if(!is3d)
coordinates[i].z = 0.0;
}
return coordinates;
}

public static Geometry geometricMedian(Geometry geometry, double tolerance, int maxIter, boolean failIfNotConverged) throws Exception {
String geometryType = geometry.getGeometryType();
if(!(Geometry.TYPENAME_POINT.equals(geometryType) || Geometry.TYPENAME_MULTIPOINT.equals(geometryType))) {
throw new Exception("Unsupported geometry type: " + geometryType);
}
Coordinate[] coordinates = extractCoordinates(geometry);
if(coordinates.length == 0)
return new Point(null, GEOMETRY_FACTORY);
Coordinate median = initGuess(coordinates);
double delta = Double.MAX_VALUE;
double[] distances = new double[coordinates.length]; // preallocate to reduce gc pressure for large iterations
for(int i = 0; i < maxIter && delta > tolerance; i++)
delta = iteratePoints(median, coordinates, distances);
if (failIfNotConverged && delta > tolerance)
throw new Exception(String.format("Median failed to converge within %.1E after %d iterations.", tolerance, maxIter));
boolean is3d = !Double.isNaN(geometry.getCoordinate().z);
if(!is3d)
median.z = Double.NaN;
Point point = new Point(new CoordinateArraySequence(new Coordinate[]{median}), GEOMETRY_FACTORY);
point.setSRID(geometry.getSRID());
return point;
}

public static Geometry geometricMedian(Geometry geometry, double tolerance, int maxIter) throws Exception {
return geometricMedian(geometry, tolerance, maxIter, false);
}

public static Geometry geometricMedian(Geometry geometry, double tolerance) throws Exception {
return geometricMedian(geometry, tolerance, DEFAULT_MAX_ITER, false);
}

public static Geometry geometricMedian(Geometry geometry) throws Exception {
return geometricMedian(geometry, DEFAULT_TOLERANCE, DEFAULT_MAX_ITER, false);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,38 @@
package org.apache.sedona.common;

import com.google.common.geometry.S2CellId;
import org.apache.sedona.common.utils.GeomUtils;
import com.google.common.math.DoubleMath;
import org.apache.sedona.common.utils.S2Utils;
import org.junit.Test;
import org.locationtech.jts.geom.*;
import org.locationtech.jts.io.WKTReader;

import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.*;

public class FunctionsTest {
public static final GeometryFactory GEOMETRY_FACTORY = new GeometryFactory();

protected static final double FP_TOLERANCE = 1e-12;
protected static final CoordinateSequenceComparator COORDINATE_SEQUENCE_COMPARATOR = new CoordinateSequenceComparator(2){
@Override
protected int compareCoordinate(CoordinateSequence s1, CoordinateSequence s2, int i, int dimension) {
for (int d = 0; d < dimension; d++) {
double ord1 = s1.getOrdinate(i, d);
double ord2 = s2.getOrdinate(i, d);
int comp = DoubleMath.fuzzyCompare(ord1, ord2, FP_TOLERANCE);
if (comp != 0) return comp;
}
return 0;
}
};

private final WKTReader wktReader = new WKTReader();

private Coordinate[] coordArray(double... coordValues) {
Coordinate[] coords = new Coordinate[(int)(coordValues.length / 2)];
for (int i = 0; i < coordValues.length; i += 2) {
Expand Down Expand Up @@ -389,4 +405,40 @@ public void getGoogleS2CellIDsAllSameLevel() {
expects.add(10);
assertEquals(expects, levels);
}

@Test
public void geometricMedian() throws Exception {
MultiPoint multiPoint = GEOMETRY_FACTORY.createMultiPointFromCoords(
coordArray(1480,0, 620,0));
Geometry actual = Functions.geometricMedian(multiPoint);
Geometry expected = wktReader.read("POINT (1050 0)");
assertEquals(0, expected.compareTo(actual, COORDINATE_SEQUENCE_COMPARATOR));
}

@Test
public void geometricMedianTolerance() throws Exception {
MultiPoint multiPoint = GEOMETRY_FACTORY.createMultiPointFromCoords(
coordArray(0,0, 10,1, 5,1, 20,20));
Geometry actual = Functions.geometricMedian(multiPoint, 1e-15);
Geometry expected = wktReader.read("POINT (5 1)");
assertEquals(0, expected.compareTo(actual, COORDINATE_SEQUENCE_COMPARATOR));
}

@Test
public void geometricMedianUnsupported() {
LineString lineString = GEOMETRY_FACTORY.createLineString(
coordArray(1480,0, 620,0));
Exception e = assertThrows(Exception.class, () -> Functions.geometricMedian(lineString));
assertEquals("Unsupported geometry type: LineString", e.getMessage());
}

@Test
public void geometricMedianFailConverge() {
MultiPoint multiPoint = GEOMETRY_FACTORY.createMultiPointFromCoords(
coordArray(12,5, 62,7, 100,-1, 100,-5, 10,20, 105,-5));
Exception e = assertThrows(Exception.class,
() -> Functions.geometricMedian(multiPoint, 1e-6, 5, true));
assertEquals("Median failed to converge within 1.0E-06 after 5 iterations.", e.getMessage());
}

}
31 changes: 30 additions & 1 deletion docs/api/flink/Function.md
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,36 @@ Result:
+-----------------------------+
```

## ST_GeometricMedian

Introduction: Computes the approximate geometric median of a MultiPoint geometry using the Weiszfeld algorithm. The geometric median provides a centrality measure that is less sensitive to outlier points than the centroid.

The algorithm will iterate until the distance change between successive iterations is less than the supplied `tolerance` parameter. If this condition has not been met after `maxIter` iterations, the function will produce an error and exit, unless `failIfNotConverged` is set to `false`.

If a `tolerance` value is not provided, a default `tolerance` value is `1e-6`.

Format: `ST_GeometricMedian(geom: geometry, tolerance: float, maxIter: integer, failIfNotConverged: boolean)`

Format: `ST_GeometricMedian(geom: geometry, tolerance: float, maxIter: integer)`

Format: `ST_GeometricMedian(geom: geometry, tolerance: float)`

Format: `ST_GeometricMedian(geom: geometry)`

Default parameters: `tolerance: 1e-6, maxIter: 1000, failIfNotConverged: false`

Since: `1.4.1`

Example:
```sql
SELECT ST_GeometricMedian(ST_GeomFromWKT('MULTIPOINT((0 0), (1 1), (2 2), (200 200))'))
```

Output:
```
POINT (1.9761550281255005 1.9761550281255005)
```

## ST_GeometryN

Introduction: Return the 0-based Nth geometry if the geometry is a GEOMETRYCOLLECTION, (MULTI)POINT, (MULTI)LINESTRING, MULTICURVE or (MULTI)POLYGON. Otherwise, return null
Expand Down Expand Up @@ -927,4 +957,3 @@ SELECT ST_ZMin(ST_GeomFromText('LINESTRING(1 3 4, 5 6 7)'))
```

Output: `4.0`

31 changes: 31 additions & 0 deletions docs/api/sql/Function.md
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,36 @@ Result:
+-----------------------------+
```

## ST_GeometricMedian

Introduction: Computes the approximate geometric median of a MultiPoint geometry using the Weiszfeld algorithm. The geometric median provides a centrality measure that is less sensitive to outlier points than the centroid.

The algorithm will iterate until the distance change between successive iterations is less than the supplied `tolerance` parameter. If this condition has not been met after `maxIter` iterations, the function will produce an error and exit, unless `failIfNotConverged` is set to `false`.

If a `tolerance` value is not provided, a default `tolerance` value is `1e-6`.

Format: `ST_GeometricMedian(geom: geometry, tolerance: float, maxIter: integer, failIfNotConverged: boolean)`

Format: `ST_GeometricMedian(geom: geometry, tolerance: float, maxIter: integer)`

Format: `ST_GeometricMedian(geom: geometry, tolerance: float)`

Format: `ST_GeometricMedian(geom: geometry)`

Default parameters: `tolerance: 1e-6, maxIter: 1000, failIfNotConverged: false`

Since: `1.4.1`

Example:
```sql
SELECT ST_GeometricMedian(ST_GeomFromWKT('MULTIPOINT((0 0), (1 1), (2 2), (200 200))'))
```

Output:
```
POINT (1.9761550281255005 1.9761550281255005)
```

## ST_GeometryN

Introduction: Return the 0-based Nth geometry if the geometry is a GEOMETRYCOLLECTION, (MULTI)POINT, (MULTI)LINESTRING, MULTICURVE or (MULTI)POLYGON. Otherwise, return null
Expand Down Expand Up @@ -1560,3 +1590,4 @@ SELECT ST_ZMin(ST_GeomFromText('LINESTRING(1 3 4, 5 6 7)'))
```

Output: `4.0`

3 changes: 2 additions & 1 deletion flink/src/main/java/org/apache/sedona/flink/Catalog.java
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ public static UserDefinedFunction[] getFuncs() {
new Functions.ST_SetPoint(),
new Functions.ST_LineFromMultiPoint(),
new Functions.ST_Split(),
new Functions.ST_S2CellIDs()
new Functions.ST_S2CellIDs(),
new Functions.ST_GeometricMedian()
};
}

Expand Down
Loading