mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
support c4 attn && fix cache
This commit is contained in:
@@ -1130,6 +1130,10 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
LoadOutScaleT out_scale_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
#pragma unroll
|
||||
for (int v_i = 0; v_i < VecSize; v_i++) {
|
||||
bias_vec[v_i] = 0;
|
||||
}
|
||||
const InT* qkv_now = quant_qkv + token_id * hidden_size;
|
||||
T* qkv_out_now = qkv_out + token_id * hidden_size;
|
||||
#pragma unroll
|
||||
@@ -1137,8 +1141,8 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
head_bias += 32 * VecSize) {
|
||||
const int bias_idx = head_idx * HeadDim + head_bias;
|
||||
Load<InT, VecSize>(&qkv_now[bias_idx], &src_vec);
|
||||
Load<T, VecSize>(&qkv_biases[bias_idx], &bias_vec);
|
||||
Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
|
||||
// Load<T, VecSize>(&qkv_biases[bias_idx], &bias_vec);
|
||||
// Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
@@ -1148,10 +1152,10 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
// dequant + add_bias + rope
|
||||
float input_left = static_cast<float>(src_vec[2 * i]);
|
||||
float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
||||
input_left = input_left * out_scale_vec[2 * i] +
|
||||
static_cast<float>(bias_vec[2 * i]);
|
||||
input_right = input_right * out_scale_vec[2 * i + 1] +
|
||||
static_cast<float>(bias_vec[2 * i + 1]);
|
||||
// input_left = input_left * out_scale_vec[2 * i] +
|
||||
// static_cast<float>(bias_vec[2 * i]);
|
||||
// input_right = input_right * out_scale_vec[2 * i + 1] +
|
||||
// static_cast<float>(bias_vec[2 * i + 1]);
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
bias_vec[2 * i] =
|
||||
@@ -1167,6 +1171,35 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
using LoadPadKVT = AlignedVector<uint8_t, KV_VEC_SIZE>;
|
||||
const uint32_t kv_head_idx = (head_idx - num_heads) % gqa_group_size;
|
||||
|
||||
if (block_offset == 0) {
|
||||
// pad zero for this kv_head_idx for this block
|
||||
LoadPadKVT pad_cache_vec;
|
||||
*(reinterpret_cast<uint4*>(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0);
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
constexpr int num_vecs_per_head_dim = half_head_size / KV_VEC_SIZE;
|
||||
constexpr int num_token_each_time = 32 / num_vecs_per_head_dim;
|
||||
const uint32_t tgt_idx = (block_idx * gqa_group_size + kv_head_idx) *
|
||||
block_size * half_head_size +
|
||||
lane_id % num_vecs_per_head_dim * KV_VEC_SIZE;
|
||||
for (int block_i = lane_id / num_vecs_per_head_dim;
|
||||
block_i < block_size;
|
||||
block_i += num_token_each_time) {
|
||||
Store<uint8_t, KV_VEC_SIZE>(
|
||||
pad_cache_vec, &key_cache[tgt_idx + block_i * half_head_size]);
|
||||
}
|
||||
} else {
|
||||
const int num_vecs_per_head_dim = half_block_size / KV_VEC_SIZE;
|
||||
const int num_token_each_time = 32 / num_vecs_per_head_dim;
|
||||
const uint32_t tgt_idx = (block_idx * gqa_group_size + kv_head_idx) *
|
||||
HeadDim * half_block_size +
|
||||
lane_id % num_vecs_per_head_dim * KV_VEC_SIZE;
|
||||
for (int block_i = lane_id / num_vecs_per_head_dim; block_i < HeadDim;
|
||||
block_i += num_token_each_time) {
|
||||
Store<uint8_t, KV_VEC_SIZE>(
|
||||
pad_cache_vec, &value_cache[tgt_idx + block_i * half_block_size]);
|
||||
}
|
||||
}
|
||||
}
|
||||
constexpr int K_VEC_SIZE = 4;
|
||||
constexpr int HALF_K_VEC_SIZE = 2;
|
||||
using LoadKVResT = AlignedVector<uint8_t, K_VEC_SIZE>;
|
||||
@@ -1182,7 +1215,11 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
LoadScaleT zp_vec1, zp_vec2;
|
||||
LoadEmbT cos_emb_vec1, cos_emb_vec2;
|
||||
LoadEmbT sin_emb_vec1, sin_emb_vec2;
|
||||
|
||||
#pragma unroll
|
||||
for (int v_i = 0; v_i < HALF_K_VEC_SIZE; v_i++) {
|
||||
bias_vec1[v_i] = 0;
|
||||
bias_vec2[v_i] = 0;
|
||||
}
|
||||
const InT* qkv_now = quant_qkv + token_id * hidden_size;
|
||||
const int head_bias = lane_id / 4 * 16 + lane_id % 4 * 2;
|
||||
//////////
|
||||
@@ -1191,11 +1228,11 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
Load<InT, HALF_K_VEC_SIZE>(&qkv_now[bias_idx], &src_vec1);
|
||||
Load<InT, HALF_K_VEC_SIZE>(&qkv_now[bias_idx + 8], &src_vec2);
|
||||
/////
|
||||
Load<T, HALF_K_VEC_SIZE>(&qkv_biases[bias_idx], &bias_vec1);
|
||||
Load<T, HALF_K_VEC_SIZE>(&qkv_biases[bias_idx + 8], &bias_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&qkv_out_scales[bias_idx], &out_scale_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&qkv_out_scales[bias_idx + 8],
|
||||
&out_scale_vec2);
|
||||
// Load<T, HALF_K_VEC_SIZE>(&qkv_biases[bias_idx], &bias_vec1);
|
||||
// Load<T, HALF_K_VEC_SIZE>(&qkv_biases[bias_idx + 8], &bias_vec2);
|
||||
// Load<float, HALF_K_VEC_SIZE>(&qkv_out_scales[bias_idx], &out_scale_vec1);
|
||||
// Load<float, HALF_K_VEC_SIZE>(&qkv_out_scales[bias_idx + 8],
|
||||
// &out_scale_vec2);
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||
@@ -1215,10 +1252,10 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
|
||||
float input_left = static_cast<float>(src_vec1[0]);
|
||||
float input_right = static_cast<float>(src_vec1[1]);
|
||||
input_left =
|
||||
input_left * out_scale_vec1[0] + static_cast<float>(bias_vec1[0]);
|
||||
input_right =
|
||||
input_right * out_scale_vec1[1] + static_cast<float>(bias_vec1[1]);
|
||||
// input_left =
|
||||
// input_left * out_scale_vec1[0] + static_cast<float>(bias_vec1[0]);
|
||||
// input_right =
|
||||
// input_right * out_scale_vec1[1] + static_cast<float>(bias_vec1[1]);
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
float cos_tmp = cos_emb_vec1[0];
|
||||
float sin_tmp = sin_emb_vec1[0];
|
||||
@@ -1233,10 +1270,10 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
|
||||
input_left = static_cast<float>(src_vec2[0]);
|
||||
input_right = static_cast<float>(src_vec2[1]);
|
||||
input_left =
|
||||
input_left * out_scale_vec2[0] + static_cast<float>(bias_vec2[0]);
|
||||
input_right =
|
||||
input_right * out_scale_vec2[1] + static_cast<float>(bias_vec2[1]);
|
||||
// input_left =
|
||||
// input_left * out_scale_vec2[0] + static_cast<float>(bias_vec2[0]);
|
||||
// input_right =
|
||||
// input_right * out_scale_vec2[1] + static_cast<float>(bias_vec2[1]);
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
float cos_tmp = cos_emb_vec2[0];
|
||||
float sin_tmp = sin_emb_vec2[0];
|
||||
|
@@ -18,7 +18,7 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import paddle
|
||||
|
||||
@@ -191,10 +191,19 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
def get_kv_cache_shape(
|
||||
self,
|
||||
max_num_blocks: int,
|
||||
) -> Tuple[int, int, int, int]:
|
||||
kv_cache_quant_type: str = None,
|
||||
):
|
||||
"""
|
||||
Caculate kv cache shape
|
||||
"""
|
||||
if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp":
|
||||
return (
|
||||
max_num_blocks,
|
||||
self.kv_num_heads,
|
||||
self.block_size,
|
||||
self.head_dim // 2,
|
||||
)
|
||||
else:
|
||||
return (
|
||||
max_num_blocks,
|
||||
self.kv_num_heads,
|
||||
|
@@ -116,10 +116,19 @@ class BlockAttentionBackend(AttentionBackend):
|
||||
def get_kv_cache_shape(
|
||||
self,
|
||||
max_num_blocks: int,
|
||||
kv_cache_quant_type: str = None,
|
||||
):
|
||||
"""
|
||||
Caculate kv cache shape
|
||||
"""
|
||||
if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp":
|
||||
return (
|
||||
max_num_blocks,
|
||||
self.kv_num_heads,
|
||||
self.block_size,
|
||||
self.head_dim // 2,
|
||||
)
|
||||
else:
|
||||
return (
|
||||
max_num_blocks,
|
||||
self.kv_num_heads,
|
||||
|
@@ -136,10 +136,19 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
def get_kv_cache_shape(
|
||||
self,
|
||||
max_num_blocks: int,
|
||||
kv_cache_quant_type: str = None,
|
||||
):
|
||||
"""
|
||||
Caculate kv cache shape
|
||||
"""
|
||||
if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp":
|
||||
return (
|
||||
max_num_blocks,
|
||||
self.kv_num_heads,
|
||||
self.block_size,
|
||||
self.head_dim // 2,
|
||||
)
|
||||
else:
|
||||
return (
|
||||
max_num_blocks,
|
||||
self.kv_num_heads,
|
||||
|
@@ -132,6 +132,7 @@ class IluvatarAttnBackend(AttentionBackend):
|
||||
def get_kv_cache_shape(
|
||||
self,
|
||||
max_num_blocks: int,
|
||||
kv_cache_quant_type: str = None,
|
||||
):
|
||||
"""
|
||||
Caculate kv cache shape
|
||||
|
@@ -217,14 +217,17 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
self.attention_metadata: AttentionMetadata = metadata
|
||||
|
||||
forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False)
|
||||
forward_meta.decoder_tile_ids_per_batch.copy_(
|
||||
metadata.decoder_tile_ids_per_batch, False)
|
||||
forward_meta.decoder_tile_ids_per_batch.copy_(metadata.decoder_tile_ids_per_batch, False)
|
||||
|
||||
def get_attntion_meta(self) -> AttentionMetadata:
|
||||
"""get_attntion_meta"""
|
||||
return self.attention_metadata
|
||||
|
||||
def get_kv_cache_shape(self, max_num_blocks: int) -> Tuple[int, int, int, int]:
|
||||
def get_kv_cache_shape(
|
||||
self,
|
||||
max_num_blocks: int,
|
||||
kv_cache_quant_type: str = None,
|
||||
) -> Tuple[int, int, int, int]:
|
||||
"""
|
||||
Calculate kv cache shape for MLA
|
||||
"""
|
||||
|
@@ -146,6 +146,7 @@ class XPUAttentionBackend(AttentionBackend):
|
||||
def get_kv_cache_shape(
|
||||
self,
|
||||
max_num_blocks: int,
|
||||
kv_cache_quant_type: str = None,
|
||||
) -> Tuple[int, int, int, int]:
|
||||
"""
|
||||
Caculate kv cache shape
|
||||
|
@@ -211,6 +211,7 @@ class GCUFlashAttnBackend(AttentionBackend):
|
||||
def get_kv_cache_shape(
|
||||
self,
|
||||
max_num_blocks: int,
|
||||
kv_cache_quant_type: str = None,
|
||||
):
|
||||
"""
|
||||
Caculate kv cache shape
|
||||
|
@@ -222,6 +222,7 @@ class GCUMemEfficientAttnBackend(AttentionBackend):
|
||||
def get_kv_cache_shape(
|
||||
self,
|
||||
max_num_blocks: int,
|
||||
kv_cache_quant_type: str = None,
|
||||
):
|
||||
"""
|
||||
Caculate kv cache shape
|
||||
|
@@ -34,6 +34,7 @@ class KvCacheQuantzationTypes(str, Enum):
|
||||
INT8 = "int8"
|
||||
FP8 = "float8_e4m3fn"
|
||||
INT8_ZP = "int8_zp"
|
||||
INT4_ZP = "int4_zp"
|
||||
FP8_ZP = "float8_e4m3fn_zp"
|
||||
|
||||
|
||||
@@ -42,24 +43,29 @@ class KvCacheQuantConfig(QuantConfigBase):
|
||||
quantization config for weight 4bits and activation fp8
|
||||
"""
|
||||
|
||||
def __init__(self, kv_cache_quant_type: str) -> None:
|
||||
def __init__(self, kv_cache_quant_type: str, is_channel_wise: bool, has_zero_point: bool) -> None:
|
||||
"""
|
||||
__init__
|
||||
"""
|
||||
super().__init__()
|
||||
self.kv_cache_quant_type = kv_cache_quant_type
|
||||
self.is_channel_wise = is_channel_wise
|
||||
self.has_zero_point = has_zero_point
|
||||
|
||||
try:
|
||||
self.quant_type = KvCacheQuantzationTypes(kv_cache_quant_type)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid Kvcache type: {kv_cache_quant_type}")
|
||||
|
||||
self.has_zero_point = "zp" in kv_cache_quant_type
|
||||
if "zp" in kv_cache_quant_type:
|
||||
self.has_zero_point = True
|
||||
|
||||
if self.quant_type == KvCacheQuantzationTypes.INT8 or self.quant_type == KvCacheQuantzationTypes.INT8_ZP:
|
||||
self.max_bound = 127.0
|
||||
elif self.quant_type == KvCacheQuantzationTypes.FP8 or self.quant_type == KvCacheQuantzationTypes.FP8_ZP:
|
||||
self.max_bound = 448.0
|
||||
elif self.quant_type == KvCacheQuantzationTypes.INT4_ZP:
|
||||
self.max_bound = 7.0
|
||||
else:
|
||||
raise ValueError(f"Invalid Kvcache type: {kv_cache_quant_type}")
|
||||
|
||||
@@ -70,11 +76,13 @@ class KvCacheQuantConfig(QuantConfigBase):
|
||||
return "kvcache"
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, kv_cache_quant_type: str) -> "KvCacheQuantConfig":
|
||||
def from_config(
|
||||
cls, kv_cache_quant_type: str, is_channel_wise: bool, has_zero_point: bool
|
||||
) -> "KvCacheQuantConfig":
|
||||
"""
|
||||
from_config
|
||||
"""
|
||||
return cls(kv_cache_quant_type)
|
||||
return cls(kv_cache_quant_type, is_channel_wise, has_zero_point)
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
"""
|
||||
@@ -102,8 +110,8 @@ class KVCacheMethodBase(QuantMethodBase):
|
||||
"""
|
||||
load_zp
|
||||
"""
|
||||
cache_k_zeropoint = get_tensor(state_dict.pop(self.cache_k_zp_name))
|
||||
cache_v_zeropoint = get_tensor(state_dict.pop(self.cache_v_zp_name))
|
||||
cache_k_zeropoint = get_tensor(state_dict.pop(self.cache_k_zp_name)).cast(paddle.get_default_dtype())
|
||||
cache_v_zeropoint = get_tensor(state_dict.pop(self.cache_v_zp_name)).cast(paddle.get_default_dtype())
|
||||
|
||||
create_and_set_parameter(layer, "cache_k_zp", cache_k_zeropoint)
|
||||
create_and_set_parameter(layer, "cache_v_zp", cache_v_zeropoint)
|
||||
@@ -112,6 +120,19 @@ class KVCacheMethodBase(QuantMethodBase):
|
||||
"""
|
||||
load_scale
|
||||
"""
|
||||
|
||||
if self.cache_quant_config.is_channel_wise:
|
||||
cache_k_scale_tensor = (
|
||||
get_tensor(state_dict.pop(self.cache_k_scale_name))
|
||||
.cast(paddle.get_default_dtype())
|
||||
.reshape_([-1, layer.head_dim])
|
||||
)
|
||||
cache_v_scale_tensor = (
|
||||
get_tensor(state_dict.pop(self.cache_v_scale_name))
|
||||
.cast(paddle.get_default_dtype())
|
||||
.reshape_([-1, layer.head_dim])
|
||||
)
|
||||
else:
|
||||
cache_k_scale_tensor = (
|
||||
get_tensor(state_dict.pop(self.cache_k_scale_name)).cast(paddle.get_default_dtype()).reshape_([-1])
|
||||
)
|
||||
@@ -119,6 +140,12 @@ class KVCacheMethodBase(QuantMethodBase):
|
||||
get_tensor(state_dict.pop(self.cache_v_scale_name)).cast(paddle.get_default_dtype()).reshape_([-1])
|
||||
)
|
||||
|
||||
if self.cache_quant_config.has_zero_point: # cache_int4_zp
|
||||
cache_k_scale = 1.0 / cache_k_scale_tensor
|
||||
cache_v_scale = 1.0 / cache_v_scale_tensor
|
||||
cache_k_out_scale = cache_k_scale_tensor
|
||||
cache_v_out_scale = cache_v_scale_tensor
|
||||
else:
|
||||
cache_k_scale = self.cache_quant_config.max_bound / cache_k_scale_tensor
|
||||
cache_v_scale = self.cache_quant_config.max_bound / cache_v_scale_tensor
|
||||
cache_k_out_scale = cache_k_scale_tensor / self.cache_quant_config.max_bound
|
||||
@@ -147,6 +174,10 @@ class KVCacheMethodBase(QuantMethodBase):
|
||||
layer.cache_quant_type_str = "cache_fp8"
|
||||
layer.quant_max_bound = 448.0
|
||||
layer.quant_min_bound = -448.0
|
||||
elif self.cache_quant_config.quant_type == KvCacheQuantzationTypes.INT4_ZP:
|
||||
layer.cache_quant_type_str = "cache_int4_zp"
|
||||
layer.quant_max_bound = 7.0
|
||||
layer.quant_min_bound = -7.0
|
||||
else:
|
||||
raise NotImplementedError(f"{self.cache_quant_config.quant_type} is not implemented")
|
||||
|
||||
|
@@ -34,6 +34,8 @@ class MixQuantConfig(QuantConfigBase):
|
||||
moe_quant_type: str,
|
||||
kv_cache_quant_type: str = None,
|
||||
image_moe_quant_type: str = None,
|
||||
is_channel_wise: bool = False,
|
||||
has_zero_point: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dense_quant_type = dense_quant_type
|
||||
@@ -43,6 +45,8 @@ class MixQuantConfig(QuantConfigBase):
|
||||
self.image_moe_quant_type = moe_quant_type
|
||||
else:
|
||||
self.image_moe_quant_type = image_moe_quant_type
|
||||
self.is_channel_wise = is_channel_wise
|
||||
self.has_zero_point = has_zero_point
|
||||
self.quant_max_bound = 0
|
||||
self.quant_min_bound = 0
|
||||
self.quant_round_type = 0
|
||||
@@ -57,6 +61,8 @@ class MixQuantConfig(QuantConfigBase):
|
||||
config["moe_quant_type"],
|
||||
config.get("kv_cache_quant_type", None),
|
||||
config.get("image_moe_quant_type", None),
|
||||
config.get("is_channel_wise", False),
|
||||
config.get("has_zero_point", False),
|
||||
)
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
@@ -67,7 +73,11 @@ class MixQuantConfig(QuantConfigBase):
|
||||
return get_quantization_config(self.moe_quant_type).from_config({}).get_quant_method(layer)
|
||||
elif isinstance(layer, Attention):
|
||||
if self.kv_cache_quant_type is not None:
|
||||
return get_quantization_config("kvcache").from_config(self.kv_cache_quant_type).get_quant_method(layer)
|
||||
return (
|
||||
get_quantization_config("kvcache")
|
||||
.from_config(self.kv_cache_quant_type, self.is_channel_wise, self.has_zero_point)
|
||||
.get_quant_method(layer)
|
||||
)
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
|
@@ -127,15 +127,19 @@ class MTPProposer(Proposer):
|
||||
|
||||
cache_type = self.parallel_config.dtype
|
||||
|
||||
kv_cache_quant_type = None
|
||||
if (
|
||||
self.quant_config
|
||||
and hasattr(self.quant_config, "kv_cache_quant_type")
|
||||
and self.quant_config.kv_cache_quant_type is not None
|
||||
):
|
||||
cache_type = "uint8"
|
||||
kv_cache_quant_type = self.quant_config.kv_cache_quant_type
|
||||
|
||||
# Get kv cache shape
|
||||
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=self.num_gpu_blocks)
|
||||
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
|
||||
max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
|
||||
)
|
||||
if not self.parallel_config.do_profile and (
|
||||
self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"
|
||||
):
|
||||
|
@@ -568,15 +568,19 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
# Get kv cache dtype
|
||||
cache_type = self.parallel_config.dtype
|
||||
|
||||
kv_cache_quant_type = None
|
||||
if (
|
||||
self.quant_config
|
||||
and hasattr(self.quant_config, "kv_cache_quant_type")
|
||||
and self.quant_config.kv_cache_quant_type is not None
|
||||
):
|
||||
cache_type = "uint8"
|
||||
kv_cache_quant_type = self.quant_config.kv_cache_quant_type
|
||||
|
||||
# Get kv cache shape
|
||||
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=max_block_num)
|
||||
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
|
||||
max_num_blocks=max_block_num, kv_cache_quant_type=kv_cache_quant_type
|
||||
)
|
||||
# local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
||||
|
||||
if not profile and (
|
||||
|
@@ -810,15 +810,19 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
# Get kv cache dtype
|
||||
cache_type = self.parallel_config.dtype
|
||||
|
||||
kv_cache_quant_type = None
|
||||
if (
|
||||
self.quant_config
|
||||
and hasattr(self.quant_config, "kv_cache_quant_type")
|
||||
and self.quant_config.kv_cache_quant_type is not None
|
||||
):
|
||||
cache_type = "uint8"
|
||||
kv_cache_quant_type = self.quant_config.kv_cache_quant_type
|
||||
|
||||
# Get kv cache shape
|
||||
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=max_block_num)
|
||||
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
|
||||
max_num_blocks=max_block_num, kv_cache_quant_type=kv_cache_quant_type
|
||||
)
|
||||
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
||||
|
||||
if not profile and (
|
||||
|
@@ -559,15 +559,19 @@ class IluvatarModelRunner(ModelRunnerBase):
|
||||
# Get kv cache dtype
|
||||
cache_type = self.parallel_config.dtype
|
||||
|
||||
kv_cache_quant_type = None
|
||||
if (
|
||||
self.quant_config
|
||||
and hasattr(self.quant_config, "kv_cache_quant_type")
|
||||
and self.quant_config.kv_cache_quant_type is not None
|
||||
):
|
||||
cache_type = "uint8"
|
||||
kv_cache_quant_type = self.quant_config.kv_cache_quant_type
|
||||
|
||||
# Get kv cache shape
|
||||
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=max_block_num)
|
||||
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
|
||||
max_num_blocks=max_block_num, kv_cache_quant_type=kv_cache_quant_type
|
||||
)
|
||||
|
||||
if not self.parallel_config.do_profile and (
|
||||
self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"
|
||||
|
@@ -520,14 +520,19 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
|
||||
cache_type = self.parallel_config.dtype
|
||||
|
||||
kv_cache_quant_type = None
|
||||
if (
|
||||
self.quant_config
|
||||
and hasattr(self.quant_config, "kv_cache_quant_type")
|
||||
and self.quant_config.kv_cache_quant_type is not None
|
||||
):
|
||||
cache_type = "uint8"
|
||||
kv_cache_quant_type = self.quant_config.kv_cache_quant_type
|
||||
|
||||
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=max_block_num)
|
||||
# Get kv cache shape
|
||||
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
|
||||
max_num_blocks=max_block_num, kv_cache_quant_type=kv_cache_quant_type
|
||||
)
|
||||
|
||||
for i in range(self.model_config.num_hidden_layers):
|
||||
cache_kvs[f"key_caches_{i}"] = paddle.full(
|
||||
|
Reference in New Issue
Block a user