Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-49540][PS] Unify the usage of distributed_sequence_id #48028

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 5 additions & 13 deletions python/pyspark/pandas/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
)
from pyspark.sql.utils import is_timestamp_ntz_preferred, is_remote
from pyspark import pandas as ps
from pyspark.pandas.spark import functions as SF
from pyspark.pandas._typing import Label
from pyspark.pandas.spark.utils import as_nullable_spark_type, force_decimal_precision_scale
from pyspark.pandas.data_type_ops.base import DataTypeOps
Expand Down Expand Up @@ -938,19 +939,10 @@ def attach_distributed_sequence_column(
+--------+---+
"""
if len(sdf.columns) > 0:
if is_remote():
from pyspark.sql.connect.column import Column as ConnectColumn
from pyspark.sql.connect.expressions import DistributedSequenceID

return sdf.select(
ConnectColumn(DistributedSequenceID()).alias(column_name),
"*",
)
else:
return PySparkDataFrame(
sdf._jdf.toDF().withSequenceColumn(column_name),
sdf.sparkSession,
)
return sdf.select(
SF.distributed_sequence_id().alias(column_name),
"*",
)
else:
cnt = sdf.count()
if cnt > 0:
Expand Down
12 changes: 12 additions & 0 deletions python/pyspark/pandas/spark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,18 @@ def null_index(col: Column) -> Column:
return Column(sc._jvm.PythonSQLUtils.nullIndex(col._jc))


def distributed_sequence_id() -> Column:
if is_remote():
from pyspark.sql.connect.functions.builtin import _invoke_function

return _invoke_function("distributed_sequence_id")
else:
from pyspark import SparkContext

sc = SparkContext._active_spark_context
return Column(sc._jvm.PythonSQLUtils.distributed_sequence_id())


def collect_top_k(col: Column, num: int, reverse: bool) -> Column:
if is_remote():
from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns
Expand Down
8 changes: 0 additions & 8 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1994,14 +1994,6 @@ class Dataset[T] private[sql](
// For Python API
////////////////////////////////////////////////////////////////////////////

/**
* It adds a new long column with the name `name` that increases one by one.
* This is for 'distributed-sequence' default index in pandas API on Spark.
*/
private[sql] def withSequenceColumn(name: String) = {
select(column(DistributedSequenceID()).alias(name), col("*"))
}

/**
* Converts a JavaRDD to a PythonRDD.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,13 @@ private[sql] object PythonSQLUtils extends Logging {
def pandasCovar(col1: Column, col2: Column, ddof: Int): Column =
Column.internalFn("pandas_covar", col1, col2, lit(ddof))

/**
* A long column that increases one by one.
* This is for 'distributed-sequence' default index in pandas API on Spark.
*/
def distributed_sequence_id(): Column =
Column.internalFn("distributed_sequence_id")

def unresolvedNamedLambdaVariable(name: String): Column =
Column(internal.UnresolvedNamedLambdaVariable.apply(name))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
package org.apache.spark.sql

import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.sql.api.python.PythonSQLUtils.distributed_sequence_id
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, AttributeReference, PythonUDF, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{Expand, Generate, ScriptInputOutputSchema, ScriptTransformation, Window => WindowPlan}
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{count, explode, sum, year}
import org.apache.spark.sql.functions.{col, count, explode, sum, year}
import org.apache.spark.sql.internal.ExpressionUtils.column
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
Expand Down Expand Up @@ -404,7 +405,7 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession {
assertAmbiguousSelfJoin(df12.join(df11, df11("x") === df12("y")))

// Test for AttachDistributedSequence
val df13 = df1.withSequenceColumn("seq")
val df13 = df1.select(distributed_sequence_id().alias("seq"), col("*"))
val df14 = df13.filter($"value" === "A2")
assertAmbiguousSelfJoin(df13.join(df14, df13("key1") === df14("key2")))
assertAmbiguousSelfJoin(df14.join(df13, df13("key1") === df14("key2")))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.scalatest.matchers.should.Matchers._

import org.apache.spark.SparkException
import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.sql.api.python.PythonSQLUtils.distributed_sequence_id
import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
Expand Down Expand Up @@ -2316,7 +2317,8 @@ class DataFrameSuite extends QueryTest
}

test("SPARK-36338: DataFrame.withSequenceColumn should append unique sequence IDs") {
val ids = spark.range(10).repartition(5).withSequenceColumn("default_index")
val ids = spark.range(10).repartition(5).select(
distributed_sequence_id().alias("default_index"), col("id"))
assert(ids.collect().map(_.getLong(0)).toSet === Range(0, 10).toSet)
assert(ids.take(5).map(_.getLong(0)).toSet === Range(0, 5).toSet)
}
Expand Down