mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-26 12:31:27 +08:00
[Optimize]support machete weight only gemm (#3561)
* support machete weight only gemm * add generate * update * fix * change file location * add sm_version limit * fix * fix * fix ci * fix coverage * fix xpu
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -159,6 +159,9 @@ custom_ops/gpu_ops/fp8_deep_gemm/deep_gemm/include/cute
|
||||
#marlin_kernel
|
||||
custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/kernel_*.cu
|
||||
|
||||
#machete_kernel
|
||||
custom_ops/gpu_ops/machete/generated
|
||||
|
||||
# buff
|
||||
custom_ops/tmp*
|
||||
|
||||
|
@@ -129,6 +129,24 @@ paddle::Tensor FusedExpertMoeFunc(
|
||||
const std::string &quant_method, const int moe_topk,
|
||||
const bool norm_topk_prob, const bool group_moe);
|
||||
|
||||
std::vector<paddle::Tensor> MacheteMMKernel(
|
||||
paddle::Tensor const& A, paddle::Tensor const& B,
|
||||
paddle::optional<paddle::Tensor> const& maybe_group_scales,
|
||||
paddle::optional<paddle::Tensor> const& maybe_group_zeros,
|
||||
paddle::optional<paddle::Tensor> const& maybe_channel_scales,
|
||||
paddle::optional<paddle::Tensor> const& maybe_token_scales,
|
||||
std::string const& b_type_str,
|
||||
std::string const& maybe_out_type_str,
|
||||
int64_t const& maybe_group_size,
|
||||
std::string const& maybe_schedule);
|
||||
|
||||
std::vector<paddle::Tensor> MachetePrepackBKernel(
|
||||
paddle::Tensor const& B, std::string const& a_type_str, std::string const& b_type_str,
|
||||
std::string const& maybe_group_scales_type_str);
|
||||
|
||||
std::vector<std::string> MacheteSupportedSchedules(
|
||||
std::string const& a_type_str, std::string const& b_type_str);
|
||||
|
||||
std::vector<paddle::Tensor> MoeExpertDispatch(
|
||||
const paddle::Tensor &input, const paddle::Tensor &gating_output,
|
||||
const paddle::optional<paddle::Tensor> &gating_correction_bias,
|
||||
@@ -924,6 +942,25 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
py::arg("recv_expert_count"), py::arg("block_size"),
|
||||
"per token per block quant");
|
||||
|
||||
/*machete/machete_mm.cu
|
||||
* machete_mm
|
||||
*/
|
||||
m.def("machete_mm", &MacheteMMKernel, py::arg("A"), py::arg("B"), py::arg("maybe_group_scale"),
|
||||
py::arg("maybe_group_zeros"), py::arg("maybe_channel_scales"), py::arg("maybe_token_scales"),
|
||||
py::arg("b_type_str"), py::arg("maybe_out_type_str"), py::arg("maybe_group_size"),
|
||||
py::arg("maybe_schedule"),
|
||||
"machete mm function");
|
||||
|
||||
/*machete/machete_prepack_B.cu
|
||||
* machete_prepack_B
|
||||
*/
|
||||
m.def("machete_prepack_B", &MachetePrepackBKernel, "machete prepacked B function");
|
||||
|
||||
/*machete/machete_supported_schedules.cu
|
||||
* machete_supported_schedules
|
||||
*/
|
||||
m.def("machete_supported_schedules", &MacheteSupportedSchedules, "machete supported schedules function");
|
||||
|
||||
/**
|
||||
* moe/fused_moe/moe_topk_select.cu
|
||||
* moe_topk_select
|
||||
|
@@ -6,6 +6,8 @@
|
||||
// clang-format off
|
||||
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
|
||||
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp"
|
||||
|
||||
#include "helper.h"
|
||||
// clang-format on
|
||||
|
||||
/*
|
||||
|
574
custom_ops/gpu_ops/machete/generate.py
Normal file
574
custom_ops/gpu_ops/machete/generate.py
Normal file
@@ -0,0 +1,574 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from collections.abc import Iterable
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, fields
|
||||
from functools import reduce
|
||||
from typing import Optional, Union
|
||||
|
||||
import jinja2
|
||||
|
||||
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
p = os.path.abspath(os.path.join(cur_dir, "../../third_party/cutlass/python"))
|
||||
sys.path.insert(0, p)
|
||||
|
||||
from cutlass_library import (
|
||||
EpilogueScheduleTag,
|
||||
EpilogueScheduleType,
|
||||
TileSchedulerTag,
|
||||
TileSchedulerType,
|
||||
)
|
||||
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from machete_cutlass_library_extension import (
|
||||
DataType,
|
||||
MACHETEDataType,
|
||||
MACHETEDataTypeMACHETEScalarTypeTag,
|
||||
MACHETEDataTypeNames,
|
||||
MACHETEDataTypePaddleDataTypeTag,
|
||||
MACHETEDataTypeSize,
|
||||
MACHETEDataTypeTag,
|
||||
MACHETEKernelScheduleTag,
|
||||
MixedInputKernelScheduleType,
|
||||
)
|
||||
|
||||
# yapf: enable
|
||||
|
||||
#
|
||||
# Generator templating
|
||||
#
|
||||
|
||||
DISPATCH_TEMPLATE = """
|
||||
#include "../machete_mm_launcher.cuh"
|
||||
|
||||
namespace machete {
|
||||
|
||||
{% for impl_config in impl_configs %}
|
||||
{% set type_sig = gen_type_sig(impl_config.types) -%}
|
||||
{% for s in impl_config.schedules %}
|
||||
extern paddle::Tensor impl_{{type_sig}}_sch_{{gen_sch_sig(s)}}(MMArgs);
|
||||
{%- endfor %}
|
||||
|
||||
paddle::Tensor mm_dispatch_{{type_sig}}(MMArgs args) {
|
||||
[[maybe_unused]] auto M = args.A.shape()[0];
|
||||
[[maybe_unused]] auto N = args.B.shape()[1];
|
||||
[[maybe_unused]] auto K = args.A.shape()[1];
|
||||
|
||||
if (!args.maybe_schedule) {
|
||||
{%- for cond, s in impl_config.heuristic %}
|
||||
{%if cond is not none%}if ({{cond}})
|
||||
{%- else %}else
|
||||
{%- endif %}
|
||||
return impl_{{type_sig}}_sch_{{ gen_sch_sig(s) }}(args);{% endfor %}
|
||||
}
|
||||
|
||||
{%- for s in impl_config.schedules %}
|
||||
if (*args.maybe_schedule == "{{ gen_sch_sig(s) }}")
|
||||
return impl_{{type_sig}}_sch_{{ gen_sch_sig(s) }}(args);
|
||||
{%- endfor %}
|
||||
PADDLE_ENFORCE(false, "machete_gemm(..) is not implemented ");
|
||||
}
|
||||
{%- endfor %}
|
||||
|
||||
|
||||
static inline std::optional<paddle::DataType> maybe_scalartype(
|
||||
std::optional<paddle::Tensor> const& t) {
|
||||
if (!t) {
|
||||
return std::nullopt;
|
||||
} else {
|
||||
return t->dtype();
|
||||
};
|
||||
}
|
||||
|
||||
paddle::Tensor mm_dispatch(MMArgs args) {
|
||||
auto out_type = args.maybe_out_type.value_or(args.A.dtype());
|
||||
auto a_type = args.A.dtype();
|
||||
auto maybe_g_scales_type = maybe_scalartype(args.maybe_group_scales);
|
||||
auto maybe_g_zeros_type = maybe_scalartype(args.maybe_group_zeros);
|
||||
auto maybe_ch_scales_type = maybe_scalartype(args.maybe_channel_scales);
|
||||
auto maybe_tok_scales_type = maybe_scalartype(args.maybe_token_scales);
|
||||
|
||||
{% for impl_config in impl_configs %}
|
||||
{% set t = impl_config.types -%}
|
||||
{% set type_sig = gen_type_sig(t) -%}
|
||||
if (args.b_type == {{MACHETEScalarTypeTag[t.b]}}
|
||||
&& a_type == {{PaddleTypeTag[t.a]}}
|
||||
&& out_type == {{PaddleTypeTag[t.out]}}
|
||||
&& {%if t.b_group_scale != void -%}
|
||||
maybe_g_scales_type == {{PaddleTypeTag[t.b_group_scale]}}
|
||||
{%- else %}!maybe_g_scales_type{%endif%}
|
||||
&& {%if t.b_group_zeropoint != void -%}
|
||||
maybe_g_zeros_type == {{PaddleTypeTag[t.b_group_zeropoint]}}
|
||||
{%- else %}!maybe_g_zeros_type{%endif%}
|
||||
&& {%if t.b_channel_scale != void -%}
|
||||
maybe_ch_scales_type == {{PaddleTypeTag[t.b_channel_scale]}}
|
||||
{%- else %}!maybe_ch_scales_type{%endif%}
|
||||
&& {%if t.a_token_scale != void -%}
|
||||
maybe_tok_scales_type == {{PaddleTypeTag[t.a_token_scale]}}
|
||||
{%- else %}!maybe_tok_scales_type{%endif%}
|
||||
) {
|
||||
return mm_dispatch_{{type_sig}}(args);
|
||||
}
|
||||
{%- endfor %}
|
||||
|
||||
PADDLE_ENFORCE(
|
||||
false, "machete_mm(..) is not implemented "
|
||||
"; implemented types are: \\n",
|
||||
{%- for impl_config in impl_configs %}
|
||||
{% set t = impl_config.types -%}
|
||||
"\\t{{gen_type_option_name(t)}}\\n",
|
||||
{%- endfor %}
|
||||
"");
|
||||
}
|
||||
|
||||
std::vector<std::string> supported_schedules_dispatch(
|
||||
SupportedSchedulesArgs args) {
|
||||
auto out_type = args.maybe_out_type.value_or(args.a_type);
|
||||
|
||||
{% for impl_config in impl_configs %}
|
||||
{% set t = impl_config.types -%}
|
||||
{% set schs = impl_config.schedules -%}
|
||||
if (args.b_type == {{MACHETEScalarTypeTag[t.b]}}
|
||||
&& args.a_type == {{PaddleTypeTag[t.a]}}
|
||||
&& out_type == {{PaddleTypeTag[t.out]}}
|
||||
&& {%if t.b_group_scale != void -%}
|
||||
args.maybe_group_scales_type == {{PaddleTypeTag[t.b_group_scale]}}
|
||||
{%- else %}!args.maybe_group_scales_type{%endif%}
|
||||
&& {%if t.b_group_zeropoint != void-%}
|
||||
args.maybe_group_zeros_type == {{PaddleTypeTag[t.b_group_zeropoint]}}
|
||||
{%- else %}!args.maybe_group_zeros_type{%endif%}
|
||||
) {
|
||||
return {
|
||||
{%- for s in impl_config.schedules %}
|
||||
"{{gen_sch_sig(s)}}"{% if not loop.last %},{% endif %}
|
||||
{%- endfor %}
|
||||
};
|
||||
}
|
||||
{%- endfor %}
|
||||
|
||||
return {};
|
||||
};
|
||||
|
||||
}; // namespace machete
|
||||
"""
|
||||
|
||||
IMPL_TEMPLATE = """
|
||||
#include "../machete_mm_launcher.cuh"
|
||||
|
||||
namespace machete {
|
||||
|
||||
{% for sch in unique_schedules(impl_configs) %}
|
||||
{% set sch_sig = gen_sch_sig(sch) -%}
|
||||
struct sch_{{sch_sig}} {
|
||||
using TileShapeNM = Shape<{{
|
||||
to_cute_constant(sch.tile_shape_mn)|join(', ')}}>;
|
||||
using ClusterShape = Shape<{{
|
||||
to_cute_constant(sch.cluster_shape_mnk)|join(', ')}}>;
|
||||
// TODO: Reimplement
|
||||
// using KernelSchedule = {{KernelScheduleTag[sch.kernel_schedule]}};
|
||||
using EpilogueSchedule = {{EpilogueScheduleTag[sch.epilogue_schedule]}};
|
||||
using TileScheduler = {{TileSchedulerTag[sch.tile_scheduler]}};
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
};
|
||||
{% endfor %}
|
||||
|
||||
{% for impl_config in impl_configs %}
|
||||
{% set t = impl_config.types -%}
|
||||
{% set schs = impl_config.schedules -%}
|
||||
{% set type_sig = gen_type_sig(t) -%}
|
||||
|
||||
template<typename Sch>
|
||||
using Kernel_{{type_sig}} = MacheteKernelTemplate<
|
||||
{{DataTypeTag[t.a]}}, // ElementA
|
||||
{{DataTypeTag[t.b]}}, // ElementB
|
||||
{{DataTypeTag[t.out]}}, // ElementD
|
||||
{{DataTypeTag[t.accumulator]}}, // Accumulator
|
||||
{{DataTypeTag[t.b_group_scale]}}, // GroupScaleT
|
||||
{{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT
|
||||
{{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT
|
||||
{{DataTypeTag[t.a_token_scale]}}, // TokenScaleT
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperative,
|
||||
Sch>;
|
||||
|
||||
{% for sch in schs %}
|
||||
{% set sch_sig = gen_sch_sig(sch) -%}
|
||||
paddle::Tensor
|
||||
impl_{{type_sig}}_sch_{{sch_sig}}(MMArgs args) {
|
||||
return run_impl<Kernel_{{type_sig}}<sch_{{sch_sig}}>>(args);
|
||||
}
|
||||
{%- endfor %}
|
||||
{%- endfor %}
|
||||
|
||||
}; // namespace machete
|
||||
"""
|
||||
|
||||
PREPACK_TEMPLATE = """
|
||||
#include "../machete_prepack_launcher.cuh"
|
||||
|
||||
namespace machete {
|
||||
|
||||
paddle::Tensor prepack_B_dispatch(PrepackBArgs args) {
|
||||
auto convert_type = args.maybe_group_scales_type.value_or(args.a_type);
|
||||
{%- for t in types %}
|
||||
{% set b_type = unsigned_type_with_bitwidth(t.b_num_bits) %}
|
||||
if (args.a_type == {{PaddleTypeTag[t.a]}}
|
||||
&& args.b_type.size_bits() == {{t.b_num_bits}}
|
||||
&& convert_type == {{PaddleTypeTag[t.convert]}}) {
|
||||
return prepack_impl<
|
||||
PrepackedLayoutBTemplate<
|
||||
{{DataTypeTag[t.a]}}, // ElementA
|
||||
{{DataTypeTag[b_type]}}, // ElementB
|
||||
{{DataTypeTag[t.convert]}}, // ElementConvert
|
||||
{{DataTypeTag[t.accumulator]}}, // Accumulator
|
||||
cutlass::layout::ColumnMajor,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperative>
|
||||
>(args.B);
|
||||
}
|
||||
{%- endfor %}
|
||||
|
||||
PADDLE_ENFORCE(false,
|
||||
"prepack_B_dispatch(..) is not implemented");
|
||||
}
|
||||
|
||||
}; // namespace machete
|
||||
"""
|
||||
|
||||
TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperative
|
||||
TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleConfig:
|
||||
tile_shape_mn: tuple[int, int]
|
||||
cluster_shape_mnk: tuple[int, int, int]
|
||||
kernel_schedule: MixedInputKernelScheduleType
|
||||
epilogue_schedule: EpilogueScheduleType
|
||||
tile_scheduler: TileSchedulerType
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TypeConfig:
|
||||
a: DataType
|
||||
b: Union[DataType, MACHETEDataType]
|
||||
b_group_scale: DataType
|
||||
b_group_zeropoint: DataType
|
||||
b_channel_scale: DataType
|
||||
a_token_scale: DataType
|
||||
out: DataType
|
||||
accumulator: DataType
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PrepackTypeConfig:
|
||||
a: DataType
|
||||
b_num_bits: int
|
||||
convert: DataType
|
||||
accumulator: DataType
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImplConfig:
|
||||
types: TypeConfig
|
||||
schedules: list[ScheduleConfig]
|
||||
heuristic: list[tuple[Optional[str], ScheduleConfig]]
|
||||
|
||||
|
||||
def generate_sch_sig(schedule_config: ScheduleConfig) -> str:
|
||||
tile_shape = f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}"
|
||||
cluster_shape = (
|
||||
f"{schedule_config.cluster_shape_mnk[0]}"
|
||||
+ f"x{schedule_config.cluster_shape_mnk[1]}"
|
||||
+ f"x{schedule_config.cluster_shape_mnk[2]}"
|
||||
)
|
||||
kernel_schedule = MACHETEKernelScheduleTag[schedule_config.kernel_schedule].split("::")[-1]
|
||||
epilogue_schedule = EpilogueScheduleTag[schedule_config.epilogue_schedule].split("::")[-1]
|
||||
tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler].split("::")[-1]
|
||||
|
||||
return f"{tile_shape}_{cluster_shape}_{kernel_schedule}" + f"_{epilogue_schedule}_{tile_scheduler}"
|
||||
|
||||
|
||||
# mostly unique shorter sch_sig
|
||||
def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str:
|
||||
kernel_terse_names_replace = {
|
||||
"KernelTmaWarpSpecializedCooperative": "TmaMI_",
|
||||
"TmaWarpSpecializedCooperative_": "TmaCoop_",
|
||||
"StreamKScheduler": "streamK",
|
||||
}
|
||||
|
||||
sch_sig = generate_sch_sig(schedule_config)
|
||||
for orig, terse in kernel_terse_names_replace.items():
|
||||
sch_sig = sch_sig.replace(orig, terse)
|
||||
return sch_sig
|
||||
|
||||
|
||||
# unique type_name
|
||||
def generate_type_signature(kernel_types: TypeConfig):
|
||||
return str("".join([MACHETEDataTypeNames[getattr(kernel_types, field.name)] for field in fields(TypeConfig)]))
|
||||
|
||||
|
||||
def generate_type_option_name(kernel_types: TypeConfig):
|
||||
return ", ".join(
|
||||
[
|
||||
f"{field.name.replace('b_', 'with_')+'_type'}=" + MACHETEDataTypeNames[getattr(kernel_types, field.name)]
|
||||
for field in fields(TypeConfig)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def is_power_of_two(n):
|
||||
return (n != 0) and (n & (n - 1) == 0)
|
||||
|
||||
|
||||
def to_cute_constant(value: list[int]):
|
||||
|
||||
def _to_cute_constant(value: int):
|
||||
if is_power_of_two(value):
|
||||
return f"_{value}"
|
||||
else:
|
||||
return f"Int<{value}>"
|
||||
|
||||
if isinstance(value, Iterable):
|
||||
return [_to_cute_constant(value) for value in value]
|
||||
else:
|
||||
return _to_cute_constant(value)
|
||||
|
||||
|
||||
def unique_schedules(impl_configs: list[ImplConfig]):
|
||||
return list(set(sch for impl_config in impl_configs for sch in impl_config.schedules))
|
||||
|
||||
|
||||
def unsigned_type_with_bitwidth(num_bits):
|
||||
return {
|
||||
4: DataType.u4,
|
||||
8: DataType.u8,
|
||||
16: DataType.u16,
|
||||
32: DataType.u32,
|
||||
64: DataType.u64,
|
||||
}[num_bits]
|
||||
|
||||
|
||||
template_globals = {
|
||||
"void": DataType.void,
|
||||
"DataTypeTag": MACHETEDataTypeTag,
|
||||
"MACHETEScalarTypeTag": MACHETEDataTypeMACHETEScalarTypeTag,
|
||||
"PaddleTypeTag": MACHETEDataTypePaddleDataTypeTag,
|
||||
"KernelScheduleTag": MACHETEKernelScheduleTag,
|
||||
"EpilogueScheduleTag": EpilogueScheduleTag,
|
||||
"TileSchedulerTag": TileSchedulerTag,
|
||||
"to_cute_constant": to_cute_constant,
|
||||
"gen_sch_sig": generate_terse_sch_sig,
|
||||
"gen_type_sig": generate_type_signature,
|
||||
"unique_schedules": unique_schedules,
|
||||
"unsigned_type_with_bitwidth": unsigned_type_with_bitwidth,
|
||||
"gen_type_option_name": generate_type_option_name,
|
||||
}
|
||||
|
||||
|
||||
def create_template(template_str):
|
||||
template = jinja2.Template(template_str)
|
||||
template.globals.update(template_globals)
|
||||
return template
|
||||
|
||||
|
||||
mm_dispatch_template = create_template(DISPATCH_TEMPLATE)
|
||||
mm_impl_template = create_template(IMPL_TEMPLATE)
|
||||
prepack_dispatch_template = create_template(PREPACK_TEMPLATE)
|
||||
|
||||
|
||||
def create_sources(impl_configs: list[ImplConfig], num_impl_files=8):
|
||||
sources = []
|
||||
|
||||
sources.append(
|
||||
(
|
||||
"machete_mm_dispatch",
|
||||
mm_dispatch_template.render(impl_configs=impl_configs),
|
||||
)
|
||||
)
|
||||
|
||||
prepack_types = []
|
||||
for impl_config in impl_configs:
|
||||
convert_type = (
|
||||
impl_config.types.a
|
||||
if impl_config.types.b_group_scale == DataType.void
|
||||
else impl_config.types.b_group_scale
|
||||
)
|
||||
prepack_types.append(
|
||||
PrepackTypeConfig(
|
||||
a=impl_config.types.a,
|
||||
b_num_bits=MACHETEDataTypeSize[impl_config.types.b],
|
||||
convert=convert_type,
|
||||
accumulator=impl_config.types.accumulator,
|
||||
)
|
||||
)
|
||||
|
||||
def prepacked_type_key(prepack_type: PrepackTypeConfig):
|
||||
# For now we we can just use the first accumulator type seen since
|
||||
# the tensor core shapes/layouts don't vary based on accumulator
|
||||
# type so we can generate less code this way
|
||||
return (prepack_type.a, prepack_type.b_num_bits, prepack_type.convert)
|
||||
|
||||
unique_prepack_types = []
|
||||
prepack_types_seen = set()
|
||||
for prepack_type in prepack_types:
|
||||
key = prepacked_type_key(prepack_type)
|
||||
if key not in prepack_types_seen:
|
||||
unique_prepack_types.append(prepack_type)
|
||||
prepack_types_seen.add(key)
|
||||
|
||||
sources.append(
|
||||
(
|
||||
"machete_prepack",
|
||||
prepack_dispatch_template.render(
|
||||
types=unique_prepack_types,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Split up impls across files
|
||||
num_impls = reduce(lambda x, y: x + len(y.schedules), impl_configs, 0)
|
||||
num_impls_per_file = math.ceil(num_impls / num_impl_files)
|
||||
|
||||
files_impls: list[list[ImplConfig]] = [[]]
|
||||
|
||||
curr_num_impls_assigned = 0
|
||||
curr_impl_in_file = 0
|
||||
curr_impl_configs = deepcopy(list(reversed(impl_configs)))
|
||||
|
||||
while curr_num_impls_assigned < num_impls:
|
||||
room_left_in_file = num_impls_per_file - curr_impl_in_file
|
||||
if room_left_in_file == 0:
|
||||
files_impls.append([])
|
||||
room_left_in_file = num_impls_per_file
|
||||
curr_impl_in_file = 0
|
||||
|
||||
curr_ic = curr_impl_configs[-1]
|
||||
if len(curr_ic.schedules) >= room_left_in_file:
|
||||
# Break apart the current impl config
|
||||
tmp_ic = deepcopy(curr_ic)
|
||||
tmp_ic.schedules = curr_ic.schedules[:room_left_in_file]
|
||||
curr_ic.schedules = curr_ic.schedules[room_left_in_file:]
|
||||
files_impls[-1].append(tmp_ic)
|
||||
else:
|
||||
files_impls[-1].append(curr_ic)
|
||||
curr_impl_configs.pop()
|
||||
curr_num_impls_assigned += len(files_impls[-1][-1].schedules)
|
||||
curr_impl_in_file += len(files_impls[-1][-1].schedules)
|
||||
|
||||
for part, file_impls in enumerate(files_impls):
|
||||
sources.append(
|
||||
(
|
||||
f"machete_mm_impl_part{part+1}",
|
||||
mm_impl_template.render(impl_configs=file_impls),
|
||||
)
|
||||
)
|
||||
|
||||
return sources
|
||||
|
||||
|
||||
def generate():
|
||||
# See csrc/quantization/machete/Readme.md, the Codegeneration for more info
|
||||
# about how this works
|
||||
SCRIPT_DIR = os.path.dirname(__file__)
|
||||
|
||||
sch_common_params = dict(
|
||||
kernel_schedule=TmaMI,
|
||||
epilogue_schedule=TmaCoop,
|
||||
tile_scheduler=TileSchedulerType.StreamK,
|
||||
)
|
||||
|
||||
# Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk))
|
||||
default_tile_heuristic_config = {
|
||||
# M = 257+
|
||||
"M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)),
|
||||
"M > 256": ((128, 256), (2, 1, 1)),
|
||||
# M = 129-256
|
||||
"M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)),
|
||||
"M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)),
|
||||
"M > 128": ((128, 256), (2, 1, 1)),
|
||||
# M = 65-128
|
||||
"M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)),
|
||||
"M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)),
|
||||
"M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)),
|
||||
"M > 64": ((128, 128), (2, 1, 1)),
|
||||
# M = 33-64
|
||||
"M > 40 && K <= 6144 && N <= 6144": ((128, 32), (2, 1, 1)),
|
||||
"M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)),
|
||||
"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)),
|
||||
"M > 32": ((128, 64), (2, 1, 1)),
|
||||
# M = 17-32
|
||||
"M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)),
|
||||
"M > 16": ((256, 32), (2, 1, 1)),
|
||||
# M = 1-16
|
||||
"N >= 26624": ((256, 16), (1, 1, 1)),
|
||||
None: ((128, 16), (1, 1, 1)),
|
||||
}
|
||||
|
||||
# For now we use the same heuristic for all types
|
||||
# Heuristic is currently tuned for H100s
|
||||
default_heuristic = [
|
||||
(cond, ScheduleConfig(*tile_config, **sch_common_params)) # type: ignore
|
||||
for cond, tile_config in default_tile_heuristic_config.items()
|
||||
]
|
||||
|
||||
def get_unique_schedules(heuristic: dict[str, ScheduleConfig]):
|
||||
# Do not use schedules = list(set(...)) because we need to make sure
|
||||
# the output list is deterministic; otherwise the generated kernel file
|
||||
# will be non-deterministic and causes ccache miss.
|
||||
schedules = []
|
||||
for _, schedule_config in heuristic:
|
||||
if schedule_config not in schedules:
|
||||
schedules.append(schedule_config)
|
||||
return schedules
|
||||
|
||||
impl_configs = []
|
||||
|
||||
GPTQ_kernel_type_configs = list(
|
||||
TypeConfig(
|
||||
a=a,
|
||||
b=b,
|
||||
b_group_scale=a,
|
||||
b_group_zeropoint=DataType.void,
|
||||
b_channel_scale=DataType.void,
|
||||
a_token_scale=DataType.void,
|
||||
out=a,
|
||||
accumulator=DataType.f32,
|
||||
)
|
||||
for b in (MACHETEDataType.u4b8, MACHETEDataType.u8b128)
|
||||
for a in (DataType.f16, DataType.bf16)
|
||||
)
|
||||
|
||||
impl_configs += [
|
||||
ImplConfig(x[0], x[1], x[2])
|
||||
for x in zip(
|
||||
GPTQ_kernel_type_configs,
|
||||
itertools.repeat(get_unique_schedules(default_heuristic)),
|
||||
itertools.repeat(default_heuristic),
|
||||
)
|
||||
]
|
||||
|
||||
output_dir = os.path.join(SCRIPT_DIR, "generated")
|
||||
|
||||
# Delete the "generated" directory if it exists
|
||||
if os.path.exists(output_dir):
|
||||
shutil.rmtree(output_dir)
|
||||
|
||||
# Create the "generated" directory
|
||||
os.makedirs(output_dir)
|
||||
|
||||
# Render each group of configurations into separate files
|
||||
for filename, code in create_sources(impl_configs):
|
||||
filepath = os.path.join(output_dir, f"{filename}.cu")
|
||||
with open(filepath, "w") as output_file:
|
||||
output_file.write(code)
|
||||
print(f"Rendered template to {filepath}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate()
|
31
custom_ops/gpu_ops/machete/machete_collective_builder.cuh
Normal file
31
custom_ops/gpu_ops/machete/machete_collective_builder.cuh
Normal file
@@ -0,0 +1,31 @@
|
||||
#pragma once
|
||||
|
||||
#include "utils/machete_collective_builder.cuh"
|
||||
#include "machete_mainloop.cuh"
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
using namespace cute;
|
||||
|
||||
struct MacheteKernelTag {};
|
||||
|
||||
template <class ElementPairA_, class GmemLayoutA_, int AlignmentA,
|
||||
class ElementPairB_, class GmemLayoutB_, int AlignmentB,
|
||||
class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
|
||||
class StageCountType, class KernelScheduleType>
|
||||
struct MacheteCollectiveBuilder<
|
||||
MacheteKernelTag, arch::Sm90, arch::OpClassTensorOp, ElementPairA_,
|
||||
GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_, AlignmentB,
|
||||
ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType,
|
||||
KernelScheduleType,
|
||||
cute::enable_if_t<(
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong> ||
|
||||
cute::is_same_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedCooperative>)>> {
|
||||
using CollectiveOp = machete::MacheteCollectiveMma<
|
||||
ElementPairA_, GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_,
|
||||
AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK,
|
||||
StageCountType, KernelScheduleType>;
|
||||
};
|
||||
|
||||
}; // namespace cutlass::gemm::collective
|
@@ -0,0 +1,85 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import enum
|
||||
from typing import Union
|
||||
|
||||
from cutlass_library import (
|
||||
DataType,
|
||||
DataTypeNames,
|
||||
DataTypeSize,
|
||||
DataTypeTag,
|
||||
KernelScheduleTag,
|
||||
KernelScheduleType,
|
||||
enum_auto,
|
||||
)
|
||||
|
||||
#
|
||||
# Extend cutlass library with custom types, and missing values
|
||||
#
|
||||
|
||||
|
||||
class MACHETEDataType(enum.Enum):
|
||||
u4b8 = enum_auto()
|
||||
u8b128 = enum_auto()
|
||||
|
||||
|
||||
class MixedInputKernelScheduleType(enum.Enum):
|
||||
TmaWarpSpecialized = enum_auto()
|
||||
TmaWarpSpecializedPingpong = enum_auto()
|
||||
TmaWarpSpecializedCooperative = enum_auto()
|
||||
|
||||
|
||||
MACHETEDataTypeNames: dict[Union[MACHETEDataType, DataType], str] = {
|
||||
**DataTypeNames, # type: ignore
|
||||
**{
|
||||
MACHETEDataType.u4b8: "u4b8",
|
||||
MACHETEDataType.u8b128: "u8b128",
|
||||
},
|
||||
}
|
||||
|
||||
MACHETEDataTypeTag: dict[Union[MACHETEDataType, DataType], str] = {
|
||||
**DataTypeTag, # type: ignore
|
||||
**{
|
||||
MACHETEDataType.u4b8: "cutlass::machete_uint4b8_t",
|
||||
MACHETEDataType.u8b128: "cutlass::machete_uint8b128_t",
|
||||
},
|
||||
}
|
||||
|
||||
MACHETEDataTypeSize: dict[Union[MACHETEDataType, DataType], int] = {
|
||||
**DataTypeSize, # type: ignore
|
||||
**{
|
||||
MACHETEDataType.u4b8: 4,
|
||||
MACHETEDataType.u8b128: 8,
|
||||
},
|
||||
}
|
||||
|
||||
MACHETEDataTypeMACHETEScalarTypeTag: dict[Union[MACHETEDataType, DataType], str] = {
|
||||
MACHETEDataType.u4b8: "machete::kU4B8",
|
||||
MACHETEDataType.u8b128: "machete::kU8B128",
|
||||
DataType.u4: "machete::kU4",
|
||||
DataType.u8: "machete::kU8",
|
||||
DataType.s4: "machete::kS4",
|
||||
DataType.s8: "machete::kS8",
|
||||
DataType.f16: "machete::kFloat16",
|
||||
DataType.bf16: "machete::kBfloat16",
|
||||
}
|
||||
|
||||
MACHETEDataTypePaddleDataTypeTag: dict[Union[MACHETEDataType, DataType], str] = {
|
||||
DataType.u8: "paddle::DataType::UINT8",
|
||||
DataType.s8: "paddle::DataType::INT8",
|
||||
DataType.e4m3: "paddle::DataType::FLOAT8_E4M3FN",
|
||||
DataType.s32: "paddle::DataType::INT32",
|
||||
DataType.f16: "paddle::DataType::FLOAT16",
|
||||
DataType.bf16: "paddle::DataType::BFLOAT16",
|
||||
DataType.f32: "paddle::DataType::FLOAT32",
|
||||
}
|
||||
|
||||
MACHETEKernelScheduleTag: dict[Union[MixedInputKernelScheduleType, KernelScheduleType], str] = {
|
||||
**KernelScheduleTag, # type: ignore
|
||||
**{
|
||||
MixedInputKernelScheduleType.TmaWarpSpecialized: "cutlass::gemm::KernelTmaWarpSpecialized",
|
||||
MixedInputKernelScheduleType.TmaWarpSpecializedPingpong: "cutlass::gemm::KernelTmaWarpSpecializedPingpong",
|
||||
MixedInputKernelScheduleType.TmaWarpSpecializedCooperative: "cutlass::gemm::KernelTmaWarpSpecializedCooperative",
|
||||
},
|
||||
}
|
35
custom_ops/gpu_ops/machete/machete_interleaving_utils.cuh
Normal file
35
custom_ops/gpu_ops/machete/machete_interleaving_utils.cuh
Normal file
@@ -0,0 +1,35 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/layout.hpp"
|
||||
|
||||
namespace machete {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
// get an interleaved block layout where each element consecutive element has a
|
||||
// stride of bit_stride and the block width is blk_bit_width,
|
||||
// examples:
|
||||
// size_bits<T> = 8, bit_stride = 8, blk_bit_width = 32 -> 4:1
|
||||
// size_bits<T> = 8, bit_stride = 16, blk_bit_width = 32 -> (2, 2):(2, 1)
|
||||
// size_bits<T> = 4, bit_stride = 8, blk_bit_width = 32 -> (4, 2):(2, 1)
|
||||
// size_bits<T> = 4, bit_stride = 16, blk_bit_width = 32 -> (2, 4):(4, 1)
|
||||
template <typename T, int bit_stride, int blk_bit_width>
|
||||
CUTE_HOST_DEVICE static constexpr auto get_interleaved_blk_layout() {
|
||||
static_assert(blk_bit_width % bit_stride == 0);
|
||||
static_assert(bit_stride % cute::sizeof_bits_v<T> == 0);
|
||||
|
||||
constexpr auto elems_per_blk = blk_bit_width / cute::sizeof_bits_v<T>;
|
||||
|
||||
if constexpr (cute::sizeof_bits_v<T> == bit_stride) {
|
||||
// identity layout
|
||||
return Layout<Shape<Int<elems_per_blk>>>{};
|
||||
} else {
|
||||
constexpr auto elems_per_stride = bit_stride / cute::sizeof_bits_v<T>;
|
||||
constexpr auto num_strides = elems_per_blk / elems_per_stride;
|
||||
return Layout<Shape<Int<num_strides>, Int<elems_per_stride>>,
|
||||
Stride<Int<elems_per_stride>, Int<1>>>{};
|
||||
}
|
||||
}
|
||||
|
||||
}; // namespace machete
|
1473
custom_ops/gpu_ops/machete/machete_mainloop.cuh
Normal file
1473
custom_ops/gpu_ops/machete/machete_mainloop.cuh
Normal file
File diff suppressed because it is too large
Load Diff
84
custom_ops/gpu_ops/machete/machete_mm.cu
Normal file
84
custom_ops/gpu_ops/machete/machete_mm.cu
Normal file
@@ -0,0 +1,84 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "machete_mm_launcher.cuh"
|
||||
#include "machete_prepack_launcher.cuh"
|
||||
|
||||
template <typename T>
|
||||
std::optional<T> ConvertToStdOptional(const paddle::optional<T>& paddle_opt) {
|
||||
return paddle_opt ? std::optional<T>(paddle_opt.get()) : std::nullopt;
|
||||
}
|
||||
|
||||
paddle::Tensor mm(paddle::Tensor const& A, paddle::Tensor const& B,
|
||||
int64_t b_type_id,
|
||||
std::optional<paddle::DataType> const& maybe_out_type,
|
||||
std::optional<paddle::Tensor> const& maybe_group_scales,
|
||||
std::optional<paddle::Tensor> const& maybe_group_zeros,
|
||||
int64_t maybe_group_size,
|
||||
std::optional<paddle::Tensor> const& maybe_channel_scales,
|
||||
std::optional<paddle::Tensor> const& maybe_token_scales,
|
||||
std::string maybe_schedule) {
|
||||
machete::ScalarType const b_type = machete::ScalarType::from_id(b_type_id);
|
||||
std::optional<int64_t> maybe_group_size_opt;
|
||||
std::optional<std::string> maybe_schedule_opt;
|
||||
if (maybe_schedule == "") {
|
||||
maybe_schedule_opt = std::nullopt;
|
||||
}
|
||||
return machete::mm_dispatch({.A = A,
|
||||
.B = B,
|
||||
.b_type = b_type,
|
||||
.maybe_out_type = maybe_out_type,
|
||||
.maybe_group_scales = maybe_group_scales,
|
||||
.maybe_group_zeros = maybe_group_zeros,
|
||||
.maybe_group_size = maybe_group_size_opt,
|
||||
.maybe_channel_scales = maybe_channel_scales,
|
||||
.maybe_token_scales = maybe_token_scales,
|
||||
.maybe_schedule = maybe_schedule_opt});
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> MacheteMMKernel(
|
||||
paddle::Tensor const& A, paddle::Tensor const& B,
|
||||
paddle::optional<paddle::Tensor> const& maybe_group_scales,
|
||||
paddle::optional<paddle::Tensor> const& maybe_group_zeros,
|
||||
paddle::optional<paddle::Tensor> const& maybe_channel_scales,
|
||||
paddle::optional<paddle::Tensor> const& maybe_token_scales,
|
||||
std::string const& b_type_str,
|
||||
std::string const& maybe_out_type_str,
|
||||
int64_t const& maybe_group_size,
|
||||
std::string const& maybe_schedule
|
||||
) {
|
||||
|
||||
machete::ScalarTypeId b_type_id;
|
||||
paddle::DataType maybe_out_type;
|
||||
if (b_type_str == "uint4b8") {
|
||||
b_type_id = machete::kU4B8.id();
|
||||
} else {
|
||||
PADDLE_ENFORCE(false, "b_type_str not supported!");
|
||||
}
|
||||
if (maybe_out_type_str == "float16") {
|
||||
maybe_out_type = paddle::DataType::FLOAT16;
|
||||
} else if (maybe_out_type_str == "bfloat16") {
|
||||
maybe_out_type = paddle::DataType::BFLOAT16;
|
||||
} else {
|
||||
maybe_out_type = A.dtype();
|
||||
}
|
||||
auto out = mm(A, B, b_type_id, maybe_out_type,
|
||||
ConvertToStdOptional<paddle::Tensor>(maybe_group_scales),
|
||||
ConvertToStdOptional<paddle::Tensor>(maybe_group_zeros),
|
||||
maybe_group_size,
|
||||
ConvertToStdOptional<paddle::Tensor>(maybe_channel_scales),
|
||||
ConvertToStdOptional<paddle::Tensor>(maybe_token_scales),
|
||||
maybe_schedule);
|
||||
return {out};
|
||||
}
|
305
custom_ops/gpu_ops/machete/machete_mm_kernel.cuh
Normal file
305
custom_ops/gpu_ops/machete/machete_mm_kernel.cuh
Normal file
@@ -0,0 +1,305 @@
|
||||
#pragma once
|
||||
|
||||
// clang-format off
|
||||
// The cutlass include order matters (annoyingly)
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
// clang-format on
|
||||
|
||||
#include "utils/cute_utils.cuh"
|
||||
#include "utils/machete_numeric_conversion.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
#include "utils/paddle_utils.hpp"
|
||||
#include "machete_collective_builder.cuh"
|
||||
#include "machete_prepacked_layout.cuh"
|
||||
#include "machete_interleaving_utils.cuh"
|
||||
|
||||
namespace machete {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
// NOTE This kernel computes D = alpha * A * B + beta * C by computing
|
||||
// D^t = alpha * B^t * A^t + beta * C^t, this is because the wgmma
|
||||
// instructions only support sourcing from registers for the left-hand
|
||||
// operand, we want to upconvert/decompress the quantized operand in
|
||||
// register. Since the primary use case we want to support is Y = XW^t where
|
||||
// W is quantized, in this situation or right-hand operand is quantized so
|
||||
// we compute the transpose to move it to the left-hand side.
|
||||
template <typename ElementA_, typename ElementB_, typename ElementD_,
|
||||
typename AccumulatorT, typename GroupScaleT, typename GroupZeroT,
|
||||
typename ChannelScaleT, typename TokenScaleT, class KernelSchedule,
|
||||
typename ScheduleConfig>
|
||||
struct MacheteKernelTemplate {
|
||||
static constexpr bool with_C = false; // not ever used
|
||||
static constexpr bool with_group_scales = !std::is_same_v<GroupScaleT, void>;
|
||||
static constexpr bool with_group_zeropoints =
|
||||
!std::is_same_v<GroupZeroT, void>;
|
||||
static constexpr bool with_channel_scales =
|
||||
!std::is_same_v<ChannelScaleT, void>;
|
||||
static constexpr bool with_token_scales = !std::is_same_v<TokenScaleT, void>;
|
||||
|
||||
using MmaType = ElementA_;
|
||||
using ElementA = ElementA_;
|
||||
using ElementB = ElementB_;
|
||||
using ElementD = ElementD_;
|
||||
using ElementC = cute::conditional_t<with_C, ElementD, void>;
|
||||
using ElementAccumulator = AccumulatorT;
|
||||
using ElementCompute = AccumulatorT; // For Epilogue
|
||||
// Use dummy values when we don't have scales or zeropoints
|
||||
using ElementZGroup =
|
||||
cute::conditional_t<with_group_zeropoints, GroupZeroT, MmaType>;
|
||||
using ElementSGroup =
|
||||
cute::conditional_t<with_group_scales, GroupScaleT, MmaType>;
|
||||
using ElementConvertGroup =
|
||||
cute::conditional_t<with_group_scales, GroupScaleT, MmaType>;
|
||||
using ElementSChannel =
|
||||
cute::conditional_t<with_channel_scales, ChannelScaleT, AccumulatorT>;
|
||||
using ElementSToken =
|
||||
cute::conditional_t<with_token_scales, TokenScaleT, AccumulatorT>;
|
||||
|
||||
using BTypeTuple = cute::conditional_t<
|
||||
with_group_scales,
|
||||
cute::conditional_t<with_group_zeropoints,
|
||||
cute::tuple<ElementB, ElementSGroup, ElementZGroup>,
|
||||
cute::tuple<ElementB, ElementSGroup>>,
|
||||
ElementB>;
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = LayoutC;
|
||||
using LayoutScale = cutlass::layout::RowMajor;
|
||||
// not actually used since B has the prepacked layout, but required by cutlass
|
||||
using _LayoutB = cutlass::layout::ColumnMajor;
|
||||
|
||||
// Interface strides expected by create_arguments (will get transposed)
|
||||
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
|
||||
using StrideC = cutlass::detail::TagToStrideA_t<LayoutC>;
|
||||
using StrideD = cutlass::detail::TagToStrideA_t<LayoutD>;
|
||||
using StrideSGroup = cutlass::detail::TagToStrideA_t<LayoutScale>;
|
||||
using StrideZGroup = StrideSGroup;
|
||||
|
||||
using LayoutA_Transpose =
|
||||
typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
||||
using LayoutC_Transpose =
|
||||
typename cutlass::layout::LayoutTranspose<LayoutC>::type;
|
||||
using LayoutD_Transpose =
|
||||
typename cutlass::layout::LayoutTranspose<LayoutD>::type;
|
||||
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
using PrepackedLayoutB =
|
||||
PrepackedLayoutBTemplate<ElementA_, ElementB_, ElementConvertGroup,
|
||||
AccumulatorT, LayoutA_Transpose, KernelSchedule>;
|
||||
|
||||
static int constexpr TileShapeK =
|
||||
128 * 8 / cutlass::sizeof_bits<MmaType>::value;
|
||||
static int constexpr AlignmentA = 128 / cutlass::sizeof_bits_v<ElementA>;
|
||||
static int constexpr AlignmentB = 128 / cutlass::sizeof_bits_v<ElementB>;
|
||||
static int constexpr AlignmentC =
|
||||
(with_C) ? 128 / cutlass::sizeof_bits_v<ElementC> : 0;
|
||||
static int constexpr AlignmentD = 128 / cutlass::sizeof_bits_v<ElementD>;
|
||||
|
||||
using TileShape = decltype(append(typename ScheduleConfig::TileShapeNM{},
|
||||
cute::Int<TileShapeK>{}));
|
||||
using ClusterShape = typename ScheduleConfig::ClusterShape;
|
||||
using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule;
|
||||
using EpilogueTileType = typename ScheduleConfig::EpilogueTileType;
|
||||
using TileScheduler = typename ScheduleConfig::TileScheduler;
|
||||
|
||||
static_assert(
|
||||
(!with_channel_scales && !with_token_scales) ||
|
||||
((with_channel_scales && with_token_scales) &&
|
||||
std::is_same_v<ElementSChannel, ElementSToken>),
|
||||
"Currently token and channel scales (if present) must be the same type");
|
||||
|
||||
// Currently only supports float scales
|
||||
using ChTokScalesEpilogue =
|
||||
typename fastdeploy::c3x::ScaledEpilogue<ElementAccumulator, ElementD,
|
||||
TileShape>;
|
||||
static_assert((with_channel_scales || with_token_scales) ||
|
||||
(std::is_same_v<ElementSChannel, float> &&
|
||||
std::is_same_v<ElementSToken, float>),
|
||||
"Currently token and channel scales (if present) must be float "
|
||||
"(and if one is present the other must be too)");
|
||||
|
||||
using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<
|
||||
cutlass::epilogue::fusion::Sm90AccFetch>;
|
||||
|
||||
using EVTCompute =
|
||||
std::conditional_t<with_channel_scales || with_token_scales,
|
||||
typename ChTokScalesEpilogue::EVTCompute,
|
||||
StoreEpilogueCompute>;
|
||||
|
||||
// EVTCompute
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType,
|
||||
ElementAccumulator, ElementSChannel, ElementC, LayoutC_Transpose,
|
||||
AlignmentC, ElementD, LayoutD_Transpose, AlignmentD, EpilogueSchedule,
|
||||
EVTCompute>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::MacheteCollectiveBuilder<
|
||||
cutlass::gemm::collective::MacheteKernelTag, ArchTag, OperatorClass,
|
||||
BTypeTuple, PrepackedLayoutB, AlignmentB, ElementA, LayoutA_Transpose,
|
||||
AlignmentA, ElementAccumulator, TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, // Indicates ProblemShape
|
||||
CollectiveMainloop, CollectiveEpilogue, TileScheduler>;
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// stride_B is unused (since B is prepacked), but still required by cutlass
|
||||
using _StrideB = cutlass::detail::TagToStrideB_t<_LayoutB>;
|
||||
|
||||
using Arguments = typename Gemm::Arguments;
|
||||
using MainloopArguments = typename GemmKernel::MainloopArguments;
|
||||
using EpilogueArguments = typename GemmKernel::EpilogueArguments;
|
||||
|
||||
static Arguments create_arguments(
|
||||
cudaStream_t stream,
|
||||
paddle::Tensor const& A, // MxK matrix
|
||||
paddle::Tensor const& B, // KxN prepacked matrix
|
||||
paddle::Tensor& D, // MxN matrix
|
||||
std::optional<paddle::Tensor> const& maybe_g_scales, // scale_KxN matrix
|
||||
std::optional<paddle::Tensor> const& maybe_g_zeros, // scale_KxN matrix
|
||||
std::optional<int64_t> maybe_group_size,
|
||||
std::optional<paddle::Tensor> const& maybe_ch_scales, // len N vector
|
||||
std::optional<paddle::Tensor> const& maybe_tok_scales) // len M vector
|
||||
{
|
||||
static_assert(!with_group_zeropoints || with_group_scales);
|
||||
|
||||
int M = A.shape()[0], N = B.shape()[1], K = A.shape()[1];
|
||||
PD_CHECK(D.shape()[0] == M && D.shape()[1] == N);
|
||||
|
||||
auto layout_A = make_cute_layout<StrideA>(A, "A");
|
||||
auto layout_D = make_cute_layout<StrideD>(D, "D");
|
||||
auto layout_S_group =
|
||||
maybe_make_cute_layout<StrideSGroup>(maybe_g_scales, "group_scales");
|
||||
auto layout_Z_group =
|
||||
maybe_make_cute_layout<StrideZGroup>(maybe_g_zeros, "group_zeros");
|
||||
int64_t numel_S_channel = maybe_ch_scales ? maybe_ch_scales->numel() : 0;
|
||||
int64_t numel_S_token = maybe_tok_scales ? maybe_tok_scales->numel() : 0;
|
||||
|
||||
auto unwrap = [](auto const& t) {
|
||||
return t ? t->data() : nullptr;
|
||||
};
|
||||
auto A_ptr = static_cast<ElementA const*>(A.data());
|
||||
auto B_ptr = static_cast<ElementB const*>(B.data());
|
||||
auto D_ptr = static_cast<ElementD*>(D.data());
|
||||
auto S_group_ptr =
|
||||
static_cast<ElementSGroup const*>(unwrap(maybe_g_scales));
|
||||
auto Z_group_ptr = static_cast<ElementZGroup const*>(unwrap(maybe_g_zeros));
|
||||
auto S_channel_ptr =
|
||||
static_cast<ElementSChannel const*>(unwrap(maybe_ch_scales));
|
||||
auto S_token_ptr =
|
||||
static_cast<ElementSToken const*>(unwrap(maybe_tok_scales));
|
||||
|
||||
int const group_size =
|
||||
maybe_group_size == -1 ? K : maybe_group_size.value_or(K);
|
||||
int const scale_k = (K + group_size - 1) / group_size;
|
||||
|
||||
PD_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K);
|
||||
PD_CHECK(size<0>(layout_D) == M && size<1>(layout_D) == N);
|
||||
|
||||
if constexpr (with_group_scales) {
|
||||
PD_CHECK(S_group_ptr && layout_S_group);
|
||||
PD_CHECK((size<0>(*layout_S_group) == scale_k &&
|
||||
size<1>(*layout_S_group) == N));
|
||||
} else {
|
||||
PD_CHECK(!S_group_ptr, "Scales not supported");
|
||||
}
|
||||
|
||||
if constexpr (with_group_zeropoints) {
|
||||
PD_CHECK(Z_group_ptr && layout_Z_group);
|
||||
PD_CHECK((size<0>(*layout_Z_group) == scale_k &&
|
||||
size<1>(*layout_Z_group) == N));
|
||||
PD_CHECK(layout_S_group && *layout_Z_group == *layout_S_group,
|
||||
"Scales and zeros must have the same layout");
|
||||
} else {
|
||||
PD_CHECK(!Z_group_ptr, "Zeropoints not supported");
|
||||
}
|
||||
|
||||
if constexpr (with_channel_scales || with_token_scales) {
|
||||
PD_CHECK(
|
||||
(maybe_ch_scales->numel() == N || maybe_ch_scales->numel() == 1) &&
|
||||
(maybe_tok_scales->numel() == M || maybe_tok_scales->numel() == 1));
|
||||
}
|
||||
|
||||
// Transpose A and D
|
||||
// A doesn't need to be transposed since cutlass expects a NxK matrix
|
||||
// for B (which is At)
|
||||
auto stride_At = layout_A.stride();
|
||||
auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride();
|
||||
|
||||
MainloopArguments mainloop_arguments{};
|
||||
// {Accum, C, C_layout, D, D}
|
||||
EpilogueArguments epilogue_arguments{};
|
||||
|
||||
if constexpr (with_channel_scales || with_token_scales) {
|
||||
epilogue_arguments =
|
||||
EpilogueArguments{ChTokScalesEpilogue::prepare_args(
|
||||
*maybe_ch_scales, *maybe_tok_scales),
|
||||
nullptr,
|
||||
{},
|
||||
D_ptr,
|
||||
stride_Dt};
|
||||
} else {
|
||||
epilogue_arguments = EpilogueArguments{{}, nullptr, {}, D_ptr, stride_Dt};
|
||||
}
|
||||
|
||||
if constexpr (with_group_scales && with_group_zeropoints) {
|
||||
auto stride_S_group = permute_layout<1, 0, 2>(*layout_S_group).stride();
|
||||
mainloop_arguments = MainloopArguments{
|
||||
B_ptr, _StrideB{}, A_ptr, stride_At,
|
||||
S_group_ptr, stride_S_group, group_size, Z_group_ptr};
|
||||
} else if constexpr (with_group_scales) {
|
||||
auto stride_S_group = permute_layout<1, 0, 2>(*layout_S_group).stride();
|
||||
mainloop_arguments =
|
||||
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At,
|
||||
S_group_ptr, stride_S_group, group_size};
|
||||
} else {
|
||||
mainloop_arguments =
|
||||
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At};
|
||||
}
|
||||
|
||||
return Arguments{cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{N, M, K, 1},
|
||||
mainloop_arguments,
|
||||
epilogue_arguments};
|
||||
};
|
||||
|
||||
static size_t get_workspace_size(Arguments const& args) {
|
||||
return Gemm::get_workspace_size(args);
|
||||
}
|
||||
|
||||
static bool can_implement(Arguments const& args) {
|
||||
return Gemm::can_implement(args) == cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
static void run(Arguments const& args, void* workspace, cudaStream_t stream) {
|
||||
Gemm gemm_op;
|
||||
|
||||
cutlass::Status status = gemm_op.initialize(args, workspace, stream);
|
||||
PD_CHECK(status == cutlass::Status::kSuccess,
|
||||
"Machete kernel failed to initialize workspace");
|
||||
|
||||
status = gemm_op.run(stream);
|
||||
PD_CHECK(status == cutlass::Status::kSuccess, "Machete kernel failed");
|
||||
}
|
||||
};
|
||||
|
||||
}; // namespace machete
|
78
custom_ops/gpu_ops/machete/machete_mm_launcher.cuh
Normal file
78
custom_ops/gpu_ops/machete/machete_mm_launcher.cuh
Normal file
@@ -0,0 +1,78 @@
|
||||
#pragma once
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
#include "machete_mm_kernel.cuh"
|
||||
#include "utils/paddle_utils.hpp"
|
||||
#include "utils/scalar_type.h"
|
||||
|
||||
namespace machete {
|
||||
|
||||
struct MMArgs {
|
||||
paddle::Tensor const& A;
|
||||
paddle::Tensor const& B;
|
||||
machete::ScalarType const& b_type;
|
||||
std::optional<paddle::DataType> const& maybe_out_type;
|
||||
std::optional<paddle::Tensor> const& maybe_group_scales;
|
||||
std::optional<paddle::Tensor> const& maybe_group_zeros;
|
||||
std::optional<int64_t> maybe_group_size;
|
||||
std::optional<paddle::Tensor> const& maybe_channel_scales;
|
||||
std::optional<paddle::Tensor> const& maybe_token_scales;
|
||||
std::optional<std::string> maybe_schedule;
|
||||
};
|
||||
|
||||
struct SupportedSchedulesArgs {
|
||||
paddle::DataType a_type;
|
||||
machete::ScalarType b_type;
|
||||
std::optional<paddle::DataType> maybe_group_scales_type;
|
||||
std::optional<paddle::DataType> maybe_group_zeros_type;
|
||||
std::optional<paddle::DataType> maybe_channel_scales_type;
|
||||
std::optional<paddle::DataType> maybe_token_scales_type;
|
||||
std::optional<paddle::DataType> maybe_out_type;
|
||||
};
|
||||
|
||||
paddle::Tensor mm_dispatch(MMArgs args);
|
||||
|
||||
std::vector<std::string> supported_schedules_dispatch(
|
||||
SupportedSchedulesArgs args);
|
||||
|
||||
template <typename MacheteKernel>
|
||||
paddle::Tensor run_impl(MMArgs args) {
|
||||
// const at::cuda::OptionalCUDAGuard device_guard(device_of(args.A));
|
||||
|
||||
// auto device = args.A.device();
|
||||
// auto stream = at::cuda::getCurrentCUDAStream(device.index());
|
||||
auto place = args.A.place();
|
||||
cudaStream_t stream = args.A.stream();
|
||||
|
||||
int M = args.A.shape()[0];
|
||||
int N = args.B.shape()[1];
|
||||
int K = args.A.shape()[1];
|
||||
|
||||
// Allocate output
|
||||
paddle::Tensor D = paddle::empty(
|
||||
{M, N},
|
||||
equivalent_scalar_type_v<typename MacheteKernel::ElementD>,
|
||||
place);
|
||||
|
||||
auto arguments = MacheteKernel::create_arguments(
|
||||
stream, //
|
||||
args.A, args.B, D, args.maybe_group_scales, args.maybe_group_zeros,
|
||||
args.maybe_group_size, args.maybe_channel_scales,
|
||||
args.maybe_token_scales);
|
||||
PD_CHECK(MacheteKernel::can_implement(arguments),
|
||||
"Machete kernel cannot be run with these arguments");
|
||||
|
||||
size_t workspace_size = MacheteKernel::get_workspace_size(arguments);
|
||||
int S = static_cast<int>(workspace_size);
|
||||
// phi::Allocator* allocator = paddle::GetAllocator(place);
|
||||
// auto workspace = allocator->Allocate(workspace_size);
|
||||
// MacheteKernel::run(arguments, workspace->ptr(), stream);
|
||||
// paddle::Tensor workspace = paddle::empty({S}, paddle::DataType::UINT8, place);
|
||||
paddle::Tensor workspace = GetEmptyTensor({S}, paddle::DataType::UINT8, place);
|
||||
MacheteKernel::run(arguments, workspace.data(), stream);
|
||||
|
||||
return D;
|
||||
};
|
||||
|
||||
}; // namespace machete
|
71
custom_ops/gpu_ops/machete/machete_prepack_B.cu
Normal file
71
custom_ops/gpu_ops/machete/machete_prepack_B.cu
Normal file
@@ -0,0 +1,71 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "machete_mm_launcher.cuh"
|
||||
#include "machete_prepack_launcher.cuh"
|
||||
|
||||
paddle::Tensor prepack_B(
|
||||
paddle::Tensor const& B, paddle::DataType const& a_type, int64_t b_type_id,
|
||||
std::string const& maybe_group_scales_type_str) {
|
||||
machete::ScalarType const b_type = machete::ScalarType::from_id(b_type_id);
|
||||
std::optional<paddle::DataType> maybe_group_scales_type;
|
||||
if (maybe_group_scales_type_str == "float16") {
|
||||
maybe_group_scales_type = paddle::DataType::FLOAT16;
|
||||
}
|
||||
else if (maybe_group_scales_type_str == "bfloat16") {
|
||||
maybe_group_scales_type = paddle::DataType::BFLOAT16;
|
||||
}
|
||||
else if (maybe_group_scales_type_str == "float32") {
|
||||
maybe_group_scales_type = paddle::DataType::FLOAT32;
|
||||
}
|
||||
else if (maybe_group_scales_type_str == "") {
|
||||
maybe_group_scales_type = std::nullopt;
|
||||
}
|
||||
else {
|
||||
PADDLE_ENFORCE(false, "maybe_group_scales_type_str not supported!");
|
||||
}
|
||||
return machete::prepack_B_dispatch(
|
||||
{.B = B,
|
||||
.a_type = a_type,
|
||||
.b_type = b_type,
|
||||
.maybe_group_scales_type = maybe_group_scales_type});
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> MachetePrepackBKernel(
|
||||
paddle::Tensor const& B, std::string const& a_type_str, std::string const& b_type_str,
|
||||
std::string const& maybe_group_scales_type_str) {
|
||||
|
||||
machete::ScalarTypeId b_type_id;
|
||||
paddle::DataType a_type, maybe_group_scales_type;
|
||||
|
||||
if (b_type_str == "uint4b8") {
|
||||
b_type_id = machete::kU4B8.id();
|
||||
} else {
|
||||
PADDLE_ENFORCE(false, "b_type_str not supported!");
|
||||
}
|
||||
|
||||
if (a_type_str == "float16") {
|
||||
a_type = paddle::DataType::FLOAT16;
|
||||
}
|
||||
else if (a_type_str == "bfloat16") {
|
||||
a_type = paddle::DataType::BFLOAT16;
|
||||
}
|
||||
else {
|
||||
PADDLE_ENFORCE(false, "a_type_str not supported!");
|
||||
}
|
||||
auto Bt = paddle::experimental::transpose(B, {1, 0});
|
||||
paddle::Tensor B_prepacked = prepack_B(Bt, a_type, b_type_id, maybe_group_scales_type_str);
|
||||
return {B_prepacked};
|
||||
|
||||
}
|
76
custom_ops/gpu_ops/machete/machete_prepack_kernel.cuh
Normal file
76
custom_ops/gpu_ops/machete/machete_prepack_kernel.cuh
Normal file
@@ -0,0 +1,76 @@
|
||||
#pragma once
|
||||
|
||||
#include "machete_mm_kernel.cuh"
|
||||
#include "utils/cute_utils.cuh"
|
||||
#include "utils/paddle_utils.hpp"
|
||||
|
||||
namespace machete {
|
||||
|
||||
template <int threads, typename PrepackedLayoutB, typename BInTensor,
|
||||
typename ElementB>
|
||||
static __global__ void prepack_B_kernel(BInTensor B_in, ElementB* B_out_ptr) {
|
||||
auto constexpr block_size =
|
||||
Int<size(typename PrepackedLayoutB::PPBlockShape_NK{})>{};
|
||||
auto constexpr eles_per_thread = Int<block_size / threads>{};
|
||||
static_assert(block_size % threads == 0,
|
||||
"block_size must be divisible by the number of threads");
|
||||
|
||||
// Which pre-packed are we responsible for
|
||||
auto blk_coord = make_coord(blockIdx.x, blockIdx.y, blockIdx.z);
|
||||
auto tB_in = local_tile(
|
||||
B_in, append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{}),
|
||||
blk_coord);
|
||||
|
||||
// Find the start offset in the output for this pre-packed block
|
||||
auto bNbKL_to_offset = PrepackedLayoutB::bNbKL_to_offset(shape(B_in));
|
||||
|
||||
// Tensor representing a 1:1 mapping to the output space in 1D
|
||||
auto tB_out_linear =
|
||||
make_tensor(get_logical_ptr(B_out_ptr) + bNbKL_to_offset(blk_coord),
|
||||
make_layout(make_shape(block_size)));
|
||||
// Mapping from output space (1D) to input space
|
||||
auto tB_in_linear = make_tensor(
|
||||
tB_in.data(),
|
||||
tB_in.layout()
|
||||
.compose(right_inverse(PrepackedLayoutB::ppblock_ilvd_NK_to_offset()))
|
||||
.with_shape(make_shape(block_size)));
|
||||
|
||||
// Tile for this specific thread (could have used a TiledCopy but these work
|
||||
// best with 2d layouts, this is a simple 1d layout so local_tile is enough,
|
||||
// we are also not that concerned with performance for this kernel)
|
||||
auto thr_tB_in_linear =
|
||||
local_tile(tB_in_linear, make_shape(eles_per_thread), threadIdx.x);
|
||||
auto thr_tB_out_linear =
|
||||
local_tile(tB_out_linear, make_shape(eles_per_thread), threadIdx.x);
|
||||
|
||||
// Construct a register-backed Tensor with the same shape as each thread's
|
||||
// partition
|
||||
auto fragment = make_tensor<ElementB>(shape(thr_tB_in_linear));
|
||||
|
||||
copy(thr_tB_in_linear, fragment);
|
||||
copy(Copy_Atom<DefaultCopy, uint8_t>{}, fragment, thr_tB_out_linear);
|
||||
}
|
||||
|
||||
template <typename PrepackedLayoutB, typename InLayout>
|
||||
static void prepack_B_template(
|
||||
cudaStream_t stream, typename PrepackedLayoutB::ElementB const* B_in_ptr,
|
||||
InLayout B_layout, typename PrepackedLayoutB::ElementB* B_out_ptr) {
|
||||
using TileShapeNKL =
|
||||
decltype(append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{}));
|
||||
auto ilvd_NKbNbKL_to_offset =
|
||||
PrepackedLayoutB::ilvd_NKbNbKL_to_offset(shape(B_layout));
|
||||
|
||||
PD_CHECK(size<0>(B_layout) % size<0>(TileShapeNKL{}) == 0);
|
||||
PD_CHECK(size<1>(B_layout) % size<1>(TileShapeNKL{}) == 0);
|
||||
|
||||
auto N_tiles = size<0>(B_layout) / size<0>(TileShapeNKL{});
|
||||
auto K_tiles = size<1>(B_layout) / size<1>(TileShapeNKL{});
|
||||
auto L_tiles = size<2>(B_layout);
|
||||
|
||||
auto B_in = make_tensor(get_logical_ptr(B_in_ptr), B_layout);
|
||||
|
||||
prepack_B_kernel<128, PrepackedLayoutB>
|
||||
<<<dim3(N_tiles, K_tiles, L_tiles), 128, 0, stream>>>(B_in, B_out_ptr);
|
||||
}
|
||||
|
||||
}; // namespace machete
|
77
custom_ops/gpu_ops/machete/machete_prepack_launcher.cuh
Normal file
77
custom_ops/gpu_ops/machete/machete_prepack_launcher.cuh
Normal file
@@ -0,0 +1,77 @@
|
||||
#pragma once
|
||||
|
||||
#include "machete_prepack_kernel.cuh"
|
||||
#include "utils/paddle_utils.hpp"
|
||||
#include "utils/scalar_type.h"
|
||||
|
||||
namespace machete {
|
||||
|
||||
struct PrepackBArgs {
|
||||
paddle::Tensor const& B;
|
||||
paddle::DataType a_type;
|
||||
machete::ScalarType b_type;
|
||||
std::optional<paddle::DataType> maybe_group_scales_type;
|
||||
};
|
||||
|
||||
template <typename PrepackedLayoutB>
|
||||
paddle::Tensor prepack_impl(paddle::Tensor const B) {
|
||||
// const at::cuda::OptionalCUDAGuard device_guard(device_of(B));
|
||||
using ElementB = typename PrepackedLayoutB::ElementB;
|
||||
using PPBlockShape_NK = typename PrepackedLayoutB::PPBlockShape_NK;
|
||||
|
||||
// auto device = B.device();
|
||||
// auto stream = at::cuda::getCurrentCUDAStream(device.index());
|
||||
cudaStream_t stream = B.stream();
|
||||
auto B_ptr = static_cast<ElementB const*>(B.data());
|
||||
// elements per storage item for B
|
||||
auto eles_per_storage =
|
||||
(SizeOf(B.dtype()) * 8) / cute::sizeof_bits_v<ElementB>;
|
||||
|
||||
// paddle B passed in is/should be (packed_K,N), the kernel expects (N,K,L) (to
|
||||
// match cutlass using (N,K,L) for B), so we transpose B to (N,packed_K,L)
|
||||
// auto Bt_packed = B.transpose();
|
||||
auto Bt_packed = paddle::experimental::transpose(B, {1, 0});
|
||||
|
||||
PD_CHECK(
|
||||
(B.shape()[0] * eles_per_storage) % size<1>(PPBlockShape_NK{}) == 0,
|
||||
"B.shape[0] (in terms of unpacked elements) must be a multiple of ",
|
||||
size<1>(PPBlockShape_NK{}));
|
||||
PD_CHECK(B.shape()[1] % size<0>(PPBlockShape_NK{}) == 0,
|
||||
"B.shape[1] must be a multiple of ", size<0>(PPBlockShape_NK{}));
|
||||
|
||||
using StrideB = cutlass::detail::TagToStrideB_t<cutlass::layout::ColumnMajor>;
|
||||
auto const l_Bt_packed = make_cute_layout<StrideB>(Bt_packed, "B");
|
||||
// auto const l_Bt_packed = make_cute_layout<StrideB>(B, "B");
|
||||
|
||||
// convert (N,packed_K,L) layout to (N,K,L) layout
|
||||
// in effect we want to do: blocked_product(layout_Bt_packed,
|
||||
// make_ordered_layout(make_shape(_1{}, eles_per_storage, _1{}),
|
||||
// Step<_1, _0, _2>{}));
|
||||
// but blocked_product does not support dynamic strides so we implement the
|
||||
// equivalent manually,
|
||||
// new_shape = (N, packed_K, L) * (1, eles_per_storage, 1) -> (N, K, L)
|
||||
// new_stride = (s0, s1, s2) * (eles_per_storage, 1, eles_per_storage)
|
||||
// when s1 == 1
|
||||
PD_CHECK(stride<1>(l_Bt_packed) == 1, "stride<1>(l_Bt_packed) must be 1");
|
||||
// clang-format off
|
||||
auto const layout_Bt = make_layout(
|
||||
transform_with_idx(l_Bt_packed.shape(), [&](auto ele, auto idx) {
|
||||
return idx == 1 ? ele * eles_per_storage : ele;
|
||||
}),
|
||||
transform_with_idx(l_Bt_packed.stride(), [&](auto ele, auto idx) {
|
||||
return idx != 1 ? ele * eles_per_storage : ele;
|
||||
}));
|
||||
// clang-format on
|
||||
|
||||
// Allocate output
|
||||
paddle::Tensor D = paddle::empty_like(B);
|
||||
|
||||
prepack_B_template<PrepackedLayoutB>(
|
||||
stream, B_ptr, layout_Bt, static_cast<ElementB*>(D.data()));
|
||||
|
||||
return D;
|
||||
};
|
||||
|
||||
paddle::Tensor prepack_B_dispatch(PrepackBArgs args);
|
||||
|
||||
}; // namespace machete
|
249
custom_ops/gpu_ops/machete/machete_prepacked_layout.cuh
Normal file
249
custom_ops/gpu_ops/machete/machete_prepacked_layout.cuh
Normal file
@@ -0,0 +1,249 @@
|
||||
#pragma once
|
||||
|
||||
// clang-format off
|
||||
// The cutlass include order matters (annoyingly)
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
// clang-format on
|
||||
|
||||
#include "utils/cute_utils.cuh"
|
||||
#include "machete_collective_builder.cuh"
|
||||
#include "machete_interleaving_utils.cuh"
|
||||
|
||||
namespace machete {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
struct IlvBlkLayoutAuto {};
|
||||
|
||||
// This defines a prepacked layout for the B matrix, where the matrix is broken
|
||||
// up into PPBlockShape_NK blocks. The data within each block is then compactly
|
||||
// stored in memory such that when performing a TiledMMA operation with the same
|
||||
// shape as prepacked block, all the data for a given thread is contiguous in
|
||||
// memory. This allows us to use wider shared memory loads when loading B from
|
||||
// shared memory. The values within a thread are also potentially interlaeved
|
||||
// inorder to allow for more efficient upconverting.
|
||||
//
|
||||
// The contract here is that the `TiledMma` determined below matches the one
|
||||
// ultimately used in the kernel. (this is also why the other element types are
|
||||
// required along with the kernel schedule)
|
||||
template <typename ElementA_, typename ElementB_, typename ElementConvert_,
|
||||
typename AccumulatorT, class LayoutB, class KernelSchedule,
|
||||
typename IlvBlkLayout_ = IlvBlkLayoutAuto>
|
||||
// clang-format on
|
||||
struct PrepackedLayoutBTemplate {
|
||||
using MmaType = ElementA_;
|
||||
using ElementA = ElementA_;
|
||||
using ElementB = ElementB_;
|
||||
using ElementAccumulator = AccumulatorT;
|
||||
using ElementMma = MmaType;
|
||||
|
||||
// Interleave for 4bit bit types when we are not upconverting to fp8 or int8,
|
||||
// in those cases case we use a LUT using prmt instructions to upconvert and
|
||||
// is more efficient if the data is not interleaved For 8bit+ prmt
|
||||
// instructions makes non-interleaved layouts efficient enough we don't need
|
||||
// iterleaved layouts (and can reuse more of the existing cutlass converts)
|
||||
static constexpr bool should_interleave =
|
||||
sizeof_bits_v<ElementB> <= 4 &&
|
||||
!std::is_same_v<ElementConvert_, cutlass::float_e4m3_t> &&
|
||||
!std::is_same_v<ElementConvert_, int8_t>;
|
||||
|
||||
// Only use interleaved layouts for subbyte weights,
|
||||
using IlvdBlkLayout = std::conditional_t<
|
||||
std::is_same_v<IlvBlkLayout_, IlvBlkLayoutAuto>,
|
||||
std::conditional_t<
|
||||
should_interleave,
|
||||
decltype(get_interleaved_blk_layout<
|
||||
ElementB, sizeof_bits_v<ElementConvert_>, 32>()),
|
||||
void>,
|
||||
IlvBlkLayout_>;
|
||||
|
||||
// TODO (LucasWilkinson): compare the performance for other sizes
|
||||
// Prepacked block shape, smallest layout atom for loading into registers
|
||||
// (can contain multiple wgmma instructions worth of data in one block)
|
||||
// We ideally want this to be configured such that a thread can perform 128bit
|
||||
// loads, i.e. we amount of data associated with each thread within a
|
||||
// prepacked block is a multiple of 128bits, when using a cooperative sechdule
|
||||
// we have 256 threads working a single block at a time, this means each
|
||||
// thread works on `sizeof_bits_v<ElementB> * (128*64) / 256` bits of data,
|
||||
// for a 4bit type this would be 128bits
|
||||
using PPBlockShape_NK = Shape<_128, _64>;
|
||||
|
||||
// Create the shape of the tile anticipated to be used by the GEMM kernel,
|
||||
// when the kernel executes we will compute `Ct = Bt * At` since the
|
||||
// quantized weights (B), must be the lhs operand so the flow through
|
||||
// registers.
|
||||
// The _128 here doesn't actually impact the shape of the stored tile directly
|
||||
// but may impact the op selected by rs_op_selector
|
||||
using GemmTileShape = decltype(make_shape(size<0>(PPBlockShape_NK{}), _128{},
|
||||
size<1>(PPBlockShape_NK{})));
|
||||
|
||||
static constexpr cute::GMMA::Major GmmaMajorB =
|
||||
gmma_rs_tag_to_major_B<LayoutB>();
|
||||
|
||||
// For coop schedules we have two warp groups cooperatively issuing wgmma
|
||||
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
|
||||
using AtomLayoutMNK = cute::conditional_t<
|
||||
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperative>,
|
||||
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
|
||||
|
||||
using TiledMma = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::rs_op_selector<ElementMma, ElementMma, ElementAccumulator,
|
||||
GemmTileShape, GMMA::Major::K, GmmaMajorB>(),
|
||||
AtomLayoutMNK{}));
|
||||
|
||||
// Prepacked block, (athrid, val) -> (N,K)
|
||||
// i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (N,K)
|
||||
CUTE_HOST_DEVICE static constexpr auto ppblock_TV_to_NK() {
|
||||
return TiledMma{}.thrfrg_A(make_layout(PPBlockShape_NK{}));
|
||||
}
|
||||
|
||||
// Prepacked block, (N,K) -> (athrid, val)
|
||||
// i.e. (N,K) -> ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...)))
|
||||
CUTE_HOST_DEVICE static constexpr auto ppblock_NK_to_TV() {
|
||||
return right_inverse(ppblock_TV_to_NK()).with_shape(PPBlockShape_NK{});
|
||||
}
|
||||
|
||||
// Prepacked block, (athrid, val) -> (storage_offset)
|
||||
// i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (storage_idx)
|
||||
CUTE_HOST_DEVICE static constexpr auto ppblock_TV_to_offset() {
|
||||
// Return iterleaved layout
|
||||
return make_ordered_layout(shape(ppblock_TV_to_NK()), Step<_1, _0>{});
|
||||
}
|
||||
|
||||
// Prepacked block, (athrid, val) -> (storage_offset)
|
||||
// i.e. ((ThrV,(ThrM,ThrK)),(IlvdFrgV,(RestM,RestK,...))) -> (storage_idx)
|
||||
CUTE_HOST_DEVICE static constexpr auto ppblock_ilvd_TV_to_offset() {
|
||||
auto layout_no_interleave =
|
||||
make_ordered_layout(shape(ppblock_TV_to_NK()), Step<_1, _0>{});
|
||||
|
||||
if constexpr (std::is_same_v<IlvdBlkLayout, void>) {
|
||||
return layout_no_interleave;
|
||||
} else {
|
||||
// interleave by transforming FrgV into interleaved blocks where each
|
||||
// block has the layout IlvdBlkLayout, for example if IlvdBlkLayout is
|
||||
// (2, 2) : (2, 1) then we get: ((2, 2), size(FrgV) / 4) : ((2, 1), 4)
|
||||
// if FrgV is {A, B, C, D, E, F, G, H}
|
||||
// then ((IlvBlk), FrgB) is {A, C, B, D, C, G, D, H}
|
||||
auto frgV = get<1, 0>(layout_no_interleave);
|
||||
auto ilvdBlk = IlvdBlkLayout{};
|
||||
static_assert(size(frgV) % size(ilvdBlk) == 0,
|
||||
"FrgV must be divisible by size(ilvdBlk)");
|
||||
auto ilvd_FrgV = make_layout(
|
||||
make_shape(shape(ilvdBlk), Int<size(frgV) / size(ilvdBlk)>{}),
|
||||
make_stride(stride(ilvdBlk), size(ilvdBlk)));
|
||||
|
||||
// Return iterleaved layout
|
||||
return make_layout(
|
||||
get<0>(layout_no_interleave),
|
||||
make_layout(ilvd_FrgV, get<1, 1>(layout_no_interleave)));
|
||||
}
|
||||
}
|
||||
|
||||
// Prepacked block, (M,K) -> (storage_offset)
|
||||
CUTE_HOST_DEVICE static constexpr auto ppblock_ilvd_NK_to_offset() {
|
||||
// do (M,K) -> (athrid, val) -> (storage_idx)
|
||||
return ppblock_ilvd_TV_to_offset().compose(ppblock_NK_to_TV());
|
||||
}
|
||||
|
||||
// ((athrid, val), (BlocksN, BlocksK), L) -> (storage_idx)
|
||||
template <typename Shape_NKL>
|
||||
CUTE_HOST_DEVICE static constexpr auto TVbNbKL_to_offset(
|
||||
Shape_NKL shape_mkl) {
|
||||
constexpr auto block_layout = ppblock_TV_to_offset();
|
||||
|
||||
// (BlocksN, BlocksK, L)
|
||||
auto blocks_shape =
|
||||
cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
|
||||
[](auto x, auto y) { return x / y; });
|
||||
|
||||
// ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx)
|
||||
auto result = make_layout(
|
||||
block_layout,
|
||||
make_layout(blocks_shape,
|
||||
compact_col_major(blocks_shape, size(block_layout))));
|
||||
|
||||
// ((athrid, val), (BlocksN, BlocksK, L))
|
||||
// => ((athrid, val), (BlocksN, BlocksK), L)
|
||||
return group<1, 3>(result(_, repeat<rank<1>(result)>(_)));
|
||||
}
|
||||
|
||||
// ((athrid_val), (BlocksN, BlocksK, L)) -> (N, K, L)
|
||||
template <typename Shape_NKL>
|
||||
CUTE_HOST_DEVICE static constexpr auto TVbNbKL_to_offset_copy(
|
||||
Shape_NKL shape_mkl) {
|
||||
auto layout = TVbNbKL_to_offset(shape_mkl);
|
||||
// for 4-bit elements, having >= 64 values per column
|
||||
// allows TMA to load full 32-byte sectors
|
||||
auto inner_layout =
|
||||
make_layout(make_shape(_256{}, size<0>(layout) / _256{}));
|
||||
|
||||
return make_layout(inner_layout, get<1>(layout), get<2>(layout));
|
||||
}
|
||||
|
||||
// ((BlockN, BlockK), (BlocksN, BlocksK), L) -> (storage_idx)
|
||||
template <typename Shape_NKL>
|
||||
CUTE_HOST_DEVICE static constexpr auto ilvd_NKbNbKL_to_offset(
|
||||
Shape_NKL shape_mkl) {
|
||||
constexpr auto block_layout = ppblock_ilvd_NK_to_offset();
|
||||
|
||||
// (BlocksN, BlocksK, L)
|
||||
auto blocks_shape =
|
||||
cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
|
||||
[](auto x, auto y) { return x / y; });
|
||||
|
||||
// ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx)
|
||||
auto result = make_layout(
|
||||
block_layout,
|
||||
make_layout(blocks_shape,
|
||||
compact_col_major(blocks_shape, size(block_layout))));
|
||||
|
||||
// ((athrid, val), (BlocksN, BlocksK, L)) => ((athrid, val), (BlocksN,
|
||||
// BlocksK), L)
|
||||
return group<1, 3>(result(_, repeat<rank<1>(result)>(_)));
|
||||
}
|
||||
|
||||
// (BlocksN, BlocksK, L) -> (storage_idx)
|
||||
template <typename Shape_NKL>
|
||||
CUTE_HOST_DEVICE static constexpr auto bNbKL_to_offset(Shape_NKL shape_mkl) {
|
||||
// (BlocksN, BlocksK, L)
|
||||
auto blocks_shape =
|
||||
cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
|
||||
[](auto x, auto y) { return x / y; });
|
||||
auto stride = size(PPBlockShape_NK{});
|
||||
|
||||
// (BlocksN, BlocksK, L) -> (storage_idx)
|
||||
return make_layout(blocks_shape, compact_col_major(blocks_shape, stride));
|
||||
}
|
||||
|
||||
// ((athrid, val), (BlocksN, BlocksK, L)) -> (N, K, L)
|
||||
template <class Shape_NKL>
|
||||
CUTE_HOST_DEVICE static auto TVbNbK_to_NKL(Shape_NKL shape_mkl) {
|
||||
auto tile = make_tile(make_layout(size<0>(PPBlockShape_NK{})),
|
||||
make_layout(size<1>(PPBlockShape_NK{})));
|
||||
|
||||
// ((BlockN, BlockK), (BlocksN, BlocksK, L)) -> (N, K, L)
|
||||
auto tiled_A = zipped_divide(make_layout(shape_mkl), tile);
|
||||
return tiled_A.compose(ppblock_TV_to_NK(), _);
|
||||
}
|
||||
|
||||
// (N, K, L) -> ((athrid, val), (BlocksN, BlocksK), L)
|
||||
template <class Shape_NKL>
|
||||
CUTE_HOST_DEVICE static auto NKL_to_TVbNbK(Shape_NKL shape_mkl) {
|
||||
auto TVbNbK_to_NKL_layout = TVbNbK_to_NKL(shape_mkl);
|
||||
return blocked_product(ppblock_NK_to_TV(),
|
||||
make_layout(shape<1>(TVbNbK_to_NKL_layout)));
|
||||
}
|
||||
};
|
||||
|
||||
}; // namespace machete
|
72
custom_ops/gpu_ops/machete/machete_supported_schedules.cu
Normal file
72
custom_ops/gpu_ops/machete/machete_supported_schedules.cu
Normal file
@@ -0,0 +1,72 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "machete_mm_launcher.cuh"
|
||||
#include "machete_prepack_launcher.cuh"
|
||||
|
||||
template <typename T>
|
||||
std::optional<T> ConvertToStdOptional(const paddle::optional<T>& paddle_opt) {
|
||||
return paddle_opt ? std::optional<T>(paddle_opt.get()) : std::nullopt;
|
||||
}
|
||||
|
||||
std::vector<std::string> supported_schedules(
|
||||
paddle::DataType a_type, int64_t b_type_id,
|
||||
std::optional<paddle::DataType> maybe_group_scales_type,
|
||||
std::optional<paddle::DataType> maybe_group_zeros_type,
|
||||
std::optional<paddle::DataType> maybe_channel_scales_type,
|
||||
std::optional<paddle::DataType> maybe_token_scales_type,
|
||||
std::optional<paddle::DataType> maybe_out_type) {
|
||||
machete::ScalarType const b_type = machete::ScalarType::from_id(b_type_id);
|
||||
auto schedules = machete::supported_schedules_dispatch({
|
||||
.a_type = a_type,
|
||||
.b_type = b_type,
|
||||
.maybe_group_scales_type = maybe_group_scales_type,
|
||||
.maybe_group_zeros_type = maybe_group_zeros_type,
|
||||
.maybe_channel_scales_type = maybe_channel_scales_type,
|
||||
.maybe_token_scales_type = maybe_token_scales_type,
|
||||
.maybe_out_type = maybe_out_type
|
||||
});
|
||||
return schedules;
|
||||
}
|
||||
|
||||
std::vector<std::string> MacheteSupportedSchedules(
|
||||
std::string const& a_type_str, std::string const& b_type_str) {
|
||||
machete::ScalarTypeId b_type_id;
|
||||
paddle::DataType a_type;
|
||||
if (b_type_str == "uint4b8") {
|
||||
b_type_id = machete::kU4B8.id();
|
||||
} else {
|
||||
PADDLE_ENFORCE(false, "b_type_str not supported!");
|
||||
}
|
||||
if (a_type_str == "bfloat16") {
|
||||
a_type = paddle::DataType::BFLOAT16;
|
||||
} else if (a_type_str == "float16") {
|
||||
a_type = paddle::DataType::FLOAT16;
|
||||
} else {
|
||||
PADDLE_ENFORCE(false, "a_type_str not supported!");
|
||||
}
|
||||
std::optional<paddle::DataType> maybe_group_scales_type = std::optional<paddle::DataType>(a_type);
|
||||
std::optional<paddle::DataType> maybe_out_type = std::optional<paddle::DataType>(a_type);
|
||||
std::optional<paddle::DataType> maybe_group_zeros_type = std::nullopt;
|
||||
std::optional<paddle::DataType> maybe_channel_scales_type = std::nullopt;
|
||||
std::optional<paddle::DataType> maybe_token_scales_type = std::nullopt;
|
||||
|
||||
auto schedules = supported_schedules(a_type, b_type_id,
|
||||
maybe_group_scales_type,
|
||||
maybe_group_zeros_type,
|
||||
maybe_channel_scales_type,
|
||||
maybe_token_scales_type,
|
||||
maybe_out_type);
|
||||
return schedules;
|
||||
}
|
69
custom_ops/gpu_ops/machete/utils/cute_utils.cuh
Normal file
69
custom_ops/gpu_ops/machete/utils/cute_utils.cuh
Normal file
@@ -0,0 +1,69 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/cutlass_extensions/cute_utils.cuh
|
||||
#pragma once
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
namespace cute {
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
// layout utils
|
||||
////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Permute layout based on indices, example:
|
||||
// permute_layout<1, 0>(layout) will swap the two dimensions
|
||||
// permute_layout<0, 2, 1>(layout) will swap the last two dimensions
|
||||
template <size_t... I, typename Layout>
|
||||
CUTE_HOST_DEVICE static constexpr auto permute_layout(Layout l) {
|
||||
static_assert(rank(l) == sizeof...(I), "Invalid permutation, rank mismatch");
|
||||
return cute::make_layout(cute::get<I>(l)...);
|
||||
}
|
||||
|
||||
// is the layout f(x) = x
|
||||
template <typename Layout>
|
||||
CUTE_HOST_DEVICE static constexpr bool is_identity_layout() {
|
||||
if constexpr (std::is_same_v<Layout, void>) {
|
||||
return true;
|
||||
} else {
|
||||
constexpr auto coalesced_layout = coalesce(Layout{});
|
||||
if constexpr (rank(coalesced_layout) == 1 &&
|
||||
stride<0>(coalesced_layout) == 1) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
// Pointer utils
|
||||
////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <class PointerType>
|
||||
static constexpr auto get_logical_ptr(PointerType* ptr) {
|
||||
if constexpr (cute::sizeof_bits_v<PointerType> < 8) {
|
||||
return cute::subbyte_iterator<PointerType>(ptr);
|
||||
} else {
|
||||
return ptr;
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
// Misc utils
|
||||
////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename Elements>
|
||||
CUTE_HOST_DEVICE static constexpr auto create_auto_vectorizing_copy() {
|
||||
constexpr auto bits = sizeof_bits_v<T> * Elements{};
|
||||
if constexpr (bits % 128 == 0) {
|
||||
return AutoVectorizingCopyWithAssumedAlignment<128>{};
|
||||
} else if constexpr (bits % 64 == 0) {
|
||||
return AutoVectorizingCopyWithAssumedAlignment<64>{};
|
||||
} else if constexpr (bits % 32 == 0) {
|
||||
return AutoVectorizingCopyWithAssumedAlignment<32>{};
|
||||
} else if constexpr (bits % 16 == 0) {
|
||||
return AutoVectorizingCopyWithAssumedAlignment<16>{};
|
||||
} else {
|
||||
return AutoVectorizingCopyWithAssumedAlignment<8>{};
|
||||
}
|
||||
}
|
||||
|
||||
}; // namespace cute
|
@@ -0,0 +1,44 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/cutlass_extensions/vllm_collective_builder.cuh
|
||||
#pragma once
|
||||
|
||||
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
using namespace cute;
|
||||
|
||||
//
|
||||
// MacheteCollectiveBuilder is a wrapper around CollectiveBuilder that allows for
|
||||
// for custom kernel tags, allowing you to build custom collectives. Without
|
||||
// touching the cutlass library headers, using `CutlassKernelTag` will mean it
|
||||
// will resort to using the standard cutlass collective builder.
|
||||
//
|
||||
|
||||
// Use the default Cutlass collective builder, i.e. use an unmodified cutless
|
||||
// collective
|
||||
struct CutlassKernelTag {};
|
||||
|
||||
template <class KernelTag, class ArchTag, class OpClass, class ElementA,
|
||||
class GmemLayoutA, int AlignmentA, class ElementB, class GmemLayoutB,
|
||||
int AlignmentB, class ElementAccumulator, class TileShape_MNK,
|
||||
class ClusterShape_MNK, class StageCountType,
|
||||
class KernelScheduleType, class Enable = void>
|
||||
struct MacheteCollectiveBuilder {
|
||||
static_assert(sizeof(ElementA) == 0,
|
||||
"Could not build a collective for given parameters.");
|
||||
};
|
||||
|
||||
template <class ArchTag, class OpClass, class ElementA, class GmemLayoutA,
|
||||
int AlignmentA, class ElementB, class GmemLayoutB, int AlignmentB,
|
||||
class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
|
||||
class StageCountType, class KernelScheduleType>
|
||||
struct MacheteCollectiveBuilder<
|
||||
CutlassKernelTag, ArchTag, OpClass, ElementA, GmemLayoutA, AlignmentA,
|
||||
ElementB, GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK,
|
||||
ClusterShape_MNK, StageCountType, KernelScheduleType> {
|
||||
using CollectiveOp = typename CollectiveBuilder<
|
||||
ArchTag, OpClass, ElementA, GmemLayoutA, AlignmentA, ElementB,
|
||||
GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK,
|
||||
ClusterShape_MNK, StageCountType, KernelScheduleType>::CollectiveOp;
|
||||
};
|
||||
|
||||
}; // namespace cutlass::gemm::collective
|
51
custom_ops/gpu_ops/machete/utils/machete_custom_types.cuh
Normal file
51
custom_ops/gpu_ops/machete/utils/machete_custom_types.cuh
Normal file
@@ -0,0 +1,51 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/cutlass_extensions/vllm_custom_types.cuh
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/integer_subbyte.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int Bits, int Bias, bool Signed = false>
|
||||
struct machete_biased_integer_subbyte : public integer_subbyte<Bits, Signed> {
|
||||
using Base = integer_subbyte<Bits, Signed>;
|
||||
|
||||
using Storage = typename Base::Storage;
|
||||
using xint_t = typename Base::xint_t;
|
||||
|
||||
using Base::bits_mask_;
|
||||
using Base::sign_mask_;
|
||||
using Base::storage;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// No operation
|
||||
machete_biased_integer_subbyte() = default;
|
||||
|
||||
/// Conversion from integer type
|
||||
CUTLASS_HOST_DEVICE explicit machete_biased_integer_subbyte(int value)
|
||||
: Base(value) {}
|
||||
CUTLASS_HOST_DEVICE explicit machete_biased_integer_subbyte(unsigned value)
|
||||
: Base(value) {}
|
||||
CUTLASS_HOST_DEVICE explicit machete_biased_integer_subbyte(double value)
|
||||
: Base(value) {}
|
||||
};
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// "GPTQ" types, i.e. symmetric quantization
|
||||
using machete_uint4b8_t = machete_biased_integer_subbyte<4, 8>; // u4b8
|
||||
using machete_uint8b128_t = machete_biased_integer_subbyte<8, 128>; // u8b128
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int Bits, int Bias, bool Signed>
|
||||
struct sizeof_bits<machete_biased_integer_subbyte<Bits, Bias, Signed>> {
|
||||
static constexpr int value = Bits;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
993
custom_ops/gpu_ops/machete/utils/machete_numeric_conversion.cuh
Normal file
993
custom_ops/gpu_ops/machete/utils/machete_numeric_conversion.cuh
Normal file
@@ -0,0 +1,993 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/cutlass_extensions/vllm_numeric_conversion.cuh
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "machete_custom_types.cuh"
|
||||
#include "cute_utils.cuh"
|
||||
#include "machete_type_utils.cuh"
|
||||
|
||||
// this file extends:
|
||||
// https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h
|
||||
// with vllm specific type conversions, namely: machete_uint4b8_t, machete_uint8b128_t
|
||||
// as well as adds interleaved numeric array converters for specific types.
|
||||
// (interleaved numeric array converters can be more efficient for subbyte
|
||||
// types)
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
// InterleavedNumericArrayConverter is like NumericArrayConverter but also
|
||||
// deinterleaves converted elements based on IlvBlkLayout, interleaving can
|
||||
// make subbyte converts more efficient by allowing for efficient extraction
|
||||
// of subbyte elements from a 32bit register.
|
||||
template <typename IlvBlkLayout, typename T, typename S, int N,
|
||||
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest,
|
||||
class Enable = void>
|
||||
struct InterleavedNumericArrayConverter {
|
||||
using Converter = NumericArrayConverter<T, S, N, Round>;
|
||||
|
||||
using result_type = typename Converter::result_type;
|
||||
using source_type = typename Converter::source_type;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
if (cute::elect_one_sync()) {
|
||||
if constexpr (std::is_same_v<IlvBlkLayout, void>) {
|
||||
printf(
|
||||
"Convert %s <= %s (N = %d, IlvBlkLayout = void), not implemented\n",
|
||||
nameof_v<T>, nameof_v<S>, N);
|
||||
} else {
|
||||
printf(
|
||||
"Convert %s <= %s (N = %d, size(IlvBlkLayout{}) = %d), not "
|
||||
"implemented\n",
|
||||
nameof_v<T>, nameof_v<S>, N, size(IlvBlkLayout{}));
|
||||
}
|
||||
__brkpt();
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
template <typename IlvBlkLayout, typename T, typename S, int N,
|
||||
FloatRoundStyle Round>
|
||||
struct InterleavedNumericArrayConverter<
|
||||
IlvBlkLayout, T, S, N, Round,
|
||||
std::enable_if_t<is_identity_layout<IlvBlkLayout>()>> {
|
||||
using Converter = NumericArrayConverter<T, S, N, Round>;
|
||||
|
||||
using result_type = typename Converter::result_type;
|
||||
using source_type = typename Converter::source_type;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return Converter::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
template <typename RegConvert32bit, typename T, typename S, int N>
|
||||
struct ArrayConverterPacked32Bit {
|
||||
using result_type = Array<T, N>;
|
||||
using source_type = Array<S, N>;
|
||||
|
||||
using result_packed_8_t = Array<T, 8>;
|
||||
using result_packed_4_t = Array<T, 4>;
|
||||
using result_packed_2_t = Array<T, 2>;
|
||||
using src_packed_8_t = Array<S, 8>;
|
||||
using src_packed_4_t = Array<S, 4>;
|
||||
using src_packed_2_t = Array<S, 2>;
|
||||
|
||||
static_assert(N % 2 == 0, "N must be a multiple of 2");
|
||||
static_assert(cutlass::sizeof_bits_v<S> >= 4); // TODO: add 16 packed sources
|
||||
static_assert(32 % cutlass::sizeof_bits_v<S> == 0);
|
||||
static constexpr auto src_elems_per_32bit_reg =
|
||||
32 / cutlass::sizeof_bits_v<S>;
|
||||
|
||||
// Maybe not Valid. ScalarConverter will not actually work unless
|
||||
// NumericConverter<T, S, Round> is implemented. However it won't be used
|
||||
// anyways since we assert N % 2 == 0, just here for compliance with
|
||||
// VectorizedConverter.
|
||||
using ScalarConverter = NumericConverter<T, S>;
|
||||
|
||||
template <typename PackedSrc>
|
||||
CUTLASS_DEVICE static auto to_regs(PackedSrc const& src) {
|
||||
if constexpr (sizeof(PackedSrc) == 1) {
|
||||
return Array<uint32_t, 1>{reinterpret_cast<uint8_t const&>(src)};
|
||||
} else if constexpr (sizeof(PackedSrc) == 2) {
|
||||
return Array<uint32_t, 1>{reinterpret_cast<uint16_t const&>(src)};
|
||||
} else if constexpr (sizeof(PackedSrc) == 4) {
|
||||
return Array<uint32_t, 1>{reinterpret_cast<uint32_t const&>(src)};
|
||||
} else {
|
||||
static_assert(sizeof(PackedSrc) == 8);
|
||||
return reinterpret_cast<Array<uint32_t, 2> const&>(src);
|
||||
}
|
||||
}
|
||||
|
||||
// The core converter uses bit tricks to construct a known FP16 number, then
|
||||
// does a subtraction in FP16 for the final result.
|
||||
template <typename PackedResultType, typename PackedSrcType>
|
||||
CUTLASS_DEVICE static PackedResultType packed_convert(
|
||||
PackedSrcType const& source) {
|
||||
static_assert(PackedSrcType::kElements == PackedResultType::kElements);
|
||||
static_assert(PackedResultType::kElements == 2 ||
|
||||
PackedResultType::kElements == 4 ||
|
||||
PackedResultType::kElements == 8,
|
||||
"Invalid PackedResultType must be 2, 4 or 8.");
|
||||
static_assert(std::is_same_v<typename PackedSrcType::Element, S>);
|
||||
static_assert(std::is_same_v<typename PackedResultType::Element, T>);
|
||||
|
||||
return RegConvert32bit::template convert<PackedResultType>(to_regs(source));
|
||||
}
|
||||
|
||||
friend class detail::VectorizedConverter;
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE static result_type convert(source_type const& source) {
|
||||
result_type result;
|
||||
using ConverterType =
|
||||
ArrayConverterPacked32Bit<RegConvert32bit,
|
||||
typename result_type::Element,
|
||||
typename source_type::Element, N>;
|
||||
|
||||
if constexpr (src_elems_per_32bit_reg >= 8) {
|
||||
detail::VectorizedConverter::convert<
|
||||
ConverterType, result_packed_8_t, src_packed_8_t, result_packed_4_t,
|
||||
src_packed_4_t, result_packed_2_t, src_packed_2_t>(result, source);
|
||||
} else if constexpr (src_elems_per_32bit_reg >= 4) {
|
||||
detail::VectorizedConverter::convert<ConverterType, result_packed_4_t,
|
||||
src_packed_4_t, result_packed_2_t,
|
||||
src_packed_2_t>(result, source);
|
||||
} else {
|
||||
detail::VectorizedConverter::convert<ConverterType, result_packed_2_t,
|
||||
src_packed_2_t>(result, source);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
// Convert 8 4bit values packed into a 32bit register to 8 8bit values packed
|
||||
// into 2 32bit register.
|
||||
template <uint8_t LUT0, uint8_t LUT1, uint8_t LUT2, uint8_t LUT3, //
|
||||
uint8_t LUT4, uint8_t LUT5, uint8_t LUT6, uint8_t LUT7, //
|
||||
uint8_t LUT8, uint8_t LUT9, uint8_t LUT10, uint8_t LUT11, //
|
||||
uint8_t LUT12, uint8_t LUT13, uint8_t LUT14, uint8_t LUT15>
|
||||
CUTLASS_DEVICE cutlass::AlignedArray<uint32_t, 2> lut_4bit_to_8bit_convert(
|
||||
uint32_t src) {
|
||||
cutlass::AlignedArray<uint32_t, 2> r;
|
||||
// Determines if the value is in the top half of the LUT if set or
|
||||
// (i.e. LUT[8:15]) in the bottom half (i.e. LUT[0:7]) if not set. Then move
|
||||
// into bit position 0x4 of each nibble so when or'd with final_prmt_base it
|
||||
// selects the correct candidate. When elements in final_prmt_base
|
||||
// are >= 0x4, the high candidate is selected (i.e. LUT[8:15]), when elements
|
||||
// are < 0x4, the low candidate is selected (i.e. LUT[0:7])
|
||||
uint32_t high_bit = (src & 0x88888888) >> 1;
|
||||
|
||||
// `high_bit` is OR'd with 0x31203120 to find the correct value in the LUT
|
||||
// (selects correct high or low candidate)
|
||||
const uint32_t final_prmt_base = 0x32103210;
|
||||
|
||||
// Ignore the high bit when indexing into LUT, for each 4bit value
|
||||
// we index into both the high and low candidates then use
|
||||
// high_bit | final_prmt_base to select the correct candidate
|
||||
uint32_t lut_idx = (src & 0x77777777);
|
||||
|
||||
auto pack = [](uint8_t a, uint8_t b, uint8_t c, uint8_t d) {
|
||||
return uint32_t(a) | (uint32_t(b) << 8) | (uint32_t(c) << 16) |
|
||||
(uint32_t(d) << 24);
|
||||
};
|
||||
|
||||
static constexpr uint32_t LOW_0 = pack(LUT0, LUT1, LUT2, LUT3);
|
||||
static constexpr uint32_t LOW_1 = pack(LUT4, LUT5, LUT6, LUT7);
|
||||
static constexpr uint32_t HIGH_0 = pack(LUT8, LUT9, LUT10, LUT11);
|
||||
static constexpr uint32_t HIGH_1 = pack(LUT12, LUT13, LUT14, LUT15);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < 2; ++ii, lut_idx >>= 16, high_bit >>= 16) {
|
||||
uint32_t final_prmt_idx = final_prmt_base | high_bit;
|
||||
|
||||
// This uses a look up table to convert packed int4s to packed int8s,
|
||||
// using the int4 value as the index to prmt. It first select both the
|
||||
// high and low candidates, then uses the high bit (i.e. `high_bit`) to
|
||||
// select the correct candidate.
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .b32 low, high;\n"
|
||||
" prmt.b32 low, %1, %2, %5;\n"
|
||||
" prmt.b32 high, %3, %4, %5;\n"
|
||||
" prmt.b32 %0, low, high, %6;\n"
|
||||
"}\n"
|
||||
: "=r"(r[ii])
|
||||
: "n"(LOW_0), "n"(LOW_1), "n"(HIGH_0), "n"(HIGH_1), "r"(lut_idx),
|
||||
"r"(final_prmt_idx));
|
||||
}
|
||||
|
||||
return r;
|
||||
};
|
||||
|
||||
// for Array<int8_t, N> <= Array<machete_uint4b8_t, N>
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<int8_t, machete_uint4b8_t, N, Round> {
|
||||
using result_type = Array<int8_t, N>;
|
||||
using source_type = Array<machete_uint4b8_t, N>;
|
||||
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
private:
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
// [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7] as int8s
|
||||
auto r = lut_4bit_to_8bit_convert<0xF8, 0xF9, 0xFA, 0xFB, //
|
||||
0xFC, 0xFD, 0xFE, 0xFF, //
|
||||
0x00, 0x01, 0x02, 0x03, //
|
||||
0x04, 0x05, 0x06, 0x07>(src_[0]);
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
// for Array<cutlass::float_e4m3_t, N> <= Array<machete_uint4b8_t, N>
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<cutlass::float_e4m3_t, machete_uint4b8_t, N, Round> {
|
||||
using result_type = Array<cutlass::float_e4m3_t, N>;
|
||||
using source_type = Array<machete_uint4b8_t, N>;
|
||||
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
private:
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
// [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7] as fp8s
|
||||
auto r = lut_4bit_to_8bit_convert<0xD0, 0xCE, 0xCC, 0xCA, //
|
||||
0xC8, 0xC4, 0xC0, 0xB8, //
|
||||
0x00, 0x38, 0x40, 0x44, //
|
||||
0x48, 0x4A, 0x4C, 0x4E>(src_[0]);
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
// for Array<cutlass::half_t, N> <= Array<machete_uint4b8_t, N>
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<cutlass::half_t, machete_uint4b8_t, N, Round> {
|
||||
using result_type = Array<cutlass::half_t, N>;
|
||||
using source_type = Array<machete_uint4b8_t, N>;
|
||||
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
uint32_t src = src_[0];
|
||||
using RegArray =
|
||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||
sizeof(PackedResultType)>;
|
||||
RegArray r;
|
||||
|
||||
// Below constructs the following temporary:
|
||||
// fp16s_01 = {0x00, i4_01, 0x00, i4_01}
|
||||
// fp16s_23 = {0x00, i4_23, 0x00, i4_23}
|
||||
// fp16s_45 = {0x00, i4_45, 0x00, i4_45}
|
||||
// fp16s_67 = {0x00, i4_67, 0x00, i4_67}
|
||||
// We use inline asm instead of __byte_perm intrinsic since we don't want
|
||||
// the documented (& 0x7) on the index. NVCC might be able to optimize it
|
||||
// out since the index is a constexpr, but we choose to be safe about it
|
||||
// here.
|
||||
uint32_t prmt_indices[4] = {0x4040, 0x4141, 0x4242, 0x4343};
|
||||
static_assert(RegArray::kElements <= 4,
|
||||
"Too many inputs for F16 -> I4 vector converter");
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" prmt.b32 %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "=r"(r[ii])
|
||||
: "r"(src), "n"(0), "r"(prmt_indices[ii]));
|
||||
}
|
||||
|
||||
// Since the stored 4bit values are biased by 8 we get stored_val = (x+8)
|
||||
// we are trying to construct x and a fp16 value
|
||||
// The below XOR does the following:
|
||||
// 1) Sets the exponent bits of the FP16 to the correct value for the
|
||||
// FP16 magic_num. We will be constructing {1024+16*(x1+8), 1024+(x0+8)},
|
||||
// where x1 in the high nibble and x0 is the low nibble then using hfma
|
||||
// to subtract 1032 from that
|
||||
// The AND does the following:
|
||||
// 1) Clear the set bits for the int4 we will ignore.
|
||||
// We use lop3 so that we can use 1 instruction for AND and XOR.
|
||||
static constexpr uint32_t xor_mask = 0x64006400;
|
||||
static constexpr uint32_t and_mask = 0xFFF0FF0F;
|
||||
static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa;
|
||||
|
||||
// For each operand, computes:
|
||||
// r[i] = (r[i] & and_mask) ^ xor_mask
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "+r"(r[ii])
|
||||
: "n"(and_mask), "n"(xor_mask), "n"(immLut));
|
||||
}
|
||||
|
||||
// We will issue 2 hfmas that do the following:
|
||||
// {x1, x0} = {1024+16*(x1+8), 1024+(x0+8)} * {1/16, 1} - {72, 1032}
|
||||
// = {x1 + 1152, x0 + 1032} * {1/16, 1} - {72, 1032}
|
||||
static constexpr uint32_t hfma_bias_rep = 0xD480E408; // {72, 1032}
|
||||
static constexpr uint32_t hfma_scale_rep = 0x2C003C00; // {1 / 16, 1}
|
||||
|
||||
const half2& hfma_bias = reinterpret_cast<const half2&>(hfma_bias_rep);
|
||||
const half2& hfma_scale = reinterpret_cast<const half2&>(hfma_scale_rep);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]);
|
||||
fp16x2_val = __hfma2(hfma_scale, fp16x2_val, hfma_bias);
|
||||
}
|
||||
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
// for Array<cutlass::half_t, N> <= Array<machete_uint4b8_t, N>
|
||||
// for IlvdLayout: (2, 4):(4, 1)
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
|
||||
cutlass::half_t, machete_uint4b8_t, N,
|
||||
Round, void> {
|
||||
using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
|
||||
static_assert(N % size(IlvdLayout{}) == 0);
|
||||
|
||||
using result_type = Array<cutlass::half_t, N>;
|
||||
using source_type = Array<machete_uint4b8_t, N>;
|
||||
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
private:
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
uint32_t src = src_[0];
|
||||
using RegArray =
|
||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||
sizeof(PackedResultType)>;
|
||||
RegArray r;
|
||||
|
||||
static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
|
||||
static constexpr uint32_t xor_mask = 0x64006400;
|
||||
|
||||
for (int ii = 0; ii < RegArray::kElements; ii += 2) {
|
||||
auto src_ = src >> (4 * (ii));
|
||||
r[ii + 0] = src_;
|
||||
r[ii + 1] = src_;
|
||||
|
||||
static constexpr uint32_t and_xor_imm_lut = (0xf0 & 0xcc) ^ 0xaa;
|
||||
|
||||
static constexpr uint32_t low_nib_mask = 0x000F000F;
|
||||
static constexpr uint32_t high_nib_mask = 0x00F000F0;
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "+r"(r[ii + 0])
|
||||
: "n"(low_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "+r"(r[ii + 1])
|
||||
: "n"(high_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
|
||||
|
||||
// For low nibble:
|
||||
// {x1, x0} = {1024+(x1+8), 1024+(x0+8)} * {1, 1} - {1032, 1032}
|
||||
// For high nibble:
|
||||
// {x1, x0} = {1024+16*(x1+8), 1024+16*(x0+8)} * {1/16, 1/16}
|
||||
// - {72, 72}
|
||||
static constexpr uint32_t low_nib_bias = 0x64086408; // {1032, 1032}
|
||||
static constexpr uint32_t high_nib_scale = 0x2C002C00; // {1/16, 1/16}
|
||||
static constexpr uint32_t high_nib_bias = 0xD480D480; // {-72, -72}
|
||||
|
||||
{
|
||||
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]);
|
||||
fp16x2_val =
|
||||
__hsub2(fp16x2_val, reinterpret_cast<const half2&>(low_nib_bias));
|
||||
}
|
||||
|
||||
{
|
||||
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]);
|
||||
fp16x2_val = __hfma2(fp16x2_val,
|
||||
reinterpret_cast<const half2&>(high_nib_scale),
|
||||
reinterpret_cast<const half2&>(high_nib_bias));
|
||||
}
|
||||
}
|
||||
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
// for Array<cutlass::half_t, N> <= Array<uint4_t, N>
|
||||
// for IlvdLayout: (2, 4):(4, 1)
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
|
||||
cutlass::half_t, uint4_t, N, Round,
|
||||
void> {
|
||||
using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
|
||||
static_assert(N % size(IlvdLayout{}) == 0);
|
||||
|
||||
using result_type = Array<cutlass::half_t, N>;
|
||||
using source_type = Array<uint4_t, N>;
|
||||
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
private:
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
uint32_t src = src_[0];
|
||||
using RegArray =
|
||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||
sizeof(PackedResultType)>;
|
||||
RegArray r;
|
||||
|
||||
static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
|
||||
static constexpr uint32_t xor_mask = 0x64006400;
|
||||
|
||||
for (int ii = 0; ii < RegArray::kElements; ii += 2) {
|
||||
auto src_ = src >> (4 * (ii));
|
||||
r[ii + 0] = src_;
|
||||
r[ii + 1] = src_;
|
||||
|
||||
static constexpr uint32_t and_xor_imm_lut = (0xf0 & 0xcc) ^ 0xaa;
|
||||
|
||||
static constexpr uint32_t low_nib_mask = 0x000F000F;
|
||||
static constexpr uint32_t high_nib_mask = 0x00F000F0;
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "+r"(r[ii + 0])
|
||||
: "n"(low_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "+r"(r[ii + 1])
|
||||
: "n"(high_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
|
||||
|
||||
// For low nibble:
|
||||
// {x1, x0} = {1024+x1, 1024+x0} - {1024, 1024}
|
||||
// For high nibble:
|
||||
// {x1, x0} = {1024+16*x1, 1024+16*x0} * {1/16, 1/16} - {64, 64}
|
||||
static constexpr uint32_t low_nib_bias = 0x64006400; // {1024, 1024}
|
||||
static constexpr uint32_t high_nib_scale = 0x2C002C00; // {1/16, 1/16}
|
||||
static constexpr uint32_t high_nib_bias = 0xD400D400; // {-64, -64}
|
||||
|
||||
{
|
||||
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]);
|
||||
fp16x2_val =
|
||||
__hsub2(fp16x2_val, reinterpret_cast<const half2&>(low_nib_bias));
|
||||
}
|
||||
|
||||
{
|
||||
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]);
|
||||
fp16x2_val = __hfma2(fp16x2_val,
|
||||
reinterpret_cast<const half2&>(high_nib_scale),
|
||||
reinterpret_cast<const half2&>(high_nib_bias));
|
||||
}
|
||||
}
|
||||
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
// for Array<cutlass::half_t, N> <= Array<machete_uint8b128_t, N>
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<cutlass::half_t, machete_uint8b128_t, N, Round> {
|
||||
using result_type = Array<cutlass::half_t, N>;
|
||||
using source_type = Array<machete_uint8b128_t, N>;
|
||||
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
uint32_t src = src_[0];
|
||||
// Hold output FP16s in reg. We need 1 reg for every 2 elements
|
||||
using RegArray =
|
||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||
sizeof(PackedResultType)>;
|
||||
RegArray r;
|
||||
|
||||
uint32_t const prmt_indices[2] = {0x5150, 0x5352};
|
||||
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
asm volatile("prmt.b32 %0,%1,%2,%3;\n"
|
||||
: "=r"(r[ii])
|
||||
: "r"(src), "n"(start_byte_for_fp16),
|
||||
"r"(prmt_indices[ii]));
|
||||
}
|
||||
|
||||
// -128 is folded into bias subtraction, i.e. the 0x80 in the low bytes
|
||||
static constexpr uint32_t bias_rep = 0x64806480;
|
||||
const half2& bias = reinterpret_cast<const half2&>(bias_rep);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]);
|
||||
fp16x2_val = __hsub2(fp16x2_val, bias);
|
||||
}
|
||||
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
// for Array<cutlass::float, N> <= Array<machete_uint8b128_t, N>
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<float, machete_uint8b128_t, N, Round> {
|
||||
using result_type = Array<float, N>;
|
||||
using source_type = Array<machete_uint8b128_t, N>;
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
private:
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
uint32_t src = src_[0];
|
||||
PackedResultType r;
|
||||
|
||||
// __byte_perm simulates the add.u32 0x4B000000 to every u8 element of
|
||||
// u8x4 source and stores the result in r (without introducing extra
|
||||
// cvt.u32.u8 instruction)
|
||||
uint32_t const prmt_indices[4] = {0x7650, 0x7651, 0x7652, 0x7653};
|
||||
uint32_t* result_as_int = reinterpret_cast<uint32_t*>(&r);
|
||||
for (int ii = 0; ii < PackedResultType::kElements; ++ii) {
|
||||
result_as_int[ii] = __byte_perm(src, 0x4B000000, prmt_indices[ii]);
|
||||
// Subtract the magic number 0x4B000000 from tmp in floating-point
|
||||
// arithmetic to obtain final result
|
||||
r[ii] -= (8388608.f + 128.f); // fold in -128 bias
|
||||
}
|
||||
|
||||
return r;
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
||||
|
||||
// for Array<cutlass::bfloat16_t, N> <= Array<machete_uint4b8_t, N>
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<cutlass::bfloat16_t, machete_uint4b8_t, N, Round> {
|
||||
using result_type = Array<cutlass::bfloat16_t, N>;
|
||||
using source_type = Array<machete_uint4b8_t, N>;
|
||||
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
private:
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
uint32_t src_reg = src_[0];
|
||||
// Hold output BF16s in reg. We need 1 reg for every 2 elements
|
||||
using RegArray =
|
||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||
sizeof(PackedResultType)>;
|
||||
RegArray r;
|
||||
uint32_t src_reg_shifted = src_reg >> 4;
|
||||
|
||||
// Below constructs the following temporary:
|
||||
uint32_t const prmt_indices[4] = {0xF4F0, 0xF5F1, 0xF6F2, 0xF7F3};
|
||||
static_assert(RegArray::kElements <= 4,
|
||||
"Too many inputs for uint4b8_t -> BF16 vector converter");
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" prmt.b32 %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "=r"(r[ii])
|
||||
: "r"(src_reg), "r"(src_reg_shifted), "r"(prmt_indices[ii]));
|
||||
}
|
||||
|
||||
// Since the stored 4bit values are biased by 8 we get stored_val = (x+8)
|
||||
// we are trying to construct x and a BF16 value
|
||||
// The below XOR does the following:
|
||||
// 1) Sets the exponent bits of the BF16 to the correct value for the
|
||||
// BF16 magic_num. We will be constructing {128 + (x1+8), 128 + (x0+8)}
|
||||
// and subtracting 136 to get {x1, x0}
|
||||
static constexpr uint32_t xor_mask = 0x43004300;
|
||||
static constexpr uint32_t and_mask = 0x000F000F;
|
||||
static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa;
|
||||
|
||||
// For each operand, computes:
|
||||
// r[i] = (r[i] & and_mask) ^ xor_mask
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "+r"(r[ii])
|
||||
: "n"(and_mask), "n"(xor_mask), "n"(immLut));
|
||||
}
|
||||
|
||||
// We will issue 2 bfmas that do the following:
|
||||
// high BF16:
|
||||
// hi_bf16 - 136, lo_bf16 - 136
|
||||
|
||||
// This is the BF16 {136, 136} represented as an integer.
|
||||
static constexpr uint32_t bias_rep = 0x43084308;
|
||||
const __nv_bfloat162& bias =
|
||||
reinterpret_cast<const __nv_bfloat162&>(bias_rep);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
__nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
|
||||
bf16x2_val = __hsub2(bf16x2_val, bias);
|
||||
}
|
||||
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
// for Array<cutlass::bfloat16_t, N> <= Array<machete_uint4b8_t, N>
|
||||
// for IlvdLayout: (2, 4):(4, 1)
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
|
||||
cutlass::bfloat16_t, machete_uint4b8_t, N,
|
||||
Round, void> {
|
||||
using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
|
||||
static_assert(N % size(IlvdLayout{}) == 0);
|
||||
|
||||
using result_type = Array<cutlass::bfloat16_t, N>;
|
||||
using source_type = Array<machete_uint4b8_t, N>;
|
||||
|
||||
private:
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
uint32_t src = src_[0];
|
||||
using RegArray =
|
||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||
sizeof(PackedResultType)>;
|
||||
RegArray r;
|
||||
|
||||
static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
|
||||
static constexpr uint32_t or_mask = 0x43004300;
|
||||
|
||||
// Unlike float16 where the mantissa is large enough to contain 2
|
||||
// nibbles, bfloat16 can only fit one, so we can only convert one
|
||||
// nibble at a time
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
r[ii] = src >> (4 * ii);
|
||||
|
||||
static constexpr uint32_t and_or_imm_lut = (0xf0 & 0xcc) | 0xaa;
|
||||
static constexpr uint32_t low_nib_mask = 0x000F000F;
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "+r"(r[ii + 0])
|
||||
: "n"(low_nib_mask), "n"(or_mask), "n"(and_or_imm_lut));
|
||||
|
||||
// For low nibble:
|
||||
// {x1, x0} = {128+(x1+8), 128+(x0+8)} * {1, 1} - {136, 136}
|
||||
static constexpr uint32_t low_nib_bias = 0x43084308; // {136, 136}
|
||||
|
||||
{
|
||||
__nv_bfloat162& fp16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
|
||||
fp16x2_val =
|
||||
__hsub2(fp16x2_val,
|
||||
reinterpret_cast<const __nv_bfloat162&>(low_nib_bias));
|
||||
}
|
||||
}
|
||||
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
// for Array<cutlass::bfloat16_t, N> <= Array<uint4_t, N>
|
||||
// for IlvdLayout: (2, 4):(4, 1)
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
|
||||
cutlass::bfloat16_t, uint4_t, N, Round,
|
||||
void> {
|
||||
using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
|
||||
static_assert(N % size(IlvdLayout{}) == 0);
|
||||
|
||||
using result_type = Array<cutlass::bfloat16_t, N>;
|
||||
using source_type = Array<uint4_t, N>;
|
||||
|
||||
private:
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(Array<uint32_t, 1> src_) {
|
||||
uint32_t src = src_[0];
|
||||
using RegArray =
|
||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||
sizeof(PackedResultType)>;
|
||||
RegArray r;
|
||||
|
||||
static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
|
||||
static constexpr uint32_t or_mask = 0x43004300;
|
||||
|
||||
// Unlike float16 where the mantissa is large enough to contain 2
|
||||
// nibbles, bfloat16 can only fit one, so we can only convert one
|
||||
// nibble at a time
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
r[ii] = src >> (4 * ii);
|
||||
|
||||
static constexpr uint32_t and_or_imm_lut = (0xf0 & 0xcc) | 0xaa;
|
||||
static constexpr uint32_t low_nib_mask = 0x000F000F;
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "+r"(r[ii])
|
||||
: "n"(low_nib_mask), "n"(or_mask), "n"(and_or_imm_lut));
|
||||
|
||||
// For low nibble:
|
||||
// {x1, x0} = {128 + x1, 128 + x0} * {1, 1} - {128, 128}
|
||||
static constexpr uint32_t low_nib_bias = 0x43004300; // {128, 128}
|
||||
|
||||
{
|
||||
__nv_bfloat162& fp16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
|
||||
fp16x2_val =
|
||||
__hsub2(fp16x2_val,
|
||||
reinterpret_cast<const __nv_bfloat162&>(low_nib_bias));
|
||||
}
|
||||
}
|
||||
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
// for Array<cutlass::bfloat16_t, N> <= Array<machete_uint8b128_t, N>
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<cutlass::bfloat16_t, machete_uint8b128_t, N, Round> {
|
||||
using result_type = Array<cutlass::bfloat16_t, N>;
|
||||
using source_type = Array<machete_uint8b128_t, N>;
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
private:
|
||||
using result_packed_4_t = Array<cutlass::bfloat16_t, 4>;
|
||||
using result_packed_2_t = Array<cutlass::bfloat16_t, 2>;
|
||||
using src_packed_4_t = Array<machete_uint8b128_t, 4>;
|
||||
using src_packed_2_t = Array<machete_uint8b128_t, 2>;
|
||||
|
||||
// Not Valid, not supported, only here to satisfy the interface and to avoid
|
||||
// a compile error. ScalarConverter will not actually work until
|
||||
// NumericConverter<cutlass::bfloat16_t, machete_uint8b128_t, Round> is
|
||||
// implemented
|
||||
using ScalarConverter =
|
||||
NumericConverter<cutlass::bfloat16_t, machete_uint8b128_t, Round>;
|
||||
|
||||
template <typename PackedResultType, typename PackedSrcType>
|
||||
CUTLASS_DEVICE static PackedResultType packed_convert(
|
||||
PackedSrcType const& source) {
|
||||
static_assert(
|
||||
(platform::is_same<PackedSrcType, src_packed_2_t>::value &&
|
||||
platform::is_same<PackedResultType, result_packed_2_t>::value) ||
|
||||
(platform::is_same<PackedSrcType, src_packed_4_t>::value &&
|
||||
platform::is_same<PackedResultType, result_packed_4_t>::value),
|
||||
"Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private "
|
||||
"convert dispatch.");
|
||||
|
||||
NumericArrayConverter<float, machete_uint8b128_t, PackedResultType::kElements,
|
||||
Round>
|
||||
convert_uint8_to_f32;
|
||||
Array<float, PackedResultType::kElements> tmp =
|
||||
convert_uint8_to_f32(source);
|
||||
NumericArrayConverter<cutlass::bfloat16_t, float,
|
||||
PackedResultType::kElements, Round>
|
||||
convert_f32_to_bf16_;
|
||||
return convert_f32_to_bf16_(tmp);
|
||||
}
|
||||
|
||||
friend class detail::VectorizedConverter;
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
result_type result;
|
||||
using ConverterType =
|
||||
NumericArrayConverter<typename result_type::Element,
|
||||
typename source_type::Element, N, Round>;
|
||||
detail::VectorizedConverter::convert<ConverterType, result_packed_4_t,
|
||||
src_packed_4_t, result_packed_2_t,
|
||||
src_packed_2_t>(result, source);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
// for Array<int8_t, N> <= Array<cutlass::half_t, N>
|
||||
// FastFP16toINT8 from https://arxiv.org/pdf/2406.09904
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<int8_t, cutlass::half_t, N, Round> {
|
||||
using result_type = Array<int8_t, N>;
|
||||
using source_type = Array<cutlass::half_t, N>;
|
||||
|
||||
struct RegConvert {
|
||||
// FastFP16toINT8 from https://arxiv.org/pdf/2406.09904
|
||||
template <typename PackedResultType, int src_regs>
|
||||
CUTLASS_DEVICE static PackedResultType convert(
|
||||
Array<uint32_t, src_regs> src) {
|
||||
// Hold output int8s in reg. We need 1 reg for every 4 elements
|
||||
using RegArray = cutlass::AlignedArray<
|
||||
uint32_t, std::max(PackedResultType::kElements / 4, size_t(1))>;
|
||||
RegArray r;
|
||||
|
||||
static constexpr uint32_t MAGIC_BIAS_ = 0x64806480;
|
||||
auto MAGIC_BIAS = *reinterpret_cast<const half2*>(&MAGIC_BIAS_);
|
||||
|
||||
*reinterpret_cast<half2*>(&src[0]) =
|
||||
__hadd2(*reinterpret_cast<half2*>(&src[0]), MAGIC_BIAS);
|
||||
|
||||
if constexpr (src_regs > 1) {
|
||||
*reinterpret_cast<half2*>(&src[1]) =
|
||||
__hadd2(*reinterpret_cast<half2*>(&src[1]), MAGIC_BIAS);
|
||||
}
|
||||
|
||||
static_assert(PackedResultType::kElements <= 4);
|
||||
uint32_t uint8s;
|
||||
static constexpr uint32_t MASK_0246 = 0x6420;
|
||||
static constexpr uint32_t UINT8s_TO_INT8s_MASK = 0x80808080;
|
||||
asm volatile("prmt.b32 %0,%1,%2,%3;\n"
|
||||
: "=r"(uint8s)
|
||||
: "r"(src[0]), "r"((src_regs > 1) ? src[1] : src[0]),
|
||||
"n"(MASK_0246));
|
||||
|
||||
uint32_t int8s = (uint8s ^ UINT8s_TO_INT8s_MASK);
|
||||
|
||||
return reinterpret_cast<PackedResultType&>(int8s);
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
43
custom_ops/gpu_ops/machete/utils/machete_type_utils.cuh
Normal file
43
custom_ops/gpu_ops/machete/utils/machete_type_utils.cuh
Normal file
@@ -0,0 +1,43 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/cutlass_extensions/vllm_numeric_conversion.cuh
|
||||
#include "cutlass/bfloat16.h"
|
||||
#include "cutlass/half.h"
|
||||
#include "cuda_bf16.h"
|
||||
|
||||
#include "machete_custom_types.cuh"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
template <typename T>
|
||||
struct nameof {
|
||||
static constexpr char const* value = "unknown";
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline constexpr auto nameof_v = nameof<T>::value;
|
||||
|
||||
#define NAMEOF_TYPE(T) \
|
||||
template <> \
|
||||
struct nameof<T> { \
|
||||
static constexpr char const* value = #T; \
|
||||
};
|
||||
|
||||
NAMEOF_TYPE(float_e4m3_t)
|
||||
NAMEOF_TYPE(float_e5m2_t)
|
||||
NAMEOF_TYPE(half_t)
|
||||
NAMEOF_TYPE(nv_bfloat16)
|
||||
NAMEOF_TYPE(bfloat16_t)
|
||||
NAMEOF_TYPE(float)
|
||||
|
||||
NAMEOF_TYPE(int4b_t)
|
||||
NAMEOF_TYPE(int8_t)
|
||||
NAMEOF_TYPE(int32_t)
|
||||
NAMEOF_TYPE(int64_t)
|
||||
|
||||
NAMEOF_TYPE(machete_uint4b8_t)
|
||||
NAMEOF_TYPE(uint4b_t)
|
||||
NAMEOF_TYPE(uint8_t)
|
||||
NAMEOF_TYPE(machete_uint8b128_t)
|
||||
NAMEOF_TYPE(uint32_t)
|
||||
NAMEOF_TYPE(uint64_t)
|
||||
|
||||
}; // namespace cutlass
|
161
custom_ops/gpu_ops/machete/utils/paddle_utils.hpp
Normal file
161
custom_ops/gpu_ops/machete/utils/paddle_utils.hpp
Normal file
@@ -0,0 +1,161 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/main/csrc/cutlass_extensions/torch_utils.hpp
|
||||
#pragma once
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
#include "cute/layout.hpp"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/bfloat16.h"
|
||||
#include "cutlass/half.h"
|
||||
|
||||
using ColumnMajor = typename cutlass::layout::ColumnMajor;
|
||||
using RowMajor = typename cutlass::layout::RowMajor;
|
||||
|
||||
namespace cute {
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class T, class F, class G, int... I>
|
||||
CUTE_HOST_DEVICE constexpr auto tapply_with_idx(T&& t, F&& f, G&& g,
|
||||
seq<I...>) {
|
||||
return g(f(cute::get<I>(static_cast<T&&>(t)), I)...);
|
||||
}
|
||||
|
||||
template <class F, int... I>
|
||||
CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f, seq<I...>) {
|
||||
return make_shape(f(I)...);
|
||||
}
|
||||
|
||||
}; // namespace detail
|
||||
|
||||
template <class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr auto transform_with_idx(T const& t, F&& f) {
|
||||
if constexpr (cute::is_tuple<T>::value) {
|
||||
return detail::tapply_with_idx(
|
||||
t, f, [](auto const&... a) { return cute::make_tuple(a...); },
|
||||
tuple_seq<T>{});
|
||||
} else {
|
||||
return f(t);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
// calls: make_shape(f(0), f(1), ..., f(N-1))
|
||||
template <int N, class F>
|
||||
CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f) {
|
||||
return detail::make_shape_from_idx(f, make_seq<N>{});
|
||||
}
|
||||
|
||||
}; // namespace cute
|
||||
|
||||
// Make a layout from a tensor with `rank(Stride{})`, where the shape is the
|
||||
// shape of the passed in tensor and the strides are of type `Stride` and
|
||||
// contain the strides of the passed in tensor, checking that any static strides
|
||||
// in `Stride{}` match the strides of the passed in tensor.
|
||||
// If `tensor.shape().size() < rank(Stride{})`, the shape is padded with 1s and the extra
|
||||
// strides are set to be 0 or 1.
|
||||
template <typename Stride>
|
||||
static inline auto make_cute_layout(paddle::Tensor const& tensor,
|
||||
std::string_view name = "tensor") {
|
||||
PD_CHECK(tensor.shape().size() <= rank(Stride{}));
|
||||
auto stride = cute::transform_with_idx(
|
||||
Stride{}, [&](auto const& stride_ele, auto const& idx) {
|
||||
using StrideEle = std::decay_t<decltype(stride_ele)>;
|
||||
|
||||
if (idx < tensor.shape().size()) {
|
||||
if constexpr (cute::is_static_v<StrideEle>) {
|
||||
PD_CHECK(StrideEle::value == tensor.strides()[idx], "Expected ",
|
||||
name, ".strides()[", idx, "] to be ", StrideEle::value, ", but got ", tensor.strides()[idx], ". ");
|
||||
return StrideEle{};
|
||||
} else {
|
||||
if (tensor.shape()[idx] == 1) {
|
||||
// use 0 stride for dims with size 1, this is easier for
|
||||
// cute/cutlass to optimize (helps the TMA code flatten dims)
|
||||
return StrideEle{0};
|
||||
} else {
|
||||
return tensor.strides()[idx];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Extra strides are assumed to be 0 or 1
|
||||
if constexpr (cute::is_static_v<StrideEle>) {
|
||||
static_assert(StrideEle::value == 0 || StrideEle::value == 1);
|
||||
}
|
||||
return StrideEle{};
|
||||
}
|
||||
});
|
||||
|
||||
auto shape = cute::make_shape_from_idx<rank(Stride{})>([&](auto const& idx) {
|
||||
if (idx < tensor.shape().size())
|
||||
return tensor.shape()[idx];
|
||||
else
|
||||
return int64_t(1);
|
||||
});
|
||||
|
||||
return make_layout(shape, stride);
|
||||
}
|
||||
|
||||
template <typename Stride>
|
||||
static inline auto maybe_make_cute_layout(
|
||||
std::optional<paddle::Tensor> const& tensor,
|
||||
std::string_view name = "tensor") {
|
||||
using Layout = decltype(make_cute_layout<Stride>(*tensor));
|
||||
|
||||
if (tensor) {
|
||||
return std::optional<Layout>{make_cute_layout<Stride>(*tensor, name)};
|
||||
} else {
|
||||
return std::optional<Layout>{};
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Paddle dtype to Cutlass Type (equivalent_cutlass_type)
|
||||
//
|
||||
|
||||
template <typename T>
|
||||
struct equivalent_cutlass_type {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using equivalent_cutlass_type_t = typename equivalent_cutlass_type<T>::type;
|
||||
|
||||
template <>
|
||||
struct equivalent_cutlass_type<phi::dtype::float16> {
|
||||
using type = cutlass::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct equivalent_cutlass_type<phi::dtype::bfloat16> {
|
||||
using type = cutlass::bfloat16_t;
|
||||
};
|
||||
|
||||
//
|
||||
// equivalent_scalar_t (basically inverse of equivalent_cutlass_type)
|
||||
//
|
||||
|
||||
// Return a `c10::CppTypeToScalarType<T>` compatible type, i.e. get the C++ from
|
||||
// c10 that is equivalent to T, e.g.: `cutlass::half_t -> c10::Half`
|
||||
template <typename T>
|
||||
struct equivalent_scalar_type {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using equivalent_scalar_type_t = typename equivalent_scalar_type<T>::type;
|
||||
|
||||
template <>
|
||||
struct equivalent_scalar_type<cutlass::half_t> {
|
||||
using type = phi::dtype::float16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct equivalent_scalar_type<cutlass::bfloat16_t> {
|
||||
using type = phi::dtype::bfloat16;
|
||||
};
|
||||
|
||||
// get equivalent c10::ScalarType tag from compile time type
|
||||
template <typename T>
|
||||
static inline constexpr paddle::DataType equivalent_scalar_type_v =
|
||||
phi::CppTypeToDataType<equivalent_scalar_type_t<T>>::Type();
|
372
custom_ops/gpu_ops/machete/utils/scalar_type.h
Normal file
372
custom_ops/gpu_ops/machete/utils/scalar_type.h
Normal file
@@ -0,0 +1,372 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/phi/common/data_type.h"
|
||||
#include "paddle/phi/core/enforce.h"
|
||||
|
||||
#include <optional>
|
||||
#include <variant>
|
||||
|
||||
namespace machete {
|
||||
|
||||
//
|
||||
// ScalarType can represent a wide range of floating point and integer types,
|
||||
// in particular it can be used to represent sub-byte data types (something
|
||||
// that torch.dtype currently does not support).
|
||||
//
|
||||
// The type definitions on the Python side can be found in: vllm/scalar_type.py
|
||||
// these type definitions should be kept up to date with any Python API changes
|
||||
// here.
|
||||
//
|
||||
class ScalarType {
|
||||
public:
|
||||
enum NanRepr : uint8_t {
|
||||
NAN_NONE = 0, // nans are not supported
|
||||
NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s
|
||||
NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s
|
||||
|
||||
NAN_REPR_ID_MAX
|
||||
};
|
||||
|
||||
constexpr ScalarType(uint8_t exponent, uint8_t mantissa, bool signed_,
|
||||
int32_t bias, bool finite_values_only = false,
|
||||
NanRepr nan_repr = NAN_IEEE_754)
|
||||
: exponent(exponent),
|
||||
mantissa(mantissa),
|
||||
signed_(signed_),
|
||||
bias(bias),
|
||||
finite_values_only(finite_values_only),
|
||||
nan_repr(nan_repr) {};
|
||||
|
||||
static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) {
|
||||
return ScalarType(0, size_bits - 1, true, bias);
|
||||
}
|
||||
|
||||
static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) {
|
||||
return ScalarType(0, size_bits, false, bias);
|
||||
}
|
||||
|
||||
// IEEE 754 compliant floating point type
|
||||
static constexpr ScalarType float_IEEE754(uint8_t exponent,
|
||||
uint8_t mantissa) {
|
||||
// PADDLE_ENFORCE(mantissa > 0 && exponent > 0);
|
||||
return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754);
|
||||
}
|
||||
|
||||
// IEEE 754 non-compliant floating point type
|
||||
static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa,
|
||||
bool finite_values_only,
|
||||
NanRepr nan_repr) {
|
||||
// PADDLE_ENFORCE(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr");
|
||||
// PADDLE_ENFORCE(mantissa > 0 && exponent > 0);
|
||||
// PADDLE_ENFORCE(nan_repr != NAN_IEEE_754,
|
||||
// "use `float_IEEE754` constructor for floating point types that "
|
||||
// "follow IEEE 754 conventions");
|
||||
return ScalarType(exponent, mantissa, true, 0, finite_values_only,
|
||||
nan_repr);
|
||||
}
|
||||
|
||||
uint8_t const exponent; // size of the exponent field (0 for integer types)
|
||||
uint8_t const mantissa; // size of the mantissa field (size of the integer
|
||||
// excluding the sign bit for integer types)
|
||||
bool const signed_; // flag if the type supports negative numbers (i.e. has a
|
||||
// sign bit)
|
||||
int32_t const bias; // stored values equal value + bias,
|
||||
// used for quantized type
|
||||
|
||||
// Extra Floating point info
|
||||
bool const finite_values_only; // i.e. no +/-inf if true
|
||||
NanRepr const nan_repr; // how NaNs are represented
|
||||
// (not applicable for integer types)
|
||||
|
||||
using Id = int64_t;
|
||||
|
||||
private:
|
||||
// Field size in id
|
||||
template <typename T_>
|
||||
static constexpr size_t member_id_field_width() {
|
||||
using T = std::decay_t<T_>;
|
||||
return std::is_same_v<T, bool> ? 1 : sizeof(T) * 8;
|
||||
}
|
||||
|
||||
template <typename Fn, typename Init, typename Member, typename... Rest>
|
||||
static constexpr auto reduce_members_helper(Fn f, Init val, Member member,
|
||||
Rest... rest) {
|
||||
auto new_val = f(val, member);
|
||||
if constexpr (sizeof...(rest) > 0) {
|
||||
return reduce_members_helper(f, new_val, rest...);
|
||||
} else {
|
||||
return new_val;
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Fn, typename Init>
|
||||
constexpr auto reduce_members(Fn f, Init init) const {
|
||||
// Should be in constructor order for `from_id`
|
||||
return reduce_members_helper(f, init, exponent, mantissa, signed_, bias,
|
||||
finite_values_only, nan_repr);
|
||||
};
|
||||
|
||||
template <typename Fn, typename Init>
|
||||
static constexpr auto reduce_member_types(Fn f, Init init) {
|
||||
constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE);
|
||||
return dummy_type.reduce_members(f, init);
|
||||
};
|
||||
|
||||
static constexpr auto id_size_bits() {
|
||||
return reduce_member_types(
|
||||
[](int acc, auto member) -> int {
|
||||
return acc + member_id_field_width<decltype(member)>();
|
||||
},
|
||||
0);
|
||||
}
|
||||
|
||||
public:
|
||||
// unique id for this scalar type that can be computed at compile time for
|
||||
// c++17 template specialization this is not needed once we migrate to
|
||||
// c++20 and can pass literal classes as template parameters
|
||||
constexpr Id id() const {
|
||||
static_assert(id_size_bits() <= sizeof(Id) * 8,
|
||||
"ScalarType id is too large to be stored");
|
||||
|
||||
auto or_and_advance = [](std::pair<Id, uint32_t> result,
|
||||
auto member) -> std::pair<Id, uint32_t> {
|
||||
auto [id, bit_offset] = result;
|
||||
auto constexpr bits = member_id_field_width<decltype(member)>();
|
||||
return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1))
|
||||
<< bit_offset,
|
||||
bit_offset + bits};
|
||||
};
|
||||
return reduce_members(or_and_advance, std::pair<Id, uint32_t>{}).first;
|
||||
}
|
||||
|
||||
// create a ScalarType from an id, for c++17 template specialization,
|
||||
// this is not needed once we migrate to c++20 and can pass literal
|
||||
// classes as template parameters
|
||||
static constexpr ScalarType from_id(Id id) {
|
||||
auto extract_and_advance = [id](auto result, auto member) {
|
||||
using T = decltype(member);
|
||||
auto [tuple, bit_offset] = result;
|
||||
auto constexpr bits = member_id_field_width<T>();
|
||||
auto extracted_val = static_cast<T>((int64_t(id) >> bit_offset) &
|
||||
((uint64_t(1) << bits) - 1));
|
||||
auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val));
|
||||
return std::pair<decltype(new_tuple), int>{new_tuple, bit_offset + bits};
|
||||
};
|
||||
|
||||
auto [tuple_args, _] = reduce_member_types(extract_and_advance,
|
||||
std::pair<std::tuple<>, int>{});
|
||||
return std::apply([](auto... args) { return ScalarType(args...); },
|
||||
tuple_args);
|
||||
}
|
||||
|
||||
constexpr int64_t size_bits() const {
|
||||
return mantissa + exponent + is_signed();
|
||||
}
|
||||
constexpr bool is_signed() const { return signed_; }
|
||||
constexpr bool is_integer() const { return exponent == 0; }
|
||||
constexpr bool is_floating_point() const { return exponent > 0; }
|
||||
constexpr bool is_ieee_754() const {
|
||||
return is_floating_point() && finite_values_only == false &&
|
||||
nan_repr == NAN_IEEE_754;
|
||||
}
|
||||
constexpr bool has_nans() const {
|
||||
return is_floating_point() && nan_repr != NAN_NONE;
|
||||
}
|
||||
constexpr bool has_infs() const {
|
||||
return is_floating_point() && finite_values_only == false;
|
||||
}
|
||||
constexpr bool has_bias() const { return bias != 0; }
|
||||
|
||||
private:
|
||||
double _floating_point_max() const {
|
||||
PADDLE_ENFORCE(mantissa <= 52 && exponent <= 11,
|
||||
"Cannot represent max/min as a double for type ", str());
|
||||
|
||||
uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1;
|
||||
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) {
|
||||
max_mantissa -= 1;
|
||||
}
|
||||
|
||||
uint64_t max_exponent = (uint64_t(1) << exponent) - 2;
|
||||
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) {
|
||||
PADDLE_ENFORCE(exponent < 11,
|
||||
"Cannot represent max/min as a double for type ", str());
|
||||
max_exponent += 1;
|
||||
}
|
||||
|
||||
// adjust the exponent to match that of a double
|
||||
// for now we assume the exponent bias is the standard 2^(e-1) -1, (where e
|
||||
// is the exponent bits), there is some precedent for non-standard biases,
|
||||
// example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes
|
||||
// but to avoid premature over complication we are just assuming the
|
||||
// standard exponent bias until there is a need to support non-standard
|
||||
// biases
|
||||
uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1;
|
||||
uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11
|
||||
|
||||
uint64_t max_exponent_double =
|
||||
max_exponent - exponent_bias + exponent_bias_double;
|
||||
|
||||
// shift the mantissa into the position for a double and
|
||||
// the exponent
|
||||
uint64_t double_raw =
|
||||
(max_mantissa << (52 - mantissa)) | (max_exponent_double << 52);
|
||||
|
||||
return *reinterpret_cast<double*>(&double_raw);
|
||||
}
|
||||
|
||||
constexpr std::variant<int64_t, double> _raw_max() const {
|
||||
if (is_floating_point()) {
|
||||
return {_floating_point_max()};
|
||||
} else {
|
||||
// PADDLE_ENFORCE(size_bits() < 64 || size_bits() == 64 && is_signed(),
|
||||
// "Cannot represent max as a int64_t");
|
||||
return {(int64_t(1) << mantissa) - 1};
|
||||
}
|
||||
}
|
||||
|
||||
constexpr std::variant<int64_t, double> _raw_min() const {
|
||||
if (is_floating_point()) {
|
||||
// PADDLE_ENFORCE(is_signed(),
|
||||
// "We currently assume all floating point types are signed");
|
||||
constexpr uint64_t sign_bit_double = (uint64_t(1) << 63);
|
||||
|
||||
double max = _floating_point_max();
|
||||
uint64_t max_raw = *reinterpret_cast<uint64_t*>(&max);
|
||||
uint64_t min_raw = max_raw | sign_bit_double;
|
||||
return {*reinterpret_cast<double*>(&min_raw)};
|
||||
} else {
|
||||
// PADDLE_ENFORCE(!is_signed() || size_bits() <= 64,
|
||||
// "Cannot represent min as a int64_t");
|
||||
if (is_signed()) {
|
||||
// set the top bit to 1 (i.e. INT64_MIN) and the rest to 0
|
||||
// then perform an arithmetic shift right to set all the bits above
|
||||
// (size_bits() - 1) to 1
|
||||
return {INT64_MIN >> (64 - size_bits())};
|
||||
} else {
|
||||
return {int64_t(0)};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
// Max representable value for this scalar type.
|
||||
// (accounting for bias if there is one)
|
||||
constexpr std::variant<int64_t, double> max() const {
|
||||
return std::visit(
|
||||
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
|
||||
_raw_max());
|
||||
}
|
||||
|
||||
// Min representable value for this scalar type.
|
||||
// (accounting for bias if there is one)
|
||||
constexpr std::variant<int64_t, double> min() const {
|
||||
return std::visit(
|
||||
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
|
||||
_raw_min());
|
||||
}
|
||||
|
||||
std::string str() const {
|
||||
/* naming generally follows: https://github.com/jax-ml/ml_dtypes
|
||||
* for floating point types (leading f) the scheme is:
|
||||
* `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
||||
* flags:
|
||||
* - no-flags: means it follows IEEE 754 conventions
|
||||
* - f: means finite values only (no infinities)
|
||||
* - n: means nans are supported (non-standard encoding)
|
||||
* for integer types the scheme is:
|
||||
* `[u]int<size_bits>[b<bias>]`
|
||||
* - if bias is not present it means its zero
|
||||
*/
|
||||
if (is_floating_point()) {
|
||||
auto ret = "float" + std::to_string(size_bits()) + "_e" +
|
||||
std::to_string(exponent) + "m" + std::to_string(mantissa);
|
||||
if (!is_ieee_754()) {
|
||||
if (finite_values_only) {
|
||||
ret += "f";
|
||||
}
|
||||
if (nan_repr != NAN_NONE) {
|
||||
ret += "n";
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
} else {
|
||||
auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits());
|
||||
if (has_bias()) {
|
||||
ret += "b" + std::to_string(bias);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr bool operator==(ScalarType const& other) const {
|
||||
return mantissa == other.mantissa && exponent == other.exponent &&
|
||||
bias == other.bias && signed_ == other.signed_ &&
|
||||
finite_values_only == other.finite_values_only &&
|
||||
nan_repr == other.nan_repr;
|
||||
}
|
||||
};
|
||||
|
||||
using ScalarTypeId = machete::ScalarType::Id;
|
||||
|
||||
// "rust style" names generally following:
|
||||
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70
|
||||
static inline constexpr auto kS4 = machete::ScalarType::int_(4);
|
||||
static inline constexpr auto kU4 = machete::ScalarType::uint(4);
|
||||
static inline constexpr auto kU4B8 = machete::ScalarType::uint(4, 8);
|
||||
static inline constexpr auto kS8 = machete::ScalarType::int_(8);
|
||||
static inline constexpr auto kU8 = machete::ScalarType::uint(8);
|
||||
static inline constexpr auto kU8B128 = machete::ScalarType::uint(8, 128);
|
||||
|
||||
static inline constexpr auto kFE2M1f =
|
||||
machete::ScalarType::float_(2, 1, true, machete::ScalarType::NAN_NONE);
|
||||
static inline constexpr auto kFE3M2f =
|
||||
machete::ScalarType::float_(3, 2, true, machete::ScalarType::NAN_NONE);
|
||||
static inline constexpr auto kFE4M3fn =
|
||||
machete::ScalarType::float_(4, 3, true, machete::ScalarType::NAN_EXTD_RANGE_MAX_MIN);
|
||||
static inline constexpr auto kFE5M2 = machete::ScalarType::float_IEEE754(5, 2);
|
||||
static inline constexpr auto kFE8M7 = machete::ScalarType::float_IEEE754(8, 7);
|
||||
static inline constexpr auto kFE5M10 = machete::ScalarType::float_IEEE754(5, 10);
|
||||
|
||||
// // Fixed width style names, generally following:
|
||||
// // https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57
|
||||
constexpr auto kInt4 = kS4;
|
||||
constexpr auto kUint4 = kU4;
|
||||
constexpr auto kUint4b8 = kU4B8;
|
||||
constexpr auto kInt8 = kS8;
|
||||
constexpr auto kUint8 = kU8;
|
||||
constexpr auto kUint8b128 = kU8B128;
|
||||
constexpr auto kFloat4_e2m1f = kFE2M1f;
|
||||
constexpr auto kFloat6_e3m2f = kFE3M2f;
|
||||
constexpr auto kFloat8_e5m2 = kFE5M2;
|
||||
constexpr auto kFloat16_e8m7 = kFE8M7;
|
||||
constexpr auto kFloat16_e5m10 = kFE5M10;
|
||||
|
||||
// colloquial names
|
||||
constexpr auto kHalf = kFE5M10;
|
||||
constexpr auto kFloat16 = kHalf;
|
||||
constexpr auto kFloat16Id = kFloat16.id();
|
||||
|
||||
constexpr auto kInt32 = phi::DataType::INT32;
|
||||
constexpr auto kInt64 = phi::DataType::INT64;
|
||||
constexpr auto kBool = phi::DataType::BOOL;
|
||||
constexpr auto kFloat8_e4m3fn = phi::DataType::FLOAT8_E4M3FN;
|
||||
constexpr auto kBFloat16 = phi::DataType::BFLOAT16;
|
||||
constexpr auto kFloat32 = phi::DataType::FLOAT32;
|
||||
constexpr auto kByte = phi::DataType::INT8;
|
||||
|
||||
}; // namespace machete
|
@@ -512,6 +512,8 @@ elif paddle.is_compiled_with_cuda():
|
||||
sources += find_end_files("gpu_ops/w4afp8_gemm", ".cu")
|
||||
os.system("python utils/auto_gen_wfp8afp8_sparse_gemm_kernel.py")
|
||||
sources += find_end_files("gpu_ops/wfp8afp8_sparse_gemm", ".cu")
|
||||
os.system("python gpu_ops/machete/generate.py")
|
||||
sources += find_end_files("gpu_ops/machete", ".cu")
|
||||
|
||||
setup(
|
||||
name="fastdeploy_ops",
|
||||
|
@@ -54,6 +54,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"FD_SAMPLING_CLASS": lambda: os.getenv("FD_SAMPLING_CLASS", "base"),
|
||||
# Set moe backend."cutlass","marlin" and "triton" can be set currently.
|
||||
"FD_MOE_BACKEND": lambda: os.getenv("FD_MOE_BACKEND", "cutlass"),
|
||||
# Whether to use Machete for wint4 dense gemm.
|
||||
"FD_USE_MACHETE": lambda: os.getenv("FD_USE_MACHETE", "0"),
|
||||
# Set whether to disable recompute the request when the KV cache is full.
|
||||
"FD_DISABLED_RECOVER": lambda: os.getenv("FD_DISABLED_RECOVER", "0"),
|
||||
# Set triton kernel JIT compilation directory.
|
||||
|
@@ -15,9 +15,12 @@
|
||||
"""
|
||||
|
||||
from .cutlass_scaled_mm import cutlass_scaled_mm
|
||||
from .machete_mm import machete_quantize_and_pack, machete_wint_mm
|
||||
from .scaled_fp8_quant import scaled_fp8_quant
|
||||
|
||||
__all__ = [
|
||||
"cutlass_scaled_mm",
|
||||
"scaled_fp8_quant",
|
||||
"machete_wint_mm",
|
||||
"machete_quantize_and_pack",
|
||||
]
|
||||
|
185
fastdeploy/model_executor/layers/quantization/ops/machete_mm.py
Normal file
185
fastdeploy/model_executor/layers/quantization/ops/machete_mm.py
Normal file
@@ -0,0 +1,185 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
|
||||
def get_sm_version():
|
||||
prop = paddle.device.cuda.get_device_properties()
|
||||
cc = prop.major * 10 + prop.minor
|
||||
return cc
|
||||
|
||||
|
||||
if current_platform.is_cuda() and get_sm_version() == 90:
|
||||
from fastdeploy.model_executor.ops.gpu import machete_mm, machete_prepack_B
|
||||
|
||||
|
||||
def get_pack_factor(num_bits):
|
||||
assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
|
||||
return 32 // num_bits
|
||||
|
||||
|
||||
def pack_rows(
|
||||
q_w: paddle.Tensor,
|
||||
num_bits: int,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
):
|
||||
assert q_w.shape == [size_k, size_n]
|
||||
|
||||
pack_factor = get_pack_factor(num_bits)
|
||||
assert size_k % pack_factor == 0
|
||||
|
||||
orig_device = q_w.place
|
||||
q_w_np = q_w.numpy().astype(np.uint32)
|
||||
|
||||
q_res = np.zeros((size_k // pack_factor, size_n), dtype=np.uint32)
|
||||
|
||||
for i in range(pack_factor):
|
||||
q_res |= q_w_np[i::pack_factor, :] << num_bits * i
|
||||
|
||||
q_res = paddle.to_tensor(q_res.astype(np.int32), place=orig_device)
|
||||
return q_res
|
||||
|
||||
|
||||
def quantize_weights(
|
||||
w: paddle.Tensor,
|
||||
group_size: Optional[int],
|
||||
quant_type: str = "uint4b8",
|
||||
):
|
||||
"""
|
||||
Quantize weights in PaddlePaddle, similar to PyTorch implementation.
|
||||
|
||||
Args:
|
||||
w: Input weight tensor (must be float type).
|
||||
quant_type: Target quantization type (e.g., `uint4`, `uint4b8`).
|
||||
group_size: Group size for quantization. If `-1`, use channel-wise quantization.
|
||||
zero_points: Whether to compute zero points (only for unsigned quant types).
|
||||
ref_zero_points_after_scales: If True, apply zero points after scales in dequantization.
|
||||
|
||||
Returns:
|
||||
w_ref: Dequantized reference weights.
|
||||
w_q: Quantized weights.
|
||||
w_s: Scales (None if `group_size` is None).
|
||||
"""
|
||||
assert paddle.is_floating_point(w), "w must be float type"
|
||||
assert quant_type in ["uint4", "uint4b8"], "only support quant_type = uint4, uint4b8"
|
||||
|
||||
orig_device = w.place
|
||||
size_k, size_n = w.shape
|
||||
|
||||
if group_size == -1:
|
||||
group_size = size_k
|
||||
|
||||
# Reshape to [group_size, -1]
|
||||
if group_size is not None and group_size < size_k:
|
||||
w = w.reshape([-1, group_size, size_n])
|
||||
w = w.transpose([1, 0, 2])
|
||||
w = w.reshape([group_size, -1])
|
||||
|
||||
# Compute scale for each group
|
||||
max_val = paddle.max(w, axis=0, keepdim=True)
|
||||
min_val = paddle.min(w, axis=0, keepdim=True)
|
||||
|
||||
max_q_val = float(7.0)
|
||||
min_q_val = float(-8.0)
|
||||
|
||||
w_s = paddle.ones([1], dtype=paddle.float32) # unscaled case
|
||||
|
||||
if group_size is not None:
|
||||
# Avoid division by zero
|
||||
max_scale = paddle.maximum(
|
||||
paddle.abs(max_val / (max_q_val if max_q_val != 0 else float("inf"))),
|
||||
paddle.abs(min_val / (min_q_val if min_q_val != 0 else float("inf"))),
|
||||
)
|
||||
w_s = max_scale
|
||||
|
||||
# Quantize
|
||||
w_q = paddle.round(w / w_s).astype(paddle.int32)
|
||||
w_q = paddle.clip(w_q, min_q_val, max_q_val)
|
||||
|
||||
# if hasattr(quant_type, 'bias'): # Custom quantization bias (if applicable)
|
||||
# w_q += quant_type.bias
|
||||
if quant_type == "uint4b8":
|
||||
w_q += 8
|
||||
|
||||
# Restore original shapes
|
||||
if group_size is not None and group_size < size_k:
|
||||
|
||||
def reshape_w(w_tensor):
|
||||
w_tensor = w_tensor.reshape([group_size, -1, size_n])
|
||||
w_tensor = w_tensor.transpose([1, 0, 2])
|
||||
w_tensor = w_tensor.reshape([size_k, size_n])
|
||||
return w_tensor
|
||||
|
||||
w_q = reshape_w(w_q)
|
||||
w_s = w_s.reshape([-1, size_n])
|
||||
|
||||
# Move tensors back to original device
|
||||
w_q = w_q.to(orig_device)
|
||||
if w_s is not None:
|
||||
w_s = w_s.to(orig_device)
|
||||
|
||||
return w_q, w_s
|
||||
|
||||
|
||||
def machete_quantize_and_pack(
|
||||
w: paddle.Tensor,
|
||||
atype: paddle.dtype,
|
||||
quant_type: str = "uint4b8",
|
||||
scale_type: str = "",
|
||||
group_size: int = -1,
|
||||
):
|
||||
w_q, w_s = quantize_weights(w, group_size, quant_type=quant_type)
|
||||
w_q = pack_rows(w_q, 4, *w_q.shape)
|
||||
w_q_col = w_q.transpose([1, 0]).contiguous() # convert to col major
|
||||
w_q_prepack = machete_prepack_B(
|
||||
w_q_col,
|
||||
atype,
|
||||
quant_type,
|
||||
scale_type,
|
||||
)[0]
|
||||
return w_q_prepack, w_s
|
||||
|
||||
|
||||
def machete_wint_mm(
|
||||
x: paddle.Tensor,
|
||||
w_prepack: paddle.Tensor,
|
||||
w_g_s: paddle.Tensor,
|
||||
w_g_zp: Optional[paddle.Tensor] = None,
|
||||
w_ch_s: Optional[paddle.Tensor] = None,
|
||||
w_tok_s: Optional[paddle.Tensor] = None,
|
||||
weight_dtype: str = "uint4b8",
|
||||
group_size: int = -1,
|
||||
out_dtype: str = "",
|
||||
scheduler: str = "",
|
||||
):
|
||||
out = machete_mm(
|
||||
x,
|
||||
w_prepack,
|
||||
w_g_s, # group scales
|
||||
w_g_zp, # group zeros
|
||||
w_ch_s, # per-channel scale
|
||||
w_tok_s, # per-token scale
|
||||
weight_dtype, # weight_dtype
|
||||
out_dtype, # out_dtype
|
||||
group_size, # group_size
|
||||
scheduler, # scheduler
|
||||
)[0]
|
||||
return out
|
@@ -21,6 +21,7 @@ from typing import Optional
|
||||
import paddle
|
||||
from paddle.nn.quant import weight_only_linear, weight_quantize
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
@@ -33,6 +34,12 @@ from ..utils import get_tensor
|
||||
from .quant_base import QuantConfigBase, QuantMethodBase
|
||||
|
||||
|
||||
def get_sm_version():
|
||||
prop = paddle.device.cuda.get_device_properties()
|
||||
cc = prop.major * 10 + prop.minor
|
||||
return cc
|
||||
|
||||
|
||||
class WeightOnlyConfig(QuantConfigBase):
|
||||
"""
|
||||
Quantization config for weight only
|
||||
@@ -132,6 +139,14 @@ class WeightOnlyConfig(QuantConfigBase):
|
||||
else:
|
||||
raise ValueError(f"Unsupported MOE backend {layer.use_method}")
|
||||
else:
|
||||
if (
|
||||
self.name() == "wint4"
|
||||
and envs.FD_USE_MACHETE == "1"
|
||||
and get_sm_version() == 90
|
||||
and layer.weight_shape[1]
|
||||
and layer.weight_shape[1] % 128 == 0
|
||||
):
|
||||
return MacheteWeightOnlyLinearMethod(self)
|
||||
return GPUWeightOnlyLinearMethod(self)
|
||||
|
||||
|
||||
@@ -329,3 +344,73 @@ class GPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
quanted_weight_tensor = paddle.transpose(quanted_weight_tensor, [1, 0])
|
||||
layer.weight.set_value(quanted_weight_tensor)
|
||||
layer.weight_scale.set_value(weight_scale_tensor.astype(paddle.get_default_dtype()))
|
||||
|
||||
|
||||
class MacheteWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
"""
|
||||
Weight only quantization method for linear layer on GPU using Machete
|
||||
The weights are loaded in the BF16 numerical format. After loading, the quantization coefficients will be computed,
|
||||
and the weights will be quantized to int8 or int4.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: WeightOnlyConfig,
|
||||
) -> None:
|
||||
super().__init__(quant_config)
|
||||
|
||||
def create_weights(self, layer, **extra_weight_attrs):
|
||||
|
||||
assert layer.bias is None, "Machete weight only linear method does not support bias."
|
||||
assert self.quant_config.name() == "wint4", "Machete weight only linear method only supports wint4."
|
||||
|
||||
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
|
||||
weight_scale_shape = [1, layer.weight_shape[1]]
|
||||
|
||||
# layer.weight_shape.reverse()
|
||||
if self.quant_config.name() == "wint4":
|
||||
layer.weight_shape[0] //= 8
|
||||
layer.weight_dtype = "int32"
|
||||
|
||||
layer.weight = layer.create_parameter(
|
||||
shape=layer.weight_shape,
|
||||
dtype=layer.weight_dtype,
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
|
||||
layer.weight_scale = layer.create_parameter(
|
||||
shape=weight_scale_shape,
|
||||
dtype=layer._dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
|
||||
def process_prequanted_weights(self, layer, state_dict) -> None:
|
||||
pass
|
||||
|
||||
def process_loaded_weights(self, layer, weight) -> None:
|
||||
from fastdeploy.model_executor.layers.quantization.ops import (
|
||||
machete_quantize_and_pack,
|
||||
)
|
||||
|
||||
quanted_weight_tensor, weight_scale_tensor = machete_quantize_and_pack(
|
||||
w=weight,
|
||||
atype=layer._dtype,
|
||||
quant_type="uint4b8",
|
||||
)
|
||||
layer.weight.set_value(quanted_weight_tensor)
|
||||
layer.weight_scale.set_value(weight_scale_tensor.astype(paddle.get_default_dtype()))
|
||||
|
||||
def apply(self, layer, x):
|
||||
assert layer.bias is None, "Machete weight only linear method does not support bias."
|
||||
assert self.quant_config.name() == "wint4", "Machete weight only linear method only supports wint4."
|
||||
from fastdeploy.model_executor.layers.quantization.ops import machete_wint_mm
|
||||
|
||||
linear_out = machete_wint_mm(
|
||||
x,
|
||||
w_prepack=layer.weight,
|
||||
w_g_s=layer.weight_scale,
|
||||
weight_dtype="uint4b8",
|
||||
)
|
||||
|
||||
return linear_out
|
||||
|
174
tests/operators/test_machete_mm.py
Normal file
174
tests/operators/test_machete_mm.py
Normal file
@@ -0,0 +1,174 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import re
|
||||
import struct
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn.quant as Q
|
||||
from paddle import base
|
||||
from paddle.base import core
|
||||
from paddle.framework import set_default_dtype
|
||||
|
||||
from fastdeploy.model_executor.layers.quantization.ops import (
|
||||
machete_quantize_and_pack,
|
||||
machete_wint_mm,
|
||||
)
|
||||
|
||||
np.random.seed(123)
|
||||
paddle.seed(123)
|
||||
|
||||
|
||||
def get_cuda_version():
|
||||
result = os.popen("nvcc --version").read()
|
||||
regex = r"release (\S+),"
|
||||
match = re.search(regex, result)
|
||||
if match:
|
||||
num = str(match.group(1))
|
||||
integer, decimal = num.split(".")
|
||||
return int(integer) * 1000 + int(float(decimal) * 10)
|
||||
else:
|
||||
return -1
|
||||
|
||||
|
||||
def get_sm_version():
|
||||
prop = paddle.device.cuda.get_device_properties()
|
||||
cc = prop.major * 10 + prop.minor
|
||||
return cc
|
||||
|
||||
|
||||
def convert_uint16_to_float(in_list):
|
||||
in_list = np.asarray(in_list)
|
||||
out = np.vectorize(
|
||||
lambda x: struct.unpack("<f", struct.pack("<I", np.uint32(x) << np.uint32(16)))[0],
|
||||
otypes=[np.float32],
|
||||
)(in_list.flat)
|
||||
return np.reshape(out, in_list.shape)
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not core.is_compiled_with_cuda() or get_sm_version() < 90,
|
||||
"machete only support sm90.",
|
||||
)
|
||||
class WeightOnlyLinearTestCase(unittest.TestCase):
|
||||
def config(self):
|
||||
self.dtype = "float16"
|
||||
self.rtol = 1e-5
|
||||
self.atol = 1e-2
|
||||
self.bias = False
|
||||
self.batch = 1
|
||||
self.token = 512
|
||||
self.in_features = 7168
|
||||
self.out_features = 1024
|
||||
self.weight_dtype = "int4"
|
||||
self.static = False
|
||||
self.group_size = -1
|
||||
|
||||
def setUp(self):
|
||||
self.config()
|
||||
if self.dtype == "bfloat16" or self.weight_dtype == "int4":
|
||||
self.atol = 1.3e-1
|
||||
x = np.random.random((self.token, self.in_features))
|
||||
self.x = paddle.to_tensor(x, dtype=self.dtype)
|
||||
if self.bias:
|
||||
bias_attr = base.ParamAttr(
|
||||
trainable=False,
|
||||
regularizer=None,
|
||||
initializer=paddle.nn.initializer.Constant(value=1.0),
|
||||
)
|
||||
else:
|
||||
bias_attr = None
|
||||
set_default_dtype(self.dtype)
|
||||
self.linear = paddle.nn.Linear(self.in_features, self.out_features, bias_attr=bias_attr)
|
||||
|
||||
self.bias = self.linear.bias
|
||||
self.weight = self.linear.weight
|
||||
self.float_weight = self.linear.weight
|
||||
self.weight_scale = None
|
||||
|
||||
self.weight, self.weight_scale = Q.weight_quantize(
|
||||
(self.float_weight.cuda() if self.weight_dtype == "int8" else self.weight.cpu()),
|
||||
algo=("weight_only_int8" if self.weight_dtype == "int8" else "weight_only_int4"),
|
||||
group_size=self.group_size,
|
||||
)
|
||||
|
||||
def get_linear_out(self):
|
||||
out = self.linear(self.x)
|
||||
return out.numpy()
|
||||
|
||||
def get_weight_only_linear_out(self):
|
||||
for i in range(10):
|
||||
out = Q.weight_only_linear(
|
||||
self.x,
|
||||
self.weight,
|
||||
bias=self.bias,
|
||||
weight_scale=self.weight_scale,
|
||||
weight_dtype=self.weight_dtype,
|
||||
group_size=self.group_size,
|
||||
)
|
||||
return out.numpy()
|
||||
|
||||
def get_machete_weight_only_linear_out(self):
|
||||
w_q, w_s = machete_quantize_and_pack(
|
||||
w=self.float_weight.cuda(),
|
||||
atype=self.dtype,
|
||||
quant_type="uint4b8",
|
||||
)
|
||||
|
||||
out = machete_wint_mm(
|
||||
self.x,
|
||||
w_prepack=w_q,
|
||||
w_g_s=w_s, # group scales
|
||||
weight_dtype="uint4b8", # weight_dtype
|
||||
)
|
||||
return out.numpy()
|
||||
|
||||
def test_weight_only_linear(self):
|
||||
# out_expect = self.get_linear_out()
|
||||
out_paddle = self.get_weight_only_linear_out()
|
||||
out_machete = self.get_machete_weight_only_linear_out()
|
||||
|
||||
if self.dtype == "bfloat16":
|
||||
out_paddle = convert_uint16_to_float(out_paddle)
|
||||
# out_expect = convert_uint16_to_float(out_expect)
|
||||
out_machete = convert_uint16_to_float(out_machete)
|
||||
np.testing.assert_allclose(out_paddle, out_machete, rtol=self.rtol, atol=self.atol)
|
||||
|
||||
|
||||
M = [32, 128]
|
||||
K_N = [[2048, 4096]]
|
||||
|
||||
|
||||
def make_case(m, k, n):
|
||||
class Case(WeightOnlyLinearTestCase):
|
||||
def config(self, _m=m, _k=k, _n=n):
|
||||
super().config()
|
||||
self.token = m
|
||||
self.in_features = k
|
||||
self.out_features = n
|
||||
|
||||
Case.name = f"WeightOnlyLinearTestCase{m}{k}{n}"
|
||||
return Case
|
||||
|
||||
|
||||
for k, n in K_N:
|
||||
for m in M:
|
||||
cls = make_case(m, k, n)
|
||||
globals()[cls.name] = cls
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Reference in New Issue
Block a user