Skip to content

Commit

Permalink
fix: [~] bitwiseNot, & assert equalities correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
eruizalo committed Aug 10, 2022
1 parent e371aaa commit e43dda5
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package syntax

import cats.implicits._
import org.apache.spark.sql.Column
import org.apache.spark.sql.{functions => f}
import org.apache.spark.sql.catalyst.expressions.{ShiftLeft, ShiftRight, ShiftRightUnsigned}

private[syntax] trait NumericColumns2_31 {
Expand Down Expand Up @@ -46,6 +47,14 @@ private[syntax] trait NumericColumns2_31 {
(column.elem, numBits.elem)
.mapN((c, n) => new Column(ShiftRightUnsigned(c.expr, n.expr)))
.toDC

/**
* Computes bitwise NOT (~) of a number.
*
* @group Numeric Type
* @see [[org.apache.spark.sql.functions.bitwiseNOT]]
*/
def bitwiseNot: DoricColumn[T] = column.elem.map(f.bitwiseNOT).toDC
}

}
14 changes: 11 additions & 3 deletions core/src/test/scala/doric/Equalities.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,24 @@ object Equalities {
}
}

private lazy val tolerance = 0.00001
implicit val eqDouble: Equality[Double] = new Equality[Double] {
override def areEqual(a: Double, b: Any): Boolean = (a, b) match {
case (x: Double, y: Double) => x === y +- 0.00001
case (x: Double, y: Double) => x === y +- tolerance
case _ => false
}
}

implicit val eqFloat: Equality[Float] = new Equality[Float] {
override def areEqual(a: Float, b: Any): Boolean = (a, b) match {
case (x: Float, y: Float) => x === y +- tolerance.toFloat
case _ => false
}
}

implicit val eqBigDecimal: Equality[BigDecimal] = new Equality[BigDecimal] {
override def areEqual(a: BigDecimal, b: Any): Boolean = (a, b) match {
case (x: BigDecimal, y: BigDecimal) => x === y +- 0.00001
case (x: BigDecimal, y: BigDecimal) => x === y +- tolerance
case _ => false
}
}
Expand All @@ -46,7 +54,7 @@ object Equalities {
override def areEqual(a: java.math.BigDecimal, b: Any): Boolean =
(a, b) match {
case (x: java.math.BigDecimal, y: java.math.BigDecimal) =>
x >= (y - 0.00001) && x <= (y + 0.00001)
x >= (y - tolerance) && x <= (y + tolerance)
case _ => false
}
}
Expand Down
5 changes: 2 additions & 3 deletions core/src/test/scala/doric/TypedColumnTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ trait TypedColumnTest extends Matchers with DatasetComparer {
df: DataFrame,
expected: List[Option[T]]
): Unit = {
import Equalities._

val eqCond: BooleanColumn = SparkType[T].dataType match {
case _: MapType =>
Expand Down Expand Up @@ -141,10 +140,10 @@ trait TypedColumnTest extends Matchers with DatasetComparer {
)

if (expected.nonEmpty) {
doricColumns.map {
assert(doricColumns.map {
case Some(x: java.lang.Double) if x.isNaN => None
case x => x
} === expected
} === expected)
}
}

Expand Down
60 changes: 30 additions & 30 deletions core/src/test/scala/doric/syntax/AggregationColumnsSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class AggregationColumnsSpec
"keyCol",
sum(colInt("col1")),
f.sum("col1"),
List(Some(6L), Some(3L))
List(Some(3L), Some(6L))
)
}

Expand All @@ -43,7 +43,7 @@ class AggregationColumnsSpec
"keyCol",
sum(colFloat("col1")),
f.sum("col1"),
List(Some(6.5d), Some(3.0d))
List(Some(3.0d), Some(6.5d))
)
}
}
Expand All @@ -62,7 +62,7 @@ class AggregationColumnsSpec
"keyCol",
count(colInt("col1")),
f.count(f.col("col1")),
List(Some(2L), Some(1L))
List(Some(1L), Some(2L))
)
}

Expand All @@ -77,7 +77,7 @@ class AggregationColumnsSpec
"keyCol",
count("col1"),
f.count("col1"),
List(Some(2L), Some(1L))
List(Some(1L), Some(2L))
)
}
}
Expand All @@ -96,7 +96,7 @@ class AggregationColumnsSpec
"keyCol",
first(colInt("col1")),
f.first(f.col("col1")),
List(Some(1), Some(3))
List(Some(3), Some(1))
)
}

Expand All @@ -111,7 +111,7 @@ class AggregationColumnsSpec
"keyCol",
first(colInt("col1"), ignoreNulls = true),
f.first(f.col("col1"), ignoreNulls = true),
List(Some(5), Some(3))
List(Some(3), Some(5))
)
}

Expand All @@ -126,7 +126,7 @@ class AggregationColumnsSpec
"keyCol",
first(colInt("col1"), ignoreNulls = false),
f.first(f.col("col1"), ignoreNulls = false),
List(None, Some(3))
List(Some(3), None)
)
}
}
Expand All @@ -145,7 +145,7 @@ class AggregationColumnsSpec
"keyCol",
last(colInt("col1")),
f.last(f.col("col1")),
List(Some(5), Some(3))
List(Some(3), Some(5))
)
}

Expand All @@ -160,7 +160,7 @@ class AggregationColumnsSpec
"keyCol",
last(colInt("col1"), ignoreNulls = true),
f.last(f.col("col1"), ignoreNulls = true),
List(Some(1), Some(3))
List(Some(3), Some(1))
)
}

Expand All @@ -175,7 +175,7 @@ class AggregationColumnsSpec
"keyCol",
last(colInt("col1"), ignoreNulls = false),
f.last(f.col("col1"), ignoreNulls = false),
List(None, Some(3))
List(Some(3), None)
)
}
}
Expand All @@ -194,7 +194,7 @@ class AggregationColumnsSpec
"keyCol",
aproxCountDistinct("col1"),
f.approx_count_distinct("col1"),
List(Some(2L), Some(1L))
List(Some(1L), Some(2L))
)
}

Expand All @@ -211,7 +211,7 @@ class AggregationColumnsSpec
"keyCol",
aproxCountDistinct("col1", 0.05),
f.approx_count_distinct("col1", 0.05),
List(Some(2L), Some(1L))
List(Some(1L), Some(2L))
)
}

Expand All @@ -226,7 +226,7 @@ class AggregationColumnsSpec
"keyCol",
aproxCountDistinct(colInt("col1")),
f.approx_count_distinct(f.col("col1")),
List(Some(2L), Some(1L))
List(Some(1L), Some(2L))
)
}

Expand All @@ -243,7 +243,7 @@ class AggregationColumnsSpec
"keyCol",
aproxCountDistinct(colInt("col1"), 0.05),
f.approx_count_distinct(f.col("col1"), 0.05),
List(Some(2L), Some(1L))
List(Some(1L), Some(2L))
)
}
}
Expand Down Expand Up @@ -281,7 +281,7 @@ class AggregationColumnsSpec
"keyCol",
collectList(colInt("col1")),
f.collect_list(f.col("col1")),
List(Some(Array(1, 5)), Some(Array(3)))
List(Some(Array(3)), Some(Array(1, 5)))
)
}
}
Expand All @@ -301,7 +301,7 @@ class AggregationColumnsSpec
"keyCol",
collectSet(colInt("col1")),
f.collect_set(f.col("col1")),
List(Some(Array(1, 5)), Some(Array(3)))
List(Some(Array(3)), Some(Array(1, 5)))
)
}
}
Expand Down Expand Up @@ -341,7 +341,7 @@ class AggregationColumnsSpec
"keyCol",
countDistinct(colDouble("col1"), colString("col2")),
f.countDistinct(f.col("col1"), f.col("col2")),
List(Some(3L), Some(1L))
List(Some(1L), Some(3L))
)
}

Expand All @@ -356,7 +356,7 @@ class AggregationColumnsSpec
"keyCol",
countDistinct("col1", "col2"),
f.countDistinct("col1", "col2"),
List(Some(2L), Some(1L))
List(Some(1L), Some(2L))
)
}
}
Expand Down Expand Up @@ -432,7 +432,7 @@ class AggregationColumnsSpec
"keyCol",
max(colDouble("col1")),
f.max(f.col("col1")),
List(Some(4.0), Some(6.0))
List(Some(6.0), Some(4.0))
)
}
}
Expand All @@ -451,7 +451,7 @@ class AggregationColumnsSpec
"keyCol",
min(colDouble("col1")),
f.min(f.col("col1")),
List(Some(3.0), Some(6.0))
List(Some(6.0), Some(3.0))
)
}
}
Expand All @@ -470,7 +470,7 @@ class AggregationColumnsSpec
"keyCol",
mean(colDouble("col1")),
f.mean(f.col("col1")),
List(Some(3.5), Some(6.0))
List(Some(6.0), Some(3.5))
)
}
}
Expand All @@ -491,7 +491,7 @@ class AggregationColumnsSpec
"keyCol",
skewness(colDouble("col1")),
f.skewness(f.col("col1")),
List(Some(1.1135657469022011), None)
List(None, Some(1.1135657469022011))
)
}
}
Expand All @@ -511,7 +511,7 @@ class AggregationColumnsSpec
"keyCol",
stdDev(colDouble("col1")),
f.stddev(f.col("col1")),
List(Some(1.0), None)
List(None, Some(1.0))
)
}

Expand All @@ -527,7 +527,7 @@ class AggregationColumnsSpec
"keyCol",
stdDevSamp(colDouble("col1")),
f.stddev_samp(f.col("col1")),
List(Some(1.0), None)
List(None, Some(1.0))
)
}
}
Expand All @@ -547,7 +547,7 @@ class AggregationColumnsSpec
"keyCol",
stdDevPop(colDouble("col1")),
f.stddev_pop(f.col("col1")),
List(Some(0.816496580927726), Some(0.0))
List(Some(0.0), Some(0.816496580927726))
)
}
}
Expand All @@ -569,7 +569,7 @@ class AggregationColumnsSpec
new Column(
Sum(f.col("col1").expr).toAggregateExpression(isDistinct = true)
),
List(Some(4L), Some(6L))
List(Some(6L), Some(4L))
)
}

Expand All @@ -587,7 +587,7 @@ class AggregationColumnsSpec
new Column(
Sum(f.col("col1").expr).toAggregateExpression(isDistinct = true)
),
List(Some(4.0), Some(6.0))
List(Some(6.0), Some(4.0))
)
}
}
Expand All @@ -607,7 +607,7 @@ class AggregationColumnsSpec
"keyCol",
variance(colDouble("col1")),
f.variance(f.col("col1")),
List(Some(1.0), None)
List(None, Some(1.0))
)
}

Expand All @@ -623,7 +623,7 @@ class AggregationColumnsSpec
"keyCol",
varSamp(colDouble("col1")),
f.var_samp(f.col("col1")),
List(Some(1.0), None)
List(None, Some(1.0))
)
}
}
Expand All @@ -643,7 +643,7 @@ class AggregationColumnsSpec
"keyCol",
varPop(colDouble("col1")),
f.var_pop(f.col("col1")),
List(Some(0.6666666666666666), Some(0.0))
List(Some(0.0), Some(0.6666666666666666))
)
}
}
Expand Down
Loading

0 comments on commit e43dda5

Please sign in to comment.