[MetaxGPU] Support FastDeploy on metax gpu (#3241)

* [MetaxGPU] Support FastDeploy on metax gpu

* Update metax_worker.py

1. change worker log;
2. remove custom allreduce, adapt it later;
3. remove cuda graph;

* Update __init__.py

1. remove metax's key work comment

* Update __init__.py

1. remove metax's key word comment;
2. add fused_moe_kernel_paddle import

---------

Co-authored-by: yongqiangma <xing.wo@163.com>
This commit is contained in:
Kane2011
2025-08-13 11:11:54 +08:00
committed by GitHub
parent ed6bff215a
commit b4fef2cf29
29 changed files with 3224 additions and 11 deletions

View File

@@ -126,6 +126,16 @@ function copy_ops(){
return
fi
is_maca=`$python -c "import paddle; print(paddle.device.is_compiled_with_custom_device('metax_gpu'))"`
if [ "$is_maca" = "True" ]; then
DEVICE_TYPE="metax_gpu"
mkdir -p ../fastdeploy/model_executor/ops/base
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gpu
echo -e "MACA ops have been copy to fastdeploy"
return
fi
DEVICE_TYPE="cpu"
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
cd ../../../../

View File

@@ -509,6 +509,7 @@ static void PrintMatrix3(const T *mat_d, int num, std::string name) {
}
#ifndef PADDLE_WITH_HIP
#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
__forceinline__ __device__ uint32_t ld_flag_acquire(uint32_t *flag_addr,
int mode = 0) {
uint32_t flag;
@@ -541,7 +542,7 @@ __forceinline__ __device__ void st_flag_release(uint32_t *flag_addr,
"l"(flag_addr));
}
}
#endif
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
int max_shared_mem_per_block_opt_in = 0;
cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,

View File

@@ -564,6 +564,72 @@ elif paddle.is_compiled_with_custom_device("gcu"):
]
),
)
elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
maca_path = os.getenv("MACA_PATH", "/opt/maca")
json_dir = "third_party/nlohmann_json"
if not os.path.exists(json_dir) or not os.listdir(json_dir):
if not os.path.exists(json_dir):
os.makedirs(json_dir)
clone_git_repo("v3.11.3", "https://gitee.com/learnlov/mirrors_nlohmann_json.git", json_dir)
if not os.listdir(json_dir):
raise ValueError("Git clone nlohmann_json failed!")
sources = [
"gpu_ops/save_with_output.cc",
"gpu_ops/set_mask_value.cu",
"gpu_ops/set_value_by_flags.cu",
"gpu_ops/ngram_mask.cu",
"gpu_ops/gather_idx.cu",
"gpu_ops/get_output_ep.cc",
"gpu_ops/token_penalty_multi_scores.cu",
"gpu_ops/token_penalty_only_once.cu",
"gpu_ops/stop_generation.cu",
"gpu_ops/stop_generation_multi_ends.cu",
"gpu_ops/set_flags.cu",
"gpu_ops/fused_get_rope.cu",
"gpu_ops/get_padding_offset.cu",
"gpu_ops/update_inputs.cu",
"gpu_ops/update_inputs_beam.cu",
"gpu_ops/beam_search_softmax.cu",
"gpu_ops/rebuild_padding.cu",
"gpu_ops/step.cu",
"gpu_ops/step_reschedule.cu",
"gpu_ops/step_system_cache.cu",
"gpu_ops/set_data_ipc.cu",
"gpu_ops/read_data_ipc.cu",
"gpu_ops/dequant_int8.cu",
"gpu_ops/share_external_data.cu",
"gpu_ops/extract_text_token_output.cu",
"gpu_ops/moe/tritonmoe_preprocess.cu",
"gpu_ops/moe/moe_topk_select.cu",
"gpu_ops/recover_decode_task.cu",
]
sources += find_end_files("gpu_ops/speculate_decoding", ".cu")
sources += find_end_files("gpu_ops/speculate_decoding", ".cc")
setup(
name="fastdeploy_ops",
ext_modules=CUDAExtension(
sources=sources,
extra_compile_args={
"cxx": ["-O3"],
"nvcc": [
"-O3",
"-Ithird_party/nlohmann_json/include",
"-Igpu_ops",
"-DPADDLE_DEV",
"-DPADDLE_WITH_CUSTOM_DEVICE_METAX_GPU",
],
},
library_dirs=[os.path.join(maca_path, "lib")],
extra_link_args=["-lruntime_cu"],
include_dirs=[
os.path.join(maca_path, "include"),
os.path.join(maca_path, "include/mcr"),
os.path.join(maca_path, "include/common"),
],
),
)
else:
use_bf16 = envs.FD_CPU_USE_BF16 == "True"

View File

@@ -37,6 +37,8 @@ class ForwardMode(IntEnum):
DECODE = auto()
# Mixed mode
MIXED = auto()
# Native mode
NATIVE = auto()
def is_prefill(self):
"""Is Extend mode"""
@@ -50,6 +52,10 @@ class ForwardMode(IntEnum):
"""Is Mixed mode"""
return self == ForwardMode.MIXED
def is_native(self):
"""Is Native mode"""
return self == ForwardMode.NATIVE
@dataclass
class ForwardMeta:

View File

@@ -68,6 +68,7 @@ class SiluAndMul(nn.Layer):
or current_platform.is_xpu()
or current_platform.is_iluvatar()
or current_platform.is_dcu()
or current_platform.is_maca()
):
self.forward = self.forward_cuda
elif current_platform.is_gcu():

View File

@@ -86,6 +86,15 @@ class AttentionBackend(ABC):
layer,
forward_meta,
)
elif forward_meta.forward_mode.is_native():
return self.forward_native_backend(
q,
k,
v,
qkv,
layer,
forward_meta,
)
else:
return self.forward_extend(
q,
@@ -139,3 +148,15 @@ class AttentionBackend(ABC):
) -> paddle.Tensor:
"""Run a forward for extend."""
raise NotImplementedError
def forward_native_backend(
self,
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
qkv: paddle.Tensor,
layer: paddle.nn.Layer,
forward_meta: ForwardMeta,
) -> paddle.Tensor:
"""Run a forward for native."""
raise NotImplementedError

View File

@@ -48,3 +48,10 @@ if current_platform.is_dcu():
if hasattr(dcu, "__all__"):
globals().update({name: getattr(dcu, name) for name in dcu.__all__})
__all__.extend(dcu.__all__)
if current_platform.is_maca():
from . import metax
if hasattr(metax, "__all__"):
globals().update({name: getattr(metax, name) for name in metax.__all__})
__all__.extend(metax.__all__)

View File

@@ -0,0 +1,21 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .attention.flash_attn_backend import FlashAttentionBackend
from .moe.fused_moe_triton_metax_backend import MetaxTritonWeightOnlyMoEMethod
__all__ = [
"FlashAttentionBackend",
"MetaxTritonWeightOnlyMoEMethod",
]

View File

@@ -0,0 +1,30 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
metax gpu backend attention methods
"""
from .flash_attention_interface import (
flash_attn_func,
flash_attn_kvcache_func,
flash_attn_unpadded_func,
)
from .flash_attn_backend import FlashAttentionBackend
__all__ = [
"FlashAttentionBackend",
"flash_attn_func",
"flash_attn_unpadded_func",
"flash_attn_kvcache_func",
]

View File

@@ -0,0 +1,104 @@
import os
from typing import Optional, Tuple, Union
import paddle
from paddle import Tensor
for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
if lib.endswith(".so"):
paddle.utils.cpp_extension.extension_utils.load_op_meta_info_and_register_op(lib)
def flash_attn_func(
q: Tensor,
k: Tensor,
v: Tensor,
fixed_seed_offset: Optional[Tensor] = None,
attn_mask: Optional[Tensor] = None,
dropout_prob: float = 0.0,
causal: bool = False,
return_softmax: bool = False,
is_test: bool = True,
rng_name: str = "",
) -> Union[Tensor, Tuple[Tensor, ...]]:
return paddle._C_ops.flash_attn(
q, k, v, fixed_seed_offset, attn_mask, dropout_prob, causal, return_softmax, is_test, rng_name
)
def flash_attn_unpadded_func(
q: Tensor,
k: Tensor,
v: Tensor,
cu_seqlens_q: Tensor,
cu_seqlens_k: Tensor,
max_seqlen_q: Union[int, float],
max_seqlen_k: Union[int, float],
fixed_seed_offset: Optional[Tensor] = None,
attn_mask: Optional[Tensor] = None,
softmax_scale: float = 1.0,
dropout: float = 0.0,
causal: bool = False,
return_softmax: bool = False,
is_test: bool = True,
rng_name: str = "",
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
max_seqlen_q_t = paddle.to_tensor(max_seqlen_q, dtype="int64")
max_seqlen_k_t = paddle.to_tensor(max_seqlen_k, dtype="int64")
outputs = paddle._C_ops.flash_attn_unpadded(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
fixed_seed_offset,
attn_mask,
max_seqlen_q_t,
max_seqlen_k_t,
softmax_scale,
dropout,
causal,
return_softmax,
is_test,
rng_name,
)
return outputs
def flash_attn_kvcache_func(
q: Tensor,
k_cache: Tensor,
v_cache: Tensor,
seqlens_k: Tensor,
block_table: Tensor,
k: Optional[Tensor] = None,
v: Optional[Tensor] = None,
rotary_cos: Optional[Tensor] = None,
rotary_sin: Optional[Tensor] = None,
cache_batch_idx: Optional[Tensor] = None,
causal: bool = True,
is_rotary_interleaved: bool = False,
num_splits: int = 1,
dropout: float = 0.0,
return_softmax: bool = False,
) -> Tuple[Tensor, Tensor]:
out, softmax_lse = paddle._C_ops._run_custom_op(
"flash_attn_kvcache",
q,
k_cache,
v_cache,
k,
v,
seqlens_k,
rotary_cos,
rotary_sin,
cache_batch_idx,
block_table,
causal,
is_rotary_interleaved,
num_splits,
dropout,
return_softmax,
)
return out, softmax_lse

View File

@@ -0,0 +1,393 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
import math
import os
from dataclasses import dataclass, field
from typing import List, Optional
import paddle
import paddle.nn.functional as F
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.forward_meta import ForwardMeta, ForwardMode
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend,
AttentionMetadata,
)
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
from fastdeploy.model_executor.layers.backends.metax.attention.flash_attention_interface import (
flash_attn_kvcache_func,
flash_attn_unpadded_func,
)
@dataclass
class FlashAttentionMetadata(AttentionMetadata):
"""
FlashAttentionMetadata
"""
max_len_kv: paddle.Tensor = None
set_max_lengths: int = -1
encoder_batch_ids: paddle.Tensor = None
encoder_tile_ids_per_batch: paddle.Tensor = None
encoder_num_blocks: paddle.Tensor = None
kv_batch_ids: paddle.Tensor = None
kv_tile_ids_per_batch: paddle.Tensor = None
kv_num_blocks: paddle.Tensor = None
decoder_batch_ids: paddle.Tensor = None
decoder_tile_ids_per_batch: paddle.Tensor = None
decoder_num_blocks: paddle.Tensor = None
_dtype: paddle.dtype = paddle.bfloat16
encoder_max_partition_size: int = 32768
max_partition_size: int = 32768
block_tables: Optional[paddle.Tensor] = None
rotary_embs: Optional[paddle.Tensor] = None
attn_mask: Optional[paddle.Tensor] = None
encoder_block_shape_q: int = -1
decoder_block_shape_q: int = -1
_fuse_kernel_compute_dtype: str = "bf16"
# pd_disaggregation
kv_signal_metadata: Optional[paddle.Tensor] = None
kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list)
class FlashAttentionBackend(AttentionBackend):
"""
FlashAttentionBackend backend implementation.
"""
__infer_dynamic_dims_fields__ = ["attention_metadata"]
attention_metadata: FlashAttentionMetadata
def __init__(
self,
fd_config: FDConfig,
kv_num_heads: int,
num_heads: int,
head_dim: int,
encoder_block_shape_q: int = -1,
decoder_block_shape_q: int = -1,
) -> None:
"""
FlashAttentionBackend __init__
"""
super().__init__()
self.attention_metadata: FlashAttentionMetadata = None
self.block_size: int = fd_config.parallel_config.block_size
self.max_seq_len: int = fd_config.parallel_config.max_model_len
self.rope_theta: float = (
10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta
)
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
self.causal: bool = getattr(fd_config.model_config, "causal", True)
self.speculative_method: str = fd_config.speculative_config.method
self.use_speculate: bool = self.speculative_method is not None
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"])
self.encoder_block_shape_q: int = encoder_block_shape_q
self.decoder_block_shape_q: int = decoder_block_shape_q
self.kv_num_heads: int = kv_num_heads
self.num_heads: int = num_heads
self.head_dim: int = fd_config.model_config.head_dim
self.num_layers: int = fd_config.model_config.num_hidden_layers
self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", 32768))
self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode
self.start_layer_index: int = fd_config.model_config.start_layer_index
if fd_config.parallel_config.expert_parallel_rank is None:
fd_config.parallel_config.expert_parallel_rank = 0
self.rank, self.device_id = init_rank_and_device_id(fd_config)
def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
forward_meta.forward_mode = ForwardMode.NATIVE
return
def get_attntion_meta(self) -> AttentionMetadata:
"""get_attntion_meta"""
return self.attention_metadata
def get_kv_cache_shape(
self,
max_num_blocks: int,
kv_cache_quant_type: str = None,
):
"""
Caculate kv cache shape
"""
if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp":
return (
max_num_blocks,
self.kv_num_heads,
self.block_size,
self.head_dim // 2,
)
else:
return (
max_num_blocks,
self.kv_num_heads,
self.block_size,
self.head_dim,
)
def split_qkv(self, qkv, num_head_q, num_head_kv, dim):
q = qkv[:, : num_head_q * dim].reshape([-1, num_head_q, dim])
k = qkv[:, num_head_q * dim : num_head_q * dim + num_head_kv * dim].reshape([-1, num_head_kv, dim])
v = qkv[:, num_head_q * dim + num_head_kv * dim :].reshape([-1, num_head_kv, dim])
return q, k, v
def flash_attn_varlen(self, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k):
num_head = q.shape[1]
dim = q.shape[2]
q_ = q.reshape([-1, num_head, dim])
k_ = k.reshape([-1, num_head, dim])
v_ = v.reshape([-1, num_head, dim])
bsz = cu_seqlens_q.shape[0] - 1
out = []
for i in range(bsz):
start_q, end_q = cu_seqlens_q[i].item(), cu_seqlens_q[i + 1].item()
start_k, end_k = cu_seqlens_k[i].item(), cu_seqlens_k[i + 1].item()
qi = q_[start_q:end_q] # [seq_q, nh, dim]
ki = k_[start_k:end_k] # [seq_k, nh, dim]
vi = v_[start_k:end_k] # [seq_k, nh, dim]
qi = qi.transpose([1, 0, 2]) # [nh, seq_q, dim]
ki = ki.transpose([1, 2, 0]) # [nh, dim, seq_k]
vi = vi.transpose([1, 0, 2]) # [nh, seq_k, dim]
score = paddle.matmul(qi, ki) / math.sqrt(dim) # [nh, seq_q, seq_k]
prob = F.softmax(score, axis=-1)
o = paddle.matmul(prob, vi) # [nh, seq_q, dim]
o = o.transpose([1, 0, 2]) # [seq_q, nh, dim]
out.append(o)
return paddle.concat(out, axis=0) # [total_q, nh, dim]
def flash_attn_with_kvcache(self, q, cache_k, cache_v, cache_seqlens, block_tables=None):
bs, _, nh, dim = q.shape
out = []
for i in range(bs):
q_i = q[i] # [1, nh, dim]
k_i = cache_k[i, : cache_seqlens[i, 0]] # [seqlen, nh, dim]
v_i = cache_v[i, : cache_seqlens[i, 0]]
qi = q_i.transpose([1, 0, 2]) # [nh, 1, dim]
ki = k_i.transpose([1, 2, 0]) # [nh, dim, seqlen]
vi = v_i.transpose([1, 0, 2]) # [nh, seqlen, dim]
score = paddle.matmul(qi, ki) / math.sqrt(dim)
prob = F.softmax(score, axis=-1)
o = paddle.matmul(prob, vi).transpose([1, 0, 2]) # [1, nh, dim]
out.append(o)
return paddle.concat(out, axis=0) # [bs, nh, dim]
def block_cache_to_naive_cache(slef, cache_k, cache_v, bsz, block_tables, cache_seq_len):
_, num_head, blocksize, dim_head = cache_k.shape
out_cache_k = paddle.zeros(shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_k.dtype)
out_cache_v = paddle.zeros(shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_v.dtype)
for i in range(bsz):
for j in range(cache_seq_len):
out_cache_k[i, :, j, :] = cache_k[block_tables[i, j // blocksize], :, j % blocksize, :]
out_cache_v[i, :, j, :] = cache_v[block_tables[i, j // blocksize], :, j % blocksize, :]
return out_cache_k, out_cache_v
def block_cache_to_naive_cache__(self, cache_k, cache_v, bsz, block_tables, max_cache_seq_len):
_, num_head, blocksize, dim_head = cache_k.shape
out_cache_k = paddle.zeros(shape=[bsz, max_cache_seq_len + 1, num_head, dim_head], dtype=cache_k.dtype)
out_cache_v = paddle.zeros(shape=[bsz, max_cache_seq_len + 1, num_head, dim_head], dtype=cache_v.dtype)
for i in range(bsz):
for j in range(max_cache_seq_len):
out_cache_k[i, j, :, :] = cache_k[block_tables[i, j // blocksize], :, j % blocksize, :]
out_cache_v[i, j, :, :] = cache_v[block_tables[i, j // blocksize], :, j % blocksize, :]
return out_cache_k, out_cache_v
def update_encoder_kv_cache(self, k, v, seq_lens_encoder, cache_k, cache_v, block_tables):
_, num_head, blocksize, dim_head = cache_k.shape
offset = 0
for batch_idx, seq_len in enumerate(seq_lens_encoder.numpy()):
if seq_len == 0:
continue
for seq_idx in range(seq_len):
block_id = block_tables[batch_idx, seq_idx // blocksize]
assert block_id != -1
index = offset + seq_idx
cache_k[block_id, :, seq_idx % blocksize, :] = k[index, :, :]
cache_v[block_id, :, seq_idx % blocksize, :] = v[index, :, :]
offset += seq_len
def update_decoder_kv_cache(self, k, v, seq_lens_decoder, cache_k, cache_v, block_tables):
_, num_head, blocksize, dim_head = cache_k.shape
for batch_idx, seq_idx in enumerate(seq_lens_decoder.numpy()):
if seq_idx == 0:
continue
block_id = block_tables[batch_idx, seq_idx // blocksize]
assert block_id != -1
cache_k[block_id, :, seq_idx % blocksize, :] = k[batch_idx, :, :]
cache_v[block_id, :, seq_idx % blocksize, :] = v[batch_idx, :, :]
def apply_rope(self, qk, cos, sin):
rotate_half = paddle.reshape(
paddle.stack([-qk[..., 1::2], qk[..., 0::2]], axis=-1),
paddle.shape(qk),
)
out = paddle.add(paddle.multiply(qk, cos), paddle.multiply(rotate_half, sin))
return paddle.cast(out, qk.dtype)
def forward_native_backend(
self,
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
qkv: paddle.Tensor,
layer,
forward_meta: ForwardMeta,
):
bsz = forward_meta.seq_lens_this_time.shape[0]
num_head_q, num_head_kv, dim = layer.num_heads, layer.kv_num_heads, layer.head_dim
# 1. 分离 encoder / decoder 的 mask
seq_lens_encoder = forward_meta.seq_lens_encoder.squeeze(-1)
seq_lens_decoder = forward_meta.seq_lens_decoder.squeeze(-1)
seq_lens_this_time = forward_meta.seq_lens_this_time.squeeze(-1)
encoder_indices = []
decoder_indices = []
offset = 0
for i in range(bsz):
length = seq_lens_this_time[i].item()
if seq_lens_encoder[i] > 0:
encoder_indices.extend(range(offset, offset + length))
elif seq_lens_decoder[i] > 0:
decoder_indices.extend(range(offset, offset + length))
offset += length
encoder_indices = paddle.to_tensor(encoder_indices, dtype="int32")
decoder_indices = paddle.to_tensor(decoder_indices, dtype="int32")
encoder_qkv = paddle.index_select(qkv, encoder_indices, axis=0)
decoder_qkv = paddle.index_select(qkv, decoder_indices, axis=0)
# 2. 分解 encoder 和 decoder 的 qkv
encoder_q, encoder_k, encoder_v = self.split_qkv(encoder_qkv, num_head_q, num_head_kv, dim)
decoder_q, decoder_k, decoder_v = self.split_qkv(decoder_qkv, num_head_q, num_head_kv, dim)
cache_k = forward_meta.caches[2 * layer.layer_id]
cache_v = forward_meta.caches[2 * layer.layer_id + 1]
# 3. Rotary Embedding
if decoder_q.numel() != 0 or encoder_q.numel() != 0:
for batch_idx in range(forward_meta.seq_lens_this_time.shape[0]):
seq_len_i = forward_meta.seq_lens_this_time[batch_idx]
if seq_len_i == 0:
continue
cached_kv_len = seq_lens_decoder[batch_idx]
cu_seq_start_q = forward_meta.cu_seqlens_q[batch_idx]
cu_seq_end_q = forward_meta.cu_seqlens_q[batch_idx + 1]
if forward_meta.rotary_embs is not None and cu_seq_end_q > cu_seq_start_q:
cos = forward_meta.rotary_embs[0, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :]
sin = forward_meta.rotary_embs[1, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :]
def rope_func(qk):
qk[cu_seq_start_q:cu_seq_end_q] = self.apply_rope(qk[cu_seq_start_q:cu_seq_end_q], cos, sin)
if encoder_q.numel() != 0:
rope_func(encoder_q)
rope_func(encoder_k)
if decoder_q.numel() != 0:
rope_func(decoder_q)
rope_func(decoder_k)
# 4. Flash Attention for encoder
encoder_v = encoder_v
cu_seqlens_q = forward_meta.cu_seqlens_q
cu_seqlens_k = forward_meta.cu_seqlens_k
max_seqlen_q = paddle.max(seq_lens_this_time)
max_seqlen_k = max_seqlen_q
if encoder_q.numel() > 0:
encoder_out = flash_attn_unpadded_func(
encoder_q,
encoder_k,
encoder_v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
attn_mask=forward_meta.attn_mask,
causal=self.causal,
)
self.update_encoder_kv_cache(
encoder_k, encoder_v, seq_lens_encoder, cache_k, cache_v, forward_meta.block_tables
)
else:
encoder_out = None
# 5. decoder attention with kv cache
bs = decoder_q.shape[0]
decoder_q = decoder_q.reshape([bs, 1, num_head_q, dim])
decoder_k_ = decoder_k.reshape([bs, 1, num_head_kv, dim])
decoder_v_ = decoder_v.reshape([bs, 1, num_head_kv, dim])
cache_seqlens = paddle.index_select(forward_meta.seq_lens_decoder, decoder_indices, axis=0)
# 5.1 convert paged kv cache to continuous cache
if decoder_q.numel() > 0:
max_cache_seq_len = paddle.max(cache_seqlens)
c_cache_k, c_cache_v = self.block_cache_to_naive_cache__(
cache_k, cache_v, bs, forward_meta.block_tables, max_cache_seq_len
)
decoder_out = flash_attn_kvcache_func(
decoder_q,
c_cache_k,
c_cache_v,
cache_seqlens.squeeze(-1),
None,
decoder_k_,
decoder_v_,
causal=self.causal,
)
self.update_decoder_kv_cache(
decoder_k, decoder_v, seq_lens_decoder, cache_k, cache_v, forward_meta.block_tables
)
else:
decoder_out = None
# 6. 拼接 encoder_out 和 decoder_out
total_len = qkv.shape[0]
out = paddle.zeros([total_len, num_head_q, dim])
if encoder_out is not None:
out = paddle.tensor.put_along_axis(
out, encoder_indices.unsqueeze(-1).unsqueeze(-1), encoder_out[0], axis=0
)
if decoder_out is not None:
new_decoder_out = decoder_out[0].squeeze(1)
out = paddle.tensor.put_along_axis(
out, decoder_indices.unsqueeze(-1).unsqueeze(-1), new_decoder_out, axis=0
)
out.reshape_([total_len, num_head_q * dim])
return out

View File

@@ -0,0 +1,19 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .triton_moe_kernels import fused_moe_kernel_paddle
__all__ = [
"fused_moe_kernel_paddle",
]

View File

@@ -0,0 +1,276 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import paddle
from paddle import nn
import fastdeploy
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess
from fastdeploy.utils import ceil_div
from .triton_moe_kernels import fused_moe_kernel_paddle
class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
"""
Use Triton Group Gemm to compute Fused MoE.
"""
def __init__(self, quant_config=None):
"""
Triton Group Gemm to compute Fused MoE.
"""
self.quant_config = quant_config
self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"]
self.added_scale_attrs = [
"up_gate_proj_weight_scale",
"down_proj_weight_scale",
]
def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None:
"""process_prequanted_weights"""
pass
def create_weights(self, layer: nn.Layer, state_dict):
"""
Triton MoE create weight process.
"""
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
assert len(up_gate_proj_weights) == layer.num_local_experts
assert len(down_proj_weights) == layer.num_local_experts
if layer.quant_method.quant_config:
algo = layer.quant_method.quant_config.name()
assert up_gate_proj_weights[0].shape == [
layer.hidden_size,
layer.moe_intermediate_size * 2,
]
assert down_proj_weights[0].shape == [
layer.moe_intermediate_size,
layer.hidden_size,
]
up_gate_proj_tensor = paddle.stack(up_gate_proj_weights, axis=0)
down_proj_tensor = paddle.stack(down_proj_weights, axis=0)
if algo == "wint8":
max_bound = 127
elif algo == "wint4":
max_bound = 7
for idx, weight_tensor in enumerate([up_gate_proj_tensor, down_proj_tensor]):
weight_name = self.added_weight_attrs[idx]
scale_name = self.added_scale_attrs[idx]
quanted_weight_scale = weight_tensor.abs().max(axis=1)
if self.quant_config is not None:
quanted_weight = weight_tensor / quanted_weight_scale[:, None, :] * max_bound
quanted_weight = paddle.round(quanted_weight).astype("int8")
quanted_weight_scale = quanted_weight_scale / max_bound
setattr(
layer,
weight_name,
layer.create_parameter(
shape=quanted_weight.shape,
dtype=quanted_weight.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, weight_name).set_value(quanted_weight)
setattr(
layer,
scale_name,
layer.create_parameter(
shape=quanted_weight_scale.shape,
dtype=quanted_weight_scale.dtype,
),
)
getattr(layer, scale_name).set_value(quanted_weight_scale)
else:
setattr(
layer,
weight_name,
layer.create_parameter(
shape=quanted_weight.shape,
dtype=quanted_weight.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, weight_name).set_value(quanted_weight)
setattr(
layer,
scale_name,
layer.create_parameter(
shape=quanted_weight_scale.shape,
dtype=quanted_weight_scale.dtype,
),
)
getattr(layer, scale_name).set_value(quanted_weight_scale)
def apply(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate_out: paddle.Tensor,
) -> paddle.Tensor:
"""
Triton compute Fused MoE.
"""
token_num = x.shape[0]
top_k = layer.top_k
num_local_experts = layer.num_local_experts
top_k = layer.top_k
moe_intermediate_size = layer.moe_intermediate_size
hidden_size = layer.hidden_size
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out,
layer.gate_correction_bias,
top_k,
True, # apply_norm_weight,
False,
)
up_gate_proj_out = paddle.empty(
[token_num * top_k, moe_intermediate_size * 2],
dtype=x.dtype,
)
if self.quant_config is not None:
config = {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
}
else:
config = {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
}
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(
topk_ids, num_local_experts, config["BLOCK_SIZE_M"]
)
max_possible_num_post_padded = sorted_token_ids.shape[0]
grid = (
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"])
* ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]),
)
fused_moe_kernel_paddle[grid](
x,
layer.up_gate_proj_weight,
up_gate_proj_out,
None,
layer.up_gate_proj_weight_scale,
None,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
max_possible_num_post_padded,
token_num * top_k,
N=moe_intermediate_size * 2,
K=hidden_size,
stride_am=x.strides[0],
stride_ak=x.strides[1],
stride_be=layer.up_gate_proj_weight.strides[0],
stride_bk=layer.up_gate_proj_weight.strides[1],
stride_bn=layer.up_gate_proj_weight.strides[2],
stride_cm=up_gate_proj_out.strides[0],
stride_cn=up_gate_proj_out.strides[1],
#
stride_asm=-1,
stride_ask=-1,
stride_bse=layer.up_gate_proj_weight_scale.strides[0],
stride_bsk=-1,
stride_bsn=layer.up_gate_proj_weight_scale.strides[1],
group_n=-1,
group_k=-1,
# Meta-parameters
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
GROUP_SIZE_M=config["GROUP_SIZE_M"],
MUL_ROUTED_WEIGHT=False,
top_k=top_k,
compute_type_enum=1,
use_fp8_w8a8=False,
use_int8_w8a16=True,
even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0,
)
down_proj_input = paddle.incubate.nn.functional.swiglu(up_gate_proj_out)
down_proj_out = paddle.empty(
(token_num * top_k, hidden_size),
dtype=x.dtype,
)
grid = (
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"])
* ceil_div(hidden_size, config["BLOCK_SIZE_N"]),
)
fused_moe_kernel_paddle[grid](
down_proj_input,
layer.down_proj_weight,
down_proj_out,
None,
layer.down_proj_weight_scale,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
max_possible_num_post_padded,
token_num * top_k,
N=hidden_size,
K=moe_intermediate_size,
stride_am=down_proj_input.strides[0],
stride_ak=down_proj_input.strides[1],
stride_be=layer.down_proj_weight.strides[0],
stride_bk=layer.down_proj_weight.strides[1],
stride_bn=layer.down_proj_weight.strides[2],
stride_cm=down_proj_out.strides[0],
stride_cn=down_proj_out.strides[1],
stride_asm=-1,
stride_ask=-1,
stride_bse=layer.down_proj_weight_scale.strides[0],
stride_bsk=-1,
stride_bsn=layer.down_proj_weight_scale.strides[1],
group_n=-1,
group_k=-1,
# Meta-parameters
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
GROUP_SIZE_M=config["GROUP_SIZE_M"],
MUL_ROUTED_WEIGHT=True,
top_k=1,
compute_type_enum=1,
use_fp8_w8a8=False,
use_int8_w8a16=True,
even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0,
)
down_proj_out.reshape_([token_num, top_k, hidden_size])
out = down_proj_out.sum(axis=1)
return out

View File

@@ -0,0 +1,187 @@
"""
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import triton
import triton.language as tl
@triton.jit
def fused_moe_kernel_paddle(
a_ptr,
b_ptr,
c_ptr,
a_scale_ptr,
b_scale_ptr,
topk_weights_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
num_tokens_post_padded_ptr,
# Matrix dimensions
max_possible_num_post_padded,
num_valid_tokens,
N,
K,
stride_am,
stride_ak,
stride_be,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_asm,
stride_ask,
stride_bse,
stride_bsk,
stride_bsn,
# Block size for block-wise fp8 quantization
group_n: tl.constexpr,
group_k: tl.constexpr,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr,
compute_type_enum: tl.constexpr,
use_fp8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr,
even_Ks: tl.constexpr,
):
"""
Key Parameters:
- A: The input tensor representing tokens with shape (*, K), where '*' can
be any shape representing batches and K is the feature dimension of
each token.
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
the number of experts, K is the input feature dimension, and N is
the output feature dimension.
- C: The output cache tensor with shape (M, topk, N), where M is the
total number of tokens post padding, topk is the number of times
each token is repeated, and N is the output feature dimension.
- sorted_token_ids: A tensor containing the sorted indices of tokens,
repeated topk times and arranged by the expert index they are
assigned to.
- expert_ids: A tensor containing the indices of the expert for each
block. It determines which expert matrix from B should be used for
each block in A.
This kernel performs the multiplication of a token by its corresponding
expert matrix as determined by `expert_ids`. The sorting of
`sorted_token_ids` by expert index and padding ensures divisibility by
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
multiplication across different blocks processed by the same expert.
"""
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(max_possible_num_post_padded, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
assert compute_type_enum == 1
compute_type = tl.bfloat16
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
token_mask = offs_token < num_valid_tokens
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak)
off_experts = tl.load(expert_ids_ptr + pid_m)
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
if use_int8_w8a16:
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
b_scale = tl.load(b_scale_ptrs)
if use_fp8_w8a8:
if group_k > 0 and group_n > 0:
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
offs_bsn = offs_bn // group_n
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
else:
# (Zkk): every expert has one activation scale and weight scale.
a_scale = tl.load(a_scale_ptr + off_experts)
b_scale = tl.load(b_scale_ptr + off_experts)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
if even_Ks:
a = tl.load(
a_ptrs,
mask=token_mask[:, None],
other=0.0,
)
b = tl.load(b_ptrs, cache_modifier=".ca", eviction_policy="evict_first")
else:
a = tl.load(
a_ptrs,
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0,
)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# We accumulate along the K dimension.
if use_int8_w8a16:
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
elif use_fp8_w8a8:
if group_k > 0 and group_n > 0:
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_scale = tl.load(
a_scale_ptrs + offs_ks * stride_ask,
mask=token_mask,
other=0.0,
)
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
else:
accumulator = tl.dot(a, b, acc=accumulator)
else:
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
accumulator = accumulator * moe_weight[:, None]
if use_int8_w8a16:
accumulator = (accumulator * b_scale).to(compute_type)
elif use_fp8_w8a8:
if group_k > 0 and group_n > 0:
accumulator = accumulator.to(compute_type)
else:
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
else:
accumulator = accumulator.to(compute_type)
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)

View File

@@ -107,6 +107,7 @@ class LinearBase(nn.Layer):
or current_platform.is_iluvatar()
or current_platform.is_gcu()
or current_platform.is_dcu()
or current_platform.is_maca()
):
self.forward = self.forward_cuda
else:

View File

@@ -49,6 +49,12 @@ def get_moe_method():
from fastdeploy.model_executor.layers.backends import GCUFusedMoeMethod
return GCUFusedMoeMethod(None)
elif current_platform.is_maca():
from fastdeploy.model_executor.layers.backends import (
MetaxTritonWeightOnlyMoEMethod,
)
return MetaxTritonWeightOnlyMoEMethod(None)
raise NotImplementedError

View File

@@ -94,6 +94,16 @@ class WeightOnlyConfig(QuantConfigBase):
)
return DCUWeightOnlyLinearMethod(self)
elif current_platform.is_maca():
if isinstance(layer, FusedMoE):
from fastdeploy.model_executor.layers.backends import (
MetaxTritonWeightOnlyMoEMethod,
)
return MetaxTritonWeightOnlyMoEMethod(self)
else:
return GPUWeightOnlyLinearMethod(self)
else:
if isinstance(layer, FusedMoE):
if layer.use_method == "cutlass":
@@ -196,6 +206,16 @@ class WeightOnlyLinearMethod(QuantMethodBase):
raise NotImplementedError
def apply(self, layer, x):
if current_platform.is_maca():
linear_out = weight_only_linear(
x,
weight=layer.weight,
bias=layer.bias if layer.add_bias else None,
weight_scale=layer.weight_scale,
weight_dtype=("int8" if self.quant_config.name() == "wint8" else "int4"),
arch=80,
)
else:
linear_out = weight_only_linear(
x,
weight=layer.weight,
@@ -240,6 +260,7 @@ class GPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
algo=self.quant_config.algo,
arch=self.quant_config.weight_only_linear_arch,
)
if current_platform.is_maca():
quanted_weight_tensor = paddle.transpose(quanted_weight_tensor, [1, 0])
layer.weight.set_value(quanted_weight_tensor)
layer.weight_scale.set_value(weight_scale_tensor.astype(paddle.get_default_dtype()))

View File

@@ -51,6 +51,10 @@ class ErnieRotaryEmbedding:
# shape: [B, S, D]
rot_emb = paddle.concat([freqs.cos(), freqs.sin()], axis=-1)
return rot_emb
elif paddle.is_compiled_with_custom_device("metax_gpu"):
# shape: [B, S, D]
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim), dtype="float32")
emb = paddle.stack([freqs, freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim))
else:
# shape: [B, S, D/2]
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32")

View File

@@ -119,6 +119,23 @@ def apply_penalty_multi_scores(
min_dec_lens,
eos_token_ids,
)
elif current_platform.is_maca():
from fastdeploy.model_executor.ops.gpu import get_token_penalty_multi_scores
logits = get_token_penalty_multi_scores(
pre_token_ids,
prompt_ids,
prompt_lens,
logits,
repetition_penalties,
frequency_penalties,
presence_penalties,
temperature,
bad_words_token_ids,
step_idx,
min_dec_lens,
eos_token_ids,
)
else:
raise NotImplementedError

View File

@@ -177,6 +177,7 @@ class Sampler(nn.Layer):
or current_platform.is_iluvatar()
or current_platform.is_gcu()
or current_platform.is_dcu()
or current_platform.is_maca()
):
self.forward = self.forward_cuda
else:

View File

@@ -45,6 +45,14 @@ elif current_platform.is_dcu():
step_paddle,
update_inputs,
)
elif current_platform.is_maca():
from fastdeploy.model_executor.ops.gpu import (
get_padding_offset,
save_output,
set_stop_value_multi_ends,
step_paddle,
update_inputs,
)
else:
from fastdeploy.model_executor.ops.gpu import (
get_padding_offset,
@@ -225,6 +233,19 @@ def post_process_normal(
model_output.stop_seqs_len,
False,
) # multi ends
elif current_platform.is_maca():
set_stop_value_multi_ends(
sampler_output.sampled_token_ids,
model_output.stop_flags,
model_output.seq_lens_this_time,
model_output.eos_token_id,
model_output.next_tokens,
model_output.pre_ids,
model_output.step_idx,
model_output.stop_token_ids,
model_output.stop_seqs_len,
False,
) # multi ends
else:
set_stop_value_multi_ends(
sampler_output.sampled_token_ids,
@@ -573,6 +594,18 @@ def rebuild_padding(
output_padding_offset,
max_input_length,
)
elif current_platform.is_maca():
from fastdeploy.model_executor.ops.gpu import rebuild_padding
hidden_states = rebuild_padding(
tmp_out,
cum_offsets,
seq_len_this_time,
seq_lens_decoder,
seq_lens_encoder,
output_padding_offset,
max_input_length,
)
else:
raise RuntimeError("Not supported platform")
return hidden_states

View File

@@ -23,6 +23,7 @@ from .cuda import CUDAPlatform
from .dcu import DCUPlatform
from .gcu import GCUPlatform
from .iluvatar import IluvatarPlatform
from .maca import MACAPlatform
from .npu import NPUPlatform
from .xpu import XPUPlatform
@@ -46,6 +47,8 @@ def __getattr__(name: str):
_current_platform = IluvatarPlatform()
elif paddle.is_compiled_with_custom_device("gcu"):
_current_platform = GCUPlatform()
elif paddle.is_compiled_with_custom_device("metax_gpu"):
_current_platform = MACAPlatform()
else:
_current_platform = CPUPlatform()
return _current_platform

View File

@@ -77,6 +77,12 @@ class Platform:
"""
return paddle.is_compiled_with_custom_device("gcu")
def is_maca(self) -> bool:
"""
whether platform is metax gpu
"""
return paddle.is_compiled_with_custom_device("metax_gpu")
@classmethod
def get_attention_backend_cls(self, selected_backend):
"""Get the attention backend"""

View File

@@ -0,0 +1,65 @@
"""
# Copyright (c) 2025 MetaX-tech Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
"""
maca platform file
"""
import paddle
from paddleformers.utils.log import logger
from .base import Platform, _Backend
class MACAPlatform(Platform):
"""
maca platform class
"""
device_name = "metax_gpu"
@classmethod
def available(self):
"""
Check whether MACA is available.
"""
try:
assert len(paddle.static.cuda_places()) > 0
return True
except Exception as e:
logger.warning(
"You are using GPU version PaddlePaddle, but there is no GPU "
"detected on your machine. Maybe CUDA devices is not set properly."
f"\n Original Error is {e}"
)
return False
@classmethod
def get_attention_backend_cls(cls, selected_backend: _Backend):
"""
get_attention_backend_cls
"""
if selected_backend == _Backend.NATIVE_ATTN:
logger.info("Using NATIVE ATTN backend.")
return "fastdeploy.model_executor.layers.attention.PaddleNativeAttnBackend"
elif selected_backend == _Backend.APPEND_ATTN:
logger.info("Using FLASH ATTN backend to instead of attend attention.")
return "fastdeploy.model_executor.layers.backends.metax.attention.flash_attn_backend.FlashAttentionBackend"
else:
raise ValueError(
"Invalid attention backend you specified.\n"
"Now only support [NATIVE_ATTN, MLA_ATTN, APPEND_ATTN] in cuda place."
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,203 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import gc
import os
import time
from typing import List, Optional
import paddle
from paddle import nn
from fastdeploy import envs
from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request
from fastdeploy.utils import get_logger
from fastdeploy.worker.metax_model_runner import MetaxModelRunner
from fastdeploy.worker.output import ModelRunnerOutput
from fastdeploy.worker.worker_base import WorkerBase
logger = get_logger("metax_worker", "metax_worker.log")
class MetaxWorker(WorkerBase):
def __init__(
self,
fd_config: FDConfig,
local_rank: int,
rank: int,
):
super().__init__(
fd_config=fd_config,
local_rank=local_rank,
rank=rank,
)
pass
def init_device(self):
"""
Initialize device and construct model runner
"""
self.max_chips_per_node = 8
if paddle.is_compiled_with_custom_device("metax_gpu"):
# Set evironment variable
self.device_ids = self.parallel_config.device_ids.split(",")
self.device = f"metax_gpu:{self.local_rank % self.max_chips_per_node}"
paddle.device.set_device(self.device)
paddle.set_default_dtype(self.parallel_config.dtype)
gc.collect()
paddle.device.cuda.empty_cache()
else:
raise RuntimeError(f"Not support device type: {self.device_config.device}")
# Construct model runner
self.model_runner: MetaxModelRunner = MetaxModelRunner(
fd_config=self.fd_config,
device=self.device,
device_id=self.device_ids[self.local_rank % self.max_chips_per_node],
rank=self.rank,
local_rank=self.local_rank,
)
def exist_prefill(self):
"""
check whether prefill stage exist
"""
return self.model_runner.exist_prefill()
def determine_available_memory(self) -> int:
"""
Profiles the peak memory usage of the model to determine how much
memory can be used for KV cache without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
Tip:
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
"""Will implement later"""
# 1. Record memory state before profile run
start_time = time.perf_counter()
Gb = 1024**3
local_rank = self.local_rank % self.max_chips_per_node
paddle.device.cuda.reset_max_memory_reserved(local_rank)
paddle.device.cuda.reset_max_memory_allocated(local_rank)
# max memory for Allocator
paddle_reserved_mem_before_run = paddle.device.cuda.max_memory_reserved(local_rank)
# max memory for Tensor
paddle_allocated_mem_before_run = paddle.device.cuda.max_memory_allocated(local_rank) # not reserved
device_id = int(self.device_ids[local_rank])
if os.getenv("MACA_VISIBLE_DEVICES") is not None:
device_id = int(os.getenv("MACA_VISIBLE_DEVICES").split(",")[device_id])
import pymxsml
pymxsml.mxSmlInit()
info = pymxsml.mxSmlGetMemoryInfo(device_id)
before_run_meminfo_total = info.vramTotal * 1024
before_run_meminfo_used = info.vramUse * 1024
before_run_meminfo_free = before_run_meminfo_total - before_run_meminfo_used
logger.info("Before running the profile, the memory usage info of Metax GPU is as follows:")
logger.info(f"Device Index: {device_id}")
logger.info(f"Device Total memory: {before_run_meminfo_total / Gb}")
logger.info(f"Device used memory: {before_run_meminfo_used / Gb}")
logger.info(f"Device free memory: {before_run_meminfo_free / Gb}")
logger.info(f"Paddle reserved memory: {paddle_reserved_mem_before_run / Gb}")
logger.info(f"Paddle allocated memory: {paddle_allocated_mem_before_run / Gb}")
# 2. Profile run
self.model_runner.profile_run()
# 3. Statistical memory information
paddle_reserved_mem_after_run = paddle.device.cuda.max_memory_reserved(local_rank)
paddle_allocated_mem_after_run = paddle.device.cuda.max_memory_allocated(local_rank)
model_block_memory_used = self.cal_theortical_kvcache()
paddle_peak_increase = paddle_reserved_mem_after_run - paddle_allocated_mem_before_run
paddle.device.cuda.empty_cache()
info = pymxsml.mxSmlGetMemoryInfo(device_id)
after_run_meminfo_total = info.vramTotal * 1024
after_run_meminfo_used = info.vramUse * 1024
after_run_meminfo_free = after_run_meminfo_total - after_run_meminfo_used
available_kv_cache_memory = (
after_run_meminfo_total * self.cache_config.gpu_memory_utilization
- after_run_meminfo_used
- paddle_peak_increase
)
available_kv_cache_memory += model_block_memory_used * self.parallel_config.total_block_num
end_time = time.perf_counter()
logger.info("After running the profile, the memory usage info of Metax GPU is as follows:")
logger.info(f"Device Index: {device_id}")
logger.info(f"Device Total memory: {after_run_meminfo_total / Gb}")
logger.info(f"Device used memory: {after_run_meminfo_used / Gb}")
logger.info(f"Device free memory: {after_run_meminfo_free / Gb}")
logger.info(f"Paddle reserved memory: {paddle_reserved_mem_after_run / Gb}")
logger.info(f"Paddle allocated memory: {paddle_allocated_mem_after_run / Gb}")
logger.info(f"Paddle available_kv_cache_memory: {available_kv_cache_memory / Gb}")
logger.info(f"Profile time: {end_time - start_time}")
return available_kv_cache_memory
def load_model(self) -> None:
"""Load model"""
self.model_runner.load_model()
def get_model(self) -> nn.Layer:
"""Get current model"""
return self.model_runner.get_model()
def initialize_cache(self, num_gpu_blocks: int) -> None:
"""Initizlize the KV Cache with accurate num_gpu_blocks"""
# accurate cache size
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)
def execute_model(
self,
model_forward_batch: Optional[List[Request]] = None,
) -> Optional[ModelRunnerOutput]:
""" """
output = self.model_runner.execute_model(model_forward_batch)
return output
def preprocess_new_task(self, req_dicts: List[Request]) -> None:
"""Process new requests and then start the decode loop
and workers and modelrunners should not perceive it.
"""
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.model_runner.insert_tasks_v1(req_dicts=req_dicts)
else:
self.model_runner.insert_prefill_inputs(req_dicts=req_dicts)
def check_health(self) -> bool:
""" """
return True
def cal_theortical_kvcache(self) -> int:
"""Calculate the block memory required"""
return self.model_runner.cal_theortical_kvcache()

View File

@@ -75,6 +75,10 @@ def get_worker(fd_config: FDConfig, local_rank: int, rank: int) -> WorkerBase:
from fastdeploy.worker.gcu_worker import GcuWorker
return GcuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
if current_platform.is_maca():
from fastdeploy.worker.metax_worker import MetaxWorker
return MetaxWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
def init_distributed_environment(seed: int = 20) -> Tuple[int, int]:

39
requirements_metaxgpu.txt Normal file
View File

@@ -0,0 +1,39 @@
setuptools>=62.3.0,<80.0
pre-commit
yapf
flake8
ruamel.yaml
zmq
aiozmq
openai>=1.93.0
tqdm
pynvml
uvicorn
fastapi
paddleformers
redis
etcd3
httpx
tool_helpers
cupy-cuda12x
pybind11[global]
tabulate
gradio
xlwt
visualdl
setuptools-scm>=8
prometheus-client
decord
moviepy
triton
use-triton-in-paddle
crcmod
fastsafetensors==0.1.14
msgpack
opentelemetry-api>=1.24.0
opentelemetry-sdk>=1.24.0
opentelemetry-instrumentation-redis
opentelemetry-instrumentation-mysql
opentelemetry-distro 
opentelemetry-exporter-otlp
opentelemetry-instrumentation-fastapi

View File

@@ -151,13 +151,15 @@ def load_requirements():
requirements_file_name = "requirements_iluvatar.txt"
elif paddle.is_compiled_with_rocm():
requirements_file_name = "requirements_dcu.txt"
elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
requirements_file_name = "requirements_metaxgpu.txt"
requirements_path = os.path.join(os.path.dirname(__file__), requirements_file_name)
with open(requirements_path, "r") as f:
return [line.strip() for line in f if line.strip() and not line.startswith("#")]
def get_device_type():
"""Get the device type (rocm/gpu/xpu/npu/cpu) that paddle is compiled with."""
"""Get the device type (rocm/gpu/xpu/npu/cpu/metax-gpu) that paddle is compiled with."""
if paddle.is_compiled_with_rocm():
return "rocm"
elif paddle.is_compiled_with_cuda():
@@ -170,6 +172,8 @@ def get_device_type():
return "iluvatar-gpu"
elif paddle.is_compiled_with_custom_device("gcu"):
return "gcu"
elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
return "metax-gpu"
else:
return "cpu"