diff --git a/fastdeploy/model_executor/layers/quantization/kv_cache.py b/fastdeploy/model_executor/layers/quantization/kv_cache.py index 18361e2b4..cd461fde7 100644 --- a/fastdeploy/model_executor/layers/quantization/kv_cache.py +++ b/fastdeploy/model_executor/layers/quantization/kv_cache.py @@ -180,69 +180,70 @@ class KVCacheMethodBase(QuantMethodBase): else: raise NotImplementedError(f"{self.cache_quant_config.quant_type} is not implemented") - scale_shape = [layer.fd_config.model_config.num_key_value_heads] - if self.cache_quant_config.is_channel_wise: - scale_shape = [layer.kv_num_heads * layer.head_dim] + if "block_wise" not in layer.cache_quant_type_str: # dynamic cache kv block_wise_fp8 not need + scale_shape = [layer.fd_config.model_config.num_key_value_heads] + if self.cache_quant_config.is_channel_wise: + scale_shape = [layer.kv_num_heads * layer.head_dim] - layer.cache_k_scale = layer.create_parameter( - shape=scale_shape, - dtype=paddle.get_default_dtype(), - default_initializer=paddle.nn.initializer.Constant(0), - ) - layer.cache_v_scale = layer.create_parameter( - shape=scale_shape, - dtype=paddle.get_default_dtype(), - default_initializer=paddle.nn.initializer.Constant(0), - ) - - set_weight_attrs( - layer.cache_k_scale, - { - **extra_weight_attrs, - }, - ) - set_weight_attrs( - layer.cache_v_scale, - { - **extra_weight_attrs, - }, - ) - - layer.cache_k_out_scale = layer.create_parameter( - shape=scale_shape, - dtype=paddle.get_default_dtype(), - default_initializer=paddle.nn.initializer.Constant(0), - ) - layer.cache_v_out_scale = layer.create_parameter( - shape=scale_shape, - dtype=paddle.get_default_dtype(), - default_initializer=paddle.nn.initializer.Constant(0), - ) - - if self.cache_quant_config.has_zero_point: - layer.cache_k_zp = layer.create_parameter( + layer.cache_k_scale = layer.create_parameter( shape=scale_shape, dtype=paddle.get_default_dtype(), default_initializer=paddle.nn.initializer.Constant(0), ) - layer.cache_v_zp = layer.create_parameter( + layer.cache_v_scale = layer.create_parameter( shape=scale_shape, dtype=paddle.get_default_dtype(), default_initializer=paddle.nn.initializer.Constant(0), ) + set_weight_attrs( - layer.cache_k_zp, + layer.cache_k_scale, { **extra_weight_attrs, }, ) set_weight_attrs( - layer.cache_v_zp, + layer.cache_v_scale, { **extra_weight_attrs, }, ) + layer.cache_k_out_scale = layer.create_parameter( + shape=scale_shape, + dtype=paddle.get_default_dtype(), + default_initializer=paddle.nn.initializer.Constant(0), + ) + layer.cache_v_out_scale = layer.create_parameter( + shape=scale_shape, + dtype=paddle.get_default_dtype(), + default_initializer=paddle.nn.initializer.Constant(0), + ) + + if self.cache_quant_config.has_zero_point: + layer.cache_k_zp = layer.create_parameter( + shape=scale_shape, + dtype=paddle.get_default_dtype(), + default_initializer=paddle.nn.initializer.Constant(0), + ) + layer.cache_v_zp = layer.create_parameter( + shape=scale_shape, + dtype=paddle.get_default_dtype(), + default_initializer=paddle.nn.initializer.Constant(0), + ) + set_weight_attrs( + layer.cache_k_zp, + { + **extra_weight_attrs, + }, + ) + set_weight_attrs( + layer.cache_v_zp, + { + **extra_weight_attrs, + }, + ) + def process_loaded_weights(self, layer: nn.Layer, state_dict): """ use for loader v0