Skip to content

Commit

Permalink
Merge pull request #160 from zeotuan/ci-update
Browse files Browse the repository at this point in the history
- Run Test on Pull Request
- Separate native(unsafe) API into different project
  • Loading branch information
zeotuan authored Oct 5, 2024
2 parents c59d416 + abdf0b2 commit f06904f
Show file tree
Hide file tree
Showing 75 changed files with 213 additions and 135 deletions.
22 changes: 22 additions & 0 deletions .github/workflows/core-ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
name: core-ci

on:
push:
branches:
- main
pull_request:

jobs:
build:
strategy:
fail-fast: false
matrix:
spark: ["3.0.1", "3.1.3", "3.2.4", "3.3.4"]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- uses: olafurpg/setup-scala@v10
- name: Test
run: sbt -Dspark.testVersion=${{ matrix.spark }} +"project core" test
- name: Code Quality
run: sbt "project core" scalafmtCheckAll
10 changes: 5 additions & 5 deletions .github/workflows/ci.yml → .github/workflows/unsafe-ci.yml
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
name: ci
name: unsafe-ci

on:
push:
branches:
- main
pull_request:

jobs:
build:
strategy:
fail-fast: false
matrix:
scala: ["2.12.12"]
spark: ["3.0.1"]
spark: ["3.2.4", "3.3.4"]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- uses: olafurpg/setup-scala@v10
- name: Test
run: sbt -Dspark.testVersion=${{ matrix.spark }} ++${{ matrix.scala }} test
run: sbt -Dspark.testVersion=${{ matrix.spark }} +"project unsafe" test
- name: Code Quality
run: sbt scalafmtCheckAll
run: sbt "project unsafe" scalafmtCheckAll
2 changes: 1 addition & 1 deletion .scalafmt.conf
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
version = 2.6.3

lineEndings = preserve
align = more
maxColumn = 150
docstrings = JavaDoc
51 changes: 43 additions & 8 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,62 @@ organization := "com.github.mrpowers"
name := "spark-daria"

version := "1.2.3"

crossScalaVersions := Seq("2.12.15", "2.13.8")
scalaVersion := "2.12.15"

val sparkVersion = "3.2.1"
val versionRegex = """^(.*)\.(.*)\.(.*)$""".r

val scala2_13 = "2.13.14"
val scala2_12 = "2.12.20"

val sparkVersion = System.getProperty("spark.testVersion", "3.3.4")
crossScalaVersions := {
sparkVersion match {
case versionRegex("3", m, _) if m.toInt >= 2 => Seq(scala2_12, scala2_13)
case versionRegex("3", _, _) => Seq(scala2_12)
}
}

scalaVersion := crossScalaVersions.value.head

lazy val commonSettings = Seq(
javaOptions ++= {
Seq("-Xms512M", "-Xmx2048M", "-Duser.timezone=GMT") ++ (if (System.getProperty("java.version").startsWith("1.8.0"))
Seq("-XX:+CMSClassUnloadingEnabled")
else Seq.empty)
},
libraryDependencies ++= Seq(
"org.apache.spark" %% "spark-sql" % sparkVersion % "provided",
"org.apache.spark" %% "spark-mllib" % sparkVersion % "provided",
"com.github.mrpowers" %% "spark-fast-tests" % "1.1.0" % "test",
"com.lihaoyi" %% "utest" % "0.7.11" % "test",
"com.lihaoyi" %% "os-lib" % "0.8.0" % "test"
),
)

lazy val core = (project in file("core"))
.settings(
commonSettings,
name := "core",
)

lazy val unsafe = (project in file("unsafe"))
.settings(
commonSettings,
name := "unsafe",
)

libraryDependencies += "org.apache.spark" %% "spark-sql" % sparkVersion % "provided"
libraryDependencies += "org.apache.spark" %% "spark-mllib" % sparkVersion % "provided"
libraryDependencies += "com.github.mrpowers" %% "spark-fast-tests" % "1.1.0" % "test"
libraryDependencies += "com.lihaoyi" %% "utest" % "0.7.11" % "test"
libraryDependencies += "com.lihaoyi" %% "os-lib" % "0.8.0" % "test"
testFrameworks += new TestFramework("com.github.mrpowers.spark.daria.CustomFramework")

credentials += Credentials(Path.userHome / ".sbt" / "sonatype_credentials")

Test / fork := true

javaOptions ++= Seq("-Xms512M", "-Xmx2048M", "-XX:+CMSClassUnloadingEnabled", "-Duser.timezone=GMT")

licenses := Seq("MIT" -> url("http://opensource.org/licenses/MIT"))

homepage := Some(url("https://github.com/MrPowers/spark-daria"))

developers ++= List(
Developer("MrPowers", "Matthew Powers", "@MrPowers", url("https://github.com/MrPowers"))
)
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import org.apache.spark.sql.SparkSession
trait SparkSessionTestWrapper {

lazy val spark: SparkSession = {
SparkSession
val session = SparkSession
.builder()
.master("local")
.appName("spark session")
Expand All @@ -14,6 +14,8 @@ trait SparkSessionTestWrapper {
"1"
)
.getOrCreate()
session.sparkContext.setLogLevel("ERROR")
session
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -1047,22 +1047,6 @@ object TransformationsTest extends TestSuite with DataFrameComparer with ColumnC
}

'withParquetCompatibleColumnNames - {
"blows up if the column name is invalid for Parquet" - {
val df = spark
.createDF(
List(
("pablo")
),
List(
("Column That {Will} Break\t;", StringType, true)
)
)
val path = new java.io.File("./tmp/blowup/example").getCanonicalPath
val e = intercept[org.apache.spark.sql.AnalysisException] {
df.write.parquet(path)
}
}

"converts column names to be Parquet compatible" - {
val actualDF = spark
.createDF(
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.commons.math3.distribution.GammaDistribution
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.RandGamma.defaultSeedExpression
import org.apache.spark.sql.catalyst.expressions.codegen.FalseLiteral
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenerator, CodegenContext, ExprCode}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandomAdapted

import scala.util.{Success, Try}

case class RandGamma(child: Expression, shape: Expression, scale: Expression, hideSeed: Boolean = false)
extends TernaryExpression
with ExpectsInputTypes
with Stateful
with ExpressionWithRandomSeed {

def seedExpression: Expression = child

@transient protected lazy val seed: Long = seedExpression match {
case e if e.dataType == IntegerType => e.eval().asInstanceOf[Int]
case e if e.dataType == LongType => e.eval().asInstanceOf[Long]
}

@transient protected lazy val shapeVal: Double = shape.dataType match {
case IntegerType => shape.eval().asInstanceOf[Int]
case LongType => shape.eval().asInstanceOf[Long]
case FloatType | DoubleType => shape.eval().asInstanceOf[Double]
}

@transient protected lazy val scaleVal: Double = scale.dataType match {
case IntegerType => scale.eval().asInstanceOf[Int]
case LongType => scale.eval().asInstanceOf[Long]
case FloatType | DoubleType => scale.eval().asInstanceOf[Double]
}

@transient private var distribution: GammaDistribution = _

protected def initializeInternal(partitionIndex: Int): Unit = {
distribution = new GammaDistribution(new XORShiftRandomAdapted(seed + partitionIndex), shapeVal, scaleVal)
}

def this() = this(defaultSeedExpression, Literal(1.0, DoubleType), Literal(1.0, DoubleType), true)

def this(child: Expression, shape: Expression, scale: Expression) = this(child, shape, scale, false)

def withNewSeed(seed: Long): RandGamma = RandGamma(Literal(seed, LongType), shape, scale, hideSeed)

protected def evalInternal(input: InternalRow): Double = distribution.sample()

def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val distributionClassName = classOf[GammaDistribution].getName
val rngClassName = classOf[XORShiftRandomAdapted].getName
val disTerm = ctx.addMutableState(distributionClassName, "distribution")
ctx.addPartitionInitializationStatement(
s"$disTerm = new $distributionClassName(new $rngClassName(${seed}L + partitionIndex), $shapeVal, $scaleVal);"
)
ev.copy(code = code"""
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $disTerm.sample();""", isNull = FalseLiteral)
}

def freshCopy(): RandGamma = RandGamma(child, shape, scale, hideSeed)

override def flatArguments: Iterator[Any] = Iterator(child, shape, scale)

override def prettyName: String = "rand_gamma"

override def sql: String = s"rand_gamma(${if (hideSeed) "" else s"${child.sql}, ${shape.sql}, ${scale.sql}"})"

def inputTypes: Seq[AbstractDataType] = Seq(LongType, DoubleType, DoubleType)

def dataType: DataType = DoubleType

def first: Expression = child

def second: Expression = shape

def third: Expression = scale

protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression =
copy(child = newFirst, shape = newSecond, scale = newThird)
}

object RandGamma {
def apply(seed: Long, shape: Double, scale: Double): RandGamma =
RandGamma(Literal(seed, LongType), Literal(shape, DoubleType), Literal(scale, DoubleType))

def defaultSeedExpression: Expression =
Try(Class.forName("org.apache.spark.sql.catalyst.analysis.UnresolvedSeed")) match {
case Success(clazz) => clazz.getConstructor().newInstance().asInstanceOf[Expression]
case _ => Literal(Utils.random.nextLong(), LongType)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,18 @@ object functions {
private def withExpr(expr: Expression): Column = Column(expr)

def randGamma(seed: Long, shape: Double, scale: Double): Column = withExpr(RandGamma(seed, shape, scale)).alias("gamma_random")
def randGamma(shape: Double, scale: Double): Column = randGamma(Utils.random.nextLong, shape, scale)
def randGamma(): Column = randGamma(1.0, 1.0)
def randGamma(shape: Double, scale: Double): Column = randGamma(Utils.random.nextLong, shape, scale)
def randGamma(): Column = randGamma(1.0, 1.0)

def randLaplace(seed: Long, mu: Double, beta: Double): Column = {
val mu_ = lit(mu)
val mu_ = lit(mu)
val beta_ = lit(beta)
val u = rand(seed)
val u = rand(seed)
when(u < 0.5, mu_ + beta_ * log(lit(2) * u))
.otherwise(mu_ - beta_ * log(lit(2) * (lit(1) - u)))
.alias("laplace_random")
}

def randLaplace(mu: Double, beta: Double): Column = randLaplace(Utils.random.nextLong, mu, beta)
def randLaplace(): Column = randLaplace(0.0, 1.0)
def randLaplace(): Column = randLaplace(0.0, 1.0)
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class XORShiftRandomAdapted(init: Long) extends java.util.Random(init: Long) wit
nextSeed ^= (nextSeed >>> 35)
nextSeed ^= (nextSeed << 4)
seed = nextSeed
(nextSeed & ((1L << bits) -1)).asInstanceOf[Int]
(nextSeed & ((1L << bits) - 1)).asInstanceOf[Int]
}

override def setSeed(s: Long): Unit = {
Expand All @@ -29,4 +29,3 @@ class XORShiftRandomAdapted(init: Long) extends java.util.Random(init: Long) wit
this.seed = XORShiftRandom.hashSeed(RandomGeneratorFactory.convertToLong(seed))
}
}

Loading

0 comments on commit f06904f

Please sign in to comment.