Skip to content

Commit

Permalink
Reuse some tests from #5520
Browse files Browse the repository at this point in the history
  • Loading branch information
gjoseph92 committed Nov 19, 2021
1 parent 9315e4c commit 04833a3
Show file tree
Hide file tree
Showing 2 changed files with 248 additions and 0 deletions.
51 changes: 51 additions & 0 deletions distributed/shuffle/tests/test_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import string
from collections import Counter

import pytest

from ..common import npartitions_for, partition_range, worker_for


@pytest.mark.parametrize("npartitions", [1, 2, 3, 5])
@pytest.mark.parametrize("n_workers", [1, 2, 3, 5])
def test_worker_for_distribution(npartitions: int, n_workers: int):
"Test that `worker_for` distributes evenly"
workers = list(string.ascii_lowercase[:n_workers])

with pytest.raises(IndexError, match="Negative"):
worker_for(-1, npartitions, workers)

assignments = [worker_for(i, npartitions, workers) for i in range(npartitions)]

# Test `partition_range`
for w in workers:
first, last = partition_range(w, npartitions, workers)
assert all(
[
first <= p_i <= last if a == w else p_i < first or p_i > last
for p_i, a in enumerate(assignments)
]
)

counter = Counter(assignments)
assert len(counter) == min(npartitions, n_workers)

# Test `npartitions_for`
calculated_counter = {w: npartitions_for(w, npartitions, workers) for w in workers}
assert counter == {
w: count for w, count in calculated_counter.items() if count != 0
}
assert calculated_counter.keys() == set(workers)
# ^ this also checks that workers receiving 0 output partitions were calculated properly

# Test the distribution of worker assignments.
# All workers should be assigned the same number of partitions, or if
# there's an odd number, some workers will be assigned only one extra partition.
counts = set(counter.values())
assert len(counts) <= 2
if len(counts) == 2:
lo, hi = sorted(counts)
assert lo == hi - 1

with pytest.raises(IndexError, match="does not exist"):
worker_for(npartitions, npartitions, workers)
197 changes: 197 additions & 0 deletions distributed/shuffle/tests/test_shuffle_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
from __future__ import annotations

import asyncio
from typing import TYPE_CHECKING

import pandas as pd
import pytest
from pandas.testing import assert_frame_equal

from distributed.utils_test import gen_cluster

from ..common import ShuffleId, npartitions_for, worker_for
from ..shuffle_worker import ShuffleState, ShuffleWorkerExtension

if TYPE_CHECKING:
from distributed import Client, Scheduler, Worker


@gen_cluster([("", 1)])
async def test_installation(s: Scheduler, worker: Worker):
ext = worker.extensions["shuffle"]
assert isinstance(ext, ShuffleWorkerExtension)
assert worker.stream_handlers["shuffle_init"] == ext.shuffle_init
assert worker.handlers["shuffle_receive"] == ext.shuffle_receive
assert worker.handlers["shuffle_inputs_done"] == ext.shuffle_inputs_done

assert ext.worker is worker
assert not ext.shuffles
assert not ext.output_data


@gen_cluster([("", 1)])
async def test_init(s: Scheduler, worker: Worker):
ext: ShuffleWorkerExtension = worker.extensions["shuffle"]
assert not ext.shuffles

id = ShuffleId("foo")
workers = [worker.address, "tcp://foo"]
npartitions = 4

ext.shuffle_init(id, workers, npartitions)
assert ext.shuffles == {
id: ShuffleState(workers, npartitions, 2, barrier_reached=False)
}

with pytest.raises(ValueError, match="already registered"):
ext.shuffle_init(id, [], 0)

# Unchanged after trying to re-register
assert list(ext.shuffles) == [id]


@gen_cluster([("", 1)] * 4)
async def test_add_partition(s: Scheduler, *workers: Worker):
exts: dict[str, ShuffleWorkerExtension] = {
w.address: w.extensions["shuffle"] for w in workers
}

id = ShuffleId("foo")
npartitions = 8
addrs = list(exts)
column = "partition"

for ext in exts.values():
ext.shuffle_init(id, addrs, npartitions)

partition = pd.DataFrame(
{
"A": ["a", "b", "c", "d", "e", "f", "g", "h"],
column: [0, 1, 2, 3, 4, 5, 6, 7],
}
)

ext = exts[addrs[0]]
await ext.add_partition(partition, id, npartitions, column)

for i, data in partition.groupby(column):
i = int(i)
addr = worker_for(i, npartitions, addrs)
ext = exts[addr]
received = ext.output_data[id][i]
assert len(received) == 1
assert_frame_equal(data, received[0])

with pytest.raises(ValueError, match="not registered"):
await ext.add_partition(partition, ShuffleId("bar"), npartitions, column)

# TODO (resilience stage) test failed sends


@gen_cluster([("", 1)] * 4, client=True)
async def test_barrier(c: Client, s: Scheduler, *workers: Worker):
exts: dict[str, ShuffleWorkerExtension] = {
w.address: w.extensions["shuffle"] for w in workers
}

id = ShuffleId("foo")
npartitions = 3
addrs = list(exts)
column = "partition"

for ext in exts.values():
ext.shuffle_init(id, addrs, npartitions)

partition = pd.DataFrame(
{
"A": ["a", "b", "c"],
column: [0, 1, 2],
}
)
first_ext = exts[addrs[0]]
await first_ext.add_partition(partition, id, npartitions, column)

await first_ext.barrier(id)

# Check all workers have been informed of the barrier
for addr, ext in exts.items():
if npartitions_for(addr, npartitions, addrs):
assert ext.shuffles[id].barrier_reached
else:
# No output partitions on this worker; shuffle already cleaned up
assert not ext.shuffles
assert not ext.output_data

# Test check on self
with pytest.raises(AssertionError, match="called multiple times"):
await first_ext.barrier(id)

first_ext.shuffles[id].barrier_reached = False

# RPC to other workers fails
with pytest.raises(AssertionError, match="`inputs_done` called again"):
await first_ext.barrier(id)


@gen_cluster([("", 1)] * 4, client=True)
async def test_get_partition(c: Client, s: Scheduler, *workers: Worker):
exts: dict[str, ShuffleWorkerExtension] = {
w.address: w.extensions["shuffle"] for w in workers
}

id = ShuffleId("foo")
npartitions = 8
addrs = list(exts)
column = "partition"

for ext in exts.values():
ext.shuffle_init(id, addrs, npartitions)

p1 = pd.DataFrame(
{
"A": ["a", "b", "c", "d", "e", "f", "g", "h"],
"partition": [0, 1, 2, 3, 4, 5, 6, 6],
}
)
p2 = pd.DataFrame(
{
"A": ["a", "b", "c", "d", "e", "f", "g", "h"],
"partition": [0, 1, 2, 3, 0, 0, 2, 3],
}
)

first_ext = exts[addrs[0]]
await asyncio.gather(
first_ext.add_partition(p1, id, npartitions, column),
first_ext.add_partition(p2, id, npartitions, column),
)
await first_ext.barrier(id)

empty = pd.DataFrame({"A": [], column: []})

with pytest.raises(AssertionError, match="was expected to go"):
first_ext.get_output_partition(id, 7, empty)

full = pd.concat([p1, p2])
expected_groups = full.groupby("partition")
for output_i in range(npartitions):
addr = worker_for(output_i, npartitions, addrs)
ext = exts[addr]
shuffle = ext.shuffles[id]
parts_left_before = shuffle.out_parts_left

result = ext.get_output_partition(id, output_i, empty)

try:
expected = expected_groups.get_group(output_i)
except KeyError:
expected = empty
assert_frame_equal(expected, result)
assert shuffle.out_parts_left == parts_left_before - 1

# Once all partitions are retrieved, shuffles are cleaned up
for ext in exts.values():
assert not ext.shuffles

with pytest.raises(ValueError, match="not registered"):
first_ext.get_output_partition(id, 0, empty)

0 comments on commit 04833a3

Please sign in to comment.