Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix randomness issue in spark_stratified_split() #1654

Merged
merged 1 commit into from
Mar 1, 2022

Conversation

simonzhaoms
Copy link
Collaborator

@simonzhaoms simonzhaoms commented Feb 26, 2022

Description

Bhrigu reported this issue when he used spark_stratified_split() to split data into train and test. He found when the data were big enough there were duplicate data in train and test which should not be there. Below is a Databricks notebook source file that shows a simple example to reproduce the issue.

Because Spark distributes data with partitions across the nodes in the cluster, F.rand() used in the code may be invoked several times in parallel on different partitions over the window defined in the code, which is non-deterministic even the seed is set.

This PR moves the random number generation to the level of the entire DataFrame instead of within each window to avoid the random invocation of F.rand().

# Databricks notebook source
import recommenders
recommenders.__version__

# Out[1]: '1.0.0'

# COMMAND ----------

spark.version

# Out[2]: '3.1.2'

# COMMAND ----------

nrows = 4262892
ncounts = 2
nusers = int(nrows / ncounts)
col_user = "user"
col_item = "item"
seed = 42

# COMMAND ----------

import numpy as np
import pandas as pd

# Create a dummy user-item dataframe where each user has 2 items
pd_df = pd.DataFrame({
  col_user: np.repeat(range(nusers), ncounts),
  col_item: np.tile(range(ncounts), nusers)
})
data = spark.createDataFrame(pd_df)

from recommenders.datasets.spark_splitters import spark_stratified_split
train, test = spark_stratified_split(data, ratio=0.8, col_user=col_user, col_item=col_item, seed=seed)

# COMMAND ----------

data.show(4)

# +----+----+
# |user|item|
# +----+----+
# |   0|   0|
# |   0|   1|
# |   1|   0|
# |   1|   1|
# +----+----+

# COMMAND ----------

# Check if train and test have duplicate rows
duplicates = train.join(test, on=["user", "item"], how="inner")
duplicate_user_id = duplicates.collect()[0][0]
train.where(f"user = {duplicate_user_id}").show()
test.where(f"user = {duplicate_user_id}").show()

# +------+----+
# |  user|item|
# +------+----+
# |865899|   0|
# +------+----+
# 
# +------+----+
# |  user|item|
# +------+----+
# |865899|   0|
# +------+----+

# COMMAND ----------

# Due to lazy evaluation and randomness, the result may vary and may not be duplicated
train.where(f"user = {duplicate_user_id}").show()
test.where(f"user = {duplicate_user_id}").show()

# +------+----+
# |  user|item|
# +------+----+
# |865899|   1|
# +------+----+
# 
# +------+----+
# |  user|item|
# +------+----+
# |865899|   0|
# +------+----+

Related Issues

Checklist:

  • I have followed the contribution guidelines and code style for this project.
  • I have added tests covering my contributions.
  • I have updated the documentation accordingly.
  • This PR is being made to staging branch and not to main branch.

Copy link
Collaborator

@miguelgfierro miguelgfierro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@miguelgfierro miguelgfierro merged commit 67b0323 into staging Mar 1, 2022
@miguelgfierro miguelgfierro deleted the simonz/_do_stratification_spark branch March 1, 2022 19:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants