load cache scale (#4623)

This commit is contained in:
Sunny-bot1
2025-10-31 11:57:57 +08:00
committed by GitHub
parent 71135d58a0
commit 3f15e6fa15

View File

@@ -318,6 +318,29 @@ def deal_state_dict(state_dict):
src_tensor._share_data_with(dst_tensor)
def load_cache_scale(model_path, fd_config, state_dict):
file_path = os.path.join(model_path, "kv_cache_scale.json")
if os.path.exists(file_path):
with open(file_path, "r") as f:
data = json.load(f)
for i in range(fd_config.model_config.num_hidden_layers):
k_scale_name = f"ernie.layers.{i}.self_attn.cachek_matmul.activation_scale"
v_scale_name = f"ernie.layers.{i}.self_attn.cachev_matmul.activation_scale"
k_scale = data[k_scale_name]
k_scale_tensor = paddle.to_tensor(k_scale, dtype=paddle.get_default_dtype())
state_dict[k_scale_name] = k_scale_tensor * 448.0
v_scale = data[v_scale_name]
v_scale_tensor = paddle.to_tensor(v_scale, dtype=paddle.get_default_dtype())
state_dict[v_scale_name] = v_scale_tensor * 448.0
logger.info(f"Loaded kv cache scales for layer {i}.")
else:
logger.warning(f"No kv_cache_scale.json found at {file_path}, skipping...")
def load_composite_checkpoint(
model_path: str,
cls: PretrainedModel,
@@ -361,4 +384,10 @@ def load_composite_checkpoint(
)
if not state_dict:
raise ValueError("weight not found in state_dict !")
if hasattr(fd_config.quant_config, "kv_cache_quant_type"):
kv_cache_quant_type = fd_config.quant_config.kv_cache_quant_type
if kv_cache_quant_type == "float8_e4m3fn":
load_cache_scale(model_path, fd_config, state_dict)
return state_dict