Skip to content

Commit

Permalink
[SEDONA-478] Make Sedona geometry functions and spatial join working …
Browse files Browse the repository at this point in the history
…without GeoTools (#1398)
  • Loading branch information
Kontinuation committed May 3, 2024
1 parent c35efdb commit 64570db
Show file tree
Hide file tree
Showing 17 changed files with 193 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,14 @@
package org.apache.sedona.sql

import org.apache.sedona.sql.UDF.RasterUdafCatalog
import org.apache.sedona.sql.utils.GeoToolsCoverageAvailability.{gridClassName, isGeoToolsAvailable}
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.sedona_sql.UDT.RasterUdtRegistratorWrapper
import org.apache.spark.sql.{SparkSession, functions}
import org.slf4j.{Logger, LoggerFactory}

object RasterRegistrator {
val logger: Logger = LoggerFactory.getLogger(getClass)
private val gridClassName = "org.geotools.coverage.grid.GridCoverage2D"

// Helper method to check if GridCoverage2D is available
private def isGeoToolsAvailable: Boolean = {
try {
Class.forName(gridClassName, true, Thread.currentThread().getContextClassLoader)
true
} catch {
case _: ClassNotFoundException =>
logger.warn("Geotools was not found on the classpath. Raster operations will not be available.")
false
}
}

def registerAll(sparkSession: SparkSession): Unit = {
if (isGeoToolsAvailable) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sedona.sql.utils

import org.apache.sedona.sql.RasterRegistrator.logger

/**
* A helper object to check if GeoTools GridCoverage2D is available on the classpath.
*/
object GeoToolsCoverageAvailability {
val gridClassName = "org.geotools.coverage.grid.GridCoverage2D"

lazy val isGeoToolsAvailable: Boolean = {
try {
Class.forName(gridClassName, true, Thread.currentThread().getContextClassLoader)
true
} catch {
case _: ClassNotFoundException =>
logger.warn("Geotools was not found on the classpath. Raster operations will not be available.")
false
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.spark.sql.sedona_sql.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.sedona_sql.UDT.RasterUDT
import org.apache.spark.sql.sedona_sql.expressions.raster.implicits.{RasterEnhancer, RasterInputExpressionEnhancer}
import org.apache.spark.sql.types.{ArrayType, DataTypes, UserDefinedType}

import scala.reflect.runtime.universe.{Type, typeOf}
import org.geotools.coverage.grid.GridCoverage2D

object InferrableRasterTypes {
implicit val gridCoverage2DInstance: InferrableType[GridCoverage2D] =
new InferrableType[GridCoverage2D] {}
implicit val gridCoverage2DArrayInstance: InferrableType[Array[GridCoverage2D]] =
new InferrableType[Array[GridCoverage2D]] {}

def isRasterType(t: Type): Boolean = t =:= typeOf[GridCoverage2D]
def isRasterArrayType(t: Type): Boolean = t =:= typeOf[Array[GridCoverage2D]]

val rasterUDT: UserDefinedType[_] = RasterUDT
val rasterUDTArray: ArrayType = DataTypes.createArrayType(RasterUDT)

def rasterExtractor(expr: Expression)(input: InternalRow): Any = expr.toRaster(input)

def rasterSerializer(output: Any): Any =
if (output != null) {
output.asInstanceOf[GridCoverage2D].serialize
} else {
null
}

def rasterArraySerializer(output: Any): Any =
if (output != null) {
val rasters = output.asInstanceOf[Array[GridCoverage2D]]
val serialized = rasters.map { raster =>
val serialized = raster.serialize
raster.dispose(true)
serialized
}
ArrayData.toArrayData(serialized)
} else {
null
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,11 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.sedona_sql.UDT.{GeometryUDT, RasterUDT}
import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
import org.apache.spark.sql.types.{AbstractDataType, BinaryType, BooleanType, DataType, DataTypes, DoubleType, IntegerType, LongType, StringType}
import org.apache.spark.unsafe.types.UTF8String
import org.locationtech.jts.geom.Geometry
import org.apache.spark.sql.sedona_sql.expressions.implicits._
import org.apache.spark.sql.sedona_sql.expressions.raster.implicits._
import org.geotools.coverage.grid.GridCoverage2D

import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable`
import scala.reflect.runtime.universe.TypeTag
Expand Down Expand Up @@ -75,14 +73,10 @@ abstract class InferredExpression(fSeq: InferrableFunction *)
// This is a compile time type shield for the types we are able to infer. Anything
// other than these types will cause a compilation error. This is the Scala
// 2 way of making a union type.
sealed class InferrableType[T: TypeTag]
class InferrableType[T: TypeTag]
object InferrableType {
implicit val geometryInstance: InferrableType[Geometry] =
new InferrableType[Geometry] {}
implicit val gridCoverage2DInstance: InferrableType[GridCoverage2D] =
new InferrableType[GridCoverage2D] {}
implicit val gridCoverage2DArrayInstance: InferrableType[Array[GridCoverage2D]] =
new InferrableType[Array[GridCoverage2D]] {}
implicit val geometryArrayInstance: InferrableType[Array[Geometry]] =
new InferrableType[Array[Geometry]] {}
implicit val javaDoubleInstance: InferrableType[java.lang.Double] =
Expand Down Expand Up @@ -127,8 +121,8 @@ object InferredTypes {
expr => input => expr.toGeometry(input)
} else if (t =:= typeOf[Array[Geometry]]) {
expr => input => expr.toGeometryArray(input)
} else if (t =:= typeOf[GridCoverage2D]) {
expr => input => expr.toRaster(input)
} else if (InferredRasterExpression.isRasterType(t)) {
InferredRasterExpression.rasterExtractor
} else if (t =:= typeOf[Array[Double]]) {
expr => input => expr.eval(input).asInstanceOf[ArrayData].toDoubleArray()
} else if (t =:= typeOf[String]) {
Expand Down Expand Up @@ -156,14 +150,8 @@ object InferredTypes {
} else {
null
}
} else if (t =:= typeOf[GridCoverage2D]) {
output => {
if (output != null) {
output.asInstanceOf[GridCoverage2D].serialize
} else {
null
}
}
} else if (InferredRasterExpression.isRasterType(t)) {
InferredRasterExpression.rasterSerializer
} else if (t =:= typeOf[String]) {
output =>
if (output != null) {
Expand Down Expand Up @@ -194,19 +182,8 @@ object InferredTypes {
} else {
null
}
} else if (t =:= typeOf[Array[GridCoverage2D]]) {
output =>
if (output != null) {
val rasters = output.asInstanceOf[Array[GridCoverage2D]]
val serialized = rasters.map { raster =>
val serialized = raster.serialize
raster.dispose(true)
serialized
}
ArrayData.toArrayData(serialized)
} else {
null
}
} else if (InferredRasterExpression.isRasterArrayType(t)) {
InferredRasterExpression.rasterArraySerializer
} else if (t =:= typeOf[Option[Boolean]]) {
output =>
if (output != null) {
Expand All @@ -224,10 +201,10 @@ object InferredTypes {
GeometryUDT
} else if (t =:= typeOf[Array[Geometry]] || t =:= typeOf[java.util.List[Geometry]]) {
DataTypes.createArrayType(GeometryUDT)
} else if (t =:= typeOf[GridCoverage2D]) {
RasterUDT
} else if (t =:= typeOf[Array[GridCoverage2D]]) {
DataTypes.createArrayType(RasterUDT)
} else if (InferredRasterExpression.isRasterType(t)) {
InferredRasterExpression.rasterUDT
} else if (InferredRasterExpression.isRasterArrayType(t)) {
InferredRasterExpression.rasterUDTArray
} else if (t =:= typeOf[java.lang.Double]) {
DoubleType
} else if (t =:= typeOf[java.lang.Integer]) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.spark.sql.sedona_sql.expressions

import org.apache.sedona.sql.utils.GeoToolsCoverageAvailability.isGeoToolsAvailable
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.types.{ArrayType, UserDefinedType}

import scala.reflect.runtime.universe.{Type, typeOf}

object InferredRasterExpression {
def isRasterType(t: Type): Boolean =
isGeoToolsAvailable && InferrableRasterTypes.isRasterType(t)

def isRasterArrayType(t: Type): Boolean =
isGeoToolsAvailable && InferrableRasterTypes.isRasterArrayType(t)

def rasterUDT: UserDefinedType[_] = if (isGeoToolsAvailable) {
InferrableRasterTypes.rasterUDT
} else {
null
}

def rasterUDTArray: ArrayType = if (isGeoToolsAvailable) {
InferrableRasterTypes.rasterUDTArray
} else {
null
}

val rasterExtractor: Expression => InternalRow => Any = if (isGeoToolsAvailable) {
InferrableRasterTypes.rasterExtractor
} else {
_ => _ => null
}

val rasterSerializer: Any => Any = if (isGeoToolsAvailable) {
InferrableRasterTypes.rasterSerializer
} else {
(_: Any) => null
}

val rasterArraySerializer: Any => Any = if (isGeoToolsAvailable) {
InferrableRasterTypes.rasterArraySerializer
} else {
(_: Any) => null
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,10 @@ package org.apache.spark.sql.sedona_sql.expressions
import org.apache.sedona.sql.utils.GeometrySerializer
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.types.{ByteType, DataTypes}
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.unsafe.types.UTF8String
import org.locationtech.jts.geom.{Geometry, GeometryFactory, Point}

import java.util

object implicits {

implicit class InputExpressionEnhancer(inputExpression: Expression) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.sedona.common.raster.GeometryFunctions
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.sedona_sql.expressions.InferredExpression
import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._
import org.apache.spark.sql.sedona_sql.expressions.InferrableRasterTypes._

case class RS_ConvexHull(inputExpressions: Seq[Expression]) extends InferredExpression(GeometryFunctions.convexHull _) {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._
import org.apache.spark.sql.sedona_sql.expressions.InferrableRasterTypes._
import org.apache.spark.sql.sedona_sql.expressions.InferredExpression

/// Calculate Normalized Difference between two bands
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package org.apache.spark.sql.sedona_sql.expressions.raster
import org.apache.sedona.common.raster.PixelFunctionEditors
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._
import org.apache.spark.sql.sedona_sql.expressions.InferrableRasterTypes._
import org.apache.spark.sql.sedona_sql.expressions.InferredExpression

case class RS_SetValues(inputExpressions: Seq[Expression]) extends InferredExpression(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.sedona_sql.UDT.{GeometryUDT, RasterUDT}
import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._
import org.apache.spark.sql.sedona_sql.expressions.InferrableRasterTypes._
import org.apache.spark.sql.sedona_sql.expressions.InferredExpression
import org.apache.spark.sql.sedona_sql.expressions.raster.implicits.RasterInputExpressionEnhancer
import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType, DoubleType, IntegerType, StructType}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package org.apache.spark.sql.sedona_sql.expressions.raster
import org.apache.sedona.common.raster.RasterAccessors
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._
import org.apache.spark.sql.sedona_sql.expressions.InferrableRasterTypes._
import org.apache.spark.sql.sedona_sql.expressions.InferredExpression

case class RS_NumBands(inputExpressions: Seq[Expression]) extends InferredExpression(RasterAccessors.numBands _) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.sedona_sql.UDT.RasterUDT
import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._
import org.apache.spark.sql.sedona_sql.expressions.InferrableRasterTypes._
import org.apache.spark.sql.sedona_sql.expressions.raster.implicits.RasterInputExpressionEnhancer
import org.apache.spark.sql.sedona_sql.expressions.InferredExpression
import org.geotools.coverage.grid.GridCoverage2D
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package org.apache.spark.sql.sedona_sql.expressions.raster
import org.apache.sedona.common.raster.RasterBandEditors
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._
import org.apache.spark.sql.sedona_sql.expressions.InferrableRasterTypes._
import org.apache.spark.sql.sedona_sql.expressions.InferredExpression

case class RS_SetBandNoDataValue(inputExpressions: Seq[Expression]) extends InferredExpression(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.{CreateArray, Expression, Gener
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.sedona_sql.UDT.RasterUDT
import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._
import org.apache.spark.sql.sedona_sql.expressions.InferrableRasterTypes._
import org.apache.spark.sql.sedona_sql.expressions.InferredExpression
import org.apache.spark.sql.sedona_sql.expressions.raster.implicits.{RasterEnhancer, RasterInputExpressionEnhancer}
import org.apache.spark.sql.types.{ArrayType, BooleanType, Decimal, IntegerType, NullType, StructType}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package org.apache.spark.sql.sedona_sql.expressions.raster
import org.apache.sedona.common.raster.RasterEditors
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._
import org.apache.spark.sql.sedona_sql.expressions.InferrableRasterTypes._
import org.apache.spark.sql.sedona_sql.expressions.InferredExpression

case class RS_SetSRID(inputExpressions: Seq[Expression]) extends InferredExpression(RasterEditors.setSrid _) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package org.apache.spark.sql.sedona_sql.expressions.raster
import org.apache.sedona.common.raster.RasterOutputs
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.sedona_sql.expressions.InferrableFunctionConverter._
import org.apache.spark.sql.sedona_sql.expressions.InferrableRasterTypes._
import org.apache.spark.sql.sedona_sql.expressions.InferredExpression

case class RS_AsGeoTiff(inputExpressions: Seq[Expression])
Expand Down
Loading

0 comments on commit 64570db

Please sign in to comment.