mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
@@ -77,16 +77,7 @@ void DisPatchW4AFp8Gemm(
|
||||
max_tokens,
|
||||
stream)
|
||||
} else {
|
||||
GEMM_SWITCH_FP16(
|
||||
M, K, batch_size, token_padding_size, kBlockN, TailN,
|
||||
weight,
|
||||
input,
|
||||
out,
|
||||
weight_scale,
|
||||
input_row_sum,
|
||||
tokens,
|
||||
max_tokens,
|
||||
stream)
|
||||
PD_THROW("Only supported dtype in ['BFLOAT16'].");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -128,22 +119,7 @@ std::vector<paddle::Tensor> W4AFp8Gemm(
|
||||
input.stream());
|
||||
return {out};
|
||||
} else {
|
||||
paddle::Tensor out = paddle::empty({all_tokens, M}, paddle::DataType::FLOAT16, input.place());
|
||||
phi::dtype::float16 *out_data = out.data<phi::dtype::float16>();
|
||||
DisPatchW4AFp8Gemm(
|
||||
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
|
||||
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<uint8_t>()),
|
||||
tokens.data<int>(),
|
||||
input_row_sum.data<float>(),
|
||||
weight_scale.data<float>(),
|
||||
reinterpret_cast<cutlass::half_t*>(out_data),
|
||||
token_padding_size,
|
||||
max_tokens,
|
||||
batch_size,
|
||||
M,
|
||||
K,
|
||||
input.stream());
|
||||
return {out};
|
||||
PD_THROW("Only supported dtype in ['BFLOAT16'].");
|
||||
}
|
||||
} else {
|
||||
if (is_bflot16) {
|
||||
@@ -164,23 +140,7 @@ std::vector<paddle::Tensor> W4AFp8Gemm(
|
||||
input.stream());
|
||||
return {out};
|
||||
} else {
|
||||
paddle::Tensor out = paddle::empty({batch_size, token_padding_size, M}, paddle::DataType::FLOAT16, input.place());
|
||||
phi::dtype::float16 * out_data = out.data<phi::dtype::float16>();
|
||||
|
||||
DisPatchW4AFp8Gemm(
|
||||
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
|
||||
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<uint8_t>()),
|
||||
tokens.data<int>(),
|
||||
input_row_sum.data<float>(),
|
||||
weight_scale.data<float>(),
|
||||
reinterpret_cast<cutlass::half_t*>(out_data),
|
||||
token_padding_size,
|
||||
max_tokens,
|
||||
batch_size,
|
||||
M,
|
||||
K,
|
||||
input.stream());
|
||||
return {out};
|
||||
PD_THROW("Only supported dtype in ['BFLOAT16'].");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user