This commit is contained in:
co63oc
2025-09-01 17:50:17 +08:00
committed by GitHub
parent 0513a78ecc
commit d6369b4d51
67 changed files with 85 additions and 85 deletions

View File

@@ -219,7 +219,7 @@ class EpilogueVisitorPerRowPerColNf4 {
iterator_C_.clear_mask();
}
// NOTE(wangbojun) Currently, this kernel don't hanve implantention for
// adding elementwise beta, we keep this here for future useage beta_ =
// adding elementwise beta, we keep this here for future usage beta_ =
// (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr :
// params.elementwise.beta); if (beta_ == ElementAccumulator()) {
// iterator_C_.clear_mask();

View File

@@ -176,7 +176,7 @@ struct Nf4DefaultIteratorsTensorOp<cutlass::bfloat16_t,
///
/// Satisfies: ReadableTileIterator
///
template <typename ThreadMap_ ///< Thread map (conept: OutputTileThreadMap)
template <typename ThreadMap_ ///< Thread map (concept: OutputTileThreadMap)
>
class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, 16, 8, 8> {
public:

View File

@@ -64,7 +64,7 @@ template <
typename InstructionShape_,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation perfomed by GEMM
/// Operation performed by GEMM
typename Operator,
/// Store the accumulators in row major or column major. Row major is used
/// when output layout is interleaved.

View File

@@ -133,7 +133,7 @@ public:
/// Shape describing the number of warps filling the CTA
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN, Shape::kK / WarpGemm::kK>;
/// Number of warp-level GEMM oeprations
/// Number of warp-level GEMM operations
static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK);
static_assert(Operator::IteratorB::InstructionShape::kRow>=Operator::InstructionShape::kK,"");
static constexpr int kNumKIterationsPerWarpBLoad =

View File

@@ -509,7 +509,7 @@ public:
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
++this->warp_tile_iterator_B_;
}
// TOOD(wangbojun) lds_converter can be remove for int8 B input
// TODO(wangbojun) lds_converter can be remove for int8 B input
typename TransformBAfterLDS::result_type converted_frag_B =
lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);

View File

@@ -96,7 +96,7 @@ public:
/// Shape describing the number of warps filling the CTA
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN, Shape::kK / WarpGemm::kK>;
/// Number of warp-level GEMM oeprations
/// Number of warp-level GEMM operations
static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK);
static_assert(Operator::IteratorB::InstructionShape::kRow>=Operator::InstructionShape::kK,"");
static constexpr int kNumKIterationsPerWarpBLoad =

View File

@@ -646,7 +646,7 @@ public:
// );
// }
}
// TOOD(wangbojun) lds_converter can be remove for int8 B input
// TODO(wangbojun) lds_converter can be remove for int8 B input
// int4
// typename TransformBAfterLDS::result_type converted_frag_B =
// lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);