mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
load cache scale (#4623)
This commit is contained in:
@@ -318,6 +318,29 @@ def deal_state_dict(state_dict):
|
||||
src_tensor._share_data_with(dst_tensor)
|
||||
|
||||
|
||||
def load_cache_scale(model_path, fd_config, state_dict):
|
||||
file_path = os.path.join(model_path, "kv_cache_scale.json")
|
||||
if os.path.exists(file_path):
|
||||
with open(file_path, "r") as f:
|
||||
data = json.load(f)
|
||||
for i in range(fd_config.model_config.num_hidden_layers):
|
||||
|
||||
k_scale_name = f"ernie.layers.{i}.self_attn.cachek_matmul.activation_scale"
|
||||
v_scale_name = f"ernie.layers.{i}.self_attn.cachev_matmul.activation_scale"
|
||||
|
||||
k_scale = data[k_scale_name]
|
||||
k_scale_tensor = paddle.to_tensor(k_scale, dtype=paddle.get_default_dtype())
|
||||
state_dict[k_scale_name] = k_scale_tensor * 448.0
|
||||
|
||||
v_scale = data[v_scale_name]
|
||||
v_scale_tensor = paddle.to_tensor(v_scale, dtype=paddle.get_default_dtype())
|
||||
state_dict[v_scale_name] = v_scale_tensor * 448.0
|
||||
|
||||
logger.info(f"Loaded kv cache scales for layer {i}.")
|
||||
else:
|
||||
logger.warning(f"No kv_cache_scale.json found at {file_path}, skipping...")
|
||||
|
||||
|
||||
def load_composite_checkpoint(
|
||||
model_path: str,
|
||||
cls: PretrainedModel,
|
||||
@@ -361,4 +384,10 @@ def load_composite_checkpoint(
|
||||
)
|
||||
if not state_dict:
|
||||
raise ValueError("weight not found in state_dict !")
|
||||
|
||||
if hasattr(fd_config.quant_config, "kv_cache_quant_type"):
|
||||
kv_cache_quant_type = fd_config.quant_config.kv_cache_quant_type
|
||||
if kv_cache_quant_type == "float8_e4m3fn":
|
||||
load_cache_scale(model_path, fd_config, state_dict)
|
||||
|
||||
return state_dict
|
||||
|
||||
Reference in New Issue
Block a user