diff --git a/yolort/relay/trt_graphsurgeon.py b/yolort/relay/trt_graphsurgeon.py index 38d33e2c..8941755b 100644 --- a/yolort/relay/trt_graphsurgeon.py +++ b/yolort/relay/trt_graphsurgeon.py @@ -75,7 +75,10 @@ def __init__( logger.info(f"Loaded saved model from {model_path}") if input_sample is not None: + self.batch_size = input_sample.shape[0] input_sample = input_sample.to(device=device) + else: + self.batch_size = 1 model_path = model_path.with_suffix(".onnx") model.to_onnx(model_path, input_sample=input_sample, enable_dynamic=enable_dynamic) logger.info("PyTorch2ONNX graph created successfully") @@ -85,7 +88,6 @@ def __init__( # Fold constants via ONNX-GS that PyTorch2ONNX may have missed self.graph.fold_constants() - self.batch_size = 1 self.precision = precision self.simplify = simplify