mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Optimization] default compile rdma, reduce cudagraph buffer size in mm, fix some config bug (#5121)
* default compile rdma, reduce cudagraph buffer size in mm, fix some config logic * update * update * fix bug * enhance rdma compile * fix
This commit is contained in:
2
.github/workflows/_build_linux.yml
vendored
2
.github/workflows/_build_linux.yml
vendored
@@ -164,7 +164,7 @@ jobs:
|
||||
python -m pip install -r requirements.txt
|
||||
python -m pip install wheel
|
||||
# 编译RDMA
|
||||
export ENABLE_FD_RDMA=1
|
||||
export FD_ENABLE_RDMA_COMPILE=1
|
||||
bash build.sh 1 python false [${COMPILE_ARCH}]
|
||||
ls ./dist/*.whl
|
||||
'
|
||||
|
||||
@@ -902,6 +902,12 @@ class GraphOptimizationConfig:
|
||||
draft_capture_sizes.append(max_capture_size)
|
||||
self.cudagraph_capture_sizes = sorted(draft_capture_sizes)
|
||||
|
||||
def filter_capture_size(self, tp_size: int = 1):
|
||||
"""When TSP is used, capture size must be divisible by tp size."""
|
||||
self.cudagraph_capture_sizes = [
|
||||
draft_size for draft_size in self.cudagraph_capture_sizes if (draft_size % tp_size == 0)
|
||||
]
|
||||
|
||||
def to_json_string(self):
|
||||
"""
|
||||
Convert speculative_config to json string.
|
||||
@@ -1628,7 +1634,15 @@ class FDConfig:
|
||||
if self.device_config is not None and self.device_config.device_type != "cuda":
|
||||
self.graph_opt_config.use_cudagraph = False
|
||||
logger.info(f"CUDAGraph only support on GPU, current device type is {self.device_config.device_type}!")
|
||||
|
||||
if self.parallel_config.use_sequence_parallel_moe and self.graph_opt_config.use_cudagraph:
|
||||
if self.scheduler_config.max_num_seqs < self.parallel_config.tensor_parallel_size:
|
||||
self.parallel_config.use_sequence_parallel_moe = False
|
||||
logger.info(
|
||||
"Warning: sequence parallel moe do not support max_num_seqs < tensor_parallel_size when cudagraph enabled. We set use_sequence_parallel_moe to False."
|
||||
)
|
||||
else:
|
||||
# It will hang when real batch_size < tp_size
|
||||
self.graph_opt_config.filter_capture_size(tp_size=self.parallel_config.tensor_parallel_size)
|
||||
if self.model_config.enable_mm and self.graph_opt_config.use_cudagraph:
|
||||
self.cache_config.enable_prefix_caching = False
|
||||
logger.info("Multi-modal models do not support prefix caching when using CUDAGraph!")
|
||||
|
||||
@@ -512,8 +512,10 @@ class EngineArgs:
|
||||
raise ValueError(
|
||||
"Please set --rdma_comm_ports argument when using " "rdma cache transfer protocol."
|
||||
)
|
||||
if len(self.rdma_comm_ports) != self.tensor_parallel_size:
|
||||
raise ValueError("The number of rdma comm ports must be equal to tensor parallel size.")
|
||||
if len(self.rdma_comm_ports) != self.tensor_parallel_size * self.data_parallel_size:
|
||||
raise ValueError(
|
||||
f"The number of rdma comm ports must be equal to number of ranks ({self.data_parallel_size=} * {self.tensor_parallel_size=} = {self.data_parallel_size * self.tensor_parallel_size}), but got {len(self.rdma_comm_ports)}."
|
||||
)
|
||||
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER == 1:
|
||||
if "ipc" in self.cache_transfer_protocol:
|
||||
|
||||
@@ -570,10 +570,11 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
|
||||
self.ernie = Ernie4_5_VLModel(fd_config=fd_config)
|
||||
|
||||
# Persistent buffers for CUDA graphs.
|
||||
self._input_embeddings = paddle.zeros(
|
||||
[fd_config.model_config.max_model_len, fd_config.model_config.hidden_size],
|
||||
dtype=fd_config.model_config.dtype,
|
||||
)
|
||||
if fd_config.graph_opt_config.use_cudagraph:
|
||||
self._decoder_input_embeddings = paddle.zeros(
|
||||
[fd_config.graph_opt_config.max_capture_size, fd_config.model_config.hidden_size],
|
||||
dtype=fd_config.model_config.dtype,
|
||||
)
|
||||
|
||||
self.ori_vocab_size = fd_config.model_config.ori_vocab_size
|
||||
|
||||
@@ -783,10 +784,13 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
|
||||
image_features=image_features,
|
||||
image_token_num=vl_moe_meta.num_image_patch_id.item(),
|
||||
)
|
||||
self._input_embeddings.copy_(input_embeddings, False)
|
||||
|
||||
if forward_meta.step_use_cudagraph:
|
||||
self._decoder_input_embeddings.copy_(input_embeddings, False)
|
||||
input_embeddings = self._decoder_input_embeddings
|
||||
|
||||
hidden_states = self.ernie(
|
||||
input_embeddings=self._input_embeddings,
|
||||
input_embeddings=input_embeddings,
|
||||
ids_remove_padding=ids_remove_padding,
|
||||
forward_meta=forward_meta,
|
||||
vl_moe_meta=vl_moe_meta,
|
||||
|
||||
@@ -59,10 +59,11 @@ class Ernie4_5_VLMoeRewardBaseModel(nn.Layer):
|
||||
self.head_dtype = paddle.bfloat16
|
||||
|
||||
# Persistent buffers for CUDA graphs.
|
||||
self._input_embeddings = paddle.zeros(
|
||||
[fd_config.parallel_config.max_model_len, fd_config.model_config.hidden_size],
|
||||
dtype=fd_config.model_config.dtype,
|
||||
)
|
||||
if fd_config.graph_opt_config.use_cudagraph:
|
||||
self._decoder_input_embeddings = paddle.zeros(
|
||||
[fd_config.graph_opt_config.max_capture_size, fd_config.model_config.hidden_size],
|
||||
dtype=fd_config.model_config.dtype,
|
||||
)
|
||||
|
||||
self.rm_head = nn.Sequential(
|
||||
(
|
||||
@@ -112,10 +113,13 @@ class Ernie4_5_VLMoeRewardBaseModel(nn.Layer):
|
||||
image_features=image_features,
|
||||
image_token_num=vl_moe_meta.image_token_num.item(),
|
||||
)
|
||||
self._input_embeddings.copy_(input_embeddings, False)
|
||||
|
||||
if forward_meta.step_use_cudagraph:
|
||||
self._decoder_input_embeddings.copy_(input_embeddings, False)
|
||||
input_embeddings = self._decoder_input_embeddings
|
||||
|
||||
hidden_states = self.ernie(
|
||||
input_embeddings=self._input_embeddings,
|
||||
input_embeddings=input_embeddings,
|
||||
ids_remove_padding=ids_remove_padding,
|
||||
forward_meta=forward_meta,
|
||||
vl_moe_meta=vl_moe_meta,
|
||||
|
||||
@@ -132,10 +132,11 @@ class PaddleOCRVLForConditionalGeneration(ModelForCasualLM):
|
||||
)
|
||||
|
||||
# Persistent buffers for CUDA graphs.
|
||||
self._decoder_input_embeddings = paddle.zeros(
|
||||
[fd_config.scheduler_config.max_num_seqs, fd_config.model_config.hidden_size],
|
||||
dtype=fd_config.model_config.dtype,
|
||||
)
|
||||
if fd_config.graph_opt_config.use_cudagraph:
|
||||
self._decoder_input_embeddings = paddle.zeros(
|
||||
[fd_config.graph_opt_config.max_capture_size, fd_config.model_config.hidden_size],
|
||||
dtype=fd_config.model_config.dtype,
|
||||
)
|
||||
|
||||
@paddle.no_grad()
|
||||
def load_weights(self, weights_iterator) -> None:
|
||||
@@ -242,15 +243,11 @@ class PaddleOCRVLForConditionalGeneration(ModelForCasualLM):
|
||||
|
||||
if forward_meta.step_use_cudagraph:
|
||||
self._decoder_input_embeddings.copy_(input_embeddings, False)
|
||||
input_embeddings = self._decoder_input_embeddings
|
||||
|
||||
hidden_states = self.model(
|
||||
input_embeddings=self._decoder_input_embeddings,
|
||||
forward_meta=forward_meta,
|
||||
)
|
||||
else:
|
||||
hidden_states = self.model(
|
||||
input_embeddings=input_embeddings,
|
||||
forward_meta=forward_meta,
|
||||
)
|
||||
hidden_states = self.model(
|
||||
input_embeddings=input_embeddings,
|
||||
forward_meta=forward_meta,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -152,10 +152,11 @@ class Qwen2_5_VLForConditionalGeneration(ModelForCasualLM):
|
||||
self.model = Qwen2_5_VLModel(fd_config=fd_config)
|
||||
|
||||
# Persistent buffers for CUDA graphs.
|
||||
self._input_embeddings = paddle.zeros(
|
||||
[fd_config.model_config.max_model_len, fd_config.model_config.hidden_size],
|
||||
dtype=fd_config.model_config.dtype,
|
||||
)
|
||||
if fd_config.graph_opt_config.use_cudagraph:
|
||||
self._decoder_input_embeddings = paddle.zeros(
|
||||
[fd_config.graph_opt_config.max_capture_size, fd_config.model_config.hidden_size],
|
||||
dtype=fd_config.model_config.dtype,
|
||||
)
|
||||
|
||||
self.ori_vocab_size = fd_config.model_config.ori_vocab_size
|
||||
|
||||
@@ -290,10 +291,13 @@ class Qwen2_5_VLForConditionalGeneration(ModelForCasualLM):
|
||||
input_embeddings = self.get_input_embeddings(
|
||||
ids_remove_padding=ids_remove_padding, image_features=image_features
|
||||
)
|
||||
self._input_embeddings.copy_(input_embeddings, False)
|
||||
|
||||
if forward_meta.step_use_cudagraph:
|
||||
self._decoder_input_embeddings.copy_(input_embeddings, False)
|
||||
input_embeddings = self._decoder_input_embeddings
|
||||
|
||||
hidden_states = self.model(
|
||||
input_embeddings=self._input_embeddings,
|
||||
input_embeddings=input_embeddings,
|
||||
ids_remove_padding=ids_remove_padding,
|
||||
image_features=image_features,
|
||||
forward_meta=forward_meta,
|
||||
|
||||
68
setup.py
68
setup.py
@@ -14,10 +14,12 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import glob
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
|
||||
import paddle
|
||||
@@ -180,6 +182,68 @@ def get_device_type():
|
||||
return "cpu"
|
||||
|
||||
|
||||
def check_header(header_path):
|
||||
return os.path.exists(header_path)
|
||||
|
||||
|
||||
def check_library(lib_name):
|
||||
# search /usr/lib /usr/lib64 /lib /lib64 .etc
|
||||
paths = [
|
||||
"/usr/lib",
|
||||
"/usr/lib32",
|
||||
"/usr/lib64",
|
||||
"/usr/lib/x86_64-linux-gnu",
|
||||
"/lib",
|
||||
"/lib32",
|
||||
"/lib64",
|
||||
"/usr/local/lib",
|
||||
"/usr/local/lib64",
|
||||
]
|
||||
for p in paths:
|
||||
if glob.glob(os.path.join(p, lib_name)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def check_rdma_packages():
|
||||
results = {}
|
||||
|
||||
# libibverbs-dev
|
||||
results["libibverbs header"] = check_header("/usr/include/infiniband/verbs.h")
|
||||
results["libibverbs library"] = check_library("libibverbs.so*") or check_library("libibverbs.so")
|
||||
|
||||
# librdmacm-dev
|
||||
results["librdmacm header"] = check_header("/usr/include/rdma/rdma_cma.h")
|
||||
results["librdmacm library"] = check_library("librdmacm.so*") or check_library("librdmacm.so")
|
||||
|
||||
print("===== RDMA Library Check Results =====")
|
||||
for k, v in results.items():
|
||||
status = "FOUND" if v else "NOT FOUND"
|
||||
print(f"{k:25}: {status}")
|
||||
|
||||
print("\n== Summary ==")
|
||||
if all(results.values()):
|
||||
print("All required RDMA libraries are installed.")
|
||||
return True
|
||||
else:
|
||||
print("Some RDMA libraries are missing. Suggested commands:")
|
||||
print("\nUbuntu/Debian:")
|
||||
print(" sudo apt-get install -y libibverbs-dev librdmacm-dev")
|
||||
print("\nCentOS/RHEL:")
|
||||
print(" sudo yum install -y libibverbs-devel librdmacm-devel")
|
||||
return False
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def rdma_comm_supported():
|
||||
supported = (
|
||||
get_device_type() in ["gpu", "xpu"]
|
||||
and check_rdma_packages()
|
||||
and os.getenv("FD_ENABLE_RDMA_COMPILE", "1") == "1"
|
||||
)
|
||||
return supported
|
||||
|
||||
|
||||
def get_name():
|
||||
"""get package name"""
|
||||
return "fastdeploy-" + get_device_type()
|
||||
@@ -237,10 +301,10 @@ setup(
|
||||
version=None,
|
||||
)
|
||||
]
|
||||
if os.getenv("ENABLE_FD_RDMA", "0") == "1"
|
||||
if rdma_comm_supported()
|
||||
else []
|
||||
),
|
||||
cmdclass=cmdclass_dict if os.getenv("ENABLE_FD_RDMA", "0") == "1" else {},
|
||||
cmdclass=cmdclass_dict if rdma_comm_supported() else {},
|
||||
zip_safe=False,
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
|
||||
Reference in New Issue
Block a user