Skip to content

Commit

Permalink
Removed kwargs and added them in constructor
Browse files Browse the repository at this point in the history
Signed-off-by: Rishabh Thakur <quic_ristha@quicinc.com>
  • Loading branch information
quic-ristha committed Jul 3, 2023
1 parent 49ddc44 commit f2956e4
Showing 1 changed file with 9 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -344,24 +344,25 @@ def forward(self, inp: torch.Tensor, roi: torch.Tensor, batch_indices: torch.Ten

class NonMaxSuppression(torch.nn.Module):
"""
Implementation od NMS Op in the form of nn.Module
Implementation of NMS Op in the form of nn.Module
"""
@staticmethod
def forward(*args, **kwargs) -> torch.Tensor:
def __init__(self, iou_threshold: float, max_output_boxes_per_class: int):
super().__init__()
self.iou_threshold = iou_threshold
self.max_output_boxes_per_class = max_output_boxes_per_class

def forward(self, *args) -> torch.Tensor:
"""
Forward-pass routine for NMS op
"""
batches_boxes = args[0]
batch_scores = args[1]
iou_thershold = kwargs['iou_threshold']
max_output_boxes_per_class = kwargs['max_output_boxes_per_class']

res = []
for index, (boxes, scores) in enumerate(zip(batches_boxes, batch_scores)):
for class_index, classes_score in enumerate(scores):
res_ = torchvision.ops.nms(boxes, classes_score, iou_thershold)
res_ = torchvision.ops.nms(boxes, classes_score, self.iou_threshold)
for val in res_:
res.append([index, class_index, val.detach()])
res = res[:(max_output_boxes_per_class*(index+1))]
res = res[:(self.max_output_boxes_per_class *(index+1))]
return torch.Tensor(res).type(torch.int64)

0 comments on commit f2956e4

Please sign in to comment.