[Executor] CUDA Graph support padding batch (#2844)

* cuda graph support padding batch

* Integrate the startup parameters for the graph optimization backend and provide support for user - defined capture sizes.

* Do not insert max_num_seqs when the user specifies a capture list

* Support set graph optimization config from YAML file

* update cuda graph ci

* fix ci bug

* fix ci bug
This commit is contained in:
RAM
2025-07-16 10:49:01 +08:00
committed by GitHub
parent 61b3997b85
commit 0fad10b35a
30 changed files with 291 additions and 225 deletions

View File

@@ -2,4 +2,5 @@ max_model_len: 32768
max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
enable_static_graph_inference: True
graph_optimization_config:
graph_opt_level: 1

View File

@@ -2,4 +2,5 @@ max_model_len: 32768
max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
enable_static_graph_inference: True
graph_optimization_config:
graph_opt_level: 1

View File

@@ -3,4 +3,5 @@ max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
quantization: wint8
enable_static_graph_inference: True
graph_optimization_config:
graph_opt_level: 1

View File

@@ -3,4 +3,5 @@ max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
quantization: wint8
enable_static_graph_inference: True
graph_optimization_config:
graph_opt_level: 1

View File

@@ -2,4 +2,5 @@ max_model_len: 32768
max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
enable_static_graph_inference: True
graph_optimization_config:
graph_opt_level: 1

View File

@@ -3,4 +3,5 @@ max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
quantization: wint4
enable_static_graph_inference: True
graph_optimization_config:
graph_opt_level: 1

View File

@@ -3,4 +3,5 @@ max_num_seqs: 96
gpu_memory_utilization: 0.9
kv_cache_ratio: 0.71
tensor_parallel_size: 4
enable_static_graph_inference: True
graph_optimization_config:
graph_opt_level: 1

View File

@@ -2,4 +2,5 @@ max_model_len: 32768
max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
enable_static_graph_inference: True
graph_optimization_config:
graph_opt_level: 1

View File

@@ -2,4 +2,5 @@ max_model_len: 32768
max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
enable_static_graph_inference: True
graph_optimization_config:
graph_opt_level: 1

View File

@@ -3,4 +3,5 @@ max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
quantization: wfp8afp8
enable_static_graph_inference: True
graph_optimization_config:
graph_opt_level: 1

View File

@@ -2,4 +2,5 @@ max_model_len: 32768
max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
enable_static_graph_inference: True
graph_optimization_config:
graph_opt_level: 1

View File

@@ -2,4 +2,5 @@ max_model_len: 32768
max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
enable_static_graph_inference: True
graph_optimization_config:
graph_opt_level: 1

View File

@@ -3,4 +3,5 @@ max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
quantization: wint8
enable_static_graph_inference: True
graph_optimization_config:
graph_opt_level: 1

View File

@@ -3,4 +3,5 @@ max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
quantization: wint8
enable_static_graph_inference: True
graph_optimization_config:
graph_opt_level: 1

View File

@@ -2,4 +2,5 @@ max_model_len: 32768
max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
enable_static_graph_inference: True
graph_optimization_config:
graph_opt_level: 1

View File

@@ -3,4 +3,5 @@ max_num_seqs: 128
kv_cache_ratio: 0.75
tensor_parallel_size: 1
quantization: wint4
enable_static_graph_inference: True
graph_optimization_config:
graph_opt_level: 1

View File

@@ -260,27 +260,84 @@ class DeviceConfig:
if hasattr(self, key):
setattr(self, key, value)
@dataclass
class GraphOptimizationConfig:
def init_with_cudagrpah_size(self,
cudagraph_capture_sizes: list[int]) -> None:
"""To complete the initialization of config,
we need to know the cudagraph sizes"""
if self.cudagraph_capture_sizes is None:
self.cudagraph_capture_sizes = cudagraph_capture_sizes
else:
dedup_sizes = list(set(self.cudagraph_capture_sizes))
if len(dedup_sizes) < len(self.cudagraph_capture_sizes):
logger.info(("cudagraph sizes specified by model runner"
" %s is overridden by config %s"),
cudagraph_capture_sizes, dedup_sizes)
self.cudagraph_capture_sizes = dedup_sizes
"""
Configuration for compute graph level optimization.
"""
# sort to make sure cudagraph capture sizes are in descending order
"""The Top-level graph optimization contral corresponds to different backends.
- 0: dyncmic graph
- 1: static graph
- 2: static graph + cinn compilation backend
"""
graph_opt_level: int = 0
# CUDA Graph Config
""" Whether to use cudagraph.
- False: cudagraph is not used.
- True: cudagraph is used.
It requires that all input buffers have fixed addresses, and all
splitting ops write their outputs to input buffers.
- With dyncmic graph backend: ...
- With static grpah backend: WIP
"""
use_cudagraph: bool = False
"""Sizes to capture cudagraph.
- None (default): capture sizes are inferred from llm config.
- list[int]: capture sizes are specified as given."""
cudagraph_capture_sizes: Optional[list[int]] = None
""" Number of warmup runs for cudagraph. """
cudagraph_num_of_warmups: int = 2
"""Whether to copy input tensors for cudagraph.
If the caller can guarantee that the same input buffers
are always used, it can set this to False. Otherwise, it should
set this to True."""
cudagraph_copy_inputs: bool = False
""" In static graph, this is an operation list that does not need to be captured by the CUDA graph.
CudaGraphBackend will split these operations from the static graph.
Example usage:
cudagraph_splitting_ops = ["paddle.unified_attention"]
Note: If want to use subgraph capture functionality in a dynamic graph,
can manually split the model into multiple layers and apply the @support_cuda_graph decorator
only to the layer where CUDA graph functionality is required.
"""
cudagraph_splitting_ops = Optional[list[str]]
"""" Whether to use a full cuda graph for the entire forward pass rather than
splitting certain operations such as attention into subgraphs.
Thus this flag cannot be used together with splitting_ops."""
full_cuda_graph: bool = True
max_capture_size: int = field(default=None, init=False) # type: ignore
batch_size_to_captured_size: dict[int,
int] = field(default=None,
init=False) # type: ignore
# CINN Config ...
def init_with_cudagrpah_size(
self,
max_num_seqs:int = 0
) -> None:
"""
Initialize cuda graph capture sizes and
pre-compute the mapping from batch size to padded graph size
"""
# Regular capture sizes
self.cudagraph_capture_sizes = [size for size in self.cudagraph_capture_sizes if size < max_num_seqs]
dedup_sizes = list(set(self.cudagraph_capture_sizes))
if len(dedup_sizes) < len(self.cudagraph_capture_sizes):
logger.info(("cudagraph sizes specified by model runner"
" %s is overridden by config %s"),
cudagraph_capture_sizes, dedup_sizes)
self.cudagraph_capture_sizes = dedup_sizes
# Sort to make sure cudagraph capture sizes are in descending order
self.cudagraph_capture_sizes.sort(reverse=True)
self.max_capture_size = self.cudagraph_capture_sizes[
0] if self.cudagraph_capture_sizes else 0
# pre-compute the mapping from batch size to padded graph size
# Pre-compute the mapping from batch size to padded graph size
self.batch_size_to_captured_size = {}
for end, start in zip(self.cudagraph_capture_sizes,
self.cudagraph_capture_sizes[1:] + [0]):
@@ -292,68 +349,24 @@ class GraphOptimizationConfig:
self.batch_size_to_captured_size[
self.max_capture_size] = self.max_capture_size
def __init__(self,
enable_static_graph_inference: bool = False,
max_capture_batch_size: int = 64,
args = None):
"""The Top-level graph optimization contral corresponds to different backends.
- 0: dyncmic graph
- 1: static graph
- 2: static graph + cinn compilation backend
def _set_cudagraph_sizes(
self,
max_num_seqs:int = 0
):
"""
self.graph_opt_level: int = 0
# CUDA Graph Config
""" Whether to use cudagraph.
- False: cudagraph is not used.
- True: cudagraph is used.
It requires that all input buffers have fixed addresses, and all
splitting ops write their outputs to input buffers.
- With dyncmic graph backend: ...
- With static grpah backend: WIP
Calculate a series of candidate capture batch sizes,
and then extract a portion of them as the capture list for the CUDA graph based on user input.
"""
self.use_cudagraph: bool = False
"""Sizes to capture cudagraph.
- None (default): capture sizes are inferred from llm config.
- list[int]: capture sizes are specified as given."""
self.cudagraph_capture_sizes: Optional[list[int]] = None
""" Number of warmup runs for cudagraph. """
self.cudagraph_num_of_warmups: int = 2
"""Whether to copy input tensors for cudagraph.
If the caller can guarantee that the same input buffers
are always used, it can set this to False. Otherwise, it should
set this to True."""
self.cudagraph_copy_inputs: bool = False
""" In static graph, this is an operation list that does not need to be captured by the CUDA graph.
CudaGraphBackend will split these operations from the static graph.
Example usage:
cudagraph_splitting_ops = ["paddle.unified_attention"]
# Batch Size [1, 2, 4, 8, 16, ... 120, 128]
draft_capture_sizes = [1, 2, 4] + [8 * i for i in range(1, 17)]
# Batch Size [128, 144, ... 240, 256]
draft_capture_sizes += [16 * i for i in range(9, 17)]
# Batch Size [256, 288, ... 992, 1024]
draft_capture_sizes += [32 * i for i in range(17, 33)]
Note: If want to use subgraph capture functionality in a dynamic graph,
can manually split the model into multiple layers and apply the @support_cuda_graph decorator
only to the layer where CUDA graph functionality is required.
"""
self.cudagraph_splitting_ops = Optional[list[str]]
""""whether to use a full cuda graph for the entire forward pass rather than
splitting certain operations such as attention into subgraphs.
Thus this flag cannot be used together with splitting_ops."""
self.full_cuda_graph: bool = False
draft_capture_sizes.append(max_num_seqs)
self.cudagraph_capture_sizes = sorted(draft_capture_sizes)
self.max_capture_size: int = field(default=None, init=False) # type: ignore
self.batch_size_to_captured_size: dict[int,
int] = field(default=None,
init=False) # type: ignore
# CINN Config ...
for key, value in args.items():
if hasattr(self, key):
setattr(self, key, value)
capture_size = [i for i in range(1, max_capture_batch_size + 1)]
self.init_with_cudagrpah_size(cudagraph_capture_sizes=capture_size)
#TODO(wangmingkai02): change graph_opt_level=2 when using static mode with cinn
if enable_static_graph_inference:
self.graph_opt_level = 1
class LoadConfig:
"""
@@ -422,3 +435,13 @@ class FDConfig:
init=True) # type: ignore
kv_cache_config: KVCacheConfig = field(default=None,
init=True) # type: ignore
def __post_init__(self):
# Initialize cuda graph capture list
if self.graph_opt_config.cudagraph_capture_sizes is None:
self.graph_opt_config._set_cudagraph_sizes(max_num_seqs=self.parallel_config.max_num_seqs)
self.graph_opt_config.init_with_cudagrpah_size(max_num_seqs=self.parallel_config.max_num_seqs)
#TODO(wangmingkai02): change graph_opt_level=2 when using static mode with cinn
if self.graph_opt_config.graph_opt_level == 2:
self.graph_opt_config.graph_opt_level = 1

View File

@@ -18,7 +18,8 @@ from dataclasses import asdict, dataclass
from dataclasses import fields as dataclass_fields
from typing import Any, Dict, List, Optional
from fastdeploy.engine.config import (CacheConfig, Config, ModelConfig,
from fastdeploy.engine.config import (CacheConfig, Config,
GraphOptimizationConfig, ModelConfig,
ParallelConfig, SpeculativeConfig,
TaskOption)
from fastdeploy.scheduler.config import SchedulerConfig
@@ -283,20 +284,13 @@ class EngineArgs:
"""
SplitWise Use, Results Writer Batch Size
"""
enable_static_graph_inference: bool = False
"""
Whether to use static mode
"""
use_cudagraph: bool = False
"""
Flags to enable Cuda Graph
"""
max_capture_batch_size: int = 64
graph_optimization_config: Optional[Dict[str, Any]] = None
"""
Maximum Batch Size for Cuda Graph Capture
NOTE: Now only support to capture continuous batch size,
Example:
max_capture_batch_size=64, FastDeploy will capture graphs for batches [1,64].
Configuration for graph optimization backend execution.
"""
enable_logprob: bool = False
@@ -399,21 +393,14 @@ class EngineArgs:
"default is None. The priority of this configuration "\
"is lower than that of the config file. " \
"More complex quantization methods need to be configured via the config file.")
model_group.add_argument(
"--enable-static-graph-inference",
action='store_true',
default=EngineArgs.enable_static_graph_inference,
help="Whether to use static mode; if enabled, " \
"'paddle.to_static' will be used to convert dynamic to static.")
model_group.add_argument("--use-cudagraph",
action='store_true',
default=EngineArgs.use_cudagraph,
help="Flags to enable cuda graph.")
model_group.add_argument("--max-capture-batch-size",
type=int,
default=EngineArgs.max_capture_batch_size,
help="Maximum of Batch Size for Warm Up.")
model_group.add_argument("--graph-optimization-config",
type=json.loads,
default=EngineArgs.graph_optimization_config,
help="")
model_group.add_argument("--guided-decoding-backend",
type=str,
default=EngineArgs.guided_decoding_backend,
@@ -757,6 +744,15 @@ class EngineArgs:
enable_custom_all_reduce=self.enable_custom_all_reduce
)
def create_graph_optimization_config(self) -> GraphOptimizationConfig:
"""
Create and retuan a GraphOptimizationConfig object based on the current settings.
"""
if self.graph_optimization_config is not None:
return GraphOptimizationConfig(**self.graph_optimization_config)
else:
return GraphOptimizationConfig()
def create_engine_config(self) -> Config:
"""
Create and return a Config object based on the current settings.
@@ -771,8 +767,9 @@ class EngineArgs:
else:
self.max_num_batched_tokens = self.max_model_len
scheduler_cfg = self.create_scheduler_config()
speculative_cfg = self.create_speculative_config()
graph_opt_cfg = self.create_graph_optimization_config()
graph_opt_cfg.update_use_cudagraph(self.use_cudagraph)
assert not (self.use_cudagraph and self.enable_prefix_caching), \
"Prefix caching cannot be used with CUDA graph"
@@ -804,9 +801,7 @@ class EngineArgs:
max_num_partial_prefills=self.max_num_partial_prefills,
max_long_partial_prefills=self.max_long_partial_prefills,
long_prefill_token_threshold=self.long_prefill_token_threshold,
enable_static_graph_inference=self.enable_static_graph_inference,
use_cudagraph=self.use_cudagraph,
max_capture_batch_size=self.max_capture_batch_size,
graph_optimization_config=graph_opt_cfg,
guided_decoding_backend=self.guided_decoding_backend,
disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
enable_custom_all_reduce=self.enable_custom_all_reduce,

View File

@@ -430,6 +430,77 @@ class SpeculativeConfig:
llm_logger.info(
"=============================================================")
def __str__(self) -> str:
return self.to_json_string()
class GraphOptimizationConfig:
def __init__(
self,
graph_opt_level: Optional[int] = 0,
use_cudagraph: Optional[bool] = None,
cudagraph_capture_sizes: Optional[List[int]] = None,
**kwargs
):
"""
Graph Optimization Configuration class.
Attributes:
graph_opt_level: Compute graph optimization level
use_cudagraph: Use CUDA Graph or not
cudagraph_capture_sizes: Batch size list will be captured by CUDA Graph
"""
self.check_legality_parameters(graph_opt_level, use_cudagraph, cudagraph_capture_sizes, **kwargs)
self.graph_opt_level = graph_opt_level
self.use_cudagraph = use_cudagraph
self.cudagraph_capture_sizes = cudagraph_capture_sizes
def to_json_string(self):
"""
Convert speculative_config to json string.
"""
return json.dumps({
key: value
for key, value in self.__dict__.items()
})
def __str__(self) -> str:
return self.to_json_string()
def check_legality_parameters(
self,
graph_opt_level: Optional[int] = None,
use_cudagraph: Optional[bool] = None,
cudagraph_capture_sizes: Optional[List[int]] = None,
**kwargs
) -> None:
""" Check the legality of parameters passed in from the command line """
if graph_opt_level is not None:
assert graph_opt_level in [0, 1, 2], "In graph optimization config, graph_opt_level can only take the values of 0, 1 and 2."
if use_cudagraph is not None:
assert type(use_cudagraph) is bool, "In graph optimization config, type of use_cudagraph must is bool."
if cudagraph_capture_sizes is not None:
assert type(cudagraph_capture_sizes) is list, "In graph optimization config, type of cudagraph_capture_sizes must is list."
assert len(cudagraph_capture_sizes) > 0, "In graph optimization config, When opening the CUDA graph, it is forbidden to set the capture sizes to an empty list."
for key, value in kwargs.items():
raise ValueError(f"Invalid --graph-optimization-config parameter {key}")
def update_use_cudagraph(self, argument:bool):
"""
Unified user specifies the use_cudagraph parameter through two methods,
'--use-cudagraph' and '--graph-optimization-config'
"""
if self.use_cudagraph is None:
# User only set '--use-cudagraph'
self.use_cudagraph = argument
else:
# User both set '--use-cudagraph' and '--graph-optimization-config'
if self.use_cudagraph is False and argument is True:
raise ValueError("Invalid parameter: Cannot set --use-cudagraph and --graph-optimization-config '{\"use_cudagraph\":false}' simultaneously.")
argument = self.use_cudagraph
class ParallelConfig:
"""
@@ -573,6 +644,7 @@ class Config:
max_num_batched_tokens: Optional[int] = None,
pod_ips: Optional[List[str]] = None,
speculative_config: Optional[Dict[str, Any]] = None,
graph_optimization_config: Optional[Dict[str, Any]] = None,
use_warmup: bool = False,
engine_worker_queue_port: int = 8002,
limit_mm_per_prompt: Optional[Dict[str, Any]] = None,
@@ -584,9 +656,6 @@ class Config:
max_long_partial_prefills: int = 1,
long_prefill_token_threshold: int = 0,
reasoning_parser: str = None,
enable_static_graph_inference: bool = False,
use_cudagraph: bool = False,
max_capture_batch_size: int = 64,
guided_decoding_backend: Optional[str] = None,
disable_any_whitespace: bool = False,
enable_custom_all_reduce: bool = False,
@@ -609,6 +678,7 @@ class Config:
pod_ips (Optional[List[str]]): List of POD IPs. Default is None.
mm_processor_kwargs (Optional[Dict[str, Any]]): Additional arguments for multi-modal processor. Default is None.
speculative_config (Optional[Dict[str, Any]]): Speculative execution configuration. Default is None.
graph_optimization_config (Optional[Dict[str, Any]]): Graph optimizaion backend execution configuration. Default is None.
use_warmup (bool): Flag to use warmup. Default is False.
engine_worker_queue_port (int): Engine worker queue port. Default is 8002.
enable_mm (bool): Flag to enable multi-modal processing. Default is False.
@@ -643,9 +713,7 @@ class Config:
self.max_long_partial_prefills = max_long_partial_prefills
self.long_prefill_token_threshold = long_prefill_token_threshold
self.reasoning_parser = reasoning_parser
self.enable_static_graph_inference = enable_static_graph_inference
self.use_cudagraph = use_cudagraph
self.max_capture_batch_size = max_capture_batch_size
self.graph_optimization_config = graph_optimization_config
self.guided_decoding_backend = guided_decoding_backend
self.disable_any_whitespace = disable_any_whitespace
self.is_master = True

View File

@@ -1026,7 +1026,7 @@ class LLMEngine(object):
f" --speculative_model_name_or_path {self.cfg.speculative_config.model_name_or_path}"
f" --speculative_model_quantization {self.cfg.speculative_config.quantization}"
f" --speculative_benchmark_mode {self.cfg.speculative_config.benchmark_mode}"
f" --max_capture_batch_size {self.cfg.max_capture_batch_size}"
f" --graph_optimiaztion_config '{self.cfg.graph_optimization_config.to_json_string()}'"
f" --guided_decoding_backend {self.cfg.guided_decoding_backend}"
f" --load_strategy {self.cfg.model_config.load_strategy}"
f" --enable_mm {self.cfg.enable_mm}")
@@ -1041,9 +1041,6 @@ class LLMEngine(object):
self.cfg.cache_config.enable_chunked_prefill,
"do_profile": self.do_profile,
"dynamic_load_weight": self.cfg.model_config.dynamic_load_weight,
"enable_static_graph_inference":
self.cfg.enable_static_graph_inference,
"use_cudagraph": self.cfg.use_cudagraph,
"disable_any_whitespace": self.cfg.disable_any_whitespace,
"enable-custom-all-reduce": self.cfg.parallel_config.enable_custom_all_reduce,
"enable_logprob": self.cfg.enable_logprob,

View File

@@ -17,11 +17,11 @@
import logging
from dataclasses import dataclass
from enum import IntEnum, auto
from typing import TYPE_CHECKING, Optional
from fastdeploy.model_executor.layers.attention import AttentionBackend
from typing import Optional
import paddle
from fastdeploy.model_executor.layers.attention import AttentionBackend
logger = logging.getLogger(__name__)
@@ -64,8 +64,6 @@ class ForwardMeta():
# Use cuda graph in this step or not. Used to avoid run cuda graph when in dummy run or prefill stage.
step_use_cudagraph: bool = False
# Batch type flag
is_decode_batch: bool = False
# Attention backend object
attn_backend: AttentionBackend = None

View File

@@ -68,16 +68,20 @@ class CudaGraphPiecewiseBackend:
self.concrete_size_entries[shape] = ConcreteSizeEntry(
runtime_bs=shape)
logger.debug("[CUDA GRAPH] Created all batch size entry ")
logger.info(
f"[CUDA GRAPH] CUDAGraph capture list {self.cudagraph_capture_sizes}, "
"Created all batch sizes entry."
)
def __call__(self, **kwargs):
# Get batch size
ids_remove_padding: paddle.Tensor = kwargs["ids_remove_padding"]
batch_size = ids_remove_padding.shape[0]
padding_batch_size = self.batch_size_to_captured_size[batch_size]
logger.debug((
f"[CUDA GRAPH] The actual batch size obtained by CUDAGraph is :{batch_size}, ",
f"The padded batch size is :{padding_batch_size}"))
logger.debug(
f"[CUDA GRAPH] The actual batch size obtained by CUDAGraph is :{batch_size}, "
f"The padded batch size is :{padding_batch_size}"
)
entry = self.concrete_size_entries.get(padding_batch_size)
assert entry is not None, f"Batch size:{padding_batch_size} is not in cuda graph capture list."
@@ -96,10 +100,10 @@ class CudaGraphPiecewiseBackend:
for n in range(entry.num_finished_warmup, self.warm_up_size):
entry.num_finished_warmup += 1
entry.runnable(**kwargs)
logger.debug((
"[CUDA GRAPH] Warm up for batch size ",
f"{padding_batch_size}, finished ({n+1}/{entry.num_finished_warmup}) times"
))
logger.debug(
f"[CUDA GRAPH] Warm up for batch size {padding_batch_size}, "
f"finished ({n+1}/{entry.num_finished_warmup}) times"
)
# Store input addresses for debug
input_addresses = [

View File

@@ -33,7 +33,8 @@ from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend, AttentionMetadata)
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
from fastdeploy.model_executor.layers.attention.utils import \
init_rank_and_device_id
@dataclass
@@ -106,7 +107,7 @@ class AppendAttentionBackend(AttentionBackend):
if fd_config.parallel_config.expert_parallel_rank is None:
fd_config.parallel_config.expert_parallel_rank = 0
self.rank, self.device_id = init_rank_and_device_id(fd_config)
def init_attention_metadata(self, forward_meta: ForwardMeta):
@@ -134,8 +135,8 @@ class AppendAttentionBackend(AttentionBackend):
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
metadata.decoder_batch_ids,
metadata.decoder_tile_ids_per_batch,
metadata.decoder_batch_ids, # will copy to buffer
metadata.decoder_tile_ids_per_batch, # will copy to buffer
metadata.decoder_num_blocks,
metadata.max_len_kv,
metadata.set_max_lengths,

View File

@@ -53,9 +53,6 @@ class RolloutModelConfig:
enable_expert_parallell: bool = False,
ori_vocab_size: int = None,
quantization: str = "None",
enable_static_graph_inference: bool = False,
use_cudagraph: bool = False,
max_capture_batch_size: int = 64,
guided_decoding_backend: str = "off",
disable_any_whitespace: bool = True,
enable_logprob: bool = False,
@@ -95,9 +92,6 @@ class RolloutModelConfig:
self.enable_expert_parallell = enable_expert_parallell
self.ori_vocab_size = ori_vocab_size
self.quantization = quantization
self.enable_static_graph_inference = enable_static_graph_inference
self.use_cudagraph = use_cudagraph
self.max_capture_batch_size = max_capture_batch_size
self.guided_decoding_backend = guided_decoding_backend
self.disable_any_whitespace = disable_any_whitespace
self.enable_logprob = enable_logprob

View File

@@ -335,66 +335,44 @@ def download_model(url, output_dir, temp_tar):
class FlexibleArgumentParser(argparse.ArgumentParser):
"""
扩展 argparse.ArgumentParser,支持从 YAML 文件加载参数。
Extend argparse.ArgumentParser to support loading parameters from YAML files.
"""
def __init__(self, *args, config_arg='--config', sep='_', **kwargs):
super().__init__(*args, **kwargs)
self.sep = sep # 用于展平嵌套字典的分隔符
# 创建临时解析器,仅用于解析 --config 参数
self.sep = sep
# Create parser to prase yaml file
self.tmp_parser = argparse.ArgumentParser(add_help=False)
self.tmp_parser.add_argument(config_arg,
type=str,
help='Path to YAML config file')
def parse_args(self, args=None, namespace=None):
# 使用临时解析器解析出 --config 参数
tmp_ns, remaining_args = self.tmp_parser.parse_known_args(args=args)
config_path = tmp_ns.config
# 加载 YAML 文件并展平嵌套结构
config = {}
if config_path:
with open(config_path, 'r') as f:
loaded_config = yaml.safe_load(f)
config = self._flatten_dict(loaded_config)
config = loaded_config
# 获取所有已定义参数的 dest 名称
# Get declared parameters
defined_dests = {action.dest for action in self._actions}
# 过滤出已定义的参数
filtered_config = {
k: v
for k, v in config.items() if k in defined_dests
}
# 创建或使用现有的命名空间对象
# Set parameters
if namespace is None:
namespace = argparse.Namespace()
# 将配置参数设置到命名空间
for key, value in filtered_config.items():
setattr(namespace, key, value)
# 解析剩余参数并覆盖默认值
return super().parse_args(args=remaining_args, namespace=namespace)
def _flatten_dict(self, d):
"""将嵌套字典展平为单层字典,键由分隔符连接"""
def _flatten(d, parent_key=''):
items = []
for k, v in d.items():
new_key = f"{parent_key}{self.sep}{k}" if parent_key else k
if isinstance(v, dict):
items.extend(_flatten(v, new_key).items())
else:
items.append((new_key, v))
return dict(items)
return _flatten(d)
def resolve_obj_from_strname(strname: str):
module_name, obj_name = strname.rsplit(".", 1)
module = importlib.import_module(module_name)

View File

@@ -748,10 +748,6 @@ class GCUModelRunner(ModelRunnerBase):
# 3. Prepare lora
# 4. Run model
is_decode_batch = not ((self.share_inputs["seq_lens_this_time"]
> 1).sum() > 0)
self.forward_meta.step_use_cudagraph = is_decode_batch and in_capturing
self.forward_meta.is_decode_batch = is_decode_batch
model_output = self.model(
ids_remove_padding=self.share_inputs["ids_remove_padding"],
forward_meta=self.forward_meta)
@@ -979,10 +975,6 @@ class GCUModelRunner(ModelRunnerBase):
# 2. Padding inputs for cuda grph
# 3. Execute model
is_decode_batch = not ((self.share_inputs["seq_lens_this_time"]
> 1).sum() > 0)
self.forward_meta.step_use_cudagraph = self.use_cudagraph and is_decode_batch
self.forward_meta.is_decode_batch = is_decode_batch
model_output = self.model(
ids_remove_padding=self.share_inputs["ids_remove_padding"],
forward_meta=self.forward_meta)

View File

@@ -417,7 +417,10 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["seq_lens_decoder"][idx:idx + 1] = 0
self.share_inputs["step_idx"][idx:idx + 1] = 0
self.share_inputs["max_dec_len"][idx:idx + 1] = max_dec_len
self.share_inputs["min_dec_len"][idx:idx + 1] = max_dec_len
self.share_inputs["stop_flags"][idx:idx + 1] = False
self.share_inputs["top_p"][idx:idx + 1] = 0.0
self.share_inputs["temperature"][idx:idx + 1] = 1
self.share_inputs["first_token_ids"][
idx:idx + 1] = self.share_inputs["input_ids"][idx:idx + 1, :1]
@@ -759,6 +762,11 @@ class GPUModelRunner(ModelRunnerBase):
caches=self.share_inputs["caches"]
)
# Update Batch type for cuda graph
# TODO(gongshaotian): Use seq_lens_encoder to set is_decode_batch
is_decode_batch = not ((self.share_inputs["seq_lens_this_time"] > 1).sum() > 0)
self.forward_meta.step_use_cudagraph = self.use_cudagraph and is_decode_batch
# Initialzie attention meta data
for attn_backend in self.attn_backends:
attn_backend.init_attention_metadata(self.forward_meta)
@@ -850,6 +858,7 @@ class GPUModelRunner(ModelRunnerBase):
Args:
num_tokens:
expected_decode_len: Expected number of tokens generated
in_capturing: Is cuda graph in capturing state
"""
self._dummy_prefill_inputs(num_tokens=num_tokens,
batch_size=batch_size,
@@ -864,17 +873,16 @@ class GPUModelRunner(ModelRunnerBase):
# 1. Initialize forward meta and attention meta data
self._prepare_inputs()
# 2. Prepare lora
# 2. Padding inputs for cuda graph
self.forward_meta.step_use_cudagraph = in_capturing and self.forward_meta.step_use_cudagraph
self.padding_cudagraph_inputs()
# 3. Run model
is_decode_batch = not ((self.share_inputs["seq_lens_this_time"]
> 1).sum() > 0)
self.forward_meta.step_use_cudagraph = is_decode_batch and in_capturing
self.forward_meta.is_decode_batch = is_decode_batch
if self.enable_mm:
hidden_states = model_output = self.model(self.share_inputs["ids_remove_padding"],
model_output = self.model(self.share_inputs["ids_remove_padding"],
self.share_inputs["image_features"],
self.forward_meta)
hidden_states = model_output
else:
model_output = self.model(
ids_remove_padding=self.share_inputs["ids_remove_padding"],
@@ -1113,9 +1121,7 @@ class GPUModelRunner(ModelRunnerBase):
We plan to replace it with 'ModelForwardBatch'.
intermediate_tensors:
"""
# NOTE(wufeisheng): If `not_need_stop`` is False, it means the current worker is in an idle state.
# This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode,
# when there is data on other runner, the current runner is required to execute part of the model.
# NOTE(wufeisheng): For Expert Parallelism
if not self.not_need_stop():
self._execute_empty_input()
return None
@@ -1126,18 +1132,14 @@ class GPUModelRunner(ModelRunnerBase):
self.sampler.pre_process(skip_idx_list)
# 2. Padding inputs for cuda graph
self.padding_cudagraph_inputs()
# 3. Execute model
# TODO(gongshaotian): Use seq_lens_encoder to set is_decode_batch
is_decode_batch = not ((self.share_inputs["seq_lens_this_time"]
> 1).sum() > 0)
self.forward_meta.step_use_cudagraph = self.use_cudagraph and is_decode_batch
self.forward_meta.is_decode_batch = is_decode_batch
if self.enable_mm:
hidden_states = model_output = self.model(self.share_inputs["ids_remove_padding"],
model_output = self.model(self.share_inputs["ids_remove_padding"],
self.share_inputs["image_features"],
self.forward_meta)
hidden_states = model_output
else:
model_output = self.model(
ids_remove_padding=self.share_inputs["ids_remove_padding"],
@@ -1399,6 +1401,18 @@ class GPUModelRunner(ModelRunnerBase):
self.dynamic_weight_manager._log_memory(
"dynamic weight manager update all memory")
def padding_cudagraph_inputs(self) -> None:
"""
Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch.
In FastDeploy, almost all input tensors have a buffer. So, just keep the buffer clean when replaying the CUDA graph with the padded batch.
"""
# TODO(gongshaotian): Use more efficient implementation
if self.forward_meta.step_use_cudagraph:
num_empty_batch = (self.forward_meta.seq_lens_this_time == 0).sum()
for i in range(1, num_empty_batch + 1):
self.forward_meta.decoder_batch_ids[-i] = 0
self.forward_meta.decoder_tile_ids_per_batch[-i] = 0
def _init_image_preprocess(self) -> None:
processor = DataProcessor(
tokenizer_name=self.tokenizer_path,

View File

@@ -715,10 +715,6 @@ class IluvatarModelRunner(ModelRunnerBase):
# 3. Prepare lora
# 4. Run model
is_decode_batch = not ((self.share_inputs["seq_lens_this_time"]
> 1).sum() > 0)
self.forward_meta.step_use_cudagraph = is_decode_batch and in_capturing
self.forward_meta.is_decode_batch = is_decode_batch
model_output = self.model(
ids_remove_padding=self.share_inputs["ids_remove_padding"],
forward_meta=self.forward_meta)
@@ -939,11 +935,6 @@ class IluvatarModelRunner(ModelRunnerBase):
# 2. Padding inputs for cuda grph
# 3. Execute model
# TODO(gongshaotian): Use seq_lens_encoder to set is_decode_batch
is_decode_batch = not ((self.share_inputs["seq_lens_this_time"]
> 1).sum() > 0)
self.forward_meta.step_use_cudagraph = self.use_cudagraph and is_decode_batch
self.forward_meta.is_decode_batch = is_decode_batch
model_output = self.model(
ids_remove_padding=self.share_inputs["ids_remove_padding"],
forward_meta=self.forward_meta)

View File

@@ -14,6 +14,7 @@
# limitations under the License.
"""
import argparse
import json
import time
from typing import List
@@ -516,18 +517,11 @@ def parse_args():
"default is None. The priority of this configuration "\
"is lower than that of the config file. " \
"More complex quantization methods need to be configured via the config file.")
parser.add_argument("--enable_static_graph_inference",
action='store_true',
help="Whether to use static mode; if enabled, " \
"'paddle.to_static' will be used to convert dynamic to static.")
parser.add_argument("--use_cudagraph",
action='store_true',
help="Flags to enable cuda graph.")
parser.add_argument("--max_capture_batch_size",
type=int,
default=64,
help="Maximum Batch Size for Cuda Graph Capture. " \
"If max_capture_batch_size set 64, FastDeploy will capture batch size in [1, 64]")
parser.add_argument("--graph_optimiaztion_config",
type=json.loads,
default=None,
help=" Configation of Graph optimization backend. "
)
parser.add_argument("--guided_decoding_backend",
type=str,
default="off",
@@ -579,9 +573,10 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
load_config = LoadConfig(vars(args))
graph_opt_config = GraphOptimizationConfig(
args.enable_static_graph_inference,
args.max_capture_batch_size,
vars(args))
use_cudagraph=args.graph_optimiaztion_config["use_cudagraph"],
graph_opt_level=args.graph_optimiaztion_config["graph_opt_level"],
cudagraph_capture_sizes=args.graph_optimiaztion_config["cudagraph_capture_sizes"]
)
# Note(tangbinhan): used for load_checkpoint
model_config.pretrained_config.tensor_parallel_rank = parallel_config.tensor_parallel_rank

View File

@@ -91,7 +91,7 @@ def setup_and_run_server():
"--max-num-seqs", "128",
"--quantization", "wint4",
"--use-cudagraph",
"--max-capture-batch-size", "1"
"--graph-optimization-config", '{"cudagraph_capture_sizes": [1]}'
]
# Start subprocess in new process group