[Speculative Decoding][MTP]Support mtp in epdptp mode (#4614)

* support mtp many features

* support mtp reshard in rl mode

* fix function

* support mtp ep

* support mtp in hybird-dp-tp mode

* default open scheduler_v1 in mtp
This commit is contained in:
freeliuzc
2025-10-28 16:02:47 +08:00
committed by GitHub
parent b4014834a9
commit c63361fd1d
10 changed files with 124 additions and 74 deletions

View File

@@ -442,8 +442,7 @@ class EngineArgs:
raise NotImplementedError("Only CUDA platform supports logprob.")
if self.speculative_config is not None and self.logprobs_mode.startswith("processed"):
raise NotImplementedError("processed_logprobs not support in speculative.")
if self.speculative_config is not None:
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
if self.splitwise_role != "mixed" and self.cache_transfer_protocol != "rdma":
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
if not current_platform.is_cuda() and not current_platform.is_xpu():

View File

@@ -92,6 +92,8 @@ class AppendAttentionBackend(AttentionBackend):
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr(
fd_config.model_config, "use_3d_rope", False
)
if fd_config.speculative_config.model_type != "main":
self.rope_3d = False
self.causal: bool = getattr(fd_config.model_config, "causal", True)
self.speculative_method: str = fd_config.speculative_config.method
self.use_speculate: bool = self.speculative_method is not None
@@ -364,7 +366,7 @@ class AppendAttentionBackend(AttentionBackend):
getattr(layer, "cache_v_zp", None),
layer.linear_shift,
layer.linear_smooth,
forward_meta.attn_mask_offsets,
None if self.use_speculate else forward_meta.attn_mask_offsets,
metadata.kv_signal_data_list[layer.layer_id],
getattr(layer, "q_norm_weight", None),
getattr(layer, "k_norm_weight", None),
@@ -383,7 +385,7 @@ class AppendAttentionBackend(AttentionBackend):
metadata.max_partition_size,
metadata.encoder_max_partition_size,
self.speculate_max_draft_token_num + 1,
self.causal,
self.causal or self.use_speculate,
self.speculative_method is not None,
sliding_window,
)

View File

@@ -18,7 +18,7 @@ import paddle
from paddle import nn
from paddle.distributed import fleet
from fastdeploy.model_executor.utils import set_weight_attrs
from fastdeploy.model_executor.utils import default_weight_loader, set_weight_attrs
from .utils import get_tensor
@@ -53,44 +53,61 @@ class ParallelEHProjection(nn.Layer):
self.bias_key = prefix + ".bias"
else:
self.bias_key = None
self.use_ep = fd_config.parallel_config.use_ep
self.fd_config = fd_config
self.tp_group = fd_config.parallel_config.tp_group
self.column_cut = True
self.nranks = fd_config.parallel_config.tensor_parallel_size
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
if self.use_ep:
self.weight = self.create_parameter(
shape=[embedding_dim, num_embeddings],
dtype=paddle.get_default_dtype(),
is_bias=False,
)
else:
if self.column_cut:
need_gather = True
self.linear = ColumnParallelLinear(
embedding_dim,
num_embeddings,
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
mp_group=self.tp_group,
weight_attr=None,
has_bias=True if self.bias_key is not None else False,
gather_output=need_gather,
fuse_matmul_bias=False, # False diff更小
)
set_weight_attrs(self.linear.weight, {"output_dim": True})
set_weight_attrs(
self.linear.weight,
{
"weight_loader": default_weight_loader(self.fd_config),
"model_format": self.fd_config.model_config.model_format,
},
)
if self.bias_key is not None:
set_weight_attrs(self.linear.bias, {"output_dim": True})
set_weight_attrs(
self.linear.bias,
{"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}},
)
if self.nranks > 1:
set_weight_attrs(self.linear.weight, {"output_dim": True})
else:
self.linear = RowParallelLinear(
embedding_dim,
num_embeddings,
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
mp_group=self.tp_group,
weight_attr=None,
has_bias=True if self.bias_key is not None else False,
input_is_parallel=False,
fuse_matmul_bias=False, # False diff更小
)
set_weight_attrs(self.linear.weight, {"output_dim": False})
set_weight_attrs(
self.linear.weight,
{
"weight_loader": default_weight_loader(self.fd_config),
"model_format": self.fd_config.model_config.model_format,
},
)
if self.nranks > 1:
set_weight_attrs(self.linear.weight, {"output_dim": True})
set_weight_attrs(
self.linear.weight, {"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}}
)
def load_state_dict(self, state_dict):
"""
@@ -100,9 +117,6 @@ class ParallelEHProjection(nn.Layer):
state_dict (dict): A dictionary containing the checkpoint weights and biases.
"""
if self.use_ep:
self.weight.set_value(get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()))
else:
weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype())
if self.linear.weight.shape != weight_tensor.shape:
weight_tensor = weight_tensor.transpose([1, 0])
@@ -123,8 +137,5 @@ class ParallelEHProjection(nn.Layer):
Tensor: The output tensor after processing through the layer.
"""
logits = input
if self.use_ep:
logits = paddle.matmul(logits, self.weight)
else:
logits = self.linear(logits)
return logits

View File

@@ -72,6 +72,11 @@ class DefaultModelLoader(BaseModelLoader):
# register rl model
import fastdeploy.rl # noqa
if fd_config.speculative_config.model_type != "mtp":
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MoeForCausalLM")
else:
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MTPForCausalLM")
architectures = architectures + "RL"
context = paddle.LazyGuard()
else:

View File

@@ -65,6 +65,11 @@ class DefaultModelLoaderV1(BaseModelLoader):
# register rl model
import fastdeploy.rl # noqa
if fd_config.speculative_config.model_type != "mtp":
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MoeForCausalLM")
else:
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MTPForCausalLM")
architectures = architectures + "RL"
enable_cache, _, weight_cache_context = is_weight_cache_enabled(fd_config)

View File

@@ -502,7 +502,7 @@ class TokenProcessor:
def _compute_speculative_status(self):
# TODO(liuzichang): Supplement more statistics
interval = 10
interval = 1
if self.speculative_stats_step % interval == 0:
accept_ratio = 1 - self.total_step * 1.0 / self.number_of_output_tokens
spec_logger.info(
@@ -593,6 +593,9 @@ class TokenProcessor:
+ accept_num[i]
].tolist()
if (not recovery_stop) and (len(token_ids) == 0 or token_ids[-1] <= 0):
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
if task_id in self.resource_manager.to_be_rescheduled_request_id_set:
self.resource_manager.reschedule_preempt_task(task_id)
continue
else:
token_id = int(tokens[i, 0])

View File

@@ -17,11 +17,10 @@
import os
import time
from multiprocessing.shared_memory import SharedMemory
from typing import Any, Dict
from typing import Any, Dict, List
import numpy as np
import paddle
from paddle import nn
from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig
@@ -31,7 +30,7 @@ from fastdeploy.inter_communicator import ModelWeightsStatus
class DynamicWeightManager:
"""Manages model weights loading, updating and shared state across processes."""
def __init__(self, fd_config: FDConfig, model: nn.Layer):
def __init__(self, fd_config: FDConfig, models):
"""Initialize with config and model instances."""
self.fd_config = fd_config
self.load_config = fd_config.load_config
@@ -42,7 +41,10 @@ class DynamicWeightManager:
self.meta_src_id = self._get_gpu_id()
self.first_load = True
self.ipc_path = f"/shared_ipc_meta/ipc_metas_{self.meta_src_id}"
self.model: nn.Layer = model
if not isinstance(models, List):
self.model_list = [models]
else:
self.model_list = models
self._capture_model_state()
self.update_parameters()
self.finalize_update()
@@ -55,8 +57,9 @@ class DynamicWeightManager:
@paddle.no_grad()
def _capture_model_state(self):
"""Capture and store initial model parameters state."""
for name, param in self.model.state_dict().items():
logger.debug(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}")
for model in self.model_list:
for name, param in model.state_dict().items():
logger.info(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}")
self.state_dict[name] = param
def update_parameters(self, pid: int = 0) -> None:
@@ -137,7 +140,8 @@ class DynamicWeightManager:
paddle.device.cuda.empty_cache()
# step2: release model weight
for param in self.model.state_dict().values():
for model in self.model_list:
for param in model.state_dict().values():
param._clear_data()
self._verify_parameters("clearance")

View File

@@ -38,13 +38,20 @@ class Proposer(ABC):
Init Speculative proposer
"""
fd_config.parallel_config.tp_group = None
fd_config.parallel_config.ep_group = None
self.fd_config = deepcopy(fd_config)
fd_config.parallel_config.tp_group = dist.get_group(
fd_config.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
)
fd_config.parallel_config.ep_group = dist.get_group(
fd_config.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET
)
self.fd_config.parallel_config.tp_group = dist.get_group(
fd_config.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
)
self.fd_config.parallel_config.ep_group = dist.get_group(
fd_config.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET
)
self.parallel_config = self.fd_config.parallel_config
self.model_config = self.fd_config.model_config
self.speculative_config = self.fd_config.speculative_config

View File

@@ -96,7 +96,8 @@ class MTPProposer(Proposer):
"""
Update config for MTP from global config
"""
self.model_config.architectures[0] = "Ernie4_5_MTPForCausalLM"
self.forward_meta: ForwardMeta = None
self.model_config.architectures[0] = self.model_config.architectures[0].replace("Moe", "MTP")
self.speculative_config.sharing_model = main_model
self.model_config.num_hidden_layers = 1
self.model_config.model = self.speculative_config.model
@@ -169,6 +170,9 @@ class MTPProposer(Proposer):
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
)
if kv_cache_quant_type == "block_wise_fp8":
kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]]
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
if not profile and (
self.cache_config.enable_prefix_caching or self.scheduler_config.splitwise_role != "mixed"
):
@@ -178,8 +182,8 @@ class MTPProposer(Proposer):
self.num_main_model_layers + self.model_config.num_hidden_layers,
):
key_cache = paddle.empty(shape=[], dtype=cache_type)
key_cache_name = f"key_caches_{i}_rank{self.local_rank}.device{self.device_id}"
val_cache_name = f"value_caches_{i}_rank{self.local_rank}.device{self.device_id}"
key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape)
cache_kvs_list.append(key_cache)
value_cache = paddle.empty(shape=[], dtype=cache_type)
@@ -199,6 +203,17 @@ class MTPProposer(Proposer):
fill_value=0,
dtype=cache_type,
)
if kv_cache_quant_type == "block_wise_fp8":
self.cache_kvs[f"key_cache_scales_{i}"] = paddle.full(
shape=kv_cache_scale_shape,
fill_value=0,
dtype=paddle.get_default_dtype(),
)
self.cache_kvs[f"value_cache_scales_{i}"] = paddle.full(
shape=kv_cache_scale_shape,
fill_value=0,
dtype=paddle.get_default_dtype(),
)
self.model_inputs["caches"] = list(self.cache_kvs.values())
for value in self.cache_kvs.values():
del value
@@ -430,11 +445,10 @@ class MTPProposer(Proposer):
if "caches" not in self.model_inputs:
self.initialize_kv_cache()
req_len = len(req_dicts)
# has_prefill_task = False
# has_decode_task = False
for i in range(req_len):
request = req_dicts[i]
logger.info(f"{i}th request-{request.request_id}: {request}")
logger.debug(f"{i}th request-{request.request_id}: {request}")
idx = request.idx
if request.task_type.value == RequestType.PREFILL.value: # prefill task
prefill_start_index = request.prefill_start_index
@@ -688,7 +702,7 @@ class MTPProposer(Proposer):
self.max_model_len,
self.model_inputs["substep"],
)
if self.role == "prefill":
if self.role == "prefill" and self.parallel_config.tensor_parallel_rank == 0:
mtp_save_first_token(
self.model_inputs["base_model_draft_tokens"],
self.model_inputs["not_need_stop"],
@@ -820,11 +834,18 @@ class MTPProposer(Proposer):
)
if self.parallel_config.tensor_parallel_size > 1:
paddle.distributed.broadcast(sampled_token_ids, 0)
paddle.distributed.broadcast(
sampled_token_ids,
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
group=self.parallel_config.tp_group,
)
self._post_process(sampled_token_ids)
if substep != self.num_model_steps - 1:
self._get_self_hidden_states(hidden_states)
else:
if hasattr(self.model, "empty_input_forward"):
self.model.empty_input_forward()
def _get_self_hidden_states(self, hidden_states):
target_hidden_states = eagle_get_self_hidden_states(

View File

@@ -778,13 +778,6 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
logger.info(f"- Dynamic load weight: {load_config.dynamic_load_weight}")
logger.info(f"- Load strategy: {load_config.load_strategy}")
if (
args.speculative_config is not None
and ("method" in args.speculative_config)
and (args.speculative_config["method"] is not None)
):
logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not support speculative decoding now.")
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
if args.splitwise_role != "mixed" and args.cache_transfer_protocol != "rdma":
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
if not current_platform.is_cuda() and not current_platform.is_xpu():