【bug fix】修复w4a8编译慢 (#3510)

* 修复w4a8编译

* code style

* 修复tma copy
This commit is contained in:
yangjianfengo1
2025-08-21 18:50:14 +08:00
committed by GitHub
parent a5692e8b7d
commit e5aa7087db
3 changed files with 9 additions and 54 deletions

View File

@@ -77,16 +77,7 @@ void DisPatchW4AFp8Gemm(
max_tokens,
stream)
} else {
GEMM_SWITCH_FP16(
M, K, batch_size, token_padding_size, kBlockN, TailN,
weight,
input,
out,
weight_scale,
input_row_sum,
tokens,
max_tokens,
stream)
PD_THROW("Only supported dtype in ['BFLOAT16'].");
}
}
@@ -128,22 +119,7 @@ std::vector<paddle::Tensor> W4AFp8Gemm(
input.stream());
return {out};
} else {
paddle::Tensor out = paddle::empty({all_tokens, M}, paddle::DataType::FLOAT16, input.place());
phi::dtype::float16 *out_data = out.data<phi::dtype::float16>();
DisPatchW4AFp8Gemm(
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<uint8_t>()),
tokens.data<int>(),
input_row_sum.data<float>(),
weight_scale.data<float>(),
reinterpret_cast<cutlass::half_t*>(out_data),
token_padding_size,
max_tokens,
batch_size,
M,
K,
input.stream());
return {out};
PD_THROW("Only supported dtype in ['BFLOAT16'].");
}
} else {
if (is_bflot16) {
@@ -164,23 +140,7 @@ std::vector<paddle::Tensor> W4AFp8Gemm(
input.stream());
return {out};
} else {
paddle::Tensor out = paddle::empty({batch_size, token_padding_size, M}, paddle::DataType::FLOAT16, input.place());
phi::dtype::float16 * out_data = out.data<phi::dtype::float16>();
DisPatchW4AFp8Gemm(
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<uint8_t>()),
tokens.data<int>(),
input_row_sum.data<float>(),
weight_scale.data<float>(),
reinterpret_cast<cutlass::half_t*>(out_data),
token_padding_size,
max_tokens,
batch_size,
M,
K,
input.stream());
return {out};
PD_THROW("Only supported dtype in ['BFLOAT16'].");
}
}
}