Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed Aug 16, 2024
1 parent 8586259 commit 89ba48b
Show file tree
Hide file tree
Showing 18 changed files with 269 additions and 141 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,22 @@ public interface TableCatalog extends CatalogPlugin {
Table loadTable(Identifier ident) throws NoSuchTableException;

/**
* A variant of {@link #loadTable(Identifier)} that indicates it's for data writing.
* Implementations can override this method to do additional handling for data writing, such as
* checking write permissions.
* Load table metadata by {@link Identifier identifier} from the catalog. Spark will write data
* into this table later.
* <p>
* If the catalog supports views and contains a view for the identifier and not a table, this
* must throw {@link NoSuchTableException}.
*
* @param ident a table identifier
* @param writePrivileges
* @return the table's metadata
* @throws NoSuchTableException If the table doesn't exist or is a view
*
* @since 4.0.0
*/
default Table loadTableForWrite(Identifier ident) throws NoSuchTableException {
default Table loadTable(
Identifier ident,
Set<TableWritePrivilege> writePrivileges) throws NoSuchTableException {
return loadTable(ident);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.catalog;

/**
* The table write privileges that will be provided when loading a table.
*
* @since 4.0.0
*/
public enum TableWritePrivilege {
/**
* The privilege for adding rows to the table.
*/
INSERT,

/**
* The privilege for changing existing rows in th table.
*/
UPDATE,

/**
* The privilege for deleting rows from the table.
*/
DELETE
}
Original file line number Diff line number Diff line change
Expand Up @@ -1321,9 +1321,12 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
cachedConnectRelation
}.getOrElse(cachedRelation)
}.orElse {
val forWrite = "true".equalsIgnoreCase(u.options.get(UnresolvedRelation.FOR_WRITE))
val table = CatalogV2Util.loadTable(catalog, ident, finalTimeTravelSpec, forWrite)
val loaded = createRelation(catalog, ident, table, u.options, u.isStreaming)
val writePrivilegesString =
Option(u.options.get(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES))
val table = CatalogV2Util.loadTable(
catalog, ident, finalTimeTravelSpec, writePrivilegesString)
val loaded = createRelation(
catalog, ident, table, u.clearWritePrivileges.options, u.isStreaming)
loaded.foreach(AnalysisContext.get.relationCache.update(key, _))
u.getTagValue(LogicalPlan.PLAN_ID_TAG).map { planId =>
loaded.map { loadedRelation =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{CTERelationDef, LeafNode, Lo
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId
import org.apache.spark.sql.connector.catalog.TableWritePrivilege
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.types.{DataType, Metadata, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
Expand Down Expand Up @@ -124,20 +125,35 @@ case class UnresolvedRelation(

override def name: String = tableName

def forWrite: UnresolvedRelation = {
val newOptions = new java.util.HashMap[String, String]
newOptions.put(UnresolvedRelation.FOR_WRITE, "true")
newOptions.putAll(options)
copy(options = new CaseInsensitiveStringMap(newOptions))
def requireWritePrivileges(privileges: Seq[TableWritePrivilege]): UnresolvedRelation = {
if (privileges.nonEmpty) {
val newOptions = new java.util.HashMap[String, String]
newOptions.putAll(options)
newOptions.put(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES, privileges.mkString(","))
copy(options = new CaseInsensitiveStringMap(newOptions))
} else {
this
}
}

def clearWritePrivileges: UnresolvedRelation = {
if (options.containsKey(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES)) {
val newOptions = new java.util.HashMap[String, String]
newOptions.putAll(options)
newOptions.remove(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES)
copy(options = new CaseInsensitiveStringMap(newOptions))
} else {
this
}
}

final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_RELATION)
}

object UnresolvedRelation {
// An internal option of `UnresolvedRelation` to indicate that we look up this relation for data
// writing.
val FOR_WRITE = "__for_write__"
// An internal option of `UnresolvedRelation` to specify the required write privileges when
// writing data to this relation.
val REQUIRED_WRITE_PRIVILEGES = "__required_write_privileges__"

def apply(
tableIdentifier: TableIdentifier,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.PARAMETER
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils, IntervalUtils}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ, getZoneId, stringToDate, stringToTimestamp, stringToTimestampWithoutTimeZone}
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsNamespaces, TableCatalog}
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsNamespaces, TableCatalog, TableWritePrivilege}
import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition
import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, Expression => V2Expression, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform}
import org.apache.spark.sql.errors.{DataTypeErrorsBase, QueryCompilationErrors, QueryParsingErrors, SqlScriptingErrors}
Expand Down Expand Up @@ -455,7 +455,7 @@ class AstBuilder extends DataTypeAstBuilder
= visitInsertIntoTable(table)
withIdentClause(relationCtx, ident => {
val insertIntoStatement = InsertIntoStatement(
createUnresolvedRelation(relationCtx, ident, options, forWrite = true),
createUnresolvedRelation(relationCtx, ident, options, Seq(TableWritePrivilege.INSERT)),
partition,
cols,
query,
Expand All @@ -473,7 +473,8 @@ class AstBuilder extends DataTypeAstBuilder
= visitInsertOverwriteTable(table)
withIdentClause(relationCtx, ident => {
InsertIntoStatement(
createUnresolvedRelation(relationCtx, ident, options, forWrite = true),
createUnresolvedRelation(relationCtx, ident, options,
Seq(TableWritePrivilege.INSERT, TableWritePrivilege.DELETE)),
partition,
cols,
query,
Expand All @@ -485,7 +486,8 @@ class AstBuilder extends DataTypeAstBuilder
val options = Option(ctx.optionsClause())
withIdentClause(ctx.identifierReference, ident => {
OverwriteByExpression.byPosition(
createUnresolvedRelation(ctx.identifierReference, ident, options, forWrite = true),
createUnresolvedRelation(ctx.identifierReference, ident, options,
Seq(TableWritePrivilege.INSERT, TableWritePrivilege.DELETE)),
query,
expression(ctx.whereClause().booleanExpression()))
})
Expand Down Expand Up @@ -570,7 +572,8 @@ class AstBuilder extends DataTypeAstBuilder

override def visitDeleteFromTable(
ctx: DeleteFromTableContext): LogicalPlan = withOrigin(ctx) {
val table = createUnresolvedRelation(ctx.identifierReference, forWrite = true)
val table = createUnresolvedRelation(
ctx.identifierReference, writePrivileges = Seq(TableWritePrivilege.DELETE))
val tableAlias = getTableAliasWithoutColumnAlias(ctx.tableAlias(), "DELETE")
val aliasedTable = tableAlias.map(SubqueryAlias(_, table)).getOrElse(table)
val predicate = if (ctx.whereClause() != null) {
Expand All @@ -582,7 +585,8 @@ class AstBuilder extends DataTypeAstBuilder
}

override def visitUpdateTable(ctx: UpdateTableContext): LogicalPlan = withOrigin(ctx) {
val table = createUnresolvedRelation(ctx.identifierReference, forWrite = true)
val table = createUnresolvedRelation(
ctx.identifierReference, writePrivileges = Seq(TableWritePrivilege.UPDATE))
val tableAlias = getTableAliasWithoutColumnAlias(ctx.tableAlias(), "UPDATE")
val aliasedTable = tableAlias.map(SubqueryAlias(_, table)).getOrElse(table)
val assignments = withAssignments(ctx.setClause().assignmentList())
Expand All @@ -605,9 +609,6 @@ class AstBuilder extends DataTypeAstBuilder

override def visitMergeIntoTable(ctx: MergeIntoTableContext): LogicalPlan = withOrigin(ctx) {
val withSchemaEvolution = ctx.EVOLUTION() != null
val targetTable = createUnresolvedRelation(ctx.target, forWrite = true)
val targetTableAlias = getTableAliasWithoutColumnAlias(ctx.targetAlias, "MERGE")
val aliasedTarget = targetTableAlias.map(SubqueryAlias(_, targetTable)).getOrElse(targetTable)

val sourceTableOrQuery = if (ctx.source != null) {
createUnresolvedRelation(ctx.source)
Expand Down Expand Up @@ -638,7 +639,7 @@ class AstBuilder extends DataTypeAstBuilder
s"Unrecognized matched action: ${clause.matchedAction().getText}")
}
}
}
}.toSeq
val notMatchedActions = ctx.notMatchedClause().asScala.map {
clause => {
if (clause.notMatchedAction().INSERT() != null) {
Expand All @@ -659,7 +660,7 @@ class AstBuilder extends DataTypeAstBuilder
s"Unrecognized matched action: ${clause.notMatchedAction().getText}")
}
}
}
}.toSeq
val notMatchedBySourceActions = ctx.notMatchedBySourceClause().asScala.map {
clause => {
val notMatchedBySourceAction = clause.notMatchedBySourceAction()
Expand All @@ -674,7 +675,7 @@ class AstBuilder extends DataTypeAstBuilder
s"Unrecognized matched action: ${clause.notMatchedBySourceAction().getText}")
}
}
}
}.toSeq
if (matchedActions.isEmpty && notMatchedActions.isEmpty && notMatchedBySourceActions.isEmpty) {
throw QueryParsingErrors.mergeStatementWithoutWhenClauseError(ctx)
}
Expand All @@ -693,13 +694,19 @@ class AstBuilder extends DataTypeAstBuilder
throw QueryParsingErrors.nonLastNotMatchedBySourceClauseOmitConditionError(ctx)
}

val targetTable = createUnresolvedRelation(
ctx.target,
writePrivileges = MergeIntoTable.getWritePrivileges(
matchedActions, notMatchedActions, notMatchedBySourceActions))
val targetTableAlias = getTableAliasWithoutColumnAlias(ctx.targetAlias, "MERGE")
val aliasedTarget = targetTableAlias.map(SubqueryAlias(_, targetTable)).getOrElse(targetTable)
MergeIntoTable(
aliasedTarget,
aliasedSource,
mergeCondition,
matchedActions.toSeq,
notMatchedActions.toSeq,
notMatchedBySourceActions.toSeq,
matchedActions,
notMatchedActions,
notMatchedBySourceActions,
withSchemaEvolution)
}

Expand Down Expand Up @@ -3117,15 +3124,11 @@ class AstBuilder extends DataTypeAstBuilder
private def createUnresolvedRelation(
ctx: IdentifierReferenceContext,
optionsClause: Option[OptionsClauseContext] = None,
forWrite: Boolean = false): LogicalPlan = withOrigin(ctx) {
writePrivileges: Seq[TableWritePrivilege] = Nil): LogicalPlan = withOrigin(ctx) {
val options = resolveOptions(optionsClause)
withIdentClause(ctx, parts => {
val relation = new UnresolvedRelation(parts, options, isStreaming = false)
if (forWrite) {
relation.forWrite
} else {
relation
}
relation.requireWritePrivileges(writePrivileges)
})
}

Expand All @@ -3136,14 +3139,10 @@ class AstBuilder extends DataTypeAstBuilder
ctx: ParserRuleContext,
ident: Seq[String],
optionsClause: Option[OptionsClauseContext],
forWrite: Boolean): UnresolvedRelation = withOrigin(ctx) {
writePrivileges: Seq[TableWritePrivilege]): UnresolvedRelation = withOrigin(ctx) {
val options = resolveOptions(optionsClause)
val relation = new UnresolvedRelation(ident, options, isStreaming = false)
if (forWrite) {
relation.forWrite
} else {
relation
}
relation.requireWritePrivileges(writePrivileges)
}

private def resolveOptions(
Expand Down Expand Up @@ -5019,7 +5018,8 @@ class AstBuilder extends DataTypeAstBuilder
if (query.isDefined) {
CacheTableAsSelect(ident.head, query.get, source(ctx.query()), isLazy, options)
} else {
CacheTable(createUnresolvedRelation(ctx.identifierReference, ident, None, forWrite = false),
CacheTable(
createUnresolvedRelation(ctx.identifierReference, ident, None, writePrivileges = Nil),
ident, isLazy, options)
}
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,21 @@ case class MergeIntoTable(
copy(targetTable = newLeft, sourceTable = newRight)
}

object MergeIntoTable {
def getWritePrivileges(
matchedActions: Seq[MergeAction],
notMatchedActions: Seq[MergeAction],
notMatchedBySourceActions: Seq[MergeAction]): Seq[TableWritePrivilege] = {
val privileges = scala.collection.mutable.HashSet.empty[TableWritePrivilege]
(matchedActions.iterator ++ notMatchedActions ++ notMatchedBySourceActions).foreach {
case _: DeleteAction => privileges.add(TableWritePrivilege.DELETE)
case _: UpdateAction | _: UpdateStarAction => privileges.add(TableWritePrivilege.UPDATE)
case _: InsertAction | _: InsertStarAction => privileges.add(TableWritePrivilege.INSERT)
}
privileges.toSeq
}
}

sealed abstract class MergeAction extends Expression with Unevaluable {
def condition: Option[Expression]
override def nullable: Boolean = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,9 +404,9 @@ private[sql] object CatalogV2Util {
catalog: CatalogPlugin,
ident: Identifier,
timeTravelSpec: Option[TimeTravelSpec] = None,
forWrite: Boolean = false): Option[Table] =
writePrivilegesString: Option[String] = None): Option[Table] =
try {
Option(getTable(catalog, ident, timeTravelSpec, forWrite))
Option(getTable(catalog, ident, timeTravelSpec, writePrivilegesString))
} catch {
case _: NoSuchTableException => None
case _: NoSuchDatabaseException => None
Expand All @@ -416,18 +416,20 @@ private[sql] object CatalogV2Util {
catalog: CatalogPlugin,
ident: Identifier,
timeTravelSpec: Option[TimeTravelSpec] = None,
forWrite: Boolean = false): Table = {
writePrivilegesString: Option[String] = None): Table = {
if (timeTravelSpec.nonEmpty) {
assert(!forWrite, "Should not write to a table with time travel")
assert(writePrivilegesString.isEmpty, "Should not write to a table with time travel")
timeTravelSpec.get match {
case v: AsOfVersion =>
catalog.asTableCatalog.loadTable(ident, v.version)
case ts: AsOfTimestamp =>
catalog.asTableCatalog.loadTable(ident, ts.timestamp)
}
} else {
if (forWrite) {
catalog.asTableCatalog.loadTableForWrite(ident)
if (writePrivilegesString.isDefined) {
val writePrivileges = writePrivilegesString.get.split(",").map(_.trim)
.map(TableWritePrivilege.valueOf).toSet.asJava
catalog.asTableCatalog.loadTable(ident, writePrivileges)
} else {
catalog.asTableCatalog.loadTable(ident)
}
Expand Down
Loading

0 comments on commit 89ba48b

Please sign in to comment.