Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extraction of Region Proposals is not deterministic #2480

Closed
varunnrao opened this issue Jan 12, 2021 · 2 comments
Closed

Extraction of Region Proposals is not deterministic #2480

varunnrao opened this issue Jan 12, 2021 · 2 comments
Labels
upstream issues Issues in other libraries

Comments

@varunnrao
Copy link

varunnrao commented Jan 12, 2021

Instructions To Reproduce the Issue:

  1. Full runnable code or full changes you made:
plt.rcParams['figure.figsize'] = [12, 8]
plt.rcParams['figure.dpi'] = 100
CONFIG_PATH = "config.yaml"
MODEL_PATH = "model_final.pth"


cfg = get_cfg()
cfg.merge_from_file(CONFIG_PATH)
cfg.MODEL.WEIGHTS = MODEL_PATH
model = build_model(cfg)
model.eval()
transform_gen = T.ResizeShortestEdge(
            [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
        )
input_format = cfg.INPUT.FORMAT
device = torch.device(cfg.MODEL.DEVICE)
num_channels = len(cfg.MODEL.PIXEL_MEAN)
pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(num_channels, 1, 1)
pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).view(num_channels, 1, 1)
normalizer = lambda x: (x - pixel_mean) / pixel_std

img = cv2.imread('image0008.jpg')
plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))


bbox_list = []
logit_list = []
if(input_format == "RGB"):
    # whether the model expects BGR inputs or RGB
    img = img[:, :, ::-1]
input_height, input_width = img.shape[:2]
img_transform = transform_gen.get_transform(img).apply_image(img)
img_tensor = torch.as_tensor(img_transform.astype("float32").transpose(2, 0, 1))
img_tensor = normalizer(img_tensor)
img_gpu = img_tensor.to(device)

# can be a list of images; 
img_list = ImageList.from_tensors([img_gpu], model.backbone.size_divisibility)

features = model.backbone(img_list.tensor)

# refer https://detectron2.readthedocs.io/tutorials/models.html#model-output-format
# for output format of proposal generator
# returns a list of proposals; 
proposals, _ = model.proposal_generator(img_list, features)
proposal_boxes = proposals[0].proposal_boxes
scale_x, scale_y = (input_width / proposals[0].image_size[1], input_height / proposals[0].image_size[0])
proposal_boxes.scale(scale_x, scale_y)

# note output format is in [XYXY] even though input may be in [XYXY] or [XYWH]
objectness_logits = proposals[0].objectness_logits

for bbox, logit in zip(proposal_boxes, objectness_logits):
    bbox, logit = bbox.cpu().numpy(), logit.cpu().numpy().item()
    x, y, w, h = int(bbox[0]), int(bbox[1]), int(bbox[2]) - int(bbox[0]) + 1, int(bbox[3]) - int(bbox[1]) + 1
    plt.gca().add_patch(Rectangle((x,y),w,h,linewidth=1,edgecolor='r',facecolor='none'))
    bbox_list.append([x,y,w,h])
    logit_list.append(logit)
plt.show()  
  1. What exact command you run:
    I ran this script
  2. Full logs you observed:
    Output list bbox_list and logit_list differ for the same image when I execute the script multiple times.
    Further it is observed that the bbox coordinates obtained when model(inputs) (the entire model inference rather than just the proposal generator) is executed or the DefaultPredictor is used, are most likely not similar to that obtained at the RPN stage.

Expected behavior:

I would expect the bbox coordinates are constant for different execution runs when the image, config and model remains the same. However, each execution run differs in the output.

  • Are there specific pre/post processing steps missing?
  • Is the output of the RPN stage expected to be non-deterministic?
  • My goal is to use the output of the RPN (image crops) to build another classifier. So I'd require the outputs of RPN to be deterministic. Is this possible?

Environment:

Provide your environment information using the following command:

wget -nc -q https://github.com/facebookresearch/detectron2/raw/master/detectron2/utils/collect_env.py && python collect_env.py
wget: /home/ubuntu/anaconda3/lib/libuuid.so.1: no version information available (required by wget)
------------------------  ----------------------------------------------------------------------------------
sys.platform              linux
Python                    3.6.12 |Anaconda, Inc.| (default, Sep  8 2020, 23:10:56) [GCC 7.3.0]
numpy                     1.19.2
detectron2                0.1.1 
detectron2 compiler       GCC 5.4
detectron2 CUDA compiler  10.0
detectron2 arch flags     sm_37
DETECTRON2_ENV_MODULE     <not set>
PyTorch                   1.4.0 
PyTorch debug build       False
CUDA available            True
GPU 0,1,2,3,4,5,6,7       Tesla K80
CUDA_HOME                 /usr/local/cuda-10.0
NVCC                      Cuda compilation tools, release 10.0, V10.0.130
Pillow                    8.0.1
torchvision               0.5.0 
torchvision arch flags    sm_35, sm_50, sm_60, sm_70, sm_75
cv2                       3.4.2
------------------------  --------
PyTorch built with:
  - GCC 7.3
  - Intel(R) Math Kernel Library Version 2020.0.2 Product Build 20200624 for Intel(R) 64 architecture applicat
ions
  - Intel(R) MKL-DNN v0.21.1 (Git Hash 7d2fd500bc78936d1d648ca713b901012f470dbc)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - NNPACK is enabled
  - CUDA Runtime 10.0
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;
arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;ar
ch=compute_75,code=sm_75;-gencode;arch=compute_37,code=compute_37
  - CuDNN 7.6.3
  - Magma 2.5.1
  - Build settings: BLAS=MKL, BUILD_NAMEDTENSOR=OFF, BUILD_TYPE=Release, CXX_FLAGS= -Wno-deprecated -fvisibili
ty-inlines-hidden -fopenmp -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -O2 -fPIC -Wno-narrowing -Wall -We
xtra -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare
 -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-
strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-error=pedantic -Wno-error=redun
dant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wn
o-maybe-uninitialized -fno-math-errno -fno-trapping-math -Wno-stringop-overflow, DISABLE_NUMA=1, PERF_WITH_AVX
=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, USE_CUDA=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_
MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_STATIC_DISPATCH=OFF, 



@ppwwyyxx
Copy link
Contributor

PyTorch is not deterministic: https://pytorch.org/docs/stable/notes/randomness.html. Therefore there is probably not much we can do, unless there is evidence that detectron2 introduces more non-determinism than what's in pytorch.

@ppwwyyxx ppwwyyxx added the upstream issues Issues in other libraries label Jan 13, 2021
@varunnrao
Copy link
Author

varunnrao commented Jan 13, 2021

I don't think it's the pytorch non determinism causing the issue here. I am using a pretrained detectron2 masked cascade faster rcnn model for inference.

It doesn't seem right that the outputs differ at the RPN stage but it's reproducible after the RoI heads.

Hence I feel there is possibly some non determinism introduced by detectron2. Could you please have another look at the code, in case there is some obvious mistakes in the pre/post processing or model loading?

Could you please reopen this issue? It doesn't seem resolved. Thanks.

@github-actions github-actions bot locked as resolved and limited conversation to collaborators Nov 10, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
upstream issues Issues in other libraries
Projects
None yet
Development

No branches or pull requests

2 participants