From 846334816de2580f254ff9ad8d9f17599d0f319e Mon Sep 17 00:00:00 2001 From: Laugh Date: Tue, 2 Jan 2024 19:55:20 +0800 Subject: [PATCH] Fix exporting ONNX model with fixed batch sizes (#508) (#509) --- yolort/relay/trt_graphsurgeon.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/yolort/relay/trt_graphsurgeon.py b/yolort/relay/trt_graphsurgeon.py index 38d33e2c..8941755b 100644 --- a/yolort/relay/trt_graphsurgeon.py +++ b/yolort/relay/trt_graphsurgeon.py @@ -75,7 +75,10 @@ def __init__( logger.info(f"Loaded saved model from {model_path}") if input_sample is not None: + self.batch_size = input_sample.shape[0] input_sample = input_sample.to(device=device) + else: + self.batch_size = 1 model_path = model_path.with_suffix(".onnx") model.to_onnx(model_path, input_sample=input_sample, enable_dynamic=enable_dynamic) logger.info("PyTorch2ONNX graph created successfully") @@ -85,7 +88,6 @@ def __init__( # Fold constants via ONNX-GS that PyTorch2ONNX may have missed self.graph.fold_constants() - self.batch_size = 1 self.precision = precision self.simplify = simplify