Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert models to TorchScript #46

Closed
bfortuner opened this issue Oct 13, 2019 · 39 comments
Closed

Convert models to TorchScript #46

bfortuner opened this issue Oct 13, 2019 · 39 comments
Labels
enhancement Improvements or good new features

Comments

@bfortuner
Copy link

Do you have any examples of how to convert these models into a format runnable in C++?

@ppwwyyxx
Copy link
Contributor

Torchscript does not currently support these models.

@ppwwyyxx ppwwyyxx added the enhancement Improvements or good new features label Oct 13, 2019
@bfortuner
Copy link
Author

Got it. Is that because some of the ops aren't supported yet? Is there another way to deploy these models to a c++ environment? (E.g. onnx --> caffe2 or tensorrt)

Is this lack of support true for object detection models in general? Or is this more specific to the SOTA implementations in detectron?

How much work would it be to get one of these models into a c++ compatible format?

Thanks!

@ppwwyyxx
Copy link
Contributor

ppwwyyxx commented Oct 13, 2019

We're working on getting TorchScript support. onnx/caffe2 deployment support (discussed in #8) is internal for now, but will also be released later released already.

@bfortuner
Copy link
Author

Thank you! Do you know if this lack of support true for object detection models in general? Or is this more specific to the SOTA implementations in detectron?

And second, does torchscript support the operations used in detectron or does torchscript require source changes to make this work?

@bfortuner
Copy link
Author

I'm curious to know the best way to deploy object detection models trained in pytorch to an optimized format runnable in c++.

@lucasjinreal
Copy link

@bfortuner Using libtorch does not gain much acceleration in terms of speed. Exporting to onnx and convert to TensorRT engine is the best way to deploy these models.

Also, onnxruntime trying supporting all ops on top of TensorRT provider, but there are lots of them does not supported and have to running on CPU.

@nikolausWest
Copy link

@ppwwyyxx Thanks for the added clarity. Could you expand at all on what you mean with "take some time to be ready"? Is that something like for the next release or more in some unknown distant future?

@bfortuner
Copy link
Author

Yeah, I'm wondering if there is a tutorial/paper about recommended approaches to c++ deployment with PyTorch. It seems there are a lot of different ways, but it's not clear what the "best" way is, or what the PyTorch team recommends in the future. I'll post in PyTorch discussion!

@fmassa
Copy link

fmassa commented Oct 15, 2019

FYI torchvision models (including Faster R-CNN and Mask R-CNN) will soon support exporting its models to both ONNX and TorchScript, see pytorch/vision#1461 pytorch/vision#1407 and pytorch/vision#1401 for some representative PRs.

I believe the learnings from this conversion step done for torchvision models will be very helpful for planning detectron2 models to be exportable to TorchScript.

@bfortuner
Copy link
Author

bfortuner commented Oct 15, 2019

Thanks for the update! I'm curious to know if TorchScript needs to make changes, too (are there any hard blockers)? Or is it mostly on our end to make our code compatible with the current TorchScript api?

The PRs above suggest it will still be a burden for our developers to bring their SOTA models into production

@fmassa
Copy link

fmassa commented Oct 16, 2019

@bfortuner I think it will be a two-sided change: TorchScript support for Python features will continue improving, but the user might need to adapt a bit their code to make it better fit the current supported.
This means avoiding using some libraries in the inference code-path (like numpy, scipy, etc).

As pytorch/vision#1407 already shows, a complicated model such as Mask R-CNN can already be converted to TorchScript, without changing too much the code (although the original code took some precautions to avoid using too many Python features).

cc @suo who can give a more accurate picture of TorchScript

@gslotman
Copy link

Say I want to convert a detectron2 mask-rcnn model to C++ (ideally using torchscript/libtorch), what's the current best approach? I tried various things last week but with no good solution. Things I tried (using recent detectron2, pytorch and torchvision code):

  1. Naively try to convert some of the blocks to torchscript using torch.jit.script. This will fail one various things. Example stacktrace:
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.6/dist-packages/torch/jit/__init__.py", line 1255, in script
    return torch.jit._recursive.recursive_script(obj)
  File "/usr/local/lib/python3.6/dist-packages/torch/jit/_recursive.py", line 534, in recursive_script
    return create_script_module(nn_module, infer_methods_to_compile(nn_module))
  File "/usr/local/lib/python3.6/dist-packages/torch/jit/_recursive.py", line 296, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, cpp_module, stubs)
  File "/usr/local/lib/python3.6/dist-packages/torch/jit/_recursive.py", line 336, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/usr/local/lib/python3.6/dist-packages/torch/jit/__init__.py", line 1593, in _construct
    init_fn(script_module)
  File "/usr/local/lib/python3.6/dist-packages/torch/jit/_recursive.py", line 328, in init_fn
    scripted = recursive_script(orig_value)
  File "/usr/local/lib/python3.6/dist-packages/torch/jit/_recursive.py", line 534, in recursive_script
    return create_script_module(nn_module, infer_methods_to_compile(nn_module))
  File "/usr/local/lib/python3.6/dist-packages/torch/jit/_recursive.py", line 296, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, cpp_module, stubs)
  File "/usr/local/lib/python3.6/dist-packages/torch/jit/_recursive.py", line 340, in create_script_module_impl
    create_methods_from_stubs(concrete_type, stubs)
  File "/usr/local/lib/python3.6/dist-packages/torch/jit/_recursive.py", line 259, in create_methods_from_stubs
    concrete_type._create_methods(defs, rcbs, defaults)
RuntimeError: 
Unknown type name 'torch.nn.SyncBatchNorm':
  File "/detectron2_repo/detectron2/layers/wrappers.py", line 64
            # https://github.com/pytorch/pytorch/issues/12013
            assert not isinstance(
                self.norm, torch.nn.SyncBatchNorm
                           ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
            ), "SyncBatchNorm does not support empty inputs!"

I get similar errors when trying to convert other layers (example, torchscript didn't support del statements in some of the forward passes etc)

  1. Secondly I tried to just copy the weights from detectron2 to a mask-rcnn model defined in pytorch/torchvision. I limited myself to the backbone architecture, example definition in torchvision:

backbone = torchvision.models.detection.backbone_utils.resnet_fpn_backbone('resnet50', True)
Although I succeeded in copying the weights, off course the model.backbone.forward calls ended up giving different results. Most likely due to slightly different definitions of the two (detectron2 vs torchvision) architectures and forward passes.

  1. My third and final try was to use the onnx/caffe2 exporter. This more or less worked (eg, I ended up with a caffe2 model, but I haven't compared the outputs of the models yet), however, only afterwards I realized I couldn't import onnx models into pytorch, and adding caffe2 support for our deployment would be quite cumbersome, since we just switched from caffe to libtorch... To me it seems that exporting to onnx should be quite similar to exporting to torchscript, so maybe it's quite easy to change the caffe2 exporter to torchscript?

@cbasavaraj
Copy link

Hi all, looks like PyTorch 1.4 and torchvision 0.5 have made progress on this and a couple of related issues. When will we see the updates rolling out to detectron2? Please see my related question here on the forum: https://discuss.pytorch.org/t/pytorch-1-4-torchvision-0-5-vs-detectron/67002

@tengerye
Copy link

Hi, I am also having some problems of JIT conversion. It raises an error:

RuntimeError: 
Unknown type name 'torch.nn.SyncBatchNorm':
  File "/detectron2/detectron2/layers/wrappers.py", line 67
            # https://github.com/pytorch/pytorch/issues/12013
            assert not isinstance(
                self.norm, torch.nn.SyncBatchNorm
                           ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
            ), "SyncBatchNorm does not support empty inputs!"

Since ONNX conversion gives fix size of input, it is not suitable in my case. Any help please?

@tengerye
Copy link

An obvious disadvantage of ONNX is we need to fix the input, but some detection models can take flexible size input. JIT-supporting is necessary and urgent.

@GitHubChrischen
Copy link

any good news?

@ppwwyyxx
Copy link
Contributor

Progress has been made recently (https://github.com/facebookresearch/detectron2/pulls?q=is%3Apr+author%3Achenbohua3+) on this issue and if everything goes well most models should be scriptable within a few months.

@rosebbb
Copy link

rosebbb commented Aug 4, 2020

Say I want to convert a detectron2 mask-rcnn model to C++ (ideally using torchscript/libtorch), what's the current best approach? I tried various things last week but with no good solution. Things I tried (using recent detectron2, pytorch and torchvision code):

  1. Naively try to convert some of the blocks to torchscript using torch.jit.script. This will fail one various things. Example stacktrace:
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.6/dist-packages/torch/jit/__init__.py", line 1255, in script
    return torch.jit._recursive.recursive_script(obj)
  File "/usr/local/lib/python3.6/dist-packages/torch/jit/_recursive.py", line 534, in recursive_script
    return create_script_module(nn_module, infer_methods_to_compile(nn_module))
  File "/usr/local/lib/python3.6/dist-packages/torch/jit/_recursive.py", line 296, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, cpp_module, stubs)
  File "/usr/local/lib/python3.6/dist-packages/torch/jit/_recursive.py", line 336, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/usr/local/lib/python3.6/dist-packages/torch/jit/__init__.py", line 1593, in _construct
    init_fn(script_module)
  File "/usr/local/lib/python3.6/dist-packages/torch/jit/_recursive.py", line 328, in init_fn
    scripted = recursive_script(orig_value)
  File "/usr/local/lib/python3.6/dist-packages/torch/jit/_recursive.py", line 534, in recursive_script
    return create_script_module(nn_module, infer_methods_to_compile(nn_module))
  File "/usr/local/lib/python3.6/dist-packages/torch/jit/_recursive.py", line 296, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, cpp_module, stubs)
  File "/usr/local/lib/python3.6/dist-packages/torch/jit/_recursive.py", line 340, in create_script_module_impl
    create_methods_from_stubs(concrete_type, stubs)
  File "/usr/local/lib/python3.6/dist-packages/torch/jit/_recursive.py", line 259, in create_methods_from_stubs
    concrete_type._create_methods(defs, rcbs, defaults)
RuntimeError: 
Unknown type name 'torch.nn.SyncBatchNorm':
  File "/detectron2_repo/detectron2/layers/wrappers.py", line 64
            # https://github.com/pytorch/pytorch/issues/12013
            assert not isinstance(
                self.norm, torch.nn.SyncBatchNorm
                           ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
            ), "SyncBatchNorm does not support empty inputs!"

I get similar errors when trying to convert other layers (example, torchscript didn't support del statements in some of the forward passes etc)

  1. Secondly I tried to just copy the weights from detectron2 to a mask-rcnn model defined in pytorch/torchvision. I limited myself to the backbone architecture, example definition in torchvision:

backbone = torchvision.models.detection.backbone_utils.resnet_fpn_backbone('resnet50', True)
Although I succeeded in copying the weights, off course the model.backbone.forward calls ended up giving different results. Most likely due to slightly different definitions of the two (detectron2 vs torchvision) architectures and forward passes.

  1. My third and final try was to use the onnx/caffe2 exporter. This more or less worked (eg, I ended up with a caffe2 model, but I haven't compared the outputs of the models yet), however, only afterwards I realized I couldn't import onnx models into pytorch, and adding caffe2 support for our deployment would be quite cumbersome, since we just switched from caffe to libtorch... To me it seems that exporting to onnx should be quite similar to exporting to torchscript, so maybe it's quite easy to change the caffe2 exporter to torchscript?

Very through try. Did you figure out a way to export the model to an onnx model that can be loaded by other runtime or to an torchscript model?

@hyc-xyz
Copy link

hyc-xyz commented Aug 31, 2020

subscribe the thread

@LESSuseLESS
Copy link

@tkuenzle
Copy link

Thanks a lot for all the amazing work being done on this project, it's appreciated a lot!

I understand that detectron models are currently not scriptable with TorchScript. @ppwwyyxx could you please elaborate on what exactly is missing for making Mask R-CNN and PointRend scriptable? Is it blocked by pytorch/pytorch#36061?

@ppwwyyxx
Copy link
Contributor

pytorch/pytorch#36061 is the main blocker

@tkuenzle
Copy link

tkuenzle commented Oct 30, 2020

Replacing the lists of modules with nn.ModuleList works pretty well and we are able to script the models (although we have to retrain them). Now we are running into pytorch/pytorch#46944 because detectron2 relies on classes like Instances and Boxes, which are not included in the scripted model and thus the model is not runnable in a non-Python environment.

@ppwwyyxx do you think it makes sense to wait for proper support for classes in TorchScript or rather change the implementation of Instances, etc. to be based on e.g. NamedTuple as in pytorch/pytorch#42258?

@ppwwyyxx
Copy link
Contributor

ppwwyyxx commented Nov 1, 2020

I haven't got to that step yet (since we can't break pre-trained models) but I'll go double check the story around scripted classes in C++.
Btw, with latest github version, tracing (with fixed batch size) already works fine under

@contextmanager
def patch_builtin_len():
(except for some postprocessing which is often not used in deployment), and using it in C++ is probably more straightforward.

@tkuenzle
Copy link

tkuenzle commented Nov 2, 2020

It turns out that converting the Instances to a Dict[str, Tensor] as final output does the trick and the model can be run in C++ even though it uses Instances internally. Sorry for the confusion, it looks like everything is working as expected!

@cbasavaraj
Copy link

Hi @tkuenzle Do you gain anything in terms of time per frame with the C++ / libtorch version? For a single frame, or maybe by running multiple C++ threads in parallel?
And would it be possible to either share code or outline the main steps you had to take to make scripting work? Thanks.

@tkuenzle
Copy link

tkuenzle commented Nov 11, 2020

I cannot really comment on time per frame because our focus is on running the model on mobile devices. I don't think sharing code would be that helpful, because it mostly depends on what models you want to script. Thanks to the work of @chenbohua3 most of the heads are scriptable already and thus the effort to make complete models scriptable is rather small.

The main steps you have to take are the following:

  1. Use export_torchscript_with_instances to export your model
  2. Fix any TorchScript errors in the detectron2 repo. This will mainly consist of
    • Replace lists of modules with nn.ModuleList (you will need to retrain the models because of this)
    • Add python type hints for non-tensor arguments
    • Replace some Python expressions which are not supported by TorchScript with equivalent supported expressions
    • Ignore code branches that you do not need by adding assert not torch.jit.is_scripting()
  3. Extract the needed fields of instances in the last layer. You could for example define a wrapper module that takes the original model as input and has the following forward method (assuming you are only interested in pred_masks):
   def forward(input):
       output = self.model(input)
       return [o["instances"].pred_masks for o in output]

I hope this helps!

@ppwwyyxx
Copy link
Contributor

FYI we just added support scripting & tracing for the most common models (R-CNN and RetinaNet). They will export models to torchscript format successfully.
(pytorch built from master branch is required)

There aren't proper APIs & docs yet, but basic usage is now shown in unittests:

class TestScripting(unittest.TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def testMaskRCNN(self):
self._test_rcnn_model("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def testRetinaNet(self):
self._test_retinanet_model("COCO-Detection/retinanet_R_50_FPN_3x.yaml")
def _test_rcnn_model(self, config_path):
model = model_zoo.get(config_path, trained=True)
model.eval()
fields = {
"proposal_boxes": Boxes,
"objectness_logits": Tensor,
"pred_boxes": Boxes,
"scores": Tensor,
"pred_classes": Tensor,
"pred_masks": Tensor,
}
script_model = export_torchscript_with_instances(model, fields)
inputs = [{"image": get_sample_coco_image()}]
with torch.no_grad():
instance = model.inference(inputs, do_postprocess=False)[0]
scripted_instance = script_model.inference(inputs, do_postprocess=False)[
0
].to_instances()
assert_instances_allclose(instance, scripted_instance)
def _test_retinanet_model(self, config_path):
model = model_zoo.get(config_path, trained=True)
model.eval()
fields = {
"pred_boxes": Boxes,
"scores": Tensor,
"pred_classes": Tensor,
}
script_model = export_torchscript_with_instances(model, fields)
img = get_sample_coco_image()
inputs = [{"image": img}]
with torch.no_grad():
instance = model(inputs)[0]["instances"]
scripted_instance = script_model(inputs)[0].to_instances()
scripted_instance = detector_postprocess(scripted_instance, img.shape[1], img.shape[2])
assert_instances_allclose(instance, scripted_instance)
@unittest.skipIf(
os.environ.get("CIRCLECI") or TORCH_VERSION < (1, 8), "Insufficient Pytorch version"
)
class TestTracing(unittest.TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def testMaskRCNN(self):
class WrapModel(nn.ModuleList):
def forward(self, image):
inputs = [{"image": image}]
outputs = self[0].inference(inputs, do_postprocess=False)[0]
size = outputs.image_size
if torch.jit.is_tracing():
assert isinstance(size, torch.Tensor)
else:
size = torch.as_tensor(size)
return (
size,
outputs.pred_classes,
outputs.pred_boxes.tensor,
outputs.scores,
outputs.pred_masks,
)
@staticmethod
def convert_output(output):
r = Instances(tuple(output[0]))
r.pred_classes = output[1]
r.pred_boxes = Boxes(output[2])
r.scores = output[3]
r.pred_masks = output[4]
return r
self._test_model("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml", WrapModel)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def testRetinaNet(self):
class WrapModel(nn.ModuleList):
def forward(self, image):
inputs = [{"image": image}]
outputs = self[0].forward(inputs)[0]["instances"]
size = outputs.image_size
if torch.jit.is_tracing():
assert isinstance(size, torch.Tensor)
else:
size = torch.as_tensor(size)
return (
size,
outputs.pred_classes,
outputs.pred_boxes.tensor,
outputs.scores,
)
@staticmethod
def convert_output(output):
r = Instances(tuple(output[0]))
r.pred_classes = output[1]
r.pred_boxes = Boxes(output[2])
r.scores = output[3]
return r
self._test_model("COCO-Detection/retinanet_R_50_FPN_3x.yaml", WrapModel)
def _test_model(self, config_path, WrapperCls):
# TODO wrapper should be handled by export API in the future
model = model_zoo.get(config_path, trained=True)
image = get_sample_coco_image()
model = WrapperCls([model])
model.eval()
with torch.no_grad(), patch_builtin_len():
small_image = nn.functional.interpolate(image, scale_factor=0.5)
# trace with a different image, and the trace must still work
traced_model = torch.jit.trace(model, (small_image,))
output = WrapperCls.convert_output(model(image))
traced_output = WrapperCls.convert_output(traced_model(image))
assert_instances_allclose(output, traced_output)

@tkuenzle
Copy link

Thanks a lot, that's great news @ppwwyyxx! Would you be willing to accept PRs for making some of the other models scriptable?

@danielgordon10
Copy link

@ppwwyyxx When I run the test, I get this error
(Python 3.6.9, torch 1.8.0.dev20201110

RuntimeError:
Module 'ResNet' has no attribute 'stages' (This attribute exists on the Python module, but we failed to convert Python type: 'list' to a TorchScript type.):
  File "detectron2/modeling/backbone/resnet.py", line 437
        if "stem" in self._out_features:
            outputs["stem"] = x
        for name, stage in zip(self.stage_names, self.stages):
                                                 ~~~~~~~~~~~ <--- HERE
            x = stage(x)
            if name in self._out_features:

Seems to be related to the issue above. Is there something I'm supposed to do to preprocess the models so they don't have lists and instead have ModuleLists?

@danielgordon10
Copy link

If I add

model.backbone.bottom_up.stages = nn.ModuleList(model.backbone.bottom_up.stages)
model.backbone.lateral_convs = nn.ModuleList(model.backbone.lateral_convs)
model.backbone.output_convs = nn.ModuleList(model.backbone.output_convs)

it seems to work, but only for a single image. Does batched mode not work yet?

@ppwwyyxx
Copy link
Contributor

@danielgordon10 your pytorch is still not new enough.

@danielgordon10
Copy link

@ppwwyyxx What's the minimum pytorch version? That was yesterday's nightly.

@ppwwyyxx ppwwyyxx changed the title TorchScript / C++ Examples? Convert models to TorchScript Nov 11, 2020
@ppwwyyxx
Copy link
Contributor

ppwwyyxx commented Nov 11, 2020

It now requires yesterday's pytorch commits which are supposed to be in today's nightly.
Once a few other ongoing pytorch features are implemented we expect to require them as well.

I'm closing this issue because the scope is too general (also renaming it so it only involves torchscript) and majority of work is done. There are some remaining TODOs about usability that should be addressed as separate issues:

Thanks a lot to pytorch JIT team and @chenbohua3 @bddpqq from Alibaba for making this happen!

@Muratoter
Copy link

Muratoter commented Feb 26, 2021

FYI we just added support scripting & tracing for the most common models (R-CNN and RetinaNet). They will export models to torchscript format successfully.
(pytorch built from master branch is required)

There aren't proper APIs & docs yet, but basic usage is now shown in unittests:

class TestScripting(unittest.TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def testMaskRCNN(self):
self._test_rcnn_model("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def testRetinaNet(self):
self._test_retinanet_model("COCO-Detection/retinanet_R_50_FPN_3x.yaml")
def _test_rcnn_model(self, config_path):
model = model_zoo.get(config_path, trained=True)
model.eval()
fields = {
"proposal_boxes": Boxes,
"objectness_logits": Tensor,
"pred_boxes": Boxes,
"scores": Tensor,
"pred_classes": Tensor,
"pred_masks": Tensor,
}
script_model = export_torchscript_with_instances(model, fields)
inputs = [{"image": get_sample_coco_image()}]
with torch.no_grad():
instance = model.inference(inputs, do_postprocess=False)[0]
scripted_instance = script_model.inference(inputs, do_postprocess=False)[
0
].to_instances()
assert_instances_allclose(instance, scripted_instance)
def _test_retinanet_model(self, config_path):
model = model_zoo.get(config_path, trained=True)
model.eval()
fields = {
"pred_boxes": Boxes,
"scores": Tensor,
"pred_classes": Tensor,
}
script_model = export_torchscript_with_instances(model, fields)
img = get_sample_coco_image()
inputs = [{"image": img}]
with torch.no_grad():
instance = model(inputs)[0]["instances"]
scripted_instance = script_model(inputs)[0].to_instances()
scripted_instance = detector_postprocess(scripted_instance, img.shape[1], img.shape[2])
assert_instances_allclose(instance, scripted_instance)
@unittest.skipIf(
os.environ.get("CIRCLECI") or TORCH_VERSION < (1, 8), "Insufficient Pytorch version"
)
class TestTracing(unittest.TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def testMaskRCNN(self):
class WrapModel(nn.ModuleList):
def forward(self, image):
inputs = [{"image": image}]
outputs = self[0].inference(inputs, do_postprocess=False)[0]
size = outputs.image_size
if torch.jit.is_tracing():
assert isinstance(size, torch.Tensor)
else:
size = torch.as_tensor(size)
return (
size,
outputs.pred_classes,
outputs.pred_boxes.tensor,
outputs.scores,
outputs.pred_masks,
)
@staticmethod
def convert_output(output):
r = Instances(tuple(output[0]))
r.pred_classes = output[1]
r.pred_boxes = Boxes(output[2])
r.scores = output[3]
r.pred_masks = output[4]
return r
self._test_model("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml", WrapModel)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def testRetinaNet(self):
class WrapModel(nn.ModuleList):
def forward(self, image):
inputs = [{"image": image}]
outputs = self[0].forward(inputs)[0]["instances"]
size = outputs.image_size
if torch.jit.is_tracing():
assert isinstance(size, torch.Tensor)
else:
size = torch.as_tensor(size)
return (
size,
outputs.pred_classes,
outputs.pred_boxes.tensor,
outputs.scores,
)
@staticmethod
def convert_output(output):
r = Instances(tuple(output[0]))
r.pred_classes = output[1]
r.pred_boxes = Boxes(output[2])
r.scores = output[3]
return r
self._test_model("COCO-Detection/retinanet_R_50_FPN_3x.yaml", WrapModel)
def _test_model(self, config_path, WrapperCls):
# TODO wrapper should be handled by export API in the future
model = model_zoo.get(config_path, trained=True)
image = get_sample_coco_image()
model = WrapperCls([model])
model.eval()
with torch.no_grad(), patch_builtin_len():
small_image = nn.functional.interpolate(image, scale_factor=0.5)
# trace with a different image, and the trace must still work
traced_model = torch.jit.trace(model, (small_image,))
output = WrapperCls.convert_output(model(image))
traced_output = WrapperCls.convert_output(traced_model(image))
assert_instances_allclose(output, traced_output)

I was able to successfully convert the model, thank you. But when I use the model in an android project, I get the following error;

2021-02-27 02:25:36.537 21060-21060/org.pytorch.demo.imagesegmentation E/AndroidRuntime: FATAL EXCEPTION: main
    Process: org.pytorch.demo.imagesegmentation, PID: 21060
    java.lang.RuntimeException: Unable to start activity ComponentInfo{org.pytorch.demo.imagesegmentation/org.pytorch.imagesegmentation.MainActivity}: com.facebook.jni.CppException: 
    Unknown builtin op: torchvision::nms.
    Could not find any similar ops to torchvision::nms. This op may not exist or may not be currently supported in TorchScript.
    :
      File "/usr/local/lib/python3.7/dist-packages/torchvision/ops/boxes.py", line 42
        """
        _assert_has_ops()
        return torch.ops.torchvision.nms(boxes, scores, iou_threshold)
               ~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    Serialized   File "code/__torch__/torchvision/ops/boxes.py", line 26
      _8 = __torch__.torchvision.extension._assert_has_ops
      _9 = _8()
      _10 = ops.torchvision.nms(boxes, scores, iou_threshold)
            ~~~~~~~~~~~~~~~~~~~ <--- HERE
      return _10
    

@fmassa
Copy link

fmassa commented Apr 10, 2021

@Muratoter please follow the instructions in https://github.com/pytorch/android-demo-app/tree/master/D2Go to get detectron2 models running on Android

Note that you need to add

implementation 'org.pytorch:torchvision_ops:0.9.0'

to your build.gradle file

@sctrueew
Copy link

sctrueew commented Apr 23, 2021

Hello,

I trained a model and I converted the model to model.ts successfully. Can we use it in windows 10? I get an error when loading the model.

error on this line:

torch::jit::load("model.ts")

@WarriorMmb
Copy link

Hello all I have solved the torchscript integration with accurate result

@SorourMo
Copy link

Hello,

I trained a model and I converted the model to model.ts successfully. Can we use it in windows 10? I get an error when loading the model.

error on this line:

torch::jit::load("model.ts")

@sctrueew
Hi. I've got the same problem on Windows 10. I've tried two pytorch/libtorch versions (retrained the model for each) with cuda 10, 10.2 and 11.3. None of them worked. Could you share how you installed torchvision c++ on windows?

niqbal996 pushed a commit to niqbal996/detectron2 that referenced this issue Jun 15, 2023
* add two stage

* update two stage with warmup

* update warmup

* update model init
niqbal996 pushed a commit to niqbal996/detectron2 that referenced this issue Jun 15, 2023
* add two stage dab deformable detr

* update two stage criterion

* dino

* Add two-stage dab-deformable-detr (facebookresearch#46)

* add two stage

* update two stage with warmup

* update warmup

* update model init

* refine dab-deformable-two-stage model config

* refine dino project

* delete redundant files

* add readme for dino

* refine dino config

Co-authored-by: SlongLiu <slongliu86@gmail.com>
Co-authored-by: hao zhang <zhanghao@dgx061.scc.idea>
Co-authored-by: Shilong Liu <34858619+SlongLiu@users.noreply.github.com>
Co-authored-by: ntianhe ren <rentianhe@dgx061.scc.idea>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Improvements or good new features
Projects
None yet
Development

No branches or pull requests