mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-16 13:41:30 +08:00
[FDConfig]Remove max_num_batched_tokens/max_num_seqs in parallel config (#4116)
* remove max_num_batched_tokens in parallel config * remove max_num_seqs * update test case * fix test * fix --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -152,7 +152,7 @@ class Attention(nn.Layer):
|
||||
|
||||
self.cache_k_block_means = paddle.zeros(
|
||||
[
|
||||
fd_config.parallel_config.max_num_seqs,
|
||||
fd_config.scheduler_config.max_num_seqs,
|
||||
moba_max_seq_length // moba_block_size,
|
||||
self.kv_num_heads,
|
||||
self.head_dim,
|
||||
|
@@ -156,7 +156,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
|
||||
self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", "32768"))
|
||||
self.zero_seq_enc_lens_for_decode = paddle.zeros(
|
||||
shape=[fd_config.parallel_config.max_num_seqs, 1], dtype=paddle.int32
|
||||
shape=[fd_config.scheduler_config.max_num_seqs, 1], dtype=paddle.int32
|
||||
)
|
||||
|
||||
def get_attntion_meta(self):
|
||||
|
@@ -77,7 +77,7 @@ class MobaAttentionBackend(AttentionBackend):
|
||||
assert fd_config.moba_attention_config is not None, "moba_attention_config is None"
|
||||
self.block_size = fd_config.parallel_config.block_size
|
||||
self.max_seq_len = fd_config.parallel_config.max_model_len
|
||||
self.max_num_seqs = fd_config.parallel_config.max_num_seqs
|
||||
self.max_num_seqs = fd_config.scheduler_config.max_num_seqs
|
||||
self.kv_num_heads = kv_num_heads
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = fd_config.model_config.head_dim
|
||||
|
@@ -86,7 +86,7 @@ class GCUFlashAttnBackend(AttentionBackend):
|
||||
self.attention_metadata: GCUFlashAttnMetadata = None
|
||||
self.block_size = fd_config.cache_config.block_size
|
||||
self.max_seq_len = fd_config.parallel_config.max_model_len
|
||||
self.max_num_seqs = fd_config.parallel_config.max_num_seqs
|
||||
self.max_num_seqs = fd_config.scheduler_config.max_num_seqs
|
||||
|
||||
self.causal = getattr(fd_config.model_config, "causal", True)
|
||||
|
||||
|
@@ -84,7 +84,7 @@ class GCUMemEfficientAttnBackend(AttentionBackend):
|
||||
self.attention_metadata: GCUMemEfficientAttnMetadata = None
|
||||
self.block_size = fd_config.cache_config.block_size
|
||||
self.max_seq_len = fd_config.parallel_config.max_model_len
|
||||
self.max_num_seqs = fd_config.parallel_config.max_num_seqs
|
||||
self.max_num_seqs = fd_config.scheduler_config.max_num_seqs
|
||||
|
||||
self.causal = getattr(fd_config.model_config, "causal", True)
|
||||
|
||||
|
@@ -221,7 +221,7 @@ class Sampler(nn.Layer):
|
||||
):
|
||||
early_stopper_cls = get_early_stopper_cls_from_stragegy(fd_config.early_stop_config.strategy)
|
||||
self.early_stopper = early_stopper_cls()
|
||||
self.early_stopper.initialize(fd_config.parallel_config.max_num_seqs, fd_config.early_stop_config)
|
||||
self.early_stopper.initialize(fd_config.scheduler_config.max_num_seqs, fd_config.early_stop_config)
|
||||
|
||||
def set_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None):
|
||||
"""set reasoning parser"""
|
||||
|
Reference in New Issue
Block a user