[Optimize]support machete weight only gemm (#3561)

* support machete weight only gemm

* add generate

* update

* fix

* change file location

* add sm_version limit

* fix

* fix

* fix ci

* fix coverage

* fix xpu
This commit is contained in:
Sunny-bot1
2025-08-28 09:49:58 +08:00
committed by GitHub
parent e37e86b3b8
commit 479c8b85d3
29 changed files with 5436 additions and 0 deletions

3
.gitignore vendored
View File

@@ -159,6 +159,9 @@ custom_ops/gpu_ops/fp8_deep_gemm/deep_gemm/include/cute
#marlin_kernel
custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/kernel_*.cu
#machete_kernel
custom_ops/gpu_ops/machete/generated
# buff
custom_ops/tmp*

View File

@@ -129,6 +129,24 @@ paddle::Tensor FusedExpertMoeFunc(
const std::string &quant_method, const int moe_topk,
const bool norm_topk_prob, const bool group_moe);
std::vector<paddle::Tensor> MacheteMMKernel(
paddle::Tensor const& A, paddle::Tensor const& B,
paddle::optional<paddle::Tensor> const& maybe_group_scales,
paddle::optional<paddle::Tensor> const& maybe_group_zeros,
paddle::optional<paddle::Tensor> const& maybe_channel_scales,
paddle::optional<paddle::Tensor> const& maybe_token_scales,
std::string const& b_type_str,
std::string const& maybe_out_type_str,
int64_t const& maybe_group_size,
std::string const& maybe_schedule);
std::vector<paddle::Tensor> MachetePrepackBKernel(
paddle::Tensor const& B, std::string const& a_type_str, std::string const& b_type_str,
std::string const& maybe_group_scales_type_str);
std::vector<std::string> MacheteSupportedSchedules(
std::string const& a_type_str, std::string const& b_type_str);
std::vector<paddle::Tensor> MoeExpertDispatch(
const paddle::Tensor &input, const paddle::Tensor &gating_output,
const paddle::optional<paddle::Tensor> &gating_correction_bias,
@@ -924,6 +942,25 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("recv_expert_count"), py::arg("block_size"),
"per token per block quant");
/*machete/machete_mm.cu
* machete_mm
*/
m.def("machete_mm", &MacheteMMKernel, py::arg("A"), py::arg("B"), py::arg("maybe_group_scale"),
py::arg("maybe_group_zeros"), py::arg("maybe_channel_scales"), py::arg("maybe_token_scales"),
py::arg("b_type_str"), py::arg("maybe_out_type_str"), py::arg("maybe_group_size"),
py::arg("maybe_schedule"),
"machete mm function");
/*machete/machete_prepack_B.cu
* machete_prepack_B
*/
m.def("machete_prepack_B", &MachetePrepackBKernel, "machete prepacked B function");
/*machete/machete_supported_schedules.cu
* machete_supported_schedules
*/
m.def("machete_supported_schedules", &MacheteSupportedSchedules, "machete supported schedules function");
/**
* moe/fused_moe/moe_topk_select.cu
* moe_topk_select

View File

@@ -6,6 +6,8 @@
// clang-format off
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp"
#include "helper.h"
// clang-format on
/*

View File

@@ -0,0 +1,574 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
import math
import os
import shutil
import sys
from collections.abc import Iterable
from copy import deepcopy
from dataclasses import dataclass, fields
from functools import reduce
from typing import Optional, Union
import jinja2
cur_dir = os.path.dirname(os.path.abspath(__file__))
p = os.path.abspath(os.path.join(cur_dir, "../../third_party/cutlass/python"))
sys.path.insert(0, p)
from cutlass_library import (
EpilogueScheduleTag,
EpilogueScheduleType,
TileSchedulerTag,
TileSchedulerType,
)
# yapf conflicts with isort for this block
# yapf: disable
from machete_cutlass_library_extension import (
DataType,
MACHETEDataType,
MACHETEDataTypeMACHETEScalarTypeTag,
MACHETEDataTypeNames,
MACHETEDataTypePaddleDataTypeTag,
MACHETEDataTypeSize,
MACHETEDataTypeTag,
MACHETEKernelScheduleTag,
MixedInputKernelScheduleType,
)
# yapf: enable
#
# Generator templating
#
DISPATCH_TEMPLATE = """
#include "../machete_mm_launcher.cuh"
namespace machete {
{% for impl_config in impl_configs %}
{% set type_sig = gen_type_sig(impl_config.types) -%}
{% for s in impl_config.schedules %}
extern paddle::Tensor impl_{{type_sig}}_sch_{{gen_sch_sig(s)}}(MMArgs);
{%- endfor %}
paddle::Tensor mm_dispatch_{{type_sig}}(MMArgs args) {
[[maybe_unused]] auto M = args.A.shape()[0];
[[maybe_unused]] auto N = args.B.shape()[1];
[[maybe_unused]] auto K = args.A.shape()[1];
if (!args.maybe_schedule) {
{%- for cond, s in impl_config.heuristic %}
{%if cond is not none%}if ({{cond}})
{%- else %}else
{%- endif %}
return impl_{{type_sig}}_sch_{{ gen_sch_sig(s) }}(args);{% endfor %}
}
{%- for s in impl_config.schedules %}
if (*args.maybe_schedule == "{{ gen_sch_sig(s) }}")
return impl_{{type_sig}}_sch_{{ gen_sch_sig(s) }}(args);
{%- endfor %}
PADDLE_ENFORCE(false, "machete_gemm(..) is not implemented ");
}
{%- endfor %}
static inline std::optional<paddle::DataType> maybe_scalartype(
std::optional<paddle::Tensor> const& t) {
if (!t) {
return std::nullopt;
} else {
return t->dtype();
};
}
paddle::Tensor mm_dispatch(MMArgs args) {
auto out_type = args.maybe_out_type.value_or(args.A.dtype());
auto a_type = args.A.dtype();
auto maybe_g_scales_type = maybe_scalartype(args.maybe_group_scales);
auto maybe_g_zeros_type = maybe_scalartype(args.maybe_group_zeros);
auto maybe_ch_scales_type = maybe_scalartype(args.maybe_channel_scales);
auto maybe_tok_scales_type = maybe_scalartype(args.maybe_token_scales);
{% for impl_config in impl_configs %}
{% set t = impl_config.types -%}
{% set type_sig = gen_type_sig(t) -%}
if (args.b_type == {{MACHETEScalarTypeTag[t.b]}}
&& a_type == {{PaddleTypeTag[t.a]}}
&& out_type == {{PaddleTypeTag[t.out]}}
&& {%if t.b_group_scale != void -%}
maybe_g_scales_type == {{PaddleTypeTag[t.b_group_scale]}}
{%- else %}!maybe_g_scales_type{%endif%}
&& {%if t.b_group_zeropoint != void -%}
maybe_g_zeros_type == {{PaddleTypeTag[t.b_group_zeropoint]}}
{%- else %}!maybe_g_zeros_type{%endif%}
&& {%if t.b_channel_scale != void -%}
maybe_ch_scales_type == {{PaddleTypeTag[t.b_channel_scale]}}
{%- else %}!maybe_ch_scales_type{%endif%}
&& {%if t.a_token_scale != void -%}
maybe_tok_scales_type == {{PaddleTypeTag[t.a_token_scale]}}
{%- else %}!maybe_tok_scales_type{%endif%}
) {
return mm_dispatch_{{type_sig}}(args);
}
{%- endfor %}
PADDLE_ENFORCE(
false, "machete_mm(..) is not implemented "
"; implemented types are: \\n",
{%- for impl_config in impl_configs %}
{% set t = impl_config.types -%}
"\\t{{gen_type_option_name(t)}}\\n",
{%- endfor %}
"");
}
std::vector<std::string> supported_schedules_dispatch(
SupportedSchedulesArgs args) {
auto out_type = args.maybe_out_type.value_or(args.a_type);
{% for impl_config in impl_configs %}
{% set t = impl_config.types -%}
{% set schs = impl_config.schedules -%}
if (args.b_type == {{MACHETEScalarTypeTag[t.b]}}
&& args.a_type == {{PaddleTypeTag[t.a]}}
&& out_type == {{PaddleTypeTag[t.out]}}
&& {%if t.b_group_scale != void -%}
args.maybe_group_scales_type == {{PaddleTypeTag[t.b_group_scale]}}
{%- else %}!args.maybe_group_scales_type{%endif%}
&& {%if t.b_group_zeropoint != void-%}
args.maybe_group_zeros_type == {{PaddleTypeTag[t.b_group_zeropoint]}}
{%- else %}!args.maybe_group_zeros_type{%endif%}
) {
return {
{%- for s in impl_config.schedules %}
"{{gen_sch_sig(s)}}"{% if not loop.last %},{% endif %}
{%- endfor %}
};
}
{%- endfor %}
return {};
};
}; // namespace machete
"""
IMPL_TEMPLATE = """
#include "../machete_mm_launcher.cuh"
namespace machete {
{% for sch in unique_schedules(impl_configs) %}
{% set sch_sig = gen_sch_sig(sch) -%}
struct sch_{{sch_sig}} {
using TileShapeNM = Shape<{{
to_cute_constant(sch.tile_shape_mn)|join(', ')}}>;
using ClusterShape = Shape<{{
to_cute_constant(sch.cluster_shape_mnk)|join(', ')}}>;
// TODO: Reimplement
// using KernelSchedule = {{KernelScheduleTag[sch.kernel_schedule]}};
using EpilogueSchedule = {{EpilogueScheduleTag[sch.epilogue_schedule]}};
using TileScheduler = {{TileSchedulerTag[sch.tile_scheduler]}};
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
};
{% endfor %}
{% for impl_config in impl_configs %}
{% set t = impl_config.types -%}
{% set schs = impl_config.schedules -%}
{% set type_sig = gen_type_sig(t) -%}
template<typename Sch>
using Kernel_{{type_sig}} = MacheteKernelTemplate<
{{DataTypeTag[t.a]}}, // ElementA
{{DataTypeTag[t.b]}}, // ElementB
{{DataTypeTag[t.out]}}, // ElementD
{{DataTypeTag[t.accumulator]}}, // Accumulator
{{DataTypeTag[t.b_group_scale]}}, // GroupScaleT
{{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT
{{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT
{{DataTypeTag[t.a_token_scale]}}, // TokenScaleT
cutlass::gemm::KernelTmaWarpSpecializedCooperative,
Sch>;
{% for sch in schs %}
{% set sch_sig = gen_sch_sig(sch) -%}
paddle::Tensor
impl_{{type_sig}}_sch_{{sch_sig}}(MMArgs args) {
return run_impl<Kernel_{{type_sig}}<sch_{{sch_sig}}>>(args);
}
{%- endfor %}
{%- endfor %}
}; // namespace machete
"""
PREPACK_TEMPLATE = """
#include "../machete_prepack_launcher.cuh"
namespace machete {
paddle::Tensor prepack_B_dispatch(PrepackBArgs args) {
auto convert_type = args.maybe_group_scales_type.value_or(args.a_type);
{%- for t in types %}
{% set b_type = unsigned_type_with_bitwidth(t.b_num_bits) %}
if (args.a_type == {{PaddleTypeTag[t.a]}}
&& args.b_type.size_bits() == {{t.b_num_bits}}
&& convert_type == {{PaddleTypeTag[t.convert]}}) {
return prepack_impl<
PrepackedLayoutBTemplate<
{{DataTypeTag[t.a]}}, // ElementA
{{DataTypeTag[b_type]}}, // ElementB
{{DataTypeTag[t.convert]}}, // ElementConvert
{{DataTypeTag[t.accumulator]}}, // Accumulator
cutlass::layout::ColumnMajor,
cutlass::gemm::KernelTmaWarpSpecializedCooperative>
>(args.B);
}
{%- endfor %}
PADDLE_ENFORCE(false,
"prepack_B_dispatch(..) is not implemented");
}
}; // namespace machete
"""
TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperative
TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative
@dataclass(frozen=True)
class ScheduleConfig:
tile_shape_mn: tuple[int, int]
cluster_shape_mnk: tuple[int, int, int]
kernel_schedule: MixedInputKernelScheduleType
epilogue_schedule: EpilogueScheduleType
tile_scheduler: TileSchedulerType
@dataclass(frozen=True)
class TypeConfig:
a: DataType
b: Union[DataType, MACHETEDataType]
b_group_scale: DataType
b_group_zeropoint: DataType
b_channel_scale: DataType
a_token_scale: DataType
out: DataType
accumulator: DataType
@dataclass(frozen=True)
class PrepackTypeConfig:
a: DataType
b_num_bits: int
convert: DataType
accumulator: DataType
@dataclass
class ImplConfig:
types: TypeConfig
schedules: list[ScheduleConfig]
heuristic: list[tuple[Optional[str], ScheduleConfig]]
def generate_sch_sig(schedule_config: ScheduleConfig) -> str:
tile_shape = f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}"
cluster_shape = (
f"{schedule_config.cluster_shape_mnk[0]}"
+ f"x{schedule_config.cluster_shape_mnk[1]}"
+ f"x{schedule_config.cluster_shape_mnk[2]}"
)
kernel_schedule = MACHETEKernelScheduleTag[schedule_config.kernel_schedule].split("::")[-1]
epilogue_schedule = EpilogueScheduleTag[schedule_config.epilogue_schedule].split("::")[-1]
tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler].split("::")[-1]
return f"{tile_shape}_{cluster_shape}_{kernel_schedule}" + f"_{epilogue_schedule}_{tile_scheduler}"
# mostly unique shorter sch_sig
def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str:
kernel_terse_names_replace = {
"KernelTmaWarpSpecializedCooperative": "TmaMI_",
"TmaWarpSpecializedCooperative_": "TmaCoop_",
"StreamKScheduler": "streamK",
}
sch_sig = generate_sch_sig(schedule_config)
for orig, terse in kernel_terse_names_replace.items():
sch_sig = sch_sig.replace(orig, terse)
return sch_sig
# unique type_name
def generate_type_signature(kernel_types: TypeConfig):
return str("".join([MACHETEDataTypeNames[getattr(kernel_types, field.name)] for field in fields(TypeConfig)]))
def generate_type_option_name(kernel_types: TypeConfig):
return ", ".join(
[
f"{field.name.replace('b_', 'with_')+'_type'}=" + MACHETEDataTypeNames[getattr(kernel_types, field.name)]
for field in fields(TypeConfig)
]
)
def is_power_of_two(n):
return (n != 0) and (n & (n - 1) == 0)
def to_cute_constant(value: list[int]):
def _to_cute_constant(value: int):
if is_power_of_two(value):
return f"_{value}"
else:
return f"Int<{value}>"
if isinstance(value, Iterable):
return [_to_cute_constant(value) for value in value]
else:
return _to_cute_constant(value)
def unique_schedules(impl_configs: list[ImplConfig]):
return list(set(sch for impl_config in impl_configs for sch in impl_config.schedules))
def unsigned_type_with_bitwidth(num_bits):
return {
4: DataType.u4,
8: DataType.u8,
16: DataType.u16,
32: DataType.u32,
64: DataType.u64,
}[num_bits]
template_globals = {
"void": DataType.void,
"DataTypeTag": MACHETEDataTypeTag,
"MACHETEScalarTypeTag": MACHETEDataTypeMACHETEScalarTypeTag,
"PaddleTypeTag": MACHETEDataTypePaddleDataTypeTag,
"KernelScheduleTag": MACHETEKernelScheduleTag,
"EpilogueScheduleTag": EpilogueScheduleTag,
"TileSchedulerTag": TileSchedulerTag,
"to_cute_constant": to_cute_constant,
"gen_sch_sig": generate_terse_sch_sig,
"gen_type_sig": generate_type_signature,
"unique_schedules": unique_schedules,
"unsigned_type_with_bitwidth": unsigned_type_with_bitwidth,
"gen_type_option_name": generate_type_option_name,
}
def create_template(template_str):
template = jinja2.Template(template_str)
template.globals.update(template_globals)
return template
mm_dispatch_template = create_template(DISPATCH_TEMPLATE)
mm_impl_template = create_template(IMPL_TEMPLATE)
prepack_dispatch_template = create_template(PREPACK_TEMPLATE)
def create_sources(impl_configs: list[ImplConfig], num_impl_files=8):
sources = []
sources.append(
(
"machete_mm_dispatch",
mm_dispatch_template.render(impl_configs=impl_configs),
)
)
prepack_types = []
for impl_config in impl_configs:
convert_type = (
impl_config.types.a
if impl_config.types.b_group_scale == DataType.void
else impl_config.types.b_group_scale
)
prepack_types.append(
PrepackTypeConfig(
a=impl_config.types.a,
b_num_bits=MACHETEDataTypeSize[impl_config.types.b],
convert=convert_type,
accumulator=impl_config.types.accumulator,
)
)
def prepacked_type_key(prepack_type: PrepackTypeConfig):
# For now we we can just use the first accumulator type seen since
# the tensor core shapes/layouts don't vary based on accumulator
# type so we can generate less code this way
return (prepack_type.a, prepack_type.b_num_bits, prepack_type.convert)
unique_prepack_types = []
prepack_types_seen = set()
for prepack_type in prepack_types:
key = prepacked_type_key(prepack_type)
if key not in prepack_types_seen:
unique_prepack_types.append(prepack_type)
prepack_types_seen.add(key)
sources.append(
(
"machete_prepack",
prepack_dispatch_template.render(
types=unique_prepack_types,
),
)
)
# Split up impls across files
num_impls = reduce(lambda x, y: x + len(y.schedules), impl_configs, 0)
num_impls_per_file = math.ceil(num_impls / num_impl_files)
files_impls: list[list[ImplConfig]] = [[]]
curr_num_impls_assigned = 0
curr_impl_in_file = 0
curr_impl_configs = deepcopy(list(reversed(impl_configs)))
while curr_num_impls_assigned < num_impls:
room_left_in_file = num_impls_per_file - curr_impl_in_file
if room_left_in_file == 0:
files_impls.append([])
room_left_in_file = num_impls_per_file
curr_impl_in_file = 0
curr_ic = curr_impl_configs[-1]
if len(curr_ic.schedules) >= room_left_in_file:
# Break apart the current impl config
tmp_ic = deepcopy(curr_ic)
tmp_ic.schedules = curr_ic.schedules[:room_left_in_file]
curr_ic.schedules = curr_ic.schedules[room_left_in_file:]
files_impls[-1].append(tmp_ic)
else:
files_impls[-1].append(curr_ic)
curr_impl_configs.pop()
curr_num_impls_assigned += len(files_impls[-1][-1].schedules)
curr_impl_in_file += len(files_impls[-1][-1].schedules)
for part, file_impls in enumerate(files_impls):
sources.append(
(
f"machete_mm_impl_part{part+1}",
mm_impl_template.render(impl_configs=file_impls),
)
)
return sources
def generate():
# See csrc/quantization/machete/Readme.md, the Codegeneration for more info
# about how this works
SCRIPT_DIR = os.path.dirname(__file__)
sch_common_params = dict(
kernel_schedule=TmaMI,
epilogue_schedule=TmaCoop,
tile_scheduler=TileSchedulerType.StreamK,
)
# Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk))
default_tile_heuristic_config = {
# M = 257+
"M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)),
"M > 256": ((128, 256), (2, 1, 1)),
# M = 129-256
"M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)),
"M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)),
"M > 128": ((128, 256), (2, 1, 1)),
# M = 65-128
"M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)),
"M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)),
"M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)),
"M > 64": ((128, 128), (2, 1, 1)),
# M = 33-64
"M > 40 && K <= 6144 && N <= 6144": ((128, 32), (2, 1, 1)),
"M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)),
"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)),
"M > 32": ((128, 64), (2, 1, 1)),
# M = 17-32
"M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)),
"M > 16": ((256, 32), (2, 1, 1)),
# M = 1-16
"N >= 26624": ((256, 16), (1, 1, 1)),
None: ((128, 16), (1, 1, 1)),
}
# For now we use the same heuristic for all types
# Heuristic is currently tuned for H100s
default_heuristic = [
(cond, ScheduleConfig(*tile_config, **sch_common_params)) # type: ignore
for cond, tile_config in default_tile_heuristic_config.items()
]
def get_unique_schedules(heuristic: dict[str, ScheduleConfig]):
# Do not use schedules = list(set(...)) because we need to make sure
# the output list is deterministic; otherwise the generated kernel file
# will be non-deterministic and causes ccache miss.
schedules = []
for _, schedule_config in heuristic:
if schedule_config not in schedules:
schedules.append(schedule_config)
return schedules
impl_configs = []
GPTQ_kernel_type_configs = list(
TypeConfig(
a=a,
b=b,
b_group_scale=a,
b_group_zeropoint=DataType.void,
b_channel_scale=DataType.void,
a_token_scale=DataType.void,
out=a,
accumulator=DataType.f32,
)
for b in (MACHETEDataType.u4b8, MACHETEDataType.u8b128)
for a in (DataType.f16, DataType.bf16)
)
impl_configs += [
ImplConfig(x[0], x[1], x[2])
for x in zip(
GPTQ_kernel_type_configs,
itertools.repeat(get_unique_schedules(default_heuristic)),
itertools.repeat(default_heuristic),
)
]
output_dir = os.path.join(SCRIPT_DIR, "generated")
# Delete the "generated" directory if it exists
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
# Create the "generated" directory
os.makedirs(output_dir)
# Render each group of configurations into separate files
for filename, code in create_sources(impl_configs):
filepath = os.path.join(output_dir, f"{filename}.cu")
with open(filepath, "w") as output_file:
output_file.write(code)
print(f"Rendered template to {filepath}")
if __name__ == "__main__":
generate()

View File

@@ -0,0 +1,31 @@
#pragma once
#include "utils/machete_collective_builder.cuh"
#include "machete_mainloop.cuh"
namespace cutlass::gemm::collective {
using namespace cute;
struct MacheteKernelTag {};
template <class ElementPairA_, class GmemLayoutA_, int AlignmentA,
class ElementPairB_, class GmemLayoutB_, int AlignmentB,
class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
class StageCountType, class KernelScheduleType>
struct MacheteCollectiveBuilder<
MacheteKernelTag, arch::Sm90, arch::OpClassTensorOp, ElementPairA_,
GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_, AlignmentB,
ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType,
KernelScheduleType,
cute::enable_if_t<(
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> ||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong> ||
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedCooperative>)>> {
using CollectiveOp = machete::MacheteCollectiveMma<
ElementPairA_, GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_,
AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK,
StageCountType, KernelScheduleType>;
};
}; // namespace cutlass::gemm::collective

View File

@@ -0,0 +1,85 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
from typing import Union
from cutlass_library import (
DataType,
DataTypeNames,
DataTypeSize,
DataTypeTag,
KernelScheduleTag,
KernelScheduleType,
enum_auto,
)
#
# Extend cutlass library with custom types, and missing values
#
class MACHETEDataType(enum.Enum):
u4b8 = enum_auto()
u8b128 = enum_auto()
class MixedInputKernelScheduleType(enum.Enum):
TmaWarpSpecialized = enum_auto()
TmaWarpSpecializedPingpong = enum_auto()
TmaWarpSpecializedCooperative = enum_auto()
MACHETEDataTypeNames: dict[Union[MACHETEDataType, DataType], str] = {
**DataTypeNames, # type: ignore
**{
MACHETEDataType.u4b8: "u4b8",
MACHETEDataType.u8b128: "u8b128",
},
}
MACHETEDataTypeTag: dict[Union[MACHETEDataType, DataType], str] = {
**DataTypeTag, # type: ignore
**{
MACHETEDataType.u4b8: "cutlass::machete_uint4b8_t",
MACHETEDataType.u8b128: "cutlass::machete_uint8b128_t",
},
}
MACHETEDataTypeSize: dict[Union[MACHETEDataType, DataType], int] = {
**DataTypeSize, # type: ignore
**{
MACHETEDataType.u4b8: 4,
MACHETEDataType.u8b128: 8,
},
}
MACHETEDataTypeMACHETEScalarTypeTag: dict[Union[MACHETEDataType, DataType], str] = {
MACHETEDataType.u4b8: "machete::kU4B8",
MACHETEDataType.u8b128: "machete::kU8B128",
DataType.u4: "machete::kU4",
DataType.u8: "machete::kU8",
DataType.s4: "machete::kS4",
DataType.s8: "machete::kS8",
DataType.f16: "machete::kFloat16",
DataType.bf16: "machete::kBfloat16",
}
MACHETEDataTypePaddleDataTypeTag: dict[Union[MACHETEDataType, DataType], str] = {
DataType.u8: "paddle::DataType::UINT8",
DataType.s8: "paddle::DataType::INT8",
DataType.e4m3: "paddle::DataType::FLOAT8_E4M3FN",
DataType.s32: "paddle::DataType::INT32",
DataType.f16: "paddle::DataType::FLOAT16",
DataType.bf16: "paddle::DataType::BFLOAT16",
DataType.f32: "paddle::DataType::FLOAT32",
}
MACHETEKernelScheduleTag: dict[Union[MixedInputKernelScheduleType, KernelScheduleType], str] = {
**KernelScheduleTag, # type: ignore
**{
MixedInputKernelScheduleType.TmaWarpSpecialized: "cutlass::gemm::KernelTmaWarpSpecialized",
MixedInputKernelScheduleType.TmaWarpSpecializedPingpong: "cutlass::gemm::KernelTmaWarpSpecializedPingpong",
MixedInputKernelScheduleType.TmaWarpSpecializedCooperative: "cutlass::gemm::KernelTmaWarpSpecializedCooperative",
},
}

View File

@@ -0,0 +1,35 @@
#pragma once
#include "cutlass/cutlass.h"
#include "cute/layout.hpp"
namespace machete {
using namespace cute;
// get an interleaved block layout where each element consecutive element has a
// stride of bit_stride and the block width is blk_bit_width,
// examples:
// size_bits<T> = 8, bit_stride = 8, blk_bit_width = 32 -> 4:1
// size_bits<T> = 8, bit_stride = 16, blk_bit_width = 32 -> (2, 2):(2, 1)
// size_bits<T> = 4, bit_stride = 8, blk_bit_width = 32 -> (4, 2):(2, 1)
// size_bits<T> = 4, bit_stride = 16, blk_bit_width = 32 -> (2, 4):(4, 1)
template <typename T, int bit_stride, int blk_bit_width>
CUTE_HOST_DEVICE static constexpr auto get_interleaved_blk_layout() {
static_assert(blk_bit_width % bit_stride == 0);
static_assert(bit_stride % cute::sizeof_bits_v<T> == 0);
constexpr auto elems_per_blk = blk_bit_width / cute::sizeof_bits_v<T>;
if constexpr (cute::sizeof_bits_v<T> == bit_stride) {
// identity layout
return Layout<Shape<Int<elems_per_blk>>>{};
} else {
constexpr auto elems_per_stride = bit_stride / cute::sizeof_bits_v<T>;
constexpr auto num_strides = elems_per_blk / elems_per_stride;
return Layout<Shape<Int<num_strides>, Int<elems_per_stride>>,
Stride<Int<elems_per_stride>, Int<1>>>{};
}
}
}; // namespace machete

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,84 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "machete_mm_launcher.cuh"
#include "machete_prepack_launcher.cuh"
template <typename T>
std::optional<T> ConvertToStdOptional(const paddle::optional<T>& paddle_opt) {
return paddle_opt ? std::optional<T>(paddle_opt.get()) : std::nullopt;
}
paddle::Tensor mm(paddle::Tensor const& A, paddle::Tensor const& B,
int64_t b_type_id,
std::optional<paddle::DataType> const& maybe_out_type,
std::optional<paddle::Tensor> const& maybe_group_scales,
std::optional<paddle::Tensor> const& maybe_group_zeros,
int64_t maybe_group_size,
std::optional<paddle::Tensor> const& maybe_channel_scales,
std::optional<paddle::Tensor> const& maybe_token_scales,
std::string maybe_schedule) {
machete::ScalarType const b_type = machete::ScalarType::from_id(b_type_id);
std::optional<int64_t> maybe_group_size_opt;
std::optional<std::string> maybe_schedule_opt;
if (maybe_schedule == "") {
maybe_schedule_opt = std::nullopt;
}
return machete::mm_dispatch({.A = A,
.B = B,
.b_type = b_type,
.maybe_out_type = maybe_out_type,
.maybe_group_scales = maybe_group_scales,
.maybe_group_zeros = maybe_group_zeros,
.maybe_group_size = maybe_group_size_opt,
.maybe_channel_scales = maybe_channel_scales,
.maybe_token_scales = maybe_token_scales,
.maybe_schedule = maybe_schedule_opt});
}
std::vector<paddle::Tensor> MacheteMMKernel(
paddle::Tensor const& A, paddle::Tensor const& B,
paddle::optional<paddle::Tensor> const& maybe_group_scales,
paddle::optional<paddle::Tensor> const& maybe_group_zeros,
paddle::optional<paddle::Tensor> const& maybe_channel_scales,
paddle::optional<paddle::Tensor> const& maybe_token_scales,
std::string const& b_type_str,
std::string const& maybe_out_type_str,
int64_t const& maybe_group_size,
std::string const& maybe_schedule
) {
machete::ScalarTypeId b_type_id;
paddle::DataType maybe_out_type;
if (b_type_str == "uint4b8") {
b_type_id = machete::kU4B8.id();
} else {
PADDLE_ENFORCE(false, "b_type_str not supported!");
}
if (maybe_out_type_str == "float16") {
maybe_out_type = paddle::DataType::FLOAT16;
} else if (maybe_out_type_str == "bfloat16") {
maybe_out_type = paddle::DataType::BFLOAT16;
} else {
maybe_out_type = A.dtype();
}
auto out = mm(A, B, b_type_id, maybe_out_type,
ConvertToStdOptional<paddle::Tensor>(maybe_group_scales),
ConvertToStdOptional<paddle::Tensor>(maybe_group_zeros),
maybe_group_size,
ConvertToStdOptional<paddle::Tensor>(maybe_channel_scales),
ConvertToStdOptional<paddle::Tensor>(maybe_token_scales),
maybe_schedule);
return {out};
}

View File

@@ -0,0 +1,305 @@
#pragma once
// clang-format off
// The cutlass include order matters (annoyingly)
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
// clang-format on
#include "utils/cute_utils.cuh"
#include "utils/machete_numeric_conversion.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include "utils/paddle_utils.hpp"
#include "machete_collective_builder.cuh"
#include "machete_prepacked_layout.cuh"
#include "machete_interleaving_utils.cuh"
namespace machete {
using namespace cute;
// NOTE This kernel computes D = alpha * A * B + beta * C by computing
// D^t = alpha * B^t * A^t + beta * C^t, this is because the wgmma
// instructions only support sourcing from registers for the left-hand
// operand, we want to upconvert/decompress the quantized operand in
// register. Since the primary use case we want to support is Y = XW^t where
// W is quantized, in this situation or right-hand operand is quantized so
// we compute the transpose to move it to the left-hand side.
template <typename ElementA_, typename ElementB_, typename ElementD_,
typename AccumulatorT, typename GroupScaleT, typename GroupZeroT,
typename ChannelScaleT, typename TokenScaleT, class KernelSchedule,
typename ScheduleConfig>
struct MacheteKernelTemplate {
static constexpr bool with_C = false; // not ever used
static constexpr bool with_group_scales = !std::is_same_v<GroupScaleT, void>;
static constexpr bool with_group_zeropoints =
!std::is_same_v<GroupZeroT, void>;
static constexpr bool with_channel_scales =
!std::is_same_v<ChannelScaleT, void>;
static constexpr bool with_token_scales = !std::is_same_v<TokenScaleT, void>;
using MmaType = ElementA_;
using ElementA = ElementA_;
using ElementB = ElementB_;
using ElementD = ElementD_;
using ElementC = cute::conditional_t<with_C, ElementD, void>;
using ElementAccumulator = AccumulatorT;
using ElementCompute = AccumulatorT; // For Epilogue
// Use dummy values when we don't have scales or zeropoints
using ElementZGroup =
cute::conditional_t<with_group_zeropoints, GroupZeroT, MmaType>;
using ElementSGroup =
cute::conditional_t<with_group_scales, GroupScaleT, MmaType>;
using ElementConvertGroup =
cute::conditional_t<with_group_scales, GroupScaleT, MmaType>;
using ElementSChannel =
cute::conditional_t<with_channel_scales, ChannelScaleT, AccumulatorT>;
using ElementSToken =
cute::conditional_t<with_token_scales, TokenScaleT, AccumulatorT>;
using BTypeTuple = cute::conditional_t<
with_group_scales,
cute::conditional_t<with_group_zeropoints,
cute::tuple<ElementB, ElementSGroup, ElementZGroup>,
cute::tuple<ElementB, ElementSGroup>>,
ElementB>;
using LayoutA = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = LayoutC;
using LayoutScale = cutlass::layout::RowMajor;
// not actually used since B has the prepacked layout, but required by cutlass
using _LayoutB = cutlass::layout::ColumnMajor;
// Interface strides expected by create_arguments (will get transposed)
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
using StrideC = cutlass::detail::TagToStrideA_t<LayoutC>;
using StrideD = cutlass::detail::TagToStrideA_t<LayoutD>;
using StrideSGroup = cutlass::detail::TagToStrideA_t<LayoutScale>;
using StrideZGroup = StrideSGroup;
using LayoutA_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutA>::type;
using LayoutC_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutC>::type;
using LayoutD_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutD>::type;
using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using PrepackedLayoutB =
PrepackedLayoutBTemplate<ElementA_, ElementB_, ElementConvertGroup,
AccumulatorT, LayoutA_Transpose, KernelSchedule>;
static int constexpr TileShapeK =
128 * 8 / cutlass::sizeof_bits<MmaType>::value;
static int constexpr AlignmentA = 128 / cutlass::sizeof_bits_v<ElementA>;
static int constexpr AlignmentB = 128 / cutlass::sizeof_bits_v<ElementB>;
static int constexpr AlignmentC =
(with_C) ? 128 / cutlass::sizeof_bits_v<ElementC> : 0;
static int constexpr AlignmentD = 128 / cutlass::sizeof_bits_v<ElementD>;
using TileShape = decltype(append(typename ScheduleConfig::TileShapeNM{},
cute::Int<TileShapeK>{}));
using ClusterShape = typename ScheduleConfig::ClusterShape;
using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule;
using EpilogueTileType = typename ScheduleConfig::EpilogueTileType;
using TileScheduler = typename ScheduleConfig::TileScheduler;
static_assert(
(!with_channel_scales && !with_token_scales) ||
((with_channel_scales && with_token_scales) &&
std::is_same_v<ElementSChannel, ElementSToken>),
"Currently token and channel scales (if present) must be the same type");
// Currently only supports float scales
using ChTokScalesEpilogue =
typename fastdeploy::c3x::ScaledEpilogue<ElementAccumulator, ElementD,
TileShape>;
static_assert((with_channel_scales || with_token_scales) ||
(std::is_same_v<ElementSChannel, float> &&
std::is_same_v<ElementSToken, float>),
"Currently token and channel scales (if present) must be float "
"(and if one is present the other must be too)");
using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<
cutlass::epilogue::fusion::Sm90AccFetch>;
using EVTCompute =
std::conditional_t<with_channel_scales || with_token_scales,
typename ChTokScalesEpilogue::EVTCompute,
StoreEpilogueCompute>;
// EVTCompute
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType,
ElementAccumulator, ElementSChannel, ElementC, LayoutC_Transpose,
AlignmentC, ElementD, LayoutD_Transpose, AlignmentD, EpilogueSchedule,
EVTCompute>::CollectiveOp;
using CollectiveMainloop =
typename cutlass::gemm::collective::MacheteCollectiveBuilder<
cutlass::gemm::collective::MacheteKernelTag, ArchTag, OperatorClass,
BTypeTuple, PrepackedLayoutB, AlignmentB, ElementA, LayoutA_Transpose,
AlignmentA, ElementAccumulator, TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, // Indicates ProblemShape
CollectiveMainloop, CollectiveEpilogue, TileScheduler>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// stride_B is unused (since B is prepacked), but still required by cutlass
using _StrideB = cutlass::detail::TagToStrideB_t<_LayoutB>;
using Arguments = typename Gemm::Arguments;
using MainloopArguments = typename GemmKernel::MainloopArguments;
using EpilogueArguments = typename GemmKernel::EpilogueArguments;
static Arguments create_arguments(
cudaStream_t stream,
paddle::Tensor const& A, // MxK matrix
paddle::Tensor const& B, // KxN prepacked matrix
paddle::Tensor& D, // MxN matrix
std::optional<paddle::Tensor> const& maybe_g_scales, // scale_KxN matrix
std::optional<paddle::Tensor> const& maybe_g_zeros, // scale_KxN matrix
std::optional<int64_t> maybe_group_size,
std::optional<paddle::Tensor> const& maybe_ch_scales, // len N vector
std::optional<paddle::Tensor> const& maybe_tok_scales) // len M vector
{
static_assert(!with_group_zeropoints || with_group_scales);
int M = A.shape()[0], N = B.shape()[1], K = A.shape()[1];
PD_CHECK(D.shape()[0] == M && D.shape()[1] == N);
auto layout_A = make_cute_layout<StrideA>(A, "A");
auto layout_D = make_cute_layout<StrideD>(D, "D");
auto layout_S_group =
maybe_make_cute_layout<StrideSGroup>(maybe_g_scales, "group_scales");
auto layout_Z_group =
maybe_make_cute_layout<StrideZGroup>(maybe_g_zeros, "group_zeros");
int64_t numel_S_channel = maybe_ch_scales ? maybe_ch_scales->numel() : 0;
int64_t numel_S_token = maybe_tok_scales ? maybe_tok_scales->numel() : 0;
auto unwrap = [](auto const& t) {
return t ? t->data() : nullptr;
};
auto A_ptr = static_cast<ElementA const*>(A.data());
auto B_ptr = static_cast<ElementB const*>(B.data());
auto D_ptr = static_cast<ElementD*>(D.data());
auto S_group_ptr =
static_cast<ElementSGroup const*>(unwrap(maybe_g_scales));
auto Z_group_ptr = static_cast<ElementZGroup const*>(unwrap(maybe_g_zeros));
auto S_channel_ptr =
static_cast<ElementSChannel const*>(unwrap(maybe_ch_scales));
auto S_token_ptr =
static_cast<ElementSToken const*>(unwrap(maybe_tok_scales));
int const group_size =
maybe_group_size == -1 ? K : maybe_group_size.value_or(K);
int const scale_k = (K + group_size - 1) / group_size;
PD_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K);
PD_CHECK(size<0>(layout_D) == M && size<1>(layout_D) == N);
if constexpr (with_group_scales) {
PD_CHECK(S_group_ptr && layout_S_group);
PD_CHECK((size<0>(*layout_S_group) == scale_k &&
size<1>(*layout_S_group) == N));
} else {
PD_CHECK(!S_group_ptr, "Scales not supported");
}
if constexpr (with_group_zeropoints) {
PD_CHECK(Z_group_ptr && layout_Z_group);
PD_CHECK((size<0>(*layout_Z_group) == scale_k &&
size<1>(*layout_Z_group) == N));
PD_CHECK(layout_S_group && *layout_Z_group == *layout_S_group,
"Scales and zeros must have the same layout");
} else {
PD_CHECK(!Z_group_ptr, "Zeropoints not supported");
}
if constexpr (with_channel_scales || with_token_scales) {
PD_CHECK(
(maybe_ch_scales->numel() == N || maybe_ch_scales->numel() == 1) &&
(maybe_tok_scales->numel() == M || maybe_tok_scales->numel() == 1));
}
// Transpose A and D
// A doesn't need to be transposed since cutlass expects a NxK matrix
// for B (which is At)
auto stride_At = layout_A.stride();
auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride();
MainloopArguments mainloop_arguments{};
// {Accum, C, C_layout, D, D}
EpilogueArguments epilogue_arguments{};
if constexpr (with_channel_scales || with_token_scales) {
epilogue_arguments =
EpilogueArguments{ChTokScalesEpilogue::prepare_args(
*maybe_ch_scales, *maybe_tok_scales),
nullptr,
{},
D_ptr,
stride_Dt};
} else {
epilogue_arguments = EpilogueArguments{{}, nullptr, {}, D_ptr, stride_Dt};
}
if constexpr (with_group_scales && with_group_zeropoints) {
auto stride_S_group = permute_layout<1, 0, 2>(*layout_S_group).stride();
mainloop_arguments = MainloopArguments{
B_ptr, _StrideB{}, A_ptr, stride_At,
S_group_ptr, stride_S_group, group_size, Z_group_ptr};
} else if constexpr (with_group_scales) {
auto stride_S_group = permute_layout<1, 0, 2>(*layout_S_group).stride();
mainloop_arguments =
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At,
S_group_ptr, stride_S_group, group_size};
} else {
mainloop_arguments =
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At};
}
return Arguments{cutlass::gemm::GemmUniversalMode::kGemm,
{N, M, K, 1},
mainloop_arguments,
epilogue_arguments};
};
static size_t get_workspace_size(Arguments const& args) {
return Gemm::get_workspace_size(args);
}
static bool can_implement(Arguments const& args) {
return Gemm::can_implement(args) == cutlass::Status::kSuccess;
}
static void run(Arguments const& args, void* workspace, cudaStream_t stream) {
Gemm gemm_op;
cutlass::Status status = gemm_op.initialize(args, workspace, stream);
PD_CHECK(status == cutlass::Status::kSuccess,
"Machete kernel failed to initialize workspace");
status = gemm_op.run(stream);
PD_CHECK(status == cutlass::Status::kSuccess, "Machete kernel failed");
}
};
}; // namespace machete

View File

@@ -0,0 +1,78 @@
#pragma once
#include <Python.h>
#include "machete_mm_kernel.cuh"
#include "utils/paddle_utils.hpp"
#include "utils/scalar_type.h"
namespace machete {
struct MMArgs {
paddle::Tensor const& A;
paddle::Tensor const& B;
machete::ScalarType const& b_type;
std::optional<paddle::DataType> const& maybe_out_type;
std::optional<paddle::Tensor> const& maybe_group_scales;
std::optional<paddle::Tensor> const& maybe_group_zeros;
std::optional<int64_t> maybe_group_size;
std::optional<paddle::Tensor> const& maybe_channel_scales;
std::optional<paddle::Tensor> const& maybe_token_scales;
std::optional<std::string> maybe_schedule;
};
struct SupportedSchedulesArgs {
paddle::DataType a_type;
machete::ScalarType b_type;
std::optional<paddle::DataType> maybe_group_scales_type;
std::optional<paddle::DataType> maybe_group_zeros_type;
std::optional<paddle::DataType> maybe_channel_scales_type;
std::optional<paddle::DataType> maybe_token_scales_type;
std::optional<paddle::DataType> maybe_out_type;
};
paddle::Tensor mm_dispatch(MMArgs args);
std::vector<std::string> supported_schedules_dispatch(
SupportedSchedulesArgs args);
template <typename MacheteKernel>
paddle::Tensor run_impl(MMArgs args) {
// const at::cuda::OptionalCUDAGuard device_guard(device_of(args.A));
// auto device = args.A.device();
// auto stream = at::cuda::getCurrentCUDAStream(device.index());
auto place = args.A.place();
cudaStream_t stream = args.A.stream();
int M = args.A.shape()[0];
int N = args.B.shape()[1];
int K = args.A.shape()[1];
// Allocate output
paddle::Tensor D = paddle::empty(
{M, N},
equivalent_scalar_type_v<typename MacheteKernel::ElementD>,
place);
auto arguments = MacheteKernel::create_arguments(
stream, //
args.A, args.B, D, args.maybe_group_scales, args.maybe_group_zeros,
args.maybe_group_size, args.maybe_channel_scales,
args.maybe_token_scales);
PD_CHECK(MacheteKernel::can_implement(arguments),
"Machete kernel cannot be run with these arguments");
size_t workspace_size = MacheteKernel::get_workspace_size(arguments);
int S = static_cast<int>(workspace_size);
// phi::Allocator* allocator = paddle::GetAllocator(place);
// auto workspace = allocator->Allocate(workspace_size);
// MacheteKernel::run(arguments, workspace->ptr(), stream);
// paddle::Tensor workspace = paddle::empty({S}, paddle::DataType::UINT8, place);
paddle::Tensor workspace = GetEmptyTensor({S}, paddle::DataType::UINT8, place);
MacheteKernel::run(arguments, workspace.data(), stream);
return D;
};
}; // namespace machete

View File

@@ -0,0 +1,71 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "machete_mm_launcher.cuh"
#include "machete_prepack_launcher.cuh"
paddle::Tensor prepack_B(
paddle::Tensor const& B, paddle::DataType const& a_type, int64_t b_type_id,
std::string const& maybe_group_scales_type_str) {
machete::ScalarType const b_type = machete::ScalarType::from_id(b_type_id);
std::optional<paddle::DataType> maybe_group_scales_type;
if (maybe_group_scales_type_str == "float16") {
maybe_group_scales_type = paddle::DataType::FLOAT16;
}
else if (maybe_group_scales_type_str == "bfloat16") {
maybe_group_scales_type = paddle::DataType::BFLOAT16;
}
else if (maybe_group_scales_type_str == "float32") {
maybe_group_scales_type = paddle::DataType::FLOAT32;
}
else if (maybe_group_scales_type_str == "") {
maybe_group_scales_type = std::nullopt;
}
else {
PADDLE_ENFORCE(false, "maybe_group_scales_type_str not supported!");
}
return machete::prepack_B_dispatch(
{.B = B,
.a_type = a_type,
.b_type = b_type,
.maybe_group_scales_type = maybe_group_scales_type});
}
std::vector<paddle::Tensor> MachetePrepackBKernel(
paddle::Tensor const& B, std::string const& a_type_str, std::string const& b_type_str,
std::string const& maybe_group_scales_type_str) {
machete::ScalarTypeId b_type_id;
paddle::DataType a_type, maybe_group_scales_type;
if (b_type_str == "uint4b8") {
b_type_id = machete::kU4B8.id();
} else {
PADDLE_ENFORCE(false, "b_type_str not supported!");
}
if (a_type_str == "float16") {
a_type = paddle::DataType::FLOAT16;
}
else if (a_type_str == "bfloat16") {
a_type = paddle::DataType::BFLOAT16;
}
else {
PADDLE_ENFORCE(false, "a_type_str not supported!");
}
auto Bt = paddle::experimental::transpose(B, {1, 0});
paddle::Tensor B_prepacked = prepack_B(Bt, a_type, b_type_id, maybe_group_scales_type_str);
return {B_prepacked};
}

View File

@@ -0,0 +1,76 @@
#pragma once
#include "machete_mm_kernel.cuh"
#include "utils/cute_utils.cuh"
#include "utils/paddle_utils.hpp"
namespace machete {
template <int threads, typename PrepackedLayoutB, typename BInTensor,
typename ElementB>
static __global__ void prepack_B_kernel(BInTensor B_in, ElementB* B_out_ptr) {
auto constexpr block_size =
Int<size(typename PrepackedLayoutB::PPBlockShape_NK{})>{};
auto constexpr eles_per_thread = Int<block_size / threads>{};
static_assert(block_size % threads == 0,
"block_size must be divisible by the number of threads");
// Which pre-packed are we responsible for
auto blk_coord = make_coord(blockIdx.x, blockIdx.y, blockIdx.z);
auto tB_in = local_tile(
B_in, append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{}),
blk_coord);
// Find the start offset in the output for this pre-packed block
auto bNbKL_to_offset = PrepackedLayoutB::bNbKL_to_offset(shape(B_in));
// Tensor representing a 1:1 mapping to the output space in 1D
auto tB_out_linear =
make_tensor(get_logical_ptr(B_out_ptr) + bNbKL_to_offset(blk_coord),
make_layout(make_shape(block_size)));
// Mapping from output space (1D) to input space
auto tB_in_linear = make_tensor(
tB_in.data(),
tB_in.layout()
.compose(right_inverse(PrepackedLayoutB::ppblock_ilvd_NK_to_offset()))
.with_shape(make_shape(block_size)));
// Tile for this specific thread (could have used a TiledCopy but these work
// best with 2d layouts, this is a simple 1d layout so local_tile is enough,
// we are also not that concerned with performance for this kernel)
auto thr_tB_in_linear =
local_tile(tB_in_linear, make_shape(eles_per_thread), threadIdx.x);
auto thr_tB_out_linear =
local_tile(tB_out_linear, make_shape(eles_per_thread), threadIdx.x);
// Construct a register-backed Tensor with the same shape as each thread's
// partition
auto fragment = make_tensor<ElementB>(shape(thr_tB_in_linear));
copy(thr_tB_in_linear, fragment);
copy(Copy_Atom<DefaultCopy, uint8_t>{}, fragment, thr_tB_out_linear);
}
template <typename PrepackedLayoutB, typename InLayout>
static void prepack_B_template(
cudaStream_t stream, typename PrepackedLayoutB::ElementB const* B_in_ptr,
InLayout B_layout, typename PrepackedLayoutB::ElementB* B_out_ptr) {
using TileShapeNKL =
decltype(append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{}));
auto ilvd_NKbNbKL_to_offset =
PrepackedLayoutB::ilvd_NKbNbKL_to_offset(shape(B_layout));
PD_CHECK(size<0>(B_layout) % size<0>(TileShapeNKL{}) == 0);
PD_CHECK(size<1>(B_layout) % size<1>(TileShapeNKL{}) == 0);
auto N_tiles = size<0>(B_layout) / size<0>(TileShapeNKL{});
auto K_tiles = size<1>(B_layout) / size<1>(TileShapeNKL{});
auto L_tiles = size<2>(B_layout);
auto B_in = make_tensor(get_logical_ptr(B_in_ptr), B_layout);
prepack_B_kernel<128, PrepackedLayoutB>
<<<dim3(N_tiles, K_tiles, L_tiles), 128, 0, stream>>>(B_in, B_out_ptr);
}
}; // namespace machete

View File

@@ -0,0 +1,77 @@
#pragma once
#include "machete_prepack_kernel.cuh"
#include "utils/paddle_utils.hpp"
#include "utils/scalar_type.h"
namespace machete {
struct PrepackBArgs {
paddle::Tensor const& B;
paddle::DataType a_type;
machete::ScalarType b_type;
std::optional<paddle::DataType> maybe_group_scales_type;
};
template <typename PrepackedLayoutB>
paddle::Tensor prepack_impl(paddle::Tensor const B) {
// const at::cuda::OptionalCUDAGuard device_guard(device_of(B));
using ElementB = typename PrepackedLayoutB::ElementB;
using PPBlockShape_NK = typename PrepackedLayoutB::PPBlockShape_NK;
// auto device = B.device();
// auto stream = at::cuda::getCurrentCUDAStream(device.index());
cudaStream_t stream = B.stream();
auto B_ptr = static_cast<ElementB const*>(B.data());
// elements per storage item for B
auto eles_per_storage =
(SizeOf(B.dtype()) * 8) / cute::sizeof_bits_v<ElementB>;
// paddle B passed in is/should be (packed_K,N), the kernel expects (N,K,L) (to
// match cutlass using (N,K,L) for B), so we transpose B to (N,packed_K,L)
// auto Bt_packed = B.transpose();
auto Bt_packed = paddle::experimental::transpose(B, {1, 0});
PD_CHECK(
(B.shape()[0] * eles_per_storage) % size<1>(PPBlockShape_NK{}) == 0,
"B.shape[0] (in terms of unpacked elements) must be a multiple of ",
size<1>(PPBlockShape_NK{}));
PD_CHECK(B.shape()[1] % size<0>(PPBlockShape_NK{}) == 0,
"B.shape[1] must be a multiple of ", size<0>(PPBlockShape_NK{}));
using StrideB = cutlass::detail::TagToStrideB_t<cutlass::layout::ColumnMajor>;
auto const l_Bt_packed = make_cute_layout<StrideB>(Bt_packed, "B");
// auto const l_Bt_packed = make_cute_layout<StrideB>(B, "B");
// convert (N,packed_K,L) layout to (N,K,L) layout
// in effect we want to do: blocked_product(layout_Bt_packed,
// make_ordered_layout(make_shape(_1{}, eles_per_storage, _1{}),
// Step<_1, _0, _2>{}));
// but blocked_product does not support dynamic strides so we implement the
// equivalent manually,
// new_shape = (N, packed_K, L) * (1, eles_per_storage, 1) -> (N, K, L)
// new_stride = (s0, s1, s2) * (eles_per_storage, 1, eles_per_storage)
// when s1 == 1
PD_CHECK(stride<1>(l_Bt_packed) == 1, "stride<1>(l_Bt_packed) must be 1");
// clang-format off
auto const layout_Bt = make_layout(
transform_with_idx(l_Bt_packed.shape(), [&](auto ele, auto idx) {
return idx == 1 ? ele * eles_per_storage : ele;
}),
transform_with_idx(l_Bt_packed.stride(), [&](auto ele, auto idx) {
return idx != 1 ? ele * eles_per_storage : ele;
}));
// clang-format on
// Allocate output
paddle::Tensor D = paddle::empty_like(B);
prepack_B_template<PrepackedLayoutB>(
stream, B_ptr, layout_Bt, static_cast<ElementB*>(D.data()));
return D;
};
paddle::Tensor prepack_B_dispatch(PrepackBArgs args);
}; // namespace machete

View File

@@ -0,0 +1,249 @@
#pragma once
// clang-format off
// The cutlass include order matters (annoyingly)
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
// clang-format on
#include "utils/cute_utils.cuh"
#include "machete_collective_builder.cuh"
#include "machete_interleaving_utils.cuh"
namespace machete {
using namespace cute;
struct IlvBlkLayoutAuto {};
// This defines a prepacked layout for the B matrix, where the matrix is broken
// up into PPBlockShape_NK blocks. The data within each block is then compactly
// stored in memory such that when performing a TiledMMA operation with the same
// shape as prepacked block, all the data for a given thread is contiguous in
// memory. This allows us to use wider shared memory loads when loading B from
// shared memory. The values within a thread are also potentially interlaeved
// inorder to allow for more efficient upconverting.
//
// The contract here is that the `TiledMma` determined below matches the one
// ultimately used in the kernel. (this is also why the other element types are
// required along with the kernel schedule)
template <typename ElementA_, typename ElementB_, typename ElementConvert_,
typename AccumulatorT, class LayoutB, class KernelSchedule,
typename IlvBlkLayout_ = IlvBlkLayoutAuto>
// clang-format on
struct PrepackedLayoutBTemplate {
using MmaType = ElementA_;
using ElementA = ElementA_;
using ElementB = ElementB_;
using ElementAccumulator = AccumulatorT;
using ElementMma = MmaType;
// Interleave for 4bit bit types when we are not upconverting to fp8 or int8,
// in those cases case we use a LUT using prmt instructions to upconvert and
// is more efficient if the data is not interleaved For 8bit+ prmt
// instructions makes non-interleaved layouts efficient enough we don't need
// iterleaved layouts (and can reuse more of the existing cutlass converts)
static constexpr bool should_interleave =
sizeof_bits_v<ElementB> <= 4 &&
!std::is_same_v<ElementConvert_, cutlass::float_e4m3_t> &&
!std::is_same_v<ElementConvert_, int8_t>;
// Only use interleaved layouts for subbyte weights,
using IlvdBlkLayout = std::conditional_t<
std::is_same_v<IlvBlkLayout_, IlvBlkLayoutAuto>,
std::conditional_t<
should_interleave,
decltype(get_interleaved_blk_layout<
ElementB, sizeof_bits_v<ElementConvert_>, 32>()),
void>,
IlvBlkLayout_>;
// TODO (LucasWilkinson): compare the performance for other sizes
// Prepacked block shape, smallest layout atom for loading into registers
// (can contain multiple wgmma instructions worth of data in one block)
// We ideally want this to be configured such that a thread can perform 128bit
// loads, i.e. we amount of data associated with each thread within a
// prepacked block is a multiple of 128bits, when using a cooperative sechdule
// we have 256 threads working a single block at a time, this means each
// thread works on `sizeof_bits_v<ElementB> * (128*64) / 256` bits of data,
// for a 4bit type this would be 128bits
using PPBlockShape_NK = Shape<_128, _64>;
// Create the shape of the tile anticipated to be used by the GEMM kernel,
// when the kernel executes we will compute `Ct = Bt * At` since the
// quantized weights (B), must be the lhs operand so the flow through
// registers.
// The _128 here doesn't actually impact the shape of the stored tile directly
// but may impact the op selected by rs_op_selector
using GemmTileShape = decltype(make_shape(size<0>(PPBlockShape_NK{}), _128{},
size<1>(PPBlockShape_NK{})));
static constexpr cute::GMMA::Major GmmaMajorB =
gmma_rs_tag_to_major_B<LayoutB>();
// For coop schedules we have two warp groups cooperatively issuing wgmma
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
using AtomLayoutMNK = cute::conditional_t<
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperative>,
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
using TiledMma = decltype(cute::make_tiled_mma(
cute::GMMA::rs_op_selector<ElementMma, ElementMma, ElementAccumulator,
GemmTileShape, GMMA::Major::K, GmmaMajorB>(),
AtomLayoutMNK{}));
// Prepacked block, (athrid, val) -> (N,K)
// i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (N,K)
CUTE_HOST_DEVICE static constexpr auto ppblock_TV_to_NK() {
return TiledMma{}.thrfrg_A(make_layout(PPBlockShape_NK{}));
}
// Prepacked block, (N,K) -> (athrid, val)
// i.e. (N,K) -> ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...)))
CUTE_HOST_DEVICE static constexpr auto ppblock_NK_to_TV() {
return right_inverse(ppblock_TV_to_NK()).with_shape(PPBlockShape_NK{});
}
// Prepacked block, (athrid, val) -> (storage_offset)
// i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (storage_idx)
CUTE_HOST_DEVICE static constexpr auto ppblock_TV_to_offset() {
// Return iterleaved layout
return make_ordered_layout(shape(ppblock_TV_to_NK()), Step<_1, _0>{});
}
// Prepacked block, (athrid, val) -> (storage_offset)
// i.e. ((ThrV,(ThrM,ThrK)),(IlvdFrgV,(RestM,RestK,...))) -> (storage_idx)
CUTE_HOST_DEVICE static constexpr auto ppblock_ilvd_TV_to_offset() {
auto layout_no_interleave =
make_ordered_layout(shape(ppblock_TV_to_NK()), Step<_1, _0>{});
if constexpr (std::is_same_v<IlvdBlkLayout, void>) {
return layout_no_interleave;
} else {
// interleave by transforming FrgV into interleaved blocks where each
// block has the layout IlvdBlkLayout, for example if IlvdBlkLayout is
// (2, 2) : (2, 1) then we get: ((2, 2), size(FrgV) / 4) : ((2, 1), 4)
// if FrgV is {A, B, C, D, E, F, G, H}
// then ((IlvBlk), FrgB) is {A, C, B, D, C, G, D, H}
auto frgV = get<1, 0>(layout_no_interleave);
auto ilvdBlk = IlvdBlkLayout{};
static_assert(size(frgV) % size(ilvdBlk) == 0,
"FrgV must be divisible by size(ilvdBlk)");
auto ilvd_FrgV = make_layout(
make_shape(shape(ilvdBlk), Int<size(frgV) / size(ilvdBlk)>{}),
make_stride(stride(ilvdBlk), size(ilvdBlk)));
// Return iterleaved layout
return make_layout(
get<0>(layout_no_interleave),
make_layout(ilvd_FrgV, get<1, 1>(layout_no_interleave)));
}
}
// Prepacked block, (M,K) -> (storage_offset)
CUTE_HOST_DEVICE static constexpr auto ppblock_ilvd_NK_to_offset() {
// do (M,K) -> (athrid, val) -> (storage_idx)
return ppblock_ilvd_TV_to_offset().compose(ppblock_NK_to_TV());
}
// ((athrid, val), (BlocksN, BlocksK), L) -> (storage_idx)
template <typename Shape_NKL>
CUTE_HOST_DEVICE static constexpr auto TVbNbKL_to_offset(
Shape_NKL shape_mkl) {
constexpr auto block_layout = ppblock_TV_to_offset();
// (BlocksN, BlocksK, L)
auto blocks_shape =
cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
[](auto x, auto y) { return x / y; });
// ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx)
auto result = make_layout(
block_layout,
make_layout(blocks_shape,
compact_col_major(blocks_shape, size(block_layout))));
// ((athrid, val), (BlocksN, BlocksK, L))
// => ((athrid, val), (BlocksN, BlocksK), L)
return group<1, 3>(result(_, repeat<rank<1>(result)>(_)));
}
// ((athrid_val), (BlocksN, BlocksK, L)) -> (N, K, L)
template <typename Shape_NKL>
CUTE_HOST_DEVICE static constexpr auto TVbNbKL_to_offset_copy(
Shape_NKL shape_mkl) {
auto layout = TVbNbKL_to_offset(shape_mkl);
// for 4-bit elements, having >= 64 values per column
// allows TMA to load full 32-byte sectors
auto inner_layout =
make_layout(make_shape(_256{}, size<0>(layout) / _256{}));
return make_layout(inner_layout, get<1>(layout), get<2>(layout));
}
// ((BlockN, BlockK), (BlocksN, BlocksK), L) -> (storage_idx)
template <typename Shape_NKL>
CUTE_HOST_DEVICE static constexpr auto ilvd_NKbNbKL_to_offset(
Shape_NKL shape_mkl) {
constexpr auto block_layout = ppblock_ilvd_NK_to_offset();
// (BlocksN, BlocksK, L)
auto blocks_shape =
cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
[](auto x, auto y) { return x / y; });
// ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx)
auto result = make_layout(
block_layout,
make_layout(blocks_shape,
compact_col_major(blocks_shape, size(block_layout))));
// ((athrid, val), (BlocksN, BlocksK, L)) => ((athrid, val), (BlocksN,
// BlocksK), L)
return group<1, 3>(result(_, repeat<rank<1>(result)>(_)));
}
// (BlocksN, BlocksK, L) -> (storage_idx)
template <typename Shape_NKL>
CUTE_HOST_DEVICE static constexpr auto bNbKL_to_offset(Shape_NKL shape_mkl) {
// (BlocksN, BlocksK, L)
auto blocks_shape =
cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
[](auto x, auto y) { return x / y; });
auto stride = size(PPBlockShape_NK{});
// (BlocksN, BlocksK, L) -> (storage_idx)
return make_layout(blocks_shape, compact_col_major(blocks_shape, stride));
}
// ((athrid, val), (BlocksN, BlocksK, L)) -> (N, K, L)
template <class Shape_NKL>
CUTE_HOST_DEVICE static auto TVbNbK_to_NKL(Shape_NKL shape_mkl) {
auto tile = make_tile(make_layout(size<0>(PPBlockShape_NK{})),
make_layout(size<1>(PPBlockShape_NK{})));
// ((BlockN, BlockK), (BlocksN, BlocksK, L)) -> (N, K, L)
auto tiled_A = zipped_divide(make_layout(shape_mkl), tile);
return tiled_A.compose(ppblock_TV_to_NK(), _);
}
// (N, K, L) -> ((athrid, val), (BlocksN, BlocksK), L)
template <class Shape_NKL>
CUTE_HOST_DEVICE static auto NKL_to_TVbNbK(Shape_NKL shape_mkl) {
auto TVbNbK_to_NKL_layout = TVbNbK_to_NKL(shape_mkl);
return blocked_product(ppblock_NK_to_TV(),
make_layout(shape<1>(TVbNbK_to_NKL_layout)));
}
};
}; // namespace machete

View File

@@ -0,0 +1,72 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "machete_mm_launcher.cuh"
#include "machete_prepack_launcher.cuh"
template <typename T>
std::optional<T> ConvertToStdOptional(const paddle::optional<T>& paddle_opt) {
return paddle_opt ? std::optional<T>(paddle_opt.get()) : std::nullopt;
}
std::vector<std::string> supported_schedules(
paddle::DataType a_type, int64_t b_type_id,
std::optional<paddle::DataType> maybe_group_scales_type,
std::optional<paddle::DataType> maybe_group_zeros_type,
std::optional<paddle::DataType> maybe_channel_scales_type,
std::optional<paddle::DataType> maybe_token_scales_type,
std::optional<paddle::DataType> maybe_out_type) {
machete::ScalarType const b_type = machete::ScalarType::from_id(b_type_id);
auto schedules = machete::supported_schedules_dispatch({
.a_type = a_type,
.b_type = b_type,
.maybe_group_scales_type = maybe_group_scales_type,
.maybe_group_zeros_type = maybe_group_zeros_type,
.maybe_channel_scales_type = maybe_channel_scales_type,
.maybe_token_scales_type = maybe_token_scales_type,
.maybe_out_type = maybe_out_type
});
return schedules;
}
std::vector<std::string> MacheteSupportedSchedules(
std::string const& a_type_str, std::string const& b_type_str) {
machete::ScalarTypeId b_type_id;
paddle::DataType a_type;
if (b_type_str == "uint4b8") {
b_type_id = machete::kU4B8.id();
} else {
PADDLE_ENFORCE(false, "b_type_str not supported!");
}
if (a_type_str == "bfloat16") {
a_type = paddle::DataType::BFLOAT16;
} else if (a_type_str == "float16") {
a_type = paddle::DataType::FLOAT16;
} else {
PADDLE_ENFORCE(false, "a_type_str not supported!");
}
std::optional<paddle::DataType> maybe_group_scales_type = std::optional<paddle::DataType>(a_type);
std::optional<paddle::DataType> maybe_out_type = std::optional<paddle::DataType>(a_type);
std::optional<paddle::DataType> maybe_group_zeros_type = std::nullopt;
std::optional<paddle::DataType> maybe_channel_scales_type = std::nullopt;
std::optional<paddle::DataType> maybe_token_scales_type = std::nullopt;
auto schedules = supported_schedules(a_type, b_type_id,
maybe_group_scales_type,
maybe_group_zeros_type,
maybe_channel_scales_type,
maybe_token_scales_type,
maybe_out_type);
return schedules;
}

View File

@@ -0,0 +1,69 @@
// adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/cutlass_extensions/cute_utils.cuh
#pragma once
#include <cute/tensor.hpp>
namespace cute {
////////////////////////////////////////////////////////////////////
// layout utils
////////////////////////////////////////////////////////////////////
// Permute layout based on indices, example:
// permute_layout<1, 0>(layout) will swap the two dimensions
// permute_layout<0, 2, 1>(layout) will swap the last two dimensions
template <size_t... I, typename Layout>
CUTE_HOST_DEVICE static constexpr auto permute_layout(Layout l) {
static_assert(rank(l) == sizeof...(I), "Invalid permutation, rank mismatch");
return cute::make_layout(cute::get<I>(l)...);
}
// is the layout f(x) = x
template <typename Layout>
CUTE_HOST_DEVICE static constexpr bool is_identity_layout() {
if constexpr (std::is_same_v<Layout, void>) {
return true;
} else {
constexpr auto coalesced_layout = coalesce(Layout{});
if constexpr (rank(coalesced_layout) == 1 &&
stride<0>(coalesced_layout) == 1) {
return true;
}
return false;
}
}
////////////////////////////////////////////////////////////////////
// Pointer utils
////////////////////////////////////////////////////////////////////
template <class PointerType>
static constexpr auto get_logical_ptr(PointerType* ptr) {
if constexpr (cute::sizeof_bits_v<PointerType> < 8) {
return cute::subbyte_iterator<PointerType>(ptr);
} else {
return ptr;
}
}
////////////////////////////////////////////////////////////////////
// Misc utils
////////////////////////////////////////////////////////////////////
template <typename T, typename Elements>
CUTE_HOST_DEVICE static constexpr auto create_auto_vectorizing_copy() {
constexpr auto bits = sizeof_bits_v<T> * Elements{};
if constexpr (bits % 128 == 0) {
return AutoVectorizingCopyWithAssumedAlignment<128>{};
} else if constexpr (bits % 64 == 0) {
return AutoVectorizingCopyWithAssumedAlignment<64>{};
} else if constexpr (bits % 32 == 0) {
return AutoVectorizingCopyWithAssumedAlignment<32>{};
} else if constexpr (bits % 16 == 0) {
return AutoVectorizingCopyWithAssumedAlignment<16>{};
} else {
return AutoVectorizingCopyWithAssumedAlignment<8>{};
}
}
}; // namespace cute

View File

@@ -0,0 +1,44 @@
// adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/cutlass_extensions/vllm_collective_builder.cuh
#pragma once
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
namespace cutlass::gemm::collective {
using namespace cute;
//
// MacheteCollectiveBuilder is a wrapper around CollectiveBuilder that allows for
// for custom kernel tags, allowing you to build custom collectives. Without
// touching the cutlass library headers, using `CutlassKernelTag` will mean it
// will resort to using the standard cutlass collective builder.
//
// Use the default Cutlass collective builder, i.e. use an unmodified cutless
// collective
struct CutlassKernelTag {};
template <class KernelTag, class ArchTag, class OpClass, class ElementA,
class GmemLayoutA, int AlignmentA, class ElementB, class GmemLayoutB,
int AlignmentB, class ElementAccumulator, class TileShape_MNK,
class ClusterShape_MNK, class StageCountType,
class KernelScheduleType, class Enable = void>
struct MacheteCollectiveBuilder {
static_assert(sizeof(ElementA) == 0,
"Could not build a collective for given parameters.");
};
template <class ArchTag, class OpClass, class ElementA, class GmemLayoutA,
int AlignmentA, class ElementB, class GmemLayoutB, int AlignmentB,
class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
class StageCountType, class KernelScheduleType>
struct MacheteCollectiveBuilder<
CutlassKernelTag, ArchTag, OpClass, ElementA, GmemLayoutA, AlignmentA,
ElementB, GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK,
ClusterShape_MNK, StageCountType, KernelScheduleType> {
using CollectiveOp = typename CollectiveBuilder<
ArchTag, OpClass, ElementA, GmemLayoutA, AlignmentA, ElementB,
GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK,
ClusterShape_MNK, StageCountType, KernelScheduleType>::CollectiveOp;
};
}; // namespace cutlass::gemm::collective

View File

@@ -0,0 +1,51 @@
// adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/cutlass_extensions/vllm_custom_types.cuh
#pragma once
#include "cutlass/integer_subbyte.h"
namespace cutlass {
///////////////////////////////////////////////////////////////////////////////////////////////////
template <int Bits, int Bias, bool Signed = false>
struct machete_biased_integer_subbyte : public integer_subbyte<Bits, Signed> {
using Base = integer_subbyte<Bits, Signed>;
using Storage = typename Base::Storage;
using xint_t = typename Base::xint_t;
using Base::bits_mask_;
using Base::sign_mask_;
using Base::storage;
//
// Methods
//
/// No operation
machete_biased_integer_subbyte() = default;
/// Conversion from integer type
CUTLASS_HOST_DEVICE explicit machete_biased_integer_subbyte(int value)
: Base(value) {}
CUTLASS_HOST_DEVICE explicit machete_biased_integer_subbyte(unsigned value)
: Base(value) {}
CUTLASS_HOST_DEVICE explicit machete_biased_integer_subbyte(double value)
: Base(value) {}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
// "GPTQ" types, i.e. symmetric quantization
using machete_uint4b8_t = machete_biased_integer_subbyte<4, 8>; // u4b8
using machete_uint8b128_t = machete_biased_integer_subbyte<8, 128>; // u8b128
///////////////////////////////////////////////////////////////////////////////////////////////////
template <int Bits, int Bias, bool Signed>
struct sizeof_bits<machete_biased_integer_subbyte<Bits, Bias, Signed>> {
static constexpr int value = Bits;
};
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass

View File

@@ -0,0 +1,993 @@
// adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/cutlass_extensions/vllm_numeric_conversion.cuh
#pragma once
#include "cutlass/numeric_conversion.h"
#include "machete_custom_types.cuh"
#include "cute_utils.cuh"
#include "machete_type_utils.cuh"
// this file extends:
// https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h
// with vllm specific type conversions, namely: machete_uint4b8_t, machete_uint8b128_t
// as well as adds interleaved numeric array converters for specific types.
// (interleaved numeric array converters can be more efficient for subbyte
// types)
namespace cutlass {
// InterleavedNumericArrayConverter is like NumericArrayConverter but also
// deinterleaves converted elements based on IlvBlkLayout, interleaving can
// make subbyte converts more efficient by allowing for efficient extraction
// of subbyte elements from a 32bit register.
template <typename IlvBlkLayout, typename T, typename S, int N,
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest,
class Enable = void>
struct InterleavedNumericArrayConverter {
using Converter = NumericArrayConverter<T, S, N, Round>;
using result_type = typename Converter::result_type;
using source_type = typename Converter::source_type;
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
if (cute::elect_one_sync()) {
if constexpr (std::is_same_v<IlvBlkLayout, void>) {
printf(
"Convert %s <= %s (N = %d, IlvBlkLayout = void), not implemented\n",
nameof_v<T>, nameof_v<S>, N);
} else {
printf(
"Convert %s <= %s (N = %d, size(IlvBlkLayout{}) = %d), not "
"implemented\n",
nameof_v<T>, nameof_v<S>, N, size(IlvBlkLayout{}));
}
__brkpt();
}
return {};
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
template <typename IlvBlkLayout, typename T, typename S, int N,
FloatRoundStyle Round>
struct InterleavedNumericArrayConverter<
IlvBlkLayout, T, S, N, Round,
std::enable_if_t<is_identity_layout<IlvBlkLayout>()>> {
using Converter = NumericArrayConverter<T, S, N, Round>;
using result_type = typename Converter::result_type;
using source_type = typename Converter::source_type;
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return Converter::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
template <typename RegConvert32bit, typename T, typename S, int N>
struct ArrayConverterPacked32Bit {
using result_type = Array<T, N>;
using source_type = Array<S, N>;
using result_packed_8_t = Array<T, 8>;
using result_packed_4_t = Array<T, 4>;
using result_packed_2_t = Array<T, 2>;
using src_packed_8_t = Array<S, 8>;
using src_packed_4_t = Array<S, 4>;
using src_packed_2_t = Array<S, 2>;
static_assert(N % 2 == 0, "N must be a multiple of 2");
static_assert(cutlass::sizeof_bits_v<S> >= 4); // TODO: add 16 packed sources
static_assert(32 % cutlass::sizeof_bits_v<S> == 0);
static constexpr auto src_elems_per_32bit_reg =
32 / cutlass::sizeof_bits_v<S>;
// Maybe not Valid. ScalarConverter will not actually work unless
// NumericConverter<T, S, Round> is implemented. However it won't be used
// anyways since we assert N % 2 == 0, just here for compliance with
// VectorizedConverter.
using ScalarConverter = NumericConverter<T, S>;
template <typename PackedSrc>
CUTLASS_DEVICE static auto to_regs(PackedSrc const& src) {
if constexpr (sizeof(PackedSrc) == 1) {
return Array<uint32_t, 1>{reinterpret_cast<uint8_t const&>(src)};
} else if constexpr (sizeof(PackedSrc) == 2) {
return Array<uint32_t, 1>{reinterpret_cast<uint16_t const&>(src)};
} else if constexpr (sizeof(PackedSrc) == 4) {
return Array<uint32_t, 1>{reinterpret_cast<uint32_t const&>(src)};
} else {
static_assert(sizeof(PackedSrc) == 8);
return reinterpret_cast<Array<uint32_t, 2> const&>(src);
}
}
// The core converter uses bit tricks to construct a known FP16 number, then
// does a subtraction in FP16 for the final result.
template <typename PackedResultType, typename PackedSrcType>
CUTLASS_DEVICE static PackedResultType packed_convert(
PackedSrcType const& source) {
static_assert(PackedSrcType::kElements == PackedResultType::kElements);
static_assert(PackedResultType::kElements == 2 ||
PackedResultType::kElements == 4 ||
PackedResultType::kElements == 8,
"Invalid PackedResultType must be 2, 4 or 8.");
static_assert(std::is_same_v<typename PackedSrcType::Element, S>);
static_assert(std::is_same_v<typename PackedResultType::Element, T>);
return RegConvert32bit::template convert<PackedResultType>(to_regs(source));
}
friend class detail::VectorizedConverter;
public:
CUTLASS_DEVICE static result_type convert(source_type const& source) {
result_type result;
using ConverterType =
ArrayConverterPacked32Bit<RegConvert32bit,
typename result_type::Element,
typename source_type::Element, N>;
if constexpr (src_elems_per_32bit_reg >= 8) {
detail::VectorizedConverter::convert<
ConverterType, result_packed_8_t, src_packed_8_t, result_packed_4_t,
src_packed_4_t, result_packed_2_t, src_packed_2_t>(result, source);
} else if constexpr (src_elems_per_32bit_reg >= 4) {
detail::VectorizedConverter::convert<ConverterType, result_packed_4_t,
src_packed_4_t, result_packed_2_t,
src_packed_2_t>(result, source);
} else {
detail::VectorizedConverter::convert<ConverterType, result_packed_2_t,
src_packed_2_t>(result, source);
}
return result;
}
};
// Convert 8 4bit values packed into a 32bit register to 8 8bit values packed
// into 2 32bit register.
template <uint8_t LUT0, uint8_t LUT1, uint8_t LUT2, uint8_t LUT3, //
uint8_t LUT4, uint8_t LUT5, uint8_t LUT6, uint8_t LUT7, //
uint8_t LUT8, uint8_t LUT9, uint8_t LUT10, uint8_t LUT11, //
uint8_t LUT12, uint8_t LUT13, uint8_t LUT14, uint8_t LUT15>
CUTLASS_DEVICE cutlass::AlignedArray<uint32_t, 2> lut_4bit_to_8bit_convert(
uint32_t src) {
cutlass::AlignedArray<uint32_t, 2> r;
// Determines if the value is in the top half of the LUT if set or
// (i.e. LUT[8:15]) in the bottom half (i.e. LUT[0:7]) if not set. Then move
// into bit position 0x4 of each nibble so when or'd with final_prmt_base it
// selects the correct candidate. When elements in final_prmt_base
// are >= 0x4, the high candidate is selected (i.e. LUT[8:15]), when elements
// are < 0x4, the low candidate is selected (i.e. LUT[0:7])
uint32_t high_bit = (src & 0x88888888) >> 1;
// `high_bit` is OR'd with 0x31203120 to find the correct value in the LUT
// (selects correct high or low candidate)
const uint32_t final_prmt_base = 0x32103210;
// Ignore the high bit when indexing into LUT, for each 4bit value
// we index into both the high and low candidates then use
// high_bit | final_prmt_base to select the correct candidate
uint32_t lut_idx = (src & 0x77777777);
auto pack = [](uint8_t a, uint8_t b, uint8_t c, uint8_t d) {
return uint32_t(a) | (uint32_t(b) << 8) | (uint32_t(c) << 16) |
(uint32_t(d) << 24);
};
static constexpr uint32_t LOW_0 = pack(LUT0, LUT1, LUT2, LUT3);
static constexpr uint32_t LOW_1 = pack(LUT4, LUT5, LUT6, LUT7);
static constexpr uint32_t HIGH_0 = pack(LUT8, LUT9, LUT10, LUT11);
static constexpr uint32_t HIGH_1 = pack(LUT12, LUT13, LUT14, LUT15);
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < 2; ++ii, lut_idx >>= 16, high_bit >>= 16) {
uint32_t final_prmt_idx = final_prmt_base | high_bit;
// This uses a look up table to convert packed int4s to packed int8s,
// using the int4 value as the index to prmt. It first select both the
// high and low candidates, then uses the high bit (i.e. `high_bit`) to
// select the correct candidate.
asm volatile(
"{\n"
" .reg .b32 low, high;\n"
" prmt.b32 low, %1, %2, %5;\n"
" prmt.b32 high, %3, %4, %5;\n"
" prmt.b32 %0, low, high, %6;\n"
"}\n"
: "=r"(r[ii])
: "n"(LOW_0), "n"(LOW_1), "n"(HIGH_0), "n"(HIGH_1), "r"(lut_idx),
"r"(final_prmt_idx));
}
return r;
};
// for Array<int8_t, N> <= Array<machete_uint4b8_t, N>
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<int8_t, machete_uint4b8_t, N, Round> {
using result_type = Array<int8_t, N>;
using source_type = Array<machete_uint4b8_t, N>;
static FloatRoundStyle const round_style = Round;
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
// [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7] as int8s
auto r = lut_4bit_to_8bit_convert<0xF8, 0xF9, 0xFA, 0xFB, //
0xFC, 0xFD, 0xFE, 0xFF, //
0x00, 0x01, 0x02, 0x03, //
0x04, 0x05, 0x06, 0x07>(src_[0]);
return reinterpret_cast<PackedResultType&>(r);
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// for Array<cutlass::float_e4m3_t, N> <= Array<machete_uint4b8_t, N>
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<cutlass::float_e4m3_t, machete_uint4b8_t, N, Round> {
using result_type = Array<cutlass::float_e4m3_t, N>;
using source_type = Array<machete_uint4b8_t, N>;
static FloatRoundStyle const round_style = Round;
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
// [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7] as fp8s
auto r = lut_4bit_to_8bit_convert<0xD0, 0xCE, 0xCC, 0xCA, //
0xC8, 0xC4, 0xC0, 0xB8, //
0x00, 0x38, 0x40, 0x44, //
0x48, 0x4A, 0x4C, 0x4E>(src_[0]);
return reinterpret_cast<PackedResultType&>(r);
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// for Array<cutlass::half_t, N> <= Array<machete_uint4b8_t, N>
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<cutlass::half_t, machete_uint4b8_t, N, Round> {
using result_type = Array<cutlass::half_t, N>;
using source_type = Array<machete_uint4b8_t, N>;
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src = src_[0];
using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>;
RegArray r;
// Below constructs the following temporary:
// fp16s_01 = {0x00, i4_01, 0x00, i4_01}
// fp16s_23 = {0x00, i4_23, 0x00, i4_23}
// fp16s_45 = {0x00, i4_45, 0x00, i4_45}
// fp16s_67 = {0x00, i4_67, 0x00, i4_67}
// We use inline asm instead of __byte_perm intrinsic since we don't want
// the documented (& 0x7) on the index. NVCC might be able to optimize it
// out since the index is a constexpr, but we choose to be safe about it
// here.
uint32_t prmt_indices[4] = {0x4040, 0x4141, 0x4242, 0x4343};
static_assert(RegArray::kElements <= 4,
"Too many inputs for F16 -> I4 vector converter");
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
asm volatile(
"{\n"
" prmt.b32 %0, %1, %2, %3;\n"
"}\n"
: "=r"(r[ii])
: "r"(src), "n"(0), "r"(prmt_indices[ii]));
}
// Since the stored 4bit values are biased by 8 we get stored_val = (x+8)
// we are trying to construct x and a fp16 value
// The below XOR does the following:
// 1) Sets the exponent bits of the FP16 to the correct value for the
// FP16 magic_num. We will be constructing {1024+16*(x1+8), 1024+(x0+8)},
// where x1 in the high nibble and x0 is the low nibble then using hfma
// to subtract 1032 from that
// The AND does the following:
// 1) Clear the set bits for the int4 we will ignore.
// We use lop3 so that we can use 1 instruction for AND and XOR.
static constexpr uint32_t xor_mask = 0x64006400;
static constexpr uint32_t and_mask = 0xFFF0FF0F;
static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa;
// For each operand, computes:
// r[i] = (r[i] & and_mask) ^ xor_mask
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
asm volatile(
"{\n"
" lop3.b32 %0, %0, %1, %2, %3;\n"
"}\n"
: "+r"(r[ii])
: "n"(and_mask), "n"(xor_mask), "n"(immLut));
}
// We will issue 2 hfmas that do the following:
// {x1, x0} = {1024+16*(x1+8), 1024+(x0+8)} * {1/16, 1} - {72, 1032}
// = {x1 + 1152, x0 + 1032} * {1/16, 1} - {72, 1032}
static constexpr uint32_t hfma_bias_rep = 0xD480E408; // {72, 1032}
static constexpr uint32_t hfma_scale_rep = 0x2C003C00; // {1 / 16, 1}
const half2& hfma_bias = reinterpret_cast<const half2&>(hfma_bias_rep);
const half2& hfma_scale = reinterpret_cast<const half2&>(hfma_scale_rep);
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]);
fp16x2_val = __hfma2(hfma_scale, fp16x2_val, hfma_bias);
}
return reinterpret_cast<PackedResultType&>(r);
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// for Array<cutlass::half_t, N> <= Array<machete_uint4b8_t, N>
// for IlvdLayout: (2, 4):(4, 1)
template <FloatRoundStyle Round, int N>
struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
cutlass::half_t, machete_uint4b8_t, N,
Round, void> {
using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
static_assert(N % size(IlvdLayout{}) == 0);
using result_type = Array<cutlass::half_t, N>;
using source_type = Array<machete_uint4b8_t, N>;
static FloatRoundStyle const round_style = Round;
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src = src_[0];
using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>;
RegArray r;
static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
static constexpr uint32_t xor_mask = 0x64006400;
for (int ii = 0; ii < RegArray::kElements; ii += 2) {
auto src_ = src >> (4 * (ii));
r[ii + 0] = src_;
r[ii + 1] = src_;
static constexpr uint32_t and_xor_imm_lut = (0xf0 & 0xcc) ^ 0xaa;
static constexpr uint32_t low_nib_mask = 0x000F000F;
static constexpr uint32_t high_nib_mask = 0x00F000F0;
asm volatile(
"{\n"
" lop3.b32 %0, %0, %1, %2, %3;\n"
"}\n"
: "+r"(r[ii + 0])
: "n"(low_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
asm volatile(
"{\n"
" lop3.b32 %0, %0, %1, %2, %3;\n"
"}\n"
: "+r"(r[ii + 1])
: "n"(high_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
// For low nibble:
// {x1, x0} = {1024+(x1+8), 1024+(x0+8)} * {1, 1} - {1032, 1032}
// For high nibble:
// {x1, x0} = {1024+16*(x1+8), 1024+16*(x0+8)} * {1/16, 1/16}
// - {72, 72}
static constexpr uint32_t low_nib_bias = 0x64086408; // {1032, 1032}
static constexpr uint32_t high_nib_scale = 0x2C002C00; // {1/16, 1/16}
static constexpr uint32_t high_nib_bias = 0xD480D480; // {-72, -72}
{
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]);
fp16x2_val =
__hsub2(fp16x2_val, reinterpret_cast<const half2&>(low_nib_bias));
}
{
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]);
fp16x2_val = __hfma2(fp16x2_val,
reinterpret_cast<const half2&>(high_nib_scale),
reinterpret_cast<const half2&>(high_nib_bias));
}
}
return reinterpret_cast<PackedResultType&>(r);
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// for Array<cutlass::half_t, N> <= Array<uint4_t, N>
// for IlvdLayout: (2, 4):(4, 1)
template <FloatRoundStyle Round, int N>
struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
cutlass::half_t, uint4_t, N, Round,
void> {
using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
static_assert(N % size(IlvdLayout{}) == 0);
using result_type = Array<cutlass::half_t, N>;
using source_type = Array<uint4_t, N>;
static FloatRoundStyle const round_style = Round;
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src = src_[0];
using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>;
RegArray r;
static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
static constexpr uint32_t xor_mask = 0x64006400;
for (int ii = 0; ii < RegArray::kElements; ii += 2) {
auto src_ = src >> (4 * (ii));
r[ii + 0] = src_;
r[ii + 1] = src_;
static constexpr uint32_t and_xor_imm_lut = (0xf0 & 0xcc) ^ 0xaa;
static constexpr uint32_t low_nib_mask = 0x000F000F;
static constexpr uint32_t high_nib_mask = 0x00F000F0;
asm volatile(
"{\n"
" lop3.b32 %0, %0, %1, %2, %3;\n"
"}\n"
: "+r"(r[ii + 0])
: "n"(low_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
asm volatile(
"{\n"
" lop3.b32 %0, %0, %1, %2, %3;\n"
"}\n"
: "+r"(r[ii + 1])
: "n"(high_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
// For low nibble:
// {x1, x0} = {1024+x1, 1024+x0} - {1024, 1024}
// For high nibble:
// {x1, x0} = {1024+16*x1, 1024+16*x0} * {1/16, 1/16} - {64, 64}
static constexpr uint32_t low_nib_bias = 0x64006400; // {1024, 1024}
static constexpr uint32_t high_nib_scale = 0x2C002C00; // {1/16, 1/16}
static constexpr uint32_t high_nib_bias = 0xD400D400; // {-64, -64}
{
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]);
fp16x2_val =
__hsub2(fp16x2_val, reinterpret_cast<const half2&>(low_nib_bias));
}
{
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]);
fp16x2_val = __hfma2(fp16x2_val,
reinterpret_cast<const half2&>(high_nib_scale),
reinterpret_cast<const half2&>(high_nib_bias));
}
}
return reinterpret_cast<PackedResultType&>(r);
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// for Array<cutlass::half_t, N> <= Array<machete_uint8b128_t, N>
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<cutlass::half_t, machete_uint8b128_t, N, Round> {
using result_type = Array<cutlass::half_t, N>;
using source_type = Array<machete_uint8b128_t, N>;
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src = src_[0];
// Hold output FP16s in reg. We need 1 reg for every 2 elements
using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>;
RegArray r;
uint32_t const prmt_indices[2] = {0x5150, 0x5352};
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
for (int ii = 0; ii < RegArray::kElements; ++ii) {
asm volatile("prmt.b32 %0,%1,%2,%3;\n"
: "=r"(r[ii])
: "r"(src), "n"(start_byte_for_fp16),
"r"(prmt_indices[ii]));
}
// -128 is folded into bias subtraction, i.e. the 0x80 in the low bytes
static constexpr uint32_t bias_rep = 0x64806480;
const half2& bias = reinterpret_cast<const half2&>(bias_rep);
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]);
fp16x2_val = __hsub2(fp16x2_val, bias);
}
return reinterpret_cast<PackedResultType&>(r);
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// for Array<cutlass::float, N> <= Array<machete_uint8b128_t, N>
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<float, machete_uint8b128_t, N, Round> {
using result_type = Array<float, N>;
using source_type = Array<machete_uint8b128_t, N>;
static FloatRoundStyle const round_style = Round;
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src = src_[0];
PackedResultType r;
// __byte_perm simulates the add.u32 0x4B000000 to every u8 element of
// u8x4 source and stores the result in r (without introducing extra
// cvt.u32.u8 instruction)
uint32_t const prmt_indices[4] = {0x7650, 0x7651, 0x7652, 0x7653};
uint32_t* result_as_int = reinterpret_cast<uint32_t*>(&r);
for (int ii = 0; ii < PackedResultType::kElements; ++ii) {
result_as_int[ii] = __byte_perm(src, 0x4B000000, prmt_indices[ii]);
// Subtract the magic number 0x4B000000 from tmp in floating-point
// arithmetic to obtain final result
r[ii] -= (8388608.f + 128.f); // fold in -128 bias
}
return r;
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
// for Array<cutlass::bfloat16_t, N> <= Array<machete_uint4b8_t, N>
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<cutlass::bfloat16_t, machete_uint4b8_t, N, Round> {
using result_type = Array<cutlass::bfloat16_t, N>;
using source_type = Array<machete_uint4b8_t, N>;
static FloatRoundStyle const round_style = Round;
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src_reg = src_[0];
// Hold output BF16s in reg. We need 1 reg for every 2 elements
using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>;
RegArray r;
uint32_t src_reg_shifted = src_reg >> 4;
// Below constructs the following temporary:
uint32_t const prmt_indices[4] = {0xF4F0, 0xF5F1, 0xF6F2, 0xF7F3};
static_assert(RegArray::kElements <= 4,
"Too many inputs for uint4b8_t -> BF16 vector converter");
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
asm volatile(
"{\n"
" prmt.b32 %0, %1, %2, %3;\n"
"}\n"
: "=r"(r[ii])
: "r"(src_reg), "r"(src_reg_shifted), "r"(prmt_indices[ii]));
}
// Since the stored 4bit values are biased by 8 we get stored_val = (x+8)
// we are trying to construct x and a BF16 value
// The below XOR does the following:
// 1) Sets the exponent bits of the BF16 to the correct value for the
// BF16 magic_num. We will be constructing {128 + (x1+8), 128 + (x0+8)}
// and subtracting 136 to get {x1, x0}
static constexpr uint32_t xor_mask = 0x43004300;
static constexpr uint32_t and_mask = 0x000F000F;
static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa;
// For each operand, computes:
// r[i] = (r[i] & and_mask) ^ xor_mask
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
asm volatile(
"{\n"
" lop3.b32 %0, %0, %1, %2, %3;\n"
"}\n"
: "+r"(r[ii])
: "n"(and_mask), "n"(xor_mask), "n"(immLut));
}
// We will issue 2 bfmas that do the following:
// high BF16:
// hi_bf16 - 136, lo_bf16 - 136
// This is the BF16 {136, 136} represented as an integer.
static constexpr uint32_t bias_rep = 0x43084308;
const __nv_bfloat162& bias =
reinterpret_cast<const __nv_bfloat162&>(bias_rep);
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
__nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
bf16x2_val = __hsub2(bf16x2_val, bias);
}
return reinterpret_cast<PackedResultType&>(r);
}
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// for Array<cutlass::bfloat16_t, N> <= Array<machete_uint4b8_t, N>
// for IlvdLayout: (2, 4):(4, 1)
template <FloatRoundStyle Round, int N>
struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
cutlass::bfloat16_t, machete_uint4b8_t, N,
Round, void> {
using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
static_assert(N % size(IlvdLayout{}) == 0);
using result_type = Array<cutlass::bfloat16_t, N>;
using source_type = Array<machete_uint4b8_t, N>;
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src = src_[0];
using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>;
RegArray r;
static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
static constexpr uint32_t or_mask = 0x43004300;
// Unlike float16 where the mantissa is large enough to contain 2
// nibbles, bfloat16 can only fit one, so we can only convert one
// nibble at a time
for (int ii = 0; ii < RegArray::kElements; ++ii) {
r[ii] = src >> (4 * ii);
static constexpr uint32_t and_or_imm_lut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t low_nib_mask = 0x000F000F;
asm volatile(
"{\n"
" lop3.b32 %0, %0, %1, %2, %3;\n"
"}\n"
: "+r"(r[ii + 0])
: "n"(low_nib_mask), "n"(or_mask), "n"(and_or_imm_lut));
// For low nibble:
// {x1, x0} = {128+(x1+8), 128+(x0+8)} * {1, 1} - {136, 136}
static constexpr uint32_t low_nib_bias = 0x43084308; // {136, 136}
{
__nv_bfloat162& fp16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
fp16x2_val =
__hsub2(fp16x2_val,
reinterpret_cast<const __nv_bfloat162&>(low_nib_bias));
}
}
return reinterpret_cast<PackedResultType&>(r);
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// for Array<cutlass::bfloat16_t, N> <= Array<uint4_t, N>
// for IlvdLayout: (2, 4):(4, 1)
template <FloatRoundStyle Round, int N>
struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
cutlass::bfloat16_t, uint4_t, N, Round,
void> {
using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
static_assert(N % size(IlvdLayout{}) == 0);
using result_type = Array<cutlass::bfloat16_t, N>;
using source_type = Array<uint4_t, N>;
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
uint32_t src = src_[0];
using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>;
RegArray r;
static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
static constexpr uint32_t or_mask = 0x43004300;
// Unlike float16 where the mantissa is large enough to contain 2
// nibbles, bfloat16 can only fit one, so we can only convert one
// nibble at a time
for (int ii = 0; ii < RegArray::kElements; ++ii) {
r[ii] = src >> (4 * ii);
static constexpr uint32_t and_or_imm_lut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t low_nib_mask = 0x000F000F;
asm volatile(
"{\n"
" lop3.b32 %0, %0, %1, %2, %3;\n"
"}\n"
: "+r"(r[ii])
: "n"(low_nib_mask), "n"(or_mask), "n"(and_or_imm_lut));
// For low nibble:
// {x1, x0} = {128 + x1, 128 + x0} * {1, 1} - {128, 128}
static constexpr uint32_t low_nib_bias = 0x43004300; // {128, 128}
{
__nv_bfloat162& fp16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
fp16x2_val =
__hsub2(fp16x2_val,
reinterpret_cast<const __nv_bfloat162&>(low_nib_bias));
}
}
return reinterpret_cast<PackedResultType&>(r);
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// for Array<cutlass::bfloat16_t, N> <= Array<machete_uint8b128_t, N>
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<cutlass::bfloat16_t, machete_uint8b128_t, N, Round> {
using result_type = Array<cutlass::bfloat16_t, N>;
using source_type = Array<machete_uint8b128_t, N>;
static FloatRoundStyle const round_style = Round;
private:
using result_packed_4_t = Array<cutlass::bfloat16_t, 4>;
using result_packed_2_t = Array<cutlass::bfloat16_t, 2>;
using src_packed_4_t = Array<machete_uint8b128_t, 4>;
using src_packed_2_t = Array<machete_uint8b128_t, 2>;
// Not Valid, not supported, only here to satisfy the interface and to avoid
// a compile error. ScalarConverter will not actually work until
// NumericConverter<cutlass::bfloat16_t, machete_uint8b128_t, Round> is
// implemented
using ScalarConverter =
NumericConverter<cutlass::bfloat16_t, machete_uint8b128_t, Round>;
template <typename PackedResultType, typename PackedSrcType>
CUTLASS_DEVICE static PackedResultType packed_convert(
PackedSrcType const& source) {
static_assert(
(platform::is_same<PackedSrcType, src_packed_2_t>::value &&
platform::is_same<PackedResultType, result_packed_2_t>::value) ||
(platform::is_same<PackedSrcType, src_packed_4_t>::value &&
platform::is_same<PackedResultType, result_packed_4_t>::value),
"Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private "
"convert dispatch.");
NumericArrayConverter<float, machete_uint8b128_t, PackedResultType::kElements,
Round>
convert_uint8_to_f32;
Array<float, PackedResultType::kElements> tmp =
convert_uint8_to_f32(source);
NumericArrayConverter<cutlass::bfloat16_t, float,
PackedResultType::kElements, Round>
convert_f32_to_bf16_;
return convert_f32_to_bf16_(tmp);
}
friend class detail::VectorizedConverter;
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
result_type result;
using ConverterType =
NumericArrayConverter<typename result_type::Element,
typename source_type::Element, N, Round>;
detail::VectorizedConverter::convert<ConverterType, result_packed_4_t,
src_packed_4_t, result_packed_2_t,
src_packed_2_t>(result, source);
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
#endif
// for Array<int8_t, N> <= Array<cutlass::half_t, N>
// FastFP16toINT8 from https://arxiv.org/pdf/2406.09904
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<int8_t, cutlass::half_t, N, Round> {
using result_type = Array<int8_t, N>;
using source_type = Array<cutlass::half_t, N>;
struct RegConvert {
// FastFP16toINT8 from https://arxiv.org/pdf/2406.09904
template <typename PackedResultType, int src_regs>
CUTLASS_DEVICE static PackedResultType convert(
Array<uint32_t, src_regs> src) {
// Hold output int8s in reg. We need 1 reg for every 4 elements
using RegArray = cutlass::AlignedArray<
uint32_t, std::max(PackedResultType::kElements / 4, size_t(1))>;
RegArray r;
static constexpr uint32_t MAGIC_BIAS_ = 0x64806480;
auto MAGIC_BIAS = *reinterpret_cast<const half2*>(&MAGIC_BIAS_);
*reinterpret_cast<half2*>(&src[0]) =
__hadd2(*reinterpret_cast<half2*>(&src[0]), MAGIC_BIAS);
if constexpr (src_regs > 1) {
*reinterpret_cast<half2*>(&src[1]) =
__hadd2(*reinterpret_cast<half2*>(&src[1]), MAGIC_BIAS);
}
static_assert(PackedResultType::kElements <= 4);
uint32_t uint8s;
static constexpr uint32_t MASK_0246 = 0x6420;
static constexpr uint32_t UINT8s_TO_INT8s_MASK = 0x80808080;
asm volatile("prmt.b32 %0,%1,%2,%3;\n"
: "=r"(uint8s)
: "r"(src[0]), "r"((src_regs > 1) ? src[1] : src[0]),
"n"(MASK_0246));
uint32_t int8s = (uint8s ^ UINT8s_TO_INT8s_MASK);
return reinterpret_cast<PackedResultType&>(int8s);
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,43 @@
// adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/cutlass_extensions/vllm_numeric_conversion.cuh
#include "cutlass/bfloat16.h"
#include "cutlass/half.h"
#include "cuda_bf16.h"
#include "machete_custom_types.cuh"
namespace cutlass {
template <typename T>
struct nameof {
static constexpr char const* value = "unknown";
};
template <typename T>
inline constexpr auto nameof_v = nameof<T>::value;
#define NAMEOF_TYPE(T) \
template <> \
struct nameof<T> { \
static constexpr char const* value = #T; \
};
NAMEOF_TYPE(float_e4m3_t)
NAMEOF_TYPE(float_e5m2_t)
NAMEOF_TYPE(half_t)
NAMEOF_TYPE(nv_bfloat16)
NAMEOF_TYPE(bfloat16_t)
NAMEOF_TYPE(float)
NAMEOF_TYPE(int4b_t)
NAMEOF_TYPE(int8_t)
NAMEOF_TYPE(int32_t)
NAMEOF_TYPE(int64_t)
NAMEOF_TYPE(machete_uint4b8_t)
NAMEOF_TYPE(uint4b_t)
NAMEOF_TYPE(uint8_t)
NAMEOF_TYPE(machete_uint8b128_t)
NAMEOF_TYPE(uint32_t)
NAMEOF_TYPE(uint64_t)
}; // namespace cutlass

View File

@@ -0,0 +1,161 @@
// adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/cutlass_extensions/torch_utils.hpp
#pragma once
#include "helper.h"
#include "cute/layout.hpp"
#include "cutlass/layout/matrix.h"
#include "cutlass/bfloat16.h"
#include "cutlass/half.h"
using ColumnMajor = typename cutlass::layout::ColumnMajor;
using RowMajor = typename cutlass::layout::RowMajor;
namespace cute {
namespace detail {
template <class T, class F, class G, int... I>
CUTE_HOST_DEVICE constexpr auto tapply_with_idx(T&& t, F&& f, G&& g,
seq<I...>) {
return g(f(cute::get<I>(static_cast<T&&>(t)), I)...);
}
template <class F, int... I>
CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f, seq<I...>) {
return make_shape(f(I)...);
}
}; // namespace detail
template <class T, class F>
CUTE_HOST_DEVICE constexpr auto transform_with_idx(T const& t, F&& f) {
if constexpr (cute::is_tuple<T>::value) {
return detail::tapply_with_idx(
t, f, [](auto const&... a) { return cute::make_tuple(a...); },
tuple_seq<T>{});
} else {
return f(t);
}
CUTE_GCC_UNREACHABLE;
}
// calls: make_shape(f(0), f(1), ..., f(N-1))
template <int N, class F>
CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f) {
return detail::make_shape_from_idx(f, make_seq<N>{});
}
}; // namespace cute
// Make a layout from a tensor with `rank(Stride{})`, where the shape is the
// shape of the passed in tensor and the strides are of type `Stride` and
// contain the strides of the passed in tensor, checking that any static strides
// in `Stride{}` match the strides of the passed in tensor.
// If `tensor.shape().size() < rank(Stride{})`, the shape is padded with 1s and the extra
// strides are set to be 0 or 1.
template <typename Stride>
static inline auto make_cute_layout(paddle::Tensor const& tensor,
std::string_view name = "tensor") {
PD_CHECK(tensor.shape().size() <= rank(Stride{}));
auto stride = cute::transform_with_idx(
Stride{}, [&](auto const& stride_ele, auto const& idx) {
using StrideEle = std::decay_t<decltype(stride_ele)>;
if (idx < tensor.shape().size()) {
if constexpr (cute::is_static_v<StrideEle>) {
PD_CHECK(StrideEle::value == tensor.strides()[idx], "Expected ",
name, ".strides()[", idx, "] to be ", StrideEle::value, ", but got ", tensor.strides()[idx], ". ");
return StrideEle{};
} else {
if (tensor.shape()[idx] == 1) {
// use 0 stride for dims with size 1, this is easier for
// cute/cutlass to optimize (helps the TMA code flatten dims)
return StrideEle{0};
} else {
return tensor.strides()[idx];
}
}
} else {
// Extra strides are assumed to be 0 or 1
if constexpr (cute::is_static_v<StrideEle>) {
static_assert(StrideEle::value == 0 || StrideEle::value == 1);
}
return StrideEle{};
}
});
auto shape = cute::make_shape_from_idx<rank(Stride{})>([&](auto const& idx) {
if (idx < tensor.shape().size())
return tensor.shape()[idx];
else
return int64_t(1);
});
return make_layout(shape, stride);
}
template <typename Stride>
static inline auto maybe_make_cute_layout(
std::optional<paddle::Tensor> const& tensor,
std::string_view name = "tensor") {
using Layout = decltype(make_cute_layout<Stride>(*tensor));
if (tensor) {
return std::optional<Layout>{make_cute_layout<Stride>(*tensor, name)};
} else {
return std::optional<Layout>{};
}
}
//
// Paddle dtype to Cutlass Type (equivalent_cutlass_type)
//
template <typename T>
struct equivalent_cutlass_type {
using type = T;
};
template <typename T>
using equivalent_cutlass_type_t = typename equivalent_cutlass_type<T>::type;
template <>
struct equivalent_cutlass_type<phi::dtype::float16> {
using type = cutlass::half_t;
};
template <>
struct equivalent_cutlass_type<phi::dtype::bfloat16> {
using type = cutlass::bfloat16_t;
};
//
// equivalent_scalar_t (basically inverse of equivalent_cutlass_type)
//
// Return a `c10::CppTypeToScalarType<T>` compatible type, i.e. get the C++ from
// c10 that is equivalent to T, e.g.: `cutlass::half_t -> c10::Half`
template <typename T>
struct equivalent_scalar_type {
using type = T;
};
template <typename T>
using equivalent_scalar_type_t = typename equivalent_scalar_type<T>::type;
template <>
struct equivalent_scalar_type<cutlass::half_t> {
using type = phi::dtype::float16;
};
template <>
struct equivalent_scalar_type<cutlass::bfloat16_t> {
using type = phi::dtype::bfloat16;
};
// get equivalent c10::ScalarType tag from compile time type
template <typename T>
static inline constexpr paddle::DataType equivalent_scalar_type_v =
phi::CppTypeToDataType<equivalent_scalar_type_t<T>>::Type();

View File

@@ -0,0 +1,372 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/enforce.h"
#include <optional>
#include <variant>
namespace machete {
//
// ScalarType can represent a wide range of floating point and integer types,
// in particular it can be used to represent sub-byte data types (something
// that torch.dtype currently does not support).
//
// The type definitions on the Python side can be found in: vllm/scalar_type.py
// these type definitions should be kept up to date with any Python API changes
// here.
//
class ScalarType {
public:
enum NanRepr : uint8_t {
NAN_NONE = 0, // nans are not supported
NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s
NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s
NAN_REPR_ID_MAX
};
constexpr ScalarType(uint8_t exponent, uint8_t mantissa, bool signed_,
int32_t bias, bool finite_values_only = false,
NanRepr nan_repr = NAN_IEEE_754)
: exponent(exponent),
mantissa(mantissa),
signed_(signed_),
bias(bias),
finite_values_only(finite_values_only),
nan_repr(nan_repr) {};
static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) {
return ScalarType(0, size_bits - 1, true, bias);
}
static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) {
return ScalarType(0, size_bits, false, bias);
}
// IEEE 754 compliant floating point type
static constexpr ScalarType float_IEEE754(uint8_t exponent,
uint8_t mantissa) {
// PADDLE_ENFORCE(mantissa > 0 && exponent > 0);
return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754);
}
// IEEE 754 non-compliant floating point type
static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa,
bool finite_values_only,
NanRepr nan_repr) {
// PADDLE_ENFORCE(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr");
// PADDLE_ENFORCE(mantissa > 0 && exponent > 0);
// PADDLE_ENFORCE(nan_repr != NAN_IEEE_754,
// "use `float_IEEE754` constructor for floating point types that "
// "follow IEEE 754 conventions");
return ScalarType(exponent, mantissa, true, 0, finite_values_only,
nan_repr);
}
uint8_t const exponent; // size of the exponent field (0 for integer types)
uint8_t const mantissa; // size of the mantissa field (size of the integer
// excluding the sign bit for integer types)
bool const signed_; // flag if the type supports negative numbers (i.e. has a
// sign bit)
int32_t const bias; // stored values equal value + bias,
// used for quantized type
// Extra Floating point info
bool const finite_values_only; // i.e. no +/-inf if true
NanRepr const nan_repr; // how NaNs are represented
// (not applicable for integer types)
using Id = int64_t;
private:
// Field size in id
template <typename T_>
static constexpr size_t member_id_field_width() {
using T = std::decay_t<T_>;
return std::is_same_v<T, bool> ? 1 : sizeof(T) * 8;
}
template <typename Fn, typename Init, typename Member, typename... Rest>
static constexpr auto reduce_members_helper(Fn f, Init val, Member member,
Rest... rest) {
auto new_val = f(val, member);
if constexpr (sizeof...(rest) > 0) {
return reduce_members_helper(f, new_val, rest...);
} else {
return new_val;
};
}
template <typename Fn, typename Init>
constexpr auto reduce_members(Fn f, Init init) const {
// Should be in constructor order for `from_id`
return reduce_members_helper(f, init, exponent, mantissa, signed_, bias,
finite_values_only, nan_repr);
};
template <typename Fn, typename Init>
static constexpr auto reduce_member_types(Fn f, Init init) {
constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE);
return dummy_type.reduce_members(f, init);
};
static constexpr auto id_size_bits() {
return reduce_member_types(
[](int acc, auto member) -> int {
return acc + member_id_field_width<decltype(member)>();
},
0);
}
public:
// unique id for this scalar type that can be computed at compile time for
// c++17 template specialization this is not needed once we migrate to
// c++20 and can pass literal classes as template parameters
constexpr Id id() const {
static_assert(id_size_bits() <= sizeof(Id) * 8,
"ScalarType id is too large to be stored");
auto or_and_advance = [](std::pair<Id, uint32_t> result,
auto member) -> std::pair<Id, uint32_t> {
auto [id, bit_offset] = result;
auto constexpr bits = member_id_field_width<decltype(member)>();
return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1))
<< bit_offset,
bit_offset + bits};
};
return reduce_members(or_and_advance, std::pair<Id, uint32_t>{}).first;
}
// create a ScalarType from an id, for c++17 template specialization,
// this is not needed once we migrate to c++20 and can pass literal
// classes as template parameters
static constexpr ScalarType from_id(Id id) {
auto extract_and_advance = [id](auto result, auto member) {
using T = decltype(member);
auto [tuple, bit_offset] = result;
auto constexpr bits = member_id_field_width<T>();
auto extracted_val = static_cast<T>((int64_t(id) >> bit_offset) &
((uint64_t(1) << bits) - 1));
auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val));
return std::pair<decltype(new_tuple), int>{new_tuple, bit_offset + bits};
};
auto [tuple_args, _] = reduce_member_types(extract_and_advance,
std::pair<std::tuple<>, int>{});
return std::apply([](auto... args) { return ScalarType(args...); },
tuple_args);
}
constexpr int64_t size_bits() const {
return mantissa + exponent + is_signed();
}
constexpr bool is_signed() const { return signed_; }
constexpr bool is_integer() const { return exponent == 0; }
constexpr bool is_floating_point() const { return exponent > 0; }
constexpr bool is_ieee_754() const {
return is_floating_point() && finite_values_only == false &&
nan_repr == NAN_IEEE_754;
}
constexpr bool has_nans() const {
return is_floating_point() && nan_repr != NAN_NONE;
}
constexpr bool has_infs() const {
return is_floating_point() && finite_values_only == false;
}
constexpr bool has_bias() const { return bias != 0; }
private:
double _floating_point_max() const {
PADDLE_ENFORCE(mantissa <= 52 && exponent <= 11,
"Cannot represent max/min as a double for type ", str());
uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1;
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) {
max_mantissa -= 1;
}
uint64_t max_exponent = (uint64_t(1) << exponent) - 2;
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) {
PADDLE_ENFORCE(exponent < 11,
"Cannot represent max/min as a double for type ", str());
max_exponent += 1;
}
// adjust the exponent to match that of a double
// for now we assume the exponent bias is the standard 2^(e-1) -1, (where e
// is the exponent bits), there is some precedent for non-standard biases,
// example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes
// but to avoid premature over complication we are just assuming the
// standard exponent bias until there is a need to support non-standard
// biases
uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1;
uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11
uint64_t max_exponent_double =
max_exponent - exponent_bias + exponent_bias_double;
// shift the mantissa into the position for a double and
// the exponent
uint64_t double_raw =
(max_mantissa << (52 - mantissa)) | (max_exponent_double << 52);
return *reinterpret_cast<double*>(&double_raw);
}
constexpr std::variant<int64_t, double> _raw_max() const {
if (is_floating_point()) {
return {_floating_point_max()};
} else {
// PADDLE_ENFORCE(size_bits() < 64 || size_bits() == 64 && is_signed(),
// "Cannot represent max as a int64_t");
return {(int64_t(1) << mantissa) - 1};
}
}
constexpr std::variant<int64_t, double> _raw_min() const {
if (is_floating_point()) {
// PADDLE_ENFORCE(is_signed(),
// "We currently assume all floating point types are signed");
constexpr uint64_t sign_bit_double = (uint64_t(1) << 63);
double max = _floating_point_max();
uint64_t max_raw = *reinterpret_cast<uint64_t*>(&max);
uint64_t min_raw = max_raw | sign_bit_double;
return {*reinterpret_cast<double*>(&min_raw)};
} else {
// PADDLE_ENFORCE(!is_signed() || size_bits() <= 64,
// "Cannot represent min as a int64_t");
if (is_signed()) {
// set the top bit to 1 (i.e. INT64_MIN) and the rest to 0
// then perform an arithmetic shift right to set all the bits above
// (size_bits() - 1) to 1
return {INT64_MIN >> (64 - size_bits())};
} else {
return {int64_t(0)};
}
}
}
public:
// Max representable value for this scalar type.
// (accounting for bias if there is one)
constexpr std::variant<int64_t, double> max() const {
return std::visit(
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
_raw_max());
}
// Min representable value for this scalar type.
// (accounting for bias if there is one)
constexpr std::variant<int64_t, double> min() const {
return std::visit(
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
_raw_min());
}
std::string str() const {
/* naming generally follows: https://github.com/jax-ml/ml_dtypes
* for floating point types (leading f) the scheme is:
* `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
* flags:
* - no-flags: means it follows IEEE 754 conventions
* - f: means finite values only (no infinities)
* - n: means nans are supported (non-standard encoding)
* for integer types the scheme is:
* `[u]int<size_bits>[b<bias>]`
* - if bias is not present it means its zero
*/
if (is_floating_point()) {
auto ret = "float" + std::to_string(size_bits()) + "_e" +
std::to_string(exponent) + "m" + std::to_string(mantissa);
if (!is_ieee_754()) {
if (finite_values_only) {
ret += "f";
}
if (nan_repr != NAN_NONE) {
ret += "n";
}
}
return ret;
} else {
auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits());
if (has_bias()) {
ret += "b" + std::to_string(bias);
}
return ret;
}
}
constexpr bool operator==(ScalarType const& other) const {
return mantissa == other.mantissa && exponent == other.exponent &&
bias == other.bias && signed_ == other.signed_ &&
finite_values_only == other.finite_values_only &&
nan_repr == other.nan_repr;
}
};
using ScalarTypeId = machete::ScalarType::Id;
// "rust style" names generally following:
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70
static inline constexpr auto kS4 = machete::ScalarType::int_(4);
static inline constexpr auto kU4 = machete::ScalarType::uint(4);
static inline constexpr auto kU4B8 = machete::ScalarType::uint(4, 8);
static inline constexpr auto kS8 = machete::ScalarType::int_(8);
static inline constexpr auto kU8 = machete::ScalarType::uint(8);
static inline constexpr auto kU8B128 = machete::ScalarType::uint(8, 128);
static inline constexpr auto kFE2M1f =
machete::ScalarType::float_(2, 1, true, machete::ScalarType::NAN_NONE);
static inline constexpr auto kFE3M2f =
machete::ScalarType::float_(3, 2, true, machete::ScalarType::NAN_NONE);
static inline constexpr auto kFE4M3fn =
machete::ScalarType::float_(4, 3, true, machete::ScalarType::NAN_EXTD_RANGE_MAX_MIN);
static inline constexpr auto kFE5M2 = machete::ScalarType::float_IEEE754(5, 2);
static inline constexpr auto kFE8M7 = machete::ScalarType::float_IEEE754(8, 7);
static inline constexpr auto kFE5M10 = machete::ScalarType::float_IEEE754(5, 10);
// // Fixed width style names, generally following:
// // https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57
constexpr auto kInt4 = kS4;
constexpr auto kUint4 = kU4;
constexpr auto kUint4b8 = kU4B8;
constexpr auto kInt8 = kS8;
constexpr auto kUint8 = kU8;
constexpr auto kUint8b128 = kU8B128;
constexpr auto kFloat4_e2m1f = kFE2M1f;
constexpr auto kFloat6_e3m2f = kFE3M2f;
constexpr auto kFloat8_e5m2 = kFE5M2;
constexpr auto kFloat16_e8m7 = kFE8M7;
constexpr auto kFloat16_e5m10 = kFE5M10;
// colloquial names
constexpr auto kHalf = kFE5M10;
constexpr auto kFloat16 = kHalf;
constexpr auto kFloat16Id = kFloat16.id();
constexpr auto kInt32 = phi::DataType::INT32;
constexpr auto kInt64 = phi::DataType::INT64;
constexpr auto kBool = phi::DataType::BOOL;
constexpr auto kFloat8_e4m3fn = phi::DataType::FLOAT8_E4M3FN;
constexpr auto kBFloat16 = phi::DataType::BFLOAT16;
constexpr auto kFloat32 = phi::DataType::FLOAT32;
constexpr auto kByte = phi::DataType::INT8;
}; // namespace machete

View File

@@ -512,6 +512,8 @@ elif paddle.is_compiled_with_cuda():
sources += find_end_files("gpu_ops/w4afp8_gemm", ".cu")
os.system("python utils/auto_gen_wfp8afp8_sparse_gemm_kernel.py")
sources += find_end_files("gpu_ops/wfp8afp8_sparse_gemm", ".cu")
os.system("python gpu_ops/machete/generate.py")
sources += find_end_files("gpu_ops/machete", ".cu")
setup(
name="fastdeploy_ops",

View File

@@ -54,6 +54,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
"FD_SAMPLING_CLASS": lambda: os.getenv("FD_SAMPLING_CLASS", "base"),
# Set moe backend."cutlass","marlin" and "triton" can be set currently.
"FD_MOE_BACKEND": lambda: os.getenv("FD_MOE_BACKEND", "cutlass"),
# Whether to use Machete for wint4 dense gemm.
"FD_USE_MACHETE": lambda: os.getenv("FD_USE_MACHETE", "0"),
# Set whether to disable recompute the request when the KV cache is full.
"FD_DISABLED_RECOVER": lambda: os.getenv("FD_DISABLED_RECOVER", "0"),
# Set triton kernel JIT compilation directory.

View File

@@ -15,9 +15,12 @@
"""
from .cutlass_scaled_mm import cutlass_scaled_mm
from .machete_mm import machete_quantize_and_pack, machete_wint_mm
from .scaled_fp8_quant import scaled_fp8_quant
__all__ = [
"cutlass_scaled_mm",
"scaled_fp8_quant",
"machete_wint_mm",
"machete_quantize_and_pack",
]

View File

@@ -0,0 +1,185 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import numpy as np
import paddle
from fastdeploy.platforms import current_platform
def get_sm_version():
prop = paddle.device.cuda.get_device_properties()
cc = prop.major * 10 + prop.minor
return cc
if current_platform.is_cuda() and get_sm_version() == 90:
from fastdeploy.model_executor.ops.gpu import machete_mm, machete_prepack_B
def get_pack_factor(num_bits):
assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
return 32 // num_bits
def pack_rows(
q_w: paddle.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
assert q_w.shape == [size_k, size_n]
pack_factor = get_pack_factor(num_bits)
assert size_k % pack_factor == 0
orig_device = q_w.place
q_w_np = q_w.numpy().astype(np.uint32)
q_res = np.zeros((size_k // pack_factor, size_n), dtype=np.uint32)
for i in range(pack_factor):
q_res |= q_w_np[i::pack_factor, :] << num_bits * i
q_res = paddle.to_tensor(q_res.astype(np.int32), place=orig_device)
return q_res
def quantize_weights(
w: paddle.Tensor,
group_size: Optional[int],
quant_type: str = "uint4b8",
):
"""
Quantize weights in PaddlePaddle, similar to PyTorch implementation.
Args:
w: Input weight tensor (must be float type).
quant_type: Target quantization type (e.g., `uint4`, `uint4b8`).
group_size: Group size for quantization. If `-1`, use channel-wise quantization.
zero_points: Whether to compute zero points (only for unsigned quant types).
ref_zero_points_after_scales: If True, apply zero points after scales in dequantization.
Returns:
w_ref: Dequantized reference weights.
w_q: Quantized weights.
w_s: Scales (None if `group_size` is None).
"""
assert paddle.is_floating_point(w), "w must be float type"
assert quant_type in ["uint4", "uint4b8"], "only support quant_type = uint4, uint4b8"
orig_device = w.place
size_k, size_n = w.shape
if group_size == -1:
group_size = size_k
# Reshape to [group_size, -1]
if group_size is not None and group_size < size_k:
w = w.reshape([-1, group_size, size_n])
w = w.transpose([1, 0, 2])
w = w.reshape([group_size, -1])
# Compute scale for each group
max_val = paddle.max(w, axis=0, keepdim=True)
min_val = paddle.min(w, axis=0, keepdim=True)
max_q_val = float(7.0)
min_q_val = float(-8.0)
w_s = paddle.ones([1], dtype=paddle.float32) # unscaled case
if group_size is not None:
# Avoid division by zero
max_scale = paddle.maximum(
paddle.abs(max_val / (max_q_val if max_q_val != 0 else float("inf"))),
paddle.abs(min_val / (min_q_val if min_q_val != 0 else float("inf"))),
)
w_s = max_scale
# Quantize
w_q = paddle.round(w / w_s).astype(paddle.int32)
w_q = paddle.clip(w_q, min_q_val, max_q_val)
# if hasattr(quant_type, 'bias'): # Custom quantization bias (if applicable)
# w_q += quant_type.bias
if quant_type == "uint4b8":
w_q += 8
# Restore original shapes
if group_size is not None and group_size < size_k:
def reshape_w(w_tensor):
w_tensor = w_tensor.reshape([group_size, -1, size_n])
w_tensor = w_tensor.transpose([1, 0, 2])
w_tensor = w_tensor.reshape([size_k, size_n])
return w_tensor
w_q = reshape_w(w_q)
w_s = w_s.reshape([-1, size_n])
# Move tensors back to original device
w_q = w_q.to(orig_device)
if w_s is not None:
w_s = w_s.to(orig_device)
return w_q, w_s
def machete_quantize_and_pack(
w: paddle.Tensor,
atype: paddle.dtype,
quant_type: str = "uint4b8",
scale_type: str = "",
group_size: int = -1,
):
w_q, w_s = quantize_weights(w, group_size, quant_type=quant_type)
w_q = pack_rows(w_q, 4, *w_q.shape)
w_q_col = w_q.transpose([1, 0]).contiguous() # convert to col major
w_q_prepack = machete_prepack_B(
w_q_col,
atype,
quant_type,
scale_type,
)[0]
return w_q_prepack, w_s
def machete_wint_mm(
x: paddle.Tensor,
w_prepack: paddle.Tensor,
w_g_s: paddle.Tensor,
w_g_zp: Optional[paddle.Tensor] = None,
w_ch_s: Optional[paddle.Tensor] = None,
w_tok_s: Optional[paddle.Tensor] = None,
weight_dtype: str = "uint4b8",
group_size: int = -1,
out_dtype: str = "",
scheduler: str = "",
):
out = machete_mm(
x,
w_prepack,
w_g_s, # group scales
w_g_zp, # group zeros
w_ch_s, # per-channel scale
w_tok_s, # per-token scale
weight_dtype, # weight_dtype
out_dtype, # out_dtype
group_size, # group_size
scheduler, # scheduler
)[0]
return out

View File

@@ -21,6 +21,7 @@ from typing import Optional
import paddle
from paddle.nn.quant import weight_only_linear, weight_quantize
from fastdeploy import envs
from fastdeploy.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
@@ -33,6 +34,12 @@ from ..utils import get_tensor
from .quant_base import QuantConfigBase, QuantMethodBase
def get_sm_version():
prop = paddle.device.cuda.get_device_properties()
cc = prop.major * 10 + prop.minor
return cc
class WeightOnlyConfig(QuantConfigBase):
"""
Quantization config for weight only
@@ -132,6 +139,14 @@ class WeightOnlyConfig(QuantConfigBase):
else:
raise ValueError(f"Unsupported MOE backend {layer.use_method}")
else:
if (
self.name() == "wint4"
and envs.FD_USE_MACHETE == "1"
and get_sm_version() == 90
and layer.weight_shape[1]
and layer.weight_shape[1] % 128 == 0
):
return MacheteWeightOnlyLinearMethod(self)
return GPUWeightOnlyLinearMethod(self)
@@ -329,3 +344,73 @@ class GPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
quanted_weight_tensor = paddle.transpose(quanted_weight_tensor, [1, 0])
layer.weight.set_value(quanted_weight_tensor)
layer.weight_scale.set_value(weight_scale_tensor.astype(paddle.get_default_dtype()))
class MacheteWeightOnlyLinearMethod(WeightOnlyLinearMethod):
"""
Weight only quantization method for linear layer on GPU using Machete
The weights are loaded in the BF16 numerical format. After loading, the quantization coefficients will be computed,
and the weights will be quantized to int8 or int4.
"""
def __init__(
self,
quant_config: WeightOnlyConfig,
) -> None:
super().__init__(quant_config)
def create_weights(self, layer, **extra_weight_attrs):
assert layer.bias is None, "Machete weight only linear method does not support bias."
assert self.quant_config.name() == "wint4", "Machete weight only linear method only supports wint4."
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
weight_scale_shape = [1, layer.weight_shape[1]]
# layer.weight_shape.reverse()
if self.quant_config.name() == "wint4":
layer.weight_shape[0] //= 8
layer.weight_dtype = "int32"
layer.weight = layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.weight_scale = layer.create_parameter(
shape=weight_scale_shape,
dtype=layer._dtype,
is_bias=False,
)
def process_prequanted_weights(self, layer, state_dict) -> None:
pass
def process_loaded_weights(self, layer, weight) -> None:
from fastdeploy.model_executor.layers.quantization.ops import (
machete_quantize_and_pack,
)
quanted_weight_tensor, weight_scale_tensor = machete_quantize_and_pack(
w=weight,
atype=layer._dtype,
quant_type="uint4b8",
)
layer.weight.set_value(quanted_weight_tensor)
layer.weight_scale.set_value(weight_scale_tensor.astype(paddle.get_default_dtype()))
def apply(self, layer, x):
assert layer.bias is None, "Machete weight only linear method does not support bias."
assert self.quant_config.name() == "wint4", "Machete weight only linear method only supports wint4."
from fastdeploy.model_executor.layers.quantization.ops import machete_wint_mm
linear_out = machete_wint_mm(
x,
w_prepack=layer.weight,
w_g_s=layer.weight_scale,
weight_dtype="uint4b8",
)
return linear_out

View File

@@ -0,0 +1,174 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import struct
import unittest
import numpy as np
import paddle
import paddle.nn.quant as Q
from paddle import base
from paddle.base import core
from paddle.framework import set_default_dtype
from fastdeploy.model_executor.layers.quantization.ops import (
machete_quantize_and_pack,
machete_wint_mm,
)
np.random.seed(123)
paddle.seed(123)
def get_cuda_version():
result = os.popen("nvcc --version").read()
regex = r"release (\S+),"
match = re.search(regex, result)
if match:
num = str(match.group(1))
integer, decimal = num.split(".")
return int(integer) * 1000 + int(float(decimal) * 10)
else:
return -1
def get_sm_version():
prop = paddle.device.cuda.get_device_properties()
cc = prop.major * 10 + prop.minor
return cc
def convert_uint16_to_float(in_list):
in_list = np.asarray(in_list)
out = np.vectorize(
lambda x: struct.unpack("<f", struct.pack("<I", np.uint32(x) << np.uint32(16)))[0],
otypes=[np.float32],
)(in_list.flat)
return np.reshape(out, in_list.shape)
@unittest.skipIf(
not core.is_compiled_with_cuda() or get_sm_version() < 90,
"machete only support sm90.",
)
class WeightOnlyLinearTestCase(unittest.TestCase):
def config(self):
self.dtype = "float16"
self.rtol = 1e-5
self.atol = 1e-2
self.bias = False
self.batch = 1
self.token = 512
self.in_features = 7168
self.out_features = 1024
self.weight_dtype = "int4"
self.static = False
self.group_size = -1
def setUp(self):
self.config()
if self.dtype == "bfloat16" or self.weight_dtype == "int4":
self.atol = 1.3e-1
x = np.random.random((self.token, self.in_features))
self.x = paddle.to_tensor(x, dtype=self.dtype)
if self.bias:
bias_attr = base.ParamAttr(
trainable=False,
regularizer=None,
initializer=paddle.nn.initializer.Constant(value=1.0),
)
else:
bias_attr = None
set_default_dtype(self.dtype)
self.linear = paddle.nn.Linear(self.in_features, self.out_features, bias_attr=bias_attr)
self.bias = self.linear.bias
self.weight = self.linear.weight
self.float_weight = self.linear.weight
self.weight_scale = None
self.weight, self.weight_scale = Q.weight_quantize(
(self.float_weight.cuda() if self.weight_dtype == "int8" else self.weight.cpu()),
algo=("weight_only_int8" if self.weight_dtype == "int8" else "weight_only_int4"),
group_size=self.group_size,
)
def get_linear_out(self):
out = self.linear(self.x)
return out.numpy()
def get_weight_only_linear_out(self):
for i in range(10):
out = Q.weight_only_linear(
self.x,
self.weight,
bias=self.bias,
weight_scale=self.weight_scale,
weight_dtype=self.weight_dtype,
group_size=self.group_size,
)
return out.numpy()
def get_machete_weight_only_linear_out(self):
w_q, w_s = machete_quantize_and_pack(
w=self.float_weight.cuda(),
atype=self.dtype,
quant_type="uint4b8",
)
out = machete_wint_mm(
self.x,
w_prepack=w_q,
w_g_s=w_s, # group scales
weight_dtype="uint4b8", # weight_dtype
)
return out.numpy()
def test_weight_only_linear(self):
# out_expect = self.get_linear_out()
out_paddle = self.get_weight_only_linear_out()
out_machete = self.get_machete_weight_only_linear_out()
if self.dtype == "bfloat16":
out_paddle = convert_uint16_to_float(out_paddle)
# out_expect = convert_uint16_to_float(out_expect)
out_machete = convert_uint16_to_float(out_machete)
np.testing.assert_allclose(out_paddle, out_machete, rtol=self.rtol, atol=self.atol)
M = [32, 128]
K_N = [[2048, 4096]]
def make_case(m, k, n):
class Case(WeightOnlyLinearTestCase):
def config(self, _m=m, _k=k, _n=n):
super().config()
self.token = m
self.in_features = k
self.out_features = n
Case.name = f"WeightOnlyLinearTestCase{m}{k}{n}"
return Case
for k, n in K_N:
for m in M:
cls = make_case(m, k, n)
globals()[cls.name] = cls
if __name__ == "__main__":
unittest.main()