[BugFix]Fix load kv cache quant scale (#4077)

* fix kv cache

* fix kv_cache

* fix kv cache
This commit is contained in:
YuanRisheng
2025-09-12 17:44:03 +08:00
committed by GitHub
parent c86b3357ce
commit 88ea565aba
2 changed files with 10 additions and 20 deletions

View File

@@ -63,6 +63,7 @@ class KvCacheQuantConfig(QuantConfigBase):
if self.quant_type == KvCacheQuantzationTypes.INT8 or self.quant_type == KvCacheQuantzationTypes.INT8_ZP:
self.max_bound = 127.0
self.is_channel_wise = True
elif (
self.quant_type == KvCacheQuantzationTypes.FP8
or self.quant_type == KvCacheQuantzationTypes.FP8_ZP
@@ -125,24 +126,12 @@ class KVCacheMethodBase(QuantMethodBase):
load_scale
"""
if self.cache_quant_config.is_channel_wise:
cache_k_scale_tensor = (
get_tensor(state_dict.pop(self.cache_k_scale_name))
.cast(paddle.get_default_dtype())
.reshape_([-1, layer.head_dim])
)
cache_v_scale_tensor = (
get_tensor(state_dict.pop(self.cache_v_scale_name))
.cast(paddle.get_default_dtype())
.reshape_([-1, layer.head_dim])
)
else:
cache_k_scale_tensor = (
get_tensor(state_dict.pop(self.cache_k_scale_name)).cast(paddle.get_default_dtype()).reshape_([-1])
)
cache_v_scale_tensor = (
get_tensor(state_dict.pop(self.cache_v_scale_name)).cast(paddle.get_default_dtype()).reshape_([-1])
)
cache_k_scale_tensor = (
get_tensor(state_dict.pop(self.cache_k_scale_name)).cast(paddle.get_default_dtype()).reshape_([-1])
)
cache_v_scale_tensor = (
get_tensor(state_dict.pop(self.cache_v_scale_name)).cast(paddle.get_default_dtype()).reshape_([-1])
)
if self.cache_quant_config.has_zero_point: # cache_int4_zp
cache_k_scale = 1.0 / cache_k_scale_tensor
@@ -185,7 +174,7 @@ class KVCacheMethodBase(QuantMethodBase):
scale_shape = [layer.fd_config.model_config.num_key_value_heads]
if self.cache_quant_config.is_channel_wise:
scale_shape = [layer.fd_config.model_config.num_key_value_heads, layer.head_dim]
scale_shape = [layer.kv_num_heads * layer.head_dim]
layer.cache_k_scale = layer.create_parameter(
shape=scale_shape,

View File

@@ -37,6 +37,7 @@ class MockLayer(nn.Layer):
self.fd_config = get_default_test_fd_config()
self.fd_config.model_config.num_key_value_heads = 1
self.head_dim = 1
self.kv_num_heads = 1
self.prefix = "mock_layer"
self.cache_k_scale = None
self.cache_v_scale = None
@@ -77,7 +78,7 @@ class TestKVCacheMethodBase(unittest.TestCase):
method = KVCacheMethodBase(config)
method.create_weights(self.layer)
self.assertEqual(self.layer.cache_k_scale.shape, [1, 1])
self.assertEqual(self.layer.cache_k_scale.shape, [1])
def test_create_weights_int4_zp(self):
# Test INT4 with zero point