diff --git a/fastdeploy/model_executor/layers/quantization/kv_cache.py b/fastdeploy/model_executor/layers/quantization/kv_cache.py index dd569de02..b5310f5c6 100644 --- a/fastdeploy/model_executor/layers/quantization/kv_cache.py +++ b/fastdeploy/model_executor/layers/quantization/kv_cache.py @@ -63,6 +63,7 @@ class KvCacheQuantConfig(QuantConfigBase): if self.quant_type == KvCacheQuantzationTypes.INT8 or self.quant_type == KvCacheQuantzationTypes.INT8_ZP: self.max_bound = 127.0 + self.is_channel_wise = True elif ( self.quant_type == KvCacheQuantzationTypes.FP8 or self.quant_type == KvCacheQuantzationTypes.FP8_ZP @@ -125,24 +126,12 @@ class KVCacheMethodBase(QuantMethodBase): load_scale """ - if self.cache_quant_config.is_channel_wise: - cache_k_scale_tensor = ( - get_tensor(state_dict.pop(self.cache_k_scale_name)) - .cast(paddle.get_default_dtype()) - .reshape_([-1, layer.head_dim]) - ) - cache_v_scale_tensor = ( - get_tensor(state_dict.pop(self.cache_v_scale_name)) - .cast(paddle.get_default_dtype()) - .reshape_([-1, layer.head_dim]) - ) - else: - cache_k_scale_tensor = ( - get_tensor(state_dict.pop(self.cache_k_scale_name)).cast(paddle.get_default_dtype()).reshape_([-1]) - ) - cache_v_scale_tensor = ( - get_tensor(state_dict.pop(self.cache_v_scale_name)).cast(paddle.get_default_dtype()).reshape_([-1]) - ) + cache_k_scale_tensor = ( + get_tensor(state_dict.pop(self.cache_k_scale_name)).cast(paddle.get_default_dtype()).reshape_([-1]) + ) + cache_v_scale_tensor = ( + get_tensor(state_dict.pop(self.cache_v_scale_name)).cast(paddle.get_default_dtype()).reshape_([-1]) + ) if self.cache_quant_config.has_zero_point: # cache_int4_zp cache_k_scale = 1.0 / cache_k_scale_tensor @@ -185,7 +174,7 @@ class KVCacheMethodBase(QuantMethodBase): scale_shape = [layer.fd_config.model_config.num_key_value_heads] if self.cache_quant_config.is_channel_wise: - scale_shape = [layer.fd_config.model_config.num_key_value_heads, layer.head_dim] + scale_shape = [layer.kv_num_heads * layer.head_dim] layer.cache_k_scale = layer.create_parameter( shape=scale_shape, diff --git a/tests/quantization/test_kv_cache.py b/tests/quantization/test_kv_cache.py index c33a0b7b0..744910309 100644 --- a/tests/quantization/test_kv_cache.py +++ b/tests/quantization/test_kv_cache.py @@ -37,6 +37,7 @@ class MockLayer(nn.Layer): self.fd_config = get_default_test_fd_config() self.fd_config.model_config.num_key_value_heads = 1 self.head_dim = 1 + self.kv_num_heads = 1 self.prefix = "mock_layer" self.cache_k_scale = None self.cache_v_scale = None @@ -77,7 +78,7 @@ class TestKVCacheMethodBase(unittest.TestCase): method = KVCacheMethodBase(config) method.create_weights(self.layer) - self.assertEqual(self.layer.cache_k_scale.shape, [1, 1]) + self.assertEqual(self.layer.cache_k_scale.shape, [1]) def test_create_weights_int4_zp(self): # Test INT4 with zero point