From b4fef2cf2967356c7a4cfaa2ba7411b87a96f6e6 Mon Sep 17 00:00:00 2001 From: Kane2011 <86709049+Kane2011@users.noreply.github.com> Date: Wed, 13 Aug 2025 11:11:54 +0800 Subject: [PATCH] [MetaxGPU] Support FastDeploy on metax gpu (#3241) * [MetaxGPU] Support FastDeploy on metax gpu * Update metax_worker.py 1. change worker log; 2. remove custom allreduce, adapt it later; 3. remove cuda graph; * Update __init__.py 1. remove metax's key work comment * Update __init__.py 1. remove metax's key word comment; 2. add fused_moe_kernel_paddle import --------- Co-authored-by: yongqiangma --- build.sh | 10 + custom_ops/gpu_ops/helper.h | 3 +- custom_ops/setup_ops.py | 66 + fastdeploy/model_executor/forward_meta.py | 6 + .../model_executor/layers/activation.py | 1 + .../attention/base_attention_backend.py | 21 + .../layers/backends/__init__.py | 7 + .../layers/backends/metax/__init__.py | 21 + .../backends/metax/attention/__init__.py | 30 + .../attention/flash_attention_interface.py | 104 ++ .../metax/attention/flash_attn_backend.py | 393 ++++ .../layers/backends/metax/moe/__init__.py | 19 + .../moe/fused_moe_triton_metax_backend.py | 276 +++ .../backends/metax/moe/triton_moe_kernels.py | 187 ++ fastdeploy/model_executor/layers/linear.py | 1 + fastdeploy/model_executor/layers/moe/moe.py | 6 + .../layers/quantization/weight_only.py | 39 +- .../model_executor/layers/rotary_embedding.py | 4 + .../sample/ops/apply_penalty_multi_scores.py | 17 + .../model_executor/layers/sample/sampler.py | 1 + .../model_executor/pre_and_post_process.py | 33 + fastdeploy/platforms/__init__.py | 3 + fastdeploy/platforms/base.py | 6 + fastdeploy/platforms/maca.py | 65 + fastdeploy/worker/metax_model_runner.py | 1664 +++++++++++++++++ fastdeploy/worker/metax_worker.py | 203 ++ fastdeploy/worker/worker_process.py | 4 + requirements_metaxgpu.txt | 39 + setup.py | 6 +- 29 files changed, 3224 insertions(+), 11 deletions(-) create mode 100644 fastdeploy/model_executor/layers/backends/metax/__init__.py create mode 100644 fastdeploy/model_executor/layers/backends/metax/attention/__init__.py create mode 100644 fastdeploy/model_executor/layers/backends/metax/attention/flash_attention_interface.py create mode 100644 fastdeploy/model_executor/layers/backends/metax/attention/flash_attn_backend.py create mode 100644 fastdeploy/model_executor/layers/backends/metax/moe/__init__.py create mode 100644 fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_triton_metax_backend.py create mode 100644 fastdeploy/model_executor/layers/backends/metax/moe/triton_moe_kernels.py create mode 100644 fastdeploy/platforms/maca.py create mode 100644 fastdeploy/worker/metax_model_runner.py create mode 100644 fastdeploy/worker/metax_worker.py create mode 100644 requirements_metaxgpu.txt diff --git a/build.sh b/build.sh index aa7f40ef8..86ec3cedb 100644 --- a/build.sh +++ b/build.sh @@ -126,6 +126,16 @@ function copy_ops(){ return fi + is_maca=`$python -c "import paddle; print(paddle.device.is_compiled_with_custom_device('metax_gpu'))"` + if [ "$is_maca" = "True" ]; then + DEVICE_TYPE="metax_gpu" + mkdir -p ../fastdeploy/model_executor/ops/base + cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base + cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gpu + echo -e "MACA ops have been copy to fastdeploy" + return + fi + DEVICE_TYPE="cpu" cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base cd ../../../../ diff --git a/custom_ops/gpu_ops/helper.h b/custom_ops/gpu_ops/helper.h index ed4efe927..468aff1fc 100644 --- a/custom_ops/gpu_ops/helper.h +++ b/custom_ops/gpu_ops/helper.h @@ -509,6 +509,7 @@ static void PrintMatrix3(const T *mat_d, int num, std::string name) { } #ifndef PADDLE_WITH_HIP +#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU __forceinline__ __device__ uint32_t ld_flag_acquire(uint32_t *flag_addr, int mode = 0) { uint32_t flag; @@ -541,7 +542,7 @@ __forceinline__ __device__ void st_flag_release(uint32_t *flag_addr, "l"(flag_addr)); } } - +#endif inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { int max_shared_mem_per_block_opt_in = 0; cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index 431edfb3e..de4202bc2 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -564,6 +564,72 @@ elif paddle.is_compiled_with_custom_device("gcu"): ] ), ) +elif paddle.device.is_compiled_with_custom_device("metax_gpu"): + maca_path = os.getenv("MACA_PATH", "/opt/maca") + json_dir = "third_party/nlohmann_json" + if not os.path.exists(json_dir) or not os.listdir(json_dir): + if not os.path.exists(json_dir): + os.makedirs(json_dir) + clone_git_repo("v3.11.3", "https://gitee.com/learnlov/mirrors_nlohmann_json.git", json_dir) + if not os.listdir(json_dir): + raise ValueError("Git clone nlohmann_json failed!") + sources = [ + "gpu_ops/save_with_output.cc", + "gpu_ops/set_mask_value.cu", + "gpu_ops/set_value_by_flags.cu", + "gpu_ops/ngram_mask.cu", + "gpu_ops/gather_idx.cu", + "gpu_ops/get_output_ep.cc", + "gpu_ops/token_penalty_multi_scores.cu", + "gpu_ops/token_penalty_only_once.cu", + "gpu_ops/stop_generation.cu", + "gpu_ops/stop_generation_multi_ends.cu", + "gpu_ops/set_flags.cu", + "gpu_ops/fused_get_rope.cu", + "gpu_ops/get_padding_offset.cu", + "gpu_ops/update_inputs.cu", + "gpu_ops/update_inputs_beam.cu", + "gpu_ops/beam_search_softmax.cu", + "gpu_ops/rebuild_padding.cu", + "gpu_ops/step.cu", + "gpu_ops/step_reschedule.cu", + "gpu_ops/step_system_cache.cu", + "gpu_ops/set_data_ipc.cu", + "gpu_ops/read_data_ipc.cu", + "gpu_ops/dequant_int8.cu", + "gpu_ops/share_external_data.cu", + "gpu_ops/extract_text_token_output.cu", + "gpu_ops/moe/tritonmoe_preprocess.cu", + "gpu_ops/moe/moe_topk_select.cu", + "gpu_ops/recover_decode_task.cu", + ] + + sources += find_end_files("gpu_ops/speculate_decoding", ".cu") + sources += find_end_files("gpu_ops/speculate_decoding", ".cc") + + setup( + name="fastdeploy_ops", + ext_modules=CUDAExtension( + sources=sources, + extra_compile_args={ + "cxx": ["-O3"], + "nvcc": [ + "-O3", + "-Ithird_party/nlohmann_json/include", + "-Igpu_ops", + "-DPADDLE_DEV", + "-DPADDLE_WITH_CUSTOM_DEVICE_METAX_GPU", + ], + }, + library_dirs=[os.path.join(maca_path, "lib")], + extra_link_args=["-lruntime_cu"], + include_dirs=[ + os.path.join(maca_path, "include"), + os.path.join(maca_path, "include/mcr"), + os.path.join(maca_path, "include/common"), + ], + ), + ) else: use_bf16 = envs.FD_CPU_USE_BF16 == "True" diff --git a/fastdeploy/model_executor/forward_meta.py b/fastdeploy/model_executor/forward_meta.py index be5d7f702..da57c1672 100644 --- a/fastdeploy/model_executor/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -37,6 +37,8 @@ class ForwardMode(IntEnum): DECODE = auto() # Mixed mode MIXED = auto() + # Native mode + NATIVE = auto() def is_prefill(self): """Is Extend mode""" @@ -50,6 +52,10 @@ class ForwardMode(IntEnum): """Is Mixed mode""" return self == ForwardMode.MIXED + def is_native(self): + """Is Native mode""" + return self == ForwardMode.NATIVE + @dataclass class ForwardMeta: diff --git a/fastdeploy/model_executor/layers/activation.py b/fastdeploy/model_executor/layers/activation.py index 977a4f2f4..5e426e7ef 100644 --- a/fastdeploy/model_executor/layers/activation.py +++ b/fastdeploy/model_executor/layers/activation.py @@ -68,6 +68,7 @@ class SiluAndMul(nn.Layer): or current_platform.is_xpu() or current_platform.is_iluvatar() or current_platform.is_dcu() + or current_platform.is_maca() ): self.forward = self.forward_cuda elif current_platform.is_gcu(): diff --git a/fastdeploy/model_executor/layers/attention/base_attention_backend.py b/fastdeploy/model_executor/layers/attention/base_attention_backend.py index 492a5790d..c4b8e9313 100644 --- a/fastdeploy/model_executor/layers/attention/base_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/base_attention_backend.py @@ -86,6 +86,15 @@ class AttentionBackend(ABC): layer, forward_meta, ) + elif forward_meta.forward_mode.is_native(): + return self.forward_native_backend( + q, + k, + v, + qkv, + layer, + forward_meta, + ) else: return self.forward_extend( q, @@ -139,3 +148,15 @@ class AttentionBackend(ABC): ) -> paddle.Tensor: """Run a forward for extend.""" raise NotImplementedError + + def forward_native_backend( + self, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + qkv: paddle.Tensor, + layer: paddle.nn.Layer, + forward_meta: ForwardMeta, + ) -> paddle.Tensor: + """Run a forward for native.""" + raise NotImplementedError diff --git a/fastdeploy/model_executor/layers/backends/__init__.py b/fastdeploy/model_executor/layers/backends/__init__.py index 18d1fccfe..ddbe410d1 100644 --- a/fastdeploy/model_executor/layers/backends/__init__.py +++ b/fastdeploy/model_executor/layers/backends/__init__.py @@ -48,3 +48,10 @@ if current_platform.is_dcu(): if hasattr(dcu, "__all__"): globals().update({name: getattr(dcu, name) for name in dcu.__all__}) __all__.extend(dcu.__all__) + +if current_platform.is_maca(): + from . import metax + + if hasattr(metax, "__all__"): + globals().update({name: getattr(metax, name) for name in metax.__all__}) + __all__.extend(metax.__all__) diff --git a/fastdeploy/model_executor/layers/backends/metax/__init__.py b/fastdeploy/model_executor/layers/backends/metax/__init__.py new file mode 100644 index 000000000..365e50e8b --- /dev/null +++ b/fastdeploy/model_executor/layers/backends/metax/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .attention.flash_attn_backend import FlashAttentionBackend +from .moe.fused_moe_triton_metax_backend import MetaxTritonWeightOnlyMoEMethod + +__all__ = [ + "FlashAttentionBackend", + "MetaxTritonWeightOnlyMoEMethod", +] diff --git a/fastdeploy/model_executor/layers/backends/metax/attention/__init__.py b/fastdeploy/model_executor/layers/backends/metax/attention/__init__.py new file mode 100644 index 000000000..6874bf05f --- /dev/null +++ b/fastdeploy/model_executor/layers/backends/metax/attention/__init__.py @@ -0,0 +1,30 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +metax gpu backend attention methods +""" +from .flash_attention_interface import ( + flash_attn_func, + flash_attn_kvcache_func, + flash_attn_unpadded_func, +) +from .flash_attn_backend import FlashAttentionBackend + +__all__ = [ + "FlashAttentionBackend", + "flash_attn_func", + "flash_attn_unpadded_func", + "flash_attn_kvcache_func", +] diff --git a/fastdeploy/model_executor/layers/backends/metax/attention/flash_attention_interface.py b/fastdeploy/model_executor/layers/backends/metax/attention/flash_attention_interface.py new file mode 100644 index 000000000..f7520d238 --- /dev/null +++ b/fastdeploy/model_executor/layers/backends/metax/attention/flash_attention_interface.py @@ -0,0 +1,104 @@ +import os +from typing import Optional, Tuple, Union + +import paddle +from paddle import Tensor + +for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")): + if lib.endswith(".so"): + paddle.utils.cpp_extension.extension_utils.load_op_meta_info_and_register_op(lib) + + +def flash_attn_func( + q: Tensor, + k: Tensor, + v: Tensor, + fixed_seed_offset: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + dropout_prob: float = 0.0, + causal: bool = False, + return_softmax: bool = False, + is_test: bool = True, + rng_name: str = "", +) -> Union[Tensor, Tuple[Tensor, ...]]: + return paddle._C_ops.flash_attn( + q, k, v, fixed_seed_offset, attn_mask, dropout_prob, causal, return_softmax, is_test, rng_name + ) + + +def flash_attn_unpadded_func( + q: Tensor, + k: Tensor, + v: Tensor, + cu_seqlens_q: Tensor, + cu_seqlens_k: Tensor, + max_seqlen_q: Union[int, float], + max_seqlen_k: Union[int, float], + fixed_seed_offset: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + softmax_scale: float = 1.0, + dropout: float = 0.0, + causal: bool = False, + return_softmax: bool = False, + is_test: bool = True, + rng_name: str = "", +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + max_seqlen_q_t = paddle.to_tensor(max_seqlen_q, dtype="int64") + max_seqlen_k_t = paddle.to_tensor(max_seqlen_k, dtype="int64") + + outputs = paddle._C_ops.flash_attn_unpadded( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + fixed_seed_offset, + attn_mask, + max_seqlen_q_t, + max_seqlen_k_t, + softmax_scale, + dropout, + causal, + return_softmax, + is_test, + rng_name, + ) + return outputs + + +def flash_attn_kvcache_func( + q: Tensor, + k_cache: Tensor, + v_cache: Tensor, + seqlens_k: Tensor, + block_table: Tensor, + k: Optional[Tensor] = None, + v: Optional[Tensor] = None, + rotary_cos: Optional[Tensor] = None, + rotary_sin: Optional[Tensor] = None, + cache_batch_idx: Optional[Tensor] = None, + causal: bool = True, + is_rotary_interleaved: bool = False, + num_splits: int = 1, + dropout: float = 0.0, + return_softmax: bool = False, +) -> Tuple[Tensor, Tensor]: + out, softmax_lse = paddle._C_ops._run_custom_op( + "flash_attn_kvcache", + q, + k_cache, + v_cache, + k, + v, + seqlens_k, + rotary_cos, + rotary_sin, + cache_batch_idx, + block_table, + causal, + is_rotary_interleaved, + num_splits, + dropout, + return_softmax, + ) + return out, softmax_lse diff --git a/fastdeploy/model_executor/layers/backends/metax/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/backends/metax/attention/flash_attn_backend.py new file mode 100644 index 000000000..a67ae76e2 --- /dev/null +++ b/fastdeploy/model_executor/layers/backends/metax/attention/flash_attn_backend.py @@ -0,0 +1,393 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from __future__ import annotations + +import math +import os +from dataclasses import dataclass, field +from typing import List, Optional + +import paddle +import paddle.nn.functional as F + +from fastdeploy.config import FDConfig +from fastdeploy.model_executor.forward_meta import ForwardMeta, ForwardMode +from fastdeploy.model_executor.layers.attention.base_attention_backend import ( + AttentionBackend, + AttentionMetadata, +) +from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id +from fastdeploy.model_executor.layers.backends.metax.attention.flash_attention_interface import ( + flash_attn_kvcache_func, + flash_attn_unpadded_func, +) + + +@dataclass +class FlashAttentionMetadata(AttentionMetadata): + """ + FlashAttentionMetadata + """ + + max_len_kv: paddle.Tensor = None + set_max_lengths: int = -1 + encoder_batch_ids: paddle.Tensor = None + encoder_tile_ids_per_batch: paddle.Tensor = None + encoder_num_blocks: paddle.Tensor = None + kv_batch_ids: paddle.Tensor = None + kv_tile_ids_per_batch: paddle.Tensor = None + kv_num_blocks: paddle.Tensor = None + decoder_batch_ids: paddle.Tensor = None + decoder_tile_ids_per_batch: paddle.Tensor = None + decoder_num_blocks: paddle.Tensor = None + + _dtype: paddle.dtype = paddle.bfloat16 + encoder_max_partition_size: int = 32768 + max_partition_size: int = 32768 + block_tables: Optional[paddle.Tensor] = None + rotary_embs: Optional[paddle.Tensor] = None + attn_mask: Optional[paddle.Tensor] = None + encoder_block_shape_q: int = -1 + decoder_block_shape_q: int = -1 + _fuse_kernel_compute_dtype: str = "bf16" + + # pd_disaggregation + kv_signal_metadata: Optional[paddle.Tensor] = None + kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list) + + +class FlashAttentionBackend(AttentionBackend): + """ + FlashAttentionBackend backend implementation. + """ + + __infer_dynamic_dims_fields__ = ["attention_metadata"] + attention_metadata: FlashAttentionMetadata + + def __init__( + self, + fd_config: FDConfig, + kv_num_heads: int, + num_heads: int, + head_dim: int, + encoder_block_shape_q: int = -1, + decoder_block_shape_q: int = -1, + ) -> None: + """ + FlashAttentionBackend __init__ + """ + super().__init__() + self.attention_metadata: FlashAttentionMetadata = None + self.block_size: int = fd_config.parallel_config.block_size + self.max_seq_len: int = fd_config.parallel_config.max_model_len + self.rope_theta: float = ( + 10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta + ) + self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) + self.causal: bool = getattr(fd_config.model_config, "causal", True) + self.speculative_method: str = fd_config.speculative_config.method + self.use_speculate: bool = self.speculative_method is not None + self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens + self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp" + self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"]) + self.encoder_block_shape_q: int = encoder_block_shape_q + self.decoder_block_shape_q: int = decoder_block_shape_q + + self.kv_num_heads: int = kv_num_heads + self.num_heads: int = num_heads + self.head_dim: int = fd_config.model_config.head_dim + self.num_layers: int = fd_config.model_config.num_hidden_layers + self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", 32768)) + + self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode + + self.start_layer_index: int = fd_config.model_config.start_layer_index + + if fd_config.parallel_config.expert_parallel_rank is None: + fd_config.parallel_config.expert_parallel_rank = 0 + + self.rank, self.device_id = init_rank_and_device_id(fd_config) + + def init_attention_metadata(self, forward_meta: ForwardMeta): + """Initialize attntion metadata hence all layers in the forward pass can reuse it.""" + forward_meta.forward_mode = ForwardMode.NATIVE + return + + def get_attntion_meta(self) -> AttentionMetadata: + """get_attntion_meta""" + return self.attention_metadata + + def get_kv_cache_shape( + self, + max_num_blocks: int, + kv_cache_quant_type: str = None, + ): + """ + Caculate kv cache shape + """ + if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp": + return ( + max_num_blocks, + self.kv_num_heads, + self.block_size, + self.head_dim // 2, + ) + else: + return ( + max_num_blocks, + self.kv_num_heads, + self.block_size, + self.head_dim, + ) + + def split_qkv(self, qkv, num_head_q, num_head_kv, dim): + q = qkv[:, : num_head_q * dim].reshape([-1, num_head_q, dim]) + k = qkv[:, num_head_q * dim : num_head_q * dim + num_head_kv * dim].reshape([-1, num_head_kv, dim]) + v = qkv[:, num_head_q * dim + num_head_kv * dim :].reshape([-1, num_head_kv, dim]) + return q, k, v + + def flash_attn_varlen(self, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k): + num_head = q.shape[1] + dim = q.shape[2] + + q_ = q.reshape([-1, num_head, dim]) + k_ = k.reshape([-1, num_head, dim]) + v_ = v.reshape([-1, num_head, dim]) + + bsz = cu_seqlens_q.shape[0] - 1 + out = [] + for i in range(bsz): + start_q, end_q = cu_seqlens_q[i].item(), cu_seqlens_q[i + 1].item() + start_k, end_k = cu_seqlens_k[i].item(), cu_seqlens_k[i + 1].item() + qi = q_[start_q:end_q] # [seq_q, nh, dim] + ki = k_[start_k:end_k] # [seq_k, nh, dim] + vi = v_[start_k:end_k] # [seq_k, nh, dim] + qi = qi.transpose([1, 0, 2]) # [nh, seq_q, dim] + ki = ki.transpose([1, 2, 0]) # [nh, dim, seq_k] + vi = vi.transpose([1, 0, 2]) # [nh, seq_k, dim] + + score = paddle.matmul(qi, ki) / math.sqrt(dim) # [nh, seq_q, seq_k] + prob = F.softmax(score, axis=-1) + o = paddle.matmul(prob, vi) # [nh, seq_q, dim] + o = o.transpose([1, 0, 2]) # [seq_q, nh, dim] + out.append(o) + + return paddle.concat(out, axis=0) # [total_q, nh, dim] + + def flash_attn_with_kvcache(self, q, cache_k, cache_v, cache_seqlens, block_tables=None): + bs, _, nh, dim = q.shape + out = [] + for i in range(bs): + q_i = q[i] # [1, nh, dim] + k_i = cache_k[i, : cache_seqlens[i, 0]] # [seqlen, nh, dim] + v_i = cache_v[i, : cache_seqlens[i, 0]] + qi = q_i.transpose([1, 0, 2]) # [nh, 1, dim] + ki = k_i.transpose([1, 2, 0]) # [nh, dim, seqlen] + vi = v_i.transpose([1, 0, 2]) # [nh, seqlen, dim] + score = paddle.matmul(qi, ki) / math.sqrt(dim) + prob = F.softmax(score, axis=-1) + o = paddle.matmul(prob, vi).transpose([1, 0, 2]) # [1, nh, dim] + out.append(o) + return paddle.concat(out, axis=0) # [bs, nh, dim] + + def block_cache_to_naive_cache(slef, cache_k, cache_v, bsz, block_tables, cache_seq_len): + _, num_head, blocksize, dim_head = cache_k.shape + out_cache_k = paddle.zeros(shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_k.dtype) + out_cache_v = paddle.zeros(shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_v.dtype) + for i in range(bsz): + for j in range(cache_seq_len): + out_cache_k[i, :, j, :] = cache_k[block_tables[i, j // blocksize], :, j % blocksize, :] + out_cache_v[i, :, j, :] = cache_v[block_tables[i, j // blocksize], :, j % blocksize, :] + return out_cache_k, out_cache_v + + def block_cache_to_naive_cache__(self, cache_k, cache_v, bsz, block_tables, max_cache_seq_len): + _, num_head, blocksize, dim_head = cache_k.shape + out_cache_k = paddle.zeros(shape=[bsz, max_cache_seq_len + 1, num_head, dim_head], dtype=cache_k.dtype) + out_cache_v = paddle.zeros(shape=[bsz, max_cache_seq_len + 1, num_head, dim_head], dtype=cache_v.dtype) + for i in range(bsz): + for j in range(max_cache_seq_len): + out_cache_k[i, j, :, :] = cache_k[block_tables[i, j // blocksize], :, j % blocksize, :] + out_cache_v[i, j, :, :] = cache_v[block_tables[i, j // blocksize], :, j % blocksize, :] + return out_cache_k, out_cache_v + + def update_encoder_kv_cache(self, k, v, seq_lens_encoder, cache_k, cache_v, block_tables): + _, num_head, blocksize, dim_head = cache_k.shape + offset = 0 + for batch_idx, seq_len in enumerate(seq_lens_encoder.numpy()): + if seq_len == 0: + continue + for seq_idx in range(seq_len): + block_id = block_tables[batch_idx, seq_idx // blocksize] + assert block_id != -1 + index = offset + seq_idx + cache_k[block_id, :, seq_idx % blocksize, :] = k[index, :, :] + cache_v[block_id, :, seq_idx % blocksize, :] = v[index, :, :] + + offset += seq_len + + def update_decoder_kv_cache(self, k, v, seq_lens_decoder, cache_k, cache_v, block_tables): + _, num_head, blocksize, dim_head = cache_k.shape + for batch_idx, seq_idx in enumerate(seq_lens_decoder.numpy()): + if seq_idx == 0: + continue + block_id = block_tables[batch_idx, seq_idx // blocksize] + assert block_id != -1 + cache_k[block_id, :, seq_idx % blocksize, :] = k[batch_idx, :, :] + cache_v[block_id, :, seq_idx % blocksize, :] = v[batch_idx, :, :] + + def apply_rope(self, qk, cos, sin): + rotate_half = paddle.reshape( + paddle.stack([-qk[..., 1::2], qk[..., 0::2]], axis=-1), + paddle.shape(qk), + ) + out = paddle.add(paddle.multiply(qk, cos), paddle.multiply(rotate_half, sin)) + return paddle.cast(out, qk.dtype) + + def forward_native_backend( + self, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + qkv: paddle.Tensor, + layer, + forward_meta: ForwardMeta, + ): + + bsz = forward_meta.seq_lens_this_time.shape[0] + num_head_q, num_head_kv, dim = layer.num_heads, layer.kv_num_heads, layer.head_dim + + # 1. 分离 encoder / decoder 的 mask + seq_lens_encoder = forward_meta.seq_lens_encoder.squeeze(-1) + seq_lens_decoder = forward_meta.seq_lens_decoder.squeeze(-1) + seq_lens_this_time = forward_meta.seq_lens_this_time.squeeze(-1) + encoder_indices = [] + decoder_indices = [] + + offset = 0 + for i in range(bsz): + length = seq_lens_this_time[i].item() + if seq_lens_encoder[i] > 0: + encoder_indices.extend(range(offset, offset + length)) + elif seq_lens_decoder[i] > 0: + decoder_indices.extend(range(offset, offset + length)) + offset += length + + encoder_indices = paddle.to_tensor(encoder_indices, dtype="int32") + decoder_indices = paddle.to_tensor(decoder_indices, dtype="int32") + + encoder_qkv = paddle.index_select(qkv, encoder_indices, axis=0) + decoder_qkv = paddle.index_select(qkv, decoder_indices, axis=0) + + # 2. 分解 encoder 和 decoder 的 qkv + encoder_q, encoder_k, encoder_v = self.split_qkv(encoder_qkv, num_head_q, num_head_kv, dim) + decoder_q, decoder_k, decoder_v = self.split_qkv(decoder_qkv, num_head_q, num_head_kv, dim) + cache_k = forward_meta.caches[2 * layer.layer_id] + cache_v = forward_meta.caches[2 * layer.layer_id + 1] + + # 3. Rotary Embedding + if decoder_q.numel() != 0 or encoder_q.numel() != 0: + for batch_idx in range(forward_meta.seq_lens_this_time.shape[0]): + seq_len_i = forward_meta.seq_lens_this_time[batch_idx] + if seq_len_i == 0: + continue + cached_kv_len = seq_lens_decoder[batch_idx] + cu_seq_start_q = forward_meta.cu_seqlens_q[batch_idx] + cu_seq_end_q = forward_meta.cu_seqlens_q[batch_idx + 1] + if forward_meta.rotary_embs is not None and cu_seq_end_q > cu_seq_start_q: + cos = forward_meta.rotary_embs[0, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :] + sin = forward_meta.rotary_embs[1, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :] + + def rope_func(qk): + qk[cu_seq_start_q:cu_seq_end_q] = self.apply_rope(qk[cu_seq_start_q:cu_seq_end_q], cos, sin) + + if encoder_q.numel() != 0: + rope_func(encoder_q) + rope_func(encoder_k) + if decoder_q.numel() != 0: + rope_func(decoder_q) + rope_func(decoder_k) + + # 4. Flash Attention for encoder + encoder_v = encoder_v + cu_seqlens_q = forward_meta.cu_seqlens_q + cu_seqlens_k = forward_meta.cu_seqlens_k + max_seqlen_q = paddle.max(seq_lens_this_time) + max_seqlen_k = max_seqlen_q + + if encoder_q.numel() > 0: + encoder_out = flash_attn_unpadded_func( + encoder_q, + encoder_k, + encoder_v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + attn_mask=forward_meta.attn_mask, + causal=self.causal, + ) + self.update_encoder_kv_cache( + encoder_k, encoder_v, seq_lens_encoder, cache_k, cache_v, forward_meta.block_tables + ) + else: + encoder_out = None + + # 5. decoder attention with kv cache + bs = decoder_q.shape[0] + decoder_q = decoder_q.reshape([bs, 1, num_head_q, dim]) + decoder_k_ = decoder_k.reshape([bs, 1, num_head_kv, dim]) + decoder_v_ = decoder_v.reshape([bs, 1, num_head_kv, dim]) + cache_seqlens = paddle.index_select(forward_meta.seq_lens_decoder, decoder_indices, axis=0) + + # 5.1 convert paged kv cache to continuous cache + if decoder_q.numel() > 0: + max_cache_seq_len = paddle.max(cache_seqlens) + c_cache_k, c_cache_v = self.block_cache_to_naive_cache__( + cache_k, cache_v, bs, forward_meta.block_tables, max_cache_seq_len + ) + decoder_out = flash_attn_kvcache_func( + decoder_q, + c_cache_k, + c_cache_v, + cache_seqlens.squeeze(-1), + None, + decoder_k_, + decoder_v_, + causal=self.causal, + ) + self.update_decoder_kv_cache( + decoder_k, decoder_v, seq_lens_decoder, cache_k, cache_v, forward_meta.block_tables + ) + else: + decoder_out = None + + # 6. 拼接 encoder_out 和 decoder_out + total_len = qkv.shape[0] + out = paddle.zeros([total_len, num_head_q, dim]) + if encoder_out is not None: + out = paddle.tensor.put_along_axis( + out, encoder_indices.unsqueeze(-1).unsqueeze(-1), encoder_out[0], axis=0 + ) + if decoder_out is not None: + new_decoder_out = decoder_out[0].squeeze(1) + out = paddle.tensor.put_along_axis( + out, decoder_indices.unsqueeze(-1).unsqueeze(-1), new_decoder_out, axis=0 + ) + + out.reshape_([total_len, num_head_q * dim]) + + return out diff --git a/fastdeploy/model_executor/layers/backends/metax/moe/__init__.py b/fastdeploy/model_executor/layers/backends/metax/moe/__init__.py new file mode 100644 index 000000000..0fd201bd1 --- /dev/null +++ b/fastdeploy/model_executor/layers/backends/metax/moe/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .triton_moe_kernels import fused_moe_kernel_paddle + +__all__ = [ + "fused_moe_kernel_paddle", +] diff --git a/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_triton_metax_backend.py b/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_triton_metax_backend.py new file mode 100644 index 000000000..50ceecf18 --- /dev/null +++ b/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_triton_metax_backend.py @@ -0,0 +1,276 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import paddle +from paddle import nn + +import fastdeploy +from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase +from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess +from fastdeploy.utils import ceil_div + +from .triton_moe_kernels import fused_moe_kernel_paddle + + +class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase): + """ + Use Triton Group Gemm to compute Fused MoE. + """ + + def __init__(self, quant_config=None): + """ + Triton Group Gemm to compute Fused MoE. + """ + self.quant_config = quant_config + self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"] + self.added_scale_attrs = [ + "up_gate_proj_weight_scale", + "down_proj_weight_scale", + ] + + def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None: + """process_prequanted_weights""" + pass + + def create_weights(self, layer: nn.Layer, state_dict): + """ + Triton MoE create weight process. + """ + up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict) + assert len(up_gate_proj_weights) == layer.num_local_experts + assert len(down_proj_weights) == layer.num_local_experts + + if layer.quant_method.quant_config: + algo = layer.quant_method.quant_config.name() + + assert up_gate_proj_weights[0].shape == [ + layer.hidden_size, + layer.moe_intermediate_size * 2, + ] + assert down_proj_weights[0].shape == [ + layer.moe_intermediate_size, + layer.hidden_size, + ] + + up_gate_proj_tensor = paddle.stack(up_gate_proj_weights, axis=0) + down_proj_tensor = paddle.stack(down_proj_weights, axis=0) + + if algo == "wint8": + max_bound = 127 + elif algo == "wint4": + max_bound = 7 + + for idx, weight_tensor in enumerate([up_gate_proj_tensor, down_proj_tensor]): + weight_name = self.added_weight_attrs[idx] + scale_name = self.added_scale_attrs[idx] + + quanted_weight_scale = weight_tensor.abs().max(axis=1) + if self.quant_config is not None: + quanted_weight = weight_tensor / quanted_weight_scale[:, None, :] * max_bound + quanted_weight = paddle.round(quanted_weight).astype("int8") + quanted_weight_scale = quanted_weight_scale / max_bound + + setattr( + layer, + weight_name, + layer.create_parameter( + shape=quanted_weight.shape, + dtype=quanted_weight.dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + getattr(layer, weight_name).set_value(quanted_weight) + + setattr( + layer, + scale_name, + layer.create_parameter( + shape=quanted_weight_scale.shape, + dtype=quanted_weight_scale.dtype, + ), + ) + getattr(layer, scale_name).set_value(quanted_weight_scale) + else: + setattr( + layer, + weight_name, + layer.create_parameter( + shape=quanted_weight.shape, + dtype=quanted_weight.dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + getattr(layer, weight_name).set_value(quanted_weight) + + setattr( + layer, + scale_name, + layer.create_parameter( + shape=quanted_weight_scale.shape, + dtype=quanted_weight_scale.dtype, + ), + ) + getattr(layer, scale_name).set_value(quanted_weight_scale) + + def apply( + self, + layer: nn.Layer, + x: paddle.Tensor, + gate_out: paddle.Tensor, + ) -> paddle.Tensor: + """ + Triton compute Fused MoE. + """ + token_num = x.shape[0] + top_k = layer.top_k + num_local_experts = layer.num_local_experts + top_k = layer.top_k + moe_intermediate_size = layer.moe_intermediate_size + hidden_size = layer.hidden_size + + topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( + gate_out, + layer.gate_correction_bias, + top_k, + True, # apply_norm_weight, + False, + ) + up_gate_proj_out = paddle.empty( + [token_num * top_k, moe_intermediate_size * 2], + dtype=x.dtype, + ) + + if self.quant_config is not None: + config = { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + } + else: + config = { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + } + + sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess( + topk_ids, num_local_experts, config["BLOCK_SIZE_M"] + ) + max_possible_num_post_padded = sorted_token_ids.shape[0] + grid = ( + ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) + * ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]), + ) + + fused_moe_kernel_paddle[grid]( + x, + layer.up_gate_proj_weight, + up_gate_proj_out, + None, + layer.up_gate_proj_weight_scale, + None, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + max_possible_num_post_padded, + token_num * top_k, + N=moe_intermediate_size * 2, + K=hidden_size, + stride_am=x.strides[0], + stride_ak=x.strides[1], + stride_be=layer.up_gate_proj_weight.strides[0], + stride_bk=layer.up_gate_proj_weight.strides[1], + stride_bn=layer.up_gate_proj_weight.strides[2], + stride_cm=up_gate_proj_out.strides[0], + stride_cn=up_gate_proj_out.strides[1], + # + stride_asm=-1, + stride_ask=-1, + stride_bse=layer.up_gate_proj_weight_scale.strides[0], + stride_bsk=-1, + stride_bsn=layer.up_gate_proj_weight_scale.strides[1], + group_n=-1, + group_k=-1, + # Meta-parameters + BLOCK_SIZE_M=config["BLOCK_SIZE_M"], + BLOCK_SIZE_N=config["BLOCK_SIZE_N"], + BLOCK_SIZE_K=config["BLOCK_SIZE_K"], + GROUP_SIZE_M=config["GROUP_SIZE_M"], + MUL_ROUTED_WEIGHT=False, + top_k=top_k, + compute_type_enum=1, + use_fp8_w8a8=False, + use_int8_w8a16=True, + even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0, + ) + + down_proj_input = paddle.incubate.nn.functional.swiglu(up_gate_proj_out) + + down_proj_out = paddle.empty( + (token_num * top_k, hidden_size), + dtype=x.dtype, + ) + + grid = ( + ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"]) + * ceil_div(hidden_size, config["BLOCK_SIZE_N"]), + ) + fused_moe_kernel_paddle[grid]( + down_proj_input, + layer.down_proj_weight, + down_proj_out, + None, + layer.down_proj_weight_scale, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + max_possible_num_post_padded, + token_num * top_k, + N=hidden_size, + K=moe_intermediate_size, + stride_am=down_proj_input.strides[0], + stride_ak=down_proj_input.strides[1], + stride_be=layer.down_proj_weight.strides[0], + stride_bk=layer.down_proj_weight.strides[1], + stride_bn=layer.down_proj_weight.strides[2], + stride_cm=down_proj_out.strides[0], + stride_cn=down_proj_out.strides[1], + stride_asm=-1, + stride_ask=-1, + stride_bse=layer.down_proj_weight_scale.strides[0], + stride_bsk=-1, + stride_bsn=layer.down_proj_weight_scale.strides[1], + group_n=-1, + group_k=-1, + # Meta-parameters + BLOCK_SIZE_M=config["BLOCK_SIZE_M"], + BLOCK_SIZE_N=config["BLOCK_SIZE_N"], + BLOCK_SIZE_K=config["BLOCK_SIZE_K"], + GROUP_SIZE_M=config["GROUP_SIZE_M"], + MUL_ROUTED_WEIGHT=True, + top_k=1, + compute_type_enum=1, + use_fp8_w8a8=False, + use_int8_w8a16=True, + even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0, + ) + + down_proj_out.reshape_([token_num, top_k, hidden_size]) + out = down_proj_out.sum(axis=1) + return out diff --git a/fastdeploy/model_executor/layers/backends/metax/moe/triton_moe_kernels.py b/fastdeploy/model_executor/layers/backends/metax/moe/triton_moe_kernels.py new file mode 100644 index 000000000..e859e7ce4 --- /dev/null +++ b/fastdeploy/model_executor/layers/backends/metax/moe/triton_moe_kernels.py @@ -0,0 +1,187 @@ +""" +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import triton +import triton.language as tl + + +@triton.jit +def fused_moe_kernel_paddle( + a_ptr, + b_ptr, + c_ptr, + a_scale_ptr, + b_scale_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + max_possible_num_post_padded, + num_valid_tokens, + N, + K, + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Block size for block-wise fp8 quantization + group_n: tl.constexpr, + group_k: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type_enum: tl.constexpr, + use_fp8_w8a8: tl.constexpr, + use_int8_w8a16: tl.constexpr, + even_Ks: tl.constexpr, +): + """ + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(max_possible_num_post_padded, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + assert compute_type_enum == 1 + compute_type = tl.bfloat16 + + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak) + + off_experts = tl.load(expert_ids_ptr + pid_m) + b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + if use_int8_w8a16: + b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn + b_scale = tl.load(b_scale_ptrs) + + if use_fp8_w8a8: + if group_k > 0 and group_n > 0: + a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm + offs_bsn = offs_bn // group_n + b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn + else: + # (Zkk): every expert has one activation scale and weight scale. + a_scale = tl.load(a_scale_ptr + off_experts) + b_scale = tl.load(b_scale_ptr + off_experts) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + if even_Ks: + a = tl.load( + a_ptrs, + mask=token_mask[:, None], + other=0.0, + ) + b = tl.load(b_ptrs, cache_modifier=".ca", eviction_policy="evict_first") + else: + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + + # We accumulate along the K dimension. + if use_int8_w8a16: + accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) + elif use_fp8_w8a8: + if group_k > 0 and group_n > 0: + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_scale = tl.load( + a_scale_ptrs + offs_ks * stride_ask, + mask=token_mask, + other=0.0, + ) + b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) + + accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] + else: + accumulator = tl.dot(a, b, acc=accumulator) + else: + accumulator += tl.dot(a, b) + + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) + accumulator = accumulator * moe_weight[:, None] + if use_int8_w8a16: + accumulator = (accumulator * b_scale).to(compute_type) + elif use_fp8_w8a8: + if group_k > 0 and group_n > 0: + accumulator = accumulator.to(compute_type) + else: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) + else: + accumulator = accumulator.to(compute_type) + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + + tl.store(c_ptrs, accumulator, mask=c_mask) diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index fe8910211..eb908bc0a 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -107,6 +107,7 @@ class LinearBase(nn.Layer): or current_platform.is_iluvatar() or current_platform.is_gcu() or current_platform.is_dcu() + or current_platform.is_maca() ): self.forward = self.forward_cuda else: diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 16b75e9e2..069ee3d04 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -49,6 +49,12 @@ def get_moe_method(): from fastdeploy.model_executor.layers.backends import GCUFusedMoeMethod return GCUFusedMoeMethod(None) + elif current_platform.is_maca(): + from fastdeploy.model_executor.layers.backends import ( + MetaxTritonWeightOnlyMoEMethod, + ) + + return MetaxTritonWeightOnlyMoEMethod(None) raise NotImplementedError diff --git a/fastdeploy/model_executor/layers/quantization/weight_only.py b/fastdeploy/model_executor/layers/quantization/weight_only.py index a221dca10..4825faaf7 100644 --- a/fastdeploy/model_executor/layers/quantization/weight_only.py +++ b/fastdeploy/model_executor/layers/quantization/weight_only.py @@ -94,6 +94,16 @@ class WeightOnlyConfig(QuantConfigBase): ) return DCUWeightOnlyLinearMethod(self) + elif current_platform.is_maca(): + if isinstance(layer, FusedMoE): + from fastdeploy.model_executor.layers.backends import ( + MetaxTritonWeightOnlyMoEMethod, + ) + + return MetaxTritonWeightOnlyMoEMethod(self) + else: + + return GPUWeightOnlyLinearMethod(self) else: if isinstance(layer, FusedMoE): if layer.use_method == "cutlass": @@ -196,14 +206,24 @@ class WeightOnlyLinearMethod(QuantMethodBase): raise NotImplementedError def apply(self, layer, x): - linear_out = weight_only_linear( - x, - weight=layer.weight, - bias=layer.bias if layer.add_bias else None, - weight_scale=layer.weight_scale, - weight_dtype=("int8" if self.quant_config.name() == "wint8" else "int4"), - arch=self.quant_config.weight_only_linear_arch, - ) + if current_platform.is_maca(): + linear_out = weight_only_linear( + x, + weight=layer.weight, + bias=layer.bias if layer.add_bias else None, + weight_scale=layer.weight_scale, + weight_dtype=("int8" if self.quant_config.name() == "wint8" else "int4"), + arch=80, + ) + else: + linear_out = weight_only_linear( + x, + weight=layer.weight, + bias=layer.bias if layer.add_bias else None, + weight_scale=layer.weight_scale, + weight_dtype=("int8" if self.quant_config.name() == "wint8" else "int4"), + arch=self.quant_config.weight_only_linear_arch, + ) return linear_out @@ -240,6 +260,7 @@ class GPUWeightOnlyLinearMethod(WeightOnlyLinearMethod): algo=self.quant_config.algo, arch=self.quant_config.weight_only_linear_arch, ) - + if current_platform.is_maca(): + quanted_weight_tensor = paddle.transpose(quanted_weight_tensor, [1, 0]) layer.weight.set_value(quanted_weight_tensor) layer.weight_scale.set_value(weight_scale_tensor.astype(paddle.get_default_dtype())) diff --git a/fastdeploy/model_executor/layers/rotary_embedding.py b/fastdeploy/model_executor/layers/rotary_embedding.py index 4c06feeab..c0e2b5a14 100644 --- a/fastdeploy/model_executor/layers/rotary_embedding.py +++ b/fastdeploy/model_executor/layers/rotary_embedding.py @@ -51,6 +51,10 @@ class ErnieRotaryEmbedding: # shape: [B, S, D] rot_emb = paddle.concat([freqs.cos(), freqs.sin()], axis=-1) return rot_emb + elif paddle.is_compiled_with_custom_device("metax_gpu"): + # shape: [B, S, D] + rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim), dtype="float32") + emb = paddle.stack([freqs, freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim)) else: # shape: [B, S, D/2] rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32") diff --git a/fastdeploy/model_executor/layers/sample/ops/apply_penalty_multi_scores.py b/fastdeploy/model_executor/layers/sample/ops/apply_penalty_multi_scores.py index 06c7ece76..e66db93ba 100644 --- a/fastdeploy/model_executor/layers/sample/ops/apply_penalty_multi_scores.py +++ b/fastdeploy/model_executor/layers/sample/ops/apply_penalty_multi_scores.py @@ -119,6 +119,23 @@ def apply_penalty_multi_scores( min_dec_lens, eos_token_ids, ) + elif current_platform.is_maca(): + from fastdeploy.model_executor.ops.gpu import get_token_penalty_multi_scores + + logits = get_token_penalty_multi_scores( + pre_token_ids, + prompt_ids, + prompt_lens, + logits, + repetition_penalties, + frequency_penalties, + presence_penalties, + temperature, + bad_words_token_ids, + step_idx, + min_dec_lens, + eos_token_ids, + ) else: raise NotImplementedError diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 6b08e02e9..cece8f870 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -177,6 +177,7 @@ class Sampler(nn.Layer): or current_platform.is_iluvatar() or current_platform.is_gcu() or current_platform.is_dcu() + or current_platform.is_maca() ): self.forward = self.forward_cuda else: diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index f622a6e39..b26746e74 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -45,6 +45,14 @@ elif current_platform.is_dcu(): step_paddle, update_inputs, ) +elif current_platform.is_maca(): + from fastdeploy.model_executor.ops.gpu import ( + get_padding_offset, + save_output, + set_stop_value_multi_ends, + step_paddle, + update_inputs, + ) else: from fastdeploy.model_executor.ops.gpu import ( get_padding_offset, @@ -225,6 +233,19 @@ def post_process_normal( model_output.stop_seqs_len, False, ) # multi ends + elif current_platform.is_maca(): + set_stop_value_multi_ends( + sampler_output.sampled_token_ids, + model_output.stop_flags, + model_output.seq_lens_this_time, + model_output.eos_token_id, + model_output.next_tokens, + model_output.pre_ids, + model_output.step_idx, + model_output.stop_token_ids, + model_output.stop_seqs_len, + False, + ) # multi ends else: set_stop_value_multi_ends( sampler_output.sampled_token_ids, @@ -573,6 +594,18 @@ def rebuild_padding( output_padding_offset, max_input_length, ) + elif current_platform.is_maca(): + from fastdeploy.model_executor.ops.gpu import rebuild_padding + + hidden_states = rebuild_padding( + tmp_out, + cum_offsets, + seq_len_this_time, + seq_lens_decoder, + seq_lens_encoder, + output_padding_offset, + max_input_length, + ) else: raise RuntimeError("Not supported platform") return hidden_states diff --git a/fastdeploy/platforms/__init__.py b/fastdeploy/platforms/__init__.py index 849005f48..adf5a3ad7 100644 --- a/fastdeploy/platforms/__init__.py +++ b/fastdeploy/platforms/__init__.py @@ -23,6 +23,7 @@ from .cuda import CUDAPlatform from .dcu import DCUPlatform from .gcu import GCUPlatform from .iluvatar import IluvatarPlatform +from .maca import MACAPlatform from .npu import NPUPlatform from .xpu import XPUPlatform @@ -46,6 +47,8 @@ def __getattr__(name: str): _current_platform = IluvatarPlatform() elif paddle.is_compiled_with_custom_device("gcu"): _current_platform = GCUPlatform() + elif paddle.is_compiled_with_custom_device("metax_gpu"): + _current_platform = MACAPlatform() else: _current_platform = CPUPlatform() return _current_platform diff --git a/fastdeploy/platforms/base.py b/fastdeploy/platforms/base.py index 6f4f235b8..974ab60d7 100644 --- a/fastdeploy/platforms/base.py +++ b/fastdeploy/platforms/base.py @@ -77,6 +77,12 @@ class Platform: """ return paddle.is_compiled_with_custom_device("gcu") + def is_maca(self) -> bool: + """ + whether platform is metax gpu + """ + return paddle.is_compiled_with_custom_device("metax_gpu") + @classmethod def get_attention_backend_cls(self, selected_backend): """Get the attention backend""" diff --git a/fastdeploy/platforms/maca.py b/fastdeploy/platforms/maca.py new file mode 100644 index 000000000..f695a3d01 --- /dev/null +++ b/fastdeploy/platforms/maca.py @@ -0,0 +1,65 @@ +""" +# Copyright (c) 2025 MetaX-tech Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +""" +maca platform file +""" + +import paddle +from paddleformers.utils.log import logger + +from .base import Platform, _Backend + + +class MACAPlatform(Platform): + """ + maca platform class + """ + + device_name = "metax_gpu" + + @classmethod + def available(self): + """ + Check whether MACA is available. + """ + try: + assert len(paddle.static.cuda_places()) > 0 + return True + except Exception as e: + logger.warning( + "You are using GPU version PaddlePaddle, but there is no GPU " + "detected on your machine. Maybe CUDA devices is not set properly." + f"\n Original Error is {e}" + ) + return False + + @classmethod + def get_attention_backend_cls(cls, selected_backend: _Backend): + """ + get_attention_backend_cls + """ + if selected_backend == _Backend.NATIVE_ATTN: + logger.info("Using NATIVE ATTN backend.") + return "fastdeploy.model_executor.layers.attention.PaddleNativeAttnBackend" + elif selected_backend == _Backend.APPEND_ATTN: + logger.info("Using FLASH ATTN backend to instead of attend attention.") + return "fastdeploy.model_executor.layers.backends.metax.attention.flash_attn_backend.FlashAttentionBackend" + else: + raise ValueError( + "Invalid attention backend you specified.\n" + "Now only support [NATIVE_ATTN, MLA_ATTN, APPEND_ATTN] in cuda place." + ) diff --git a/fastdeploy/worker/metax_model_runner.py b/fastdeploy/worker/metax_model_runner.py new file mode 100644 index 000000000..d0a820dbd --- /dev/null +++ b/fastdeploy/worker/metax_model_runner.py @@ -0,0 +1,1664 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import os +import time +from typing import List, Optional + +import numpy as np +import paddle +from paddle import nn +from paddleformers.utils.log import logger + +from fastdeploy.config import FDConfig +from fastdeploy.engine.request import Request, RequestType +from fastdeploy.model_executor.graph_optimization.utils import ( + profile_run_guard, + sot_warmup_guard, +) +from fastdeploy.model_executor.guided_decoding import get_guided_backend +from fastdeploy.model_executor.guided_decoding.base_guided_decoding import ( + LogitsProcessorBase, +) +from fastdeploy.model_executor.layers.attention import get_attention_backend +from fastdeploy.model_executor.layers.attention.base_attention_backend import ( + AttentionBackend, +) +from fastdeploy.model_executor.layers.rotary_embedding import get_rope, get_rope_3d +from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata +from fastdeploy.model_executor.layers.sample.sampler import Sampler, SpeculativeSampler +from fastdeploy.model_executor.model_loader import get_model_loader +from fastdeploy.model_executor.ops.gpu import ( + recover_decode_task, + set_value_by_flags_and_idx, + share_external_data, +) +from fastdeploy.model_executor.pre_and_post_process import ( + post_process, + pre_process, + rebuild_padding, + step_cuda, +) +from fastdeploy.platforms import current_platform + +if not current_platform.is_dcu(): + from fastdeploy.spec_decode import MTPProposer, NgramProposer + +from fastdeploy import envs +from fastdeploy.input.mm_processor import DataProcessor +from fastdeploy.model_executor.forward_meta import ForwardMeta +from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp +from fastdeploy.worker.model_runner_base import ModelRunnerBase +from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput + + +class MetaxModelRunner(ModelRunnerBase): + def __init__( + self, + fd_config: FDConfig, + device: str, # logic device + device_id: int, # physical device id + rank: int, + local_rank: int, + ): + super().__init__(fd_config=fd_config, device=device) + self.enable_mm = self.model_config.enable_mm + self.rank = rank + self.local_rank = local_rank + self.device_id = device_id + self.speculative_method = self.fd_config.speculative_config.method + self.speculative_decoding = self.speculative_method is not None + self.enable_logprob = fd_config.model_config.enable_logprob + self.enable_early_stop = self.fd_config.early_stop_config.enable_early_stop + + self.guided_backend = None + if self.fd_config.parallel_config.guided_decoding_backend != "off": + self.guided_backend = get_guided_backend(fd_config=self.fd_config) + + # VL model config: + if self.enable_mm: + self._init_image_preprocess() + + self.amp_black = [ + "reduce_sum", + "c_softmax_with_cross_entropy", + "elementwise_div", + "sin", + "cos", + "sort", + "multinomial", + ] + self.amp_white = [ + "lookup_table", + "lookup_table_v2", + "flash_attn", + "matmul", + "matmul_v2", + "fused_gemm_epilogue", + ] + # Sampler + if not self.speculative_decoding: + self.sampler = Sampler(fd_config) + else: + self.sampler = SpeculativeSampler(fd_config) + + # Lazy initialize kv cache after model loading + # self.kv_caches: list[paddle.Tensor] = [] + + # Cuda Graph + self.graph_opt_level = self.graph_opt_config.graph_opt_level + self.use_cudagraph = self.graph_opt_config.use_cudagraph + self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes)) + self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes + + # Initialize share inputs + self._init_share_inputs(self.parallel_config.max_num_seqs) + self.infer_seed_increment = paddle.full( + shape=[self.parallel_config.max_num_seqs, 1], + fill_value=4, + dtype="int64", + ) + self.restore_chunked_prefill_request = dict() + + # Initialize attention Backend + # NOTE(gonshaotian): Currently, all attention layers share one attention backend instance. + # In the future, we will expand it as a list. + self.attn_backends: list[AttentionBackend] = [] + # self.attn_metadatas: list[AttentionMetadata] = [] + self.initialize_attn_backend() + + # Forward meta store the global meta information of the forward + self.forward_meta: ForwardMeta = None + + # Postprocess Env params + os.environ["INFERENCE_MSG_QUEUE_ID"] = str( + self.local_rank + int(self.parallel_config.engine_worker_queue_port) + ) + + def exist_prefill(self): + """ + check whether prefill stage exist + """ + if int(paddle.max(self.share_inputs["seq_lens_encoder"])) != 0: + return 1 + else: + return 0 + + def _init_speculative_proposer(self): + """ + Init speculative proposer + """ + if self.speculative_method == "ngram": + self.proposer = NgramProposer(self.fd_config) + elif self.speculative_method == "mtp": + self.proposer = MTPProposer( + self.fd_config, + self.get_model(), + self.local_rank, + self.device_id, + self.share_inputs, + ) + else: + self.proposer = None + + def _init_logits_processor(self, request): + """ + init logits processor for guided decoding + """ + assert self.guided_backend is not None, ( + "guided_backend is None, use " "--guided-decoding-backend to specify the backend at server startup." + ) + + if request.guided_json is not None: + schemata_key = ("json", request.guided_json) + elif request.guided_regex is not None: + schemata_key = ("regex", request.guided_regex) + elif request.guided_grammar is not None: + schemata_key = ("grammar", request.guided_grammar) + elif request.structural_tag is not None: + schemata_key = ("structural_tag", request.structural_tag) + + return self.guided_backend.get_logits_processor(schemata_key=schemata_key), schemata_key + + def insert_tasks_v1(self, req_dicts: List[Request]): + """ + Process scheduler output tasks, used when ENABLE_V1_KVCACHE_SCHEDULER=1 + """ + # NOTE(luotingdan): Lazy initialize kv cache + if "caches" not in self.share_inputs: + self.initialize_kv_cache() + + req_len = len(req_dicts) + has_prefill_task = False + for i in range(req_len): + request = req_dicts[i] + idx = request.idx + if request.task_type.value == RequestType.PREFILL.value: # prefill task + logger.debug(f"Handle prefill request {request} at idx {idx}") + prefill_start_index = request.prefill_start_index + prefill_end_index = request.prefill_end_index + length = prefill_end_index - prefill_start_index + if self.enable_mm: + inputs = request.multimodal_inputs + if request.with_image: + vision_inputs = {} + vision_inputs["input_ids"] = paddle.to_tensor( + inputs["input_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64 + ) + vision_inputs["token_type_ids"] = paddle.to_tensor( + inputs["token_type_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64 + ) + vision_inputs["image_type_ids"] = paddle.to_tensor( + inputs["image_type_ids"][request.image_type_ids_start : request.image_type_ids_end], + dtype=paddle.int64, + ) + vision_inputs["images"] = paddle.to_tensor( + inputs["images"][request.image_start : request.image_end], dtype="uint8" + ) + vision_inputs["grid_thw"] = paddle.to_tensor( + inputs["grid_thw"][request.num_image_start : request.num_image_end], dtype="int64" + ) + self.share_inputs["image_features"] = self.extract_vision_features(vision_inputs) + else: + self.share_inputs["image_features"] = None + + if inputs["position_ids"] is not None: + position_ids = paddle.to_tensor( + request.multimodal_inputs["position_ids"], + dtype="int64", + ).unsqueeze([0]) + else: + position_ids = None + + enable_thinking = request.get("enable_thinking", True) + enable_thinking = enable_thinking if enable_thinking is not None else True + self.share_inputs["enable_thinking"][:] = enable_thinking + self.share_inputs["need_think_end"][idx : idx + 1, :] = 1 if enable_thinking else 0 + self.share_inputs["reasoning_index"][idx : idx + 1, :] = request.get("reasoning_max_tokens", 2048) + self.share_inputs["rope_emb"][idx : idx + 1, :] = self.prepare_rope3d( + position_ids, request.get("max_tokens", 2048) + ) + + input_ids = request.prompt_token_ids + request.output_token_ids + self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array( + input_ids[prefill_start_index:prefill_end_index] + ) + encoder_block_num = len(request.block_tables) + self.share_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num + self.share_inputs["block_tables"][idx : idx + 1, :] = -1 + self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( + request.block_tables, dtype="int32" + ) + self.share_inputs["stop_flags"][idx : idx + 1] = False + self.share_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index + self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length + self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length + self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = 0 + self.share_inputs["prompt_lens"][idx : idx + 1] = len(input_ids) + self.share_inputs["is_block_step"][idx : idx + 1] = False + self.share_inputs["step_idx"][idx : idx + 1] = ( + len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0 + ) + has_prefill_task = True + elif request.task_type.value == RequestType.DECODE.value: # decode task + logger.debug(f"Handle decode request {request} at idx {idx}") + encoder_block_num = len(request.block_tables) + self.share_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num + self.share_inputs["block_tables"][idx : idx + 1, :] = -1 + self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( + request.block_tables, dtype="int32" + ) + continue + else: # preempted task + logger.debug(f"Handle preempted request {request} at idx {idx}") + self.share_inputs["block_tables"][idx : idx + 1, :] = -1 + self.share_inputs["stop_flags"][idx : idx + 1] = True + self.share_inputs["seq_lens_this_time"][idx : idx + 1] = 0 + self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0 + self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0 + self.share_inputs["is_block_step"][idx : idx + 1] = False + continue + + if len(request.eos_token_ids) < self.parallel_config.eos_tokens_lens: + request.eos_token_ids.append(request.eos_token_ids[0]) + self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1) + + self.share_inputs["top_p"][idx : idx + 1] = request.get("top_p", 0.7) + self.share_inputs["temperature"][idx : idx + 1] = request.get("temperature", 0.95) + self.share_inputs["penalty_score"][idx : idx + 1] = request.get("repetition_penalty", 1.0) + self.share_inputs["frequency_score"][idx : idx + 1] = request.get("frequency_penalty", 0.0) + self.share_inputs["presence_score"][idx : idx + 1] = request.get("presence_penalty", 0.0) + + self.share_inputs["min_dec_len"][idx : idx + 1] = request.get("min_tokens", 1) + self.share_inputs["max_dec_len"][idx : idx + 1] = request.get( + "max_tokens", self.model_config.max_model_len + ) + + self.share_inputs["first_token_ids"][idx : idx + 1] = self.share_inputs["input_ids"][idx : idx + 1, :1] + self.share_inputs["ori_seq_lens_encoder"][idx : idx + 1] = length + + if request.get("seed") is not None: + self.share_inputs["infer_seed"][idx : idx + 1] = request.get("seed") + + if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None: + stop_seqs_num = len(request.get("stop_seqs_len")) + for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num): + request.sampling_params.stop_seqs_len.append(0) + self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = np.array( + request.sampling_params.stop_seqs_len, dtype="int32" + ) + self.share_inputs["stop_seqs"][ + idx : idx + 1, :stop_seqs_num, : len(request.get("stop_token_ids")[0]) + ] = np.array(request.get("stop_token_ids"), dtype="int64") + else: + self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0 + + if has_prefill_task: + self.share_inputs["not_need_stop"][0] = True + + def insert_prefill_inputs(self, req_dicts: List[Request]): + """ + Process inputs for prefill tasks and insert it to share_inputs buffer + TODO(gongshaotian): Refactor this func + """ + + # NOTE(luotingdan): Set environment variable of prefill node + if req_dicts[-1].disaggregate_info is not None and req_dicts[-1].disaggregate_info["role"] == "prefill": + os.environ["PREFILL_NODE_ONE_STEP_STOP"] = "1" + + req_len = len(req_dicts) + for i in range(req_len): + request = req_dicts[i] + idx = request.idx + length = len(request.prompt_token_ids) + assert length > 0, "The prompt requested must not be empty." + + prefill_tokens = [] + if ( + request.guided_json is not None + or request.guided_regex is not None + or request.structural_tag is not None + or request.guided_grammar is not None + ): + logits_info, schemata_key = self._init_logits_processor(request) + request.logits_processor, request.logits_cached = logits_info + request.schemata_key = schemata_key + + # Is Decode Node + if req_dicts[i].disaggregate_info is not None and req_dicts[i].disaggregate_info["role"] == "decode": + prefill_tokens.append(request.prompt_token_ids[0]) + self.share_inputs["pre_ids"][idx : idx + 1] = request.prompt_token_ids[-1] + self.share_inputs["input_ids"][idx : idx + 1, 0] = request.prompt_token_ids[0] + self.share_inputs["prompt_ids"][idx : idx + 1, :length] = np.array(request.prompt_token_ids) + self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0 + self.share_inputs["seq_lens_decoder"][idx : idx + 1] = length + self.share_inputs["seq_lens_this_time"][idx : idx + 1] = 1 + self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = 0 + self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = length + self.share_inputs["prompt_lens"][idx : idx + 1] = length + self.share_inputs["step_idx"][idx : idx + 1] = 1 + + if self.speculative_decoding: + num_prefill_send_token = self.speculative_config.num_speculative_tokens + 1 + self.share_inputs["draft_tokens"][idx : idx + 1, 0:num_prefill_send_token] = paddle.to_tensor( + request.draft_token_ids[0:num_prefill_send_token], + dtype="int64", + ) + self.share_inputs["seq_lens_this_time"][idx : idx + 1] = num_prefill_send_token + else: + self.share_inputs["pre_ids"][idx : idx + 1] = -1 + self.share_inputs["step_idx"][idx : idx + 1] = 0 + self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array(request.prompt_token_ids) + self.share_inputs["prompt_ids"][idx : idx + 1, :length] = np.array(request.prompt_token_ids) + + # Use chunked prefill + if self.cache_config.enable_chunked_prefill: + request.set("chunk_idx", 1) + logger.info(f"prefill_chunk_info: {request.prefill_chunk_info}") + token_chunk_size = request.prefill_chunk_info[0] + if self.enable_mm: + inputs = self._preprocess_mm_task(token_chunk_size) + if inputs.get("images") is not None: + self.share_inputs["image_features"] = self.extract_vision_features(inputs) + else: + # Compatible with the situation that lacks images and videos + self.share_inputs["image_features"] = None + if request.multimodal_inputs["position_ids"] is not None: + position_ids = paddle.to_tensor( + request.multimodal_inputs["position_ids"], + dtype="int64", + ).unsqueeze([0]) + else: + position_ids = None + token_chunk_size = inputs["input_ids"].shape[1] + request.set("start_idx", token_chunk_size) + self.share_inputs["input_ids"][idx : idx + 1, :token_chunk_size] = inputs["input_ids"] + else: + self.share_inputs["input_ids"][idx, :token_chunk_size] = np.array( + request.prompt_token_ids[:token_chunk_size] + ) + self.share_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0) + self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0) + self.share_inputs["seq_lens_this_time"][idx : idx + 1] = token_chunk_size + self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = token_chunk_size + self.share_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size + self.share_inputs["prompt_lens"][idx : idx + 1] = token_chunk_size + else: + if self.enable_mm: + inputs = self._preprocess_mm_task(request.multimodal_inputs) + if inputs.get("images") is not None: + self.share_inputs["image_features"] = self.extract_vision_features(inputs) + else: + # Compatible with the situation that lacks images and videos + self.share_inputs["image_features"] = None + position_ids = inputs["position_ids"] + length = inputs["input_ids"].shape[1] + self.share_inputs["input_ids"][idx : idx + 1, :length] = inputs["input_ids"] + else: + self.share_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0) + self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0) + self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length + self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = length + self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length + self.share_inputs["prompt_lens"][idx : idx + 1] = length + + if self.enable_mm: + enable_thinking = request.get("enable_thinking", True) + enable_thinking = enable_thinking if enable_thinking is not None else True + self.share_inputs["enable_thinking"][:] = enable_thinking + self.share_inputs["need_think_end"][idx : idx + 1, :] = 1 if enable_thinking else 0 + self.share_inputs["reasoning_index"][idx : idx + 1, :] = request.get("reasoning_max_tokens", 2048) + self.share_inputs["rope_emb"][idx : idx + 1, :] = self.prepare_rope3d( + position_ids, request.get("max_tokens", 2048) + ) + self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0 + + def get_attr_from_request(request, attr, default_value=None): + res = request.get(attr, default_value) + if res is not None: + return res + else: + return default_value + + if len(request.eos_token_ids) < self.parallel_config.eos_tokens_lens: + request.eos_token_ids.append(request.eos_token_ids[0]) + self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1) + self.share_inputs["top_p"][idx : idx + 1] = get_attr_from_request(request, "top_p", 0.7) + self.share_inputs["top_k"][idx : idx + 1] = request.get("top_k", 0) + self.share_inputs["min_p"][idx : idx + 1] = request.get("min_p", 0.0) + + self.share_inputs["temperature"][idx : idx + 1] = get_attr_from_request(request, "temperature", 0.95) + self.share_inputs["penalty_score"][idx : idx + 1] = get_attr_from_request( + request, "repetition_penalty", 1.0 + ) + self.share_inputs["frequency_score"][idx : idx + 1] = get_attr_from_request( + request, "frequency_penalty", 0.0 + ) + self.share_inputs["presence_score"][idx : idx + 1] = get_attr_from_request( + request, "presence_penalty", 0.0 + ) + + self.share_inputs["min_dec_len"][idx : idx + 1] = request.get("min_tokens", 1) + self.share_inputs["max_dec_len"][idx : idx + 1] = request.get( + "max_tokens", self.model_config.max_model_len + ) + self.share_inputs["stop_flags"][idx : idx + 1] = False + + self.share_inputs["first_token_ids"][idx : idx + 1] = self.share_inputs["input_ids"][idx : idx + 1, :1] + self.share_inputs["ori_seq_lens_encoder"][idx : idx + 1] = length + + if request.get("seed") is not None: + self.share_inputs["infer_seed"][idx : idx + 1] = request.get("seed") + encoder_block_num = len(request.get("block_tables")) + self.share_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num + self.share_inputs["block_tables"][idx : idx + 1, :] = -1 + self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( + request.block_tables, dtype="int32" + ) + + if request.get("bad_words_token_ids") is not None: + bad_words_len = len(request.get("bad_words_token_ids")) + if bad_words_len > 0: + self.share_inputs["bad_tokens_len"][idx : idx + 1] = bad_words_len + self.share_inputs["bad_tokens"][idx : idx + 1, :bad_words_len] = np.array( + request.get("bad_words_token_ids"), dtype="int64" + ) + + if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None: + stop_seqs_num = len(request.get("stop_seqs_len")) + for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num): + request.sampling_params.stop_seqs_len.append(0) + self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = np.array( + request.sampling_params.stop_seqs_len, dtype="int32" + ) + self.share_inputs["stop_seqs"][ + idx : idx + 1, :stop_seqs_num, : len(request.get("stop_token_ids")[0]) + ] = np.array(request.get("stop_token_ids"), dtype="int64") + else: + self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0 + + self.sampler.apply_logits_processor(idx, request.get("logits_processor"), prefill_tokens) + + self.share_inputs["not_need_stop"][0] = True + + if self.speculative_method in ["mtp"]: + self.proposer.insert_prefill_inputs(req_dicts) + + def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int): + """Set dummy prefill inputs to share_inputs""" + # NOTE(gongshaotian): The maximum decoding length is equal to the expected decoded tokens plus the eos token + max_dec_len = expected_decode_len + 1 + full_length = min( + num_tokens // batch_size, + self.parallel_config.max_model_len - max_dec_len, + ) + input_length = int(full_length * self.cache_config.kv_cache_ratio) + block_num = ( + input_length + self.cache_config.block_size - 1 + ) // self.cache_config.block_size + self.cache_config.enc_dec_block_num + + for i in range(batch_size): + idx = i + self.share_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length) + self.share_inputs["prompt_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length) + self.share_inputs["eos_token_id"][:] = np.array([2], dtype="int64").reshape(-1, 1) + self.share_inputs["seq_lens_this_time"][idx : idx + 1] = input_length + self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = input_length + self.share_inputs["seq_lens_encoder"][idx : idx + 1] = input_length + self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0 + self.share_inputs["prompt_lens"][idx : idx + 1] = 0 + self.share_inputs["step_idx"][idx : idx + 1] = 0 + self.share_inputs["max_dec_len"][idx : idx + 1] = max_dec_len + self.share_inputs["min_dec_len"][idx : idx + 1] = max_dec_len + self.share_inputs["stop_flags"][idx : idx + 1] = False + self.share_inputs["temperature"][idx : idx + 1] = 1 + + self.share_inputs["first_token_ids"][idx : idx + 1] = self.share_inputs["input_ids"][idx : idx + 1, :1] + self.share_inputs["ori_seq_lens_encoder"][idx : idx + 1] = input_length + + self.share_inputs["encoder_block_lens"][idx : idx + 1] = block_num + self.share_inputs["block_tables"][idx : idx + 1, :block_num] = np.arange( + idx * block_num, (idx + 1) * block_num, 1 + ) + + def _init_share_inputs(self, max_num_seqs: int): + """ + Initialize all share buffers for model inputs. + """ + self.MAX_INFER_SEED = 9223372036854775806 + self.share_inputs = {} + + self.share_inputs["pre_ids"] = paddle.full( + [max_num_seqs, self.parallel_config.max_model_len], + -1, + dtype="int64", + ) + self.share_inputs["input_ids"] = paddle.full( + [max_num_seqs, self.parallel_config.max_model_len], + self.parallel_config.pad_token_id, + dtype="int64", + ) + self.share_inputs["prompt_ids"] = paddle.full( + [max_num_seqs, self.parallel_config.max_model_len], + self.parallel_config.pad_token_id, + dtype="int64", + ) + self.share_inputs["eos_token_id"] = paddle.full([self.parallel_config.eos_tokens_lens, 1], 0, dtype="int64") + self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], self.model_config.top_p, dtype="float32") + self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int64") + self.share_inputs["min_p"] = paddle.full([max_num_seqs, 1], 0.0, dtype="float32") + self.share_inputs["temperature"] = paddle.full( + [max_num_seqs, 1], self.model_config.temperature, dtype="float32" + ) + self.share_inputs["penalty_score"] = paddle.full( + [max_num_seqs, 1], self.model_config.penalty_score, dtype="float32" + ) + self.share_inputs["frequency_score"] = paddle.full( + [max_num_seqs, 1], + self.model_config.frequency_score, + dtype="float32", + ) + self.share_inputs["presence_score"] = paddle.full( + [max_num_seqs, 1], self.model_config.presence_score, dtype="float32" + ) + + self.share_inputs["min_dec_len"] = paddle.full([max_num_seqs, 1], self.model_config.min_length, dtype="int64") + self.share_inputs["max_dec_len"] = paddle.full( + [max_num_seqs, 1], self.model_config.max_model_len, dtype="int64" + ) + self.share_inputs["min_length"] = paddle.full([max_num_seqs, 1], self.model_config.min_length, dtype="int64") + self.share_inputs["max_length"] = paddle.full( + [max_num_seqs, 1], self.model_config.max_model_len, dtype="int64" + ) + self.share_inputs["seq_lens_this_time"] = paddle.full(max_num_seqs, 0, dtype="int32") + self.share_inputs["seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") + self.share_inputs["seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") + self.share_inputs["step_seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") + self.share_inputs["step_seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") + self.share_inputs["prompt_lens"] = paddle.full([max_num_seqs, 1], 0, dtype="int64") + self.share_inputs["step_idx"] = paddle.full([max_num_seqs, 1], 0, dtype="int64") + self.share_inputs["not_need_stop"] = paddle.full([1], False, dtype="bool").cpu() + self.share_inputs["stop_flags"] = paddle.full([max_num_seqs, 1], True, dtype="bool") + self.share_inputs["stop_nums"] = paddle.full([1], max_num_seqs, dtype="int64") + + self.share_inputs["bad_tokens"] = paddle.full([max_num_seqs, self.model_config.vocab_size], -1, dtype="int64") + self.share_inputs["bad_tokens_len"] = paddle.full([max_num_seqs], 1, dtype="int64") + self.share_inputs["next_tokens"] = paddle.full([max_num_seqs, 1], -1, dtype="int64") + self.share_inputs["is_block_step"] = paddle.full([max_num_seqs], False, dtype="bool") + self.share_inputs["encoder_block_lens"] = paddle.full([max_num_seqs], 0, dtype="int32") + self.share_inputs["step_block_list"] = paddle.full([max_num_seqs], -1, dtype="int32") + self.share_inputs["step_lens"] = paddle.full([1], 0, dtype="int32") + self.share_inputs["recover_block_list"] = paddle.full([max_num_seqs], -1, dtype="int32") + self.share_inputs["recover_lens"] = paddle.full([1], 0, dtype="int32") + self.share_inputs["need_block_list"] = paddle.full([max_num_seqs], -1, dtype="int32") + self.share_inputs["need_block_len"] = paddle.full([1], 0, dtype="int32") + self.share_inputs["used_list_len"] = paddle.full([max_num_seqs], 0, dtype="int32") + self.share_inputs["infer_seed"] = paddle.full([max_num_seqs, 1], 0, dtype="int64") + self.share_inputs["first_token_ids"] = paddle.full([max_num_seqs, 1], -1, dtype="int64") + self.share_inputs["ori_seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") + self.share_inputs["system_lens"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") + self.share_inputs["system_ids"] = paddle.full([max_num_seqs, 1], -1, dtype="int32") + + self.share_inputs["ids_remove_padding"] = paddle.full( + [max_num_seqs * self.parallel_config.max_model_len], + 0, + dtype="int64", + ) + self.share_inputs["cum_offsets"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") + self.share_inputs["batch_id_per_token"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") + self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") + self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") + + # Declare AttentionBackend buffers + self.share_inputs["decoder_batch_ids"] = None + self.share_inputs["decoder_tile_ids_per_batch"] = None + self.share_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory + self.share_inputs["max_len_tensor_cpu"] = None # CPU + + # Initialize rotary position embedding + tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1)) + + # TODO(gongshaotian): move to models + if not self.enable_mm: + self.share_inputs["rope_emb"] = get_rope( + rotary_dim=self.model_config.head_dim, + position_ids=tmp_position_ids, + base=self.model_config.rope_theta, + model_config=self.model_config, + ) + + # Set block tables + pre_max_block_num = ( + self.parallel_config.max_model_len + self.cache_config.block_size - 1 + ) // self.cache_config.block_size + self.cache_config.enc_dec_block_num + self.share_inputs["block_tables"] = paddle.full([max_num_seqs, pre_max_block_num], -1, dtype="int32") + + # Initialize free list + free_list = list( + range( + self.parallel_config.total_block_num - 1, + int(self.parallel_config.total_block_num * self.cache_config.kv_cache_ratio) - 1, + -1, + ) + ) + self.free_list_len = len(free_list) + self.share_inputs["free_list"] = paddle.to_tensor(free_list, dtype="int32") + self.share_inputs["free_list_len"] = paddle.full([1], self.free_list_len, dtype="int32") + + # Initialize stop seqs + self.share_inputs["stop_seqs_len"] = paddle.full( + [max_num_seqs, self.model_config.max_stop_seqs_num], 0, dtype="int32" + ) + self.share_inputs["stop_seqs"] = paddle.full( + [ + max_num_seqs, + self.model_config.max_stop_seqs_num, + self.model_config.stop_seqs_max_len, + ], + -1, + dtype="int64", + ) + if self.speculative_decoding: + max_draft_token_num = self.speculative_config.num_speculative_tokens + self.share_inputs["input_ids_cpu"] = paddle.full( + shape=[max_num_seqs, self.parallel_config.max_model_len], + fill_value=1, + dtype="int64", + ).cpu() + self.share_inputs["accept_tokens"] = paddle.full( + shape=[max_num_seqs, max_draft_token_num + 1], + fill_value=0, + dtype="int64", + ) + self.share_inputs["accept_num"] = paddle.full(shape=[max_num_seqs], fill_value=0, dtype="int32") + self.share_inputs["draft_tokens"] = paddle.full( + shape=[max_num_seqs, max_draft_token_num + 1], + fill_value=0, + dtype="int64", + ) + + self.share_inputs["actual_draft_token_num"] = paddle.full( + shape=[max_num_seqs], + fill_value=max_draft_token_num, + dtype="int32", + ) + self.share_inputs["output_cum_offsets"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") + self.share_inputs["output_padding_offset"] = paddle.full( + shape=[max_num_seqs * (max_draft_token_num + 1)], + fill_value=0, + dtype="int32", + ) + + if self.enable_mm: + head_dim = self.model_config.head_dim + self.share_inputs["rope_emb"] = paddle.full( + shape=[ + max_num_seqs, + 2, + 1, + self.parallel_config.max_model_len, + 1, + head_dim // 2, + ], + fill_value=0, + dtype="float32", + ) + self.share_inputs["image_features"] = None + self.share_inputs["need_think_end"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") + self.share_inputs["enable_thinking"] = paddle.full(shape=[1], fill_value=True, dtype="bool") + self.share_inputs["reasoning_index"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") + + def _prepare_inputs(self) -> None: + """Prepare the model inputs""" + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + recover_decode_task( + self.share_inputs["stop_flags"], + self.share_inputs["seq_lens_this_time"], + self.share_inputs["seq_lens_encoder"], + self.share_inputs["seq_lens_decoder"], + self.share_inputs["step_seq_lens_decoder"], + self.share_inputs["block_tables"], + self.share_inputs["is_block_step"], + self.cache_config.block_size, + ) + + # Remove padding + ( + ids_remove_padding, + cum_offsets, + batch_id_per_token, + cu_seqlens_q, + cu_seqlens_k, + output_cum_offsets, + output_padding_offset, + ) = pre_process( + self.share_inputs["input_ids"], + self.share_inputs["seq_lens_this_time"], + self.speculative_decoding, + (self.share_inputs["draft_tokens"] if self.speculative_decoding else None), + self.share_inputs["seq_lens_encoder"], + self.share_inputs["seq_lens_decoder"], + ) + + self.share_inputs["ids_remove_padding"].copy_(ids_remove_padding, False) + self.share_inputs["cum_offsets"].copy_(cum_offsets, False) + self.share_inputs["batch_id_per_token"].copy_(batch_id_per_token, False) + self.share_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False) + self.share_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False) + + # For speculative decoding + if self.speculative_decoding: + self.share_inputs["output_cum_offsets"].copy_(output_cum_offsets, False) + self.share_inputs["output_padding_offset"].copy_(output_padding_offset, False) + + # Update bad tokens len + max_bad_tokens_len = paddle.max(self.share_inputs["bad_tokens_len"]) + + # Initialize forward meta data + self.initialize_forward_meta() + + # Get sampling metadata + self.sampling_metadata = SamplingMetadata( + temperature=self.share_inputs["temperature"], + top_p=self.share_inputs["top_p"], + top_k=self.share_inputs["top_k"], + min_p=self.share_inputs["min_p"], + step_idx=self.share_inputs["step_idx"], + pre_token_ids=self.share_inputs["pre_ids"], + prompt_ids=self.share_inputs["prompt_ids"], + prompt_lens=self.share_inputs["prompt_lens"], + frequency_penalties=self.share_inputs["frequency_score"], + presence_penalties=self.share_inputs["presence_score"], + repetition_penalties=self.share_inputs["penalty_score"], + min_dec_lens=self.share_inputs["min_dec_len"], + bad_words_token_ids=self.share_inputs["bad_tokens"][:, :max_bad_tokens_len], + eos_token_ids=self.share_inputs["eos_token_id"], + max_num_logprobs=20 if self.enable_logprob else None, + enable_early_stop=self.enable_early_stop, + stop_flags=self.share_inputs["stop_flags"], + ) + + def load_model(self) -> None: + """load or download model""" + logger.info(f"Starting to load model {self.model_config.architectures[0]}") + # 1. Load original model + model_loader = get_model_loader(load_config=self.fd_config.load_config) + self.model = model_loader.load_model(fd_config=self.fd_config) + # 1.1 Load RL dynamic model + if self.fd_config.load_config.dynamic_load_weight: + from fastdeploy.rl.dynamic_weight_manager import DynamicWeightManager + + self.dynamic_weight_manager = DynamicWeightManager(self.fd_config, self.model) + + # 2. Load lora model + + # 3. Load drafter model(for speculative decoding) + + # 4. Init proposer for speculative method + self._init_speculative_proposer() + + def get_model(self) -> nn.Layer: + """Get current model""" + return self.model + + def initialize_forward_meta(self): + """ + Initialize forward meta and attention meta data + """ + # Initialize forward meta + self.forward_meta = ForwardMeta( + input_ids=self.share_inputs["input_ids"], + ids_remove_padding=self.share_inputs["ids_remove_padding"], + rotary_embs=self.share_inputs["rope_emb"], + attn_backend=self.attn_backends[0], + decoder_batch_ids=self.share_inputs["decoder_batch_ids"], + decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"], + decoder_num_blocks_cpu=self.share_inputs["decoder_num_blocks_cpu"], + max_len_tensor_cpu=self.share_inputs["max_len_tensor_cpu"], + seq_lens_encoder=self.share_inputs["seq_lens_encoder"], + seq_lens_decoder=self.share_inputs["seq_lens_decoder"], + seq_lens_this_time=self.share_inputs["seq_lens_this_time"], + batch_id_per_token=self.share_inputs["batch_id_per_token"], + cu_seqlens_q=self.share_inputs["cu_seqlens_q"], + cu_seqlens_k=self.share_inputs["cu_seqlens_k"], + block_tables=self.share_inputs["block_tables"], + caches=self.share_inputs["caches"], + ) + + # Update Batch type for cuda graph + only_decode_batch = True + prefill_exists = None + # mix ep in single node + if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed": + only_decode_batch_list = [] + prefill_exists = self.exist_prefill() + paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists) + only_decode_batch = all(only_decode_batch_list) + self.fd_config.parallel_config.moe_phase.phase = "decode" if only_decode_batch else "prefill" + + self.forward_meta.step_use_cudagraph = ( + self.use_cudagraph + and only_decode_batch + and not (prefill_exists if prefill_exists is not None else self.exist_prefill()) + ) + + # Initialzie attention meta data + for attn_backend in self.attn_backends: + attn_backend.init_attention_metadata(self.forward_meta) + + def initialize_kv_cache(self, profile: bool = False) -> None: + """ + Initialize kv cache + """ + cache_kvs = {} + max_block_num = self.num_gpu_blocks + + # Get kv cache dtype + cache_type = self.parallel_config.dtype + + kv_cache_quant_type = None + if ( + self.quant_config + and hasattr(self.quant_config, "kv_cache_quant_type") + and self.quant_config.kv_cache_quant_type is not None + ): + cache_type = "uint8" + kv_cache_quant_type = self.quant_config.kv_cache_quant_type + + # Get kv cache shape + kv_cache_shape = self.attn_backends[0].get_kv_cache_shape( + max_num_blocks=max_block_num, kv_cache_quant_type=kv_cache_quant_type + ) + local_rank = self.local_rank % self.parallel_config.tensor_parallel_size + + if not profile and (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"): + cache_kvs_list = [] + for i in range(self.model_config.num_hidden_layers): + key_cache = paddle.empty(shape=[], dtype=cache_type) + key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}" + val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}" + key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape) + cache_kvs_list.append(key_cache) + value_cache = paddle.empty(shape=[], dtype=cache_type) + value_cache = share_external_data(value_cache, val_cache_name, kv_cache_shape) + cache_kvs_list.append(value_cache) + + self.share_inputs["caches"] = cache_kvs_list + + else: + for i in range(self.model_config.num_hidden_layers): + cache_kvs[f"key_caches_{i}"] = paddle.full( + shape=kv_cache_shape, + fill_value=0, + dtype=cache_type, + ) + cache_kvs[f"value_caches_{i}"] = paddle.full( + shape=kv_cache_shape, + fill_value=0, + dtype=cache_type, + ) + self.share_inputs["caches"] = list(cache_kvs.values()) + for value in cache_kvs.values(): + del value + paddle.device.cuda.empty_cache() + + def initialize_attn_backend(self) -> None: + """ + Initialize attention backends + """ + assert len(self.attn_backends) == 0 + + num_heads = self.model_config.num_attention_heads // self.parallel_config.tensor_parallel_size + self.model_config.kv_num_heads = max( + 1, + int(self.model_config.num_key_value_heads) // self.parallel_config.tensor_parallel_size, + ) + head_dim = self.model_config.head_dim + + # Initialize AttentionBackend buffers + encoder_block_shape_q = 64 + decoder_block_shape_q = 16 + decoder_step_token_num = self.speculative_config.num_speculative_tokens + 1 + decode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil( + (decoder_step_token_num * np.ceil(num_heads / self.model_config.kv_num_heads)) / decoder_block_shape_q + ) + self.share_inputs["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") + self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") + # self.share_inputs["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").pin_memory() + # self.share_inputs["max_len_tensor_cpu"] = paddle.full([8], 0, dtype="int32").cpu() + + # Get the attention backend + attn_cls = get_attention_backend() + attn_backend = attn_cls( + self.fd_config, + kv_num_heads=self.model_config.kv_num_heads, + num_heads=num_heads, + head_dim=head_dim, + encoder_block_shape_q=encoder_block_shape_q, + decoder_block_shape_q=decoder_block_shape_q, + ) + + self.attn_backends.append(attn_backend) + + def _dummy_run( + self, + num_tokens: paddle.Tensor, + batch_size: paddle.Tensor, + expected_decode_len: int = 1, + in_capturing: bool = False, + ) -> paddle.Tensor: + """ + Use dummy inputs to run before formal execution. + Args: + num_tokens: + expected_decode_len: Expected number of tokens generated + in_capturing: Is cuda graph in capturing state + """ + self._dummy_prefill_inputs( + num_tokens=num_tokens, + batch_size=batch_size, + expected_decode_len=expected_decode_len, + ) + if self.speculative_method in ["mtp"]: + self.proposer.dummy_prefill_inputs( + num_tokens=num_tokens, + batch_size=batch_size, + expected_decode_len=expected_decode_len, + ) + while True: + + # 1. Initialize forward meta and attention meta data + self._prepare_inputs() + + # 2. Padding inputs for cuda graph + self.forward_meta.step_use_cudagraph = in_capturing and self.forward_meta.step_use_cudagraph + self.padding_cudagraph_inputs() + + # 3. Run model + if self.enable_mm: + model_output = self.model( + self.share_inputs["ids_remove_padding"], + self.share_inputs["image_features"], + self.forward_meta, + ) + hidden_states = model_output + else: + model_output = self.model( + ids_remove_padding=self.share_inputs["ids_remove_padding"], + forward_meta=self.forward_meta, + ) + + hidden_states = rebuild_padding( + model_output, + self.share_inputs["cum_offsets"], + self.share_inputs["seq_lens_this_time"], + self.share_inputs["seq_lens_decoder"], + self.share_inputs["seq_lens_encoder"], + ( + self.share_inputs["output_padding_offset"] if self.speculative_decoding else None + ), # speculative decoding requires + self.parallel_config.max_model_len, + ) + + # 4. Execute spec decode + logits = self.model.compute_logits(hidden_states) + + if not self.speculative_decoding: + set_value_by_flags_and_idx( + self.share_inputs["pre_ids"], + self.share_inputs["input_ids"], + self.share_inputs["seq_lens_this_time"], + self.share_inputs["seq_lens_encoder"], + self.share_inputs["seq_lens_decoder"], + self.share_inputs["step_idx"], + self.share_inputs["stop_flags"], + ) + sampler_output = self.sampler(logits, self.sampling_metadata) + if self.parallel_config.tensor_parallel_size > 1: + paddle.distributed.broadcast(sampler_output.sampled_token_ids, 0) + else: + self.sampler( + logits, + self.sampling_metadata, + self.parallel_config.max_model_len, + self.share_inputs, + ) + sampler_output = None + if self.parallel_config.tensor_parallel_size > 1: + paddle.distributed.broadcast(self.share_inputs["accept_tokens"], 0) + paddle.distributed.broadcast(self.share_inputs["accept_num"], 0) + paddle.distributed.broadcast(self.share_inputs["step_idx"], 0) + paddle.distributed.broadcast(self.share_inputs["stop_flags"], 0) + + # 5. post process + model_output_data = ModelOutputData( + next_tokens=self.share_inputs["next_tokens"], + stop_flags=self.share_inputs["stop_flags"], + step_idx=self.share_inputs["step_idx"], + max_dec_len=self.share_inputs["max_dec_len"], + pre_ids=self.share_inputs["pre_ids"], + seq_lens_this_time=self.share_inputs["seq_lens_this_time"], + eos_token_id=self.share_inputs["eos_token_id"], + not_need_stop=self.share_inputs["not_need_stop"], + input_ids=self.share_inputs["input_ids"], + stop_nums=self.share_inputs["stop_nums"], + seq_lens_encoder=self.share_inputs["seq_lens_encoder"], + seq_lens_decoder=self.share_inputs["seq_lens_decoder"], + is_block_step=self.share_inputs["is_block_step"], + full_hidden_states=model_output, + msg_queue_id=self.parallel_config.msg_queue_id, + mp_rank=self.local_rank, + use_ep=self.parallel_config.use_ep, + draft_tokens=(self.share_inputs["draft_tokens"] if self.speculative_decoding else None), + actual_draft_token_num=( + self.share_inputs["actual_draft_token_num"] if self.speculative_decoding else None + ), + accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None), + accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None), + enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None), + think_end_id=(self.model_config.think_end_id if self.enable_mm else -1), + need_think_end=(self.share_inputs["need_think_end"] if self.enable_mm else None), + reasoning_index=(self.share_inputs["reasoning_index"] if self.enable_mm else None), + stop_token_ids=self.share_inputs["stop_seqs"], + stop_seqs_len=self.share_inputs["stop_seqs_len"], + ) + + post_process( + sampler_output=sampler_output, + model_output=model_output_data, + share_inputs=self.share_inputs, + block_size=self.cache_config.block_size, + speculative_decoding=self.speculative_decoding, + skip_save_output=True, + ) + + if self.speculative_decoding: + if self.speculative_method == "mtp": + self.proposer.run(full_hidden_states=model_output) + else: + self.proposer.run(share_inputs=self.share_inputs) + + # 7. Updata 'infer_seed' and step_cuda() + self.share_inputs["infer_seed"].add_(self.infer_seed_increment) + self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED + step_cuda( + self.share_inputs, + self.cache_config.block_size, + self.cache_config.enc_dec_block_num, + self.speculative_config, + self.cache_config.enable_prefix_caching, + ) + + if int((self.share_inputs["seq_lens_this_time"] > 0).sum()) == 0: + break + + def _update_chunked_prefill(self, tasks): + """ + Update chunked prefill related parameters + """ + if not self.cache_config.enable_chunked_prefill: + return + for task in tasks: + if task.get("prefill_chunk_info", None) is None: + continue + + if task.chunk_idx > len(task.prefill_chunk_info): + continue + self.restore_chunked_prefill_request[task.request_id] = task + + for id, task in list(self.restore_chunked_prefill_request.items()): + idx = task.idx + logger.debug(f"{task.request_id} chunked prefill {task.chunk_idx}/{len(task.prefill_chunk_info)}") + if not self.enable_mm: + start_idx = sum(task.prefill_chunk_info[: task.chunk_idx]) + if task.chunk_idx == len(task.prefill_chunk_info): + self.share_inputs["seq_lens_this_time"][idx : idx + 1] = 1 + self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0 + self.share_inputs["step_idx"][idx : idx + 1] = 1 + if self.enable_mm: + self.share_inputs["seq_lens_decoder"][idx : idx + 1] = task.start_idx + else: + self.share_inputs["seq_lens_decoder"][idx : idx + 1] = start_idx + task.get("seq_lens_decoder", 0) + del self.restore_chunked_prefill_request[task.request_id] + else: + token_chunk_size = task.prefill_chunk_info[task.chunk_idx] + if self.enable_mm: + inputs = self._preprocess_mm_task(task.prefill_chunk_info[task.chunk_idx]) + if inputs.get("images") is not None: + self.share_inputs["image_features"] = self.extract_vision_features(inputs) + else: + # Compatible with the situation that lacks images and videos + self.share_inputs["image_features"] = None + token_chunk_size = inputs["input_ids"].shape[1] + self.share_inputs["input_ids"][idx : idx + 1, :token_chunk_size] = inputs["input_ids"] + self.share_inputs["prompt_ids"][ + idx : idx + 1, + self.share_inputs["prompt_lens"][idx : idx + 1] : self.share_inputs["prompt_lens"][ + idx : idx + 1 + ] + + token_chunk_size, + ] = inputs["input_ids"] + self.share_inputs["seq_lens_decoder"][idx : idx + 1] = task.start_idx + task.start_idx += token_chunk_size + else: + self.share_inputs["input_ids"][idx, :token_chunk_size] = np.array( + task.prompt_token_ids[start_idx : start_idx + token_chunk_size] + ) + self.share_inputs["seq_lens_decoder"][idx : idx + 1] = start_idx + task.get("seq_lens_decoder", 0) + self.share_inputs["seq_lens_this_time"][idx : idx + 1] = token_chunk_size + self.share_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size + self.share_inputs["prompt_lens"][idx : idx + 1] += token_chunk_size + self.share_inputs["step_idx"][idx : idx + 1] = 0 + + if self.speculative_decoding and self.proposer.is_chunk_prefill_enabled(): + self.proposer.update_task_chunk_prefill(task) + task.chunk_idx += 1 + + def capture_model(self) -> None: + """ + Trigger CUDA Graph capture for all shapes in cuda graph capture list + """ + if not self.use_cudagraph: + logger.info("Skipping CUDA graph capture. Please check GraphOptimizationConfig") + return + time_before_capture = time.perf_counter() + expected_decode_len = 1 + capture_sizes = self.cudagraph_capture_sizes.copy() + for batch_size in sorted(capture_sizes, reverse=True): + self._dummy_run( + num_tokens=self.parallel_config.max_num_batched_tokens, + batch_size=batch_size, + in_capturing=True, + expected_decode_len=expected_decode_len, + ) + logger.info(f"Warm up the model with the batch size:{batch_size}, num tokens:{expected_decode_len}") + + time_after_capture = time.perf_counter() + logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds") + + @sot_warmup_guard(True) + def sot_warmup(self) -> None: + start_time = time.perf_counter() + for batch_size in self.sot_warmup_sizes: + self._dummy_run( + num_tokens=self.parallel_config.max_num_batched_tokens, + batch_size=batch_size, + ) + logger.info(f"SOT warmup the model with the batch size:{batch_size}") + logger.info(f"SOT warmup took {time.perf_counter() - start_time} seconds") + + def _get_skip_idx(self, model_forward_batch: Optional[List[Request]] = None): + """ + Get the index of the request that needs to be skipped during execution. + Args: + model_forward_batch: A list of requests to be executed by this runner. + Returns: + A list of indices corresponding to the requests that need to be skipped. + """ + skip_idx_list = [] + if not self.cache_config.enable_chunked_prefill or self.guided_backend is None: + return skip_idx_list + + for task in model_forward_batch: + if task.get("prefill_chunk_info", None) is None or task.chunk_idx >= len(task.prefill_chunk_info): + continue + skip_idx_list.append(task.idx) + + for task in self.restore_chunked_prefill_request.values(): + if task.idx in skip_idx_list or task.chunk_idx >= len(task.prefill_chunk_info): + continue + skip_idx_list.append(task.idx) + + return skip_idx_list + + def execute_model( + self, + model_forward_batch: Optional[List[Request]] = None, + ) -> Optional[ModelRunnerOutput]: + """ + The Entrance of model execute. + Args: + model_forward_batch: 'Request' contains information related to prompt and is an abstract + class at the server level, which is too granular for ModelRunner. + We plan to replace it with 'ModelForwardBatch'. + intermediate_tensors: + """ + # 1. Prepare inputs of model and sampler. + skip_idx_list = self._get_skip_idx(model_forward_batch) + self._prepare_inputs() + self.sampler.pre_process(skip_idx_list) + + # NOTE(wufeisheng): If `not_need_stop`` is False, it means the current worker is in an idle state. + # This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode, + # when there is data on other runner, the current runner is required to execute part of the model. + if not self.not_need_stop(): + self._execute_empty_input() + return None + + # 2. Padding inputs for cuda graph + self.padding_cudagraph_inputs() + + # 3. Execute model + if self.enable_mm: + model_output = self.model( + self.share_inputs["ids_remove_padding"], + self.share_inputs["image_features"], + self.forward_meta, + ) + hidden_states = model_output + else: + model_output = self.model( + ids_remove_padding=self.share_inputs["ids_remove_padding"], + forward_meta=self.forward_meta, + ) + hidden_states = rebuild_padding( + model_output, + self.share_inputs["cum_offsets"], + self.share_inputs["seq_lens_this_time"], + self.share_inputs["seq_lens_decoder"], + self.share_inputs["seq_lens_encoder"], + (self.share_inputs["output_padding_offset"] if self.speculative_decoding else None), + self.parallel_config.max_model_len, + ) + + # 4. Compute logits, Sample + logits = self.model.compute_logits(hidden_states) + + if not self.speculative_decoding: + set_value_by_flags_and_idx( + self.share_inputs["pre_ids"], + self.share_inputs["input_ids"], + self.share_inputs["seq_lens_this_time"], + self.share_inputs["seq_lens_encoder"], + self.share_inputs["seq_lens_decoder"], + self.share_inputs["step_idx"], + self.share_inputs["stop_flags"], + ) + sampler_output = self.sampler( + logits, + self.sampling_metadata, + skip_idx_list, + ) + if self.parallel_config.tensor_parallel_size > 1: + paddle.distributed.broadcast(sampler_output.sampled_token_ids, 0) + + else: + self.sampler( + logits, + self.sampling_metadata, + self.parallel_config.max_model_len, + self.share_inputs, + ) + sampler_output = None + if self.parallel_config.tensor_parallel_size > 1: + paddle.distributed.broadcast(self.share_inputs["accept_tokens"], 0) + paddle.distributed.broadcast(self.share_inputs["accept_num"], 0) + paddle.distributed.broadcast(self.share_inputs["step_idx"], 0) + paddle.distributed.broadcast(self.share_inputs["stop_flags"], 0) + + # 5. Post Process + model_output_data = ModelOutputData( + next_tokens=self.share_inputs["next_tokens"], + stop_flags=self.share_inputs["stop_flags"], + step_idx=self.share_inputs["step_idx"], + max_dec_len=self.share_inputs["max_dec_len"], + pre_ids=self.share_inputs["pre_ids"], + seq_lens_this_time=self.share_inputs["seq_lens_this_time"], + eos_token_id=self.share_inputs["eos_token_id"], + not_need_stop=self.share_inputs["not_need_stop"], + input_ids=self.share_inputs["input_ids"], + stop_nums=self.share_inputs["stop_nums"], + seq_lens_encoder=self.share_inputs["seq_lens_encoder"], + seq_lens_decoder=self.share_inputs["seq_lens_decoder"], + is_block_step=self.share_inputs["is_block_step"], + full_hidden_states=model_output, + msg_queue_id=self.parallel_config.msg_queue_id, + mp_rank=self.local_rank, + use_ep=self.parallel_config.use_ep, + draft_tokens=(self.share_inputs["draft_tokens"] if self.speculative_decoding else None), + actual_draft_token_num=( + self.share_inputs["actual_draft_token_num"] if self.speculative_decoding else None + ), + accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None), + accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None), + enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None), + think_end_id=(self.model_config.think_end_id if self.enable_mm else -1), + need_think_end=(self.share_inputs["need_think_end"] if self.enable_mm else None), + reasoning_index=(self.share_inputs["reasoning_index"] if self.enable_mm else None), + stop_token_ids=self.share_inputs["stop_seqs"], + stop_seqs_len=self.share_inputs["stop_seqs_len"], + ) + + if self.speculative_config.method in ["mtp"] and self.parallel_config.splitwise_role == "prefill": + skip_save_output = True + else: + skip_save_output = False + post_process( + sampler_output=sampler_output, + model_output=model_output_data, + share_inputs=self.share_inputs, + block_size=self.cache_config.block_size, + save_each_rank=self.parallel_config.use_ep, + speculative_decoding=self.speculative_decoding, + skip_save_output=skip_save_output, + ) + + # 6. Speculative decode + if self.speculative_decoding: + if self.speculative_method == "mtp": + self.proposer.run(full_hidden_states=model_output) + else: + self.proposer.run(share_inputs=self.share_inputs) + + # 7. Updata 'infer_seed' and step_cuda() + self.share_inputs["infer_seed"].add_(self.infer_seed_increment) + self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED + if not envs.ENABLE_V1_KVCACHE_SCHEDULER: + step_cuda( + self.share_inputs, + self.cache_config.block_size, + self.cache_config.enc_dec_block_num, + self.speculative_config, + self.cache_config.enable_prefix_caching, + ) + + self._update_chunked_prefill(model_forward_batch) + self._add_cache(model_forward_batch) + return None + + def _add_cache(self, model_forward_batch) -> None: + """ + Add cache for guided decoding. + """ + if self.guided_backend is None: + return + + for request in model_forward_batch: + logits_cached = request.get("logits_cached", None) + if logits_cached is None or logits_cached: + continue + + request.logits_cached = True + if isinstance(request.logits_processor, LogitsProcessorBase): + self.guided_backend.add_cache(request.schemata_key, request.logits_processor) + else: + self.guided_backend.add_cache(request.schemata_key, request.logits_processor.result()) + + def _execute_empty_input(self) -> None: + """ + In certain scenarios, such as during EP, + the runner needs to execute partial modules of the model without input data. + This requires the model to implement the `empty_input_forward` method. + """ + if hasattr(self.model, "empty_input_forward"): + self.model.empty_input_forward() + else: + raise ValueError(f"{type(self.model)} has no attribute 'empty_input_forward") + + @profile_run_guard(True) + def profile_run(self) -> None: + """Execute a forward pass with dummy inputs to profile the memory usage of the model""" + + # Initialize kv cache for profile run. After profile run kv cache will be reset. + # TODO(gongshaotian): Optimize the management logic of kvcache + self.num_gpu_blocks = self.parallel_config.total_block_num + self.initialize_kv_cache(profile=True) + + # 1. Profile with multimodal encoder & encoder cache + + # 2. Dummy run + self._dummy_run( + num_tokens=self.parallel_config.max_num_batched_tokens, + batch_size=min(self.parallel_config.max_num_seqs, 3), + ) + + # 3. gc + self.clear_cache() + + if self.speculative_method in ["mtp"]: + self.proposer.clear_dummy_input() + + def update_share_input_block_num(self, num_gpu_blocks: int) -> None: + """ + Set a globally unified block number and update the model's shared input. + Args: + num_gpu_blocks: + """ + self.num_gpu_blocks = num_gpu_blocks + + # Reset block table and kv cache with global block num + self.initialize_kv_cache() + + # Reset free list + free_list = list( + range( + self.num_gpu_blocks - 1, + int(self.num_gpu_blocks * self.cache_config.kv_cache_ratio) - 1, + -1, + ) + ) + self.free_list_len = len(free_list) + self.share_inputs.update( + { + "free_list": paddle.to_tensor(free_list, dtype="int32"), + "free_list_len": paddle.full([1], self.free_list_len, dtype="int32"), + } + ) + + if self.speculative_method in ["mtp"]: + self.proposer.update_block_num(num_gpu_blocks) + + def cal_theortical_kvcache(self): + """ + Calculate the total block memory required at the model level + """ + """ + Byte of dtype: + - default(bf16): 2 + - cache_int8: 1 + - cache_int4: + """ + cache_quant_dtype = None + if ( + self.quant_config + and hasattr(self.quant_config, "kv_cache_quant_type") + and self.quant_config.kv_cache_quant_type is not None + ): + cache_quant_dtype = self.quant_config.kv_cache_quant_type + + if cache_quant_dtype is not None: # int8, int8_zp, fp8, fp8_zp + byte_of_dtype = 1 + else: # default + byte_of_dtype = 2 + + hidden_dim = self.model_config.head_dim * self.model_config.kv_num_heads + + num_layers = ( + self.model_config.num_hidden_layers + self.speculative_config.num_gpu_block_expand_ratio + if self.speculative_method in ["mtp"] + else self.model_config.num_hidden_layers + ) + required_memory = byte_of_dtype * 2 * (self.cache_config.block_size * hidden_dim) * num_layers # k + v + return required_memory + + def not_need_stop(self) -> bool: + """Stop decoding if the tensor meets the termination condition""" + return self.share_inputs["not_need_stop"][0] + + def clear_cache(self): + """Clear cached data from shared inputs and forward metadata""" + self.share_inputs.pop("caches", None) + if self.forward_meta is not None: + self.forward_meta.clear_caches() + + def clear_parameters(self, pid): + """ " Dynamic model loader use to clear parameters use for RL""" + self.dynamic_weight_manager.clear_parameters(pid) + self.clear_cache() + paddle.device.cuda.empty_cache() + self.dynamic_weight_manager._log_memory("dynamic weight manager clear all memory") + + def update_parameters(self, pid): + """ " Dynamic model loader use to update parameters use for RL""" + self.dynamic_weight_manager.update_parameters(pid) + self.initialize_kv_cache() + self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory") + + def padding_cudagraph_inputs(self) -> None: + """ + Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch. + In FastDeploy, almost all input tensors have a buffer. So, just keep the buffer clean when replaying the CUDA graph with the padded batch. + """ + # In init_attention_metadata, the decode buffer has already been cleared + return + + def _init_image_preprocess(self) -> None: + processor = DataProcessor( + tokenizer_name=self.model_config.model, + image_preprocessor_name=str(self.model_config.model), + ) + processor.eval() + image_preprocess = processor.image_preprocessor + image_preprocess.image_mean_tensor = paddle.to_tensor(image_preprocess.image_mean, dtype="float32").reshape( + [1, 3, 1, 1] + ) + image_preprocess.image_std_tensor = paddle.to_tensor(image_preprocess.image_std, dtype="float32").reshape( + [1, 3, 1, 1] + ) + image_preprocess.rescale_factor = paddle.to_tensor(image_preprocess.rescale_factor, dtype="float32") + image_preprocess.image_mean_tensor = image_preprocess.image_mean_tensor.squeeze([-2, -1]).repeat_interleave( + self.model_config.vision_config.patch_size**2 * 1, -1 + ) + image_preprocess.image_std_tensor = image_preprocess.image_std_tensor.squeeze([-2, -1]).repeat_interleave( + self.model_config.vision_config.patch_size**2 * 1, -1 + ) + self.image_preprocess = image_preprocess + + def _preprocess_mm_task(self, one: dict) -> None: + """process batch""" + + input_ids = one["input_ids"][np.newaxis, :] + input_ids = paddle.to_tensor(input_ids, dtype=paddle.int64) + token_type_ids = one["token_type_ids"][np.newaxis, :] + token_type_ids = paddle.to_tensor(token_type_ids, dtype=paddle.int64) + + if one["images"] is not None: + image_type_ids = one["image_type_ids"][np.newaxis, :] + images = one["images"] + image_type_ids = paddle.to_tensor(image_type_ids, dtype=paddle.int64) + images = paddle.to_tensor(images, dtype="uint8") + grid_thw = paddle.to_tensor(one["grid_thw"], dtype="int64") + else: + image_type_ids = None + images = None + grid_thw = None + + if one["position_ids"] is not None: + position_ids = paddle.to_tensor(one["position_ids"], dtype="int64").unsqueeze([0]) + else: + position_ids = None + + result = dict( + input_ids=input_ids, + image_type_ids=image_type_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + grid_thw=grid_thw, + images=images, + ) + return result + + @paddle.no_grad() + def extract_vision_features(self, inputs: list[paddle.Tensor]) -> paddle.Tensor: + """extract_vision_features""" + assert inputs["images"] is not None + grid_thw = inputs["grid_thw"] + + images = inputs["images"].cast("float32") + images = self.image_preprocess.rescale_factor * images - self.image_preprocess.image_mean_tensor + images = images / self.image_preprocess.image_std_tensor + images = images.cast("bfloat16") + + token_type_ids = inputs["token_type_ids"] + token_type_ids_w_video = token_type_ids + input_ids = inputs["input_ids"] + # convert to img patch id + # TODO(lulinjun): may need to check model_config and model_cfg + image_mask = input_ids == self.model_config.im_patch_id + image_type_ids = inputs["image_type_ids"] + with paddle.amp.auto_cast( + True, + custom_black_list=self.amp_black, + custom_white_list=self.amp_white, + level="O2", + dtype=self.parallel_config.dtype, + ): + image_features = self.model.vision_model.extract_feature(images, grid_thw) + if self.parallel_config.tensor_parallel_size > 1: + S, C = image_features.shape + image_features = image_features.reshape([-1, C * self.model_config.spatial_conv_size**2]) + image_features = ScatterOp.apply(image_features, axis=-1) # mp 切 Fea + image_features = image_features.reshape([S, -1]) + image_features = self.model.resampler_model( + image_features, + image_mask, + token_type_ids_w_video, + image_type_ids, + grid_thw, + ) + return image_features + + @paddle.no_grad() + def prepare_rope3d(self, position_ids: paddle.Tensor, max_len: int) -> paddle.Tensor: + """prepare_rope3d""" + + prefix_max_position_ids = paddle.max(position_ids) + 1 + dec_pos_ids = paddle.tile( + paddle.arange(max_len, dtype="int64").unsqueeze(0).unsqueeze(-1), + [1, 1, 3], + ) + dec_pos_ids = dec_pos_ids + prefix_max_position_ids + position_ids_3d_real = paddle.concat([position_ids, dec_pos_ids], axis=1) + + rope_emb = get_rope_3d( + position_ids=position_ids_3d_real, + rotary_dim=self.model_config.head_dim, + partial_rotary_factor=1.0, + base=self.model_config.rope_theta, + max_position=self.parallel_config.max_model_len, + freq_allocation=getattr(self.model_config, "freq_allocation", 20), + ) + return rope_emb diff --git a/fastdeploy/worker/metax_worker.py b/fastdeploy/worker/metax_worker.py new file mode 100644 index 000000000..ddf36580c --- /dev/null +++ b/fastdeploy/worker/metax_worker.py @@ -0,0 +1,203 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import gc +import os +import time +from typing import List, Optional + +import paddle +from paddle import nn + +from fastdeploy import envs +from fastdeploy.config import FDConfig +from fastdeploy.engine.request import Request +from fastdeploy.utils import get_logger +from fastdeploy.worker.metax_model_runner import MetaxModelRunner +from fastdeploy.worker.output import ModelRunnerOutput +from fastdeploy.worker.worker_base import WorkerBase + +logger = get_logger("metax_worker", "metax_worker.log") + + +class MetaxWorker(WorkerBase): + def __init__( + self, + fd_config: FDConfig, + local_rank: int, + rank: int, + ): + super().__init__( + fd_config=fd_config, + local_rank=local_rank, + rank=rank, + ) + pass + + def init_device(self): + """ + Initialize device and construct model runner + """ + self.max_chips_per_node = 8 + if paddle.is_compiled_with_custom_device("metax_gpu"): + # Set evironment variable + self.device_ids = self.parallel_config.device_ids.split(",") + self.device = f"metax_gpu:{self.local_rank % self.max_chips_per_node}" + paddle.device.set_device(self.device) + paddle.set_default_dtype(self.parallel_config.dtype) + + gc.collect() + paddle.device.cuda.empty_cache() + else: + raise RuntimeError(f"Not support device type: {self.device_config.device}") + + # Construct model runner + self.model_runner: MetaxModelRunner = MetaxModelRunner( + fd_config=self.fd_config, + device=self.device, + device_id=self.device_ids[self.local_rank % self.max_chips_per_node], + rank=self.rank, + local_rank=self.local_rank, + ) + + def exist_prefill(self): + """ + check whether prefill stage exist + """ + return self.model_runner.exist_prefill() + + def determine_available_memory(self) -> int: + """ + Profiles the peak memory usage of the model to determine how much + memory can be used for KV cache without OOMs. + + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the maximum possible number of GPU and CPU blocks + that can be allocated with the remaining free memory. + + Tip: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. + """ + """Will implement later""" + + # 1. Record memory state before profile run + start_time = time.perf_counter() + Gb = 1024**3 + + local_rank = self.local_rank % self.max_chips_per_node + paddle.device.cuda.reset_max_memory_reserved(local_rank) + paddle.device.cuda.reset_max_memory_allocated(local_rank) + # max memory for Allocator + paddle_reserved_mem_before_run = paddle.device.cuda.max_memory_reserved(local_rank) + # max memory for Tensor + paddle_allocated_mem_before_run = paddle.device.cuda.max_memory_allocated(local_rank) # not reserved + + device_id = int(self.device_ids[local_rank]) + if os.getenv("MACA_VISIBLE_DEVICES") is not None: + device_id = int(os.getenv("MACA_VISIBLE_DEVICES").split(",")[device_id]) + + import pymxsml + + pymxsml.mxSmlInit() + info = pymxsml.mxSmlGetMemoryInfo(device_id) + before_run_meminfo_total = info.vramTotal * 1024 + before_run_meminfo_used = info.vramUse * 1024 + before_run_meminfo_free = before_run_meminfo_total - before_run_meminfo_used + + logger.info("Before running the profile, the memory usage info of Metax GPU is as follows:") + logger.info(f"Device Index: {device_id}") + logger.info(f"Device Total memory: {before_run_meminfo_total / Gb}") + logger.info(f"Device used memory: {before_run_meminfo_used / Gb}") + logger.info(f"Device free memory: {before_run_meminfo_free / Gb}") + logger.info(f"Paddle reserved memory: {paddle_reserved_mem_before_run / Gb}") + logger.info(f"Paddle allocated memory: {paddle_allocated_mem_before_run / Gb}") + + # 2. Profile run + self.model_runner.profile_run() + + # 3. Statistical memory information + paddle_reserved_mem_after_run = paddle.device.cuda.max_memory_reserved(local_rank) + paddle_allocated_mem_after_run = paddle.device.cuda.max_memory_allocated(local_rank) + + model_block_memory_used = self.cal_theortical_kvcache() + paddle_peak_increase = paddle_reserved_mem_after_run - paddle_allocated_mem_before_run + + paddle.device.cuda.empty_cache() + + info = pymxsml.mxSmlGetMemoryInfo(device_id) + after_run_meminfo_total = info.vramTotal * 1024 + after_run_meminfo_used = info.vramUse * 1024 + after_run_meminfo_free = after_run_meminfo_total - after_run_meminfo_used + + available_kv_cache_memory = ( + after_run_meminfo_total * self.cache_config.gpu_memory_utilization + - after_run_meminfo_used + - paddle_peak_increase + ) + available_kv_cache_memory += model_block_memory_used * self.parallel_config.total_block_num + + end_time = time.perf_counter() + + logger.info("After running the profile, the memory usage info of Metax GPU is as follows:") + logger.info(f"Device Index: {device_id}") + logger.info(f"Device Total memory: {after_run_meminfo_total / Gb}") + logger.info(f"Device used memory: {after_run_meminfo_used / Gb}") + logger.info(f"Device free memory: {after_run_meminfo_free / Gb}") + logger.info(f"Paddle reserved memory: {paddle_reserved_mem_after_run / Gb}") + logger.info(f"Paddle allocated memory: {paddle_allocated_mem_after_run / Gb}") + logger.info(f"Paddle available_kv_cache_memory: {available_kv_cache_memory / Gb}") + logger.info(f"Profile time: {end_time - start_time}") + + return available_kv_cache_memory + + def load_model(self) -> None: + """Load model""" + self.model_runner.load_model() + + def get_model(self) -> nn.Layer: + """Get current model""" + return self.model_runner.get_model() + + def initialize_cache(self, num_gpu_blocks: int) -> None: + """Initizlize the KV Cache with accurate num_gpu_blocks""" + # accurate cache size + self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks) + + def execute_model( + self, + model_forward_batch: Optional[List[Request]] = None, + ) -> Optional[ModelRunnerOutput]: + """ """ + output = self.model_runner.execute_model(model_forward_batch) + return output + + def preprocess_new_task(self, req_dicts: List[Request]) -> None: + """Process new requests and then start the decode loop + and workers and modelrunners should not perceive it. + """ + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + self.model_runner.insert_tasks_v1(req_dicts=req_dicts) + else: + self.model_runner.insert_prefill_inputs(req_dicts=req_dicts) + + def check_health(self) -> bool: + """ """ + return True + + def cal_theortical_kvcache(self) -> int: + """Calculate the block memory required""" + return self.model_runner.cal_theortical_kvcache() diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 357fd7b85..956c89a66 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -75,6 +75,10 @@ def get_worker(fd_config: FDConfig, local_rank: int, rank: int) -> WorkerBase: from fastdeploy.worker.gcu_worker import GcuWorker return GcuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank) + if current_platform.is_maca(): + from fastdeploy.worker.metax_worker import MetaxWorker + + return MetaxWorker(fd_config=fd_config, local_rank=local_rank, rank=rank) def init_distributed_environment(seed: int = 20) -> Tuple[int, int]: diff --git a/requirements_metaxgpu.txt b/requirements_metaxgpu.txt new file mode 100644 index 000000000..305f9825f --- /dev/null +++ b/requirements_metaxgpu.txt @@ -0,0 +1,39 @@ +setuptools>=62.3.0,<80.0 +pre-commit +yapf +flake8 +ruamel.yaml +zmq +aiozmq +openai>=1.93.0 +tqdm +pynvml +uvicorn +fastapi +paddleformers +redis +etcd3 +httpx +tool_helpers +cupy-cuda12x +pybind11[global] +tabulate +gradio +xlwt +visualdl +setuptools-scm>=8 +prometheus-client +decord +moviepy +triton +use-triton-in-paddle +crcmod +fastsafetensors==0.1.14 +msgpack +opentelemetry-api>=1.24.0 +opentelemetry-sdk>=1.24.0 +opentelemetry-instrumentation-redis +opentelemetry-instrumentation-mysql +opentelemetry-distro  +opentelemetry-exporter-otlp +opentelemetry-instrumentation-fastapi diff --git a/setup.py b/setup.py index 87099104b..53e5fec07 100644 --- a/setup.py +++ b/setup.py @@ -151,13 +151,15 @@ def load_requirements(): requirements_file_name = "requirements_iluvatar.txt" elif paddle.is_compiled_with_rocm(): requirements_file_name = "requirements_dcu.txt" + elif paddle.device.is_compiled_with_custom_device("metax_gpu"): + requirements_file_name = "requirements_metaxgpu.txt" requirements_path = os.path.join(os.path.dirname(__file__), requirements_file_name) with open(requirements_path, "r") as f: return [line.strip() for line in f if line.strip() and not line.startswith("#")] def get_device_type(): - """Get the device type (rocm/gpu/xpu/npu/cpu) that paddle is compiled with.""" + """Get the device type (rocm/gpu/xpu/npu/cpu/metax-gpu) that paddle is compiled with.""" if paddle.is_compiled_with_rocm(): return "rocm" elif paddle.is_compiled_with_cuda(): @@ -170,6 +172,8 @@ def get_device_type(): return "iluvatar-gpu" elif paddle.is_compiled_with_custom_device("gcu"): return "gcu" + elif paddle.device.is_compiled_with_custom_device("metax_gpu"): + return "metax-gpu" else: return "cpu"