[Feature] refactor metax_gpu attention and moe and remove some useless code (#3688)

Co-authored-by: yongqiangma <xing.wo@163.com>
This commit is contained in:
SuperNova
2025-09-12 14:40:25 +08:00
committed by GitHub
parent cab7a633fe
commit 805f29a06c
5 changed files with 389 additions and 289 deletions

View File

@@ -894,7 +894,7 @@ class CacheConfig:
self.kv_cache_ratio = 1.0 self.kv_cache_ratio = 1.0
else: else:
self.kv_cache_ratio = 0.75 self.kv_cache_ratio = 0.75
self.enc_dec_block_num = 0 if current_platform.is_iluvatar() else 2 self.enc_dec_block_num = 0 if current_platform.is_iluvatar() or current_platform.is_maca() else 2
self.prealloc_dec_block_slot_num_threshold = 12 self.prealloc_dec_block_slot_num_threshold = 12
self.cache_dtype = "bfloat16" self.cache_dtype = "bfloat16"
self.model_cfg = None self.model_cfg = None

View File

@@ -16,13 +16,11 @@
from __future__ import annotations from __future__ import annotations
import math
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Optional from typing import List, Optional
import paddle import paddle
import paddle.nn.functional as F
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
from fastdeploy.model_executor.forward_meta import ForwardMeta, ForwardMode from fastdeploy.model_executor.forward_meta import ForwardMeta, ForwardMode
@@ -92,6 +90,7 @@ class FlashAttentionBackend(AttentionBackend):
""" """
super().__init__() super().__init__()
self.attention_metadata: FlashAttentionMetadata = None self.attention_metadata: FlashAttentionMetadata = None
self.record_block_table_metadata = {}
self.block_size: int = fd_config.parallel_config.block_size self.block_size: int = fd_config.parallel_config.block_size
self.max_seq_len: int = fd_config.parallel_config.max_model_len self.max_seq_len: int = fd_config.parallel_config.max_model_len
self.rope_theta: float = ( self.rope_theta: float = (
@@ -110,6 +109,9 @@ class FlashAttentionBackend(AttentionBackend):
self.kv_num_heads: int = kv_num_heads self.kv_num_heads: int = kv_num_heads
self.num_heads: int = num_heads self.num_heads: int = num_heads
self.head_dim: int = fd_config.model_config.head_dim self.head_dim: int = fd_config.model_config.head_dim
self.total_num_heads = self.num_heads + 2 * self.kv_num_heads
self.total_hidden_dim = self.total_num_heads * self.head_dim
self.dtype = paddle.get_default_dtype()
self.num_layers: int = fd_config.model_config.num_hidden_layers self.num_layers: int = fd_config.model_config.num_hidden_layers
self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", 32768)) self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", 32768))
@@ -125,7 +127,98 @@ class FlashAttentionBackend(AttentionBackend):
def init_attention_metadata(self, forward_meta: ForwardMeta): def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attntion metadata hence all layers in the forward pass can reuse it.""" """Initialize attntion metadata hence all layers in the forward pass can reuse it."""
forward_meta.forward_mode = ForwardMode.NATIVE forward_meta.forward_mode = ForwardMode.NATIVE
return self.prefill_info_dict = {}
self.decode_info_dict = {}
prefill_non_zeros_ids = forward_meta.seq_lens_this_time > 1
decode_non_zeros_ids = forward_meta.seq_lens_this_time == 1
self.prefill_info_dict["batch_ids"] = paddle.where(prefill_non_zeros_ids)[0]
self.decode_info_dict["batch_ids"] = paddle.where(decode_non_zeros_ids)[0]
self.prefill_len = len(self.prefill_info_dict["batch_ids"])
self.decode_len = len(self.decode_info_dict["batch_ids"])
# only prefill
if self.decode_len == 0:
cu_seq_ids = list(range(self.prefill_len + 1))
self.prefill_info_dict["cu_seqlens_q"] = forward_meta.cu_seqlens_q[cu_seq_ids].astype("int32")
# only decode
elif self.prefill_len == 0:
pass
# both prefill and decode
else:
prefill_num_tokens = paddle.sum(forward_meta.seq_lens_this_time[prefill_non_zeros_ids])
decode_num_tokens = paddle.sum(forward_meta.seq_lens_this_time[decode_non_zeros_ids])
self.prefill_info_dict["cu_seqlens_q"] = paddle.zeros(
[self.prefill_len + 1], dtype=forward_meta.cu_seqlens_q.dtype
)
self.prefill_info_dict["cu_seqlens_q"][1:] = forward_meta.seq_lens_encoder[
self.prefill_info_dict["batch_ids"], 0
]
self.prefill_info_dict["cu_seqlens_q"] = paddle.cumsum(self.prefill_info_dict["cu_seqlens_q"]).astype(
"int32"
)
self.prefill_qkv = paddle.zeros([prefill_num_tokens, self.total_hidden_dim], dtype=self.dtype)
self.decode_qkv = paddle.zeros([decode_num_tokens, self.total_hidden_dim], dtype=self.dtype)
self.merged_output = paddle.zeros(
[prefill_num_tokens + decode_num_tokens, self.num_heads, self.head_dim], dtype=self.dtype
)
prefill_start, decode_start, start = 0, 0, 0
non_zeros_ids = forward_meta.seq_lens_this_time != 0
non_zeros_seq_lens = forward_meta.seq_lens_this_time[non_zeros_ids]
end = non_zeros_seq_lens[0]
if end > 1:
last_stage = "prefill"
prefill_end = end
decode_end = 0
else:
last_stage = "decode"
prefill_end = 0
decode_end = end
self.prefill_info_dict["id_group"] = []
self.prefill_info_dict["reverse_id_group"] = []
self.decode_info_dict["id_group"] = []
self.decode_info_dict["reverse_id_group"] = []
self.record_stages = []
for seq_len in non_zeros_seq_lens[1:]:
if seq_len > 1:
if last_stage == "decode":
self.record_stages.append((last_stage, len(self.decode_info_dict["id_group"])))
self.decode_info_dict["id_group"].append((decode_start, decode_end))
self.decode_info_dict["reverse_id_group"].append((start, end))
decode_start = decode_end
start = end
last_stage = "prefill"
prefill_end += seq_len
end += seq_len
else:
if last_stage == "prefill":
self.record_stages.append((last_stage, len(self.prefill_info_dict["id_group"])))
self.prefill_info_dict["id_group"].append((prefill_start, prefill_end))
self.prefill_info_dict["reverse_id_group"].append((start, end))
prefill_start = prefill_end
start = end
last_stage = "decode"
decode_end += seq_len
end += seq_len
if prefill_start < prefill_end:
self.record_stages.append(("prefill", len(self.prefill_info_dict["id_group"])))
self.prefill_info_dict["id_group"].append((prefill_start, prefill_end))
self.prefill_info_dict["reverse_id_group"].append((start, end))
if decode_start < decode_end:
self.record_stages.append(("decode", len(self.decode_info_dict["id_group"])))
self.decode_info_dict["id_group"].append((decode_start, decode_end))
self.decode_info_dict["reverse_id_group"].append((start, end))
self.batch_ids_prefill = paddle.to_tensor(self.prefill_info_dict["batch_ids"])
self.batch_ids_decode = paddle.to_tensor(self.decode_info_dict["batch_ids"])
self.seq_lens_dec = forward_meta.seq_lens_decoder[self.batch_ids_decode, 0]
self.block_table_dec = forward_meta.block_tables[self.batch_ids_decode, :]
def get_attntion_meta(self) -> AttentionMetadata: def get_attntion_meta(self) -> AttentionMetadata:
"""get_attntion_meta""" """get_attntion_meta"""
@@ -149,106 +242,11 @@ class FlashAttentionBackend(AttentionBackend):
else: else:
return ( return (
max_num_blocks, max_num_blocks,
self.kv_num_heads,
self.block_size, self.block_size,
self.kv_num_heads,
self.head_dim, self.head_dim,
) )
def split_qkv(self, qkv, num_head_q, num_head_kv, dim):
q = qkv[:, : num_head_q * dim].reshape([-1, num_head_q, dim])
k = qkv[:, num_head_q * dim : num_head_q * dim + num_head_kv * dim].reshape([-1, num_head_kv, dim])
v = qkv[:, num_head_q * dim + num_head_kv * dim :].reshape([-1, num_head_kv, dim])
return q, k, v
def flash_attn_varlen(self, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k):
num_head = q.shape[1]
dim = q.shape[2]
q_ = q.reshape([-1, num_head, dim])
k_ = k.reshape([-1, num_head, dim])
v_ = v.reshape([-1, num_head, dim])
bsz = cu_seqlens_q.shape[0] - 1
out = []
for i in range(bsz):
start_q, end_q = cu_seqlens_q[i].item(), cu_seqlens_q[i + 1].item()
start_k, end_k = cu_seqlens_k[i].item(), cu_seqlens_k[i + 1].item()
qi = q_[start_q:end_q] # [seq_q, nh, dim]
ki = k_[start_k:end_k] # [seq_k, nh, dim]
vi = v_[start_k:end_k] # [seq_k, nh, dim]
qi = qi.transpose([1, 0, 2]) # [nh, seq_q, dim]
ki = ki.transpose([1, 2, 0]) # [nh, dim, seq_k]
vi = vi.transpose([1, 0, 2]) # [nh, seq_k, dim]
score = paddle.matmul(qi, ki) / math.sqrt(dim) # [nh, seq_q, seq_k]
prob = F.softmax(score, axis=-1)
o = paddle.matmul(prob, vi) # [nh, seq_q, dim]
o = o.transpose([1, 0, 2]) # [seq_q, nh, dim]
out.append(o)
return paddle.concat(out, axis=0) # [total_q, nh, dim]
def flash_attn_with_kvcache(self, q, cache_k, cache_v, cache_seqlens, block_tables=None):
bs, _, nh, dim = q.shape
out = []
for i in range(bs):
q_i = q[i] # [1, nh, dim]
k_i = cache_k[i, : cache_seqlens[i, 0]] # [seqlen, nh, dim]
v_i = cache_v[i, : cache_seqlens[i, 0]]
qi = q_i.transpose([1, 0, 2]) # [nh, 1, dim]
ki = k_i.transpose([1, 2, 0]) # [nh, dim, seqlen]
vi = v_i.transpose([1, 0, 2]) # [nh, seqlen, dim]
score = paddle.matmul(qi, ki) / math.sqrt(dim)
prob = F.softmax(score, axis=-1)
o = paddle.matmul(prob, vi).transpose([1, 0, 2]) # [1, nh, dim]
out.append(o)
return paddle.concat(out, axis=0) # [bs, nh, dim]
def block_cache_to_naive_cache(slef, cache_k, cache_v, bsz, block_tables, cache_seq_len):
_, num_head, blocksize, dim_head = cache_k.shape
out_cache_k = paddle.zeros(shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_k.dtype)
out_cache_v = paddle.zeros(shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_v.dtype)
for i in range(bsz):
for j in range(cache_seq_len):
out_cache_k[i, :, j, :] = cache_k[block_tables[i, j // blocksize], :, j % blocksize, :]
out_cache_v[i, :, j, :] = cache_v[block_tables[i, j // blocksize], :, j % blocksize, :]
return out_cache_k, out_cache_v
def block_cache_to_naive_cache__(self, cache_k, cache_v, bsz, block_tables, max_cache_seq_len):
_, num_head, blocksize, dim_head = cache_k.shape
out_cache_k = paddle.zeros(shape=[bsz, max_cache_seq_len + 1, num_head, dim_head], dtype=cache_k.dtype)
out_cache_v = paddle.zeros(shape=[bsz, max_cache_seq_len + 1, num_head, dim_head], dtype=cache_v.dtype)
for i in range(bsz):
for j in range(max_cache_seq_len):
out_cache_k[i, j, :, :] = cache_k[block_tables[i, j // blocksize], :, j % blocksize, :]
out_cache_v[i, j, :, :] = cache_v[block_tables[i, j // blocksize], :, j % blocksize, :]
return out_cache_k, out_cache_v
def update_encoder_kv_cache(self, k, v, seq_lens_encoder, cache_k, cache_v, block_tables):
_, num_head, blocksize, dim_head = cache_k.shape
offset = 0
for batch_idx, seq_len in enumerate(seq_lens_encoder.numpy()):
if seq_len == 0:
continue
for seq_idx in range(seq_len):
block_id = block_tables[batch_idx, seq_idx // blocksize]
assert block_id != -1
index = offset + seq_idx
cache_k[block_id, :, seq_idx % blocksize, :] = k[index, :, :]
cache_v[block_id, :, seq_idx % blocksize, :] = v[index, :, :]
offset += seq_len
def update_decoder_kv_cache(self, k, v, seq_lens_decoder, cache_k, cache_v, block_tables):
_, num_head, blocksize, dim_head = cache_k.shape
for batch_idx, seq_idx in enumerate(seq_lens_decoder.numpy()):
if seq_idx == 0:
continue
block_id = block_tables[batch_idx, seq_idx // blocksize]
assert block_id != -1
cache_k[block_id, :, seq_idx % blocksize, :] = k[batch_idx, :, :]
cache_v[block_id, :, seq_idx % blocksize, :] = v[batch_idx, :, :]
def apply_rope(self, qk, cos, sin): def apply_rope(self, qk, cos, sin):
rotate_half = paddle.reshape( rotate_half = paddle.reshape(
paddle.stack([-qk[..., 1::2], qk[..., 0::2]], axis=-1), paddle.stack([-qk[..., 1::2], qk[..., 0::2]], axis=-1),
@@ -257,138 +255,234 @@ class FlashAttentionBackend(AttentionBackend):
out = paddle.add(paddle.multiply(qk, cos), paddle.multiply(rotate_half, sin)) out = paddle.add(paddle.multiply(qk, cos), paddle.multiply(rotate_half, sin))
return paddle.cast(out, qk.dtype) return paddle.cast(out, qk.dtype)
@paddle.no_grad() def get_splited_qkv(
def forward_native_backend(
self, self,
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
qkv: paddle.Tensor, qkv: paddle.Tensor,
layer,
forward_meta: ForwardMeta, forward_meta: ForwardMeta,
cu_seqlens_q: paddle.Tensor,
batch_ids=None,
is_decode=False,
): ):
q_end = self.num_heads * self.head_dim
k_end = q_end + self.kv_num_heads * self.head_dim
v_end = k_end + self.kv_num_heads * self.head_dim
assert v_end == qkv.shape[-1], f"Shape mismatch: {v_end} vs {qkv.shape[-1]}"
assert qkv.shape[0] == cu_seqlens_q[-1], f"Shape mismatch: {qkv.shape[0]} vs {cu_seqlens_q[-1]}"
bsz = forward_meta.seq_lens_this_time.shape[0] if batch_ids is None:
num_head_q, num_head_kv, dim = layer.num_heads, layer.kv_num_heads, layer.head_dim batch_ids = list(range(forward_meta.seq_lens_this_time.shape[0]))
# 1. 分离 encoder / decoder 的 mask q = qkv[..., 0:q_end]
seq_lens_encoder = forward_meta.seq_lens_encoder.squeeze(-1) k = qkv[..., q_end:k_end]
seq_lens_decoder = forward_meta.seq_lens_decoder.squeeze(-1) v = qkv[..., k_end:v_end]
seq_lens_this_time = forward_meta.seq_lens_this_time
encoder_indices = []
decoder_indices = []
offset = 0 q = q.view([-1, self.num_heads, self.head_dim])
for i in range(bsz): k = k.view([-1, self.kv_num_heads, self.head_dim])
length = seq_lens_this_time[i].item() v = v.view([-1, self.kv_num_heads, self.head_dim])
if seq_lens_encoder[i] > 0:
encoder_indices.extend(range(offset, offset + length))
elif seq_lens_decoder[i] > 0:
decoder_indices.extend(range(offset, offset + length))
offset += length
encoder_indices = paddle.to_tensor(encoder_indices, dtype="int32") if is_decode:
decoder_indices = paddle.to_tensor(decoder_indices, dtype="int32") return q, k, v
encoder_qkv = paddle.index_select(qkv, encoder_indices, axis=0) for idx in range(len(cu_seqlens_q) - 1):
decoder_qkv = paddle.index_select(qkv, decoder_indices, axis=0) batch_idx = batch_ids[idx]
seq_len_i = forward_meta.seq_lens_this_time[batch_idx]
if seq_len_i == 0:
continue
cached_kv_len = forward_meta.seq_lens_decoder[batch_idx][0]
cu_seq_start_q = cu_seqlens_q[idx]
cu_seq_end_q = cu_seqlens_q[idx + 1]
# forward_meta.rotary_embs is [2, 1, S, 1, D // 2]
if forward_meta.rotary_embs is not None:
cos = paddle.repeat_interleave(
forward_meta.rotary_embs[0, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :], repeats=2, axis=-1
) # [Si, D]
sin = paddle.repeat_interleave(
forward_meta.rotary_embs[1, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :], repeats=2, axis=-1
) # [Si, D]
q[cu_seq_start_q:cu_seq_end_q] = self.apply_rope(q[cu_seq_start_q:cu_seq_end_q], cos, sin)
k[cu_seq_start_q:cu_seq_end_q] = self.apply_rope(k[cu_seq_start_q:cu_seq_end_q], cos, sin)
# 2. 分解 encoder 和 decoder 的 qkv return q, k, v
encoder_q, encoder_k, encoder_v = self.split_qkv(encoder_qkv, num_head_q, num_head_kv, dim)
decoder_q, decoder_k, decoder_v = self.split_qkv(decoder_qkv, num_head_q, num_head_kv, dim)
cache_k = forward_meta.caches[2 * layer.layer_id]
cache_v = forward_meta.caches[2 * layer.layer_id + 1]
# 3. Rotary Embedding def split_pd_qkv(self, qkv):
if decoder_q.numel() != 0 or encoder_q.numel() != 0:
for batch_idx in range(forward_meta.seq_lens_this_time.shape[0]): for ids, reverse_ids in zip(self.prefill_info_dict["id_group"], self.prefill_info_dict["reverse_id_group"]):
seq_len_i = forward_meta.seq_lens_this_time[batch_idx] self.prefill_qkv[ids[0] : ids[1], :] = qkv[reverse_ids[0] : reverse_ids[1], :]
if seq_len_i == 0:
for ids, reverse_ids in zip(self.decode_info_dict["id_group"], self.decode_info_dict["reverse_id_group"]):
self.decode_qkv[ids[0] : ids[1], :] = qkv[reverse_ids[0] : reverse_ids[1], :]
return self.prefill_qkv, self.decode_qkv
def merge_pd_output(self, prefill_out, decode_out):
for stage, idx in self.record_stages:
if stage == "prefill":
ids = self.prefill_info_dict["id_group"][idx]
reverse_ids = self.prefill_info_dict["reverse_id_group"][idx]
self.merged_output[reverse_ids[0] : reverse_ids[1], :, :] = prefill_out[ids[0] : ids[1], :, :]
else:
ids = self.decode_info_dict["id_group"][idx]
reverse_ids = self.decode_info_dict["reverse_id_group"][idx]
self.merged_output[reverse_ids[0] : reverse_ids[1], :, :] = decode_out[ids[0] : ids[1], :, :]
return self.merged_output
def update_kv_cache(
self, k, v, k_cache_id, v_cache_id, layer_id, forward_meta: ForwardMeta, specific_batch_ids=None
):
tensor_start = 0
for batch_idx in range(forward_meta.block_tables.shape[0]):
if specific_batch_ids is not None and batch_idx not in specific_batch_ids:
continue
seq_len = forward_meta.seq_lens_this_time[batch_idx]
if seq_len == 0:
continue
tensor_end = tensor_start + seq_len
slice_trans_k = k[tensor_start:tensor_end, :, :]
slice_trans_v = v[tensor_start:tensor_end, :, :]
cur_block_tables = forward_meta.block_tables[batch_idx]
cur_used_block_tables = cur_block_tables[cur_block_tables != -1]
# encoder prefil
if seq_len > 1:
cache_start = 0
cur_used_num_blocks = cur_used_block_tables.shape[0]
for i, block_id in enumerate(cur_used_block_tables):
# last block: seq_len - cache_start <= block_size
if i == cur_used_num_blocks - 1:
cache_end = seq_len - cache_start
assert cache_end <= self.block_size
forward_meta.caches[k_cache_id][block_id, 0:cache_end, :, :] = slice_trans_k[
cache_start:seq_len, :, :
]
forward_meta.caches[v_cache_id][block_id, 0:cache_end, :, :] = slice_trans_v[
cache_start:seq_len, :, :
]
if layer_id == self.num_layers - 1:
self.record_block_table_metadata[batch_idx] = {
"block_id": block_id.item(),
"cache_end": cache_end,
}
# non last block: seq_lens_this_time > block_size
else:
assert seq_len > self.block_size
cache_end = cache_start + self.block_size
forward_meta.caches[k_cache_id][block_id] = slice_trans_k[cache_start:cache_end, :, :]
forward_meta.caches[v_cache_id][block_id] = slice_trans_v[cache_start:cache_end, :, :]
cache_start += self.block_size
tensor_start = tensor_end
def merge_output(self, prefill_out, decode_out, forward_meta: ForwardMeta):
assert not (prefill_out is None and decode_out is None), "prefill and decode output cannot both be None"
if prefill_out is None:
return decode_out
elif decode_out is None:
return prefill_out
else:
prefill_out = prefill_out
decode_out = decode_out
merged_output = []
prefill_tensor_start = 0
decode_tensor_start = 0
for seq_lens_this_time in forward_meta.seq_lens_this_time:
if seq_lens_this_time == 0:
continue continue
cached_kv_len = seq_lens_decoder[batch_idx] if seq_lens_this_time > 1:
cu_seq_start_q = forward_meta.cu_seqlens_q[batch_idx] tensor_end = prefill_tensor_start + seq_lens_this_time
cu_seq_end_q = forward_meta.cu_seqlens_q[batch_idx + 1] merged_output.append(prefill_out[prefill_tensor_start:tensor_end, :, :])
if forward_meta.rotary_embs is not None and cu_seq_end_q > cu_seq_start_q: prefill_tensor_start = tensor_end
cos = forward_meta.rotary_embs[0, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :] else:
sin = forward_meta.rotary_embs[1, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :] assert seq_lens_this_time == 1
tensor_end = decode_tensor_start + seq_lens_this_time
merged_output.append(decode_out[decode_tensor_start:tensor_end, :, :])
decode_tensor_start = tensor_end
def rope_func(qk): assert (
qk[cu_seq_start_q:cu_seq_end_q] = self.apply_rope(qk[cu_seq_start_q:cu_seq_end_q], cos, sin) prefill_tensor_start == prefill_out.shape[0]
), f"prefill merged unfinished: {prefill_tensor_start} vs {prefill_out.shape[0]}"
assert (
decode_tensor_start == decode_out.shape[0]
), f"decode merged unfinished: {decode_tensor_start} vs {decode_out.shape[0]}"
merged_output = paddle.concat(merged_output, axis=0)
return merged_output
if encoder_q.numel() != 0: def forward_prefill(self, prefill_qkv, layer_id, k_cache_id, v_cache_id, forward_meta: ForwardMeta):
rope_func(encoder_q)
rope_func(encoder_k)
if decoder_q.numel() != 0:
rope_func(decoder_q)
rope_func(decoder_k)
# 4. Flash Attention for encoder prefill_q, prefill_k, prefill_v = self.get_splited_qkv(
encoder_v = encoder_v prefill_qkv,
cu_seqlens_q = forward_meta.cu_seqlens_q forward_meta,
cu_seqlens_k = forward_meta.cu_seqlens_k self.prefill_info_dict["cu_seqlens_q"],
max_seqlen_q = paddle.max(seq_lens_this_time) batch_ids=self.batch_ids_prefill,
max_seqlen_k = max_seqlen_q )
prefill_out = flash_attn_unpadded_func(
prefill_q,
prefill_k,
prefill_v,
self.prefill_info_dict["cu_seqlens_q"],
self.prefill_info_dict["cu_seqlens_q"],
max_seqlen_q=self.max_seq_len,
max_seqlen_k=self.max_seq_len,
attn_mask=forward_meta.attn_mask,
causal=self.causal,
)[0]
self.update_kv_cache(
prefill_k, prefill_v, k_cache_id, v_cache_id, layer_id, forward_meta, self.batch_ids_prefill
)
return prefill_out
def forward_decode(self, decode_qkv, k_cache_id, v_cache_id, forward_meta: ForwardMeta):
cache_k = forward_meta.caches[k_cache_id]
cache_v = forward_meta.caches[v_cache_id]
cu_seq_lens = list(range(self.decode_len + 1))
q, k, v = self.get_splited_qkv(decode_qkv, forward_meta, cu_seq_lens, self.batch_ids_decode, is_decode=True)
decoder_q = q.view([self.decode_len, 1, self.num_heads, self.head_dim])
decoder_k_ = k.view([self.decode_len, 1, self.kv_num_heads, self.head_dim])
decoder_v_ = v.view([self.decode_len, 1, self.kv_num_heads, self.head_dim])
decode_out = flash_attn_kvcache_func(
decoder_q,
cache_k,
cache_v,
self.seq_lens_dec,
self.block_table_dec,
decoder_k_,
decoder_v_,
rotary_cos=forward_meta.rotary_embs[0, 0, :, 0, :].astype("bfloat16"),
rotary_sin=forward_meta.rotary_embs[1, 0, :, 0, :].astype("bfloat16"),
causal=self.causal,
is_rotary_interleaved=True,
)[0].squeeze(1)
return decode_out
@paddle.no_grad()
def forward_native_backend(self, q, k, v, qkv, layer, forward_meta: ForwardMeta):
layer_id = layer.layer_id
k_cache_id = layer_id * 2
v_cache_id = k_cache_id + 1
if self.decode_len == 0:
out = self.forward_prefill(qkv, layer_id, k_cache_id, v_cache_id, forward_meta)
elif self.prefill_len == 0:
out = self.forward_decode(qkv, k_cache_id, v_cache_id, forward_meta)
if encoder_q.numel() > 0:
encoder_out = flash_attn_unpadded_func(
encoder_q,
encoder_k,
encoder_v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
attn_mask=forward_meta.attn_mask,
causal=self.causal,
)
self.update_encoder_kv_cache(
encoder_k, encoder_v, seq_lens_encoder, cache_k, cache_v, forward_meta.block_tables
)
else: else:
encoder_out = None prefill_qkv, decode_qkv = self.split_pd_qkv(qkv)
prefill_output = self.forward_prefill(prefill_qkv, layer_id, k_cache_id, v_cache_id, forward_meta)
decode_output = self.forward_decode(decode_qkv, k_cache_id, v_cache_id, forward_meta)
out = self.merge_pd_output(prefill_output, decode_output)
# 5. decoder attention with kv cache if qkv.dim() == 2:
bs = decoder_q.shape[0] out = out.view([-1, self.num_heads * self.head_dim])
decoder_q = decoder_q.reshape([bs, 1, num_head_q, dim])
decoder_k_ = decoder_k.reshape([bs, 1, num_head_kv, dim])
decoder_v_ = decoder_v.reshape([bs, 1, num_head_kv, dim])
cache_seqlens = paddle.index_select(forward_meta.seq_lens_decoder, decoder_indices, axis=0)
# 5.1 convert paged kv cache to continuous cache
if decoder_q.numel() > 0:
max_cache_seq_len = paddle.max(cache_seqlens)
c_cache_k, c_cache_v = self.block_cache_to_naive_cache__(
cache_k, cache_v, bs, forward_meta.block_tables, max_cache_seq_len
)
decoder_out = flash_attn_kvcache_func(
decoder_q,
c_cache_k,
c_cache_v,
cache_seqlens.squeeze(-1),
None,
decoder_k_,
decoder_v_,
causal=self.causal,
)
self.update_decoder_kv_cache(
decoder_k, decoder_v, seq_lens_decoder, cache_k, cache_v, forward_meta.block_tables
)
else:
decoder_out = None
# 6. 拼接 encoder_out 和 decoder_out
total_len = qkv.shape[0]
out = paddle.zeros([total_len, num_head_q, dim])
if encoder_out is not None:
out = paddle.tensor.put_along_axis(
out, encoder_indices.unsqueeze(-1).unsqueeze(-1), encoder_out[0], axis=0
)
if decoder_out is not None:
new_decoder_out = decoder_out[0].squeeze(1)
out = paddle.tensor.put_along_axis(
out, decoder_indices.unsqueeze(-1).unsqueeze(-1), new_decoder_out, axis=0
)
out.reshape_([total_len, num_head_q * dim])
return out return out

View File

@@ -45,17 +45,73 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
"""process_prequanted_weights""" """process_prequanted_weights"""
pass pass
@paddle.no_grad() def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
def create_weights(self, layer: nn.Layer, state_dict):
""" """
Triton MoE create weight process. Triton MoE create weight process.
""" """
self.weight_dtype = "int8"
self.default_dtype = layer._helper.get_default_dtype()
up_gate_proj_weight_name = self.added_weight_attrs[0]
down_proj_weight_name = self.added_weight_attrs[1]
self.up_gate_proj_weight_shape = [
layer.num_local_experts,
layer.hidden_size,
layer.moe_intermediate_size * 2,
]
self.down_proj_weight_shape = [
layer.num_local_experts,
layer.moe_intermediate_size,
layer.hidden_size,
]
setattr(
layer,
up_gate_proj_weight_name,
layer.create_parameter(
shape=self.up_gate_proj_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
down_proj_weight_name,
layer.create_parameter(
shape=self.down_proj_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# weight_scale
setattr(
layer,
self.added_scale_attrs[0],
layer.create_parameter(
shape=[layer.num_local_experts, layer.moe_intermediate_size * 2],
dtype=self.default_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
self.added_scale_attrs[1],
layer.create_parameter(
shape=[layer.num_local_experts, layer.hidden_size],
dtype=self.default_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
def process_loaded_weights(self, layer: nn.Layer, state_dict):
"""
Triton MoE load weight process.
"""
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict) up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
assert len(up_gate_proj_weights) == layer.num_local_experts assert len(up_gate_proj_weights) == layer.num_local_experts
assert len(down_proj_weights) == layer.num_local_experts assert len(down_proj_weights) == layer.num_local_experts
if layer.quant_method.quant_config: algo = layer.quant_method.quant_config.name()
algo = layer.quant_method.quant_config.name()
assert algo == "wint8"
assert up_gate_proj_weights[0].shape == [ assert up_gate_proj_weights[0].shape == [
layer.hidden_size, layer.hidden_size,
@@ -79,52 +135,12 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
scale_name = self.added_scale_attrs[idx] scale_name = self.added_scale_attrs[idx]
quanted_weight_scale = weight_tensor.abs().max(axis=1) quanted_weight_scale = weight_tensor.abs().max(axis=1)
if self.quant_config is not None: quanted_weight = weight_tensor / quanted_weight_scale[:, None, :] * max_bound
quanted_weight = weight_tensor / quanted_weight_scale[:, None, :] * max_bound quanted_weight = paddle.round(quanted_weight).astype("int8")
quanted_weight = paddle.round(quanted_weight).astype("int8") quanted_weight_scale = quanted_weight_scale / max_bound
quanted_weight_scale = quanted_weight_scale / max_bound
setattr( getattr(layer, weight_name).set_value(quanted_weight)
layer, getattr(layer, scale_name).set_value(quanted_weight_scale)
weight_name,
layer.create_parameter(
shape=quanted_weight.shape,
dtype=quanted_weight.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, weight_name).set_value(quanted_weight)
setattr(
layer,
scale_name,
layer.create_parameter(
shape=quanted_weight_scale.shape,
dtype=quanted_weight_scale.dtype,
),
)
getattr(layer, scale_name).set_value(quanted_weight_scale)
else:
setattr(
layer,
weight_name,
layer.create_parameter(
shape=quanted_weight.shape,
dtype=quanted_weight.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, weight_name).set_value(quanted_weight)
setattr(
layer,
scale_name,
layer.create_parameter(
shape=quanted_weight_scale.shape,
dtype=quanted_weight_scale.dtype,
),
)
getattr(layer, scale_name).set_value(quanted_weight_scale)
@paddle.no_grad() @paddle.no_grad()
def apply( def apply(
@@ -159,16 +175,16 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
if self.quant_config is not None: if self.quant_config is not None:
config = { config = {
"BLOCK_SIZE_M": 32, "BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128, "BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128, "BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 8,
} }
else: else:
config = { config = {
"BLOCK_SIZE_M": 32, "BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64, "BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64, "BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 8,
} }
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess( sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(

View File

@@ -313,24 +313,14 @@ class WeightOnlyLinearMethod(QuantMethodBase):
raise NotImplementedError raise NotImplementedError
def apply(self, layer, x): def apply(self, layer, x):
if current_platform.is_maca(): linear_out = weight_only_linear(
linear_out = weight_only_linear( x,
x, weight=layer.weight,
weight=layer.weight, bias=layer.bias if layer.add_bias else None,
bias=layer.bias if layer.add_bias else None, weight_scale=layer.weight_scale,
weight_scale=layer.weight_scale, weight_dtype=("int8" if self.quant_config.name() == "wint8" else "int4"),
weight_dtype=("int8" if self.quant_config.name() == "wint8" else "int4"), arch=self.quant_config.weight_only_linear_arch,
arch=80, )
)
else:
linear_out = weight_only_linear(
x,
weight=layer.weight,
bias=layer.bias if layer.add_bias else None,
weight_scale=layer.weight_scale,
weight_dtype=("int8" if self.quant_config.name() == "wint8" else "int4"),
arch=self.quant_config.weight_only_linear_arch,
)
return linear_out return linear_out

View File

@@ -52,9 +52,9 @@ class ErnieRotaryEmbedding:
rot_emb = paddle.concat([freqs.cos(), freqs.sin()], axis=-1) rot_emb = paddle.concat([freqs.cos(), freqs.sin()], axis=-1)
return rot_emb return rot_emb
elif paddle.is_compiled_with_custom_device("metax_gpu"): elif paddle.is_compiled_with_custom_device("metax_gpu"):
# shape: [B, S, D] # shape: [B, S, D/2]
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim), dtype="float32") rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32")
emb = paddle.stack([freqs, freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim)) emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim // 2))
else: else:
# shape: [B, S, D/2] # shape: [B, S, D/2]
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32") rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32")