[XPU] refactor of block_attn param 'pos_emb_type' (#5511)

This commit is contained in:
Lucas
2025-12-12 14:30:09 +08:00
committed by GitHub
parent 4eb55332f6
commit 888c4b992d
6 changed files with 25 additions and 19 deletions

View File

@@ -822,10 +822,8 @@ class XPUModelRunner(ModelRunnerBase):
head_dim = self.model_config.head_dim
if "paddleocr" in self.model_config.model_type: # neox style = True
rope_head_dim = head_dim
self.share_inputs["pos_emb_type"] = "NEOX"
else: # neox style = False
rope_head_dim = head_dim // 2
self.share_inputs["pos_emb_type"] = "HALF_HEAD_DIM"
self.share_inputs["rope_emb"] = paddle.full(
shape=[
@@ -918,8 +916,6 @@ class XPUModelRunner(ModelRunnerBase):
# Update bad tokens len
max_bad_tokens_len = paddle.max(self.share_inputs["bad_tokens_len"])
if self.enable_mm:
self.forward_meta.pos_emb_type = self.share_inputs["pos_emb_type"]
self.forward_meta.attn_backend = self.attn_backends[0]
self.initialize_attention_backend()