add w4afp8 offline script (#3636)

This commit is contained in:
Yuan Xiaolan
2025-08-29 17:56:05 +08:00
committed by GitHub
parent f677c032c0
commit c71ee0831c
12 changed files with 163 additions and 37 deletions

View File

@@ -226,8 +226,8 @@ __global__ void permute_scale_kernel(
}
void W4AFp8GemmScalePermute(const paddle::Tensor& scale) {
const int row = scale.dims()[0];
const int col = scale.dims()[1];
const int row = scale.dims().size() == 2 ? scale.dims()[0] : 1;
const int col = scale.dims().size() == 2 ? scale.dims()[1] : scale.dims()[0];
if (col % 16 != 0) {
PD_THROW("Only supported when col is divisible by 16.");
}