Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Commit

Permalink
Initial mixed-precision training (#196)
Browse files Browse the repository at this point in the history
* Initial multi-precision training

Adds fp16 support via apex.amp
Also switches communication to apex.DistributedDataParallel

* Add Apex install to dockerfile

* Fixes from @fmassa review

Added support to tools/test_net.py
SOLVER.MIXED_PRECISION -> DTYPE \in {float32, float16}
apex.amp not installed now raises ImportError

* Remove extraneous apex DDP import

* Move to new amp API
  • Loading branch information
slayton58 authored and fmassa committed Apr 19, 2019
1 parent bf04379 commit 08fcf12
Show file tree
Hide file tree
Showing 12 changed files with 70 additions and 4 deletions.
6 changes: 6 additions & 0 deletions INSTALL.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ git clone https://github.com/cocodataset/cocoapi.git
cd cocoapi/PythonAPI
python setup.py build_ext install

# install apex
cd ~github
git clone https://github.com/NVIDIA/apex.git
cd apex
python setup.py install --cuda_ext --cpp_ext

# install PyTorch Detection
cd $INSTALL_DIR
git clone https://github.com/facebookresearch/maskrcnn-benchmark.git
Expand Down
5 changes: 5 additions & 0 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ RUN git clone https://github.com/cocodataset/cocoapi.git \
&& cd cocoapi/PythonAPI \
&& python setup.py build_ext install

# install apex
RUN git clone https://github.com/NVIDIA/apex.git \
&& cd apex \
&& python setup.py install --cuda_ext --cpp_ext

# install PyTorch Detection
ARG FORCE_CUDA="1"
ENV FORCE_CUDA=${FORCE_CUDA}
Expand Down
10 changes: 10 additions & 0 deletions maskrcnn_benchmark/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,3 +431,13 @@
_C.OUTPUT_DIR = "."

_C.PATHS_CATALOG = os.path.join(os.path.dirname(__file__), "paths_catalog.py")

# ---------------------------------------------------------------------------- #
# Precision options
# ---------------------------------------------------------------------------- #

# Precision of input, allowable: (float32, float16)
_C.DTYPE = "float32"

# Enable verbosity in apex.amp
_C.AMP_VERBOSE = False
6 changes: 5 additions & 1 deletion maskrcnn_benchmark/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from maskrcnn_benchmark.utils.comm import get_world_size
from maskrcnn_benchmark.utils.metric_logger import MetricLogger

from apex import amp

def reduce_loss_dict(loss_dict):
"""
Expand Down Expand Up @@ -73,7 +74,10 @@ def do_train(
meters.update(loss=losses_reduced, **loss_dict_reduced)

optimizer.zero_grad()
losses.backward()
# Note: If mixed precision is not used, this ends up doing nothing
# Otherwise apply loss scaling for mixed-precision recipe
with amp.scale_loss(losses, optimizer) as scaled_losses:
scaled_losses.backward()
optimizer.step()

batch_time = time.time() - end
Expand Down
7 changes: 7 additions & 0 deletions maskrcnn_benchmark/layers/batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ def __init__(self, n):
self.register_buffer("running_var", torch.ones(n))

def forward(self, x):
# Cast all fixed parameters to half() if necessary
if x.dtype == torch.float16:
self.weight = self.weight.half()
self.bias = self.bias.half()
self.running_mean = self.running_mean.half()
self.running_var = self.running_var.half()

scale = self.weight * self.running_var.rsqrt()
bias = self.bias - self.running_mean * scale
scale = scale.reshape(1, -1, 1, 1)
Expand Down
6 changes: 5 additions & 1 deletion maskrcnn_benchmark/layers/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
# from ._utils import _C
from maskrcnn_benchmark import _C

nms = _C.nms
from apex import amp

# Only valid with fp32 inputs - give AMP the hint
nms = amp.float_function(_C.nms)

# nms.__doc__ = """
# This function performs Non-maximum suppresion"""
3 changes: 2 additions & 1 deletion maskrcnn_benchmark/layers/roi_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from maskrcnn_benchmark import _C

from apex import amp

class _ROIAlign(Function):
@staticmethod
Expand Down Expand Up @@ -46,14 +47,14 @@ def backward(ctx, grad_output):

roi_align = _ROIAlign.apply


class ROIAlign(nn.Module):
def __init__(self, output_size, spatial_scale, sampling_ratio):
super(ROIAlign, self).__init__()
self.output_size = output_size
self.spatial_scale = spatial_scale
self.sampling_ratio = sampling_ratio

@amp.float_function
def forward(self, input, rois):
return roi_align(
input, rois, self.output_size, self.spatial_scale, self.sampling_ratio
Expand Down
2 changes: 2 additions & 0 deletions maskrcnn_benchmark/layers/roi_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from maskrcnn_benchmark import _C

from apex import amp

class _ROIPool(Function):
@staticmethod
Expand Down Expand Up @@ -52,6 +53,7 @@ def __init__(self, output_size, spatial_scale):
self.output_size = output_size
self.spatial_scale = spatial_scale

@amp.float_function
def forward(self, input, rois):
return roi_pool(input, rois, self.output_size, self.spatial_scale)

Expand Down
2 changes: 1 addition & 1 deletion maskrcnn_benchmark/modeling/poolers.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def forward(self, x, boxes):
for level, (per_level_feature, pooler) in enumerate(zip(x, self.poolers)):
idx_in_level = torch.nonzero(levels == level).squeeze(1)
rois_per_level = rois[idx_in_level]
result[idx_in_level] = pooler(per_level_feature, rois_per_level)
result[idx_in_level] = pooler(per_level_feature, rois_per_level).to(dtype)

return result

Expand Down
5 changes: 5 additions & 0 deletions maskrcnn_benchmark/modeling/roi_heads/mask_head/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,16 @@ def expand_masks(mask, padding):
pad2 = 2 * padding
scale = float(M + pad2) / M
padded_mask = mask.new_zeros((N, 1, M + pad2, M + pad2))

padded_mask[:, :, padding:-padding, padding:-padding] = mask
return padded_mask, scale


def paste_mask_in_image(mask, box, im_h, im_w, thresh=0.5, padding=1):
# Need to work on the CPU, where fp16 isn't supported - cast to float to avoid this
mask = mask.float()
box = box.float()

padded_mask, scale = expand_masks(mask[None], padding=padding)
mask = padded_mask[0, 0]
box = expand_boxes(box[None], scale)[0]
Expand Down
10 changes: 10 additions & 0 deletions tools/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
from maskrcnn_benchmark.utils.logger import setup_logger
from maskrcnn_benchmark.utils.miscellaneous import mkdir

# Check if we can enable mixed-precision via apex.amp
try:
from apex import amp
except ImportError:
raise ImportError('Use APEX for mixed precision via apex.amp')


def main():
parser = argparse.ArgumentParser(description="PyTorch Object Detection Inference")
Expand Down Expand Up @@ -61,6 +67,10 @@ def main():
model = build_detection_model(cfg)
model.to(cfg.MODEL.DEVICE)

# Initialize mixed-precision if necessary
use_mixed_precision = cfg.DTYPE == 'float16'
amp_handle = amp.init(enabled=use_mixed_precision, verbose=cfg.AMP_VERBOSE)

output_dir = cfg.OUTPUT_DIR
checkpointer = DetectronCheckpointer(cfg, model, save_dir=output_dir)
_ = checkpointer.load(cfg.MODEL.WEIGHT)
Expand Down
12 changes: 12 additions & 0 deletions tools/train_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@
from maskrcnn_benchmark.utils.logger import setup_logger
from maskrcnn_benchmark.utils.miscellaneous import mkdir

# See if we can use apex.DistributedDataParallel instead of the torch default,
# and enable mixed-precision via apex.amp
try:
from apex import amp
except ImportError:
raise ImportError('Use APEX for multi-precision via apex.amp')


def train(cfg, local_rank, distributed):
model = build_detection_model(cfg)
Expand All @@ -34,6 +41,11 @@ def train(cfg, local_rank, distributed):
optimizer = make_optimizer(cfg, model)
scheduler = make_lr_scheduler(cfg, optimizer)

# Initialize mixed-precision training
use_mixed_precision = cfg.DTYPE == "float16"
amp_opt_level = 'O1' if use_mixed_precision else 'O0'
model, optimizer = amp.initialize(model, optimizer, opt_level=amp_opt_level)

if distributed:
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[local_rank], output_device=local_rank,
Expand Down

0 comments on commit 08fcf12

Please sign in to comment.