diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index 1de39507a..8290e3986 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -110,7 +110,7 @@ class FlashAttentionBackend(AttentionBackend): self.kv_num_heads = kv_num_heads self.num_heads = num_heads self.head_dim = fd_config.model_config.head_dim - self.hidden_size = fd_config.model_config.hidden_size + self.hidden_size = self.num_heads * self.head_dim self.block_size = fd_config.parallel_config.block_size self.num_layers: int = fd_config.model_config.num_hidden_layers