Skip to content

Commit

Permalink
Fix exporting ONNX model with fixed batch sizes (#508) (#509)
Browse files Browse the repository at this point in the history
  • Loading branch information
laugh12321 committed Jan 2, 2024
1 parent 672ae82 commit 8463348
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion yolort/relay/trt_graphsurgeon.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ def __init__(
logger.info(f"Loaded saved model from {model_path}")

if input_sample is not None:
self.batch_size = input_sample.shape[0]
input_sample = input_sample.to(device=device)
else:
self.batch_size = 1
model_path = model_path.with_suffix(".onnx")
model.to_onnx(model_path, input_sample=input_sample, enable_dynamic=enable_dynamic)
logger.info("PyTorch2ONNX graph created successfully")
Expand All @@ -85,7 +88,6 @@ def __init__(

# Fold constants via ONNX-GS that PyTorch2ONNX may have missed
self.graph.fold_constants()
self.batch_size = 1
self.precision = precision
self.simplify = simplify

Expand Down

0 comments on commit 8463348

Please sign in to comment.