Skip to content

Commit

Permalink
[REVIEW] Add scheduler_file argument to support MNMG setup (#1593)
Browse files Browse the repository at this point in the history
## Add scheduler_file argument to support MNMG setup

### Overview:
The primary goal is to provide more flexibility and adaptability in how the Dask cluster for testing is configured.

### Changes:



1. **Allow connecting to an existing cluster** 
- The creation of the `LocalCUDACluster` instances is now contingent on the presence of a `SCHEDULER_FILE` environment variable. If this variable exists, the path to the Dask scheduler file is returned instead of creating a new cluster. This change allows the use of pre-existing clusters specified via the `SCHEDULER_FILE` environment variable.
   
   
 2.  **Remove UCX related flags as they are no longer needed**
- Removed specific flags (`enable_tcp_over_ucx`, `enable_nvlink`, `enable_infiniband`) previously used to initialize the `LocalCUDACluster`.  This is because since `Dask-CUDA 22.02` and `UCX >= 1.11.1` we dont need those. 
  See docs:  https://docs.rapids.ai/api/dask-cuda/nightly/examples/ucx/#localcudacluster-with-automatic-configuration




This could help in situations where test scenarios need to be conducted on a specific pre-existing cluster (especially for MNMG setups) . 


### Testing:

I tested using the following setup: 
Start Cluster:
```
dask scheduler --scheduler-file /raid/vjawa/scheduler.json &
dask-cuda-worker --scheduler-file /raid/vjawa/scheduler.json
```

Run Tests:
```
export SCHEDULER_FILE=/raid/vjawa/scheduler.json 
cd /home/nfs/vjawa/raft/python/raft-dask/raft_dask/test
pytest .
```

Authors:
  - Vibhu Jawa (https://github.com/VibhuJawa)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1593
  • Loading branch information
VibhuJawa authored Jun 22, 2023
1 parent cb77979 commit 6f0abae
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 33 deletions.
2 changes: 1 addition & 1 deletion python/raft-dask/raft_dask/common/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2022, NVIDIA CORPORATION.
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
13 changes: 13 additions & 0 deletions python/raft-dask/raft_dask/test/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
69 changes: 43 additions & 26 deletions python/raft-dask/raft_dask/test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,71 @@
# Copyright (c) 2022, NVIDIA CORPORATION.
# Copyright (c) 2022-2023, NVIDIA CORPORATION.

import os

import pytest

from dask.distributed import Client
from dask_cuda import LocalCUDACluster, initialize
from dask_cuda import LocalCUDACluster

os.environ["UCX_LOG_LEVEL"] = "error"


enable_tcp_over_ucx = True
enable_nvlink = False
enable_infiniband = False


@pytest.fixture(scope="session")
def cluster():
cluster = LocalCUDACluster(protocol="tcp", scheduler_port=0)
yield cluster
cluster.close()
scheduler_file = os.environ.get("SCHEDULER_FILE")
if scheduler_file:
yield scheduler_file
else:
cluster = LocalCUDACluster(protocol="tcp", scheduler_port=0)
yield cluster
cluster.close()


@pytest.fixture(scope="session")
def ucx_cluster():
initialize.initialize(
create_cuda_context=True,
enable_tcp_over_ucx=enable_tcp_over_ucx,
enable_nvlink=enable_nvlink,
enable_infiniband=enable_infiniband,
)
cluster = LocalCUDACluster(
protocol="ucx",
enable_tcp_over_ucx=enable_tcp_over_ucx,
enable_nvlink=enable_nvlink,
enable_infiniband=enable_infiniband,
)
yield cluster
cluster.close()
scheduler_file = os.environ.get("SCHEDULER_FILE")
if scheduler_file:
yield scheduler_file
else:
cluster = LocalCUDACluster(
protocol="ucx",
)
yield cluster
cluster.close()


@pytest.fixture(scope="session")
def client(cluster):
client = Client(cluster)
client = create_client(cluster)
yield client
client.close()


@pytest.fixture()
def ucx_client(ucx_cluster):
client = Client(cluster)
client = create_client(ucx_cluster)
yield client
client.close()


def create_client(cluster):
"""
Create a Dask distributed client for a specified cluster.
Parameters
----------
cluster : LocalCUDACluster instance or str
If a LocalCUDACluster instance is provided, a client will be created
for it directly. If a string is provided, it should specify the path to
a Dask scheduler file. A client will then be created for the cluster
referenced by this scheduler file.
Returns
-------
dask.distributed.Client
A client connected to the specified cluster.
"""
if isinstance(cluster, LocalCUDACluster):
return Client(cluster)
else:
return Client(scheduler_file=cluster)
11 changes: 5 additions & 6 deletions python/raft-dask/raft_dask/test/test_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

import pytest

from dask.distributed import Client, get_worker, wait
from dask.distributed import get_worker, wait

from .conftest import create_client

try:
from raft_dask.common import (
Expand All @@ -43,9 +45,7 @@


def test_comms_init_no_p2p(cluster):

client = Client(cluster)

client = create_client(cluster)
try:
cb = Comms(verbose=True)
cb.init()
Expand Down Expand Up @@ -121,8 +121,7 @@ def func_check_uid_on_worker(sessionId, uniqueId, dask_worker=None):


def test_handles(cluster):

client = Client(cluster)
client = create_client(cluster)

def _has_handle(sessionId):
return local_handle(sessionId, dask_worker=get_worker()) is not None
Expand Down

0 comments on commit 6f0abae

Please sign in to comment.