diff --git a/.scalafix.conf b/.scalafix.conf index 9fc172821..ac69e82e6 100644 --- a/.scalafix.conf +++ b/.scalafix.conf @@ -1,7 +1,7 @@ OrganizeImports { blankLines = Auto coalesceToWildcardImportThreshold = 5 - groups = ["habla.", "org.apache.spark.", "*"] + groups = ["doric.", "org.apache.spark.", "*"] groupedImports = Merge importSelectorsOrder = Ascii removeUnused = true diff --git a/build.sbt b/build.sbt index 791c37cc7..58e54d202 100644 --- a/build.sbt +++ b/build.sbt @@ -1,4 +1,32 @@ -import sbt.Compile +import sbt.{Compile, Def} + +val sparkDefaultShortVersion = "3.1" +val spark30Version = "3.0.3" +val spark31Version = "3.1.3" +val spark32Version = "3.2.1" + +val versionRegex = """^(.*)\.(.*)\.(.*)$""".r + +val scala212 = "2.12.15" +val scala213 = "2.13.8" + +val sparkShort: String => String = { + case "3.0" => spark30Version + case "3.1" => spark31Version + case "3.2" => spark32Version +} + +val sparkLong2ShortVersion: String => String = { + case versionRegex("3", "0", _) => "3.0" + case versionRegex("3", "1", _) => "3.1" + case versionRegex("3", "2", _) => "3.2" +} + +val scalaVersionSelect: String => String = { + case versionRegex("3", "0", _) => scala212 + case versionRegex("3", "1", _) => scala212 + case versionRegex("3", "2", _) => scala212 +} ThisBuild / organization := "org.hablapps" ThisBuild / homepage := Some(url("https://github.com/hablapps/doric")) @@ -13,14 +41,23 @@ ThisBuild / developers := List( url("https://github.com/alfonsorr") ), Developer( - "AlfonsoRR", + "eruizalo", "Eduardo Ruiz", "", url("https://github.com/eruizalo") ) ) - -Global / scalaVersion := "2.12.15" +val sparkVersion = settingKey[String]("Spark version") +Global / sparkVersion := + System.getProperty( + "sparkVersion", + sparkShort( + System.getProperty("sparkShortVersion", sparkDefaultShortVersion) + ) + ) +Global / scalaVersion := scalaVersionSelect(sparkVersion.value) +Global / publish / skip := true +Global / publishArtifact := false // scaladoc settings Compile / doc / scalacOptions ++= Seq("-groups") @@ -45,20 +82,32 @@ scmInfo := Some( updateOptions := updateOptions.value.withLatestSnapshots(false) -val sparkVersion = "3.1.3" +val configSpark = Seq( + sparkVersion := System.getProperty( + "sparkVersion", + sparkShort( + System.getProperty("sparkShortVersion", sparkDefaultShortVersion) + ) + ) +) + lazy val core = project .in(file("core")) .settings( - name := "doric", - run / fork := true, + configSpark, + name := "doric_" + sparkLong2ShortVersion(sparkVersion.value), + run / fork := true, + publish / skip := false, + publishArtifact := true, + scalaVersion := scalaVersionSelect(sparkVersion.value), libraryDependencies ++= Seq( - "org.apache.spark" %% "spark-sql" % sparkVersion % "provided", - "org.typelevel" %% "cats-core" % "2.7.0", - "com.lihaoyi" %% "sourcecode" % "0.2.8", - "io.monix" %% "newtypes-core" % "0.2.1", - "com.github.mrpowers" %% "spark-daria" % "1.2.3" % "test", - "com.github.mrpowers" %% "spark-fast-tests" % "1.2.0" % "test", - "org.scalatest" %% "scalatest" % "3.2.11" % "test" + "org.apache.spark" %% "spark-sql" % sparkVersion.value % "provided", + "org.typelevel" %% "cats-core" % "2.7.0", + "com.lihaoyi" %% "sourcecode" % "0.2.8", + "io.monix" %% "newtypes-core" % "0.2.1", + "com.github.mrpowers" %% "spark-daria" % "1.2.3" % "test", + "com.github.mrpowers" %% "spark-fast-tests" % "1.2.0" % "test", + "org.scalatest" %% "scalatest" % "3.2.11" % "test" ), // docs run / fork := true, @@ -68,23 +117,63 @@ lazy val core = project "-implicits", "-skip-packages", "org.apache.spark" - ) + ), + Compile / unmanagedSourceDirectories ++= { + sparkVersion.value match { + case versionRegex("3", "0", _) => + Seq( + (Compile / sourceDirectory)(_ / "spark_3.0_mount" / "scala"), + (Compile / sourceDirectory)(_ / "spark_3.0_3.1" / "scala") + ).join.value + case versionRegex("3", "1", _) => + Seq( + (Compile / sourceDirectory)(_ / "spark_3.0_3.1" / "scala"), + (Compile / sourceDirectory)(_ / "spark_3.1" / "scala"), + (Compile / sourceDirectory)(_ / "spark_3.1_mount" / "scala") + ).join.value + case versionRegex("3", "2", _) => + Seq( + (Compile / sourceDirectory)(_ / "spark_3.1" / "scala"), + (Compile / sourceDirectory)(_ / "spark_3.2" / "scala"), + (Compile / sourceDirectory)(_ / "spark_3.2_mount" / "scala") + ).join.value + } + }, + Test / unmanagedSourceDirectories ++= { + sparkVersion.value match { + case versionRegex("3", "0", _) => + Seq.empty[Def.Initialize[File]].join.value + case versionRegex("3", "1", _) => + Seq( + (Test / sourceDirectory)(_ / "spark_3.1" / "scala") + ).join.value + case versionRegex("3", "2", _) => + Seq( + (Test / sourceDirectory)(_ / "spark_3.1" / "scala"), + (Test / sourceDirectory)(_ / "spark_3.2" / "scala") + ).join.value + } + } ) lazy val docs = project .in(file("docs")) .dependsOn(core) .settings( - run / fork := true, + configSpark, + run / fork := true, + publish / skip := true, + publishArtifact := false, run / javaOptions += "-XX:MaxJavaStackTraceDepth=10", - mdocIn := baseDirectory.value / "docs", + scalaVersion := scalaVersionSelect(sparkVersion.value), + mdocIn := baseDirectory.value / "docs", libraryDependencies ++= Seq( - "org.apache.spark" %% "spark-sql" % sparkVersion + "org.apache.spark" %% "spark-sql" % sparkVersion.value ), mdocVariables := Map( "VERSION" -> version.value, "STABLE_VERSION" -> "0.0.2", - "SPARK_VERSION" -> sparkVersion + "SPARK_VERSION" -> sparkVersion.value ), mdocExtraArguments := Seq( "--clean-target" diff --git a/core/src/main/scala/doric/syntax/AggregationColumns.scala b/core/src/main/scala/doric/syntax/AggregationColumns.scala index 1e80fdbc5..a817a642f 100644 --- a/core/src/main/scala/doric/syntax/AggregationColumns.scala +++ b/core/src/main/scala/doric/syntax/AggregationColumns.scala @@ -2,8 +2,10 @@ package doric package syntax import cats.implicits.{catsSyntaxTuple2Semigroupal, toTraverseOps} -import doric.types.{DoubleC, NumericType} -import org.apache.spark.sql.{functions => f} +import doric.types.NumericType + +import org.apache.spark.sql.{Column, functions => f} +import org.apache.spark.sql.catalyst.expressions.aggregate.Sum private[syntax] trait AggregationColumns { @@ -252,70 +254,6 @@ private[syntax] trait AggregationColumns { def mean[T: NumericType](col: DoricColumn[T]): DoubleColumn = col.elem.map(f.mean).toDC - /** - * Aggregate function: returns the approximate `percentile` of the numeric column `col` which - * is the smallest value in the ordered `col` values (sorted from least to greatest) such that - * no more than `percentage` of `col` values is less than the value or equal to that value. - * - * @param percentage each value must be between 0.0 and 1.0. - * @param accuracy controls approximation accuracy at the cost of memory. Higher value of accuracy - * yields better accuracy, 1.0/accuracy is the relative error of the approximation. - * @note Support NumericType, DateType and TimestampType since their internal types are all numeric, - * and can be easily cast to double for processing. - * @group Aggregation DoubleC Type - * @see [[org.apache.spark.sql.functions.percentile_approx]] - */ - def percentileApprox[T: DoubleC]( - col: DoricColumn[T], - percentage: Array[Double], - accuracy: Int - ): ArrayColumn[T] = { - require( - percentage.forall(x => x >= 0.0 && x <= 1.0), - "Each value of percentage must be between 0.0 and 1.0." - ) - require( - accuracy >= 0 && accuracy < Int.MaxValue, - s"The accuracy provided must be a literal between (0, ${Int.MaxValue}]" + - s" (current value = $accuracy)" - ) - col.elem - .map(f.percentile_approx(_, f.lit(percentage), f.lit(accuracy))) - .toDC - } - - /** - * Aggregate function: returns the approximate `percentile` of the numeric column `col` which - * is the smallest value in the ordered `col` values (sorted from least to greatest) such that - * no more than `percentage` of `col` values is less than the value or equal to that value. - * - * @param percentage must be between 0.0 and 1.0. - * @param accuracy controls approximation accuracy at the cost of memory. Higher value of accuracy - * yields better accuracy, 1.0/accuracy is the relative error of the approximation. - * @note Support NumericType, DateType and TimestampType since their internal types are all numeric, - * and can be easily cast to double for processing. - * @group Aggregation DoubleC Type - * @see [[org.apache.spark.sql.functions.percentile_approx]] - */ - def percentileApprox[T: DoubleC]( - col: DoricColumn[T], - percentage: Double, - accuracy: Int - ): DoricColumn[T] = { - require( - percentage >= 0.0 && percentage <= 1.0, - "Percentage must be between 0.0 and 1.0." - ) - require( - accuracy >= 0 && accuracy < Int.MaxValue, - s"The accuracy provided must be a literal between (0, ${Int.MaxValue}]" + - s" (current value = $accuracy)" - ) - col.elem - .map(f.percentile_approx(_, f.lit(percentage), f.lit(accuracy))) - .toDC - } - /** * Aggregate function: returns the skewness of the values in a group. * @@ -361,7 +299,11 @@ private[syntax] trait AggregationColumns { def sumDistinct[T](col: DoricColumn[T])(implicit nt: NumericType[T] ): DoricColumn[nt.Sum] = - col.elem.map(f.sumDistinct).toDC + col.elem + .map(e => + new Column(Sum(e.expr).toAggregateExpression(isDistinct = true)) + ) + .toDC /** * Aggregate function: alias for `var_samp`. diff --git a/core/src/main/scala/doric/syntax/ArrayColumns.scala b/core/src/main/scala/doric/syntax/ArrayColumns.scala index e73b46e9f..a634adcdc 100644 --- a/core/src/main/scala/doric/syntax/ArrayColumns.scala +++ b/core/src/main/scala/doric/syntax/ArrayColumns.scala @@ -3,7 +3,8 @@ package syntax import cats.implicits._ import doric.types.CollectionType -import org.apache.spark.sql.{Column, Row, functions => f} + +import org.apache.spark.sql.{Column, functions => f} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.LambdaFunction.identity @@ -469,12 +470,14 @@ private[syntax] trait ArrayColumns { * end if `start` is negative) with the specified `length`. * * @note - * if `start` == 0 an exception will be thrown + * if `start` == 0 an exception will be thrown * @group Array Type * @see [[org.apache.spark.sql.functions.slice(x:org\.apache\.spark\.sql\.Column,start:org\.apache\.spark\.sql\.Column,length* org.apache.spark.sql.functions.slice]] */ def slice(start: IntegerColumn, length: IntegerColumn): ArrayColumn[T] = - (col.elem, start.elem, length.elem).mapN(f.slice).toDC + (col.elem, start.elem, length.elem) + .mapN((a, b, c) => new Column(Slice(a.expr, b.expr, c.expr))) + .toDC /** * Merge two given arrays, element-wise, into a single array using a function. diff --git a/core/src/main/scala/doric/syntax/BinaryColumns.scala b/core/src/main/scala/doric/syntax/BinaryColumns.scala index 1161ef55a..d5f87e1ab 100644 --- a/core/src/main/scala/doric/syntax/BinaryColumns.scala +++ b/core/src/main/scala/doric/syntax/BinaryColumns.scala @@ -1,10 +1,10 @@ package doric package syntax -import cats.implicits.{catsSyntaxTuple2Semigroupal, toTraverseOps} +import cats.implicits.toTraverseOps import doric.types.{BinaryType, SparkType} -import org.apache.spark.sql.catalyst.expressions.Decode -import org.apache.spark.sql.{Column, functions => f} + +import org.apache.spark.sql.{functions => f} private[syntax] trait BinaryColumns { @@ -76,21 +76,6 @@ private[syntax] trait BinaryColumns { * @see [[org.apache.spark.sql.functions.base64]] */ def base64: StringColumn = column.elem.map(f.base64).toDC - - /** - * Computes the first argument into a string from a binary using the provided character set - * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). - * If either argument is null, the result will also be null. - * - * @group Binary Type - * @see [[org.apache.spark.sql.functions.decode]] - */ - def decode(charset: StringColumn): StringColumn = - (column.elem, charset.elem) - .mapN((col, char) => { - new Column(Decode(col.expr, char.expr)) - }) - .toDC } } diff --git a/core/src/main/scala/doric/syntax/BooleanColumns.scala b/core/src/main/scala/doric/syntax/BooleanColumns.scala index e9c683661..6f1fb8b45 100644 --- a/core/src/main/scala/doric/syntax/BooleanColumns.scala +++ b/core/src/main/scala/doric/syntax/BooleanColumns.scala @@ -1,8 +1,8 @@ package doric package syntax -import cats.implicits._ import doric.DoricColumn.sparkFunction + import org.apache.spark.sql.{functions => f} private[syntax] trait BooleanColumns { @@ -61,24 +61,5 @@ private[syntax] trait BooleanColumns { */ def ||(other: DoricColumn[Boolean]): DoricColumn[Boolean] = or(other) - - /** - * Returns null if the condition is true, and throws an exception otherwise. - * - * @throws java.lang.RuntimeException if the condition is false - * @group Boolean Type - * @see [[org.apache.spark.sql.functions.assert_true(c:org\.apache\.spark\.sql\.Column):* org.apache.spark.sql.functions.assert_true]] - */ - def assertTrue: NullColumn = column.elem.map(f.assert_true).toDC - - /** - * Returns null if the condition is true; throws an exception with the error message otherwise. - * - * @throws java.lang.RuntimeException if the condition is false - * @group Boolean Type - * @see [[org.apache.spark.sql.functions.assert_true(c:org\.apache\.spark\.sql\.Column,e:* org.apache.spark.sql.functions.assert_true]] - */ - def assertTrue(msg: StringColumn): NullColumn = - (column.elem, msg.elem).mapN(f.assert_true).toDC } } diff --git a/core/src/main/scala/doric/syntax/NumericColumns.scala b/core/src/main/scala/doric/syntax/NumericColumns.scala index 9e50b6d11..75ac3b7e5 100644 --- a/core/src/main/scala/doric/syntax/NumericColumns.scala +++ b/core/src/main/scala/doric/syntax/NumericColumns.scala @@ -4,8 +4,8 @@ package syntax import cats.implicits._ import doric.DoricColumn.sparkFunction import doric.types.NumericType -import org.apache.spark.sql.Column -import org.apache.spark.sql.{functions => f} + +import org.apache.spark.sql.{Column, functions => f} import org.apache.spark.sql.catalyst.expressions.{FormatNumber, FromUnixTime, Rand, Randn} private[syntax] trait NumericColumns { @@ -14,7 +14,7 @@ private[syntax] trait NumericColumns { * Returns the current Unix timestamp (in seconds) as a long. * * @note All calls of `unix_timestamp` within the same query return the same value - * (i.e. the current timestamp is calculated at the start of query evaluation). + * (i.e. the current timestamp is calculated at the start of query evaluation). * * @group Numeric Type * @see [[org.apache.spark.sql.functions.unix_timestamp()* org.apache.spark.sql.functions.unix_timestamp]] @@ -170,15 +170,6 @@ private[syntax] trait NumericColumns { }) .toDC - /** - * Creates timestamp from the number of seconds since UTC epoch. - * - * @group Numeric Type - * @see [[org.apache.spark.sql.functions.timestamp_seconds]] - */ - def timestampSeconds: TimestampColumn = - column.elem.map(f.timestamp_seconds).toDC - /** * Checks if the value of the column is not a number * @group All Types diff --git a/core/src/main/scala/doric/syntax/StringColumns.scala b/core/src/main/scala/doric/syntax/StringColumns.scala index 5d12c93c9..724dd7c62 100644 --- a/core/src/main/scala/doric/syntax/StringColumns.scala +++ b/core/src/main/scala/doric/syntax/StringColumns.scala @@ -623,20 +623,5 @@ private[syntax] trait StringColumns { new Column(new ParseToTimestamp(str.expr, tsFormat.expr)) ) .toDC - - /** - * ******************************************************** - * MISC FUNCTIONS - * ******************************************************** - */ - - /** - * Throws an exception with the provided error message. - * - * @throws java.lang.RuntimeException with the error message - * @group String Type - * @see [[org.apache.spark.sql.functions.raise_error]] - */ - def raiseError: NullColumn = s.elem.map(f.raise_error).toDC } } diff --git a/core/src/main/spark_3.0_3.1/scala/doric/syntax/BinaryColumns30_31.scala b/core/src/main/spark_3.0_3.1/scala/doric/syntax/BinaryColumns30_31.scala new file mode 100644 index 000000000..453cb102e --- /dev/null +++ b/core/src/main/spark_3.0_3.1/scala/doric/syntax/BinaryColumns30_31.scala @@ -0,0 +1,32 @@ +package doric +package syntax + +import cats.implicits.catsSyntaxTuple2Semigroupal +import doric.types.{BinaryType, SparkType} + +import org.apache.spark.sql.catalyst.expressions.Decode +import org.apache.spark.sql.Column + +private[syntax] trait BinaryColumns30_31 { + + implicit class BinaryOperationsSyntax30_31[T: BinaryType: SparkType]( + column: DoricColumn[T] + ) { + + /** + * Computes the first argument into a string from a binary using the provided character set + * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). + * If either argument is null, the result will also be null. + * + * @group Binary Type + * @see [[org.apache.spark.sql.functions.decode]] + */ + def decode(charset: StringColumn): StringColumn = + (column.elem, charset.elem) + .mapN((col, char) => { + new Column(Decode(col.expr, char.expr)) + }) + .toDC + } + +} diff --git a/core/src/main/scala/doric/syntax/All.scala b/core/src/main/spark_3.0_mount/scala/doric/syntax/All.scala similarity index 83% rename from core/src/main/scala/doric/syntax/All.scala rename to core/src/main/spark_3.0_mount/scala/doric/syntax/All.scala index 105293b88..498a6034d 100644 --- a/core/src/main/scala/doric/syntax/All.scala +++ b/core/src/main/spark_3.0_mount/scala/doric/syntax/All.scala @@ -1,8 +1,7 @@ -package doric -package syntax +package doric.syntax private[doric] trait All - extends ArrayColumns + extends ArrayColumns with TypeMatcher with CommonColumns with DStructs @@ -18,3 +17,4 @@ private[doric] trait All with CNameOps with BinaryColumns with Interpolators + with BinaryColumns30_31 diff --git a/core/src/main/spark_3.1/scala/doric/syntax/AggregationColumns31.scala b/core/src/main/spark_3.1/scala/doric/syntax/AggregationColumns31.scala new file mode 100644 index 000000000..f1dc6e724 --- /dev/null +++ b/core/src/main/spark_3.1/scala/doric/syntax/AggregationColumns31.scala @@ -0,0 +1,73 @@ +package doric +package syntax + +import doric.types.DoubleC + +import org.apache.spark.sql.{functions => f} + +private[syntax] trait AggregationColumns31 { + + /** + * Aggregate function: returns the approximate `percentile` of the numeric column `col` which + * is the smallest value in the ordered `col` values (sorted from least to greatest) such that + * no more than `percentage` of `col` values is less than the value or equal to that value. + * + * @param percentage each value must be between 0.0 and 1.0. + * @param accuracy controls approximation accuracy at the cost of memory. Higher value of accuracy + * yields better accuracy, 1.0/accuracy is the relative error of the approximation. + * @note Support NumericType, DateType and TimestampType since their internal types are all numeric, + * and can be easily cast to double for processing. + * @group Aggregation DoubleC Type + * @see [[org.apache.spark.sql.functions.percentile_approx]] + */ + def percentileApprox[T: DoubleC]( + col: DoricColumn[T], + percentage: Array[Double], + accuracy: Int + ): ArrayColumn[T] = { + require( + percentage.forall(x => x >= 0.0 && x <= 1.0), + "Each value of percentage must be between 0.0 and 1.0." + ) + require( + accuracy >= 0 && accuracy < Int.MaxValue, + s"The accuracy provided must be a literal between (0, ${Int.MaxValue}]" + + s" (current value = $accuracy)" + ) + col.elem + .map(f.percentile_approx(_, f.lit(percentage), f.lit(accuracy))) + .toDC + } + + /** + * Aggregate function: returns the approximate `percentile` of the numeric column `col` which + * is the smallest value in the ordered `col` values (sorted from least to greatest) such that + * no more than `percentage` of `col` values is less than the value or equal to that value. + * + * @param percentage must be between 0.0 and 1.0. + * @param accuracy controls approximation accuracy at the cost of memory. Higher value of accuracy + * yields better accuracy, 1.0/accuracy is the relative error of the approximation. + * @note Support NumericType, DateType and TimestampType since their internal types are all numeric, + * and can be easily cast to double for processing. + * @group Aggregation DoubleC Type + * @see [[org.apache.spark.sql.functions.percentile_approx]] + */ + def percentileApprox[T: DoubleC]( + col: DoricColumn[T], + percentage: Double, + accuracy: Int + ): DoricColumn[T] = { + require( + percentage >= 0.0 && percentage <= 1.0, + "Percentage must be between 0.0 and 1.0." + ) + require( + accuracy >= 0 && accuracy < Int.MaxValue, + s"The accuracy provided must be a literal between (0, ${Int.MaxValue}]" + + s" (current value = $accuracy)" + ) + col.elem + .map(f.percentile_approx(_, f.lit(percentage), f.lit(accuracy))) + .toDC + } +} diff --git a/core/src/main/spark_3.1/scala/doric/syntax/BooleanColumns31.scala b/core/src/main/spark_3.1/scala/doric/syntax/BooleanColumns31.scala new file mode 100644 index 000000000..abfd30f4d --- /dev/null +++ b/core/src/main/spark_3.1/scala/doric/syntax/BooleanColumns31.scala @@ -0,0 +1,36 @@ +package doric +package syntax + +import cats.implicits._ + +import org.apache.spark.sql.{functions => f} + +private[syntax] trait BooleanColumns31 { + + /** + * @group Boolean Type + */ + implicit class BooleanOperationsSyntax31( + column: DoricColumn[Boolean] + ) { + + /** + * Returns null if the condition is true, and throws an exception otherwise. + * + * @throws java.lang.RuntimeException if the condition is false + * @group Boolean Type + * @see [[org.apache.spark.sql.functions.assert_true(c:org\.apache\.spark\.sql\.Column):* org.apache.spark.sql.functions.assert_true]] + */ + def assertTrue: NullColumn = column.elem.map(f.assert_true).toDC + + /** + * Returns null if the condition is true; throws an exception with the error message otherwise. + * + * @throws java.lang.RuntimeException if the condition is false + * @group Boolean Type + * @see [[org.apache.spark.sql.functions.assert_true(c:org\.apache\.spark\.sql\.Column,e:* org.apache.spark.sql.functions.assert_true]] + */ + def assertTrue(msg: StringColumn): NullColumn = + (column.elem, msg.elem).mapN(f.assert_true).toDC + } +} diff --git a/core/src/main/spark_3.1/scala/doric/syntax/NumericColumns31.scala b/core/src/main/spark_3.1/scala/doric/syntax/NumericColumns31.scala new file mode 100644 index 000000000..35980b7a7 --- /dev/null +++ b/core/src/main/spark_3.1/scala/doric/syntax/NumericColumns31.scala @@ -0,0 +1,23 @@ +package doric +package syntax + +import doric.types.NumericType + +import org.apache.spark.sql.{functions => f} + +private[syntax] trait NumericColumns31 { + implicit class NumericOperationsSyntax31[T: NumericType]( + column: DoricColumn[T] + ) { + + /** + * Creates timestamp from the number of seconds since UTC epoch. + * + * @group Numeric Type + * @see [[org.apache.spark.sql.functions.timestamp_seconds]] + */ + def timestampSeconds: TimestampColumn = + column.elem.map(f.timestamp_seconds).toDC + } + +} diff --git a/core/src/main/spark_3.1/scala/doric/syntax/StringColumns31.scala b/core/src/main/spark_3.1/scala/doric/syntax/StringColumns31.scala new file mode 100644 index 000000000..c3b2182a0 --- /dev/null +++ b/core/src/main/spark_3.1/scala/doric/syntax/StringColumns31.scala @@ -0,0 +1,25 @@ +package doric +package syntax + +import org.apache.spark.sql.{functions => f} + +private[syntax] trait StringColumns31 { + + implicit class StringOperationsSyntax31(s: DoricColumn[String]) { + + /** + * ******************************************************** + * MISC FUNCTIONS + * ******************************************************** + */ + + /** + * Throws an exception with the provided error message. + * + * @throws java.lang.RuntimeException with the error message + * @group String Type + * @see [[org.apache.spark.sql.functions.raise_error]] + */ + def raiseError: NullColumn = s.elem.map(f.raise_error).toDC + } +} diff --git a/core/src/main/spark_3.1_mount/scala/doric/syntax/All.scala b/core/src/main/spark_3.1_mount/scala/doric/syntax/All.scala new file mode 100644 index 000000000..40bbfa6fd --- /dev/null +++ b/core/src/main/spark_3.1_mount/scala/doric/syntax/All.scala @@ -0,0 +1,25 @@ +package doric +package syntax + +private[doric] trait All + extends ArrayColumns + with TypeMatcher + with CommonColumns + with DStructs + with LiteralConversions + with MapColumns + with NumericColumns + with DateColumns + with TimestampColumns + with BooleanColumns + with StringColumns + with ControlStructures + with AggregationColumns + with CNameOps + with BinaryColumns + with Interpolators + with AggregationColumns31 + with BooleanColumns31 + with NumericColumns31 + with StringColumns31 + with BinaryColumns30_31 diff --git a/core/src/main/spark_3.2/scala/doric/syntax/BinaryColumns32.scala b/core/src/main/spark_3.2/scala/doric/syntax/BinaryColumns32.scala new file mode 100644 index 000000000..2a19dfb07 --- /dev/null +++ b/core/src/main/spark_3.2/scala/doric/syntax/BinaryColumns32.scala @@ -0,0 +1,32 @@ +package doric +package syntax + +import cats.implicits.catsSyntaxTuple2Semigroupal +import doric.types.{BinaryType, SparkType} + +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.expressions.StringDecode + +private[syntax] trait BinaryColumns32 { + + implicit class BinaryOperationsSyntax32[T: BinaryType : SparkType]( + column: DoricColumn[T] + ) { + + /** + * Computes the first argument into a string from a binary using the provided character set + * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). + * If either argument is null, the result will also be null. + * + * @group Binary Type + * @see [[org.apache.spark.sql.functions.decode]] + */ + def decode(charset: StringColumn): StringColumn = + (column.elem, charset.elem) + .mapN((col, char) => { + new Column(StringDecode(col.expr, char.expr)) + }) + .toDC + } + +} diff --git a/core/src/main/spark_3.2_mount/scala/doric/syntax/All.scala b/core/src/main/spark_3.2_mount/scala/doric/syntax/All.scala new file mode 100644 index 000000000..dcc3858ec --- /dev/null +++ b/core/src/main/spark_3.2_mount/scala/doric/syntax/All.scala @@ -0,0 +1,24 @@ +package doric.syntax + +private[doric] trait All + extends ArrayColumns + with TypeMatcher + with CommonColumns + with DStructs + with LiteralConversions + with MapColumns + with NumericColumns + with DateColumns + with TimestampColumns + with BooleanColumns + with StringColumns + with ControlStructures + with AggregationColumns + with CNameOps + with BinaryColumns + with Interpolators + with AggregationColumns31 + with BooleanColumns31 + with NumericColumns31 + with StringColumns31 + with BinaryColumns32 diff --git a/core/src/test/scala/doric/DoricColumnSpec.scala b/core/src/test/scala/doric/DoricColumnSpec.scala index 62a634d19..886246462 100644 --- a/core/src/test/scala/doric/DoricColumnSpec.scala +++ b/core/src/test/scala/doric/DoricColumnSpec.scala @@ -121,9 +121,9 @@ class DoricColumnSpec extends DoricTestElements with EitherValues { val error = dCol.elem.run(df).toEither.left.value.head error shouldBe an[SparkErrorWrapper] - error.getMessage should include( - "cannot resolve '`nonExistentCol`' given input columns" - ) + error.getMessage should include("cannot resolve '") + error.getMessage should include("nonExistentCol") + error.getMessage should include("' given input columns") } } diff --git a/core/src/test/scala/doric/TypedColumnTest.scala b/core/src/test/scala/doric/TypedColumnTest.scala index 841975184..c1506a721 100644 --- a/core/src/test/scala/doric/TypedColumnTest.scala +++ b/core/src/test/scala/doric/TypedColumnTest.scala @@ -1,15 +1,16 @@ package doric +import scala.reflect._ +import scala.reflect.runtime.universe._ + import com.github.mrpowers.spark.fast.tests.DatasetComparer import doric.implicitConversions.stringCname import doric.types.{Casting, SparkType} -import org.apache.spark.sql.types._ -import org.apache.spark.sql.{Column, DataFrame, Encoder, RelationalGroupedDataset, functions => f} import org.scalactic._ import org.scalatest.matchers.should.Matchers -import scala.reflect._ -import scala.reflect.runtime.universe._ +import org.apache.spark.sql.{Column, DataFrame, Encoder, RelationalGroupedDataset, functions => f} +import org.apache.spark.sql.types._ trait TypedColumnTest extends Matchers with DatasetComparer { @@ -70,8 +71,12 @@ trait TypedColumnTest extends Matchers with DatasetComparer { s"${if (expected.nonEmpty) s"\nExpected: $expected"}" ) - if (expected.nonEmpty) - doricColumns should contain theSameElementsAs expected + if (expected.nonEmpty) { + doricColumns.map { + case Some(x: java.lang.Double) if x.isNaN => None + case x => x + } should contain theSameElementsAs expected + } } implicit class ValidateColumnGroupType(gr: RelationalGroupedDataset) { diff --git a/core/src/test/scala/doric/sem/JoinOpsSpec.scala b/core/src/test/scala/doric/sem/JoinOpsSpec.scala index f5cb48cf5..8f1b4a3f7 100644 --- a/core/src/test/scala/doric/sem/JoinOpsSpec.scala +++ b/core/src/test/scala/doric/sem/JoinOpsSpec.scala @@ -67,7 +67,9 @@ class JoinOpsSpec extends DoricTestElements with Matchers with EitherValues { errors.errors.toChain .get(1) .get - .message shouldBe "Cannot resolve column name \"" + id + "entifier\" among (" + id + ", " + otherColumn + ")" + .message should startWith( + "Cannot resolve column name \"" + id + "entifier\" among (" + id + ", " + otherColumn + ")" + ) val joinFunction2: DoricJoinColumn = LeftDF(colLong(id)) === RightDF.colLong(id) @@ -87,7 +89,9 @@ class JoinOpsSpec extends DoricTestElements with Matchers with EitherValues { errors2.errors.toChain .get(1) .get - .message shouldBe "Cannot resolve column name \"" + id + "entifier\" among (" + id + ", " + otherColumn + ")" + .message should startWith( + "Cannot resolve column name \"" + id + "entifier\" among (" + id + ", " + otherColumn + ")" + ) } it("should prevent key ambiguity with innerJoinDropRightKey") { diff --git a/core/src/test/scala/doric/sem/TransformOpsSpec.scala b/core/src/test/scala/doric/sem/TransformOpsSpec.scala index 211190b23..69065d7a7 100644 --- a/core/src/test/scala/doric/sem/TransformOpsSpec.scala +++ b/core/src/test/scala/doric/sem/TransformOpsSpec.scala @@ -119,7 +119,9 @@ class TransformOpsSpec "b" -> colLong("id") ) } - error.getMessage shouldBe "Found duplicate column(s) in given column names: `a`, `b`" + error.getMessage should startWith( + "Found duplicate column(s) in given column names: `a`, `b`" + ) } } } diff --git a/core/src/test/scala/doric/syntax/AggregationColumnsSpec.scala b/core/src/test/scala/doric/syntax/AggregationColumnsSpec.scala index 4bafe2a91..f91fdeeb6 100644 --- a/core/src/test/scala/doric/syntax/AggregationColumnsSpec.scala +++ b/core/src/test/scala/doric/syntax/AggregationColumnsSpec.scala @@ -1,13 +1,13 @@ package doric package syntax +import doric.implicitConversions.stringCname +import doric.Equalities._ import org.scalatest.EitherValues import org.scalatest.matchers.should.Matchers -import org.apache.spark.sql.{functions => f} -import doric.implicitConversions.stringCname -import Equalities._ -import java.sql.Date +import org.apache.spark.sql.{Column, functions => f} +import org.apache.spark.sql.catalyst.expressions.aggregate.Sum class AggregationColumnsSpec extends DoricTestElements @@ -475,149 +475,6 @@ class AggregationColumnsSpec } } - describe("percentileApprox doric function") { - import spark.implicits._ - val df = List( - ("k1", 0.0), - ("k1", 1.0), - ("k1", 2.0), - ("k1", 10.0) - ).toDF("keyCol", "col1") - val dfDate = List( - ("k1", Date.valueOf("2021-12-05")), - ("k1", Date.valueOf("2021-12-06")), - ("k1", Date.valueOf("2021-12-07")), - ("k1", Date.valueOf("2021-12-01")) - ).toDF("keyCol", "col1") - - it( - "should work as spark percentile_approx function working with percentile array & double type" - ) { - val percentage = Array(0.5, 0.4, 0.1) - val accuracy = 100 - df.testAggregation( - "keyCol", - percentileApprox( - colDouble("col1"), - percentage, - accuracy - ), - f.percentile_approx( - f.col("col1"), - f.lit(percentage), - f.lit(accuracy) - ), - List(Some(Array(1.0, 1.0, 0.0))) - ) - } - - it( - "should work as spark percentile_approx function working with percentile array & date type" - ) { - val percentage = Array(0.5, 0.4, 0.1) - val accuracy = 100 - - dfDate.testAggregation( - "keyCol", - percentileApprox( - colDate("col1"), - percentage, - accuracy - ), - f.percentile_approx( - f.col("col1"), - f.lit(percentage), - f.lit(accuracy) - ), - List( - Some( - Array( - Date.valueOf("2021-12-05"), - Date.valueOf("2021-12-05"), - Date.valueOf("2021-12-01") - ) - ) - ) - ) - } - - it("should throw an exception if percentile array & wrong percentile") { - val msg = intercept[java.lang.IllegalArgumentException] { - df.select( - percentileApprox(colDouble("col1"), Array(-0.5, 0.4, 0.1), 100) - ).collect() - } - - msg.getMessage shouldBe "requirement failed: Each value of percentage must be between 0.0 and 1.0." - } - - it("should throw an exception if percentile array & wrong accuracy") { - val msg = intercept[java.lang.IllegalArgumentException] { - df.select( - percentileApprox(colDouble("col1"), Array(0.5, 0.4, 0.1), -1) - ).collect() - } - - msg.getMessage shouldBe s"requirement failed: The accuracy provided must be a literal between (0, ${Int.MaxValue}] (current value = -1)" - } - - it( - "should work as spark percentile_approx function working with percentile double & double type" - ) { - val percentage = 0.5 - val accuracy = 100 - - df.testAggregation( - "keyCol", - percentileApprox(colDouble("col1"), percentage, accuracy), - f.percentile_approx(f.col("col1"), f.lit(percentage), f.lit(accuracy)), - List(Some(1.0)) - ) - } - - it( - "should work as spark percentile_approx function working percentile double & date type" - ) { - val percentage = 0.5 - val accuracy = 100 - - dfDate.testAggregation( - "keyCol", - percentileApprox( - colDate("col1"), - percentage, - accuracy - ), - f.percentile_approx( - f.col("col1"), - f.lit(percentage), - f.lit(accuracy) - ), - List(Some(Date.valueOf("2021-12-05"))) - ) - } - - it("should throw an exception if percentile double & wrong percentile") { - val msg = intercept[java.lang.IllegalArgumentException] { - df.select( - percentileApprox(colDouble("col1"), -0.5, 100) - ).collect() - } - - msg.getMessage shouldBe "requirement failed: Percentage must be between 0.0 and 1.0." - } - - it("should throw an exception if percentile double & wrong accuracy") { - val msg = intercept[java.lang.IllegalArgumentException] { - df.select( - percentileApprox(colDouble("col1"), 0.5, -1) - ).collect() - } - - msg.getMessage shouldBe s"requirement failed: The accuracy provided must be a literal between (0, ${Int.MaxValue}] (current value = -1)" - } - } - describe("skewness doric function") { import spark.implicits._ @@ -709,7 +566,9 @@ class AggregationColumnsSpec df.testAggregation( "keyCol", sumDistinct(colLong("col1")), - f.sumDistinct(f.col("col1")), + new Column( + Sum(f.col("col1").expr).toAggregateExpression(isDistinct = true) + ), List(Some(4L), Some(6L)) ) } @@ -725,7 +584,9 @@ class AggregationColumnsSpec df.testAggregation( "keyCol", sumDistinct(colDouble("col1")), - f.sumDistinct(f.col("col1")), + new Column( + Sum(f.col("col1").expr).toAggregateExpression(isDistinct = true) + ), List(Some(4.0), Some(6.0)) ) } diff --git a/core/src/test/scala/doric/syntax/ArrayColumnsSpec.scala b/core/src/test/scala/doric/syntax/ArrayColumnsSpec.scala index a59394c81..25bd14956 100644 --- a/core/src/test/scala/doric/syntax/ArrayColumnsSpec.scala +++ b/core/src/test/scala/doric/syntax/ArrayColumnsSpec.scala @@ -1,10 +1,11 @@ package doric package syntax -import org.apache.spark.sql.{functions => f} import org.scalatest.EitherValues import org.scalatest.matchers.should.Matchers +import org.apache.spark.sql.{functions => f} + class ArrayColumnsSpec extends DoricTestElements with EitherValues @@ -47,7 +48,9 @@ class ArrayColumnsSpec .left .value .head - .message shouldBe "Cannot resolve column name \"something2\" among (col, something)" + .message should startWith( + "Cannot resolve column name \"something2\" among (col, something)" + ) colArrayInt("col") .transform(_ => colString("something")) @@ -93,7 +96,9 @@ class ArrayColumnsSpec .left .value .head - .message shouldBe "Cannot resolve column name \"something2\" among (col, something)" + .message should startWith( + "Cannot resolve column name \"something2\" among (col, something)" + ) } it( @@ -122,8 +127,9 @@ class ArrayColumnsSpec .value errors.toChain.size shouldBe 2 + val end = if (spark.version.take(3) <= "3.0") ";" else "" errors.map(_.message).toChain.toList shouldBe List( - "Cannot resolve column name \"something2\" among (col, something)", + "Cannot resolve column name \"something2\" among (col, something)" + end, "The column with name 'something' is of type StringType and it was expected to be IntegerType" ) } @@ -158,10 +164,11 @@ class ArrayColumnsSpec .value errors.toChain.size shouldBe 3 + val end = if (spark.version.take(3) <= "3.0") ";" else "" errors.map(_.message).toChain.toList shouldBe List( - "Cannot resolve column name \"something2\" among (col, something)", + "Cannot resolve column name \"something2\" among (col, something)" + end, "The column with name 'something' is of type StringType and it was expected to be IntegerType", - "Cannot resolve column name \"something3\" among (col, something)" + "Cannot resolve column name \"something3\" among (col, something)" + end ) } diff --git a/core/src/test/scala/doric/syntax/AsSpec.scala b/core/src/test/scala/doric/syntax/AsSpec.scala index 4b150c29d..bb5e1d0d7 100644 --- a/core/src/test/scala/doric/syntax/AsSpec.scala +++ b/core/src/test/scala/doric/syntax/AsSpec.scala @@ -27,9 +27,9 @@ class AsSpec extends DoricTestElements with EitherValues with Matchers { val originalColumn = sparkCol("error").asDoric[Int] val errors = originalColumn.elem.run(df).toEither.left.value errors.length shouldBe 1 - errors.head.message.take( - 57 - ) shouldBe "cannot resolve '`error`' given input columns: [int, str];" + val errorMessage = errors.head.message.take(57) + errorMessage should startWith("cannot resolve") + errorMessage should include("given input columns: [int, str];") errors.head.location.fileName.value shouldBe "AsSpec.scala" } diff --git a/core/src/test/scala/doric/syntax/BinaryColumnsSpec.scala b/core/src/test/scala/doric/syntax/BinaryColumnsSpec.scala index e87af3ab6..e9568ab2d 100644 --- a/core/src/test/scala/doric/syntax/BinaryColumnsSpec.scala +++ b/core/src/test/scala/doric/syntax/BinaryColumnsSpec.scala @@ -1,11 +1,11 @@ package doric package syntax -import doric.implicitConversions.stringCname -import org.apache.spark.sql.{functions => f} import org.scalatest.EitherValues import org.scalatest.matchers.should.Matchers +import org.apache.spark.sql.{functions => f} + class BinaryColumnsSpec extends DoricTestElements with EitherValues diff --git a/core/src/test/scala/doric/syntax/BooleanColumnsSpec.scala b/core/src/test/scala/doric/syntax/BooleanColumnsSpec.scala index 560187075..6258f603d 100644 --- a/core/src/test/scala/doric/syntax/BooleanColumnsSpec.scala +++ b/core/src/test/scala/doric/syntax/BooleanColumnsSpec.scala @@ -6,8 +6,6 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.sql.{functions => f} -import doric.implicitConversions.stringCname - class BooleanColumnsSpec extends DoricTestElements with EitherValues @@ -91,45 +89,4 @@ class BooleanColumnsSpec } } - describe("assertTrue doric function") { - import spark.implicits._ - - it("should do nothing if assertion is true") { - val df = Seq(true, true, true) - .toDF("col1") - - df.testColumns("col1")( - c => colBoolean(c).assertTrue, - c => f.assert_true(f.col(c)), - List(None, None, None) - ) - } - - it("should throw an exception if assertion is false") { - val df = Seq(true, false, false) - .toDF("col1") - - intercept[java.lang.RuntimeException] { - df.select( - colBoolean("col1").assertTrue - ).collect() - } - } - - it("should throw an exception if assertion is false with a message") { - val df = Seq(true, false, false) - .toDF("col1") - - val errorMessage = "this is an error message" - - val exception = intercept[java.lang.RuntimeException] { - df.select( - colBoolean("col1").assertTrue(errorMessage.lit) - ).collect() - } - - exception.getMessage shouldBe errorMessage - } - } - } diff --git a/core/src/test/scala/doric/syntax/NumericOperationsSpec.scala b/core/src/test/scala/doric/syntax/NumericOperationsSpec.scala index 71fd0fa17..501917e73 100644 --- a/core/src/test/scala/doric/syntax/NumericOperationsSpec.scala +++ b/core/src/test/scala/doric/syntax/NumericOperationsSpec.scala @@ -1,15 +1,12 @@ package doric package syntax -import scala.reflect.{ClassTag, classTag} +import scala.reflect.{classTag, ClassTag} + import doric.types.{NumericType, SparkType} import org.scalatest.funspec.AnyFunSpecLike -import org.apache.spark.sql.{functions => f} -import org.apache.spark.sql.DataFrame - -import java.sql.Timestamp -import doric.implicitConversions.stringCname +import org.apache.spark.sql.{DataFrame, functions => f} trait NumericOperationsSpec extends AnyFunSpecLike with TypedColumnTest { @@ -241,55 +238,6 @@ class NumericSpec extends NumericOperationsSpec with SparkSessionTestWrapper { } } - describe("timestampSeconds doric function") { - import spark.implicits._ - - it("should work as spark timestamp_seconds function with integers") { - val df = List(Some(123), Some(1), None) - .toDF("col1") - - df.testColumns("col1")( - c => colInt(c).timestampSeconds, - c => f.timestamp_seconds(f.col(c)), - List( - Some(Timestamp.valueOf("1970-01-01 00:02:03")), - Some(Timestamp.valueOf("1970-01-01 00:00:01")), - None - ) - ) - } - - it("should work as spark timestamp_seconds function with longs") { - val df = List(Some(123L), Some(1L), None) - .toDF("col1") - - df.testColumns("col1")( - c => colLong(c).timestampSeconds, - c => f.timestamp_seconds(f.col(c)), - List( - Some(Timestamp.valueOf("1970-01-01 00:02:03")), - Some(Timestamp.valueOf("1970-01-01 00:00:01")), - None - ) - ) - } - - it("should work as spark timestamp_seconds function with doubles") { - val df = List(Some(123.2), Some(1.9), None) - .toDF("col1") - - df.testColumns("col1")( - c => colDouble(c).timestampSeconds, - c => f.timestamp_seconds(f.col(c)), - List( - Some(Timestamp.valueOf("1970-01-01 00:02:03.2")), - Some(Timestamp.valueOf("1970-01-01 00:00:01.9")), - None - ) - ) - } - } - describe("fromUnixTime doric function") { import spark.implicits._ @@ -315,14 +263,15 @@ class NumericSpec extends NumericOperationsSpec with SparkSessionTestWrapper { ) } - it("should fail if wrong pattern is given") { - val df = List(Some(123L), Some(1L), None) - .toDF("col1") - - intercept[java.lang.IllegalArgumentException]( - df.select(colLong("col1").fromUnixTime("wrong pattern".lit)) - .collect() - ) + if (spark.version.take(3) > "3.0") { + it("should fail if wrong pattern is given") { + val df = List(Some(123L), Some(1L), None) + .toDF("col1") + intercept[java.lang.IllegalArgumentException]( + df.select(colLong("col1").fromUnixTime("wrong pattern".lit)) + .collect() + ) + } } } diff --git a/core/src/test/scala/doric/syntax/StringColumnsSpec.scala b/core/src/test/scala/doric/syntax/StringColumnsSpec.scala index babfeffc8..bc32e2262 100644 --- a/core/src/test/scala/doric/syntax/StringColumnsSpec.scala +++ b/core/src/test/scala/doric/syntax/StringColumnsSpec.scala @@ -7,7 +7,6 @@ import org.scalatest.EitherValues import org.scalatest.matchers.should.Matchers import org.apache.spark.sql.{functions => f} -import org.apache.spark.sql.types.NullType class StringColumnsSpec extends DoricTestElements @@ -843,28 +842,6 @@ class StringColumnsSpec } } - describe("raiseError doric function") { - import spark.implicits._ - - val df = List("this is an error").toDF("errorMsg") - - it("should work as spark raise_error function") { - import java.lang.{RuntimeException => exception} - - val doricErr = intercept[exception] { - val res = df.select(colString("errorMsg").raiseError) - - res.schema.head.dataType shouldBe NullType - res.collect() - } - val sparkErr = intercept[exception] { - df.select(f.raise_error(f.col("errorMsg"))).collect() - } - - doricErr.getMessage shouldBe sparkErr.getMessage - } - } - describe("encode doric function") { import spark.implicits._ @@ -928,11 +905,13 @@ class StringColumnsSpec ) } - it("should fail if malformed format") { - intercept[java.lang.IllegalArgumentException]( - df.select(colString("dateCol").unixTimestamp("yabcd".lit)) - .collect() - ) + if (spark.version.take(3) > "3.0") { + it("should fail if malformed format") { + intercept[java.lang.IllegalArgumentException]( + df.select(colString("dateCol").unixTimestamp("yabcd".lit)) + .collect() + ) + } } } diff --git a/core/src/test/scala/doric/syntax/TypeMatcherSpec.scala b/core/src/test/scala/doric/syntax/TypeMatcherSpec.scala index 96159799d..b6a13b5f4 100644 --- a/core/src/test/scala/doric/syntax/TypeMatcherSpec.scala +++ b/core/src/test/scala/doric/syntax/TypeMatcherSpec.scala @@ -63,7 +63,9 @@ class TypeMatcherSpec val errors = testColumn.elem.run(df).toEither.left.value errors.length shouldBe 1 - errors.head.message shouldBe "Cannot resolve column name \"int2\" among (colArr, int, str)" + errors.head.message should startWith( + "Cannot resolve column name \"int2\" among (colArr, int, str)" + ) } it( @@ -76,7 +78,9 @@ class TypeMatcherSpec val errors = testColumn.elem.run(df).toEither.left.value errors.length shouldBe 1 - errors.head.message shouldBe "Cannot resolve column name \"int3\" among (colArr, int, str)" + errors.head.message should startWith( + "Cannot resolve column name \"int3\" among (colArr, int, str)" + ) } it( diff --git a/core/src/test/spark_3.1/scala/doric/syntax/AggregationColumns31Spec.scala b/core/src/test/spark_3.1/scala/doric/syntax/AggregationColumns31Spec.scala new file mode 100644 index 000000000..b2d8a5157 --- /dev/null +++ b/core/src/test/spark_3.1/scala/doric/syntax/AggregationColumns31Spec.scala @@ -0,0 +1,160 @@ +package doric +package syntax + +import doric.implicitConversions.stringCname +import doric.Equalities._ +import java.sql.Date +import org.scalatest.EitherValues +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.sql.{functions => f} + +class AggregationColumns31Spec + extends DoricTestElements + with EitherValues + with Matchers { + + describe("percentileApprox doric function") { + import spark.implicits._ + val df = List( + ("k1", 0.0), + ("k1", 1.0), + ("k1", 2.0), + ("k1", 10.0) + ).toDF("keyCol", "col1") + val dfDate = List( + ("k1", Date.valueOf("2021-12-05")), + ("k1", Date.valueOf("2021-12-06")), + ("k1", Date.valueOf("2021-12-07")), + ("k1", Date.valueOf("2021-12-01")) + ).toDF("keyCol", "col1") + + it( + "should work as spark percentile_approx function working with percentile array & double type" + ) { + val percentage = Array(0.5, 0.4, 0.1) + val accuracy = 100 + df.testAggregation( + "keyCol", + percentileApprox( + colDouble("col1"), + percentage, + accuracy + ), + f.percentile_approx( + f.col("col1"), + f.lit(percentage), + f.lit(accuracy) + ), + List(Some(Array(1.0, 1.0, 0.0))) + ) + } + + it( + "should work as spark percentile_approx function working with percentile array & date type" + ) { + val percentage = Array(0.5, 0.4, 0.1) + val accuracy = 100 + + dfDate.testAggregation( + "keyCol", + percentileApprox( + colDate("col1"), + percentage, + accuracy + ), + f.percentile_approx( + f.col("col1"), + f.lit(percentage), + f.lit(accuracy) + ), + List( + Some( + Array( + Date.valueOf("2021-12-05"), + Date.valueOf("2021-12-05"), + Date.valueOf("2021-12-01") + ) + ) + ) + ) + } + + it("should throw an exception if percentile array & wrong percentile") { + val msg = intercept[java.lang.IllegalArgumentException] { + df.select( + percentileApprox(colDouble("col1"), Array(-0.5, 0.4, 0.1), 100) + ).collect() + } + + msg.getMessage shouldBe "requirement failed: Each value of percentage must be between 0.0 and 1.0." + } + + it("should throw an exception if percentile array & wrong accuracy") { + val msg = intercept[java.lang.IllegalArgumentException] { + df.select( + percentileApprox(colDouble("col1"), Array(0.5, 0.4, 0.1), -1) + ).collect() + } + + msg.getMessage shouldBe s"requirement failed: The accuracy provided must be a literal between (0, ${Int.MaxValue}] (current value = -1)" + } + + it( + "should work as spark percentile_approx function working with percentile double & double type" + ) { + val percentage = 0.5 + val accuracy = 100 + + df.testAggregation( + "keyCol", + percentileApprox(colDouble("col1"), percentage, accuracy), + f.percentile_approx(f.col("col1"), f.lit(percentage), f.lit(accuracy)), + List(Some(1.0)) + ) + } + + it( + "should work as spark percentile_approx function working percentile double & date type" + ) { + val percentage = 0.5 + val accuracy = 100 + + dfDate.testAggregation( + "keyCol", + percentileApprox( + colDate("col1"), + percentage, + accuracy + ), + f.percentile_approx( + f.col("col1"), + f.lit(percentage), + f.lit(accuracy) + ), + List(Some(Date.valueOf("2021-12-05"))) + ) + } + + it("should throw an exception if percentile double & wrong percentile") { + val msg = intercept[java.lang.IllegalArgumentException] { + df.select( + percentileApprox(colDouble("col1"), -0.5, 100) + ).collect() + } + + msg.getMessage shouldBe "requirement failed: Percentage must be between 0.0 and 1.0." + } + + it("should throw an exception if percentile double & wrong accuracy") { + val msg = intercept[java.lang.IllegalArgumentException] { + df.select( + percentileApprox(colDouble("col1"), 0.5, -1) + ).collect() + } + + msg.getMessage shouldBe s"requirement failed: The accuracy provided must be a literal between (0, ${Int.MaxValue}] (current value = -1)" + } + } + +} diff --git a/core/src/test/spark_3.1/scala/doric/syntax/BooleanColumns31Spec.scala b/core/src/test/spark_3.1/scala/doric/syntax/BooleanColumns31Spec.scala new file mode 100644 index 000000000..270954661 --- /dev/null +++ b/core/src/test/spark_3.1/scala/doric/syntax/BooleanColumns31Spec.scala @@ -0,0 +1,55 @@ +package doric +package syntax + +import org.scalatest.EitherValues +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.sql.{functions => f} + +class BooleanColumns31Spec + extends DoricTestElements + with EitherValues + with Matchers { + + describe("assertTrue doric function") { + import spark.implicits._ + + it("should do nothing if assertion is true") { + val df = Seq(true, true, true) + .toDF("col1") + + df.testColumns("col1")( + c => colBoolean(c).assertTrue, + c => f.assert_true(f.col(c)), + List(None, None, None) + ) + } + + it("should throw an exception if assertion is false") { + val df = Seq(true, false, false) + .toDF("col1") + + intercept[java.lang.RuntimeException] { + df.select( + colBoolean("col1").assertTrue + ).collect() + } + } + + it("should throw an exception if assertion is false with a message") { + val df = Seq(true, false, false) + .toDF("col1") + + val errorMessage = "this is an error message" + + val exception = intercept[java.lang.RuntimeException] { + df.select( + colBoolean("col1").assertTrue(errorMessage.lit) + ).collect() + } + + exception.getMessage shouldBe errorMessage + } + } + +} diff --git a/core/src/test/spark_3.1/scala/doric/syntax/Numeric31Spec.scala b/core/src/test/spark_3.1/scala/doric/syntax/Numeric31Spec.scala new file mode 100644 index 000000000..412470af8 --- /dev/null +++ b/core/src/test/spark_3.1/scala/doric/syntax/Numeric31Spec.scala @@ -0,0 +1,63 @@ +package doric +package syntax + +import java.sql.Timestamp +import org.scalatest.funspec.AnyFunSpecLike + +import org.apache.spark.sql.{functions => f} + +class Numeric31Spec + extends SparkSessionTestWrapper + with AnyFunSpecLike + with TypedColumnTest { + + describe("timestampSeconds doric function") { + import spark.implicits._ + + it("should work as spark timestamp_seconds function with integers") { + val df = List(Some(123), Some(1), None) + .toDF("col1") + + df.testColumns("col1")( + c => colInt(c).timestampSeconds, + c => f.timestamp_seconds(f.col(c)), + List( + Some(Timestamp.valueOf("1970-01-01 00:02:03")), + Some(Timestamp.valueOf("1970-01-01 00:00:01")), + None + ) + ) + } + + it("should work as spark timestamp_seconds function with longs") { + val df = List(Some(123L), Some(1L), None) + .toDF("col1") + + df.testColumns("col1")( + c => colLong(c).timestampSeconds, + c => f.timestamp_seconds(f.col(c)), + List( + Some(Timestamp.valueOf("1970-01-01 00:02:03")), + Some(Timestamp.valueOf("1970-01-01 00:00:01")), + None + ) + ) + } + + it("should work as spark timestamp_seconds function with doubles") { + val df = List(Some(123.2), Some(1.9), None) + .toDF("col1") + + df.testColumns("col1")( + c => colDouble(c).timestampSeconds, + c => f.timestamp_seconds(f.col(c)), + List( + Some(Timestamp.valueOf("1970-01-01 00:02:03.2")), + Some(Timestamp.valueOf("1970-01-01 00:00:01.9")), + None + ) + ) + } + } + +} diff --git a/core/src/test/spark_3.1/scala/doric/syntax/StringColumns31Spec.scala b/core/src/test/spark_3.1/scala/doric/syntax/StringColumns31Spec.scala new file mode 100644 index 000000000..4c72d13ab --- /dev/null +++ b/core/src/test/spark_3.1/scala/doric/syntax/StringColumns31Spec.scala @@ -0,0 +1,37 @@ +package doric +package syntax + +import org.scalatest.EitherValues +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.sql.{functions => f} +import org.apache.spark.sql.types.NullType + +class StringColumns31Spec + extends DoricTestElements + with EitherValues + with Matchers { + + describe("raiseError doric function") { + import spark.implicits._ + + val df = List("this is an error").toDF("errorMsg") + + it("should work as spark raise_error function") { + import java.lang.{RuntimeException => exception} + + val doricErr = intercept[exception] { + val res = df.select(colString("errorMsg").raiseError) + + res.schema.head.dataType shouldBe NullType + res.collect() + } + val sparkErr = intercept[exception] { + df.select(f.raise_error(f.col("errorMsg"))).collect() + } + + doricErr.getMessage shouldBe sparkErr.getMessage + } + } + +}