Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(python): change to_bagua_tensor API to support PyTorch 1.10 #338

Merged
merged 13 commits into from
Oct 31, 2021
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ steps:
command: bash .buildkite/scripts/benchmark_master.sh
plugins:
- docker#v3.8.0:
image: "baguasys/bagua:latest"
image: "baguasys/bagua:master-pytorch-1.9.0-cuda11.1-cudnn8"
workdir: /upstream
user: root
propagate-environment: true
Expand All @@ -20,7 +20,7 @@ steps:
command: bash .buildkite/scripts/benchmark_worker.sh
plugins:
- docker#v3.8.0:
image: "baguasys/bagua:latest"
image: "baguasys/bagua:master-pytorch-1.9.0-cuda11.1-cudnn8"
workdir: /upstream
user: root
propagate-environment: true
Expand All @@ -34,7 +34,7 @@ steps:
command: bash .buildkite/scripts/benchmark.sh
plugins:
- docker#v3.8.0:
image: "baguasys/bagua:latest"
image: "baguasys/bagua:master-pytorch-1.9.0-cuda11.1-cudnn8"
workdir: /upstream
user: root
propagate-environment: true
Expand All @@ -48,7 +48,7 @@ steps:
command: bash .buildkite/scripts/run_pytest.sh
plugins:
- docker#v3.8.0:
image: "baguasys/bagua:latest"
image: "baguasys/bagua:master-pytorch-1.9.0-cuda11.1-cudnn8"
workdir: /upstream
user: root
propagate-environment: true
Expand Down
5 changes: 2 additions & 3 deletions .github/workflows/bagua-pypi-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ concurrency:
jobs:
publish:
runs-on: ubuntu-latest
container: baguasys/bagua:latest
container: baguasys/bagua:master-pytorch-1.9.0-cuda11.1-cudnn8
steps:
- uses: actions/checkout@v2
with:
Expand Down Expand Up @@ -38,7 +38,7 @@ jobs:

check_source_install:
runs-on: ubuntu-latest
container: baguasys/bagua:latest
container: baguasys/bagua:master-pytorch-1.9.0-cuda11.1-cudnn8
needs:
- publish
steps:
Expand Down Expand Up @@ -115,7 +115,6 @@ jobs:
with:
fetch-depth: 0
submodules: recursive

- name: setup python
uses: actions/setup-python@v2
with:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/bagua-python-package-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ on:
jobs:
build:
runs-on: ubuntu-latest
container: baguasys/bagua:latest
container: baguasys/bagua:master-pytorch-1.9.0-cuda11.1-cudnn8
steps:
- uses: actions/checkout@v2
with:
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/dockerhub.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ jobs:
- cuda-version: "cuda11.1"
cudnn-version: "cudnn8"
pytorch-version: "pytorch-1.9.0"
- cuda-version: "cuda11.3"
cudnn-version: "cudnn8"
pytorch-version: "pytorch-1.10.0"
name: 'Build'
runs-on: ubuntu-latest
steps:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pytype.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
build:
# The type of runner that the job will run on
runs-on: ubuntu-latest
container: baguasys/bagua:latest
container: baguasys/bagua:master-pytorch-1.9.0-cuda11.1-cudnn8

# Steps represent a sequence of tasks that will be executed as part of the job
steps:
Expand Down
8 changes: 4 additions & 4 deletions bagua/torch_api/algorithms/decentralized.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def hook():

def _init_states(self, bucket: BaguaBucket):
weight_tensor = bucket.flattened_tensor()
bucket._peer_weight = weight_tensor.to_bagua_tensor("peer_weight")
bucket._peer_weight = weight_tensor.ensure_bagua_tensor("peer_weight")

def init_operations(
self,
Expand Down Expand Up @@ -182,11 +182,11 @@ def _init_states(self, bucket: BaguaBucket):
left_peer_weight_tensor = bucket.flattened_tensor()
right_peer_weight_tensor = bucket.flattened_tensor()

bucket._weight = weight_tensor.to_bagua_tensor("weight")
bucket._left_peer_weight = left_peer_weight_tensor.to_bagua_tensor(
bucket._weight = weight_tensor.ensure_bagua_tensor("weight")
bucket._left_peer_weight = left_peer_weight_tensor.ensure_bagua_tensor(
"left_peer_weight"
)
bucket._right_peer_weight = right_peer_weight_tensor.to_bagua_tensor(
bucket._right_peer_weight = right_peer_weight_tensor.ensure_bagua_tensor(
"right_peer_weight"
)

Expand Down
10 changes: 5 additions & 5 deletions bagua/torch_api/bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(
# padding tensor must be of name bagua_padding_tensor, so that they are always marked as ready for communication in the backend
self.padding_tensor = torch.zeros(
padding, dtype=self.tensors[0].dtype, device=self.tensors[0].device
).to_bagua_tensor(
).ensure_bagua_tensor(
"bagua_padding_tensor_bucket_" + name,
module_name=self.bagua_module_name,
)
Expand Down Expand Up @@ -243,7 +243,7 @@ def append_decentralized_synchronous_op(

Args:
peer_weight (BaguaTensor): A tensor used for averaging model with peers, should be of the same size
with the bucket tensors total size. Use ``self.flattened_tensor().to_bagua_tensor(...)`` to create such a tensor.
with the bucket tensors total size. Use ``self.flattened_tensor().ensure_bagua_tensor(...)`` to create such a tensor.
hierarchical (bool): Enable hierarchical communication. Which means the GPUs on the same machine
will communicate will each other first. After that, machines do inter-node communication. This can
boost performance when the inter-node communication cost is high.
Expand Down Expand Up @@ -292,12 +292,12 @@ def append_low_precision_decentralized_synchronous_op(

Args:
weight (BaguaTensor): Model replica of current worker's local model. It should be of the same size
with the bucket tensors total size. Use ``self.flattened_tensor().to_bagua_tensor(...)`` to create such a tensor.
with the bucket tensors total size. Use ``self.flattened_tensor().ensure_bagua_tensor(...)`` to create such a tensor.
left_peer_weight (BaguaTensor): Model replica of current worker's left peer. It should be of the same size
with the bucket tensors total size. Use ``self.flattened_tensor().to_bagua_tensor(...)`` to create such a tensor,
with the bucket tensors total size. Use ``self.flattened_tensor().ensure_bagua_tensor(...)`` to create such a tensor,
then copy the initializing weights of current worker's left peer to the tensor.
right_peer_weight (BaguaTensor): Model replica of current worker's right peer. It should be of the same size
with the bucket tensors total size. Use ``self.flattened_tensor().to_bagua_tensor(...)`` to create such a tensor.
with the bucket tensors total size. Use ``self.flattened_tensor().ensure_bagua_tensor(...)`` to create such a tensor.
then copy the initializing weights of current worker's right peer to the tensor.
hierarchical (bool): Enable hierarchical communication. Which means the GPUs on the same machine
will communicate will each other first. After that, machines do inter-node communication. This can
Expand Down
51 changes: 27 additions & 24 deletions bagua/torch_api/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class BaguaTensor:
with additional methods.

A Bagua tensor is required to use Bagua's communication algorithms. Users can convert a PyTorch tensor to Bagua
tensor by :meth:`ensure_bagua_tensor` or :meth:`to_bagua_tensor`.
tensor by :meth:`ensure_bagua_tensor`.

Bagua tensor features a proxy structure, where the actual tensor used by backend is accessed via a **"Proxy Tensor"**.
The proxy tensor is registered in Bagua, whenever the Bagua backend needs a tensor (for example use it for
Expand Down Expand Up @@ -134,24 +134,6 @@ def ensure_bagua_tensor(
self._bagua_bucket = None
return self

def bagua_getter_closure(self) -> torch.Tensor:
"""Returns the tensor that will be used in runtime."""
return (
self._bagua_getter_closure(self)
if self._bagua_getter_closure is not None
else self
)

def bagua_setter_closure(self, tensor: torch.Tensor):
"""Sets the tensor that will be used in runtime to a new Pytorch tensor :attr:`tensor`.

Args:
tensor: The new tensor to be set to.
"""

assert self._bagua_setter_closure is not None
self._bagua_setter_closure(self, tensor)

def to_bagua_tensor(
self,
name: Optional[str] = None,
Expand All @@ -161,9 +143,13 @@ def to_bagua_tensor(
):
"""
Create a new Bagua tensor from a PyTorch tensor or parameter and return it.
The original tensor is not changed. A Bagua tensor is required to use
Bagua's communication algorithms. See :meth:`ensure_bagua_tensor` for more
information.
The new Bagua tensor will share the same storage with the input PyTorch tensor.
A Bagua tensor is required to use Bagua's communication algorithms.
See :meth:`ensure_bagua_tensor` for more information.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[blackfmt] reported by reviewdog 🐶

Suggested change

Caveat: Be aware that if the original tensor changes to use a different storage
using for example ``torch.Tensor.set_(...)``, the new Bagua tensor will still
use the old storage.

Args:
name: The unique name of the tensor.
Expand All @@ -173,15 +159,32 @@ def to_bagua_tensor(
getter_closure: A function that accepts a Pytorch tensor as its input and returns a Pytorch tensor as
its output. See :meth:`ensure_bagua_tensor`.
setter_closure: A function that accepts two Pytorch tensors as its inputs and returns nothing. See :meth:`ensure_bagua_tensor`.

Returns:
The new Bagua tensor sharing the same storage with the original tensor.
"""
new_tensor = torch.Tensor(cdata=self._cdata)
new_tensor = self.view(self.dtype)
return new_tensor.ensure_bagua_tensor(
name, module_name, getter_closure, setter_closure
)

def bagua_getter_closure(self) -> torch.Tensor:
"""Returns the tensor that will be used in runtime."""
return (
self._bagua_getter_closure(self)
if self._bagua_getter_closure is not None
else self
)

def bagua_setter_closure(self, tensor: torch.Tensor):
"""Sets the tensor that will be used in runtime to a new Pytorch tensor :attr:`tensor`.

Args:
tensor: The new tensor to be set to.
"""

assert self._bagua_setter_closure is not None
self._bagua_setter_closure(self, tensor)

def bagua_backend_tensor(self) -> B.BaguaTensorPy:
"""
Returns:
Expand Down
56 changes: 56 additions & 0 deletions docker/Dockerfile.pytorch-1.10.0-cuda11.3-cudnn8
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
FROM pytorch/pytorch:1.10.0-cuda11.3-cudnn8-devel

RUN apt-get update && apt-get install -y curl software-properties-common wget sudo
RUN add-apt-repository ppa:git-core/ppa -y
RUN sed -i 's/mozilla\/DST_Root_CA_X3.crt/!mozilla\/DST_Root_CA_X3.crt/g' /etc/ca-certificates.conf && update-ca-certificates
RUN curl -sSf https://apt.kitware.com/kitware-archive.sh | sh
RUN apt-get update && apt-get install -y git cmake
RUN curl https://sh.rustup.rs -sSf | sh -s -- --default-toolchain stable -y
ENV PATH=/root/.cargo/bin:${PATH}
RUN cargo install mdbook mdbook-linkcheck mdbook-katex mdbook-open-on-gh

RUN yes | python3 -m pip install -U setuptools wheel build pip

ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64/stubs/:/usr/local/lib64:/usr/local/lib"
ENV LIBRARY_PATH="/usr/local/cuda/lib64/stubs/:/usr/local/lib64:/usr/local/lib"
ENV PKG_CONFIG_PATH="/usr/local/cuda/pkgconfig/"
ENV CUDA_LIBRARY_PATH="/usr/local/cuda/lib64/"


# OpenMPI version 4.0.3
RUN apt-get update -y && \
DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
bzip2 \
file \
hwloc \
libnuma-dev \
make \
openssh-client \
perl \
tar \
wget && \
rm -rf /var/lib/apt/lists/*
RUN mkdir -p /var/tmp && wget -q -nc --no-check-certificate -P /var/tmp https://www.open-mpi.org/software/ompi/v4.0/downloads/openmpi-4.0.3.tar.bz2 && \
mkdir -p /var/tmp && tar -x -f /var/tmp/openmpi-4.0.3.tar.bz2 -C /var/tmp -j && \
cd /var/tmp/openmpi-4.0.3 && ./configure --disable-getpwuid --disable-oshmem --enable-fortran --enable-mca-no-build=btl-uct --enable-orterun-prefix-by-default --with-cuda --without-verbs && \
make -j$(nproc) && \
make -j$(nproc) install && \
rm -rf /var/tmp/openmpi-4.0.3 /var/tmp/openmpi-4.0.3.tar.bz2 && cd -

# hwloc
RUN mkdir -p /var/tmp && wget -q -nc --no-check-certificate -P /var/tmp https://download.open-mpi.org/release/hwloc/v2.5/hwloc-2.5.0.tar.bz2 && \
mkdir -p /var/tmp && tar -x -f /var/tmp/hwloc-2.5.0.tar.bz2 -C /var/tmp -j && \
cd /var/tmp/hwloc-2.5.0 && ./configure && \
make -j$(nproc) && \
make -j$(nproc) install && \
rm -rf /var/tmp/hwloc* && cd -

# Redis
RUN add-apt-repository ppa:redislabs/redis
RUN apt-get update && apt-get install -y redis
RUN yes | python3 -m pip install -U redis

RUN mkdir /bagua
COPY examples/ /bagua/examples
COPY ./ /var/tmp/bagua
RUN cd /var/tmp/bagua && python3 -m pip install . && cd - && rm -rf /var/tmp/bagua
3 changes: 2 additions & 1 deletion docker/Dockerfile.pytorch-1.9.0-cuda10.2-cudnn7
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
FROM pytorch/pytorch:1.9.0-cuda10.2-cudnn7-devel

RUN apt-get update && apt-get install -y curl software-properties-common wget
RUN apt-get update && apt-get install -y curl software-properties-common wget sudo
RUN add-apt-repository ppa:git-core/ppa -y
RUN sed -i 's/mozilla\/DST_Root_CA_X3.crt/!mozilla\/DST_Root_CA_X3.crt/g' /etc/ca-certificates.conf && update-ca-certificates
RUN curl -sSf https://apt.kitware.com/kitware-archive.sh | sh
RUN apt-get update && apt-get install -y git cmake
Expand Down
3 changes: 2 additions & 1 deletion docker/Dockerfile.pytorch-1.9.0-cuda11.1-cudnn8
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
FROM pytorch/pytorch:1.9.0-cuda11.1-cudnn8-devel

RUN apt-get update && apt-get install -y curl software-properties-common wget
RUN apt-get update && apt-get install -y curl software-properties-common wget sudo
RUN add-apt-repository ppa:git-core/ppa -y
RUN sed -i 's/mozilla\/DST_Root_CA_X3.crt/!mozilla\/DST_Root_CA_X3.crt/g' /etc/ca-certificates.conf && update-ca-certificates
RUN curl -sSf https://apt.kitware.com/kitware-archive.sh | sh
RUN apt-get update && apt-get install -y git cmake
Expand Down
2 changes: 1 addition & 1 deletion tests/comm/test_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def abort():
data = torch.rand(10).cuda()

for _ in range(rank + 1):
comm.allreduce_inplace(data.to_bagua_tensor().bagua_backend_tensor(), 10)
comm.allreduce_inplace(data.ensure_bagua_tensor().bagua_backend_tensor(), 10)

comm_stream.synchronize()

Expand Down