mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-26 12:31:27 +08:00

* support machete weight only gemm * add generate * update * fix * change file location * add sm_version limit * fix * fix * fix ci * fix coverage * fix xpu
86 lines
2.5 KiB
Python
86 lines
2.5 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import enum
|
|
from typing import Union
|
|
|
|
from cutlass_library import (
|
|
DataType,
|
|
DataTypeNames,
|
|
DataTypeSize,
|
|
DataTypeTag,
|
|
KernelScheduleTag,
|
|
KernelScheduleType,
|
|
enum_auto,
|
|
)
|
|
|
|
#
|
|
# Extend cutlass library with custom types, and missing values
|
|
#
|
|
|
|
|
|
class MACHETEDataType(enum.Enum):
|
|
u4b8 = enum_auto()
|
|
u8b128 = enum_auto()
|
|
|
|
|
|
class MixedInputKernelScheduleType(enum.Enum):
|
|
TmaWarpSpecialized = enum_auto()
|
|
TmaWarpSpecializedPingpong = enum_auto()
|
|
TmaWarpSpecializedCooperative = enum_auto()
|
|
|
|
|
|
MACHETEDataTypeNames: dict[Union[MACHETEDataType, DataType], str] = {
|
|
**DataTypeNames, # type: ignore
|
|
**{
|
|
MACHETEDataType.u4b8: "u4b8",
|
|
MACHETEDataType.u8b128: "u8b128",
|
|
},
|
|
}
|
|
|
|
MACHETEDataTypeTag: dict[Union[MACHETEDataType, DataType], str] = {
|
|
**DataTypeTag, # type: ignore
|
|
**{
|
|
MACHETEDataType.u4b8: "cutlass::machete_uint4b8_t",
|
|
MACHETEDataType.u8b128: "cutlass::machete_uint8b128_t",
|
|
},
|
|
}
|
|
|
|
MACHETEDataTypeSize: dict[Union[MACHETEDataType, DataType], int] = {
|
|
**DataTypeSize, # type: ignore
|
|
**{
|
|
MACHETEDataType.u4b8: 4,
|
|
MACHETEDataType.u8b128: 8,
|
|
},
|
|
}
|
|
|
|
MACHETEDataTypeMACHETEScalarTypeTag: dict[Union[MACHETEDataType, DataType], str] = {
|
|
MACHETEDataType.u4b8: "machete::kU4B8",
|
|
MACHETEDataType.u8b128: "machete::kU8B128",
|
|
DataType.u4: "machete::kU4",
|
|
DataType.u8: "machete::kU8",
|
|
DataType.s4: "machete::kS4",
|
|
DataType.s8: "machete::kS8",
|
|
DataType.f16: "machete::kFloat16",
|
|
DataType.bf16: "machete::kBfloat16",
|
|
}
|
|
|
|
MACHETEDataTypePaddleDataTypeTag: dict[Union[MACHETEDataType, DataType], str] = {
|
|
DataType.u8: "paddle::DataType::UINT8",
|
|
DataType.s8: "paddle::DataType::INT8",
|
|
DataType.e4m3: "paddle::DataType::FLOAT8_E4M3FN",
|
|
DataType.s32: "paddle::DataType::INT32",
|
|
DataType.f16: "paddle::DataType::FLOAT16",
|
|
DataType.bf16: "paddle::DataType::BFLOAT16",
|
|
DataType.f32: "paddle::DataType::FLOAT32",
|
|
}
|
|
|
|
MACHETEKernelScheduleTag: dict[Union[MixedInputKernelScheduleType, KernelScheduleType], str] = {
|
|
**KernelScheduleTag, # type: ignore
|
|
**{
|
|
MixedInputKernelScheduleType.TmaWarpSpecialized: "cutlass::gemm::KernelTmaWarpSpecialized",
|
|
MixedInputKernelScheduleType.TmaWarpSpecializedPingpong: "cutlass::gemm::KernelTmaWarpSpecializedPingpong",
|
|
MixedInputKernelScheduleType.TmaWarpSpecializedCooperative: "cutlass::gemm::KernelTmaWarpSpecializedCooperative",
|
|
},
|
|
}
|