From e371aaa4f09376ff558d27e10c3d74e0c02ad992 Mon Sep 17 00:00:00 2001 From: Eduardo Ruiz Date: Mon, 1 Aug 2022 16:55:18 +0200 Subject: [PATCH 1/2] feat: [+] remaining non aggregate functions (resolves #68) --- .github/labeler.yml | 2 +- .../scala/doric/syntax/NumericColumns.scala | 24 ++++++++++ .../scala/doric/syntax/NumericColumns32.scala | 9 ++++ .../doric/syntax/NumericOperationsSpec.scala | 47 ++++++++++++++++++- .../scala/doric/syntax/NumericUtilsSpec.scala | 26 ++++++++++ 5 files changed, 106 insertions(+), 2 deletions(-) diff --git a/.github/labeler.yml b/.github/labeler.yml index dc2f6e860..c5f81374d 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -2,7 +2,7 @@ CI/CD: - .github/**/*.yml dependencies: - - any: [ '.github/dependabot.yml', 'build.sbt', 'project/**/*', '.scala-steward.conf', '.scalafix.conf', '.scalafmt.conf' ] + - any: [ '.github/dependabot.yml', 'build.sbt', 'project/**/*', '.scala-steward.conf', '.scalafix.conf', '.scalafmt.conf', 'project/build.properties' ] documentation: - any: [ 'docs/**/*', 'notebooks/**/*', '*.md' ] diff --git a/core/src/main/scala/doric/syntax/NumericColumns.scala b/core/src/main/scala/doric/syntax/NumericColumns.scala index 34432fa04..607e214c2 100644 --- a/core/src/main/scala/doric/syntax/NumericColumns.scala +++ b/core/src/main/scala/doric/syntax/NumericColumns.scala @@ -411,6 +411,21 @@ private[syntax] trait NumericColumns { * @see [[org.apache.spark.sql.functions.tanh(e:org\.apache\.spark\.sql\.Column)* org.apache.spark.sql.functions.tanh]] */ def tanh: DoubleColumn = column.elem.map(f.tanh).toDC + + /** + * Unary minus, i.e. negate the expression. + * + * @example {{{ + * // Select the amount column and negates all values. + * // Scala: + * df.select( -df("amount") ) + * }}} + * + * @todo DayTimeIntervalType & YearMonthIntervalType + * @group Numeric Type + * @see [[org.apache.spark.sql.functions.negate]] + */ + def negate: DoricColumn[T] = column.elem.map(f.negate).toDC } /** @@ -572,6 +587,15 @@ private[syntax] trait NumericColumns { def round(scale: IntegerColumn): DoricColumn[T] = (column.elem, scale.elem) .mapN((c, s) => new Column(Round(c.expr, s.expr))) .toDC + + /** + * Returns col1 if it is not NaN, or col2 if col1 is NaN. + * + * @group Numeric Type + * @see [[org.apache.spark.sql.functions.nanvl]] + */ + def naNvl(col2: DoricColumn[T]): DoricColumn[T] = + (column.elem, col2.elem).mapN(f.nanvl).toDC } } diff --git a/core/src/main/spark_3.2_3.3/scala/doric/syntax/NumericColumns32.scala b/core/src/main/spark_3.2_3.3/scala/doric/syntax/NumericColumns32.scala index adbebe9ac..eb6504e6e 100644 --- a/core/src/main/spark_3.2_3.3/scala/doric/syntax/NumericColumns32.scala +++ b/core/src/main/spark_3.2_3.3/scala/doric/syntax/NumericColumns32.scala @@ -3,6 +3,7 @@ package syntax import cats.implicits._ import org.apache.spark.sql.Column +import org.apache.spark.sql.{functions => f} import org.apache.spark.sql.catalyst.expressions.{ShiftLeft, ShiftRight, ShiftRightUnsigned} private[syntax] trait NumericColumns32 { @@ -46,6 +47,14 @@ private[syntax] trait NumericColumns32 { (column.elem, numBits.elem) .mapN((c, n) => new Column(ShiftRightUnsigned(c.expr, n.expr))) .toDC + + /** + * Computes bitwise NOT (~) of a number. + * + * @group Numeric Type + * @see [[org.apache.spark.sql.functions.bitwise_not]] + */ + def bitwiseNot: DoricColumn[T] = column.elem.map(f.bitwise_not).toDC } } diff --git a/core/src/test/scala/doric/syntax/NumericOperationsSpec.scala b/core/src/test/scala/doric/syntax/NumericOperationsSpec.scala index 02dc17973..f476cab17 100644 --- a/core/src/test/scala/doric/syntax/NumericOperationsSpec.scala +++ b/core/src/test/scala/doric/syntax/NumericOperationsSpec.scala @@ -3,7 +3,7 @@ package syntax import doric.types.{NumericType, SparkType} import doric.types.SparkType.Primitive -import org.apache.spark.sql.catalyst.expressions.{ShiftLeft, ShiftRight, ShiftRightUnsigned} +import org.apache.spark.sql.catalyst.expressions.{BitwiseNot, ShiftLeft, ShiftRight, ShiftRightUnsigned} import org.apache.spark.sql.{Column, DataFrame, SparkSession, functions => f} import org.scalatest.funspec.AnyFunSpecLike @@ -333,6 +333,15 @@ trait NumericOperationsSpec f.tanh ) } + + it(s"negate function $numTypeStr") { + testDoricSpark[T, T]( + List(Some(-1), Some(1), Some(2), None), + List(Some(1), Some(-1), Some(2), None), + _.negate, + f.negate + ) + } } } @@ -434,6 +443,19 @@ trait NumericOperationsSpec shiftRightUnsignedBefore32(_, f.lit(numBits)) ) } + + it(s"bitwiseNot function $numTypeStr") { + // Aux function as it is deprecated since 3.2, otherwise specs would get complicated + val bitwiseNotBefore32: Column => Column = + col => new Column(BitwiseNot(col.expr)) + val numBits = 2 + testDoricSpark[T, T]( + List(Some(0), Some(4), Some(20), None), + List(Some(0), Some(1), Some(5), None), + _.bitwiseNot, + bitwiseNotBefore32 + ) + } } } @@ -510,6 +532,29 @@ trait NumericOperationsSpec f.round(_, 2) ) } + + it(s"naNvl function with param $numTypeStr") { + testDoricSparkDecimals2[T, T, T]( + List( + (Some(-1.466f), Some(-2.0f)), + (Some(0f), Some(0.7111f)), + (Some(1f / 0f), Some(1.0f)), + (None, Some(1.0f)), + (Some(1.0f), None), + (None, None) + ), + List( + Some(-1.466f), + Some(0f), + Some(1.0f), + Some(1.0f), + Some(1.0f), + None + ), + _.naNvl(_), + f.nanvl + ) + } } } } diff --git a/core/src/test/scala/doric/syntax/NumericUtilsSpec.scala b/core/src/test/scala/doric/syntax/NumericUtilsSpec.scala index dd56ba1bf..8b5e34299 100644 --- a/core/src/test/scala/doric/syntax/NumericUtilsSpec.scala +++ b/core/src/test/scala/doric/syntax/NumericUtilsSpec.scala @@ -105,4 +105,30 @@ protected trait NumericUtilsSpec extends TypedColumnTest { ) } + def testDoricSparkDecimals2[ + T1: Primitive: ClassTag: TypeTag, + T2: Primitive: ClassTag: TypeTag, + O: Primitive: ClassTag: TypeTag + ]( + input: List[(Option[Float], Option[Float])], + output: List[Option[O]], + doricFun: (DoricColumn[T1], DoricColumn[T2]) => DoricColumn[O], + sparkFun: (Column, Column) => Column + )(implicit + spark: SparkSession, + funT1: FromFloat[T1], + funT2: FromFloat[T2] + ): Unit = { + import spark.implicits._ + val df = input + .map { case (x, y) => (x.map(funT1), y.map(funT2)) } + .toDF("col1", "col2") + + df.testColumns2("col1", "col2")( + (c1, c2) => doricFun(col[T1](c1), col[T2](c2)), + (c1, c2) => sparkFun(f.col(c1), f.col(c2)), + output + ) + } + } From e43dda52cfb6af49b1ca4d07bf80a45a5cde9c01 Mon Sep 17 00:00:00 2001 From: Eduardo Ruiz Date: Wed, 10 Aug 2022 16:53:41 +0200 Subject: [PATCH 2/2] fix: [~] bitwiseNot, & assert equalities correctly --- .../doric/syntax/NumericColumns2_31.scala | 9 +++ core/src/test/scala/doric/Equalities.scala | 14 ++++- .../test/scala/doric/TypedColumnTest.scala | 5 +- .../doric/syntax/AggregationColumnsSpec.scala | 60 +++++++++---------- .../doric/syntax/NumericOperationsSpec.scala | 44 ++++++++------ .../scala/doric/syntax/NumericUtilsSpec.scala | 9 +-- .../scala/doric/syntax/Numeric31Spec.scala | 1 + 7 files changed, 83 insertions(+), 59 deletions(-) diff --git a/core/src/main/spark_2.4_3.0_3.1/scala/doric/syntax/NumericColumns2_31.scala b/core/src/main/spark_2.4_3.0_3.1/scala/doric/syntax/NumericColumns2_31.scala index 16ed3916f..64b97cb5e 100644 --- a/core/src/main/spark_2.4_3.0_3.1/scala/doric/syntax/NumericColumns2_31.scala +++ b/core/src/main/spark_2.4_3.0_3.1/scala/doric/syntax/NumericColumns2_31.scala @@ -3,6 +3,7 @@ package syntax import cats.implicits._ import org.apache.spark.sql.Column +import org.apache.spark.sql.{functions => f} import org.apache.spark.sql.catalyst.expressions.{ShiftLeft, ShiftRight, ShiftRightUnsigned} private[syntax] trait NumericColumns2_31 { @@ -46,6 +47,14 @@ private[syntax] trait NumericColumns2_31 { (column.elem, numBits.elem) .mapN((c, n) => new Column(ShiftRightUnsigned(c.expr, n.expr))) .toDC + + /** + * Computes bitwise NOT (~) of a number. + * + * @group Numeric Type + * @see [[org.apache.spark.sql.functions.bitwiseNOT]] + */ + def bitwiseNot: DoricColumn[T] = column.elem.map(f.bitwiseNOT).toDC } } diff --git a/core/src/test/scala/doric/Equalities.scala b/core/src/test/scala/doric/Equalities.scala index 123ad2a17..45ab56509 100644 --- a/core/src/test/scala/doric/Equalities.scala +++ b/core/src/test/scala/doric/Equalities.scala @@ -27,16 +27,24 @@ object Equalities { } } + private lazy val tolerance = 0.00001 implicit val eqDouble: Equality[Double] = new Equality[Double] { override def areEqual(a: Double, b: Any): Boolean = (a, b) match { - case (x: Double, y: Double) => x === y +- 0.00001 + case (x: Double, y: Double) => x === y +- tolerance case _ => false } } + implicit val eqFloat: Equality[Float] = new Equality[Float] { + override def areEqual(a: Float, b: Any): Boolean = (a, b) match { + case (x: Float, y: Float) => x === y +- tolerance.toFloat + case _ => false + } + } + implicit val eqBigDecimal: Equality[BigDecimal] = new Equality[BigDecimal] { override def areEqual(a: BigDecimal, b: Any): Boolean = (a, b) match { - case (x: BigDecimal, y: BigDecimal) => x === y +- 0.00001 + case (x: BigDecimal, y: BigDecimal) => x === y +- tolerance case _ => false } } @@ -46,7 +54,7 @@ object Equalities { override def areEqual(a: java.math.BigDecimal, b: Any): Boolean = (a, b) match { case (x: java.math.BigDecimal, y: java.math.BigDecimal) => - x >= (y - 0.00001) && x <= (y + 0.00001) + x >= (y - tolerance) && x <= (y + tolerance) case _ => false } } diff --git a/core/src/test/scala/doric/TypedColumnTest.scala b/core/src/test/scala/doric/TypedColumnTest.scala index f3e83422a..942b14fcb 100644 --- a/core/src/test/scala/doric/TypedColumnTest.scala +++ b/core/src/test/scala/doric/TypedColumnTest.scala @@ -103,7 +103,6 @@ trait TypedColumnTest extends Matchers with DatasetComparer { df: DataFrame, expected: List[Option[T]] ): Unit = { - import Equalities._ val eqCond: BooleanColumn = SparkType[T].dataType match { case _: MapType => @@ -141,10 +140,10 @@ trait TypedColumnTest extends Matchers with DatasetComparer { ) if (expected.nonEmpty) { - doricColumns.map { + assert(doricColumns.map { case Some(x: java.lang.Double) if x.isNaN => None case x => x - } === expected + } === expected) } } diff --git a/core/src/test/scala/doric/syntax/AggregationColumnsSpec.scala b/core/src/test/scala/doric/syntax/AggregationColumnsSpec.scala index 791e3525f..11131266b 100644 --- a/core/src/test/scala/doric/syntax/AggregationColumnsSpec.scala +++ b/core/src/test/scala/doric/syntax/AggregationColumnsSpec.scala @@ -28,7 +28,7 @@ class AggregationColumnsSpec "keyCol", sum(colInt("col1")), f.sum("col1"), - List(Some(6L), Some(3L)) + List(Some(3L), Some(6L)) ) } @@ -43,7 +43,7 @@ class AggregationColumnsSpec "keyCol", sum(colFloat("col1")), f.sum("col1"), - List(Some(6.5d), Some(3.0d)) + List(Some(3.0d), Some(6.5d)) ) } } @@ -62,7 +62,7 @@ class AggregationColumnsSpec "keyCol", count(colInt("col1")), f.count(f.col("col1")), - List(Some(2L), Some(1L)) + List(Some(1L), Some(2L)) ) } @@ -77,7 +77,7 @@ class AggregationColumnsSpec "keyCol", count("col1"), f.count("col1"), - List(Some(2L), Some(1L)) + List(Some(1L), Some(2L)) ) } } @@ -96,7 +96,7 @@ class AggregationColumnsSpec "keyCol", first(colInt("col1")), f.first(f.col("col1")), - List(Some(1), Some(3)) + List(Some(3), Some(1)) ) } @@ -111,7 +111,7 @@ class AggregationColumnsSpec "keyCol", first(colInt("col1"), ignoreNulls = true), f.first(f.col("col1"), ignoreNulls = true), - List(Some(5), Some(3)) + List(Some(3), Some(5)) ) } @@ -126,7 +126,7 @@ class AggregationColumnsSpec "keyCol", first(colInt("col1"), ignoreNulls = false), f.first(f.col("col1"), ignoreNulls = false), - List(None, Some(3)) + List(Some(3), None) ) } } @@ -145,7 +145,7 @@ class AggregationColumnsSpec "keyCol", last(colInt("col1")), f.last(f.col("col1")), - List(Some(5), Some(3)) + List(Some(3), Some(5)) ) } @@ -160,7 +160,7 @@ class AggregationColumnsSpec "keyCol", last(colInt("col1"), ignoreNulls = true), f.last(f.col("col1"), ignoreNulls = true), - List(Some(1), Some(3)) + List(Some(3), Some(1)) ) } @@ -175,7 +175,7 @@ class AggregationColumnsSpec "keyCol", last(colInt("col1"), ignoreNulls = false), f.last(f.col("col1"), ignoreNulls = false), - List(None, Some(3)) + List(Some(3), None) ) } } @@ -194,7 +194,7 @@ class AggregationColumnsSpec "keyCol", aproxCountDistinct("col1"), f.approx_count_distinct("col1"), - List(Some(2L), Some(1L)) + List(Some(1L), Some(2L)) ) } @@ -211,7 +211,7 @@ class AggregationColumnsSpec "keyCol", aproxCountDistinct("col1", 0.05), f.approx_count_distinct("col1", 0.05), - List(Some(2L), Some(1L)) + List(Some(1L), Some(2L)) ) } @@ -226,7 +226,7 @@ class AggregationColumnsSpec "keyCol", aproxCountDistinct(colInt("col1")), f.approx_count_distinct(f.col("col1")), - List(Some(2L), Some(1L)) + List(Some(1L), Some(2L)) ) } @@ -243,7 +243,7 @@ class AggregationColumnsSpec "keyCol", aproxCountDistinct(colInt("col1"), 0.05), f.approx_count_distinct(f.col("col1"), 0.05), - List(Some(2L), Some(1L)) + List(Some(1L), Some(2L)) ) } } @@ -281,7 +281,7 @@ class AggregationColumnsSpec "keyCol", collectList(colInt("col1")), f.collect_list(f.col("col1")), - List(Some(Array(1, 5)), Some(Array(3))) + List(Some(Array(3)), Some(Array(1, 5))) ) } } @@ -301,7 +301,7 @@ class AggregationColumnsSpec "keyCol", collectSet(colInt("col1")), f.collect_set(f.col("col1")), - List(Some(Array(1, 5)), Some(Array(3))) + List(Some(Array(3)), Some(Array(1, 5))) ) } } @@ -341,7 +341,7 @@ class AggregationColumnsSpec "keyCol", countDistinct(colDouble("col1"), colString("col2")), f.countDistinct(f.col("col1"), f.col("col2")), - List(Some(3L), Some(1L)) + List(Some(1L), Some(3L)) ) } @@ -356,7 +356,7 @@ class AggregationColumnsSpec "keyCol", countDistinct("col1", "col2"), f.countDistinct("col1", "col2"), - List(Some(2L), Some(1L)) + List(Some(1L), Some(2L)) ) } } @@ -432,7 +432,7 @@ class AggregationColumnsSpec "keyCol", max(colDouble("col1")), f.max(f.col("col1")), - List(Some(4.0), Some(6.0)) + List(Some(6.0), Some(4.0)) ) } } @@ -451,7 +451,7 @@ class AggregationColumnsSpec "keyCol", min(colDouble("col1")), f.min(f.col("col1")), - List(Some(3.0), Some(6.0)) + List(Some(6.0), Some(3.0)) ) } } @@ -470,7 +470,7 @@ class AggregationColumnsSpec "keyCol", mean(colDouble("col1")), f.mean(f.col("col1")), - List(Some(3.5), Some(6.0)) + List(Some(6.0), Some(3.5)) ) } } @@ -491,7 +491,7 @@ class AggregationColumnsSpec "keyCol", skewness(colDouble("col1")), f.skewness(f.col("col1")), - List(Some(1.1135657469022011), None) + List(None, Some(1.1135657469022011)) ) } } @@ -511,7 +511,7 @@ class AggregationColumnsSpec "keyCol", stdDev(colDouble("col1")), f.stddev(f.col("col1")), - List(Some(1.0), None) + List(None, Some(1.0)) ) } @@ -527,7 +527,7 @@ class AggregationColumnsSpec "keyCol", stdDevSamp(colDouble("col1")), f.stddev_samp(f.col("col1")), - List(Some(1.0), None) + List(None, Some(1.0)) ) } } @@ -547,7 +547,7 @@ class AggregationColumnsSpec "keyCol", stdDevPop(colDouble("col1")), f.stddev_pop(f.col("col1")), - List(Some(0.816496580927726), Some(0.0)) + List(Some(0.0), Some(0.816496580927726)) ) } } @@ -569,7 +569,7 @@ class AggregationColumnsSpec new Column( Sum(f.col("col1").expr).toAggregateExpression(isDistinct = true) ), - List(Some(4L), Some(6L)) + List(Some(6L), Some(4L)) ) } @@ -587,7 +587,7 @@ class AggregationColumnsSpec new Column( Sum(f.col("col1").expr).toAggregateExpression(isDistinct = true) ), - List(Some(4.0), Some(6.0)) + List(Some(6.0), Some(4.0)) ) } } @@ -607,7 +607,7 @@ class AggregationColumnsSpec "keyCol", variance(colDouble("col1")), f.variance(f.col("col1")), - List(Some(1.0), None) + List(None, Some(1.0)) ) } @@ -623,7 +623,7 @@ class AggregationColumnsSpec "keyCol", varSamp(colDouble("col1")), f.var_samp(f.col("col1")), - List(Some(1.0), None) + List(None, Some(1.0)) ) } } @@ -643,7 +643,7 @@ class AggregationColumnsSpec "keyCol", varPop(colDouble("col1")), f.var_pop(f.col("col1")), - List(Some(0.6666666666666666), Some(0.0)) + List(Some(0.0), Some(0.6666666666666666)) ) } } diff --git a/core/src/test/scala/doric/syntax/NumericOperationsSpec.scala b/core/src/test/scala/doric/syntax/NumericOperationsSpec.scala index f476cab17..eeec71565 100644 --- a/core/src/test/scala/doric/syntax/NumericOperationsSpec.scala +++ b/core/src/test/scala/doric/syntax/NumericOperationsSpec.scala @@ -1,10 +1,12 @@ package doric package syntax -import doric.types.{NumericType, SparkType} +import Equalities._ import doric.types.SparkType.Primitive +import doric.types.{NumericType, SparkType} import org.apache.spark.sql.catalyst.expressions.{BitwiseNot, ShiftLeft, ShiftRight, ShiftRightUnsigned} import org.apache.spark.sql.{Column, DataFrame, SparkSession, functions => f} +import org.scalactic.Equality import org.scalatest.funspec.AnyFunSpecLike import scala.reflect.ClassTag @@ -17,7 +19,7 @@ trait NumericOperationsSpec def df: DataFrame import scala.reflect.runtime.universe._ - def test[T: NumericType: Primitive: ClassTag: TypeTag]()(implicit + def test[T: NumericType: Primitive: ClassTag: TypeTag: Equality]()(implicit spark: SparkSession, fun: FromInt[T] ): Unit = { @@ -84,7 +86,7 @@ trait NumericOperationsSpec it(s"atan function $numTypeStr") { testDoricSpark[T, Double]( List(Some(-1), Some(1), Some(2), None), - List(Some(-0.78538), Some(0.78540), Some(1.10715), None), + List(Some(-0.785398), Some(0.78540), Some(1.10715), None), _.atan, f.atan ) @@ -337,7 +339,7 @@ trait NumericOperationsSpec it(s"negate function $numTypeStr") { testDoricSpark[T, T]( List(Some(-1), Some(1), Some(2), None), - List(Some(1), Some(-1), Some(2), None), + List(Some(1), Some(-1), Some(-2), None), _.negate, f.negate ) @@ -345,9 +347,9 @@ trait NumericOperationsSpec } } - def testIntegrals[T: IntegralType: ClassTag: TypeTag]()(implicit + def testIntegrals[T: IntegralType: ClassTag: TypeTag: Equality]()(implicit spark: SparkSession, - sparkTypeT: SparkType[T], + sparkTypeT: Primitive[T], fun: FromInt[T] ): Unit = { val numTypeStr = getClassName[T] @@ -448,10 +450,10 @@ trait NumericOperationsSpec // Aux function as it is deprecated since 3.2, otherwise specs would get complicated val bitwiseNotBefore32: Column => Column = col => new Column(BitwiseNot(col.expr)) - val numBits = 2 + testDoricSpark[T, T]( - List(Some(0), Some(4), Some(20), None), - List(Some(0), Some(1), Some(5), None), + List(Some(0), Some(4), Some(-20), None), + List(Some(-1), Some(-5), Some(19), None), _.bitwiseNot, bitwiseNotBefore32 ) @@ -459,8 +461,9 @@ trait NumericOperationsSpec } } - def testDecimals[T: NumWithDecimalsType: Primitive: ClassTag: TypeTag]()( - implicit + def testDecimals[ + T: NumWithDecimalsType: Primitive: ClassTag: TypeTag: Equality + ]()(implicit spark: SparkSession, fun: FromFloat[T] ): Unit = { @@ -478,10 +481,10 @@ trait NumericOperationsSpec } it(s"bRound function with param $numTypeStr") { - val scale = 2 + val scale = 5 testDoricSparkDecimals[T, T]( - List(Some(-0.2567f), Some(0.811f), Some(0.0f), None), - List(Some(-0.26f), Some(0.81f), Some(0.0f), None), + List(Some(-0.256777f), Some(0.811111f), Some(0.0f), None), + List(Some(-0.25678f), Some(0.81111f), Some(0.0f), None), _.bRound(scale.lit), f.bround(_, scale) ) @@ -525,11 +528,12 @@ trait NumericOperationsSpec } it(s"round function with param $numTypeStr") { + val scale = 5 testDoricSparkDecimals[T, T]( - List(Some(-1.466f), Some(0.7111f), Some(1.0f), None), - List(Some(-1.47f), Some(0.71f), Some(1.0f), None), - _.round(2.lit), - f.round(_, 2) + List(Some(-1.466666f), Some(0.7111111f), Some(1.0f), None), + List(Some(-1.46667f), Some(0.71111f), Some(1.0f), None), + _.round(scale.lit), + f.round(_, scale) ) } @@ -538,7 +542,8 @@ trait NumericOperationsSpec List( (Some(-1.466f), Some(-2.0f)), (Some(0f), Some(0.7111f)), - (Some(1f / 0f), Some(1.0f)), + (Some(Float.NaN), Some(1.0f)), + (Some(1.0f), Some(Float.NaN)), (None, Some(1.0f)), (Some(1.0f), None), (None, None) @@ -548,6 +553,7 @@ trait NumericOperationsSpec Some(0f), Some(1.0f), Some(1.0f), + None, Some(1.0f), None ), diff --git a/core/src/test/scala/doric/syntax/NumericUtilsSpec.scala b/core/src/test/scala/doric/syntax/NumericUtilsSpec.scala index 8b5e34299..0277825ed 100644 --- a/core/src/test/scala/doric/syntax/NumericUtilsSpec.scala +++ b/core/src/test/scala/doric/syntax/NumericUtilsSpec.scala @@ -5,6 +5,7 @@ import doric.types.SparkType import doric.{DoricColumn, TypedColumnTest} import doric.types.SparkType.Primitive import org.apache.spark.sql.{Column, DataFrame, SparkSession, functions => f} +import org.scalactic.Equality import scala.reflect.{ClassTag, classTag} import scala.reflect.runtime.universe._ @@ -37,7 +38,7 @@ protected trait NumericUtilsSpec extends TypedColumnTest { def testDoricSpark[ T: SparkType: ClassTag: TypeTag, - O: SparkType: ClassTag: TypeTag + O: SparkType: ClassTag: TypeTag: Equality ]( input: List[Option[Int]], output: List[Option[O]], @@ -60,7 +61,7 @@ protected trait NumericUtilsSpec extends TypedColumnTest { def testDoricSpark2[ T1: Primitive: ClassTag: TypeTag, T2: Primitive: ClassTag: TypeTag, - O: Primitive: ClassTag: TypeTag + O: Primitive: ClassTag: TypeTag: Equality ]( input: List[(Option[Int], Option[Int])], output: List[Option[O]], @@ -85,7 +86,7 @@ protected trait NumericUtilsSpec extends TypedColumnTest { def testDoricSparkDecimals[ T: Primitive: ClassTag: TypeTag, - O: Primitive: ClassTag: TypeTag + O: Primitive: ClassTag: TypeTag: Equality ]( input: List[Option[Float]], output: List[Option[O]], @@ -108,7 +109,7 @@ protected trait NumericUtilsSpec extends TypedColumnTest { def testDoricSparkDecimals2[ T1: Primitive: ClassTag: TypeTag, T2: Primitive: ClassTag: TypeTag, - O: Primitive: ClassTag: TypeTag + O: Primitive: ClassTag: TypeTag: Equality ]( input: List[(Option[Float], Option[Float])], output: List[Option[O]], diff --git a/core/src/test/spark_3.1_3.2_3.3/scala/doric/syntax/Numeric31Spec.scala b/core/src/test/spark_3.1_3.2_3.3/scala/doric/syntax/Numeric31Spec.scala index a7c8d195c..d281c3925 100644 --- a/core/src/test/spark_3.1_3.2_3.3/scala/doric/syntax/Numeric31Spec.scala +++ b/core/src/test/spark_3.1_3.2_3.3/scala/doric/syntax/Numeric31Spec.scala @@ -1,6 +1,7 @@ package doric package syntax +import Equalities._ import doric.types.NumericType import doric.types.SparkType.Primitive import org.apache.spark.sql.{DataFrame, SparkSession, functions => f}