diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java index b990f59bfd90e..387477d0f1911 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java @@ -110,6 +110,26 @@ public interface TableCatalog extends CatalogPlugin { */ Table loadTable(Identifier ident) throws NoSuchTableException; + /** + * Load table metadata by {@link Identifier identifier} from the catalog. Spark will write data + * into this table later. + *
+ * If the catalog supports views and contains a view for the identifier and not a table, this
+ * must throw {@link NoSuchTableException}.
+ *
+ * @param ident a table identifier
+ * @param writePrivileges
+ * @return the table's metadata
+ * @throws NoSuchTableException If the table doesn't exist or is a view
+ *
+ * @since 3.5.3
+ */
+ default Table loadTable(
+ Identifier ident,
+ Set
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableWritePrivilege.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableWritePrivilege.java
new file mode 100644
index 0000000000000..ca2d4ba9e7b4e
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableWritePrivilege.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog;
+
+/**
+ * The table write privileges that will be provided when loading a table.
+ *
+ * @since 3.5.3
+ */
+public enum TableWritePrivilege {
+ /**
+ * The privilege for adding rows to the table.
+ */
+ INSERT,
+
+ /**
+ * The privilege for changing existing rows in th table.
+ */
+ UPDATE,
+
+ /**
+ * The privilege for deleting rows from the table.
+ */
+ DELETE
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index fd0a0715b6344..463bd3c3a8a27 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1309,8 +1309,12 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
cachedConnectRelation
}.getOrElse(cachedRelation)
}.orElse {
- val table = CatalogV2Util.loadTable(catalog, ident, timeTravelSpec)
- val loaded = createRelation(catalog, ident, table, u.options, u.isStreaming)
+ val writePrivilegesString =
+ Option(u.options.get(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES))
+ val table = CatalogV2Util.loadTable(
+ catalog, ident, timeTravelSpec, writePrivilegesString)
+ val loaded = createRelation(
+ catalog, ident, table, u.clearWritePrivileges.options, u.isStreaming)
loaded.foreach(AnalysisContext.get.relationCache.update(key, _))
u.getTagValue(LogicalPlan.PLAN_ID_TAG).map { planId =>
loaded.map { loadedRelation =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index a03c3e317c104..81d92acc6e84a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Unary
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId
+import org.apache.spark.sql.connector.catalog.TableWritePrivilege
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.types.{DataType, Metadata, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -106,10 +107,36 @@ case class UnresolvedRelation(
override def name: String = tableName
+ def requireWritePrivileges(privileges: Seq[TableWritePrivilege]): UnresolvedRelation = {
+ if (privileges.nonEmpty) {
+ val newOptions = new java.util.HashMap[String, String]
+ newOptions.putAll(options)
+ newOptions.put(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES, privileges.mkString(","))
+ copy(options = new CaseInsensitiveStringMap(newOptions))
+ } else {
+ this
+ }
+ }
+
+ def clearWritePrivileges: UnresolvedRelation = {
+ if (options.containsKey(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES)) {
+ val newOptions = new java.util.HashMap[String, String]
+ newOptions.putAll(options)
+ newOptions.remove(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES)
+ copy(options = new CaseInsensitiveStringMap(newOptions))
+ } else {
+ this
+ }
+ }
+
final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_RELATION)
}
object UnresolvedRelation {
+ // An internal option of `UnresolvedRelation` to specify the required write privileges when
+ // writing data to this relation.
+ val REQUIRED_WRITE_PRIVILEGES = "__required_write_privileges__"
+
def apply(
tableIdentifier: TableIdentifier,
extraOptions: CaseInsensitiveStringMap,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index f28a6de9fc2c1..2b600743e1bd8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -44,7 +44,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.PARAMETER
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils, GeneratedColumn, IntervalUtils, ResolveDefaultColumns}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ, getZoneId, stringToDate, stringToTimestamp, stringToTimestampWithoutTimeZone}
-import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsNamespaces, TableCatalog}
+import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsNamespaces, TableCatalog, TableWritePrivilege}
import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition
import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, Expression => V2Expression, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryParsingErrors}
@@ -336,7 +336,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
= visitInsertIntoTable(table)
withIdentClause(relationCtx, Seq(query), (ident, otherPlans) => {
InsertIntoStatement(
- createUnresolvedRelation(relationCtx, ident),
+ createUnresolvedRelation(relationCtx, ident, Seq(TableWritePrivilege.INSERT)),
partition,
cols,
otherPlans.head,
@@ -349,7 +349,8 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
= visitInsertOverwriteTable(table)
withIdentClause(relationCtx, Seq(query), (ident, otherPlans) => {
InsertIntoStatement(
- createUnresolvedRelation(relationCtx, ident),
+ createUnresolvedRelation(relationCtx, ident,
+ Seq(TableWritePrivilege.INSERT, TableWritePrivilege.DELETE)),
partition,
cols,
otherPlans.head,
@@ -360,7 +361,8 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
case ctx: InsertIntoReplaceWhereContext =>
withIdentClause(ctx.identifierReference, Seq(query), (ident, otherPlans) => {
OverwriteByExpression.byPosition(
- createUnresolvedRelation(ctx.identifierReference, ident),
+ createUnresolvedRelation(ctx.identifierReference, ident,
+ Seq(TableWritePrivilege.INSERT, TableWritePrivilege.DELETE)),
otherPlans.head,
expression(ctx.whereClause().booleanExpression()))
})
@@ -439,7 +441,8 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
override def visitDeleteFromTable(
ctx: DeleteFromTableContext): LogicalPlan = withOrigin(ctx) {
- val table = createUnresolvedRelation(ctx.identifierReference)
+ val table = createUnresolvedRelation(
+ ctx.identifierReference, writePrivileges = Seq(TableWritePrivilege.DELETE))
val tableAlias = getTableAliasWithoutColumnAlias(ctx.tableAlias(), "DELETE")
val aliasedTable = tableAlias.map(SubqueryAlias(_, table)).getOrElse(table)
val predicate = if (ctx.whereClause() != null) {
@@ -451,7 +454,8 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
}
override def visitUpdateTable(ctx: UpdateTableContext): LogicalPlan = withOrigin(ctx) {
- val table = createUnresolvedRelation(ctx.identifierReference)
+ val table = createUnresolvedRelation(
+ ctx.identifierReference, writePrivileges = Seq(TableWritePrivilege.UPDATE))
val tableAlias = getTableAliasWithoutColumnAlias(ctx.tableAlias(), "UPDATE")
val aliasedTable = tableAlias.map(SubqueryAlias(_, table)).getOrElse(table)
val assignments = withAssignments(ctx.setClause().assignmentList())
@@ -473,10 +477,6 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
}
override def visitMergeIntoTable(ctx: MergeIntoTableContext): LogicalPlan = withOrigin(ctx) {
- val targetTable = createUnresolvedRelation(ctx.target)
- val targetTableAlias = getTableAliasWithoutColumnAlias(ctx.targetAlias, "MERGE")
- val aliasedTarget = targetTableAlias.map(SubqueryAlias(_, targetTable)).getOrElse(targetTable)
-
val sourceTableOrQuery = if (ctx.source != null) {
createUnresolvedRelation(ctx.source)
} else if (ctx.sourceQuery != null) {
@@ -506,7 +506,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
s"Unrecognized matched action: ${clause.matchedAction().getText}")
}
}
- }
+ }.toSeq
val notMatchedActions = ctx.notMatchedClause().asScala.map {
clause => {
if (clause.notMatchedAction().INSERT() != null) {
@@ -527,7 +527,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
s"Unrecognized matched action: ${clause.notMatchedAction().getText}")
}
}
- }
+ }.toSeq
val notMatchedBySourceActions = ctx.notMatchedBySourceClause().asScala.map {
clause => {
val notMatchedBySourceAction = clause.notMatchedBySourceAction()
@@ -542,7 +542,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
s"Unrecognized matched action: ${clause.notMatchedBySourceAction().getText}")
}
}
- }
+ }.toSeq
if (matchedActions.isEmpty && notMatchedActions.isEmpty && notMatchedBySourceActions.isEmpty) {
throw QueryParsingErrors.mergeStatementWithoutWhenClauseError(ctx)
}
@@ -561,13 +561,19 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
throw QueryParsingErrors.nonLastNotMatchedBySourceClauseOmitConditionError(ctx)
}
+ val targetTable = createUnresolvedRelation(
+ ctx.target,
+ writePrivileges = MergeIntoTable.getWritePrivileges(
+ matchedActions, notMatchedActions, notMatchedBySourceActions))
+ val targetTableAlias = getTableAliasWithoutColumnAlias(ctx.targetAlias, "MERGE")
+ val aliasedTarget = targetTableAlias.map(SubqueryAlias(_, targetTable)).getOrElse(targetTable)
MergeIntoTable(
aliasedTarget,
aliasedSource,
mergeCondition,
- matchedActions.toSeq,
- notMatchedActions.toSeq,
- notMatchedBySourceActions.toSeq)
+ matchedActions,
+ notMatchedActions,
+ notMatchedBySourceActions)
}
/**
@@ -2793,16 +2799,23 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
* Create an [[UnresolvedRelation]] from an identifier reference.
*/
private def createUnresolvedRelation(
- ctx: IdentifierReferenceContext): LogicalPlan = withOrigin(ctx) {
- withIdentClause(ctx, UnresolvedRelation(_))
+ ctx: IdentifierReferenceContext,
+ writePrivileges: Seq[TableWritePrivilege] = Nil): LogicalPlan = withOrigin(ctx) {
+ withIdentClause(ctx, parts => {
+ val relation = new UnresolvedRelation(parts, isStreaming = false)
+ relation.requireWritePrivileges(writePrivileges)
+ })
}
/**
* Create an [[UnresolvedRelation]] from a multi-part identifier.
*/
private def createUnresolvedRelation(
- ctx: ParserRuleContext, ident: Seq[String]): UnresolvedRelation = withOrigin(ctx) {
- UnresolvedRelation(ident)
+ ctx: ParserRuleContext,
+ ident: Seq[String],
+ writePrivileges: Seq[TableWritePrivilege]): UnresolvedRelation = withOrigin(ctx) {
+ val relation = new UnresolvedRelation(ident, isStreaming = false)
+ relation.requireWritePrivileges(writePrivileges)
}
/**
@@ -4601,7 +4614,9 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
if (query.isDefined) {
CacheTableAsSelect(ident.head, query.get, source(ctx.query()), isLazy, options)
} else {
- CacheTable(createUnresolvedRelation(ctx.identifierReference, ident), ident, isLazy, options)
+ CacheTable(
+ createUnresolvedRelation(ctx.identifierReference, ident, writePrivileges = Nil),
+ ident, isLazy, options)
}
})
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
index 805f277cf9f6a..d7669ac0b1d78 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
@@ -758,6 +758,21 @@ case class MergeIntoTable(
copy(targetTable = newLeft, sourceTable = newRight)
}
+object MergeIntoTable {
+ def getWritePrivileges(
+ matchedActions: Seq[MergeAction],
+ notMatchedActions: Seq[MergeAction],
+ notMatchedBySourceActions: Seq[MergeAction]): Seq[TableWritePrivilege] = {
+ val privileges = scala.collection.mutable.HashSet.empty[TableWritePrivilege]
+ (matchedActions.iterator ++ notMatchedActions ++ notMatchedBySourceActions).foreach {
+ case _: DeleteAction => privileges.add(TableWritePrivilege.DELETE)
+ case _: UpdateAction | _: UpdateStarAction => privileges.add(TableWritePrivilege.UPDATE)
+ case _: InsertAction | _: InsertStarAction => privileges.add(TableWritePrivilege.INSERT)
+ }
+ privileges.toSeq
+ }
+}
+
sealed abstract class MergeAction extends Expression with Unevaluable {
def condition: Option[Expression]
override def nullable: Boolean = false
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala
index 47c438f154ab9..f8f682e76cfc8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala
@@ -331,9 +331,10 @@ private[sql] object CatalogV2Util {
def loadTable(
catalog: CatalogPlugin,
ident: Identifier,
- timeTravelSpec: Option[TimeTravelSpec] = None): Option[Table] =
+ timeTravelSpec: Option[TimeTravelSpec] = None,
+ writePrivilegesString: Option[String] = None): Option[Table] =
try {
- Option(getTable(catalog, ident, timeTravelSpec))
+ Option(getTable(catalog, ident, timeTravelSpec, writePrivilegesString))
} catch {
case _: NoSuchTableException => None
case _: NoSuchDatabaseException => None
@@ -343,8 +344,10 @@ private[sql] object CatalogV2Util {
def getTable(
catalog: CatalogPlugin,
ident: Identifier,
- timeTravelSpec: Option[TimeTravelSpec] = None): Table = {
+ timeTravelSpec: Option[TimeTravelSpec] = None,
+ writePrivilegesString: Option[String] = None): Table = {
if (timeTravelSpec.nonEmpty) {
+ assert(writePrivilegesString.isEmpty, "Should not write to a table with time travel")
timeTravelSpec.get match {
case v: AsOfVersion =>
catalog.asTableCatalog.loadTable(ident, v.version)
@@ -352,7 +355,13 @@ private[sql] object CatalogV2Util {
catalog.asTableCatalog.loadTable(ident, ts.timestamp)
}
} else {
- catalog.asTableCatalog.loadTable(ident)
+ if (writePrivilegesString.isDefined) {
+ val writePrivileges = writePrivilegesString.get.split(",").map(_.trim)
+ .map(TableWritePrivilege.valueOf).toSet.asJava
+ catalog.asTableCatalog.loadTable(ident, writePrivileges)
+ } else {
+ catalog.asTableCatalog.loadTable(ident)
+ }
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala
index 6f36a8c9719cb..176c24d4e100f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala
@@ -39,7 +39,16 @@ class DDLParserSuite extends AnalysisTest {
}
private def parseCompare(sql: String, expected: LogicalPlan): Unit = {
- comparePlans(parsePlan(sql), expected, checkAnalysis = false)
+ // We don't care the write privileges in this suite.
+ val parsed = parsePlan(sql).transform {
+ case u: UnresolvedRelation => u.clearWritePrivileges
+ case i: InsertIntoStatement =>
+ i.table match {
+ case u: UnresolvedRelation => i.copy(table = u.clearWritePrivileges)
+ case _ => i
+ }
+ }
+ comparePlans(parsed, expected, checkAnalysis = false)
}
private def internalException(sqlText: String): SparkThrowable = {
@@ -2614,15 +2623,15 @@ class DDLParserSuite extends AnalysisTest {
val timestampTypeSql = s"INSERT INTO t PARTITION(part = timestamp'$timestamp') VALUES('a')"
val binaryTypeSql = s"INSERT INTO t PARTITION(part = X'$binaryHexStr') VALUES('a')"
- comparePlans(parsePlan(dateTypeSql), insertPartitionPlan("2019-01-02"))
+ parseCompare(dateTypeSql, insertPartitionPlan("2019-01-02"))
withSQLConf(SQLConf.LEGACY_INTERVAL_ENABLED.key -> "true") {
- comparePlans(parsePlan(intervalTypeSql), insertPartitionPlan(interval))
+ parseCompare(intervalTypeSql, insertPartitionPlan(interval))
}
- comparePlans(parsePlan(ymIntervalTypeSql), insertPartitionPlan("INTERVAL '1-2' YEAR TO MONTH"))
- comparePlans(parsePlan(dtIntervalTypeSql),
+ parseCompare(ymIntervalTypeSql, insertPartitionPlan("INTERVAL '1-2' YEAR TO MONTH"))
+ parseCompare(dtIntervalTypeSql,
insertPartitionPlan("INTERVAL '1 02:03:04.128462' DAY TO SECOND"))
- comparePlans(parsePlan(timestampTypeSql), insertPartitionPlan(timestamp))
- comparePlans(parsePlan(binaryTypeSql), insertPartitionPlan(binaryStr))
+ parseCompare(timestampTypeSql, insertPartitionPlan(timestamp))
+ parseCompare(binaryTypeSql, insertPartitionPlan(binaryStr))
}
test("SPARK-38335: Implement parser support for DEFAULT values for columns in tables") {
@@ -2717,12 +2726,12 @@ class DDLParserSuite extends AnalysisTest {
// In each of the following cases, the DEFAULT reference parses as an unresolved attribute
// reference. We can handle these cases after the parsing stage, at later phases of analysis.
- comparePlans(parsePlan("VALUES (1, 2, DEFAULT) AS val"),
+ parseCompare("VALUES (1, 2, DEFAULT) AS val",
SubqueryAlias("val",
UnresolvedInlineTable(Seq("col1", "col2", "col3"), Seq(Seq(Literal(1), Literal(2),
UnresolvedAttribute("DEFAULT"))))))
- comparePlans(parsePlan(
- "INSERT INTO t PARTITION(part = date'2019-01-02') VALUES ('a', DEFAULT)"),
+ parseCompare(
+ "INSERT INTO t PARTITION(part = date'2019-01-02') VALUES ('a', DEFAULT)",
InsertIntoStatement(
UnresolvedRelation(Seq("t")),
Map("part" -> Some("2019-01-02")),
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index 13474fe29de98..acc5a6ebddd2e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -39,7 +39,16 @@ class PlanParserSuite extends AnalysisTest {
import org.apache.spark.sql.catalyst.dsl.plans._
private def assertEqual(sqlCommand: String, plan: LogicalPlan): Unit = {
- comparePlans(parsePlan(sqlCommand), plan, checkAnalysis = false)
+ // We don't care the write privileges in this suite.
+ val parsed = parsePlan(sqlCommand).transform {
+ case u: UnresolvedRelation => u.clearWritePrivileges
+ case i: InsertIntoStatement =>
+ i.table match {
+ case u: UnresolvedRelation => i.copy(table = u.clearWritePrivileges)
+ case _ => i
+ }
+ }
+ comparePlans(parsed, plan, checkAnalysis = false)
}
private def parseException(sqlText: String): SparkThrowable = {
@@ -1033,57 +1042,56 @@ class PlanParserSuite extends AnalysisTest {
errorClass = "PARSE_SYNTAX_ERROR",
parameters = Map("error" -> "'b'", "hint" -> ""))
- comparePlans(
- parsePlan("SELECT /*+ HINT */ * FROM t"),
+ assertEqual(
+ "SELECT /*+ HINT */ * FROM t",
UnresolvedHint("HINT", Seq.empty, table("t").select(star())))
- comparePlans(
- parsePlan("SELECT /*+ BROADCASTJOIN(u) */ * FROM t"),
+ assertEqual(
+ "SELECT /*+ BROADCASTJOIN(u) */ * FROM t",
UnresolvedHint("BROADCASTJOIN", Seq($"u"), table("t").select(star())))
- comparePlans(
- parsePlan("SELECT /*+ MAPJOIN(u) */ * FROM t"),
+ assertEqual(
+ "SELECT /*+ MAPJOIN(u) */ * FROM t",
UnresolvedHint("MAPJOIN", Seq($"u"), table("t").select(star())))
- comparePlans(
- parsePlan("SELECT /*+ STREAMTABLE(a,b,c) */ * FROM t"),
+ assertEqual(
+ "SELECT /*+ STREAMTABLE(a,b,c) */ * FROM t",
UnresolvedHint("STREAMTABLE", Seq($"a", $"b", $"c"), table("t").select(star())))
- comparePlans(
- parsePlan("SELECT /*+ INDEX(t, emp_job_ix) */ * FROM t"),
+ assertEqual(
+ "SELECT /*+ INDEX(t, emp_job_ix) */ * FROM t",
UnresolvedHint("INDEX", Seq($"t", $"emp_job_ix"), table("t").select(star())))
- comparePlans(
- parsePlan("SELECT /*+ MAPJOIN(`default.t`) */ * from `default.t`"),
+ assertEqual(
+ "SELECT /*+ MAPJOIN(`default.t`) */ * from `default.t`",
UnresolvedHint("MAPJOIN", Seq(UnresolvedAttribute.quoted("default.t")),
table("default.t").select(star())))
- comparePlans(
- parsePlan("SELECT /*+ MAPJOIN(t) */ a from t where true group by a order by a"),
+ assertEqual(
+ "SELECT /*+ MAPJOIN(t) */ a from t where true group by a order by a",
UnresolvedHint("MAPJOIN", Seq($"t"),
table("t").where(Literal(true)).groupBy($"a")($"a")).orderBy($"a".asc))
- comparePlans(
- parsePlan("SELECT /*+ COALESCE(10) */ * FROM t"),
+ assertEqual(
+ "SELECT /*+ COALESCE(10) */ * FROM t",
UnresolvedHint("COALESCE", Seq(Literal(10)),
table("t").select(star())))
- comparePlans(
- parsePlan("SELECT /*+ REPARTITION(100) */ * FROM t"),
+ assertEqual(
+ "SELECT /*+ REPARTITION(100) */ * FROM t",
UnresolvedHint("REPARTITION", Seq(Literal(100)),
table("t").select(star())))
- comparePlans(
- parsePlan(
- "INSERT INTO s SELECT /*+ REPARTITION(100), COALESCE(500), COALESCE(10) */ * FROM t"),
+ assertEqual(
+ "INSERT INTO s SELECT /*+ REPARTITION(100), COALESCE(500), COALESCE(10) */ * FROM t",
InsertIntoStatement(table("s"), Map.empty, Nil,
UnresolvedHint("REPARTITION", Seq(Literal(100)),
UnresolvedHint("COALESCE", Seq(Literal(500)),
UnresolvedHint("COALESCE", Seq(Literal(10)),
table("t").select(star())))), overwrite = false, ifPartitionNotExists = false))
- comparePlans(
- parsePlan("SELECT /*+ BROADCASTJOIN(u), REPARTITION(100) */ * FROM t"),
+ assertEqual(
+ "SELECT /*+ BROADCASTJOIN(u), REPARTITION(100) */ * FROM t",
UnresolvedHint("BROADCASTJOIN", Seq($"u"),
UnresolvedHint("REPARTITION", Seq(Literal(100)),
table("t").select(star()))))
@@ -1094,49 +1102,48 @@ class PlanParserSuite extends AnalysisTest {
errorClass = "PARSE_SYNTAX_ERROR",
parameters = Map("error" -> "'+'", "hint" -> ""))
- comparePlans(
- parsePlan("SELECT /*+ REPARTITION(c) */ * FROM t"),
+ assertEqual(
+ "SELECT /*+ REPARTITION(c) */ * FROM t",
UnresolvedHint("REPARTITION", Seq(UnresolvedAttribute("c")),
table("t").select(star())))
- comparePlans(
- parsePlan("SELECT /*+ REPARTITION(100, c) */ * FROM t"),
+ assertEqual(
+ "SELECT /*+ REPARTITION(100, c) */ * FROM t",
UnresolvedHint("REPARTITION", Seq(Literal(100), UnresolvedAttribute("c")),
table("t").select(star())))
- comparePlans(
- parsePlan("SELECT /*+ REPARTITION(100, c), COALESCE(50) */ * FROM t"),
+ assertEqual(
+ "SELECT /*+ REPARTITION(100, c), COALESCE(50) */ * FROM t",
UnresolvedHint("REPARTITION", Seq(Literal(100), UnresolvedAttribute("c")),
UnresolvedHint("COALESCE", Seq(Literal(50)),
table("t").select(star()))))
- comparePlans(
- parsePlan("SELECT /*+ REPARTITION(100, c), BROADCASTJOIN(u), COALESCE(50) */ * FROM t"),
+ assertEqual(
+ "SELECT /*+ REPARTITION(100, c), BROADCASTJOIN(u), COALESCE(50) */ * FROM t",
UnresolvedHint("REPARTITION", Seq(Literal(100), UnresolvedAttribute("c")),
UnresolvedHint("BROADCASTJOIN", Seq($"u"),
UnresolvedHint("COALESCE", Seq(Literal(50)),
table("t").select(star())))))
- comparePlans(
- parsePlan(
- """
- |SELECT
- |/*+ REPARTITION(100, c), BROADCASTJOIN(u), COALESCE(50), REPARTITION(300, c) */
- |* FROM t
- """.stripMargin),
+ assertEqual(
+ """
+ |SELECT
+ |/*+ REPARTITION(100, c), BROADCASTJOIN(u), COALESCE(50), REPARTITION(300, c) */
+ |* FROM t
+ """.stripMargin,
UnresolvedHint("REPARTITION", Seq(Literal(100), UnresolvedAttribute("c")),
UnresolvedHint("BROADCASTJOIN", Seq($"u"),
UnresolvedHint("COALESCE", Seq(Literal(50)),
UnresolvedHint("REPARTITION", Seq(Literal(300), UnresolvedAttribute("c")),
table("t").select(star()))))))
- comparePlans(
- parsePlan("SELECT /*+ REPARTITION_BY_RANGE(c) */ * FROM t"),
+ assertEqual(
+ "SELECT /*+ REPARTITION_BY_RANGE(c) */ * FROM t",
UnresolvedHint("REPARTITION_BY_RANGE", Seq(UnresolvedAttribute("c")),
table("t").select(star())))
- comparePlans(
- parsePlan("SELECT /*+ REPARTITION_BY_RANGE(100, c) */ * FROM t"),
+ assertEqual(
+ "SELECT /*+ REPARTITION_BY_RANGE(100, c) */ * FROM t",
UnresolvedHint("REPARTITION_BY_RANGE", Seq(Literal(100), UnresolvedAttribute("c")),
table("t").select(star())))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index 4de6b944bc868..84f02c723136b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -30,12 +30,15 @@ import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSel
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.connector.catalog.{CatalogPlugin, CatalogV2Implicits, CatalogV2Util, Identifier, SupportsCatalogOptions, Table, TableCatalog, TableProvider, V1Table}
import org.apache.spark.sql.connector.catalog.TableCapability._
+import org.apache.spark.sql.connector.catalog.TableWritePrivilege
+import org.apache.spark.sql.connector.catalog.TableWritePrivilege._
import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, DataSourceUtils, LogicalRelation}
import org.apache.spark.sql.execution.datasources.v2._
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.StructType
@@ -448,7 +451,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
private def insertInto(catalog: CatalogPlugin, ident: Identifier): Unit = {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
- val table = catalog.asTableCatalog.loadTable(ident) match {
+ val table = catalog.asTableCatalog.loadTable(ident, getWritePrivileges.toSet.asJava) match {
case _: V1Table =>
return insertInto(TableIdentifier(ident.name(), ident.namespace().headOption))
case t =>
@@ -479,7 +482,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
private def insertInto(tableIdent: TableIdentifier): Unit = {
runCommand(df.sparkSession) {
InsertIntoStatement(
- table = UnresolvedRelation(tableIdent),
+ table = UnresolvedRelation(tableIdent).requireWritePrivileges(getWritePrivileges),
partitionSpec = Map.empty[String, Option[String]],
Nil,
query = df.logicalPlan,
@@ -488,6 +491,11 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
}
}
+ private def getWritePrivileges: Seq[TableWritePrivilege] = mode match {
+ case SaveMode.Overwrite => Seq(INSERT, DELETE)
+ case _ => Seq(INSERT)
+ }
+
private def getBucketSpec: Option[BucketSpec] = {
if (sortColumnNames.isDefined && numBuckets.isEmpty) {
throw QueryCompilationErrors.sortByWithoutBucketingError()
@@ -557,7 +565,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
val session = df.sparkSession
- val canUseV2 = lookupV2Provider().isDefined
+ val canUseV2 = lookupV2Provider().isDefined ||
+ df.sparkSession.sessionState.conf.getConf(SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION).isDefined
session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match {
case nameParts @ NonSessionCatalogAndIdentifier(catalog, ident) =>
@@ -578,7 +587,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
private def saveAsTable(
catalog: TableCatalog, ident: Identifier, nameParts: Seq[String]): Unit = {
- val tableOpt = try Option(catalog.loadTable(ident)) catch {
+ val tableOpt = try Option(catalog.loadTable(ident, getWritePrivileges.toSet.asJava)) catch {
case _: NoSuchTableException => None
}
@@ -639,7 +648,6 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
val catalog = df.sparkSession.sessionState.catalog
val qualifiedIdent = catalog.qualifyIdentifier(tableIdent)
val tableExists = catalog.tableExists(qualifiedIdent)
- val tableName = qualifiedIdent.unquotedString
(tableExists, mode) match {
case (true, SaveMode.Ignore) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala
index 7ca9c7ef71d67..09d884af05b18 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala
@@ -24,6 +24,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException, UnresolvedIdentifier, UnresolvedRelation}
import org.apache.spark.sql.catalyst.expressions.{Attribute, Bucket, Days, Hours, Literal, Months, Years}
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, LogicalPlan, OptionList, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, UnresolvedTableSpec}
+import org.apache.spark.sql.connector.catalog.TableWritePrivilege._
import org.apache.spark.sql.connector.expressions.{LogicalExpressions, NamedReference, Transform}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.QueryExecution
@@ -146,7 +147,9 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
*/
@throws(classOf[NoSuchTableException])
def append(): Unit = {
- val append = AppendData.byName(UnresolvedRelation(tableName), logicalPlan, options.toMap)
+ val append = AppendData.byName(
+ UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT)),
+ logicalPlan, options.toMap)
runCommand(append)
}
@@ -163,7 +166,8 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
@throws(classOf[NoSuchTableException])
def overwrite(condition: Column): Unit = {
val overwrite = OverwriteByExpression.byName(
- UnresolvedRelation(tableName), logicalPlan, condition.expr, options.toMap)
+ UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT, DELETE)),
+ logicalPlan, condition.expr, options.toMap)
runCommand(overwrite)
}
@@ -183,7 +187,8 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
@throws(classOf[NoSuchTableException])
def overwritePartitions(): Unit = {
val dynamicOverwrite = OverwritePartitionsDynamic.byName(
- UnresolvedRelation(tableName), logicalPlan, options.toMap)
+ UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT, DELETE)),
+ logicalPlan, options.toMap)
runCommand(dynamicOverwrite)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala
index aee243b6529da..c96ce2daa49e3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala
@@ -669,7 +669,7 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager)
}
private def supportsV1Command(catalog: CatalogPlugin): Boolean = {
- catalog.name().equalsIgnoreCase(CatalogManager.SESSION_CATALOG_NAME) &&
- !SQLConf.get.getConf(SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION).isDefined
+ isSessionCatalog(catalog) &&
+ SQLConf.get.getConf(SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION).isEmpty
}
}
diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/explain-aqe.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/explain-aqe.sql.out
index f37e31bdb389c..522cfb0cbbd28 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/explain-aqe.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/explain-aqe.sql.out
@@ -196,7 +196,7 @@ ExplainCommand 'Aggregate ['key], ['key, unresolvedalias('MIN('val), None)], For
-- !query
EXPLAIN EXTENDED INSERT INTO TABLE explain_temp5 SELECT * FROM explain_temp4
-- !query analysis
-ExplainCommand 'InsertIntoStatement 'UnresolvedRelation [explain_temp5], [], false, false, false, false, ExtendedMode
+ExplainCommand 'InsertIntoStatement 'UnresolvedRelation [explain_temp5], [__required_write_privileges__=INSERT], false, false, false, false, ExtendedMode
-- !query
diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/explain.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/explain.sql.out
index f37e31bdb389c..522cfb0cbbd28 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/explain.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/explain.sql.out
@@ -196,7 +196,7 @@ ExplainCommand 'Aggregate ['key], ['key, unresolvedalias('MIN('val), None)], For
-- !query
EXPLAIN EXTENDED INSERT INTO TABLE explain_temp5 SELECT * FROM explain_temp4
-- !query analysis
-ExplainCommand 'InsertIntoStatement 'UnresolvedRelation [explain_temp5], [], false, false, false, false, ExtendedMode
+ExplainCommand 'InsertIntoStatement 'UnresolvedRelation [explain_temp5], [__required_write_privileges__=INSERT], false, false, false, false, ExtendedMode
-- !query
diff --git a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out
index 3c2677c936f9c..54fa9ca418cc1 100644
--- a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out
@@ -1081,7 +1081,7 @@ EXPLAIN EXTENDED INSERT INTO TABLE explain_temp5 SELECT * FROM explain_temp4
struct