Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SEDONA-516] Add RS_Interpolate #1282

Merged
merged 11 commits into from
Mar 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.sedona.common.raster;

import org.apache.sedona.common.FunctionsGeoTools;
import org.apache.sedona.common.utils.RasterInterpolate;
import org.apache.sedona.common.utils.RasterUtils;
import org.geotools.coverage.CoverageFactoryFinder;
import org.geotools.coverage.GridSampleDimension;
Expand All @@ -30,6 +31,7 @@
import org.geotools.geometry.Envelope2D;
import org.geotools.referencing.crs.DefaultEngineeringCRS;
import org.geotools.referencing.operation.transform.AffineTransform2D;
import org.locationtech.jts.index.strtree.STRtree;
import org.opengis.coverage.grid.GridCoverage;
import org.opengis.coverage.grid.GridGeometry;
import org.opengis.metadata.spatial.PixelOrientation;
Expand All @@ -48,7 +50,6 @@
import java.util.Map;
import java.util.Objects;

import static java.lang.Double.NaN;
import static org.apache.sedona.common.raster.MapAlgebra.addBandFromArray;
import static org.apache.sedona.common.raster.MapAlgebra.bandAsArray;

Expand Down Expand Up @@ -434,4 +435,87 @@ private static double castRasterDataType(double value, int dataType) {
}
}

public static GridCoverage2D interpolate(GridCoverage2D inputRaster) throws IllegalArgumentException{
return interpolate(inputRaster, 2.0, "fixed", null, null, null);
}

public static GridCoverage2D interpolate(GridCoverage2D inputRaster, Double power) throws IllegalArgumentException{
return interpolate(inputRaster, power, "fixed", null, null, null);
}

public static GridCoverage2D interpolate(GridCoverage2D inputRaster, Double power, String mode) throws IllegalArgumentException{
return interpolate(inputRaster, power, mode, null, null, null);
}

public static GridCoverage2D interpolate(GridCoverage2D inputRaster, Double power, String mode, Double numPointsOrRadius) throws IllegalArgumentException{
return interpolate(inputRaster, power, mode, numPointsOrRadius, null, null);
}

public static GridCoverage2D interpolate(GridCoverage2D inputRaster, Double power, String mode, Double numPointsOrRadius, Double maxRadiusOrMinPoints) throws IllegalArgumentException{
return interpolate(inputRaster, power, mode, numPointsOrRadius, maxRadiusOrMinPoints, null);
}

public static GridCoverage2D interpolate(GridCoverage2D inputRaster, Double power, String mode, Double numPointsOrRadius, Double maxRadiusOrMinPoints, Integer band) throws IllegalArgumentException {
if (!mode.equalsIgnoreCase("variable") && !mode.equalsIgnoreCase("fixed")) {
throw new IllegalArgumentException("Invalid 'mode': '" + mode + "'. Expected one of: 'Variable', 'Fixed'.");
}

Raster rasterData = inputRaster.getRenderedImage().getData();
WritableRaster raster = rasterData.createCompatibleWritableRaster(RasterAccessors.getWidth(inputRaster), RasterAccessors.getHeight(inputRaster));
int width = raster.getWidth();
int height = raster.getHeight();
int numBands = raster.getNumBands();
GridSampleDimension [] gridSampleDimensions = inputRaster.getSampleDimensions();

if (band != null && (band < 1 || band > numBands)) {
throw new IllegalArgumentException("Band index out of range.");
}

// Interpolation for each band
for (int bandIndex=0; bandIndex < numBands; bandIndex++) {
if (band == null || bandIndex == band - 1) {
// Generate STRtree
STRtree strtree = RasterInterpolate.generateSTRtree(inputRaster, bandIndex);
Double noDataValue = RasterUtils.getNoDataValue(inputRaster.getSampleDimension(bandIndex));
int countNoDataValues = 0;

// Skip band if STRtree is empty or has all valid data pixels
if (strtree.isEmpty() || strtree.size() == width*height) {
continue;
}

if (mode.equalsIgnoreCase("variable") && strtree.size() < numPointsOrRadius) {
throw new IllegalArgumentException("Parameter 'numPoints' is larger than no. of valid pixels in band "+bandIndex+". Please choose an appropriate value");
}

// Perform interpolation
for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
double value = rasterData.getSampleDouble(x, y, bandIndex);
if (Double.isNaN(value) || value == noDataValue) {
countNoDataValues ++;
double interpolatedValue = RasterInterpolate.interpolateIDW(x, y, strtree, width, height, power, mode, numPointsOrRadius, maxRadiusOrMinPoints);
interpolatedValue = (Double.isNaN(interpolatedValue)) ? noDataValue:interpolatedValue;
if (interpolatedValue != noDataValue) {
countNoDataValues --;
}
raster.setSample(x, y, bandIndex, interpolatedValue);
} else {
raster.setSample(x, y, bandIndex, value);
}
}
}

// If all noDataValues are interpolated, update band metadata (remove nodatavalue)
if (countNoDataValues == 0){
gridSampleDimensions[bandIndex] = RasterUtils.removeNoDataValue(inputRaster.getSampleDimension(bandIndex));
}
} else {
raster.setSamples(0, 0, raster.getWidth(), raster.getHeight(), band, rasterData.getSamples(0, 0, raster.getWidth(), raster.getHeight(), band, (double[]) null));
}
}

return RasterUtils.clone(raster, inputRaster.getGridGeometry(), gridSampleDimensions, inputRaster, null, true);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
package org.apache.sedona.common.utils;

import org.geotools.coverage.grid.GridCoverage2D;
import org.locationtech.jts.geom.Coordinate;
import org.locationtech.jts.geom.Envelope;
import org.locationtech.jts.geom.GeometryFactory;
import org.locationtech.jts.geom.Point;
import org.locationtech.jts.index.strtree.STRtree;

import java.awt.image.Raster;
import java.util.*;

public class RasterInterpolate {
private RasterInterpolate() {}

public static STRtree generateSTRtree(GridCoverage2D inputRaster, int band) {
Raster rasterData = inputRaster.getRenderedImage().getData();
int width = rasterData.getWidth();
int height = rasterData.getHeight();
Double noDataValue = RasterUtils.getNoDataValue(inputRaster.getSampleDimension(band));
GeometryFactory geometryFactory = new GeometryFactory();
STRtree rtree = new STRtree();

for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
double value = rasterData.getSampleDouble(x, y, band);
if (!Double.isNaN(value) && value != noDataValue) {
Point point = geometryFactory.createPoint(new Coordinate(x, y));
RasterPoint rasterPoint = new RasterPoint(point, value, 0.0);
rtree.insert(new Envelope(point.getCoordinate()), rasterPoint);
}
}
}
rtree.build();
return rtree;
}

public static double interpolateIDW(int x, int y, STRtree strtree, int width, int height, double power, String mode, Double numPointsOrRadius, Double maxRadiusOrMinPoints) {
GeometryFactory geometryFactory = new GeometryFactory();
PriorityQueue<RasterPoint> minHeap = new PriorityQueue<>(Comparator.comparingDouble(RasterPoint::getDistance));

if (mode.equalsIgnoreCase("variable")) {
Double numPoints = (numPointsOrRadius==null) ? 12:numPointsOrRadius; // Default no. of points -> 12
Double maxRadius = (maxRadiusOrMinPoints==null) ? Math.sqrt((width*width)+(height*height)):maxRadiusOrMinPoints; // Default max radius -> diagonal of raster
List<RasterPoint> queryResult = strtree.query(new Envelope(x - maxRadius, x + maxRadius, y - maxRadius, y + maxRadius));
if (mode.equalsIgnoreCase("variable") && strtree.size() < numPointsOrRadius) {
throw new IllegalArgumentException("Parameter 'numPoints' defaulted to 12 which is larger than no. of valid pixels within the max search radius. Please choose an appropriate value");
}
for (RasterPoint rasterPoint : queryResult) {
if (numPoints<=0) {
break;
}
Point point = rasterPoint.getPoint();
double distance = point.distance(geometryFactory.createPoint(new Coordinate(x, y)));
rasterPoint.setDistance(distance);
minHeap.add(rasterPoint);
numPoints --;
}
} else if (mode.equalsIgnoreCase("fixed")) {
Double radius = (numPointsOrRadius==null) ? Math.sqrt((width*width)+(height*height)):numPointsOrRadius; // Default radius -> diagonal of raster
Double minPoints = (maxRadiusOrMinPoints==null) ? 0:maxRadiusOrMinPoints; // Default min no. of points -> 0
List<RasterPoint> queryResult = new ArrayList<>();
do {
queryResult.clear();
Envelope searchEnvelope = new Envelope(x - radius, x + radius, y - radius, y + radius);
queryResult = strtree.query(searchEnvelope);
// If minimum points requirement met, break the loop
if (queryResult.size() >= minPoints) {
break;
}
radius *= 1.5; // Increase radius by 50%
} while (true);

for (RasterPoint rasterPoint : queryResult) {
Point point = rasterPoint.getPoint();
double distance = point.distance(geometryFactory.createPoint(new Coordinate(x, y)));
if (distance <= 0 || distance > radius) {
continue;
}
rasterPoint.setDistance(distance);
minHeap.add(rasterPoint);
}
}

double numerator = 0.0;
double denominator = 0.0;

while (!minHeap.isEmpty()) {
RasterPoint rasterPoint = minHeap.poll();
double value = rasterPoint.getValue();
double distance = rasterPoint.getDistance();
double weight = 1.0 / Math.pow(distance, power);
numerator += weight * value;
denominator += weight;
}

double interpolatedValue = (denominator > 0 ? numerator / denominator : Double.NaN);
return interpolatedValue;
}

public static class RasterPoint {
private Point point; // JTS Point
private double value; // The associated value
private double distance; // Distance measure

public RasterPoint(Point point, double value, double distance) {
this.point = point;
this.value = value;
this.distance = distance;
}

public Point getPoint() {
return point;
}

public double getValue() {
return value;
}

public double getDistance() {
return distance;
}

public void setDistance(double distance) {
this.distance = distance;
}
}
}
Loading
Loading