Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use lists for broadcast #2

Open
wants to merge 1 commit into
base: single-broadcast
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ def asdict_zerocopy(self,
for field in fields(self) if field.name not in skip_fields
}

def values_list(self) -> List[Any]:
return [getattr(self, field.name) for field in fields(self)]


T = TypeVar("T", bound=AttentionMetadata)

Expand Down
11 changes: 10 additions & 1 deletion vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, fields
from typing import Any, Dict, List, Optional, Set, Tuple, Type

import flashinfer
Expand Down Expand Up @@ -138,6 +138,15 @@ def asdict_zerocopy(self,
skip_fields.add('decode_wrapper')
return super().asdict_zerocopy(skip_fields)

def values_list(self) -> List[Any]:
# We need to skip the decode_wrapper field since it cannot be
# broadcasted with nccl when TP is enabled.
return [
getattr(self, field.name)
if field.name != 'decode_wrapper' else None
for field in fields(self)
]

@property
def prefill_metadata(self) -> Optional["FlashInferMetadata"]:
# Currently chunked prefill is not supported
Expand Down
65 changes: 45 additions & 20 deletions vllm/distributed/communication_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections import namedtuple
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from itertools import chain
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -196,28 +197,25 @@ def broadcast_object_list(obj_list: List[Any],
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])


def _split_tensor_dict(
tensor_dict: Dict[str, Union[torch.Tensor, Any]]
def _split_object_tensor_list(
object_list: List[Union[torch.Tensor, Any]]
) -> Tuple[List[Any], List[torch.Tensor]]:
"""Split the tensor dictionary into two parts:
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
by its metadata.
1. The input list with tensors replaced by its metadata.
2. A list of tensors.
"""
metadata_list = []
tensor_list = []
for key, value in tensor_dict.items():
for value in object_list:
if isinstance(value, torch.Tensor):
# Note: we cannot use `value.device` here,
# because it contains not only the device type but also the device
# index (e.g. "cuda:0"). We only need the device type.
# receiving side will set the device index.
device = "cpu" if value.is_cpu else "cuda"
metadata_list.append(
(key, TensorMetadata(device, value.dtype, value.size())))
tensor_list.append(value)
else:
metadata_list.append((key, value))
value = TensorMetadata(device, value.dtype, value.size())
metadata_list.append(value)
return metadata_list, tensor_list


Expand Down Expand Up @@ -247,19 +245,48 @@ def broadcast_tensor_dict(
metadata_group: Optional[ProcessGroup] = None,
callsite_id: int = 0,
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
"""Broadcast the input list comprising tensors and/or other objects.
`group` is used to broadcast the tensors, while `metadata_group` is used
to broadcast the metadata of the list (e.g. tensor sizes, dtypes,
non-tensor objects).
`callsite_id` should be obtained for a particular call site by
calling register_broadcast_callsite().
"""
if tensor_dict is not None:
assert isinstance(tensor_dict, dict), \
f"Expecting a dictionary, got {type(tensor_dict)}"
send_list = list(chain(tensor_dict.keys(), tensor_dict.values()))
broadcast_tensor_list(send_list, src, group, metadata_group,
callsite_id)
return tensor_dict

recv_list = broadcast_tensor_list(None, src, group, metadata_group,
callsite_id)
assert recv_list is not None
count = len(recv_list) // 2
return {recv_list[i]: recv_list[count + i] for i in range(count)}


def broadcast_tensor_list(
object_list: Optional[List[Union[torch.Tensor, Any]]] = None,
src: int = 0,
group: Optional[ProcessGroup] = None,
metadata_group: Optional[ProcessGroup] = None,
callsite_id: int = 0,
) -> Optional[List[Union[torch.Tensor, Any]]]:
"""Broadcast the input tensor dictionary.
`group` is used to broadcast the tensors, while `metadata_group` is used
to broadcast the metadata of the dict (e.g. dict structure, tensor sizes,
dtypes).
`callsite_id` should be obtained for a particular call site by
calling register_broadcast_callsite()
calling register_broadcast_callsite().
"""
group = group or torch.distributed.group.WORLD

# Bypass the function if we are using only 1 GPU.
if (not torch.distributed.is_initialized()
or torch.distributed.get_world_size(group=group) == 1):
return tensor_dict
return object_list

metadata_group = metadata_group or get_cpu_world_group()
ranks = torch.distributed.get_process_group_ranks(group)
Expand All @@ -270,10 +297,8 @@ def broadcast_tensor_dict(

rank = torch.distributed.get_rank()
if rank == src:
assert isinstance(
tensor_dict,
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
assert isinstance(object_list, list)
metadata_list, tensor_list = _split_object_tensor_list(object_list)
pickled = pickle.dumps(metadata_list)
size = len(pickled) + 8

Expand Down Expand Up @@ -331,11 +356,11 @@ def broadcast_tensor_dict(

recv_metadata = pickle.loads(BROADCAST_BUFFER[8:callsite_tensor_size])

tensor_dict = {}
object_list = []
async_handles = []
for key, value in recv_metadata:
for value in recv_metadata:
if not isinstance(value, TensorMetadata):
tensor_dict[key] = value
object_list.append(value)
continue

tensor = torch.empty(value.size,
Expand All @@ -351,8 +376,8 @@ def broadcast_tensor_dict(
async_op=True)
async_handles.append(handle)

tensor_dict[key] = tensor
object_list.append(tensor)

for async_handle in async_handles:
async_handle.wait()
return tensor_dict
return object_list
51 changes: 22 additions & 29 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
from vllm.distributed import (broadcast_tensor_dict, graph_capture,
from vllm.distributed import (broadcast_tensor_list, graph_capture,
register_broadcast_callsite)
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
Expand Down Expand Up @@ -621,40 +621,33 @@ def prepare_input_tensors(
seq_group_metadata_list, seq_lens, query_lens, self.device,
self.pin_memory)

metadata_dict = {
"input_tokens": input_tokens,
"input_positions": input_positions,
"selected_token_indices":
metadata_list = [
input_tokens,
input_positions,
sampling_metadata.selected_token_indices,
"lora_requests": lora_requests,
"lora_mapping": lora_mapping,
"multi_modal_input": multi_modal_input,
"num_prefill_tokens": num_prefill_tokens,
"num_decode_tokens": num_decode_tokens,
"slot_mapping": slot_mapping,
"num_prefills": num_prefills,
}
lora_requests,
lora_mapping,
multi_modal_input,
]
if attn_metadata:
metadata_dict.update(attn_metadata.asdict_zerocopy())
broadcast_tensor_dict(
metadata_dict,
metadata_list.extend(attn_metadata.values_list())
broadcast_tensor_list(
metadata_list,
src=0,
callsite_id=self.prepare_input_tensors_callsite_id)
else:
metadata_dict = broadcast_tensor_dict(
metadata_list = broadcast_tensor_list(
src=0, callsite_id=self.prepare_input_tensors_callsite_id)
input_tokens = metadata_dict.pop("input_tokens")
input_positions = metadata_dict.pop("input_positions")
selected_token_indices = metadata_dict.pop(
"selected_token_indices")
lora_mapping = metadata_dict.pop("lora_mapping")
lora_requests = metadata_dict.pop("lora_requests")
multi_modal_input = metadata_dict.pop("multi_modal_input")
if metadata_dict:
attn_metadata = self.attn_backend.make_metadata(
**metadata_dict)
else:
attn_metadata = None
(
input_tokens,
input_positions,
selected_token_indices,
lora_requests,
lora_mapping,
multi_modal_input,
) = metadata_list[:6]
attn_metadata = None if len(metadata_list) == 6 else \
self.attn_backend.make_metadata(*metadata_list[6:])
sampling_metadata = SamplingMetadata(
seq_groups=None,
selected_token_indices=selected_token_indices,
Expand Down
24 changes: 10 additions & 14 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
"""A GPU worker class."""
import gc
import os
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from typing import List, Optional, Set, Tuple, Union

import torch
import torch.distributed

from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
SpeculativeConfig, VisionLanguageConfig)
from vllm.distributed import (broadcast_tensor_dict,
from vllm.distributed import (broadcast_tensor_list,
ensure_model_parallel_initialized,
init_distributed_environment,
register_broadcast_callsite,
Expand Down Expand Up @@ -259,19 +259,15 @@ def execute_model(
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
device=self.device,
dtype=torch.int64).view(-1, 2)
data: Dict[str, Any] = {
"num_seq_groups": num_seq_groups,
"blocks_to_swap_in": blocks_to_swap_in,
"blocks_to_swap_out": blocks_to_swap_out,
"blocks_to_copy": blocks_to_copy,
}
broadcast_tensor_dict(data, src=0, callsite_id=callsite_id)
data = [
num_seq_groups, blocks_to_swap_in, blocks_to_swap_out,
blocks_to_copy
]
broadcast_tensor_list(data, src=0, callsite_id=callsite_id)
else:
data = broadcast_tensor_dict(src=0, callsite_id=callsite_id)
num_seq_groups = data["num_seq_groups"]
blocks_to_swap_in = data["blocks_to_swap_in"]
blocks_to_swap_out = data["blocks_to_swap_out"]
blocks_to_copy = data["blocks_to_copy"]
data = broadcast_tensor_list(src=0, callsite_id=callsite_id)
(num_seq_groups, blocks_to_swap_in, blocks_to_swap_out,
blocks_to_copy) = data

self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)

Expand Down