mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[XPU] refactor of block_attn param 'pos_emb_type' (#5511)
This commit is contained in:
@@ -822,10 +822,8 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
head_dim = self.model_config.head_dim
|
||||
if "paddleocr" in self.model_config.model_type: # neox style = True
|
||||
rope_head_dim = head_dim
|
||||
self.share_inputs["pos_emb_type"] = "NEOX"
|
||||
else: # neox style = False
|
||||
rope_head_dim = head_dim // 2
|
||||
self.share_inputs["pos_emb_type"] = "HALF_HEAD_DIM"
|
||||
|
||||
self.share_inputs["rope_emb"] = paddle.full(
|
||||
shape=[
|
||||
@@ -918,8 +916,6 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
# Update bad tokens len
|
||||
max_bad_tokens_len = paddle.max(self.share_inputs["bad_tokens_len"])
|
||||
|
||||
if self.enable_mm:
|
||||
self.forward_meta.pos_emb_type = self.share_inputs["pos_emb_type"]
|
||||
self.forward_meta.attn_backend = self.attn_backends[0]
|
||||
self.initialize_attention_backend()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user