mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[MetaxGPU] Support FastDeploy on metax gpu (#3241)
* [MetaxGPU] Support FastDeploy on metax gpu * Update metax_worker.py 1. change worker log; 2. remove custom allreduce, adapt it later; 3. remove cuda graph; * Update __init__.py 1. remove metax's key work comment * Update __init__.py 1. remove metax's key word comment; 2. add fused_moe_kernel_paddle import --------- Co-authored-by: yongqiangma <xing.wo@163.com>
This commit is contained in:
@@ -48,3 +48,10 @@ if current_platform.is_dcu():
|
||||
if hasattr(dcu, "__all__"):
|
||||
globals().update({name: getattr(dcu, name) for name in dcu.__all__})
|
||||
__all__.extend(dcu.__all__)
|
||||
|
||||
if current_platform.is_maca():
|
||||
from . import metax
|
||||
|
||||
if hasattr(metax, "__all__"):
|
||||
globals().update({name: getattr(metax, name) for name in metax.__all__})
|
||||
__all__.extend(metax.__all__)
|
||||
|
21
fastdeploy/model_executor/layers/backends/metax/__init__.py
Normal file
21
fastdeploy/model_executor/layers/backends/metax/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .attention.flash_attn_backend import FlashAttentionBackend
|
||||
from .moe.fused_moe_triton_metax_backend import MetaxTritonWeightOnlyMoEMethod
|
||||
|
||||
__all__ = [
|
||||
"FlashAttentionBackend",
|
||||
"MetaxTritonWeightOnlyMoEMethod",
|
||||
]
|
@@ -0,0 +1,30 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
metax gpu backend attention methods
|
||||
"""
|
||||
from .flash_attention_interface import (
|
||||
flash_attn_func,
|
||||
flash_attn_kvcache_func,
|
||||
flash_attn_unpadded_func,
|
||||
)
|
||||
from .flash_attn_backend import FlashAttentionBackend
|
||||
|
||||
__all__ = [
|
||||
"FlashAttentionBackend",
|
||||
"flash_attn_func",
|
||||
"flash_attn_unpadded_func",
|
||||
"flash_attn_kvcache_func",
|
||||
]
|
@@ -0,0 +1,104 @@
|
||||
import os
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import paddle
|
||||
from paddle import Tensor
|
||||
|
||||
for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
|
||||
if lib.endswith(".so"):
|
||||
paddle.utils.cpp_extension.extension_utils.load_op_meta_info_and_register_op(lib)
|
||||
|
||||
|
||||
def flash_attn_func(
|
||||
q: Tensor,
|
||||
k: Tensor,
|
||||
v: Tensor,
|
||||
fixed_seed_offset: Optional[Tensor] = None,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
dropout_prob: float = 0.0,
|
||||
causal: bool = False,
|
||||
return_softmax: bool = False,
|
||||
is_test: bool = True,
|
||||
rng_name: str = "",
|
||||
) -> Union[Tensor, Tuple[Tensor, ...]]:
|
||||
return paddle._C_ops.flash_attn(
|
||||
q, k, v, fixed_seed_offset, attn_mask, dropout_prob, causal, return_softmax, is_test, rng_name
|
||||
)
|
||||
|
||||
|
||||
def flash_attn_unpadded_func(
|
||||
q: Tensor,
|
||||
k: Tensor,
|
||||
v: Tensor,
|
||||
cu_seqlens_q: Tensor,
|
||||
cu_seqlens_k: Tensor,
|
||||
max_seqlen_q: Union[int, float],
|
||||
max_seqlen_k: Union[int, float],
|
||||
fixed_seed_offset: Optional[Tensor] = None,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
softmax_scale: float = 1.0,
|
||||
dropout: float = 0.0,
|
||||
causal: bool = False,
|
||||
return_softmax: bool = False,
|
||||
is_test: bool = True,
|
||||
rng_name: str = "",
|
||||
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
max_seqlen_q_t = paddle.to_tensor(max_seqlen_q, dtype="int64")
|
||||
max_seqlen_k_t = paddle.to_tensor(max_seqlen_k, dtype="int64")
|
||||
|
||||
outputs = paddle._C_ops.flash_attn_unpadded(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
fixed_seed_offset,
|
||||
attn_mask,
|
||||
max_seqlen_q_t,
|
||||
max_seqlen_k_t,
|
||||
softmax_scale,
|
||||
dropout,
|
||||
causal,
|
||||
return_softmax,
|
||||
is_test,
|
||||
rng_name,
|
||||
)
|
||||
return outputs
|
||||
|
||||
|
||||
def flash_attn_kvcache_func(
|
||||
q: Tensor,
|
||||
k_cache: Tensor,
|
||||
v_cache: Tensor,
|
||||
seqlens_k: Tensor,
|
||||
block_table: Tensor,
|
||||
k: Optional[Tensor] = None,
|
||||
v: Optional[Tensor] = None,
|
||||
rotary_cos: Optional[Tensor] = None,
|
||||
rotary_sin: Optional[Tensor] = None,
|
||||
cache_batch_idx: Optional[Tensor] = None,
|
||||
causal: bool = True,
|
||||
is_rotary_interleaved: bool = False,
|
||||
num_splits: int = 1,
|
||||
dropout: float = 0.0,
|
||||
return_softmax: bool = False,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
out, softmax_lse = paddle._C_ops._run_custom_op(
|
||||
"flash_attn_kvcache",
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
k,
|
||||
v,
|
||||
seqlens_k,
|
||||
rotary_cos,
|
||||
rotary_sin,
|
||||
cache_batch_idx,
|
||||
block_table,
|
||||
causal,
|
||||
is_rotary_interleaved,
|
||||
num_splits,
|
||||
dropout,
|
||||
return_softmax,
|
||||
)
|
||||
return out, softmax_lse
|
@@ -0,0 +1,393 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta, ForwardMode
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend,
|
||||
AttentionMetadata,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
|
||||
from fastdeploy.model_executor.layers.backends.metax.attention.flash_attention_interface import (
|
||||
flash_attn_kvcache_func,
|
||||
flash_attn_unpadded_func,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttentionMetadata(AttentionMetadata):
|
||||
"""
|
||||
FlashAttentionMetadata
|
||||
"""
|
||||
|
||||
max_len_kv: paddle.Tensor = None
|
||||
set_max_lengths: int = -1
|
||||
encoder_batch_ids: paddle.Tensor = None
|
||||
encoder_tile_ids_per_batch: paddle.Tensor = None
|
||||
encoder_num_blocks: paddle.Tensor = None
|
||||
kv_batch_ids: paddle.Tensor = None
|
||||
kv_tile_ids_per_batch: paddle.Tensor = None
|
||||
kv_num_blocks: paddle.Tensor = None
|
||||
decoder_batch_ids: paddle.Tensor = None
|
||||
decoder_tile_ids_per_batch: paddle.Tensor = None
|
||||
decoder_num_blocks: paddle.Tensor = None
|
||||
|
||||
_dtype: paddle.dtype = paddle.bfloat16
|
||||
encoder_max_partition_size: int = 32768
|
||||
max_partition_size: int = 32768
|
||||
block_tables: Optional[paddle.Tensor] = None
|
||||
rotary_embs: Optional[paddle.Tensor] = None
|
||||
attn_mask: Optional[paddle.Tensor] = None
|
||||
encoder_block_shape_q: int = -1
|
||||
decoder_block_shape_q: int = -1
|
||||
_fuse_kernel_compute_dtype: str = "bf16"
|
||||
|
||||
# pd_disaggregation
|
||||
kv_signal_metadata: Optional[paddle.Tensor] = None
|
||||
kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list)
|
||||
|
||||
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
"""
|
||||
FlashAttentionBackend backend implementation.
|
||||
"""
|
||||
|
||||
__infer_dynamic_dims_fields__ = ["attention_metadata"]
|
||||
attention_metadata: FlashAttentionMetadata
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config: FDConfig,
|
||||
kv_num_heads: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
encoder_block_shape_q: int = -1,
|
||||
decoder_block_shape_q: int = -1,
|
||||
) -> None:
|
||||
"""
|
||||
FlashAttentionBackend __init__
|
||||
"""
|
||||
super().__init__()
|
||||
self.attention_metadata: FlashAttentionMetadata = None
|
||||
self.block_size: int = fd_config.parallel_config.block_size
|
||||
self.max_seq_len: int = fd_config.parallel_config.max_model_len
|
||||
self.rope_theta: float = (
|
||||
10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta
|
||||
)
|
||||
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
|
||||
self.causal: bool = getattr(fd_config.model_config, "causal", True)
|
||||
self.speculative_method: str = fd_config.speculative_config.method
|
||||
self.use_speculate: bool = self.speculative_method is not None
|
||||
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
|
||||
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
|
||||
self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"])
|
||||
self.encoder_block_shape_q: int = encoder_block_shape_q
|
||||
self.decoder_block_shape_q: int = decoder_block_shape_q
|
||||
|
||||
self.kv_num_heads: int = kv_num_heads
|
||||
self.num_heads: int = num_heads
|
||||
self.head_dim: int = fd_config.model_config.head_dim
|
||||
self.num_layers: int = fd_config.model_config.num_hidden_layers
|
||||
self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", 32768))
|
||||
|
||||
self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode
|
||||
|
||||
self.start_layer_index: int = fd_config.model_config.start_layer_index
|
||||
|
||||
if fd_config.parallel_config.expert_parallel_rank is None:
|
||||
fd_config.parallel_config.expert_parallel_rank = 0
|
||||
|
||||
self.rank, self.device_id = init_rank_and_device_id(fd_config)
|
||||
|
||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
|
||||
forward_meta.forward_mode = ForwardMode.NATIVE
|
||||
return
|
||||
|
||||
def get_attntion_meta(self) -> AttentionMetadata:
|
||||
"""get_attntion_meta"""
|
||||
return self.attention_metadata
|
||||
|
||||
def get_kv_cache_shape(
|
||||
self,
|
||||
max_num_blocks: int,
|
||||
kv_cache_quant_type: str = None,
|
||||
):
|
||||
"""
|
||||
Caculate kv cache shape
|
||||
"""
|
||||
if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp":
|
||||
return (
|
||||
max_num_blocks,
|
||||
self.kv_num_heads,
|
||||
self.block_size,
|
||||
self.head_dim // 2,
|
||||
)
|
||||
else:
|
||||
return (
|
||||
max_num_blocks,
|
||||
self.kv_num_heads,
|
||||
self.block_size,
|
||||
self.head_dim,
|
||||
)
|
||||
|
||||
def split_qkv(self, qkv, num_head_q, num_head_kv, dim):
|
||||
q = qkv[:, : num_head_q * dim].reshape([-1, num_head_q, dim])
|
||||
k = qkv[:, num_head_q * dim : num_head_q * dim + num_head_kv * dim].reshape([-1, num_head_kv, dim])
|
||||
v = qkv[:, num_head_q * dim + num_head_kv * dim :].reshape([-1, num_head_kv, dim])
|
||||
return q, k, v
|
||||
|
||||
def flash_attn_varlen(self, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k):
|
||||
num_head = q.shape[1]
|
||||
dim = q.shape[2]
|
||||
|
||||
q_ = q.reshape([-1, num_head, dim])
|
||||
k_ = k.reshape([-1, num_head, dim])
|
||||
v_ = v.reshape([-1, num_head, dim])
|
||||
|
||||
bsz = cu_seqlens_q.shape[0] - 1
|
||||
out = []
|
||||
for i in range(bsz):
|
||||
start_q, end_q = cu_seqlens_q[i].item(), cu_seqlens_q[i + 1].item()
|
||||
start_k, end_k = cu_seqlens_k[i].item(), cu_seqlens_k[i + 1].item()
|
||||
qi = q_[start_q:end_q] # [seq_q, nh, dim]
|
||||
ki = k_[start_k:end_k] # [seq_k, nh, dim]
|
||||
vi = v_[start_k:end_k] # [seq_k, nh, dim]
|
||||
qi = qi.transpose([1, 0, 2]) # [nh, seq_q, dim]
|
||||
ki = ki.transpose([1, 2, 0]) # [nh, dim, seq_k]
|
||||
vi = vi.transpose([1, 0, 2]) # [nh, seq_k, dim]
|
||||
|
||||
score = paddle.matmul(qi, ki) / math.sqrt(dim) # [nh, seq_q, seq_k]
|
||||
prob = F.softmax(score, axis=-1)
|
||||
o = paddle.matmul(prob, vi) # [nh, seq_q, dim]
|
||||
o = o.transpose([1, 0, 2]) # [seq_q, nh, dim]
|
||||
out.append(o)
|
||||
|
||||
return paddle.concat(out, axis=0) # [total_q, nh, dim]
|
||||
|
||||
def flash_attn_with_kvcache(self, q, cache_k, cache_v, cache_seqlens, block_tables=None):
|
||||
bs, _, nh, dim = q.shape
|
||||
out = []
|
||||
for i in range(bs):
|
||||
q_i = q[i] # [1, nh, dim]
|
||||
k_i = cache_k[i, : cache_seqlens[i, 0]] # [seqlen, nh, dim]
|
||||
v_i = cache_v[i, : cache_seqlens[i, 0]]
|
||||
qi = q_i.transpose([1, 0, 2]) # [nh, 1, dim]
|
||||
ki = k_i.transpose([1, 2, 0]) # [nh, dim, seqlen]
|
||||
vi = v_i.transpose([1, 0, 2]) # [nh, seqlen, dim]
|
||||
score = paddle.matmul(qi, ki) / math.sqrt(dim)
|
||||
prob = F.softmax(score, axis=-1)
|
||||
o = paddle.matmul(prob, vi).transpose([1, 0, 2]) # [1, nh, dim]
|
||||
out.append(o)
|
||||
return paddle.concat(out, axis=0) # [bs, nh, dim]
|
||||
|
||||
def block_cache_to_naive_cache(slef, cache_k, cache_v, bsz, block_tables, cache_seq_len):
|
||||
_, num_head, blocksize, dim_head = cache_k.shape
|
||||
out_cache_k = paddle.zeros(shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_k.dtype)
|
||||
out_cache_v = paddle.zeros(shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_v.dtype)
|
||||
for i in range(bsz):
|
||||
for j in range(cache_seq_len):
|
||||
out_cache_k[i, :, j, :] = cache_k[block_tables[i, j // blocksize], :, j % blocksize, :]
|
||||
out_cache_v[i, :, j, :] = cache_v[block_tables[i, j // blocksize], :, j % blocksize, :]
|
||||
return out_cache_k, out_cache_v
|
||||
|
||||
def block_cache_to_naive_cache__(self, cache_k, cache_v, bsz, block_tables, max_cache_seq_len):
|
||||
_, num_head, blocksize, dim_head = cache_k.shape
|
||||
out_cache_k = paddle.zeros(shape=[bsz, max_cache_seq_len + 1, num_head, dim_head], dtype=cache_k.dtype)
|
||||
out_cache_v = paddle.zeros(shape=[bsz, max_cache_seq_len + 1, num_head, dim_head], dtype=cache_v.dtype)
|
||||
for i in range(bsz):
|
||||
for j in range(max_cache_seq_len):
|
||||
out_cache_k[i, j, :, :] = cache_k[block_tables[i, j // blocksize], :, j % blocksize, :]
|
||||
out_cache_v[i, j, :, :] = cache_v[block_tables[i, j // blocksize], :, j % blocksize, :]
|
||||
return out_cache_k, out_cache_v
|
||||
|
||||
def update_encoder_kv_cache(self, k, v, seq_lens_encoder, cache_k, cache_v, block_tables):
|
||||
_, num_head, blocksize, dim_head = cache_k.shape
|
||||
offset = 0
|
||||
for batch_idx, seq_len in enumerate(seq_lens_encoder.numpy()):
|
||||
if seq_len == 0:
|
||||
continue
|
||||
for seq_idx in range(seq_len):
|
||||
block_id = block_tables[batch_idx, seq_idx // blocksize]
|
||||
assert block_id != -1
|
||||
index = offset + seq_idx
|
||||
cache_k[block_id, :, seq_idx % blocksize, :] = k[index, :, :]
|
||||
cache_v[block_id, :, seq_idx % blocksize, :] = v[index, :, :]
|
||||
|
||||
offset += seq_len
|
||||
|
||||
def update_decoder_kv_cache(self, k, v, seq_lens_decoder, cache_k, cache_v, block_tables):
|
||||
_, num_head, blocksize, dim_head = cache_k.shape
|
||||
for batch_idx, seq_idx in enumerate(seq_lens_decoder.numpy()):
|
||||
if seq_idx == 0:
|
||||
continue
|
||||
block_id = block_tables[batch_idx, seq_idx // blocksize]
|
||||
assert block_id != -1
|
||||
cache_k[block_id, :, seq_idx % blocksize, :] = k[batch_idx, :, :]
|
||||
cache_v[block_id, :, seq_idx % blocksize, :] = v[batch_idx, :, :]
|
||||
|
||||
def apply_rope(self, qk, cos, sin):
|
||||
rotate_half = paddle.reshape(
|
||||
paddle.stack([-qk[..., 1::2], qk[..., 0::2]], axis=-1),
|
||||
paddle.shape(qk),
|
||||
)
|
||||
out = paddle.add(paddle.multiply(qk, cos), paddle.multiply(rotate_half, sin))
|
||||
return paddle.cast(out, qk.dtype)
|
||||
|
||||
def forward_native_backend(
|
||||
self,
|
||||
q: paddle.Tensor,
|
||||
k: paddle.Tensor,
|
||||
v: paddle.Tensor,
|
||||
qkv: paddle.Tensor,
|
||||
layer,
|
||||
forward_meta: ForwardMeta,
|
||||
):
|
||||
|
||||
bsz = forward_meta.seq_lens_this_time.shape[0]
|
||||
num_head_q, num_head_kv, dim = layer.num_heads, layer.kv_num_heads, layer.head_dim
|
||||
|
||||
# 1. 分离 encoder / decoder 的 mask
|
||||
seq_lens_encoder = forward_meta.seq_lens_encoder.squeeze(-1)
|
||||
seq_lens_decoder = forward_meta.seq_lens_decoder.squeeze(-1)
|
||||
seq_lens_this_time = forward_meta.seq_lens_this_time.squeeze(-1)
|
||||
encoder_indices = []
|
||||
decoder_indices = []
|
||||
|
||||
offset = 0
|
||||
for i in range(bsz):
|
||||
length = seq_lens_this_time[i].item()
|
||||
if seq_lens_encoder[i] > 0:
|
||||
encoder_indices.extend(range(offset, offset + length))
|
||||
elif seq_lens_decoder[i] > 0:
|
||||
decoder_indices.extend(range(offset, offset + length))
|
||||
offset += length
|
||||
|
||||
encoder_indices = paddle.to_tensor(encoder_indices, dtype="int32")
|
||||
decoder_indices = paddle.to_tensor(decoder_indices, dtype="int32")
|
||||
|
||||
encoder_qkv = paddle.index_select(qkv, encoder_indices, axis=0)
|
||||
decoder_qkv = paddle.index_select(qkv, decoder_indices, axis=0)
|
||||
|
||||
# 2. 分解 encoder 和 decoder 的 qkv
|
||||
encoder_q, encoder_k, encoder_v = self.split_qkv(encoder_qkv, num_head_q, num_head_kv, dim)
|
||||
decoder_q, decoder_k, decoder_v = self.split_qkv(decoder_qkv, num_head_q, num_head_kv, dim)
|
||||
cache_k = forward_meta.caches[2 * layer.layer_id]
|
||||
cache_v = forward_meta.caches[2 * layer.layer_id + 1]
|
||||
|
||||
# 3. Rotary Embedding
|
||||
if decoder_q.numel() != 0 or encoder_q.numel() != 0:
|
||||
for batch_idx in range(forward_meta.seq_lens_this_time.shape[0]):
|
||||
seq_len_i = forward_meta.seq_lens_this_time[batch_idx]
|
||||
if seq_len_i == 0:
|
||||
continue
|
||||
cached_kv_len = seq_lens_decoder[batch_idx]
|
||||
cu_seq_start_q = forward_meta.cu_seqlens_q[batch_idx]
|
||||
cu_seq_end_q = forward_meta.cu_seqlens_q[batch_idx + 1]
|
||||
if forward_meta.rotary_embs is not None and cu_seq_end_q > cu_seq_start_q:
|
||||
cos = forward_meta.rotary_embs[0, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :]
|
||||
sin = forward_meta.rotary_embs[1, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :]
|
||||
|
||||
def rope_func(qk):
|
||||
qk[cu_seq_start_q:cu_seq_end_q] = self.apply_rope(qk[cu_seq_start_q:cu_seq_end_q], cos, sin)
|
||||
|
||||
if encoder_q.numel() != 0:
|
||||
rope_func(encoder_q)
|
||||
rope_func(encoder_k)
|
||||
if decoder_q.numel() != 0:
|
||||
rope_func(decoder_q)
|
||||
rope_func(decoder_k)
|
||||
|
||||
# 4. Flash Attention for encoder
|
||||
encoder_v = encoder_v
|
||||
cu_seqlens_q = forward_meta.cu_seqlens_q
|
||||
cu_seqlens_k = forward_meta.cu_seqlens_k
|
||||
max_seqlen_q = paddle.max(seq_lens_this_time)
|
||||
max_seqlen_k = max_seqlen_q
|
||||
|
||||
if encoder_q.numel() > 0:
|
||||
encoder_out = flash_attn_unpadded_func(
|
||||
encoder_q,
|
||||
encoder_k,
|
||||
encoder_v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
attn_mask=forward_meta.attn_mask,
|
||||
causal=self.causal,
|
||||
)
|
||||
self.update_encoder_kv_cache(
|
||||
encoder_k, encoder_v, seq_lens_encoder, cache_k, cache_v, forward_meta.block_tables
|
||||
)
|
||||
else:
|
||||
encoder_out = None
|
||||
|
||||
# 5. decoder attention with kv cache
|
||||
bs = decoder_q.shape[0]
|
||||
decoder_q = decoder_q.reshape([bs, 1, num_head_q, dim])
|
||||
decoder_k_ = decoder_k.reshape([bs, 1, num_head_kv, dim])
|
||||
decoder_v_ = decoder_v.reshape([bs, 1, num_head_kv, dim])
|
||||
cache_seqlens = paddle.index_select(forward_meta.seq_lens_decoder, decoder_indices, axis=0)
|
||||
|
||||
# 5.1 convert paged kv cache to continuous cache
|
||||
if decoder_q.numel() > 0:
|
||||
max_cache_seq_len = paddle.max(cache_seqlens)
|
||||
c_cache_k, c_cache_v = self.block_cache_to_naive_cache__(
|
||||
cache_k, cache_v, bs, forward_meta.block_tables, max_cache_seq_len
|
||||
)
|
||||
decoder_out = flash_attn_kvcache_func(
|
||||
decoder_q,
|
||||
c_cache_k,
|
||||
c_cache_v,
|
||||
cache_seqlens.squeeze(-1),
|
||||
None,
|
||||
decoder_k_,
|
||||
decoder_v_,
|
||||
causal=self.causal,
|
||||
)
|
||||
self.update_decoder_kv_cache(
|
||||
decoder_k, decoder_v, seq_lens_decoder, cache_k, cache_v, forward_meta.block_tables
|
||||
)
|
||||
else:
|
||||
decoder_out = None
|
||||
|
||||
# 6. 拼接 encoder_out 和 decoder_out
|
||||
total_len = qkv.shape[0]
|
||||
out = paddle.zeros([total_len, num_head_q, dim])
|
||||
if encoder_out is not None:
|
||||
out = paddle.tensor.put_along_axis(
|
||||
out, encoder_indices.unsqueeze(-1).unsqueeze(-1), encoder_out[0], axis=0
|
||||
)
|
||||
if decoder_out is not None:
|
||||
new_decoder_out = decoder_out[0].squeeze(1)
|
||||
out = paddle.tensor.put_along_axis(
|
||||
out, decoder_indices.unsqueeze(-1).unsqueeze(-1), new_decoder_out, axis=0
|
||||
)
|
||||
|
||||
out.reshape_([total_len, num_head_q * dim])
|
||||
|
||||
return out
|
@@ -0,0 +1,19 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .triton_moe_kernels import fused_moe_kernel_paddle
|
||||
|
||||
__all__ = [
|
||||
"fused_moe_kernel_paddle",
|
||||
]
|
@@ -0,0 +1,276 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
import fastdeploy
|
||||
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
|
||||
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess
|
||||
from fastdeploy.utils import ceil_div
|
||||
|
||||
from .triton_moe_kernels import fused_moe_kernel_paddle
|
||||
|
||||
|
||||
class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
"""
|
||||
Use Triton Group Gemm to compute Fused MoE.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config=None):
|
||||
"""
|
||||
Triton Group Gemm to compute Fused MoE.
|
||||
"""
|
||||
self.quant_config = quant_config
|
||||
self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"]
|
||||
self.added_scale_attrs = [
|
||||
"up_gate_proj_weight_scale",
|
||||
"down_proj_weight_scale",
|
||||
]
|
||||
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None:
|
||||
"""process_prequanted_weights"""
|
||||
pass
|
||||
|
||||
def create_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
Triton MoE create weight process.
|
||||
"""
|
||||
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
assert len(up_gate_proj_weights) == layer.num_local_experts
|
||||
assert len(down_proj_weights) == layer.num_local_experts
|
||||
|
||||
if layer.quant_method.quant_config:
|
||||
algo = layer.quant_method.quant_config.name()
|
||||
|
||||
assert up_gate_proj_weights[0].shape == [
|
||||
layer.hidden_size,
|
||||
layer.moe_intermediate_size * 2,
|
||||
]
|
||||
assert down_proj_weights[0].shape == [
|
||||
layer.moe_intermediate_size,
|
||||
layer.hidden_size,
|
||||
]
|
||||
|
||||
up_gate_proj_tensor = paddle.stack(up_gate_proj_weights, axis=0)
|
||||
down_proj_tensor = paddle.stack(down_proj_weights, axis=0)
|
||||
|
||||
if algo == "wint8":
|
||||
max_bound = 127
|
||||
elif algo == "wint4":
|
||||
max_bound = 7
|
||||
|
||||
for idx, weight_tensor in enumerate([up_gate_proj_tensor, down_proj_tensor]):
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
scale_name = self.added_scale_attrs[idx]
|
||||
|
||||
quanted_weight_scale = weight_tensor.abs().max(axis=1)
|
||||
if self.quant_config is not None:
|
||||
quanted_weight = weight_tensor / quanted_weight_scale[:, None, :] * max_bound
|
||||
quanted_weight = paddle.round(quanted_weight).astype("int8")
|
||||
quanted_weight_scale = quanted_weight_scale / max_bound
|
||||
|
||||
setattr(
|
||||
layer,
|
||||
weight_name,
|
||||
layer.create_parameter(
|
||||
shape=quanted_weight.shape,
|
||||
dtype=quanted_weight.dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
getattr(layer, weight_name).set_value(quanted_weight)
|
||||
|
||||
setattr(
|
||||
layer,
|
||||
scale_name,
|
||||
layer.create_parameter(
|
||||
shape=quanted_weight_scale.shape,
|
||||
dtype=quanted_weight_scale.dtype,
|
||||
),
|
||||
)
|
||||
getattr(layer, scale_name).set_value(quanted_weight_scale)
|
||||
else:
|
||||
setattr(
|
||||
layer,
|
||||
weight_name,
|
||||
layer.create_parameter(
|
||||
shape=quanted_weight.shape,
|
||||
dtype=quanted_weight.dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
getattr(layer, weight_name).set_value(quanted_weight)
|
||||
|
||||
setattr(
|
||||
layer,
|
||||
scale_name,
|
||||
layer.create_parameter(
|
||||
shape=quanted_weight_scale.shape,
|
||||
dtype=quanted_weight_scale.dtype,
|
||||
),
|
||||
)
|
||||
getattr(layer, scale_name).set_value(quanted_weight_scale)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Triton compute Fused MoE.
|
||||
"""
|
||||
token_num = x.shape[0]
|
||||
top_k = layer.top_k
|
||||
num_local_experts = layer.num_local_experts
|
||||
top_k = layer.top_k
|
||||
moe_intermediate_size = layer.moe_intermediate_size
|
||||
hidden_size = layer.hidden_size
|
||||
|
||||
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
top_k,
|
||||
True, # apply_norm_weight,
|
||||
False,
|
||||
)
|
||||
up_gate_proj_out = paddle.empty(
|
||||
[token_num * top_k, moe_intermediate_size * 2],
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
if self.quant_config is not None:
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
}
|
||||
else:
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
}
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(
|
||||
topk_ids, num_local_experts, config["BLOCK_SIZE_M"]
|
||||
)
|
||||
max_possible_num_post_padded = sorted_token_ids.shape[0]
|
||||
grid = (
|
||||
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"])
|
||||
* ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
fused_moe_kernel_paddle[grid](
|
||||
x,
|
||||
layer.up_gate_proj_weight,
|
||||
up_gate_proj_out,
|
||||
None,
|
||||
layer.up_gate_proj_weight_scale,
|
||||
None,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
max_possible_num_post_padded,
|
||||
token_num * top_k,
|
||||
N=moe_intermediate_size * 2,
|
||||
K=hidden_size,
|
||||
stride_am=x.strides[0],
|
||||
stride_ak=x.strides[1],
|
||||
stride_be=layer.up_gate_proj_weight.strides[0],
|
||||
stride_bk=layer.up_gate_proj_weight.strides[1],
|
||||
stride_bn=layer.up_gate_proj_weight.strides[2],
|
||||
stride_cm=up_gate_proj_out.strides[0],
|
||||
stride_cn=up_gate_proj_out.strides[1],
|
||||
#
|
||||
stride_asm=-1,
|
||||
stride_ask=-1,
|
||||
stride_bse=layer.up_gate_proj_weight_scale.strides[0],
|
||||
stride_bsk=-1,
|
||||
stride_bsn=layer.up_gate_proj_weight_scale.strides[1],
|
||||
group_n=-1,
|
||||
group_k=-1,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
|
||||
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
|
||||
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
|
||||
GROUP_SIZE_M=config["GROUP_SIZE_M"],
|
||||
MUL_ROUTED_WEIGHT=False,
|
||||
top_k=top_k,
|
||||
compute_type_enum=1,
|
||||
use_fp8_w8a8=False,
|
||||
use_int8_w8a16=True,
|
||||
even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0,
|
||||
)
|
||||
|
||||
down_proj_input = paddle.incubate.nn.functional.swiglu(up_gate_proj_out)
|
||||
|
||||
down_proj_out = paddle.empty(
|
||||
(token_num * top_k, hidden_size),
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
grid = (
|
||||
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"])
|
||||
* ceil_div(hidden_size, config["BLOCK_SIZE_N"]),
|
||||
)
|
||||
fused_moe_kernel_paddle[grid](
|
||||
down_proj_input,
|
||||
layer.down_proj_weight,
|
||||
down_proj_out,
|
||||
None,
|
||||
layer.down_proj_weight_scale,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
max_possible_num_post_padded,
|
||||
token_num * top_k,
|
||||
N=hidden_size,
|
||||
K=moe_intermediate_size,
|
||||
stride_am=down_proj_input.strides[0],
|
||||
stride_ak=down_proj_input.strides[1],
|
||||
stride_be=layer.down_proj_weight.strides[0],
|
||||
stride_bk=layer.down_proj_weight.strides[1],
|
||||
stride_bn=layer.down_proj_weight.strides[2],
|
||||
stride_cm=down_proj_out.strides[0],
|
||||
stride_cn=down_proj_out.strides[1],
|
||||
stride_asm=-1,
|
||||
stride_ask=-1,
|
||||
stride_bse=layer.down_proj_weight_scale.strides[0],
|
||||
stride_bsk=-1,
|
||||
stride_bsn=layer.down_proj_weight_scale.strides[1],
|
||||
group_n=-1,
|
||||
group_k=-1,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
|
||||
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
|
||||
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
|
||||
GROUP_SIZE_M=config["GROUP_SIZE_M"],
|
||||
MUL_ROUTED_WEIGHT=True,
|
||||
top_k=1,
|
||||
compute_type_enum=1,
|
||||
use_fp8_w8a8=False,
|
||||
use_int8_w8a16=True,
|
||||
even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0,
|
||||
)
|
||||
|
||||
down_proj_out.reshape_([token_num, top_k, hidden_size])
|
||||
out = down_proj_out.sum(axis=1)
|
||||
return out
|
@@ -0,0 +1,187 @@
|
||||
"""
|
||||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_moe_kernel_paddle(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
a_scale_ptr,
|
||||
b_scale_ptr,
|
||||
topk_weights_ptr,
|
||||
sorted_token_ids_ptr,
|
||||
expert_ids_ptr,
|
||||
num_tokens_post_padded_ptr,
|
||||
# Matrix dimensions
|
||||
max_possible_num_post_padded,
|
||||
num_valid_tokens,
|
||||
N,
|
||||
K,
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_be,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
stride_asm,
|
||||
stride_ask,
|
||||
stride_bse,
|
||||
stride_bsk,
|
||||
stride_bsn,
|
||||
# Block size for block-wise fp8 quantization
|
||||
group_n: tl.constexpr,
|
||||
group_k: tl.constexpr,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
MUL_ROUTED_WEIGHT: tl.constexpr,
|
||||
top_k: tl.constexpr,
|
||||
compute_type_enum: tl.constexpr,
|
||||
use_fp8_w8a8: tl.constexpr,
|
||||
use_int8_w8a16: tl.constexpr,
|
||||
even_Ks: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
|
||||
Key Parameters:
|
||||
- A: The input tensor representing tokens with shape (*, K), where '*' can
|
||||
be any shape representing batches and K is the feature dimension of
|
||||
each token.
|
||||
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
|
||||
the number of experts, K is the input feature dimension, and N is
|
||||
the output feature dimension.
|
||||
- C: The output cache tensor with shape (M, topk, N), where M is the
|
||||
total number of tokens post padding, topk is the number of times
|
||||
each token is repeated, and N is the output feature dimension.
|
||||
- sorted_token_ids: A tensor containing the sorted indices of tokens,
|
||||
repeated topk times and arranged by the expert index they are
|
||||
assigned to.
|
||||
- expert_ids: A tensor containing the indices of the expert for each
|
||||
block. It determines which expert matrix from B should be used for
|
||||
each block in A.
|
||||
This kernel performs the multiplication of a token by its corresponding
|
||||
expert matrix as determined by `expert_ids`. The sorting of
|
||||
`sorted_token_ids` by expert index and padding ensures divisibility by
|
||||
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
|
||||
multiplication across different blocks processed by the same expert.
|
||||
"""
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(max_possible_num_post_padded, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
assert compute_type_enum == 1
|
||||
compute_type = tl.bfloat16
|
||||
|
||||
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
|
||||
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
||||
return
|
||||
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
||||
token_mask = offs_token < num_valid_tokens
|
||||
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak)
|
||||
|
||||
off_experts = tl.load(expert_ids_ptr + pid_m)
|
||||
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
|
||||
if use_int8_w8a16:
|
||||
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
|
||||
b_scale = tl.load(b_scale_ptrs)
|
||||
|
||||
if use_fp8_w8a8:
|
||||
if group_k > 0 and group_n > 0:
|
||||
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
||||
offs_bsn = offs_bn // group_n
|
||||
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
|
||||
else:
|
||||
# (Zkk): every expert has one activation scale and weight scale.
|
||||
a_scale = tl.load(a_scale_ptr + off_experts)
|
||||
b_scale = tl.load(b_scale_ptr + off_experts)
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
if even_Ks:
|
||||
a = tl.load(
|
||||
a_ptrs,
|
||||
mask=token_mask[:, None],
|
||||
other=0.0,
|
||||
)
|
||||
b = tl.load(b_ptrs, cache_modifier=".ca", eviction_policy="evict_first")
|
||||
else:
|
||||
a = tl.load(
|
||||
a_ptrs,
|
||||
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
||||
other=0.0,
|
||||
)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
|
||||
# We accumulate along the K dimension.
|
||||
if use_int8_w8a16:
|
||||
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
|
||||
elif use_fp8_w8a8:
|
||||
if group_k > 0 and group_n > 0:
|
||||
k_start = k * BLOCK_SIZE_K
|
||||
offs_ks = k_start // group_k
|
||||
a_scale = tl.load(
|
||||
a_scale_ptrs + offs_ks * stride_ask,
|
||||
mask=token_mask,
|
||||
other=0.0,
|
||||
)
|
||||
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
|
||||
|
||||
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
|
||||
else:
|
||||
accumulator = tl.dot(a, b, acc=accumulator)
|
||||
else:
|
||||
accumulator += tl.dot(a, b)
|
||||
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
|
||||
if MUL_ROUTED_WEIGHT:
|
||||
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
||||
accumulator = accumulator * moe_weight[:, None]
|
||||
if use_int8_w8a16:
|
||||
accumulator = (accumulator * b_scale).to(compute_type)
|
||||
elif use_fp8_w8a8:
|
||||
if group_k > 0 and group_n > 0:
|
||||
accumulator = accumulator.to(compute_type)
|
||||
else:
|
||||
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
|
||||
else:
|
||||
accumulator = accumulator.to(compute_type)
|
||||
# Write back the block of the output
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
||||
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
Reference in New Issue
Block a user