Skip to content

Commit

Permalink
loadTable should indicate if it's for writing
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed Aug 15, 2024
1 parent def42d4 commit 8586259
Show file tree
Hide file tree
Showing 11 changed files with 147 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,17 @@ public interface TableCatalog extends CatalogPlugin {
*/
Table loadTable(Identifier ident) throws NoSuchTableException;

/**
* A variant of {@link #loadTable(Identifier)} that indicates it's for data writing.
* Implementations can override this method to do additional handling for data writing, such as
* checking write permissions.
*
* @since 4.0.0
*/
default Table loadTableForWrite(Identifier ident) throws NoSuchTableException {
return loadTable(ident);
}

/**
* Load table metadata of a specific version by {@link Identifier identifier} from the catalog.
* <p>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1321,7 +1321,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
cachedConnectRelation
}.getOrElse(cachedRelation)
}.orElse {
val table = CatalogV2Util.loadTable(catalog, ident, finalTimeTravelSpec)
val forWrite = "true".equalsIgnoreCase(u.options.get(UnresolvedRelation.FOR_WRITE))
val table = CatalogV2Util.loadTable(catalog, ident, finalTimeTravelSpec, forWrite)
val loaded = createRelation(catalog, ident, table, u.options, u.isStreaming)
loaded.foreach(AnalysisContext.get.relationCache.update(key, _))
u.getTagValue(LogicalPlan.PLAN_ID_TAG).map { planId =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,21 @@ case class UnresolvedRelation(

override def name: String = tableName

def forWrite: UnresolvedRelation = {
val newOptions = new java.util.HashMap[String, String]
newOptions.put(UnresolvedRelation.FOR_WRITE, "true")
newOptions.putAll(options)
copy(options = new CaseInsensitiveStringMap(newOptions))
}

final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_RELATION)
}

object UnresolvedRelation {
// An internal option of `UnresolvedRelation` to indicate that we look up this relation for data
// writing.
val FOR_WRITE = "__for_write__"

def apply(
tableIdentifier: TableIdentifier,
extraOptions: CaseInsensitiveStringMap,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ class AstBuilder extends DataTypeAstBuilder
= visitInsertIntoTable(table)
withIdentClause(relationCtx, ident => {
val insertIntoStatement = InsertIntoStatement(
createUnresolvedRelation(relationCtx, ident, options),
createUnresolvedRelation(relationCtx, ident, options, forWrite = true),
partition,
cols,
query,
Expand All @@ -473,7 +473,7 @@ class AstBuilder extends DataTypeAstBuilder
= visitInsertOverwriteTable(table)
withIdentClause(relationCtx, ident => {
InsertIntoStatement(
createUnresolvedRelation(relationCtx, ident, options),
createUnresolvedRelation(relationCtx, ident, options, forWrite = true),
partition,
cols,
query,
Expand All @@ -482,9 +482,10 @@ class AstBuilder extends DataTypeAstBuilder
byName)
})
case ctx: InsertIntoReplaceWhereContext =>
val options = Option(ctx.optionsClause())
withIdentClause(ctx.identifierReference, ident => {
OverwriteByExpression.byPosition(
createUnresolvedRelation(ctx.identifierReference, ident, Option(ctx.optionsClause())),
createUnresolvedRelation(ctx.identifierReference, ident, options, forWrite = true),
query,
expression(ctx.whereClause().booleanExpression()))
})
Expand Down Expand Up @@ -569,7 +570,7 @@ class AstBuilder extends DataTypeAstBuilder

override def visitDeleteFromTable(
ctx: DeleteFromTableContext): LogicalPlan = withOrigin(ctx) {
val table = createUnresolvedRelation(ctx.identifierReference)
val table = createUnresolvedRelation(ctx.identifierReference, forWrite = true)
val tableAlias = getTableAliasWithoutColumnAlias(ctx.tableAlias(), "DELETE")
val aliasedTable = tableAlias.map(SubqueryAlias(_, table)).getOrElse(table)
val predicate = if (ctx.whereClause() != null) {
Expand All @@ -581,7 +582,7 @@ class AstBuilder extends DataTypeAstBuilder
}

override def visitUpdateTable(ctx: UpdateTableContext): LogicalPlan = withOrigin(ctx) {
val table = createUnresolvedRelation(ctx.identifierReference)
val table = createUnresolvedRelation(ctx.identifierReference, forWrite = true)
val tableAlias = getTableAliasWithoutColumnAlias(ctx.tableAlias(), "UPDATE")
val aliasedTable = tableAlias.map(SubqueryAlias(_, table)).getOrElse(table)
val assignments = withAssignments(ctx.setClause().assignmentList())
Expand All @@ -604,7 +605,7 @@ class AstBuilder extends DataTypeAstBuilder

override def visitMergeIntoTable(ctx: MergeIntoTableContext): LogicalPlan = withOrigin(ctx) {
val withSchemaEvolution = ctx.EVOLUTION() != null
val targetTable = createUnresolvedRelation(ctx.target)
val targetTable = createUnresolvedRelation(ctx.target, forWrite = true)
val targetTableAlias = getTableAliasWithoutColumnAlias(ctx.targetAlias, "MERGE")
val aliasedTarget = targetTableAlias.map(SubqueryAlias(_, targetTable)).getOrElse(targetTable)

Expand Down Expand Up @@ -3115,10 +3116,17 @@ class AstBuilder extends DataTypeAstBuilder
*/
private def createUnresolvedRelation(
ctx: IdentifierReferenceContext,
optionsClause: Option[OptionsClauseContext] = None): LogicalPlan = withOrigin(ctx) {
optionsClause: Option[OptionsClauseContext] = None,
forWrite: Boolean = false): LogicalPlan = withOrigin(ctx) {
val options = resolveOptions(optionsClause)
withIdentClause(ctx, parts =>
new UnresolvedRelation(parts, options, isStreaming = false))
withIdentClause(ctx, parts => {
val relation = new UnresolvedRelation(parts, options, isStreaming = false)
if (forWrite) {
relation.forWrite
} else {
relation
}
})
}

/**
Expand All @@ -3127,9 +3135,15 @@ class AstBuilder extends DataTypeAstBuilder
private def createUnresolvedRelation(
ctx: ParserRuleContext,
ident: Seq[String],
optionsClause: Option[OptionsClauseContext]): UnresolvedRelation = withOrigin(ctx) {
optionsClause: Option[OptionsClauseContext],
forWrite: Boolean): UnresolvedRelation = withOrigin(ctx) {
val options = resolveOptions(optionsClause)
new UnresolvedRelation(ident, options, isStreaming = false)
val relation = new UnresolvedRelation(ident, options, isStreaming = false)
if (forWrite) {
relation.forWrite
} else {
relation
}
}

private def resolveOptions(
Expand Down Expand Up @@ -5005,7 +5019,7 @@ class AstBuilder extends DataTypeAstBuilder
if (query.isDefined) {
CacheTableAsSelect(ident.head, query.get, source(ctx.query()), isLazy, options)
} else {
CacheTable(createUnresolvedRelation(ctx.identifierReference, ident, None),
CacheTable(createUnresolvedRelation(ctx.identifierReference, ident, None, forWrite = false),
ident, isLazy, options)
}
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,9 +403,10 @@ private[sql] object CatalogV2Util {
def loadTable(
catalog: CatalogPlugin,
ident: Identifier,
timeTravelSpec: Option[TimeTravelSpec] = None): Option[Table] =
timeTravelSpec: Option[TimeTravelSpec] = None,
forWrite: Boolean = false): Option[Table] =
try {
Option(getTable(catalog, ident, timeTravelSpec))
Option(getTable(catalog, ident, timeTravelSpec, forWrite))
} catch {
case _: NoSuchTableException => None
case _: NoSuchDatabaseException => None
Expand All @@ -414,16 +415,22 @@ private[sql] object CatalogV2Util {
def getTable(
catalog: CatalogPlugin,
ident: Identifier,
timeTravelSpec: Option[TimeTravelSpec] = None): Table = {
timeTravelSpec: Option[TimeTravelSpec] = None,
forWrite: Boolean = false): Table = {
if (timeTravelSpec.nonEmpty) {
assert(!forWrite, "Should not write to a table with time travel")
timeTravelSpec.get match {
case v: AsOfVersion =>
catalog.asTableCatalog.loadTable(ident, v.version)
case ts: AsOfTimestamp =>
catalog.asTableCatalog.loadTable(ident, ts.timestamp)
}
} else {
catalog.asTableCatalog.loadTable(ident)
if (forWrite) {
catalog.asTableCatalog.loadTableForWrite(ident)
} else {
catalog.asTableCatalog.loadTable(ident)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, DataSourceUtils, LogicalRelation}
import org.apache.spark.sql.execution.datasources.v2._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -473,7 +474,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
private def insertInto(catalog: CatalogPlugin, ident: Identifier): Unit = {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._

val table = catalog.asTableCatalog.loadTable(ident) match {
val table = catalog.asTableCatalog.loadTableForWrite(ident) match {
case _: V1Table =>
return insertInto(TableIdentifier(ident.name(), ident.namespace().headOption))
case t =>
Expand Down Expand Up @@ -504,7 +505,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
private def insertInto(tableIdent: TableIdentifier): Unit = {
runCommand(df.sparkSession) {
InsertIntoStatement(
table = UnresolvedRelation(tableIdent),
table = UnresolvedRelation(tableIdent).forWrite,
partitionSpec = Map.empty[String, Option[String]],
Nil,
query = df.logicalPlan,
Expand Down Expand Up @@ -588,7 +589,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._

val session = df.sparkSession
val canUseV2 = lookupV2Provider().isDefined
val canUseV2 = lookupV2Provider().isDefined ||
df.sparkSession.sessionState.conf.getConf(SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION).isDefined

session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match {
case nameParts @ NonSessionCatalogAndIdentifier(catalog, ident) =>
Expand All @@ -609,7 +611,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {

private def saveAsTable(
catalog: TableCatalog, ident: Identifier, nameParts: Seq[String]): Unit = {
val tableOpt = try Option(catalog.loadTable(ident)) catch {
val tableOpt = try Option(catalog.loadTableForWrite(ident)) catch {
case _: NoSuchTableException => None
}

Expand Down Expand Up @@ -670,7 +672,6 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
val catalog = df.sparkSession.sessionState.catalog
val qualifiedIdent = catalog.qualifyIdentifier(tableIdent)
val tableExists = catalog.tableExists(qualifiedIdent)
val tableName = qualifiedIdent.unquotedString

(tableExists, mode) match {
case (true, SaveMode.Ignore) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
*/
@throws(classOf[NoSuchTableException])
def append(): Unit = {
val append = AppendData.byName(UnresolvedRelation(tableName), logicalPlan, options.toMap)
val append = AppendData.byName(
UnresolvedRelation(tableName).forWrite, logicalPlan, options.toMap)
runCommand(append)
}

Expand All @@ -185,7 +186,7 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
@throws(classOf[NoSuchTableException])
def overwrite(condition: Column): Unit = {
val overwrite = OverwriteByExpression.byName(
UnresolvedRelation(tableName), logicalPlan, condition.expr, options.toMap)
UnresolvedRelation(tableName).forWrite, logicalPlan, condition.expr, options.toMap)
runCommand(overwrite)
}

Expand All @@ -205,7 +206,7 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
@throws(classOf[NoSuchTableException])
def overwritePartitions(): Unit = {
val dynamicOverwrite = OverwritePartitionsDynamic.byName(
UnresolvedRelation(tableName), logicalPlan, options.toMap)
UnresolvedRelation(tableName).forWrite, logicalPlan, options.toMap)
runCommand(dynamicOverwrite)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ class MergeIntoWriter[T] private[sql] (
}

val merge = MergeIntoTable(
UnresolvedRelation(tableName),
UnresolvedRelation(tableName).forWrite,
logicalPlan,
on.expr,
matchedActions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,7 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager)
}

private def supportsV1Command(catalog: CatalogPlugin): Boolean = {
catalog.name().equalsIgnoreCase(CatalogManager.SESSION_CATALOG_NAME) &&
!SQLConf.get.getConf(SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION).isDefined
isSessionCatalog(catalog) &&
SQLConf.get.getConf(SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION).isEmpty
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ case class CreateTableAsSelectExec(
}
val table = Option(catalog.createTable(
ident, getV2Columns(query.schema, catalog.useNullableQuerySchema),
partitioning.toArray, properties.asJava)).getOrElse(catalog.loadTable(ident))
partitioning.toArray, properties.asJava)).getOrElse(catalog.loadTableForWrite(ident))
writeToTable(catalog, table, writeOptions, ident, query)
}
}
Expand Down Expand Up @@ -164,7 +164,7 @@ case class ReplaceTableAsSelectExec(
}
val table = Option(catalog.createTable(
ident, getV2Columns(query.schema, catalog.useNullableQuerySchema),
partitioning.toArray, properties.asJava)).getOrElse(catalog.loadTable(ident))
partitioning.toArray, properties.asJava)).getOrElse(catalog.loadTableForWrite(ident))
writeToTable(catalog, table, writeOptions, ident, query)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3703,6 +3703,63 @@ class DataSourceV2SQLSuiteV1Filter
}
}

test("SPARK-49246: read-only catalog") {
def checkWriteOperations(catalog: String): Unit = {
withSQLConf(s"spark.sql.catalog.$catalog" -> classOf[ReadOnlyCatalog].getName) {
val input = sql("SELECT 1")
val tbl = s"$catalog.default.t"
withTable(tbl) {
sql(s"CREATE TABLE $tbl (i INT)")
val df = sql(s"SELECT * FROM $tbl")
assert(df.collect().isEmpty)
assert(df.schema == new StructType().add("i", "int"))

intercept[RuntimeException](sql(s"INSERT INTO $tbl SELECT 1"))
intercept[RuntimeException](sql(s"INSERT INTO $tbl REPLACE WHERE i = 0 SELECT 1"))
intercept[RuntimeException] (sql(s"INSERT OVERWRITE $tbl SELECT 1"))
intercept[RuntimeException] (sql(s"DELETE FROM $tbl WHERE i = 0"))
intercept[RuntimeException] (sql(s"UPDATE $tbl SET i = 0"))
intercept[RuntimeException] {
sql(
s"""
|MERGE INTO $tbl USING (SELECT 1 i) AS source
|ON source.i = $tbl.i
|WHEN NOT MATCHED THEN INSERT *
|""".stripMargin)
}

intercept[RuntimeException](input.write.insertInto(tbl))
intercept[RuntimeException](input.write.mode("append").saveAsTable(tbl))
intercept[RuntimeException](input.writeTo(tbl).append())
intercept[RuntimeException](input.writeTo(tbl).overwrite(df.col("i") === 1))
intercept[RuntimeException](input.writeTo(tbl).overwritePartitions())
}

// Test CTAS
withTable(tbl) {
intercept[RuntimeException](sql(s"CREATE TABLE $tbl AS SELECT 1 i"))
}
withTable(tbl) {
intercept[RuntimeException](sql(s"CREATE OR REPLACE TABLE $tbl AS SELECT 1 i"))
}
withTable(tbl) {
intercept[RuntimeException](input.write.saveAsTable(tbl))
}
withTable(tbl) {
intercept[RuntimeException](input.writeTo(tbl).create())
}
withTable(tbl) {
intercept[RuntimeException](input.writeTo(tbl).createOrReplace())
}
}
}
// Reset CatalogManager to clear the materialized `spark_catalog` instance, so that we can
// configure a new implementation.
spark.sessionState.catalogManager.reset()
checkWriteOperations(SESSION_CATALOG_NAME)
checkWriteOperations("read_only_cat")
}

private def testNotSupportedV2Command(
sqlCommand: String,
sqlParams: String,
Expand Down Expand Up @@ -3771,3 +3828,17 @@ class V2CatalogSupportBuiltinDataSource extends InMemoryCatalog {
}
}

class ReadOnlyCatalog extends InMemoryCatalog {
override def createTable(
ident: Identifier,
columns: Array[ColumnV2],
partitions: Array[Transform],
properties: util.Map[String, String]): Table = {
super.createTable(ident, columns, partitions, properties)
null
}

override def loadTableForWrite(ident: Identifier): Table = {
throw new RuntimeException("cannot write")
}
}

0 comments on commit 8586259

Please sign in to comment.