Skip to content

Commit

Permalink
fix: prevent a buggy optimization in traced fasterrcnn
Browse files Browse the repository at this point in the history
  • Loading branch information
Bycob authored and mergify[bot] committed Aug 16, 2022
1 parent 90d536e commit dab88ca
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 23 deletions.
46 changes: 23 additions & 23 deletions tests/ut_python/ut_tools_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,32 +34,32 @@ def test_resnet50_export(self):
onnx_file = os.path.join(_temp_dir, "resnet50.onnx")
self.assertTrue(os.path.exists(onnx_file), onnx_file)

# def test_fasterrcnn_export(self):
# # Export model (not pretrained because we don't have permission for the cache)
# subprocess.run(["python3", "trace_torchvision.py", "-vp", "fasterrcnn_resnet50_fpn", "-o", _temp_dir])
# model_file = os.path.join(_temp_dir, "fasterrcnn_resnet50_fpn-cls91.pt")
# self.assertTrue(os.path.exists(model_file), model_file)
def test_fasterrcnn_export(self):
# Export model (not pretrained because we don't have permission for the cache)
subprocess.run(["python3", "trace_torchvision.py", "-vp", "fasterrcnn_resnet50_fpn", "-o", _temp_dir])
model_file = os.path.join(_temp_dir, "fasterrcnn_resnet50_fpn-cls91.pt")
self.assertTrue(os.path.exists(model_file), model_file)

# # Test inference
# rfcnn = torch.jit.load(model_file)
# rfcnn.train()
# model_losses, model_preds = rfcnn(*get_detection_input())
# self.assertTrue("total_loss" in model_losses)
# self.assertTrue(model_losses["total_loss"] > 0)
# self.assertAlmostEqual(
# model_losses["total_loss"].item(),
# sum([model_losses[l].item() for l in model_losses if l != "total_loss"]),
# delta = 0.0001
# )
# Test inference
rfcnn = torch.jit.load(model_file)
rfcnn.train()
model_losses, model_preds = rfcnn(*get_detection_input())
self.assertTrue("total_loss" in model_losses)
self.assertTrue(model_losses["total_loss"] > 0)
self.assertAlmostEqual(
model_losses["total_loss"].item(),
sum([model_losses[l].item() for l in model_losses if l != "total_loss"]),
delta = 0.0001
)

# rfcnn.eval()
# model_losses, model_preds = rfcnn(torch.rand(1, 3, 224, 224))
# self.assertTrue("boxes" in model_preds[0])
rfcnn.eval()
model_losses, model_preds = rfcnn(torch.rand(1, 3, 224, 224))
self.assertTrue("boxes" in model_preds[0])

# # Export to onnx
# subprocess.run(["python3", "trace_torchvision.py", "-vp", "fasterrcnn_resnet50_fpn", "-o", _temp_dir, "--to-onnx", "--weights", model_file])
# onnx_file = os.path.join(_temp_dir, "fasterrcnn_resnet50_fpn-cls91.onnx")
# self.assertTrue(os.path.exists(onnx_file), onnx_file)
# Export to onnx
subprocess.run(["python3", "trace_torchvision.py", "-vp", "fasterrcnn_resnet50_fpn", "-o", _temp_dir, "--to-onnx", "--weights", model_file])
onnx_file = os.path.join(_temp_dir, "fasterrcnn_resnet50_fpn-cls91.onnx")
self.assertTrue(os.path.exists(onnx_file), onnx_file)

def tearDown(self):
print("Removing all files in %s" % _temp_dir)
Expand Down
7 changes: 7 additions & 0 deletions tools/torch/trace_torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class DetectionModel(torch.nn.Module):
def __init__(self, model):
super(DetectionModel, self).__init__()
self.model = model
self.str = ""

def forward(self, x, ids = None, bboxes = None, labels = None):
# type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tuple[Dict[str,Tensor], List[Dict[str, Tensor]]]
Expand Down Expand Up @@ -109,6 +110,12 @@ def forward(self, x, ids = None, bboxes = None, labels = None):
targ = {"boxes": bboxes[start:stop], "labels": labels[start:stop]}
l_targs.append(targ)

# XXX: This prevents a buggy optimization in torchscript.
# Try to remove this after next pytorch update
self.str = str(l_targs)
if l_x[0].shape[0] > 40000:
print(self.str)

losses, predictions = self.model(l_x, l_targs)

# Sum of all losses for finetuning (as done in vision/references/detection/engine.py)
Expand Down

0 comments on commit dab88ca

Please sign in to comment.