Skip to content

Commit

Permalink
[SPARK-49246][SQL] TableCatalog#loadTable should indicate if it's for…
Browse files Browse the repository at this point in the history
… writing

For custom catalogs that have access control, read and write permissions can be different. However, currently Spark always call `TableCatalog#loadTable` to look up the table, no matter it's for read or write.

This PR adds a variant of `loadTable`: `loadTableForWrite`, in `TableCatalog`. All the write commands will call this new method to look up tables instead. This new method has a default implementation that just calls `loadTable`, so there is no breaking change.

allow more fine-grained access control for custom catalogs.

No

new tests

no

Closes #47772 from cloud-fan/write.

Lead-authored-by: Wenchen Fan <wenchen@databricks.com>
Co-authored-by: Wenchen Fan <cloud0fan@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit b6164e6)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
cloud-fan and cloud-fan committed Aug 21, 2024
1 parent c06906d commit 027a14b
Show file tree
Hide file tree
Showing 19 changed files with 344 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,26 @@ public interface TableCatalog extends CatalogPlugin {
*/
Table loadTable(Identifier ident) throws NoSuchTableException;

/**
* Load table metadata by {@link Identifier identifier} from the catalog. Spark will write data
* into this table later.
* <p>
* If the catalog supports views and contains a view for the identifier and not a table, this
* must throw {@link NoSuchTableException}.
*
* @param ident a table identifier
* @param writePrivileges
* @return the table's metadata
* @throws NoSuchTableException If the table doesn't exist or is a view
*
* @since 3.5.3
*/
default Table loadTable(
Identifier ident,
Set<TableWritePrivilege> writePrivileges) throws NoSuchTableException {
return loadTable(ident);
}

/**
* Load table metadata of a specific version by {@link Identifier identifier} from the catalog.
* <p>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.catalog;

/**
* The table write privileges that will be provided when loading a table.
*
* @since 3.5.3
*/
public enum TableWritePrivilege {
/**
* The privilege for adding rows to the table.
*/
INSERT,

/**
* The privilege for changing existing rows in th table.
*/
UPDATE,

/**
* The privilege for deleting rows from the table.
*/
DELETE
}
Original file line number Diff line number Diff line change
Expand Up @@ -1309,8 +1309,12 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
cachedConnectRelation
}.getOrElse(cachedRelation)
}.orElse {
val table = CatalogV2Util.loadTable(catalog, ident, timeTravelSpec)
val loaded = createRelation(catalog, ident, table, u.options, u.isStreaming)
val writePrivilegesString =
Option(u.options.get(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES))
val table = CatalogV2Util.loadTable(
catalog, ident, timeTravelSpec, writePrivilegesString)
val loaded = createRelation(
catalog, ident, table, u.clearWritePrivileges.options, u.isStreaming)
loaded.foreach(AnalysisContext.get.relationCache.update(key, _))
u.getTagValue(LogicalPlan.PLAN_ID_TAG).map { planId =>
loaded.map { loadedRelation =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Unary
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId
import org.apache.spark.sql.connector.catalog.TableWritePrivilege
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.types.{DataType, Metadata, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
Expand Down Expand Up @@ -106,10 +107,36 @@ case class UnresolvedRelation(

override def name: String = tableName

def requireWritePrivileges(privileges: Seq[TableWritePrivilege]): UnresolvedRelation = {
if (privileges.nonEmpty) {
val newOptions = new java.util.HashMap[String, String]
newOptions.putAll(options)
newOptions.put(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES, privileges.mkString(","))
copy(options = new CaseInsensitiveStringMap(newOptions))
} else {
this
}
}

def clearWritePrivileges: UnresolvedRelation = {
if (options.containsKey(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES)) {
val newOptions = new java.util.HashMap[String, String]
newOptions.putAll(options)
newOptions.remove(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES)
copy(options = new CaseInsensitiveStringMap(newOptions))
} else {
this
}
}

final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_RELATION)
}

object UnresolvedRelation {
// An internal option of `UnresolvedRelation` to specify the required write privileges when
// writing data to this relation.
val REQUIRED_WRITE_PRIVILEGES = "__required_write_privileges__"

def apply(
tableIdentifier: TableIdentifier,
extraOptions: CaseInsensitiveStringMap,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.PARAMETER
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils, GeneratedColumn, IntervalUtils, ResolveDefaultColumns}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ, getZoneId, stringToDate, stringToTimestamp, stringToTimestampWithoutTimeZone}
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsNamespaces, TableCatalog}
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsNamespaces, TableCatalog, TableWritePrivilege}
import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition
import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, Expression => V2Expression, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryParsingErrors}
Expand Down Expand Up @@ -336,7 +336,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
= visitInsertIntoTable(table)
withIdentClause(relationCtx, Seq(query), (ident, otherPlans) => {
InsertIntoStatement(
createUnresolvedRelation(relationCtx, ident),
createUnresolvedRelation(relationCtx, ident, Seq(TableWritePrivilege.INSERT)),
partition,
cols,
otherPlans.head,
Expand All @@ -349,7 +349,8 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
= visitInsertOverwriteTable(table)
withIdentClause(relationCtx, Seq(query), (ident, otherPlans) => {
InsertIntoStatement(
createUnresolvedRelation(relationCtx, ident),
createUnresolvedRelation(relationCtx, ident,
Seq(TableWritePrivilege.INSERT, TableWritePrivilege.DELETE)),
partition,
cols,
otherPlans.head,
Expand All @@ -360,7 +361,8 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
case ctx: InsertIntoReplaceWhereContext =>
withIdentClause(ctx.identifierReference, Seq(query), (ident, otherPlans) => {
OverwriteByExpression.byPosition(
createUnresolvedRelation(ctx.identifierReference, ident),
createUnresolvedRelation(ctx.identifierReference, ident,
Seq(TableWritePrivilege.INSERT, TableWritePrivilege.DELETE)),
otherPlans.head,
expression(ctx.whereClause().booleanExpression()))
})
Expand Down Expand Up @@ -439,7 +441,8 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {

override def visitDeleteFromTable(
ctx: DeleteFromTableContext): LogicalPlan = withOrigin(ctx) {
val table = createUnresolvedRelation(ctx.identifierReference)
val table = createUnresolvedRelation(
ctx.identifierReference, writePrivileges = Seq(TableWritePrivilege.DELETE))
val tableAlias = getTableAliasWithoutColumnAlias(ctx.tableAlias(), "DELETE")
val aliasedTable = tableAlias.map(SubqueryAlias(_, table)).getOrElse(table)
val predicate = if (ctx.whereClause() != null) {
Expand All @@ -451,7 +454,8 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
}

override def visitUpdateTable(ctx: UpdateTableContext): LogicalPlan = withOrigin(ctx) {
val table = createUnresolvedRelation(ctx.identifierReference)
val table = createUnresolvedRelation(
ctx.identifierReference, writePrivileges = Seq(TableWritePrivilege.UPDATE))
val tableAlias = getTableAliasWithoutColumnAlias(ctx.tableAlias(), "UPDATE")
val aliasedTable = tableAlias.map(SubqueryAlias(_, table)).getOrElse(table)
val assignments = withAssignments(ctx.setClause().assignmentList())
Expand All @@ -473,10 +477,6 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
}

override def visitMergeIntoTable(ctx: MergeIntoTableContext): LogicalPlan = withOrigin(ctx) {
val targetTable = createUnresolvedRelation(ctx.target)
val targetTableAlias = getTableAliasWithoutColumnAlias(ctx.targetAlias, "MERGE")
val aliasedTarget = targetTableAlias.map(SubqueryAlias(_, targetTable)).getOrElse(targetTable)

val sourceTableOrQuery = if (ctx.source != null) {
createUnresolvedRelation(ctx.source)
} else if (ctx.sourceQuery != null) {
Expand Down Expand Up @@ -506,7 +506,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
s"Unrecognized matched action: ${clause.matchedAction().getText}")
}
}
}
}.toSeq
val notMatchedActions = ctx.notMatchedClause().asScala.map {
clause => {
if (clause.notMatchedAction().INSERT() != null) {
Expand All @@ -527,7 +527,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
s"Unrecognized matched action: ${clause.notMatchedAction().getText}")
}
}
}
}.toSeq
val notMatchedBySourceActions = ctx.notMatchedBySourceClause().asScala.map {
clause => {
val notMatchedBySourceAction = clause.notMatchedBySourceAction()
Expand All @@ -542,7 +542,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
s"Unrecognized matched action: ${clause.notMatchedBySourceAction().getText}")
}
}
}
}.toSeq
if (matchedActions.isEmpty && notMatchedActions.isEmpty && notMatchedBySourceActions.isEmpty) {
throw QueryParsingErrors.mergeStatementWithoutWhenClauseError(ctx)
}
Expand All @@ -561,13 +561,19 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
throw QueryParsingErrors.nonLastNotMatchedBySourceClauseOmitConditionError(ctx)
}

val targetTable = createUnresolvedRelation(
ctx.target,
writePrivileges = MergeIntoTable.getWritePrivileges(
matchedActions, notMatchedActions, notMatchedBySourceActions))
val targetTableAlias = getTableAliasWithoutColumnAlias(ctx.targetAlias, "MERGE")
val aliasedTarget = targetTableAlias.map(SubqueryAlias(_, targetTable)).getOrElse(targetTable)
MergeIntoTable(
aliasedTarget,
aliasedSource,
mergeCondition,
matchedActions.toSeq,
notMatchedActions.toSeq,
notMatchedBySourceActions.toSeq)
matchedActions,
notMatchedActions,
notMatchedBySourceActions)
}

/**
Expand Down Expand Up @@ -2793,16 +2799,23 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
* Create an [[UnresolvedRelation]] from an identifier reference.
*/
private def createUnresolvedRelation(
ctx: IdentifierReferenceContext): LogicalPlan = withOrigin(ctx) {
withIdentClause(ctx, UnresolvedRelation(_))
ctx: IdentifierReferenceContext,
writePrivileges: Seq[TableWritePrivilege] = Nil): LogicalPlan = withOrigin(ctx) {
withIdentClause(ctx, parts => {
val relation = new UnresolvedRelation(parts, isStreaming = false)
relation.requireWritePrivileges(writePrivileges)
})
}

/**
* Create an [[UnresolvedRelation]] from a multi-part identifier.
*/
private def createUnresolvedRelation(
ctx: ParserRuleContext, ident: Seq[String]): UnresolvedRelation = withOrigin(ctx) {
UnresolvedRelation(ident)
ctx: ParserRuleContext,
ident: Seq[String],
writePrivileges: Seq[TableWritePrivilege]): UnresolvedRelation = withOrigin(ctx) {
val relation = new UnresolvedRelation(ident, isStreaming = false)
relation.requireWritePrivileges(writePrivileges)
}

/**
Expand Down Expand Up @@ -4601,7 +4614,9 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
if (query.isDefined) {
CacheTableAsSelect(ident.head, query.get, source(ctx.query()), isLazy, options)
} else {
CacheTable(createUnresolvedRelation(ctx.identifierReference, ident), ident, isLazy, options)
CacheTable(
createUnresolvedRelation(ctx.identifierReference, ident, writePrivileges = Nil),
ident, isLazy, options)
}
})
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,21 @@ case class MergeIntoTable(
copy(targetTable = newLeft, sourceTable = newRight)
}

object MergeIntoTable {
def getWritePrivileges(
matchedActions: Seq[MergeAction],
notMatchedActions: Seq[MergeAction],
notMatchedBySourceActions: Seq[MergeAction]): Seq[TableWritePrivilege] = {
val privileges = scala.collection.mutable.HashSet.empty[TableWritePrivilege]
(matchedActions.iterator ++ notMatchedActions ++ notMatchedBySourceActions).foreach {
case _: DeleteAction => privileges.add(TableWritePrivilege.DELETE)
case _: UpdateAction | _: UpdateStarAction => privileges.add(TableWritePrivilege.UPDATE)
case _: InsertAction | _: InsertStarAction => privileges.add(TableWritePrivilege.INSERT)
}
privileges.toSeq
}
}

sealed abstract class MergeAction extends Expression with Unevaluable {
def condition: Option[Expression]
override def nullable: Boolean = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,9 +331,10 @@ private[sql] object CatalogV2Util {
def loadTable(
catalog: CatalogPlugin,
ident: Identifier,
timeTravelSpec: Option[TimeTravelSpec] = None): Option[Table] =
timeTravelSpec: Option[TimeTravelSpec] = None,
writePrivilegesString: Option[String] = None): Option[Table] =
try {
Option(getTable(catalog, ident, timeTravelSpec))
Option(getTable(catalog, ident, timeTravelSpec, writePrivilegesString))
} catch {
case _: NoSuchTableException => None
case _: NoSuchDatabaseException => None
Expand All @@ -343,16 +344,24 @@ private[sql] object CatalogV2Util {
def getTable(
catalog: CatalogPlugin,
ident: Identifier,
timeTravelSpec: Option[TimeTravelSpec] = None): Table = {
timeTravelSpec: Option[TimeTravelSpec] = None,
writePrivilegesString: Option[String] = None): Table = {
if (timeTravelSpec.nonEmpty) {
assert(writePrivilegesString.isEmpty, "Should not write to a table with time travel")
timeTravelSpec.get match {
case v: AsOfVersion =>
catalog.asTableCatalog.loadTable(ident, v.version)
case ts: AsOfTimestamp =>
catalog.asTableCatalog.loadTable(ident, ts.timestamp)
}
} else {
catalog.asTableCatalog.loadTable(ident)
if (writePrivilegesString.isDefined) {
val writePrivileges = writePrivilegesString.get.split(",").map(_.trim)
.map(TableWritePrivilege.valueOf).toSet.asJava
catalog.asTableCatalog.loadTable(ident, writePrivileges)
} else {
catalog.asTableCatalog.loadTable(ident)
}
}
}

Expand Down
Loading

0 comments on commit 027a14b

Please sign in to comment.