mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
support c4 attn && fix cache
This commit is contained in:
@@ -18,7 +18,7 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import paddle
|
||||
|
||||
@@ -191,16 +191,25 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
def get_kv_cache_shape(
|
||||
self,
|
||||
max_num_blocks: int,
|
||||
) -> Tuple[int, int, int, int]:
|
||||
kv_cache_quant_type: str = None,
|
||||
):
|
||||
"""
|
||||
Caculate kv cache shape
|
||||
"""
|
||||
return (
|
||||
max_num_blocks,
|
||||
self.kv_num_heads,
|
||||
self.block_size,
|
||||
self.head_dim,
|
||||
)
|
||||
if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp":
|
||||
return (
|
||||
max_num_blocks,
|
||||
self.kv_num_heads,
|
||||
self.block_size,
|
||||
self.head_dim // 2,
|
||||
)
|
||||
else:
|
||||
return (
|
||||
max_num_blocks,
|
||||
self.kv_num_heads,
|
||||
self.block_size,
|
||||
self.head_dim,
|
||||
)
|
||||
|
||||
def forward_mixed(
|
||||
self,
|
||||
|
Reference in New Issue
Block a user