Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix device place error in fcn mask head #6374

Merged
merged 1 commit into from
Oct 26, 2021
Merged

Fix device place error in fcn mask head #6374

merged 1 commit into from
Oct 26, 2021

Conversation

st9007a
Copy link
Contributor

@st9007a st9007a commented Oct 26, 2021

Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.

Motivation

The following error was raised when using mmdet.apis.inference_detector run HTC model on CUDA device :

  File "/usr/local/lib/python3.8/dist-packages/mmdet/apis/inference.py", line 148, in inference_detector
    results = model(return_loss=False, rescale=True, **data)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/mmcv/runner/fp16_utils.py", line 97, in new_func
    return old_func(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/mmdet/models/detectors/base.py", line 174, in forward
    return self.forward_test(img, img_metas, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/mmdet/models/detectors/base.py", line 147, in forward_test
    return self.simple_test(imgs[0], img_metas[0], **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/mmdet/models/detectors/two_stage.py", line 182, in simple_test
    return self.roi_head.simple_test(
  File "/usr/local/lib/python3.8/dist-packages/mmdet/models/roi_heads/htc_roi_head.py", line 489, in simple_test
    segm_result = self.mask_head[-1].get_seg_masks(
  File "/usr/local/lib/python3.8/dist-packages/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py", line 252, in get_seg_masks
    bboxes = bboxes / scale_factor
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Go deeper in the code of mmdet/models/roi_heads/mask_heads/fcn_mask_head.py, I found that scale_factor is created by torch.Tensor in line 248 and compute with boxes in line 252. If the model is on a CUDA device, bboxes will be also put on the CUDA device. But scale_factor created by torch.Tensor is on the CPU device. So, a runtime error will be raised when running inference.

242         if not isinstance(scale_factor, torch.Tensor):
243             if isinstance(scale_factor, float):
244                 scale_factor = np.array([scale_factor] * 4)
245                 warn('Scale_factor should be a Tensor or ndarray '
246                      'with shape (4,), float would be deprecated. ')
247             assert isinstance(scale_factor, np.ndarray)
248             scale_factor = torch.Tensor(scale_factor)
249
250         if rescale:
251             img_h, img_w = ori_shape[:2]
252             bboxes = bboxes / scale_factor

Modification

Move scale_factor to the device of bboxes and also cast its dtype.

BC-breaking (Optional)

Does the modification introduce changes that break the backward-compatibility of the downstream repos?
If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.

Use cases (Optional)

If this PR introduces a new feature, it is better to list some use cases here, and update the documentation.

Checklist

  1. Pre-commit or other linting tools are used to fix the potential lint issues.
  2. The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness.
  3. If the modification has potential influence on downstream projects, this PR should be tested with downstream projects, like MMDet or MMCls.
  4. The documentation has been modified accordingly, like docstring or example tutorials.

Pre-commit hook triggered isort to fix import order in this commit.
@CLAassistant
Copy link

CLAassistant commented Oct 26, 2021

CLA assistant check
All committers have signed the CLA.

Copy link
Collaborator

@jshilong jshilong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ZwwWayne ZwwWayne merged commit 69704b2 into open-mmlab:master Oct 26, 2021
ZwwWayne pushed a commit to ZwwWayne/mmdetection that referenced this pull request Jul 19, 2022
…open-mmlab#6374)

Pre-commit hook triggered isort to fix import order in this commit.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants