Skip to content

Commit

Permalink
[SPARK-49031] Implement validation for the TransformWithStateExec ope…
Browse files Browse the repository at this point in the history
…rator using OperatorStateMetadataV2

### What changes were proposed in this pull request?

Implementing validation for the TransformWithStateExec operator, so that it can't restart with a different TimeMode and OutputMode, or invalid State Variable transformations.

### Why are the changes needed?

If there is an invalid change to the query after a restart, we want the query to fail.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Unit tests

### Was this patch authored or co-authored using generative AI tooling?

No

Closes apache#47508 from ericm-db/validation.

Authored-by: Eric Marnadi <eric.marnadi@databricks.com>
Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
  • Loading branch information
ericm-db authored and HeartSaVioR committed Jul 30, 2024
1 parent acb2fec commit 40d94b6
Show file tree
Hide file tree
Showing 8 changed files with 553 additions and 18 deletions.
18 changes: 18 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -3824,6 +3824,12 @@
],
"sqlState" : "42802"
},
"STATEFUL_PROCESSOR_DUPLICATE_STATE_VARIABLE_DEFINED" : {
"message" : [
"State variable with name <stateVarName> has already been defined in the StatefulProcessor."
],
"sqlState" : "42802"
},
"STATEFUL_PROCESSOR_INCORRECT_TIME_MODE_TO_ASSIGN_TTL" : {
"message" : [
"Cannot use TTL for state=<stateName> in timeMode=<timeMode>, use TimeMode.ProcessingTime() instead."
Expand Down Expand Up @@ -3873,12 +3879,24 @@
],
"sqlState" : "42802"
},
"STATE_STORE_INVALID_CONFIG_AFTER_RESTART" : {
"message" : [
"Cannot change <configName> from <oldConfig> to <newConfig> between restarts. Please set <configName> to <oldConfig>, or restart with a new checkpoint directory."
],
"sqlState" : "42K06"
},
"STATE_STORE_INVALID_PROVIDER" : {
"message" : [
"The given State Store Provider <inputClass> does not extend org.apache.spark.sql.execution.streaming.state.StateStoreProvider."
],
"sqlState" : "42K06"
},
"STATE_STORE_INVALID_VARIABLE_TYPE_CHANGE" : {
"message" : [
"Cannot change <stateVarName> to <newType> between query restarts. Please set <stateVarName> to <oldType>, or restart with a new checkpoint directory."
],
"sqlState" : "42K06"
},
"STATE_STORE_KEY_ROW_FORMAT_VALIDATION_FAILURE" : {
"message" : [
"The streaming query failed to validate written state for key row.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadat
import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec
import org.apache.spark.sql.execution.streaming.sources.WriteToMicroBatchDataSourceV1
import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadataV1, OperatorStateMetadataV2, OperatorStateMetadataWriter}
import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadataReader, OperatorStateMetadataV1, OperatorStateMetadataV2, OperatorStateMetadataWriter}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.util.{SerializableConfiguration, Utils}
Expand Down Expand Up @@ -213,6 +213,27 @@ class IncrementalExecution(
statefulOp match {
case ssw: StateStoreWriter =>
val metadata = ssw.operatorStateMetadata(stateSchemaPaths)
// validate metadata
if (isFirstBatch && currentBatchId != 0) {
// If we are restarting from a different checkpoint directory
// there may be a mismatch between the stateful operators in the
// physical plan and the metadata.
val oldMetadata = try {
OperatorStateMetadataReader.createReader(
new Path(checkpointLocation, ssw.getStateInfo.operatorId.toString),
hadoopConf, ssw.operatorStateMetadataVersion).read()
} catch {
case e: Exception =>
logWarning(log"Error reading metadata path for stateful operator. This " +
log"may due to no prior committed batch, or previously run on lower " +
log"versions: ${MDC(ERROR, e.getMessage)}")
None
}
oldMetadata match {
case Some(oldMetadata) => ssw.validateNewMetadata(oldMetadata, metadata)
case None =>
}
}
val metadataWriter = OperatorStateMetadataWriter.createWriter(
new Path(checkpointLocation, ssw.getStateInfo.operatorId.toString),
hadoopConf,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,17 +301,32 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
extends StatefulProcessorHandleImplBase(timeMode, keyExprEnc) {

// Because this is only happening on the driver side, there is only
// one task modifying and accessing this map at a time
// one task modifying and accessing these maps at a time
private[sql] val columnFamilySchemas: mutable.Map[String, StateStoreColFamilySchema] =
new mutable.HashMap[String, StateStoreColFamilySchema]()

private val stateVariableInfos: mutable.Map[String, TransformWithStateVariableInfo] =
new mutable.HashMap[String, TransformWithStateVariableInfo]()

def getColumnFamilySchemas: Map[String, StateStoreColFamilySchema] = columnFamilySchemas.toMap

def getStateVariableInfos: Map[String, TransformWithStateVariableInfo] = stateVariableInfos.toMap

private def checkIfDuplicateVariableDefined(stateVarName: String): Unit = {
if (columnFamilySchemas.contains(stateVarName)) {
throw StateStoreErrors.duplicateStateVariableDefined(stateVarName)
}
}

override def getValueState[T](stateName: String, valEncoder: Encoder[T]): ValueState[T] = {
verifyStateVarOperations("get_value_state", PRE_INIT)
val colFamilySchema = StateStoreColumnFamilySchemaUtils.
getValueStateSchema(stateName, keyExprEnc, valEncoder, false)
checkIfDuplicateVariableDefined(stateName)
columnFamilySchemas.put(stateName, colFamilySchema)
val stateVariableInfo = TransformWithStateVariableUtils.
getValueState(stateName, ttlEnabled = false)
stateVariableInfos.put(stateName, stateVariableInfo)
null.asInstanceOf[ValueState[T]]
}

Expand All @@ -322,15 +337,23 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
verifyStateVarOperations("get_value_state", PRE_INIT)
val colFamilySchema = StateStoreColumnFamilySchemaUtils.
getValueStateSchema(stateName, keyExprEnc, valEncoder, true)
checkIfDuplicateVariableDefined(stateName)
columnFamilySchemas.put(stateName, colFamilySchema)
val stateVariableInfo = TransformWithStateVariableUtils.
getValueState(stateName, ttlEnabled = true)
stateVariableInfos.put(stateName, stateVariableInfo)
null.asInstanceOf[ValueState[T]]
}

override def getListState[T](stateName: String, valEncoder: Encoder[T]): ListState[T] = {
verifyStateVarOperations("get_list_state", PRE_INIT)
val colFamilySchema = StateStoreColumnFamilySchemaUtils.
getListStateSchema(stateName, keyExprEnc, valEncoder, false)
checkIfDuplicateVariableDefined(stateName)
columnFamilySchemas.put(stateName, colFamilySchema)
val stateVariableInfo = TransformWithStateVariableUtils.
getListState(stateName, ttlEnabled = false)
stateVariableInfos.put(stateName, stateVariableInfo)
null.asInstanceOf[ListState[T]]
}

Expand All @@ -341,7 +364,11 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
verifyStateVarOperations("get_list_state", PRE_INIT)
val colFamilySchema = StateStoreColumnFamilySchemaUtils.
getListStateSchema(stateName, keyExprEnc, valEncoder, true)
checkIfDuplicateVariableDefined(stateName)
columnFamilySchemas.put(stateName, colFamilySchema)
val stateVariableInfo = TransformWithStateVariableUtils.
getListState(stateName, ttlEnabled = true)
stateVariableInfos.put(stateName, stateVariableInfo)
null.asInstanceOf[ListState[T]]
}

Expand All @@ -352,7 +379,11 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
verifyStateVarOperations("get_map_state", PRE_INIT)
val colFamilySchema = StateStoreColumnFamilySchemaUtils.
getMapStateSchema(stateName, keyExprEnc, userKeyEnc, valEncoder, false)
checkIfDuplicateVariableDefined(stateName)
columnFamilySchemas.put(stateName, colFamilySchema)
val stateVariableInfo = TransformWithStateVariableUtils.
getMapState(stateName, ttlEnabled = false)
stateVariableInfos.put(stateName, stateVariableInfo)
null.asInstanceOf[MapState[K, V]]
}

Expand All @@ -365,6 +396,9 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi
val colFamilySchema = StateStoreColumnFamilySchemaUtils.
getMapStateSchema(stateName, keyExprEnc, userKeyEnc, valEncoder, true)
columnFamilySchemas.put(stateName, colFamilySchema)
val stateVariableInfo = TransformWithStateVariableUtils.
getMapState(stateName, ttlEnabled = true)
stateVariableInfos.put(stateName, stateVariableInfo)
null.asInstanceOf[MapState[K, V]]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@ import java.util.concurrent.TimeUnit.NANOSECONDS

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.json4s.JsonAST.JValue
import org.json4s.JsonDSL._
import org.json4s.JString
import org.json4s.jackson.JsonMethods.{compact, render}

import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -125,6 +121,12 @@ case class TransformWithStateExec(
columnFamilySchemas
}

private def getStateVariableInfos(): Map[String, TransformWithStateVariableInfo] = {
val stateVariableInfos = getDriverProcessorHandle().getStateVariableInfos
closeProcessorHandle()
stateVariableInfos
}

/**
* This method is used for the driver-side stateful processor after we
* have collected all the necessary schemas.
Expand Down Expand Up @@ -423,12 +425,12 @@ case class TransformWithStateExec(
Array(StateStoreMetadataV2(
StateStoreId.DEFAULT_STORE_NAME, 0, info.numPartitions, stateSchemaPaths.head))

val operatorPropertiesJson: JValue =
("timeMode" -> JString(timeMode.toString)) ~
("outputMode" -> JString(outputMode.toString))

val json = compact(render(operatorPropertiesJson))
OperatorStateMetadataV2(operatorInfo, stateStoreInfo, json)
val operatorProperties = TransformWithStateOperatorProperties(
timeMode.toString,
outputMode.toString,
getStateVariableInfos().values.toList
)
OperatorStateMetadataV2(operatorInfo, stateStoreInfo, operatorProperties.json)
}

private def stateSchemaDirPath(): Path = {
Expand All @@ -441,6 +443,23 @@ case class TransformWithStateExec(
new Path(new Path(storeNamePath, "_metadata"), "schema")
}

override def validateNewMetadata(
oldOperatorMetadata: OperatorStateMetadata,
newOperatorMetadata: OperatorStateMetadata): Unit = {
(oldOperatorMetadata, newOperatorMetadata) match {
case (
oldMetadataV2: OperatorStateMetadataV2,
newMetadataV2: OperatorStateMetadataV2) =>
val oldOperatorProps = TransformWithStateOperatorProperties.fromJson(
oldMetadataV2.operatorPropertiesJson)
val newOperatorProps = TransformWithStateOperatorProperties.fromJson(
newMetadataV2.operatorPropertiesJson)
TransformWithStateOperatorProperties.validateOperatorProperties(
oldOperatorProps, newOperatorProps)
case (_, _) =>
}
}

override protected def doExecute(): RDD[InternalRow] = {
metrics // force lazy init at driver

Expand Down Expand Up @@ -666,4 +685,3 @@ object TransformWithStateExec {
}
}
// scalastyle:on argcount

Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.streaming

import org.json4s.DefaultFormats
import org.json4s.JsonAST._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods
import org.json4s.jackson.JsonMethods.{compact, render}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.streaming.StateVariableType.StateVariableType
import org.apache.spark.sql.execution.streaming.state.StateStoreErrors

/**
* This file contains utility classes and functions for managing state variables in
* the operatorProperties field of the OperatorStateMetadata for TransformWithState.
* We use these utils to read and write state variable information for validation purposes.
*/
object TransformWithStateVariableUtils {
def getValueState(stateName: String, ttlEnabled: Boolean): TransformWithStateVariableInfo = {
TransformWithStateVariableInfo(stateName, StateVariableType.ValueState, ttlEnabled)
}

def getListState(stateName: String, ttlEnabled: Boolean): TransformWithStateVariableInfo = {
TransformWithStateVariableInfo(stateName, StateVariableType.ListState, ttlEnabled)
}

def getMapState(stateName: String, ttlEnabled: Boolean): TransformWithStateVariableInfo = {
TransformWithStateVariableInfo(stateName, StateVariableType.MapState, ttlEnabled)
}
}

// Enum of possible State Variable types
object StateVariableType extends Enumeration {
type StateVariableType = Value
val ValueState, ListState, MapState = Value
}

case class TransformWithStateVariableInfo(
stateName: String,
stateVariableType: StateVariableType,
ttlEnabled: Boolean) {
def jsonValue: JValue = {
("stateName" -> JString(stateName)) ~
("stateVariableType" -> JString(stateVariableType.toString)) ~
("ttlEnabled" -> JBool(ttlEnabled))
}

def json: String = {
compact(render(jsonValue))
}
}

object TransformWithStateVariableInfo {

def fromJson(json: String): TransformWithStateVariableInfo = {
implicit val formats: DefaultFormats.type = DefaultFormats
val parsed = JsonMethods.parse(json).extract[Map[String, Any]]
fromMap(parsed)
}

def fromMap(map: Map[String, Any]): TransformWithStateVariableInfo = {
val stateName = map("stateName").asInstanceOf[String]
val stateVariableType = StateVariableType.withName(
map("stateVariableType").asInstanceOf[String])
val ttlEnabled = map("ttlEnabled").asInstanceOf[Boolean]
TransformWithStateVariableInfo(stateName, stateVariableType, ttlEnabled)
}
}

case class TransformWithStateOperatorProperties(
timeMode: String,
outputMode: String,
stateVariables: List[TransformWithStateVariableInfo]) {

def json: String = {
val stateVariablesJson = stateVariables.map(_.jsonValue)
val json =
("timeMode" -> timeMode) ~
("outputMode" -> outputMode) ~
("stateVariables" -> stateVariablesJson)
compact(render(json))
}
}

object TransformWithStateOperatorProperties extends Logging {
def fromJson(json: String): TransformWithStateOperatorProperties = {
implicit val formats: DefaultFormats.type = DefaultFormats
val jsonMap = JsonMethods.parse(json).extract[Map[String, Any]]
TransformWithStateOperatorProperties(
jsonMap("timeMode").asInstanceOf[String],
jsonMap("outputMode").asInstanceOf[String],
jsonMap("stateVariables").asInstanceOf[List[Map[String, Any]]].map { stateVarMap =>
TransformWithStateVariableInfo.fromMap(stateVarMap)
}
)
}

// This function is to confirm that the operator properties and state variables have
// only changed in an acceptable way after query restart. If the properties have changed
// in an unacceptable way, this function will throw an exception.
def validateOperatorProperties(
oldOperatorProperties: TransformWithStateOperatorProperties,
newOperatorProperties: TransformWithStateOperatorProperties): Unit = {
if (oldOperatorProperties.timeMode != newOperatorProperties.timeMode) {
throw StateStoreErrors.invalidConfigChangedAfterRestart(
"timeMode", oldOperatorProperties.timeMode, newOperatorProperties.timeMode)
}

if (oldOperatorProperties.outputMode != newOperatorProperties.outputMode) {
throw StateStoreErrors.invalidConfigChangedAfterRestart(
"outputMode", oldOperatorProperties.outputMode, newOperatorProperties.outputMode)
}

val oldStateVariableInfos = oldOperatorProperties.stateVariables
val newStateVariableInfos = newOperatorProperties.stateVariables.map { stateVarInfo =>
stateVarInfo.stateName -> stateVarInfo
}.toMap
oldStateVariableInfos.foreach { oldInfo =>
val newInfo = newStateVariableInfos.get(oldInfo.stateName)
newInfo match {
case Some(stateVarInfo) =>
if (oldInfo.stateVariableType != stateVarInfo.stateVariableType) {
throw StateStoreErrors.invalidVariableTypeChange(
stateVarInfo.stateName,
oldInfo.stateVariableType.toString,
stateVarInfo.stateVariableType.toString
)
}
case None =>
}
}
}
}
Loading

0 comments on commit 40d94b6

Please sign in to comment.