mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[feat] support fa3 backend for pd disaggregated (#2695)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* support fa3 backend run in pd disaggregated * support fa3 backend run in pd disaggregated * support fa3 backend run in pd disaggregated * support fa3 backend run in pd disaggregated * delete use_fast_ffn
This commit is contained in:
@@ -90,7 +90,8 @@ class MTPProposer(Proposer):
|
||||
|
||||
self.model = get_model_from_loader(self.cfg)
|
||||
|
||||
def dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int):
|
||||
def dummy_prefill_inputs(self, num_tokens: int, batch_size: int,
|
||||
expected_decode_len: int):
|
||||
"""Set dummy prefill inputs to model_inputs"""
|
||||
max_dec_len = expected_decode_len + 1
|
||||
self.num_gpu_blocks = self.parallel_config.max_block_num
|
||||
@@ -130,10 +131,10 @@ class MTPProposer(Proposer):
|
||||
self.cache_kvs = {}
|
||||
|
||||
cache_type = self.parallel_config.dtype
|
||||
|
||||
if (self.quant_config and
|
||||
hasattr(self.quant_config, "kv_cache_quant_type") and
|
||||
self.quant_config.kv_cache_quant_type is not None):
|
||||
|
||||
if (self.quant_config
|
||||
and hasattr(self.quant_config, "kv_cache_quant_type")
|
||||
and self.quant_config.kv_cache_quant_type is not None):
|
||||
cache_type = 'uint8'
|
||||
|
||||
# Get kv cache shape
|
||||
@@ -190,8 +191,7 @@ class MTPProposer(Proposer):
|
||||
head_dim = self.model_config.head_dim
|
||||
|
||||
# Get the attention backend
|
||||
attn_cls = get_attention_backend(
|
||||
self.parallel_config.attention_backend)
|
||||
attn_cls = get_attention_backend()
|
||||
attn_backend = attn_cls(
|
||||
self.cfg,
|
||||
kv_num_heads=self.model_config.kv_num_heads,
|
||||
@@ -200,8 +200,8 @@ class MTPProposer(Proposer):
|
||||
)
|
||||
if attn_backend is None:
|
||||
raise NotImplementedError(
|
||||
f"{ self.parallel_config.attention_backend} attention backend"
|
||||
" is not support by GPUModelRunner")
|
||||
"Attention backend which you specified is not supported, please set FD_ATTENTION_BACKEND correctly."
|
||||
)
|
||||
self.attn_backends.append(attn_backend)
|
||||
|
||||
def clear_dummy_input(self):
|
||||
|
Reference in New Issue
Block a user