mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[BugFix]Fix load kv cache quant scale (#4077)
* fix kv cache * fix kv_cache * fix kv cache
This commit is contained in:
@@ -63,6 +63,7 @@ class KvCacheQuantConfig(QuantConfigBase):
|
|||||||
|
|
||||||
if self.quant_type == KvCacheQuantzationTypes.INT8 or self.quant_type == KvCacheQuantzationTypes.INT8_ZP:
|
if self.quant_type == KvCacheQuantzationTypes.INT8 or self.quant_type == KvCacheQuantzationTypes.INT8_ZP:
|
||||||
self.max_bound = 127.0
|
self.max_bound = 127.0
|
||||||
|
self.is_channel_wise = True
|
||||||
elif (
|
elif (
|
||||||
self.quant_type == KvCacheQuantzationTypes.FP8
|
self.quant_type == KvCacheQuantzationTypes.FP8
|
||||||
or self.quant_type == KvCacheQuantzationTypes.FP8_ZP
|
or self.quant_type == KvCacheQuantzationTypes.FP8_ZP
|
||||||
@@ -125,18 +126,6 @@ class KVCacheMethodBase(QuantMethodBase):
|
|||||||
load_scale
|
load_scale
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.cache_quant_config.is_channel_wise:
|
|
||||||
cache_k_scale_tensor = (
|
|
||||||
get_tensor(state_dict.pop(self.cache_k_scale_name))
|
|
||||||
.cast(paddle.get_default_dtype())
|
|
||||||
.reshape_([-1, layer.head_dim])
|
|
||||||
)
|
|
||||||
cache_v_scale_tensor = (
|
|
||||||
get_tensor(state_dict.pop(self.cache_v_scale_name))
|
|
||||||
.cast(paddle.get_default_dtype())
|
|
||||||
.reshape_([-1, layer.head_dim])
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
cache_k_scale_tensor = (
|
cache_k_scale_tensor = (
|
||||||
get_tensor(state_dict.pop(self.cache_k_scale_name)).cast(paddle.get_default_dtype()).reshape_([-1])
|
get_tensor(state_dict.pop(self.cache_k_scale_name)).cast(paddle.get_default_dtype()).reshape_([-1])
|
||||||
)
|
)
|
||||||
@@ -185,7 +174,7 @@ class KVCacheMethodBase(QuantMethodBase):
|
|||||||
|
|
||||||
scale_shape = [layer.fd_config.model_config.num_key_value_heads]
|
scale_shape = [layer.fd_config.model_config.num_key_value_heads]
|
||||||
if self.cache_quant_config.is_channel_wise:
|
if self.cache_quant_config.is_channel_wise:
|
||||||
scale_shape = [layer.fd_config.model_config.num_key_value_heads, layer.head_dim]
|
scale_shape = [layer.kv_num_heads * layer.head_dim]
|
||||||
|
|
||||||
layer.cache_k_scale = layer.create_parameter(
|
layer.cache_k_scale = layer.create_parameter(
|
||||||
shape=scale_shape,
|
shape=scale_shape,
|
||||||
|
@@ -37,6 +37,7 @@ class MockLayer(nn.Layer):
|
|||||||
self.fd_config = get_default_test_fd_config()
|
self.fd_config = get_default_test_fd_config()
|
||||||
self.fd_config.model_config.num_key_value_heads = 1
|
self.fd_config.model_config.num_key_value_heads = 1
|
||||||
self.head_dim = 1
|
self.head_dim = 1
|
||||||
|
self.kv_num_heads = 1
|
||||||
self.prefix = "mock_layer"
|
self.prefix = "mock_layer"
|
||||||
self.cache_k_scale = None
|
self.cache_k_scale = None
|
||||||
self.cache_v_scale = None
|
self.cache_v_scale = None
|
||||||
@@ -77,7 +78,7 @@ class TestKVCacheMethodBase(unittest.TestCase):
|
|||||||
method = KVCacheMethodBase(config)
|
method = KVCacheMethodBase(config)
|
||||||
method.create_weights(self.layer)
|
method.create_weights(self.layer)
|
||||||
|
|
||||||
self.assertEqual(self.layer.cache_k_scale.shape, [1, 1])
|
self.assertEqual(self.layer.cache_k_scale.shape, [1])
|
||||||
|
|
||||||
def test_create_weights_int4_zp(self):
|
def test_create_weights_int4_zp(self):
|
||||||
# Test INT4 with zero point
|
# Test INT4 with zero point
|
||||||
|
Reference in New Issue
Block a user