[BugFix]Fix wfp8afp8 triton moe group_topk renormalized=True (#4449)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

* fix group_topk renormalized=True

* check test
This commit is contained in:
chen
2025-10-16 23:17:48 +08:00
committed by GitHub
parent dbca63f862
commit db82e9a022
3 changed files with 4 additions and 5 deletions

View File

@@ -32,6 +32,7 @@ from fastdeploy.model_executor.layers.quantization.quant_base import (
QuantConfigBase,
QuantMethodBase,
)
from fastdeploy.model_executor.layers.utils import per_token_cast_to_fp8
from fastdeploy.model_executor.utils import TensorTracker, set_weight_attrs
@@ -143,10 +144,7 @@ class WFP8AFP8LinearMethod(QuantMethodBase):
return
weight_tensor = layer.weight.transpose([1, 0]).contiguous()
assert self.quant_config.weight_block_size == [-1, 1]
qweight, weight_scale = scaled_fp8_quant(
weight_tensor,
use_per_token_if_dynamic=True,
)
qweight, weight_scale = per_token_cast_to_fp8(weight_tensor)
if hasattr(layer.weight, "tensor_track"):
layer.weight.tensor_track = None