[MetaxGPU] Support FastDeploy on metax gpu (#3241)

* [MetaxGPU] Support FastDeploy on metax gpu

* Update metax_worker.py

1. change worker log;
2. remove custom allreduce, adapt it later;
3. remove cuda graph;

* Update __init__.py

1. remove metax's key work comment

* Update __init__.py

1. remove metax's key word comment;
2. add fused_moe_kernel_paddle import

---------

Co-authored-by: yongqiangma <xing.wo@163.com>
This commit is contained in:
Kane2011
2025-08-13 11:11:54 +08:00
committed by GitHub
parent ed6bff215a
commit b4fef2cf29
29 changed files with 3224 additions and 11 deletions

View File

@@ -68,6 +68,7 @@ class SiluAndMul(nn.Layer):
or current_platform.is_xpu()
or current_platform.is_iluvatar()
or current_platform.is_dcu()
or current_platform.is_maca()
):
self.forward = self.forward_cuda
elif current_platform.is_gcu():

View File

@@ -86,6 +86,15 @@ class AttentionBackend(ABC):
layer,
forward_meta,
)
elif forward_meta.forward_mode.is_native():
return self.forward_native_backend(
q,
k,
v,
qkv,
layer,
forward_meta,
)
else:
return self.forward_extend(
q,
@@ -139,3 +148,15 @@ class AttentionBackend(ABC):
) -> paddle.Tensor:
"""Run a forward for extend."""
raise NotImplementedError
def forward_native_backend(
self,
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
qkv: paddle.Tensor,
layer: paddle.nn.Layer,
forward_meta: ForwardMeta,
) -> paddle.Tensor:
"""Run a forward for native."""
raise NotImplementedError

View File

@@ -48,3 +48,10 @@ if current_platform.is_dcu():
if hasattr(dcu, "__all__"):
globals().update({name: getattr(dcu, name) for name in dcu.__all__})
__all__.extend(dcu.__all__)
if current_platform.is_maca():
from . import metax
if hasattr(metax, "__all__"):
globals().update({name: getattr(metax, name) for name in metax.__all__})
__all__.extend(metax.__all__)

View File

@@ -0,0 +1,21 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .attention.flash_attn_backend import FlashAttentionBackend
from .moe.fused_moe_triton_metax_backend import MetaxTritonWeightOnlyMoEMethod
__all__ = [
"FlashAttentionBackend",
"MetaxTritonWeightOnlyMoEMethod",
]

View File

@@ -0,0 +1,30 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
metax gpu backend attention methods
"""
from .flash_attention_interface import (
flash_attn_func,
flash_attn_kvcache_func,
flash_attn_unpadded_func,
)
from .flash_attn_backend import FlashAttentionBackend
__all__ = [
"FlashAttentionBackend",
"flash_attn_func",
"flash_attn_unpadded_func",
"flash_attn_kvcache_func",
]

View File

@@ -0,0 +1,104 @@
import os
from typing import Optional, Tuple, Union
import paddle
from paddle import Tensor
for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
if lib.endswith(".so"):
paddle.utils.cpp_extension.extension_utils.load_op_meta_info_and_register_op(lib)
def flash_attn_func(
q: Tensor,
k: Tensor,
v: Tensor,
fixed_seed_offset: Optional[Tensor] = None,
attn_mask: Optional[Tensor] = None,
dropout_prob: float = 0.0,
causal: bool = False,
return_softmax: bool = False,
is_test: bool = True,
rng_name: str = "",
) -> Union[Tensor, Tuple[Tensor, ...]]:
return paddle._C_ops.flash_attn(
q, k, v, fixed_seed_offset, attn_mask, dropout_prob, causal, return_softmax, is_test, rng_name
)
def flash_attn_unpadded_func(
q: Tensor,
k: Tensor,
v: Tensor,
cu_seqlens_q: Tensor,
cu_seqlens_k: Tensor,
max_seqlen_q: Union[int, float],
max_seqlen_k: Union[int, float],
fixed_seed_offset: Optional[Tensor] = None,
attn_mask: Optional[Tensor] = None,
softmax_scale: float = 1.0,
dropout: float = 0.0,
causal: bool = False,
return_softmax: bool = False,
is_test: bool = True,
rng_name: str = "",
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
max_seqlen_q_t = paddle.to_tensor(max_seqlen_q, dtype="int64")
max_seqlen_k_t = paddle.to_tensor(max_seqlen_k, dtype="int64")
outputs = paddle._C_ops.flash_attn_unpadded(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
fixed_seed_offset,
attn_mask,
max_seqlen_q_t,
max_seqlen_k_t,
softmax_scale,
dropout,
causal,
return_softmax,
is_test,
rng_name,
)
return outputs
def flash_attn_kvcache_func(
q: Tensor,
k_cache: Tensor,
v_cache: Tensor,
seqlens_k: Tensor,
block_table: Tensor,
k: Optional[Tensor] = None,
v: Optional[Tensor] = None,
rotary_cos: Optional[Tensor] = None,
rotary_sin: Optional[Tensor] = None,
cache_batch_idx: Optional[Tensor] = None,
causal: bool = True,
is_rotary_interleaved: bool = False,
num_splits: int = 1,
dropout: float = 0.0,
return_softmax: bool = False,
) -> Tuple[Tensor, Tensor]:
out, softmax_lse = paddle._C_ops._run_custom_op(
"flash_attn_kvcache",
q,
k_cache,
v_cache,
k,
v,
seqlens_k,
rotary_cos,
rotary_sin,
cache_batch_idx,
block_table,
causal,
is_rotary_interleaved,
num_splits,
dropout,
return_softmax,
)
return out, softmax_lse

View File

@@ -0,0 +1,393 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
import math
import os
from dataclasses import dataclass, field
from typing import List, Optional
import paddle
import paddle.nn.functional as F
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.forward_meta import ForwardMeta, ForwardMode
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend,
AttentionMetadata,
)
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
from fastdeploy.model_executor.layers.backends.metax.attention.flash_attention_interface import (
flash_attn_kvcache_func,
flash_attn_unpadded_func,
)
@dataclass
class FlashAttentionMetadata(AttentionMetadata):
"""
FlashAttentionMetadata
"""
max_len_kv: paddle.Tensor = None
set_max_lengths: int = -1
encoder_batch_ids: paddle.Tensor = None
encoder_tile_ids_per_batch: paddle.Tensor = None
encoder_num_blocks: paddle.Tensor = None
kv_batch_ids: paddle.Tensor = None
kv_tile_ids_per_batch: paddle.Tensor = None
kv_num_blocks: paddle.Tensor = None
decoder_batch_ids: paddle.Tensor = None
decoder_tile_ids_per_batch: paddle.Tensor = None
decoder_num_blocks: paddle.Tensor = None
_dtype: paddle.dtype = paddle.bfloat16
encoder_max_partition_size: int = 32768
max_partition_size: int = 32768
block_tables: Optional[paddle.Tensor] = None
rotary_embs: Optional[paddle.Tensor] = None
attn_mask: Optional[paddle.Tensor] = None
encoder_block_shape_q: int = -1
decoder_block_shape_q: int = -1
_fuse_kernel_compute_dtype: str = "bf16"
# pd_disaggregation
kv_signal_metadata: Optional[paddle.Tensor] = None
kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list)
class FlashAttentionBackend(AttentionBackend):
"""
FlashAttentionBackend backend implementation.
"""
__infer_dynamic_dims_fields__ = ["attention_metadata"]
attention_metadata: FlashAttentionMetadata
def __init__(
self,
fd_config: FDConfig,
kv_num_heads: int,
num_heads: int,
head_dim: int,
encoder_block_shape_q: int = -1,
decoder_block_shape_q: int = -1,
) -> None:
"""
FlashAttentionBackend __init__
"""
super().__init__()
self.attention_metadata: FlashAttentionMetadata = None
self.block_size: int = fd_config.parallel_config.block_size
self.max_seq_len: int = fd_config.parallel_config.max_model_len
self.rope_theta: float = (
10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta
)
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
self.causal: bool = getattr(fd_config.model_config, "causal", True)
self.speculative_method: str = fd_config.speculative_config.method
self.use_speculate: bool = self.speculative_method is not None
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"])
self.encoder_block_shape_q: int = encoder_block_shape_q
self.decoder_block_shape_q: int = decoder_block_shape_q
self.kv_num_heads: int = kv_num_heads
self.num_heads: int = num_heads
self.head_dim: int = fd_config.model_config.head_dim
self.num_layers: int = fd_config.model_config.num_hidden_layers
self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", 32768))
self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode
self.start_layer_index: int = fd_config.model_config.start_layer_index
if fd_config.parallel_config.expert_parallel_rank is None:
fd_config.parallel_config.expert_parallel_rank = 0
self.rank, self.device_id = init_rank_and_device_id(fd_config)
def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
forward_meta.forward_mode = ForwardMode.NATIVE
return
def get_attntion_meta(self) -> AttentionMetadata:
"""get_attntion_meta"""
return self.attention_metadata
def get_kv_cache_shape(
self,
max_num_blocks: int,
kv_cache_quant_type: str = None,
):
"""
Caculate kv cache shape
"""
if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp":
return (
max_num_blocks,
self.kv_num_heads,
self.block_size,
self.head_dim // 2,
)
else:
return (
max_num_blocks,
self.kv_num_heads,
self.block_size,
self.head_dim,
)
def split_qkv(self, qkv, num_head_q, num_head_kv, dim):
q = qkv[:, : num_head_q * dim].reshape([-1, num_head_q, dim])
k = qkv[:, num_head_q * dim : num_head_q * dim + num_head_kv * dim].reshape([-1, num_head_kv, dim])
v = qkv[:, num_head_q * dim + num_head_kv * dim :].reshape([-1, num_head_kv, dim])
return q, k, v
def flash_attn_varlen(self, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k):
num_head = q.shape[1]
dim = q.shape[2]
q_ = q.reshape([-1, num_head, dim])
k_ = k.reshape([-1, num_head, dim])
v_ = v.reshape([-1, num_head, dim])
bsz = cu_seqlens_q.shape[0] - 1
out = []
for i in range(bsz):
start_q, end_q = cu_seqlens_q[i].item(), cu_seqlens_q[i + 1].item()
start_k, end_k = cu_seqlens_k[i].item(), cu_seqlens_k[i + 1].item()
qi = q_[start_q:end_q] # [seq_q, nh, dim]
ki = k_[start_k:end_k] # [seq_k, nh, dim]
vi = v_[start_k:end_k] # [seq_k, nh, dim]
qi = qi.transpose([1, 0, 2]) # [nh, seq_q, dim]
ki = ki.transpose([1, 2, 0]) # [nh, dim, seq_k]
vi = vi.transpose([1, 0, 2]) # [nh, seq_k, dim]
score = paddle.matmul(qi, ki) / math.sqrt(dim) # [nh, seq_q, seq_k]
prob = F.softmax(score, axis=-1)
o = paddle.matmul(prob, vi) # [nh, seq_q, dim]
o = o.transpose([1, 0, 2]) # [seq_q, nh, dim]
out.append(o)
return paddle.concat(out, axis=0) # [total_q, nh, dim]
def flash_attn_with_kvcache(self, q, cache_k, cache_v, cache_seqlens, block_tables=None):
bs, _, nh, dim = q.shape
out = []
for i in range(bs):
q_i = q[i] # [1, nh, dim]
k_i = cache_k[i, : cache_seqlens[i, 0]] # [seqlen, nh, dim]
v_i = cache_v[i, : cache_seqlens[i, 0]]
qi = q_i.transpose([1, 0, 2]) # [nh, 1, dim]
ki = k_i.transpose([1, 2, 0]) # [nh, dim, seqlen]
vi = v_i.transpose([1, 0, 2]) # [nh, seqlen, dim]
score = paddle.matmul(qi, ki) / math.sqrt(dim)
prob = F.softmax(score, axis=-1)
o = paddle.matmul(prob, vi).transpose([1, 0, 2]) # [1, nh, dim]
out.append(o)
return paddle.concat(out, axis=0) # [bs, nh, dim]
def block_cache_to_naive_cache(slef, cache_k, cache_v, bsz, block_tables, cache_seq_len):
_, num_head, blocksize, dim_head = cache_k.shape
out_cache_k = paddle.zeros(shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_k.dtype)
out_cache_v = paddle.zeros(shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_v.dtype)
for i in range(bsz):
for j in range(cache_seq_len):
out_cache_k[i, :, j, :] = cache_k[block_tables[i, j // blocksize], :, j % blocksize, :]
out_cache_v[i, :, j, :] = cache_v[block_tables[i, j // blocksize], :, j % blocksize, :]
return out_cache_k, out_cache_v
def block_cache_to_naive_cache__(self, cache_k, cache_v, bsz, block_tables, max_cache_seq_len):
_, num_head, blocksize, dim_head = cache_k.shape
out_cache_k = paddle.zeros(shape=[bsz, max_cache_seq_len + 1, num_head, dim_head], dtype=cache_k.dtype)
out_cache_v = paddle.zeros(shape=[bsz, max_cache_seq_len + 1, num_head, dim_head], dtype=cache_v.dtype)
for i in range(bsz):
for j in range(max_cache_seq_len):
out_cache_k[i, j, :, :] = cache_k[block_tables[i, j // blocksize], :, j % blocksize, :]
out_cache_v[i, j, :, :] = cache_v[block_tables[i, j // blocksize], :, j % blocksize, :]
return out_cache_k, out_cache_v
def update_encoder_kv_cache(self, k, v, seq_lens_encoder, cache_k, cache_v, block_tables):
_, num_head, blocksize, dim_head = cache_k.shape
offset = 0
for batch_idx, seq_len in enumerate(seq_lens_encoder.numpy()):
if seq_len == 0:
continue
for seq_idx in range(seq_len):
block_id = block_tables[batch_idx, seq_idx // blocksize]
assert block_id != -1
index = offset + seq_idx
cache_k[block_id, :, seq_idx % blocksize, :] = k[index, :, :]
cache_v[block_id, :, seq_idx % blocksize, :] = v[index, :, :]
offset += seq_len
def update_decoder_kv_cache(self, k, v, seq_lens_decoder, cache_k, cache_v, block_tables):
_, num_head, blocksize, dim_head = cache_k.shape
for batch_idx, seq_idx in enumerate(seq_lens_decoder.numpy()):
if seq_idx == 0:
continue
block_id = block_tables[batch_idx, seq_idx // blocksize]
assert block_id != -1
cache_k[block_id, :, seq_idx % blocksize, :] = k[batch_idx, :, :]
cache_v[block_id, :, seq_idx % blocksize, :] = v[batch_idx, :, :]
def apply_rope(self, qk, cos, sin):
rotate_half = paddle.reshape(
paddle.stack([-qk[..., 1::2], qk[..., 0::2]], axis=-1),
paddle.shape(qk),
)
out = paddle.add(paddle.multiply(qk, cos), paddle.multiply(rotate_half, sin))
return paddle.cast(out, qk.dtype)
def forward_native_backend(
self,
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
qkv: paddle.Tensor,
layer,
forward_meta: ForwardMeta,
):
bsz = forward_meta.seq_lens_this_time.shape[0]
num_head_q, num_head_kv, dim = layer.num_heads, layer.kv_num_heads, layer.head_dim
# 1. 分离 encoder / decoder 的 mask
seq_lens_encoder = forward_meta.seq_lens_encoder.squeeze(-1)
seq_lens_decoder = forward_meta.seq_lens_decoder.squeeze(-1)
seq_lens_this_time = forward_meta.seq_lens_this_time.squeeze(-1)
encoder_indices = []
decoder_indices = []
offset = 0
for i in range(bsz):
length = seq_lens_this_time[i].item()
if seq_lens_encoder[i] > 0:
encoder_indices.extend(range(offset, offset + length))
elif seq_lens_decoder[i] > 0:
decoder_indices.extend(range(offset, offset + length))
offset += length
encoder_indices = paddle.to_tensor(encoder_indices, dtype="int32")
decoder_indices = paddle.to_tensor(decoder_indices, dtype="int32")
encoder_qkv = paddle.index_select(qkv, encoder_indices, axis=0)
decoder_qkv = paddle.index_select(qkv, decoder_indices, axis=0)
# 2. 分解 encoder 和 decoder 的 qkv
encoder_q, encoder_k, encoder_v = self.split_qkv(encoder_qkv, num_head_q, num_head_kv, dim)
decoder_q, decoder_k, decoder_v = self.split_qkv(decoder_qkv, num_head_q, num_head_kv, dim)
cache_k = forward_meta.caches[2 * layer.layer_id]
cache_v = forward_meta.caches[2 * layer.layer_id + 1]
# 3. Rotary Embedding
if decoder_q.numel() != 0 or encoder_q.numel() != 0:
for batch_idx in range(forward_meta.seq_lens_this_time.shape[0]):
seq_len_i = forward_meta.seq_lens_this_time[batch_idx]
if seq_len_i == 0:
continue
cached_kv_len = seq_lens_decoder[batch_idx]
cu_seq_start_q = forward_meta.cu_seqlens_q[batch_idx]
cu_seq_end_q = forward_meta.cu_seqlens_q[batch_idx + 1]
if forward_meta.rotary_embs is not None and cu_seq_end_q > cu_seq_start_q:
cos = forward_meta.rotary_embs[0, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :]
sin = forward_meta.rotary_embs[1, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :]
def rope_func(qk):
qk[cu_seq_start_q:cu_seq_end_q] = self.apply_rope(qk[cu_seq_start_q:cu_seq_end_q], cos, sin)
if encoder_q.numel() != 0:
rope_func(encoder_q)
rope_func(encoder_k)
if decoder_q.numel() != 0:
rope_func(decoder_q)
rope_func(decoder_k)
# 4. Flash Attention for encoder
encoder_v = encoder_v
cu_seqlens_q = forward_meta.cu_seqlens_q
cu_seqlens_k = forward_meta.cu_seqlens_k
max_seqlen_q = paddle.max(seq_lens_this_time)
max_seqlen_k = max_seqlen_q
if encoder_q.numel() > 0:
encoder_out = flash_attn_unpadded_func(
encoder_q,
encoder_k,
encoder_v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
attn_mask=forward_meta.attn_mask,
causal=self.causal,
)
self.update_encoder_kv_cache(
encoder_k, encoder_v, seq_lens_encoder, cache_k, cache_v, forward_meta.block_tables
)
else:
encoder_out = None
# 5. decoder attention with kv cache
bs = decoder_q.shape[0]
decoder_q = decoder_q.reshape([bs, 1, num_head_q, dim])
decoder_k_ = decoder_k.reshape([bs, 1, num_head_kv, dim])
decoder_v_ = decoder_v.reshape([bs, 1, num_head_kv, dim])
cache_seqlens = paddle.index_select(forward_meta.seq_lens_decoder, decoder_indices, axis=0)
# 5.1 convert paged kv cache to continuous cache
if decoder_q.numel() > 0:
max_cache_seq_len = paddle.max(cache_seqlens)
c_cache_k, c_cache_v = self.block_cache_to_naive_cache__(
cache_k, cache_v, bs, forward_meta.block_tables, max_cache_seq_len
)
decoder_out = flash_attn_kvcache_func(
decoder_q,
c_cache_k,
c_cache_v,
cache_seqlens.squeeze(-1),
None,
decoder_k_,
decoder_v_,
causal=self.causal,
)
self.update_decoder_kv_cache(
decoder_k, decoder_v, seq_lens_decoder, cache_k, cache_v, forward_meta.block_tables
)
else:
decoder_out = None
# 6. 拼接 encoder_out 和 decoder_out
total_len = qkv.shape[0]
out = paddle.zeros([total_len, num_head_q, dim])
if encoder_out is not None:
out = paddle.tensor.put_along_axis(
out, encoder_indices.unsqueeze(-1).unsqueeze(-1), encoder_out[0], axis=0
)
if decoder_out is not None:
new_decoder_out = decoder_out[0].squeeze(1)
out = paddle.tensor.put_along_axis(
out, decoder_indices.unsqueeze(-1).unsqueeze(-1), new_decoder_out, axis=0
)
out.reshape_([total_len, num_head_q * dim])
return out

View File

@@ -0,0 +1,19 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .triton_moe_kernels import fused_moe_kernel_paddle
__all__ = [
"fused_moe_kernel_paddle",
]

View File

@@ -0,0 +1,276 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import paddle
from paddle import nn
import fastdeploy
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess
from fastdeploy.utils import ceil_div
from .triton_moe_kernels import fused_moe_kernel_paddle
class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
"""
Use Triton Group Gemm to compute Fused MoE.
"""
def __init__(self, quant_config=None):
"""
Triton Group Gemm to compute Fused MoE.
"""
self.quant_config = quant_config
self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"]
self.added_scale_attrs = [
"up_gate_proj_weight_scale",
"down_proj_weight_scale",
]
def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None:
"""process_prequanted_weights"""
pass
def create_weights(self, layer: nn.Layer, state_dict):
"""
Triton MoE create weight process.
"""
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
assert len(up_gate_proj_weights) == layer.num_local_experts
assert len(down_proj_weights) == layer.num_local_experts
if layer.quant_method.quant_config:
algo = layer.quant_method.quant_config.name()
assert up_gate_proj_weights[0].shape == [
layer.hidden_size,
layer.moe_intermediate_size * 2,
]
assert down_proj_weights[0].shape == [
layer.moe_intermediate_size,
layer.hidden_size,
]
up_gate_proj_tensor = paddle.stack(up_gate_proj_weights, axis=0)
down_proj_tensor = paddle.stack(down_proj_weights, axis=0)
if algo == "wint8":
max_bound = 127
elif algo == "wint4":
max_bound = 7
for idx, weight_tensor in enumerate([up_gate_proj_tensor, down_proj_tensor]):
weight_name = self.added_weight_attrs[idx]
scale_name = self.added_scale_attrs[idx]
quanted_weight_scale = weight_tensor.abs().max(axis=1)
if self.quant_config is not None:
quanted_weight = weight_tensor / quanted_weight_scale[:, None, :] * max_bound
quanted_weight = paddle.round(quanted_weight).astype("int8")
quanted_weight_scale = quanted_weight_scale / max_bound
setattr(
layer,
weight_name,
layer.create_parameter(
shape=quanted_weight.shape,
dtype=quanted_weight.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, weight_name).set_value(quanted_weight)
setattr(
layer,
scale_name,
layer.create_parameter(
shape=quanted_weight_scale.shape,
dtype=quanted_weight_scale.dtype,
),
)
getattr(layer, scale_name).set_value(quanted_weight_scale)
else:
setattr(
layer,
weight_name,
layer.create_parameter(
shape=quanted_weight.shape,
dtype=quanted_weight.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, weight_name).set_value(quanted_weight)
setattr(
layer,
scale_name,
layer.create_parameter(
shape=quanted_weight_scale.shape,
dtype=quanted_weight_scale.dtype,
),
)
getattr(layer, scale_name).set_value(quanted_weight_scale)
def apply(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate_out: paddle.Tensor,
) -> paddle.Tensor:
"""
Triton compute Fused MoE.
"""
token_num = x.shape[0]
top_k = layer.top_k
num_local_experts = layer.num_local_experts
top_k = layer.top_k
moe_intermediate_size = layer.moe_intermediate_size
hidden_size = layer.hidden_size
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out,
layer.gate_correction_bias,
top_k,
True, # apply_norm_weight,
False,
)
up_gate_proj_out = paddle.empty(
[token_num * top_k, moe_intermediate_size * 2],
dtype=x.dtype,
)
if self.quant_config is not None:
config = {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
}
else:
config = {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
}
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(
topk_ids, num_local_experts, config["BLOCK_SIZE_M"]
)
max_possible_num_post_padded = sorted_token_ids.shape[0]
grid = (
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"])
* ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]),
)
fused_moe_kernel_paddle[grid](
x,
layer.up_gate_proj_weight,
up_gate_proj_out,
None,
layer.up_gate_proj_weight_scale,
None,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
max_possible_num_post_padded,
token_num * top_k,
N=moe_intermediate_size * 2,
K=hidden_size,
stride_am=x.strides[0],
stride_ak=x.strides[1],
stride_be=layer.up_gate_proj_weight.strides[0],
stride_bk=layer.up_gate_proj_weight.strides[1],
stride_bn=layer.up_gate_proj_weight.strides[2],
stride_cm=up_gate_proj_out.strides[0],
stride_cn=up_gate_proj_out.strides[1],
#
stride_asm=-1,
stride_ask=-1,
stride_bse=layer.up_gate_proj_weight_scale.strides[0],
stride_bsk=-1,
stride_bsn=layer.up_gate_proj_weight_scale.strides[1],
group_n=-1,
group_k=-1,
# Meta-parameters
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
GROUP_SIZE_M=config["GROUP_SIZE_M"],
MUL_ROUTED_WEIGHT=False,
top_k=top_k,
compute_type_enum=1,
use_fp8_w8a8=False,
use_int8_w8a16=True,
even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0,
)
down_proj_input = paddle.incubate.nn.functional.swiglu(up_gate_proj_out)
down_proj_out = paddle.empty(
(token_num * top_k, hidden_size),
dtype=x.dtype,
)
grid = (
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"])
* ceil_div(hidden_size, config["BLOCK_SIZE_N"]),
)
fused_moe_kernel_paddle[grid](
down_proj_input,
layer.down_proj_weight,
down_proj_out,
None,
layer.down_proj_weight_scale,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
max_possible_num_post_padded,
token_num * top_k,
N=hidden_size,
K=moe_intermediate_size,
stride_am=down_proj_input.strides[0],
stride_ak=down_proj_input.strides[1],
stride_be=layer.down_proj_weight.strides[0],
stride_bk=layer.down_proj_weight.strides[1],
stride_bn=layer.down_proj_weight.strides[2],
stride_cm=down_proj_out.strides[0],
stride_cn=down_proj_out.strides[1],
stride_asm=-1,
stride_ask=-1,
stride_bse=layer.down_proj_weight_scale.strides[0],
stride_bsk=-1,
stride_bsn=layer.down_proj_weight_scale.strides[1],
group_n=-1,
group_k=-1,
# Meta-parameters
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
GROUP_SIZE_M=config["GROUP_SIZE_M"],
MUL_ROUTED_WEIGHT=True,
top_k=1,
compute_type_enum=1,
use_fp8_w8a8=False,
use_int8_w8a16=True,
even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0,
)
down_proj_out.reshape_([token_num, top_k, hidden_size])
out = down_proj_out.sum(axis=1)
return out

View File

@@ -0,0 +1,187 @@
"""
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import triton
import triton.language as tl
@triton.jit
def fused_moe_kernel_paddle(
a_ptr,
b_ptr,
c_ptr,
a_scale_ptr,
b_scale_ptr,
topk_weights_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
num_tokens_post_padded_ptr,
# Matrix dimensions
max_possible_num_post_padded,
num_valid_tokens,
N,
K,
stride_am,
stride_ak,
stride_be,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_asm,
stride_ask,
stride_bse,
stride_bsk,
stride_bsn,
# Block size for block-wise fp8 quantization
group_n: tl.constexpr,
group_k: tl.constexpr,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr,
compute_type_enum: tl.constexpr,
use_fp8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr,
even_Ks: tl.constexpr,
):
"""
Key Parameters:
- A: The input tensor representing tokens with shape (*, K), where '*' can
be any shape representing batches and K is the feature dimension of
each token.
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
the number of experts, K is the input feature dimension, and N is
the output feature dimension.
- C: The output cache tensor with shape (M, topk, N), where M is the
total number of tokens post padding, topk is the number of times
each token is repeated, and N is the output feature dimension.
- sorted_token_ids: A tensor containing the sorted indices of tokens,
repeated topk times and arranged by the expert index they are
assigned to.
- expert_ids: A tensor containing the indices of the expert for each
block. It determines which expert matrix from B should be used for
each block in A.
This kernel performs the multiplication of a token by its corresponding
expert matrix as determined by `expert_ids`. The sorting of
`sorted_token_ids` by expert index and padding ensures divisibility by
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
multiplication across different blocks processed by the same expert.
"""
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(max_possible_num_post_padded, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
assert compute_type_enum == 1
compute_type = tl.bfloat16
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
token_mask = offs_token < num_valid_tokens
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak)
off_experts = tl.load(expert_ids_ptr + pid_m)
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
if use_int8_w8a16:
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
b_scale = tl.load(b_scale_ptrs)
if use_fp8_w8a8:
if group_k > 0 and group_n > 0:
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
offs_bsn = offs_bn // group_n
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
else:
# (Zkk): every expert has one activation scale and weight scale.
a_scale = tl.load(a_scale_ptr + off_experts)
b_scale = tl.load(b_scale_ptr + off_experts)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
if even_Ks:
a = tl.load(
a_ptrs,
mask=token_mask[:, None],
other=0.0,
)
b = tl.load(b_ptrs, cache_modifier=".ca", eviction_policy="evict_first")
else:
a = tl.load(
a_ptrs,
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0,
)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# We accumulate along the K dimension.
if use_int8_w8a16:
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
elif use_fp8_w8a8:
if group_k > 0 and group_n > 0:
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_scale = tl.load(
a_scale_ptrs + offs_ks * stride_ask,
mask=token_mask,
other=0.0,
)
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
else:
accumulator = tl.dot(a, b, acc=accumulator)
else:
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
accumulator = accumulator * moe_weight[:, None]
if use_int8_w8a16:
accumulator = (accumulator * b_scale).to(compute_type)
elif use_fp8_w8a8:
if group_k > 0 and group_n > 0:
accumulator = accumulator.to(compute_type)
else:
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
else:
accumulator = accumulator.to(compute_type)
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)

View File

@@ -107,6 +107,7 @@ class LinearBase(nn.Layer):
or current_platform.is_iluvatar()
or current_platform.is_gcu()
or current_platform.is_dcu()
or current_platform.is_maca()
):
self.forward = self.forward_cuda
else:

View File

@@ -49,6 +49,12 @@ def get_moe_method():
from fastdeploy.model_executor.layers.backends import GCUFusedMoeMethod
return GCUFusedMoeMethod(None)
elif current_platform.is_maca():
from fastdeploy.model_executor.layers.backends import (
MetaxTritonWeightOnlyMoEMethod,
)
return MetaxTritonWeightOnlyMoEMethod(None)
raise NotImplementedError

View File

@@ -94,6 +94,16 @@ class WeightOnlyConfig(QuantConfigBase):
)
return DCUWeightOnlyLinearMethod(self)
elif current_platform.is_maca():
if isinstance(layer, FusedMoE):
from fastdeploy.model_executor.layers.backends import (
MetaxTritonWeightOnlyMoEMethod,
)
return MetaxTritonWeightOnlyMoEMethod(self)
else:
return GPUWeightOnlyLinearMethod(self)
else:
if isinstance(layer, FusedMoE):
if layer.use_method == "cutlass":
@@ -196,14 +206,24 @@ class WeightOnlyLinearMethod(QuantMethodBase):
raise NotImplementedError
def apply(self, layer, x):
linear_out = weight_only_linear(
x,
weight=layer.weight,
bias=layer.bias if layer.add_bias else None,
weight_scale=layer.weight_scale,
weight_dtype=("int8" if self.quant_config.name() == "wint8" else "int4"),
arch=self.quant_config.weight_only_linear_arch,
)
if current_platform.is_maca():
linear_out = weight_only_linear(
x,
weight=layer.weight,
bias=layer.bias if layer.add_bias else None,
weight_scale=layer.weight_scale,
weight_dtype=("int8" if self.quant_config.name() == "wint8" else "int4"),
arch=80,
)
else:
linear_out = weight_only_linear(
x,
weight=layer.weight,
bias=layer.bias if layer.add_bias else None,
weight_scale=layer.weight_scale,
weight_dtype=("int8" if self.quant_config.name() == "wint8" else "int4"),
arch=self.quant_config.weight_only_linear_arch,
)
return linear_out
@@ -240,6 +260,7 @@ class GPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
algo=self.quant_config.algo,
arch=self.quant_config.weight_only_linear_arch,
)
if current_platform.is_maca():
quanted_weight_tensor = paddle.transpose(quanted_weight_tensor, [1, 0])
layer.weight.set_value(quanted_weight_tensor)
layer.weight_scale.set_value(weight_scale_tensor.astype(paddle.get_default_dtype()))

View File

@@ -51,6 +51,10 @@ class ErnieRotaryEmbedding:
# shape: [B, S, D]
rot_emb = paddle.concat([freqs.cos(), freqs.sin()], axis=-1)
return rot_emb
elif paddle.is_compiled_with_custom_device("metax_gpu"):
# shape: [B, S, D]
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim), dtype="float32")
emb = paddle.stack([freqs, freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim))
else:
# shape: [B, S, D/2]
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32")

View File

@@ -119,6 +119,23 @@ def apply_penalty_multi_scores(
min_dec_lens,
eos_token_ids,
)
elif current_platform.is_maca():
from fastdeploy.model_executor.ops.gpu import get_token_penalty_multi_scores
logits = get_token_penalty_multi_scores(
pre_token_ids,
prompt_ids,
prompt_lens,
logits,
repetition_penalties,
frequency_penalties,
presence_penalties,
temperature,
bad_words_token_ids,
step_idx,
min_dec_lens,
eos_token_ids,
)
else:
raise NotImplementedError

View File

@@ -177,6 +177,7 @@ class Sampler(nn.Layer):
or current_platform.is_iluvatar()
or current_platform.is_gcu()
or current_platform.is_dcu()
or current_platform.is_maca()
):
self.forward = self.forward_cuda
else: