Skip to content

Commit

Permalink
Fix Exporting ONNX Model with Fixed Batch Size of 1 Using export_tens…
Browse files Browse the repository at this point in the history
…orrt_engine (#508)
  • Loading branch information
laugh12321 committed Dec 31, 2023
1 parent 672ae82 commit fbfef69
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion yolort/relay/trt_graphsurgeon.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ def __init__(
logger.info(f"Loaded saved model from {model_path}")

if input_sample is not None:
self.batch_size = input_sample.shape[0]
input_sample = input_sample.to(device=device)
else:
self.batch_size = 1
model_path = model_path.with_suffix(".onnx")
model.to_onnx(model_path, input_sample=input_sample, enable_dynamic=enable_dynamic)
logger.info("PyTorch2ONNX graph created successfully")
Expand All @@ -85,7 +88,6 @@ def __init__(

# Fold constants via ONNX-GS that PyTorch2ONNX may have missed
self.graph.fold_constants()
self.batch_size = 1
self.precision = precision
self.simplify = simplify

Expand Down

0 comments on commit fbfef69

Please sign in to comment.