Skip to content

Commit

Permalink
feat: support adaptive avg pool2d and pool3d dynamo converters
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 committed Mar 21, 2024
1 parent db11ce6 commit e2e0aa5
Show file tree
Hide file tree
Showing 3 changed files with 225 additions and 3 deletions.
40 changes: 39 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2185,7 +2185,7 @@ def aten_ops_avg_pool(


@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool1d.default)
def aten_ops_adaptive_avg_pool(
def aten_ops_adaptive_avg_pool1d(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
Expand All @@ -2202,6 +2202,44 @@ def aten_ops_adaptive_avg_pool(
)


def adaptive_pool_static_input_validator(pool_node: Node) -> bool:
output_size = args_bounds_check(pool_node.args, 1)
return all([x > 0 for x in output_size])


@dynamo_tensorrt_converter(
torch.ops.aten.adaptive_avg_pool2d.default,
capability_validator=adaptive_pool_static_input_validator,
)
@dynamo_tensorrt_converter(
torch.ops.aten._adaptive_avg_pool2d.default,
capability_validator=adaptive_pool_static_input_validator,
)
@dynamo_tensorrt_converter(
torch.ops.aten.adaptive_avg_pool3d.default,
capability_validator=adaptive_pool_static_input_validator,
)
@dynamo_tensorrt_converter(
torch.ops.aten._adaptive_avg_pool3d.default,
capability_validator=adaptive_pool_static_input_validator,
)
def aten_ops_adaptive_avg_poolNd(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.pool.adaptive_avg_poolNd(
ctx,
target,
source_ir=SourceIR.ATEN,
name=name,
input=args[0],
output_size=args[1],
)


def max_pool_param_validator(pool_node: Node) -> bool:
dilation = args_bounds_check(pool_node.args, 4, 1)
ceil_mode = args_bounds_check(pool_node.args, 5, False)
Expand Down
186 changes: 185 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import extend_attr_to_tuple
from torch_tensorrt.dynamo.conversion.converter_utils import (
extend_attr_to_tuple,
get_positive_dim,
)
from torch_tensorrt.fx.converters.converter_utils import (
has_dynamic_shape,
set_layer_name,
Expand Down Expand Up @@ -169,3 +172,184 @@ def end_index(idx: int, out_dim: int, in_dim: int) -> int:

output = impl.cat.cat(ctx, target, source_ir, f"{name}_cat", output_list, dim=-1)
return output


def adaptive_avg_poolNd(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
output_size: Sequence[int],
) -> TRTTensor:
input_rank = len(input.shape)

if input_rank == 3: # TRT doesn't support 3D pooling
input = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape", input, (1, *input.shape)
)

extend_len = len(output_size)
output_size = list(output_size)
original_input = input

# repeat_interleave the input if the dim of output is larger than input
input_shape = input.shape
insert_axises = []
for axis in range(1, extend_len + 1):
axis = -axis
positive_axis = get_positive_dim(
axis, input_rank
) # convert to positive axis, which is for calculating new shapes below
input_dim = input_shape[axis]
output_dim = output_size[axis]
diff = output_dim - input_dim
if diff > 0: # the dim of output is larger than input
times = output_dim // input_dim
remainder = output_dim % input_dim
if (
diff == 2 and remainder == 2
): # case 1: output_dim - input_dim == 2 and is not an integral multiple
insert_axises.append(axis)
remainder -= 1
output_size[axis] -= 1

if (
remainder + 1 == input_dim
): # case 2: remainder + 1 == input_dim, we will repeat_interleave the whole input
remainder = 0
times += 1

flags = [] # record the axis that needs to be repeated
concat_list = []
for j in range(
input_dim
): # iterate the input dim to see which dim needs to be repeated or not
single_elem = impl.select.select(
ctx, target, source_ir, f"{name}_select_{axis}_{j}", input, axis, j
)
new_shape = list(single_elem.shape)
new_shape.insert(positive_axis, 1)
single_elem = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_{axis}_{j}",
single_elem,
new_shape,
)
if remainder > 0 or j in flags:
concat_list.extend([single_elem] * (times + 1))
remainder -= 2
flags.append(input_dim - j - 1)
else:
concat_list.extend([single_elem] * times)
out = impl.cat.cat(
ctx, target, source_ir, f"{name}_cat_{axis}", concat_list, axis
)
input = out

stride = tuple(
input.shape[-extend_len + i] // output_size[i] for i in range(extend_len)
)
kernel_size = tuple(
input.shape[-extend_len + i] - (output_size[i] - 1) * stride[i]
for i in range(extend_len)
)

# Don't have to pool, directly return
if all(s == 1 for s in stride) and all(k == 1 for k in kernel_size):
if input_rank == 3: # reshape back to 3D
input = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_back",
input,
(*input.shape[1:],),
)
return input

layer = ctx.net.add_pooling_nd(
input=input, type=trt.PoolingType.AVERAGE, window_size=kernel_size
)
layer.stride_nd = stride
set_layer_name(layer, target, f"{name}_pooling_{extend_len}d", source_ir)

output = layer.get_output(0)

# For case 1, we need to split the output and insert the mid of input
for axis in insert_axises:
positive_axis = get_positive_dim(axis, input_rank)
input_dim = input_shape[axis]
output_dim = output_size[axis]
if input_dim % 2 == 1:
mid = impl.select.select(
ctx,
target,
source_ir,
f"{name}_select_{axis}",
original_input,
axis,
input_dim // 2,
)
new_shape = list(mid.shape)
new_shape.insert(positive_axis, 1)
mid = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape_{axis}", mid, new_shape
)
split_output = impl.split.split(
ctx, target, source_ir, f"{name}_split_{axis}", output, 2, axis
)
split_output.insert(1, mid)
output = impl.cat.cat(
ctx, target, source_ir, f"{name}_cat_{axis}", split_output, axis
)
else:
mid1 = impl.select.select(
ctx,
target,
source_ir,
f"{name}_select_{axis}",
original_input,
axis,
input_dim // 2 - 1,
)
new_shape = list(mid1.shape)
new_shape.insert(positive_axis, 1)
mid1 = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape_{axis}", mid1, new_shape
)
mid2 = impl.select.select(
ctx,
target,
source_ir,
f"{name}_select_{axis}",
original_input,
axis,
input_dim // 2,
)
mid2 = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape_{axis}", mid2, new_shape
)
split_output = impl.split.split(
ctx,
target,
source_ir,
f"{name}_split_{axis}",
output,
[output_dim // 2, 1, output_dim // 2],
axis,
)
split_output[1] = mid1
split_output.insert(2, mid2)
output = impl.cat.cat(
ctx, target, source_ir, f"{name}_cat_{axis}", split_output, axis
)

if input_rank == 3: # reshape back to 3D
output = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape_back", output, (*output.shape[1:],)
)

return output
2 changes: 1 addition & 1 deletion tests/py/dynamo/conversion/test_adaptive_avgpool_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from harness import DispatchTestCase
from .harness import DispatchTestCase


class TestAdaptiveAvgPoolConverter(DispatchTestCase):
Expand Down

0 comments on commit e2e0aa5

Please sign in to comment.