Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SEDONA-326] Improve raster algebra functions: RS_Array and RS_MultiplyFactor #907

Merged
merged 1 commit into from
Jul 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,6 @@ public static Geometry polygonFromEnvelope(double minX, double minY, double maxX
}

public static Geometry geomFromGeoHash(String geoHash, Integer precision) {
System.out.println(geoHash);
System.out.println(precision);
try {
return GeoHashDecoder.decode(geoHash, precision);
} catch (GeoHashDecoder.InvalidGeoHashException e) {
Expand Down
2 changes: 1 addition & 1 deletion docs/api/sql/Raster-loader.md
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ Output:

Introduction: Create an array that is filled by the given value

Format: `RS_Array(length:Int, value: Decimal)`
Format: `RS_Array(length:Int, value: Double)`

Since: `v1.1.0`

Expand Down
4 changes: 3 additions & 1 deletion docs/api/sql/Raster-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ val multiplyDF = spark.sql("select RS_Multiply(band1, band2) as multiplyBands fr

Introduction: Multiply a factor to a spectral band in a geotiff image

Format: `RS_MultiplyFactor (Band1: Array[Double], Factor: Int)`
Format: `RS_MultiplyFactor (Band1: Array[Double], Factor: Double)`

Since: `v1.1.0`

Expand All @@ -528,6 +528,8 @@ val multiplyFactorDF = spark.sql("select RS_MultiplyFactor(band1, 2) as multiply

```

This function only accepts integer as factor before `v1.5.0`.

### RS_Normalize

Introduction: Normalize the value in the array to [0, 255]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ package org.apache.spark.sql.sedona_sql.expressions.raster
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeArrayData}
import org.apache.spark.sql.catalyst.expressions.ImplicitCastInputTypes
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.sedona_sql.expressions.UserDataGeneratator
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -126,14 +127,14 @@ case class RS_GetBand(inputExpressions: Seq[Expression])
}

case class RS_Array(inputExpressions: Seq[Expression])
extends Expression with CodegenFallback with UserDataGeneratator {
extends Expression with ImplicitCastInputTypes with CodegenFallback with UserDataGeneratator {
override def nullable: Boolean = false

override def eval(inputRow: InternalRow): Any = {
// This is an expression which takes one input expressions
assert(inputExpressions.length == 2)
val len =inputExpressions(0).eval(inputRow).asInstanceOf[Int]
val num = inputExpressions(1).eval(inputRow).asInstanceOf[Decimal].toDouble
val num = inputExpressions(1).eval(inputRow).asInstanceOf[Double]
val result = createarray(len, num)
new GenericArrayData(result)
}
Expand All @@ -148,6 +149,8 @@ case class RS_Array(inputExpressions: Seq[Expression])
result
}

override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType, DoubleType)

override def dataType: DataType = ArrayType(DoubleType)

override def children: Seq[Expression] = inputExpressions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.sedona.common.raster.{MapAlgebra, Serde}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression}
import org.apache.spark.sql.catalyst.expressions.ImplicitCastInputTypes
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.sedona_sql.UDT.RasterUDT
import org.apache.spark.sql.sedona_sql.expressions.raster.implicits.RasterInputExpressionEnhancer
Expand Down Expand Up @@ -352,30 +353,31 @@ case class RS_Count(inputExpressions: Seq[Expression])

// Multiply a factor to all values of a band
case class RS_MultiplyFactor(inputExpressions: Seq[Expression])
extends Expression with CodegenFallback with UserDataGeneratator {
extends Expression with ImplicitCastInputTypes with CodegenFallback with UserDataGeneratator {
assert(inputExpressions.length == 2)

override def nullable: Boolean = false

override def eval(inputRow: InternalRow): Any = {
val band = inputExpressions(0).eval(inputRow).asInstanceOf[ArrayData].toDoubleArray()
val target = inputExpressions(1).eval(inputRow).asInstanceOf[Int]
new GenericArrayData(multiply(band, target))
val factor = inputExpressions(1).eval(inputRow).asInstanceOf[Double]
new GenericArrayData(multiply(band, factor))

}

private def multiply(band: Array[Double], target: Int):Array[Double] = {
private def multiply(band: Array[Double], factor: Double):Array[Double] = {

var result = new Array[Double](band.length)
for(i<-0 until band.length) {

result(i) = band(i)*target
result(i) = band(i) * factor

}
result

}

override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType(DoubleType), DoubleType)

override def dataType: DataType = ArrayType(DoubleType)

override def children: Seq[Expression] = inputExpressions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ class rasteralgebraTest extends TestBaseScala with BeforeAndAfter with GivenWhen
assert(inputDf.first().getAs[mutable.WrappedArray[Double]](0) == expectedDF.first().getAs[mutable.WrappedArray[Double]](0))
}

it("Passed RS_MultiplyFactor with double factor") {
val inputDf = Seq((Seq(200.0, 400.0, 600.0))).toDF("Band")
val expectedDF = Seq((Seq(20.0, 40.0, 60.0))).toDF("multiply")
val actualDF = inputDf.selectExpr("RS_MultiplyFactor(Band, 0.1) as multiply")
assert(actualDF.first().getAs[mutable.WrappedArray[Double]](0) == expectedDF.first().getAs[mutable.WrappedArray[Double]](0))
}
}

describe("Should pass basic statistical tests") {
Expand Down Expand Up @@ -202,6 +208,13 @@ class rasteralgebraTest extends TestBaseScala with BeforeAndAfter with GivenWhen
df = df.selectExpr("RS_Normalize(Band) as normalizedBand")
assert(df.first().getAs[mutable.WrappedArray[Double]](0)(1) == 255)
}

it("should pass RS_Array") {
val df = sparkSession.sql("SELECT RS_Array(6, 1e-6) as band")
val result = df.first().getAs[mutable.WrappedArray[Double]](0)
assert(result.length == 6)
assert(result sameElements Array.fill[Double](6)(1e-6))
}
}

describe("Should pass all transformation tests") {
Expand Down
Loading