mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
support c4 attn && fix cache
This commit is contained in:
@@ -1130,6 +1130,10 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
|||||||
LoadOutScaleT out_scale_vec;
|
LoadOutScaleT out_scale_vec;
|
||||||
LoadEmbT cos_emb_vec;
|
LoadEmbT cos_emb_vec;
|
||||||
LoadEmbT sin_emb_vec;
|
LoadEmbT sin_emb_vec;
|
||||||
|
#pragma unroll
|
||||||
|
for (int v_i = 0; v_i < VecSize; v_i++) {
|
||||||
|
bias_vec[v_i] = 0;
|
||||||
|
}
|
||||||
const InT* qkv_now = quant_qkv + token_id * hidden_size;
|
const InT* qkv_now = quant_qkv + token_id * hidden_size;
|
||||||
T* qkv_out_now = qkv_out + token_id * hidden_size;
|
T* qkv_out_now = qkv_out + token_id * hidden_size;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@@ -1137,8 +1141,8 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
|||||||
head_bias += 32 * VecSize) {
|
head_bias += 32 * VecSize) {
|
||||||
const int bias_idx = head_idx * HeadDim + head_bias;
|
const int bias_idx = head_idx * HeadDim + head_bias;
|
||||||
Load<InT, VecSize>(&qkv_now[bias_idx], &src_vec);
|
Load<InT, VecSize>(&qkv_now[bias_idx], &src_vec);
|
||||||
Load<T, VecSize>(&qkv_biases[bias_idx], &bias_vec);
|
// Load<T, VecSize>(&qkv_biases[bias_idx], &bias_vec);
|
||||||
Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
|
// Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
|
||||||
// q rope
|
// q rope
|
||||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||||
@@ -1148,10 +1152,10 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
|||||||
// dequant + add_bias + rope
|
// dequant + add_bias + rope
|
||||||
float input_left = static_cast<float>(src_vec[2 * i]);
|
float input_left = static_cast<float>(src_vec[2 * i]);
|
||||||
float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
||||||
input_left = input_left * out_scale_vec[2 * i] +
|
// input_left = input_left * out_scale_vec[2 * i] +
|
||||||
static_cast<float>(bias_vec[2 * i]);
|
// static_cast<float>(bias_vec[2 * i]);
|
||||||
input_right = input_right * out_scale_vec[2 * i + 1] +
|
// input_right = input_right * out_scale_vec[2 * i + 1] +
|
||||||
static_cast<float>(bias_vec[2 * i + 1]);
|
// static_cast<float>(bias_vec[2 * i + 1]);
|
||||||
const float cos_tmp = cos_emb_vec[i];
|
const float cos_tmp = cos_emb_vec[i];
|
||||||
const float sin_tmp = sin_emb_vec[i];
|
const float sin_tmp = sin_emb_vec[i];
|
||||||
bias_vec[2 * i] =
|
bias_vec[2 * i] =
|
||||||
@@ -1167,6 +1171,35 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
|||||||
using LoadPadKVT = AlignedVector<uint8_t, KV_VEC_SIZE>;
|
using LoadPadKVT = AlignedVector<uint8_t, KV_VEC_SIZE>;
|
||||||
const uint32_t kv_head_idx = (head_idx - num_heads) % gqa_group_size;
|
const uint32_t kv_head_idx = (head_idx - num_heads) % gqa_group_size;
|
||||||
|
|
||||||
|
if (block_offset == 0) {
|
||||||
|
// pad zero for this kv_head_idx for this block
|
||||||
|
LoadPadKVT pad_cache_vec;
|
||||||
|
*(reinterpret_cast<uint4*>(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0);
|
||||||
|
if (head_idx < num_heads + gqa_group_size) {
|
||||||
|
constexpr int num_vecs_per_head_dim = half_head_size / KV_VEC_SIZE;
|
||||||
|
constexpr int num_token_each_time = 32 / num_vecs_per_head_dim;
|
||||||
|
const uint32_t tgt_idx = (block_idx * gqa_group_size + kv_head_idx) *
|
||||||
|
block_size * half_head_size +
|
||||||
|
lane_id % num_vecs_per_head_dim * KV_VEC_SIZE;
|
||||||
|
for (int block_i = lane_id / num_vecs_per_head_dim;
|
||||||
|
block_i < block_size;
|
||||||
|
block_i += num_token_each_time) {
|
||||||
|
Store<uint8_t, KV_VEC_SIZE>(
|
||||||
|
pad_cache_vec, &key_cache[tgt_idx + block_i * half_head_size]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const int num_vecs_per_head_dim = half_block_size / KV_VEC_SIZE;
|
||||||
|
const int num_token_each_time = 32 / num_vecs_per_head_dim;
|
||||||
|
const uint32_t tgt_idx = (block_idx * gqa_group_size + kv_head_idx) *
|
||||||
|
HeadDim * half_block_size +
|
||||||
|
lane_id % num_vecs_per_head_dim * KV_VEC_SIZE;
|
||||||
|
for (int block_i = lane_id / num_vecs_per_head_dim; block_i < HeadDim;
|
||||||
|
block_i += num_token_each_time) {
|
||||||
|
Store<uint8_t, KV_VEC_SIZE>(
|
||||||
|
pad_cache_vec, &value_cache[tgt_idx + block_i * half_block_size]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
constexpr int K_VEC_SIZE = 4;
|
constexpr int K_VEC_SIZE = 4;
|
||||||
constexpr int HALF_K_VEC_SIZE = 2;
|
constexpr int HALF_K_VEC_SIZE = 2;
|
||||||
using LoadKVResT = AlignedVector<uint8_t, K_VEC_SIZE>;
|
using LoadKVResT = AlignedVector<uint8_t, K_VEC_SIZE>;
|
||||||
@@ -1182,7 +1215,11 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
|||||||
LoadScaleT zp_vec1, zp_vec2;
|
LoadScaleT zp_vec1, zp_vec2;
|
||||||
LoadEmbT cos_emb_vec1, cos_emb_vec2;
|
LoadEmbT cos_emb_vec1, cos_emb_vec2;
|
||||||
LoadEmbT sin_emb_vec1, sin_emb_vec2;
|
LoadEmbT sin_emb_vec1, sin_emb_vec2;
|
||||||
|
#pragma unroll
|
||||||
|
for (int v_i = 0; v_i < HALF_K_VEC_SIZE; v_i++) {
|
||||||
|
bias_vec1[v_i] = 0;
|
||||||
|
bias_vec2[v_i] = 0;
|
||||||
|
}
|
||||||
const InT* qkv_now = quant_qkv + token_id * hidden_size;
|
const InT* qkv_now = quant_qkv + token_id * hidden_size;
|
||||||
const int head_bias = lane_id / 4 * 16 + lane_id % 4 * 2;
|
const int head_bias = lane_id / 4 * 16 + lane_id % 4 * 2;
|
||||||
//////////
|
//////////
|
||||||
@@ -1191,11 +1228,11 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
|||||||
Load<InT, HALF_K_VEC_SIZE>(&qkv_now[bias_idx], &src_vec1);
|
Load<InT, HALF_K_VEC_SIZE>(&qkv_now[bias_idx], &src_vec1);
|
||||||
Load<InT, HALF_K_VEC_SIZE>(&qkv_now[bias_idx + 8], &src_vec2);
|
Load<InT, HALF_K_VEC_SIZE>(&qkv_now[bias_idx + 8], &src_vec2);
|
||||||
/////
|
/////
|
||||||
Load<T, HALF_K_VEC_SIZE>(&qkv_biases[bias_idx], &bias_vec1);
|
// Load<T, HALF_K_VEC_SIZE>(&qkv_biases[bias_idx], &bias_vec1);
|
||||||
Load<T, HALF_K_VEC_SIZE>(&qkv_biases[bias_idx + 8], &bias_vec2);
|
// Load<T, HALF_K_VEC_SIZE>(&qkv_biases[bias_idx + 8], &bias_vec2);
|
||||||
Load<float, HALF_K_VEC_SIZE>(&qkv_out_scales[bias_idx], &out_scale_vec1);
|
// Load<float, HALF_K_VEC_SIZE>(&qkv_out_scales[bias_idx], &out_scale_vec1);
|
||||||
Load<float, HALF_K_VEC_SIZE>(&qkv_out_scales[bias_idx + 8],
|
// Load<float, HALF_K_VEC_SIZE>(&qkv_out_scales[bias_idx + 8],
|
||||||
&out_scale_vec2);
|
// &out_scale_vec2);
|
||||||
if (head_idx < num_heads + gqa_group_size) {
|
if (head_idx < num_heads + gqa_group_size) {
|
||||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||||
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
|
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||||
@@ -1215,10 +1252,10 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
|||||||
|
|
||||||
float input_left = static_cast<float>(src_vec1[0]);
|
float input_left = static_cast<float>(src_vec1[0]);
|
||||||
float input_right = static_cast<float>(src_vec1[1]);
|
float input_right = static_cast<float>(src_vec1[1]);
|
||||||
input_left =
|
// input_left =
|
||||||
input_left * out_scale_vec1[0] + static_cast<float>(bias_vec1[0]);
|
// input_left * out_scale_vec1[0] + static_cast<float>(bias_vec1[0]);
|
||||||
input_right =
|
// input_right =
|
||||||
input_right * out_scale_vec1[1] + static_cast<float>(bias_vec1[1]);
|
// input_right * out_scale_vec1[1] + static_cast<float>(bias_vec1[1]);
|
||||||
if (head_idx < num_heads + gqa_group_size) {
|
if (head_idx < num_heads + gqa_group_size) {
|
||||||
float cos_tmp = cos_emb_vec1[0];
|
float cos_tmp = cos_emb_vec1[0];
|
||||||
float sin_tmp = sin_emb_vec1[0];
|
float sin_tmp = sin_emb_vec1[0];
|
||||||
@@ -1233,10 +1270,10 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
|||||||
|
|
||||||
input_left = static_cast<float>(src_vec2[0]);
|
input_left = static_cast<float>(src_vec2[0]);
|
||||||
input_right = static_cast<float>(src_vec2[1]);
|
input_right = static_cast<float>(src_vec2[1]);
|
||||||
input_left =
|
// input_left =
|
||||||
input_left * out_scale_vec2[0] + static_cast<float>(bias_vec2[0]);
|
// input_left * out_scale_vec2[0] + static_cast<float>(bias_vec2[0]);
|
||||||
input_right =
|
// input_right =
|
||||||
input_right * out_scale_vec2[1] + static_cast<float>(bias_vec2[1]);
|
// input_right * out_scale_vec2[1] + static_cast<float>(bias_vec2[1]);
|
||||||
if (head_idx < num_heads + gqa_group_size) {
|
if (head_idx < num_heads + gqa_group_size) {
|
||||||
float cos_tmp = cos_emb_vec2[0];
|
float cos_tmp = cos_emb_vec2[0];
|
||||||
float sin_tmp = sin_emb_vec2[0];
|
float sin_tmp = sin_emb_vec2[0];
|
||||||
|
@@ -18,7 +18,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
|
|
||||||
@@ -191,16 +191,25 @@ class AppendAttentionBackend(AttentionBackend):
|
|||||||
def get_kv_cache_shape(
|
def get_kv_cache_shape(
|
||||||
self,
|
self,
|
||||||
max_num_blocks: int,
|
max_num_blocks: int,
|
||||||
) -> Tuple[int, int, int, int]:
|
kv_cache_quant_type: str = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Caculate kv cache shape
|
Caculate kv cache shape
|
||||||
"""
|
"""
|
||||||
return (
|
if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp":
|
||||||
max_num_blocks,
|
return (
|
||||||
self.kv_num_heads,
|
max_num_blocks,
|
||||||
self.block_size,
|
self.kv_num_heads,
|
||||||
self.head_dim,
|
self.block_size,
|
||||||
)
|
self.head_dim // 2,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return (
|
||||||
|
max_num_blocks,
|
||||||
|
self.kv_num_heads,
|
||||||
|
self.block_size,
|
||||||
|
self.head_dim,
|
||||||
|
)
|
||||||
|
|
||||||
def forward_mixed(
|
def forward_mixed(
|
||||||
self,
|
self,
|
||||||
|
@@ -116,16 +116,25 @@ class BlockAttentionBackend(AttentionBackend):
|
|||||||
def get_kv_cache_shape(
|
def get_kv_cache_shape(
|
||||||
self,
|
self,
|
||||||
max_num_blocks: int,
|
max_num_blocks: int,
|
||||||
|
kv_cache_quant_type: str = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Caculate kv cache shape
|
Caculate kv cache shape
|
||||||
"""
|
"""
|
||||||
return (
|
if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp":
|
||||||
max_num_blocks,
|
return (
|
||||||
self.kv_num_heads,
|
max_num_blocks,
|
||||||
self.block_size,
|
self.kv_num_heads,
|
||||||
self.head_dim,
|
self.block_size,
|
||||||
)
|
self.head_dim // 2,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return (
|
||||||
|
max_num_blocks,
|
||||||
|
self.kv_num_heads,
|
||||||
|
self.block_size,
|
||||||
|
self.head_dim,
|
||||||
|
)
|
||||||
|
|
||||||
def forward_mixed(
|
def forward_mixed(
|
||||||
self,
|
self,
|
||||||
|
@@ -136,16 +136,25 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
def get_kv_cache_shape(
|
def get_kv_cache_shape(
|
||||||
self,
|
self,
|
||||||
max_num_blocks: int,
|
max_num_blocks: int,
|
||||||
|
kv_cache_quant_type: str = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Caculate kv cache shape
|
Caculate kv cache shape
|
||||||
"""
|
"""
|
||||||
return (
|
if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp":
|
||||||
max_num_blocks,
|
return (
|
||||||
self.kv_num_heads,
|
max_num_blocks,
|
||||||
self.block_size,
|
self.kv_num_heads,
|
||||||
self.head_dim,
|
self.block_size,
|
||||||
)
|
self.head_dim // 2,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return (
|
||||||
|
max_num_blocks,
|
||||||
|
self.kv_num_heads,
|
||||||
|
self.block_size,
|
||||||
|
self.head_dim,
|
||||||
|
)
|
||||||
|
|
||||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||||
metadata = FlashAttentionMetadata()
|
metadata = FlashAttentionMetadata()
|
||||||
|
@@ -132,6 +132,7 @@ class IluvatarAttnBackend(AttentionBackend):
|
|||||||
def get_kv_cache_shape(
|
def get_kv_cache_shape(
|
||||||
self,
|
self,
|
||||||
max_num_blocks: int,
|
max_num_blocks: int,
|
||||||
|
kv_cache_quant_type: str = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Caculate kv cache shape
|
Caculate kv cache shape
|
||||||
|
@@ -217,14 +217,17 @@ class MLAAttentionBackend(AttentionBackend):
|
|||||||
self.attention_metadata: AttentionMetadata = metadata
|
self.attention_metadata: AttentionMetadata = metadata
|
||||||
|
|
||||||
forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False)
|
forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False)
|
||||||
forward_meta.decoder_tile_ids_per_batch.copy_(
|
forward_meta.decoder_tile_ids_per_batch.copy_(metadata.decoder_tile_ids_per_batch, False)
|
||||||
metadata.decoder_tile_ids_per_batch, False)
|
|
||||||
|
|
||||||
def get_attntion_meta(self) -> AttentionMetadata:
|
def get_attntion_meta(self) -> AttentionMetadata:
|
||||||
"""get_attntion_meta"""
|
"""get_attntion_meta"""
|
||||||
return self.attention_metadata
|
return self.attention_metadata
|
||||||
|
|
||||||
def get_kv_cache_shape(self, max_num_blocks: int) -> Tuple[int, int, int, int]:
|
def get_kv_cache_shape(
|
||||||
|
self,
|
||||||
|
max_num_blocks: int,
|
||||||
|
kv_cache_quant_type: str = None,
|
||||||
|
) -> Tuple[int, int, int, int]:
|
||||||
"""
|
"""
|
||||||
Calculate kv cache shape for MLA
|
Calculate kv cache shape for MLA
|
||||||
"""
|
"""
|
||||||
|
@@ -146,6 +146,7 @@ class XPUAttentionBackend(AttentionBackend):
|
|||||||
def get_kv_cache_shape(
|
def get_kv_cache_shape(
|
||||||
self,
|
self,
|
||||||
max_num_blocks: int,
|
max_num_blocks: int,
|
||||||
|
kv_cache_quant_type: str = None,
|
||||||
) -> Tuple[int, int, int, int]:
|
) -> Tuple[int, int, int, int]:
|
||||||
"""
|
"""
|
||||||
Caculate kv cache shape
|
Caculate kv cache shape
|
||||||
|
@@ -211,6 +211,7 @@ class GCUFlashAttnBackend(AttentionBackend):
|
|||||||
def get_kv_cache_shape(
|
def get_kv_cache_shape(
|
||||||
self,
|
self,
|
||||||
max_num_blocks: int,
|
max_num_blocks: int,
|
||||||
|
kv_cache_quant_type: str = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Caculate kv cache shape
|
Caculate kv cache shape
|
||||||
|
@@ -222,6 +222,7 @@ class GCUMemEfficientAttnBackend(AttentionBackend):
|
|||||||
def get_kv_cache_shape(
|
def get_kv_cache_shape(
|
||||||
self,
|
self,
|
||||||
max_num_blocks: int,
|
max_num_blocks: int,
|
||||||
|
kv_cache_quant_type: str = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Caculate kv cache shape
|
Caculate kv cache shape
|
||||||
|
@@ -34,6 +34,7 @@ class KvCacheQuantzationTypes(str, Enum):
|
|||||||
INT8 = "int8"
|
INT8 = "int8"
|
||||||
FP8 = "float8_e4m3fn"
|
FP8 = "float8_e4m3fn"
|
||||||
INT8_ZP = "int8_zp"
|
INT8_ZP = "int8_zp"
|
||||||
|
INT4_ZP = "int4_zp"
|
||||||
FP8_ZP = "float8_e4m3fn_zp"
|
FP8_ZP = "float8_e4m3fn_zp"
|
||||||
|
|
||||||
|
|
||||||
@@ -42,24 +43,29 @@ class KvCacheQuantConfig(QuantConfigBase):
|
|||||||
quantization config for weight 4bits and activation fp8
|
quantization config for weight 4bits and activation fp8
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, kv_cache_quant_type: str) -> None:
|
def __init__(self, kv_cache_quant_type: str, is_channel_wise: bool, has_zero_point: bool) -> None:
|
||||||
"""
|
"""
|
||||||
__init__
|
__init__
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.kv_cache_quant_type = kv_cache_quant_type
|
self.kv_cache_quant_type = kv_cache_quant_type
|
||||||
|
self.is_channel_wise = is_channel_wise
|
||||||
|
self.has_zero_point = has_zero_point
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.quant_type = KvCacheQuantzationTypes(kv_cache_quant_type)
|
self.quant_type = KvCacheQuantzationTypes(kv_cache_quant_type)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise ValueError(f"Invalid Kvcache type: {kv_cache_quant_type}")
|
raise ValueError(f"Invalid Kvcache type: {kv_cache_quant_type}")
|
||||||
|
|
||||||
self.has_zero_point = "zp" in kv_cache_quant_type
|
if "zp" in kv_cache_quant_type:
|
||||||
|
self.has_zero_point = True
|
||||||
|
|
||||||
if self.quant_type == KvCacheQuantzationTypes.INT8 or self.quant_type == KvCacheQuantzationTypes.INT8_ZP:
|
if self.quant_type == KvCacheQuantzationTypes.INT8 or self.quant_type == KvCacheQuantzationTypes.INT8_ZP:
|
||||||
self.max_bound = 127.0
|
self.max_bound = 127.0
|
||||||
elif self.quant_type == KvCacheQuantzationTypes.FP8 or self.quant_type == KvCacheQuantzationTypes.FP8_ZP:
|
elif self.quant_type == KvCacheQuantzationTypes.FP8 or self.quant_type == KvCacheQuantzationTypes.FP8_ZP:
|
||||||
self.max_bound = 448.0
|
self.max_bound = 448.0
|
||||||
|
elif self.quant_type == KvCacheQuantzationTypes.INT4_ZP:
|
||||||
|
self.max_bound = 7.0
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid Kvcache type: {kv_cache_quant_type}")
|
raise ValueError(f"Invalid Kvcache type: {kv_cache_quant_type}")
|
||||||
|
|
||||||
@@ -70,11 +76,13 @@ class KvCacheQuantConfig(QuantConfigBase):
|
|||||||
return "kvcache"
|
return "kvcache"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, kv_cache_quant_type: str) -> "KvCacheQuantConfig":
|
def from_config(
|
||||||
|
cls, kv_cache_quant_type: str, is_channel_wise: bool, has_zero_point: bool
|
||||||
|
) -> "KvCacheQuantConfig":
|
||||||
"""
|
"""
|
||||||
from_config
|
from_config
|
||||||
"""
|
"""
|
||||||
return cls(kv_cache_quant_type)
|
return cls(kv_cache_quant_type, is_channel_wise, has_zero_point)
|
||||||
|
|
||||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||||
"""
|
"""
|
||||||
@@ -102,8 +110,8 @@ class KVCacheMethodBase(QuantMethodBase):
|
|||||||
"""
|
"""
|
||||||
load_zp
|
load_zp
|
||||||
"""
|
"""
|
||||||
cache_k_zeropoint = get_tensor(state_dict.pop(self.cache_k_zp_name))
|
cache_k_zeropoint = get_tensor(state_dict.pop(self.cache_k_zp_name)).cast(paddle.get_default_dtype())
|
||||||
cache_v_zeropoint = get_tensor(state_dict.pop(self.cache_v_zp_name))
|
cache_v_zeropoint = get_tensor(state_dict.pop(self.cache_v_zp_name)).cast(paddle.get_default_dtype())
|
||||||
|
|
||||||
create_and_set_parameter(layer, "cache_k_zp", cache_k_zeropoint)
|
create_and_set_parameter(layer, "cache_k_zp", cache_k_zeropoint)
|
||||||
create_and_set_parameter(layer, "cache_v_zp", cache_v_zeropoint)
|
create_and_set_parameter(layer, "cache_v_zp", cache_v_zeropoint)
|
||||||
@@ -112,17 +120,36 @@ class KVCacheMethodBase(QuantMethodBase):
|
|||||||
"""
|
"""
|
||||||
load_scale
|
load_scale
|
||||||
"""
|
"""
|
||||||
cache_k_scale_tensor = (
|
|
||||||
get_tensor(state_dict.pop(self.cache_k_scale_name)).cast(paddle.get_default_dtype()).reshape_([-1])
|
|
||||||
)
|
|
||||||
cache_v_scale_tensor = (
|
|
||||||
get_tensor(state_dict.pop(self.cache_v_scale_name)).cast(paddle.get_default_dtype()).reshape_([-1])
|
|
||||||
)
|
|
||||||
|
|
||||||
cache_k_scale = self.cache_quant_config.max_bound / cache_k_scale_tensor
|
if self.cache_quant_config.is_channel_wise:
|
||||||
cache_v_scale = self.cache_quant_config.max_bound / cache_v_scale_tensor
|
cache_k_scale_tensor = (
|
||||||
cache_k_out_scale = cache_k_scale_tensor / self.cache_quant_config.max_bound
|
get_tensor(state_dict.pop(self.cache_k_scale_name))
|
||||||
cache_v_out_scale = cache_v_scale_tensor / self.cache_quant_config.max_bound
|
.cast(paddle.get_default_dtype())
|
||||||
|
.reshape_([-1, layer.head_dim])
|
||||||
|
)
|
||||||
|
cache_v_scale_tensor = (
|
||||||
|
get_tensor(state_dict.pop(self.cache_v_scale_name))
|
||||||
|
.cast(paddle.get_default_dtype())
|
||||||
|
.reshape_([-1, layer.head_dim])
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cache_k_scale_tensor = (
|
||||||
|
get_tensor(state_dict.pop(self.cache_k_scale_name)).cast(paddle.get_default_dtype()).reshape_([-1])
|
||||||
|
)
|
||||||
|
cache_v_scale_tensor = (
|
||||||
|
get_tensor(state_dict.pop(self.cache_v_scale_name)).cast(paddle.get_default_dtype()).reshape_([-1])
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.cache_quant_config.has_zero_point: # cache_int4_zp
|
||||||
|
cache_k_scale = 1.0 / cache_k_scale_tensor
|
||||||
|
cache_v_scale = 1.0 / cache_v_scale_tensor
|
||||||
|
cache_k_out_scale = cache_k_scale_tensor
|
||||||
|
cache_v_out_scale = cache_v_scale_tensor
|
||||||
|
else:
|
||||||
|
cache_k_scale = self.cache_quant_config.max_bound / cache_k_scale_tensor
|
||||||
|
cache_v_scale = self.cache_quant_config.max_bound / cache_v_scale_tensor
|
||||||
|
cache_k_out_scale = cache_k_scale_tensor / self.cache_quant_config.max_bound
|
||||||
|
cache_v_out_scale = cache_v_scale_tensor / self.cache_quant_config.max_bound
|
||||||
|
|
||||||
create_and_set_parameter(layer, "cache_k_scale", cache_k_scale)
|
create_and_set_parameter(layer, "cache_k_scale", cache_k_scale)
|
||||||
create_and_set_parameter(layer, "cache_v_scale", cache_v_scale)
|
create_and_set_parameter(layer, "cache_v_scale", cache_v_scale)
|
||||||
@@ -147,6 +174,10 @@ class KVCacheMethodBase(QuantMethodBase):
|
|||||||
layer.cache_quant_type_str = "cache_fp8"
|
layer.cache_quant_type_str = "cache_fp8"
|
||||||
layer.quant_max_bound = 448.0
|
layer.quant_max_bound = 448.0
|
||||||
layer.quant_min_bound = -448.0
|
layer.quant_min_bound = -448.0
|
||||||
|
elif self.cache_quant_config.quant_type == KvCacheQuantzationTypes.INT4_ZP:
|
||||||
|
layer.cache_quant_type_str = "cache_int4_zp"
|
||||||
|
layer.quant_max_bound = 7.0
|
||||||
|
layer.quant_min_bound = -7.0
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"{self.cache_quant_config.quant_type} is not implemented")
|
raise NotImplementedError(f"{self.cache_quant_config.quant_type} is not implemented")
|
||||||
|
|
||||||
|
@@ -34,6 +34,8 @@ class MixQuantConfig(QuantConfigBase):
|
|||||||
moe_quant_type: str,
|
moe_quant_type: str,
|
||||||
kv_cache_quant_type: str = None,
|
kv_cache_quant_type: str = None,
|
||||||
image_moe_quant_type: str = None,
|
image_moe_quant_type: str = None,
|
||||||
|
is_channel_wise: bool = False,
|
||||||
|
has_zero_point: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dense_quant_type = dense_quant_type
|
self.dense_quant_type = dense_quant_type
|
||||||
@@ -43,6 +45,8 @@ class MixQuantConfig(QuantConfigBase):
|
|||||||
self.image_moe_quant_type = moe_quant_type
|
self.image_moe_quant_type = moe_quant_type
|
||||||
else:
|
else:
|
||||||
self.image_moe_quant_type = image_moe_quant_type
|
self.image_moe_quant_type = image_moe_quant_type
|
||||||
|
self.is_channel_wise = is_channel_wise
|
||||||
|
self.has_zero_point = has_zero_point
|
||||||
self.quant_max_bound = 0
|
self.quant_max_bound = 0
|
||||||
self.quant_min_bound = 0
|
self.quant_min_bound = 0
|
||||||
self.quant_round_type = 0
|
self.quant_round_type = 0
|
||||||
@@ -57,6 +61,8 @@ class MixQuantConfig(QuantConfigBase):
|
|||||||
config["moe_quant_type"],
|
config["moe_quant_type"],
|
||||||
config.get("kv_cache_quant_type", None),
|
config.get("kv_cache_quant_type", None),
|
||||||
config.get("image_moe_quant_type", None),
|
config.get("image_moe_quant_type", None),
|
||||||
|
config.get("is_channel_wise", False),
|
||||||
|
config.get("has_zero_point", False),
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||||
@@ -67,7 +73,11 @@ class MixQuantConfig(QuantConfigBase):
|
|||||||
return get_quantization_config(self.moe_quant_type).from_config({}).get_quant_method(layer)
|
return get_quantization_config(self.moe_quant_type).from_config({}).get_quant_method(layer)
|
||||||
elif isinstance(layer, Attention):
|
elif isinstance(layer, Attention):
|
||||||
if self.kv_cache_quant_type is not None:
|
if self.kv_cache_quant_type is not None:
|
||||||
return get_quantization_config("kvcache").from_config(self.kv_cache_quant_type).get_quant_method(layer)
|
return (
|
||||||
|
get_quantization_config("kvcache")
|
||||||
|
.from_config(self.kv_cache_quant_type, self.is_channel_wise, self.has_zero_point)
|
||||||
|
.get_quant_method(layer)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
|
@@ -127,15 +127,19 @@ class MTPProposer(Proposer):
|
|||||||
|
|
||||||
cache_type = self.parallel_config.dtype
|
cache_type = self.parallel_config.dtype
|
||||||
|
|
||||||
|
kv_cache_quant_type = None
|
||||||
if (
|
if (
|
||||||
self.quant_config
|
self.quant_config
|
||||||
and hasattr(self.quant_config, "kv_cache_quant_type")
|
and hasattr(self.quant_config, "kv_cache_quant_type")
|
||||||
and self.quant_config.kv_cache_quant_type is not None
|
and self.quant_config.kv_cache_quant_type is not None
|
||||||
):
|
):
|
||||||
cache_type = "uint8"
|
cache_type = "uint8"
|
||||||
|
kv_cache_quant_type = self.quant_config.kv_cache_quant_type
|
||||||
|
|
||||||
# Get kv cache shape
|
# Get kv cache shape
|
||||||
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=self.num_gpu_blocks)
|
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
|
||||||
|
max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
|
||||||
|
)
|
||||||
if not self.parallel_config.do_profile and (
|
if not self.parallel_config.do_profile and (
|
||||||
self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"
|
self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"
|
||||||
):
|
):
|
||||||
|
@@ -568,15 +568,19 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
# Get kv cache dtype
|
# Get kv cache dtype
|
||||||
cache_type = self.parallel_config.dtype
|
cache_type = self.parallel_config.dtype
|
||||||
|
|
||||||
|
kv_cache_quant_type = None
|
||||||
if (
|
if (
|
||||||
self.quant_config
|
self.quant_config
|
||||||
and hasattr(self.quant_config, "kv_cache_quant_type")
|
and hasattr(self.quant_config, "kv_cache_quant_type")
|
||||||
and self.quant_config.kv_cache_quant_type is not None
|
and self.quant_config.kv_cache_quant_type is not None
|
||||||
):
|
):
|
||||||
cache_type = "uint8"
|
cache_type = "uint8"
|
||||||
|
kv_cache_quant_type = self.quant_config.kv_cache_quant_type
|
||||||
|
|
||||||
# Get kv cache shape
|
# Get kv cache shape
|
||||||
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=max_block_num)
|
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
|
||||||
|
max_num_blocks=max_block_num, kv_cache_quant_type=kv_cache_quant_type
|
||||||
|
)
|
||||||
# local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
# local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
||||||
|
|
||||||
if not profile and (
|
if not profile and (
|
||||||
|
@@ -810,15 +810,19 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
# Get kv cache dtype
|
# Get kv cache dtype
|
||||||
cache_type = self.parallel_config.dtype
|
cache_type = self.parallel_config.dtype
|
||||||
|
|
||||||
|
kv_cache_quant_type = None
|
||||||
if (
|
if (
|
||||||
self.quant_config
|
self.quant_config
|
||||||
and hasattr(self.quant_config, "kv_cache_quant_type")
|
and hasattr(self.quant_config, "kv_cache_quant_type")
|
||||||
and self.quant_config.kv_cache_quant_type is not None
|
and self.quant_config.kv_cache_quant_type is not None
|
||||||
):
|
):
|
||||||
cache_type = "uint8"
|
cache_type = "uint8"
|
||||||
|
kv_cache_quant_type = self.quant_config.kv_cache_quant_type
|
||||||
|
|
||||||
# Get kv cache shape
|
# Get kv cache shape
|
||||||
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=max_block_num)
|
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
|
||||||
|
max_num_blocks=max_block_num, kv_cache_quant_type=kv_cache_quant_type
|
||||||
|
)
|
||||||
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
||||||
|
|
||||||
if not profile and (
|
if not profile and (
|
||||||
|
@@ -559,15 +559,19 @@ class IluvatarModelRunner(ModelRunnerBase):
|
|||||||
# Get kv cache dtype
|
# Get kv cache dtype
|
||||||
cache_type = self.parallel_config.dtype
|
cache_type = self.parallel_config.dtype
|
||||||
|
|
||||||
|
kv_cache_quant_type = None
|
||||||
if (
|
if (
|
||||||
self.quant_config
|
self.quant_config
|
||||||
and hasattr(self.quant_config, "kv_cache_quant_type")
|
and hasattr(self.quant_config, "kv_cache_quant_type")
|
||||||
and self.quant_config.kv_cache_quant_type is not None
|
and self.quant_config.kv_cache_quant_type is not None
|
||||||
):
|
):
|
||||||
cache_type = "uint8"
|
cache_type = "uint8"
|
||||||
|
kv_cache_quant_type = self.quant_config.kv_cache_quant_type
|
||||||
|
|
||||||
# Get kv cache shape
|
# Get kv cache shape
|
||||||
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=max_block_num)
|
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
|
||||||
|
max_num_blocks=max_block_num, kv_cache_quant_type=kv_cache_quant_type
|
||||||
|
)
|
||||||
|
|
||||||
if not self.parallel_config.do_profile and (
|
if not self.parallel_config.do_profile and (
|
||||||
self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"
|
self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"
|
||||||
|
@@ -520,14 +520,19 @@ class XPUModelRunner(ModelRunnerBase):
|
|||||||
|
|
||||||
cache_type = self.parallel_config.dtype
|
cache_type = self.parallel_config.dtype
|
||||||
|
|
||||||
|
kv_cache_quant_type = None
|
||||||
if (
|
if (
|
||||||
self.quant_config
|
self.quant_config
|
||||||
and hasattr(self.quant_config, "kv_cache_quant_type")
|
and hasattr(self.quant_config, "kv_cache_quant_type")
|
||||||
and self.quant_config.kv_cache_quant_type is not None
|
and self.quant_config.kv_cache_quant_type is not None
|
||||||
):
|
):
|
||||||
cache_type = "uint8"
|
cache_type = "uint8"
|
||||||
|
kv_cache_quant_type = self.quant_config.kv_cache_quant_type
|
||||||
|
|
||||||
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=max_block_num)
|
# Get kv cache shape
|
||||||
|
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
|
||||||
|
max_num_blocks=max_block_num, kv_cache_quant_type=kv_cache_quant_type
|
||||||
|
)
|
||||||
|
|
||||||
for i in range(self.model_config.num_hidden_layers):
|
for i in range(self.model_config.num_hidden_layers):
|
||||||
cache_kvs[f"key_caches_{i}"] = paddle.full(
|
cache_kvs[f"key_caches_{i}"] = paddle.full(
|
||||||
|
Reference in New Issue
Block a user