diff --git a/fastdeploy/model_executor/layers/mtp_linear.py b/fastdeploy/model_executor/layers/mtp_linear.py index cd9e48a72..4250b611f 100644 --- a/fastdeploy/model_executor/layers/mtp_linear.py +++ b/fastdeploy/model_executor/layers/mtp_linear.py @@ -18,7 +18,7 @@ import paddle from paddle import nn from paddle.distributed import fleet -from fastdeploy.model_executor.utils import set_weight_attrs +from fastdeploy.model_executor.utils import default_weight_loader, set_weight_attrs from .utils import get_tensor @@ -53,44 +53,61 @@ class ParallelEHProjection(nn.Layer): self.bias_key = prefix + ".bias" else: self.bias_key = None - self.use_ep = fd_config.parallel_config.use_ep + self.fd_config = fd_config + self.tp_group = fd_config.parallel_config.tp_group self.column_cut = True + self.nranks = fd_config.parallel_config.tensor_parallel_size ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear RowParallelLinear = fleet.meta_parallel.RowParallelLinear - if self.use_ep: - self.weight = self.create_parameter( - shape=[embedding_dim, num_embeddings], - dtype=paddle.get_default_dtype(), - is_bias=False, + if self.column_cut: + need_gather = True + self.linear = ColumnParallelLinear( + embedding_dim, + num_embeddings, + mp_group=self.tp_group, + weight_attr=None, + has_bias=True if self.bias_key is not None else False, + gather_output=need_gather, + fuse_matmul_bias=False, # False diff更小 ) - else: - if self.column_cut: - need_gather = True - self.linear = ColumnParallelLinear( - embedding_dim, - num_embeddings, - mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(), - weight_attr=None, - has_bias=True if self.bias_key is not None else False, - gather_output=need_gather, - fuse_matmul_bias=False, # False diff更小 + set_weight_attrs( + self.linear.weight, + { + "weight_loader": default_weight_loader(self.fd_config), + "model_format": self.fd_config.model_config.model_format, + }, + ) + if self.bias_key is not None: + set_weight_attrs( + self.linear.bias, + {"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}}, ) + if self.nranks > 1: set_weight_attrs(self.linear.weight, {"output_dim": True}) - if self.bias_key is not None: - set_weight_attrs(self.linear.bias, {"output_dim": True}) - else: - self.linear = RowParallelLinear( - embedding_dim, - num_embeddings, - mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(), - weight_attr=None, - has_bias=True if self.bias_key is not None else False, - input_is_parallel=False, - fuse_matmul_bias=False, # False diff更小 - ) - set_weight_attrs(self.linear.weight, {"output_dim": False}) + else: + self.linear = RowParallelLinear( + embedding_dim, + num_embeddings, + mp_group=self.tp_group, + weight_attr=None, + has_bias=True if self.bias_key is not None else False, + input_is_parallel=False, + fuse_matmul_bias=False, # False diff更小 + ) + set_weight_attrs( + self.linear.weight, + { + "weight_loader": default_weight_loader(self.fd_config), + "model_format": self.fd_config.model_config.model_format, + }, + ) + if self.nranks > 1: + set_weight_attrs(self.linear.weight, {"output_dim": True}) + set_weight_attrs( + self.linear.weight, {"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}} + ) def load_state_dict(self, state_dict): """ @@ -100,17 +117,14 @@ class ParallelEHProjection(nn.Layer): state_dict (dict): A dictionary containing the checkpoint weights and biases. """ - if self.use_ep: - self.weight.set_value(get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype())) - else: - weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()) - if self.linear.weight.shape != weight_tensor.shape: - weight_tensor = weight_tensor.transpose([1, 0]) - self.linear.weight.set_value(weight_tensor) + weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()) + if self.linear.weight.shape != weight_tensor.shape: + weight_tensor = weight_tensor.transpose([1, 0]) + self.linear.weight.set_value(weight_tensor) - if self.bias_key is not None: - bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype()) - self.linear.bias.set_value(bias) + if self.bias_key is not None: + bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype()) + self.linear.bias.set_value(bias) def forward(self, input): """ @@ -123,8 +137,5 @@ class ParallelEHProjection(nn.Layer): Tensor: The output tensor after processing through the layer. """ logits = input - if self.use_ep: - logits = paddle.matmul(logits, self.weight) - else: - logits = self.linear(logits) + logits = self.linear(logits) return logits diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index cc75ed96c..92b79b6fb 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -147,6 +147,10 @@ class MTPProposer(Proposer): kv_cache_shape = self.attn_backends[0].get_kv_cache_shape( max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type ) + if kv_cache_quant_type == "block_wise_fp8": + kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]] + local_rank = self.local_rank % self.parallel_config.tensor_parallel_size + if not self.parallel_config.do_profile and ( self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed" ): @@ -156,8 +160,8 @@ class MTPProposer(Proposer): self.num_main_model_layers + self.model_config.num_hidden_layers, ): key_cache = paddle.empty(shape=[], dtype=cache_type) - key_cache_name = f"key_caches_{i}_rank{self.local_rank}.device{self.device_id}" - val_cache_name = f"value_caches_{i}_rank{self.local_rank}.device{self.device_id}" + key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}" + val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}" key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape) cache_kvs_list.append(key_cache) value_cache = paddle.empty(shape=[], dtype=cache_type) @@ -177,6 +181,17 @@ class MTPProposer(Proposer): fill_value=0, dtype=cache_type, ) + if kv_cache_quant_type == "block_wise_fp8": + self.cache_kvs[f"key_cache_scales_{i}"] = paddle.full( + shape=kv_cache_scale_shape, + fill_value=0, + dtype=paddle.get_default_dtype(), + ) + self.cache_kvs[f"value_cache_scales_{i}"] = paddle.full( + shape=kv_cache_scale_shape, + fill_value=0, + dtype=paddle.get_default_dtype(), + ) self.model_inputs["caches"] = list(self.cache_kvs.values()) for value in self.cache_kvs.values(): del value @@ -610,7 +625,7 @@ class MTPProposer(Proposer): self.max_model_len, self.model_inputs["substep"], ) - if self.role == "prefill": + if self.role == "prefill" and self.parallel_config.tensor_parallel_rank == 0: mtp_save_first_token( self.model_inputs["base_model_draft_tokens"], self.model_inputs["not_need_stop"], @@ -697,7 +712,11 @@ class MTPProposer(Proposer): ) if self.parallel_config.tensor_parallel_size > 1: - paddle.distributed.broadcast(sampled_token_ids, 0) + paddle.distributed.broadcast( + sampled_token_ids, + self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, + group=self.parallel_config.tp_group, + ) self._post_process(sampled_token_ids)