enable dcu ci (#3402)

This commit is contained in:
lifulll
2025-08-29 10:23:08 +08:00
committed by GitHub
parent 73d60fe64d
commit 72094d4d82
11 changed files with 295 additions and 5 deletions

View File

@@ -101,11 +101,12 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
self,
layer: nn.Layer,
x: paddle.Tensor,
gate_out: paddle.Tensor,
gate: nn.Layer,
) -> paddle.Tensor:
"""
Triton compute Fused MoE.
"""
gate_out = gate(x.cast("float32"))
token_num = x.shape[0]
top_k = layer.top_k
num_local_experts = layer.num_local_experts
@@ -113,7 +114,6 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
moe_intermediate_size = layer.moe_intermediate_size
hidden_size = layer.hidden_size
gate_out = paddle.matmul(x.cast("float32"), layer.gate_weight)
scores = paddle.nn.functional.softmax(gate_out, axis=-1)
scores += layer.gate_correction_bias
topk_weights, topk_ids = paddle.topk(scores, k=top_k, axis=-1, sorted=False)

View File

@@ -21,6 +21,8 @@ def native_top_p_sampling(probs: paddle.Tensor, top_p: paddle.Tensor) -> tuple[p
sorted_indices = paddle.argsort(probs, descending=True)
sorted_probs = paddle.sort(probs, descending=True)
cumulative_probs = paddle.cumsum(sorted_probs, axis=-1)
if probs.shape[0] != top_p.shape[0]:
top_p = paddle.slice(top_p, [0], [0], [probs.shape[0]])
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove = paddle.cast(sorted_indices_to_remove, dtype="int64")
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()