Skip to content

Commit

Permalink
support rand_gamma on spark 3.4+
Browse files Browse the repository at this point in the history
  • Loading branch information
zeotuan committed Oct 6, 2024
1 parent f06904f commit 33746ab
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/core-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
strategy:
fail-fast: false
matrix:
spark: ["3.0.1", "3.1.3", "3.2.4", "3.3.4"]
spark: ["3.0.1", "3.1.3", "3.2.4", "3.3.4", "3.4.3", "3.5.3"]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/unsafe-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
strategy:
fail-fast: false
matrix:
spark: ["3.2.4", "3.3.4"]
spark: ["3.2.4", "3.3.4", "3.4.3", "3.5.3"]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
Expand Down
14 changes: 14 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import scala.language.postfixOps

Compile / scalafmtOnCompile := true

organization := "com.github.mrpowers"
Expand Down Expand Up @@ -48,6 +50,18 @@ lazy val unsafe = (project in file("unsafe"))
.settings(
commonSettings,
name := "unsafe",
Compile / unmanagedSourceDirectories ++= {
sparkVersion match {
case versionRegex(mayor, minor, _) =>
(Compile / sourceDirectory).value ** s"*spark_*$mayor.$minor*" / "scala" get
}
},
Test / unmanagedSourceDirectories ++= {
sparkVersion match {
case versionRegex(mayor, minor, _) =>
(Compile / sourceDirectory).value ** s"*spark_*$mayor.$minor*" / "scala" get
}
},
)

testFrameworks += new TestFramework("com.github.mrpowers.spark.daria.CustomFramework")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,13 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.commons.math3.distribution.GammaDistribution
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.RandGamma.defaultSeedExpression
import org.apache.spark.sql.catalyst.analysis.UnresolvedSeed
import org.apache.spark.sql.catalyst.expressions.codegen.FalseLiteral
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenerator, CodegenContext, ExprCode}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandomAdapted

import scala.util.{Success, Try}

case class RandGamma(child: Expression, shape: Expression, scale: Expression, hideSeed: Boolean = false)
extends TernaryExpression
with ExpectsInputTypes
Expand Down Expand Up @@ -43,7 +40,7 @@ case class RandGamma(child: Expression, shape: Expression, scale: Expression, hi
distribution = new GammaDistribution(new XORShiftRandomAdapted(seed + partitionIndex), shapeVal, scaleVal)
}

def this() = this(defaultSeedExpression, Literal(1.0, DoubleType), Literal(1.0, DoubleType), true)
def this() = this(UnresolvedSeed, Literal(1.0, DoubleType), Literal(1.0, DoubleType), true)

def this(child: Expression, shape: Expression, scale: Expression) = this(child, shape, scale, false)

Expand Down Expand Up @@ -87,10 +84,4 @@ case class RandGamma(child: Expression, shape: Expression, scale: Expression, hi
object RandGamma {
def apply(seed: Long, shape: Double, scale: Double): RandGamma =
RandGamma(Literal(seed, LongType), Literal(shape, DoubleType), Literal(scale, DoubleType))

def defaultSeedExpression: Expression =
Try(Class.forName("org.apache.spark.sql.catalyst.analysis.UnresolvedSeed")) match {
case Success(clazz) => clazz.getConstructor().newInstance().asInstanceOf[Expression]
case _ => Literal(Utils.random.nextLong(), LongType)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.commons.math3.distribution.GammaDistribution
import org.apache.spark.sql.catalyst.analysis.UnresolvedSeed
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.FalseLiteral
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenerator, CodegenContext, ExprCode}
import org.apache.spark.sql.types._
import org.apache.spark.util.random.XORShiftRandomAdapted

case class RandGamma(child: Expression, shape: Expression, scale: Expression, hideSeed: Boolean = false)
extends TernaryExpression
with ExpectsInputTypes
with Nondeterministic
with ExpressionWithRandomSeed {

def seedExpression: Expression = child

@transient protected lazy val seed: Long = seedExpression match {
case e if e.dataType == IntegerType => e.eval().asInstanceOf[Int]
case e if e.dataType == LongType => e.eval().asInstanceOf[Long]
}

@transient protected lazy val shapeVal: Double = shape.dataType match {
case IntegerType => shape.eval().asInstanceOf[Int]
case LongType => shape.eval().asInstanceOf[Long]
case FloatType | DoubleType => shape.eval().asInstanceOf[Double]
}

@transient protected lazy val scaleVal: Double = scale.dataType match {
case IntegerType => scale.eval().asInstanceOf[Int]
case LongType => scale.eval().asInstanceOf[Long]
case FloatType | DoubleType => scale.eval().asInstanceOf[Double]
}

@transient private var distribution: GammaDistribution = _

override protected def initializeInternal(partitionIndex: Int): Unit = {
distribution = new GammaDistribution(new XORShiftRandomAdapted(seed + partitionIndex), shapeVal, scaleVal)
}

def this() = this(UnresolvedSeed, Literal(1.0, DoubleType), Literal(1.0, DoubleType), true)

def this(child: Expression, shape: Expression, scale: Expression) = this(child, shape, scale, false)

def withNewSeed(seed: Long): RandGamma = RandGamma(Literal(seed, LongType), shape, scale, hideSeed)

protected def evalInternal(input: InternalRow): Double = distribution.sample()

def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val distributionClassName = classOf[GammaDistribution].getName
val rngClassName = classOf[XORShiftRandomAdapted].getName
val disTerm = ctx.addMutableState(distributionClassName, "distribution")
ctx.addPartitionInitializationStatement(
s"$disTerm = new $distributionClassName(new $rngClassName(${seed}L + partitionIndex), $shapeVal, $scaleVal);"
)
ev.copy(code = code"""
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $disTerm.sample();""", isNull = FalseLiteral)
}

def freshCopy(): RandGamma = RandGamma(child, shape, scale, hideSeed)

override def flatArguments: Iterator[Any] = Iterator(child, shape, scale)

override def prettyName: String = "rand_gamma"

override def sql: String = s"rand_gamma(${if (hideSeed) "" else s"${child.sql}, ${shape.sql}, ${scale.sql}"})"

override def stateful: Boolean = true

def inputTypes: Seq[AbstractDataType] = Seq(LongType, DoubleType, DoubleType)

def dataType: DataType = DoubleType

def first: Expression = child

def second: Expression = shape

def third: Expression = scale

protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression =
copy(child = newFirst, shape = newSecond, scale = newThird)
}

object RandGamma {
def apply(seed: Long, shape: Double, scale: Double): RandGamma =
RandGamma(Literal(seed, LongType), Literal(shape, DoubleType), Literal(scale, DoubleType))
}

0 comments on commit 33746ab

Please sign in to comment.