[Intel HPU] fix bugs caused by other commits (#5074)

* [Intel HPU] fix bugs caused by other commits

* update code by copilot
This commit is contained in:
fmiao2372
2025-11-17 15:28:55 +08:00
committed by GitHub
parent 33f96ff93a
commit 74f33efdbf
2 changed files with 13 additions and 3 deletions

View File

@@ -186,7 +186,15 @@ class HPUAttentionBackend(AttentionBackend_HPU):
HPUAttentionBackend backend implementation.
"""
def __init__(self, llm_config: FDConfig, kv_num_heads: int, num_heads: int, head_dim: int):
def __init__(
self,
llm_config: FDConfig,
kv_num_heads: int,
num_heads: int,
head_dim: int,
encoder_block_shape_q: int = -1,
decoder_block_shape_q: int = -1,
):
"""
HPUAttentionBackend __init__
"""
@@ -239,11 +247,13 @@ class HPUAttentionBackend(AttentionBackend_HPU):
def get_kv_cache_shape(
self,
max_num_blocks: int,
kv_cache_quant_type: Optional[str] = None,
):
"""
Caculate kv cache shape
"""
return (max_num_blocks, self.block_size, self.kv_num_heads, self.head_dim)
key_cache_shape = value_cache_shape = [max_num_blocks, self.block_size, self.kv_num_heads, self.head_dim]
return key_cache_shape, value_cache_shape
def forward_extend(
self, src, qkv_proj: QKVParallelLinear, o_proj: RowParallelLinear, layer: Attention, forward_meta

View File

@@ -328,7 +328,7 @@ class HPUModelRunner(ModelRunnerBase):
# Sampler
if not self.speculative_decoding:
self.sampler = Sampler()
self.sampler = Sampler(fd_config)
else:
self.sampler = SpeculativeSampler(fd_config)