mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[TSP] Support qwen3 moe tsp + cudagraph (#4871)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
* support qwen3_moe tsp mode * fix * fix * update * update * update * fix * support external_rmsnorm * update * fix
This commit is contained in:
@@ -40,6 +40,8 @@ When using FastDeploy to deploy models (including offline inference and service
|
||||
| ```use_cudagraph``` | `bool` | __[DEPRECATED]__ CUDAGraph is enabled by default since version 2.3. It is recommended to read [graph_optimization.md](./features/graph_optimization.md) carefully before opening. |
|
||||
| ```graph_optimization_config``` | `dict[str]` | Can configure parameters related to calculation graph optimization, the default value is'{"use_cudagraph":true, "graph_opt_level":0}',Detailed description reference [graph_optimization.md](./features/graph_optimization.md)|
|
||||
| ```disable_custom_all_reduce``` | `bool` | Disable Custom all-reduce, default: False |
|
||||
| ```use_internode_ll_two_stage``` | `bool` | Use two stage communication in deepep moe, default: False |
|
||||
| ```disable_sequence_parallel_moe``` | `bool` | Disable sequence parallel moe, default: False |
|
||||
| ```splitwise_role``` | `str` | Whether to enable splitwise inference, default value: mixed, supported parameters: ["mixed", "decode", "prefill"] |
|
||||
| ```innode_prefill_ports``` | `str` | Internal engine startup ports for prefill instances (only required for single-machine PD separation), default: None |
|
||||
| ```guided_decoding_backend``` | `str` | Specify the guided decoding backend to use, supports `auto`, `xgrammar`, `off`, default: `off` |
|
||||
|
||||
@@ -38,6 +38,8 @@
|
||||
| ```use_cudagraph``` | `bool` | __[已废弃]__ 2.3版本开始 CUDAGraph 默认开启,详细说明参考 [graph_optimization.md](./features/graph_optimization.md) |
|
||||
| ```graph_optimization_config``` | `dict[str]` | 可以配置计算图优化相关的参数,默认值为'{"use_cudagraph":true, "graph_opt_level":0}',详细说明参考 [graph_optimization.md](./features/graph_optimization.md)|
|
||||
| ```disable_custom_all_reduce``` | `bool` | 关闭Custom all-reduce,默认False |
|
||||
| ```use_internode_ll_two_stage``` | `bool` | 是否在DeepEP MoE中使用两阶段通信, default: False |
|
||||
| ```disable_sequence_parallel_moe``` | `bool` | 禁止在TP+EP中使用序列并行优化, default: False |
|
||||
| ```splitwise_role``` | `str` | 是否开启splitwise推理,默认值mixed, 支持参数为["mixed", "decode", "prefill"] |
|
||||
| ```innode_prefill_ports``` | `str` | prefill 实例内部引擎启动端口 (仅单机PD分离需要),默认值None |
|
||||
| ```guided_decoding_backend``` | `str` | 指定要使用的guided decoding后端,支持 `auto`、`xgrammar`、`off`, 默认为 `off` |
|
||||
|
||||
@@ -307,8 +307,8 @@ class ModelConfig:
|
||||
Read configuration information from environment variables and update the object's attributes.
|
||||
If an attribute is not present or is an empty string in the environment variables, use the default value.
|
||||
"""
|
||||
self.max_stop_seqs_num = int(envs.FD_MAX_STOP_SEQS_NUM)
|
||||
self.stop_seqs_max_len = int(envs.FD_STOP_SEQS_MAX_LEN)
|
||||
self.max_stop_seqs_num = envs.FD_MAX_STOP_SEQS_NUM
|
||||
self.stop_seqs_max_len = envs.FD_STOP_SEQS_MAX_LEN
|
||||
|
||||
def reset_config_value(key, value):
|
||||
if not hasattr(self, key.lower()):
|
||||
@@ -548,6 +548,8 @@ class ParallelConfig:
|
||||
self.do_profile: bool = False
|
||||
# Use internode_ll_two_stage or not
|
||||
self.use_internode_ll_two_stage: bool = False
|
||||
# disable sequence parallel moe
|
||||
self.disable_sequence_parallel_moe: bool = False
|
||||
|
||||
self.pod_ip: str = None
|
||||
# enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce).
|
||||
@@ -577,14 +579,14 @@ class ParallelConfig:
|
||||
else:
|
||||
self.pd_disaggregation_mode = "None"
|
||||
|
||||
# ep+tp strategy: "all_reduce" or "all_to_all"
|
||||
# all_reduce: qkv_linear + attn + out_linear + allreduce
|
||||
# all_to_all: allgather + qkv_linear + attn + all2all + out_linear
|
||||
self.ep_tp_strategy = envs.FD_EP_TP_STRATEGY
|
||||
assert self.ep_tp_strategy in [
|
||||
"all_reduce",
|
||||
"all_to_all",
|
||||
], f"FD_EP_TP_STRATEGY: '{self.ep_tp_strategy}' is not supported, only supports 'all_reduce' or 'all_to_all'."
|
||||
# disable_sequence_parallel_moe: qkv_linear + attn + out_linear + allreduce
|
||||
# use_sequence_parallel_moe: allgather + qkv_linear + attn + all2all + out_linear
|
||||
self.use_sequence_parallel_moe = (
|
||||
(not self.disable_sequence_parallel_moe)
|
||||
and self.expert_parallel_size > 1
|
||||
and self.tensor_parallel_size > 1
|
||||
)
|
||||
logger.info(f"use_sequence_parallel_moe: {self.use_sequence_parallel_moe}")
|
||||
|
||||
def set_communicate_group(self):
|
||||
# different tp group id
|
||||
|
||||
@@ -240,7 +240,7 @@ class EngineArgs:
|
||||
|
||||
disable_custom_all_reduce: bool = False
|
||||
"""
|
||||
Flag to enable the custom all-reduce kernel.
|
||||
Flag to disable the custom all-reduce kernel.
|
||||
"""
|
||||
|
||||
use_internode_ll_two_stage: bool = False
|
||||
@@ -248,6 +248,19 @@ class EngineArgs:
|
||||
Flag to use the internode_ll_two_stage kernel.
|
||||
"""
|
||||
|
||||
disable_sequence_parallel_moe: bool = False
|
||||
"""
|
||||
# The all_reduce at the end of attention (during o_proj) means that
|
||||
# inputs are replicated across each rank of the tensor parallel group.
|
||||
# If using expert-parallelism with DeepEP All2All ops, replicated
|
||||
# tokens results in useless duplicate computation and communication.
|
||||
#
|
||||
# In this case, ensure the input to the experts is sequence parallel
|
||||
# to avoid the excess work.
|
||||
#
|
||||
# This optimization is enabled by default, and can be disabled by using this flag.
|
||||
"""
|
||||
|
||||
engine_worker_queue_port: str = "0"
|
||||
"""
|
||||
Port for worker queue communication.
|
||||
@@ -766,6 +779,12 @@ class EngineArgs:
|
||||
default=EngineArgs.use_internode_ll_two_stage,
|
||||
help="Flag to use the internode_ll_two_stage kernel.",
|
||||
)
|
||||
parallel_group.add_argument(
|
||||
"--disable-sequence-parallel-moe",
|
||||
action="store_true",
|
||||
default=EngineArgs.disable_sequence_parallel_moe,
|
||||
help="Flag to disable disable the sequence parallel moe.",
|
||||
)
|
||||
parallel_group.add_argument(
|
||||
"--max-num-seqs",
|
||||
type=int,
|
||||
|
||||
@@ -842,6 +842,8 @@ class AsyncLLMEngine:
|
||||
"dynamic_load_weight": self.cfg.load_config.dynamic_load_weight,
|
||||
"disable_any_whitespace": self.cfg.structured_outputs_config.disable_any_whitespace,
|
||||
"disable_custom_all_reduce": self.cfg.parallel_config.disable_custom_all_reduce,
|
||||
"use_internode_ll_two_stage": self.cfg.parallel_config.use_internode_ll_two_stage,
|
||||
"disable_sequence_parallel_moe": self.cfg.parallel_config.disable_sequence_parallel_moe,
|
||||
"enable_logprob": self.cfg.model_config.enable_logprob,
|
||||
"lm_head_fp32": self.cfg.model_config.lm_head_fp32,
|
||||
}
|
||||
|
||||
@@ -286,7 +286,7 @@ class LLMEngine:
|
||||
|
||||
if request.get("stop_seqs_len") is not None:
|
||||
stop_seqs_len = request.get("stop_seqs_len")
|
||||
max_stop_seqs_num = int(envs.FD_MAX_STOP_SEQS_NUM)
|
||||
max_stop_seqs_num = envs.FD_MAX_STOP_SEQS_NUM
|
||||
if len(stop_seqs_len) > max_stop_seqs_num:
|
||||
error_msg = (
|
||||
f"Length of stop ({stop_seqs_len}) exceeds the limit max_stop_seqs_num({max_stop_seqs_num})."
|
||||
@@ -294,7 +294,7 @@ class LLMEngine:
|
||||
)
|
||||
llm_logger.error(error_msg)
|
||||
raise EngineError(error_msg, error_code=400)
|
||||
stop_seqs_max_len = int(envs.FD_STOP_SEQS_MAX_LEN)
|
||||
stop_seqs_max_len = envs.FD_STOP_SEQS_MAX_LEN
|
||||
for single_stop_seq_len in stop_seqs_len:
|
||||
if single_stop_seq_len > stop_seqs_max_len:
|
||||
error_msg = (
|
||||
@@ -568,6 +568,7 @@ class LLMEngine:
|
||||
"disable_any_whitespace": self.cfg.structured_outputs_config.disable_any_whitespace,
|
||||
"disable_custom_all_reduce": self.cfg.parallel_config.disable_custom_all_reduce,
|
||||
"use_internode_ll_two_stage": self.cfg.parallel_config.use_internode_ll_two_stage,
|
||||
"disable_sequence_parallel_moe": self.cfg.parallel_config.disable_sequence_parallel_moe,
|
||||
"enable_logprob": self.cfg.model_config.enable_logprob,
|
||||
"lm_head_fp32": self.cfg.model_config.lm_head_fp32,
|
||||
}
|
||||
|
||||
@@ -235,7 +235,7 @@ class EngineClient:
|
||||
|
||||
if "stop_seqs_len" in task:
|
||||
stop_seqs_len = task["stop_seqs_len"]
|
||||
max_stop_seqs_num = int(envs.FD_MAX_STOP_SEQS_NUM)
|
||||
max_stop_seqs_num = envs.FD_MAX_STOP_SEQS_NUM
|
||||
if len(stop_seqs_len) > max_stop_seqs_num:
|
||||
error_msg = (
|
||||
f"Length of stop ({stop_seqs_len}) exceeds the limit max_stop_seqs_num({max_stop_seqs_num})."
|
||||
@@ -243,7 +243,7 @@ class EngineClient:
|
||||
)
|
||||
api_server_logger.error(error_msg)
|
||||
raise EngineError(error_msg, error_code=400)
|
||||
stop_seqs_max_len = int(envs.FD_STOP_SEQS_MAX_LEN)
|
||||
stop_seqs_max_len = envs.FD_STOP_SEQS_MAX_LEN
|
||||
for single_stop_seq_len in stop_seqs_len:
|
||||
if single_stop_seq_len > stop_seqs_max_len:
|
||||
error_msg = (
|
||||
|
||||
@@ -35,9 +35,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# Model download cache directory.
|
||||
"FD_MODEL_CACHE": lambda: os.getenv("FD_MODEL_CACHE", None),
|
||||
# Maximum number of stop sequences.
|
||||
"FD_MAX_STOP_SEQS_NUM": lambda: os.getenv("FD_MAX_STOP_SEQS_NUM", "5"),
|
||||
"FD_MAX_STOP_SEQS_NUM": lambda: int(os.getenv("FD_MAX_STOP_SEQS_NUM", "5")),
|
||||
# Maximum length of stop sequences.
|
||||
"FD_STOP_SEQS_MAX_LEN": lambda: os.getenv("FD_STOP_SEQS_MAX_LEN", "8"),
|
||||
"FD_STOP_SEQS_MAX_LEN": lambda: int(os.getenv("FD_STOP_SEQS_MAX_LEN", "8")),
|
||||
# GPU devices that will be used. This is a string that
|
||||
# splited by comma, such as 0,1,2.
|
||||
"CUDA_VISIBLE_DEVICES": lambda: os.getenv("CUDA_VISIBLE_DEVICES", None),
|
||||
@@ -159,10 +159,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"FD_OFFLINE_PERF_TEST_FOR_PD": lambda: int(os.getenv("FD_OFFLINE_PERF_TEST_FOR_PD", "0")),
|
||||
"FD_ENABLE_E2W_TENSOR_CONVERT": lambda: int(os.getenv("FD_ENABLE_E2W_TENSOR_CONVERT", "0")),
|
||||
"FD_ENGINE_TASK_QUEUE_WITH_SHM": lambda: int(os.getenv("FD_ENGINE_TASK_QUEUE_WITH_SHM", "0")),
|
||||
# ep+tp strategy: "all_reduce" or "all_to_all"
|
||||
# all_reduce: qkv_linear + attn + out_linear + allreduce
|
||||
# all_to_all: allgather + qkv_linear + attn + all2all + out_linear
|
||||
"FD_EP_TP_STRATEGY": lambda: os.getenv("FD_EP_TP_STRATEGY", "all_reduce"),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -550,7 +550,6 @@ class DataProcessor(BaseDataProcessor):
|
||||
tokenize=False,
|
||||
split_special_tokens=False,
|
||||
add_special_tokens=False,
|
||||
return_tensors="pd",
|
||||
**kwargs,
|
||||
)
|
||||
request["prompt_tokens"] = spliced_message
|
||||
|
||||
@@ -823,9 +823,7 @@ class RowParallelLinear(LinearBase):
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
self.head_dim = fd_config.model_config.head_dim
|
||||
self.split_token = (
|
||||
self.ep_size > 1
|
||||
and self.tp_size > 1
|
||||
and fd_config.parallel_config.ep_tp_strategy == "all_to_all"
|
||||
fd_config.parallel_config.use_sequence_parallel_moe
|
||||
and layer_id >= fd_config.model_config.moe_layer_start_index
|
||||
and layer_id < fd_config.model_config.num_hidden_layers
|
||||
)
|
||||
@@ -853,7 +851,7 @@ class RowParallelLinear(LinearBase):
|
||||
self.quant_method.create_weights(
|
||||
self,
|
||||
split_axis=0,
|
||||
output_dim=False,
|
||||
output_dim=None if self.split_token else False,
|
||||
weight_loader=(
|
||||
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
|
||||
),
|
||||
@@ -877,7 +875,7 @@ class RowParallelLinear(LinearBase):
|
||||
paddle.distributed.alltoall(out, x, group=self.tp_group)
|
||||
out.reshape_([self.tp_size, -1, x.shape[1]])
|
||||
out = paddle.transpose(out, [1, 0, 2])
|
||||
out.reshape_([x.shape[0] // self.tp_size, self.hidden_size])
|
||||
out.reshape_([x.shape[0] // self.tp_size, self.input_size])
|
||||
return out
|
||||
|
||||
def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:
|
||||
|
||||
@@ -137,7 +137,6 @@ class FusedMoE(nn.Layer):
|
||||
self.ep_size = fd_config.parallel_config.expert_parallel_size
|
||||
self.ep_rank = fd_config.parallel_config.expert_parallel_rank
|
||||
self.tp_group = fd_config.parallel_config.tp_group
|
||||
self.ep_tp_strategy = self.fd_config.parallel_config.ep_tp_strategy
|
||||
# NOTE(Zhenyu Li): just supports tp_size = 1 when ep_size > 1 in MOE now.
|
||||
if self.ep_size > 1:
|
||||
self.tp_size = 1
|
||||
@@ -582,20 +581,18 @@ class FusedMoE(nn.Layer):
|
||||
Forward split allgather function.
|
||||
"""
|
||||
token_num = x.shape[0]
|
||||
tp_size = self.fd_config.parallel_config.tensor_parallel_size
|
||||
tp_rank = self.fd_config.parallel_config.tensor_parallel_rank
|
||||
token_num_per_rank = (token_num + tp_size - 1) // tp_size
|
||||
token_num_per_rank = (token_num + self.tp_size - 1) // self.tp_size
|
||||
# AllGather will hang when the data shapes on multi-ranks are different!
|
||||
part_x = paddle.zeros(shape=[token_num_per_rank, x.shape[1]], dtype=x.dtype)
|
||||
start_offset = tp_rank * token_num_per_rank
|
||||
end_offset = (tp_rank + 1) * token_num_per_rank
|
||||
start_offset = self.tp_rank * token_num_per_rank
|
||||
end_offset = (self.tp_rank + 1) * token_num_per_rank
|
||||
if start_offset >= token_num:
|
||||
start_offset = token_num
|
||||
if end_offset > token_num:
|
||||
end_offset = token_num
|
||||
part_x[: (end_offset - start_offset), :] = x[start_offset:end_offset, :]
|
||||
out = self.quant_method.apply(self, part_x, gate)
|
||||
multi_outs = paddle.zeros([token_num_per_rank * tp_size, x.shape[1]], dtype=x.dtype)
|
||||
multi_outs = paddle.zeros([token_num_per_rank * self.tp_size, x.shape[1]], dtype=x.dtype)
|
||||
paddle.distributed.all_gather(multi_outs, out, self.tp_group)
|
||||
out = multi_outs[:token_num, :]
|
||||
return out
|
||||
@@ -612,8 +609,12 @@ class FusedMoE(nn.Layer):
|
||||
|
||||
"""
|
||||
token_num = x.shape[0]
|
||||
tp_size = self.fd_config.parallel_config.tensor_parallel_size
|
||||
if self.ep_size > 1 and tp_size > 1 and self.ep_tp_strategy == "all_reduce" and token_num >= tp_size:
|
||||
if (
|
||||
self.ep_size > 1
|
||||
and self.tp_size > 1
|
||||
and (not self.fd_config.parallel_config.use_sequence_parallel_moe)
|
||||
and token_num >= self.tp_size
|
||||
):
|
||||
out = self.forward_split_allgather(x, gate)
|
||||
else:
|
||||
out = self.quant_method.apply(self, x, gate)
|
||||
|
||||
@@ -100,27 +100,19 @@ class RMSNorm(nn.Layer):
|
||||
self.begin_norm_axis: int = begin_norm_axis
|
||||
|
||||
self.layer_id = layer_id
|
||||
parallel_config = self.fd_config.parallel_config
|
||||
self.ep_size = parallel_config.expert_parallel_size
|
||||
self.tp_size = parallel_config.tensor_parallel_size
|
||||
self.tp_rank = parallel_config.tensor_parallel_rank
|
||||
self.tp_group = parallel_config.tp_group
|
||||
self.ep_tp_strategy = parallel_config.ep_tp_strategy
|
||||
self.moe_layer_start_index = self.fd_config.model_config.moe_layer_start_index
|
||||
self.ep_size = self.fd_config.parallel_config.expert_parallel_size
|
||||
self.tp_size = self.fd_config.parallel_config.tensor_parallel_size
|
||||
self.tp_rank = self.fd_config.parallel_config.tensor_parallel_rank
|
||||
self.tp_group = self.fd_config.parallel_config.tp_group
|
||||
is_input_norm = prefix.endswith(".input_layernorm")
|
||||
is_last_norm = prefix.endswith(".norm")
|
||||
self.split_x = (
|
||||
self.ep_size > 1
|
||||
and self.tp_size > 1
|
||||
and self.ep_tp_strategy == "all_to_all"
|
||||
and self.layer_id == self.moe_layer_start_index
|
||||
self.fd_config.parallel_config.use_sequence_parallel_moe
|
||||
and self.layer_id == self.fd_config.model_config.moe_layer_start_index
|
||||
and is_input_norm
|
||||
)
|
||||
self.allgather_out = (
|
||||
self.ep_size > 1
|
||||
and self.tp_size > 1
|
||||
and self.ep_tp_strategy == "all_to_all"
|
||||
and ((self.layer_id > self.moe_layer_start_index and is_input_norm) or is_last_norm)
|
||||
self.allgather_out = self.fd_config.parallel_config.use_sequence_parallel_moe and (
|
||||
(self.layer_id > self.fd_config.model_config.moe_layer_start_index and is_input_norm) or is_last_norm
|
||||
)
|
||||
|
||||
self.init_weight()
|
||||
@@ -193,6 +185,7 @@ class RMSNorm(nn.Layer):
|
||||
x,
|
||||
residual_input: Optional[paddle.Tensor] = None,
|
||||
forward_meta: Optional[ForwardMeta] = None,
|
||||
external_rmsnorm: Optional[Callable] = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Defines the forward computation of the layer.
|
||||
@@ -215,37 +208,46 @@ class RMSNorm(nn.Layer):
|
||||
if residual_input is not None:
|
||||
residual_input_dtype = residual_input.dtype
|
||||
residual_input = residual_input.astype(self.weight.dtype)
|
||||
if current_platform.is_gcu():
|
||||
if residual_input is None:
|
||||
norm_out = rms_norm(x, self.weight, self.eps)
|
||||
return norm_out.astype(x_dtype)
|
||||
norm_out = self.norm_func(x, residual_input, self.weight, self.eps)
|
||||
else:
|
||||
norm_out = self.norm_func(
|
||||
x,
|
||||
norm_weight=self.weight,
|
||||
norm_bias=None,
|
||||
epsilon=self.eps,
|
||||
begin_norm_axis=self.begin_norm_axis,
|
||||
bias=self.bias,
|
||||
residual=residual_input,
|
||||
quant_scale=(-1 if self.quant_scale is None else self.quant_scale),
|
||||
quant_round_type=self.quant_round_type,
|
||||
quant_max_bound=self.quant_max_bound,
|
||||
quant_min_bound=self.quant_min_bound,
|
||||
)
|
||||
out = norm_out[0].astype(x_dtype)
|
||||
residual_out = norm_out[1].astype(residual_input_dtype) if residual_input is not None else None
|
||||
|
||||
if self.split_x:
|
||||
residual_out = self.split(residual_out)
|
||||
if self.allgather_out:
|
||||
out = self.allgather(out, forward_meta.ids_remove_padding.shape[0])
|
||||
|
||||
if residual_input is None:
|
||||
return out
|
||||
residual_out = x
|
||||
if external_rmsnorm is None:
|
||||
if current_platform.is_gcu():
|
||||
if residual_input is None:
|
||||
norm_out = rms_norm(x, self.weight, self.eps)
|
||||
return norm_out.astype(x_dtype), residual_out
|
||||
norm_out = self.norm_func(x, residual_input, self.weight, self.eps)
|
||||
else:
|
||||
norm_out = self.norm_func(
|
||||
x,
|
||||
norm_weight=self.weight,
|
||||
norm_bias=None,
|
||||
epsilon=self.eps,
|
||||
begin_norm_axis=self.begin_norm_axis,
|
||||
bias=self.bias,
|
||||
residual=residual_input,
|
||||
quant_scale=(-1 if self.quant_scale is None else self.quant_scale),
|
||||
quant_round_type=self.quant_round_type,
|
||||
quant_max_bound=self.quant_max_bound,
|
||||
quant_min_bound=self.quant_min_bound,
|
||||
)
|
||||
else:
|
||||
return out, residual_out
|
||||
if residual_input is not None:
|
||||
x = x + residual_input
|
||||
norm_out = external_rmsnorm(x, self.weight, self.eps), x
|
||||
|
||||
out = norm_out[0].astype(x_dtype)
|
||||
if residual_input is not None:
|
||||
residual_out = norm_out[1].astype(residual_input_dtype)
|
||||
|
||||
if self.split_x:
|
||||
assert residual_out is not None
|
||||
residual_out = self.split(residual_out)
|
||||
if self.allgather_out:
|
||||
assert forward_meta is not None
|
||||
out = self.allgather(out, forward_meta.ids_remove_padding.shape[0])
|
||||
|
||||
return out, residual_out
|
||||
|
||||
|
||||
class LayerNorm(nn.Layer):
|
||||
|
||||
@@ -130,7 +130,9 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
|
||||
dtype="float32",
|
||||
is_bias=False,
|
||||
)
|
||||
extra_weight_attrs["output_dim"] = not extra_weight_attrs["output_dim"]
|
||||
extra_weight_attrs["output_dim"] = (
|
||||
not extra_weight_attrs["output_dim"] if extra_weight_attrs["output_dim"] is not None else None
|
||||
)
|
||||
|
||||
extra_weight_attrs["weight_need_transpose"] = not extra_weight_attrs.get("model_format") == "torch"
|
||||
set_weight_attrs(
|
||||
|
||||
@@ -269,7 +269,7 @@ def load_ep_checkpoint(cls: PretrainedModel, model_path: str, fd_config: FDConfi
|
||||
|
||||
if fd_config.parallel_config.tensor_parallel_size > 1:
|
||||
no_tp_action_keys = copy.deepcopy(num_local_ffn_keys)
|
||||
if fd_config.parallel_config.ep_tp_strategy == "all_to_all":
|
||||
if fd_config.parallel_config.use_sequence_parallel_moe:
|
||||
for i in range(fd_config.model_config.moe_layer_start_index, fd_config.model_config.num_hidden_layers):
|
||||
k = f"ernie.layers.{i}.self_attn.o_proj.weight"
|
||||
if k in weight_list:
|
||||
|
||||
@@ -271,6 +271,7 @@ class DeepseekV3MLAAttention(nn.Layer):
|
||||
input_size=self.num_attention_heads * self.v_head_dim,
|
||||
output_size=self.hidden_size,
|
||||
with_bias=False,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
|
||||
self.kv_b_proj_bmm = KVBatchLinear(
|
||||
@@ -344,13 +345,13 @@ class DeepseekV3MLAAttention(nn.Layer):
|
||||
[self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim], axis=-1
|
||||
)
|
||||
|
||||
query = self.q_a_layernorm(query)
|
||||
query = self.q_a_layernorm(query)[0]
|
||||
query = self.q_b_proj(query)
|
||||
query = query.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim])
|
||||
query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)
|
||||
|
||||
key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim])
|
||||
compressed_kv = self.kv_a_layernorm(compressed_kv)
|
||||
compressed_kv = self.kv_a_layernorm(compressed_kv)[0]
|
||||
|
||||
query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe)
|
||||
|
||||
@@ -479,6 +480,7 @@ class DeepSeekV3DecoderLayer(nn.Layer):
|
||||
hidden_size=fd_config.model_config.hidden_size,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix=f"{prefix}.input_layernorm",
|
||||
layer_id=layer_id,
|
||||
)
|
||||
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
@@ -486,6 +488,7 @@ class DeepSeekV3DecoderLayer(nn.Layer):
|
||||
hidden_size=fd_config.model_config.hidden_size,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix=f"{prefix}.post_attention_layernorm",
|
||||
layer_id=layer_id,
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
@@ -504,11 +507,9 @@ class DeepSeekV3DecoderLayer(nn.Layer):
|
||||
mask_encoder_batch: paddle.Tensor,
|
||||
):
|
||||
""" """
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual_input=residual, forward_meta=forward_meta
|
||||
)
|
||||
|
||||
hidden_states = self.self_attn(forward_meta, hidden_states, position_ids, mask_encoder_batch)
|
||||
|
||||
@@ -588,8 +589,7 @@ class DeepSeekV3Model(nn.Layer):
|
||||
position_ids,
|
||||
mask_encoder_batch,
|
||||
)
|
||||
hidden_states = hidden_states + residual
|
||||
out = self.norm(hidden_states)
|
||||
out = self.norm(hidden_states, residual, forward_meta=forward_meta)[0]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@@ -312,6 +312,7 @@ class Ernie4_5_DecoderLayer(nn.Layer):
|
||||
hidden_size=fd_config.model_config.hidden_size,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix=f"{prefix}.post_attention_layernorm",
|
||||
layer_id=layer_id,
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
@@ -329,18 +330,9 @@ class Ernie4_5_DecoderLayer(nn.Layer):
|
||||
hidden_states: paddle.Tensor,
|
||||
residual: paddle.Tensor = None,
|
||||
):
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(
|
||||
hidden_states,
|
||||
forward_meta=forward_meta,
|
||||
)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states,
|
||||
residual,
|
||||
forward_meta=forward_meta,
|
||||
)
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual_input=residual, forward_meta=forward_meta
|
||||
)
|
||||
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
@@ -350,7 +342,6 @@ class Ernie4_5_DecoderLayer(nn.Layer):
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states,
|
||||
residual,
|
||||
forward_meta=forward_meta,
|
||||
)
|
||||
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
@@ -455,9 +446,7 @@ class Ernie4_5_Model(nn.Layer):
|
||||
for i in range(self.num_layers):
|
||||
hidden_states, residual = self.layers[i](forward_meta, hidden_states, residual)
|
||||
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
out = self.norm(hidden_states, forward_meta=forward_meta)
|
||||
out = self.norm(hidden_states, residual, forward_meta=forward_meta)[0]
|
||||
|
||||
if current_platform.is_iluvatar() and forward_meta.attn_backend.mixed:
|
||||
out = forward_meta.attn_backend.reverse_transpose(out)
|
||||
|
||||
@@ -318,7 +318,7 @@ class Ernie4_5_MTPModel(nn.Layer):
|
||||
"""
|
||||
inputs_embedding = self.embed_tokens(ids_remove_padding=ids_remove_padding)
|
||||
inputs_embedding = paddle.concat(
|
||||
[self.enorm(inputs_embedding), self.hnorm(previous_hidden_states)],
|
||||
[self.enorm(inputs_embedding)[0], self.hnorm(previous_hidden_states)[0]],
|
||||
axis=-1,
|
||||
)
|
||||
hidden_states = self.eh_proj(inputs_embedding)
|
||||
@@ -326,9 +326,7 @@ class Ernie4_5_MTPModel(nn.Layer):
|
||||
for i in range(self.num_layers):
|
||||
hidden_states, residual = self.mtp_block[i](forward_meta, hidden_states, residual)
|
||||
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = self.norm(hidden_states, residual)[0]
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -358,6 +358,7 @@ class Ernie4_5_VLDecoderLayer(nn.Layer):
|
||||
hidden_size=fd_config.model_config.hidden_size,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix=f"{prefix}.input_layernorm",
|
||||
layer_id=layer_id,
|
||||
)
|
||||
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
@@ -365,6 +366,7 @@ class Ernie4_5_VLDecoderLayer(nn.Layer):
|
||||
hidden_size=fd_config.model_config.hidden_size,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix=f"{prefix}.post_attention_layernorm",
|
||||
layer_id=layer_id,
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
@@ -380,11 +382,9 @@ class Ernie4_5_VLDecoderLayer(nn.Layer):
|
||||
residual: paddle.Tensor = None,
|
||||
vl_moe_meta: VLMoEMeta = None,
|
||||
):
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual_input=residual, forward_meta=forward_meta
|
||||
)
|
||||
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
@@ -546,8 +546,7 @@ class Ernie4_5_VLModel(nn.Layer):
|
||||
vl_moe_meta,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + residual
|
||||
out = self.norm(hidden_states)
|
||||
out = self.norm(hidden_states, residual, forward_meta=forward_meta)[0]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@@ -194,6 +194,7 @@ class Glm4MoeAttention(nn.Layer):
|
||||
prefix=f"{prefix}.o_proj",
|
||||
input_size=fd_config.model_config.num_attention_heads * fd_config.model_config.head_dim,
|
||||
output_size=fd_config.model_config.hidden_size,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
|
||||
self.attn = Attention(
|
||||
@@ -229,8 +230,8 @@ class Glm4MoeAttention(nn.Layer):
|
||||
|
||||
if self.use_qk_norm:
|
||||
q, k, v = qkv_out.split([self.q_size, self.kv_size, self.kv_size], axis=-1)
|
||||
q = self.q_norm(q.reshape([-1, self.num_heads, self.head_dim])).reshape(q.shape)
|
||||
k = self.k_norm(k.reshape([-1, self.num_kv_heads, self.head_dim])).reshape(k.shape)
|
||||
q = self.q_norm(q.reshape([-1, self.num_heads, self.head_dim]))[0].reshape(q.shape)
|
||||
k = self.k_norm(k.reshape([-1, self.num_kv_heads, self.head_dim]))[0].reshape(k.shape)
|
||||
qkv_out = paddle.concat([q, k, v], axis=-1)
|
||||
|
||||
atten_out = self.attn(
|
||||
@@ -275,6 +276,7 @@ class Glm4MoeDecoderLayer(nn.Layer):
|
||||
hidden_size=fd_config.model_config.hidden_size,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix=f"{prefix}.input_layernorm",
|
||||
layer_id=layer_id,
|
||||
)
|
||||
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
@@ -282,6 +284,7 @@ class Glm4MoeDecoderLayer(nn.Layer):
|
||||
hidden_size=fd_config.model_config.hidden_size,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix=f"{prefix}.post_attention_layernorm",
|
||||
layer_id=layer_id,
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -291,11 +294,9 @@ class Glm4MoeDecoderLayer(nn.Layer):
|
||||
residual: paddle.Tensor = None,
|
||||
):
|
||||
""" """
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual_input=residual, forward_meta=forward_meta
|
||||
)
|
||||
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
@@ -366,9 +367,8 @@ class Glm4MoeModel(nn.Layer):
|
||||
|
||||
for i in range(self.num_layers):
|
||||
hidden_states, residual = self.layers[i](forward_meta, hidden_states, residual)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
out = self.norm(hidden_states)
|
||||
out = self.norm(hidden_states, residual, forward_meta=forward_meta)[0]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@@ -150,6 +150,7 @@ class GptOssDecoderLayer(nn.Layer):
|
||||
hidden_size=hidden_size,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix=f"{prefix}.post_attention_layernorm",
|
||||
layer_id=layer_id,
|
||||
)
|
||||
self.mlp = GptOssMoe(fd_config, layer_id, prefix=f"{prefix}.mlp")
|
||||
|
||||
@@ -159,11 +160,9 @@ class GptOssDecoderLayer(nn.Layer):
|
||||
hidden_states: paddle.Tensor,
|
||||
residual: paddle.Tensor = None,
|
||||
):
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual_input=residual, forward_meta=forward_meta
|
||||
)
|
||||
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
@@ -213,9 +212,8 @@ class GptOssModel(nn.Layer):
|
||||
residual = None
|
||||
for i in range(self.num_layers):
|
||||
hidden_states, residual = self.layers[i](forward_meta, hidden_states, residual)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = self.norm(hidden_states, residual)[0]
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
@@ -20,7 +20,6 @@ from typing import Dict, Optional, Union
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
from paddleformers.transformers import PretrainedModel
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
@@ -104,9 +103,7 @@ class PaddleOCRVLModel(nn.Layer):
|
||||
for i in range(self.num_layers):
|
||||
hidden_states, residual = self.layers[i](forward_meta, hidden_states, residual)
|
||||
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
out = self.norm(hidden_states)
|
||||
out = self.norm(hidden_states, residual)[0]
|
||||
|
||||
return out
|
||||
|
||||
@@ -257,94 +254,3 @@ class PaddleOCRVLForConditionalGeneration(ModelForCasualLM):
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PaddleOCRVLPretrainedModel(PretrainedModel):
|
||||
|
||||
config_class = FDConfig
|
||||
|
||||
def _init_weight(self, layer):
|
||||
"""
|
||||
_init_weight
|
||||
"""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def arch_name(self):
|
||||
return "PaddleOCRVLForConditionalGeneration"
|
||||
|
||||
from fastdeploy.model_executor.models.tp_utils import TensorSplitMode as tsm
|
||||
from fastdeploy.model_executor.models.utils import LayerIdPlaceholder as layerid
|
||||
from fastdeploy.model_executor.models.utils import WeightMeta
|
||||
|
||||
weight_infos = [
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.LAYER_ID}}}.self_attn.qkv_proj.weight",
|
||||
True,
|
||||
tsm.GQA,
|
||||
),
|
||||
WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.o_proj.weight", False),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.FFN_LAYER_ID}}}.mlp.up_gate_proj.weight",
|
||||
True,
|
||||
tsm.PairFused,
|
||||
),
|
||||
WeightMeta(f".layers.{{{layerid.FFN_LAYER_ID}}}.mlp.down_proj.weight", False),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.TEXT_EXPERT_ID}}}.up_gate_proj.weight",
|
||||
True,
|
||||
tsm.PairFused,
|
||||
),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.TEXT_EXPERT_ID}}}.down_proj.weight",
|
||||
False,
|
||||
),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.IMG_EXPERT_ID}}}.up_gate_proj.weight",
|
||||
True,
|
||||
tsm.PairFused,
|
||||
),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.IMG_EXPERT_ID}}}.down_proj.weight",
|
||||
False,
|
||||
),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.shared_experts.up_gate_proj.weight",
|
||||
True,
|
||||
tsm.PairFused,
|
||||
),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.shared_experts.down_proj.weight",
|
||||
False,
|
||||
),
|
||||
WeightMeta(
|
||||
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.shared_experts.down_proj.weight",
|
||||
False,
|
||||
),
|
||||
WeightMeta(".embed_tokens.weight", False),
|
||||
WeightMeta("lm_head.weight", True),
|
||||
]
|
||||
|
||||
weight_vison = [
|
||||
# resampler_model
|
||||
WeightMeta("ernie.resampler_model.spatial_linear.0.weight", False),
|
||||
WeightMeta("resampler_model.spatial_linear.0.weight", False),
|
||||
# vision
|
||||
WeightMeta(
|
||||
f"vision_model.blocks.{{{layerid.LAYER_ID}}}.attn.proj.weight",
|
||||
False,
|
||||
),
|
||||
WeightMeta(f"vision_model.blocks.{{{layerid.LAYER_ID}}}.mlp.fc2.weight", False),
|
||||
WeightMeta(f"vision_model.blocks.{{{layerid.LAYER_ID}}}.mlp.fc1.weight", True),
|
||||
WeightMeta(f"vision_model.blocks.{{{layerid.LAYER_ID}}}.mlp.fc1.bias", True),
|
||||
WeightMeta(
|
||||
f"vision_model.blocks.{{{layerid.LAYER_ID}}}.attn.qkv.weight",
|
||||
True,
|
||||
tsm.GQA,
|
||||
),
|
||||
WeightMeta(
|
||||
f"vision_model.blocks.{{{layerid.LAYER_ID}}}.attn.qkv.bias",
|
||||
True,
|
||||
tsm.GQA,
|
||||
),
|
||||
]
|
||||
|
||||
@@ -176,6 +176,7 @@ class Qwen2DecoderLayer(nn.Layer):
|
||||
hidden_size=fd_config.model_config.hidden_size,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix=f"{prefix}.post_attention_layernorm",
|
||||
layer_id=layer_id,
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
@@ -193,11 +194,9 @@ class Qwen2DecoderLayer(nn.Layer):
|
||||
):
|
||||
""" """
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual_input=residual, forward_meta=forward_meta
|
||||
)
|
||||
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
@@ -285,9 +284,7 @@ class Qwen2Model(nn.Layer):
|
||||
for i in range(self.num_layers):
|
||||
hidden_states, residual = self.layers[i](forward_meta, hidden_states, residual)
|
||||
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
out = self.norm(hidden_states)
|
||||
out = self.norm(hidden_states, residual)[0]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@@ -124,9 +124,7 @@ class Qwen2_5_VLModel(nn.Layer):
|
||||
residual,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
out = self.norm(hidden_states)
|
||||
out = self.norm(hidden_states, residual)[0]
|
||||
|
||||
return out
|
||||
|
||||
@@ -262,21 +260,6 @@ class Qwen2_5_VLForConditionalGeneration(ModelForCasualLM):
|
||||
|
||||
return logits
|
||||
|
||||
def empty_input_forward(self):
|
||||
"""
|
||||
empty_input_forward
|
||||
"""
|
||||
fake_hidden_states = paddle.empty(
|
||||
shape=[0, self.fd_config.model_config.hidden_size],
|
||||
dtype=paddle.get_default_dtype(),
|
||||
)
|
||||
for i in range(
|
||||
self.fd_config.model_config.moe_layer_start_index,
|
||||
self.fd_config.model_config.num_hidden_layers,
|
||||
):
|
||||
self.ernie.layers[i].mlp.text_fused_moe(fake_hidden_states)
|
||||
self.ernie.layers[i].mlp.image_fused_moe(fake_hidden_states)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
ids_remove_padding: paddle.Tensor,
|
||||
|
||||
@@ -66,6 +66,7 @@ class Qwen3Attention(nn.Layer):
|
||||
prefix=f"{prefix}.o_proj",
|
||||
input_size=fd_config.model_config.head_dim * fd_config.model_config.num_attention_heads,
|
||||
output_size=fd_config.model_config.hidden_size,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
|
||||
self.attn = Attention(
|
||||
@@ -114,11 +115,11 @@ class Qwen3Attention(nn.Layer):
|
||||
q, k, v = qkv_out.split([self.q_size, self.kv_size, self.kv_size], axis=-1)
|
||||
|
||||
q_by_head = q.reshape([*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim])
|
||||
q_by_head = self.q_norm(q_by_head)
|
||||
q_by_head = self.q_norm(q_by_head)[0]
|
||||
q = q_by_head.reshape(q.shape)
|
||||
|
||||
k_by_head = k.reshape([*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim])
|
||||
k_by_head = self.k_norm(k_by_head)
|
||||
k_by_head = self.k_norm(k_by_head)[0]
|
||||
k = k_by_head.reshape(k.shape)
|
||||
|
||||
qkv_out = paddle.concat([q, k, v], axis=-1)
|
||||
@@ -216,9 +217,7 @@ class Qwen3Model(nn.Layer):
|
||||
for i in range(self.num_layers):
|
||||
hidden_states, residual = self.layers[i](forward_meta, hidden_states, residual)
|
||||
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
out = self.norm(hidden_states)
|
||||
out = self.norm(hidden_states, residual)[0]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@@ -167,15 +167,17 @@ class Qwen3DecoderLayer(nn.Layer):
|
||||
self.input_layernorm = RMSNorm(
|
||||
fd_config,
|
||||
hidden_size=fd_config.model_config.hidden_size,
|
||||
eps=1e-6,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix=f"{prefix}.input_layernorm",
|
||||
layer_id=layer_id,
|
||||
)
|
||||
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
fd_config,
|
||||
hidden_size=fd_config.model_config.hidden_size,
|
||||
eps=1e-6,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix=f"{prefix}.post_attention_layernorm",
|
||||
layer_id=layer_id,
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
@@ -192,11 +194,9 @@ class Qwen3DecoderLayer(nn.Layer):
|
||||
residual: paddle.Tensor = None,
|
||||
):
|
||||
""" """
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual_input=residual, forward_meta=forward_meta
|
||||
)
|
||||
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
@@ -251,7 +251,7 @@ class Qwen3MoeModel(nn.Layer):
|
||||
self.norm = RMSNorm(
|
||||
fd_config,
|
||||
hidden_size=fd_config.model_config.hidden_size,
|
||||
eps=1e-6,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.norm",
|
||||
)
|
||||
|
||||
@@ -275,16 +275,14 @@ class Qwen3MoeModel(nn.Layer):
|
||||
ids_remove_padding: paddle.Tensor,
|
||||
forward_meta: ForwardMeta,
|
||||
):
|
||||
""" """
|
||||
hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding)
|
||||
|
||||
residual = None
|
||||
|
||||
for i in range(self.num_layers):
|
||||
hidden_states, residual = self.layers[i](forward_meta, hidden_states, residual)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
out = self.norm(hidden_states)
|
||||
out = self.norm(hidden_states, residual, forward_meta=forward_meta)[0]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@@ -660,6 +660,11 @@ def parse_args():
|
||||
action="store_true",
|
||||
help="enable custom all-reduce",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable_sequence_parallel_moe",
|
||||
action="store_true",
|
||||
help="disable sequence parallel moe",
|
||||
)
|
||||
parser.add_argument("--splitwise_role", type=str, default="mixed", help="splitwise role")
|
||||
parser.add_argument(
|
||||
"--tensor_parallel_size",
|
||||
|
||||
@@ -360,6 +360,7 @@ export BKCL_RDMA_VERBS=1
|
||||
|
||||
export enable_expert_parallel=1
|
||||
export enable_tensor_parallel=1
|
||||
export disable_sequence_parallel_moe=1
|
||||
|
||||
python -m pytest -s --timeout=600 tests/ci_use/XPU_45T/run_ep.py
|
||||
ep_exit_code=$?
|
||||
@@ -373,6 +374,7 @@ unset XSHMEM_QP_NUM_PER_RANK
|
||||
unset BKCL_RDMA_VERBS
|
||||
unset enable_expert_parallel
|
||||
unset enable_tensor_parallel
|
||||
unset disable_sequence_parallel_moe
|
||||
stop_processes
|
||||
|
||||
if [ ${ep_exit_code} -ne 0 ]; then
|
||||
@@ -400,7 +402,6 @@ export BKCL_RDMA_VERBS=1
|
||||
|
||||
export enable_expert_parallel=1
|
||||
export enable_tensor_parallel=1
|
||||
export FD_EP_TP_STRATEGY=all_to_all
|
||||
|
||||
python -m pytest -s --timeout=600 tests/ci_use/XPU_45T/run_ep.py
|
||||
ep_exit_code=$?
|
||||
@@ -414,7 +415,6 @@ unset XSHMEM_QP_NUM_PER_RANK
|
||||
unset BKCL_RDMA_VERBS
|
||||
unset enable_expert_parallel
|
||||
unset enable_tensor_parallel
|
||||
unset FD_EP_TP_STRATEGY
|
||||
stop_processes
|
||||
|
||||
if [ ${ep_exit_code} -ne 0 ]; then
|
||||
|
||||
@@ -26,6 +26,7 @@ def test_fd_ep():
|
||||
|
||||
enable_expert_parallel = strtobool(os.getenv("enable_expert_parallel", "1"))
|
||||
enable_tensor_parallel = strtobool(os.getenv("enable_tensor_parallel", "0"))
|
||||
disable_sequence_parallel_moe = strtobool(os.getenv("disable_sequence_parallel_moe", "0"))
|
||||
print(f"enable_expert_parallel: {enable_expert_parallel}, enable_tensor_parallel: {enable_tensor_parallel}")
|
||||
if enable_expert_parallel:
|
||||
if enable_tensor_parallel:
|
||||
@@ -47,6 +48,7 @@ def test_fd_ep():
|
||||
enable_expert_parallel=enable_expert_parallel,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
data_parallel_size=data_parallel_size,
|
||||
disable_sequence_parallel_moe=disable_sequence_parallel_moe,
|
||||
max_model_len=8192,
|
||||
quantization="wint4",
|
||||
engine_worker_queue_port=engine_worker_queue_port,
|
||||
|
||||
Reference in New Issue
Block a user