Skip to content

Commit

Permalink
feat: support bagua-net (#89)
Browse files Browse the repository at this point in the history
  • Loading branch information
shjwudp authored and NOBLES5E committed Sep 15, 2021
1 parent 3ba5cfc commit bf166dc
Showing 1 changed file with 38 additions and 9 deletions.
47 changes: 38 additions & 9 deletions bagua-core/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def update_to(self, b=1, bsize=1, tsize=None):
def download_url(url, output_path):
with DownloadProgressBar(unit='B', unit_scale=True,
miniters=1, desc=url.split('/')[-1]) as t:
urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to)
urllib.request.urlretrieve(
url, filename=output_path, reporthook=t.update_to)


def _make_nccl_url(public_version, filename):
Expand All @@ -52,29 +53,48 @@ def _make_nccl_record(cuda_version, full_version, public_version, filename_linux


_nccl_records.append(
_make_nccl_record("11.4", "2.10.3", "2.10", "nccl_2.10.3-1+cuda11.4_x86_64.txz")
_make_nccl_record("11.4", "2.10.3", "2.10",
"nccl_2.10.3-1+cuda11.4_x86_64.txz")
)
_nccl_records.append(
_make_nccl_record("11.3", "2.10.3", "2.10", "nccl_2.10.3-1+cuda11.0_x86_64.txz")
_make_nccl_record("11.3", "2.10.3", "2.10",
"nccl_2.10.3-1+cuda11.0_x86_64.txz")
)
_nccl_records.append(
_make_nccl_record("11.2", "2.10.3", "2.10", "nccl_2.10.3-1+cuda11.0_x86_64.txz")
_make_nccl_record("11.2", "2.10.3", "2.10",
"nccl_2.10.3-1+cuda11.0_x86_64.txz")
)
_nccl_records.append(
_make_nccl_record("11.1", "2.10.3", "2.10", "nccl_2.10.3-1+cuda11.0_x86_64.txz")
_make_nccl_record("11.1", "2.10.3", "2.10",
"nccl_2.10.3-1+cuda11.0_x86_64.txz")
)
_nccl_records.append(
_make_nccl_record("11.0", "2.10.3", "2.10", "nccl_2.10.3-1+cuda11.0_x86_64.txz")
_make_nccl_record("11.0", "2.10.3", "2.10",
"nccl_2.10.3-1+cuda11.0_x86_64.txz")
)
_nccl_records.append(
_make_nccl_record("10.2", "2.10.3", "2.10", "nccl_2.10.3-1+cuda10.2_x86_64.txz")
_make_nccl_record("10.2", "2.10.3", "2.10",
"nccl_2.10.3-1+cuda10.2_x86_64.txz")
)
_nccl_records.append(
_make_nccl_record("10.1", "2.10.3", "2.10", "nccl_2.10.3-1+cuda10.2_x86_64.txz")
_make_nccl_record("10.1", "2.10.3", "2.10",
"nccl_2.10.3-1+cuda10.2_x86_64.txz")
)
library_records["nccl"] = _nccl_records


def install_baguanet(url, destination):
with tempfile.TemporaryDirectory() as tmpdir:
filename = os.path.join(tmpdir, os.path.basename(url))
print("Downloading {}...".format(url))
download_url(url, filename)
outdir = os.path.join(tmpdir, "extract")
shutil.unpack_archive(filename, outdir)
lib_dir = os.path.join(outdir, 'build')
for filename in os.listdir(lib_dir):
shutil.move(os.path.join(lib_dir, filename), destination)


def install_lib(cuda, prefix, library):
record = None
lib_records = library_records
Expand Down Expand Up @@ -126,6 +146,13 @@ def install_lib(cuda, prefix, library):
subdir = os.listdir(outdir)
assert len(subdir) == 1
shutil.move(os.path.join(outdir, subdir[0]), destination)

# Install bagua-net
dst_dir = os.path.join(destination, 'bagua-net')
os.mkdir(dst_dir)
install_baguanet(
"https://github.com/BaguaSys/bagua-net/releases/download/v0.1.1/bagua-net_refs.tags.v0.1.1_x86_64.tar.gz",
dst_dir)
else:
assert False
print("Cleaning up...")
Expand Down Expand Up @@ -187,7 +214,9 @@ def install_dependency_library():
description="Core communication lib for Bagua.",
package_dir={"": "python/"},
packages=find_packages("python/"),
package_data={"": [".data/lib/libnccl.so"]},
package_data={"": [".data/lib/libnccl.so",
".data/bagua-net/libbagua_net.so",
".data/bagua-net/libnccl-net.so"]},
rust_extensions=[
RustExtension(
"bagua_core.bagua_core",
Expand Down

0 comments on commit bf166dc

Please sign in to comment.