[GCU] Support gcu platform (#2702)

baseline: e7fa57ebae

Co-authored-by: yongqiangma <xing.wo@163.com>
This commit is contained in:
EnflameGCU
2025-07-08 13:00:52 +08:00
committed by GitHub
parent 26d5d737dd
commit d0f4d6ba3a
33 changed files with 2988 additions and 85 deletions

View File

@@ -55,6 +55,10 @@ class ErnieRotaryEmbedding:
dtype="float32")
emb = paddle.stack([freqs, freqs], axis=-1).reshape(
(bsz, max_seq_len, self.rotary_dim))
elif current_platform.is_gcu():
# shape: [B, S, D]
rot_emb = paddle.concat([freqs.cos(), freqs.sin()], axis=-1)
return rot_emb
else:
# shape: [B, S, D/2]
rot_emb = paddle.zeros(
@@ -95,6 +99,10 @@ class QwenRotaryEmbedding:
# shape: [B, S, D/2]
freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"),
inv_freq)
if current_platform.is_gcu():
# shape: [B, S, D]
rot_emb = paddle.concat([freqs.cos(), freqs.sin()], axis=-1)
return rot_emb
# shape: [B, S, 1, D]
emb = paddle.concat([freqs, freqs], axis=-1).reshape(
(bsz, max_seq_len, 1, self.rotary_dim))