[Feature] refactor metax_gpu attention and moe and remove some useless code (#3688)

Co-authored-by: yongqiangma <xing.wo@163.com>
This commit is contained in:
SuperNova
2025-09-12 14:40:25 +08:00
committed by GitHub
parent cab7a633fe
commit 805f29a06c
5 changed files with 389 additions and 289 deletions

View File

@@ -894,7 +894,7 @@ class CacheConfig:
self.kv_cache_ratio = 1.0
else:
self.kv_cache_ratio = 0.75
self.enc_dec_block_num = 0 if current_platform.is_iluvatar() else 2
self.enc_dec_block_num = 0 if current_platform.is_iluvatar() or current_platform.is_maca() else 2
self.prealloc_dec_block_slot_num_threshold = 12
self.cache_dtype = "bfloat16"
self.model_cfg = None

View File

@@ -16,13 +16,11 @@
from __future__ import annotations
import math
import os
from dataclasses import dataclass, field
from typing import List, Optional
import paddle
import paddle.nn.functional as F
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.forward_meta import ForwardMeta, ForwardMode
@@ -92,6 +90,7 @@ class FlashAttentionBackend(AttentionBackend):
"""
super().__init__()
self.attention_metadata: FlashAttentionMetadata = None
self.record_block_table_metadata = {}
self.block_size: int = fd_config.parallel_config.block_size
self.max_seq_len: int = fd_config.parallel_config.max_model_len
self.rope_theta: float = (
@@ -110,6 +109,9 @@ class FlashAttentionBackend(AttentionBackend):
self.kv_num_heads: int = kv_num_heads
self.num_heads: int = num_heads
self.head_dim: int = fd_config.model_config.head_dim
self.total_num_heads = self.num_heads + 2 * self.kv_num_heads
self.total_hidden_dim = self.total_num_heads * self.head_dim
self.dtype = paddle.get_default_dtype()
self.num_layers: int = fd_config.model_config.num_hidden_layers
self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", 32768))
@@ -125,7 +127,98 @@ class FlashAttentionBackend(AttentionBackend):
def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
forward_meta.forward_mode = ForwardMode.NATIVE
return
self.prefill_info_dict = {}
self.decode_info_dict = {}
prefill_non_zeros_ids = forward_meta.seq_lens_this_time > 1
decode_non_zeros_ids = forward_meta.seq_lens_this_time == 1
self.prefill_info_dict["batch_ids"] = paddle.where(prefill_non_zeros_ids)[0]
self.decode_info_dict["batch_ids"] = paddle.where(decode_non_zeros_ids)[0]
self.prefill_len = len(self.prefill_info_dict["batch_ids"])
self.decode_len = len(self.decode_info_dict["batch_ids"])
# only prefill
if self.decode_len == 0:
cu_seq_ids = list(range(self.prefill_len + 1))
self.prefill_info_dict["cu_seqlens_q"] = forward_meta.cu_seqlens_q[cu_seq_ids].astype("int32")
# only decode
elif self.prefill_len == 0:
pass
# both prefill and decode
else:
prefill_num_tokens = paddle.sum(forward_meta.seq_lens_this_time[prefill_non_zeros_ids])
decode_num_tokens = paddle.sum(forward_meta.seq_lens_this_time[decode_non_zeros_ids])
self.prefill_info_dict["cu_seqlens_q"] = paddle.zeros(
[self.prefill_len + 1], dtype=forward_meta.cu_seqlens_q.dtype
)
self.prefill_info_dict["cu_seqlens_q"][1:] = forward_meta.seq_lens_encoder[
self.prefill_info_dict["batch_ids"], 0
]
self.prefill_info_dict["cu_seqlens_q"] = paddle.cumsum(self.prefill_info_dict["cu_seqlens_q"]).astype(
"int32"
)
self.prefill_qkv = paddle.zeros([prefill_num_tokens, self.total_hidden_dim], dtype=self.dtype)
self.decode_qkv = paddle.zeros([decode_num_tokens, self.total_hidden_dim], dtype=self.dtype)
self.merged_output = paddle.zeros(
[prefill_num_tokens + decode_num_tokens, self.num_heads, self.head_dim], dtype=self.dtype
)
prefill_start, decode_start, start = 0, 0, 0
non_zeros_ids = forward_meta.seq_lens_this_time != 0
non_zeros_seq_lens = forward_meta.seq_lens_this_time[non_zeros_ids]
end = non_zeros_seq_lens[0]
if end > 1:
last_stage = "prefill"
prefill_end = end
decode_end = 0
else:
last_stage = "decode"
prefill_end = 0
decode_end = end
self.prefill_info_dict["id_group"] = []
self.prefill_info_dict["reverse_id_group"] = []
self.decode_info_dict["id_group"] = []
self.decode_info_dict["reverse_id_group"] = []
self.record_stages = []
for seq_len in non_zeros_seq_lens[1:]:
if seq_len > 1:
if last_stage == "decode":
self.record_stages.append((last_stage, len(self.decode_info_dict["id_group"])))
self.decode_info_dict["id_group"].append((decode_start, decode_end))
self.decode_info_dict["reverse_id_group"].append((start, end))
decode_start = decode_end
start = end
last_stage = "prefill"
prefill_end += seq_len
end += seq_len
else:
if last_stage == "prefill":
self.record_stages.append((last_stage, len(self.prefill_info_dict["id_group"])))
self.prefill_info_dict["id_group"].append((prefill_start, prefill_end))
self.prefill_info_dict["reverse_id_group"].append((start, end))
prefill_start = prefill_end
start = end
last_stage = "decode"
decode_end += seq_len
end += seq_len
if prefill_start < prefill_end:
self.record_stages.append(("prefill", len(self.prefill_info_dict["id_group"])))
self.prefill_info_dict["id_group"].append((prefill_start, prefill_end))
self.prefill_info_dict["reverse_id_group"].append((start, end))
if decode_start < decode_end:
self.record_stages.append(("decode", len(self.decode_info_dict["id_group"])))
self.decode_info_dict["id_group"].append((decode_start, decode_end))
self.decode_info_dict["reverse_id_group"].append((start, end))
self.batch_ids_prefill = paddle.to_tensor(self.prefill_info_dict["batch_ids"])
self.batch_ids_decode = paddle.to_tensor(self.decode_info_dict["batch_ids"])
self.seq_lens_dec = forward_meta.seq_lens_decoder[self.batch_ids_decode, 0]
self.block_table_dec = forward_meta.block_tables[self.batch_ids_decode, :]
def get_attntion_meta(self) -> AttentionMetadata:
"""get_attntion_meta"""
@@ -149,106 +242,11 @@ class FlashAttentionBackend(AttentionBackend):
else:
return (
max_num_blocks,
self.kv_num_heads,
self.block_size,
self.kv_num_heads,
self.head_dim,
)
def split_qkv(self, qkv, num_head_q, num_head_kv, dim):
q = qkv[:, : num_head_q * dim].reshape([-1, num_head_q, dim])
k = qkv[:, num_head_q * dim : num_head_q * dim + num_head_kv * dim].reshape([-1, num_head_kv, dim])
v = qkv[:, num_head_q * dim + num_head_kv * dim :].reshape([-1, num_head_kv, dim])
return q, k, v
def flash_attn_varlen(self, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k):
num_head = q.shape[1]
dim = q.shape[2]
q_ = q.reshape([-1, num_head, dim])
k_ = k.reshape([-1, num_head, dim])
v_ = v.reshape([-1, num_head, dim])
bsz = cu_seqlens_q.shape[0] - 1
out = []
for i in range(bsz):
start_q, end_q = cu_seqlens_q[i].item(), cu_seqlens_q[i + 1].item()
start_k, end_k = cu_seqlens_k[i].item(), cu_seqlens_k[i + 1].item()
qi = q_[start_q:end_q] # [seq_q, nh, dim]
ki = k_[start_k:end_k] # [seq_k, nh, dim]
vi = v_[start_k:end_k] # [seq_k, nh, dim]
qi = qi.transpose([1, 0, 2]) # [nh, seq_q, dim]
ki = ki.transpose([1, 2, 0]) # [nh, dim, seq_k]
vi = vi.transpose([1, 0, 2]) # [nh, seq_k, dim]
score = paddle.matmul(qi, ki) / math.sqrt(dim) # [nh, seq_q, seq_k]
prob = F.softmax(score, axis=-1)
o = paddle.matmul(prob, vi) # [nh, seq_q, dim]
o = o.transpose([1, 0, 2]) # [seq_q, nh, dim]
out.append(o)
return paddle.concat(out, axis=0) # [total_q, nh, dim]
def flash_attn_with_kvcache(self, q, cache_k, cache_v, cache_seqlens, block_tables=None):
bs, _, nh, dim = q.shape
out = []
for i in range(bs):
q_i = q[i] # [1, nh, dim]
k_i = cache_k[i, : cache_seqlens[i, 0]] # [seqlen, nh, dim]
v_i = cache_v[i, : cache_seqlens[i, 0]]
qi = q_i.transpose([1, 0, 2]) # [nh, 1, dim]
ki = k_i.transpose([1, 2, 0]) # [nh, dim, seqlen]
vi = v_i.transpose([1, 0, 2]) # [nh, seqlen, dim]
score = paddle.matmul(qi, ki) / math.sqrt(dim)
prob = F.softmax(score, axis=-1)
o = paddle.matmul(prob, vi).transpose([1, 0, 2]) # [1, nh, dim]
out.append(o)
return paddle.concat(out, axis=0) # [bs, nh, dim]
def block_cache_to_naive_cache(slef, cache_k, cache_v, bsz, block_tables, cache_seq_len):
_, num_head, blocksize, dim_head = cache_k.shape
out_cache_k = paddle.zeros(shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_k.dtype)
out_cache_v = paddle.zeros(shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_v.dtype)
for i in range(bsz):
for j in range(cache_seq_len):
out_cache_k[i, :, j, :] = cache_k[block_tables[i, j // blocksize], :, j % blocksize, :]
out_cache_v[i, :, j, :] = cache_v[block_tables[i, j // blocksize], :, j % blocksize, :]
return out_cache_k, out_cache_v
def block_cache_to_naive_cache__(self, cache_k, cache_v, bsz, block_tables, max_cache_seq_len):
_, num_head, blocksize, dim_head = cache_k.shape
out_cache_k = paddle.zeros(shape=[bsz, max_cache_seq_len + 1, num_head, dim_head], dtype=cache_k.dtype)
out_cache_v = paddle.zeros(shape=[bsz, max_cache_seq_len + 1, num_head, dim_head], dtype=cache_v.dtype)
for i in range(bsz):
for j in range(max_cache_seq_len):
out_cache_k[i, j, :, :] = cache_k[block_tables[i, j // blocksize], :, j % blocksize, :]
out_cache_v[i, j, :, :] = cache_v[block_tables[i, j // blocksize], :, j % blocksize, :]
return out_cache_k, out_cache_v
def update_encoder_kv_cache(self, k, v, seq_lens_encoder, cache_k, cache_v, block_tables):
_, num_head, blocksize, dim_head = cache_k.shape
offset = 0
for batch_idx, seq_len in enumerate(seq_lens_encoder.numpy()):
if seq_len == 0:
continue
for seq_idx in range(seq_len):
block_id = block_tables[batch_idx, seq_idx // blocksize]
assert block_id != -1
index = offset + seq_idx
cache_k[block_id, :, seq_idx % blocksize, :] = k[index, :, :]
cache_v[block_id, :, seq_idx % blocksize, :] = v[index, :, :]
offset += seq_len
def update_decoder_kv_cache(self, k, v, seq_lens_decoder, cache_k, cache_v, block_tables):
_, num_head, blocksize, dim_head = cache_k.shape
for batch_idx, seq_idx in enumerate(seq_lens_decoder.numpy()):
if seq_idx == 0:
continue
block_id = block_tables[batch_idx, seq_idx // blocksize]
assert block_id != -1
cache_k[block_id, :, seq_idx % blocksize, :] = k[batch_idx, :, :]
cache_v[block_id, :, seq_idx % blocksize, :] = v[batch_idx, :, :]
def apply_rope(self, qk, cos, sin):
rotate_half = paddle.reshape(
paddle.stack([-qk[..., 1::2], qk[..., 0::2]], axis=-1),
@@ -257,138 +255,234 @@ class FlashAttentionBackend(AttentionBackend):
out = paddle.add(paddle.multiply(qk, cos), paddle.multiply(rotate_half, sin))
return paddle.cast(out, qk.dtype)
@paddle.no_grad()
def forward_native_backend(
def get_splited_qkv(
self,
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
qkv: paddle.Tensor,
layer,
forward_meta: ForwardMeta,
cu_seqlens_q: paddle.Tensor,
batch_ids=None,
is_decode=False,
):
q_end = self.num_heads * self.head_dim
k_end = q_end + self.kv_num_heads * self.head_dim
v_end = k_end + self.kv_num_heads * self.head_dim
assert v_end == qkv.shape[-1], f"Shape mismatch: {v_end} vs {qkv.shape[-1]}"
assert qkv.shape[0] == cu_seqlens_q[-1], f"Shape mismatch: {qkv.shape[0]} vs {cu_seqlens_q[-1]}"
bsz = forward_meta.seq_lens_this_time.shape[0]
num_head_q, num_head_kv, dim = layer.num_heads, layer.kv_num_heads, layer.head_dim
if batch_ids is None:
batch_ids = list(range(forward_meta.seq_lens_this_time.shape[0]))
# 1. 分离 encoder / decoder 的 mask
seq_lens_encoder = forward_meta.seq_lens_encoder.squeeze(-1)
seq_lens_decoder = forward_meta.seq_lens_decoder.squeeze(-1)
seq_lens_this_time = forward_meta.seq_lens_this_time
encoder_indices = []
decoder_indices = []
q = qkv[..., 0:q_end]
k = qkv[..., q_end:k_end]
v = qkv[..., k_end:v_end]
offset = 0
for i in range(bsz):
length = seq_lens_this_time[i].item()
if seq_lens_encoder[i] > 0:
encoder_indices.extend(range(offset, offset + length))
elif seq_lens_decoder[i] > 0:
decoder_indices.extend(range(offset, offset + length))
offset += length
q = q.view([-1, self.num_heads, self.head_dim])
k = k.view([-1, self.kv_num_heads, self.head_dim])
v = v.view([-1, self.kv_num_heads, self.head_dim])
encoder_indices = paddle.to_tensor(encoder_indices, dtype="int32")
decoder_indices = paddle.to_tensor(decoder_indices, dtype="int32")
if is_decode:
return q, k, v
encoder_qkv = paddle.index_select(qkv, encoder_indices, axis=0)
decoder_qkv = paddle.index_select(qkv, decoder_indices, axis=0)
for idx in range(len(cu_seqlens_q) - 1):
batch_idx = batch_ids[idx]
seq_len_i = forward_meta.seq_lens_this_time[batch_idx]
if seq_len_i == 0:
continue
cached_kv_len = forward_meta.seq_lens_decoder[batch_idx][0]
cu_seq_start_q = cu_seqlens_q[idx]
cu_seq_end_q = cu_seqlens_q[idx + 1]
# forward_meta.rotary_embs is [2, 1, S, 1, D // 2]
if forward_meta.rotary_embs is not None:
cos = paddle.repeat_interleave(
forward_meta.rotary_embs[0, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :], repeats=2, axis=-1
) # [Si, D]
sin = paddle.repeat_interleave(
forward_meta.rotary_embs[1, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :], repeats=2, axis=-1
) # [Si, D]
q[cu_seq_start_q:cu_seq_end_q] = self.apply_rope(q[cu_seq_start_q:cu_seq_end_q], cos, sin)
k[cu_seq_start_q:cu_seq_end_q] = self.apply_rope(k[cu_seq_start_q:cu_seq_end_q], cos, sin)
# 2. 分解 encoder 和 decoder 的 qkv
encoder_q, encoder_k, encoder_v = self.split_qkv(encoder_qkv, num_head_q, num_head_kv, dim)
decoder_q, decoder_k, decoder_v = self.split_qkv(decoder_qkv, num_head_q, num_head_kv, dim)
cache_k = forward_meta.caches[2 * layer.layer_id]
cache_v = forward_meta.caches[2 * layer.layer_id + 1]
return q, k, v
# 3. Rotary Embedding
if decoder_q.numel() != 0 or encoder_q.numel() != 0:
for batch_idx in range(forward_meta.seq_lens_this_time.shape[0]):
seq_len_i = forward_meta.seq_lens_this_time[batch_idx]
if seq_len_i == 0:
def split_pd_qkv(self, qkv):
for ids, reverse_ids in zip(self.prefill_info_dict["id_group"], self.prefill_info_dict["reverse_id_group"]):
self.prefill_qkv[ids[0] : ids[1], :] = qkv[reverse_ids[0] : reverse_ids[1], :]
for ids, reverse_ids in zip(self.decode_info_dict["id_group"], self.decode_info_dict["reverse_id_group"]):
self.decode_qkv[ids[0] : ids[1], :] = qkv[reverse_ids[0] : reverse_ids[1], :]
return self.prefill_qkv, self.decode_qkv
def merge_pd_output(self, prefill_out, decode_out):
for stage, idx in self.record_stages:
if stage == "prefill":
ids = self.prefill_info_dict["id_group"][idx]
reverse_ids = self.prefill_info_dict["reverse_id_group"][idx]
self.merged_output[reverse_ids[0] : reverse_ids[1], :, :] = prefill_out[ids[0] : ids[1], :, :]
else:
ids = self.decode_info_dict["id_group"][idx]
reverse_ids = self.decode_info_dict["reverse_id_group"][idx]
self.merged_output[reverse_ids[0] : reverse_ids[1], :, :] = decode_out[ids[0] : ids[1], :, :]
return self.merged_output
def update_kv_cache(
self, k, v, k_cache_id, v_cache_id, layer_id, forward_meta: ForwardMeta, specific_batch_ids=None
):
tensor_start = 0
for batch_idx in range(forward_meta.block_tables.shape[0]):
if specific_batch_ids is not None and batch_idx not in specific_batch_ids:
continue
seq_len = forward_meta.seq_lens_this_time[batch_idx]
if seq_len == 0:
continue
tensor_end = tensor_start + seq_len
slice_trans_k = k[tensor_start:tensor_end, :, :]
slice_trans_v = v[tensor_start:tensor_end, :, :]
cur_block_tables = forward_meta.block_tables[batch_idx]
cur_used_block_tables = cur_block_tables[cur_block_tables != -1]
# encoder prefil
if seq_len > 1:
cache_start = 0
cur_used_num_blocks = cur_used_block_tables.shape[0]
for i, block_id in enumerate(cur_used_block_tables):
# last block: seq_len - cache_start <= block_size
if i == cur_used_num_blocks - 1:
cache_end = seq_len - cache_start
assert cache_end <= self.block_size
forward_meta.caches[k_cache_id][block_id, 0:cache_end, :, :] = slice_trans_k[
cache_start:seq_len, :, :
]
forward_meta.caches[v_cache_id][block_id, 0:cache_end, :, :] = slice_trans_v[
cache_start:seq_len, :, :
]
if layer_id == self.num_layers - 1:
self.record_block_table_metadata[batch_idx] = {
"block_id": block_id.item(),
"cache_end": cache_end,
}
# non last block: seq_lens_this_time > block_size
else:
assert seq_len > self.block_size
cache_end = cache_start + self.block_size
forward_meta.caches[k_cache_id][block_id] = slice_trans_k[cache_start:cache_end, :, :]
forward_meta.caches[v_cache_id][block_id] = slice_trans_v[cache_start:cache_end, :, :]
cache_start += self.block_size
tensor_start = tensor_end
def merge_output(self, prefill_out, decode_out, forward_meta: ForwardMeta):
assert not (prefill_out is None and decode_out is None), "prefill and decode output cannot both be None"
if prefill_out is None:
return decode_out
elif decode_out is None:
return prefill_out
else:
prefill_out = prefill_out
decode_out = decode_out
merged_output = []
prefill_tensor_start = 0
decode_tensor_start = 0
for seq_lens_this_time in forward_meta.seq_lens_this_time:
if seq_lens_this_time == 0:
continue
cached_kv_len = seq_lens_decoder[batch_idx]
cu_seq_start_q = forward_meta.cu_seqlens_q[batch_idx]
cu_seq_end_q = forward_meta.cu_seqlens_q[batch_idx + 1]
if forward_meta.rotary_embs is not None and cu_seq_end_q > cu_seq_start_q:
cos = forward_meta.rotary_embs[0, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :]
sin = forward_meta.rotary_embs[1, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :]
if seq_lens_this_time > 1:
tensor_end = prefill_tensor_start + seq_lens_this_time
merged_output.append(prefill_out[prefill_tensor_start:tensor_end, :, :])
prefill_tensor_start = tensor_end
else:
assert seq_lens_this_time == 1
tensor_end = decode_tensor_start + seq_lens_this_time
merged_output.append(decode_out[decode_tensor_start:tensor_end, :, :])
decode_tensor_start = tensor_end
def rope_func(qk):
qk[cu_seq_start_q:cu_seq_end_q] = self.apply_rope(qk[cu_seq_start_q:cu_seq_end_q], cos, sin)
assert (
prefill_tensor_start == prefill_out.shape[0]
), f"prefill merged unfinished: {prefill_tensor_start} vs {prefill_out.shape[0]}"
assert (
decode_tensor_start == decode_out.shape[0]
), f"decode merged unfinished: {decode_tensor_start} vs {decode_out.shape[0]}"
merged_output = paddle.concat(merged_output, axis=0)
return merged_output
if encoder_q.numel() != 0:
rope_func(encoder_q)
rope_func(encoder_k)
if decoder_q.numel() != 0:
rope_func(decoder_q)
rope_func(decoder_k)
def forward_prefill(self, prefill_qkv, layer_id, k_cache_id, v_cache_id, forward_meta: ForwardMeta):
# 4. Flash Attention for encoder
encoder_v = encoder_v
cu_seqlens_q = forward_meta.cu_seqlens_q
cu_seqlens_k = forward_meta.cu_seqlens_k
max_seqlen_q = paddle.max(seq_lens_this_time)
max_seqlen_k = max_seqlen_q
prefill_q, prefill_k, prefill_v = self.get_splited_qkv(
prefill_qkv,
forward_meta,
self.prefill_info_dict["cu_seqlens_q"],
batch_ids=self.batch_ids_prefill,
)
prefill_out = flash_attn_unpadded_func(
prefill_q,
prefill_k,
prefill_v,
self.prefill_info_dict["cu_seqlens_q"],
self.prefill_info_dict["cu_seqlens_q"],
max_seqlen_q=self.max_seq_len,
max_seqlen_k=self.max_seq_len,
attn_mask=forward_meta.attn_mask,
causal=self.causal,
)[0]
self.update_kv_cache(
prefill_k, prefill_v, k_cache_id, v_cache_id, layer_id, forward_meta, self.batch_ids_prefill
)
return prefill_out
def forward_decode(self, decode_qkv, k_cache_id, v_cache_id, forward_meta: ForwardMeta):
cache_k = forward_meta.caches[k_cache_id]
cache_v = forward_meta.caches[v_cache_id]
cu_seq_lens = list(range(self.decode_len + 1))
q, k, v = self.get_splited_qkv(decode_qkv, forward_meta, cu_seq_lens, self.batch_ids_decode, is_decode=True)
decoder_q = q.view([self.decode_len, 1, self.num_heads, self.head_dim])
decoder_k_ = k.view([self.decode_len, 1, self.kv_num_heads, self.head_dim])
decoder_v_ = v.view([self.decode_len, 1, self.kv_num_heads, self.head_dim])
decode_out = flash_attn_kvcache_func(
decoder_q,
cache_k,
cache_v,
self.seq_lens_dec,
self.block_table_dec,
decoder_k_,
decoder_v_,
rotary_cos=forward_meta.rotary_embs[0, 0, :, 0, :].astype("bfloat16"),
rotary_sin=forward_meta.rotary_embs[1, 0, :, 0, :].astype("bfloat16"),
causal=self.causal,
is_rotary_interleaved=True,
)[0].squeeze(1)
return decode_out
@paddle.no_grad()
def forward_native_backend(self, q, k, v, qkv, layer, forward_meta: ForwardMeta):
layer_id = layer.layer_id
k_cache_id = layer_id * 2
v_cache_id = k_cache_id + 1
if self.decode_len == 0:
out = self.forward_prefill(qkv, layer_id, k_cache_id, v_cache_id, forward_meta)
elif self.prefill_len == 0:
out = self.forward_decode(qkv, k_cache_id, v_cache_id, forward_meta)
if encoder_q.numel() > 0:
encoder_out = flash_attn_unpadded_func(
encoder_q,
encoder_k,
encoder_v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
attn_mask=forward_meta.attn_mask,
causal=self.causal,
)
self.update_encoder_kv_cache(
encoder_k, encoder_v, seq_lens_encoder, cache_k, cache_v, forward_meta.block_tables
)
else:
encoder_out = None
prefill_qkv, decode_qkv = self.split_pd_qkv(qkv)
prefill_output = self.forward_prefill(prefill_qkv, layer_id, k_cache_id, v_cache_id, forward_meta)
decode_output = self.forward_decode(decode_qkv, k_cache_id, v_cache_id, forward_meta)
out = self.merge_pd_output(prefill_output, decode_output)
# 5. decoder attention with kv cache
bs = decoder_q.shape[0]
decoder_q = decoder_q.reshape([bs, 1, num_head_q, dim])
decoder_k_ = decoder_k.reshape([bs, 1, num_head_kv, dim])
decoder_v_ = decoder_v.reshape([bs, 1, num_head_kv, dim])
cache_seqlens = paddle.index_select(forward_meta.seq_lens_decoder, decoder_indices, axis=0)
# 5.1 convert paged kv cache to continuous cache
if decoder_q.numel() > 0:
max_cache_seq_len = paddle.max(cache_seqlens)
c_cache_k, c_cache_v = self.block_cache_to_naive_cache__(
cache_k, cache_v, bs, forward_meta.block_tables, max_cache_seq_len
)
decoder_out = flash_attn_kvcache_func(
decoder_q,
c_cache_k,
c_cache_v,
cache_seqlens.squeeze(-1),
None,
decoder_k_,
decoder_v_,
causal=self.causal,
)
self.update_decoder_kv_cache(
decoder_k, decoder_v, seq_lens_decoder, cache_k, cache_v, forward_meta.block_tables
)
else:
decoder_out = None
# 6. 拼接 encoder_out 和 decoder_out
total_len = qkv.shape[0]
out = paddle.zeros([total_len, num_head_q, dim])
if encoder_out is not None:
out = paddle.tensor.put_along_axis(
out, encoder_indices.unsqueeze(-1).unsqueeze(-1), encoder_out[0], axis=0
)
if decoder_out is not None:
new_decoder_out = decoder_out[0].squeeze(1)
out = paddle.tensor.put_along_axis(
out, decoder_indices.unsqueeze(-1).unsqueeze(-1), new_decoder_out, axis=0
)
out.reshape_([total_len, num_head_q * dim])
if qkv.dim() == 2:
out = out.view([-1, self.num_heads * self.head_dim])
return out

View File

@@ -45,17 +45,73 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
"""process_prequanted_weights"""
pass
@paddle.no_grad()
def create_weights(self, layer: nn.Layer, state_dict):
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
"""
Triton MoE create weight process.
"""
self.weight_dtype = "int8"
self.default_dtype = layer._helper.get_default_dtype()
up_gate_proj_weight_name = self.added_weight_attrs[0]
down_proj_weight_name = self.added_weight_attrs[1]
self.up_gate_proj_weight_shape = [
layer.num_local_experts,
layer.hidden_size,
layer.moe_intermediate_size * 2,
]
self.down_proj_weight_shape = [
layer.num_local_experts,
layer.moe_intermediate_size,
layer.hidden_size,
]
setattr(
layer,
up_gate_proj_weight_name,
layer.create_parameter(
shape=self.up_gate_proj_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
down_proj_weight_name,
layer.create_parameter(
shape=self.down_proj_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# weight_scale
setattr(
layer,
self.added_scale_attrs[0],
layer.create_parameter(
shape=[layer.num_local_experts, layer.moe_intermediate_size * 2],
dtype=self.default_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
self.added_scale_attrs[1],
layer.create_parameter(
shape=[layer.num_local_experts, layer.hidden_size],
dtype=self.default_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
def process_loaded_weights(self, layer: nn.Layer, state_dict):
"""
Triton MoE load weight process.
"""
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
assert len(up_gate_proj_weights) == layer.num_local_experts
assert len(down_proj_weights) == layer.num_local_experts
if layer.quant_method.quant_config:
algo = layer.quant_method.quant_config.name()
algo = layer.quant_method.quant_config.name()
assert algo == "wint8"
assert up_gate_proj_weights[0].shape == [
layer.hidden_size,
@@ -79,52 +135,12 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
scale_name = self.added_scale_attrs[idx]
quanted_weight_scale = weight_tensor.abs().max(axis=1)
if self.quant_config is not None:
quanted_weight = weight_tensor / quanted_weight_scale[:, None, :] * max_bound
quanted_weight = paddle.round(quanted_weight).astype("int8")
quanted_weight_scale = quanted_weight_scale / max_bound
quanted_weight = weight_tensor / quanted_weight_scale[:, None, :] * max_bound
quanted_weight = paddle.round(quanted_weight).astype("int8")
quanted_weight_scale = quanted_weight_scale / max_bound
setattr(
layer,
weight_name,
layer.create_parameter(
shape=quanted_weight.shape,
dtype=quanted_weight.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, weight_name).set_value(quanted_weight)
setattr(
layer,
scale_name,
layer.create_parameter(
shape=quanted_weight_scale.shape,
dtype=quanted_weight_scale.dtype,
),
)
getattr(layer, scale_name).set_value(quanted_weight_scale)
else:
setattr(
layer,
weight_name,
layer.create_parameter(
shape=quanted_weight.shape,
dtype=quanted_weight.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, weight_name).set_value(quanted_weight)
setattr(
layer,
scale_name,
layer.create_parameter(
shape=quanted_weight_scale.shape,
dtype=quanted_weight_scale.dtype,
),
)
getattr(layer, scale_name).set_value(quanted_weight_scale)
getattr(layer, weight_name).set_value(quanted_weight)
getattr(layer, scale_name).set_value(quanted_weight_scale)
@paddle.no_grad()
def apply(
@@ -159,16 +175,16 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
if self.quant_config is not None:
config = {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
}
else:
config = {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"GROUP_SIZE_M": 8,
}
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(

View File

@@ -313,24 +313,14 @@ class WeightOnlyLinearMethod(QuantMethodBase):
raise NotImplementedError
def apply(self, layer, x):
if current_platform.is_maca():
linear_out = weight_only_linear(
x,
weight=layer.weight,
bias=layer.bias if layer.add_bias else None,
weight_scale=layer.weight_scale,
weight_dtype=("int8" if self.quant_config.name() == "wint8" else "int4"),
arch=80,
)
else:
linear_out = weight_only_linear(
x,
weight=layer.weight,
bias=layer.bias if layer.add_bias else None,
weight_scale=layer.weight_scale,
weight_dtype=("int8" if self.quant_config.name() == "wint8" else "int4"),
arch=self.quant_config.weight_only_linear_arch,
)
linear_out = weight_only_linear(
x,
weight=layer.weight,
bias=layer.bias if layer.add_bias else None,
weight_scale=layer.weight_scale,
weight_dtype=("int8" if self.quant_config.name() == "wint8" else "int4"),
arch=self.quant_config.weight_only_linear_arch,
)
return linear_out

View File

@@ -52,9 +52,9 @@ class ErnieRotaryEmbedding:
rot_emb = paddle.concat([freqs.cos(), freqs.sin()], axis=-1)
return rot_emb
elif paddle.is_compiled_with_custom_device("metax_gpu"):
# shape: [B, S, D]
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim), dtype="float32")
emb = paddle.stack([freqs, freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim))
# shape: [B, S, D/2]
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32")
emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim // 2))
else:
# shape: [B, S, D/2]
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32")