Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SEDONA-275] Add raster function RS_SetSRID #817

Merged
merged 1 commit into from
Apr 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,15 @@
*/
package org.apache.sedona.common.raster;

import org.geotools.coverage.CoverageFactoryFinder;
import org.geotools.coverage.grid.GridCoordinates2D;
import org.geotools.coverage.grid.GridCoverage2D;
import org.geotools.coverage.grid.GridCoverageFactory;
import org.geotools.coverage.grid.GridGeometry2D;
import org.geotools.gce.geotiff.GeoTiffWriter;
import org.geotools.geometry.DirectPosition2D;
import org.geotools.geometry.Envelope2D;
import org.geotools.geometry.jts.ReferencedEnvelope;
import org.geotools.referencing.CRS;
import org.geotools.referencing.crs.DefaultEngineeringCRS;
import org.locationtech.jts.geom.*;
Expand Down Expand Up @@ -48,6 +52,18 @@ public static int numBands(GridCoverage2D raster) {
return raster.getNumSampleDimensions();
}

public static GridCoverage2D setSrid(GridCoverage2D raster, int srid) throws FactoryException {
CoordinateReferenceSystem crs;
if (srid == 0) {
crs = DefaultEngineeringCRS.CARTESIAN_2D;
} else {
crs = CRS.decode("EPSG:" + srid);
}
ReferencedEnvelope referencedEnvelope = new ReferencedEnvelope(raster.getEnvelope2D(), crs);
GridCoverageFactory gridCoverageFactory = CoverageFactoryFinder.getGridCoverageFactory(null);
return gridCoverageFactory.create(raster.getName().toString(), raster.getRenderedImage(), referencedEnvelope);
}

public static int srid(GridCoverage2D raster) throws FactoryException {
CoordinateReferenceSystem crs = raster.getCoordinateReferenceSystem();
if (crs instanceof DefaultEngineeringCRS) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package org.apache.sedona.common.raster;

import org.geotools.coverage.grid.GridCoverage2D;
import org.junit.Test;
import org.locationtech.jts.geom.Coordinate;
import org.locationtech.jts.geom.Geometry;
Expand Down Expand Up @@ -45,6 +46,20 @@ public void testNumBands() {
assertEquals(4, Functions.numBands(multiBandRaster));
}

@Test
public void testSetSrid() throws FactoryException {
assertEquals(0, Functions.srid(oneBandRaster));
assertEquals(4326, Functions.srid(multiBandRaster));

GridCoverage2D oneBandRasterWithUpdatedSrid = Functions.setSrid(oneBandRaster, 4326);
assertEquals(4326, Functions.srid(oneBandRasterWithUpdatedSrid));
assertEquals(4326, Functions.envelope(oneBandRasterWithUpdatedSrid).getSRID());
assertTrue(Functions.envelope(oneBandRasterWithUpdatedSrid).equalsTopo(Functions.envelope(oneBandRaster)));

GridCoverage2D multiBandRasterWithUpdatedSrid = Functions.setSrid(multiBandRaster, 0);
assertEquals(0 , Functions.srid(multiBandRasterWithUpdatedSrid));
}

@Test
public void testSrid() throws FactoryException {
assertEquals(0, Functions.srid(oneBandRaster));
Expand Down
14 changes: 14 additions & 0 deletions docs/api/sql/Raster-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,20 @@ Output:
4
```

## RS_SetSRID

Introduction: Sets the spatial reference system identifier (SRID) of the raster geometry.

Format: `RS_SetSRID (raster: Raster, srid: Integer)`

Since: `v1.4.1`

Spark SQL example:
```sql
SELECT RS_SetSRID(raster, 4326)
FROM raster_table
```

### RS_SRID

Introduction: Returns the spatial reference system identifier (SRID) of the raster geometry.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ object Catalog {
function[RS_FromGeoTiff](),
function[RS_Envelope](),
function[RS_NumBands](),
function[RS_SetSRID](),
function[RS_SRID](),
function[RS_Value](1),
function[RS_Values](1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,17 @@
package org.apache.spark.sql.sedona_sql.expressions.raster

import org.apache.sedona.common.geometrySerde.GeometrySerializer
import org.apache.sedona.common.raster.Functions
import org.apache.sedona.common.raster.{Functions, Serde}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression}
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.sedona_sql.UDT.{GeometryUDT, RasterUDT}
import org.apache.spark.sql.sedona_sql.expressions.UserDataGeneratator
import org.apache.spark.sql.sedona_sql.expressions.{SerdeAware, UserDataGeneratator}
import org.apache.spark.sql.sedona_sql.expressions.implicits._
import org.apache.spark.sql.sedona_sql.expressions.raster.implicits._
import org.apache.spark.sql.types._
import org.geotools.coverage.grid.GridCoverage2D



Expand Down Expand Up @@ -855,6 +856,34 @@ case class RS_NumBands(inputExpressions: Seq[Expression]) extends Expression wit
override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT)
}

case class RS_SetSRID(inputExpressions: Seq[Expression]) extends Expression with CodegenFallback with ExpectsInputTypes with SerdeAware {
override def nullable: Boolean = true

override def eval(input: InternalRow): Any = {
Option(evalWithoutSerialization(input)).map(Serde.serialize).orNull
}

override def evalWithoutSerialization(input: InternalRow): GridCoverage2D = {
val raster = inputExpressions(0).toRaster(input)
val srid = inputExpressions(1).eval(input).asInstanceOf[Int]
if (raster == null) {
null
} else {
Functions.setSrid(raster, srid)
}
}

override def dataType: DataType = RasterUDT

override def children: Seq[Expression] = inputExpressions

protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
}

override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT, IntegerType)
}

case class RS_SRID(inputExpressions: Seq[Expression]) extends Expression with CodegenFallback with ExpectsInputTypes {
override def nullable: Boolean = true

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,17 @@ class rasteralgebraTest extends TestBaseScala with BeforeAndAfter with GivenWhen
assert(result == 1)
}

it("Passed RS_SetSRID should handle null values") {
val result = sparkSession.sql("select RS_SetSRID(null, 0)").first().get(0)
assert(result == null)
}

it("Passed RS_SetSRID with raster") {
val df = sparkSession.read.format("binaryFile").load(resourceFolder + "raster/test1.tiff")
val result = df.selectExpr("RS_SRID(RS_SetSRID(RS_FromGeoTiff(content), 4326))").first().getInt(0)
assert(result == 4326)
}

it("Passed RS_SRID should handle null values") {
val result = sparkSession.sql("select RS_SRID(null)").first().get(0)
assert(result == null)
Expand Down