diff --git a/TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py b/TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py index 570c35ae21..691eaeecee 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py @@ -344,24 +344,25 @@ def forward(self, inp: torch.Tensor, roi: torch.Tensor, batch_indices: torch.Ten class NonMaxSuppression(torch.nn.Module): """ - Implementation od NMS Op in the form of nn.Module + Implementation of NMS Op in the form of nn.Module """ - @staticmethod - def forward(*args, **kwargs) -> torch.Tensor: + def __init__(self, iou_threshold: float, max_output_boxes_per_class: int): + super().__init__() + self.iou_threshold = iou_threshold + self.max_output_boxes_per_class = max_output_boxes_per_class + + def forward(self, *args) -> torch.Tensor: """ Forward-pass routine for NMS op """ batches_boxes = args[0] batch_scores = args[1] - iou_thershold = kwargs['iou_threshold'] - max_output_boxes_per_class = kwargs['max_output_boxes_per_class'] res = [] for index, (boxes, scores) in enumerate(zip(batches_boxes, batch_scores)): for class_index, classes_score in enumerate(scores): - res_ = torchvision.ops.nms(boxes, classes_score, iou_thershold) + res_ = torchvision.ops.nms(boxes, classes_score, self.iou_threshold) for val in res_: res.append([index, class_index, val.detach()]) - res = res[:(max_output_boxes_per_class*(index+1))] + res = res[:(self.max_output_boxes_per_class *(index+1))] return torch.Tensor(res).type(torch.int64) -