Skip to content

Commit

Permalink
fix(torch): correct output name for onnx classification model
Browse files Browse the repository at this point in the history
  • Loading branch information
Bycob authored and mergify[bot] committed Dec 7, 2021
1 parent 8a34aa1 commit a03eb87
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion tools/torch/trace_torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
parser.add_argument('--print-models', action='store_true', help="Print all the available models names and exit")
parser.add_argument('--to-dd-native', action='store_true', help="Prepare the model so that the weights can be loaded on native model with dede")
parser.add_argument('--to-onnx', action="store_true", help="If specified, export to onnx instead of jit.")
parser.add_argument('--onnx_out', type=str, default="prob", help="Name of onnx output")
parser.add_argument('--weights', type=str, help="If not None, these weights will be embedded in the model before exporting")
parser.add_argument('-a', "--all", action='store_true', help="Export all available models")
parser.add_argument('-v', "--verbose", action='store_true', help="Set logging level to INFO")
Expand Down Expand Up @@ -336,11 +337,13 @@ def get_detection_input(batch_size=1, img_width=224, img_height=224):
# remove extension
filename = filename[:-3] + ".onnx"
example = get_image_input(args.batch_size, args.img_width, args.img_height)

# change for detection
torch.onnx.export(
model, example, filename,
export_params=True, verbose=args.verbose,
opset_version=11, do_constant_folding=True,
input_names=["input"], output_names=["output"])
input_names=["input"], output_names=[args.onnx_out])
# dynamic_axes={"input":{0:"batch_size"},"output":{0:"batch_size"}}
else:
logging.info("Saving to %s", filename)
Expand Down

0 comments on commit a03eb87

Please sign in to comment.