Unify server-side and model-side Config (Part1) (#3018)

* move cache config

* fix mtp
This commit is contained in:
YuanRisheng
2025-07-28 10:51:52 +08:00
committed by GitHub
parent 8f426c1690
commit 6ccc10ad47
23 changed files with 243 additions and 289 deletions

View File

@@ -186,12 +186,8 @@ class ParallelConfig:
self.dtype: str = "bfloat16" self.dtype: str = "bfloat16"
# Encoder's decoder num # Encoder's decoder num
self.enc_dec_block_num: int = 1 self.enc_dec_block_num: int = 1
# KV cache ratio for input
self.kv_cache_ratio: float = 0.7
# First token id # First token id
self.first_token_id: int = 1 self.first_token_id: int = 1
# Gpu memory utilization
self.gpu_memory_utilization: float = 0.9
# Process ID of engine # Process ID of engine
self.engine_pid: Optional[int] = None self.engine_pid: Optional[int] = None
# Do profile or not # Do profile or not
@@ -200,12 +196,8 @@ class ParallelConfig:
self.pad_token_id: int = -1 self.pad_token_id: int = -1
# #
self.eos_tokens_lens: int = 2 self.eos_tokens_lens: int = 2
# Enable chunked prefill
self.enable_chunked_prefill: bool = False
self.max_num_batched_tokens: int = 2048 self.max_num_batched_tokens: int = 2048
# enable prefix cache
self.enable_prefix_caching = None
# splitwise role # splitwise role
self.splitwise_role: str = "mixed" self.splitwise_role: str = "mixed"
# guided decoding backend # guided decoding backend
@@ -440,10 +432,153 @@ class LoRAConfig:
pass pass
class KVCacheConfig: class CacheConfig:
"""KV Cache Config""" """
Configuration for the KV cache.
cache_quant_dtype: str = "none" Attributes:
block_size (int): Size of a cache block in number of tokens.
gpu_memory_utilization (float): Fraction of GPU memory to use for model execution.
cache_dtype (str): Data type for kv cache storage. Default is 'bfloat16'.
num_gpu_blocks_override (Optional[int]): Number of GPU blocks to use.
Overrides profiled num_gpu_blocks if provided.
kv_cache_ratio (float): Ratio for calculating the maximum block number.
enc_dec_block_num (int): Number of encoder-decoder blocks.
prealloc_dec_block_slot_num_threshold (int): Number of token slot threadshold to allocate next blocks for decoding.
enable_prefix_caching (bool): Flag to enable prefix caching.
"""
def __init__(self, args):
"""
Initialize the CacheConfig class.
Args:
block_size (int): Size of a cache block in number of tokens.
gpu_memory_utilization (float): Fraction of GPU memory to use.
cache_dtype (str): Data type for cache storage. Default is 'bfloat16'.
num_gpu_blocks_override (Optional[int]): Override for number of GPU blocks.
num_cpu_blocks (Optional[int]): Number of CPU blocks.
kv_cache_ratio (float): Ratio for max block calculation.
enc_dec_block_num (int): Number of encoder-decoder blocks.
prealloc_dec_block_slot_num_threshold (int): Number of token slot threadshold to allocate next blocks for decoding, used when ENABLE_V1_KVCACHE_SCHEDULER=1.
enable_prefix_caching (bool): Enable prefix caching.
"""
self.block_size = 64
self.gpu_memory_utilization = 0.9
self.num_gpu_blocks_override = None
self.kv_cache_ratio = 0.75
self.enc_dec_block_num = 2
self.prealloc_dec_block_slot_num_threshold = 5
self.cache_dtype = "bfloat16"
self.model_cfg = None
self.enable_chunked_prefill = False
self.rdma_comm_ports = None
self.cache_transfer_protocol = None
self.pd_comm_port = None
self.enable_prefix_caching = False
self.enable_ssd_cache = False
self.cache_queue_port = None
self.swap_space = None
for key, value in args.items():
if hasattr(self, key):
setattr(self, key, value)
if self.rdma_comm_ports is not None and isinstance(self.rdma_comm_ports, str):
self.rdma_comm_ports = self.rdma_comm_ports.split(",")
if self.pd_comm_port is not None and isinstance(self.pd_comm_port, str):
self.pd_comm_port = [int(port) for port in self.pd_comm_port.split(",")]
if self.swap_space is None:
self.enable_hierarchical_cache = False
else:
self.enable_hierarchical_cache = True
if self.model_cfg is not None:
if hasattr(self.model_cfg, "quantization_config"):
self.cache_dtype = self.model_cfg.quantization_config.get("kv_cache_quant_type", self.cache_dtype)
if (
hasattr(self.model_cfg, "num_key_value_heads")
and hasattr(self.model_cfg, "num_key_value_heads")
and self.model_cfg.num_key_value_heads is not None
and int(self.model_cfg.num_key_value_heads) > 0
):
kv_num_head = int(self.model_cfg.num_key_value_heads)
else:
kv_num_head = self.model_cfg.num_attention_heads
self.model_cfg.kv_num_head = kv_num_head
# TODO check name
if "int4" in self.cache_dtype.lower() or "float4" in self.cache_dtype.lower():
byte_size = 0.5
self.cache_dtype = "uint8"
elif "int8" in self.cache_dtype.lower() or "float8" in self.cache_dtype.lower():
self.cache_dtype = "uint8"
byte_size = 1
else:
byte_size = 2
self.each_token_cache_space = int(
self.model_cfg.num_layers * kv_num_head * self.model_cfg.head_dim * byte_size
)
self.bytes_per_block = int(self.each_token_cache_space * self.block_size)
self.bytes_per_layer_per_block = int(
self.block_size
* self.model_cfg.kv_num_head
* self.model_cfg.head_dim
// args["tensor_parallel_size"]
* byte_size
)
if self.swap_space is None:
self.num_cpu_blocks = 0
else:
self.num_cpu_blocks = int(self.swap_space * 1024**3 / self.bytes_per_block)
self._verify_args()
def metrics_info(self):
"""Convert cache_config to dict(key: str, value: str) for prometheus metrics info."""
return {key: str(value) for key, value in self.__dict__.items()}
def _verify_args(self):
if self.gpu_memory_utilization > 1.0:
raise ValueError("GPU memory utilization must be less than 1.0. Got " f"{self.gpu_memory_utilization}.")
if self.kv_cache_ratio > 1.0:
raise ValueError("KV cache ratio must be less than 1.0. Got " f"{self.kv_cache_ratio}.")
def postprocess(self, num_total_tokens, number_of_tasks):
"""
calculate block num
"""
self.dec_token_num = self.enc_dec_block_num * self.block_size
if self.num_gpu_blocks_override is not None:
self.total_block_num = self.num_gpu_blocks_override
self.prefill_kvcache_block_num = int(self.total_block_num * self.kv_cache_ratio)
else:
length = num_total_tokens // number_of_tasks
block_num = (length + self.block_size - 1 + self.dec_token_num) // self.block_size
self.total_block_num = block_num * number_of_tasks
self.prefill_kvcache_block_num = self.total_block_num
logger.info(f"Doing profile, the total_block_num:{self.total_block_num}")
def reset(self, num_gpu_blocks):
"""
reset gpu block number
"""
self.total_block_num = num_gpu_blocks
self.prefill_kvcache_block_num = int(self.total_block_num * self.kv_cache_ratio)
logger.info(
f"Reset block num, the total_block_num:{self.total_block_num},"
f" prefill_kvcache_block_num:{self.prefill_kvcache_block_num}"
)
def print(self):
"""
print all config
"""
logger.info("Cache Configuration Information :")
for k, v in self.__dict__.items():
logger.info("{:<20}:{:<6}{}".format(k, "", v))
logger.info("=============================================================")
class DecodingConfig: class DecodingConfig:
@@ -477,7 +612,7 @@ class FDConfig:
quant_config: Optional[QuantConfigBase] = None quant_config: Optional[QuantConfigBase] = None
graph_opt_config: Optional[GraphOptimizationConfig] = None graph_opt_config: Optional[GraphOptimizationConfig] = None
decoding_config: DecodingConfig = field(default=None, init=True) # type: ignore decoding_config: DecodingConfig = field(default=None, init=True) # type: ignore
kv_cache_config: KVCacheConfig = field(default=None, init=True) # type: ignore cache_config: CacheConfig = field(default=None, init=True) # type: ignore
def __post_init__(self): def __post_init__(self):
# Initialize cuda graph capture list # Initialize cuda graph capture list

View File

@@ -19,8 +19,8 @@ from dataclasses import asdict, dataclass
from dataclasses import fields as dataclass_fields from dataclasses import fields as dataclass_fields
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from fastdeploy.config import CacheConfig
from fastdeploy.engine.config import ( from fastdeploy.engine.config import (
CacheConfig,
Config, Config,
GraphOptimizationConfig, GraphOptimizationConfig,
ModelConfig, ModelConfig,
@@ -770,28 +770,6 @@ class EngineArgs:
load_strategy=self.load_strategy, load_strategy=self.load_strategy,
) )
def create_cache_config(self, model_cfg) -> CacheConfig:
"""
Create and return a CacheConfig object based on the current settings.
"""
return CacheConfig(
block_size=self.block_size,
tensor_parallel_size=self.tensor_parallel_size,
gpu_memory_utilization=self.gpu_memory_utilization,
num_gpu_blocks_override=self.num_gpu_blocks_override,
kv_cache_ratio=self.kv_cache_ratio,
prealloc_dec_block_slot_num_threshold=self.prealloc_dec_block_slot_num_threshold,
enable_prefix_caching=self.enable_prefix_caching,
swap_space=self.swap_space,
cache_queue_port=self.cache_queue_port,
model_cfg=model_cfg,
enable_chunked_prefill=self.enable_chunked_prefill,
enc_dec_block_num=self.static_decode_blocks,
rdma_comm_ports=self.rdma_comm_ports,
cache_transfer_protocol=self.cache_transfer_protocol,
pd_comm_port=self.pd_comm_port,
)
def create_speculative_config(self) -> SpeculativeConfig: def create_speculative_config(self) -> SpeculativeConfig:
""" """ """ """
if self.speculative_config is not None: if self.speculative_config is not None:
@@ -864,12 +842,16 @@ class EngineArgs:
self.tensor_parallel_size <= 1 and self.enable_custom_all_reduce self.tensor_parallel_size <= 1 and self.enable_custom_all_reduce
), "enable_custom_all_reduce must be used with tensor_parallel_size>1" ), "enable_custom_all_reduce must be used with tensor_parallel_size>1"
all_dict = asdict(self)
all_dict["model_cfg"] = model_cfg
cache_cfg = CacheConfig(all_dict)
return Config( return Config(
model_name_or_path=self.model, model_name_or_path=self.model,
model_config=model_cfg, model_config=model_cfg,
scheduler_config=scheduler_cfg, scheduler_config=scheduler_cfg,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
cache_config=self.create_cache_config(model_cfg), cache_config=cache_cfg,
parallel_config=self.create_parallel_config(), parallel_config=self.create_parallel_config(),
max_model_len=self.max_model_len, max_model_len=self.max_model_len,
tensor_parallel_size=self.tensor_parallel_size, tensor_parallel_size=self.tensor_parallel_size,

View File

@@ -20,6 +20,7 @@ from datetime import datetime
from typing import Any, Dict, List, Literal, Optional from typing import Any, Dict, List, Literal, Optional
from fastdeploy import envs from fastdeploy import envs
from fastdeploy.config import CacheConfig
from fastdeploy.platforms import current_platform from fastdeploy.platforms import current_platform
from fastdeploy.scheduler import SchedulerConfig from fastdeploy.scheduler import SchedulerConfig
from fastdeploy.utils import ( from fastdeploy.utils import (
@@ -157,170 +158,6 @@ class ModelConfig:
llm_logger.info("=============================================================") llm_logger.info("=============================================================")
class CacheConfig:
"""
Configuration for the KV cache.
Attributes:
block_size (int): Size of a cache block in number of tokens.
gpu_memory_utilization (float): Fraction of GPU memory to use for model execution.
cache_dtype (str): Data type for kv cache storage. Default is 'bfloat16'.
num_gpu_blocks_override (Optional[int]): Number of GPU blocks to use.
Overrides profiled num_gpu_blocks if provided.
kv_cache_ratio (float): Ratio for calculating the maximum block number.
enc_dec_block_num (int): Number of encoder-decoder blocks.
prealloc_dec_block_slot_num_threshold (int): Number of token slot threadshold to allocate next blocks for decoding.
enable_prefix_caching (bool): Flag to enable prefix caching.
"""
def __init__(
self,
block_size: int,
gpu_memory_utilization: float,
cache_dtype: str = "bfloat16",
num_gpu_blocks_override: Optional[int] = None,
swap_space: Optional[int] = None,
kv_cache_ratio: float = 0.75,
enc_dec_block_num: int = 2,
prealloc_dec_block_slot_num_threshold: int = 5,
tensor_parallel_size: int = 1,
enable_prefix_caching=False,
enable_ssd_cache=False,
model_cfg=None,
cache_queue_port=None,
enable_chunked_prefill=False,
rdma_comm_ports=None,
cache_transfer_protocol=None,
pd_comm_port=None,
):
"""
Initialize the CacheConfig class.
Args:
block_size (int): Size of a cache block in number of tokens.
gpu_memory_utilization (float): Fraction of GPU memory to use.
cache_dtype (str): Data type for cache storage. Default is 'bfloat16'.
num_gpu_blocks_override (Optional[int]): Override for number of GPU blocks.
num_cpu_blocks (Optional[int]): Number of CPU blocks.
kv_cache_ratio (float): Ratio for max block calculation.
enc_dec_block_num (int): Number of encoder-decoder blocks.
prealloc_dec_block_slot_num_threshold (int): Number of token slot threadshold to allocate next blocks for decoding, used when ENABLE_V1_KVCACHE_SCHEDULER=1.
enable_prefix_caching (bool): Enable prefix caching.
"""
self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization
self.num_gpu_blocks_override = num_gpu_blocks_override
self.kv_cache_ratio = kv_cache_ratio
self.enc_dec_block_num = enc_dec_block_num
self.prealloc_dec_block_slot_num_threshold = prealloc_dec_block_slot_num_threshold
self.cache_dtype = cache_dtype
if hasattr(model_cfg, "quantization_config"):
self.cache_dtype = model_cfg.quantization_config.get("kv_cache_quant_type", cache_dtype)
self.enable_chunked_prefill = enable_chunked_prefill
self.rdma_comm_ports = rdma_comm_ports
self.cache_transfer_protocol = cache_transfer_protocol
self.pd_comm_port = pd_comm_port
if rdma_comm_ports is not None and isinstance(rdma_comm_ports, str):
self.rdma_comm_ports = rdma_comm_ports.split(",")
if pd_comm_port is not None and isinstance(pd_comm_port, str):
self.pd_comm_port = [int(port) for port in pd_comm_port.split(",")]
self.enable_prefix_caching = enable_prefix_caching
if swap_space is None:
self.enable_hierarchical_cache = False
else:
self.enable_hierarchical_cache = True
self.enable_ssd_cache = enable_ssd_cache
self.model_cfg = model_cfg
self.cache_queue_port = cache_queue_port
self.swap_space = swap_space
if (
hasattr(self.model_cfg, "num_key_value_heads")
and hasattr(self.model_cfg, "num_key_value_heads")
and self.model_cfg.num_key_value_heads is not None
and int(self.model_cfg.num_key_value_heads) > 0
):
kv_num_head = int(self.model_cfg.num_key_value_heads)
else:
kv_num_head = self.model_cfg.num_attention_heads
self.model_cfg.kv_num_head = kv_num_head
# TODO check name
if "int4" in self.cache_dtype.lower() or "float4" in self.cache_dtype.lower():
byte_size = 0.5
self.cache_dtype = "uint8"
elif "int8" in self.cache_dtype.lower() or "float8" in self.cache_dtype.lower():
self.cache_dtype = "uint8"
byte_size = 1
else:
byte_size = 2
self.each_token_cache_space = int(
self.model_cfg.num_layers * kv_num_head * self.model_cfg.head_dim * byte_size
)
self.bytes_per_block = int(self.each_token_cache_space * self.block_size)
self.bytes_per_layer_per_block = int(
self.block_size * self.model_cfg.kv_num_head * self.model_cfg.head_dim // tensor_parallel_size * byte_size
)
if self.swap_space is None:
self.num_cpu_blocks = 0
else:
self.num_cpu_blocks = int(self.swap_space * 1024**3 / self.bytes_per_block)
self._verify_args()
def metrics_info(self):
"""Convert cache_config to dict(key: str, value: str) for prometheus metrics info."""
return {key: str(value) for key, value in self.__dict__.items()}
def _verify_args(self):
if self.gpu_memory_utilization > 1.0:
raise ValueError("GPU memory utilization must be less than 1.0. Got " f"{self.gpu_memory_utilization}.")
if self.kv_cache_ratio > 1.0:
raise ValueError("KV cache ratio must be less than 1.0. Got " f"{self.kv_cache_ratio}.")
def postprocess(self, num_total_tokens, number_of_tasks):
"""
calculate block num
"""
self.dec_token_num = self.enc_dec_block_num * self.block_size
if self.num_gpu_blocks_override is not None:
self.total_block_num = self.num_gpu_blocks_override
self.prefill_kvcache_block_num = int(self.total_block_num * self.kv_cache_ratio)
else:
length = num_total_tokens // number_of_tasks
block_num = (length + self.block_size - 1 + self.dec_token_num) // self.block_size
self.total_block_num = block_num * number_of_tasks
self.prefill_kvcache_block_num = self.total_block_num
llm_logger.info(f"Doing profile, the total_block_num:{self.total_block_num}")
def reset(self, num_gpu_blocks):
"""
reset gpu block number
"""
self.total_block_num = num_gpu_blocks
self.prefill_kvcache_block_num = int(self.total_block_num * self.kv_cache_ratio)
llm_logger.info(
f"Reset block num, the total_block_num:{self.total_block_num},"
f" prefill_kvcache_block_num:{self.prefill_kvcache_block_num}"
)
def print(self):
"""
print all config
"""
llm_logger.info("Cache Configuration Information :")
for k, v in self.__dict__.items():
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
llm_logger.info("=============================================================")
class SpeculativeConfig: class SpeculativeConfig:
""" """
Speculative Decoding Configuration class. Speculative Decoding Configuration class.

View File

@@ -95,7 +95,7 @@ class AppendAttentionBackend(AttentionBackend):
""" """
super().__init__() super().__init__()
self.attention_metadata: AppendAttentionMetadata = None self.attention_metadata: AppendAttentionMetadata = None
self.block_size: int = fd_config.parallel_config.block_size self.block_size: int = fd_config.cache_config.block_size
self.max_seq_len: int = fd_config.parallel_config.max_model_len self.max_seq_len: int = fd_config.parallel_config.max_model_len
self.rope_theta: float = ( self.rope_theta: float = (
10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta 10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta

View File

@@ -85,7 +85,7 @@ class BlockAttentionBackend(AttentionBackend):
""" """
super().__init__() super().__init__()
self.attention_metadata: BlockAttentionMetadata = None self.attention_metadata: BlockAttentionMetadata = None
self.block_size = fd_config.parallel_config.block_size self.block_size = fd_config.cache_config.block_size
self.max_seq_len = fd_config.parallel_config.max_model_len self.max_seq_len = fd_config.parallel_config.max_model_len
self.rope_theta = 10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta self.rope_theta = 10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta
self.rank = fd_config.parallel_config.tensor_parallel_rank self.rank = fd_config.parallel_config.tensor_parallel_rank

View File

@@ -113,7 +113,7 @@ class FlashAttentionBackend(AttentionBackend):
self.num_heads = num_heads self.num_heads = num_heads
self.head_dim = fd_config.model_config.head_dim self.head_dim = fd_config.model_config.head_dim
self.attn_outputsize_tp = self.num_heads * self.head_dim self.attn_outputsize_tp = self.num_heads * self.head_dim
self.block_size = fd_config.parallel_config.block_size self.block_size = fd_config.cache_config.block_size
self.num_layers: int = fd_config.model_config.num_hidden_layers self.num_layers: int = fd_config.model_config.num_hidden_layers
self.speculative_method = fd_config.speculative_config.method self.speculative_method = fd_config.speculative_config.method

View File

@@ -94,8 +94,8 @@ class IluvatarAttnBackend(AttentionBackend):
): ):
super().__init__() super().__init__()
self.attention_metadata = IluvatarAttentionMetadata() self.attention_metadata = IluvatarAttentionMetadata()
self.attention_metadata.block_size = llm_config.parallel_config.block_size self.attention_metadata.block_size = llm_config.cache_config.block_size
assert llm_config.parallel_config.enc_dec_block_num == 0, "Iluvatar does not support yet" assert llm_config.cache_config.enc_dec_block_num == 0, "Iluvatar does not support yet"
self.attention_metadata.max_context_len = llm_config.parallel_config.max_model_len self.attention_metadata.max_context_len = llm_config.parallel_config.max_model_len
self.attention_metadata.causal = getattr(llm_config.model_config, "causal", True) self.attention_metadata.causal = getattr(llm_config.model_config, "causal", True)

View File

@@ -113,7 +113,7 @@ class MLAAttentionBackend(AttentionBackend):
self.attention_metadata: MLAAttentionMetadata = None self.attention_metadata: MLAAttentionMetadata = None
# 基础配置 # 基础配置
self.block_size: int = fd_config.parallel_config.block_size self.block_size: int = fd_config.cache_config.block_size
self.max_seq_len: int = fd_config.parallel_config.max_model_len self.max_seq_len: int = fd_config.parallel_config.max_model_len
self.rope_theta: float = ( self.rope_theta: float = (
10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta 10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta

View File

@@ -92,7 +92,7 @@ class XPUAttentionBackend(AttentionBackend):
super().__init__() super().__init__()
self.attention_metadata: XPUAttentionMetadata = None self.attention_metadata: XPUAttentionMetadata = None
# TODO(gongshaotian): Use fd_config parameters in the correct location # TODO(gongshaotian): Use fd_config parameters in the correct location
self.block_size: int = fd_config.parallel_config.block_size self.block_size: int = fd_config.cache_config.block_size
self.max_seq_len: int = fd_config.parallel_config.max_model_len self.max_seq_len: int = fd_config.parallel_config.max_model_len
self.rope_theta: float = ( self.rope_theta: float = (
10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta 10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta

View File

@@ -82,7 +82,7 @@ class GCUFlashAttnBackend(AttentionBackend):
""" """
super().__init__() super().__init__()
self.attention_metadata: GCUFlashAttnMetadata = None self.attention_metadata: GCUFlashAttnMetadata = None
self.block_size = fd_config.parallel_config.block_size self.block_size = fd_config.cache_config.block_size
self.max_seq_len = fd_config.parallel_config.max_model_len self.max_seq_len = fd_config.parallel_config.max_model_len
self.max_num_seqs = fd_config.parallel_config.max_num_seqs self.max_num_seqs = fd_config.parallel_config.max_num_seqs

View File

@@ -80,7 +80,7 @@ class GCUMemEfficientAttnBackend(AttentionBackend):
""" """
super().__init__() super().__init__()
self.attention_metadata: GCUMemEfficientAttnMetadata = None self.attention_metadata: GCUMemEfficientAttnMetadata = None
self.block_size = fd_config.parallel_config.block_size self.block_size = fd_config.cache_config.block_size
self.max_seq_len = fd_config.parallel_config.max_model_len self.max_seq_len = fd_config.parallel_config.max_model_len
self.max_num_seqs = fd_config.parallel_config.max_num_seqs self.max_num_seqs = fd_config.parallel_config.max_num_seqs

View File

@@ -38,7 +38,7 @@ class Proposer(ABC):
self.parallel_config = self.cfg.parallel_config self.parallel_config = self.cfg.parallel_config
self.model_config = self.cfg.model_config self.model_config = self.cfg.model_config
self.speculative_config = self.cfg.speculative_config self.speculative_config = self.cfg.speculative_config
self.kv_cache_config = self.cfg.kv_cache_config self.cache_config = self.cfg.cache_config
self.quant_config = self.cfg.quant_config self.quant_config = self.cfg.quant_config
self.max_num_seqs = self.parallel_config.max_num_seqs self.max_num_seqs = self.parallel_config.max_num_seqs

View File

@@ -97,10 +97,10 @@ class MTPProposer(Proposer):
num_tokens // batch_size, num_tokens // batch_size,
self.parallel_config.max_model_len - max_dec_len, self.parallel_config.max_model_len - max_dec_len,
) )
input_length = int(full_length * self.parallel_config.kv_cache_ratio) input_length = int(full_length * self.cache_config.kv_cache_ratio)
block_num = ( block_num = (
input_length + self.parallel_config.block_size - 1 input_length + self.cache_config.block_size - 1
) // self.parallel_config.block_size + self.parallel_config.enc_dec_block_num ) // self.cache_config.block_size + self.cache_config.enc_dec_block_num
for i in range(batch_size): for i in range(batch_size):
idx = i idx = i
@@ -141,7 +141,7 @@ class MTPProposer(Proposer):
max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
) )
if not self.parallel_config.do_profile and ( if not self.parallel_config.do_profile and (
self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed" self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"
): ):
cache_kvs_list = [] cache_kvs_list = []
for i in range( for i in range(
@@ -219,14 +219,14 @@ class MTPProposer(Proposer):
self.main_model_num_gpu_blocks = num_gpu_blocks self.main_model_num_gpu_blocks = num_gpu_blocks
self.num_gpu_blocks = int(num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio) self.num_gpu_blocks = int(num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio)
if not (self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"): if not (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"):
self.initialize_kv_cache() self.initialize_kv_cache()
# Reset free list # Reset free list
free_list = list( free_list = list(
range( range(
self.num_gpu_blocks - 1, self.num_gpu_blocks - 1,
int(self.main_model_num_gpu_blocks * self.parallel_config.kv_cache_ratio) - 1, int(self.main_model_num_gpu_blocks * self.cache_config.kv_cache_ratio) - 1,
-1, -1,
) )
) )
@@ -299,7 +299,7 @@ class MTPProposer(Proposer):
self.free_list = list( self.free_list = list(
range( range(
self.parallel_config.total_block_num - 1, self.parallel_config.total_block_num - 1,
int(self.parallel_config.total_block_num * self.parallel_config.kv_cache_ratio) - 1, int(self.parallel_config.total_block_num * self.cache_config.kv_cache_ratio) - 1,
-1, -1,
) )
) )
@@ -371,7 +371,7 @@ class MTPProposer(Proposer):
] ]
self.model_inputs["pre_ids"][idx : idx + 1] = -1 self.model_inputs["pre_ids"][idx : idx + 1] = -1
self.model_inputs["step_idx"][idx : idx + 1] = 0 self.model_inputs["step_idx"][idx : idx + 1] = 0
if self.parallel_config.enable_chunked_prefill: if self.cache_config.enable_chunked_prefill:
token_chunk_size = request.prefill_chunk_info[0] token_chunk_size = request.prefill_chunk_info[0]
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size self.model_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size
self.model_inputs["seq_lens_this_time"][idx : idx + 1] = token_chunk_size self.model_inputs["seq_lens_this_time"][idx : idx + 1] = token_chunk_size
@@ -640,7 +640,7 @@ class MTPProposer(Proposer):
self.model_inputs["used_list_len"], self.model_inputs["used_list_len"],
self.model_inputs["free_list"], self.model_inputs["free_list"],
self.model_inputs["free_list_len"], self.model_inputs["free_list_len"],
self.parallel_config.block_size, self.cache_config.block_size,
self.max_draft_token_num, self.max_draft_token_num,
) )

View File

@@ -89,9 +89,7 @@ class DcuWorker(GpuWorker):
paddle.device.cuda.empty_cache() paddle.device.cuda.empty_cache()
paddle_peak_increase = paddle_reserved_mem_after_run - paddle_allocated_mem_before_run paddle_peak_increase = paddle_reserved_mem_after_run - paddle_allocated_mem_before_run
available_kv_cache_memory = ( available_kv_cache_memory = (
total_gpu_memory * self.parallel_config.gpu_memory_utilization total_gpu_memory * self.cache_config.gpu_memory_utilization - after_used_gpu_memory - paddle_peak_increase
- after_used_gpu_memory
- paddle_peak_increase
) )
available_kv_cache_memory += model_block_memory_used * self.parallel_config.total_block_num available_kv_cache_memory += model_block_memory_used * self.parallel_config.total_block_num

View File

@@ -207,7 +207,7 @@ class GCUModelRunner(ModelRunnerBase):
self.share_inputs["prompt_ids"][idx : idx + 1, :length] = np.array(request.prompt_token_ids) self.share_inputs["prompt_ids"][idx : idx + 1, :length] = np.array(request.prompt_token_ids)
# Use chunked prefill # Use chunked prefill
if self.parallel_config.enable_chunked_prefill: if self.cache_config.enable_chunked_prefill:
request.set("chunk_idx", 1) request.set("chunk_idx", 1)
logger.info(f"prefill_chunk_info: {request.prefill_chunk_info}") logger.info(f"prefill_chunk_info: {request.prefill_chunk_info}")
token_chunk_size = request.prefill_chunk_info[0] token_chunk_size = request.prefill_chunk_info[0]
@@ -287,10 +287,10 @@ class GCUModelRunner(ModelRunnerBase):
num_tokens // batch_size, num_tokens // batch_size,
self.parallel_config.max_model_len - max_dec_len, self.parallel_config.max_model_len - max_dec_len,
) )
input_length = int(full_length * self.parallel_config.kv_cache_ratio) input_length = int(full_length * self.cache_config.kv_cache_ratio)
block_num = ( block_num = (
input_length + self.parallel_config.block_size - 1 input_length + self.cache_config.block_size - 1
) // self.parallel_config.block_size + self.parallel_config.enc_dec_block_num ) // self.cache_config.block_size + self.cache_config.enc_dec_block_num
for i in range(batch_size): for i in range(batch_size):
idx = i idx = i
@@ -417,15 +417,15 @@ class GCUModelRunner(ModelRunnerBase):
# Set block tables # Set block tables
pre_max_block_num = ( pre_max_block_num = (
self.parallel_config.max_model_len + self.parallel_config.block_size - 1 self.parallel_config.max_model_len + self.cache_config.block_size - 1
) // self.parallel_config.block_size + self.parallel_config.enc_dec_block_num ) // self.cache_config.block_size + self.cache_config.enc_dec_block_num
self.share_inputs["block_tables"] = paddle.full([max_num_seqs, pre_max_block_num], -1, dtype="int32") self.share_inputs["block_tables"] = paddle.full([max_num_seqs, pre_max_block_num], -1, dtype="int32")
# Initialize free list # Initialize free list
free_list = list( free_list = list(
range( range(
self.parallel_config.total_block_num - 1, self.parallel_config.total_block_num - 1,
int(self.parallel_config.total_block_num * self.parallel_config.kv_cache_ratio) - 1, int(self.parallel_config.total_block_num * self.cache_config.kv_cache_ratio) - 1,
-1, -1,
) )
) )
@@ -608,9 +608,7 @@ class GCUModelRunner(ModelRunnerBase):
) )
# local_rank = self.local_rank % self.parallel_config.tensor_parallel_size # local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
if not profile and ( if not profile and (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"):
self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"
):
raise NotImplementedError("prefix_caching is not support by GCUModelRunner.") raise NotImplementedError("prefix_caching is not support by GCUModelRunner.")
else: else:
for i in range(self.model_config.num_hidden_layers): for i in range(self.model_config.num_hidden_layers):
@@ -795,7 +793,7 @@ class GCUModelRunner(ModelRunnerBase):
""" """
Update chunked prefill related parameters Update chunked prefill related parameters
""" """
if not self.parallel_config.enable_chunked_prefill: if not self.cache_config.enable_chunked_prefill:
return return
for task in tasks: for task in tasks:
if task.get("prefill_chunk_info", None) is None: if task.get("prefill_chunk_info", None) is None:
@@ -861,7 +859,7 @@ class GCUModelRunner(ModelRunnerBase):
A list of indices corresponding to the requests that need to be skipped. A list of indices corresponding to the requests that need to be skipped.
""" """
skip_idx_list = [] skip_idx_list = []
if not self.parallel_config.enable_chunked_prefill or self.guided_backend is None: if not self.cache_config.enable_chunked_prefill or self.guided_backend is None:
return skip_idx_list return skip_idx_list
for task in model_forward_batch: for task in model_forward_batch:
@@ -1079,7 +1077,7 @@ class GCUModelRunner(ModelRunnerBase):
free_list = list( free_list = list(
range( range(
self.num_gcu_blocks - 1, self.num_gcu_blocks - 1,
int(self.num_gcu_blocks * self.parallel_config.kv_cache_ratio) - 1, int(self.num_gcu_blocks * self.cache_config.kv_cache_ratio) - 1,
-1, -1,
) )
) )
@@ -1123,7 +1121,7 @@ class GCUModelRunner(ModelRunnerBase):
if self.speculative_method in ["mtp"] if self.speculative_method in ["mtp"]
else self.model_config.num_hidden_layers else self.model_config.num_hidden_layers
) )
required_memory = byte_of_dtype * 2 * (self.parallel_config.block_size * hidden_dim) * num_layers # k + v required_memory = byte_of_dtype * 2 * (self.cache_config.block_size * hidden_dim) * num_layers # k + v
return required_memory return required_memory
def not_need_stop(self) -> bool: def not_need_stop(self) -> bool:

View File

@@ -339,7 +339,7 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["prompt_ids"][idx : idx + 1, :length] = np.array(request.prompt_token_ids) self.share_inputs["prompt_ids"][idx : idx + 1, :length] = np.array(request.prompt_token_ids)
# Use chunked prefill # Use chunked prefill
if self.parallel_config.enable_chunked_prefill: if self.cache_config.enable_chunked_prefill:
request.set("chunk_idx", 1) request.set("chunk_idx", 1)
logger.info(f"prefill_chunk_info: {request.prefill_chunk_info}") logger.info(f"prefill_chunk_info: {request.prefill_chunk_info}")
token_chunk_size = request.prefill_chunk_info[0] token_chunk_size = request.prefill_chunk_info[0]
@@ -467,10 +467,10 @@ class GPUModelRunner(ModelRunnerBase):
num_tokens // batch_size, num_tokens // batch_size,
self.parallel_config.max_model_len - max_dec_len, self.parallel_config.max_model_len - max_dec_len,
) )
input_length = int(full_length * self.parallel_config.kv_cache_ratio) input_length = int(full_length * self.cache_config.kv_cache_ratio)
block_num = ( block_num = (
input_length + self.parallel_config.block_size - 1 input_length + self.cache_config.block_size - 1
) // self.parallel_config.block_size + self.parallel_config.enc_dec_block_num ) // self.cache_config.block_size + self.cache_config.enc_dec_block_num
for i in range(batch_size): for i in range(batch_size):
idx = i idx = i
@@ -602,15 +602,15 @@ class GPUModelRunner(ModelRunnerBase):
# Set block tables # Set block tables
pre_max_block_num = ( pre_max_block_num = (
self.parallel_config.max_model_len + self.parallel_config.block_size - 1 self.parallel_config.max_model_len + self.cache_config.block_size - 1
) // self.parallel_config.block_size + self.parallel_config.enc_dec_block_num ) // self.cache_config.block_size + self.cache_config.enc_dec_block_num
self.share_inputs["block_tables"] = paddle.full([max_num_seqs, pre_max_block_num], -1, dtype="int32") self.share_inputs["block_tables"] = paddle.full([max_num_seqs, pre_max_block_num], -1, dtype="int32")
# Initialize free list # Initialize free list
free_list = list( free_list = list(
range( range(
self.parallel_config.total_block_num - 1, self.parallel_config.total_block_num - 1,
int(self.parallel_config.total_block_num * self.parallel_config.kv_cache_ratio) - 1, int(self.parallel_config.total_block_num * self.cache_config.kv_cache_ratio) - 1,
-1, -1,
) )
) )
@@ -689,7 +689,7 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["step_seq_lens_decoder"], self.share_inputs["step_seq_lens_decoder"],
self.share_inputs["block_tables"], self.share_inputs["block_tables"],
self.share_inputs["is_block_step"], self.share_inputs["is_block_step"],
self.parallel_config.block_size, self.cache_config.block_size,
) )
# Remove padding # Remove padding
@@ -833,9 +833,7 @@ class GPUModelRunner(ModelRunnerBase):
) )
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
if not profile and ( if not profile and (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"):
self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"
):
cache_kvs_list = [] cache_kvs_list = []
for i in range(self.model_config.num_hidden_layers): for i in range(self.model_config.num_hidden_layers):
key_cache = paddle.empty(shape=[], dtype=cache_type) key_cache = paddle.empty(shape=[], dtype=cache_type)
@@ -1015,7 +1013,7 @@ class GPUModelRunner(ModelRunnerBase):
sampler_output=sampler_output, sampler_output=sampler_output,
model_output=model_output_data, model_output=model_output_data,
share_inputs=self.share_inputs, share_inputs=self.share_inputs,
block_size=self.parallel_config.block_size, block_size=self.cache_config.block_size,
speculative_decoding=self.speculative_decoding, speculative_decoding=self.speculative_decoding,
skip_save_output=True, skip_save_output=True,
) )
@@ -1031,10 +1029,10 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED
step_cuda( step_cuda(
self.share_inputs, self.share_inputs,
self.parallel_config.block_size, self.cache_config.block_size,
self.parallel_config.enc_dec_block_num, self.cache_config.enc_dec_block_num,
self.speculative_config, self.speculative_config,
self.parallel_config.enable_prefix_caching, self.cache_config.enable_prefix_caching,
) )
if int((self.share_inputs["seq_lens_this_time"] > 0).sum()) == 0: if int((self.share_inputs["seq_lens_this_time"] > 0).sum()) == 0:
@@ -1044,7 +1042,7 @@ class GPUModelRunner(ModelRunnerBase):
""" """
Update chunked prefill related parameters Update chunked prefill related parameters
""" """
if not self.parallel_config.enable_chunked_prefill: if not self.cache_config.enable_chunked_prefill:
return return
for task in tasks: for task in tasks:
if task.get("prefill_chunk_info", None) is None: if task.get("prefill_chunk_info", None) is None:
@@ -1144,7 +1142,7 @@ class GPUModelRunner(ModelRunnerBase):
A list of indices corresponding to the requests that need to be skipped. A list of indices corresponding to the requests that need to be skipped.
""" """
skip_idx_list = [] skip_idx_list = []
if not self.parallel_config.enable_chunked_prefill or self.guided_backend is None: if not self.cache_config.enable_chunked_prefill or self.guided_backend is None:
return skip_idx_list return skip_idx_list
for task in model_forward_batch: for task in model_forward_batch:
@@ -1283,7 +1281,7 @@ class GPUModelRunner(ModelRunnerBase):
sampler_output=sampler_output, sampler_output=sampler_output,
model_output=model_output_data, model_output=model_output_data,
share_inputs=self.share_inputs, share_inputs=self.share_inputs,
block_size=self.parallel_config.block_size, block_size=self.cache_config.block_size,
save_each_rank=self.parallel_config.use_ep, save_each_rank=self.parallel_config.use_ep,
speculative_decoding=self.speculative_decoding, speculative_decoding=self.speculative_decoding,
skip_save_output=skip_save_output, skip_save_output=skip_save_output,
@@ -1302,10 +1300,10 @@ class GPUModelRunner(ModelRunnerBase):
if not envs.ENABLE_V1_KVCACHE_SCHEDULER: if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
step_cuda( step_cuda(
self.share_inputs, self.share_inputs,
self.parallel_config.block_size, self.cache_config.block_size,
self.parallel_config.enc_dec_block_num, self.cache_config.enc_dec_block_num,
self.speculative_config, self.speculative_config,
self.parallel_config.enable_prefix_caching, self.cache_config.enable_prefix_caching,
) )
self._update_chunked_prefill(model_forward_batch) self._update_chunked_prefill(model_forward_batch)
@@ -1379,7 +1377,7 @@ class GPUModelRunner(ModelRunnerBase):
free_list = list( free_list = list(
range( range(
self.num_gpu_blocks - 1, self.num_gpu_blocks - 1,
int(self.num_gpu_blocks * self.parallel_config.kv_cache_ratio) - 1, int(self.num_gpu_blocks * self.cache_config.kv_cache_ratio) - 1,
-1, -1,
) )
) )
@@ -1425,7 +1423,7 @@ class GPUModelRunner(ModelRunnerBase):
if self.speculative_method in ["mtp"] if self.speculative_method in ["mtp"]
else self.model_config.num_hidden_layers else self.model_config.num_hidden_layers
) )
required_memory = byte_of_dtype * 2 * (self.parallel_config.block_size * hidden_dim) * num_layers # k + v required_memory = byte_of_dtype * 2 * (self.cache_config.block_size * hidden_dim) * num_layers # k + v
return required_memory return required_memory
def not_need_stop(self) -> bool: def not_need_stop(self) -> bool:

View File

@@ -137,7 +137,7 @@ class GpuWorker(WorkerBase):
pynvml.nvmlShutdown() pynvml.nvmlShutdown()
available_kv_cache_memory = ( available_kv_cache_memory = (
after_run_meminfo.total * self.parallel_config.gpu_memory_utilization after_run_meminfo.total * self.cache_config.gpu_memory_utilization
- after_run_meminfo.used - after_run_meminfo.used
- paddle_peak_increase - paddle_peak_increase
) )

View File

@@ -189,7 +189,7 @@ class IluvatarModelRunner(ModelRunnerBase):
self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array(request.prompt_token_ids) self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array(request.prompt_token_ids)
# Use chunked prefill # Use chunked prefill
if self.parallel_config.enable_chunked_prefill: if self.cache_config.enable_chunked_prefill:
request.set("chunk_idx", 1) request.set("chunk_idx", 1)
logger.info(f"prefill_chunk_info: {request.prefill_chunk_info}") logger.info(f"prefill_chunk_info: {request.prefill_chunk_info}")
token_chunk_size = request.prefill_chunk_info[0] token_chunk_size = request.prefill_chunk_info[0]
@@ -257,10 +257,10 @@ class IluvatarModelRunner(ModelRunnerBase):
num_tokens // batch_size, num_tokens // batch_size,
self.parallel_config.max_model_len - max_dec_len, self.parallel_config.max_model_len - max_dec_len,
) )
input_length = int(full_length * self.parallel_config.kv_cache_ratio) input_length = int(full_length * self.cache_config.kv_cache_ratio)
block_num = ( block_num = (
input_length + self.parallel_config.block_size - 1 input_length + self.cache_config.block_size - 1
) // self.parallel_config.block_size + self.parallel_config.enc_dec_block_num ) // self.cache_config.block_size + self.cache_config.enc_dec_block_num
for i in range(batch_size): for i in range(batch_size):
idx = i idx = i
@@ -383,15 +383,15 @@ class IluvatarModelRunner(ModelRunnerBase):
# Set block tables # Set block tables
pre_max_block_num = ( pre_max_block_num = (
self.parallel_config.max_model_len + self.parallel_config.block_size - 1 self.parallel_config.max_model_len + self.cache_config.block_size - 1
) // self.parallel_config.block_size + self.parallel_config.enc_dec_block_num ) // self.cache_config.block_size + self.cache_config.enc_dec_block_num
self.share_inputs["block_tables"] = paddle.full([max_num_seqs, pre_max_block_num], -1, dtype="int32") self.share_inputs["block_tables"] = paddle.full([max_num_seqs, pre_max_block_num], -1, dtype="int32")
# Initialize free list # Initialize free list
free_list = list( free_list = list(
range( range(
self.parallel_config.total_block_num - 1, self.parallel_config.total_block_num - 1,
int(self.parallel_config.total_block_num * self.parallel_config.kv_cache_ratio) - 1, int(self.parallel_config.total_block_num * self.cache_config.kv_cache_ratio) - 1,
-1, -1,
) )
) )
@@ -574,7 +574,7 @@ class IluvatarModelRunner(ModelRunnerBase):
) )
if not self.parallel_config.do_profile and ( if not self.parallel_config.do_profile and (
self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed" self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"
): ):
raise NotImplementedError("Iluvatar does not support yet") raise NotImplementedError("Iluvatar does not support yet")
else: else:
@@ -733,10 +733,10 @@ class IluvatarModelRunner(ModelRunnerBase):
self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED
step_cuda( step_cuda(
self.share_inputs, self.share_inputs,
self.parallel_config.block_size, self.cache_config.block_size,
self.parallel_config.enc_dec_block_num, self.cache_config.enc_dec_block_num,
self.speculative_config, self.speculative_config,
self.parallel_config.enable_prefix_caching, self.cache_config.enable_prefix_caching,
) )
if int((self.share_inputs["seq_lens_this_time"] > 0).sum()) == 0: if int((self.share_inputs["seq_lens_this_time"] > 0).sum()) == 0:
@@ -746,7 +746,7 @@ class IluvatarModelRunner(ModelRunnerBase):
""" """
更新chunked prefill相关参数 更新chunked prefill相关参数
""" """
if not self.parallel_config.enable_chunked_prefill: if not self.cache_config.enable_chunked_prefill:
return return
for task in tasks: for task in tasks:
@@ -815,7 +815,7 @@ class IluvatarModelRunner(ModelRunnerBase):
A list of indices corresponding to the requests that need to be skipped. A list of indices corresponding to the requests that need to be skipped.
""" """
skip_idx_list = [] skip_idx_list = []
if not self.parallel_config.enable_chunked_prefill or self.guided_backend is None: if not self.cache_config.enable_chunked_prefill or self.guided_backend is None:
return skip_idx_list return skip_idx_list
for task in model_forward_batch: for task in model_forward_batch:
@@ -952,10 +952,10 @@ class IluvatarModelRunner(ModelRunnerBase):
self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED
step_cuda( step_cuda(
self.share_inputs, self.share_inputs,
self.parallel_config.block_size, self.cache_config.block_size,
self.parallel_config.enc_dec_block_num, self.cache_config.enc_dec_block_num,
self.speculative_config, self.speculative_config,
self.parallel_config.enable_prefix_caching, self.cache_config.enable_prefix_caching,
) )
self._update_chunked_prefill(model_forward_batch) self._update_chunked_prefill(model_forward_batch)
@@ -1023,7 +1023,7 @@ class IluvatarModelRunner(ModelRunnerBase):
free_list = list( free_list = list(
range( range(
self.num_gpu_blocks - 1, self.num_gpu_blocks - 1,
int(self.num_gpu_blocks * self.parallel_config.kv_cache_ratio) - 1, int(self.num_gpu_blocks * self.cache_config.kv_cache_ratio) - 1,
-1, -1,
) )
) )
@@ -1066,7 +1066,7 @@ class IluvatarModelRunner(ModelRunnerBase):
if self.speculative_method in ["mtp"] if self.speculative_method in ["mtp"]
else self.model_config.num_hidden_layers else self.model_config.num_hidden_layers
) )
required_memory = byte_of_dtype * 2 * (self.parallel_config.block_size * hidden_dim) * num_layers # k + v required_memory = byte_of_dtype * 2 * (self.cache_config.block_size * hidden_dim) * num_layers # k + v
return required_memory return required_memory
def not_need_stop(self) -> bool: def not_need_stop(self) -> bool:

View File

@@ -44,6 +44,7 @@ class ModelRunnerBase(ABC):
self.parallel_config = fd_config.parallel_config self.parallel_config = fd_config.parallel_config
self.graph_opt_config = fd_config.graph_opt_config self.graph_opt_config = fd_config.graph_opt_config
self.quant_config = fd_config.quant_config self.quant_config = fd_config.quant_config
self.cache_config = fd_config.cache_config
# ... config # ... config
self.device = device self.device = device

View File

@@ -50,6 +50,7 @@ class WorkerBase(ABC):
self.load_config = fd_config.load_config self.load_config = fd_config.load_config
self.parallel_config = fd_config.parallel_config self.parallel_config = fd_config.parallel_config
self.device_config = fd_config.device_config self.device_config = fd_config.device_config
self.cache_config = fd_config.cache_config
# ... config # ... config
# Device and Runner # Device and Runner

View File

@@ -25,6 +25,7 @@ import paddle.distributed as dist
from paddle.distributed import fleet from paddle.distributed import fleet
from fastdeploy.config import ( from fastdeploy.config import (
CacheConfig,
DecodingConfig, DecodingConfig,
DeviceConfig, DeviceConfig,
ErnieArchitectures, ErnieArchitectures,
@@ -140,6 +141,7 @@ class PaddleDisWorkerProc:
self.local_rank = local_rank self.local_rank = local_rank
self.fd_config = fd_config self.fd_config = fd_config
self.parallel_config = fd_config.parallel_config self.parallel_config = fd_config.parallel_config
self.cache_config = fd_config.cache_config
# TODO(gongshaotian): Use worker factory to get worker # TODO(gongshaotian): Use worker factory to get worker
self.worker = get_worker(fd_config=fd_config, local_rank=self.local_rank, rank=self.ranks) self.worker = get_worker(fd_config=fd_config, local_rank=self.local_rank, rank=self.ranks)
@@ -404,7 +406,7 @@ class PaddleDisWorkerProc:
logger.info(f"------- num_blocks_global: {num_blocks_local} --------") logger.info(f"------- num_blocks_global: {num_blocks_local} --------")
# wait engine launch cache_manager # wait engine launch cache_manager
if self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed": if self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed":
launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32) launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32)
self.launched_cache_manager_signal = IPCSignal( self.launched_cache_manager_signal = IPCSignal(
name="launched_cache_manager_signal", name="launched_cache_manager_signal",
@@ -607,6 +609,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
decoding_config = DecodingConfig(vars(args)) decoding_config = DecodingConfig(vars(args))
speculative_config = SpeculativeConfig(vars(args)) speculative_config = SpeculativeConfig(vars(args))
parallel_config = ParallelConfig(vars(args)) parallel_config = ParallelConfig(vars(args))
cache_config = CacheConfig(vars(args))
parallel_config.tensor_parallel_size = args.tensor_parallel_size parallel_config.tensor_parallel_size = args.tensor_parallel_size
parallel_config.tensor_parallel_rank = local_rank % args.tensor_parallel_size parallel_config.tensor_parallel_rank = local_rank % args.tensor_parallel_size
parallel_config.expert_parallel_size = args.expert_parallel_size parallel_config.expert_parallel_size = args.expert_parallel_size
@@ -707,6 +710,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
decoding_config=decoding_config, decoding_config=decoding_config,
quant_config=quant_config, quant_config=quant_config,
graph_opt_config=graph_opt_config, graph_opt_config=graph_opt_config,
cache_config=cache_config,
) )
update_fd_config_for_mm(fd_config) update_fd_config_for_mm(fd_config)

View File

@@ -428,15 +428,15 @@ class XPUModelRunner(ModelRunnerBase):
# Set block tables # Set block tables
pre_max_block_num = ( pre_max_block_num = (
self.parallel_config.max_model_len + self.parallel_config.block_size - 1 self.parallel_config.max_model_len + self.cache_config.block_size - 1
) // self.parallel_config.block_size + self.parallel_config.enc_dec_block_num ) // self.cache_config.block_size + self.cache_config.enc_dec_block_num
self.share_inputs["block_tables"] = paddle.full([max_num_seqs, pre_max_block_num], -1, dtype="int32") self.share_inputs["block_tables"] = paddle.full([max_num_seqs, pre_max_block_num], -1, dtype="int32")
# Initialize free list # Initialize free list
free_list = list( free_list = list(
range( range(
self.parallel_config.total_block_num - 1, self.parallel_config.total_block_num - 1,
int(self.parallel_config.total_block_num * self.parallel_config.kv_cache_ratio) - 1, int(self.parallel_config.total_block_num * self.cache_config.kv_cache_ratio) - 1,
-1, -1,
) )
) )
@@ -598,8 +598,8 @@ class XPUModelRunner(ModelRunnerBase):
full_length = min(num_tokens // batch_size, self.parallel_config.max_model_len - 10) full_length = min(num_tokens // batch_size, self.parallel_config.max_model_len - 10)
input_length = int(full_length - 512) input_length = int(full_length - 512)
block_num = ( block_num = (
input_length + self.parallel_config.block_size - 1 input_length + self.cache_config.block_size - 1
) // self.parallel_config.block_size + self.parallel_config.enc_dec_block_num ) // self.cache_config.block_size + self.cache_config.enc_dec_block_num
for i in range(batch_size): for i in range(batch_size):
idx = i idx = i
@@ -707,8 +707,8 @@ class XPUModelRunner(ModelRunnerBase):
self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED
step_paddle( step_paddle(
self.share_inputs, self.share_inputs,
self.parallel_config.block_size, self.cache_config.block_size,
self.parallel_config.enc_dec_block_num, self.cache_config.enc_dec_block_num,
) )
return None return None
@@ -764,7 +764,7 @@ class XPUModelRunner(ModelRunnerBase):
required_memory = ( required_memory = (
byte_of_dtype byte_of_dtype
* 2 # k + v * 2 # k + v
* (self.parallel_config.block_size * hidden_dim) * (self.cache_config.block_size * hidden_dim)
* self.model_config.num_hidden_layers * self.model_config.num_hidden_layers
) )
return required_memory return required_memory
@@ -784,7 +784,7 @@ class XPUModelRunner(ModelRunnerBase):
free_list = list( free_list = list(
range( range(
self.num_gpu_blocks - 1, self.num_gpu_blocks - 1,
int(self.num_gpu_blocks * self.parallel_config.kv_cache_ratio) - 1, int(self.num_gpu_blocks * self.cache_config.kv_cache_ratio) - 1,
-1, -1,
) )
) )

View File

@@ -104,7 +104,7 @@ class XpuWorker(WorkerBase):
self.model_runner.prepare_profile() self.model_runner.prepare_profile()
self.model_runner.profile_run() self.model_runner.profile_run()
total_available_memory = int(total_memory * self.parallel_config.gpu_memory_utilization) total_available_memory = int(total_memory * self.cache_config.gpu_memory_utilization)
used_memory = xpu_get_used_global_memory(self.local_rank) used_memory = xpu_get_used_global_memory(self.local_rank)
available_kv_cache_memory = total_available_memory - used_memory available_kv_cache_memory = total_available_memory - used_memory
model_block_memory_used = self.cal_theortical_kvcache() model_block_memory_used = self.cal_theortical_kvcache()