diff --git a/python/pyspark/pandas/internal.py b/python/pyspark/pandas/internal.py index 92d4a3357319f..4be345201ba65 100644 --- a/python/pyspark/pandas/internal.py +++ b/python/pyspark/pandas/internal.py @@ -43,6 +43,7 @@ ) from pyspark.sql.utils import is_timestamp_ntz_preferred, is_remote from pyspark import pandas as ps +from pyspark.pandas.spark import functions as SF from pyspark.pandas._typing import Label from pyspark.pandas.spark.utils import as_nullable_spark_type, force_decimal_precision_scale from pyspark.pandas.data_type_ops.base import DataTypeOps @@ -938,19 +939,10 @@ def attach_distributed_sequence_column( +--------+---+ """ if len(sdf.columns) > 0: - if is_remote(): - from pyspark.sql.connect.column import Column as ConnectColumn - from pyspark.sql.connect.expressions import DistributedSequenceID - - return sdf.select( - ConnectColumn(DistributedSequenceID()).alias(column_name), - "*", - ) - else: - return PySparkDataFrame( - sdf._jdf.toDF().withSequenceColumn(column_name), - sdf.sparkSession, - ) + return sdf.select( + SF.distributed_sequence_id().alias(column_name), + "*", + ) else: cnt = sdf.count() if cnt > 0: diff --git a/python/pyspark/pandas/spark/functions.py b/python/pyspark/pandas/spark/functions.py index 6aaa63956c14b..4bcf07f6f6503 100644 --- a/python/pyspark/pandas/spark/functions.py +++ b/python/pyspark/pandas/spark/functions.py @@ -174,6 +174,18 @@ def null_index(col: Column) -> Column: return Column(sc._jvm.PythonSQLUtils.nullIndex(col._jc)) +def distributed_sequence_id() -> Column: + if is_remote(): + from pyspark.sql.connect.functions.builtin import _invoke_function + + return _invoke_function("distributed_sequence_id") + else: + from pyspark import SparkContext + + sc = SparkContext._active_spark_context + return Column(sc._jvm.PythonSQLUtils.distributed_sequence_id()) + + def collect_top_k(col: Column, num: int, reverse: bool) -> Column: if is_remote(): from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 870571b533d09..0fab60a948423 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2010,14 +2010,6 @@ class Dataset[T] private[sql]( // For Python API //////////////////////////////////////////////////////////////////////////// - /** - * It adds a new long column with the name `name` that increases one by one. - * This is for 'distributed-sequence' default index in pandas API on Spark. - */ - private[sql] def withSequenceColumn(name: String) = { - select(column(DistributedSequenceID()).alias(name), col("*")) - } - /** * Converts a JavaRDD to a PythonRDD. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index 7dbc586f64730..93082740cca64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -176,6 +176,13 @@ private[sql] object PythonSQLUtils extends Logging { def pandasCovar(col1: Column, col2: Column, ddof: Int): Column = Column.internalFn("pandas_covar", col1, col2, lit(ddof)) + /** + * A long column that increases one by one. + * This is for 'distributed-sequence' default index in pandas API on Spark. + */ + def distributed_sequence_id(): Column = + Column.internalFn("distributed_sequence_id") + def unresolvedNamedLambdaVariable(name: String): Column = Column(internal.UnresolvedNamedLambdaVariable.apply(name)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala index 310b5a62c908a..d888b09d76eac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala @@ -18,10 +18,11 @@ package org.apache.spark.sql import org.apache.spark.api.python.PythonEvalType +import org.apache.spark.sql.api.python.PythonSQLUtils.distributed_sequence_id import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, AttributeReference, PythonUDF, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{Expand, Generate, ScriptInputOutputSchema, ScriptTransformation, Window => WindowPlan} import org.apache.spark.sql.expressions.Window -import org.apache.spark.sql.functions.{count, explode, sum, year} +import org.apache.spark.sql.functions.{col, count, explode, sum, year} import org.apache.spark.sql.internal.ExpressionUtils.column import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -404,7 +405,7 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { assertAmbiguousSelfJoin(df12.join(df11, df11("x") === df12("y"))) // Test for AttachDistributedSequence - val df13 = df1.withSequenceColumn("seq") + val df13 = df1.select(distributed_sequence_id().alias("seq"), col("*")) val df14 = df13.filter($"value" === "A2") assertAmbiguousSelfJoin(df13.join(df14, df13("key1") === df14("key2"))) assertAmbiguousSelfJoin(df14.join(df13, df13("key1") === df14("key2"))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index b1c41033fd760..9bfbdda33c36d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -29,6 +29,7 @@ import org.scalatest.matchers.should.Matchers._ import org.apache.spark.SparkException import org.apache.spark.api.python.PythonEvalType +import org.apache.spark.sql.api.python.PythonSQLUtils.distributed_sequence_id import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -2316,7 +2317,8 @@ class DataFrameSuite extends QueryTest } test("SPARK-36338: DataFrame.withSequenceColumn should append unique sequence IDs") { - val ids = spark.range(10).repartition(5).withSequenceColumn("default_index") + val ids = spark.range(10).repartition(5).select( + distributed_sequence_id().alias("default_index"), col("id")) assert(ids.collect().map(_.getLong(0)).toSet === Range(0, 10).toSet) assert(ids.take(5).map(_.getLong(0)).toSet === Range(0, 5).toSet) }