From 805f29a06c52bfbcee14d9f02f3ea340e9ef4436 Mon Sep 17 00:00:00 2001 From: SuperNova <91192235+handsomecoderyang@users.noreply.github.com> Date: Fri, 12 Sep 2025 14:40:25 +0800 Subject: [PATCH] [Feature] refactor metax_gpu attention and moe and remove some useless code (#3688) Co-authored-by: yongqiangma --- fastdeploy/config.py | 2 +- .../metax/attention/flash_attn_backend.py | 522 +++++++++++------- .../moe/fused_moe_triton_metax_backend.py | 122 ++-- .../layers/quantization/weight_only.py | 26 +- .../model_executor/layers/rotary_embedding.py | 6 +- 5 files changed, 389 insertions(+), 289 deletions(-) diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 1ab4619fc..ce794435e 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -894,7 +894,7 @@ class CacheConfig: self.kv_cache_ratio = 1.0 else: self.kv_cache_ratio = 0.75 - self.enc_dec_block_num = 0 if current_platform.is_iluvatar() else 2 + self.enc_dec_block_num = 0 if current_platform.is_iluvatar() or current_platform.is_maca() else 2 self.prealloc_dec_block_slot_num_threshold = 12 self.cache_dtype = "bfloat16" self.model_cfg = None diff --git a/fastdeploy/model_executor/layers/backends/metax/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/backends/metax/attention/flash_attn_backend.py index 790e989f2..a4993d165 100644 --- a/fastdeploy/model_executor/layers/backends/metax/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/backends/metax/attention/flash_attn_backend.py @@ -16,13 +16,11 @@ from __future__ import annotations -import math import os from dataclasses import dataclass, field from typing import List, Optional import paddle -import paddle.nn.functional as F from fastdeploy.config import FDConfig from fastdeploy.model_executor.forward_meta import ForwardMeta, ForwardMode @@ -92,6 +90,7 @@ class FlashAttentionBackend(AttentionBackend): """ super().__init__() self.attention_metadata: FlashAttentionMetadata = None + self.record_block_table_metadata = {} self.block_size: int = fd_config.parallel_config.block_size self.max_seq_len: int = fd_config.parallel_config.max_model_len self.rope_theta: float = ( @@ -110,6 +109,9 @@ class FlashAttentionBackend(AttentionBackend): self.kv_num_heads: int = kv_num_heads self.num_heads: int = num_heads self.head_dim: int = fd_config.model_config.head_dim + self.total_num_heads = self.num_heads + 2 * self.kv_num_heads + self.total_hidden_dim = self.total_num_heads * self.head_dim + self.dtype = paddle.get_default_dtype() self.num_layers: int = fd_config.model_config.num_hidden_layers self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", 32768)) @@ -125,7 +127,98 @@ class FlashAttentionBackend(AttentionBackend): def init_attention_metadata(self, forward_meta: ForwardMeta): """Initialize attntion metadata hence all layers in the forward pass can reuse it.""" forward_meta.forward_mode = ForwardMode.NATIVE - return + self.prefill_info_dict = {} + self.decode_info_dict = {} + + prefill_non_zeros_ids = forward_meta.seq_lens_this_time > 1 + decode_non_zeros_ids = forward_meta.seq_lens_this_time == 1 + self.prefill_info_dict["batch_ids"] = paddle.where(prefill_non_zeros_ids)[0] + self.decode_info_dict["batch_ids"] = paddle.where(decode_non_zeros_ids)[0] + + self.prefill_len = len(self.prefill_info_dict["batch_ids"]) + self.decode_len = len(self.decode_info_dict["batch_ids"]) + + # only prefill + if self.decode_len == 0: + cu_seq_ids = list(range(self.prefill_len + 1)) + self.prefill_info_dict["cu_seqlens_q"] = forward_meta.cu_seqlens_q[cu_seq_ids].astype("int32") + # only decode + elif self.prefill_len == 0: + pass + # both prefill and decode + else: + prefill_num_tokens = paddle.sum(forward_meta.seq_lens_this_time[prefill_non_zeros_ids]) + decode_num_tokens = paddle.sum(forward_meta.seq_lens_this_time[decode_non_zeros_ids]) + + self.prefill_info_dict["cu_seqlens_q"] = paddle.zeros( + [self.prefill_len + 1], dtype=forward_meta.cu_seqlens_q.dtype + ) + self.prefill_info_dict["cu_seqlens_q"][1:] = forward_meta.seq_lens_encoder[ + self.prefill_info_dict["batch_ids"], 0 + ] + self.prefill_info_dict["cu_seqlens_q"] = paddle.cumsum(self.prefill_info_dict["cu_seqlens_q"]).astype( + "int32" + ) + + self.prefill_qkv = paddle.zeros([prefill_num_tokens, self.total_hidden_dim], dtype=self.dtype) + self.decode_qkv = paddle.zeros([decode_num_tokens, self.total_hidden_dim], dtype=self.dtype) + self.merged_output = paddle.zeros( + [prefill_num_tokens + decode_num_tokens, self.num_heads, self.head_dim], dtype=self.dtype + ) + + prefill_start, decode_start, start = 0, 0, 0 + non_zeros_ids = forward_meta.seq_lens_this_time != 0 + non_zeros_seq_lens = forward_meta.seq_lens_this_time[non_zeros_ids] + end = non_zeros_seq_lens[0] + if end > 1: + last_stage = "prefill" + prefill_end = end + decode_end = 0 + else: + last_stage = "decode" + prefill_end = 0 + decode_end = end + + self.prefill_info_dict["id_group"] = [] + self.prefill_info_dict["reverse_id_group"] = [] + self.decode_info_dict["id_group"] = [] + self.decode_info_dict["reverse_id_group"] = [] + self.record_stages = [] + for seq_len in non_zeros_seq_lens[1:]: + if seq_len > 1: + if last_stage == "decode": + self.record_stages.append((last_stage, len(self.decode_info_dict["id_group"]))) + self.decode_info_dict["id_group"].append((decode_start, decode_end)) + self.decode_info_dict["reverse_id_group"].append((start, end)) + decode_start = decode_end + start = end + last_stage = "prefill" + prefill_end += seq_len + end += seq_len + else: + if last_stage == "prefill": + self.record_stages.append((last_stage, len(self.prefill_info_dict["id_group"]))) + self.prefill_info_dict["id_group"].append((prefill_start, prefill_end)) + self.prefill_info_dict["reverse_id_group"].append((start, end)) + prefill_start = prefill_end + start = end + last_stage = "decode" + decode_end += seq_len + end += seq_len + + if prefill_start < prefill_end: + self.record_stages.append(("prefill", len(self.prefill_info_dict["id_group"]))) + self.prefill_info_dict["id_group"].append((prefill_start, prefill_end)) + self.prefill_info_dict["reverse_id_group"].append((start, end)) + if decode_start < decode_end: + self.record_stages.append(("decode", len(self.decode_info_dict["id_group"]))) + self.decode_info_dict["id_group"].append((decode_start, decode_end)) + self.decode_info_dict["reverse_id_group"].append((start, end)) + + self.batch_ids_prefill = paddle.to_tensor(self.prefill_info_dict["batch_ids"]) + self.batch_ids_decode = paddle.to_tensor(self.decode_info_dict["batch_ids"]) + self.seq_lens_dec = forward_meta.seq_lens_decoder[self.batch_ids_decode, 0] + self.block_table_dec = forward_meta.block_tables[self.batch_ids_decode, :] def get_attntion_meta(self) -> AttentionMetadata: """get_attntion_meta""" @@ -149,106 +242,11 @@ class FlashAttentionBackend(AttentionBackend): else: return ( max_num_blocks, - self.kv_num_heads, self.block_size, + self.kv_num_heads, self.head_dim, ) - def split_qkv(self, qkv, num_head_q, num_head_kv, dim): - q = qkv[:, : num_head_q * dim].reshape([-1, num_head_q, dim]) - k = qkv[:, num_head_q * dim : num_head_q * dim + num_head_kv * dim].reshape([-1, num_head_kv, dim]) - v = qkv[:, num_head_q * dim + num_head_kv * dim :].reshape([-1, num_head_kv, dim]) - return q, k, v - - def flash_attn_varlen(self, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k): - num_head = q.shape[1] - dim = q.shape[2] - - q_ = q.reshape([-1, num_head, dim]) - k_ = k.reshape([-1, num_head, dim]) - v_ = v.reshape([-1, num_head, dim]) - - bsz = cu_seqlens_q.shape[0] - 1 - out = [] - for i in range(bsz): - start_q, end_q = cu_seqlens_q[i].item(), cu_seqlens_q[i + 1].item() - start_k, end_k = cu_seqlens_k[i].item(), cu_seqlens_k[i + 1].item() - qi = q_[start_q:end_q] # [seq_q, nh, dim] - ki = k_[start_k:end_k] # [seq_k, nh, dim] - vi = v_[start_k:end_k] # [seq_k, nh, dim] - qi = qi.transpose([1, 0, 2]) # [nh, seq_q, dim] - ki = ki.transpose([1, 2, 0]) # [nh, dim, seq_k] - vi = vi.transpose([1, 0, 2]) # [nh, seq_k, dim] - - score = paddle.matmul(qi, ki) / math.sqrt(dim) # [nh, seq_q, seq_k] - prob = F.softmax(score, axis=-1) - o = paddle.matmul(prob, vi) # [nh, seq_q, dim] - o = o.transpose([1, 0, 2]) # [seq_q, nh, dim] - out.append(o) - - return paddle.concat(out, axis=0) # [total_q, nh, dim] - - def flash_attn_with_kvcache(self, q, cache_k, cache_v, cache_seqlens, block_tables=None): - bs, _, nh, dim = q.shape - out = [] - for i in range(bs): - q_i = q[i] # [1, nh, dim] - k_i = cache_k[i, : cache_seqlens[i, 0]] # [seqlen, nh, dim] - v_i = cache_v[i, : cache_seqlens[i, 0]] - qi = q_i.transpose([1, 0, 2]) # [nh, 1, dim] - ki = k_i.transpose([1, 2, 0]) # [nh, dim, seqlen] - vi = v_i.transpose([1, 0, 2]) # [nh, seqlen, dim] - score = paddle.matmul(qi, ki) / math.sqrt(dim) - prob = F.softmax(score, axis=-1) - o = paddle.matmul(prob, vi).transpose([1, 0, 2]) # [1, nh, dim] - out.append(o) - return paddle.concat(out, axis=0) # [bs, nh, dim] - - def block_cache_to_naive_cache(slef, cache_k, cache_v, bsz, block_tables, cache_seq_len): - _, num_head, blocksize, dim_head = cache_k.shape - out_cache_k = paddle.zeros(shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_k.dtype) - out_cache_v = paddle.zeros(shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_v.dtype) - for i in range(bsz): - for j in range(cache_seq_len): - out_cache_k[i, :, j, :] = cache_k[block_tables[i, j // blocksize], :, j % blocksize, :] - out_cache_v[i, :, j, :] = cache_v[block_tables[i, j // blocksize], :, j % blocksize, :] - return out_cache_k, out_cache_v - - def block_cache_to_naive_cache__(self, cache_k, cache_v, bsz, block_tables, max_cache_seq_len): - _, num_head, blocksize, dim_head = cache_k.shape - out_cache_k = paddle.zeros(shape=[bsz, max_cache_seq_len + 1, num_head, dim_head], dtype=cache_k.dtype) - out_cache_v = paddle.zeros(shape=[bsz, max_cache_seq_len + 1, num_head, dim_head], dtype=cache_v.dtype) - for i in range(bsz): - for j in range(max_cache_seq_len): - out_cache_k[i, j, :, :] = cache_k[block_tables[i, j // blocksize], :, j % blocksize, :] - out_cache_v[i, j, :, :] = cache_v[block_tables[i, j // blocksize], :, j % blocksize, :] - return out_cache_k, out_cache_v - - def update_encoder_kv_cache(self, k, v, seq_lens_encoder, cache_k, cache_v, block_tables): - _, num_head, blocksize, dim_head = cache_k.shape - offset = 0 - for batch_idx, seq_len in enumerate(seq_lens_encoder.numpy()): - if seq_len == 0: - continue - for seq_idx in range(seq_len): - block_id = block_tables[batch_idx, seq_idx // blocksize] - assert block_id != -1 - index = offset + seq_idx - cache_k[block_id, :, seq_idx % blocksize, :] = k[index, :, :] - cache_v[block_id, :, seq_idx % blocksize, :] = v[index, :, :] - - offset += seq_len - - def update_decoder_kv_cache(self, k, v, seq_lens_decoder, cache_k, cache_v, block_tables): - _, num_head, blocksize, dim_head = cache_k.shape - for batch_idx, seq_idx in enumerate(seq_lens_decoder.numpy()): - if seq_idx == 0: - continue - block_id = block_tables[batch_idx, seq_idx // blocksize] - assert block_id != -1 - cache_k[block_id, :, seq_idx % blocksize, :] = k[batch_idx, :, :] - cache_v[block_id, :, seq_idx % blocksize, :] = v[batch_idx, :, :] - def apply_rope(self, qk, cos, sin): rotate_half = paddle.reshape( paddle.stack([-qk[..., 1::2], qk[..., 0::2]], axis=-1), @@ -257,138 +255,234 @@ class FlashAttentionBackend(AttentionBackend): out = paddle.add(paddle.multiply(qk, cos), paddle.multiply(rotate_half, sin)) return paddle.cast(out, qk.dtype) - @paddle.no_grad() - def forward_native_backend( + def get_splited_qkv( self, - q: paddle.Tensor, - k: paddle.Tensor, - v: paddle.Tensor, qkv: paddle.Tensor, - layer, forward_meta: ForwardMeta, + cu_seqlens_q: paddle.Tensor, + batch_ids=None, + is_decode=False, ): + q_end = self.num_heads * self.head_dim + k_end = q_end + self.kv_num_heads * self.head_dim + v_end = k_end + self.kv_num_heads * self.head_dim + assert v_end == qkv.shape[-1], f"Shape mismatch: {v_end} vs {qkv.shape[-1]}" + assert qkv.shape[0] == cu_seqlens_q[-1], f"Shape mismatch: {qkv.shape[0]} vs {cu_seqlens_q[-1]}" - bsz = forward_meta.seq_lens_this_time.shape[0] - num_head_q, num_head_kv, dim = layer.num_heads, layer.kv_num_heads, layer.head_dim + if batch_ids is None: + batch_ids = list(range(forward_meta.seq_lens_this_time.shape[0])) - # 1. 分离 encoder / decoder 的 mask - seq_lens_encoder = forward_meta.seq_lens_encoder.squeeze(-1) - seq_lens_decoder = forward_meta.seq_lens_decoder.squeeze(-1) - seq_lens_this_time = forward_meta.seq_lens_this_time - encoder_indices = [] - decoder_indices = [] + q = qkv[..., 0:q_end] + k = qkv[..., q_end:k_end] + v = qkv[..., k_end:v_end] - offset = 0 - for i in range(bsz): - length = seq_lens_this_time[i].item() - if seq_lens_encoder[i] > 0: - encoder_indices.extend(range(offset, offset + length)) - elif seq_lens_decoder[i] > 0: - decoder_indices.extend(range(offset, offset + length)) - offset += length + q = q.view([-1, self.num_heads, self.head_dim]) + k = k.view([-1, self.kv_num_heads, self.head_dim]) + v = v.view([-1, self.kv_num_heads, self.head_dim]) - encoder_indices = paddle.to_tensor(encoder_indices, dtype="int32") - decoder_indices = paddle.to_tensor(decoder_indices, dtype="int32") + if is_decode: + return q, k, v - encoder_qkv = paddle.index_select(qkv, encoder_indices, axis=0) - decoder_qkv = paddle.index_select(qkv, decoder_indices, axis=0) + for idx in range(len(cu_seqlens_q) - 1): + batch_idx = batch_ids[idx] + seq_len_i = forward_meta.seq_lens_this_time[batch_idx] + if seq_len_i == 0: + continue + cached_kv_len = forward_meta.seq_lens_decoder[batch_idx][0] + cu_seq_start_q = cu_seqlens_q[idx] + cu_seq_end_q = cu_seqlens_q[idx + 1] + # forward_meta.rotary_embs is [2, 1, S, 1, D // 2] + if forward_meta.rotary_embs is not None: + cos = paddle.repeat_interleave( + forward_meta.rotary_embs[0, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :], repeats=2, axis=-1 + ) # [Si, D] + sin = paddle.repeat_interleave( + forward_meta.rotary_embs[1, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :], repeats=2, axis=-1 + ) # [Si, D] + q[cu_seq_start_q:cu_seq_end_q] = self.apply_rope(q[cu_seq_start_q:cu_seq_end_q], cos, sin) + k[cu_seq_start_q:cu_seq_end_q] = self.apply_rope(k[cu_seq_start_q:cu_seq_end_q], cos, sin) - # 2. 分解 encoder 和 decoder 的 qkv - encoder_q, encoder_k, encoder_v = self.split_qkv(encoder_qkv, num_head_q, num_head_kv, dim) - decoder_q, decoder_k, decoder_v = self.split_qkv(decoder_qkv, num_head_q, num_head_kv, dim) - cache_k = forward_meta.caches[2 * layer.layer_id] - cache_v = forward_meta.caches[2 * layer.layer_id + 1] + return q, k, v - # 3. Rotary Embedding - if decoder_q.numel() != 0 or encoder_q.numel() != 0: - for batch_idx in range(forward_meta.seq_lens_this_time.shape[0]): - seq_len_i = forward_meta.seq_lens_this_time[batch_idx] - if seq_len_i == 0: + def split_pd_qkv(self, qkv): + + for ids, reverse_ids in zip(self.prefill_info_dict["id_group"], self.prefill_info_dict["reverse_id_group"]): + self.prefill_qkv[ids[0] : ids[1], :] = qkv[reverse_ids[0] : reverse_ids[1], :] + + for ids, reverse_ids in zip(self.decode_info_dict["id_group"], self.decode_info_dict["reverse_id_group"]): + self.decode_qkv[ids[0] : ids[1], :] = qkv[reverse_ids[0] : reverse_ids[1], :] + + return self.prefill_qkv, self.decode_qkv + + def merge_pd_output(self, prefill_out, decode_out): + for stage, idx in self.record_stages: + if stage == "prefill": + ids = self.prefill_info_dict["id_group"][idx] + reverse_ids = self.prefill_info_dict["reverse_id_group"][idx] + self.merged_output[reverse_ids[0] : reverse_ids[1], :, :] = prefill_out[ids[0] : ids[1], :, :] + else: + ids = self.decode_info_dict["id_group"][idx] + reverse_ids = self.decode_info_dict["reverse_id_group"][idx] + self.merged_output[reverse_ids[0] : reverse_ids[1], :, :] = decode_out[ids[0] : ids[1], :, :] + return self.merged_output + + def update_kv_cache( + self, k, v, k_cache_id, v_cache_id, layer_id, forward_meta: ForwardMeta, specific_batch_ids=None + ): + tensor_start = 0 + for batch_idx in range(forward_meta.block_tables.shape[0]): + if specific_batch_ids is not None and batch_idx not in specific_batch_ids: + continue + seq_len = forward_meta.seq_lens_this_time[batch_idx] + if seq_len == 0: + continue + tensor_end = tensor_start + seq_len + slice_trans_k = k[tensor_start:tensor_end, :, :] + slice_trans_v = v[tensor_start:tensor_end, :, :] + + cur_block_tables = forward_meta.block_tables[batch_idx] + cur_used_block_tables = cur_block_tables[cur_block_tables != -1] + + # encoder prefil + if seq_len > 1: + cache_start = 0 + cur_used_num_blocks = cur_used_block_tables.shape[0] + + for i, block_id in enumerate(cur_used_block_tables): + + # last block: seq_len - cache_start <= block_size + if i == cur_used_num_blocks - 1: + cache_end = seq_len - cache_start + assert cache_end <= self.block_size + + forward_meta.caches[k_cache_id][block_id, 0:cache_end, :, :] = slice_trans_k[ + cache_start:seq_len, :, : + ] + forward_meta.caches[v_cache_id][block_id, 0:cache_end, :, :] = slice_trans_v[ + cache_start:seq_len, :, : + ] + if layer_id == self.num_layers - 1: + self.record_block_table_metadata[batch_idx] = { + "block_id": block_id.item(), + "cache_end": cache_end, + } + # non last block: seq_lens_this_time > block_size + else: + assert seq_len > self.block_size + cache_end = cache_start + self.block_size + forward_meta.caches[k_cache_id][block_id] = slice_trans_k[cache_start:cache_end, :, :] + forward_meta.caches[v_cache_id][block_id] = slice_trans_v[cache_start:cache_end, :, :] + cache_start += self.block_size + tensor_start = tensor_end + + def merge_output(self, prefill_out, decode_out, forward_meta: ForwardMeta): + assert not (prefill_out is None and decode_out is None), "prefill and decode output cannot both be None" + if prefill_out is None: + return decode_out + elif decode_out is None: + return prefill_out + else: + prefill_out = prefill_out + decode_out = decode_out + + merged_output = [] + prefill_tensor_start = 0 + decode_tensor_start = 0 + for seq_lens_this_time in forward_meta.seq_lens_this_time: + if seq_lens_this_time == 0: continue - cached_kv_len = seq_lens_decoder[batch_idx] - cu_seq_start_q = forward_meta.cu_seqlens_q[batch_idx] - cu_seq_end_q = forward_meta.cu_seqlens_q[batch_idx + 1] - if forward_meta.rotary_embs is not None and cu_seq_end_q > cu_seq_start_q: - cos = forward_meta.rotary_embs[0, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :] - sin = forward_meta.rotary_embs[1, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :] + if seq_lens_this_time > 1: + tensor_end = prefill_tensor_start + seq_lens_this_time + merged_output.append(prefill_out[prefill_tensor_start:tensor_end, :, :]) + prefill_tensor_start = tensor_end + else: + assert seq_lens_this_time == 1 + tensor_end = decode_tensor_start + seq_lens_this_time + merged_output.append(decode_out[decode_tensor_start:tensor_end, :, :]) + decode_tensor_start = tensor_end - def rope_func(qk): - qk[cu_seq_start_q:cu_seq_end_q] = self.apply_rope(qk[cu_seq_start_q:cu_seq_end_q], cos, sin) + assert ( + prefill_tensor_start == prefill_out.shape[0] + ), f"prefill merged unfinished: {prefill_tensor_start} vs {prefill_out.shape[0]}" + assert ( + decode_tensor_start == decode_out.shape[0] + ), f"decode merged unfinished: {decode_tensor_start} vs {decode_out.shape[0]}" + merged_output = paddle.concat(merged_output, axis=0) + return merged_output - if encoder_q.numel() != 0: - rope_func(encoder_q) - rope_func(encoder_k) - if decoder_q.numel() != 0: - rope_func(decoder_q) - rope_func(decoder_k) + def forward_prefill(self, prefill_qkv, layer_id, k_cache_id, v_cache_id, forward_meta: ForwardMeta): - # 4. Flash Attention for encoder - encoder_v = encoder_v - cu_seqlens_q = forward_meta.cu_seqlens_q - cu_seqlens_k = forward_meta.cu_seqlens_k - max_seqlen_q = paddle.max(seq_lens_this_time) - max_seqlen_k = max_seqlen_q + prefill_q, prefill_k, prefill_v = self.get_splited_qkv( + prefill_qkv, + forward_meta, + self.prefill_info_dict["cu_seqlens_q"], + batch_ids=self.batch_ids_prefill, + ) + + prefill_out = flash_attn_unpadded_func( + prefill_q, + prefill_k, + prefill_v, + self.prefill_info_dict["cu_seqlens_q"], + self.prefill_info_dict["cu_seqlens_q"], + max_seqlen_q=self.max_seq_len, + max_seqlen_k=self.max_seq_len, + attn_mask=forward_meta.attn_mask, + causal=self.causal, + )[0] + + self.update_kv_cache( + prefill_k, prefill_v, k_cache_id, v_cache_id, layer_id, forward_meta, self.batch_ids_prefill + ) + + return prefill_out + + def forward_decode(self, decode_qkv, k_cache_id, v_cache_id, forward_meta: ForwardMeta): + cache_k = forward_meta.caches[k_cache_id] + cache_v = forward_meta.caches[v_cache_id] + cu_seq_lens = list(range(self.decode_len + 1)) + + q, k, v = self.get_splited_qkv(decode_qkv, forward_meta, cu_seq_lens, self.batch_ids_decode, is_decode=True) + decoder_q = q.view([self.decode_len, 1, self.num_heads, self.head_dim]) + decoder_k_ = k.view([self.decode_len, 1, self.kv_num_heads, self.head_dim]) + decoder_v_ = v.view([self.decode_len, 1, self.kv_num_heads, self.head_dim]) + + decode_out = flash_attn_kvcache_func( + decoder_q, + cache_k, + cache_v, + self.seq_lens_dec, + self.block_table_dec, + decoder_k_, + decoder_v_, + rotary_cos=forward_meta.rotary_embs[0, 0, :, 0, :].astype("bfloat16"), + rotary_sin=forward_meta.rotary_embs[1, 0, :, 0, :].astype("bfloat16"), + causal=self.causal, + is_rotary_interleaved=True, + )[0].squeeze(1) + + return decode_out + + @paddle.no_grad() + def forward_native_backend(self, q, k, v, qkv, layer, forward_meta: ForwardMeta): + + layer_id = layer.layer_id + k_cache_id = layer_id * 2 + v_cache_id = k_cache_id + 1 + + if self.decode_len == 0: + out = self.forward_prefill(qkv, layer_id, k_cache_id, v_cache_id, forward_meta) + + elif self.prefill_len == 0: + out = self.forward_decode(qkv, k_cache_id, v_cache_id, forward_meta) - if encoder_q.numel() > 0: - encoder_out = flash_attn_unpadded_func( - encoder_q, - encoder_k, - encoder_v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - attn_mask=forward_meta.attn_mask, - causal=self.causal, - ) - self.update_encoder_kv_cache( - encoder_k, encoder_v, seq_lens_encoder, cache_k, cache_v, forward_meta.block_tables - ) else: - encoder_out = None + prefill_qkv, decode_qkv = self.split_pd_qkv(qkv) + prefill_output = self.forward_prefill(prefill_qkv, layer_id, k_cache_id, v_cache_id, forward_meta) + decode_output = self.forward_decode(decode_qkv, k_cache_id, v_cache_id, forward_meta) + out = self.merge_pd_output(prefill_output, decode_output) - # 5. decoder attention with kv cache - bs = decoder_q.shape[0] - decoder_q = decoder_q.reshape([bs, 1, num_head_q, dim]) - decoder_k_ = decoder_k.reshape([bs, 1, num_head_kv, dim]) - decoder_v_ = decoder_v.reshape([bs, 1, num_head_kv, dim]) - cache_seqlens = paddle.index_select(forward_meta.seq_lens_decoder, decoder_indices, axis=0) - - # 5.1 convert paged kv cache to continuous cache - if decoder_q.numel() > 0: - max_cache_seq_len = paddle.max(cache_seqlens) - c_cache_k, c_cache_v = self.block_cache_to_naive_cache__( - cache_k, cache_v, bs, forward_meta.block_tables, max_cache_seq_len - ) - decoder_out = flash_attn_kvcache_func( - decoder_q, - c_cache_k, - c_cache_v, - cache_seqlens.squeeze(-1), - None, - decoder_k_, - decoder_v_, - causal=self.causal, - ) - self.update_decoder_kv_cache( - decoder_k, decoder_v, seq_lens_decoder, cache_k, cache_v, forward_meta.block_tables - ) - else: - decoder_out = None - - # 6. 拼接 encoder_out 和 decoder_out - total_len = qkv.shape[0] - out = paddle.zeros([total_len, num_head_q, dim]) - if encoder_out is not None: - out = paddle.tensor.put_along_axis( - out, encoder_indices.unsqueeze(-1).unsqueeze(-1), encoder_out[0], axis=0 - ) - if decoder_out is not None: - new_decoder_out = decoder_out[0].squeeze(1) - out = paddle.tensor.put_along_axis( - out, decoder_indices.unsqueeze(-1).unsqueeze(-1), new_decoder_out, axis=0 - ) - - out.reshape_([total_len, num_head_q * dim]) + if qkv.dim() == 2: + out = out.view([-1, self.num_heads * self.head_dim]) return out diff --git a/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_triton_metax_backend.py b/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_triton_metax_backend.py index e945a189a..907ddff65 100644 --- a/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_triton_metax_backend.py +++ b/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_triton_metax_backend.py @@ -45,17 +45,73 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase): """process_prequanted_weights""" pass - @paddle.no_grad() - def create_weights(self, layer: nn.Layer, state_dict): + def create_weights(self, layer: nn.Layer, **extra_weight_attrs): """ Triton MoE create weight process. """ + self.weight_dtype = "int8" + self.default_dtype = layer._helper.get_default_dtype() + up_gate_proj_weight_name = self.added_weight_attrs[0] + down_proj_weight_name = self.added_weight_attrs[1] + self.up_gate_proj_weight_shape = [ + layer.num_local_experts, + layer.hidden_size, + layer.moe_intermediate_size * 2, + ] + self.down_proj_weight_shape = [ + layer.num_local_experts, + layer.moe_intermediate_size, + layer.hidden_size, + ] + setattr( + layer, + up_gate_proj_weight_name, + layer.create_parameter( + shape=self.up_gate_proj_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + setattr( + layer, + down_proj_weight_name, + layer.create_parameter( + shape=self.down_proj_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + # weight_scale + setattr( + layer, + self.added_scale_attrs[0], + layer.create_parameter( + shape=[layer.num_local_experts, layer.moe_intermediate_size * 2], + dtype=self.default_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + setattr( + layer, + self.added_scale_attrs[1], + layer.create_parameter( + shape=[layer.num_local_experts, layer.hidden_size], + dtype=self.default_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + + def process_loaded_weights(self, layer: nn.Layer, state_dict): + """ + Triton MoE load weight process. + """ up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict) assert len(up_gate_proj_weights) == layer.num_local_experts assert len(down_proj_weights) == layer.num_local_experts - if layer.quant_method.quant_config: - algo = layer.quant_method.quant_config.name() + algo = layer.quant_method.quant_config.name() + + assert algo == "wint8" assert up_gate_proj_weights[0].shape == [ layer.hidden_size, @@ -79,52 +135,12 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase): scale_name = self.added_scale_attrs[idx] quanted_weight_scale = weight_tensor.abs().max(axis=1) - if self.quant_config is not None: - quanted_weight = weight_tensor / quanted_weight_scale[:, None, :] * max_bound - quanted_weight = paddle.round(quanted_weight).astype("int8") - quanted_weight_scale = quanted_weight_scale / max_bound + quanted_weight = weight_tensor / quanted_weight_scale[:, None, :] * max_bound + quanted_weight = paddle.round(quanted_weight).astype("int8") + quanted_weight_scale = quanted_weight_scale / max_bound - setattr( - layer, - weight_name, - layer.create_parameter( - shape=quanted_weight.shape, - dtype=quanted_weight.dtype, - default_initializer=paddle.nn.initializer.Constant(0), - ), - ) - getattr(layer, weight_name).set_value(quanted_weight) - - setattr( - layer, - scale_name, - layer.create_parameter( - shape=quanted_weight_scale.shape, - dtype=quanted_weight_scale.dtype, - ), - ) - getattr(layer, scale_name).set_value(quanted_weight_scale) - else: - setattr( - layer, - weight_name, - layer.create_parameter( - shape=quanted_weight.shape, - dtype=quanted_weight.dtype, - default_initializer=paddle.nn.initializer.Constant(0), - ), - ) - getattr(layer, weight_name).set_value(quanted_weight) - - setattr( - layer, - scale_name, - layer.create_parameter( - shape=quanted_weight_scale.shape, - dtype=quanted_weight_scale.dtype, - ), - ) - getattr(layer, scale_name).set_value(quanted_weight_scale) + getattr(layer, weight_name).set_value(quanted_weight) + getattr(layer, scale_name).set_value(quanted_weight_scale) @paddle.no_grad() def apply( @@ -159,16 +175,16 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase): if self.quant_config is not None: config = { "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, } else: config = { "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 8, } sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess( diff --git a/fastdeploy/model_executor/layers/quantization/weight_only.py b/fastdeploy/model_executor/layers/quantization/weight_only.py index b8e929b4c..5e4c8ed52 100644 --- a/fastdeploy/model_executor/layers/quantization/weight_only.py +++ b/fastdeploy/model_executor/layers/quantization/weight_only.py @@ -313,24 +313,14 @@ class WeightOnlyLinearMethod(QuantMethodBase): raise NotImplementedError def apply(self, layer, x): - if current_platform.is_maca(): - linear_out = weight_only_linear( - x, - weight=layer.weight, - bias=layer.bias if layer.add_bias else None, - weight_scale=layer.weight_scale, - weight_dtype=("int8" if self.quant_config.name() == "wint8" else "int4"), - arch=80, - ) - else: - linear_out = weight_only_linear( - x, - weight=layer.weight, - bias=layer.bias if layer.add_bias else None, - weight_scale=layer.weight_scale, - weight_dtype=("int8" if self.quant_config.name() == "wint8" else "int4"), - arch=self.quant_config.weight_only_linear_arch, - ) + linear_out = weight_only_linear( + x, + weight=layer.weight, + bias=layer.bias if layer.add_bias else None, + weight_scale=layer.weight_scale, + weight_dtype=("int8" if self.quant_config.name() == "wint8" else "int4"), + arch=self.quant_config.weight_only_linear_arch, + ) return linear_out diff --git a/fastdeploy/model_executor/layers/rotary_embedding.py b/fastdeploy/model_executor/layers/rotary_embedding.py index 1d405d7a9..a51b53b0d 100644 --- a/fastdeploy/model_executor/layers/rotary_embedding.py +++ b/fastdeploy/model_executor/layers/rotary_embedding.py @@ -52,9 +52,9 @@ class ErnieRotaryEmbedding: rot_emb = paddle.concat([freqs.cos(), freqs.sin()], axis=-1) return rot_emb elif paddle.is_compiled_with_custom_device("metax_gpu"): - # shape: [B, S, D] - rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim), dtype="float32") - emb = paddle.stack([freqs, freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim)) + # shape: [B, S, D/2] + rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32") + emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim // 2)) else: # shape: [B, S, D/2] rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32")