Skip to content

Commit

Permalink
Fix AOTI CPP GEMM Template issue without freezing (pytorch#136421)
Browse files Browse the repository at this point in the history
**Summary**
Fix issue: pytorch#135106. For AOTI, there is the Inductor IR of weight
```
ReinterpretView(
  StorageBox(
    ConstantBuffer(name='L__self___mlp_0_weight', layout=FixedLayout('cpu', torch.float32, size=[64, 128], stride=[128, 1]))
  ),
  FixedLayout('cpu', torch.float32, size=[128, 64], stride=[1, 128]),
  origins=OrderedSet([addmm])
)
```
In the post-processing step of the GEMM template, the used weight was before permutation, leading to correctness issues. In this PR, we address this by reshaping the weight to the expected size and stride before the weight prepack.

**Test Plan**
```
python -u -m pytest -s -v test/inductor/test_aot_inductor.py -k test_misc_1_max_autotune_True_non_abi_compatible_cpu
python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_aoti_linear
python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_aoti_linear_multi_view_operations
```

Pull Request resolved: pytorch#136421
Approved by: https://github.com/jgong5, https://github.com/desertfire
  • Loading branch information
leslie-fang-intel authored and pytorchmergebot committed Oct 9, 2024
1 parent be0b752 commit 0b8048c
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 11 deletions.
4 changes: 1 addition & 3 deletions test/inductor/test_aot_inductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3313,9 +3313,7 @@ def forward(self, values, offsets):
model, example_inputs_list, dynamic_shapes=dynamic_shapes
)

# max_autotune is disabled due to https://github.com/pytorch/pytorch/issues/135106
# @common_utils.parametrize("max_autotune", [False, True])
@common_utils.parametrize("max_autotune", [False])
@common_utils.parametrize("max_autotune", [True, False])
def test_misc_1(self, max_autotune):
if self.device == "cpu" and IS_MACOS and max_autotune:
raise unittest.SkipTest("max_autotune not supported on macos")
Expand Down
105 changes: 105 additions & 0 deletions test/inductor/test_cpu_select_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1623,6 +1623,111 @@ def forward(self, x):
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)

@inductor_config.patch({"freezing": False})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("batch_size", (16,))
@parametrize("in_features", (128,))
@parametrize("out_features", (64,))
@parametrize("bias", (True,))
@dtypes(
torch.float,
)
def test_aoti_linear(self, batch_size, in_features, out_features, bias, dtype):
try:
try:
from . import test_aot_inductor_utils
except ImportError:
import test_aot_inductor_utils
except Exception:
# skip this UT if import failed
return

class M(torch.nn.Module):
def __init__(self, bias=bias) -> None:
super().__init__()
self.mlp = torch.nn.Sequential(
torch.nn.Linear(in_features, out_features, bias=bias),
torch.nn.ReLU(),
)

def forward(self, x):
return self.mlp(x)

assert torch._inductor.config.freezing is False

counters.clear()
v = torch.randn(batch_size, in_features).to(dtype=dtype)
mod = M(bias=bias).to(dtype=dtype).eval()
torch._dynamo.reset()
torch._inductor.metrics.reset()
torch.manual_seed(0)
with verify(dtype) as (atol, rtol), torch.no_grad():
expected = mod(v)
actual = test_aot_inductor_utils.AOTIRunnerUtil.run(
"cpu",
mod,
(v,),
)
self.assertEqual(actual, expected, atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)

@inductor_config.patch({"freezing": False})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("batch_size", (16,))
@parametrize("in_features", (128,))
@parametrize("out_features", (64,))
@dtypes(
torch.float,
)
def test_aoti_linear_multi_view_operations(
self, batch_size, in_features, out_features, dtype
):
try:
try:
from . import test_aot_inductor_utils
except ImportError:
import test_aot_inductor_utils
except Exception:
# skip this UT if import failed
return

class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.bias = torch.randn(out_features)
self.weight = torch.randn(out_features // 2, 2, in_features)
self.relu = torch.nn.ReLU()

def forward(self, x):
tmp = torch.addmm(
self.bias,
x,
self.weight.permute(2, 0, 1).view(in_features, out_features),
)
return self.relu(tmp)

assert torch._inductor.config.freezing is False

counters.clear()
v = torch.randn(batch_size, in_features).to(dtype=dtype)
mod = M().to(dtype=dtype).eval()
torch._dynamo.reset()
torch._inductor.metrics.reset()
torch.manual_seed(0)
with verify(dtype) as (atol, rtol), torch.no_grad():
expected = mod(v)
actual = test_aot_inductor_utils.AOTIRunnerUtil.run(
"cpu",
mod,
(v,),
)
self.assertEqual(actual, expected, atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)


@dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False})
class _DynamicShapesTestBase(BaseTestSelectAlgorithm):
Expand Down
34 changes: 26 additions & 8 deletions torch/_inductor/codegen/cpp_gemm_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,19 @@ def reorder_and_filter(inputs, layout_or_out):
assert len(input_indices) >= 2
return [inputs[idx] for idx in input_indices], layout_or_out

new_inputs, new_layout = reorder_and_filter(input_nodes, layout)
assert new_inputs[1].get_name() in V.graph.constants
is_mkldnn_wgt = V.graph.constants[new_inputs[1].get_name()].is_mkldnn
if is_mkldnn_wgt:
# It shouldn't happen as viewing an mkldnn tensor, we can extend the
# implementation if it does.
assert not isinstance(new_inputs[1], ir.BaseView)
assert isinstance(new_inputs[1].layout, ir.FixedLayout)
# Note that the layout of MKLDNN Tensor is with the wrong stride
view_size = new_inputs[1].layout.size
view_stride = new_inputs[1].layout.stride
view_offset = new_inputs[1].layout.offset

def maybe_to_dense(inputs, layout_or_out):
new_inputs = list(inputs)
if isinstance(inputs[1], torch.Tensor):
Expand All @@ -563,12 +576,19 @@ def maybe_to_dense(inputs, layout_or_out):
return new_inputs, layout_or_out

def normalize_shapes(inputs, layout_or_out):
if not trans_w:
return inputs, layout_or_out
new_inputs = list(inputs)
X = inputs[0]
W = inputs[1]
B = inputs[2] if has_bias else None
if not is_mkldnn_wgt and isinstance(new_inputs[1], torch.Tensor):
# With the assumptation that W is the storage of unwrap view
# thus view it back here
new_inputs[1] = new_inputs[1].as_strided(
view_size, view_stride, view_offset
)

if not trans_w:
return new_inputs, layout_or_out
X = new_inputs[0]
W = new_inputs[1]
B = new_inputs[2] if has_bias else None
if isinstance(W, ir.IRNode):
if trans_w:
if not isinstance(W, ir.TensorBox):
Expand All @@ -593,9 +613,7 @@ def normalize_shapes(inputs, layout_or_out):

# TODO(jgong5): decide proper number of threads per problem size
num_threads = parallel_num_threads()
new_inputs, _ = normalize_shapes(
*maybe_to_dense(*reorder_and_filter(input_nodes, layout))
)
new_inputs, _ = normalize_shapes(*maybe_to_dense(new_inputs, new_layout))
m, n, k, *_ = mm_args(new_inputs[0], new_inputs[1])
output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype(
new_inputs[0].get_dtype()
Expand Down

0 comments on commit 0b8048c

Please sign in to comment.