[FDConfig]Remove max_num_batched_tokens/max_num_seqs in parallel config (#4116)

* remove max_num_batched_tokens in parallel config

* remove max_num_seqs

* update test case

* fix test

* fix

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
YuanRisheng
2025-09-17 10:43:35 +08:00
committed by GitHub
parent c01a756912
commit 2e9e53ff7e
30 changed files with 169 additions and 131 deletions

View File

@@ -152,7 +152,7 @@ class Attention(nn.Layer):
self.cache_k_block_means = paddle.zeros(
[
fd_config.parallel_config.max_num_seqs,
fd_config.scheduler_config.max_num_seqs,
moba_max_seq_length // moba_block_size,
self.kv_num_heads,
self.head_dim,

View File

@@ -156,7 +156,7 @@ class FlashAttentionBackend(AttentionBackend):
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", "32768"))
self.zero_seq_enc_lens_for_decode = paddle.zeros(
shape=[fd_config.parallel_config.max_num_seqs, 1], dtype=paddle.int32
shape=[fd_config.scheduler_config.max_num_seqs, 1], dtype=paddle.int32
)
def get_attntion_meta(self):

View File

@@ -77,7 +77,7 @@ class MobaAttentionBackend(AttentionBackend):
assert fd_config.moba_attention_config is not None, "moba_attention_config is None"
self.block_size = fd_config.parallel_config.block_size
self.max_seq_len = fd_config.parallel_config.max_model_len
self.max_num_seqs = fd_config.parallel_config.max_num_seqs
self.max_num_seqs = fd_config.scheduler_config.max_num_seqs
self.kv_num_heads = kv_num_heads
self.num_heads = num_heads
self.head_dim = fd_config.model_config.head_dim

View File

@@ -86,7 +86,7 @@ class GCUFlashAttnBackend(AttentionBackend):
self.attention_metadata: GCUFlashAttnMetadata = None
self.block_size = fd_config.cache_config.block_size
self.max_seq_len = fd_config.parallel_config.max_model_len
self.max_num_seqs = fd_config.parallel_config.max_num_seqs
self.max_num_seqs = fd_config.scheduler_config.max_num_seqs
self.causal = getattr(fd_config.model_config, "causal", True)

View File

@@ -84,7 +84,7 @@ class GCUMemEfficientAttnBackend(AttentionBackend):
self.attention_metadata: GCUMemEfficientAttnMetadata = None
self.block_size = fd_config.cache_config.block_size
self.max_seq_len = fd_config.parallel_config.max_model_len
self.max_num_seqs = fd_config.parallel_config.max_num_seqs
self.max_num_seqs = fd_config.scheduler_config.max_num_seqs
self.causal = getattr(fd_config.model_config, "causal", True)

View File

@@ -221,7 +221,7 @@ class Sampler(nn.Layer):
):
early_stopper_cls = get_early_stopper_cls_from_stragegy(fd_config.early_stop_config.strategy)
self.early_stopper = early_stopper_cls()
self.early_stopper.initialize(fd_config.parallel_config.max_num_seqs, fd_config.early_stop_config)
self.early_stopper.initialize(fd_config.scheduler_config.max_num_seqs, fd_config.early_stop_config)
def set_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None):
"""set reasoning parser"""