mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-26 20:41:53 +08:00

* support machete weight only gemm * add generate * update * fix * change file location * add sm_version limit * fix * fix * fix ci * fix coverage * fix xpu
32 lines
1.3 KiB
Plaintext
32 lines
1.3 KiB
Plaintext
#pragma once
|
|
|
|
#include "utils/machete_collective_builder.cuh"
|
|
#include "machete_mainloop.cuh"
|
|
|
|
namespace cutlass::gemm::collective {
|
|
using namespace cute;
|
|
|
|
struct MacheteKernelTag {};
|
|
|
|
template <class ElementPairA_, class GmemLayoutA_, int AlignmentA,
|
|
class ElementPairB_, class GmemLayoutB_, int AlignmentB,
|
|
class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
|
|
class StageCountType, class KernelScheduleType>
|
|
struct MacheteCollectiveBuilder<
|
|
MacheteKernelTag, arch::Sm90, arch::OpClassTensorOp, ElementPairA_,
|
|
GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_, AlignmentB,
|
|
ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType,
|
|
KernelScheduleType,
|
|
cute::enable_if_t<(
|
|
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> ||
|
|
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong> ||
|
|
cute::is_same_v<KernelScheduleType,
|
|
KernelTmaWarpSpecializedCooperative>)>> {
|
|
using CollectiveOp = machete::MacheteCollectiveMma<
|
|
ElementPairA_, GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_,
|
|
AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK,
|
|
StageCountType, KernelScheduleType>;
|
|
};
|
|
|
|
}; // namespace cutlass::gemm::collective
|