mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 17:17:14 +08:00
supports dynamic Cfp8 (#3767)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
* supports dynamic Cfp8 * add unittest
This commit is contained in:
@@ -231,6 +231,17 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
metadata.kv_signal_metadata,
|
||||
layer.layer_id + self.start_layer_index,
|
||||
)
|
||||
cache_quant_type_str = getattr(layer, "cache_quant_type_str", "none")
|
||||
if cache_quant_type_str == "block_wise_fp8":
|
||||
cache_k = forward_meta.caches[4 * layer.layer_id]
|
||||
cache_v = forward_meta.caches[4 * layer.layer_id + 1]
|
||||
cache_k_scales = forward_meta.caches[4 * layer.layer_id + 2]
|
||||
cache_v_scales = forward_meta.caches[4 * layer.layer_id + 3]
|
||||
else:
|
||||
cache_k = forward_meta.caches[2 * layer.layer_id]
|
||||
cache_v = forward_meta.caches[2 * layer.layer_id + 1]
|
||||
cache_k_scales = getattr(layer, "cache_k_scale", None)
|
||||
cache_v_scales = getattr(layer, "cache_v_scale", None)
|
||||
|
||||
if self.use_output:
|
||||
quant_max_bound = getattr(layer, "quant_max_bound", 0.0)
|
||||
@@ -269,8 +280,8 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
|
||||
append_attention_with_output(
|
||||
qkv,
|
||||
forward_meta.caches[2 * layer.layer_id],
|
||||
forward_meta.caches[2 * layer.layer_id + 1],
|
||||
cache_k,
|
||||
cache_v,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
@@ -293,8 +304,8 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
metadata.attn_mask,
|
||||
layer.qkv_bias,
|
||||
layer.qkv_scale,
|
||||
getattr(layer, "cache_k_scale", None),
|
||||
getattr(layer, "cache_v_scale", None),
|
||||
cache_k_scales,
|
||||
cache_v_scales,
|
||||
getattr(layer, "cache_k_out_scale", None),
|
||||
getattr(layer, "cache_v_out_scale", None),
|
||||
getattr(layer, "cache_k_zp", None),
|
||||
@@ -325,8 +336,8 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
else:
|
||||
res = append_attention(
|
||||
qkv,
|
||||
forward_meta.caches[2 * layer.layer_id],
|
||||
forward_meta.caches[2 * layer.layer_id + 1],
|
||||
cache_k,
|
||||
cache_v,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
@@ -348,8 +359,8 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
metadata.attn_mask,
|
||||
layer.qkv_bias,
|
||||
layer.qkv_scale,
|
||||
getattr(layer, "cache_k_scale", None),
|
||||
getattr(layer, "cache_v_scale", None),
|
||||
cache_k_scales,
|
||||
cache_v_scales,
|
||||
getattr(layer, "cache_k_out_scale", None),
|
||||
getattr(layer, "cache_v_out_scale", None),
|
||||
getattr(layer, "cache_k_zp", None),
|
||||
|
Reference in New Issue
Block a user