mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-19 06:54:41 +08:00
Revert "【FIX】Change the name of sparse attn from moba to plas (#3845)" (#4001)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
This reverts commit e31c8f7336
.
This commit is contained in:
@@ -690,63 +690,63 @@ class GraphOptimizationConfig:
|
||||
argument = self.use_cudagraph
|
||||
|
||||
|
||||
class PlasAttentionConfig:
|
||||
class MobaAttentionConfig:
|
||||
def __init__(
|
||||
self,
|
||||
args,
|
||||
):
|
||||
self.plas_encoder_top_k_left: int = None
|
||||
self.plas_encoder_top_k_right: int = None
|
||||
"The sparse topk of encoder attention is located at [plas_encoder_top_k_left, plas_encoder top_k_right]"
|
||||
self.plas_decoder_top_k_left: int = None
|
||||
self.plas_decoder_top_k_right: int = None
|
||||
"The sparse topk of decoder attention is located at [plas_decoder_top_k_left, plas_decoder top_k_right]"
|
||||
self.plas_use_encoder_seq_limit: int = None
|
||||
"When the number of encdoer token is less than plas_use_encoder_seq_limit, it is not sparse"
|
||||
self.plas_use_decoder_seq_limit: int = None
|
||||
"When the number of decdoer token is less than plas_use_decoder_seq_limit, it is not sparse"
|
||||
self.plas_block_size: int = 128
|
||||
self.mlp_weight_name: str = "plas_attention_mlp_weight.safetensors"
|
||||
self.plas_max_seq_length: int = 128 * 1024
|
||||
self.moba_encoder_top_k_left: int = None
|
||||
self.moba_encoder_top_k_right: int = None
|
||||
"The sparse topk of encoder attention is located at [moba_encoder_top_k_left, moba_encoder top_k_right]"
|
||||
self.moba_decoder_top_k_left: int = None
|
||||
self.moba_decoder_top_k_right: int = None
|
||||
"The sparse topk of decoder attention is located at [moba_decoder_top_k_left, moba_decoder top_k_right]"
|
||||
self.moba_use_encoder_seq_limit: int = None
|
||||
"When the number of encdoer token is less than moba_use_encoder_seq_limit, it is not sparse"
|
||||
self.moba_use_decoder_seq_limit: int = None
|
||||
"When the number of decdoer token is less than moba_use_decoder_seq_limit, it is not sparse"
|
||||
self.moba_block_size: int = 128
|
||||
self.mlp_weight_name: str = "moba_mlp_weight.safetensors"
|
||||
self.moba_max_seq_length: int = 128 * 1024
|
||||
if args is not None:
|
||||
for key, value in args.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
if self.plas_use_encoder_seq_limit is None and self.plas_encoder_top_k_left is not None:
|
||||
self.plas_use_encoder_seq_limit = self.plas_encoder_top_k_left * self.plas_block_size
|
||||
if self.plas_use_decoder_seq_limit is None and self.plas_decoder_top_k_left is not None:
|
||||
self.plas_use_decoder_seq_limit = self.plas_decoder_top_k_left * self.plas_block_size
|
||||
if self.moba_use_encoder_seq_limit is None and self.moba_encoder_top_k_left is not None:
|
||||
self.moba_use_encoder_seq_limit = self.moba_encoder_top_k_left * self.moba_block_size
|
||||
if self.moba_use_decoder_seq_limit is None and self.moba_decoder_top_k_left is not None:
|
||||
self.moba_use_decoder_seq_limit = self.moba_decoder_top_k_left * self.moba_block_size
|
||||
self.check_legality_parameters()
|
||||
|
||||
def check_legality_parameters(
|
||||
self,
|
||||
) -> None:
|
||||
if self.plas_encoder_top_k_left is not None:
|
||||
assert self.plas_encoder_top_k_left > 0, "plas_encoder_top_k_left must large than 0"
|
||||
if self.moba_encoder_top_k_left is not None:
|
||||
assert self.moba_encoder_top_k_left > 0, "moba_encoder_top_k_left must large than 0"
|
||||
|
||||
if self.plas_encoder_top_k_right is not None:
|
||||
assert self.plas_encoder_top_k_right > 0, "plas_encoder_top_k_right must large than 0"
|
||||
if self.moba_encoder_top_k_right is not None:
|
||||
assert self.moba_encoder_top_k_right > 0, "moba_encoder_top_k_right must large than 0"
|
||||
assert (
|
||||
self.plas_encoder_top_k_right >= self.plas_encoder_top_k_left
|
||||
), "plas_encoder_top_k_right must large than plas_encoder_top_k_left"
|
||||
self.moba_encoder_top_k_right >= self.moba_encoder_top_k_left
|
||||
), "moba_encoder_top_k_right must large than moba_encoder_top_k_left"
|
||||
|
||||
if self.plas_decoder_top_k_left is not None:
|
||||
assert self.plas_decoder_top_k_left > 0, "plas_decoder_top_k_left must large than 0"
|
||||
if self.moba_decoder_top_k_left is not None:
|
||||
assert self.moba_decoder_top_k_left > 0, "moba_decoder_top_k_left must large than 0"
|
||||
|
||||
if self.plas_decoder_top_k_right is not None:
|
||||
assert self.plas_decoder_top_k_right > 0, "plas_decoder_top_k_right must large than 0"
|
||||
if self.moba_decoder_top_k_right is not None:
|
||||
assert self.moba_decoder_top_k_right > 0, "moba_decoder_top_k_right must large than 0"
|
||||
assert (
|
||||
self.plas_decoder_top_k_right >= self.plas_decoder_top_k_left
|
||||
), "plas_decoder_top_k_right must large than plas_decoder_top_k_left"
|
||||
self.moba_decoder_top_k_right >= self.moba_decoder_top_k_left
|
||||
), "moba_decoder_top_k_right must large than moba_decoder_top_k_left"
|
||||
|
||||
if self.plas_use_encoder_seq_limit is not None and self.plas_encoder_top_k_left is not None:
|
||||
assert self.plas_use_encoder_seq_limit >= self.plas_encoder_top_k_left * self.plas_block_size
|
||||
if self.plas_use_decoder_seq_limit is not None and self.plas_decoder_top_k_left is not None:
|
||||
assert self.plas_use_decoder_seq_limit >= self.plas_decoder_top_k_left * self.plas_block_size
|
||||
if self.moba_use_encoder_seq_limit is not None and self.moba_encoder_top_k_left is not None:
|
||||
assert self.moba_use_encoder_seq_limit >= self.moba_encoder_top_k_left * self.moba_block_size
|
||||
if self.moba_use_decoder_seq_limit is not None and self.moba_decoder_top_k_left is not None:
|
||||
assert self.moba_use_decoder_seq_limit >= self.moba_decoder_top_k_left * self.moba_block_size
|
||||
|
||||
def to_json_string(self):
|
||||
"""
|
||||
Convert plas_attention_config to json string.
|
||||
Convert moba_attention_config to json string.
|
||||
"""
|
||||
return json.dumps({key: value for key, value in self.__dict__.items() if value is not None})
|
||||
|
||||
@@ -1105,7 +1105,7 @@ class FDConfig:
|
||||
decoding_config: DecodingConfig = None,
|
||||
quant_config: QuantConfigBase = None,
|
||||
graph_opt_config: GraphOptimizationConfig = None,
|
||||
plas_attention_config: PlasAttentionConfig = None,
|
||||
moba_attention_config: MobaAttentionConfig = None,
|
||||
speculative_config: SpeculativeConfig = None,
|
||||
tokenizer: str = None,
|
||||
max_model_len: int = 8192,
|
||||
@@ -1140,7 +1140,7 @@ class FDConfig:
|
||||
self.early_stop_config: Optional[EarlyStopConfig] = early_stop_config
|
||||
self.decoding_config: DecodingConfig = decoding_config # type: ignore
|
||||
self.cache_config: CacheConfig = cache_config # type: ignore
|
||||
self.plas_attention_config: Optional[PlasAttentionConfig] = plas_attention_config
|
||||
self.moba_attention_config: Optional[MobaAttentionConfig] = moba_attention_config
|
||||
# Initialize cuda graph capture list
|
||||
if self.graph_opt_config.cudagraph_capture_sizes is None:
|
||||
self.graph_opt_config._set_cudagraph_sizes(max_num_seqs=self.parallel_config.max_num_seqs)
|
||||
|
@@ -28,9 +28,9 @@ from fastdeploy.config import (
|
||||
FDConfig,
|
||||
GraphOptimizationConfig,
|
||||
LoadConfig,
|
||||
MobaAttentionConfig,
|
||||
ModelConfig,
|
||||
ParallelConfig,
|
||||
PlasAttentionConfig,
|
||||
SpeculativeConfig,
|
||||
TaskOption,
|
||||
)
|
||||
@@ -342,9 +342,9 @@ class EngineArgs:
|
||||
"""
|
||||
Configuration for graph optimization backend execution.
|
||||
"""
|
||||
plas_attention_config: Optional[Dict[str, Any]] = None
|
||||
moba_attention_config: Optional[Dict[str, Any]] = None
|
||||
"""
|
||||
Configuration for plas attention.
|
||||
Configuration for moba attention.
|
||||
"""
|
||||
|
||||
enable_logprob: bool = False
|
||||
@@ -559,9 +559,9 @@ class EngineArgs:
|
||||
help="",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--plas-attention-config",
|
||||
"--moba-attention-config",
|
||||
type=json.loads,
|
||||
default=EngineArgs.plas_attention_config,
|
||||
default=EngineArgs.moba_attention_config,
|
||||
help="",
|
||||
)
|
||||
model_group.add_argument(
|
||||
@@ -959,17 +959,17 @@ class EngineArgs:
|
||||
graph_optimization_args[k] = v
|
||||
return GraphOptimizationConfig(graph_optimization_args)
|
||||
|
||||
def create_plas_attention_config(self) -> PlasAttentionConfig:
|
||||
def create_moba_attention_config(self) -> MobaAttentionConfig:
|
||||
"""
|
||||
Create and retuan a PlasAttentionConfig object based on the current settings.
|
||||
Create and retuan a MobaAttentionConfig object based on the current settings.
|
||||
"""
|
||||
attention_args = asdict(self)
|
||||
if self.plas_attention_config is not None:
|
||||
for k, v in self.plas_attention_config.items():
|
||||
if self.moba_attention_config is not None:
|
||||
for k, v in self.moba_attention_config.items():
|
||||
attention_args[k] = v
|
||||
return PlasAttentionConfig(attention_args)
|
||||
return MobaAttentionConfig(attention_args)
|
||||
else:
|
||||
return PlasAttentionConfig(None)
|
||||
return MobaAttentionConfig(None)
|
||||
|
||||
def create_early_stop_config(self) -> EarlyStopConfig:
|
||||
"""
|
||||
@@ -1025,7 +1025,7 @@ class EngineArgs:
|
||||
scheduler_cfg = self.create_scheduler_config()
|
||||
graph_opt_cfg = self.create_graph_optimization_config()
|
||||
graph_opt_cfg.update_use_cudagraph(self.use_cudagraph)
|
||||
plas_attention_config = self.create_plas_attention_config()
|
||||
moba_attention_config = self.create_moba_attention_config()
|
||||
|
||||
early_stop_cfg = self.create_early_stop_config()
|
||||
early_stop_cfg.update_enable_early_stop(self.enable_early_stop)
|
||||
@@ -1063,7 +1063,7 @@ class EngineArgs:
|
||||
max_long_partial_prefills=self.max_long_partial_prefills,
|
||||
long_prefill_token_threshold=self.long_prefill_token_threshold,
|
||||
graph_opt_config=graph_opt_cfg,
|
||||
plas_attention_config=plas_attention_config,
|
||||
moba_attention_config=moba_attention_config,
|
||||
guided_decoding_backend=self.guided_decoding_backend,
|
||||
disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
|
||||
early_stop_config=early_stop_cfg,
|
||||
|
@@ -493,7 +493,7 @@ class LLMEngine:
|
||||
f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'"
|
||||
f" --reasoning_parser {self.cfg.reasoning_parser}"
|
||||
f" --load_choices {self.cfg.load_config.load_choices}"
|
||||
f" --plas_attention_config '{self.cfg.plas_attention_config.to_json_string()}'"
|
||||
f" --moba_attention_config '{self.cfg.moba_attention_config.to_json_string()}'"
|
||||
f" --ips {ips}"
|
||||
)
|
||||
|
||||
|
@@ -20,7 +20,7 @@ from .block_multihead_attn_backend import BlockAttentionBackend
|
||||
from .flash_attn_backend import FlashAttentionBackend
|
||||
from .iluvatar_attn_backend import IluvatarAttnBackend
|
||||
from .mla_attention_backend import MLAAttentionBackend
|
||||
from .moba_attention_backend import PlasAttentionBackend
|
||||
from .moba_attention_backend import MobaAttentionBackend
|
||||
from .native_paddle_backend import PaddleNativeAttnBackend
|
||||
from .xpu_attn_backend import XPUAttentionBackend
|
||||
|
||||
@@ -35,5 +35,5 @@ __all__ = [
|
||||
"IluvatarAttnBackend",
|
||||
"BlockAttentionBackend",
|
||||
"Attention",
|
||||
"PlasAttentionBackend",
|
||||
"MobaAttentionBackend",
|
||||
]
|
||||
|
@@ -119,19 +119,19 @@ class Attention(nn.Layer):
|
||||
self.init_weight()
|
||||
|
||||
if (
|
||||
fd_config.plas_attention_config is not None
|
||||
and fd_config.plas_attention_config.plas_encoder_top_k_left is not None
|
||||
and fd_config.plas_attention_config.plas_encoder_top_k_right is not None
|
||||
and fd_config.plas_attention_config.plas_decoder_top_k_left is not None
|
||||
and fd_config.plas_attention_config.plas_decoder_top_k_right is not None
|
||||
fd_config.moba_attention_config is not None
|
||||
and fd_config.moba_attention_config.moba_encoder_top_k_left is not None
|
||||
and fd_config.moba_attention_config.moba_encoder_top_k_right is not None
|
||||
and fd_config.moba_attention_config.moba_decoder_top_k_left is not None
|
||||
and fd_config.moba_attention_config.moba_decoder_top_k_right is not None
|
||||
):
|
||||
mlp_weight_path = os.path.join(
|
||||
fd_config.model_config.model, fd_config.plas_attention_config.mlp_weight_name
|
||||
fd_config.model_config.model, fd_config.moba_attention_config.mlp_weight_name
|
||||
)
|
||||
self.plas_use_mlp = mlp_weight_path is not None and os.path.exists(mlp_weight_path)
|
||||
plas_block_size = fd_config.plas_attention_config.plas_block_size
|
||||
plas_max_seq_length = fd_config.plas_attention_config.plas_max_seq_length
|
||||
if self.plas_use_mlp:
|
||||
self.moba_use_mlp = mlp_weight_path is not None and os.path.exists(mlp_weight_path)
|
||||
moba_block_size = fd_config.moba_attention_config.moba_block_size
|
||||
moba_max_seq_length = fd_config.moba_attention_config.moba_max_seq_length
|
||||
if self.moba_use_mlp:
|
||||
mlp_weight = {}
|
||||
with safe_open(mlp_weight_path, framework="np", device="cpu") as f:
|
||||
for key_name in f.keys():
|
||||
@@ -148,12 +148,12 @@ class Attention(nn.Layer):
|
||||
* self.kv_num_heads : (fd_config.parallel_config.tensor_parallel_rank + 1)
|
||||
* self.kv_num_heads
|
||||
]
|
||||
assert self.attn_gate_weight.shape[1] % plas_block_size == 0
|
||||
assert self.attn_gate_weight.shape[1] % moba_block_size == 0
|
||||
|
||||
self.cache_k_block_means = paddle.zeros(
|
||||
[
|
||||
fd_config.parallel_config.max_num_seqs,
|
||||
plas_max_seq_length // plas_block_size,
|
||||
moba_max_seq_length // moba_block_size,
|
||||
self.kv_num_heads,
|
||||
self.head_dim,
|
||||
],
|
||||
|
@@ -39,7 +39,7 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlasAttentionMetadata(AttentionMetadata):
|
||||
class MobaAttentionMetadata(AttentionMetadata):
|
||||
"""
|
||||
AppendAttentionMetadata
|
||||
"""
|
||||
@@ -54,7 +54,7 @@ class PlasAttentionMetadata(AttentionMetadata):
|
||||
max_dec_len_this_time: int = 0
|
||||
|
||||
|
||||
class PlasAttentionBackend(AttentionBackend):
|
||||
class MobaAttentionBackend(AttentionBackend):
|
||||
"""
|
||||
The backend class that uses paddle native attention implementation.
|
||||
Which is used only for testing purpose.
|
||||
@@ -70,11 +70,11 @@ class PlasAttentionBackend(AttentionBackend):
|
||||
decoder_block_shape_q: int = -1,
|
||||
) -> None:
|
||||
"""
|
||||
PlasAttentionBackend __init__
|
||||
MobaAttentionBackend __init__
|
||||
"""
|
||||
super().__init__()
|
||||
self.attention_metadata: PlasAttentionMetadata = None
|
||||
assert fd_config.plas_attention_config is not None, "plas_attention_config is None"
|
||||
self.attention_metadata: MobaAttentionMetadata = None
|
||||
assert fd_config.moba_attention_config is not None, "moba_attention_config is None"
|
||||
self.block_size = fd_config.parallel_config.block_size
|
||||
self.max_seq_len = fd_config.parallel_config.max_model_len
|
||||
self.max_num_seqs = fd_config.parallel_config.max_num_seqs
|
||||
@@ -83,18 +83,18 @@ class PlasAttentionBackend(AttentionBackend):
|
||||
self.head_dim = fd_config.model_config.head_dim
|
||||
self.num_layers: int = fd_config.model_config.num_hidden_layers
|
||||
self.attn_block_m = 128
|
||||
self.plas_block_size = fd_config.plas_attention_config.plas_block_size
|
||||
self.plas_encoder_top_k_left = int(fd_config.plas_attention_config.plas_encoder_top_k_left)
|
||||
self.plas_encoder_top_k_right = int(fd_config.plas_attention_config.plas_encoder_top_k_right)
|
||||
self.plas_use_encoder_seq_limit = int(fd_config.plas_attention_config.plas_use_encoder_seq_limit)
|
||||
self.plas_decoder_top_k_left = int(fd_config.plas_attention_config.plas_decoder_top_k_left)
|
||||
self.plas_decoder_top_k_right = int(fd_config.plas_attention_config.plas_decoder_top_k_right)
|
||||
self.plas_use_decoder_seq_limit = int(fd_config.plas_attention_config.plas_use_decoder_seq_limit)
|
||||
self.plas_max_seq_length = fd_config.plas_attention_config.plas_max_seq_length
|
||||
self.moba_block_size = fd_config.moba_attention_config.moba_block_size
|
||||
self.moba_encoder_top_k_left = int(fd_config.moba_attention_config.moba_encoder_top_k_left)
|
||||
self.moba_encoder_top_k_right = int(fd_config.moba_attention_config.moba_encoder_top_k_right)
|
||||
self.moba_use_encoder_seq_limit = int(fd_config.moba_attention_config.moba_use_encoder_seq_limit)
|
||||
self.moba_decoder_top_k_left = int(fd_config.moba_attention_config.moba_decoder_top_k_left)
|
||||
self.moba_decoder_top_k_right = int(fd_config.moba_attention_config.moba_decoder_top_k_right)
|
||||
self.moba_use_decoder_seq_limit = int(fd_config.moba_attention_config.moba_use_decoder_seq_limit)
|
||||
self.moba_max_seq_length = fd_config.moba_attention_config.moba_max_seq_length
|
||||
|
||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||
"""Init the metadata for a forward pass."""
|
||||
metadata = PlasAttentionMetadata()
|
||||
metadata = MobaAttentionMetadata()
|
||||
metadata._dtype = paddle.get_default_dtype()
|
||||
metadata.cu_seq_q_pack, metadata.cu_seqlens_k, metadata.q_pack_tokens = get_cur_cu_seq_len_k(
|
||||
forward_meta.seq_lens_encoder,
|
||||
@@ -116,7 +116,7 @@ class PlasAttentionBackend(AttentionBackend):
|
||||
[k_token_num + self.attn_block_m, self.kv_num_heads * self.head_dim], dtype=metadata._dtype
|
||||
)
|
||||
self.attention_metadata = metadata
|
||||
assert self.max_seq_len <= self.plas_max_seq_length
|
||||
assert self.max_seq_len <= self.moba_max_seq_length
|
||||
|
||||
def get_kv_cache_shape(
|
||||
self,
|
||||
@@ -186,13 +186,13 @@ class PlasAttentionBackend(AttentionBackend):
|
||||
self.max_seq_len,
|
||||
attention_metadata.max_enc_len_this_time,
|
||||
attention_metadata.max_dec_len_this_time,
|
||||
self.plas_encoder_top_k_left,
|
||||
self.plas_encoder_top_k_right,
|
||||
self.plas_use_encoder_seq_limit,
|
||||
self.plas_decoder_top_k_left,
|
||||
self.plas_decoder_top_k_right,
|
||||
self.plas_use_decoder_seq_limit,
|
||||
layer.plas_use_mlp,
|
||||
self.moba_encoder_top_k_left,
|
||||
self.moba_encoder_top_k_right,
|
||||
self.moba_use_encoder_seq_limit,
|
||||
self.moba_decoder_top_k_left,
|
||||
self.moba_decoder_top_k_right,
|
||||
self.moba_use_decoder_seq_limit,
|
||||
layer.moba_use_mlp,
|
||||
getattr(layer, "cache_quant_type_str", "none"),
|
||||
)[0]
|
||||
return out
|
||||
|
@@ -26,7 +26,7 @@ class _Backend(enum.Enum):
|
||||
MLA_ATTN = enum.auto()
|
||||
FLASH_ATTN = enum.auto()
|
||||
BLOCK_ATTN = enum.auto()
|
||||
PLAS_ATTN = enum.auto()
|
||||
MOBA_ATTN = enum.auto()
|
||||
|
||||
|
||||
class Platform:
|
||||
|
@@ -64,9 +64,9 @@ class CUDAPlatform(Platform):
|
||||
elif selected_backend == _Backend.FLASH_ATTN:
|
||||
logger.info("Using FLASH ATTN backend.")
|
||||
return "fastdeploy.model_executor.layers.attention.FlashAttentionBackend"
|
||||
elif selected_backend == _Backend.PLAS_ATTN:
|
||||
logger.info("Using PLAS ATTN backend.")
|
||||
return "fastdeploy.model_executor.layers.attention.PlasAttentionBackend"
|
||||
elif selected_backend == _Backend.MOBA_ATTN:
|
||||
logger.info("Using MOBA ATTN backend.")
|
||||
return "fastdeploy.model_executor.layers.attention.MobaAttentionBackend"
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid attention backend you specified.\n"
|
||||
|
@@ -59,7 +59,7 @@ class RolloutModelConfig:
|
||||
graph_optimization_config: str = None,
|
||||
early_stop_config: str = None,
|
||||
local_rank: int = 0,
|
||||
plas_attention_config: str = None,
|
||||
moba_attention_config: str = None,
|
||||
data_parallel_size: int = 1,
|
||||
):
|
||||
# Required parameters
|
||||
@@ -106,7 +106,7 @@ class RolloutModelConfig:
|
||||
self.local_rank = local_rank
|
||||
self.early_stop_config = early_stop_config
|
||||
self.ips = None
|
||||
self.plas_attention_config = plas_attention_config
|
||||
self.moba_attention_config = moba_attention_config
|
||||
|
||||
def __str__(self):
|
||||
return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items())
|
||||
|
@@ -34,9 +34,9 @@ from fastdeploy.config import (
|
||||
FDConfig,
|
||||
GraphOptimizationConfig,
|
||||
LoadConfig,
|
||||
MobaAttentionConfig,
|
||||
ModelConfig,
|
||||
ParallelConfig,
|
||||
PlasAttentionConfig,
|
||||
SpeculativeConfig,
|
||||
)
|
||||
from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer
|
||||
@@ -561,10 +561,10 @@ def parse_args():
|
||||
help="Configuration of Graph optimization backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--plas_attention_config",
|
||||
"--moba_attention_config",
|
||||
type=json.loads,
|
||||
default=None,
|
||||
help="Configation of plas attention.",
|
||||
help="Configation of moba attention.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--guided_decoding_backend",
|
||||
@@ -677,7 +677,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
|
||||
graph_opt_config = GraphOptimizationConfig(args.graph_optimization_config)
|
||||
|
||||
plas_attention_config = PlasAttentionConfig(args.plas_attention_config)
|
||||
moba_attention_config = MobaAttentionConfig(args.moba_attention_config)
|
||||
|
||||
early_stop_config = EarlyStopConfig(args.early_stop_config)
|
||||
|
||||
@@ -777,7 +777,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
cache_config=cache_config,
|
||||
engine_worker_queue_port=args.engine_worker_queue_port,
|
||||
ips=args.ips,
|
||||
plas_attention_config=plas_attention_config,
|
||||
moba_attention_config=moba_attention_config,
|
||||
)
|
||||
update_fd_config_for_mm(fd_config)
|
||||
|
||||
|
Reference in New Issue
Block a user