Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade GPU CI to PyTorch 1.13 #15583

Merged
merged 57 commits into from
Nov 12, 2022
Merged
Show file tree
Hide file tree
Changes from 45 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
49b8e9e
Upgrade GPU CI to PyTorch 1.13
awaelchli Nov 8, 2022
8b5f104
Carlos suggestion
awaelchli Nov 8, 2022
bfb3912
11.6.1
awaelchli Nov 8, 2022
e40ecb0
Merge branch 'master' into ci/gpu-pytorch-1-13
Borda Nov 8, 2022
b8245c1
Remove FIXME
carmocca Nov 9, 2022
4602792
Fix
carmocca Nov 9, 2022
559b200
PUSH_TO_HUB
carmocca Nov 9, 2022
fad88fd
One more
carmocca Nov 9, 2022
2bc6880
Remove unnecessary changes
carmocca Nov 9, 2022
524fb77
Skip colossalai
carmocca Nov 9, 2022
25bf852
Missed 1
carmocca Nov 9, 2022
aab1e18
Revert "PUSH_TO_HUB"
carmocca Nov 10, 2022
bf05912
Uncomment colossalai
carmocca Nov 10, 2022
277144d
Allowlist
carmocca Nov 10, 2022
4f697a8
Consistent config
carmocca Nov 10, 2022
7550cb8
Also for benchmarks
carmocca Nov 10, 2022
289ef81
Merge branch 'master' into ci/gpu-pytorch-1-13
Borda Nov 10, 2022
785dc64
strategies
Borda Nov 10, 2022
bc9d2b2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 10, 2022
706f56d
bagua
Borda Nov 10, 2022
ad6fbd3
push
Borda Nov 10, 2022
42b09c5
Merge branch 'master' into ci/gpu-pytorch-1-13
carmocca Nov 10, 2022
43dd511
Update .azure/gpu-tests-pytorch.yml
carmocca Nov 10, 2022
1452f8e
build
Borda Nov 10, 2022
d2d4a2f
.
Borda Nov 10, 2022
3ce2705
fix test leaking vars
awaelchli Nov 10, 2022
a0d323f
update env var allowlist
awaelchli Nov 10, 2022
64f3827
pytorch removed reshard_after_forward attribute
awaelchli Nov 10, 2022
5a6b3bd
guard bagua import
awaelchli Nov 10, 2022
f0b5166
Revert "guard bagua import"
awaelchli Nov 10, 2022
2fac900
we're closing in 15 mins
awaelchli Nov 10, 2022
7e15a28
last call
awaelchli Nov 10, 2022
2ee106f
you have the right to remain silent
awaelchli Nov 10, 2022
afef95e
push
awaelchli Nov 10, 2022
9e4310d
.0
carmocca Nov 10, 2022
d9f47fb
Benchmarks just use DDP
carmocca Nov 10, 2022
7045290
try to fix
awaelchli Nov 11, 2022
40f2756
fix
awaelchli Nov 11, 2022
502b20e
i shot the sheriff
awaelchli Nov 11, 2022
d0a8f1c
x
awaelchli Nov 11, 2022
d0321bb
super user do
awaelchli Nov 11, 2022
eceb038
i am your father nooooo
awaelchli Nov 11, 2022
cc4e74f
move
awaelchli Nov 11, 2022
d33fe04
bump max nccl version
awaelchli Nov 11, 2022
3840def
single run command
awaelchli Nov 11, 2022
146513d
Apply suggestions from code review
carmocca Nov 11, 2022
496482e
update
awaelchli Nov 12, 2022
627ec9b
cuda 11.6.1 is going to work
awaelchli Nov 12, 2022
32e87b4
push to hub
awaelchli Nov 12, 2022
161a3c4
update benchmark job
awaelchli Nov 12, 2022
01204bd
horovod
awaelchli Nov 12, 2022
c88977e
debug benchmark
awaelchli Nov 12, 2022
0c3eb2d
Update .github/checkgroup.yml
awaelchli Nov 12, 2022
d5a3d14
restore removed lines
awaelchli Nov 12, 2022
30340d8
Merge branch 'master' into ci/gpu-pytorch-1-13
awaelchli Nov 12, 2022
61593a3
Apply suggestions from code review
awaelchli Nov 12, 2022
4381edb
Apply suggestions from code review
Borda Nov 12, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions .azure/gpu-benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:
variables:
DEVICES: $( python -c 'print("$(Agent.Name)".split("_")[-1])' )
container:
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.9-torch1.12-cuda11.6.1"
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.9-torch1.13-cuda11.7.0"
options: "--gpus=all --shm-size=32g"
workspace:
clean: all
Expand All @@ -46,19 +46,43 @@ jobs:

- bash: |
echo "##vso[task.setvariable variable=CUDA_VISIBLE_DEVICES]$(DEVICES)"
cuda_ver=$(python -c "import torch ; print(''.join(map(str, torch.version.cuda.split('.')[:2])))")
echo "##vso[task.setvariable variable=CUDA_VERSION_MM]$cuda_ver"
echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/cu${cuda_ver}/torch_stable.html"
displayName: 'set env. vars'

- bash: |
pip install -e .[strategies] --find-links ${TORCH_URL}
echo $CUDA_VISIBLE_DEVICES
echo $TORCH_URL
lspci | egrep 'VGA|3D'
whereis nvidia
nvidia-smi
which python && which pip
python --version
pip --version
pip list
displayName: 'Image info & NVIDIA'

- bash: |
python -c "fname = 'requirements/pytorch/strategies.txt' ; lines = [line for line in open(fname).readlines() if 'horovod' not in line] ; open(fname, 'w').writelines(lines)"
python -c "fname = 'requirements/pytorch/strategies.txt' ; lines = [line for line in open(fname).readlines() if 'bagua' not in line] ; open(fname, 'w').writelines(lines)"
python -c "fname = 'requirements/pytorch/strategies.txt' ; lines = [line for line in open(fname).readlines() if 'colossalai' not in line] ; open(fname, 'w').writelines(lines)"
Borda marked this conversation as resolved.
Show resolved Hide resolved

PYTORCH_VERSION=$(python -c "import torch; print(torch.__version__.split('+')[0])")
python ./requirements/pytorch/adjust-versions.py requirements/pytorch/base.txt ${PYTORCH_VERSION}
displayName: 'Adjust dependencies'

- bash: pip install -e .[dev,strategies,examples] --find-links ${TORCH_URL}
env:
PACKAGE_NAME: pytorch
FREEZE_REQUIREMENTS: 1
PACKAGE_NAME: "pytorch"
FREEZE_REQUIREMENTS: "1"
displayName: 'Install package'

- bash: |
set -e
pip list
python requirements/collect_env_details.py
python -c "import torch ; mgpu = torch.cuda.device_count() ; assert mgpu == 2, f'GPU: {mgpu}'"
displayName: 'Env details'

- bash: python -m pytest benchmarks -v --durations=0
env:
PL_RUNNING_BENCHMARKS: "1"
Expand Down
23 changes: 15 additions & 8 deletions .azure/gpu-tests-lite.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jobs:
variables:
DEVICES: $( python -c 'print("$(Agent.Name)".split("_")[-1])' )
container:
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.9-torch1.12-cuda11.6.1"
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.9-torch1.13-cuda11.7.0"
# default shm size is 64m. Increase it to avoid:
# 'Error while creating shared memory: unhandled system error, NCCL version 2.7.8'
options: "--gpus=all --shm-size=2gb"
Expand All @@ -50,6 +50,14 @@ jobs:

steps:
- bash: |
echo "##vso[task.setvariable variable=CUDA_VISIBLE_DEVICES]$(DEVICES)"
cuda_ver=$(python -c "import torch ; print(''.join(map(str, torch.version.cuda.split('.')[:2])))")
echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/cu${cuda_ver}/torch_stable.html"
displayName: 'set env. vars'

- bash: |
echo $CUDA_VISIBLE_DEVICES
echo $TORCH_URL
lspci | egrep 'VGA|3D'
whereis nvidia
nvidia-smi
Expand All @@ -60,22 +68,21 @@ jobs:
displayName: 'Image info & NVIDIA'

- bash: |
echo "##vso[task.setvariable variable=CUDA_VISIBLE_DEVICES]$(DEVICES)"
displayName: 'set visible devices'
PYTORCH_VERSION=$(python -c "import torch; print(torch.__version__.split('+')[0])")
python ./requirements/pytorch/adjust-versions.py requirements/lite/base.txt ${PYTORCH_VERSION}
python ./requirements/pytorch/adjust-versions.py requirements/lite/examples.txt ${PYTORCH_VERSION}
displayName: 'Adjust dependencies'

- bash: |
set -e
CUDA_VERSION_MM=$(python -c "import torch ; print(''.join(map(str, torch.version.cuda.split('.')[:2])))")
pip install -e .[dev,strategies,examples] --find-links https://download.pytorch.org/whl/cu${CUDA_VERSION_MM}/torch_stable.html
pip list
pip install -e .[dev,strategies,examples] --find-links ${TORCH_URL}
env:
PACKAGE_NAME: "lite"
FREEZE_REQUIREMENTS: "1"
displayName: 'Install package & dependencies'

- bash: |
set -e
echo $CUDA_VISIBLE_DEVICES
pip list
python requirements/collect_env_details.py
python -c "import torch ; mgpu = torch.cuda.device_count() ; assert mgpu == 2, f'GPU: {mgpu}'"
displayName: 'Env details'
Expand Down
20 changes: 13 additions & 7 deletions .azure/gpu-tests-pytorch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,12 @@ jobs:
- job: testing
strategy:
matrix:
# TODO: package parametrization
'PyTorch - stable':
'PyTorch & strategies': # this uses torch 1.12 as not all strategies support 1.13 yet
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.9-torch1.12-cuda11.6.1"
carmocca marked this conversation as resolved.
Show resolved Hide resolved
scope: "strategies"
'PyTorch - latest':
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.9-torch1.13-cuda11.7.0"
scope: ""
# how long to run the job before automatically cancelling
timeoutInMinutes: "80"
# how much time to give 'run always even if cancelled tasks' before stopping them
Expand Down Expand Up @@ -90,11 +93,11 @@ jobs:
python ./requirements/pytorch/adjust-versions.py requirements/pytorch/examples.txt ${PYTORCH_VERSION}
displayName: 'Adjust dependencies'

- bash: pip install -e .[strategies] -r requirements/pytorch/devel.txt -r requirements/pytorch/examples.txt --find-links ${TORCH_URL}
- bash: pip install -e .[dev,examples] --find-links ${TORCH_URL}
env:
PACKAGE_NAME: "pytorch"
FREEZE_REQUIREMENTS: "1"
displayName: 'Install package'
displayName: 'Install package & extras'

- bash: |
set -e
Expand All @@ -106,14 +109,17 @@ jobs:
CUDA_VERSION_COLOSSALAI=$(python -c "print([ver for ver in [11.3, 11.1] if $CUDA_VERSION_MM_COLOSSALAI >= ver][0])")
pip install "colossalai==0.1.10+torch${PYTORCH_VERSION_COLOSSALAI}cu${CUDA_VERSION_COLOSSALAI}" --find-links https://release.colossalai.org

pip list
displayName: 'Install dependencies'
pip install -r requirements/pytorch/strategies.txt --find-links ${TORCH_URL}

python requirements/pytorch/check-avail-strategies.py
condition: eq(variables['scope'], 'strategies')
displayName: 'Install strategies'

- bash: |
set -e
pip list
python requirements/collect_env_details.py
python -c "import torch ; mgpu = torch.cuda.device_count() ; assert mgpu == 2, f'GPU: {mgpu}'"
python requirements/pytorch/check-avail-strategies.py
python requirements/pytorch/check-avail-extras.py
displayName: 'Env details'

Expand Down
2 changes: 2 additions & 0 deletions .github/checkgroup.yml
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,14 @@ subprojects:
- "build-cuda (3.9, 1.10, 11.3.1)"
- "build-cuda (3.9, 1.11, 11.3.1)"
- "build-cuda (3.9, 1.12, 11.6.1)"
- "build-cuda (3.9, 1.13, 11.7.0)"
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
- "build-hpu (1.5.0, 1.11.0)"
- "build-ipu (3.9, 1.10)"
- "build-NGC"
- "build-pl (3.9, 1.10, 11.3.1)"
- "build-pl (3.9, 1.11, 11.3.1)"
- "build-pl (3.9, 1.12, 11.6.1)"
# TODO: add 1.13
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
- "build-xla (3.7, 1.12)"

# SECTION: lightning_lite
Expand Down
6 changes: 4 additions & 2 deletions .github/workflows/ci-pytorch-dockers.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: Docker

on:
push:
branches: [master, "release/*"]
branches: ["*"] # FIXME
carmocca marked this conversation as resolved.
Show resolved Hide resolved
pull_request:
branches: [master, "release/*"]
types: [opened, reopened, ready_for_review, synchronize] # added `ready_for_review` since draft is skipped
Expand All @@ -24,7 +24,7 @@ concurrency:
cancel-in-progress: ${{ ! (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/heads/release/')) }}

env:
PUSH_TO_HUB: ${{ github.event_name == 'schedule' }}
PUSH_TO_HUB: true # FIXME ${{ github.event_name == 'schedule' }}
carmocca marked this conversation as resolved.
Show resolved Hide resolved

jobs:
build-pl:
Expand All @@ -39,6 +39,7 @@ jobs:
- {python_version: "3.9", pytorch_version: "1.10", cuda_version: "11.3.1"}
- {python_version: "3.9", pytorch_version: "1.11", cuda_version: "11.3.1"}
- {python_version: "3.9", pytorch_version: "1.12", cuda_version: "11.6.1"}
# TODO: add 1.13
Borda marked this conversation as resolved.
Show resolved Hide resolved
steps:
- uses: actions/checkout@v3
- uses: docker/setup-buildx-action@v2
Expand Down Expand Up @@ -100,6 +101,7 @@ jobs:
- {python_version: "3.9", pytorch_version: "1.10", cuda_version: "11.3.1"}
- {python_version: "3.9", pytorch_version: "1.11", cuda_version: "11.3.1"}
- {python_version: "3.9", pytorch_version: "1.12", cuda_version: "11.6.1"}
- {python_version: "3.9", pytorch_version: "1.13", cuda_version: "11.7.0"}
steps:
- uses: actions/checkout@v3
- uses: docker/setup-buildx-action@v2
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/release-docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: Docker

on:
push:
branches: [master, "release/*"]
branches: ["*"]
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
release:
types: [published]

Expand All @@ -19,6 +19,7 @@ jobs:
- {python_version: "3.9", pytorch_version: "1.10", cuda_version: "11.3.1"}
- {python_version: "3.9", pytorch_version: "1.11", cuda_version: "11.3.1"}
- {python_version: "3.9", pytorch_version: "1.12", cuda_version: "11.6.1"}
- {python_version: "3.9", pytorch_version: "1.13", cuda_version: "11.7.0"}
steps:
- name: Checkout
uses: actions/checkout@v3
Expand Down
42 changes: 27 additions & 15 deletions dockers/base-cuda/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
# limitations under the License.

ARG UBUNTU_VERSION=20.04
ARG CUDA_VERSION=11.3.1
ARG CUDA_VERSION=11.7.0


FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION}

ARG PYTHON_VERSION=3.9
ARG PYTORCH_VERSION=1.12
ARG PYTORCH_VERSION=1.13
Borda marked this conversation as resolved.
Show resolved Hide resolved

SHELL ["/bin/bash", "-c"]
# https://techoverflow.net/2019/05/18/how-to-fix-configuring-tzdata-interactive-input-when-building-docker-images/
Expand All @@ -35,11 +36,16 @@ ENV \
RUN \
# TODO: Remove the manual key installation once the base image is updated.
# https://github.com/NVIDIA/nvidia-docker/issues/1631
apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/3bf863cc.pub && \
# https://github.com/NVIDIA/nvidia-docker/issues/1631#issuecomment-1264715214
apt-get update && apt-get install -y wget && \
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/3bf863cc.pub && \
mkdir -p /etc/apt/keyrings/ && mv 3bf863cc.pub /etc/apt/keyrings/ && \
echo "deb [signed-by=/etc/apt/keyrings/3bf863cc.pub] https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/ /" /etc/apt/sources.list.d/cuda.list && \
apt-get update && \
apt-get update -qq --fix-missing && \
NCCL_VER=$(dpkg -s libnccl2 | grep '^Version:' | awk -F ' ' '{print $2}' | awk -F '-' '{print $1}' | grep -ve '^\s*$') && \
CUDA_VERSION_MM="${CUDA_VERSION%.*}" && \
MAX_ALLOWED_NCCL=2.11.4 && \
MAX_ALLOWED_NCCL=2.15.5 && \
TO_INSTALL_NCCL=$(echo -e "$MAX_ALLOWED_NCCL\n$NCCL_VER" | sort -V | head -n1)-1+cuda${CUDA_VERSION_MM} && \
apt-get install -y --no-install-recommends --allow-downgrades --allow-change-held-packages \
build-essential \
Expand Down Expand Up @@ -132,19 +138,26 @@ RUN \

RUN \
# install Bagua
CUDA_VERSION_MM=$(python -c "print(''.join('$CUDA_VERSION'.split('.')[:2]))") && \
CUDA_VERSION_BAGUA=$(python -c "print([ver for ver in [116,113,111,102] if $CUDA_VERSION_MM >= ver][0])") && \
pip install "bagua-cuda$CUDA_VERSION_BAGUA" && \
if [[ "$CUDA_VERSION_MM" = "$CUDA_VERSION_BAGUA" ]]; then python -c "import bagua_core; bagua_core.install_deps()"; fi && \
python -c "import bagua; print(bagua.__version__)"
if [[ $PYTORCH_VERSION != "1.13" ]]; then \
CUDA_VERSION_MM=$(python -c "print(''.join('$CUDA_VERSION'.split('.')[:2]))") ; \
CUDA_VERSION_BAGUA=$(python -c "print([ver for ver in [116,113,111,102] if $CUDA_VERSION_MM >= ver][0])") ; \
pip install "bagua-cuda$CUDA_VERSION_BAGUA" ; \
if [[ "$CUDA_VERSION_MM" = "$CUDA_VERSION_BAGUA" ]]; then \
python -c "import bagua_core; bagua_core.install_deps()"; \
fi ; \
python -c "import bagua; print(bagua.__version__)"; \
fi

RUN \
# install ColossalAI
PYTORCH_VERSION_COLOSSALAI=$(python -c "import torch; print(torch.__version__.split('+')[0][:4])") ; \
CUDA_VERSION_MM_COLOSSALAI=$(python -c "import torch ; print(''.join(map(str, torch.version.cuda)))") ; \
CUDA_VERSION_COLOSSALAI=$(python -c "print([ver for ver in [11.3, 11.1] if $CUDA_VERSION_MM_COLOSSALAI >= ver][0])") ; \
pip install "colossalai==0.1.10+torch${PYTORCH_VERSION_COLOSSALAI}cu${CUDA_VERSION_COLOSSALAI}" --find-links https://release.colossalai.org ; \
python -c "import colossalai; print(colossalai.__version__)" ; \
# TODO: 1.13 wheels are not released, remove skip once they are
if [[ $PYTORCH_VERSION != "1.13" ]]; then \
PYTORCH_VERSION_COLOSSALAI=$(python -c "import torch; print(torch.__version__.split('+')[0][:4])") ; \
CUDA_VERSION_MM_COLOSSALAI=$(python -c "import torch ; print(''.join(map(str, torch.version.cuda)))") ; \
CUDA_VERSION_COLOSSALAI=$(python -c "print([ver for ver in [11.3, 11.1] if $CUDA_VERSION_MM_COLOSSALAI >= ver][0])") ; \
pip install "colossalai==0.1.10+torch${PYTORCH_VERSION_COLOSSALAI}cu${CUDA_VERSION_COLOSSALAI}" --find-links https://release.colossalai.org ; \
python -c "import colossalai; print(colossalai.__version__)" ; \
fi

RUN \
# install rest of strategies
Expand All @@ -163,5 +176,4 @@ RUN \
python -c "import sys; ver = sys.version_info ; assert f'{ver.major}.{ver.minor}' == '$PYTHON_VERSION', ver" && \
python -c "import torch; assert torch.__version__.startswith('$PYTORCH_VERSION'), torch.__version__" && \
python requirements/pytorch/check-avail-extras.py && \
python requirements/pytorch/check-avail-strategies.py && \
rm -rf requirements/
1 change: 1 addition & 0 deletions tests/tests_lite/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def restore_env_variables():
"HOROVOD_FUSION_THRESHOLD",
"RANK", # set by DeepSpeed
"POPLAR_ENGINE_OPTIONS", # set by IPUStrategy
"CUDA_MODULE_LOADING", # leaked since PyTorch 1.13
}
leaked_vars.difference_update(allowlist)
assert not leaked_vars, f"test is leaking environment variable(s): {set(leaked_vars)}"
Expand Down
3 changes: 3 additions & 0 deletions tests/tests_pytorch/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ def restore_env_variables():
"HOROVOD_FUSION_THRESHOLD",
"RANK", # set by DeepSpeed
"POPLAR_ENGINE_OPTIONS", # set by IPUStrategy
"CUDA_MODULE_LOADING", # leaked since PyTorch 1.13
"KMP_INIT_AT_FORK", # leaked since PyTorch 1.13
"KMP_DUPLICATE_LIB_OK", # leaked since PyTorch 1.13
}
leaked_vars.difference_update(allowlist)
assert not leaked_vars, f"test is leaking environment variable(s): {set(leaked_vars)}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,19 +61,13 @@ def on_predict_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> Non
def _assert_layer_fsdp_instance(self) -> None:
assert isinstance(self.layer, FullyShardedDataParallel)
assert isinstance(self.trainer.strategy.precision_plugin, FullyShardedNativeNativeMixedPrecisionPlugin)
# root should not be resharding
assert self.layer.reshard_after_forward is False

precision = torch.float16 if self.precision == 16 else torch.bfloat16
assert self.layer.mixed_precision.param_dtype == precision
assert self.layer.mixed_precision.reduce_dtype == precision
assert self.layer.mixed_precision.buffer_dtype == precision

for layer_num in [0, 2]:
assert isinstance(self.layer.module[layer_num], FullyShardedDataParallel)
# Assert that the nested layers are set reshard_after_forward to True
assert self.layer.module[layer_num].reshard_after_forward is True

assert self.layer[layer_num].mixed_precision.param_dtype == precision
assert self.layer[layer_num].mixed_precision.reduce_dtype == precision
assert self.layer[layer_num].mixed_precision.buffer_dtype == precision
Expand Down Expand Up @@ -106,9 +100,6 @@ def _assert_layer_fsdp_instance(self) -> None:
precision = torch.float16 if self.precision == 16 else torch.bfloat16
for layer_num in [0, 2]:
assert isinstance(self.layer[layer_num], FullyShardedDataParallel)
# Assert that the nested layers are set reshard_after_forward to True
assert self.layer[layer_num].reshard_after_forward

assert self.layer[layer_num].mixed_precision.param_dtype == precision
assert self.layer[layer_num].mixed_precision.reduce_dtype == precision
assert self.layer[layer_num].mixed_precision.buffer_dtype == precision
Expand Down