Skip to content

Commit

Permalink
refactor: numeric types improvement
Browse files Browse the repository at this point in the history
  • Loading branch information
eruizalo committed Dec 23, 2021
1 parent 3337ca1 commit 8b216dd
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 50 deletions.
15 changes: 4 additions & 11 deletions core/src/main/scala/doric/syntax/AggregationColumns.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package doric
package syntax

import cats.implicits.{catsSyntaxTuple2Semigroupal, toTraverseOps}
import doric.types.{DoubleC, NumericDecimalsType, NumericIntegerType, NumericType}
import doric.types.{DoubleC, NumericType}
import org.apache.spark.sql.{functions => f}

private[syntax] trait AggregationColumns {
Expand All @@ -13,16 +13,9 @@ private[syntax] trait AggregationColumns {
* @group Aggregation Numeric Type
* @see [[org.apache.spark.sql.functions.sum(e:* org.apache.spark.sql.functions.sum]]
*/
def sum2Long[T: NumericIntegerType](col: DoricColumn[T]): LongColumn =
col.elem.map(f.sum).toDC

/**
* Aggregate function: returns the sum of all values in the expression.
*
* @group Aggregation Numeric Type
* @see [[org.apache.spark.sql.functions.sum(e:* org.apache.spark.sql.functions.sum]]
*/
def sum2Double[T: NumericDecimalsType](col: DoricColumn[T]): DoubleColumn =
def sum[T](col: DoricColumn[T])(implicit
nt: NumericType[T]
): DoricColumn[nt.Sum] =
col.elem.map(f.sum).toDC

/**
Expand Down
9 changes: 2 additions & 7 deletions core/src/main/scala/doric/syntax/NumericColumns.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package syntax

import cats.implicits.catsSyntaxTuple2Semigroupal
import doric.DoricColumn.sparkFunction
import doric.types.{NumericDecimalsType, NumericType}
import doric.types.NumericType
import org.apache.spark.sql.Column
import org.apache.spark.sql.{functions => f}
import org.apache.spark.sql.catalyst.expressions.{FormatNumber, FromUnixTime, Rand, Randn}
Expand Down Expand Up @@ -179,19 +179,14 @@ private[syntax] trait NumericColumns {
def timestampSeconds: TimestampColumn =
column.elem.map(f.timestamp_seconds).toDC

}

implicit class NumericDecimalsOpsSyntax[T: NumericDecimalsType](
column: DoricColumn[T]
) {

/**
* Checks if the value of the column is not a number
* @group All Types
* @return
* Boolean DoricColumn
*/
def isNaN: BooleanColumn = column.elem.map(_.isNaN).toDC

}

implicit class LongOperationsSyntax(
Expand Down
9 changes: 0 additions & 9 deletions core/src/main/scala/doric/types/NumericDecimalsType.scala

This file was deleted.

9 changes: 0 additions & 9 deletions core/src/main/scala/doric/types/NumericIntegerType.scala

This file was deleted.

25 changes: 23 additions & 2 deletions core/src/main/scala/doric/types/NumericType.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,29 @@
package doric
package types

trait NumericType[T] extends NumericIntegerType[T] with NumericDecimalsType[T]
trait NumericType[T] {
type Sum
}

object NumericType {
def apply[T]: NumericType[T] = new NumericType[T] {}
implicit val intNumeric: NumericType[Int] {
type Sum = Long
} = new NumericType[Int] {
type Sum = Long
}
implicit val longNumeric: NumericType[Long] {
type Sum = Long
} = new NumericType[Long] {
type Sum = Long
}
implicit val floatNumeric: NumericType[Float] {
type Sum = Double
} = new NumericType[Float] {
type Sum = Double
}
implicit val doubleNumeric: NumericType[Double] {
type Sum = Double
} = new NumericType[Double] {
type Sum = Double
}
}
20 changes: 10 additions & 10 deletions core/src/test/scala/doric/sem/AggregationOpsSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,48 +17,48 @@ class AggregationOpsSpec extends DoricTestElements {

it("can use original spark aggregateFunctions") {
df.groupByCName(str)
.agg(colInt(num).pipe(sum2Long(_)) as sum1)
.agg(colInt(num).pipe(sum(_)) as sum1)
.validateColumnType(colLong(sum1))

assertThrows[DoricMultiError] {
df.groupByCName(str)
.agg(colLong(num).pipe(sum2Long(_)) as sum1)
.agg(colLong(num).pipe(sum(_)) as sum1)
}
}

it("groupBy") {
df.groupBy(concat(col(str), col(str)) as conc)
.agg(col[Int](num).pipe(sum2Long(_)) as sum1)
.agg(col[Int](num).pipe(sum(_)) as sum1)
.validateColumnType(colString(conc))
.validateColumnType(colLong(sum1))

assertThrows[DoricMultiError] {
df.groupBy(col[String](str2))
.agg(col[Int](num).pipe(sum2Long(_)) as sum1)
.agg(col[Int](num).pipe(sum(_)) as sum1)
}
}

it("cube") {
df.cube(concat(col(str), col(str)) as conc)
.agg(col[Int](num).pipe(sum2Long(_)) as sum1)
.agg(col[Int](num).pipe(sum(_)) as sum1)
.validateColumnType(colString(conc))
.validateColumnType(colLong(sum1))

assertThrows[DoricMultiError] {
df.cube(col[String](str2))
.agg(col[Int](num).pipe(sum2Long(_)) as sum1)
.agg(col[Int](num).pipe(sum(_)) as sum1)
}
}

it("rollup") {
df.rollup(concat(col(str), col(str)) as conc)
.agg(col[Int](num).pipe(sum2Long(_)) as sum1)
.agg(col[Int](num).pipe(sum(_)) as sum1)
.validateColumnType(colString(conc))
.validateColumnType(colLong(sum1))

assertThrows[DoricMultiError] {
df.rollup(col[String](str2))
.agg(col[Int](num).pipe(sum2Long(_)) as sum1)
.agg(col[Int](num).pipe(sum(_)) as sum1)
}
}

Expand All @@ -71,7 +71,7 @@ class AggregationOpsSpec extends DoricTestElements {
df.groupBy(concat(col(str), col(str)) as conc)
.pivot(colInt(num2))(List(1, 4))
.agg(
col[Int](num).pipe(sum2Long(_)) as sum1,
col[Int](num).pipe(sum(_)) as sum1,
col[Int](num).pipe(first(_)) as firstC
)
.validateColumnType(colString(conc))
Expand All @@ -83,7 +83,7 @@ class AggregationOpsSpec extends DoricTestElements {
assertThrows[DoricMultiError] {
df.groupBy(concat(col(str), col(str)) as conc)
.pivot(colString(num2))(List("1", "4"))
.agg(col[Int](num).pipe(sum2Long(_)) as sum1)
.agg(col[Int](num).pipe(sum(_)) as sum1)
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions core/src/test/scala/doric/syntax/AggregationColumnsSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class AggregationColumnsSpec

df.testAggregation(
"keyCol",
sum2Long(colInt("col1")),
sum(colInt("col1")),
f.sum("col1"),
List(Some(6L), Some(3L))
)
Expand All @@ -40,7 +40,7 @@ class AggregationColumnsSpec

df.testAggregation(
"keyCol",
sum2Double(colFloat("col1")),
sum(colFloat("col1")),
f.sum("col1"),
List(Some(6.5), Some(3.0))
)
Expand Down

0 comments on commit 8b216dd

Please sign in to comment.