Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: fix ValueRanges computation in symbolic nodes #2918

Merged
merged 4 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1779,6 +1779,7 @@ def aten_ops_add(
)


@dynamo_tensorrt_converter(operator.mul, supports_dynamic_shapes=True)
peri044 marked this conversation as resolved.
Show resolved Hide resolved
@dynamo_tensorrt_converter(torch.ops.aten.mul.Tensor, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.mul.Scalar, supports_dynamic_shapes=True)
def aten_ops_mul(
Expand Down
8 changes: 6 additions & 2 deletions py/torch_tensorrt/dynamo/partitioning/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,12 @@ def construct_dynamic_input(
node = dim.node
expr = node.expr
shape_env = node.shape_env
var_range = shape_env.var_to_range.get(expr, None)
var_val = shape_env.var_to_val.get(expr, None)
var_range = shape_env.var_to_range.get(expr, None) or shape_env.bound_sympy(
expr
)
var_val = shape_env.var_to_val.get(expr, None) or expr.xreplace(
shape_env.var_to_val
)
peri044 marked this conversation as resolved.
Show resolved Hide resolved
assert var_range, var_val
# Torchdynamo 0/1 specialization outlier
if var_range.lower == 2:
Expand Down
75 changes: 48 additions & 27 deletions tests/py/dynamo/models/test_dyn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,6 @@ def forward(self, x):
cos_sim > COSINE_THRESHOLD,
msg=f"test_dyn_full_compile model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)
# Clean up model env
torch._dynamo.reset()

with torch.no_grad():
torch.cuda.empty_cache()


@unittest.skip(
Expand Down Expand Up @@ -128,12 +123,6 @@ def forward(self, x):
msg=f"test_base_dynamic_fallback model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

# Clean up model env
torch._dynamo.reset()

with torch.no_grad():
torch.cuda.empty_cache()


@pytest.mark.unit
def test_view(ir):
Expand Down Expand Up @@ -185,12 +174,6 @@ def forward(self, x):
msg=f"test_view model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

# Clean up model env
torch._dynamo.reset()

with torch.no_grad():
torch.cuda.empty_cache()


@pytest.mark.unit
def test_resnet_dynamic(ir):
Expand Down Expand Up @@ -234,12 +217,6 @@ def test_resnet_dynamic(ir):
msg=f"test_resnet_dynamic model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

# Clean up model env
torch._dynamo.reset()

with torch.no_grad():
torch.cuda.empty_cache()


@pytest.mark.unit
def test_view(ir):
Expand Down Expand Up @@ -284,8 +261,52 @@ def forward(self, x):
msg=f"test_base_dynamic model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

# Clean up model env
torch._dynamo.reset()

with torch.no_grad():
torch.cuda.empty_cache()
@pytest.mark.unit
def test_linear(ir):
"""
Tests the model with linear op and operator.mul (added internally by PyTorch)
with dynamic shapes
"""

class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(10, 10)

def forward(self, x):
return self.linear1(x)

model = MyModule().eval().cuda()

compile_spec = {
"device": torchtrt.Device("cuda:0"),
"enabled_precisions": {torch.float},
"ir": ir,
"min_block_size": 1,
}
inputs_bs2 = torch.randn(2, 2, 10).to("cuda")
if ir == "torch_compile":
torch._dynamo.mark_dynamic(inputs_bs2, 0, min=1, max=10)
torch._dynamo.mark_dynamic(inputs_bs2, 1, min=1, max=10)
# Compile the model
trt_model = torch.compile(model, backend="tensorrt", options=compile_spec)
trt_model(inputs_bs2)
elif ir == "dynamo":
dynamic_shapes = (
{
0: torch.export.Dim("batch_size", min=1, max=10),
1: torch.export.Dim("seq_len", max=10),
},
)
exp_program = torch.export.export(
model, (inputs_bs2,), dynamic_shapes=dynamic_shapes
)
trt_model = torchtrt.dynamo.compile(exp_program, [inputs_bs2], **compile_spec)

input_bs6_s3 = torch.randn((6, 3, 10)).to("cuda")
cos_sim = cosine_similarity(model(input_bs6_s3), trt_model(input_bs6_s3))
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_linear model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)
Loading