mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[GCU] Support gcu platform (#2702)
baseline: e7fa57ebae
Co-authored-by: yongqiangma <xing.wo@163.com>
This commit is contained in:
@@ -55,6 +55,10 @@ class ErnieRotaryEmbedding:
|
||||
dtype="float32")
|
||||
emb = paddle.stack([freqs, freqs], axis=-1).reshape(
|
||||
(bsz, max_seq_len, self.rotary_dim))
|
||||
elif current_platform.is_gcu():
|
||||
# shape: [B, S, D]
|
||||
rot_emb = paddle.concat([freqs.cos(), freqs.sin()], axis=-1)
|
||||
return rot_emb
|
||||
else:
|
||||
# shape: [B, S, D/2]
|
||||
rot_emb = paddle.zeros(
|
||||
@@ -95,6 +99,10 @@ class QwenRotaryEmbedding:
|
||||
# shape: [B, S, D/2]
|
||||
freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"),
|
||||
inv_freq)
|
||||
if current_platform.is_gcu():
|
||||
# shape: [B, S, D]
|
||||
rot_emb = paddle.concat([freqs.cos(), freqs.sin()], axis=-1)
|
||||
return rot_emb
|
||||
# shape: [B, S, 1, D]
|
||||
emb = paddle.concat([freqs, freqs], axis=-1).reshape(
|
||||
(bsz, max_seq_len, 1, self.rotary_dim))
|
||||
|
Reference in New Issue
Block a user