[Feature] refactor metax_gpu attention and moe and remove some useless code (#3688)

Co-authored-by: yongqiangma <xing.wo@163.com>
This commit is contained in:
SuperNova
2025-09-12 14:40:25 +08:00
committed by GitHub
parent cab7a633fe
commit 805f29a06c
5 changed files with 389 additions and 289 deletions

View File

@@ -52,9 +52,9 @@ class ErnieRotaryEmbedding:
rot_emb = paddle.concat([freqs.cos(), freqs.sin()], axis=-1)
return rot_emb
elif paddle.is_compiled_with_custom_device("metax_gpu"):
# shape: [B, S, D]
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim), dtype="float32")
emb = paddle.stack([freqs, freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim))
# shape: [B, S, D/2]
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32")
emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim // 2))
else:
# shape: [B, S, D/2]
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32")