[V1 Loader] Ernie kv cache quant support v1 loader (#3899)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled

* support c8 for ernie

* add unittest

* support vl

* fix c8
This commit is contained in:
YuanRisheng
2025-09-09 20:25:08 +08:00
committed by GitHub
parent 98bfefea02
commit b3fac5bde1
8 changed files with 497 additions and 53 deletions

View File

@@ -34,6 +34,7 @@ import os
from safetensors import safe_open
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.utils import default_weight_loader
class Attention(nn.Layer):
@@ -77,6 +78,7 @@ class Attention(nn.Layer):
ValueError: If the `v_head_dim` is less than 0.
"""
super().__init__()
self.fd_config = fd_config
self.num_heads: int = (
fd_config.model_config.num_attention_heads // fd_config.parallel_config.tensor_parallel_size
)
@@ -101,23 +103,21 @@ class Attention(nn.Layer):
self.use_neox_rotary_style: bool = use_neox_rotary_style
if fd_config.quant_config and hasattr(fd_config.quant_config, "kv_cache_quant_type"):
self.kvcache_quant_method: QuantMethodBase = fd_config.quant_config.get_quant_method(self)
self.quant_method: QuantMethodBase = fd_config.quant_config.get_quant_method(self)
else:
self.kvcache_quant_method = None
self.quant_method = None
if self.kvcache_quant_method is None:
if self.quant_method is None:
logger.info(f"Attention is running in cache kv {self._dtype} mode")
else:
logger.info(
f"Attention is running in cache kv {self.kvcache_quant_method.cache_quant_config.quant_type} mode"
)
logger.info(f"Attention is running in cache kv {self.quant_method.cache_quant_config.quant_type} mode")
self.use_qk_norm = use_qk_norm
self.rms_norm_eps = rms_norm_eps
if self.use_qk_norm:
self.q_norm_key = f"{self.prefix}.q_norm"
self.k_norm_key = f"{self.prefix}.k_norm"
self.init_weight()
self.init_weight()
if (
fd_config.moba_attention_config is not None
and fd_config.moba_attention_config.moba_encoder_top_k_left is not None
@@ -161,6 +161,15 @@ class Attention(nn.Layer):
)
def init_weight(self):
if self.quant_method is not None:
self.quant_method.create_weights(
self,
weight_loader=(
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
),
)
if self.use_qk_norm:
self.q_norm_weight = self.create_parameter(
shape=[self.qk_head_dim],
dtype="float32",
@@ -179,14 +188,23 @@ class Attention(nn.Layer):
"""
Attention only have quant related scales not other parameters.
"""
if self.kvcache_quant_method is not None:
self.kvcache_quant_method.create_weights(self, state_dict)
if self.quant_method is not None:
self.quant_method.process_loaded_weights(self, state_dict)
if self.use_qk_norm:
q_norm_weight_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.q_norm_key + ".weight")))
k_norm_weight_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.k_norm_key + ".weight")))
self.q_norm_weight.set_value(q_norm_weight_tensor.astype("float32"))
self.k_norm_weight.set_value(k_norm_weight_tensor.astype("float32"))
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
loaded_weight = get_tensor(loaded_weight).cast(paddle.get_default_dtype())
if self.quant_method.cache_quant_config.has_zero_point: # cache_int4_zp
loaded_weight = 1.0 / loaded_weight
else:
loaded_weight = self.quant_method.cache_quant_config.max_bound / loaded_weight
param.copy_(loaded_weight, False)
def forward(
self,
q: paddle.Tensor = None,

View File

@@ -21,8 +21,8 @@ import paddle
from paddle import nn
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.utils import set_weight_attrs
from ..utils import create_and_set_parameter
from .quant_base import QuantConfigBase, QuantMethodBase
@@ -117,9 +117,8 @@ class KVCacheMethodBase(QuantMethodBase):
"""
cache_k_zeropoint = get_tensor(state_dict.pop(self.cache_k_zp_name)).cast(paddle.get_default_dtype())
cache_v_zeropoint = get_tensor(state_dict.pop(self.cache_v_zp_name)).cast(paddle.get_default_dtype())
create_and_set_parameter(layer, "cache_k_zp", cache_k_zeropoint)
create_and_set_parameter(layer, "cache_v_zp", cache_v_zeropoint)
layer.cache_k_zp.set_value(cache_k_zeropoint)
layer.cache_v_zp.set_value(cache_v_zeropoint)
def load_scale(self, layer: nn.Layer, state_dict):
"""
@@ -156,21 +155,15 @@ class KVCacheMethodBase(QuantMethodBase):
cache_k_out_scale = cache_k_scale_tensor / self.cache_quant_config.max_bound
cache_v_out_scale = cache_v_scale_tensor / self.cache_quant_config.max_bound
create_and_set_parameter(layer, "cache_k_scale", cache_k_scale)
create_and_set_parameter(layer, "cache_v_scale", cache_v_scale)
create_and_set_parameter(layer, "cache_k_out_scale", cache_k_out_scale)
create_and_set_parameter(layer, "cache_v_out_scale", cache_v_out_scale)
layer.cache_k_scale.set_value(cache_k_scale)
layer.cache_v_scale.set_value(cache_v_scale)
layer.cache_k_out_scale.set_value(cache_k_out_scale)
layer.cache_v_out_scale.set_value(cache_v_out_scale)
def create_weights(self, layer: nn.Layer, state_dict):
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
"""
create_weights
"""
self.prefix = layer.prefix
self.cache_k_scale_name = layer.prefix + ".cachek_matmul.activation_scale"
self.cache_v_scale_name = layer.prefix + ".cachev_matmul.activation_scale"
self.cache_k_zp_name = layer.prefix + ".cachek_matmul.activation_zero_point"
self.cache_v_zp_name = layer.prefix + ".cachev_matmul.activation_zero_point"
if self.cache_quant_config.quant_type == KvCacheQuantzationTypes.INT8:
layer.cache_quant_type_str = "cache_int8"
layer.quant_max_bound = 127.0
@@ -190,11 +183,91 @@ class KVCacheMethodBase(QuantMethodBase):
else:
raise NotImplementedError(f"{self.cache_quant_config.quant_type} is not implemented")
scale_shape = [layer.fd_config.model_config.num_key_value_heads]
if self.cache_quant_config.is_channel_wise:
scale_shape = [layer.fd_config.model_config.num_key_value_heads, layer.head_dim]
layer.cache_k_scale = layer.create_parameter(
shape=scale_shape,
dtype=paddle.get_default_dtype(),
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.cache_v_scale = layer.create_parameter(
shape=scale_shape,
dtype=paddle.get_default_dtype(),
default_initializer=paddle.nn.initializer.Constant(0),
)
set_weight_attrs(
layer.cache_k_scale,
{
**extra_weight_attrs,
},
)
set_weight_attrs(
layer.cache_v_scale,
{
**extra_weight_attrs,
},
)
layer.cache_k_out_scale = layer.create_parameter(
shape=scale_shape,
dtype=paddle.get_default_dtype(),
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.cache_v_out_scale = layer.create_parameter(
shape=scale_shape,
dtype=paddle.get_default_dtype(),
default_initializer=paddle.nn.initializer.Constant(0),
)
if self.cache_quant_config.has_zero_point:
layer.cache_k_zp = layer.create_parameter(
shape=scale_shape,
dtype=paddle.get_default_dtype(),
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.cache_v_zp = layer.create_parameter(
shape=scale_shape,
dtype=paddle.get_default_dtype(),
default_initializer=paddle.nn.initializer.Constant(0),
)
set_weight_attrs(
layer.cache_k_zp,
{
**extra_weight_attrs,
},
)
set_weight_attrs(
layer.cache_v_zp,
{
**extra_weight_attrs,
},
)
def process_loaded_weights(self, layer: nn.Layer, state_dict):
"""
use for loader v0
"""
self.prefix = layer.prefix
self.cache_k_scale_name = layer.prefix + ".cachek_matmul.activation_scale"
self.cache_v_scale_name = layer.prefix + ".cachev_matmul.activation_scale"
self.cache_k_zp_name = layer.prefix + ".cachek_matmul.activation_zero_point"
self.cache_v_zp_name = layer.prefix + ".cachev_matmul.activation_zero_point"
if "block_wise" not in layer.cache_quant_type_str:
self.load_scale(layer, state_dict)
if self.cache_quant_config.has_zero_point:
self.load_zp(layer, state_dict)
def process_weights_after_loading(self, layer: nn.Layer):
"""
use for loader v1
"""
if layer.cache_k_scale._is_initialized():
layer.cache_k_out_scale.set_value(1 / layer.cache_k_scale)
if layer.cache_v_scale._is_initialized():
layer.cache_v_out_scale.set_value(1 / layer.cache_v_scale)
def apply(self, layer):
"""
apply

View File

@@ -539,6 +539,10 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
("qkv_proj", "v_proj", None, "v"),
("up_gate_proj", "gate_proj", None, "gate"),
("up_gate_proj", "up_proj", None, "up"),
("attn.cache_k_scale", "cachek_matmul.activation_scale", None, None),
("attn.cache_v_scale", "cachev_matmul.activation_scale", None, None),
("attn.cache_k_zp", "cachek_matmul.activation_zero_point", None, None),
("attn.cache_v_zp", "cachev_matmul.activation_zero_point", None, None),
]
expert_params_mapping = []
@@ -563,6 +567,7 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
all_param_mapping = general_params_mapping + expert_params_mapping
params_dict = dict(self.named_parameters())
process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()))
for loaded_weight_name, loaded_weight in weights_iterator:
@@ -591,7 +596,9 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
else:
weight_loader(param, loaded_weight, shard_id)
model_sublayer_name = re.sub(r"\.(up_gate_proj_weight|down_proj_weight|weight)$", "", model_param_name)
model_sublayer_name = re.sub(
r"\.(up_gate_proj_weight|down_proj_weight|weight|cache_k_scale|cache_v_scale)$", "", model_param_name
)
process_weights_after_loading_fn(model_sublayer_name, param)
if self.tie_word_embeddings:

View File

@@ -616,6 +616,10 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
("resampler_model", "ernie.resampler_model", None, None),
("vision_model", "ernie.vision_model", None, None),
("gate_correction_bias", "moe_statics.e_score_correction_bias", None, None),
("attn.cache_k_scale", "cachek_matmul.activation_scale", None, None),
("attn.cache_v_scale", "cachev_matmul.activation_scale", None, None),
("attn.cache_k_zp", "cachek_matmul.activation_zero_point", None, None),
("attn.cache_v_zp", "cachev_matmul.activation_zero_point", None, None),
# for torch model
("resampler_model", "model.resampler_model", None, None),
("qkv_proj", "q_proj", None, "q"),
@@ -679,7 +683,9 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
weight_loader(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)
else:
weight_loader(param, loaded_weight, shard_id)
model_sublayer_name = re.sub(r"\.(up_gate_proj_weight|down_proj_weight|weight)$", "", model_param_name)
model_sublayer_name = re.sub(
r"\.(up_gate_proj_weight|down_proj_weight|weight|cache_k_scale|cache_v_scale)$", "", model_param_name
)
process_weights_after_loading_fn(model_sublayer_name, param)
if self.tie_word_embeddings:
# because we use lazy guard and is not initialized by default

View File

@@ -709,6 +709,10 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
if quantization_config is not None:
quant_config_name = quantization_config["quantization"]
# TODO(YuanRisheng) is_checkpoint_bf16 may need to be removed and replaced by is_quantized in future
if "kv_cache_quant_type" in quantization_config and load_config.load_choices == "default_v1":
quantization_config["is_checkpoint_bf16"] = True
elif args.quantization != "None":
quantization_config = {}
quant_config_name = args.quantization

View File

@@ -1,15 +0,0 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

View File

@@ -0,0 +1,194 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from unittest.mock import Mock
import numpy as np
import paddle
from fastdeploy.config import CacheConfig, FDConfig, ModelConfig, ParallelConfig
from fastdeploy.model_executor.layers.attention.attention import Attention
class MockQuantMethod:
"""Mock quantization method for testing."""
def __init__(self, has_zero_point=False, max_bound=1.0):
self.cache_quant_config = Mock()
self.cache_quant_config.has_zero_point = has_zero_point
self.cache_quant_config.max_bound = max_bound
self.create_weights_called = False
self.create_weights_args = None
def create_weights(self, layer, weight_loader):
self.create_weights_called = True
self.create_weights_args = (layer, weight_loader)
def process_loaded_weights(self, layer, state_dict):
pass
class TestAttentionInitWeight(unittest.TestCase):
"""Test cases for Attention.init_weight method."""
def setUp(self):
"""Set up test fixtures."""
# Create mock config
self.model_config = Mock(spec=ModelConfig)
self.model_config.num_attention_heads = 32
self.model_config.head_dim = 128
self.model_config.num_key_value_heads = 8
self.model_config.model = "test_model"
self.model_config.num_hidden_layers = 12
self.parallel_config = Mock(spec=ParallelConfig)
self.parallel_config.tensor_parallel_size = 1
self.parallel_config.tensor_parallel_rank = 0
self.parallel_config.max_num_seqs = 8
self.cache_config = Mock(spec=CacheConfig)
self.fd_config = Mock(spec=FDConfig)
self.fd_config.model_config = self.model_config
self.fd_config.parallel_config = self.parallel_config
self.fd_config.cache_config = self.cache_config
self.fd_config.quant_config = None
self.fd_config.moba_attention_config = None
def test_init_weight_without_quantization(self):
"""Test init_weight without quantization."""
# Test case 1: No quantization, no qk_norm
attention = Attention(fd_config=self.fd_config, layer_id=0, use_qk_norm=False)
# Check that q_norm_weight and k_norm_weight are not created
self.assertFalse(hasattr(attention, "q_norm_weight"))
self.assertFalse(hasattr(attention, "k_norm_weight"))
def test_init_weight_with_qk_norm(self):
"""Test init_weight with qk_norm enabled."""
# Test case 2: No quantization, with qk_norm
attention = Attention(fd_config=self.fd_config, layer_id=0, use_qk_norm=True, rms_norm_eps=1e-6)
# Check that q_norm_weight and k_norm_weight are created
self.assertTrue(hasattr(attention, "q_norm_weight"))
self.assertTrue(hasattr(attention, "k_norm_weight"))
# Check parameter shapes
self.assertEqual(attention.q_norm_weight.shape, [attention.qk_head_dim])
self.assertEqual(attention.k_norm_weight.shape, [attention.qk_head_dim])
# Check parameter dtype
self.assertEqual(attention.q_norm_weight.dtype, paddle.float32)
self.assertEqual(attention.k_norm_weight.dtype, paddle.float32)
# Check initial values (should be zeros)
np.testing.assert_array_equal(
attention.q_norm_weight.numpy(), np.zeros(attention.qk_head_dim, dtype=np.float32)
)
np.testing.assert_array_equal(
attention.k_norm_weight.numpy(), np.zeros(attention.qk_head_dim, dtype=np.float32)
)
def test_init_weight_with_quantization(self):
"""Test init_weight with quantization enabled."""
# Test case 3: With quantization
mock_quant_method = MockQuantMethod()
self.fd_config.quant_config = Mock()
self.fd_config.quant_config.get_quant_method = Mock(return_value=mock_quant_method)
attention = Attention(fd_config=self.fd_config, layer_id=0, use_qk_norm=False)
# Check that quant_method.create_weights was called
self.assertTrue(mock_quant_method.create_weights_called)
self.assertEqual(mock_quant_method.create_weights_args[0], attention)
# Check that weight_loader is passed correctly
self.assertIsNotNone(mock_quant_method.create_weights_args[1])
class TestAttentionWeightLoader(unittest.TestCase):
"""Test cases for Attention.weight_loader method."""
def setUp(self):
"""Set up test fixtures."""
# Create mock config
self.model_config = Mock(spec=ModelConfig)
self.model_config.num_attention_heads = 32
self.model_config.head_dim = 128
self.model_config.num_key_value_heads = 8
self.model_config.model = "test_model"
self.model_config.num_hidden_layers = 12
self.parallel_config = Mock(spec=ParallelConfig)
self.parallel_config.tensor_parallel_size = 1
self.parallel_config.tensor_parallel_rank = 0
self.parallel_config.max_num_seqs = 8
self.cache_config = Mock(spec=CacheConfig)
self.fd_config = Mock(spec=FDConfig)
self.fd_config.model_config = self.model_config
self.fd_config.parallel_config = self.parallel_config
self.fd_config.cache_config = self.cache_config
self.fd_config.moba_attention_config = None
# Create mock quant method
self.mock_quant_method = MockQuantMethod()
self.fd_config.quant_config = Mock()
self.fd_config.quant_config.get_quant_method = Mock(return_value=self.mock_quant_method)
# Create attention layer
self.attention = Attention(fd_config=self.fd_config, layer_id=0, use_qk_norm=False)
def test_weight_loader_without_zero_point(self):
"""Test weight_loader without zero point."""
# Test case 1: No zero point
mock_quant_method = MockQuantMethod(has_zero_point=False, max_bound=8.0)
self.attention.quant_method = mock_quant_method
# Create mock parameter
param = paddle.zeros([10], dtype=paddle.float32)
# Create mock loaded weight
loaded_weight = np.array([2.0, 4.0, 8.0, 1.0, 0.5, 2.0, 4.0, 8.0, 1.0, 0.5])
# Call weight_loader
self.attention.weight_loader(param, loaded_weight)
# Check that the parameter is updated correctly
expected_value = 8.0 / loaded_weight
np.testing.assert_array_almost_equal(param.numpy(), expected_value.astype(np.float32))
def test_weight_loader_with_zero_point(self):
"""Test weight_loader with zero point."""
# Test case 2: With zero point
mock_quant_method = MockQuantMethod(has_zero_point=True, max_bound=8.0)
self.attention.quant_method = mock_quant_method
# Create mock parameter
param = paddle.zeros([10], dtype=paddle.float32)
# Create mock loaded weight
loaded_weight = np.array([2.0, 4.0, 8.0, 1.0, 0.5, 2.0, 4.0, 8.0, 1.0, 0.5])
# Call weight_loader
self.attention.weight_loader(param, loaded_weight)
# Check that the parameter is updated correctly
expected_value = 1.0 / loaded_weight
np.testing.assert_array_almost_equal(param.numpy(), expected_value.astype(np.float32))
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,157 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import numpy as np
import paddle
from paddle import nn
from fastdeploy.model_executor.layers.quantization.kv_cache import (
KVCacheMethodBase,
KvCacheQuantConfig,
KvCacheQuantzationTypes,
)
sys.path.append("../")
from tests.utils import get_default_test_fd_config
class MockLayer(nn.Layer):
def __init__(
self,
) -> None:
super().__init__()
self.fd_config = get_default_test_fd_config()
self.fd_config.model_config.num_key_value_heads = 1
self.head_dim = 1
self.prefix = "mock_layer"
self.cache_k_scale = None
self.cache_v_scale = None
self.cache_k_out_scale = None
self.cache_v_out_scale = None
self.cache_k_zp = None
self.cache_v_zp = None
class TestKVCacheMethodBase(unittest.TestCase):
def setUp(self):
self.layer = MockLayer()
def test_create_weights_int8(self):
# Test INT8 without zero point
config = KvCacheQuantConfig(
kv_cache_quant_type=KvCacheQuantzationTypes.INT8, is_channel_wise=False, has_zero_point=False
)
method = KVCacheMethodBase(config)
method.create_weights(self.layer)
self.assertEqual(self.layer.cache_quant_type_str, "cache_int8")
self.assertEqual(self.layer.quant_max_bound, 127.0)
self.assertEqual(self.layer.quant_min_bound, -127.0)
self.assertIsNotNone(self.layer.cache_k_scale)
self.assertIsNotNone(self.layer.cache_v_scale)
self.assertIsNotNone(self.layer.cache_k_out_scale)
self.assertIsNotNone(self.layer.cache_v_out_scale)
self.assertIsNone(self.layer.cache_k_zp)
self.assertIsNone(self.layer.cache_v_zp)
self.assertEqual(self.layer.cache_k_scale.shape, [1])
def test_create_weights_int8_channel_wise(self):
# Test INT8 with channel wise
config = KvCacheQuantConfig(
kv_cache_quant_type=KvCacheQuantzationTypes.INT8, is_channel_wise=True, has_zero_point=False
)
method = KVCacheMethodBase(config)
method.create_weights(self.layer)
self.assertEqual(self.layer.cache_k_scale.shape, [1, 1])
def test_create_weights_int4_zp(self):
# Test INT4 with zero point
config = KvCacheQuantConfig(
kv_cache_quant_type=KvCacheQuantzationTypes.INT4_ZP, is_channel_wise=False, has_zero_point=True
)
method = KVCacheMethodBase(config)
method.create_weights(self.layer)
self.assertEqual(self.layer.cache_quant_type_str, "cache_int4_zp")
self.assertEqual(self.layer.quant_max_bound, 7.0)
self.assertEqual(self.layer.quant_min_bound, -7.0)
self.assertIsNotNone(self.layer.cache_k_zp)
self.assertIsNotNone(self.layer.cache_v_zp)
def test_process_loaded_weights_int8(self):
# Test process INT8 weights
config = KvCacheQuantConfig(
kv_cache_quant_type=KvCacheQuantzationTypes.INT8, is_channel_wise=False, has_zero_point=False
)
method = KVCacheMethodBase(config)
method.create_weights(self.layer)
state_dict = {
"mock_layer.cachek_matmul.activation_scale": np.array([2.0], dtype=np.float32),
"mock_layer.cachev_matmul.activation_scale": np.array([3.0], dtype=np.float32),
}
method.process_loaded_weights(self.layer, state_dict)
self.assertAlmostEqual(self.layer.cache_k_scale.numpy()[0], 127.0 / 2.0, places=3)
self.assertAlmostEqual(self.layer.cache_v_scale.numpy()[0], 127.0 / 3.0, places=3)
self.assertAlmostEqual(self.layer.cache_k_out_scale.numpy()[0], 2.0 / 127.0, places=3)
self.assertAlmostEqual(self.layer.cache_v_out_scale.numpy()[0], 3.0 / 127.0, places=3)
def test_process_loaded_weights_int4_zp(self):
# Test process INT4 with zero point weights
config = KvCacheQuantConfig(
kv_cache_quant_type=KvCacheQuantzationTypes.INT4_ZP, is_channel_wise=False, has_zero_point=True
)
method = KVCacheMethodBase(config)
method.create_weights(self.layer)
state_dict = {
"mock_layer.cachek_matmul.activation_scale": np.array([2.0], dtype=np.float32),
"mock_layer.cachev_matmul.activation_scale": np.array([3.0], dtype=np.float32),
"mock_layer.cachek_matmul.activation_zero_point": np.array([1.0], dtype=np.float32),
"mock_layer.cachev_matmul.activation_zero_point": np.array([2.0], dtype=np.float32),
}
method.process_loaded_weights(self.layer, state_dict)
self.assertAlmostEqual(self.layer.cache_k_scale.numpy()[0], 1.0 / 2.0, places=3)
self.assertAlmostEqual(self.layer.cache_v_scale.numpy()[0], 1.0 / 3.0, places=3)
self.assertAlmostEqual(self.layer.cache_k_out_scale.numpy()[0], 2.0)
self.assertAlmostEqual(self.layer.cache_v_out_scale.numpy()[0], 3.0)
self.assertAlmostEqual(self.layer.cache_k_zp.numpy()[0], 1.0)
self.assertAlmostEqual(self.layer.cache_v_zp.numpy()[0], 2.0)
def test_process_weights_after_loading_initialized(self):
# Test process weights after loading when scale is initialized
config = KvCacheQuantConfig(
kv_cache_quant_type=KvCacheQuantzationTypes.INT8, is_channel_wise=False, has_zero_point=False
)
method = KVCacheMethodBase(config)
method.create_weights(self.layer)
# Simulate initialized scale
self.layer.cache_k_scale.set_value(paddle.to_tensor([2.0], dtype="float32"))
self.layer.cache_v_scale.set_value(paddle.to_tensor([3.0], dtype="float32"))
method.process_weights_after_loading(self.layer)
self.assertAlmostEqual(self.layer.cache_k_out_scale.numpy()[0], 0.5)
self.assertAlmostEqual(self.layer.cache_v_out_scale.numpy()[0], 1.0 / 3.0, places=3)
if __name__ == "__main__":
unittest.main()