mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 08:16:42 +08:00
[Executor] CUDA Graph support padding batch (#2844)
* cuda graph support padding batch * Integrate the startup parameters for the graph optimization backend and provide support for user - defined capture sizes. * Do not insert max_num_seqs when the user specifies a capture list * Support set graph optimization config from YAML file * update cuda graph ci * fix ci bug * fix ci bug
This commit is contained in:
@@ -2,4 +2,5 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -2,4 +2,5 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -3,4 +3,5 @@ max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
quantization: wint8
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -3,4 +3,5 @@ max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
quantization: wint8
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -2,4 +2,5 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -3,4 +3,5 @@ max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
quantization: wint4
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -3,4 +3,5 @@ max_num_seqs: 96
|
||||
gpu_memory_utilization: 0.9
|
||||
kv_cache_ratio: 0.71
|
||||
tensor_parallel_size: 4
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -2,4 +2,5 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -2,4 +2,5 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -3,4 +3,5 @@ max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
quantization: wfp8afp8
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -2,4 +2,5 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -2,4 +2,5 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -3,4 +3,5 @@ max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
quantization: wint8
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -3,4 +3,5 @@ max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
quantization: wint8
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -2,4 +2,5 @@ max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -3,4 +3,5 @@ max_num_seqs: 128
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 1
|
||||
quantization: wint4
|
||||
enable_static_graph_inference: True
|
||||
graph_optimization_config:
|
||||
graph_opt_level: 1
|
||||
|
@@ -260,14 +260,71 @@ class DeviceConfig:
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
|
||||
@dataclass
|
||||
class GraphOptimizationConfig:
|
||||
def init_with_cudagrpah_size(self,
|
||||
cudagraph_capture_sizes: list[int]) -> None:
|
||||
"""To complete the initialization of config,
|
||||
we need to know the cudagraph sizes"""
|
||||
if self.cudagraph_capture_sizes is None:
|
||||
self.cudagraph_capture_sizes = cudagraph_capture_sizes
|
||||
else:
|
||||
"""
|
||||
Configuration for compute graph level optimization.
|
||||
"""
|
||||
|
||||
"""The Top-level graph optimization contral corresponds to different backends.
|
||||
- 0: dyncmic graph
|
||||
- 1: static graph
|
||||
- 2: static graph + cinn compilation backend
|
||||
"""
|
||||
graph_opt_level: int = 0
|
||||
|
||||
# CUDA Graph Config
|
||||
""" Whether to use cudagraph.
|
||||
- False: cudagraph is not used.
|
||||
- True: cudagraph is used.
|
||||
It requires that all input buffers have fixed addresses, and all
|
||||
splitting ops write their outputs to input buffers.
|
||||
- With dyncmic graph backend: ...
|
||||
- With static grpah backend: WIP
|
||||
"""
|
||||
use_cudagraph: bool = False
|
||||
"""Sizes to capture cudagraph.
|
||||
- None (default): capture sizes are inferred from llm config.
|
||||
- list[int]: capture sizes are specified as given."""
|
||||
cudagraph_capture_sizes: Optional[list[int]] = None
|
||||
""" Number of warmup runs for cudagraph. """
|
||||
cudagraph_num_of_warmups: int = 2
|
||||
"""Whether to copy input tensors for cudagraph.
|
||||
If the caller can guarantee that the same input buffers
|
||||
are always used, it can set this to False. Otherwise, it should
|
||||
set this to True."""
|
||||
cudagraph_copy_inputs: bool = False
|
||||
""" In static graph, this is an operation list that does not need to be captured by the CUDA graph.
|
||||
CudaGraphBackend will split these operations from the static graph.
|
||||
Example usage:
|
||||
cudagraph_splitting_ops = ["paddle.unified_attention"]
|
||||
|
||||
Note: If want to use subgraph capture functionality in a dynamic graph,
|
||||
can manually split the model into multiple layers and apply the @support_cuda_graph decorator
|
||||
only to the layer where CUDA graph functionality is required.
|
||||
"""
|
||||
cudagraph_splitting_ops = Optional[list[str]]
|
||||
"""" Whether to use a full cuda graph for the entire forward pass rather than
|
||||
splitting certain operations such as attention into subgraphs.
|
||||
Thus this flag cannot be used together with splitting_ops."""
|
||||
full_cuda_graph: bool = True
|
||||
|
||||
max_capture_size: int = field(default=None, init=False) # type: ignore
|
||||
batch_size_to_captured_size: dict[int,
|
||||
int] = field(default=None,
|
||||
init=False) # type: ignore
|
||||
# CINN Config ...
|
||||
|
||||
def init_with_cudagrpah_size(
|
||||
self,
|
||||
max_num_seqs:int = 0
|
||||
) -> None:
|
||||
"""
|
||||
Initialize cuda graph capture sizes and
|
||||
pre-compute the mapping from batch size to padded graph size
|
||||
"""
|
||||
# Regular capture sizes
|
||||
self.cudagraph_capture_sizes = [size for size in self.cudagraph_capture_sizes if size < max_num_seqs]
|
||||
dedup_sizes = list(set(self.cudagraph_capture_sizes))
|
||||
if len(dedup_sizes) < len(self.cudagraph_capture_sizes):
|
||||
logger.info(("cudagraph sizes specified by model runner"
|
||||
@@ -275,12 +332,12 @@ class GraphOptimizationConfig:
|
||||
cudagraph_capture_sizes, dedup_sizes)
|
||||
self.cudagraph_capture_sizes = dedup_sizes
|
||||
|
||||
# sort to make sure cudagraph capture sizes are in descending order
|
||||
# Sort to make sure cudagraph capture sizes are in descending order
|
||||
self.cudagraph_capture_sizes.sort(reverse=True)
|
||||
self.max_capture_size = self.cudagraph_capture_sizes[
|
||||
0] if self.cudagraph_capture_sizes else 0
|
||||
|
||||
# pre-compute the mapping from batch size to padded graph size
|
||||
# Pre-compute the mapping from batch size to padded graph size
|
||||
self.batch_size_to_captured_size = {}
|
||||
for end, start in zip(self.cudagraph_capture_sizes,
|
||||
self.cudagraph_capture_sizes[1:] + [0]):
|
||||
@@ -292,68 +349,24 @@ class GraphOptimizationConfig:
|
||||
self.batch_size_to_captured_size[
|
||||
self.max_capture_size] = self.max_capture_size
|
||||
|
||||
def __init__(self,
|
||||
enable_static_graph_inference: bool = False,
|
||||
max_capture_batch_size: int = 64,
|
||||
args = None):
|
||||
"""The Top-level graph optimization contral corresponds to different backends.
|
||||
- 0: dyncmic graph
|
||||
- 1: static graph
|
||||
- 2: static graph + cinn compilation backend
|
||||
def _set_cudagraph_sizes(
|
||||
self,
|
||||
max_num_seqs:int = 0
|
||||
):
|
||||
"""
|
||||
self.graph_opt_level: int = 0
|
||||
|
||||
# CUDA Graph Config
|
||||
""" Whether to use cudagraph.
|
||||
- False: cudagraph is not used.
|
||||
- True: cudagraph is used.
|
||||
It requires that all input buffers have fixed addresses, and all
|
||||
splitting ops write their outputs to input buffers.
|
||||
- With dyncmic graph backend: ...
|
||||
- With static grpah backend: WIP
|
||||
Calculate a series of candidate capture batch sizes,
|
||||
and then extract a portion of them as the capture list for the CUDA graph based on user input.
|
||||
"""
|
||||
self.use_cudagraph: bool = False
|
||||
"""Sizes to capture cudagraph.
|
||||
- None (default): capture sizes are inferred from llm config.
|
||||
- list[int]: capture sizes are specified as given."""
|
||||
self.cudagraph_capture_sizes: Optional[list[int]] = None
|
||||
""" Number of warmup runs for cudagraph. """
|
||||
self.cudagraph_num_of_warmups: int = 2
|
||||
"""Whether to copy input tensors for cudagraph.
|
||||
If the caller can guarantee that the same input buffers
|
||||
are always used, it can set this to False. Otherwise, it should
|
||||
set this to True."""
|
||||
self.cudagraph_copy_inputs: bool = False
|
||||
""" In static graph, this is an operation list that does not need to be captured by the CUDA graph.
|
||||
CudaGraphBackend will split these operations from the static graph.
|
||||
Example usage:
|
||||
cudagraph_splitting_ops = ["paddle.unified_attention"]
|
||||
# Batch Size [1, 2, 4, 8, 16, ... 120, 128]
|
||||
draft_capture_sizes = [1, 2, 4] + [8 * i for i in range(1, 17)]
|
||||
# Batch Size [128, 144, ... 240, 256]
|
||||
draft_capture_sizes += [16 * i for i in range(9, 17)]
|
||||
# Batch Size [256, 288, ... 992, 1024]
|
||||
draft_capture_sizes += [32 * i for i in range(17, 33)]
|
||||
|
||||
Note: If want to use subgraph capture functionality in a dynamic graph,
|
||||
can manually split the model into multiple layers and apply the @support_cuda_graph decorator
|
||||
only to the layer where CUDA graph functionality is required.
|
||||
"""
|
||||
self.cudagraph_splitting_ops = Optional[list[str]]
|
||||
""""whether to use a full cuda graph for the entire forward pass rather than
|
||||
splitting certain operations such as attention into subgraphs.
|
||||
Thus this flag cannot be used together with splitting_ops."""
|
||||
self.full_cuda_graph: bool = False
|
||||
draft_capture_sizes.append(max_num_seqs)
|
||||
self.cudagraph_capture_sizes = sorted(draft_capture_sizes)
|
||||
|
||||
self.max_capture_size: int = field(default=None, init=False) # type: ignore
|
||||
self.batch_size_to_captured_size: dict[int,
|
||||
int] = field(default=None,
|
||||
init=False) # type: ignore
|
||||
|
||||
# CINN Config ...
|
||||
|
||||
for key, value in args.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
capture_size = [i for i in range(1, max_capture_batch_size + 1)]
|
||||
self.init_with_cudagrpah_size(cudagraph_capture_sizes=capture_size)
|
||||
#TODO(wangmingkai02): change graph_opt_level=2 when using static mode with cinn
|
||||
if enable_static_graph_inference:
|
||||
self.graph_opt_level = 1
|
||||
|
||||
class LoadConfig:
|
||||
"""
|
||||
@@ -422,3 +435,13 @@ class FDConfig:
|
||||
init=True) # type: ignore
|
||||
kv_cache_config: KVCacheConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
|
||||
def __post_init__(self):
|
||||
# Initialize cuda graph capture list
|
||||
if self.graph_opt_config.cudagraph_capture_sizes is None:
|
||||
self.graph_opt_config._set_cudagraph_sizes(max_num_seqs=self.parallel_config.max_num_seqs)
|
||||
self.graph_opt_config.init_with_cudagrpah_size(max_num_seqs=self.parallel_config.max_num_seqs)
|
||||
|
||||
#TODO(wangmingkai02): change graph_opt_level=2 when using static mode with cinn
|
||||
if self.graph_opt_config.graph_opt_level == 2:
|
||||
self.graph_opt_config.graph_opt_level = 1
|
||||
|
@@ -18,7 +18,8 @@ from dataclasses import asdict, dataclass
|
||||
from dataclasses import fields as dataclass_fields
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastdeploy.engine.config import (CacheConfig, Config, ModelConfig,
|
||||
from fastdeploy.engine.config import (CacheConfig, Config,
|
||||
GraphOptimizationConfig, ModelConfig,
|
||||
ParallelConfig, SpeculativeConfig,
|
||||
TaskOption)
|
||||
from fastdeploy.scheduler.config import SchedulerConfig
|
||||
@@ -283,20 +284,13 @@ class EngineArgs:
|
||||
"""
|
||||
SplitWise Use, Results Writer Batch Size
|
||||
"""
|
||||
enable_static_graph_inference: bool = False
|
||||
"""
|
||||
Whether to use static mode
|
||||
"""
|
||||
use_cudagraph: bool = False
|
||||
"""
|
||||
Flags to enable Cuda Graph
|
||||
"""
|
||||
max_capture_batch_size: int = 64
|
||||
graph_optimization_config: Optional[Dict[str, Any]] = None
|
||||
"""
|
||||
Maximum Batch Size for Cuda Graph Capture
|
||||
NOTE: Now only support to capture continuous batch size,
|
||||
Example:
|
||||
max_capture_batch_size=64, FastDeploy will capture graphs for batches [1,64].
|
||||
Configuration for graph optimization backend execution.
|
||||
"""
|
||||
|
||||
enable_logprob: bool = False
|
||||
@@ -399,21 +393,14 @@ class EngineArgs:
|
||||
"default is None. The priority of this configuration "\
|
||||
"is lower than that of the config file. " \
|
||||
"More complex quantization methods need to be configured via the config file.")
|
||||
|
||||
model_group.add_argument(
|
||||
"--enable-static-graph-inference",
|
||||
action='store_true',
|
||||
default=EngineArgs.enable_static_graph_inference,
|
||||
help="Whether to use static mode; if enabled, " \
|
||||
"'paddle.to_static' will be used to convert dynamic to static.")
|
||||
model_group.add_argument("--use-cudagraph",
|
||||
action='store_true',
|
||||
default=EngineArgs.use_cudagraph,
|
||||
help="Flags to enable cuda graph.")
|
||||
model_group.add_argument("--max-capture-batch-size",
|
||||
type=int,
|
||||
default=EngineArgs.max_capture_batch_size,
|
||||
help="Maximum of Batch Size for Warm Up.")
|
||||
model_group.add_argument("--graph-optimization-config",
|
||||
type=json.loads,
|
||||
default=EngineArgs.graph_optimization_config,
|
||||
help="")
|
||||
model_group.add_argument("--guided-decoding-backend",
|
||||
type=str,
|
||||
default=EngineArgs.guided_decoding_backend,
|
||||
@@ -757,6 +744,15 @@ class EngineArgs:
|
||||
enable_custom_all_reduce=self.enable_custom_all_reduce
|
||||
)
|
||||
|
||||
def create_graph_optimization_config(self) -> GraphOptimizationConfig:
|
||||
"""
|
||||
Create and retuan a GraphOptimizationConfig object based on the current settings.
|
||||
"""
|
||||
if self.graph_optimization_config is not None:
|
||||
return GraphOptimizationConfig(**self.graph_optimization_config)
|
||||
else:
|
||||
return GraphOptimizationConfig()
|
||||
|
||||
def create_engine_config(self) -> Config:
|
||||
"""
|
||||
Create and return a Config object based on the current settings.
|
||||
@@ -771,8 +767,9 @@ class EngineArgs:
|
||||
else:
|
||||
self.max_num_batched_tokens = self.max_model_len
|
||||
scheduler_cfg = self.create_scheduler_config()
|
||||
|
||||
speculative_cfg = self.create_speculative_config()
|
||||
graph_opt_cfg = self.create_graph_optimization_config()
|
||||
graph_opt_cfg.update_use_cudagraph(self.use_cudagraph)
|
||||
|
||||
assert not (self.use_cudagraph and self.enable_prefix_caching), \
|
||||
"Prefix caching cannot be used with CUDA graph"
|
||||
@@ -804,9 +801,7 @@ class EngineArgs:
|
||||
max_num_partial_prefills=self.max_num_partial_prefills,
|
||||
max_long_partial_prefills=self.max_long_partial_prefills,
|
||||
long_prefill_token_threshold=self.long_prefill_token_threshold,
|
||||
enable_static_graph_inference=self.enable_static_graph_inference,
|
||||
use_cudagraph=self.use_cudagraph,
|
||||
max_capture_batch_size=self.max_capture_batch_size,
|
||||
graph_optimization_config=graph_opt_cfg,
|
||||
guided_decoding_backend=self.guided_decoding_backend,
|
||||
disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
|
||||
enable_custom_all_reduce=self.enable_custom_all_reduce,
|
||||
|
@@ -430,6 +430,77 @@ class SpeculativeConfig:
|
||||
llm_logger.info(
|
||||
"=============================================================")
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.to_json_string()
|
||||
|
||||
|
||||
class GraphOptimizationConfig:
|
||||
def __init__(
|
||||
self,
|
||||
graph_opt_level: Optional[int] = 0,
|
||||
use_cudagraph: Optional[bool] = None,
|
||||
cudagraph_capture_sizes: Optional[List[int]] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Graph Optimization Configuration class.
|
||||
|
||||
Attributes:
|
||||
graph_opt_level: Compute graph optimization level
|
||||
use_cudagraph: Use CUDA Graph or not
|
||||
cudagraph_capture_sizes: Batch size list will be captured by CUDA Graph
|
||||
"""
|
||||
self.check_legality_parameters(graph_opt_level, use_cudagraph, cudagraph_capture_sizes, **kwargs)
|
||||
|
||||
self.graph_opt_level = graph_opt_level
|
||||
self.use_cudagraph = use_cudagraph
|
||||
self.cudagraph_capture_sizes = cudagraph_capture_sizes
|
||||
|
||||
def to_json_string(self):
|
||||
"""
|
||||
Convert speculative_config to json string.
|
||||
"""
|
||||
return json.dumps({
|
||||
key: value
|
||||
for key, value in self.__dict__.items()
|
||||
})
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.to_json_string()
|
||||
|
||||
def check_legality_parameters(
|
||||
self,
|
||||
graph_opt_level: Optional[int] = None,
|
||||
use_cudagraph: Optional[bool] = None,
|
||||
cudagraph_capture_sizes: Optional[List[int]] = None,
|
||||
**kwargs
|
||||
) -> None:
|
||||
""" Check the legality of parameters passed in from the command line """
|
||||
|
||||
if graph_opt_level is not None:
|
||||
assert graph_opt_level in [0, 1, 2], "In graph optimization config, graph_opt_level can only take the values of 0, 1 and 2."
|
||||
if use_cudagraph is not None:
|
||||
assert type(use_cudagraph) is bool, "In graph optimization config, type of use_cudagraph must is bool."
|
||||
if cudagraph_capture_sizes is not None:
|
||||
assert type(cudagraph_capture_sizes) is list, "In graph optimization config, type of cudagraph_capture_sizes must is list."
|
||||
assert len(cudagraph_capture_sizes) > 0, "In graph optimization config, When opening the CUDA graph, it is forbidden to set the capture sizes to an empty list."
|
||||
|
||||
for key, value in kwargs.items():
|
||||
raise ValueError(f"Invalid --graph-optimization-config parameter {key}")
|
||||
|
||||
def update_use_cudagraph(self, argument:bool):
|
||||
"""
|
||||
Unified user specifies the use_cudagraph parameter through two methods,
|
||||
'--use-cudagraph' and '--graph-optimization-config'
|
||||
"""
|
||||
if self.use_cudagraph is None:
|
||||
# User only set '--use-cudagraph'
|
||||
self.use_cudagraph = argument
|
||||
else:
|
||||
# User both set '--use-cudagraph' and '--graph-optimization-config'
|
||||
if self.use_cudagraph is False and argument is True:
|
||||
raise ValueError("Invalid parameter: Cannot set --use-cudagraph and --graph-optimization-config '{\"use_cudagraph\":false}' simultaneously.")
|
||||
argument = self.use_cudagraph
|
||||
|
||||
class ParallelConfig:
|
||||
"""
|
||||
@@ -573,6 +644,7 @@ class Config:
|
||||
max_num_batched_tokens: Optional[int] = None,
|
||||
pod_ips: Optional[List[str]] = None,
|
||||
speculative_config: Optional[Dict[str, Any]] = None,
|
||||
graph_optimization_config: Optional[Dict[str, Any]] = None,
|
||||
use_warmup: bool = False,
|
||||
engine_worker_queue_port: int = 8002,
|
||||
limit_mm_per_prompt: Optional[Dict[str, Any]] = None,
|
||||
@@ -584,9 +656,6 @@ class Config:
|
||||
max_long_partial_prefills: int = 1,
|
||||
long_prefill_token_threshold: int = 0,
|
||||
reasoning_parser: str = None,
|
||||
enable_static_graph_inference: bool = False,
|
||||
use_cudagraph: bool = False,
|
||||
max_capture_batch_size: int = 64,
|
||||
guided_decoding_backend: Optional[str] = None,
|
||||
disable_any_whitespace: bool = False,
|
||||
enable_custom_all_reduce: bool = False,
|
||||
@@ -609,6 +678,7 @@ class Config:
|
||||
pod_ips (Optional[List[str]]): List of POD IPs. Default is None.
|
||||
mm_processor_kwargs (Optional[Dict[str, Any]]): Additional arguments for multi-modal processor. Default is None.
|
||||
speculative_config (Optional[Dict[str, Any]]): Speculative execution configuration. Default is None.
|
||||
graph_optimization_config (Optional[Dict[str, Any]]): Graph optimizaion backend execution configuration. Default is None.
|
||||
use_warmup (bool): Flag to use warmup. Default is False.
|
||||
engine_worker_queue_port (int): Engine worker queue port. Default is 8002.
|
||||
enable_mm (bool): Flag to enable multi-modal processing. Default is False.
|
||||
@@ -643,9 +713,7 @@ class Config:
|
||||
self.max_long_partial_prefills = max_long_partial_prefills
|
||||
self.long_prefill_token_threshold = long_prefill_token_threshold
|
||||
self.reasoning_parser = reasoning_parser
|
||||
self.enable_static_graph_inference = enable_static_graph_inference
|
||||
self.use_cudagraph = use_cudagraph
|
||||
self.max_capture_batch_size = max_capture_batch_size
|
||||
self.graph_optimization_config = graph_optimization_config
|
||||
self.guided_decoding_backend = guided_decoding_backend
|
||||
self.disable_any_whitespace = disable_any_whitespace
|
||||
self.is_master = True
|
||||
|
@@ -1026,7 +1026,7 @@ class LLMEngine(object):
|
||||
f" --speculative_model_name_or_path {self.cfg.speculative_config.model_name_or_path}"
|
||||
f" --speculative_model_quantization {self.cfg.speculative_config.quantization}"
|
||||
f" --speculative_benchmark_mode {self.cfg.speculative_config.benchmark_mode}"
|
||||
f" --max_capture_batch_size {self.cfg.max_capture_batch_size}"
|
||||
f" --graph_optimiaztion_config '{self.cfg.graph_optimization_config.to_json_string()}'"
|
||||
f" --guided_decoding_backend {self.cfg.guided_decoding_backend}"
|
||||
f" --load_strategy {self.cfg.model_config.load_strategy}"
|
||||
f" --enable_mm {self.cfg.enable_mm}")
|
||||
@@ -1041,9 +1041,6 @@ class LLMEngine(object):
|
||||
self.cfg.cache_config.enable_chunked_prefill,
|
||||
"do_profile": self.do_profile,
|
||||
"dynamic_load_weight": self.cfg.model_config.dynamic_load_weight,
|
||||
"enable_static_graph_inference":
|
||||
self.cfg.enable_static_graph_inference,
|
||||
"use_cudagraph": self.cfg.use_cudagraph,
|
||||
"disable_any_whitespace": self.cfg.disable_any_whitespace,
|
||||
"enable-custom-all-reduce": self.cfg.parallel_config.enable_custom_all_reduce,
|
||||
"enable_logprob": self.cfg.enable_logprob,
|
||||
|
@@ -17,11 +17,11 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum, auto
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from fastdeploy.model_executor.layers.attention import AttentionBackend
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.layers.attention import AttentionBackend
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -64,8 +64,6 @@ class ForwardMeta():
|
||||
|
||||
# Use cuda graph in this step or not. Used to avoid run cuda graph when in dummy run or prefill stage.
|
||||
step_use_cudagraph: bool = False
|
||||
# Batch type flag
|
||||
is_decode_batch: bool = False
|
||||
|
||||
# Attention backend object
|
||||
attn_backend: AttentionBackend = None
|
||||
|
@@ -68,16 +68,20 @@ class CudaGraphPiecewiseBackend:
|
||||
self.concrete_size_entries[shape] = ConcreteSizeEntry(
|
||||
runtime_bs=shape)
|
||||
|
||||
logger.debug("[CUDA GRAPH] Created all batch size entry ")
|
||||
logger.info(
|
||||
f"[CUDA GRAPH] CUDAGraph capture list {self.cudagraph_capture_sizes}, "
|
||||
"Created all batch sizes entry."
|
||||
)
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
# Get batch size
|
||||
ids_remove_padding: paddle.Tensor = kwargs["ids_remove_padding"]
|
||||
batch_size = ids_remove_padding.shape[0]
|
||||
padding_batch_size = self.batch_size_to_captured_size[batch_size]
|
||||
logger.debug((
|
||||
f"[CUDA GRAPH] The actual batch size obtained by CUDAGraph is :{batch_size}, ",
|
||||
f"The padded batch size is :{padding_batch_size}"))
|
||||
logger.debug(
|
||||
f"[CUDA GRAPH] The actual batch size obtained by CUDAGraph is :{batch_size}, "
|
||||
f"The padded batch size is :{padding_batch_size}"
|
||||
)
|
||||
|
||||
entry = self.concrete_size_entries.get(padding_batch_size)
|
||||
assert entry is not None, f"Batch size:{padding_batch_size} is not in cuda graph capture list."
|
||||
@@ -96,10 +100,10 @@ class CudaGraphPiecewiseBackend:
|
||||
for n in range(entry.num_finished_warmup, self.warm_up_size):
|
||||
entry.num_finished_warmup += 1
|
||||
entry.runnable(**kwargs)
|
||||
logger.debug((
|
||||
"[CUDA GRAPH] Warm up for batch size ",
|
||||
f"{padding_batch_size}, finished ({n+1}/{entry.num_finished_warmup}) times"
|
||||
))
|
||||
logger.debug(
|
||||
f"[CUDA GRAPH] Warm up for batch size {padding_batch_size}, "
|
||||
f"finished ({n+1}/{entry.num_finished_warmup}) times"
|
||||
)
|
||||
|
||||
# Store input addresses for debug
|
||||
input_addresses = [
|
||||
|
@@ -33,7 +33,8 @@ from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend, AttentionMetadata)
|
||||
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
|
||||
from fastdeploy.model_executor.layers.attention.utils import \
|
||||
init_rank_and_device_id
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -134,8 +135,8 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
metadata.decoder_batch_ids,
|
||||
metadata.decoder_tile_ids_per_batch,
|
||||
metadata.decoder_batch_ids, # will copy to buffer
|
||||
metadata.decoder_tile_ids_per_batch, # will copy to buffer
|
||||
metadata.decoder_num_blocks,
|
||||
metadata.max_len_kv,
|
||||
metadata.set_max_lengths,
|
||||
|
@@ -53,9 +53,6 @@ class RolloutModelConfig:
|
||||
enable_expert_parallell: bool = False,
|
||||
ori_vocab_size: int = None,
|
||||
quantization: str = "None",
|
||||
enable_static_graph_inference: bool = False,
|
||||
use_cudagraph: bool = False,
|
||||
max_capture_batch_size: int = 64,
|
||||
guided_decoding_backend: str = "off",
|
||||
disable_any_whitespace: bool = True,
|
||||
enable_logprob: bool = False,
|
||||
@@ -95,9 +92,6 @@ class RolloutModelConfig:
|
||||
self.enable_expert_parallell = enable_expert_parallell
|
||||
self.ori_vocab_size = ori_vocab_size
|
||||
self.quantization = quantization
|
||||
self.enable_static_graph_inference = enable_static_graph_inference
|
||||
self.use_cudagraph = use_cudagraph
|
||||
self.max_capture_batch_size = max_capture_batch_size
|
||||
self.guided_decoding_backend = guided_decoding_backend
|
||||
self.disable_any_whitespace = disable_any_whitespace
|
||||
self.enable_logprob = enable_logprob
|
||||
|
@@ -335,66 +335,44 @@ def download_model(url, output_dir, temp_tar):
|
||||
|
||||
class FlexibleArgumentParser(argparse.ArgumentParser):
|
||||
"""
|
||||
扩展 argparse.ArgumentParser,支持从 YAML 文件加载参数。
|
||||
Extend argparse.ArgumentParser to support loading parameters from YAML files.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, config_arg='--config', sep='_', **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.sep = sep # 用于展平嵌套字典的分隔符
|
||||
# 创建临时解析器,仅用于解析 --config 参数
|
||||
self.sep = sep
|
||||
|
||||
# Create parser to prase yaml file
|
||||
self.tmp_parser = argparse.ArgumentParser(add_help=False)
|
||||
self.tmp_parser.add_argument(config_arg,
|
||||
type=str,
|
||||
help='Path to YAML config file')
|
||||
|
||||
def parse_args(self, args=None, namespace=None):
|
||||
# 使用临时解析器解析出 --config 参数
|
||||
tmp_ns, remaining_args = self.tmp_parser.parse_known_args(args=args)
|
||||
config_path = tmp_ns.config
|
||||
|
||||
# 加载 YAML 文件并展平嵌套结构
|
||||
config = {}
|
||||
if config_path:
|
||||
with open(config_path, 'r') as f:
|
||||
loaded_config = yaml.safe_load(f)
|
||||
config = self._flatten_dict(loaded_config)
|
||||
config = loaded_config
|
||||
|
||||
# 获取所有已定义参数的 dest 名称
|
||||
# Get declared parameters
|
||||
defined_dests = {action.dest for action in self._actions}
|
||||
|
||||
# 过滤出已定义的参数
|
||||
filtered_config = {
|
||||
k: v
|
||||
for k, v in config.items() if k in defined_dests
|
||||
}
|
||||
|
||||
# 创建或使用现有的命名空间对象
|
||||
# Set parameters
|
||||
if namespace is None:
|
||||
namespace = argparse.Namespace()
|
||||
|
||||
# 将配置参数设置到命名空间
|
||||
for key, value in filtered_config.items():
|
||||
setattr(namespace, key, value)
|
||||
|
||||
# 解析剩余参数并覆盖默认值
|
||||
return super().parse_args(args=remaining_args, namespace=namespace)
|
||||
|
||||
def _flatten_dict(self, d):
|
||||
"""将嵌套字典展平为单层字典,键由分隔符连接"""
|
||||
|
||||
def _flatten(d, parent_key=''):
|
||||
items = []
|
||||
for k, v in d.items():
|
||||
new_key = f"{parent_key}{self.sep}{k}" if parent_key else k
|
||||
if isinstance(v, dict):
|
||||
items.extend(_flatten(v, new_key).items())
|
||||
else:
|
||||
items.append((new_key, v))
|
||||
return dict(items)
|
||||
|
||||
return _flatten(d)
|
||||
|
||||
|
||||
def resolve_obj_from_strname(strname: str):
|
||||
module_name, obj_name = strname.rsplit(".", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
|
@@ -748,10 +748,6 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
# 3. Prepare lora
|
||||
|
||||
# 4. Run model
|
||||
is_decode_batch = not ((self.share_inputs["seq_lens_this_time"]
|
||||
> 1).sum() > 0)
|
||||
self.forward_meta.step_use_cudagraph = is_decode_batch and in_capturing
|
||||
self.forward_meta.is_decode_batch = is_decode_batch
|
||||
model_output = self.model(
|
||||
ids_remove_padding=self.share_inputs["ids_remove_padding"],
|
||||
forward_meta=self.forward_meta)
|
||||
@@ -979,10 +975,6 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
# 2. Padding inputs for cuda grph
|
||||
|
||||
# 3. Execute model
|
||||
is_decode_batch = not ((self.share_inputs["seq_lens_this_time"]
|
||||
> 1).sum() > 0)
|
||||
self.forward_meta.step_use_cudagraph = self.use_cudagraph and is_decode_batch
|
||||
self.forward_meta.is_decode_batch = is_decode_batch
|
||||
model_output = self.model(
|
||||
ids_remove_padding=self.share_inputs["ids_remove_padding"],
|
||||
forward_meta=self.forward_meta)
|
||||
|
@@ -417,7 +417,10 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["seq_lens_decoder"][idx:idx + 1] = 0
|
||||
self.share_inputs["step_idx"][idx:idx + 1] = 0
|
||||
self.share_inputs["max_dec_len"][idx:idx + 1] = max_dec_len
|
||||
self.share_inputs["min_dec_len"][idx:idx + 1] = max_dec_len
|
||||
self.share_inputs["stop_flags"][idx:idx + 1] = False
|
||||
self.share_inputs["top_p"][idx:idx + 1] = 0.0
|
||||
self.share_inputs["temperature"][idx:idx + 1] = 1
|
||||
|
||||
self.share_inputs["first_token_ids"][
|
||||
idx:idx + 1] = self.share_inputs["input_ids"][idx:idx + 1, :1]
|
||||
@@ -759,6 +762,11 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
caches=self.share_inputs["caches"]
|
||||
)
|
||||
|
||||
# Update Batch type for cuda graph
|
||||
# TODO(gongshaotian): Use seq_lens_encoder to set is_decode_batch
|
||||
is_decode_batch = not ((self.share_inputs["seq_lens_this_time"] > 1).sum() > 0)
|
||||
self.forward_meta.step_use_cudagraph = self.use_cudagraph and is_decode_batch
|
||||
|
||||
# Initialzie attention meta data
|
||||
for attn_backend in self.attn_backends:
|
||||
attn_backend.init_attention_metadata(self.forward_meta)
|
||||
@@ -850,6 +858,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
Args:
|
||||
num_tokens:
|
||||
expected_decode_len: Expected number of tokens generated
|
||||
in_capturing: Is cuda graph in capturing state
|
||||
"""
|
||||
self._dummy_prefill_inputs(num_tokens=num_tokens,
|
||||
batch_size=batch_size,
|
||||
@@ -864,17 +873,16 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
# 1. Initialize forward meta and attention meta data
|
||||
self._prepare_inputs()
|
||||
|
||||
# 2. Prepare lora
|
||||
# 2. Padding inputs for cuda graph
|
||||
self.forward_meta.step_use_cudagraph = in_capturing and self.forward_meta.step_use_cudagraph
|
||||
self.padding_cudagraph_inputs()
|
||||
|
||||
# 3. Run model
|
||||
is_decode_batch = not ((self.share_inputs["seq_lens_this_time"]
|
||||
> 1).sum() > 0)
|
||||
self.forward_meta.step_use_cudagraph = is_decode_batch and in_capturing
|
||||
self.forward_meta.is_decode_batch = is_decode_batch
|
||||
if self.enable_mm:
|
||||
hidden_states = model_output = self.model(self.share_inputs["ids_remove_padding"],
|
||||
model_output = self.model(self.share_inputs["ids_remove_padding"],
|
||||
self.share_inputs["image_features"],
|
||||
self.forward_meta)
|
||||
hidden_states = model_output
|
||||
else:
|
||||
model_output = self.model(
|
||||
ids_remove_padding=self.share_inputs["ids_remove_padding"],
|
||||
@@ -1113,9 +1121,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
We plan to replace it with 'ModelForwardBatch'.
|
||||
intermediate_tensors:
|
||||
"""
|
||||
# NOTE(wufeisheng): If `not_need_stop`` is False, it means the current worker is in an idle state.
|
||||
# This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode,
|
||||
# when there is data on other runner, the current runner is required to execute part of the model.
|
||||
# NOTE(wufeisheng): For Expert Parallelism
|
||||
if not self.not_need_stop():
|
||||
self._execute_empty_input()
|
||||
return None
|
||||
@@ -1126,18 +1132,14 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.sampler.pre_process(skip_idx_list)
|
||||
|
||||
# 2. Padding inputs for cuda graph
|
||||
self.padding_cudagraph_inputs()
|
||||
|
||||
# 3. Execute model
|
||||
# TODO(gongshaotian): Use seq_lens_encoder to set is_decode_batch
|
||||
is_decode_batch = not ((self.share_inputs["seq_lens_this_time"]
|
||||
> 1).sum() > 0)
|
||||
self.forward_meta.step_use_cudagraph = self.use_cudagraph and is_decode_batch
|
||||
self.forward_meta.is_decode_batch = is_decode_batch
|
||||
|
||||
if self.enable_mm:
|
||||
hidden_states = model_output = self.model(self.share_inputs["ids_remove_padding"],
|
||||
model_output = self.model(self.share_inputs["ids_remove_padding"],
|
||||
self.share_inputs["image_features"],
|
||||
self.forward_meta)
|
||||
hidden_states = model_output
|
||||
else:
|
||||
model_output = self.model(
|
||||
ids_remove_padding=self.share_inputs["ids_remove_padding"],
|
||||
@@ -1399,6 +1401,18 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.dynamic_weight_manager._log_memory(
|
||||
"dynamic weight manager update all memory")
|
||||
|
||||
def padding_cudagraph_inputs(self) -> None:
|
||||
"""
|
||||
Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch.
|
||||
In FastDeploy, almost all input tensors have a buffer. So, just keep the buffer clean when replaying the CUDA graph with the padded batch.
|
||||
"""
|
||||
# TODO(gongshaotian): Use more efficient implementation
|
||||
if self.forward_meta.step_use_cudagraph:
|
||||
num_empty_batch = (self.forward_meta.seq_lens_this_time == 0).sum()
|
||||
for i in range(1, num_empty_batch + 1):
|
||||
self.forward_meta.decoder_batch_ids[-i] = 0
|
||||
self.forward_meta.decoder_tile_ids_per_batch[-i] = 0
|
||||
|
||||
def _init_image_preprocess(self) -> None:
|
||||
processor = DataProcessor(
|
||||
tokenizer_name=self.tokenizer_path,
|
||||
|
@@ -715,10 +715,6 @@ class IluvatarModelRunner(ModelRunnerBase):
|
||||
# 3. Prepare lora
|
||||
|
||||
# 4. Run model
|
||||
is_decode_batch = not ((self.share_inputs["seq_lens_this_time"]
|
||||
> 1).sum() > 0)
|
||||
self.forward_meta.step_use_cudagraph = is_decode_batch and in_capturing
|
||||
self.forward_meta.is_decode_batch = is_decode_batch
|
||||
model_output = self.model(
|
||||
ids_remove_padding=self.share_inputs["ids_remove_padding"],
|
||||
forward_meta=self.forward_meta)
|
||||
@@ -939,11 +935,6 @@ class IluvatarModelRunner(ModelRunnerBase):
|
||||
# 2. Padding inputs for cuda grph
|
||||
|
||||
# 3. Execute model
|
||||
# TODO(gongshaotian): Use seq_lens_encoder to set is_decode_batch
|
||||
is_decode_batch = not ((self.share_inputs["seq_lens_this_time"]
|
||||
> 1).sum() > 0)
|
||||
self.forward_meta.step_use_cudagraph = self.use_cudagraph and is_decode_batch
|
||||
self.forward_meta.is_decode_batch = is_decode_batch
|
||||
model_output = self.model(
|
||||
ids_remove_padding=self.share_inputs["ids_remove_padding"],
|
||||
forward_meta=self.forward_meta)
|
||||
|
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
@@ -516,18 +517,11 @@ def parse_args():
|
||||
"default is None. The priority of this configuration "\
|
||||
"is lower than that of the config file. " \
|
||||
"More complex quantization methods need to be configured via the config file.")
|
||||
parser.add_argument("--enable_static_graph_inference",
|
||||
action='store_true',
|
||||
help="Whether to use static mode; if enabled, " \
|
||||
"'paddle.to_static' will be used to convert dynamic to static.")
|
||||
parser.add_argument("--use_cudagraph",
|
||||
action='store_true',
|
||||
help="Flags to enable cuda graph.")
|
||||
parser.add_argument("--max_capture_batch_size",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Maximum Batch Size for Cuda Graph Capture. " \
|
||||
"If max_capture_batch_size set 64, FastDeploy will capture batch size in [1, 64]")
|
||||
parser.add_argument("--graph_optimiaztion_config",
|
||||
type=json.loads,
|
||||
default=None,
|
||||
help=" Configation of Graph optimization backend. "
|
||||
)
|
||||
parser.add_argument("--guided_decoding_backend",
|
||||
type=str,
|
||||
default="off",
|
||||
@@ -579,9 +573,10 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
load_config = LoadConfig(vars(args))
|
||||
|
||||
graph_opt_config = GraphOptimizationConfig(
|
||||
args.enable_static_graph_inference,
|
||||
args.max_capture_batch_size,
|
||||
vars(args))
|
||||
use_cudagraph=args.graph_optimiaztion_config["use_cudagraph"],
|
||||
graph_opt_level=args.graph_optimiaztion_config["graph_opt_level"],
|
||||
cudagraph_capture_sizes=args.graph_optimiaztion_config["cudagraph_capture_sizes"]
|
||||
)
|
||||
|
||||
# Note(tangbinhan): used for load_checkpoint
|
||||
model_config.pretrained_config.tensor_parallel_rank = parallel_config.tensor_parallel_rank
|
||||
|
@@ -91,7 +91,7 @@ def setup_and_run_server():
|
||||
"--max-num-seqs", "128",
|
||||
"--quantization", "wint4",
|
||||
"--use-cudagraph",
|
||||
"--max-capture-batch-size", "1"
|
||||
"--graph-optimization-config", '{"cudagraph_capture_sizes": [1]}'
|
||||
]
|
||||
|
||||
# Start subprocess in new process group
|
||||
|
Reference in New Issue
Block a user