【Fix】fix deepep dispatch (#5036)

* fix dispatch

* fix dispatch

---------

Co-authored-by: yuanxiaolan <yuanxiaolan01@baidu.com>
This commit is contained in:
yangjianfengo1
2025-11-17 10:34:01 +08:00
committed by GitHub
parent 3b80a799ab
commit 3afb717995
3 changed files with 21 additions and 2 deletions

View File

@@ -228,3 +228,22 @@ template void MoeFastHardamardWrapper<phi::dtype::bfloat16, int8_t>(
const int hadamard_block_size,
int8_t *out,
cudaStream_t &stream);
template void
MoeFastHardamardWrapper<phi::dtype::bfloat16, phi::dtype::float8_e4m3fn>(
const phi::dtype::bfloat16 *x_data,
const int64_t *expert_idx_per_token,
const int64_t *recv_expert_count,
const phi::dtype::bfloat16 *shift,
const phi::dtype::bfloat16 *smooth,
const float *quant_scales,
const int quant_round_type,
const float quant_max_bound,
const float quant_min_bound,
const int64_t token_num,
const int64_t dim,
const int num_max_tokens_per_expert,
bool used_in_ep_low_latency,
const int hadamard_block_size,
phi::dtype::float8_e4m3fn *out,
cudaStream_t &stream);

View File

@@ -174,7 +174,7 @@ for type in dtype:
template_head_file.write(
""" } else { \\
PADDLE_THROW(phi::errors::Unimplemented("W4aFp8 not supported m=%d k=%d experts=%d token_padding_size=%d kBlockN=%d groupsize=%d, please add [%d, %d, %d, %d, %d, %d] to the gemm_case array in the custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py file and recompile it\\n", _M, _K, _EXPERTS, _TokenPaddingSize, _kBlockN, _GROUPSIZE, _M, _K, _EXPERTS, _TokenPaddingSize, _kBlockN, _GROUPSIZE)); \\
PADDLE_THROW(phi::errors::Unimplemented("W4aFp8 not supported m=%d k=%d experts=%d token_padding_size=%d kBlockN=%d groupsize=%d, please add [%d, %d, %d, %d, %d] to the gemm_case array in the custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py file and recompile it\\n", _M, _K, _EXPERTS, _TokenPaddingSize, _kBlockN, _GROUPSIZE, _M, _K, _EXPERTS, _TokenPaddingSize, _GROUPSIZE)); \\
} \\
}"""
)

View File

@@ -295,7 +295,7 @@ class DeepEPEngine:
use_fp8=use_fp8,
async_finish=False,
return_recv_hook=True,
num_per_channel=quant_group_size,
# num_per_channel=quant_group_size,
)
return packed_recv_x, recv_expert_count, handle, dispatch_hook