mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-11-01 20:32:52 +08:00
[Speculative Decoding][MTP]Support mtp in epdptp mode (#4614)
* support mtp many features * support mtp reshard in rl mode * fix function * support mtp ep * support mtp in hybird-dp-tp mode * default open scheduler_v1 in mtp
This commit is contained in:
@@ -442,8 +442,7 @@ class EngineArgs:
|
||||
raise NotImplementedError("Only CUDA platform supports logprob.")
|
||||
if self.speculative_config is not None and self.logprobs_mode.startswith("processed"):
|
||||
raise NotImplementedError("processed_logprobs not support in speculative.")
|
||||
if self.speculative_config is not None:
|
||||
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
|
||||
|
||||
if self.splitwise_role != "mixed" and self.cache_transfer_protocol != "rdma":
|
||||
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
|
||||
if not current_platform.is_cuda() and not current_platform.is_xpu():
|
||||
|
||||
@@ -92,6 +92,8 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr(
|
||||
fd_config.model_config, "use_3d_rope", False
|
||||
)
|
||||
if fd_config.speculative_config.model_type != "main":
|
||||
self.rope_3d = False
|
||||
self.causal: bool = getattr(fd_config.model_config, "causal", True)
|
||||
self.speculative_method: str = fd_config.speculative_config.method
|
||||
self.use_speculate: bool = self.speculative_method is not None
|
||||
@@ -364,7 +366,7 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
getattr(layer, "cache_v_zp", None),
|
||||
layer.linear_shift,
|
||||
layer.linear_smooth,
|
||||
forward_meta.attn_mask_offsets,
|
||||
None if self.use_speculate else forward_meta.attn_mask_offsets,
|
||||
metadata.kv_signal_data_list[layer.layer_id],
|
||||
getattr(layer, "q_norm_weight", None),
|
||||
getattr(layer, "k_norm_weight", None),
|
||||
@@ -383,7 +385,7 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
metadata.max_partition_size,
|
||||
metadata.encoder_max_partition_size,
|
||||
self.speculate_max_draft_token_num + 1,
|
||||
self.causal,
|
||||
self.causal or self.use_speculate,
|
||||
self.speculative_method is not None,
|
||||
sliding_window,
|
||||
)
|
||||
|
||||
@@ -18,7 +18,7 @@ import paddle
|
||||
from paddle import nn
|
||||
from paddle.distributed import fleet
|
||||
|
||||
from fastdeploy.model_executor.utils import set_weight_attrs
|
||||
from fastdeploy.model_executor.utils import default_weight_loader, set_weight_attrs
|
||||
|
||||
from .utils import get_tensor
|
||||
|
||||
@@ -53,44 +53,61 @@ class ParallelEHProjection(nn.Layer):
|
||||
self.bias_key = prefix + ".bias"
|
||||
else:
|
||||
self.bias_key = None
|
||||
self.use_ep = fd_config.parallel_config.use_ep
|
||||
self.fd_config = fd_config
|
||||
self.tp_group = fd_config.parallel_config.tp_group
|
||||
self.column_cut = True
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_size
|
||||
|
||||
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
|
||||
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
|
||||
|
||||
if self.use_ep:
|
||||
self.weight = self.create_parameter(
|
||||
shape=[embedding_dim, num_embeddings],
|
||||
dtype=paddle.get_default_dtype(),
|
||||
is_bias=False,
|
||||
if self.column_cut:
|
||||
need_gather = True
|
||||
self.linear = ColumnParallelLinear(
|
||||
embedding_dim,
|
||||
num_embeddings,
|
||||
mp_group=self.tp_group,
|
||||
weight_attr=None,
|
||||
has_bias=True if self.bias_key is not None else False,
|
||||
gather_output=need_gather,
|
||||
fuse_matmul_bias=False, # False diff更小
|
||||
)
|
||||
else:
|
||||
if self.column_cut:
|
||||
need_gather = True
|
||||
self.linear = ColumnParallelLinear(
|
||||
embedding_dim,
|
||||
num_embeddings,
|
||||
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
|
||||
weight_attr=None,
|
||||
has_bias=True if self.bias_key is not None else False,
|
||||
gather_output=need_gather,
|
||||
fuse_matmul_bias=False, # False diff更小
|
||||
set_weight_attrs(
|
||||
self.linear.weight,
|
||||
{
|
||||
"weight_loader": default_weight_loader(self.fd_config),
|
||||
"model_format": self.fd_config.model_config.model_format,
|
||||
},
|
||||
)
|
||||
if self.bias_key is not None:
|
||||
set_weight_attrs(
|
||||
self.linear.bias,
|
||||
{"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}},
|
||||
)
|
||||
if self.nranks > 1:
|
||||
set_weight_attrs(self.linear.weight, {"output_dim": True})
|
||||
if self.bias_key is not None:
|
||||
set_weight_attrs(self.linear.bias, {"output_dim": True})
|
||||
else:
|
||||
self.linear = RowParallelLinear(
|
||||
embedding_dim,
|
||||
num_embeddings,
|
||||
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
|
||||
weight_attr=None,
|
||||
has_bias=True if self.bias_key is not None else False,
|
||||
input_is_parallel=False,
|
||||
fuse_matmul_bias=False, # False diff更小
|
||||
)
|
||||
set_weight_attrs(self.linear.weight, {"output_dim": False})
|
||||
else:
|
||||
self.linear = RowParallelLinear(
|
||||
embedding_dim,
|
||||
num_embeddings,
|
||||
mp_group=self.tp_group,
|
||||
weight_attr=None,
|
||||
has_bias=True if self.bias_key is not None else False,
|
||||
input_is_parallel=False,
|
||||
fuse_matmul_bias=False, # False diff更小
|
||||
)
|
||||
set_weight_attrs(
|
||||
self.linear.weight,
|
||||
{
|
||||
"weight_loader": default_weight_loader(self.fd_config),
|
||||
"model_format": self.fd_config.model_config.model_format,
|
||||
},
|
||||
)
|
||||
if self.nranks > 1:
|
||||
set_weight_attrs(self.linear.weight, {"output_dim": True})
|
||||
set_weight_attrs(
|
||||
self.linear.weight, {"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}}
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""
|
||||
@@ -100,17 +117,14 @@ class ParallelEHProjection(nn.Layer):
|
||||
state_dict (dict): A dictionary containing the checkpoint weights and biases.
|
||||
"""
|
||||
|
||||
if self.use_ep:
|
||||
self.weight.set_value(get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()))
|
||||
else:
|
||||
weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype())
|
||||
if self.linear.weight.shape != weight_tensor.shape:
|
||||
weight_tensor = weight_tensor.transpose([1, 0])
|
||||
self.linear.weight.set_value(weight_tensor)
|
||||
weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype())
|
||||
if self.linear.weight.shape != weight_tensor.shape:
|
||||
weight_tensor = weight_tensor.transpose([1, 0])
|
||||
self.linear.weight.set_value(weight_tensor)
|
||||
|
||||
if self.bias_key is not None:
|
||||
bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype())
|
||||
self.linear.bias.set_value(bias)
|
||||
if self.bias_key is not None:
|
||||
bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype())
|
||||
self.linear.bias.set_value(bias)
|
||||
|
||||
def forward(self, input):
|
||||
"""
|
||||
@@ -123,8 +137,5 @@ class ParallelEHProjection(nn.Layer):
|
||||
Tensor: The output tensor after processing through the layer.
|
||||
"""
|
||||
logits = input
|
||||
if self.use_ep:
|
||||
logits = paddle.matmul(logits, self.weight)
|
||||
else:
|
||||
logits = self.linear(logits)
|
||||
logits = self.linear(logits)
|
||||
return logits
|
||||
|
||||
@@ -72,6 +72,11 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
# register rl model
|
||||
import fastdeploy.rl # noqa
|
||||
|
||||
if fd_config.speculative_config.model_type != "mtp":
|
||||
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MoeForCausalLM")
|
||||
else:
|
||||
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MTPForCausalLM")
|
||||
|
||||
architectures = architectures + "RL"
|
||||
context = paddle.LazyGuard()
|
||||
else:
|
||||
|
||||
@@ -65,6 +65,11 @@ class DefaultModelLoaderV1(BaseModelLoader):
|
||||
# register rl model
|
||||
import fastdeploy.rl # noqa
|
||||
|
||||
if fd_config.speculative_config.model_type != "mtp":
|
||||
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MoeForCausalLM")
|
||||
else:
|
||||
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MTPForCausalLM")
|
||||
|
||||
architectures = architectures + "RL"
|
||||
|
||||
enable_cache, _, weight_cache_context = is_weight_cache_enabled(fd_config)
|
||||
|
||||
@@ -502,7 +502,7 @@ class TokenProcessor:
|
||||
|
||||
def _compute_speculative_status(self):
|
||||
# TODO(liuzichang): Supplement more statistics
|
||||
interval = 10
|
||||
interval = 1
|
||||
if self.speculative_stats_step % interval == 0:
|
||||
accept_ratio = 1 - self.total_step * 1.0 / self.number_of_output_tokens
|
||||
spec_logger.info(
|
||||
@@ -593,6 +593,9 @@ class TokenProcessor:
|
||||
+ accept_num[i]
|
||||
].tolist()
|
||||
if (not recovery_stop) and (len(token_ids) == 0 or token_ids[-1] <= 0):
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
if task_id in self.resource_manager.to_be_rescheduled_request_id_set:
|
||||
self.resource_manager.reschedule_preempt_task(task_id)
|
||||
continue
|
||||
else:
|
||||
token_id = int(tokens[i, 0])
|
||||
|
||||
@@ -17,11 +17,10 @@
|
||||
import os
|
||||
import time
|
||||
from multiprocessing.shared_memory import SharedMemory
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
@@ -31,7 +30,7 @@ from fastdeploy.inter_communicator import ModelWeightsStatus
|
||||
class DynamicWeightManager:
|
||||
"""Manages model weights loading, updating and shared state across processes."""
|
||||
|
||||
def __init__(self, fd_config: FDConfig, model: nn.Layer):
|
||||
def __init__(self, fd_config: FDConfig, models):
|
||||
"""Initialize with config and model instances."""
|
||||
self.fd_config = fd_config
|
||||
self.load_config = fd_config.load_config
|
||||
@@ -42,7 +41,10 @@ class DynamicWeightManager:
|
||||
self.meta_src_id = self._get_gpu_id()
|
||||
self.first_load = True
|
||||
self.ipc_path = f"/shared_ipc_meta/ipc_metas_{self.meta_src_id}"
|
||||
self.model: nn.Layer = model
|
||||
if not isinstance(models, List):
|
||||
self.model_list = [models]
|
||||
else:
|
||||
self.model_list = models
|
||||
self._capture_model_state()
|
||||
self.update_parameters()
|
||||
self.finalize_update()
|
||||
@@ -55,9 +57,10 @@ class DynamicWeightManager:
|
||||
@paddle.no_grad()
|
||||
def _capture_model_state(self):
|
||||
"""Capture and store initial model parameters state."""
|
||||
for name, param in self.model.state_dict().items():
|
||||
logger.debug(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}")
|
||||
self.state_dict[name] = param
|
||||
for model in self.model_list:
|
||||
for name, param in model.state_dict().items():
|
||||
logger.info(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}")
|
||||
self.state_dict[name] = param
|
||||
|
||||
def update_parameters(self, pid: int = 0) -> None:
|
||||
"""Core method to update model parameters based on strategy."""
|
||||
@@ -137,8 +140,9 @@ class DynamicWeightManager:
|
||||
|
||||
paddle.device.cuda.empty_cache()
|
||||
# step2: release model weight
|
||||
for param in self.model.state_dict().values():
|
||||
param._clear_data()
|
||||
for model in self.model_list:
|
||||
for param in model.state_dict().values():
|
||||
param._clear_data()
|
||||
|
||||
self._verify_parameters("clearance")
|
||||
|
||||
|
||||
@@ -38,13 +38,20 @@ class Proposer(ABC):
|
||||
Init Speculative proposer
|
||||
"""
|
||||
fd_config.parallel_config.tp_group = None
|
||||
fd_config.parallel_config.ep_group = None
|
||||
self.fd_config = deepcopy(fd_config)
|
||||
fd_config.parallel_config.tp_group = dist.get_group(
|
||||
fd_config.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
|
||||
)
|
||||
fd_config.parallel_config.ep_group = dist.get_group(
|
||||
fd_config.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET
|
||||
)
|
||||
self.fd_config.parallel_config.tp_group = dist.get_group(
|
||||
fd_config.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
|
||||
)
|
||||
self.fd_config.parallel_config.ep_group = dist.get_group(
|
||||
fd_config.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET
|
||||
)
|
||||
self.parallel_config = self.fd_config.parallel_config
|
||||
self.model_config = self.fd_config.model_config
|
||||
self.speculative_config = self.fd_config.speculative_config
|
||||
|
||||
@@ -96,7 +96,8 @@ class MTPProposer(Proposer):
|
||||
"""
|
||||
Update config for MTP from global config
|
||||
"""
|
||||
self.model_config.architectures[0] = "Ernie4_5_MTPForCausalLM"
|
||||
self.forward_meta: ForwardMeta = None
|
||||
self.model_config.architectures[0] = self.model_config.architectures[0].replace("Moe", "MTP")
|
||||
self.speculative_config.sharing_model = main_model
|
||||
self.model_config.num_hidden_layers = 1
|
||||
self.model_config.model = self.speculative_config.model
|
||||
@@ -169,6 +170,9 @@ class MTPProposer(Proposer):
|
||||
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
|
||||
max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
|
||||
)
|
||||
if kv_cache_quant_type == "block_wise_fp8":
|
||||
kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]]
|
||||
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
||||
if not profile and (
|
||||
self.cache_config.enable_prefix_caching or self.scheduler_config.splitwise_role != "mixed"
|
||||
):
|
||||
@@ -178,8 +182,8 @@ class MTPProposer(Proposer):
|
||||
self.num_main_model_layers + self.model_config.num_hidden_layers,
|
||||
):
|
||||
key_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||
key_cache_name = f"key_caches_{i}_rank{self.local_rank}.device{self.device_id}"
|
||||
val_cache_name = f"value_caches_{i}_rank{self.local_rank}.device{self.device_id}"
|
||||
key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
|
||||
val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
|
||||
key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape)
|
||||
cache_kvs_list.append(key_cache)
|
||||
value_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||
@@ -199,6 +203,17 @@ class MTPProposer(Proposer):
|
||||
fill_value=0,
|
||||
dtype=cache_type,
|
||||
)
|
||||
if kv_cache_quant_type == "block_wise_fp8":
|
||||
self.cache_kvs[f"key_cache_scales_{i}"] = paddle.full(
|
||||
shape=kv_cache_scale_shape,
|
||||
fill_value=0,
|
||||
dtype=paddle.get_default_dtype(),
|
||||
)
|
||||
self.cache_kvs[f"value_cache_scales_{i}"] = paddle.full(
|
||||
shape=kv_cache_scale_shape,
|
||||
fill_value=0,
|
||||
dtype=paddle.get_default_dtype(),
|
||||
)
|
||||
self.model_inputs["caches"] = list(self.cache_kvs.values())
|
||||
for value in self.cache_kvs.values():
|
||||
del value
|
||||
@@ -430,11 +445,10 @@ class MTPProposer(Proposer):
|
||||
if "caches" not in self.model_inputs:
|
||||
self.initialize_kv_cache()
|
||||
req_len = len(req_dicts)
|
||||
# has_prefill_task = False
|
||||
# has_decode_task = False
|
||||
|
||||
for i in range(req_len):
|
||||
request = req_dicts[i]
|
||||
logger.info(f"{i}th request-{request.request_id}: {request}")
|
||||
logger.debug(f"{i}th request-{request.request_id}: {request}")
|
||||
idx = request.idx
|
||||
if request.task_type.value == RequestType.PREFILL.value: # prefill task
|
||||
prefill_start_index = request.prefill_start_index
|
||||
@@ -688,7 +702,7 @@ class MTPProposer(Proposer):
|
||||
self.max_model_len,
|
||||
self.model_inputs["substep"],
|
||||
)
|
||||
if self.role == "prefill":
|
||||
if self.role == "prefill" and self.parallel_config.tensor_parallel_rank == 0:
|
||||
mtp_save_first_token(
|
||||
self.model_inputs["base_model_draft_tokens"],
|
||||
self.model_inputs["not_need_stop"],
|
||||
@@ -820,11 +834,18 @@ class MTPProposer(Proposer):
|
||||
)
|
||||
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
paddle.distributed.broadcast(sampled_token_ids, 0)
|
||||
paddle.distributed.broadcast(
|
||||
sampled_token_ids,
|
||||
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
|
||||
group=self.parallel_config.tp_group,
|
||||
)
|
||||
|
||||
self._post_process(sampled_token_ids)
|
||||
if substep != self.num_model_steps - 1:
|
||||
self._get_self_hidden_states(hidden_states)
|
||||
else:
|
||||
if hasattr(self.model, "empty_input_forward"):
|
||||
self.model.empty_input_forward()
|
||||
|
||||
def _get_self_hidden_states(self, hidden_states):
|
||||
target_hidden_states = eagle_get_self_hidden_states(
|
||||
|
||||
@@ -778,13 +778,6 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
logger.info(f"- Dynamic load weight: {load_config.dynamic_load_weight}")
|
||||
logger.info(f"- Load strategy: {load_config.load_strategy}")
|
||||
|
||||
if (
|
||||
args.speculative_config is not None
|
||||
and ("method" in args.speculative_config)
|
||||
and (args.speculative_config["method"] is not None)
|
||||
):
|
||||
logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not support speculative decoding now.")
|
||||
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
|
||||
if args.splitwise_role != "mixed" and args.cache_transfer_protocol != "rdma":
|
||||
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
|
||||
if not current_platform.is_cuda() and not current_platform.is_xpu():
|
||||
|
||||
Reference in New Issue
Block a user