Skip to content

Commit

Permalink
[SPARK-49540][PS] Unify the usage of distributed_sequence_id
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

in PySpark Classic, it was used via a dataframe method `withSequenceColumn`, while in PySpark Connect, it was used as an internal function

This PR unifies the usage of `distributed_sequence_id`

### Why are the changes needed?
code refactoring

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
updated tests

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#48028 from zhengruifeng/func_withSequenceColumn.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
  • Loading branch information
zhengruifeng committed Sep 9, 2024
1 parent 37b39b4 commit a3b918e
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 24 deletions.
18 changes: 5 additions & 13 deletions python/pyspark/pandas/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
)
from pyspark.sql.utils import is_timestamp_ntz_preferred, is_remote
from pyspark import pandas as ps
from pyspark.pandas.spark import functions as SF
from pyspark.pandas._typing import Label
from pyspark.pandas.spark.utils import as_nullable_spark_type, force_decimal_precision_scale
from pyspark.pandas.data_type_ops.base import DataTypeOps
Expand Down Expand Up @@ -938,19 +939,10 @@ def attach_distributed_sequence_column(
+--------+---+
"""
if len(sdf.columns) > 0:
if is_remote():
from pyspark.sql.connect.column import Column as ConnectColumn
from pyspark.sql.connect.expressions import DistributedSequenceID

return sdf.select(
ConnectColumn(DistributedSequenceID()).alias(column_name),
"*",
)
else:
return PySparkDataFrame(
sdf._jdf.toDF().withSequenceColumn(column_name),
sdf.sparkSession,
)
return sdf.select(
SF.distributed_sequence_id().alias(column_name),
"*",
)
else:
cnt = sdf.count()
if cnt > 0:
Expand Down
12 changes: 12 additions & 0 deletions python/pyspark/pandas/spark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,18 @@ def null_index(col: Column) -> Column:
return Column(sc._jvm.PythonSQLUtils.nullIndex(col._jc))


def distributed_sequence_id() -> Column:
if is_remote():
from pyspark.sql.connect.functions.builtin import _invoke_function

return _invoke_function("distributed_sequence_id")
else:
from pyspark import SparkContext

sc = SparkContext._active_spark_context
return Column(sc._jvm.PythonSQLUtils.distributed_sequence_id())


def collect_top_k(col: Column, num: int, reverse: bool) -> Column:
if is_remote():
from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns
Expand Down
8 changes: 0 additions & 8 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2010,14 +2010,6 @@ class Dataset[T] private[sql](
// For Python API
////////////////////////////////////////////////////////////////////////////

/**
* It adds a new long column with the name `name` that increases one by one.
* This is for 'distributed-sequence' default index in pandas API on Spark.
*/
private[sql] def withSequenceColumn(name: String) = {
select(column(DistributedSequenceID()).alias(name), col("*"))
}

/**
* Converts a JavaRDD to a PythonRDD.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,13 @@ private[sql] object PythonSQLUtils extends Logging {
def pandasCovar(col1: Column, col2: Column, ddof: Int): Column =
Column.internalFn("pandas_covar", col1, col2, lit(ddof))

/**
* A long column that increases one by one.
* This is for 'distributed-sequence' default index in pandas API on Spark.
*/
def distributed_sequence_id(): Column =
Column.internalFn("distributed_sequence_id")

def unresolvedNamedLambdaVariable(name: String): Column =
Column(internal.UnresolvedNamedLambdaVariable.apply(name))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
package org.apache.spark.sql

import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.sql.api.python.PythonSQLUtils.distributed_sequence_id
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, AttributeReference, PythonUDF, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{Expand, Generate, ScriptInputOutputSchema, ScriptTransformation, Window => WindowPlan}
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{count, explode, sum, year}
import org.apache.spark.sql.functions.{col, count, explode, sum, year}
import org.apache.spark.sql.internal.ExpressionUtils.column
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
Expand Down Expand Up @@ -404,7 +405,7 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession {
assertAmbiguousSelfJoin(df12.join(df11, df11("x") === df12("y")))

// Test for AttachDistributedSequence
val df13 = df1.withSequenceColumn("seq")
val df13 = df1.select(distributed_sequence_id().alias("seq"), col("*"))
val df14 = df13.filter($"value" === "A2")
assertAmbiguousSelfJoin(df13.join(df14, df13("key1") === df14("key2")))
assertAmbiguousSelfJoin(df14.join(df13, df13("key1") === df14("key2")))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.scalatest.matchers.should.Matchers._

import org.apache.spark.SparkException
import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.sql.api.python.PythonSQLUtils.distributed_sequence_id
import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
Expand Down Expand Up @@ -2316,7 +2317,8 @@ class DataFrameSuite extends QueryTest
}

test("SPARK-36338: DataFrame.withSequenceColumn should append unique sequence IDs") {
val ids = spark.range(10).repartition(5).withSequenceColumn("default_index")
val ids = spark.range(10).repartition(5).select(
distributed_sequence_id().alias("default_index"), col("id"))
assert(ids.collect().map(_.getLong(0)).toSet === Range(0, 10).toSet)
assert(ids.take(5).map(_.getLong(0)).toSet === Range(0, 5).toSet)
}
Expand Down

0 comments on commit a3b918e

Please sign in to comment.