[Speculative Decoding][MTP]Support mtp in epdptp mode (#4614)

* support mtp many features

* support mtp reshard in rl mode

* fix function

* support mtp ep

* support mtp in hybird-dp-tp mode

* default open scheduler_v1 in mtp
This commit is contained in:
freeliuzc
2025-10-28 16:02:47 +08:00
committed by GitHub
parent b4014834a9
commit c63361fd1d
10 changed files with 124 additions and 74 deletions

View File

@@ -442,8 +442,7 @@ class EngineArgs:
raise NotImplementedError("Only CUDA platform supports logprob.") raise NotImplementedError("Only CUDA platform supports logprob.")
if self.speculative_config is not None and self.logprobs_mode.startswith("processed"): if self.speculative_config is not None and self.logprobs_mode.startswith("processed"):
raise NotImplementedError("processed_logprobs not support in speculative.") raise NotImplementedError("processed_logprobs not support in speculative.")
if self.speculative_config is not None:
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
if self.splitwise_role != "mixed" and self.cache_transfer_protocol != "rdma": if self.splitwise_role != "mixed" and self.cache_transfer_protocol != "rdma":
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
if not current_platform.is_cuda() and not current_platform.is_xpu(): if not current_platform.is_cuda() and not current_platform.is_xpu():

View File

@@ -92,6 +92,8 @@ class AppendAttentionBackend(AttentionBackend):
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr( self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr(
fd_config.model_config, "use_3d_rope", False fd_config.model_config, "use_3d_rope", False
) )
if fd_config.speculative_config.model_type != "main":
self.rope_3d = False
self.causal: bool = getattr(fd_config.model_config, "causal", True) self.causal: bool = getattr(fd_config.model_config, "causal", True)
self.speculative_method: str = fd_config.speculative_config.method self.speculative_method: str = fd_config.speculative_config.method
self.use_speculate: bool = self.speculative_method is not None self.use_speculate: bool = self.speculative_method is not None
@@ -364,7 +366,7 @@ class AppendAttentionBackend(AttentionBackend):
getattr(layer, "cache_v_zp", None), getattr(layer, "cache_v_zp", None),
layer.linear_shift, layer.linear_shift,
layer.linear_smooth, layer.linear_smooth,
forward_meta.attn_mask_offsets, None if self.use_speculate else forward_meta.attn_mask_offsets,
metadata.kv_signal_data_list[layer.layer_id], metadata.kv_signal_data_list[layer.layer_id],
getattr(layer, "q_norm_weight", None), getattr(layer, "q_norm_weight", None),
getattr(layer, "k_norm_weight", None), getattr(layer, "k_norm_weight", None),
@@ -383,7 +385,7 @@ class AppendAttentionBackend(AttentionBackend):
metadata.max_partition_size, metadata.max_partition_size,
metadata.encoder_max_partition_size, metadata.encoder_max_partition_size,
self.speculate_max_draft_token_num + 1, self.speculate_max_draft_token_num + 1,
self.causal, self.causal or self.use_speculate,
self.speculative_method is not None, self.speculative_method is not None,
sliding_window, sliding_window,
) )

View File

@@ -18,7 +18,7 @@ import paddle
from paddle import nn from paddle import nn
from paddle.distributed import fleet from paddle.distributed import fleet
from fastdeploy.model_executor.utils import set_weight_attrs from fastdeploy.model_executor.utils import default_weight_loader, set_weight_attrs
from .utils import get_tensor from .utils import get_tensor
@@ -53,44 +53,61 @@ class ParallelEHProjection(nn.Layer):
self.bias_key = prefix + ".bias" self.bias_key = prefix + ".bias"
else: else:
self.bias_key = None self.bias_key = None
self.use_ep = fd_config.parallel_config.use_ep self.fd_config = fd_config
self.tp_group = fd_config.parallel_config.tp_group
self.column_cut = True self.column_cut = True
self.nranks = fd_config.parallel_config.tensor_parallel_size
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear RowParallelLinear = fleet.meta_parallel.RowParallelLinear
if self.use_ep: if self.column_cut:
self.weight = self.create_parameter( need_gather = True
shape=[embedding_dim, num_embeddings], self.linear = ColumnParallelLinear(
dtype=paddle.get_default_dtype(), embedding_dim,
is_bias=False, num_embeddings,
mp_group=self.tp_group,
weight_attr=None,
has_bias=True if self.bias_key is not None else False,
gather_output=need_gather,
fuse_matmul_bias=False, # False diff更小
) )
else: set_weight_attrs(
if self.column_cut: self.linear.weight,
need_gather = True {
self.linear = ColumnParallelLinear( "weight_loader": default_weight_loader(self.fd_config),
embedding_dim, "model_format": self.fd_config.model_config.model_format,
num_embeddings, },
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(), )
weight_attr=None, if self.bias_key is not None:
has_bias=True if self.bias_key is not None else False, set_weight_attrs(
gather_output=need_gather, self.linear.bias,
fuse_matmul_bias=False, # False diff更小 {"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}},
) )
if self.nranks > 1:
set_weight_attrs(self.linear.weight, {"output_dim": True}) set_weight_attrs(self.linear.weight, {"output_dim": True})
if self.bias_key is not None: else:
set_weight_attrs(self.linear.bias, {"output_dim": True}) self.linear = RowParallelLinear(
else: embedding_dim,
self.linear = RowParallelLinear( num_embeddings,
embedding_dim, mp_group=self.tp_group,
num_embeddings, weight_attr=None,
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(), has_bias=True if self.bias_key is not None else False,
weight_attr=None, input_is_parallel=False,
has_bias=True if self.bias_key is not None else False, fuse_matmul_bias=False, # False diff更小
input_is_parallel=False, )
fuse_matmul_bias=False, # False diff更小 set_weight_attrs(
) self.linear.weight,
set_weight_attrs(self.linear.weight, {"output_dim": False}) {
"weight_loader": default_weight_loader(self.fd_config),
"model_format": self.fd_config.model_config.model_format,
},
)
if self.nranks > 1:
set_weight_attrs(self.linear.weight, {"output_dim": True})
set_weight_attrs(
self.linear.weight, {"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}}
)
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
""" """
@@ -100,17 +117,14 @@ class ParallelEHProjection(nn.Layer):
state_dict (dict): A dictionary containing the checkpoint weights and biases. state_dict (dict): A dictionary containing the checkpoint weights and biases.
""" """
if self.use_ep: weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype())
self.weight.set_value(get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype())) if self.linear.weight.shape != weight_tensor.shape:
else: weight_tensor = weight_tensor.transpose([1, 0])
weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()) self.linear.weight.set_value(weight_tensor)
if self.linear.weight.shape != weight_tensor.shape:
weight_tensor = weight_tensor.transpose([1, 0])
self.linear.weight.set_value(weight_tensor)
if self.bias_key is not None: if self.bias_key is not None:
bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype()) bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype())
self.linear.bias.set_value(bias) self.linear.bias.set_value(bias)
def forward(self, input): def forward(self, input):
""" """
@@ -123,8 +137,5 @@ class ParallelEHProjection(nn.Layer):
Tensor: The output tensor after processing through the layer. Tensor: The output tensor after processing through the layer.
""" """
logits = input logits = input
if self.use_ep: logits = self.linear(logits)
logits = paddle.matmul(logits, self.weight)
else:
logits = self.linear(logits)
return logits return logits

View File

@@ -72,6 +72,11 @@ class DefaultModelLoader(BaseModelLoader):
# register rl model # register rl model
import fastdeploy.rl # noqa import fastdeploy.rl # noqa
if fd_config.speculative_config.model_type != "mtp":
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MoeForCausalLM")
else:
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MTPForCausalLM")
architectures = architectures + "RL" architectures = architectures + "RL"
context = paddle.LazyGuard() context = paddle.LazyGuard()
else: else:

View File

@@ -65,6 +65,11 @@ class DefaultModelLoaderV1(BaseModelLoader):
# register rl model # register rl model
import fastdeploy.rl # noqa import fastdeploy.rl # noqa
if fd_config.speculative_config.model_type != "mtp":
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MoeForCausalLM")
else:
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MTPForCausalLM")
architectures = architectures + "RL" architectures = architectures + "RL"
enable_cache, _, weight_cache_context = is_weight_cache_enabled(fd_config) enable_cache, _, weight_cache_context = is_weight_cache_enabled(fd_config)

View File

@@ -502,7 +502,7 @@ class TokenProcessor:
def _compute_speculative_status(self): def _compute_speculative_status(self):
# TODO(liuzichang): Supplement more statistics # TODO(liuzichang): Supplement more statistics
interval = 10 interval = 1
if self.speculative_stats_step % interval == 0: if self.speculative_stats_step % interval == 0:
accept_ratio = 1 - self.total_step * 1.0 / self.number_of_output_tokens accept_ratio = 1 - self.total_step * 1.0 / self.number_of_output_tokens
spec_logger.info( spec_logger.info(
@@ -593,6 +593,9 @@ class TokenProcessor:
+ accept_num[i] + accept_num[i]
].tolist() ].tolist()
if (not recovery_stop) and (len(token_ids) == 0 or token_ids[-1] <= 0): if (not recovery_stop) and (len(token_ids) == 0 or token_ids[-1] <= 0):
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
if task_id in self.resource_manager.to_be_rescheduled_request_id_set:
self.resource_manager.reschedule_preempt_task(task_id)
continue continue
else: else:
token_id = int(tokens[i, 0]) token_id = int(tokens[i, 0])

View File

@@ -17,11 +17,10 @@
import os import os
import time import time
from multiprocessing.shared_memory import SharedMemory from multiprocessing.shared_memory import SharedMemory
from typing import Any, Dict from typing import Any, Dict, List
import numpy as np import numpy as np
import paddle import paddle
from paddle import nn
from paddleformers.utils.log import logger from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
@@ -31,7 +30,7 @@ from fastdeploy.inter_communicator import ModelWeightsStatus
class DynamicWeightManager: class DynamicWeightManager:
"""Manages model weights loading, updating and shared state across processes.""" """Manages model weights loading, updating and shared state across processes."""
def __init__(self, fd_config: FDConfig, model: nn.Layer): def __init__(self, fd_config: FDConfig, models):
"""Initialize with config and model instances.""" """Initialize with config and model instances."""
self.fd_config = fd_config self.fd_config = fd_config
self.load_config = fd_config.load_config self.load_config = fd_config.load_config
@@ -42,7 +41,10 @@ class DynamicWeightManager:
self.meta_src_id = self._get_gpu_id() self.meta_src_id = self._get_gpu_id()
self.first_load = True self.first_load = True
self.ipc_path = f"/shared_ipc_meta/ipc_metas_{self.meta_src_id}" self.ipc_path = f"/shared_ipc_meta/ipc_metas_{self.meta_src_id}"
self.model: nn.Layer = model if not isinstance(models, List):
self.model_list = [models]
else:
self.model_list = models
self._capture_model_state() self._capture_model_state()
self.update_parameters() self.update_parameters()
self.finalize_update() self.finalize_update()
@@ -55,9 +57,10 @@ class DynamicWeightManager:
@paddle.no_grad() @paddle.no_grad()
def _capture_model_state(self): def _capture_model_state(self):
"""Capture and store initial model parameters state.""" """Capture and store initial model parameters state."""
for name, param in self.model.state_dict().items(): for model in self.model_list:
logger.debug(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}") for name, param in model.state_dict().items():
self.state_dict[name] = param logger.info(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}")
self.state_dict[name] = param
def update_parameters(self, pid: int = 0) -> None: def update_parameters(self, pid: int = 0) -> None:
"""Core method to update model parameters based on strategy.""" """Core method to update model parameters based on strategy."""
@@ -137,8 +140,9 @@ class DynamicWeightManager:
paddle.device.cuda.empty_cache() paddle.device.cuda.empty_cache()
# step2: release model weight # step2: release model weight
for param in self.model.state_dict().values(): for model in self.model_list:
param._clear_data() for param in model.state_dict().values():
param._clear_data()
self._verify_parameters("clearance") self._verify_parameters("clearance")

View File

@@ -38,13 +38,20 @@ class Proposer(ABC):
Init Speculative proposer Init Speculative proposer
""" """
fd_config.parallel_config.tp_group = None fd_config.parallel_config.tp_group = None
fd_config.parallel_config.ep_group = None
self.fd_config = deepcopy(fd_config) self.fd_config = deepcopy(fd_config)
fd_config.parallel_config.tp_group = dist.get_group( fd_config.parallel_config.tp_group = dist.get_group(
fd_config.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET fd_config.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
) )
fd_config.parallel_config.ep_group = dist.get_group(
fd_config.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET
)
self.fd_config.parallel_config.tp_group = dist.get_group( self.fd_config.parallel_config.tp_group = dist.get_group(
fd_config.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET fd_config.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
) )
self.fd_config.parallel_config.ep_group = dist.get_group(
fd_config.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET
)
self.parallel_config = self.fd_config.parallel_config self.parallel_config = self.fd_config.parallel_config
self.model_config = self.fd_config.model_config self.model_config = self.fd_config.model_config
self.speculative_config = self.fd_config.speculative_config self.speculative_config = self.fd_config.speculative_config

View File

@@ -96,7 +96,8 @@ class MTPProposer(Proposer):
""" """
Update config for MTP from global config Update config for MTP from global config
""" """
self.model_config.architectures[0] = "Ernie4_5_MTPForCausalLM" self.forward_meta: ForwardMeta = None
self.model_config.architectures[0] = self.model_config.architectures[0].replace("Moe", "MTP")
self.speculative_config.sharing_model = main_model self.speculative_config.sharing_model = main_model
self.model_config.num_hidden_layers = 1 self.model_config.num_hidden_layers = 1
self.model_config.model = self.speculative_config.model self.model_config.model = self.speculative_config.model
@@ -169,6 +170,9 @@ class MTPProposer(Proposer):
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape( kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
) )
if kv_cache_quant_type == "block_wise_fp8":
kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]]
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
if not profile and ( if not profile and (
self.cache_config.enable_prefix_caching or self.scheduler_config.splitwise_role != "mixed" self.cache_config.enable_prefix_caching or self.scheduler_config.splitwise_role != "mixed"
): ):
@@ -178,8 +182,8 @@ class MTPProposer(Proposer):
self.num_main_model_layers + self.model_config.num_hidden_layers, self.num_main_model_layers + self.model_config.num_hidden_layers,
): ):
key_cache = paddle.empty(shape=[], dtype=cache_type) key_cache = paddle.empty(shape=[], dtype=cache_type)
key_cache_name = f"key_caches_{i}_rank{self.local_rank}.device{self.device_id}" key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
val_cache_name = f"value_caches_{i}_rank{self.local_rank}.device{self.device_id}" val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape) key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape)
cache_kvs_list.append(key_cache) cache_kvs_list.append(key_cache)
value_cache = paddle.empty(shape=[], dtype=cache_type) value_cache = paddle.empty(shape=[], dtype=cache_type)
@@ -199,6 +203,17 @@ class MTPProposer(Proposer):
fill_value=0, fill_value=0,
dtype=cache_type, dtype=cache_type,
) )
if kv_cache_quant_type == "block_wise_fp8":
self.cache_kvs[f"key_cache_scales_{i}"] = paddle.full(
shape=kv_cache_scale_shape,
fill_value=0,
dtype=paddle.get_default_dtype(),
)
self.cache_kvs[f"value_cache_scales_{i}"] = paddle.full(
shape=kv_cache_scale_shape,
fill_value=0,
dtype=paddle.get_default_dtype(),
)
self.model_inputs["caches"] = list(self.cache_kvs.values()) self.model_inputs["caches"] = list(self.cache_kvs.values())
for value in self.cache_kvs.values(): for value in self.cache_kvs.values():
del value del value
@@ -430,11 +445,10 @@ class MTPProposer(Proposer):
if "caches" not in self.model_inputs: if "caches" not in self.model_inputs:
self.initialize_kv_cache() self.initialize_kv_cache()
req_len = len(req_dicts) req_len = len(req_dicts)
# has_prefill_task = False
# has_decode_task = False
for i in range(req_len): for i in range(req_len):
request = req_dicts[i] request = req_dicts[i]
logger.info(f"{i}th request-{request.request_id}: {request}") logger.debug(f"{i}th request-{request.request_id}: {request}")
idx = request.idx idx = request.idx
if request.task_type.value == RequestType.PREFILL.value: # prefill task if request.task_type.value == RequestType.PREFILL.value: # prefill task
prefill_start_index = request.prefill_start_index prefill_start_index = request.prefill_start_index
@@ -688,7 +702,7 @@ class MTPProposer(Proposer):
self.max_model_len, self.max_model_len,
self.model_inputs["substep"], self.model_inputs["substep"],
) )
if self.role == "prefill": if self.role == "prefill" and self.parallel_config.tensor_parallel_rank == 0:
mtp_save_first_token( mtp_save_first_token(
self.model_inputs["base_model_draft_tokens"], self.model_inputs["base_model_draft_tokens"],
self.model_inputs["not_need_stop"], self.model_inputs["not_need_stop"],
@@ -820,11 +834,18 @@ class MTPProposer(Proposer):
) )
if self.parallel_config.tensor_parallel_size > 1: if self.parallel_config.tensor_parallel_size > 1:
paddle.distributed.broadcast(sampled_token_ids, 0) paddle.distributed.broadcast(
sampled_token_ids,
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
group=self.parallel_config.tp_group,
)
self._post_process(sampled_token_ids) self._post_process(sampled_token_ids)
if substep != self.num_model_steps - 1: if substep != self.num_model_steps - 1:
self._get_self_hidden_states(hidden_states) self._get_self_hidden_states(hidden_states)
else:
if hasattr(self.model, "empty_input_forward"):
self.model.empty_input_forward()
def _get_self_hidden_states(self, hidden_states): def _get_self_hidden_states(self, hidden_states):
target_hidden_states = eagle_get_self_hidden_states( target_hidden_states = eagle_get_self_hidden_states(

View File

@@ -778,13 +778,6 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
logger.info(f"- Dynamic load weight: {load_config.dynamic_load_weight}") logger.info(f"- Dynamic load weight: {load_config.dynamic_load_weight}")
logger.info(f"- Load strategy: {load_config.load_strategy}") logger.info(f"- Load strategy: {load_config.load_strategy}")
if (
args.speculative_config is not None
and ("method" in args.speculative_config)
and (args.speculative_config["method"] is not None)
):
logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not support speculative decoding now.")
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
if args.splitwise_role != "mixed" and args.cache_transfer_protocol != "rdma": if args.splitwise_role != "mixed" and args.cache_transfer_protocol != "rdma":
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
if not current_platform.is_cuda() and not current_platform.is_xpu(): if not current_platform.is_cuda() and not current_platform.is_xpu():