Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reproducing MIST function and results #3

Closed
bradezard131 opened this issue Aug 6, 2020 · 18 comments
Closed

Reproducing MIST function and results #3

bradezard131 opened this issue Aug 6, 2020 · 18 comments

Comments

@bradezard131
Copy link
Contributor

Struggling to reproduce the results "MIST w/o Reg" from Table 5. From my understanding, this should be the same network structure as OICR but using the MIST algorithm (Algorithm 1) rather than the typical top-1 box. I have a working implementation of the OICR network structure, it achieves ~41% with OICR and ~44% with PCL using the dilated VGG-16 backbone.

I have tried using the original OICR and PCL hyperparameters (LR=1e-3, WD=5e-4, BS=2 or 4, 5 scales) as well as the new ones in the appendix (LR=1e-2, WD=1e-4, BS=8, 6 scales) and have been unable to break 25% with p=15%. My implementation of this function is included below:

@torch.no_grad()
def mist_label(preds, rois, label, reg_preds=None, p=0.15, tau=0.2):    
    preds = (preds if preds.shape[-1] == label.shape[-1] else preds[:,1:]).clone()  # remove background class if present
    keep_count = int(ceil(p * preds.size(0)))
    all_overlaps = ops.box_iou(rois, rois)  # more efficient to just compute all overlaps once and index in, than computing overlaps each time
    klasses = label.nonzero(as_tuple=True)[0]
    gt_labels = -torch.ones((preds.size(0),), dtype=torch.long, device=preds.device)
    gt_weights = -torch.ones((preds.size(0),), dtype=preds.dtype, device=preds.device)

    for c in klasses:
        cls_prob_tmp = preds[:,c]
        sort_idx = cls_prob_tmp.argsort(descending=True)[:keep_count]  # top p percent of proposals
        
        # add the top scoring
        keep_idxs = [sort_idx[0].item()]

        # add the rest
        for idx in sort_idx[1:]:
            if (all_overlaps[idx, keep_idxs] < tau).all():
                keep_idxs.append(idx.item())

        # add them to the GT set unless they're already selected with a higher score
        is_higher_scoring_class = cls_prob_tmp[keep_idxs] > gt_weights[keep_idxs]
        keep_idxs = torch.tensor(keep_idxs)[is_higher_scoring_class]
        gt_labels[keep_idxs] = c+1
        gt_weights[keep_idxs] = cls_prob_tmp[keep_idxs]

    kept = gt_labels > 0
    
    gt_boxes, gt_labels, gt_weights = rois[kept], gt_labels[kept], gt_weights[kept]
    
    # Adjust boxes with regression, if available
    if reg_preds is not None:
        box_tfms = reg_preds[kept, gt_labels]
        gt_boxes[:,:2] += gt_boxes[:,2:] * box_tfms[:,:2]  # G_x = P_x + P_w * t_x
        gt_boxes[:,2:] *= box_tfms[:,2:].exp()  # G_w = P_w * e^{t_w}
    
    return gt_boxes, gt_labels, gt_weights
@jason718
Copy link
Contributor

jason718 commented Aug 7, 2020

Thanks for the information.

It seems to be correct to me even though I didn't run this code line by line. It's strange that you can only achieve 25% mAP on voc2007(?). Some of my thoughts here:

  1. if only the top-1 prediction is selected in our algo, it should perform the same as OICR. Have you verified that? just as a sanity check.
  2. the gt_boxes, gt_labels, and gt_weights generated here should then be used to choose the fg/bg as done in fast rcnn and OICR/PCL (https://github.com/ppengtang/oicr/blob/0e77f63f1086cef5c836fa2c07300019546f79c5/lib/oicr_layer/layer.py#L224). Did you do that also?

@bradezard131
Copy link
Contributor Author

Yes and yes. As I lower p, the performance approaches OICR, as it increases the performance drops.

Do you do the trick from their ECCV paper? Ignoring boxes with iou < 0.1? Also can I confirm for this particular experiment (vgg, no regression) did you use the original hyperparameters or the new ones in the new appendix? Any warmup or other things?

@jason718
Copy link
Contributor

we tried their trick but it doesn't help too much for our algo.

We use the same params as in Appendix. For warm-up, we set it to 200 iters. But I don't think that's the reason because warm-up doesn't change the results too much.

For other things:

  1. Just to double check: same as oicr, we have 3 student (refinement) branchs, and we use our algo to generate pseudo-label from one to each other. The results are averaged over 3 student branch.
  2. It seems like there is something fundamentally wrong (25%) rather than the small details. We have tested MIST w/o reg multiple times in different development stage and it's always better than vanilla oicr.
  3. There are something I'm curious, have you tried to visualize the results? Can you spot some obvious problems? what does the per-class accuracy look like?

FYI, the config we used (maskrcnn-benchmark format) :

DATALOADER:
  SIZE_DIVISIBILITY: 32
INPUT:
  MIN_SIZE_TRAIN: (480, 576, 688, 864, 1000, 1200) 
  MAX_SIZE_TRAIN: 2000
  MIN_SIZE_TEST: 800
  MAX_SIZE_TEST: 2000
SOLVER:
  IMS_PER_BATCH: 8
  BASE_LR: 0.01
  WEIGHT_DECAY: 0.0001
  WARMUP_ITERS: 200
  STEPS: (20000, 26700)
  MAX_ITER: 30000
  CHECKPOINT_PERIOD: 5000
TEST:
  BBOX_AUG:
    ENABLED: True
    HEUR: "AVG"
    H_FLIP: True
    SCALES: (480, 576, 688, 864, 1000, 1200) 
    MAX_SIZE: 2000
    SCALE_H_FLIP: True

@bradezard131
Copy link
Contributor Author

  1. Yes, exactly the same structure as OICR. Confirmed by the fact I can replicate OICR and PCL without issue. My codebase has an inheritance structure, so for PCL and your MIST model I just inherit from OICR and replace the build_loss function.
  2. Yes, I'm coming to that conclusion.
  3. Nothing jumps out at me no, looks like what I'd get from a good run of WSDDN or other ~25% mAP network. Per class performance seems pretty normal, classes like Car and Motorbike perform significantly better than average, while classes like Person and Chair do quite poorly.

I refactored, moving my sort outside the loop and using torchvision.ops.nms rather than the internal loop. Performance came up a little to 28-ish. One thing I thought of is that your original formulation in the paper doesn't guarantee pseudo-labels are kept for all classes in image labels. If we take the first image in the VOC07 test set (the man with the dog in his jacket) it has the labels Person and Dog. It is possible for all selected regions for dog to be selected also for person. Your paper says in this case, you use the highest scoring class. It is possible for this to result in all boxes for Dog being labelled as Person, since that class tends to be confident (very often over 95%, even though it is a low accuracy class). So then you have no "dog" labelled boxes, even though you know that there is a dog in the image. I tried forcing the top box for each class (the one OICR would pick) to always be kept, even if another class wanted to take it as a secondary box, and this seems to have helped although I am still far behind the reported result. Do you do this or take some other solution to this problem?

Also can I confirm that MIN_SIZE_TEST gets ignored when augmentation is enabled? I'm pretty sure it does, but I want to be sure

@bradezard131
Copy link
Contributor Author

Could you confirm whether all layers are at the same learning rate? and which layers are frozen?

I noticed some variation in playing with learning rate. Typically OICR/PCL would have 10x LR for the refinement heads. I found I got better (although still not good) results for MIST using 10x LR for the WSDDN head too, and even better if I had the backbone at 0.001 with the heads at 0.01.

@jason718
Copy link
Contributor

jason718 commented Aug 17, 2020

Thanks so much for the feedback~ Here are some of my thoughts. Let me know if I miss or mis-understand anything.

  1. You raised a good point for the "conflict" case. I think I might knew the reason. The code attached here deal with "conflicts" in pseudo-label level (is_higher_scoring_class in this function), but we actually deal the "conflicts" it in FG/BG level.
    To make it more clear: our algorithm first generate pseudo-GT using MIST, and the generate FG/BG. We don't reject conflict/redundant pseudo-GT ( as done here through is_higher_scoring_class) in the first stage. As a result, a box can be either dog or person in pseudo-GT. We then generate FG/BG boxes where (1) if a proposal overlaps w/ a pseudo-GT (e.g, dog), it's a FG box for that class (dog); (2) since one proposal can only have one class label, when a box overlaps w/ multiple pseudo-GT, we use the one with higher IoU. If a proposal overlaps w/ multiple pseudo-GT with the same IoU, we use the one w/ higher score.
    It's hard to say how big difference it will make. It's still very strange to me that the implementation here cannot out-performs OICR. I'll also try your implementation in my codebase to see what is the real problem.

  2. yes, MIN_SIZE_TEST won't be used in multi-scale testing.

  3. not all the layers are fine-tuned. For VGG-16, the first two stages are frozen. (same as: https://github.com/ppengtang/oicr/blob/cd136fa4c64cadb19bc3bdf5ba4238b8a55166ca/models/VGG16/train.prototxt#L20)

  4. I actually used the same LR (e.g., 0.01) for all the layers, except that I made the loss for first refinement head 3x bigger than others. This is to make all the losses more balanced. More specifically, the loss weights computed from WSDDN heads using oicr/mist is usually smaller than the weights from a refinement head (due to the double softmax and element-wise multiplication). We found this will slightly improve results.

@bradezard131
Copy link
Contributor Author

Hey there, a couple more questions in this unending struggle:

  1. Do you use dropout in the FC layers when training? I noticed that although the Caffe implementation of OICR/PCL does, the pytorch implementation does not. I was wondering if this might have been a reference when you wrote your models and this quirk might have crossed over.
  2. How do you get the OICR loss weight for BG classes? When OICR gets its weights all boxes are assigned to their highest overlap box and have the weight from that box. If a box overlaps with no pseudo-GT boxes, it defaults to the first box. Does your re-assignment for equal overlaps mean that you reassign these background boxes too?
  3. A long shot but ETA on the reference code becoming available? Now that we're on the other side of ECCV and your second paper it would be really appreciated

@jason718
Copy link
Contributor

jason718 commented Sep 1, 2020

  1. yes I did use Dropout in the FC layers. Actually there are lots of small things in the oicr/pcl model (e.g. dilation, remove the last max-pooling, remove the last relu...) I followed exactly the same arch as their caffe implementation.
  2. loss weights for BG: we follow the OICR practice here (https://github.com/ppengtang/oicr/blob/cd136fa4c64cadb19bc3bdf5ba4238b8a55166ca/lib/oicr_layer/layer.py#L120). We didn't do special things to reassign BG rois.

@bradezard131
Copy link
Contributor Author

bradezard131 commented Sep 2, 2020

remove the last relu

Which layer has a ReLU removed? The only layers I see without ReLU are the output layers that are activated by Softmax

@bradezard131
Copy link
Contributor Author

Are there any updates on an ETA for code release? This is so far proving very difficult to reproduce.

The pseudo-labelling function as you've described it now appears to boil down to "apply NMS to the top 15% of boxes in each positive class". The label assignment strategy is the same as OICR except in the case where multiple GT boxes overlap the same amount with a given proposal, where you assign to the box with the highest predicted class-confidence (essentially choosing the "most correct" assignment). I am still unable to exceed the performance of OICR in my replication.

@bradezard131
Copy link
Contributor Author

Still struggling to reproduce anything even close to your results. I'm sure there must be some hidden details that are missing from the paper still. Below is the code for the key functions as you've described. The network structure is identical to OICR as defined in the original Caffe implementation. The dataset is VOC at 6 scales (the 5 scales that are standard practice in WSOD as well as min_side=1000) with their horizontal flips. I have tried both including and excluding the "difficult" examples from the training set (these are typically excluded, but I saw maskrcnn-benchmark includes them). The optimiser is SGD with the parameters in your config as well as momentum=0.9, warmup=1/3. I've tried both with and without the modified bias parameters (no weight decay and 2xLR for biases). At test time the scores for all scales and flips are averaged for each image, NMS is applied (I've tried both 0.3 as is standard for OICR-based models, and 0.5 as is default in maskrcnn-benchmark) and the top 100 predictions per image are retained (as is default in maskrcnn-benchmark, although I've tried disabling it and it makes a minimal difference).

Overall my experiments have shown strictly worse results for increasing values of p. When p approaches top-1 mAP is around 43%, and it falls to about 35% as p approaches 15%.

def build_loss(preds, labels, rois):
    midn_loss = F.binary_cross_entropy(preds['midn'].sum(0).clamp(EPS, 1-EPS), image_labels)

    gt_boxes, gt_labels, gt_weights = mist_label(preds['midn'] * 3., rois, image_labels)
    pseudo_labels, weights, tfm_tgt = mist_sample_rois(preds['ref0'].softmax(-1), rois, gt_boxes, gt_labels, gt_weights)
    ref0_loss = weighted_softmax_with_loss(preds['ref0'], pseudo_labels, weights)

    gt_boxes, gt_labels, gt_weights = mist_label(preds['ref0'].softmax(-1), rois, image_labels)
    pseudo_labels, weights, tfm_tgt = mist_sample_rois(preds['ref1'].softmax(-1), rois, gt_boxes, gt_labels, gt_weights)
    ref1_loss = weighted_softmax_with_loss(preds['ref1'], pseudo_labels, weights)

    gt_boxes, gt_labels, gt_weights = mist_label(preds['ref1'].softmax(-1), rois, image_labels)
    pseudo_labels, weights, tfm_tgt = mist_sample_rois(preds['ref2'].softmax(-1), rois, gt_boxes, gt_labels, gt_weights)
    ref2_loss = weighted_softmax_with_loss(preds['ref2'], pseudo_labels, weights)

    total_loss = midn_loss + ref0_loss + ref1_loss + ref2_loss
    return total_loss

@torch.no_grad()
def mist_label(preds, rois, label, p=0.15, tau=0.2):
    preds = (preds if preds.shape[-1] == label.shape[-1] else preds[...,1:]).clone()  # remove background class if present
    keep_count = int(p * preds.size(0))
    klasses = label.nonzero(as_tuple=True)[0]

    gt_boxes, gt_scores, gt_labels = [], [], []
    for klass in klasses:
        c_scores = preds[...,klass]
        sort_idxs = c_scores.argsort(dim=0, descending=True)[:keep_count]
        boxes = rois[sort_idxs]
        c_scores = c_scores[sort_idxs]
        keep_idxs = ops.nms(boxes, c_scores, tau)
        gt_boxes.append(boxes[keep_idxs])
        gt_scores.append(c_scores[keep_idxs])
        gt_labels.append(torch.full_like(keep_idxs, klass+1))
    gt_boxes = torch.cat(gt_boxes, 0)
    gt_labels = torch.cat(gt_labels, 0)
    gt_weights = torch.cat(gt_scores, 0)

    return gt_boxes, gt_labels, gt_weights

@torch.no_grad()
def mist_sample_rois(preds, rois, gt_boxes, gt_labels, gt_weights, bg_threshold=0.5):
    overlaps = ops.box_iou(rois, gt_boxes)
    max_overlaps, gt_assignment = overlaps.max(dim=1)
    
    # Compute assignment
    bg = max_overlaps < bg_threshold
    fg = ~bg
    maximally_overlapping = ((max_overlaps[fg].unsqueeze(1) - overlaps[fg]) < EPS) # max_overlaps[fg].unsqueeze(1) == overlaps[fg]  #
    gt_assignment[fg] = (maximally_overlapping * preds[fg][:, gt_labels]).argmax(1)

    # Construct labels
    labels = gt_labels.gather(0, gt_assignment)
    labels[bg] = 0
    
    # Construct weights
    weights = gt_weights.gather(0, gt_assignment)
    
    # Calculate the regression target
    G = to_xywh(gt_boxes).gather(0, gt_assignment.repeat(4,1).T)
    P = to_xywh(rois)
    T = torch.empty_like(rois)
    T[:,:2] = (G[:,:2] - P[:,:2]) / P[:,2:]
    T[:,2:] = (G[:,2:] / P[:,2:]).log()
    
    return labels, weights, T

def weighted_softmax_with_loss(score:torch.Tensor, labels:torch.Tensor, weights:torch.Tensor) -> torch.Tensor:
    # calculate loss
    loss = -weights * F.log_softmax(score, dim=-1).gather(-1, labels.long().unsqueeze(-1)).squeeze(-1)
    valid_sum = weights.gt(1e-12).float().sum()
    if valid_sum < EPS:
        return loss.sum() / loss.numel() 
    else:
        return loss.sum() / valid_sum

@jason718
Copy link
Contributor

  1. We removed the relu before roi_pooling. We did this probably due to some small perf. boost. However, I tested again and it doesn't make too much difference.
  2. Do you insert this code into this repo wsod? Please kindly guide me how I can reproduce these issues exactly. It seems like this piece of code is using maskrcnn-benchmark functions. So if I understand correctly, I can insert this piece of code into maskrcnn-benchmark, am I right?

The results are way too strange. My OICR re-implementation got similar number as yours but mist never hurt. I'm wondering what's the tricky point that screwed up everything.

@bradezard131
Copy link
Contributor Author

No I have an updated repo. I will refactor my code over the weekend and try to put it up by Monday. I have claimed your paper in the PapersWithCode Reproducibility Challenge 2020 and so will need to produce a public implementation anyway. I hope I can reproduce your results as well as get some practice in writing a paper-like document by doing this.

@bradezard131
Copy link
Contributor Author

@jason718
Took a little longer than anticipated but here's an extraction of the relevant code from my mess of a working repo. I struggle to get more than 1 GPU to work with (Titan RTX) so it does a full batch on 1 GPU by default, but I have also tried accumulating batches (i.e. 8 calls to .backward() before calling optimiser.step()) with no better luck. I tried with 4 GPUs doing 2 each and got similar results too though, so I'd be somewhat surprised if the issue was in that.

https://github.com/bradezard131/repro

@bradezard131
Copy link
Contributor Author

I have also now tried using 8 GPUs with batch size 1 on each GPU, and the results are the same as with accumulation or true batch size 8. I also noticed that you don't mention initialisation in your paper, so tried using pytorch's default init as opposed to the normal init used in previous works which got me up to ~35.5 mAP.

@bradezard131
Copy link
Contributor Author

bradezard131 commented Oct 5, 2020

If a proposal overlaps w/ multiple pseudo-GT with the same IoU, we use the one w/ higher score.

A query about this part: Do you use the score of the teacher or student branch at this stage as the tiebreaker?

@jason718
Copy link
Contributor

jason718 commented Oct 5, 2020

replied through email. please check

@anonymousiccv

This comment has been minimized.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants