support c4 attn && fix cache

This commit is contained in:
lizhenyun01
2025-07-23 23:51:28 +08:00
parent 832d25334a
commit 29c3292f02
16 changed files with 198 additions and 65 deletions

View File

@@ -1130,6 +1130,10 @@ __global__ void append_speculate_cache_int4_rope_kernel(
LoadOutScaleT out_scale_vec; LoadOutScaleT out_scale_vec;
LoadEmbT cos_emb_vec; LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec; LoadEmbT sin_emb_vec;
#pragma unroll
for (int v_i = 0; v_i < VecSize; v_i++) {
bias_vec[v_i] = 0;
}
const InT* qkv_now = quant_qkv + token_id * hidden_size; const InT* qkv_now = quant_qkv + token_id * hidden_size;
T* qkv_out_now = qkv_out + token_id * hidden_size; T* qkv_out_now = qkv_out + token_id * hidden_size;
#pragma unroll #pragma unroll
@@ -1137,8 +1141,8 @@ __global__ void append_speculate_cache_int4_rope_kernel(
head_bias += 32 * VecSize) { head_bias += 32 * VecSize) {
const int bias_idx = head_idx * HeadDim + head_bias; const int bias_idx = head_idx * HeadDim + head_bias;
Load<InT, VecSize>(&qkv_now[bias_idx], &src_vec); Load<InT, VecSize>(&qkv_now[bias_idx], &src_vec);
Load<T, VecSize>(&qkv_biases[bias_idx], &bias_vec); // Load<T, VecSize>(&qkv_biases[bias_idx], &bias_vec);
Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec); // Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
// q rope // q rope
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec); Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
@@ -1148,10 +1152,10 @@ __global__ void append_speculate_cache_int4_rope_kernel(
// dequant + add_bias + rope // dequant + add_bias + rope
float input_left = static_cast<float>(src_vec[2 * i]); float input_left = static_cast<float>(src_vec[2 * i]);
float input_right = static_cast<float>(src_vec[2 * i + 1]); float input_right = static_cast<float>(src_vec[2 * i + 1]);
input_left = input_left * out_scale_vec[2 * i] + // input_left = input_left * out_scale_vec[2 * i] +
static_cast<float>(bias_vec[2 * i]); // static_cast<float>(bias_vec[2 * i]);
input_right = input_right * out_scale_vec[2 * i + 1] + // input_right = input_right * out_scale_vec[2 * i + 1] +
static_cast<float>(bias_vec[2 * i + 1]); // static_cast<float>(bias_vec[2 * i + 1]);
const float cos_tmp = cos_emb_vec[i]; const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i]; const float sin_tmp = sin_emb_vec[i];
bias_vec[2 * i] = bias_vec[2 * i] =
@@ -1167,6 +1171,35 @@ __global__ void append_speculate_cache_int4_rope_kernel(
using LoadPadKVT = AlignedVector<uint8_t, KV_VEC_SIZE>; using LoadPadKVT = AlignedVector<uint8_t, KV_VEC_SIZE>;
const uint32_t kv_head_idx = (head_idx - num_heads) % gqa_group_size; const uint32_t kv_head_idx = (head_idx - num_heads) % gqa_group_size;
if (block_offset == 0) {
// pad zero for this kv_head_idx for this block
LoadPadKVT pad_cache_vec;
*(reinterpret_cast<uint4*>(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0);
if (head_idx < num_heads + gqa_group_size) {
constexpr int num_vecs_per_head_dim = half_head_size / KV_VEC_SIZE;
constexpr int num_token_each_time = 32 / num_vecs_per_head_dim;
const uint32_t tgt_idx = (block_idx * gqa_group_size + kv_head_idx) *
block_size * half_head_size +
lane_id % num_vecs_per_head_dim * KV_VEC_SIZE;
for (int block_i = lane_id / num_vecs_per_head_dim;
block_i < block_size;
block_i += num_token_each_time) {
Store<uint8_t, KV_VEC_SIZE>(
pad_cache_vec, &key_cache[tgt_idx + block_i * half_head_size]);
}
} else {
const int num_vecs_per_head_dim = half_block_size / KV_VEC_SIZE;
const int num_token_each_time = 32 / num_vecs_per_head_dim;
const uint32_t tgt_idx = (block_idx * gqa_group_size + kv_head_idx) *
HeadDim * half_block_size +
lane_id % num_vecs_per_head_dim * KV_VEC_SIZE;
for (int block_i = lane_id / num_vecs_per_head_dim; block_i < HeadDim;
block_i += num_token_each_time) {
Store<uint8_t, KV_VEC_SIZE>(
pad_cache_vec, &value_cache[tgt_idx + block_i * half_block_size]);
}
}
}
constexpr int K_VEC_SIZE = 4; constexpr int K_VEC_SIZE = 4;
constexpr int HALF_K_VEC_SIZE = 2; constexpr int HALF_K_VEC_SIZE = 2;
using LoadKVResT = AlignedVector<uint8_t, K_VEC_SIZE>; using LoadKVResT = AlignedVector<uint8_t, K_VEC_SIZE>;
@@ -1182,7 +1215,11 @@ __global__ void append_speculate_cache_int4_rope_kernel(
LoadScaleT zp_vec1, zp_vec2; LoadScaleT zp_vec1, zp_vec2;
LoadEmbT cos_emb_vec1, cos_emb_vec2; LoadEmbT cos_emb_vec1, cos_emb_vec2;
LoadEmbT sin_emb_vec1, sin_emb_vec2; LoadEmbT sin_emb_vec1, sin_emb_vec2;
#pragma unroll
for (int v_i = 0; v_i < HALF_K_VEC_SIZE; v_i++) {
bias_vec1[v_i] = 0;
bias_vec2[v_i] = 0;
}
const InT* qkv_now = quant_qkv + token_id * hidden_size; const InT* qkv_now = quant_qkv + token_id * hidden_size;
const int head_bias = lane_id / 4 * 16 + lane_id % 4 * 2; const int head_bias = lane_id / 4 * 16 + lane_id % 4 * 2;
////////// //////////
@@ -1191,11 +1228,11 @@ __global__ void append_speculate_cache_int4_rope_kernel(
Load<InT, HALF_K_VEC_SIZE>(&qkv_now[bias_idx], &src_vec1); Load<InT, HALF_K_VEC_SIZE>(&qkv_now[bias_idx], &src_vec1);
Load<InT, HALF_K_VEC_SIZE>(&qkv_now[bias_idx + 8], &src_vec2); Load<InT, HALF_K_VEC_SIZE>(&qkv_now[bias_idx + 8], &src_vec2);
///// /////
Load<T, HALF_K_VEC_SIZE>(&qkv_biases[bias_idx], &bias_vec1); // Load<T, HALF_K_VEC_SIZE>(&qkv_biases[bias_idx], &bias_vec1);
Load<T, HALF_K_VEC_SIZE>(&qkv_biases[bias_idx + 8], &bias_vec2); // Load<T, HALF_K_VEC_SIZE>(&qkv_biases[bias_idx + 8], &bias_vec2);
Load<float, HALF_K_VEC_SIZE>(&qkv_out_scales[bias_idx], &out_scale_vec1); // Load<float, HALF_K_VEC_SIZE>(&qkv_out_scales[bias_idx], &out_scale_vec1);
Load<float, HALF_K_VEC_SIZE>(&qkv_out_scales[bias_idx + 8], // Load<float, HALF_K_VEC_SIZE>(&qkv_out_scales[bias_idx + 8],
&out_scale_vec2); // &out_scale_vec2);
if (head_idx < num_heads + gqa_group_size) { if (head_idx < num_heads + gqa_group_size) {
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1); Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
@@ -1215,10 +1252,10 @@ __global__ void append_speculate_cache_int4_rope_kernel(
float input_left = static_cast<float>(src_vec1[0]); float input_left = static_cast<float>(src_vec1[0]);
float input_right = static_cast<float>(src_vec1[1]); float input_right = static_cast<float>(src_vec1[1]);
input_left = // input_left =
input_left * out_scale_vec1[0] + static_cast<float>(bias_vec1[0]); // input_left * out_scale_vec1[0] + static_cast<float>(bias_vec1[0]);
input_right = // input_right =
input_right * out_scale_vec1[1] + static_cast<float>(bias_vec1[1]); // input_right * out_scale_vec1[1] + static_cast<float>(bias_vec1[1]);
if (head_idx < num_heads + gqa_group_size) { if (head_idx < num_heads + gqa_group_size) {
float cos_tmp = cos_emb_vec1[0]; float cos_tmp = cos_emb_vec1[0];
float sin_tmp = sin_emb_vec1[0]; float sin_tmp = sin_emb_vec1[0];
@@ -1233,10 +1270,10 @@ __global__ void append_speculate_cache_int4_rope_kernel(
input_left = static_cast<float>(src_vec2[0]); input_left = static_cast<float>(src_vec2[0]);
input_right = static_cast<float>(src_vec2[1]); input_right = static_cast<float>(src_vec2[1]);
input_left = // input_left =
input_left * out_scale_vec2[0] + static_cast<float>(bias_vec2[0]); // input_left * out_scale_vec2[0] + static_cast<float>(bias_vec2[0]);
input_right = // input_right =
input_right * out_scale_vec2[1] + static_cast<float>(bias_vec2[1]); // input_right * out_scale_vec2[1] + static_cast<float>(bias_vec2[1]);
if (head_idx < num_heads + gqa_group_size) { if (head_idx < num_heads + gqa_group_size) {
float cos_tmp = cos_emb_vec2[0]; float cos_tmp = cos_emb_vec2[0];
float sin_tmp = sin_emb_vec2[0]; float sin_tmp = sin_emb_vec2[0];

View File

@@ -18,7 +18,7 @@ from __future__ import annotations
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, List, Optional, Tuple from typing import TYPE_CHECKING, List, Optional
import paddle import paddle
@@ -191,16 +191,25 @@ class AppendAttentionBackend(AttentionBackend):
def get_kv_cache_shape( def get_kv_cache_shape(
self, self,
max_num_blocks: int, max_num_blocks: int,
) -> Tuple[int, int, int, int]: kv_cache_quant_type: str = None,
):
""" """
Caculate kv cache shape Caculate kv cache shape
""" """
return ( if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp":
max_num_blocks, return (
self.kv_num_heads, max_num_blocks,
self.block_size, self.kv_num_heads,
self.head_dim, self.block_size,
) self.head_dim // 2,
)
else:
return (
max_num_blocks,
self.kv_num_heads,
self.block_size,
self.head_dim,
)
def forward_mixed( def forward_mixed(
self, self,

View File

@@ -116,16 +116,25 @@ class BlockAttentionBackend(AttentionBackend):
def get_kv_cache_shape( def get_kv_cache_shape(
self, self,
max_num_blocks: int, max_num_blocks: int,
kv_cache_quant_type: str = None,
): ):
""" """
Caculate kv cache shape Caculate kv cache shape
""" """
return ( if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp":
max_num_blocks, return (
self.kv_num_heads, max_num_blocks,
self.block_size, self.kv_num_heads,
self.head_dim, self.block_size,
) self.head_dim // 2,
)
else:
return (
max_num_blocks,
self.kv_num_heads,
self.block_size,
self.head_dim,
)
def forward_mixed( def forward_mixed(
self, self,

View File

@@ -136,16 +136,25 @@ class FlashAttentionBackend(AttentionBackend):
def get_kv_cache_shape( def get_kv_cache_shape(
self, self,
max_num_blocks: int, max_num_blocks: int,
kv_cache_quant_type: str = None,
): ):
""" """
Caculate kv cache shape Caculate kv cache shape
""" """
return ( if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp":
max_num_blocks, return (
self.kv_num_heads, max_num_blocks,
self.block_size, self.kv_num_heads,
self.head_dim, self.block_size,
) self.head_dim // 2,
)
else:
return (
max_num_blocks,
self.kv_num_heads,
self.block_size,
self.head_dim,
)
def init_attention_metadata(self, forward_meta: ForwardMeta): def init_attention_metadata(self, forward_meta: ForwardMeta):
metadata = FlashAttentionMetadata() metadata = FlashAttentionMetadata()

View File

@@ -132,6 +132,7 @@ class IluvatarAttnBackend(AttentionBackend):
def get_kv_cache_shape( def get_kv_cache_shape(
self, self,
max_num_blocks: int, max_num_blocks: int,
kv_cache_quant_type: str = None,
): ):
""" """
Caculate kv cache shape Caculate kv cache shape

View File

@@ -217,14 +217,17 @@ class MLAAttentionBackend(AttentionBackend):
self.attention_metadata: AttentionMetadata = metadata self.attention_metadata: AttentionMetadata = metadata
forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False) forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False)
forward_meta.decoder_tile_ids_per_batch.copy_( forward_meta.decoder_tile_ids_per_batch.copy_(metadata.decoder_tile_ids_per_batch, False)
metadata.decoder_tile_ids_per_batch, False)
def get_attntion_meta(self) -> AttentionMetadata: def get_attntion_meta(self) -> AttentionMetadata:
"""get_attntion_meta""" """get_attntion_meta"""
return self.attention_metadata return self.attention_metadata
def get_kv_cache_shape(self, max_num_blocks: int) -> Tuple[int, int, int, int]: def get_kv_cache_shape(
self,
max_num_blocks: int,
kv_cache_quant_type: str = None,
) -> Tuple[int, int, int, int]:
""" """
Calculate kv cache shape for MLA Calculate kv cache shape for MLA
""" """

View File

@@ -146,6 +146,7 @@ class XPUAttentionBackend(AttentionBackend):
def get_kv_cache_shape( def get_kv_cache_shape(
self, self,
max_num_blocks: int, max_num_blocks: int,
kv_cache_quant_type: str = None,
) -> Tuple[int, int, int, int]: ) -> Tuple[int, int, int, int]:
""" """
Caculate kv cache shape Caculate kv cache shape

View File

@@ -211,6 +211,7 @@ class GCUFlashAttnBackend(AttentionBackend):
def get_kv_cache_shape( def get_kv_cache_shape(
self, self,
max_num_blocks: int, max_num_blocks: int,
kv_cache_quant_type: str = None,
): ):
""" """
Caculate kv cache shape Caculate kv cache shape

View File

@@ -222,6 +222,7 @@ class GCUMemEfficientAttnBackend(AttentionBackend):
def get_kv_cache_shape( def get_kv_cache_shape(
self, self,
max_num_blocks: int, max_num_blocks: int,
kv_cache_quant_type: str = None,
): ):
""" """
Caculate kv cache shape Caculate kv cache shape

View File

@@ -34,6 +34,7 @@ class KvCacheQuantzationTypes(str, Enum):
INT8 = "int8" INT8 = "int8"
FP8 = "float8_e4m3fn" FP8 = "float8_e4m3fn"
INT8_ZP = "int8_zp" INT8_ZP = "int8_zp"
INT4_ZP = "int4_zp"
FP8_ZP = "float8_e4m3fn_zp" FP8_ZP = "float8_e4m3fn_zp"
@@ -42,24 +43,29 @@ class KvCacheQuantConfig(QuantConfigBase):
quantization config for weight 4bits and activation fp8 quantization config for weight 4bits and activation fp8
""" """
def __init__(self, kv_cache_quant_type: str) -> None: def __init__(self, kv_cache_quant_type: str, is_channel_wise: bool, has_zero_point: bool) -> None:
""" """
__init__ __init__
""" """
super().__init__() super().__init__()
self.kv_cache_quant_type = kv_cache_quant_type self.kv_cache_quant_type = kv_cache_quant_type
self.is_channel_wise = is_channel_wise
self.has_zero_point = has_zero_point
try: try:
self.quant_type = KvCacheQuantzationTypes(kv_cache_quant_type) self.quant_type = KvCacheQuantzationTypes(kv_cache_quant_type)
except ValueError: except ValueError:
raise ValueError(f"Invalid Kvcache type: {kv_cache_quant_type}") raise ValueError(f"Invalid Kvcache type: {kv_cache_quant_type}")
self.has_zero_point = "zp" in kv_cache_quant_type if "zp" in kv_cache_quant_type:
self.has_zero_point = True
if self.quant_type == KvCacheQuantzationTypes.INT8 or self.quant_type == KvCacheQuantzationTypes.INT8_ZP: if self.quant_type == KvCacheQuantzationTypes.INT8 or self.quant_type == KvCacheQuantzationTypes.INT8_ZP:
self.max_bound = 127.0 self.max_bound = 127.0
elif self.quant_type == KvCacheQuantzationTypes.FP8 or self.quant_type == KvCacheQuantzationTypes.FP8_ZP: elif self.quant_type == KvCacheQuantzationTypes.FP8 or self.quant_type == KvCacheQuantzationTypes.FP8_ZP:
self.max_bound = 448.0 self.max_bound = 448.0
elif self.quant_type == KvCacheQuantzationTypes.INT4_ZP:
self.max_bound = 7.0
else: else:
raise ValueError(f"Invalid Kvcache type: {kv_cache_quant_type}") raise ValueError(f"Invalid Kvcache type: {kv_cache_quant_type}")
@@ -70,11 +76,13 @@ class KvCacheQuantConfig(QuantConfigBase):
return "kvcache" return "kvcache"
@classmethod @classmethod
def from_config(cls, kv_cache_quant_type: str) -> "KvCacheQuantConfig": def from_config(
cls, kv_cache_quant_type: str, is_channel_wise: bool, has_zero_point: bool
) -> "KvCacheQuantConfig":
""" """
from_config from_config
""" """
return cls(kv_cache_quant_type) return cls(kv_cache_quant_type, is_channel_wise, has_zero_point)
def get_quant_method(self, layer) -> Optional[QuantMethodBase]: def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
""" """
@@ -102,8 +110,8 @@ class KVCacheMethodBase(QuantMethodBase):
""" """
load_zp load_zp
""" """
cache_k_zeropoint = get_tensor(state_dict.pop(self.cache_k_zp_name)) cache_k_zeropoint = get_tensor(state_dict.pop(self.cache_k_zp_name)).cast(paddle.get_default_dtype())
cache_v_zeropoint = get_tensor(state_dict.pop(self.cache_v_zp_name)) cache_v_zeropoint = get_tensor(state_dict.pop(self.cache_v_zp_name)).cast(paddle.get_default_dtype())
create_and_set_parameter(layer, "cache_k_zp", cache_k_zeropoint) create_and_set_parameter(layer, "cache_k_zp", cache_k_zeropoint)
create_and_set_parameter(layer, "cache_v_zp", cache_v_zeropoint) create_and_set_parameter(layer, "cache_v_zp", cache_v_zeropoint)
@@ -112,17 +120,36 @@ class KVCacheMethodBase(QuantMethodBase):
""" """
load_scale load_scale
""" """
cache_k_scale_tensor = (
get_tensor(state_dict.pop(self.cache_k_scale_name)).cast(paddle.get_default_dtype()).reshape_([-1])
)
cache_v_scale_tensor = (
get_tensor(state_dict.pop(self.cache_v_scale_name)).cast(paddle.get_default_dtype()).reshape_([-1])
)
cache_k_scale = self.cache_quant_config.max_bound / cache_k_scale_tensor if self.cache_quant_config.is_channel_wise:
cache_v_scale = self.cache_quant_config.max_bound / cache_v_scale_tensor cache_k_scale_tensor = (
cache_k_out_scale = cache_k_scale_tensor / self.cache_quant_config.max_bound get_tensor(state_dict.pop(self.cache_k_scale_name))
cache_v_out_scale = cache_v_scale_tensor / self.cache_quant_config.max_bound .cast(paddle.get_default_dtype())
.reshape_([-1, layer.head_dim])
)
cache_v_scale_tensor = (
get_tensor(state_dict.pop(self.cache_v_scale_name))
.cast(paddle.get_default_dtype())
.reshape_([-1, layer.head_dim])
)
else:
cache_k_scale_tensor = (
get_tensor(state_dict.pop(self.cache_k_scale_name)).cast(paddle.get_default_dtype()).reshape_([-1])
)
cache_v_scale_tensor = (
get_tensor(state_dict.pop(self.cache_v_scale_name)).cast(paddle.get_default_dtype()).reshape_([-1])
)
if self.cache_quant_config.has_zero_point: # cache_int4_zp
cache_k_scale = 1.0 / cache_k_scale_tensor
cache_v_scale = 1.0 / cache_v_scale_tensor
cache_k_out_scale = cache_k_scale_tensor
cache_v_out_scale = cache_v_scale_tensor
else:
cache_k_scale = self.cache_quant_config.max_bound / cache_k_scale_tensor
cache_v_scale = self.cache_quant_config.max_bound / cache_v_scale_tensor
cache_k_out_scale = cache_k_scale_tensor / self.cache_quant_config.max_bound
cache_v_out_scale = cache_v_scale_tensor / self.cache_quant_config.max_bound
create_and_set_parameter(layer, "cache_k_scale", cache_k_scale) create_and_set_parameter(layer, "cache_k_scale", cache_k_scale)
create_and_set_parameter(layer, "cache_v_scale", cache_v_scale) create_and_set_parameter(layer, "cache_v_scale", cache_v_scale)
@@ -147,6 +174,10 @@ class KVCacheMethodBase(QuantMethodBase):
layer.cache_quant_type_str = "cache_fp8" layer.cache_quant_type_str = "cache_fp8"
layer.quant_max_bound = 448.0 layer.quant_max_bound = 448.0
layer.quant_min_bound = -448.0 layer.quant_min_bound = -448.0
elif self.cache_quant_config.quant_type == KvCacheQuantzationTypes.INT4_ZP:
layer.cache_quant_type_str = "cache_int4_zp"
layer.quant_max_bound = 7.0
layer.quant_min_bound = -7.0
else: else:
raise NotImplementedError(f"{self.cache_quant_config.quant_type} is not implemented") raise NotImplementedError(f"{self.cache_quant_config.quant_type} is not implemented")

View File

@@ -34,6 +34,8 @@ class MixQuantConfig(QuantConfigBase):
moe_quant_type: str, moe_quant_type: str,
kv_cache_quant_type: str = None, kv_cache_quant_type: str = None,
image_moe_quant_type: str = None, image_moe_quant_type: str = None,
is_channel_wise: bool = False,
has_zero_point: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.dense_quant_type = dense_quant_type self.dense_quant_type = dense_quant_type
@@ -43,6 +45,8 @@ class MixQuantConfig(QuantConfigBase):
self.image_moe_quant_type = moe_quant_type self.image_moe_quant_type = moe_quant_type
else: else:
self.image_moe_quant_type = image_moe_quant_type self.image_moe_quant_type = image_moe_quant_type
self.is_channel_wise = is_channel_wise
self.has_zero_point = has_zero_point
self.quant_max_bound = 0 self.quant_max_bound = 0
self.quant_min_bound = 0 self.quant_min_bound = 0
self.quant_round_type = 0 self.quant_round_type = 0
@@ -57,6 +61,8 @@ class MixQuantConfig(QuantConfigBase):
config["moe_quant_type"], config["moe_quant_type"],
config.get("kv_cache_quant_type", None), config.get("kv_cache_quant_type", None),
config.get("image_moe_quant_type", None), config.get("image_moe_quant_type", None),
config.get("is_channel_wise", False),
config.get("has_zero_point", False),
) )
def get_quant_method(self, layer) -> Optional[QuantMethodBase]: def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
@@ -67,7 +73,11 @@ class MixQuantConfig(QuantConfigBase):
return get_quantization_config(self.moe_quant_type).from_config({}).get_quant_method(layer) return get_quantization_config(self.moe_quant_type).from_config({}).get_quant_method(layer)
elif isinstance(layer, Attention): elif isinstance(layer, Attention):
if self.kv_cache_quant_type is not None: if self.kv_cache_quant_type is not None:
return get_quantization_config("kvcache").from_config(self.kv_cache_quant_type).get_quant_method(layer) return (
get_quantization_config("kvcache")
.from_config(self.kv_cache_quant_type, self.is_channel_wise, self.has_zero_point)
.get_quant_method(layer)
)
else: else:
return None return None
else: else:

View File

@@ -127,15 +127,19 @@ class MTPProposer(Proposer):
cache_type = self.parallel_config.dtype cache_type = self.parallel_config.dtype
kv_cache_quant_type = None
if ( if (
self.quant_config self.quant_config
and hasattr(self.quant_config, "kv_cache_quant_type") and hasattr(self.quant_config, "kv_cache_quant_type")
and self.quant_config.kv_cache_quant_type is not None and self.quant_config.kv_cache_quant_type is not None
): ):
cache_type = "uint8" cache_type = "uint8"
kv_cache_quant_type = self.quant_config.kv_cache_quant_type
# Get kv cache shape # Get kv cache shape
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=self.num_gpu_blocks) kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
)
if not self.parallel_config.do_profile and ( if not self.parallel_config.do_profile and (
self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed" self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"
): ):

View File

@@ -568,15 +568,19 @@ class GCUModelRunner(ModelRunnerBase):
# Get kv cache dtype # Get kv cache dtype
cache_type = self.parallel_config.dtype cache_type = self.parallel_config.dtype
kv_cache_quant_type = None
if ( if (
self.quant_config self.quant_config
and hasattr(self.quant_config, "kv_cache_quant_type") and hasattr(self.quant_config, "kv_cache_quant_type")
and self.quant_config.kv_cache_quant_type is not None and self.quant_config.kv_cache_quant_type is not None
): ):
cache_type = "uint8" cache_type = "uint8"
kv_cache_quant_type = self.quant_config.kv_cache_quant_type
# Get kv cache shape # Get kv cache shape
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=max_block_num) kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
max_num_blocks=max_block_num, kv_cache_quant_type=kv_cache_quant_type
)
# local_rank = self.local_rank % self.parallel_config.tensor_parallel_size # local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
if not profile and ( if not profile and (

View File

@@ -810,15 +810,19 @@ class GPUModelRunner(ModelRunnerBase):
# Get kv cache dtype # Get kv cache dtype
cache_type = self.parallel_config.dtype cache_type = self.parallel_config.dtype
kv_cache_quant_type = None
if ( if (
self.quant_config self.quant_config
and hasattr(self.quant_config, "kv_cache_quant_type") and hasattr(self.quant_config, "kv_cache_quant_type")
and self.quant_config.kv_cache_quant_type is not None and self.quant_config.kv_cache_quant_type is not None
): ):
cache_type = "uint8" cache_type = "uint8"
kv_cache_quant_type = self.quant_config.kv_cache_quant_type
# Get kv cache shape # Get kv cache shape
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=max_block_num) kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
max_num_blocks=max_block_num, kv_cache_quant_type=kv_cache_quant_type
)
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
if not profile and ( if not profile and (

View File

@@ -559,15 +559,19 @@ class IluvatarModelRunner(ModelRunnerBase):
# Get kv cache dtype # Get kv cache dtype
cache_type = self.parallel_config.dtype cache_type = self.parallel_config.dtype
kv_cache_quant_type = None
if ( if (
self.quant_config self.quant_config
and hasattr(self.quant_config, "kv_cache_quant_type") and hasattr(self.quant_config, "kv_cache_quant_type")
and self.quant_config.kv_cache_quant_type is not None and self.quant_config.kv_cache_quant_type is not None
): ):
cache_type = "uint8" cache_type = "uint8"
kv_cache_quant_type = self.quant_config.kv_cache_quant_type
# Get kv cache shape # Get kv cache shape
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=max_block_num) kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
max_num_blocks=max_block_num, kv_cache_quant_type=kv_cache_quant_type
)
if not self.parallel_config.do_profile and ( if not self.parallel_config.do_profile and (
self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed" self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"

View File

@@ -520,14 +520,19 @@ class XPUModelRunner(ModelRunnerBase):
cache_type = self.parallel_config.dtype cache_type = self.parallel_config.dtype
kv_cache_quant_type = None
if ( if (
self.quant_config self.quant_config
and hasattr(self.quant_config, "kv_cache_quant_type") and hasattr(self.quant_config, "kv_cache_quant_type")
and self.quant_config.kv_cache_quant_type is not None and self.quant_config.kv_cache_quant_type is not None
): ):
cache_type = "uint8" cache_type = "uint8"
kv_cache_quant_type = self.quant_config.kv_cache_quant_type
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=max_block_num) # Get kv cache shape
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
max_num_blocks=max_block_num, kv_cache_quant_type=kv_cache_quant_type
)
for i in range(self.model_config.num_hidden_layers): for i in range(self.model_config.num_hidden_layers):
cache_kvs[f"key_caches_{i}"] = paddle.full( cache_kvs[f"key_caches_{i}"] = paddle.full(