Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SEDONA-475] Add RS_NormalizeAll #1221

Merged
merged 9 commits into from
Feb 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions common/src/main/java/org/apache/sedona/common/raster/MapAlgebra.java
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,128 @@ public static double[] normalize(double[] band) {
return result;
}

public static GridCoverage2D normalizeAll(GridCoverage2D rasterGeom) {
return normalizeAll(rasterGeom, 0d, 255d, null, null, null, true);
}

public static GridCoverage2D normalizeAll(GridCoverage2D rasterGeom, double minLim, double maxLim) {
return normalizeAll(rasterGeom, minLim, maxLim, null, null, null, true);
}

public static GridCoverage2D normalizeAll(GridCoverage2D rasterGeom, double minLim, double maxLim, double noDataValue) {
return normalizeAll(rasterGeom, minLim, maxLim, noDataValue, null, null, true);
}

public static GridCoverage2D normalizeAll(GridCoverage2D rasterGeom, double minLim, double maxLim, Double noDataValue, boolean normalizeAcrossBands) {
return normalizeAll(rasterGeom, minLim, maxLim, noDataValue, null, null, normalizeAcrossBands);
}

public static GridCoverage2D normalizeAll(GridCoverage2D rasterGeom, double minLim, double maxLim, Double noDataValue, Double minValue, Double maxValue) {
return normalizeAll(rasterGeom, minLim, maxLim, noDataValue, minValue, maxValue, true);
}

/**
*
* @param rasterGeom Raster to be normalized
* @param minLim Lower limit of normalization range
* @param maxLim Upper limit of normalization range
* @param noDataValue NoDataValue used in raster
* @param minValue Minimum value in raster
* @param maxValue Maximum value in raster
* @param normalizeAcrossBands flag to determine the normalization method
* @return a raster with all values in all bands normalized between minLim and maxLim
*/
public static GridCoverage2D normalizeAll(GridCoverage2D rasterGeom, double minLim, double maxLim, Double noDataValue, Double minValue, Double maxValue, boolean normalizeAcrossBands) {
if (minLim > maxLim) {
throw new IllegalArgumentException("minLim cannot be greater than maxLim");
}

int numBands = rasterGeom.getNumSampleDimensions();
RenderedImage renderedImage = rasterGeom.getRenderedImage();
int rasterDataType = renderedImage.getSampleModel().getDataType();

double globalMin = minValue != null ? minValue : Double.MAX_VALUE;
double globalMax = maxValue != null ? maxValue : -Double.MAX_VALUE;

// Initialize arrays to store band-wise min and max values
double[] minValues = new double[numBands];
double[] maxValues = new double[numBands];
Arrays.fill(minValues, Double.MAX_VALUE);
Arrays.fill(maxValues, -Double.MAX_VALUE);

// Compute global min and max values across all bands if necessary and not provided
if (minValue == null || maxValue == null) {
for (int bandIndex = 0; bandIndex < numBands; bandIndex++) {
double[] bandValues = bandAsArray(rasterGeom, bandIndex + 1);
double bandNoDataValue = RasterUtils.getNoDataValue(rasterGeom.getSampleDimension(bandIndex));

if (noDataValue == null) {
noDataValue = maxLim;
}

for (double val : bandValues) {
if (val != bandNoDataValue) {
if (normalizeAcrossBands) {
globalMin = Math.min(globalMin, val);
globalMax = Math.max(globalMax, val);
} else {
minValues[bandIndex] = Math.min(minValues[bandIndex], val);
maxValues[bandIndex] = Math.max(maxValues[bandIndex], val);
}
}
}
}
} else {
globalMin = minValue;
globalMax = maxValue;
}

// Normalize each band
for (int bandIndex = 0; bandIndex < numBands; bandIndex++) {
double[] bandValues = bandAsArray(rasterGeom, bandIndex + 1);
double bandNoDataValue = RasterUtils.getNoDataValue(rasterGeom.getSampleDimension(bandIndex));
double currentMin = normalizeAcrossBands ? globalMin : (minValue != null ? minValue : minValues[bandIndex]);
double currentMax = normalizeAcrossBands ? globalMax : (maxValue != null ? maxValue : maxValues[bandIndex]);

if (Double.compare(currentMax, currentMin) == 0) {
Arrays.fill(bandValues, minLim);
} else {
for (int i = 0; i < bandValues.length; i++) {
if (bandValues[i] != bandNoDataValue) {
double normalizedValue = minLim + ((bandValues[i] - currentMin) * (maxLim - minLim)) / (currentMax - currentMin);
bandValues[i] = castRasterDataType(normalizedValue, rasterDataType);
} else {
bandValues[i] = noDataValue;
}
}
}

// Update the raster with the normalized band and noDataValue
rasterGeom = addBandFromArray(rasterGeom, bandValues, bandIndex+1);
rasterGeom = RasterBandEditors.setBandNoDataValue(rasterGeom, bandIndex+1, noDataValue);
}

return rasterGeom;
}

private static double castRasterDataType(double value, int dataType) {
switch (dataType) {
case DataBuffer.TYPE_BYTE:
return (byte) value;
case DataBuffer.TYPE_SHORT:
return (short) value;
case DataBuffer.TYPE_INT:
return (int) value;
case DataBuffer.TYPE_USHORT:
return (char) value;
case DataBuffer.TYPE_FLOAT:
return (float) value;
case DataBuffer.TYPE_DOUBLE:
default:
return value;
}
}

/**
* @param band1 band values
* @param band2 band values
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import org.opengis.referencing.FactoryException;

import java.awt.image.DataBuffer;
import java.lang.reflect.Array;
import java.util.Arrays;
import java.util.Random;

import static org.junit.Assert.*;
Expand Down Expand Up @@ -321,6 +323,122 @@ public void testNormalize() {
assertArrayEquals(expected, actual, 0.1d);
}

@Test
public void testNormalizeAll() throws FactoryException {
GridCoverage2D raster1 = RasterConstructors.makeEmptyRaster(2, 4, 4, 0, 0, 1);
GridCoverage2D raster2 = RasterConstructors.makeEmptyRaster(2, 4, 4, 0, 0, 1);
GridCoverage2D raster3 = RasterConstructors.makeEmptyRaster(2, "I", 4, 4, 0, 0, 1);
GridCoverage2D raster4 = RasterConstructors.makeEmptyRaster(2, 4, 4, 0, 0, 1);
GridCoverage2D raster5 = RasterConstructors.makeEmptyRaster(2, 4, 4, 0, 0, 1);

for (int band = 1; band <= 2; band++) {
double[] bandValues1 = new double[4 * 4];
double[] bandValues2 = new double[4 * 4];
double[] bandValues3 = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16};
double[] bandValues4 = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,0};
double[] bandValues5 = new double[4 * 4];
for (int i = 0; i < bandValues1.length; i++) {
bandValues1[i] = (i) * band;
bandValues2[i] = (1) * (band-1);
bandValues5[i] = i + ((band-1)*15);
}
raster1 = MapAlgebra.addBandFromArray(raster1, bandValues1, band);
raster2 = MapAlgebra.addBandFromArray(raster2, bandValues2, band);
raster3 = MapAlgebra.addBandFromArray(raster3, bandValues3, band);
raster4 = MapAlgebra.addBandFromArray(raster4, bandValues4, band);
raster4 = RasterBandEditors.setBandNoDataValue(raster4, band, 0.0);
raster5 = MapAlgebra.addBandFromArray(raster5, bandValues5, band);
}
raster3 = RasterBandEditors.setBandNoDataValue(raster3, 1, 16.0);
raster3 = RasterBandEditors.setBandNoDataValue(raster3, 2, 1.0);

GridCoverage2D normalizedRaster1 = MapAlgebra.normalizeAll(raster1, 0, 255, -9999.0, false);
GridCoverage2D normalizedRaster2 = MapAlgebra.normalizeAll(raster1, 256d, 511d, -9999.0, false);
GridCoverage2D normalizedRaster3 = MapAlgebra.normalizeAll(raster2);
GridCoverage2D normalizedRaster4 = MapAlgebra.normalizeAll(raster3, 0, 255, 95.0);
GridCoverage2D normalizedRaster5 = MapAlgebra.normalizeAll(raster4, 0, 255);
GridCoverage2D normalizedRaster6 = MapAlgebra.normalizeAll(raster5, 0.0, 255.0, -9999.0, 0.0, 30.0);
GridCoverage2D normalizedRaster7 = MapAlgebra.normalizeAll(raster5, 0, 255, -9999.0, false);

double[] expected1 = {0.0, 17.0, 34.0, 51.0, 68.0, 85.0, 102.0, 119.0, 136.0, 153.0, 170.0, 187.0, 204.0, 221.0, 238.0, 255.0};
double[] expected2 = {256.0, 273.0, 290.0, 307.0, 324.0, 341.0, 358.0, 375.0, 392.0, 409.0, 426.0, 443.0, 460.0, 477.0, 494.0, 511.0};
double[] expected3 = {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0};
double[] expected4 = {0.0, 17.0, 34.0, 51.0, 68.0, 85.0, 102.0, 119.0, 136.0, 153.0, 170.0, 187.0, 204.0, 221.0, 238.0, 95.0};
double[] expected5 = {95.0, 17.0, 34.0, 51.0, 68.0, 85.0, 102.0, 119.0, 136.0, 153.0, 170.0, 187.0, 204.0, 221.0, 238.0, 255.0};
double[] expected6 = {0.0, 18.214285714285715, 36.42857142857143, 54.642857142857146, 72.85714285714286, 91.07142857142857, 109.28571428571429, 127.5, 145.71428571428572, 163.92857142857142, 182.14285714285714, 200.35714285714286, 218.57142857142858, 236.78571428571428, 255.0, 255.0};

// Step 3: Validate the results for each band
for (int band = 1; band <= 2; band++) {
double[] normalizedBand1 = MapAlgebra.bandAsArray(normalizedRaster1, band);
double[] normalizedBand2 = MapAlgebra.bandAsArray(normalizedRaster2, band);
double[] normalizedBand5 = MapAlgebra.bandAsArray(normalizedRaster5, band);
double[] normalizedBand6 = MapAlgebra.bandAsArray(normalizedRaster6, band);
double[] normalizedBand7 = MapAlgebra.bandAsArray(normalizedRaster7, band);
double normalizedMin6 = Arrays.stream(normalizedBand6).min().getAsDouble();
double normalizedMax6 = Arrays.stream(normalizedBand6).max().getAsDouble();

assertEquals(Arrays.toString(expected1), Arrays.toString(normalizedBand1));
assertEquals(Arrays.toString(expected2), Arrays.toString(normalizedBand2));
assertEquals(Arrays.toString(expected6), Arrays.toString(normalizedBand5));
assertEquals(Arrays.toString(expected1), Arrays.toString(normalizedBand7));

assertEquals(0+((band-1)*127.5), normalizedMin6, 0.01d);
assertEquals(127.5+((band-1)*127.5), normalizedMax6, 0.01d);
}

assertEquals(95.0, RasterUtils.getNoDataValue(normalizedRaster4.getSampleDimension(0)), 0.01d);
assertEquals(95.0, RasterUtils.getNoDataValue(normalizedRaster4.getSampleDimension(1)), 0.01d);

assertEquals(Arrays.toString(expected3), Arrays.toString(MapAlgebra.bandAsArray(normalizedRaster3, 1)));
assertEquals(Arrays.toString(expected4), Arrays.toString(MapAlgebra.bandAsArray(normalizedRaster4, 1)));
assertEquals(Arrays.toString(expected5), Arrays.toString(MapAlgebra.bandAsArray(normalizedRaster4, 2)));
}

@Test
public void testNormalizeAll2() throws FactoryException {
String[] pixelTypes = {"B", "I", "S", "US", "F", "D"}; // Byte, Integer, Short, Unsigned Short, Float, Double
for (String pixelType : pixelTypes) {
testNormalizeAll2(10, 10, pixelType);
}
}

private void testNormalizeAll2(int width, int height, String pixelType) throws FactoryException {
// Create an empty raster with the specified pixel type
GridCoverage2D raster = RasterConstructors.makeEmptyRaster(1, pixelType, width, height, 10, 20, 1);

// Fill raster
double[] bandValues = new double[width * height];
for (int i = 0; i < bandValues.length; i++) {
bandValues[i] = i;
}
raster = MapAlgebra.addBandFromArray(raster, bandValues, 1);

GridCoverage2D normalizedRaster = MapAlgebra.normalizeAll(raster, 0, 255);

// Check the normalized values and data type
double[] normalizedBandValues = MapAlgebra.bandAsArray(normalizedRaster, 1);
for (int i = 0; i < bandValues.length; i++) {
double expected = (bandValues[i] - 0) * (255 - 0) / (99 - 0);
double actual = normalizedBandValues[i];
switch (normalizedRaster.getRenderedImage().getSampleModel().getDataType()) {
case DataBuffer.TYPE_BYTE:
case DataBuffer.TYPE_SHORT:
case DataBuffer.TYPE_USHORT:
case DataBuffer.TYPE_INT:
assertEquals((int) expected, (int) actual);
break;
default:
assertEquals(expected, actual, 0.01);
}
}

// Assert the data type remains as expected
int resultDataType = normalizedRaster.getRenderedImage().getSampleModel().getDataType();
int expectedDataType = RasterUtils.getDataTypeCode(pixelType);
assertEquals(expectedDataType, resultDataType);
}


@Test
public void testNormalizedDifference() {
double[] band1 = new double[] {960, 1067, 107, 20, 1868};
Expand Down
41 changes: 41 additions & 0 deletions docs/api/sql/Raster-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2480,6 +2480,47 @@ Spark SQL Example:
SELECT RS_Normalize(band)
```

### RS_NormalizeAll

Introduction: Normalizes values in all bands of a raster between a given normalization range. The function maintains the data type of the raster values by ensuring that the normalized values are cast back to the original data type of each band in the raster. By default, the values are normalized to range [0, 255]. RS_NormalizeAll can take upto 6 of the following arguments.

- `raster`: The raster to be normalized.
- `minLim` and `maxLim` (Optional): The lower and upper limits of the normalization range. By default, normalization range is set to [0, 255].
- `noDataValue` (Optional): Defines the value to be used for missing or invalid data in raster bands. By default, noDataValue is set to `maxLim`.
- `minValue` and `maxValue` (Optional): Optionally, specific minimum and maximum values of the input raster can be provided. If not provided, these values are computed from the raster data.
- `normalizeAcrossBands` (Optional): A boolean flag to determine the normalization method. If set to true (default), normalization is performed across all bands based on global min and max values. If false, each band is normalized individually based on its own min and max values.

!!! Warning
Using a noDataValue that falls within the normalization range can lead to loss of valid data. If any data value within a raster band matches the specified noDataValue, it will be replaced and cannot be distinguished or recovered later. Exercise caution in selecting a noDataValue to avoid unintentional data alteration.

Formats:
```
RS_NormalizeAll (raster: Raster)`
```
```
RS_NormalizeAll (raster: Raster, minLim: Double, maxLim: Double)
```
```
RS_NormalizeAll (raster: Raster, minLim: Double, maxLim: Double, noDataValue: Double)
```
```
RS_NormalizeAll (raster: Raster, minLim: Double, maxLim: Double, noDataValue: Double, normalizeAcrossBands: Boolean)
```
```
RS_NormalizeAll (raster: Raster, minLim: Double, maxLim: Double, noDataValue: Double, minValue: Double, maxValue: Double)
```
```
RS_NormalizeAll (raster: Raster, minLim: Double, maxLim: Double, noDataValue: Double, minValue: Double, maxValue: Double, normalizeAcrossBands: Boolean)
```

Since: `v1.6.0`

Spark SQL Example:

```sql
SELECT RS_NormalizeAll(raster, 0, 1)
```

### RS_NormalizedDifference

Introduction: Returns Normalized Difference between two bands(band2 and band1) in a Geotiff image(example: NDVI, NDBI)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ object Catalog {
function[RS_LogicalOver](),
function[RS_Array](),
function[RS_Normalize](),
function[RS_NormalizeAll](),
function[RS_AddBandFromArray](),
function[RS_BandAsArray](),
function[RS_MapAlgebra](null),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,14 @@ case class RS_Normalize(inputExpressions: Seq[Expression]) extends InferredExpre
}
}

case class RS_NormalizeAll(inputExpressions: Seq[Expression]) extends InferredExpression(
inferrableFunction1(MapAlgebra.normalizeAll), inferrableFunction3(MapAlgebra.normalizeAll), inferrableFunction4(MapAlgebra.normalizeAll), inferrableFunction5(MapAlgebra.normalizeAll), inferrableFunction6(MapAlgebra.normalizeAll), inferrableFunction7(MapAlgebra.normalizeAll)
) {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
}
}

case class RS_AddBandFromArray(inputExpressions: Seq[Expression])
extends InferredExpression(nullTolerantInferrableFunction3(MapAlgebra.addBandFromArray), nullTolerantInferrableFunction4(MapAlgebra.addBandFromArray), inferrableFunction2(MapAlgebra.addBandFromArray)) {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,15 @@ class rasteralgebraTest extends TestBaseScala with BeforeAndAfter with GivenWhen
assert(df.first().getAs[mutable.WrappedArray[Double]](0)(1) == 255)
}

it("should pass RS_NormalizeAll") {
var df = sparkSession.read.format("binaryFile").load(resourceFolder + "raster/test1.tiff")
df = df.selectExpr("RS_FromGeoTiff(content) as raster")
val result1 = df.selectExpr("RS_NormalizeAll(raster, 0, 255) as normalized").first().get(0)
val result2 = df.selectExpr("RS_NormalizeAll(raster, 0, 255, 0) as normalized").first().get(0)
assert(result1.isInstanceOf[GridCoverage2D])
assert(result2.isInstanceOf[GridCoverage2D])
}

it("should pass RS_Array") {
val df = sparkSession.sql("SELECT RS_Array(6, 1e-6) as band")
val result = df.first().getAs[mutable.WrappedArray[Double]](0)
Expand Down
Loading