From e72148fe545866c66b9c2d96997b8efc414ba6cc Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 13 Jun 2024 12:33:10 -0700 Subject: [PATCH 1/3] chore: fix ValueRanges computation in symbolic nodes --- .../dynamo/conversion/aten_ops_converters.py | 1 + .../dynamo/partitioning/common.py | 8 ++- tests/py/dynamo/models/test_dyn_models.py | 62 +++++++++++++++++-- 3 files changed, 63 insertions(+), 8 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 8846497348..9aaf2fa3f9 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1752,6 +1752,7 @@ def aten_ops_add( ) +@dynamo_tensorrt_converter(operator.mul, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.mul.Tensor, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.mul.Scalar, supports_dynamic_shapes=True) def aten_ops_mul( diff --git a/py/torch_tensorrt/dynamo/partitioning/common.py b/py/torch_tensorrt/dynamo/partitioning/common.py index 9ac677484f..8350b027b6 100644 --- a/py/torch_tensorrt/dynamo/partitioning/common.py +++ b/py/torch_tensorrt/dynamo/partitioning/common.py @@ -35,8 +35,12 @@ def construct_dynamic_input( node = dim.node expr = node.expr shape_env = node.shape_env - var_range = shape_env.var_to_range.get(expr, None) - var_val = shape_env.var_to_val.get(expr, None) + var_range = shape_env.var_to_range.get(expr, None) or shape_env.bound_sympy( + expr + ) + var_val = shape_env.var_to_val.get(expr, None) or expr.xreplace( + shape_env.var_to_val + ) assert var_range, var_val # Torchdynamo 0/1 specialization outlier if var_range.lower == 2: diff --git a/tests/py/dynamo/models/test_dyn_models.py b/tests/py/dynamo/models/test_dyn_models.py index 4c6b98e555..d63dc96bae 100644 --- a/tests/py/dynamo/models/test_dyn_models.py +++ b/tests/py/dynamo/models/test_dyn_models.py @@ -11,9 +11,9 @@ assertions = unittest.TestCase() -@unittest.skip( - "Skipping this test for now due to constraint violation error: https://github.com/pytorch/TensorRT/issues/2794" -) +# @unittest.skip( +# "Skipping this test for now due to constraint violation error: https://github.com/pytorch/TensorRT/issues/2794" +# ) @pytest.mark.unit def test_base_dynamic(ir): """ @@ -71,9 +71,9 @@ def forward(self, x): torch.cuda.empty_cache() -@unittest.skip( - "Skipping this test for now due to constraint violation error: https://github.com/pytorch/TensorRT/issues/2794" -) +# @unittest.skip( +# "Skipping this test for now due to constraint violation error: https://github.com/pytorch/TensorRT/issues/2794" +# ) @pytest.mark.unit def test_base_dynamic_fallback(ir): """ @@ -289,3 +289,53 @@ def forward(self, x): with torch.no_grad(): torch.cuda.empty_cache() + + +@pytest.mark.unit +def test_linear(ir): + """ + Tests the model with linear op and operator.mul (added internally by PyTorch) + with dynamic shapes + """ + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(10, 10) + + def forward(self, x): + return self.linear1(x) + + model = MyModule().eval().cuda() + + compile_spec = { + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "ir": ir, + "min_block_size": 1, + } + inputs_bs2 = torch.randn(2, 2, 10).to("cuda") + if ir == "torch_compile": + torch._dynamo.mark_dynamic(inputs_bs2, 0, min=1, max=10) + torch._dynamo.mark_dynamic(inputs_bs2, 1, min=1, max=10) + # Compile the model + trt_model = torch.compile(model, backend="tensorrt", options=compile_spec) + trt_model(inputs_bs2) + elif ir == "dynamo": + dynamic_shapes = ( + { + 0: torch.export.Dim("batch_size", min=1, max=10), + 1: torch.export.Dim("seq_len", max=10), + }, + ) + exp_program = torch.export.export( + model, (inputs_bs2,), dynamic_shapes=dynamic_shapes + ) + trt_model = torchtrt.dynamo.compile(exp_program, [inputs_bs2], **compile_spec) + + input_bs6_s3 = torch.randn((6, 3, 10)).to("cuda") + cos_sim = cosine_similarity(model(input_bs6_s3), trt_model(input_bs6_s3)) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_linear model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) From 295f0fcfb977970a4db618761df65c397199bca8 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 13 Jun 2024 12:41:26 -0700 Subject: [PATCH 2/3] chore: clean up test script --- tests/py/dynamo/models/test_dyn_models.py | 41 ++++------------------- 1 file changed, 6 insertions(+), 35 deletions(-) diff --git a/tests/py/dynamo/models/test_dyn_models.py b/tests/py/dynamo/models/test_dyn_models.py index d63dc96bae..3fd34de2ea 100644 --- a/tests/py/dynamo/models/test_dyn_models.py +++ b/tests/py/dynamo/models/test_dyn_models.py @@ -11,9 +11,9 @@ assertions = unittest.TestCase() -# @unittest.skip( -# "Skipping this test for now due to constraint violation error: https://github.com/pytorch/TensorRT/issues/2794" -# ) +@unittest.skip( + "Skipping this test for now due to constraint violation error: https://github.com/pytorch/TensorRT/issues/2794" +) @pytest.mark.unit def test_base_dynamic(ir): """ @@ -64,16 +64,11 @@ def forward(self, x): cos_sim > COSINE_THRESHOLD, msg=f"test_dyn_full_compile model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) - # Clean up model env - torch._dynamo.reset() - with torch.no_grad(): - torch.cuda.empty_cache() - -# @unittest.skip( -# "Skipping this test for now due to constraint violation error: https://github.com/pytorch/TensorRT/issues/2794" -# ) +@unittest.skip( + "Skipping this test for now due to constraint violation error: https://github.com/pytorch/TensorRT/issues/2794" +) @pytest.mark.unit def test_base_dynamic_fallback(ir): """ @@ -128,12 +123,6 @@ def forward(self, x): msg=f"test_base_dynamic_fallback model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) - # Clean up model env - torch._dynamo.reset() - - with torch.no_grad(): - torch.cuda.empty_cache() - @pytest.mark.unit def test_view(ir): @@ -185,12 +174,6 @@ def forward(self, x): msg=f"test_view model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) - # Clean up model env - torch._dynamo.reset() - - with torch.no_grad(): - torch.cuda.empty_cache() - @pytest.mark.unit def test_resnet_dynamic(ir): @@ -234,12 +217,6 @@ def test_resnet_dynamic(ir): msg=f"test_resnet_dynamic model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) - # Clean up model env - torch._dynamo.reset() - - with torch.no_grad(): - torch.cuda.empty_cache() - @pytest.mark.unit def test_view(ir): @@ -284,12 +261,6 @@ def forward(self, x): msg=f"test_base_dynamic model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) - # Clean up model env - torch._dynamo.reset() - - with torch.no_grad(): - torch.cuda.empty_cache() - @pytest.mark.unit def test_linear(ir): From c1947f60c5742af117766c0024d37af97c07b898 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Fri, 14 Jun 2024 16:00:58 -0700 Subject: [PATCH 3/3] chore: add comments --- py/torch_tensorrt/dynamo/partitioning/common.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/py/torch_tensorrt/dynamo/partitioning/common.py b/py/torch_tensorrt/dynamo/partitioning/common.py index 8350b027b6..fdc55126ee 100644 --- a/py/torch_tensorrt/dynamo/partitioning/common.py +++ b/py/torch_tensorrt/dynamo/partitioning/common.py @@ -35,6 +35,10 @@ def construct_dynamic_input( node = dim.node expr = node.expr shape_env = node.shape_env + # An expr can be a independent SymInt node (eg: s0 or s1) or a composition of them eg: (48*s0 or s0*s1). + # In the case of expr which has symbolic computation, bound_sympy evaluates them. + # https://pytorch.org/docs/stable/generated/torch.fx.experimental.symbolic_shapes.ShapeEnv.html#torch.fx.experimental.symbolic_shapes.ShapeEnv.bound_sympy + # expr.xreplace replaces the symbolic variables with their current values and computes the expression. var_range = shape_env.var_to_range.get(expr, None) or shape_env.bound_sympy( expr )