[OPs] MoE support wfp8afp8(channelwise) and improve per_token_quant_fp8 (#4238)

This commit is contained in:
chen
2025-09-24 16:39:51 +08:00
committed by GitHub
parent 8b0ce8e3ab
commit 7c1fd19f0f
7 changed files with 683 additions and 33 deletions

View File

@@ -85,6 +85,17 @@ def per_block_cast_to_fp8(x: Tensor, block_size: list = [128, 128]) -> Tuple[Ten
)
def per_token_cast_to_fp8(x: Tensor) -> Tuple[Tensor, Tensor]:
"""
Per token cast to float8_e4m3fn used in wfp8apf8
"""
x_abs = paddle.abs(x).astype(paddle.float32)
x_max = x_abs.max(axis=-1, keepdim=True).clip_(min=1e-4)
x_s = x_max / 448.0
x_q = paddle.clip(x / x_s, -448.0, 448.0).astype(paddle.float8_e4m3fn)
return x_q, x_s
# for distributed tensor model parallel
def _set_var_distributed(var: Tensor, split_axis: int):
"""