mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-27 04:46:16 +08:00
[Feature] block sparse attention (#3668)
* 支持稀疏attn * fix bug * code style * fix moba attn get kv shape * 修复a100编译 * codestyle * code style * code style * code style * fix conflict * 增加单侧 * code style * 增加eblite 加载时间 * fix bug * for ci * for ci * for ci * for ci * 支持mlp block size 128 * 增加小算子单测 * fix 单测 mlp * 将环境变量加入到config里面 * fix rollout config * 修复显存 * add test server * add test server * fix mlp 最后一层使用full attn
This commit is contained in:
@@ -684,6 +684,67 @@ class GraphOptimizationConfig:
|
||||
argument = self.use_cudagraph
|
||||
|
||||
|
||||
class MobaAttentionConfig:
|
||||
def __init__(
|
||||
self,
|
||||
args,
|
||||
):
|
||||
self.moba_encoder_top_k_left: int = None
|
||||
self.moba_encoder_top_k_right: int = None
|
||||
"The sparse topk of encoder attention is located at [moba_encoder_top_k_left, moba_encoder top_k_right]"
|
||||
self.moba_decoder_top_k_left: int = None
|
||||
self.moba_decoder_top_k_right: int = None
|
||||
"The sparse topk of decoder attention is located at [moba_decoder_top_k_left, moba_decoder top_k_right]"
|
||||
self.moba_use_encoder_seq_limit: int = None
|
||||
"When the number of encdoer token is less than moba_use_encoder_seq_limit, it is not sparse"
|
||||
self.moba_use_decoder_seq_limit: int = None
|
||||
"When the number of decdoer token is less than moba_use_decoder_seq_limit, it is not sparse"
|
||||
self.moba_block_size: int = 128
|
||||
self.mlp_weight_name: str = "moba_mlp_weight.safetensors"
|
||||
self.moba_max_seq_length: int = 128 * 1024
|
||||
if args is not None:
|
||||
for key, value in args.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
if self.moba_use_encoder_seq_limit is None and self.moba_encoder_top_k_left is not None:
|
||||
self.moba_use_encoder_seq_limit = self.moba_encoder_top_k_left * self.moba_block_size
|
||||
if self.moba_use_decoder_seq_limit is None and self.moba_decoder_top_k_left is not None:
|
||||
self.moba_use_decoder_seq_limit = self.moba_decoder_top_k_left * self.moba_block_size
|
||||
self.check_legality_parameters()
|
||||
|
||||
def check_legality_parameters(
|
||||
self,
|
||||
) -> None:
|
||||
if self.moba_encoder_top_k_left is not None:
|
||||
assert self.moba_encoder_top_k_left > 0, "moba_encoder_top_k_left must large than 0"
|
||||
|
||||
if self.moba_encoder_top_k_right is not None:
|
||||
assert self.moba_encoder_top_k_right > 0, "moba_encoder_top_k_right must large than 0"
|
||||
assert (
|
||||
self.moba_encoder_top_k_right >= self.moba_encoder_top_k_left
|
||||
), "moba_encoder_top_k_right must large than moba_encoder_top_k_left"
|
||||
|
||||
if self.moba_decoder_top_k_left is not None:
|
||||
assert self.moba_decoder_top_k_left > 0, "moba_decoder_top_k_left must large than 0"
|
||||
|
||||
if self.moba_decoder_top_k_right is not None:
|
||||
assert self.moba_decoder_top_k_right > 0, "moba_decoder_top_k_right must large than 0"
|
||||
assert (
|
||||
self.moba_decoder_top_k_right >= self.moba_decoder_top_k_left
|
||||
), "moba_decoder_top_k_right must large than moba_decoder_top_k_left"
|
||||
|
||||
if self.moba_use_encoder_seq_limit is not None and self.moba_encoder_top_k_left is not None:
|
||||
assert self.moba_use_encoder_seq_limit >= self.moba_encoder_top_k_left * self.moba_block_size
|
||||
if self.moba_use_decoder_seq_limit is not None and self.moba_decoder_top_k_left is not None:
|
||||
assert self.moba_use_decoder_seq_limit >= self.moba_decoder_top_k_left * self.moba_block_size
|
||||
|
||||
def to_json_string(self):
|
||||
"""
|
||||
Convert moba_attention_config to json string.
|
||||
"""
|
||||
return json.dumps({key: value for key, value in self.__dict__.items() if value is not None})
|
||||
|
||||
|
||||
class EarlyStopConfig:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1038,6 +1099,7 @@ class FDConfig:
|
||||
decoding_config: DecodingConfig = None,
|
||||
quant_config: QuantConfigBase = None,
|
||||
graph_opt_config: GraphOptimizationConfig = None,
|
||||
moba_attention_config: MobaAttentionConfig = None,
|
||||
speculative_config: SpeculativeConfig = None,
|
||||
tokenizer: str = None,
|
||||
max_model_len: int = 8192,
|
||||
@@ -1072,7 +1134,7 @@ class FDConfig:
|
||||
self.early_stop_config: Optional[EarlyStopConfig] = early_stop_config
|
||||
self.decoding_config: DecodingConfig = decoding_config # type: ignore
|
||||
self.cache_config: CacheConfig = cache_config # type: ignore
|
||||
|
||||
self.moba_attention_config: Optional[MobaAttentionConfig] = moba_attention_config
|
||||
# Initialize cuda graph capture list
|
||||
if self.graph_opt_config.cudagraph_capture_sizes is None:
|
||||
self.graph_opt_config._set_cudagraph_sizes(max_num_seqs=self.parallel_config.max_num_seqs)
|
||||
|
Reference in New Issue
Block a user