mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[BugFix]Fix load kv cache quant scale (#4077)
* fix kv cache * fix kv_cache * fix kv cache
This commit is contained in:
@@ -63,6 +63,7 @@ class KvCacheQuantConfig(QuantConfigBase):
|
||||
|
||||
if self.quant_type == KvCacheQuantzationTypes.INT8 or self.quant_type == KvCacheQuantzationTypes.INT8_ZP:
|
||||
self.max_bound = 127.0
|
||||
self.is_channel_wise = True
|
||||
elif (
|
||||
self.quant_type == KvCacheQuantzationTypes.FP8
|
||||
or self.quant_type == KvCacheQuantzationTypes.FP8_ZP
|
||||
@@ -125,24 +126,12 @@ class KVCacheMethodBase(QuantMethodBase):
|
||||
load_scale
|
||||
"""
|
||||
|
||||
if self.cache_quant_config.is_channel_wise:
|
||||
cache_k_scale_tensor = (
|
||||
get_tensor(state_dict.pop(self.cache_k_scale_name))
|
||||
.cast(paddle.get_default_dtype())
|
||||
.reshape_([-1, layer.head_dim])
|
||||
)
|
||||
cache_v_scale_tensor = (
|
||||
get_tensor(state_dict.pop(self.cache_v_scale_name))
|
||||
.cast(paddle.get_default_dtype())
|
||||
.reshape_([-1, layer.head_dim])
|
||||
)
|
||||
else:
|
||||
cache_k_scale_tensor = (
|
||||
get_tensor(state_dict.pop(self.cache_k_scale_name)).cast(paddle.get_default_dtype()).reshape_([-1])
|
||||
)
|
||||
cache_v_scale_tensor = (
|
||||
get_tensor(state_dict.pop(self.cache_v_scale_name)).cast(paddle.get_default_dtype()).reshape_([-1])
|
||||
)
|
||||
cache_k_scale_tensor = (
|
||||
get_tensor(state_dict.pop(self.cache_k_scale_name)).cast(paddle.get_default_dtype()).reshape_([-1])
|
||||
)
|
||||
cache_v_scale_tensor = (
|
||||
get_tensor(state_dict.pop(self.cache_v_scale_name)).cast(paddle.get_default_dtype()).reshape_([-1])
|
||||
)
|
||||
|
||||
if self.cache_quant_config.has_zero_point: # cache_int4_zp
|
||||
cache_k_scale = 1.0 / cache_k_scale_tensor
|
||||
@@ -185,7 +174,7 @@ class KVCacheMethodBase(QuantMethodBase):
|
||||
|
||||
scale_shape = [layer.fd_config.model_config.num_key_value_heads]
|
||||
if self.cache_quant_config.is_channel_wise:
|
||||
scale_shape = [layer.fd_config.model_config.num_key_value_heads, layer.head_dim]
|
||||
scale_shape = [layer.kv_num_heads * layer.head_dim]
|
||||
|
||||
layer.cache_k_scale = layer.create_parameter(
|
||||
shape=scale_shape,
|
||||
|
Reference in New Issue
Block a user