mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Cherry-Pick][Speculative Decoding][BugFix]Fix attention bug in spec decoding(#5460)
Co-authored-by: freeliuzc <23568094+freeliuzc@users.noreply.github.com>
This commit is contained in:
@@ -485,9 +485,6 @@ class LLMEngine:
|
||||
if self.cfg.scheduler_config.splitwise_role == "prefill":
|
||||
variables["FLAGS_fmt_write_cache_completed_signal"] = 1
|
||||
|
||||
if self.cfg.model_config.enable_mm:
|
||||
variables["FLAGS_max_partition_size"] = 1024
|
||||
|
||||
command_prefix = ""
|
||||
for k, v in variables.items():
|
||||
command_prefix += f"{k}={v} "
|
||||
|
||||
@@ -152,6 +152,9 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
self.head_dim: int = fd_config.model_config.head_dim
|
||||
self.num_layers: int = fd_config.model_config.num_hidden_layers
|
||||
self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", 1024))
|
||||
# split kv still has bug in speculative decoding
|
||||
if self.speculative_method is not None:
|
||||
self.max_partition_size = self.max_seq_len
|
||||
self.encoder_block_shape_q: int = encoder_block_shape_q
|
||||
self.decoder_block_shape_q: int = decoder_block_shape_q
|
||||
|
||||
|
||||
Reference in New Issue
Block a user