[XPU] fix pos_emb_type bug (#4638)

This commit is contained in:
Lucas
2025-10-29 17:14:32 +08:00
committed by GitHub
parent d68345cb7e
commit 8f40dfa9bf

View File

@@ -850,7 +850,7 @@ class XPUModelRunner(ModelRunnerBase):
else: # neox style = False
rope_head_dim = head_dim // 2
if head_dim == self.model_config.head_dim:
if rope_head_dim == self.model_config.head_dim:
self.share_inputs["pos_emb_type"] = "NORMAL"
else:
self.share_inputs["pos_emb_type"] = "HALF_HEAD_DIM"