mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-11-01 20:32:52 +08:00
[Speculative Decoding][MTP]Support mtp in epdptp mode (#4614)
* support mtp many features * support mtp reshard in rl mode * fix function * support mtp ep * support mtp in hybird-dp-tp mode * default open scheduler_v1 in mtp
This commit is contained in:
@@ -442,8 +442,7 @@ class EngineArgs:
|
|||||||
raise NotImplementedError("Only CUDA platform supports logprob.")
|
raise NotImplementedError("Only CUDA platform supports logprob.")
|
||||||
if self.speculative_config is not None and self.logprobs_mode.startswith("processed"):
|
if self.speculative_config is not None and self.logprobs_mode.startswith("processed"):
|
||||||
raise NotImplementedError("processed_logprobs not support in speculative.")
|
raise NotImplementedError("processed_logprobs not support in speculative.")
|
||||||
if self.speculative_config is not None:
|
|
||||||
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
|
|
||||||
if self.splitwise_role != "mixed" and self.cache_transfer_protocol != "rdma":
|
if self.splitwise_role != "mixed" and self.cache_transfer_protocol != "rdma":
|
||||||
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
|
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
|
||||||
if not current_platform.is_cuda() and not current_platform.is_xpu():
|
if not current_platform.is_cuda() and not current_platform.is_xpu():
|
||||||
|
|||||||
@@ -92,6 +92,8 @@ class AppendAttentionBackend(AttentionBackend):
|
|||||||
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr(
|
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr(
|
||||||
fd_config.model_config, "use_3d_rope", False
|
fd_config.model_config, "use_3d_rope", False
|
||||||
)
|
)
|
||||||
|
if fd_config.speculative_config.model_type != "main":
|
||||||
|
self.rope_3d = False
|
||||||
self.causal: bool = getattr(fd_config.model_config, "causal", True)
|
self.causal: bool = getattr(fd_config.model_config, "causal", True)
|
||||||
self.speculative_method: str = fd_config.speculative_config.method
|
self.speculative_method: str = fd_config.speculative_config.method
|
||||||
self.use_speculate: bool = self.speculative_method is not None
|
self.use_speculate: bool = self.speculative_method is not None
|
||||||
@@ -364,7 +366,7 @@ class AppendAttentionBackend(AttentionBackend):
|
|||||||
getattr(layer, "cache_v_zp", None),
|
getattr(layer, "cache_v_zp", None),
|
||||||
layer.linear_shift,
|
layer.linear_shift,
|
||||||
layer.linear_smooth,
|
layer.linear_smooth,
|
||||||
forward_meta.attn_mask_offsets,
|
None if self.use_speculate else forward_meta.attn_mask_offsets,
|
||||||
metadata.kv_signal_data_list[layer.layer_id],
|
metadata.kv_signal_data_list[layer.layer_id],
|
||||||
getattr(layer, "q_norm_weight", None),
|
getattr(layer, "q_norm_weight", None),
|
||||||
getattr(layer, "k_norm_weight", None),
|
getattr(layer, "k_norm_weight", None),
|
||||||
@@ -383,7 +385,7 @@ class AppendAttentionBackend(AttentionBackend):
|
|||||||
metadata.max_partition_size,
|
metadata.max_partition_size,
|
||||||
metadata.encoder_max_partition_size,
|
metadata.encoder_max_partition_size,
|
||||||
self.speculate_max_draft_token_num + 1,
|
self.speculate_max_draft_token_num + 1,
|
||||||
self.causal,
|
self.causal or self.use_speculate,
|
||||||
self.speculative_method is not None,
|
self.speculative_method is not None,
|
||||||
sliding_window,
|
sliding_window,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ import paddle
|
|||||||
from paddle import nn
|
from paddle import nn
|
||||||
from paddle.distributed import fleet
|
from paddle.distributed import fleet
|
||||||
|
|
||||||
from fastdeploy.model_executor.utils import set_weight_attrs
|
from fastdeploy.model_executor.utils import default_weight_loader, set_weight_attrs
|
||||||
|
|
||||||
from .utils import get_tensor
|
from .utils import get_tensor
|
||||||
|
|
||||||
@@ -53,44 +53,61 @@ class ParallelEHProjection(nn.Layer):
|
|||||||
self.bias_key = prefix + ".bias"
|
self.bias_key = prefix + ".bias"
|
||||||
else:
|
else:
|
||||||
self.bias_key = None
|
self.bias_key = None
|
||||||
self.use_ep = fd_config.parallel_config.use_ep
|
self.fd_config = fd_config
|
||||||
|
self.tp_group = fd_config.parallel_config.tp_group
|
||||||
self.column_cut = True
|
self.column_cut = True
|
||||||
|
self.nranks = fd_config.parallel_config.tensor_parallel_size
|
||||||
|
|
||||||
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
|
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
|
||||||
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
|
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
|
||||||
|
|
||||||
if self.use_ep:
|
if self.column_cut:
|
||||||
self.weight = self.create_parameter(
|
need_gather = True
|
||||||
shape=[embedding_dim, num_embeddings],
|
self.linear = ColumnParallelLinear(
|
||||||
dtype=paddle.get_default_dtype(),
|
embedding_dim,
|
||||||
is_bias=False,
|
num_embeddings,
|
||||||
|
mp_group=self.tp_group,
|
||||||
|
weight_attr=None,
|
||||||
|
has_bias=True if self.bias_key is not None else False,
|
||||||
|
gather_output=need_gather,
|
||||||
|
fuse_matmul_bias=False, # False diff更小
|
||||||
)
|
)
|
||||||
else:
|
set_weight_attrs(
|
||||||
if self.column_cut:
|
self.linear.weight,
|
||||||
need_gather = True
|
{
|
||||||
self.linear = ColumnParallelLinear(
|
"weight_loader": default_weight_loader(self.fd_config),
|
||||||
embedding_dim,
|
"model_format": self.fd_config.model_config.model_format,
|
||||||
num_embeddings,
|
},
|
||||||
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
|
)
|
||||||
weight_attr=None,
|
if self.bias_key is not None:
|
||||||
has_bias=True if self.bias_key is not None else False,
|
set_weight_attrs(
|
||||||
gather_output=need_gather,
|
self.linear.bias,
|
||||||
fuse_matmul_bias=False, # False diff更小
|
{"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}},
|
||||||
)
|
)
|
||||||
|
if self.nranks > 1:
|
||||||
set_weight_attrs(self.linear.weight, {"output_dim": True})
|
set_weight_attrs(self.linear.weight, {"output_dim": True})
|
||||||
if self.bias_key is not None:
|
else:
|
||||||
set_weight_attrs(self.linear.bias, {"output_dim": True})
|
self.linear = RowParallelLinear(
|
||||||
else:
|
embedding_dim,
|
||||||
self.linear = RowParallelLinear(
|
num_embeddings,
|
||||||
embedding_dim,
|
mp_group=self.tp_group,
|
||||||
num_embeddings,
|
weight_attr=None,
|
||||||
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
|
has_bias=True if self.bias_key is not None else False,
|
||||||
weight_attr=None,
|
input_is_parallel=False,
|
||||||
has_bias=True if self.bias_key is not None else False,
|
fuse_matmul_bias=False, # False diff更小
|
||||||
input_is_parallel=False,
|
)
|
||||||
fuse_matmul_bias=False, # False diff更小
|
set_weight_attrs(
|
||||||
)
|
self.linear.weight,
|
||||||
set_weight_attrs(self.linear.weight, {"output_dim": False})
|
{
|
||||||
|
"weight_loader": default_weight_loader(self.fd_config),
|
||||||
|
"model_format": self.fd_config.model_config.model_format,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if self.nranks > 1:
|
||||||
|
set_weight_attrs(self.linear.weight, {"output_dim": True})
|
||||||
|
set_weight_attrs(
|
||||||
|
self.linear.weight, {"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}}
|
||||||
|
)
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
def load_state_dict(self, state_dict):
|
||||||
"""
|
"""
|
||||||
@@ -100,17 +117,14 @@ class ParallelEHProjection(nn.Layer):
|
|||||||
state_dict (dict): A dictionary containing the checkpoint weights and biases.
|
state_dict (dict): A dictionary containing the checkpoint weights and biases.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.use_ep:
|
weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype())
|
||||||
self.weight.set_value(get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()))
|
if self.linear.weight.shape != weight_tensor.shape:
|
||||||
else:
|
weight_tensor = weight_tensor.transpose([1, 0])
|
||||||
weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype())
|
self.linear.weight.set_value(weight_tensor)
|
||||||
if self.linear.weight.shape != weight_tensor.shape:
|
|
||||||
weight_tensor = weight_tensor.transpose([1, 0])
|
|
||||||
self.linear.weight.set_value(weight_tensor)
|
|
||||||
|
|
||||||
if self.bias_key is not None:
|
if self.bias_key is not None:
|
||||||
bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype())
|
bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype())
|
||||||
self.linear.bias.set_value(bias)
|
self.linear.bias.set_value(bias)
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
"""
|
"""
|
||||||
@@ -123,8 +137,5 @@ class ParallelEHProjection(nn.Layer):
|
|||||||
Tensor: The output tensor after processing through the layer.
|
Tensor: The output tensor after processing through the layer.
|
||||||
"""
|
"""
|
||||||
logits = input
|
logits = input
|
||||||
if self.use_ep:
|
logits = self.linear(logits)
|
||||||
logits = paddle.matmul(logits, self.weight)
|
|
||||||
else:
|
|
||||||
logits = self.linear(logits)
|
|
||||||
return logits
|
return logits
|
||||||
|
|||||||
@@ -72,6 +72,11 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
# register rl model
|
# register rl model
|
||||||
import fastdeploy.rl # noqa
|
import fastdeploy.rl # noqa
|
||||||
|
|
||||||
|
if fd_config.speculative_config.model_type != "mtp":
|
||||||
|
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MoeForCausalLM")
|
||||||
|
else:
|
||||||
|
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MTPForCausalLM")
|
||||||
|
|
||||||
architectures = architectures + "RL"
|
architectures = architectures + "RL"
|
||||||
context = paddle.LazyGuard()
|
context = paddle.LazyGuard()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -65,6 +65,11 @@ class DefaultModelLoaderV1(BaseModelLoader):
|
|||||||
# register rl model
|
# register rl model
|
||||||
import fastdeploy.rl # noqa
|
import fastdeploy.rl # noqa
|
||||||
|
|
||||||
|
if fd_config.speculative_config.model_type != "mtp":
|
||||||
|
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MoeForCausalLM")
|
||||||
|
else:
|
||||||
|
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MTPForCausalLM")
|
||||||
|
|
||||||
architectures = architectures + "RL"
|
architectures = architectures + "RL"
|
||||||
|
|
||||||
enable_cache, _, weight_cache_context = is_weight_cache_enabled(fd_config)
|
enable_cache, _, weight_cache_context = is_weight_cache_enabled(fd_config)
|
||||||
|
|||||||
@@ -502,7 +502,7 @@ class TokenProcessor:
|
|||||||
|
|
||||||
def _compute_speculative_status(self):
|
def _compute_speculative_status(self):
|
||||||
# TODO(liuzichang): Supplement more statistics
|
# TODO(liuzichang): Supplement more statistics
|
||||||
interval = 10
|
interval = 1
|
||||||
if self.speculative_stats_step % interval == 0:
|
if self.speculative_stats_step % interval == 0:
|
||||||
accept_ratio = 1 - self.total_step * 1.0 / self.number_of_output_tokens
|
accept_ratio = 1 - self.total_step * 1.0 / self.number_of_output_tokens
|
||||||
spec_logger.info(
|
spec_logger.info(
|
||||||
@@ -593,6 +593,9 @@ class TokenProcessor:
|
|||||||
+ accept_num[i]
|
+ accept_num[i]
|
||||||
].tolist()
|
].tolist()
|
||||||
if (not recovery_stop) and (len(token_ids) == 0 or token_ids[-1] <= 0):
|
if (not recovery_stop) and (len(token_ids) == 0 or token_ids[-1] <= 0):
|
||||||
|
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||||
|
if task_id in self.resource_manager.to_be_rescheduled_request_id_set:
|
||||||
|
self.resource_manager.reschedule_preempt_task(task_id)
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
token_id = int(tokens[i, 0])
|
token_id = int(tokens[i, 0])
|
||||||
|
|||||||
@@ -17,11 +17,10 @@
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from multiprocessing.shared_memory import SharedMemory
|
from multiprocessing.shared_memory import SharedMemory
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import paddle
|
import paddle
|
||||||
from paddle import nn
|
|
||||||
from paddleformers.utils.log import logger
|
from paddleformers.utils.log import logger
|
||||||
|
|
||||||
from fastdeploy.config import FDConfig
|
from fastdeploy.config import FDConfig
|
||||||
@@ -31,7 +30,7 @@ from fastdeploy.inter_communicator import ModelWeightsStatus
|
|||||||
class DynamicWeightManager:
|
class DynamicWeightManager:
|
||||||
"""Manages model weights loading, updating and shared state across processes."""
|
"""Manages model weights loading, updating and shared state across processes."""
|
||||||
|
|
||||||
def __init__(self, fd_config: FDConfig, model: nn.Layer):
|
def __init__(self, fd_config: FDConfig, models):
|
||||||
"""Initialize with config and model instances."""
|
"""Initialize with config and model instances."""
|
||||||
self.fd_config = fd_config
|
self.fd_config = fd_config
|
||||||
self.load_config = fd_config.load_config
|
self.load_config = fd_config.load_config
|
||||||
@@ -42,7 +41,10 @@ class DynamicWeightManager:
|
|||||||
self.meta_src_id = self._get_gpu_id()
|
self.meta_src_id = self._get_gpu_id()
|
||||||
self.first_load = True
|
self.first_load = True
|
||||||
self.ipc_path = f"/shared_ipc_meta/ipc_metas_{self.meta_src_id}"
|
self.ipc_path = f"/shared_ipc_meta/ipc_metas_{self.meta_src_id}"
|
||||||
self.model: nn.Layer = model
|
if not isinstance(models, List):
|
||||||
|
self.model_list = [models]
|
||||||
|
else:
|
||||||
|
self.model_list = models
|
||||||
self._capture_model_state()
|
self._capture_model_state()
|
||||||
self.update_parameters()
|
self.update_parameters()
|
||||||
self.finalize_update()
|
self.finalize_update()
|
||||||
@@ -55,9 +57,10 @@ class DynamicWeightManager:
|
|||||||
@paddle.no_grad()
|
@paddle.no_grad()
|
||||||
def _capture_model_state(self):
|
def _capture_model_state(self):
|
||||||
"""Capture and store initial model parameters state."""
|
"""Capture and store initial model parameters state."""
|
||||||
for name, param in self.model.state_dict().items():
|
for model in self.model_list:
|
||||||
logger.debug(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}")
|
for name, param in model.state_dict().items():
|
||||||
self.state_dict[name] = param
|
logger.info(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}")
|
||||||
|
self.state_dict[name] = param
|
||||||
|
|
||||||
def update_parameters(self, pid: int = 0) -> None:
|
def update_parameters(self, pid: int = 0) -> None:
|
||||||
"""Core method to update model parameters based on strategy."""
|
"""Core method to update model parameters based on strategy."""
|
||||||
@@ -137,8 +140,9 @@ class DynamicWeightManager:
|
|||||||
|
|
||||||
paddle.device.cuda.empty_cache()
|
paddle.device.cuda.empty_cache()
|
||||||
# step2: release model weight
|
# step2: release model weight
|
||||||
for param in self.model.state_dict().values():
|
for model in self.model_list:
|
||||||
param._clear_data()
|
for param in model.state_dict().values():
|
||||||
|
param._clear_data()
|
||||||
|
|
||||||
self._verify_parameters("clearance")
|
self._verify_parameters("clearance")
|
||||||
|
|
||||||
|
|||||||
@@ -38,13 +38,20 @@ class Proposer(ABC):
|
|||||||
Init Speculative proposer
|
Init Speculative proposer
|
||||||
"""
|
"""
|
||||||
fd_config.parallel_config.tp_group = None
|
fd_config.parallel_config.tp_group = None
|
||||||
|
fd_config.parallel_config.ep_group = None
|
||||||
self.fd_config = deepcopy(fd_config)
|
self.fd_config = deepcopy(fd_config)
|
||||||
fd_config.parallel_config.tp_group = dist.get_group(
|
fd_config.parallel_config.tp_group = dist.get_group(
|
||||||
fd_config.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
|
fd_config.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
|
||||||
)
|
)
|
||||||
|
fd_config.parallel_config.ep_group = dist.get_group(
|
||||||
|
fd_config.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET
|
||||||
|
)
|
||||||
self.fd_config.parallel_config.tp_group = dist.get_group(
|
self.fd_config.parallel_config.tp_group = dist.get_group(
|
||||||
fd_config.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
|
fd_config.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
|
||||||
)
|
)
|
||||||
|
self.fd_config.parallel_config.ep_group = dist.get_group(
|
||||||
|
fd_config.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET
|
||||||
|
)
|
||||||
self.parallel_config = self.fd_config.parallel_config
|
self.parallel_config = self.fd_config.parallel_config
|
||||||
self.model_config = self.fd_config.model_config
|
self.model_config = self.fd_config.model_config
|
||||||
self.speculative_config = self.fd_config.speculative_config
|
self.speculative_config = self.fd_config.speculative_config
|
||||||
|
|||||||
@@ -96,7 +96,8 @@ class MTPProposer(Proposer):
|
|||||||
"""
|
"""
|
||||||
Update config for MTP from global config
|
Update config for MTP from global config
|
||||||
"""
|
"""
|
||||||
self.model_config.architectures[0] = "Ernie4_5_MTPForCausalLM"
|
self.forward_meta: ForwardMeta = None
|
||||||
|
self.model_config.architectures[0] = self.model_config.architectures[0].replace("Moe", "MTP")
|
||||||
self.speculative_config.sharing_model = main_model
|
self.speculative_config.sharing_model = main_model
|
||||||
self.model_config.num_hidden_layers = 1
|
self.model_config.num_hidden_layers = 1
|
||||||
self.model_config.model = self.speculative_config.model
|
self.model_config.model = self.speculative_config.model
|
||||||
@@ -169,6 +170,9 @@ class MTPProposer(Proposer):
|
|||||||
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
|
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
|
||||||
max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
|
max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
|
||||||
)
|
)
|
||||||
|
if kv_cache_quant_type == "block_wise_fp8":
|
||||||
|
kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]]
|
||||||
|
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
||||||
if not profile and (
|
if not profile and (
|
||||||
self.cache_config.enable_prefix_caching or self.scheduler_config.splitwise_role != "mixed"
|
self.cache_config.enable_prefix_caching or self.scheduler_config.splitwise_role != "mixed"
|
||||||
):
|
):
|
||||||
@@ -178,8 +182,8 @@ class MTPProposer(Proposer):
|
|||||||
self.num_main_model_layers + self.model_config.num_hidden_layers,
|
self.num_main_model_layers + self.model_config.num_hidden_layers,
|
||||||
):
|
):
|
||||||
key_cache = paddle.empty(shape=[], dtype=cache_type)
|
key_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||||
key_cache_name = f"key_caches_{i}_rank{self.local_rank}.device{self.device_id}"
|
key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
|
||||||
val_cache_name = f"value_caches_{i}_rank{self.local_rank}.device{self.device_id}"
|
val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
|
||||||
key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape)
|
key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape)
|
||||||
cache_kvs_list.append(key_cache)
|
cache_kvs_list.append(key_cache)
|
||||||
value_cache = paddle.empty(shape=[], dtype=cache_type)
|
value_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||||
@@ -199,6 +203,17 @@ class MTPProposer(Proposer):
|
|||||||
fill_value=0,
|
fill_value=0,
|
||||||
dtype=cache_type,
|
dtype=cache_type,
|
||||||
)
|
)
|
||||||
|
if kv_cache_quant_type == "block_wise_fp8":
|
||||||
|
self.cache_kvs[f"key_cache_scales_{i}"] = paddle.full(
|
||||||
|
shape=kv_cache_scale_shape,
|
||||||
|
fill_value=0,
|
||||||
|
dtype=paddle.get_default_dtype(),
|
||||||
|
)
|
||||||
|
self.cache_kvs[f"value_cache_scales_{i}"] = paddle.full(
|
||||||
|
shape=kv_cache_scale_shape,
|
||||||
|
fill_value=0,
|
||||||
|
dtype=paddle.get_default_dtype(),
|
||||||
|
)
|
||||||
self.model_inputs["caches"] = list(self.cache_kvs.values())
|
self.model_inputs["caches"] = list(self.cache_kvs.values())
|
||||||
for value in self.cache_kvs.values():
|
for value in self.cache_kvs.values():
|
||||||
del value
|
del value
|
||||||
@@ -430,11 +445,10 @@ class MTPProposer(Proposer):
|
|||||||
if "caches" not in self.model_inputs:
|
if "caches" not in self.model_inputs:
|
||||||
self.initialize_kv_cache()
|
self.initialize_kv_cache()
|
||||||
req_len = len(req_dicts)
|
req_len = len(req_dicts)
|
||||||
# has_prefill_task = False
|
|
||||||
# has_decode_task = False
|
|
||||||
for i in range(req_len):
|
for i in range(req_len):
|
||||||
request = req_dicts[i]
|
request = req_dicts[i]
|
||||||
logger.info(f"{i}th request-{request.request_id}: {request}")
|
logger.debug(f"{i}th request-{request.request_id}: {request}")
|
||||||
idx = request.idx
|
idx = request.idx
|
||||||
if request.task_type.value == RequestType.PREFILL.value: # prefill task
|
if request.task_type.value == RequestType.PREFILL.value: # prefill task
|
||||||
prefill_start_index = request.prefill_start_index
|
prefill_start_index = request.prefill_start_index
|
||||||
@@ -688,7 +702,7 @@ class MTPProposer(Proposer):
|
|||||||
self.max_model_len,
|
self.max_model_len,
|
||||||
self.model_inputs["substep"],
|
self.model_inputs["substep"],
|
||||||
)
|
)
|
||||||
if self.role == "prefill":
|
if self.role == "prefill" and self.parallel_config.tensor_parallel_rank == 0:
|
||||||
mtp_save_first_token(
|
mtp_save_first_token(
|
||||||
self.model_inputs["base_model_draft_tokens"],
|
self.model_inputs["base_model_draft_tokens"],
|
||||||
self.model_inputs["not_need_stop"],
|
self.model_inputs["not_need_stop"],
|
||||||
@@ -820,11 +834,18 @@ class MTPProposer(Proposer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.parallel_config.tensor_parallel_size > 1:
|
if self.parallel_config.tensor_parallel_size > 1:
|
||||||
paddle.distributed.broadcast(sampled_token_ids, 0)
|
paddle.distributed.broadcast(
|
||||||
|
sampled_token_ids,
|
||||||
|
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
|
||||||
|
group=self.parallel_config.tp_group,
|
||||||
|
)
|
||||||
|
|
||||||
self._post_process(sampled_token_ids)
|
self._post_process(sampled_token_ids)
|
||||||
if substep != self.num_model_steps - 1:
|
if substep != self.num_model_steps - 1:
|
||||||
self._get_self_hidden_states(hidden_states)
|
self._get_self_hidden_states(hidden_states)
|
||||||
|
else:
|
||||||
|
if hasattr(self.model, "empty_input_forward"):
|
||||||
|
self.model.empty_input_forward()
|
||||||
|
|
||||||
def _get_self_hidden_states(self, hidden_states):
|
def _get_self_hidden_states(self, hidden_states):
|
||||||
target_hidden_states = eagle_get_self_hidden_states(
|
target_hidden_states = eagle_get_self_hidden_states(
|
||||||
|
|||||||
@@ -778,13 +778,6 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
|||||||
logger.info(f"- Dynamic load weight: {load_config.dynamic_load_weight}")
|
logger.info(f"- Dynamic load weight: {load_config.dynamic_load_weight}")
|
||||||
logger.info(f"- Load strategy: {load_config.load_strategy}")
|
logger.info(f"- Load strategy: {load_config.load_strategy}")
|
||||||
|
|
||||||
if (
|
|
||||||
args.speculative_config is not None
|
|
||||||
and ("method" in args.speculative_config)
|
|
||||||
and (args.speculative_config["method"] is not None)
|
|
||||||
):
|
|
||||||
logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not support speculative decoding now.")
|
|
||||||
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
|
|
||||||
if args.splitwise_role != "mixed" and args.cache_transfer_protocol != "rdma":
|
if args.splitwise_role != "mixed" and args.cache_transfer_protocol != "rdma":
|
||||||
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
|
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
|
||||||
if not current_platform.is_cuda() and not current_platform.is_xpu():
|
if not current_platform.is_cuda() and not current_platform.is_xpu():
|
||||||
|
|||||||
Reference in New Issue
Block a user