diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index d6b31050d..6fcbeac1e 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -442,8 +442,7 @@ class EngineArgs: raise NotImplementedError("Only CUDA platform supports logprob.") if self.speculative_config is not None and self.logprobs_mode.startswith("processed"): raise NotImplementedError("processed_logprobs not support in speculative.") - if self.speculative_config is not None: - envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 + if self.splitwise_role != "mixed" and self.cache_transfer_protocol != "rdma": envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 if not current_platform.is_cuda() and not current_platform.is_xpu(): diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index 3e178c92d..813fb4790 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -92,6 +92,8 @@ class AppendAttentionBackend(AttentionBackend): self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr( fd_config.model_config, "use_3d_rope", False ) + if fd_config.speculative_config.model_type != "main": + self.rope_3d = False self.causal: bool = getattr(fd_config.model_config, "causal", True) self.speculative_method: str = fd_config.speculative_config.method self.use_speculate: bool = self.speculative_method is not None @@ -364,7 +366,7 @@ class AppendAttentionBackend(AttentionBackend): getattr(layer, "cache_v_zp", None), layer.linear_shift, layer.linear_smooth, - forward_meta.attn_mask_offsets, + None if self.use_speculate else forward_meta.attn_mask_offsets, metadata.kv_signal_data_list[layer.layer_id], getattr(layer, "q_norm_weight", None), getattr(layer, "k_norm_weight", None), @@ -383,7 +385,7 @@ class AppendAttentionBackend(AttentionBackend): metadata.max_partition_size, metadata.encoder_max_partition_size, self.speculate_max_draft_token_num + 1, - self.causal, + self.causal or self.use_speculate, self.speculative_method is not None, sliding_window, ) diff --git a/fastdeploy/model_executor/layers/mtp_linear.py b/fastdeploy/model_executor/layers/mtp_linear.py index cd9e48a72..4250b611f 100644 --- a/fastdeploy/model_executor/layers/mtp_linear.py +++ b/fastdeploy/model_executor/layers/mtp_linear.py @@ -18,7 +18,7 @@ import paddle from paddle import nn from paddle.distributed import fleet -from fastdeploy.model_executor.utils import set_weight_attrs +from fastdeploy.model_executor.utils import default_weight_loader, set_weight_attrs from .utils import get_tensor @@ -53,44 +53,61 @@ class ParallelEHProjection(nn.Layer): self.bias_key = prefix + ".bias" else: self.bias_key = None - self.use_ep = fd_config.parallel_config.use_ep + self.fd_config = fd_config + self.tp_group = fd_config.parallel_config.tp_group self.column_cut = True + self.nranks = fd_config.parallel_config.tensor_parallel_size ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear RowParallelLinear = fleet.meta_parallel.RowParallelLinear - if self.use_ep: - self.weight = self.create_parameter( - shape=[embedding_dim, num_embeddings], - dtype=paddle.get_default_dtype(), - is_bias=False, + if self.column_cut: + need_gather = True + self.linear = ColumnParallelLinear( + embedding_dim, + num_embeddings, + mp_group=self.tp_group, + weight_attr=None, + has_bias=True if self.bias_key is not None else False, + gather_output=need_gather, + fuse_matmul_bias=False, # False diff更小 ) - else: - if self.column_cut: - need_gather = True - self.linear = ColumnParallelLinear( - embedding_dim, - num_embeddings, - mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(), - weight_attr=None, - has_bias=True if self.bias_key is not None else False, - gather_output=need_gather, - fuse_matmul_bias=False, # False diff更小 + set_weight_attrs( + self.linear.weight, + { + "weight_loader": default_weight_loader(self.fd_config), + "model_format": self.fd_config.model_config.model_format, + }, + ) + if self.bias_key is not None: + set_weight_attrs( + self.linear.bias, + {"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}}, ) + if self.nranks > 1: set_weight_attrs(self.linear.weight, {"output_dim": True}) - if self.bias_key is not None: - set_weight_attrs(self.linear.bias, {"output_dim": True}) - else: - self.linear = RowParallelLinear( - embedding_dim, - num_embeddings, - mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(), - weight_attr=None, - has_bias=True if self.bias_key is not None else False, - input_is_parallel=False, - fuse_matmul_bias=False, # False diff更小 - ) - set_weight_attrs(self.linear.weight, {"output_dim": False}) + else: + self.linear = RowParallelLinear( + embedding_dim, + num_embeddings, + mp_group=self.tp_group, + weight_attr=None, + has_bias=True if self.bias_key is not None else False, + input_is_parallel=False, + fuse_matmul_bias=False, # False diff更小 + ) + set_weight_attrs( + self.linear.weight, + { + "weight_loader": default_weight_loader(self.fd_config), + "model_format": self.fd_config.model_config.model_format, + }, + ) + if self.nranks > 1: + set_weight_attrs(self.linear.weight, {"output_dim": True}) + set_weight_attrs( + self.linear.weight, {"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}} + ) def load_state_dict(self, state_dict): """ @@ -100,17 +117,14 @@ class ParallelEHProjection(nn.Layer): state_dict (dict): A dictionary containing the checkpoint weights and biases. """ - if self.use_ep: - self.weight.set_value(get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype())) - else: - weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()) - if self.linear.weight.shape != weight_tensor.shape: - weight_tensor = weight_tensor.transpose([1, 0]) - self.linear.weight.set_value(weight_tensor) + weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()) + if self.linear.weight.shape != weight_tensor.shape: + weight_tensor = weight_tensor.transpose([1, 0]) + self.linear.weight.set_value(weight_tensor) - if self.bias_key is not None: - bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype()) - self.linear.bias.set_value(bias) + if self.bias_key is not None: + bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype()) + self.linear.bias.set_value(bias) def forward(self, input): """ @@ -123,8 +137,5 @@ class ParallelEHProjection(nn.Layer): Tensor: The output tensor after processing through the layer. """ logits = input - if self.use_ep: - logits = paddle.matmul(logits, self.weight) - else: - logits = self.linear(logits) + logits = self.linear(logits) return logits diff --git a/fastdeploy/model_executor/model_loader/default_loader.py b/fastdeploy/model_executor/model_loader/default_loader.py index dc6bedc91..ca0dfa84f 100644 --- a/fastdeploy/model_executor/model_loader/default_loader.py +++ b/fastdeploy/model_executor/model_loader/default_loader.py @@ -72,6 +72,11 @@ class DefaultModelLoader(BaseModelLoader): # register rl model import fastdeploy.rl # noqa + if fd_config.speculative_config.model_type != "mtp": + architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MoeForCausalLM") + else: + architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MTPForCausalLM") + architectures = architectures + "RL" context = paddle.LazyGuard() else: diff --git a/fastdeploy/model_executor/model_loader/default_loader_v1.py b/fastdeploy/model_executor/model_loader/default_loader_v1.py index 83f13382f..0193b259d 100644 --- a/fastdeploy/model_executor/model_loader/default_loader_v1.py +++ b/fastdeploy/model_executor/model_loader/default_loader_v1.py @@ -65,6 +65,11 @@ class DefaultModelLoaderV1(BaseModelLoader): # register rl model import fastdeploy.rl # noqa + if fd_config.speculative_config.model_type != "mtp": + architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MoeForCausalLM") + else: + architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MTPForCausalLM") + architectures = architectures + "RL" enable_cache, _, weight_cache_context = is_weight_cache_enabled(fd_config) diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index f60967ca6..694dd8fe5 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -502,7 +502,7 @@ class TokenProcessor: def _compute_speculative_status(self): # TODO(liuzichang): Supplement more statistics - interval = 10 + interval = 1 if self.speculative_stats_step % interval == 0: accept_ratio = 1 - self.total_step * 1.0 / self.number_of_output_tokens spec_logger.info( @@ -593,6 +593,9 @@ class TokenProcessor: + accept_num[i] ].tolist() if (not recovery_stop) and (len(token_ids) == 0 or token_ids[-1] <= 0): + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + if task_id in self.resource_manager.to_be_rescheduled_request_id_set: + self.resource_manager.reschedule_preempt_task(task_id) continue else: token_id = int(tokens[i, 0]) diff --git a/fastdeploy/rl/dynamic_weight_manager.py b/fastdeploy/rl/dynamic_weight_manager.py index 4605c879a..dba4d0f27 100644 --- a/fastdeploy/rl/dynamic_weight_manager.py +++ b/fastdeploy/rl/dynamic_weight_manager.py @@ -17,11 +17,10 @@ import os import time from multiprocessing.shared_memory import SharedMemory -from typing import Any, Dict +from typing import Any, Dict, List import numpy as np import paddle -from paddle import nn from paddleformers.utils.log import logger from fastdeploy.config import FDConfig @@ -31,7 +30,7 @@ from fastdeploy.inter_communicator import ModelWeightsStatus class DynamicWeightManager: """Manages model weights loading, updating and shared state across processes.""" - def __init__(self, fd_config: FDConfig, model: nn.Layer): + def __init__(self, fd_config: FDConfig, models): """Initialize with config and model instances.""" self.fd_config = fd_config self.load_config = fd_config.load_config @@ -42,7 +41,10 @@ class DynamicWeightManager: self.meta_src_id = self._get_gpu_id() self.first_load = True self.ipc_path = f"/shared_ipc_meta/ipc_metas_{self.meta_src_id}" - self.model: nn.Layer = model + if not isinstance(models, List): + self.model_list = [models] + else: + self.model_list = models self._capture_model_state() self.update_parameters() self.finalize_update() @@ -55,9 +57,10 @@ class DynamicWeightManager: @paddle.no_grad() def _capture_model_state(self): """Capture and store initial model parameters state.""" - for name, param in self.model.state_dict().items(): - logger.debug(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}") - self.state_dict[name] = param + for model in self.model_list: + for name, param in model.state_dict().items(): + logger.info(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}") + self.state_dict[name] = param def update_parameters(self, pid: int = 0) -> None: """Core method to update model parameters based on strategy.""" @@ -137,8 +140,9 @@ class DynamicWeightManager: paddle.device.cuda.empty_cache() # step2: release model weight - for param in self.model.state_dict().values(): - param._clear_data() + for model in self.model_list: + for param in model.state_dict().values(): + param._clear_data() self._verify_parameters("clearance") diff --git a/fastdeploy/spec_decode/base.py b/fastdeploy/spec_decode/base.py index 458f8e579..7438a0dbe 100644 --- a/fastdeploy/spec_decode/base.py +++ b/fastdeploy/spec_decode/base.py @@ -38,13 +38,20 @@ class Proposer(ABC): Init Speculative proposer """ fd_config.parallel_config.tp_group = None + fd_config.parallel_config.ep_group = None self.fd_config = deepcopy(fd_config) fd_config.parallel_config.tp_group = dist.get_group( fd_config.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET ) + fd_config.parallel_config.ep_group = dist.get_group( + fd_config.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET + ) self.fd_config.parallel_config.tp_group = dist.get_group( fd_config.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET ) + self.fd_config.parallel_config.ep_group = dist.get_group( + fd_config.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET + ) self.parallel_config = self.fd_config.parallel_config self.model_config = self.fd_config.model_config self.speculative_config = self.fd_config.speculative_config diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index ae496d710..70127c8eb 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -96,7 +96,8 @@ class MTPProposer(Proposer): """ Update config for MTP from global config """ - self.model_config.architectures[0] = "Ernie4_5_MTPForCausalLM" + self.forward_meta: ForwardMeta = None + self.model_config.architectures[0] = self.model_config.architectures[0].replace("Moe", "MTP") self.speculative_config.sharing_model = main_model self.model_config.num_hidden_layers = 1 self.model_config.model = self.speculative_config.model @@ -169,6 +170,9 @@ class MTPProposer(Proposer): kv_cache_shape = self.attn_backends[0].get_kv_cache_shape( max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type ) + if kv_cache_quant_type == "block_wise_fp8": + kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]] + local_rank = self.local_rank % self.parallel_config.tensor_parallel_size if not profile and ( self.cache_config.enable_prefix_caching or self.scheduler_config.splitwise_role != "mixed" ): @@ -178,8 +182,8 @@ class MTPProposer(Proposer): self.num_main_model_layers + self.model_config.num_hidden_layers, ): key_cache = paddle.empty(shape=[], dtype=cache_type) - key_cache_name = f"key_caches_{i}_rank{self.local_rank}.device{self.device_id}" - val_cache_name = f"value_caches_{i}_rank{self.local_rank}.device{self.device_id}" + key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}" + val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}" key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape) cache_kvs_list.append(key_cache) value_cache = paddle.empty(shape=[], dtype=cache_type) @@ -199,6 +203,17 @@ class MTPProposer(Proposer): fill_value=0, dtype=cache_type, ) + if kv_cache_quant_type == "block_wise_fp8": + self.cache_kvs[f"key_cache_scales_{i}"] = paddle.full( + shape=kv_cache_scale_shape, + fill_value=0, + dtype=paddle.get_default_dtype(), + ) + self.cache_kvs[f"value_cache_scales_{i}"] = paddle.full( + shape=kv_cache_scale_shape, + fill_value=0, + dtype=paddle.get_default_dtype(), + ) self.model_inputs["caches"] = list(self.cache_kvs.values()) for value in self.cache_kvs.values(): del value @@ -430,11 +445,10 @@ class MTPProposer(Proposer): if "caches" not in self.model_inputs: self.initialize_kv_cache() req_len = len(req_dicts) - # has_prefill_task = False - # has_decode_task = False + for i in range(req_len): request = req_dicts[i] - logger.info(f"{i}th request-{request.request_id}: {request}") + logger.debug(f"{i}th request-{request.request_id}: {request}") idx = request.idx if request.task_type.value == RequestType.PREFILL.value: # prefill task prefill_start_index = request.prefill_start_index @@ -688,7 +702,7 @@ class MTPProposer(Proposer): self.max_model_len, self.model_inputs["substep"], ) - if self.role == "prefill": + if self.role == "prefill" and self.parallel_config.tensor_parallel_rank == 0: mtp_save_first_token( self.model_inputs["base_model_draft_tokens"], self.model_inputs["not_need_stop"], @@ -820,11 +834,18 @@ class MTPProposer(Proposer): ) if self.parallel_config.tensor_parallel_size > 1: - paddle.distributed.broadcast(sampled_token_ids, 0) + paddle.distributed.broadcast( + sampled_token_ids, + self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, + group=self.parallel_config.tp_group, + ) self._post_process(sampled_token_ids) if substep != self.num_model_steps - 1: self._get_self_hidden_states(hidden_states) + else: + if hasattr(self.model, "empty_input_forward"): + self.model.empty_input_forward() def _get_self_hidden_states(self, hidden_states): target_hidden_states = eagle_get_self_hidden_states( diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index ab03d8ce4..d69c8449a 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -778,13 +778,6 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: logger.info(f"- Dynamic load weight: {load_config.dynamic_load_weight}") logger.info(f"- Load strategy: {load_config.load_strategy}") - if ( - args.speculative_config is not None - and ("method" in args.speculative_config) - and (args.speculative_config["method"] is not None) - ): - logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not support speculative decoding now.") - envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 if args.splitwise_role != "mixed" and args.cache_transfer_protocol != "rdma": envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 if not current_platform.is_cuda() and not current_platform.is_xpu():