mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[Feature] refactor metax_gpu attention and moe and remove some useless code (#3688)
Co-authored-by: yongqiangma <xing.wo@163.com>
This commit is contained in:
@@ -894,7 +894,7 @@ class CacheConfig:
|
|||||||
self.kv_cache_ratio = 1.0
|
self.kv_cache_ratio = 1.0
|
||||||
else:
|
else:
|
||||||
self.kv_cache_ratio = 0.75
|
self.kv_cache_ratio = 0.75
|
||||||
self.enc_dec_block_num = 0 if current_platform.is_iluvatar() else 2
|
self.enc_dec_block_num = 0 if current_platform.is_iluvatar() or current_platform.is_maca() else 2
|
||||||
self.prealloc_dec_block_slot_num_threshold = 12
|
self.prealloc_dec_block_slot_num_threshold = 12
|
||||||
self.cache_dtype = "bfloat16"
|
self.cache_dtype = "bfloat16"
|
||||||
self.model_cfg = None
|
self.model_cfg = None
|
||||||
|
@@ -16,13 +16,11 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import math
|
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
import paddle.nn.functional as F
|
|
||||||
|
|
||||||
from fastdeploy.config import FDConfig
|
from fastdeploy.config import FDConfig
|
||||||
from fastdeploy.model_executor.forward_meta import ForwardMeta, ForwardMode
|
from fastdeploy.model_executor.forward_meta import ForwardMeta, ForwardMode
|
||||||
@@ -92,6 +90,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.attention_metadata: FlashAttentionMetadata = None
|
self.attention_metadata: FlashAttentionMetadata = None
|
||||||
|
self.record_block_table_metadata = {}
|
||||||
self.block_size: int = fd_config.parallel_config.block_size
|
self.block_size: int = fd_config.parallel_config.block_size
|
||||||
self.max_seq_len: int = fd_config.parallel_config.max_model_len
|
self.max_seq_len: int = fd_config.parallel_config.max_model_len
|
||||||
self.rope_theta: float = (
|
self.rope_theta: float = (
|
||||||
@@ -110,6 +109,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
self.kv_num_heads: int = kv_num_heads
|
self.kv_num_heads: int = kv_num_heads
|
||||||
self.num_heads: int = num_heads
|
self.num_heads: int = num_heads
|
||||||
self.head_dim: int = fd_config.model_config.head_dim
|
self.head_dim: int = fd_config.model_config.head_dim
|
||||||
|
self.total_num_heads = self.num_heads + 2 * self.kv_num_heads
|
||||||
|
self.total_hidden_dim = self.total_num_heads * self.head_dim
|
||||||
|
self.dtype = paddle.get_default_dtype()
|
||||||
self.num_layers: int = fd_config.model_config.num_hidden_layers
|
self.num_layers: int = fd_config.model_config.num_hidden_layers
|
||||||
self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", 32768))
|
self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", 32768))
|
||||||
|
|
||||||
@@ -125,7 +127,98 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||||
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
|
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
|
||||||
forward_meta.forward_mode = ForwardMode.NATIVE
|
forward_meta.forward_mode = ForwardMode.NATIVE
|
||||||
return
|
self.prefill_info_dict = {}
|
||||||
|
self.decode_info_dict = {}
|
||||||
|
|
||||||
|
prefill_non_zeros_ids = forward_meta.seq_lens_this_time > 1
|
||||||
|
decode_non_zeros_ids = forward_meta.seq_lens_this_time == 1
|
||||||
|
self.prefill_info_dict["batch_ids"] = paddle.where(prefill_non_zeros_ids)[0]
|
||||||
|
self.decode_info_dict["batch_ids"] = paddle.where(decode_non_zeros_ids)[0]
|
||||||
|
|
||||||
|
self.prefill_len = len(self.prefill_info_dict["batch_ids"])
|
||||||
|
self.decode_len = len(self.decode_info_dict["batch_ids"])
|
||||||
|
|
||||||
|
# only prefill
|
||||||
|
if self.decode_len == 0:
|
||||||
|
cu_seq_ids = list(range(self.prefill_len + 1))
|
||||||
|
self.prefill_info_dict["cu_seqlens_q"] = forward_meta.cu_seqlens_q[cu_seq_ids].astype("int32")
|
||||||
|
# only decode
|
||||||
|
elif self.prefill_len == 0:
|
||||||
|
pass
|
||||||
|
# both prefill and decode
|
||||||
|
else:
|
||||||
|
prefill_num_tokens = paddle.sum(forward_meta.seq_lens_this_time[prefill_non_zeros_ids])
|
||||||
|
decode_num_tokens = paddle.sum(forward_meta.seq_lens_this_time[decode_non_zeros_ids])
|
||||||
|
|
||||||
|
self.prefill_info_dict["cu_seqlens_q"] = paddle.zeros(
|
||||||
|
[self.prefill_len + 1], dtype=forward_meta.cu_seqlens_q.dtype
|
||||||
|
)
|
||||||
|
self.prefill_info_dict["cu_seqlens_q"][1:] = forward_meta.seq_lens_encoder[
|
||||||
|
self.prefill_info_dict["batch_ids"], 0
|
||||||
|
]
|
||||||
|
self.prefill_info_dict["cu_seqlens_q"] = paddle.cumsum(self.prefill_info_dict["cu_seqlens_q"]).astype(
|
||||||
|
"int32"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.prefill_qkv = paddle.zeros([prefill_num_tokens, self.total_hidden_dim], dtype=self.dtype)
|
||||||
|
self.decode_qkv = paddle.zeros([decode_num_tokens, self.total_hidden_dim], dtype=self.dtype)
|
||||||
|
self.merged_output = paddle.zeros(
|
||||||
|
[prefill_num_tokens + decode_num_tokens, self.num_heads, self.head_dim], dtype=self.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
prefill_start, decode_start, start = 0, 0, 0
|
||||||
|
non_zeros_ids = forward_meta.seq_lens_this_time != 0
|
||||||
|
non_zeros_seq_lens = forward_meta.seq_lens_this_time[non_zeros_ids]
|
||||||
|
end = non_zeros_seq_lens[0]
|
||||||
|
if end > 1:
|
||||||
|
last_stage = "prefill"
|
||||||
|
prefill_end = end
|
||||||
|
decode_end = 0
|
||||||
|
else:
|
||||||
|
last_stage = "decode"
|
||||||
|
prefill_end = 0
|
||||||
|
decode_end = end
|
||||||
|
|
||||||
|
self.prefill_info_dict["id_group"] = []
|
||||||
|
self.prefill_info_dict["reverse_id_group"] = []
|
||||||
|
self.decode_info_dict["id_group"] = []
|
||||||
|
self.decode_info_dict["reverse_id_group"] = []
|
||||||
|
self.record_stages = []
|
||||||
|
for seq_len in non_zeros_seq_lens[1:]:
|
||||||
|
if seq_len > 1:
|
||||||
|
if last_stage == "decode":
|
||||||
|
self.record_stages.append((last_stage, len(self.decode_info_dict["id_group"])))
|
||||||
|
self.decode_info_dict["id_group"].append((decode_start, decode_end))
|
||||||
|
self.decode_info_dict["reverse_id_group"].append((start, end))
|
||||||
|
decode_start = decode_end
|
||||||
|
start = end
|
||||||
|
last_stage = "prefill"
|
||||||
|
prefill_end += seq_len
|
||||||
|
end += seq_len
|
||||||
|
else:
|
||||||
|
if last_stage == "prefill":
|
||||||
|
self.record_stages.append((last_stage, len(self.prefill_info_dict["id_group"])))
|
||||||
|
self.prefill_info_dict["id_group"].append((prefill_start, prefill_end))
|
||||||
|
self.prefill_info_dict["reverse_id_group"].append((start, end))
|
||||||
|
prefill_start = prefill_end
|
||||||
|
start = end
|
||||||
|
last_stage = "decode"
|
||||||
|
decode_end += seq_len
|
||||||
|
end += seq_len
|
||||||
|
|
||||||
|
if prefill_start < prefill_end:
|
||||||
|
self.record_stages.append(("prefill", len(self.prefill_info_dict["id_group"])))
|
||||||
|
self.prefill_info_dict["id_group"].append((prefill_start, prefill_end))
|
||||||
|
self.prefill_info_dict["reverse_id_group"].append((start, end))
|
||||||
|
if decode_start < decode_end:
|
||||||
|
self.record_stages.append(("decode", len(self.decode_info_dict["id_group"])))
|
||||||
|
self.decode_info_dict["id_group"].append((decode_start, decode_end))
|
||||||
|
self.decode_info_dict["reverse_id_group"].append((start, end))
|
||||||
|
|
||||||
|
self.batch_ids_prefill = paddle.to_tensor(self.prefill_info_dict["batch_ids"])
|
||||||
|
self.batch_ids_decode = paddle.to_tensor(self.decode_info_dict["batch_ids"])
|
||||||
|
self.seq_lens_dec = forward_meta.seq_lens_decoder[self.batch_ids_decode, 0]
|
||||||
|
self.block_table_dec = forward_meta.block_tables[self.batch_ids_decode, :]
|
||||||
|
|
||||||
def get_attntion_meta(self) -> AttentionMetadata:
|
def get_attntion_meta(self) -> AttentionMetadata:
|
||||||
"""get_attntion_meta"""
|
"""get_attntion_meta"""
|
||||||
@@ -149,106 +242,11 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
else:
|
else:
|
||||||
return (
|
return (
|
||||||
max_num_blocks,
|
max_num_blocks,
|
||||||
self.kv_num_heads,
|
|
||||||
self.block_size,
|
self.block_size,
|
||||||
|
self.kv_num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
def split_qkv(self, qkv, num_head_q, num_head_kv, dim):
|
|
||||||
q = qkv[:, : num_head_q * dim].reshape([-1, num_head_q, dim])
|
|
||||||
k = qkv[:, num_head_q * dim : num_head_q * dim + num_head_kv * dim].reshape([-1, num_head_kv, dim])
|
|
||||||
v = qkv[:, num_head_q * dim + num_head_kv * dim :].reshape([-1, num_head_kv, dim])
|
|
||||||
return q, k, v
|
|
||||||
|
|
||||||
def flash_attn_varlen(self, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k):
|
|
||||||
num_head = q.shape[1]
|
|
||||||
dim = q.shape[2]
|
|
||||||
|
|
||||||
q_ = q.reshape([-1, num_head, dim])
|
|
||||||
k_ = k.reshape([-1, num_head, dim])
|
|
||||||
v_ = v.reshape([-1, num_head, dim])
|
|
||||||
|
|
||||||
bsz = cu_seqlens_q.shape[0] - 1
|
|
||||||
out = []
|
|
||||||
for i in range(bsz):
|
|
||||||
start_q, end_q = cu_seqlens_q[i].item(), cu_seqlens_q[i + 1].item()
|
|
||||||
start_k, end_k = cu_seqlens_k[i].item(), cu_seqlens_k[i + 1].item()
|
|
||||||
qi = q_[start_q:end_q] # [seq_q, nh, dim]
|
|
||||||
ki = k_[start_k:end_k] # [seq_k, nh, dim]
|
|
||||||
vi = v_[start_k:end_k] # [seq_k, nh, dim]
|
|
||||||
qi = qi.transpose([1, 0, 2]) # [nh, seq_q, dim]
|
|
||||||
ki = ki.transpose([1, 2, 0]) # [nh, dim, seq_k]
|
|
||||||
vi = vi.transpose([1, 0, 2]) # [nh, seq_k, dim]
|
|
||||||
|
|
||||||
score = paddle.matmul(qi, ki) / math.sqrt(dim) # [nh, seq_q, seq_k]
|
|
||||||
prob = F.softmax(score, axis=-1)
|
|
||||||
o = paddle.matmul(prob, vi) # [nh, seq_q, dim]
|
|
||||||
o = o.transpose([1, 0, 2]) # [seq_q, nh, dim]
|
|
||||||
out.append(o)
|
|
||||||
|
|
||||||
return paddle.concat(out, axis=0) # [total_q, nh, dim]
|
|
||||||
|
|
||||||
def flash_attn_with_kvcache(self, q, cache_k, cache_v, cache_seqlens, block_tables=None):
|
|
||||||
bs, _, nh, dim = q.shape
|
|
||||||
out = []
|
|
||||||
for i in range(bs):
|
|
||||||
q_i = q[i] # [1, nh, dim]
|
|
||||||
k_i = cache_k[i, : cache_seqlens[i, 0]] # [seqlen, nh, dim]
|
|
||||||
v_i = cache_v[i, : cache_seqlens[i, 0]]
|
|
||||||
qi = q_i.transpose([1, 0, 2]) # [nh, 1, dim]
|
|
||||||
ki = k_i.transpose([1, 2, 0]) # [nh, dim, seqlen]
|
|
||||||
vi = v_i.transpose([1, 0, 2]) # [nh, seqlen, dim]
|
|
||||||
score = paddle.matmul(qi, ki) / math.sqrt(dim)
|
|
||||||
prob = F.softmax(score, axis=-1)
|
|
||||||
o = paddle.matmul(prob, vi).transpose([1, 0, 2]) # [1, nh, dim]
|
|
||||||
out.append(o)
|
|
||||||
return paddle.concat(out, axis=0) # [bs, nh, dim]
|
|
||||||
|
|
||||||
def block_cache_to_naive_cache(slef, cache_k, cache_v, bsz, block_tables, cache_seq_len):
|
|
||||||
_, num_head, blocksize, dim_head = cache_k.shape
|
|
||||||
out_cache_k = paddle.zeros(shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_k.dtype)
|
|
||||||
out_cache_v = paddle.zeros(shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_v.dtype)
|
|
||||||
for i in range(bsz):
|
|
||||||
for j in range(cache_seq_len):
|
|
||||||
out_cache_k[i, :, j, :] = cache_k[block_tables[i, j // blocksize], :, j % blocksize, :]
|
|
||||||
out_cache_v[i, :, j, :] = cache_v[block_tables[i, j // blocksize], :, j % blocksize, :]
|
|
||||||
return out_cache_k, out_cache_v
|
|
||||||
|
|
||||||
def block_cache_to_naive_cache__(self, cache_k, cache_v, bsz, block_tables, max_cache_seq_len):
|
|
||||||
_, num_head, blocksize, dim_head = cache_k.shape
|
|
||||||
out_cache_k = paddle.zeros(shape=[bsz, max_cache_seq_len + 1, num_head, dim_head], dtype=cache_k.dtype)
|
|
||||||
out_cache_v = paddle.zeros(shape=[bsz, max_cache_seq_len + 1, num_head, dim_head], dtype=cache_v.dtype)
|
|
||||||
for i in range(bsz):
|
|
||||||
for j in range(max_cache_seq_len):
|
|
||||||
out_cache_k[i, j, :, :] = cache_k[block_tables[i, j // blocksize], :, j % blocksize, :]
|
|
||||||
out_cache_v[i, j, :, :] = cache_v[block_tables[i, j // blocksize], :, j % blocksize, :]
|
|
||||||
return out_cache_k, out_cache_v
|
|
||||||
|
|
||||||
def update_encoder_kv_cache(self, k, v, seq_lens_encoder, cache_k, cache_v, block_tables):
|
|
||||||
_, num_head, blocksize, dim_head = cache_k.shape
|
|
||||||
offset = 0
|
|
||||||
for batch_idx, seq_len in enumerate(seq_lens_encoder.numpy()):
|
|
||||||
if seq_len == 0:
|
|
||||||
continue
|
|
||||||
for seq_idx in range(seq_len):
|
|
||||||
block_id = block_tables[batch_idx, seq_idx // blocksize]
|
|
||||||
assert block_id != -1
|
|
||||||
index = offset + seq_idx
|
|
||||||
cache_k[block_id, :, seq_idx % blocksize, :] = k[index, :, :]
|
|
||||||
cache_v[block_id, :, seq_idx % blocksize, :] = v[index, :, :]
|
|
||||||
|
|
||||||
offset += seq_len
|
|
||||||
|
|
||||||
def update_decoder_kv_cache(self, k, v, seq_lens_decoder, cache_k, cache_v, block_tables):
|
|
||||||
_, num_head, blocksize, dim_head = cache_k.shape
|
|
||||||
for batch_idx, seq_idx in enumerate(seq_lens_decoder.numpy()):
|
|
||||||
if seq_idx == 0:
|
|
||||||
continue
|
|
||||||
block_id = block_tables[batch_idx, seq_idx // blocksize]
|
|
||||||
assert block_id != -1
|
|
||||||
cache_k[block_id, :, seq_idx % blocksize, :] = k[batch_idx, :, :]
|
|
||||||
cache_v[block_id, :, seq_idx % blocksize, :] = v[batch_idx, :, :]
|
|
||||||
|
|
||||||
def apply_rope(self, qk, cos, sin):
|
def apply_rope(self, qk, cos, sin):
|
||||||
rotate_half = paddle.reshape(
|
rotate_half = paddle.reshape(
|
||||||
paddle.stack([-qk[..., 1::2], qk[..., 0::2]], axis=-1),
|
paddle.stack([-qk[..., 1::2], qk[..., 0::2]], axis=-1),
|
||||||
@@ -257,138 +255,234 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
out = paddle.add(paddle.multiply(qk, cos), paddle.multiply(rotate_half, sin))
|
out = paddle.add(paddle.multiply(qk, cos), paddle.multiply(rotate_half, sin))
|
||||||
return paddle.cast(out, qk.dtype)
|
return paddle.cast(out, qk.dtype)
|
||||||
|
|
||||||
@paddle.no_grad()
|
def get_splited_qkv(
|
||||||
def forward_native_backend(
|
|
||||||
self,
|
self,
|
||||||
q: paddle.Tensor,
|
|
||||||
k: paddle.Tensor,
|
|
||||||
v: paddle.Tensor,
|
|
||||||
qkv: paddle.Tensor,
|
qkv: paddle.Tensor,
|
||||||
layer,
|
|
||||||
forward_meta: ForwardMeta,
|
forward_meta: ForwardMeta,
|
||||||
|
cu_seqlens_q: paddle.Tensor,
|
||||||
|
batch_ids=None,
|
||||||
|
is_decode=False,
|
||||||
):
|
):
|
||||||
|
q_end = self.num_heads * self.head_dim
|
||||||
|
k_end = q_end + self.kv_num_heads * self.head_dim
|
||||||
|
v_end = k_end + self.kv_num_heads * self.head_dim
|
||||||
|
assert v_end == qkv.shape[-1], f"Shape mismatch: {v_end} vs {qkv.shape[-1]}"
|
||||||
|
assert qkv.shape[0] == cu_seqlens_q[-1], f"Shape mismatch: {qkv.shape[0]} vs {cu_seqlens_q[-1]}"
|
||||||
|
|
||||||
bsz = forward_meta.seq_lens_this_time.shape[0]
|
if batch_ids is None:
|
||||||
num_head_q, num_head_kv, dim = layer.num_heads, layer.kv_num_heads, layer.head_dim
|
batch_ids = list(range(forward_meta.seq_lens_this_time.shape[0]))
|
||||||
|
|
||||||
# 1. 分离 encoder / decoder 的 mask
|
q = qkv[..., 0:q_end]
|
||||||
seq_lens_encoder = forward_meta.seq_lens_encoder.squeeze(-1)
|
k = qkv[..., q_end:k_end]
|
||||||
seq_lens_decoder = forward_meta.seq_lens_decoder.squeeze(-1)
|
v = qkv[..., k_end:v_end]
|
||||||
seq_lens_this_time = forward_meta.seq_lens_this_time
|
|
||||||
encoder_indices = []
|
|
||||||
decoder_indices = []
|
|
||||||
|
|
||||||
offset = 0
|
q = q.view([-1, self.num_heads, self.head_dim])
|
||||||
for i in range(bsz):
|
k = k.view([-1, self.kv_num_heads, self.head_dim])
|
||||||
length = seq_lens_this_time[i].item()
|
v = v.view([-1, self.kv_num_heads, self.head_dim])
|
||||||
if seq_lens_encoder[i] > 0:
|
|
||||||
encoder_indices.extend(range(offset, offset + length))
|
|
||||||
elif seq_lens_decoder[i] > 0:
|
|
||||||
decoder_indices.extend(range(offset, offset + length))
|
|
||||||
offset += length
|
|
||||||
|
|
||||||
encoder_indices = paddle.to_tensor(encoder_indices, dtype="int32")
|
if is_decode:
|
||||||
decoder_indices = paddle.to_tensor(decoder_indices, dtype="int32")
|
return q, k, v
|
||||||
|
|
||||||
encoder_qkv = paddle.index_select(qkv, encoder_indices, axis=0)
|
for idx in range(len(cu_seqlens_q) - 1):
|
||||||
decoder_qkv = paddle.index_select(qkv, decoder_indices, axis=0)
|
batch_idx = batch_ids[idx]
|
||||||
|
seq_len_i = forward_meta.seq_lens_this_time[batch_idx]
|
||||||
|
if seq_len_i == 0:
|
||||||
|
continue
|
||||||
|
cached_kv_len = forward_meta.seq_lens_decoder[batch_idx][0]
|
||||||
|
cu_seq_start_q = cu_seqlens_q[idx]
|
||||||
|
cu_seq_end_q = cu_seqlens_q[idx + 1]
|
||||||
|
# forward_meta.rotary_embs is [2, 1, S, 1, D // 2]
|
||||||
|
if forward_meta.rotary_embs is not None:
|
||||||
|
cos = paddle.repeat_interleave(
|
||||||
|
forward_meta.rotary_embs[0, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :], repeats=2, axis=-1
|
||||||
|
) # [Si, D]
|
||||||
|
sin = paddle.repeat_interleave(
|
||||||
|
forward_meta.rotary_embs[1, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :], repeats=2, axis=-1
|
||||||
|
) # [Si, D]
|
||||||
|
q[cu_seq_start_q:cu_seq_end_q] = self.apply_rope(q[cu_seq_start_q:cu_seq_end_q], cos, sin)
|
||||||
|
k[cu_seq_start_q:cu_seq_end_q] = self.apply_rope(k[cu_seq_start_q:cu_seq_end_q], cos, sin)
|
||||||
|
|
||||||
# 2. 分解 encoder 和 decoder 的 qkv
|
return q, k, v
|
||||||
encoder_q, encoder_k, encoder_v = self.split_qkv(encoder_qkv, num_head_q, num_head_kv, dim)
|
|
||||||
decoder_q, decoder_k, decoder_v = self.split_qkv(decoder_qkv, num_head_q, num_head_kv, dim)
|
|
||||||
cache_k = forward_meta.caches[2 * layer.layer_id]
|
|
||||||
cache_v = forward_meta.caches[2 * layer.layer_id + 1]
|
|
||||||
|
|
||||||
# 3. Rotary Embedding
|
def split_pd_qkv(self, qkv):
|
||||||
if decoder_q.numel() != 0 or encoder_q.numel() != 0:
|
|
||||||
for batch_idx in range(forward_meta.seq_lens_this_time.shape[0]):
|
for ids, reverse_ids in zip(self.prefill_info_dict["id_group"], self.prefill_info_dict["reverse_id_group"]):
|
||||||
seq_len_i = forward_meta.seq_lens_this_time[batch_idx]
|
self.prefill_qkv[ids[0] : ids[1], :] = qkv[reverse_ids[0] : reverse_ids[1], :]
|
||||||
if seq_len_i == 0:
|
|
||||||
|
for ids, reverse_ids in zip(self.decode_info_dict["id_group"], self.decode_info_dict["reverse_id_group"]):
|
||||||
|
self.decode_qkv[ids[0] : ids[1], :] = qkv[reverse_ids[0] : reverse_ids[1], :]
|
||||||
|
|
||||||
|
return self.prefill_qkv, self.decode_qkv
|
||||||
|
|
||||||
|
def merge_pd_output(self, prefill_out, decode_out):
|
||||||
|
for stage, idx in self.record_stages:
|
||||||
|
if stage == "prefill":
|
||||||
|
ids = self.prefill_info_dict["id_group"][idx]
|
||||||
|
reverse_ids = self.prefill_info_dict["reverse_id_group"][idx]
|
||||||
|
self.merged_output[reverse_ids[0] : reverse_ids[1], :, :] = prefill_out[ids[0] : ids[1], :, :]
|
||||||
|
else:
|
||||||
|
ids = self.decode_info_dict["id_group"][idx]
|
||||||
|
reverse_ids = self.decode_info_dict["reverse_id_group"][idx]
|
||||||
|
self.merged_output[reverse_ids[0] : reverse_ids[1], :, :] = decode_out[ids[0] : ids[1], :, :]
|
||||||
|
return self.merged_output
|
||||||
|
|
||||||
|
def update_kv_cache(
|
||||||
|
self, k, v, k_cache_id, v_cache_id, layer_id, forward_meta: ForwardMeta, specific_batch_ids=None
|
||||||
|
):
|
||||||
|
tensor_start = 0
|
||||||
|
for batch_idx in range(forward_meta.block_tables.shape[0]):
|
||||||
|
if specific_batch_ids is not None and batch_idx not in specific_batch_ids:
|
||||||
|
continue
|
||||||
|
seq_len = forward_meta.seq_lens_this_time[batch_idx]
|
||||||
|
if seq_len == 0:
|
||||||
|
continue
|
||||||
|
tensor_end = tensor_start + seq_len
|
||||||
|
slice_trans_k = k[tensor_start:tensor_end, :, :]
|
||||||
|
slice_trans_v = v[tensor_start:tensor_end, :, :]
|
||||||
|
|
||||||
|
cur_block_tables = forward_meta.block_tables[batch_idx]
|
||||||
|
cur_used_block_tables = cur_block_tables[cur_block_tables != -1]
|
||||||
|
|
||||||
|
# encoder prefil
|
||||||
|
if seq_len > 1:
|
||||||
|
cache_start = 0
|
||||||
|
cur_used_num_blocks = cur_used_block_tables.shape[0]
|
||||||
|
|
||||||
|
for i, block_id in enumerate(cur_used_block_tables):
|
||||||
|
|
||||||
|
# last block: seq_len - cache_start <= block_size
|
||||||
|
if i == cur_used_num_blocks - 1:
|
||||||
|
cache_end = seq_len - cache_start
|
||||||
|
assert cache_end <= self.block_size
|
||||||
|
|
||||||
|
forward_meta.caches[k_cache_id][block_id, 0:cache_end, :, :] = slice_trans_k[
|
||||||
|
cache_start:seq_len, :, :
|
||||||
|
]
|
||||||
|
forward_meta.caches[v_cache_id][block_id, 0:cache_end, :, :] = slice_trans_v[
|
||||||
|
cache_start:seq_len, :, :
|
||||||
|
]
|
||||||
|
if layer_id == self.num_layers - 1:
|
||||||
|
self.record_block_table_metadata[batch_idx] = {
|
||||||
|
"block_id": block_id.item(),
|
||||||
|
"cache_end": cache_end,
|
||||||
|
}
|
||||||
|
# non last block: seq_lens_this_time > block_size
|
||||||
|
else:
|
||||||
|
assert seq_len > self.block_size
|
||||||
|
cache_end = cache_start + self.block_size
|
||||||
|
forward_meta.caches[k_cache_id][block_id] = slice_trans_k[cache_start:cache_end, :, :]
|
||||||
|
forward_meta.caches[v_cache_id][block_id] = slice_trans_v[cache_start:cache_end, :, :]
|
||||||
|
cache_start += self.block_size
|
||||||
|
tensor_start = tensor_end
|
||||||
|
|
||||||
|
def merge_output(self, prefill_out, decode_out, forward_meta: ForwardMeta):
|
||||||
|
assert not (prefill_out is None and decode_out is None), "prefill and decode output cannot both be None"
|
||||||
|
if prefill_out is None:
|
||||||
|
return decode_out
|
||||||
|
elif decode_out is None:
|
||||||
|
return prefill_out
|
||||||
|
else:
|
||||||
|
prefill_out = prefill_out
|
||||||
|
decode_out = decode_out
|
||||||
|
|
||||||
|
merged_output = []
|
||||||
|
prefill_tensor_start = 0
|
||||||
|
decode_tensor_start = 0
|
||||||
|
for seq_lens_this_time in forward_meta.seq_lens_this_time:
|
||||||
|
if seq_lens_this_time == 0:
|
||||||
continue
|
continue
|
||||||
cached_kv_len = seq_lens_decoder[batch_idx]
|
if seq_lens_this_time > 1:
|
||||||
cu_seq_start_q = forward_meta.cu_seqlens_q[batch_idx]
|
tensor_end = prefill_tensor_start + seq_lens_this_time
|
||||||
cu_seq_end_q = forward_meta.cu_seqlens_q[batch_idx + 1]
|
merged_output.append(prefill_out[prefill_tensor_start:tensor_end, :, :])
|
||||||
if forward_meta.rotary_embs is not None and cu_seq_end_q > cu_seq_start_q:
|
prefill_tensor_start = tensor_end
|
||||||
cos = forward_meta.rotary_embs[0, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :]
|
else:
|
||||||
sin = forward_meta.rotary_embs[1, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :]
|
assert seq_lens_this_time == 1
|
||||||
|
tensor_end = decode_tensor_start + seq_lens_this_time
|
||||||
|
merged_output.append(decode_out[decode_tensor_start:tensor_end, :, :])
|
||||||
|
decode_tensor_start = tensor_end
|
||||||
|
|
||||||
def rope_func(qk):
|
assert (
|
||||||
qk[cu_seq_start_q:cu_seq_end_q] = self.apply_rope(qk[cu_seq_start_q:cu_seq_end_q], cos, sin)
|
prefill_tensor_start == prefill_out.shape[0]
|
||||||
|
), f"prefill merged unfinished: {prefill_tensor_start} vs {prefill_out.shape[0]}"
|
||||||
|
assert (
|
||||||
|
decode_tensor_start == decode_out.shape[0]
|
||||||
|
), f"decode merged unfinished: {decode_tensor_start} vs {decode_out.shape[0]}"
|
||||||
|
merged_output = paddle.concat(merged_output, axis=0)
|
||||||
|
return merged_output
|
||||||
|
|
||||||
if encoder_q.numel() != 0:
|
def forward_prefill(self, prefill_qkv, layer_id, k_cache_id, v_cache_id, forward_meta: ForwardMeta):
|
||||||
rope_func(encoder_q)
|
|
||||||
rope_func(encoder_k)
|
|
||||||
if decoder_q.numel() != 0:
|
|
||||||
rope_func(decoder_q)
|
|
||||||
rope_func(decoder_k)
|
|
||||||
|
|
||||||
# 4. Flash Attention for encoder
|
prefill_q, prefill_k, prefill_v = self.get_splited_qkv(
|
||||||
encoder_v = encoder_v
|
prefill_qkv,
|
||||||
cu_seqlens_q = forward_meta.cu_seqlens_q
|
forward_meta,
|
||||||
cu_seqlens_k = forward_meta.cu_seqlens_k
|
self.prefill_info_dict["cu_seqlens_q"],
|
||||||
max_seqlen_q = paddle.max(seq_lens_this_time)
|
batch_ids=self.batch_ids_prefill,
|
||||||
max_seqlen_k = max_seqlen_q
|
)
|
||||||
|
|
||||||
|
prefill_out = flash_attn_unpadded_func(
|
||||||
|
prefill_q,
|
||||||
|
prefill_k,
|
||||||
|
prefill_v,
|
||||||
|
self.prefill_info_dict["cu_seqlens_q"],
|
||||||
|
self.prefill_info_dict["cu_seqlens_q"],
|
||||||
|
max_seqlen_q=self.max_seq_len,
|
||||||
|
max_seqlen_k=self.max_seq_len,
|
||||||
|
attn_mask=forward_meta.attn_mask,
|
||||||
|
causal=self.causal,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
self.update_kv_cache(
|
||||||
|
prefill_k, prefill_v, k_cache_id, v_cache_id, layer_id, forward_meta, self.batch_ids_prefill
|
||||||
|
)
|
||||||
|
|
||||||
|
return prefill_out
|
||||||
|
|
||||||
|
def forward_decode(self, decode_qkv, k_cache_id, v_cache_id, forward_meta: ForwardMeta):
|
||||||
|
cache_k = forward_meta.caches[k_cache_id]
|
||||||
|
cache_v = forward_meta.caches[v_cache_id]
|
||||||
|
cu_seq_lens = list(range(self.decode_len + 1))
|
||||||
|
|
||||||
|
q, k, v = self.get_splited_qkv(decode_qkv, forward_meta, cu_seq_lens, self.batch_ids_decode, is_decode=True)
|
||||||
|
decoder_q = q.view([self.decode_len, 1, self.num_heads, self.head_dim])
|
||||||
|
decoder_k_ = k.view([self.decode_len, 1, self.kv_num_heads, self.head_dim])
|
||||||
|
decoder_v_ = v.view([self.decode_len, 1, self.kv_num_heads, self.head_dim])
|
||||||
|
|
||||||
|
decode_out = flash_attn_kvcache_func(
|
||||||
|
decoder_q,
|
||||||
|
cache_k,
|
||||||
|
cache_v,
|
||||||
|
self.seq_lens_dec,
|
||||||
|
self.block_table_dec,
|
||||||
|
decoder_k_,
|
||||||
|
decoder_v_,
|
||||||
|
rotary_cos=forward_meta.rotary_embs[0, 0, :, 0, :].astype("bfloat16"),
|
||||||
|
rotary_sin=forward_meta.rotary_embs[1, 0, :, 0, :].astype("bfloat16"),
|
||||||
|
causal=self.causal,
|
||||||
|
is_rotary_interleaved=True,
|
||||||
|
)[0].squeeze(1)
|
||||||
|
|
||||||
|
return decode_out
|
||||||
|
|
||||||
|
@paddle.no_grad()
|
||||||
|
def forward_native_backend(self, q, k, v, qkv, layer, forward_meta: ForwardMeta):
|
||||||
|
|
||||||
|
layer_id = layer.layer_id
|
||||||
|
k_cache_id = layer_id * 2
|
||||||
|
v_cache_id = k_cache_id + 1
|
||||||
|
|
||||||
|
if self.decode_len == 0:
|
||||||
|
out = self.forward_prefill(qkv, layer_id, k_cache_id, v_cache_id, forward_meta)
|
||||||
|
|
||||||
|
elif self.prefill_len == 0:
|
||||||
|
out = self.forward_decode(qkv, k_cache_id, v_cache_id, forward_meta)
|
||||||
|
|
||||||
if encoder_q.numel() > 0:
|
|
||||||
encoder_out = flash_attn_unpadded_func(
|
|
||||||
encoder_q,
|
|
||||||
encoder_k,
|
|
||||||
encoder_v,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
max_seqlen_q,
|
|
||||||
max_seqlen_k,
|
|
||||||
attn_mask=forward_meta.attn_mask,
|
|
||||||
causal=self.causal,
|
|
||||||
)
|
|
||||||
self.update_encoder_kv_cache(
|
|
||||||
encoder_k, encoder_v, seq_lens_encoder, cache_k, cache_v, forward_meta.block_tables
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
encoder_out = None
|
prefill_qkv, decode_qkv = self.split_pd_qkv(qkv)
|
||||||
|
prefill_output = self.forward_prefill(prefill_qkv, layer_id, k_cache_id, v_cache_id, forward_meta)
|
||||||
|
decode_output = self.forward_decode(decode_qkv, k_cache_id, v_cache_id, forward_meta)
|
||||||
|
out = self.merge_pd_output(prefill_output, decode_output)
|
||||||
|
|
||||||
# 5. decoder attention with kv cache
|
if qkv.dim() == 2:
|
||||||
bs = decoder_q.shape[0]
|
out = out.view([-1, self.num_heads * self.head_dim])
|
||||||
decoder_q = decoder_q.reshape([bs, 1, num_head_q, dim])
|
|
||||||
decoder_k_ = decoder_k.reshape([bs, 1, num_head_kv, dim])
|
|
||||||
decoder_v_ = decoder_v.reshape([bs, 1, num_head_kv, dim])
|
|
||||||
cache_seqlens = paddle.index_select(forward_meta.seq_lens_decoder, decoder_indices, axis=0)
|
|
||||||
|
|
||||||
# 5.1 convert paged kv cache to continuous cache
|
|
||||||
if decoder_q.numel() > 0:
|
|
||||||
max_cache_seq_len = paddle.max(cache_seqlens)
|
|
||||||
c_cache_k, c_cache_v = self.block_cache_to_naive_cache__(
|
|
||||||
cache_k, cache_v, bs, forward_meta.block_tables, max_cache_seq_len
|
|
||||||
)
|
|
||||||
decoder_out = flash_attn_kvcache_func(
|
|
||||||
decoder_q,
|
|
||||||
c_cache_k,
|
|
||||||
c_cache_v,
|
|
||||||
cache_seqlens.squeeze(-1),
|
|
||||||
None,
|
|
||||||
decoder_k_,
|
|
||||||
decoder_v_,
|
|
||||||
causal=self.causal,
|
|
||||||
)
|
|
||||||
self.update_decoder_kv_cache(
|
|
||||||
decoder_k, decoder_v, seq_lens_decoder, cache_k, cache_v, forward_meta.block_tables
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
decoder_out = None
|
|
||||||
|
|
||||||
# 6. 拼接 encoder_out 和 decoder_out
|
|
||||||
total_len = qkv.shape[0]
|
|
||||||
out = paddle.zeros([total_len, num_head_q, dim])
|
|
||||||
if encoder_out is not None:
|
|
||||||
out = paddle.tensor.put_along_axis(
|
|
||||||
out, encoder_indices.unsqueeze(-1).unsqueeze(-1), encoder_out[0], axis=0
|
|
||||||
)
|
|
||||||
if decoder_out is not None:
|
|
||||||
new_decoder_out = decoder_out[0].squeeze(1)
|
|
||||||
out = paddle.tensor.put_along_axis(
|
|
||||||
out, decoder_indices.unsqueeze(-1).unsqueeze(-1), new_decoder_out, axis=0
|
|
||||||
)
|
|
||||||
|
|
||||||
out.reshape_([total_len, num_head_q * dim])
|
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
@@ -45,17 +45,73 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
|
|||||||
"""process_prequanted_weights"""
|
"""process_prequanted_weights"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@paddle.no_grad()
|
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
|
||||||
def create_weights(self, layer: nn.Layer, state_dict):
|
|
||||||
"""
|
"""
|
||||||
Triton MoE create weight process.
|
Triton MoE create weight process.
|
||||||
"""
|
"""
|
||||||
|
self.weight_dtype = "int8"
|
||||||
|
self.default_dtype = layer._helper.get_default_dtype()
|
||||||
|
up_gate_proj_weight_name = self.added_weight_attrs[0]
|
||||||
|
down_proj_weight_name = self.added_weight_attrs[1]
|
||||||
|
self.up_gate_proj_weight_shape = [
|
||||||
|
layer.num_local_experts,
|
||||||
|
layer.hidden_size,
|
||||||
|
layer.moe_intermediate_size * 2,
|
||||||
|
]
|
||||||
|
self.down_proj_weight_shape = [
|
||||||
|
layer.num_local_experts,
|
||||||
|
layer.moe_intermediate_size,
|
||||||
|
layer.hidden_size,
|
||||||
|
]
|
||||||
|
setattr(
|
||||||
|
layer,
|
||||||
|
up_gate_proj_weight_name,
|
||||||
|
layer.create_parameter(
|
||||||
|
shape=self.up_gate_proj_weight_shape,
|
||||||
|
dtype=self.weight_dtype,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
setattr(
|
||||||
|
layer,
|
||||||
|
down_proj_weight_name,
|
||||||
|
layer.create_parameter(
|
||||||
|
shape=self.down_proj_weight_shape,
|
||||||
|
dtype=self.weight_dtype,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# weight_scale
|
||||||
|
setattr(
|
||||||
|
layer,
|
||||||
|
self.added_scale_attrs[0],
|
||||||
|
layer.create_parameter(
|
||||||
|
shape=[layer.num_local_experts, layer.moe_intermediate_size * 2],
|
||||||
|
dtype=self.default_dtype,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
setattr(
|
||||||
|
layer,
|
||||||
|
self.added_scale_attrs[1],
|
||||||
|
layer.create_parameter(
|
||||||
|
shape=[layer.num_local_experts, layer.hidden_size],
|
||||||
|
dtype=self.default_dtype,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
||||||
|
"""
|
||||||
|
Triton MoE load weight process.
|
||||||
|
"""
|
||||||
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
|
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
|
||||||
assert len(up_gate_proj_weights) == layer.num_local_experts
|
assert len(up_gate_proj_weights) == layer.num_local_experts
|
||||||
assert len(down_proj_weights) == layer.num_local_experts
|
assert len(down_proj_weights) == layer.num_local_experts
|
||||||
|
|
||||||
if layer.quant_method.quant_config:
|
algo = layer.quant_method.quant_config.name()
|
||||||
algo = layer.quant_method.quant_config.name()
|
|
||||||
|
assert algo == "wint8"
|
||||||
|
|
||||||
assert up_gate_proj_weights[0].shape == [
|
assert up_gate_proj_weights[0].shape == [
|
||||||
layer.hidden_size,
|
layer.hidden_size,
|
||||||
@@ -79,52 +135,12 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
|
|||||||
scale_name = self.added_scale_attrs[idx]
|
scale_name = self.added_scale_attrs[idx]
|
||||||
|
|
||||||
quanted_weight_scale = weight_tensor.abs().max(axis=1)
|
quanted_weight_scale = weight_tensor.abs().max(axis=1)
|
||||||
if self.quant_config is not None:
|
quanted_weight = weight_tensor / quanted_weight_scale[:, None, :] * max_bound
|
||||||
quanted_weight = weight_tensor / quanted_weight_scale[:, None, :] * max_bound
|
quanted_weight = paddle.round(quanted_weight).astype("int8")
|
||||||
quanted_weight = paddle.round(quanted_weight).astype("int8")
|
quanted_weight_scale = quanted_weight_scale / max_bound
|
||||||
quanted_weight_scale = quanted_weight_scale / max_bound
|
|
||||||
|
|
||||||
setattr(
|
getattr(layer, weight_name).set_value(quanted_weight)
|
||||||
layer,
|
getattr(layer, scale_name).set_value(quanted_weight_scale)
|
||||||
weight_name,
|
|
||||||
layer.create_parameter(
|
|
||||||
shape=quanted_weight.shape,
|
|
||||||
dtype=quanted_weight.dtype,
|
|
||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
getattr(layer, weight_name).set_value(quanted_weight)
|
|
||||||
|
|
||||||
setattr(
|
|
||||||
layer,
|
|
||||||
scale_name,
|
|
||||||
layer.create_parameter(
|
|
||||||
shape=quanted_weight_scale.shape,
|
|
||||||
dtype=quanted_weight_scale.dtype,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
getattr(layer, scale_name).set_value(quanted_weight_scale)
|
|
||||||
else:
|
|
||||||
setattr(
|
|
||||||
layer,
|
|
||||||
weight_name,
|
|
||||||
layer.create_parameter(
|
|
||||||
shape=quanted_weight.shape,
|
|
||||||
dtype=quanted_weight.dtype,
|
|
||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
getattr(layer, weight_name).set_value(quanted_weight)
|
|
||||||
|
|
||||||
setattr(
|
|
||||||
layer,
|
|
||||||
scale_name,
|
|
||||||
layer.create_parameter(
|
|
||||||
shape=quanted_weight_scale.shape,
|
|
||||||
dtype=quanted_weight_scale.dtype,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
getattr(layer, scale_name).set_value(quanted_weight_scale)
|
|
||||||
|
|
||||||
@paddle.no_grad()
|
@paddle.no_grad()
|
||||||
def apply(
|
def apply(
|
||||||
@@ -159,16 +175,16 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
|
|||||||
if self.quant_config is not None:
|
if self.quant_config is not None:
|
||||||
config = {
|
config = {
|
||||||
"BLOCK_SIZE_M": 32,
|
"BLOCK_SIZE_M": 32,
|
||||||
"BLOCK_SIZE_N": 128,
|
"BLOCK_SIZE_N": 64,
|
||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 64,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 8,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
config = {
|
config = {
|
||||||
"BLOCK_SIZE_M": 32,
|
"BLOCK_SIZE_M": 32,
|
||||||
"BLOCK_SIZE_N": 64,
|
"BLOCK_SIZE_N": 64,
|
||||||
"BLOCK_SIZE_K": 64,
|
"BLOCK_SIZE_K": 64,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 8,
|
||||||
}
|
}
|
||||||
|
|
||||||
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(
|
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(
|
||||||
|
@@ -313,24 +313,14 @@ class WeightOnlyLinearMethod(QuantMethodBase):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def apply(self, layer, x):
|
def apply(self, layer, x):
|
||||||
if current_platform.is_maca():
|
linear_out = weight_only_linear(
|
||||||
linear_out = weight_only_linear(
|
x,
|
||||||
x,
|
weight=layer.weight,
|
||||||
weight=layer.weight,
|
bias=layer.bias if layer.add_bias else None,
|
||||||
bias=layer.bias if layer.add_bias else None,
|
weight_scale=layer.weight_scale,
|
||||||
weight_scale=layer.weight_scale,
|
weight_dtype=("int8" if self.quant_config.name() == "wint8" else "int4"),
|
||||||
weight_dtype=("int8" if self.quant_config.name() == "wint8" else "int4"),
|
arch=self.quant_config.weight_only_linear_arch,
|
||||||
arch=80,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
linear_out = weight_only_linear(
|
|
||||||
x,
|
|
||||||
weight=layer.weight,
|
|
||||||
bias=layer.bias if layer.add_bias else None,
|
|
||||||
weight_scale=layer.weight_scale,
|
|
||||||
weight_dtype=("int8" if self.quant_config.name() == "wint8" else "int4"),
|
|
||||||
arch=self.quant_config.weight_only_linear_arch,
|
|
||||||
)
|
|
||||||
return linear_out
|
return linear_out
|
||||||
|
|
||||||
|
|
||||||
|
@@ -52,9 +52,9 @@ class ErnieRotaryEmbedding:
|
|||||||
rot_emb = paddle.concat([freqs.cos(), freqs.sin()], axis=-1)
|
rot_emb = paddle.concat([freqs.cos(), freqs.sin()], axis=-1)
|
||||||
return rot_emb
|
return rot_emb
|
||||||
elif paddle.is_compiled_with_custom_device("metax_gpu"):
|
elif paddle.is_compiled_with_custom_device("metax_gpu"):
|
||||||
# shape: [B, S, D]
|
# shape: [B, S, D/2]
|
||||||
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim), dtype="float32")
|
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32")
|
||||||
emb = paddle.stack([freqs, freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim))
|
emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim // 2))
|
||||||
else:
|
else:
|
||||||
# shape: [B, S, D/2]
|
# shape: [B, S, D/2]
|
||||||
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32")
|
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32")
|
||||||
|
Reference in New Issue
Block a user