diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java index b990f59bfd90e..387477d0f1911 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java @@ -110,6 +110,26 @@ public interface TableCatalog extends CatalogPlugin { */ Table loadTable(Identifier ident) throws NoSuchTableException; + /** + * Load table metadata by {@link Identifier identifier} from the catalog. Spark will write data + * into this table later. + *

+ * If the catalog supports views and contains a view for the identifier and not a table, this + * must throw {@link NoSuchTableException}. + * + * @param ident a table identifier + * @param writePrivileges + * @return the table's metadata + * @throws NoSuchTableException If the table doesn't exist or is a view + * + * @since 3.5.3 + */ + default Table loadTable( + Identifier ident, + Set writePrivileges) throws NoSuchTableException { + return loadTable(ident); + } + /** * Load table metadata of a specific version by {@link Identifier identifier} from the catalog. *

diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableWritePrivilege.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableWritePrivilege.java new file mode 100644 index 0000000000000..ca2d4ba9e7b4e --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableWritePrivilege.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog; + +/** + * The table write privileges that will be provided when loading a table. + * + * @since 3.5.3 + */ +public enum TableWritePrivilege { + /** + * The privilege for adding rows to the table. + */ + INSERT, + + /** + * The privilege for changing existing rows in th table. + */ + UPDATE, + + /** + * The privilege for deleting rows from the table. + */ + DELETE +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index fd0a0715b6344..463bd3c3a8a27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1309,8 +1309,12 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor cachedConnectRelation }.getOrElse(cachedRelation) }.orElse { - val table = CatalogV2Util.loadTable(catalog, ident, timeTravelSpec) - val loaded = createRelation(catalog, ident, table, u.options, u.isStreaming) + val writePrivilegesString = + Option(u.options.get(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES)) + val table = CatalogV2Util.loadTable( + catalog, ident, timeTravelSpec, writePrivilegesString) + val loaded = createRelation( + catalog, ident, table, u.clearWritePrivileges.options, u.isStreaming) loaded.foreach(AnalysisContext.get.relationCache.update(key, _)) u.getTagValue(LogicalPlan.PLAN_ID_TAG).map { planId => loaded.map { loadedRelation => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index a03c3e317c104..81d92acc6e84a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Unary import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId +import org.apache.spark.sql.connector.catalog.TableWritePrivilege import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.types.{DataType, Metadata, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -106,10 +107,36 @@ case class UnresolvedRelation( override def name: String = tableName + def requireWritePrivileges(privileges: Seq[TableWritePrivilege]): UnresolvedRelation = { + if (privileges.nonEmpty) { + val newOptions = new java.util.HashMap[String, String] + newOptions.putAll(options) + newOptions.put(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES, privileges.mkString(",")) + copy(options = new CaseInsensitiveStringMap(newOptions)) + } else { + this + } + } + + def clearWritePrivileges: UnresolvedRelation = { + if (options.containsKey(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES)) { + val newOptions = new java.util.HashMap[String, String] + newOptions.putAll(options) + newOptions.remove(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES) + copy(options = new CaseInsensitiveStringMap(newOptions)) + } else { + this + } + } + final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_RELATION) } object UnresolvedRelation { + // An internal option of `UnresolvedRelation` to specify the required write privileges when + // writing data to this relation. + val REQUIRED_WRITE_PRIVILEGES = "__required_write_privileges__" + def apply( tableIdentifier: TableIdentifier, extraOptions: CaseInsensitiveStringMap, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index f28a6de9fc2c1..2b600743e1bd8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.PARAMETER import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils, GeneratedColumn, IntervalUtils, ResolveDefaultColumns} import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ, getZoneId, stringToDate, stringToTimestamp, stringToTimestampWithoutTimeZone} -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsNamespaces, TableCatalog} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsNamespaces, TableCatalog, TableWritePrivilege} import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, Expression => V2Expression, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryParsingErrors} @@ -336,7 +336,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { = visitInsertIntoTable(table) withIdentClause(relationCtx, Seq(query), (ident, otherPlans) => { InsertIntoStatement( - createUnresolvedRelation(relationCtx, ident), + createUnresolvedRelation(relationCtx, ident, Seq(TableWritePrivilege.INSERT)), partition, cols, otherPlans.head, @@ -349,7 +349,8 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { = visitInsertOverwriteTable(table) withIdentClause(relationCtx, Seq(query), (ident, otherPlans) => { InsertIntoStatement( - createUnresolvedRelation(relationCtx, ident), + createUnresolvedRelation(relationCtx, ident, + Seq(TableWritePrivilege.INSERT, TableWritePrivilege.DELETE)), partition, cols, otherPlans.head, @@ -360,7 +361,8 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { case ctx: InsertIntoReplaceWhereContext => withIdentClause(ctx.identifierReference, Seq(query), (ident, otherPlans) => { OverwriteByExpression.byPosition( - createUnresolvedRelation(ctx.identifierReference, ident), + createUnresolvedRelation(ctx.identifierReference, ident, + Seq(TableWritePrivilege.INSERT, TableWritePrivilege.DELETE)), otherPlans.head, expression(ctx.whereClause().booleanExpression())) }) @@ -439,7 +441,8 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { override def visitDeleteFromTable( ctx: DeleteFromTableContext): LogicalPlan = withOrigin(ctx) { - val table = createUnresolvedRelation(ctx.identifierReference) + val table = createUnresolvedRelation( + ctx.identifierReference, writePrivileges = Seq(TableWritePrivilege.DELETE)) val tableAlias = getTableAliasWithoutColumnAlias(ctx.tableAlias(), "DELETE") val aliasedTable = tableAlias.map(SubqueryAlias(_, table)).getOrElse(table) val predicate = if (ctx.whereClause() != null) { @@ -451,7 +454,8 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { } override def visitUpdateTable(ctx: UpdateTableContext): LogicalPlan = withOrigin(ctx) { - val table = createUnresolvedRelation(ctx.identifierReference) + val table = createUnresolvedRelation( + ctx.identifierReference, writePrivileges = Seq(TableWritePrivilege.UPDATE)) val tableAlias = getTableAliasWithoutColumnAlias(ctx.tableAlias(), "UPDATE") val aliasedTable = tableAlias.map(SubqueryAlias(_, table)).getOrElse(table) val assignments = withAssignments(ctx.setClause().assignmentList()) @@ -473,10 +477,6 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { } override def visitMergeIntoTable(ctx: MergeIntoTableContext): LogicalPlan = withOrigin(ctx) { - val targetTable = createUnresolvedRelation(ctx.target) - val targetTableAlias = getTableAliasWithoutColumnAlias(ctx.targetAlias, "MERGE") - val aliasedTarget = targetTableAlias.map(SubqueryAlias(_, targetTable)).getOrElse(targetTable) - val sourceTableOrQuery = if (ctx.source != null) { createUnresolvedRelation(ctx.source) } else if (ctx.sourceQuery != null) { @@ -506,7 +506,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { s"Unrecognized matched action: ${clause.matchedAction().getText}") } } - } + }.toSeq val notMatchedActions = ctx.notMatchedClause().asScala.map { clause => { if (clause.notMatchedAction().INSERT() != null) { @@ -527,7 +527,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { s"Unrecognized matched action: ${clause.notMatchedAction().getText}") } } - } + }.toSeq val notMatchedBySourceActions = ctx.notMatchedBySourceClause().asScala.map { clause => { val notMatchedBySourceAction = clause.notMatchedBySourceAction() @@ -542,7 +542,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { s"Unrecognized matched action: ${clause.notMatchedBySourceAction().getText}") } } - } + }.toSeq if (matchedActions.isEmpty && notMatchedActions.isEmpty && notMatchedBySourceActions.isEmpty) { throw QueryParsingErrors.mergeStatementWithoutWhenClauseError(ctx) } @@ -561,13 +561,19 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { throw QueryParsingErrors.nonLastNotMatchedBySourceClauseOmitConditionError(ctx) } + val targetTable = createUnresolvedRelation( + ctx.target, + writePrivileges = MergeIntoTable.getWritePrivileges( + matchedActions, notMatchedActions, notMatchedBySourceActions)) + val targetTableAlias = getTableAliasWithoutColumnAlias(ctx.targetAlias, "MERGE") + val aliasedTarget = targetTableAlias.map(SubqueryAlias(_, targetTable)).getOrElse(targetTable) MergeIntoTable( aliasedTarget, aliasedSource, mergeCondition, - matchedActions.toSeq, - notMatchedActions.toSeq, - notMatchedBySourceActions.toSeq) + matchedActions, + notMatchedActions, + notMatchedBySourceActions) } /** @@ -2793,16 +2799,23 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { * Create an [[UnresolvedRelation]] from an identifier reference. */ private def createUnresolvedRelation( - ctx: IdentifierReferenceContext): LogicalPlan = withOrigin(ctx) { - withIdentClause(ctx, UnresolvedRelation(_)) + ctx: IdentifierReferenceContext, + writePrivileges: Seq[TableWritePrivilege] = Nil): LogicalPlan = withOrigin(ctx) { + withIdentClause(ctx, parts => { + val relation = new UnresolvedRelation(parts, isStreaming = false) + relation.requireWritePrivileges(writePrivileges) + }) } /** * Create an [[UnresolvedRelation]] from a multi-part identifier. */ private def createUnresolvedRelation( - ctx: ParserRuleContext, ident: Seq[String]): UnresolvedRelation = withOrigin(ctx) { - UnresolvedRelation(ident) + ctx: ParserRuleContext, + ident: Seq[String], + writePrivileges: Seq[TableWritePrivilege]): UnresolvedRelation = withOrigin(ctx) { + val relation = new UnresolvedRelation(ident, isStreaming = false) + relation.requireWritePrivileges(writePrivileges) } /** @@ -4601,7 +4614,9 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { if (query.isDefined) { CacheTableAsSelect(ident.head, query.get, source(ctx.query()), isLazy, options) } else { - CacheTable(createUnresolvedRelation(ctx.identifierReference, ident), ident, isLazy, options) + CacheTable( + createUnresolvedRelation(ctx.identifierReference, ident, writePrivileges = Nil), + ident, isLazy, options) } }) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 805f277cf9f6a..d7669ac0b1d78 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -758,6 +758,21 @@ case class MergeIntoTable( copy(targetTable = newLeft, sourceTable = newRight) } +object MergeIntoTable { + def getWritePrivileges( + matchedActions: Seq[MergeAction], + notMatchedActions: Seq[MergeAction], + notMatchedBySourceActions: Seq[MergeAction]): Seq[TableWritePrivilege] = { + val privileges = scala.collection.mutable.HashSet.empty[TableWritePrivilege] + (matchedActions.iterator ++ notMatchedActions ++ notMatchedBySourceActions).foreach { + case _: DeleteAction => privileges.add(TableWritePrivilege.DELETE) + case _: UpdateAction | _: UpdateStarAction => privileges.add(TableWritePrivilege.UPDATE) + case _: InsertAction | _: InsertStarAction => privileges.add(TableWritePrivilege.INSERT) + } + privileges.toSeq + } +} + sealed abstract class MergeAction extends Expression with Unevaluable { def condition: Option[Expression] override def nullable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala index 47c438f154ab9..f8f682e76cfc8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala @@ -331,9 +331,10 @@ private[sql] object CatalogV2Util { def loadTable( catalog: CatalogPlugin, ident: Identifier, - timeTravelSpec: Option[TimeTravelSpec] = None): Option[Table] = + timeTravelSpec: Option[TimeTravelSpec] = None, + writePrivilegesString: Option[String] = None): Option[Table] = try { - Option(getTable(catalog, ident, timeTravelSpec)) + Option(getTable(catalog, ident, timeTravelSpec, writePrivilegesString)) } catch { case _: NoSuchTableException => None case _: NoSuchDatabaseException => None @@ -343,8 +344,10 @@ private[sql] object CatalogV2Util { def getTable( catalog: CatalogPlugin, ident: Identifier, - timeTravelSpec: Option[TimeTravelSpec] = None): Table = { + timeTravelSpec: Option[TimeTravelSpec] = None, + writePrivilegesString: Option[String] = None): Table = { if (timeTravelSpec.nonEmpty) { + assert(writePrivilegesString.isEmpty, "Should not write to a table with time travel") timeTravelSpec.get match { case v: AsOfVersion => catalog.asTableCatalog.loadTable(ident, v.version) @@ -352,7 +355,13 @@ private[sql] object CatalogV2Util { catalog.asTableCatalog.loadTable(ident, ts.timestamp) } } else { - catalog.asTableCatalog.loadTable(ident) + if (writePrivilegesString.isDefined) { + val writePrivileges = writePrivilegesString.get.split(",").map(_.trim) + .map(TableWritePrivilege.valueOf).toSet.asJava + catalog.asTableCatalog.loadTable(ident, writePrivileges) + } else { + catalog.asTableCatalog.loadTable(ident) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index 6f36a8c9719cb..176c24d4e100f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -39,7 +39,16 @@ class DDLParserSuite extends AnalysisTest { } private def parseCompare(sql: String, expected: LogicalPlan): Unit = { - comparePlans(parsePlan(sql), expected, checkAnalysis = false) + // We don't care the write privileges in this suite. + val parsed = parsePlan(sql).transform { + case u: UnresolvedRelation => u.clearWritePrivileges + case i: InsertIntoStatement => + i.table match { + case u: UnresolvedRelation => i.copy(table = u.clearWritePrivileges) + case _ => i + } + } + comparePlans(parsed, expected, checkAnalysis = false) } private def internalException(sqlText: String): SparkThrowable = { @@ -2614,15 +2623,15 @@ class DDLParserSuite extends AnalysisTest { val timestampTypeSql = s"INSERT INTO t PARTITION(part = timestamp'$timestamp') VALUES('a')" val binaryTypeSql = s"INSERT INTO t PARTITION(part = X'$binaryHexStr') VALUES('a')" - comparePlans(parsePlan(dateTypeSql), insertPartitionPlan("2019-01-02")) + parseCompare(dateTypeSql, insertPartitionPlan("2019-01-02")) withSQLConf(SQLConf.LEGACY_INTERVAL_ENABLED.key -> "true") { - comparePlans(parsePlan(intervalTypeSql), insertPartitionPlan(interval)) + parseCompare(intervalTypeSql, insertPartitionPlan(interval)) } - comparePlans(parsePlan(ymIntervalTypeSql), insertPartitionPlan("INTERVAL '1-2' YEAR TO MONTH")) - comparePlans(parsePlan(dtIntervalTypeSql), + parseCompare(ymIntervalTypeSql, insertPartitionPlan("INTERVAL '1-2' YEAR TO MONTH")) + parseCompare(dtIntervalTypeSql, insertPartitionPlan("INTERVAL '1 02:03:04.128462' DAY TO SECOND")) - comparePlans(parsePlan(timestampTypeSql), insertPartitionPlan(timestamp)) - comparePlans(parsePlan(binaryTypeSql), insertPartitionPlan(binaryStr)) + parseCompare(timestampTypeSql, insertPartitionPlan(timestamp)) + parseCompare(binaryTypeSql, insertPartitionPlan(binaryStr)) } test("SPARK-38335: Implement parser support for DEFAULT values for columns in tables") { @@ -2717,12 +2726,12 @@ class DDLParserSuite extends AnalysisTest { // In each of the following cases, the DEFAULT reference parses as an unresolved attribute // reference. We can handle these cases after the parsing stage, at later phases of analysis. - comparePlans(parsePlan("VALUES (1, 2, DEFAULT) AS val"), + parseCompare("VALUES (1, 2, DEFAULT) AS val", SubqueryAlias("val", UnresolvedInlineTable(Seq("col1", "col2", "col3"), Seq(Seq(Literal(1), Literal(2), UnresolvedAttribute("DEFAULT")))))) - comparePlans(parsePlan( - "INSERT INTO t PARTITION(part = date'2019-01-02') VALUES ('a', DEFAULT)"), + parseCompare( + "INSERT INTO t PARTITION(part = date'2019-01-02') VALUES ('a', DEFAULT)", InsertIntoStatement( UnresolvedRelation(Seq("t")), Map("part" -> Some("2019-01-02")), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 13474fe29de98..acc5a6ebddd2e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -39,7 +39,16 @@ class PlanParserSuite extends AnalysisTest { import org.apache.spark.sql.catalyst.dsl.plans._ private def assertEqual(sqlCommand: String, plan: LogicalPlan): Unit = { - comparePlans(parsePlan(sqlCommand), plan, checkAnalysis = false) + // We don't care the write privileges in this suite. + val parsed = parsePlan(sqlCommand).transform { + case u: UnresolvedRelation => u.clearWritePrivileges + case i: InsertIntoStatement => + i.table match { + case u: UnresolvedRelation => i.copy(table = u.clearWritePrivileges) + case _ => i + } + } + comparePlans(parsed, plan, checkAnalysis = false) } private def parseException(sqlText: String): SparkThrowable = { @@ -1033,57 +1042,56 @@ class PlanParserSuite extends AnalysisTest { errorClass = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'b'", "hint" -> "")) - comparePlans( - parsePlan("SELECT /*+ HINT */ * FROM t"), + assertEqual( + "SELECT /*+ HINT */ * FROM t", UnresolvedHint("HINT", Seq.empty, table("t").select(star()))) - comparePlans( - parsePlan("SELECT /*+ BROADCASTJOIN(u) */ * FROM t"), + assertEqual( + "SELECT /*+ BROADCASTJOIN(u) */ * FROM t", UnresolvedHint("BROADCASTJOIN", Seq($"u"), table("t").select(star()))) - comparePlans( - parsePlan("SELECT /*+ MAPJOIN(u) */ * FROM t"), + assertEqual( + "SELECT /*+ MAPJOIN(u) */ * FROM t", UnresolvedHint("MAPJOIN", Seq($"u"), table("t").select(star()))) - comparePlans( - parsePlan("SELECT /*+ STREAMTABLE(a,b,c) */ * FROM t"), + assertEqual( + "SELECT /*+ STREAMTABLE(a,b,c) */ * FROM t", UnresolvedHint("STREAMTABLE", Seq($"a", $"b", $"c"), table("t").select(star()))) - comparePlans( - parsePlan("SELECT /*+ INDEX(t, emp_job_ix) */ * FROM t"), + assertEqual( + "SELECT /*+ INDEX(t, emp_job_ix) */ * FROM t", UnresolvedHint("INDEX", Seq($"t", $"emp_job_ix"), table("t").select(star()))) - comparePlans( - parsePlan("SELECT /*+ MAPJOIN(`default.t`) */ * from `default.t`"), + assertEqual( + "SELECT /*+ MAPJOIN(`default.t`) */ * from `default.t`", UnresolvedHint("MAPJOIN", Seq(UnresolvedAttribute.quoted("default.t")), table("default.t").select(star()))) - comparePlans( - parsePlan("SELECT /*+ MAPJOIN(t) */ a from t where true group by a order by a"), + assertEqual( + "SELECT /*+ MAPJOIN(t) */ a from t where true group by a order by a", UnresolvedHint("MAPJOIN", Seq($"t"), table("t").where(Literal(true)).groupBy($"a")($"a")).orderBy($"a".asc)) - comparePlans( - parsePlan("SELECT /*+ COALESCE(10) */ * FROM t"), + assertEqual( + "SELECT /*+ COALESCE(10) */ * FROM t", UnresolvedHint("COALESCE", Seq(Literal(10)), table("t").select(star()))) - comparePlans( - parsePlan("SELECT /*+ REPARTITION(100) */ * FROM t"), + assertEqual( + "SELECT /*+ REPARTITION(100) */ * FROM t", UnresolvedHint("REPARTITION", Seq(Literal(100)), table("t").select(star()))) - comparePlans( - parsePlan( - "INSERT INTO s SELECT /*+ REPARTITION(100), COALESCE(500), COALESCE(10) */ * FROM t"), + assertEqual( + "INSERT INTO s SELECT /*+ REPARTITION(100), COALESCE(500), COALESCE(10) */ * FROM t", InsertIntoStatement(table("s"), Map.empty, Nil, UnresolvedHint("REPARTITION", Seq(Literal(100)), UnresolvedHint("COALESCE", Seq(Literal(500)), UnresolvedHint("COALESCE", Seq(Literal(10)), table("t").select(star())))), overwrite = false, ifPartitionNotExists = false)) - comparePlans( - parsePlan("SELECT /*+ BROADCASTJOIN(u), REPARTITION(100) */ * FROM t"), + assertEqual( + "SELECT /*+ BROADCASTJOIN(u), REPARTITION(100) */ * FROM t", UnresolvedHint("BROADCASTJOIN", Seq($"u"), UnresolvedHint("REPARTITION", Seq(Literal(100)), table("t").select(star())))) @@ -1094,49 +1102,48 @@ class PlanParserSuite extends AnalysisTest { errorClass = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'+'", "hint" -> "")) - comparePlans( - parsePlan("SELECT /*+ REPARTITION(c) */ * FROM t"), + assertEqual( + "SELECT /*+ REPARTITION(c) */ * FROM t", UnresolvedHint("REPARTITION", Seq(UnresolvedAttribute("c")), table("t").select(star()))) - comparePlans( - parsePlan("SELECT /*+ REPARTITION(100, c) */ * FROM t"), + assertEqual( + "SELECT /*+ REPARTITION(100, c) */ * FROM t", UnresolvedHint("REPARTITION", Seq(Literal(100), UnresolvedAttribute("c")), table("t").select(star()))) - comparePlans( - parsePlan("SELECT /*+ REPARTITION(100, c), COALESCE(50) */ * FROM t"), + assertEqual( + "SELECT /*+ REPARTITION(100, c), COALESCE(50) */ * FROM t", UnresolvedHint("REPARTITION", Seq(Literal(100), UnresolvedAttribute("c")), UnresolvedHint("COALESCE", Seq(Literal(50)), table("t").select(star())))) - comparePlans( - parsePlan("SELECT /*+ REPARTITION(100, c), BROADCASTJOIN(u), COALESCE(50) */ * FROM t"), + assertEqual( + "SELECT /*+ REPARTITION(100, c), BROADCASTJOIN(u), COALESCE(50) */ * FROM t", UnresolvedHint("REPARTITION", Seq(Literal(100), UnresolvedAttribute("c")), UnresolvedHint("BROADCASTJOIN", Seq($"u"), UnresolvedHint("COALESCE", Seq(Literal(50)), table("t").select(star()))))) - comparePlans( - parsePlan( - """ - |SELECT - |/*+ REPARTITION(100, c), BROADCASTJOIN(u), COALESCE(50), REPARTITION(300, c) */ - |* FROM t - """.stripMargin), + assertEqual( + """ + |SELECT + |/*+ REPARTITION(100, c), BROADCASTJOIN(u), COALESCE(50), REPARTITION(300, c) */ + |* FROM t + """.stripMargin, UnresolvedHint("REPARTITION", Seq(Literal(100), UnresolvedAttribute("c")), UnresolvedHint("BROADCASTJOIN", Seq($"u"), UnresolvedHint("COALESCE", Seq(Literal(50)), UnresolvedHint("REPARTITION", Seq(Literal(300), UnresolvedAttribute("c")), table("t").select(star())))))) - comparePlans( - parsePlan("SELECT /*+ REPARTITION_BY_RANGE(c) */ * FROM t"), + assertEqual( + "SELECT /*+ REPARTITION_BY_RANGE(c) */ * FROM t", UnresolvedHint("REPARTITION_BY_RANGE", Seq(UnresolvedAttribute("c")), table("t").select(star()))) - comparePlans( - parsePlan("SELECT /*+ REPARTITION_BY_RANGE(100, c) */ * FROM t"), + assertEqual( + "SELECT /*+ REPARTITION_BY_RANGE(100, c) */ * FROM t", UnresolvedHint("REPARTITION_BY_RANGE", Seq(Literal(100), UnresolvedAttribute("c")), table("t").select(star()))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 4de6b944bc868..84f02c723136b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -30,12 +30,15 @@ import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSel import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.connector.catalog.{CatalogPlugin, CatalogV2Implicits, CatalogV2Util, Identifier, SupportsCatalogOptions, Table, TableCatalog, TableProvider, V1Table} import org.apache.spark.sql.connector.catalog.TableCapability._ +import org.apache.spark.sql.connector.catalog.TableWritePrivilege +import org.apache.spark.sql.connector.catalog.TableWritePrivilege._ import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, DataSourceUtils, LogicalRelation} import org.apache.spark.sql.execution.datasources.v2._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types.StructType @@ -448,7 +451,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { private def insertInto(catalog: CatalogPlugin, ident: Identifier): Unit = { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ - val table = catalog.asTableCatalog.loadTable(ident) match { + val table = catalog.asTableCatalog.loadTable(ident, getWritePrivileges.toSet.asJava) match { case _: V1Table => return insertInto(TableIdentifier(ident.name(), ident.namespace().headOption)) case t => @@ -479,7 +482,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { private def insertInto(tableIdent: TableIdentifier): Unit = { runCommand(df.sparkSession) { InsertIntoStatement( - table = UnresolvedRelation(tableIdent), + table = UnresolvedRelation(tableIdent).requireWritePrivileges(getWritePrivileges), partitionSpec = Map.empty[String, Option[String]], Nil, query = df.logicalPlan, @@ -488,6 +491,11 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } } + private def getWritePrivileges: Seq[TableWritePrivilege] = mode match { + case SaveMode.Overwrite => Seq(INSERT, DELETE) + case _ => Seq(INSERT) + } + private def getBucketSpec: Option[BucketSpec] = { if (sortColumnNames.isDefined && numBuckets.isEmpty) { throw QueryCompilationErrors.sortByWithoutBucketingError() @@ -557,7 +565,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ val session = df.sparkSession - val canUseV2 = lookupV2Provider().isDefined + val canUseV2 = lookupV2Provider().isDefined || + df.sparkSession.sessionState.conf.getConf(SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION).isDefined session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match { case nameParts @ NonSessionCatalogAndIdentifier(catalog, ident) => @@ -578,7 +587,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { private def saveAsTable( catalog: TableCatalog, ident: Identifier, nameParts: Seq[String]): Unit = { - val tableOpt = try Option(catalog.loadTable(ident)) catch { + val tableOpt = try Option(catalog.loadTable(ident, getWritePrivileges.toSet.asJava)) catch { case _: NoSuchTableException => None } @@ -639,7 +648,6 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val catalog = df.sparkSession.sessionState.catalog val qualifiedIdent = catalog.qualifyIdentifier(tableIdent) val tableExists = catalog.tableExists(qualifiedIdent) - val tableName = qualifiedIdent.unquotedString (tableExists, mode) match { case (true, SaveMode.Ignore) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala index 7ca9c7ef71d67..09d884af05b18 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala @@ -24,6 +24,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException, UnresolvedIdentifier, UnresolvedRelation} import org.apache.spark.sql.catalyst.expressions.{Attribute, Bucket, Days, Hours, Literal, Months, Years} import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, LogicalPlan, OptionList, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, UnresolvedTableSpec} +import org.apache.spark.sql.connector.catalog.TableWritePrivilege._ import org.apache.spark.sql.connector.expressions.{LogicalExpressions, NamedReference, Transform} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.QueryExecution @@ -146,7 +147,9 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) */ @throws(classOf[NoSuchTableException]) def append(): Unit = { - val append = AppendData.byName(UnresolvedRelation(tableName), logicalPlan, options.toMap) + val append = AppendData.byName( + UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT)), + logicalPlan, options.toMap) runCommand(append) } @@ -163,7 +166,8 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) @throws(classOf[NoSuchTableException]) def overwrite(condition: Column): Unit = { val overwrite = OverwriteByExpression.byName( - UnresolvedRelation(tableName), logicalPlan, condition.expr, options.toMap) + UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT, DELETE)), + logicalPlan, condition.expr, options.toMap) runCommand(overwrite) } @@ -183,7 +187,8 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) @throws(classOf[NoSuchTableException]) def overwritePartitions(): Unit = { val dynamicOverwrite = OverwritePartitionsDynamic.byName( - UnresolvedRelation(tableName), logicalPlan, options.toMap) + UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT, DELETE)), + logicalPlan, options.toMap) runCommand(dynamicOverwrite) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index aee243b6529da..c96ce2daa49e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -669,7 +669,7 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) } private def supportsV1Command(catalog: CatalogPlugin): Boolean = { - catalog.name().equalsIgnoreCase(CatalogManager.SESSION_CATALOG_NAME) && - !SQLConf.get.getConf(SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION).isDefined + isSessionCatalog(catalog) && + SQLConf.get.getConf(SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION).isEmpty } } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/explain-aqe.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/explain-aqe.sql.out index f37e31bdb389c..522cfb0cbbd28 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/explain-aqe.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/explain-aqe.sql.out @@ -196,7 +196,7 @@ ExplainCommand 'Aggregate ['key], ['key, unresolvedalias('MIN('val), None)], For -- !query EXPLAIN EXTENDED INSERT INTO TABLE explain_temp5 SELECT * FROM explain_temp4 -- !query analysis -ExplainCommand 'InsertIntoStatement 'UnresolvedRelation [explain_temp5], [], false, false, false, false, ExtendedMode +ExplainCommand 'InsertIntoStatement 'UnresolvedRelation [explain_temp5], [__required_write_privileges__=INSERT], false, false, false, false, ExtendedMode -- !query diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/explain.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/explain.sql.out index f37e31bdb389c..522cfb0cbbd28 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/explain.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/explain.sql.out @@ -196,7 +196,7 @@ ExplainCommand 'Aggregate ['key], ['key, unresolvedalias('MIN('val), None)], For -- !query EXPLAIN EXTENDED INSERT INTO TABLE explain_temp5 SELECT * FROM explain_temp4 -- !query analysis -ExplainCommand 'InsertIntoStatement 'UnresolvedRelation [explain_temp5], [], false, false, false, false, ExtendedMode +ExplainCommand 'InsertIntoStatement 'UnresolvedRelation [explain_temp5], [__required_write_privileges__=INSERT], false, false, false, false, ExtendedMode -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out index 3c2677c936f9c..54fa9ca418cc1 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out @@ -1081,7 +1081,7 @@ EXPLAIN EXTENDED INSERT INTO TABLE explain_temp5 SELECT * FROM explain_temp4 struct -- !query output == Parsed Logical Plan == -'InsertIntoStatement 'UnresolvedRelation [explain_temp5], [], false, false, false, false +'InsertIntoStatement 'UnresolvedRelation [explain_temp5], [__required_write_privileges__=INSERT], false, false, false, false +- 'Project [*] +- 'UnresolvedRelation [explain_temp4], [], false diff --git a/sql/core/src/test/resources/sql-tests/results/explain.sql.out b/sql/core/src/test/resources/sql-tests/results/explain.sql.out index f54c6c5e44f2e..20314b5f9b93a 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain.sql.out @@ -1023,7 +1023,7 @@ EXPLAIN EXTENDED INSERT INTO TABLE explain_temp5 SELECT * FROM explain_temp4 struct -- !query output == Parsed Logical Plan == -'InsertIntoStatement 'UnresolvedRelation [explain_temp5], [], false, false, false, false +'InsertIntoStatement 'UnresolvedRelation [explain_temp5], [__required_write_privileges__=INSERT], false, false, false, false +- 'Project [*] +- 'UnresolvedRelation [explain_temp4], [], false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 4c6ae425291d0..27a0b731021eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -3442,6 +3442,73 @@ class DataSourceV2SQLSuiteV1Filter } } + test("SPARK-49246: read-only catalog") { + def assertPrivilegeError(f: => Unit, privilege: String): Unit = { + val e = intercept[RuntimeException](f) + assert(e.getMessage.contains(privilege)) + } + + def checkWriteOperations(catalog: String): Unit = { + withSQLConf(s"spark.sql.catalog.$catalog" -> classOf[ReadOnlyCatalog].getName) { + val input = sql("SELECT 1") + val tbl = s"$catalog.default.t" + withTable(tbl) { + sql(s"CREATE TABLE $tbl (i INT)") + val df = sql(s"SELECT * FROM $tbl") + assert(df.collect().isEmpty) + assert(df.schema == new StructType().add("i", "int")) + + assertPrivilegeError(sql(s"INSERT INTO $tbl SELECT 1"), "INSERT") + assertPrivilegeError( + sql(s"INSERT INTO $tbl REPLACE WHERE i = 0 SELECT 1"), "DELETE,INSERT") + assertPrivilegeError(sql(s"INSERT OVERWRITE $tbl SELECT 1"), "DELETE,INSERT") + assertPrivilegeError(sql(s"DELETE FROM $tbl WHERE i = 0"), "DELETE") + assertPrivilegeError(sql(s"UPDATE $tbl SET i = 0"), "UPDATE") + assertPrivilegeError( + sql(s""" + |MERGE INTO $tbl USING (SELECT 1 i) AS source + |ON source.i = $tbl.i + |WHEN MATCHED THEN UPDATE SET * + |WHEN NOT MATCHED THEN INSERT * + |WHEN NOT MATCHED BY SOURCE THEN DELETE + |""".stripMargin), + "DELETE,INSERT,UPDATE" + ) + + assertPrivilegeError(input.write.insertInto(tbl), "INSERT") + assertPrivilegeError(input.write.mode("overwrite").insertInto(tbl), "DELETE,INSERT") + assertPrivilegeError(input.write.mode("append").saveAsTable(tbl), "INSERT") + assertPrivilegeError(input.write.mode("overwrite").saveAsTable(tbl), "DELETE,INSERT") + assertPrivilegeError(input.writeTo(tbl).append(), "INSERT") + assertPrivilegeError(input.writeTo(tbl).overwrite(df.col("i") === 1), "DELETE,INSERT") + assertPrivilegeError(input.writeTo(tbl).overwritePartitions(), "DELETE,INSERT") + } + + // Test CTAS + withTable(tbl) { + // assertPrivilegeError(sql(s"CREATE TABLE $tbl AS SELECT 1 i"), "INSERT") + } + withTable(tbl) { + // assertPrivilegeError(sql(s"CREATE OR REPLACE TABLE $tbl AS SELECT 1 i"), "INSERT") + } + withTable(tbl) { + // assertPrivilegeError(input.write.saveAsTable(tbl), "INSERT") + } + withTable(tbl) { + // assertPrivilegeError(input.writeTo(tbl).create(), "INSERT") + } + withTable(tbl) { + // assertPrivilegeError(input.writeTo(tbl).createOrReplace(), "INSERT") + } + } + } + // Reset CatalogManager to clear the materialized `spark_catalog` instance, so that we can + // configure a new implementation. + spark.sessionState.catalogManager.reset() + checkWriteOperations(SESSION_CATALOG_NAME) + checkWriteOperations("read_only_cat") + } + private def testNotSupportedV2Command( sqlCommand: String, sqlParams: String, @@ -3517,3 +3584,19 @@ class V2CatalogSupportBuiltinDataSource extends InMemoryCatalog { } } +class ReadOnlyCatalog extends InMemoryCatalog { + override def createTable( + ident: Identifier, + columns: Array[ColumnV2], + partitions: Array[Transform], + properties: jutil.Map[String, String]): Table = { + super.createTable(ident, columns, partitions, properties) + } + + override def loadTable( + ident: Identifier, + writePrivileges: jutil.Set[TableWritePrivilege]): Table = { + throw new RuntimeException("cannot write with " + + writePrivileges.asScala.toSeq.map(_.toString).sorted.mkString(",")) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuiteBase.scala index a2f3d872a68e9..2979d3cdcab56 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuiteBase.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogNotFoundException, CatalogV2Util, Column, ColumnDefaultValue, Identifier, SupportsRowLevelOperations, TableCapability, TableCatalog} +import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogNotFoundException, CatalogV2Util, Column, ColumnDefaultValue, Identifier, SupportsRowLevelOperations, TableCapability, TableCatalog, TableWritePrivilege} import org.apache.spark.sql.connector.expressions.{LiteralValue, Transform} import org.apache.spark.sql.execution.datasources.v2.V2SessionCatalog import org.apache.spark.sql.internal.SQLConf @@ -160,6 +160,8 @@ abstract class AlignAssignmentsSuiteBase extends AnalysisTest { case name => throw new NoSuchTableException(Seq(name)) } }) + when(newCatalog.loadTable(any(), any[java.util.Set[TableWritePrivilege]]())) + .thenCallRealMethod() when(newCatalog.name()).thenReturn("cat") newCatalog } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 8eb0d5456c111..d738270699bd8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{AlterColumn, AnalysisOnlyCom import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId import org.apache.spark.sql.connector.FakeV2Provider -import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogNotFoundException, Column, ColumnDefaultValue, Identifier, SupportsDelete, Table, TableCapability, TableCatalog, V1Table} +import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogNotFoundException, Column, ColumnDefaultValue, Identifier, SupportsDelete, Table, TableCapability, TableCatalog, TableWritePrivilege, V1Table} import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME import org.apache.spark.sql.connector.expressions.{LiteralValue, Transform} import org.apache.spark.sql.execution.datasources.{CreateTable => CreateTableV1} @@ -157,6 +157,8 @@ class PlanResolutionSuite extends AnalysisTest { case name => throw new NoSuchTableException(Seq(name)) } }) + when(newCatalog.loadTable(any(), any[java.util.Set[TableWritePrivilege]]())) + .thenCallRealMethod() when(newCatalog.name()).thenReturn("testcat") newCatalog } @@ -174,6 +176,8 @@ class PlanResolutionSuite extends AnalysisTest { case name => throw new NoSuchTableException(Seq(name)) } }) + when(newCatalog.loadTable(any(), any[java.util.Set[TableWritePrivilege]]())) + .thenCallRealMethod() when(newCatalog.name()).thenReturn(CatalogManager.SESSION_CATALOG_NAME) newCatalog }