Skip to content

Commit

Permalink
using nccl ops from TRT-LLM namespace
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed Oct 21, 2024
1 parent 1bb044f commit 195b1c4
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 1 deletion.
57 changes: 56 additions & 1 deletion examples/distributed_inference/tensor_parallel_simple_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,73 @@
import torch
import torch.nn as nn
import torch_tensorrt
import torch.distributed as dist
from torch.distributed._tensor import Shard
from torch.distributed._tensor.device_mesh import init_device_mesh
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
parallelize_module,
)

from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
dynamo_tensorrt_converter,
)
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch.fx.node import Target, Argument
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
from torch_tensorrt.dynamo.types import TRTTensor
import numpy as np
from torch_tensorrt.fx.converters.converter_utils import (
set_layer_name,
)
import tensorrt as trt
import tensorrt_llm
import ctypes
import logging
"""
This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
"""

plugin_lib_path = "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so"
try:
ctypes.CDLL("/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so")
print("plugin loaded sucessfully")
except OSError as e:
print(f"unsuccessful load : {e}")
logger = trt.Logger(trt.Logger.VERBOSE)
trt.init_libnvinfer_plugins(None, '')
#-[p;Iterate over all registered plugin creators
plugin_registry = trt.get_plugin_registry()
for plugin_creator in plugin_registry.plugin_creator_list:
print(f"Plugin Name: {plugin_creator.name}, Namespace: {plugin_creator.plugin_namespace}, Version: {plugin_creator.plugin_version}")


@dynamo_tensorrt_converter(torch.ops._c10d_functional.all_gather_into_tensor.default)
def insert_gather_op(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
plug_inputs = [args[0]]
allgather_plg_creator = trt.get_plugin_registry().get_plugin_creator(
"AllGather", "1", "tensorrt_llm"
)
assert allgather_plg_creator is not None
world_size = dist.get_world_size()
group = list(range(world_size))
group = trt.PluginField("group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32)
p_dtype = trt.float16
pf_type = trt.PluginField(
"type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32
)
pfc = trt.PluginFieldCollection([group, pf_type])
allgather = allgather_plg_creator.create_plugin("allgather", pfc)
layer = ctx.net.add_plugin_v2(plug_inputs, allgather)
set_layer_name(layer, target, name)
return layer.get_output(0)


class ToyModel(nn.Module):
"""MLP based model"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .accumulate_fp32_matmul import accumulate_fp32_matmul
from .constant_folding import constant_fold
from .fuse_prims_broadcast import fuse_prims_broadcast
from .fuse_distributed_ops import fuse_distributed_ops
from .lower_linear import lower_linear
from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention
from .pass_manager import DynamoPassManager
Expand Down
69 changes: 69 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import logging
from typing import Sequence

import torch

# dead-code elimination, linting, and recompilation for graph, in-place
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)

logger = logging.getLogger(__name__)


def custom_fused_all_gather_op(args0, args1, args2):
return torch.ops._c10d_functional.wait_tensor.default(
torch.ops._c10d_functional.all_gather_into_tensor.default(args0, args1, args2)
)


def custom_fused_reduce_scatter_op(args0, args1, args2, args3):
return torch.ops._c10d_functional.wait_tensor.default(
torch.ops._c10d_functional.reduce_scatter_tensor.default(
args0, args1, args2, args3
)
)


def fuse_distributed_ops(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
modified_graph = False
for node in gm.graph.nodes:
if (
node.target
in (
torch.ops._c10d_functional.all_gather_into_tensor.default,
torch.ops._c10d_functional.reduce_scatter_tensor.default,
)
and len(node.users) == 1
and list(node.users)[0].target
== torch.ops._c10d_functional.wait_tensor.default
):
wait_tensor_node = list(node.users)[0]
fused_op = None
if node.target == torch.ops._c10d_functional.all_gather_into_tensor.default:
fused_op = custom_fused_all_gather_op
fused_op_args = (node.args[0], node.args[1], node.args[2])
else:
fused_op = custom_fused_reduce_scatter_op
fused_op_args = (node.args[0], node.args[1], node.args[2], node.args[3])
with gm.graph.inserting_after(wait_tensor_node):
fused_node = gm.graph.create_node(
op="call_function",
target=fused_op, # Define your custom fused function
args=fused_op_args,
)

wait_tensor_node.replace_all_uses_with(fused_node)
fused_node.meta.update(node.meta)
modified_graph = True
gm.graph.erase_node(wait_tensor_node)
gm.graph.erase_node(node)

# If graph was modified, clean it up
if modified_graph:
gm = clean_up_graph_after_modifications(gm)
logger.debug(
f"Graph after fusing wait_tensor and distributed op tensor:\n{gm.graph}"
)

return gm

0 comments on commit 195b1c4

Please sign in to comment.