Skip to content

Commit

Permalink
feat: [+] remaining non aggregate functions (resolves hablapps#68)
Browse files Browse the repository at this point in the history
  • Loading branch information
eruizalo committed Aug 1, 2022
1 parent 58b7a2a commit e11fdf6
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .github/labeler.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ CI/CD:
- .github/**/*.yml

dependencies:
- any: [ '.github/dependabot.yml', 'build.sbt', 'project/**/*', '.scala-steward.conf', '.scalafix.conf', '.scalafmt.conf' ]
- any: [ '.github/dependabot.yml', 'build.sbt', 'project/**/*', '.scala-steward.conf', '.scalafix.conf', '.scalafmt.conf', 'project/build.properties' ]

documentation:
- any: [ 'docs/**/*', 'notebooks/**/*', '*.md' ]
Expand Down
24 changes: 24 additions & 0 deletions core/src/main/scala/doric/syntax/NumericColumns.scala
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,21 @@ private[syntax] trait NumericColumns {
* @see [[org.apache.spark.sql.functions.tanh(e:org\.apache\.spark\.sql\.Column)* org.apache.spark.sql.functions.tanh]]
*/
def tanh: DoubleColumn = column.elem.map(f.tanh).toDC

/**
* Unary minus, i.e. negate the expression.
*
* @example {{{
* // Select the amount column and negates all values.
* // Scala:
* df.select( -df("amount") )
* }}}
*
* @todo DayTimeIntervalType & YearMonthIntervalType
* @group Numeric Type
* @see [[org.apache.spark.sql.functions.negate]]
*/
def negate: DoricColumn[T] = column.elem.map(f.negate).toDC
}

/**
Expand Down Expand Up @@ -572,6 +587,15 @@ private[syntax] trait NumericColumns {
def round(scale: IntegerColumn): DoricColumn[T] = (column.elem, scale.elem)
.mapN((c, s) => new Column(Round(c.expr, s.expr)))
.toDC

/**
* Returns col1 if it is not NaN, or col2 if col1 is NaN.
*
* @group Numeric Type
* @see [[org.apache.spark.sql.functions.nanvl]]
*/
def naNvl(col2: DoricColumn[T]): DoricColumn[T] =
(column.elem, col2.elem).mapN(f.nanvl).toDC
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package syntax

import cats.implicits._
import org.apache.spark.sql.Column
import org.apache.spark.sql.{functions => f}
import org.apache.spark.sql.catalyst.expressions.{ShiftLeft, ShiftRight, ShiftRightUnsigned}

private[syntax] trait NumericColumns32 {
Expand Down Expand Up @@ -46,6 +47,14 @@ private[syntax] trait NumericColumns32 {
(column.elem, numBits.elem)
.mapN((c, n) => new Column(ShiftRightUnsigned(c.expr, n.expr)))
.toDC

/**
* Computes bitwise NOT (~) of a number.
*
* @group Numeric Type
* @see [[org.apache.spark.sql.functions.bitwise_not]]
*/
def bitwiseNot: DoricColumn[T] = column.elem.map(f.bitwise_not).toDC
}

}
47 changes: 46 additions & 1 deletion core/src/test/scala/doric/syntax/NumericOperationsSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package syntax

import doric.types.{NumericType, SparkType}
import doric.types.SparkType.Primitive
import org.apache.spark.sql.catalyst.expressions.{ShiftLeft, ShiftRight, ShiftRightUnsigned}
import org.apache.spark.sql.catalyst.expressions.{BitwiseNot, ShiftLeft, ShiftRight, ShiftRightUnsigned}
import org.apache.spark.sql.{Column, DataFrame, SparkSession, functions => f}
import org.scalatest.funspec.AnyFunSpecLike

Expand Down Expand Up @@ -333,6 +333,15 @@ trait NumericOperationsSpec
f.tanh
)
}

it(s"negate function $numTypeStr") {
testDoricSpark[T, T](
List(Some(-1), Some(1), Some(2), None),
List(Some(1), Some(-1), Some(2), None),
_.negate,
f.negate
)
}
}
}

Expand Down Expand Up @@ -434,6 +443,19 @@ trait NumericOperationsSpec
shiftRightUnsignedBefore32(_, f.lit(numBits))
)
}

it(s"bitwiseNot function $numTypeStr") {
// Aux function as it is deprecated since 3.2, otherwise specs would get complicated
val bitwiseNotBefore32: Column => Column =
col => new Column(BitwiseNot(col.expr))
val numBits = 2
testDoricSpark[T, T](
List(Some(0), Some(4), Some(20), None),
List(Some(0), Some(1), Some(5), None),
_.bitwiseNot,
bitwiseNotBefore32
)
}
}
}

Expand Down Expand Up @@ -510,6 +532,29 @@ trait NumericOperationsSpec
f.round(_, 2)
)
}

it(s"naNvl function with param $numTypeStr") {
testDoricSparkDecimals2[T, T, T](
List(
(Some(-1.466f), Some(-2.0f)),
(Some(0f), Some(0.7111f)),
(Some(1f / 0f), Some(1.0f)),
(None, Some(1.0f)),
(Some(1.0f), None),
(None, None)
),
List(
Some(-1.466f),
Some(0f),
Some(1.0f),
Some(1.0f),
Some(1.0f),
None
),
_.naNvl(_),
f.nanvl
)
}
}
}
}
Expand Down
26 changes: 26 additions & 0 deletions core/src/test/scala/doric/syntax/NumericUtilsSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,30 @@ protected trait NumericUtilsSpec extends TypedColumnTest {
)
}

def testDoricSparkDecimals2[
T1: Primitive: ClassTag: TypeTag,
T2: Primitive: ClassTag: TypeTag,
O: Primitive: ClassTag: TypeTag
](
input: List[(Option[Float], Option[Float])],
output: List[Option[O]],
doricFun: (DoricColumn[T1], DoricColumn[T2]) => DoricColumn[O],
sparkFun: (Column, Column) => Column
)(implicit
spark: SparkSession,
funT1: FromFloat[T1],
funT2: FromFloat[T2]
): Unit = {
import spark.implicits._
val df = input
.map { case (x, y) => (x.map(funT1), y.map(funT2)) }
.toDF("col1", "col2")

df.testColumns2("col1", "col2")(
(c1, c2) => doricFun(col[T1](c1), col[T2](c2)),
(c1, c2) => sparkFun(f.col(c1), f.col(c2)),
output
)
}

}

0 comments on commit e11fdf6

Please sign in to comment.