diff --git a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh index 936d88e87..57612c458 100644 --- a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh @@ -1130,6 +1130,10 @@ __global__ void append_speculate_cache_int4_rope_kernel( LoadOutScaleT out_scale_vec; LoadEmbT cos_emb_vec; LoadEmbT sin_emb_vec; +#pragma unroll + for (int v_i = 0; v_i < VecSize; v_i++) { + bias_vec[v_i] = 0; + } const InT* qkv_now = quant_qkv + token_id * hidden_size; T* qkv_out_now = qkv_out + token_id * hidden_size; #pragma unroll @@ -1137,8 +1141,8 @@ __global__ void append_speculate_cache_int4_rope_kernel( head_bias += 32 * VecSize) { const int bias_idx = head_idx * HeadDim + head_bias; Load(&qkv_now[bias_idx], &src_vec); - Load(&qkv_biases[bias_idx], &bias_vec); - Load(&qkv_out_scales[bias_idx], &out_scale_vec); + // Load(&qkv_biases[bias_idx], &bias_vec); + // Load(&qkv_out_scales[bias_idx], &out_scale_vec); // q rope const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; Load(&cos_emb[emb_idx], &cos_emb_vec); @@ -1148,10 +1152,10 @@ __global__ void append_speculate_cache_int4_rope_kernel( // dequant + add_bias + rope float input_left = static_cast(src_vec[2 * i]); float input_right = static_cast(src_vec[2 * i + 1]); - input_left = input_left * out_scale_vec[2 * i] + - static_cast(bias_vec[2 * i]); - input_right = input_right * out_scale_vec[2 * i + 1] + - static_cast(bias_vec[2 * i + 1]); + // input_left = input_left * out_scale_vec[2 * i] + + // static_cast(bias_vec[2 * i]); + // input_right = input_right * out_scale_vec[2 * i + 1] + + // static_cast(bias_vec[2 * i + 1]); const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; bias_vec[2 * i] = @@ -1167,6 +1171,35 @@ __global__ void append_speculate_cache_int4_rope_kernel( using LoadPadKVT = AlignedVector; const uint32_t kv_head_idx = (head_idx - num_heads) % gqa_group_size; + if (block_offset == 0) { + // pad zero for this kv_head_idx for this block + LoadPadKVT pad_cache_vec; + *(reinterpret_cast(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0); + if (head_idx < num_heads + gqa_group_size) { + constexpr int num_vecs_per_head_dim = half_head_size / KV_VEC_SIZE; + constexpr int num_token_each_time = 32 / num_vecs_per_head_dim; + const uint32_t tgt_idx = (block_idx * gqa_group_size + kv_head_idx) * + block_size * half_head_size + + lane_id % num_vecs_per_head_dim * KV_VEC_SIZE; + for (int block_i = lane_id / num_vecs_per_head_dim; + block_i < block_size; + block_i += num_token_each_time) { + Store( + pad_cache_vec, &key_cache[tgt_idx + block_i * half_head_size]); + } + } else { + const int num_vecs_per_head_dim = half_block_size / KV_VEC_SIZE; + const int num_token_each_time = 32 / num_vecs_per_head_dim; + const uint32_t tgt_idx = (block_idx * gqa_group_size + kv_head_idx) * + HeadDim * half_block_size + + lane_id % num_vecs_per_head_dim * KV_VEC_SIZE; + for (int block_i = lane_id / num_vecs_per_head_dim; block_i < HeadDim; + block_i += num_token_each_time) { + Store( + pad_cache_vec, &value_cache[tgt_idx + block_i * half_block_size]); + } + } + } constexpr int K_VEC_SIZE = 4; constexpr int HALF_K_VEC_SIZE = 2; using LoadKVResT = AlignedVector; @@ -1182,7 +1215,11 @@ __global__ void append_speculate_cache_int4_rope_kernel( LoadScaleT zp_vec1, zp_vec2; LoadEmbT cos_emb_vec1, cos_emb_vec2; LoadEmbT sin_emb_vec1, sin_emb_vec2; - +#pragma unroll + for (int v_i = 0; v_i < HALF_K_VEC_SIZE; v_i++) { + bias_vec1[v_i] = 0; + bias_vec2[v_i] = 0; + } const InT* qkv_now = quant_qkv + token_id * hidden_size; const int head_bias = lane_id / 4 * 16 + lane_id % 4 * 2; ////////// @@ -1191,11 +1228,11 @@ __global__ void append_speculate_cache_int4_rope_kernel( Load(&qkv_now[bias_idx], &src_vec1); Load(&qkv_now[bias_idx + 8], &src_vec2); ///// - Load(&qkv_biases[bias_idx], &bias_vec1); - Load(&qkv_biases[bias_idx + 8], &bias_vec2); - Load(&qkv_out_scales[bias_idx], &out_scale_vec1); - Load(&qkv_out_scales[bias_idx + 8], - &out_scale_vec2); + // Load(&qkv_biases[bias_idx], &bias_vec1); + // Load(&qkv_biases[bias_idx + 8], &bias_vec2); + // Load(&qkv_out_scales[bias_idx], &out_scale_vec1); + // Load(&qkv_out_scales[bias_idx + 8], + // &out_scale_vec2); if (head_idx < num_heads + gqa_group_size) { const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; Load(&cos_emb[emb_idx], &cos_emb_vec1); @@ -1215,10 +1252,10 @@ __global__ void append_speculate_cache_int4_rope_kernel( float input_left = static_cast(src_vec1[0]); float input_right = static_cast(src_vec1[1]); - input_left = - input_left * out_scale_vec1[0] + static_cast(bias_vec1[0]); - input_right = - input_right * out_scale_vec1[1] + static_cast(bias_vec1[1]); + // input_left = + // input_left * out_scale_vec1[0] + static_cast(bias_vec1[0]); + // input_right = + // input_right * out_scale_vec1[1] + static_cast(bias_vec1[1]); if (head_idx < num_heads + gqa_group_size) { float cos_tmp = cos_emb_vec1[0]; float sin_tmp = sin_emb_vec1[0]; @@ -1233,10 +1270,10 @@ __global__ void append_speculate_cache_int4_rope_kernel( input_left = static_cast(src_vec2[0]); input_right = static_cast(src_vec2[1]); - input_left = - input_left * out_scale_vec2[0] + static_cast(bias_vec2[0]); - input_right = - input_right * out_scale_vec2[1] + static_cast(bias_vec2[1]); + // input_left = + // input_left * out_scale_vec2[0] + static_cast(bias_vec2[0]); + // input_right = + // input_right * out_scale_vec2[1] + static_cast(bias_vec2[1]); if (head_idx < num_heads + gqa_group_size) { float cos_tmp = cos_emb_vec2[0]; float sin_tmp = sin_emb_vec2[0]; diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index f9b63066d..b17c512af 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -18,7 +18,7 @@ from __future__ import annotations import os from dataclasses import dataclass, field -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional import paddle @@ -191,16 +191,25 @@ class AppendAttentionBackend(AttentionBackend): def get_kv_cache_shape( self, max_num_blocks: int, - ) -> Tuple[int, int, int, int]: + kv_cache_quant_type: str = None, + ): """ Caculate kv cache shape """ - return ( - max_num_blocks, - self.kv_num_heads, - self.block_size, - self.head_dim, - ) + if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp": + return ( + max_num_blocks, + self.kv_num_heads, + self.block_size, + self.head_dim // 2, + ) + else: + return ( + max_num_blocks, + self.kv_num_heads, + self.block_size, + self.head_dim, + ) def forward_mixed( self, diff --git a/fastdeploy/model_executor/layers/attention/block_multihead_attn_backend.py b/fastdeploy/model_executor/layers/attention/block_multihead_attn_backend.py index a04b80018..27c3c98be 100644 --- a/fastdeploy/model_executor/layers/attention/block_multihead_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/block_multihead_attn_backend.py @@ -116,16 +116,25 @@ class BlockAttentionBackend(AttentionBackend): def get_kv_cache_shape( self, max_num_blocks: int, + kv_cache_quant_type: str = None, ): """ Caculate kv cache shape """ - return ( - max_num_blocks, - self.kv_num_heads, - self.block_size, - self.head_dim, - ) + if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp": + return ( + max_num_blocks, + self.kv_num_heads, + self.block_size, + self.head_dim // 2, + ) + else: + return ( + max_num_blocks, + self.kv_num_heads, + self.block_size, + self.head_dim, + ) def forward_mixed( self, diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index 8290e3986..d5e367fe0 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -136,16 +136,25 @@ class FlashAttentionBackend(AttentionBackend): def get_kv_cache_shape( self, max_num_blocks: int, + kv_cache_quant_type: str = None, ): """ Caculate kv cache shape """ - return ( - max_num_blocks, - self.kv_num_heads, - self.block_size, - self.head_dim, - ) + if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp": + return ( + max_num_blocks, + self.kv_num_heads, + self.block_size, + self.head_dim // 2, + ) + else: + return ( + max_num_blocks, + self.kv_num_heads, + self.block_size, + self.head_dim, + ) def init_attention_metadata(self, forward_meta: ForwardMeta): metadata = FlashAttentionMetadata() diff --git a/fastdeploy/model_executor/layers/attention/iluvatar_attn_backend.py b/fastdeploy/model_executor/layers/attention/iluvatar_attn_backend.py index 5a4bf549e..5e2da4816 100644 --- a/fastdeploy/model_executor/layers/attention/iluvatar_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/iluvatar_attn_backend.py @@ -132,6 +132,7 @@ class IluvatarAttnBackend(AttentionBackend): def get_kv_cache_shape( self, max_num_blocks: int, + kv_cache_quant_type: str = None, ): """ Caculate kv cache shape diff --git a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py index 6f0581e57..d413eb4c2 100644 --- a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py @@ -217,14 +217,17 @@ class MLAAttentionBackend(AttentionBackend): self.attention_metadata: AttentionMetadata = metadata forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False) - forward_meta.decoder_tile_ids_per_batch.copy_( - metadata.decoder_tile_ids_per_batch, False) + forward_meta.decoder_tile_ids_per_batch.copy_(metadata.decoder_tile_ids_per_batch, False) def get_attntion_meta(self) -> AttentionMetadata: """get_attntion_meta""" return self.attention_metadata - def get_kv_cache_shape(self, max_num_blocks: int) -> Tuple[int, int, int, int]: + def get_kv_cache_shape( + self, + max_num_blocks: int, + kv_cache_quant_type: str = None, + ) -> Tuple[int, int, int, int]: """ Calculate kv cache shape for MLA """ diff --git a/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py b/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py index 321fa8327..52e62bf0d 100644 --- a/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py @@ -146,6 +146,7 @@ class XPUAttentionBackend(AttentionBackend): def get_kv_cache_shape( self, max_num_blocks: int, + kv_cache_quant_type: str = None, ) -> Tuple[int, int, int, int]: """ Caculate kv cache shape diff --git a/fastdeploy/model_executor/layers/backends/gcu/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/backends/gcu/attention/flash_attn_backend.py index cf0899062..1c6fdbc64 100644 --- a/fastdeploy/model_executor/layers/backends/gcu/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/backends/gcu/attention/flash_attn_backend.py @@ -211,6 +211,7 @@ class GCUFlashAttnBackend(AttentionBackend): def get_kv_cache_shape( self, max_num_blocks: int, + kv_cache_quant_type: str = None, ): """ Caculate kv cache shape diff --git a/fastdeploy/model_executor/layers/backends/gcu/attention/mem_efficient_attn_backend.py b/fastdeploy/model_executor/layers/backends/gcu/attention/mem_efficient_attn_backend.py index d105a41c2..25ed5358f 100644 --- a/fastdeploy/model_executor/layers/backends/gcu/attention/mem_efficient_attn_backend.py +++ b/fastdeploy/model_executor/layers/backends/gcu/attention/mem_efficient_attn_backend.py @@ -222,6 +222,7 @@ class GCUMemEfficientAttnBackend(AttentionBackend): def get_kv_cache_shape( self, max_num_blocks: int, + kv_cache_quant_type: str = None, ): """ Caculate kv cache shape diff --git a/fastdeploy/model_executor/layers/quantization/kv_cache.py b/fastdeploy/model_executor/layers/quantization/kv_cache.py index 8cc77ae54..d560e6122 100644 --- a/fastdeploy/model_executor/layers/quantization/kv_cache.py +++ b/fastdeploy/model_executor/layers/quantization/kv_cache.py @@ -34,6 +34,7 @@ class KvCacheQuantzationTypes(str, Enum): INT8 = "int8" FP8 = "float8_e4m3fn" INT8_ZP = "int8_zp" + INT4_ZP = "int4_zp" FP8_ZP = "float8_e4m3fn_zp" @@ -42,24 +43,29 @@ class KvCacheQuantConfig(QuantConfigBase): quantization config for weight 4bits and activation fp8 """ - def __init__(self, kv_cache_quant_type: str) -> None: + def __init__(self, kv_cache_quant_type: str, is_channel_wise: bool, has_zero_point: bool) -> None: """ __init__ """ super().__init__() self.kv_cache_quant_type = kv_cache_quant_type + self.is_channel_wise = is_channel_wise + self.has_zero_point = has_zero_point try: self.quant_type = KvCacheQuantzationTypes(kv_cache_quant_type) except ValueError: raise ValueError(f"Invalid Kvcache type: {kv_cache_quant_type}") - self.has_zero_point = "zp" in kv_cache_quant_type + if "zp" in kv_cache_quant_type: + self.has_zero_point = True if self.quant_type == KvCacheQuantzationTypes.INT8 or self.quant_type == KvCacheQuantzationTypes.INT8_ZP: self.max_bound = 127.0 elif self.quant_type == KvCacheQuantzationTypes.FP8 or self.quant_type == KvCacheQuantzationTypes.FP8_ZP: self.max_bound = 448.0 + elif self.quant_type == KvCacheQuantzationTypes.INT4_ZP: + self.max_bound = 7.0 else: raise ValueError(f"Invalid Kvcache type: {kv_cache_quant_type}") @@ -70,11 +76,13 @@ class KvCacheQuantConfig(QuantConfigBase): return "kvcache" @classmethod - def from_config(cls, kv_cache_quant_type: str) -> "KvCacheQuantConfig": + def from_config( + cls, kv_cache_quant_type: str, is_channel_wise: bool, has_zero_point: bool + ) -> "KvCacheQuantConfig": """ from_config """ - return cls(kv_cache_quant_type) + return cls(kv_cache_quant_type, is_channel_wise, has_zero_point) def get_quant_method(self, layer) -> Optional[QuantMethodBase]: """ @@ -102,8 +110,8 @@ class KVCacheMethodBase(QuantMethodBase): """ load_zp """ - cache_k_zeropoint = get_tensor(state_dict.pop(self.cache_k_zp_name)) - cache_v_zeropoint = get_tensor(state_dict.pop(self.cache_v_zp_name)) + cache_k_zeropoint = get_tensor(state_dict.pop(self.cache_k_zp_name)).cast(paddle.get_default_dtype()) + cache_v_zeropoint = get_tensor(state_dict.pop(self.cache_v_zp_name)).cast(paddle.get_default_dtype()) create_and_set_parameter(layer, "cache_k_zp", cache_k_zeropoint) create_and_set_parameter(layer, "cache_v_zp", cache_v_zeropoint) @@ -112,17 +120,36 @@ class KVCacheMethodBase(QuantMethodBase): """ load_scale """ - cache_k_scale_tensor = ( - get_tensor(state_dict.pop(self.cache_k_scale_name)).cast(paddle.get_default_dtype()).reshape_([-1]) - ) - cache_v_scale_tensor = ( - get_tensor(state_dict.pop(self.cache_v_scale_name)).cast(paddle.get_default_dtype()).reshape_([-1]) - ) - cache_k_scale = self.cache_quant_config.max_bound / cache_k_scale_tensor - cache_v_scale = self.cache_quant_config.max_bound / cache_v_scale_tensor - cache_k_out_scale = cache_k_scale_tensor / self.cache_quant_config.max_bound - cache_v_out_scale = cache_v_scale_tensor / self.cache_quant_config.max_bound + if self.cache_quant_config.is_channel_wise: + cache_k_scale_tensor = ( + get_tensor(state_dict.pop(self.cache_k_scale_name)) + .cast(paddle.get_default_dtype()) + .reshape_([-1, layer.head_dim]) + ) + cache_v_scale_tensor = ( + get_tensor(state_dict.pop(self.cache_v_scale_name)) + .cast(paddle.get_default_dtype()) + .reshape_([-1, layer.head_dim]) + ) + else: + cache_k_scale_tensor = ( + get_tensor(state_dict.pop(self.cache_k_scale_name)).cast(paddle.get_default_dtype()).reshape_([-1]) + ) + cache_v_scale_tensor = ( + get_tensor(state_dict.pop(self.cache_v_scale_name)).cast(paddle.get_default_dtype()).reshape_([-1]) + ) + + if self.cache_quant_config.has_zero_point: # cache_int4_zp + cache_k_scale = 1.0 / cache_k_scale_tensor + cache_v_scale = 1.0 / cache_v_scale_tensor + cache_k_out_scale = cache_k_scale_tensor + cache_v_out_scale = cache_v_scale_tensor + else: + cache_k_scale = self.cache_quant_config.max_bound / cache_k_scale_tensor + cache_v_scale = self.cache_quant_config.max_bound / cache_v_scale_tensor + cache_k_out_scale = cache_k_scale_tensor / self.cache_quant_config.max_bound + cache_v_out_scale = cache_v_scale_tensor / self.cache_quant_config.max_bound create_and_set_parameter(layer, "cache_k_scale", cache_k_scale) create_and_set_parameter(layer, "cache_v_scale", cache_v_scale) @@ -147,6 +174,10 @@ class KVCacheMethodBase(QuantMethodBase): layer.cache_quant_type_str = "cache_fp8" layer.quant_max_bound = 448.0 layer.quant_min_bound = -448.0 + elif self.cache_quant_config.quant_type == KvCacheQuantzationTypes.INT4_ZP: + layer.cache_quant_type_str = "cache_int4_zp" + layer.quant_max_bound = 7.0 + layer.quant_min_bound = -7.0 else: raise NotImplementedError(f"{self.cache_quant_config.quant_type} is not implemented") diff --git a/fastdeploy/model_executor/layers/quantization/mix_quant.py b/fastdeploy/model_executor/layers/quantization/mix_quant.py index 0c39cbc63..5f7c6e523 100644 --- a/fastdeploy/model_executor/layers/quantization/mix_quant.py +++ b/fastdeploy/model_executor/layers/quantization/mix_quant.py @@ -34,6 +34,8 @@ class MixQuantConfig(QuantConfigBase): moe_quant_type: str, kv_cache_quant_type: str = None, image_moe_quant_type: str = None, + is_channel_wise: bool = False, + has_zero_point: bool = False, ) -> None: super().__init__() self.dense_quant_type = dense_quant_type @@ -43,6 +45,8 @@ class MixQuantConfig(QuantConfigBase): self.image_moe_quant_type = moe_quant_type else: self.image_moe_quant_type = image_moe_quant_type + self.is_channel_wise = is_channel_wise + self.has_zero_point = has_zero_point self.quant_max_bound = 0 self.quant_min_bound = 0 self.quant_round_type = 0 @@ -57,6 +61,8 @@ class MixQuantConfig(QuantConfigBase): config["moe_quant_type"], config.get("kv_cache_quant_type", None), config.get("image_moe_quant_type", None), + config.get("is_channel_wise", False), + config.get("has_zero_point", False), ) def get_quant_method(self, layer) -> Optional[QuantMethodBase]: @@ -67,7 +73,11 @@ class MixQuantConfig(QuantConfigBase): return get_quantization_config(self.moe_quant_type).from_config({}).get_quant_method(layer) elif isinstance(layer, Attention): if self.kv_cache_quant_type is not None: - return get_quantization_config("kvcache").from_config(self.kv_cache_quant_type).get_quant_method(layer) + return ( + get_quantization_config("kvcache") + .from_config(self.kv_cache_quant_type, self.is_channel_wise, self.has_zero_point) + .get_quant_method(layer) + ) else: return None else: diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index f9a550685..3ae9bf163 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -127,15 +127,19 @@ class MTPProposer(Proposer): cache_type = self.parallel_config.dtype + kv_cache_quant_type = None if ( self.quant_config and hasattr(self.quant_config, "kv_cache_quant_type") and self.quant_config.kv_cache_quant_type is not None ): cache_type = "uint8" + kv_cache_quant_type = self.quant_config.kv_cache_quant_type # Get kv cache shape - kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=self.num_gpu_blocks) + kv_cache_shape = self.attn_backends[0].get_kv_cache_shape( + max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type + ) if not self.parallel_config.do_profile and ( self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed" ): diff --git a/fastdeploy/worker/gcu_model_runner.py b/fastdeploy/worker/gcu_model_runner.py index 751f45432..a67ac75e4 100644 --- a/fastdeploy/worker/gcu_model_runner.py +++ b/fastdeploy/worker/gcu_model_runner.py @@ -568,15 +568,19 @@ class GCUModelRunner(ModelRunnerBase): # Get kv cache dtype cache_type = self.parallel_config.dtype + kv_cache_quant_type = None if ( self.quant_config and hasattr(self.quant_config, "kv_cache_quant_type") and self.quant_config.kv_cache_quant_type is not None ): cache_type = "uint8" + kv_cache_quant_type = self.quant_config.kv_cache_quant_type # Get kv cache shape - kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=max_block_num) + kv_cache_shape = self.attn_backends[0].get_kv_cache_shape( + max_num_blocks=max_block_num, kv_cache_quant_type=kv_cache_quant_type + ) # local_rank = self.local_rank % self.parallel_config.tensor_parallel_size if not profile and ( diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 432a12ddf..45dd69d59 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -810,15 +810,19 @@ class GPUModelRunner(ModelRunnerBase): # Get kv cache dtype cache_type = self.parallel_config.dtype + kv_cache_quant_type = None if ( self.quant_config and hasattr(self.quant_config, "kv_cache_quant_type") and self.quant_config.kv_cache_quant_type is not None ): cache_type = "uint8" + kv_cache_quant_type = self.quant_config.kv_cache_quant_type # Get kv cache shape - kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=max_block_num) + kv_cache_shape = self.attn_backends[0].get_kv_cache_shape( + max_num_blocks=max_block_num, kv_cache_quant_type=kv_cache_quant_type + ) local_rank = self.local_rank % self.parallel_config.tensor_parallel_size if not profile and ( diff --git a/fastdeploy/worker/iluvatar_model_runner.py b/fastdeploy/worker/iluvatar_model_runner.py index 54d6600d3..f110ee64d 100644 --- a/fastdeploy/worker/iluvatar_model_runner.py +++ b/fastdeploy/worker/iluvatar_model_runner.py @@ -559,15 +559,19 @@ class IluvatarModelRunner(ModelRunnerBase): # Get kv cache dtype cache_type = self.parallel_config.dtype + kv_cache_quant_type = None if ( self.quant_config and hasattr(self.quant_config, "kv_cache_quant_type") and self.quant_config.kv_cache_quant_type is not None ): cache_type = "uint8" + kv_cache_quant_type = self.quant_config.kv_cache_quant_type # Get kv cache shape - kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=max_block_num) + kv_cache_shape = self.attn_backends[0].get_kv_cache_shape( + max_num_blocks=max_block_num, kv_cache_quant_type=kv_cache_quant_type + ) if not self.parallel_config.do_profile and ( self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed" diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index 731990ff5..3240b217e 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -520,14 +520,19 @@ class XPUModelRunner(ModelRunnerBase): cache_type = self.parallel_config.dtype + kv_cache_quant_type = None if ( self.quant_config and hasattr(self.quant_config, "kv_cache_quant_type") and self.quant_config.kv_cache_quant_type is not None ): cache_type = "uint8" + kv_cache_quant_type = self.quant_config.kv_cache_quant_type - kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=max_block_num) + # Get kv cache shape + kv_cache_shape = self.attn_backends[0].get_kv_cache_shape( + max_num_blocks=max_block_num, kv_cache_quant_type=kv_cache_quant_type + ) for i in range(self.model_config.num_hidden_layers): cache_kvs[f"key_caches_{i}"] = paddle.full(