Skip to content

Commit

Permalink
Use storageAssighmentPolicy for casts in DML commands
Browse files Browse the repository at this point in the history
Follow spark.sql.storeAssignmentPolicy instead of spark.sql.ansi.enabled for casting behaviour in UPDATE and MERGE. This will by default error out at runtime when an overflow happens.

Closes #1938

GitOrigin-RevId: c960a0521df27daa6ee231e0a1022d8756496785
  • Loading branch information
olaky authored and allisonport-db committed Jul 26, 2023
1 parent 0626664 commit 6d78d43
Show file tree
Hide file tree
Showing 8 changed files with 458 additions and 7 deletions.
8 changes: 8 additions & 0 deletions spark/src/main/resources/error/delta-error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,14 @@
],
"sqlState" : "0A000"
},
"DELTA_CAST_OVERFLOW_IN_TABLE_WRITE" : {
"message" : [
"Failed to write a value of <sourceType> type into the <targetType> type column <columnName> due to an overflow.",
"Use `try_cast` on the input value to tolerate overflow and return NULL instead.",
"If necessary, set <storeAssignmentPolicyFlag> to \"LEGACY\" to bypass this error or set <updateAndMergeCastingFollowsAnsiEnabledFlag> to true to revert to the old behaviour and follow <ansiEnabledFlag> in UPDATE and MERGE."
],
"sqlState" : "22003"
},
"DELTA_CDC_NOT_ALLOWED_IN_THIS_VERSION" : {
"message" : [
"Configuration delta.enableChangeDataFeed cannot be set. Change data feed from Delta is not yet available."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference,
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.errors.QueryErrorsBase
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructField, StructType}

Expand Down Expand Up @@ -118,7 +119,8 @@ trait DocsPath {
*/
trait DeltaErrorsBase
extends DocsPath
with DeltaLogging {
with DeltaLogging
with QueryErrorsBase {

def baseDocsPath(spark: SparkSession): String = baseDocsPath(spark.sparkContext.getConf)

Expand Down Expand Up @@ -618,6 +620,22 @@ trait DeltaErrorsBase
)
}

def castingCauseOverflowErrorInTableWrite(
from: DataType,
to: DataType,
columnName: String): ArithmeticException = {
new DeltaArithmeticException(
errorClass = "DELTA_CAST_OVERFLOW_IN_TABLE_WRITE",
messageParameters = Map(
"sourceType" -> toSQLType(from),
"targetType" -> toSQLType(to),
"columnName" -> toSQLId(columnName),
"storeAssignmentPolicyFlag" -> SQLConf.STORE_ASSIGNMENT_POLICY.key,
"updateAndMergeCastingFollowsAnsiEnabledFlag" ->
DeltaSQLConf.UPDATE_AND_MERGE_CASTING_FOLLOWS_ANSI_ENABLED_FLAG.key,
"ansiEnabledFlag" -> SQLConf.ANSI_ENABLED.key))
}

def notADeltaTable(table: String): Throwable = {
new DeltaAnalysisException(errorClass = "DELTA_NOT_A_DELTA_TABLE",
messageParameters = Array(table))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,11 @@ class DeltaParseException(
ParserUtils.position(ctx.getStop)
) with DeltaThrowable

class DeltaArithmeticException(
errorClass: String,
messageParameters: Map[String, String]) extends ArithmeticException with DeltaThrowable {
override def getErrorClass: String = errorClass

override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava
}

Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@ import org.apache.spark.sql.delta.sources.DeltaSQLConf
import org.apache.spark.sql.delta.util.AnalysisHelper

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper}
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -405,7 +408,109 @@ trait UpdateExpressionsSupport extends SQLConfHelper with AnalysisHelper {
}
}

/**
* Replaces 'CastSupport.cast'. Selects a cast based on 'spark.sql.storeAssignmentPolicy' if
* 'spark.databricks.delta.updateAndMergeCastingFollowsAnsiEnabledFlag. is false, and based on
* 'spark.sql.ansi.enabled' otherwise.
*/
private def cast(child: Expression, dataType: DataType, columnName: String): Expression = {
Cast(child, dataType, Option(conf.sessionLocalTimeZone))
if (conf.getConf(DeltaSQLConf.UPDATE_AND_MERGE_CASTING_FOLLOWS_ANSI_ENABLED_FLAG)) {
return Cast(child, dataType, Option(conf.sessionLocalTimeZone))
}

conf.storeAssignmentPolicy match {
case SQLConf.StoreAssignmentPolicy.LEGACY =>
Cast(child, dataType, Some(conf.sessionLocalTimeZone), ansiEnabled = false)
case SQLConf.StoreAssignmentPolicy.ANSI =>
val cast = Cast(child, dataType, Some(conf.sessionLocalTimeZone), ansiEnabled = true)
if (canCauseCastOverflow(cast)) {
CheckOverflowInTableWrite(cast, columnName)
} else {
cast
}
case SQLConf.StoreAssignmentPolicy.STRICT =>
UpCast(child, dataType)
}
}

private def containsIntegralOrDecimalType(dt: DataType): Boolean = dt match {
case _: IntegralType | _: DecimalType => true
case a: ArrayType => containsIntegralOrDecimalType(a.elementType)
case m: MapType =>
containsIntegralOrDecimalType(m.keyType) || containsIntegralOrDecimalType(m.valueType)
case s: StructType =>
s.fields.exists(sf => containsIntegralOrDecimalType(sf.dataType))
case _ => false
}

private def canCauseCastOverflow(cast: Cast): Boolean = {
containsIntegralOrDecimalType(cast.dataType) &&
!Cast.canUpCast(cast.child.dataType, cast.dataType)
}
}

case class CheckOverflowInTableWrite(child: Expression, columnName: String)
extends UnaryExpression {
override protected def withNewChildInternal(newChild: Expression): Expression = {
copy(child = newChild)
}

private def getCast: Option[Cast] = child match {
case c: Cast => Some(c)
case ExpressionProxy(c: Cast, _, _) => Some(c)
case _ => None
}

override def eval(input: InternalRow): Any = try {
child.eval(input)
} catch {
case e: ArithmeticException =>
getCast match {
case Some(cast) =>
throw DeltaErrors.castingCauseOverflowErrorInTableWrite(
cast.child.dataType,
cast.dataType,
columnName)
case None => throw e
}
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
getCast match {
case Some(child) => doGenCodeWithBetterErrorMsg(ctx, ev, child)
case None => child.genCode(ctx)
}
}

def doGenCodeWithBetterErrorMsg(ctx: CodegenContext, ev: ExprCode, child: Cast): ExprCode = {
val childGen = child.genCode(ctx)
val exceptionClass = classOf[ArithmeticException].getCanonicalName
assert(child.isInstanceOf[Cast])
val cast = child.asInstanceOf[Cast]
val fromDt =
ctx.addReferenceObj("from", cast.child.dataType, cast.child.dataType.getClass.getName)
val toDt = ctx.addReferenceObj("to", child.dataType, child.dataType.getClass.getName)
val col = ctx.addReferenceObj("colName", columnName, "java.lang.String")
// scalastyle:off line.size.limit
ev.copy(code =
code"""
boolean ${ev.isNull} = true;
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
try {
${childGen.code}
${ev.isNull} = ${childGen.isNull};
${ev.value} = ${childGen.value};
} catch ($exceptionClass e) {
throw org.apache.spark.sql.delta.DeltaErrors
.castingCauseOverflowErrorInTableWrite($fromDt, $toDt, $col);
}"""
)
// scalastyle:on line.size.limit
}

override def dataType: DataType = child.dataType

override def sql: String = child.sql

override def toString: String = child.toString
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import org.apache.spark.internal.config.ConfigBuilder
import org.apache.spark.network.util.ByteUnit
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils

/**
* [[SQLConf]] entries for Delta features.
Expand Down Expand Up @@ -1254,6 +1253,15 @@ trait DeltaSQLConfBase {
.intConf
.createWithDefault(100 * 1000)

val UPDATE_AND_MERGE_CASTING_FOLLOWS_ANSI_ENABLED_FLAG =
buildConf("updateAndMergeCastingFollowsAnsiEnabledFlag")
.internal()
.doc("""If false, casting behaviour in implicit casts in UPDATE and MERGE follows
|'spark.sql.storeAssignmentPolicy'. If true, these casts follow 'ansi.enabled'. This
|was the default before Delta 3.5.""".stripMargin)
.booleanConf
.createWithDefault(false)

}

object DeltaSQLConf extends DeltaSQLConfBase
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ import org.apache.spark.sql.catalyst.expressions.Uuid
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.errors.QueryErrorsBase
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils}
import org.apache.spark.sql.types.{CalendarIntervalType, DataTypes, DateType, IntegerType, StringType, StructField, StructType, TimestampNTZType}
Expand All @@ -60,7 +61,8 @@ trait DeltaErrorsSuiteBase
extends QueryTest
with SharedSparkSession with GivenWhenThen
with DeltaSQLCommandTest
with SQLTestUtils {
with SQLTestUtils
with QueryErrorsBase {

val MAX_URL_ACCESS_RETRIES = 3
val path = "/sample/path"
Expand Down Expand Up @@ -288,6 +290,24 @@ trait DeltaErrorsSuiteBase
assert(
e.getMessage == s"$table is a view. Writes to a view are not supported.")
}
{
val sourceType = IntegerType
val targetType = DateType
val columnName = "column_name"
val e = intercept[DeltaArithmeticException] {
throw DeltaErrors.castingCauseOverflowErrorInTableWrite(sourceType, targetType, columnName)
}
assert(e.getErrorClass == "DELTA_CAST_OVERFLOW_IN_TABLE_WRITE")
assert(e.getSqlState == "22003")
assert(e.getMessageParameters.get("sourceType") == toSQLType(sourceType))
assert(e.getMessageParameters.get("targetType") == toSQLType(targetType))
assert(e.getMessageParameters.get("columnName") == toSQLId(columnName))
assert(e.getMessageParameters.get("storeAssignmentPolicyFlag")
== SQLConf.STORE_ASSIGNMENT_POLICY.key)
assert(e.getMessageParameters.get("updateAndMergeCastingFollowsAnsiEnabledFlag")
== DeltaSQLConf.UPDATE_AND_MERGE_CASTING_FOLLOWS_ANSI_ENABLED_FLAG.key)
assert(e.getMessageParameters.get("ansiEnabledFlag") == SQLConf.ANSI_ENABLED.key)
}
{
val e = intercept[DeltaAnalysisException] {
throw DeltaErrors.invalidColumnName(name = "col-1")
Expand Down
Loading

0 comments on commit 6d78d43

Please sign in to comment.