mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-26 12:31:27 +08:00

* support machete weight only gemm * add generate * update * fix * change file location * add sm_version limit * fix * fix * fix ci * fix coverage * fix xpu
575 lines
18 KiB
Python
575 lines
18 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import itertools
|
|
import math
|
|
import os
|
|
import shutil
|
|
import sys
|
|
from collections.abc import Iterable
|
|
from copy import deepcopy
|
|
from dataclasses import dataclass, fields
|
|
from functools import reduce
|
|
from typing import Optional, Union
|
|
|
|
import jinja2
|
|
|
|
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
|
p = os.path.abspath(os.path.join(cur_dir, "../../third_party/cutlass/python"))
|
|
sys.path.insert(0, p)
|
|
|
|
from cutlass_library import (
|
|
EpilogueScheduleTag,
|
|
EpilogueScheduleType,
|
|
TileSchedulerTag,
|
|
TileSchedulerType,
|
|
)
|
|
|
|
# yapf conflicts with isort for this block
|
|
# yapf: disable
|
|
from machete_cutlass_library_extension import (
|
|
DataType,
|
|
MACHETEDataType,
|
|
MACHETEDataTypeMACHETEScalarTypeTag,
|
|
MACHETEDataTypeNames,
|
|
MACHETEDataTypePaddleDataTypeTag,
|
|
MACHETEDataTypeSize,
|
|
MACHETEDataTypeTag,
|
|
MACHETEKernelScheduleTag,
|
|
MixedInputKernelScheduleType,
|
|
)
|
|
|
|
# yapf: enable
|
|
|
|
#
|
|
# Generator templating
|
|
#
|
|
|
|
DISPATCH_TEMPLATE = """
|
|
#include "../machete_mm_launcher.cuh"
|
|
|
|
namespace machete {
|
|
|
|
{% for impl_config in impl_configs %}
|
|
{% set type_sig = gen_type_sig(impl_config.types) -%}
|
|
{% for s in impl_config.schedules %}
|
|
extern paddle::Tensor impl_{{type_sig}}_sch_{{gen_sch_sig(s)}}(MMArgs);
|
|
{%- endfor %}
|
|
|
|
paddle::Tensor mm_dispatch_{{type_sig}}(MMArgs args) {
|
|
[[maybe_unused]] auto M = args.A.shape()[0];
|
|
[[maybe_unused]] auto N = args.B.shape()[1];
|
|
[[maybe_unused]] auto K = args.A.shape()[1];
|
|
|
|
if (!args.maybe_schedule) {
|
|
{%- for cond, s in impl_config.heuristic %}
|
|
{%if cond is not none%}if ({{cond}})
|
|
{%- else %}else
|
|
{%- endif %}
|
|
return impl_{{type_sig}}_sch_{{ gen_sch_sig(s) }}(args);{% endfor %}
|
|
}
|
|
|
|
{%- for s in impl_config.schedules %}
|
|
if (*args.maybe_schedule == "{{ gen_sch_sig(s) }}")
|
|
return impl_{{type_sig}}_sch_{{ gen_sch_sig(s) }}(args);
|
|
{%- endfor %}
|
|
PADDLE_ENFORCE(false, "machete_gemm(..) is not implemented ");
|
|
}
|
|
{%- endfor %}
|
|
|
|
|
|
static inline std::optional<paddle::DataType> maybe_scalartype(
|
|
std::optional<paddle::Tensor> const& t) {
|
|
if (!t) {
|
|
return std::nullopt;
|
|
} else {
|
|
return t->dtype();
|
|
};
|
|
}
|
|
|
|
paddle::Tensor mm_dispatch(MMArgs args) {
|
|
auto out_type = args.maybe_out_type.value_or(args.A.dtype());
|
|
auto a_type = args.A.dtype();
|
|
auto maybe_g_scales_type = maybe_scalartype(args.maybe_group_scales);
|
|
auto maybe_g_zeros_type = maybe_scalartype(args.maybe_group_zeros);
|
|
auto maybe_ch_scales_type = maybe_scalartype(args.maybe_channel_scales);
|
|
auto maybe_tok_scales_type = maybe_scalartype(args.maybe_token_scales);
|
|
|
|
{% for impl_config in impl_configs %}
|
|
{% set t = impl_config.types -%}
|
|
{% set type_sig = gen_type_sig(t) -%}
|
|
if (args.b_type == {{MACHETEScalarTypeTag[t.b]}}
|
|
&& a_type == {{PaddleTypeTag[t.a]}}
|
|
&& out_type == {{PaddleTypeTag[t.out]}}
|
|
&& {%if t.b_group_scale != void -%}
|
|
maybe_g_scales_type == {{PaddleTypeTag[t.b_group_scale]}}
|
|
{%- else %}!maybe_g_scales_type{%endif%}
|
|
&& {%if t.b_group_zeropoint != void -%}
|
|
maybe_g_zeros_type == {{PaddleTypeTag[t.b_group_zeropoint]}}
|
|
{%- else %}!maybe_g_zeros_type{%endif%}
|
|
&& {%if t.b_channel_scale != void -%}
|
|
maybe_ch_scales_type == {{PaddleTypeTag[t.b_channel_scale]}}
|
|
{%- else %}!maybe_ch_scales_type{%endif%}
|
|
&& {%if t.a_token_scale != void -%}
|
|
maybe_tok_scales_type == {{PaddleTypeTag[t.a_token_scale]}}
|
|
{%- else %}!maybe_tok_scales_type{%endif%}
|
|
) {
|
|
return mm_dispatch_{{type_sig}}(args);
|
|
}
|
|
{%- endfor %}
|
|
|
|
PADDLE_ENFORCE(
|
|
false, "machete_mm(..) is not implemented "
|
|
"; implemented types are: \\n",
|
|
{%- for impl_config in impl_configs %}
|
|
{% set t = impl_config.types -%}
|
|
"\\t{{gen_type_option_name(t)}}\\n",
|
|
{%- endfor %}
|
|
"");
|
|
}
|
|
|
|
std::vector<std::string> supported_schedules_dispatch(
|
|
SupportedSchedulesArgs args) {
|
|
auto out_type = args.maybe_out_type.value_or(args.a_type);
|
|
|
|
{% for impl_config in impl_configs %}
|
|
{% set t = impl_config.types -%}
|
|
{% set schs = impl_config.schedules -%}
|
|
if (args.b_type == {{MACHETEScalarTypeTag[t.b]}}
|
|
&& args.a_type == {{PaddleTypeTag[t.a]}}
|
|
&& out_type == {{PaddleTypeTag[t.out]}}
|
|
&& {%if t.b_group_scale != void -%}
|
|
args.maybe_group_scales_type == {{PaddleTypeTag[t.b_group_scale]}}
|
|
{%- else %}!args.maybe_group_scales_type{%endif%}
|
|
&& {%if t.b_group_zeropoint != void-%}
|
|
args.maybe_group_zeros_type == {{PaddleTypeTag[t.b_group_zeropoint]}}
|
|
{%- else %}!args.maybe_group_zeros_type{%endif%}
|
|
) {
|
|
return {
|
|
{%- for s in impl_config.schedules %}
|
|
"{{gen_sch_sig(s)}}"{% if not loop.last %},{% endif %}
|
|
{%- endfor %}
|
|
};
|
|
}
|
|
{%- endfor %}
|
|
|
|
return {};
|
|
};
|
|
|
|
}; // namespace machete
|
|
"""
|
|
|
|
IMPL_TEMPLATE = """
|
|
#include "../machete_mm_launcher.cuh"
|
|
|
|
namespace machete {
|
|
|
|
{% for sch in unique_schedules(impl_configs) %}
|
|
{% set sch_sig = gen_sch_sig(sch) -%}
|
|
struct sch_{{sch_sig}} {
|
|
using TileShapeNM = Shape<{{
|
|
to_cute_constant(sch.tile_shape_mn)|join(', ')}}>;
|
|
using ClusterShape = Shape<{{
|
|
to_cute_constant(sch.cluster_shape_mnk)|join(', ')}}>;
|
|
// TODO: Reimplement
|
|
// using KernelSchedule = {{KernelScheduleTag[sch.kernel_schedule]}};
|
|
using EpilogueSchedule = {{EpilogueScheduleTag[sch.epilogue_schedule]}};
|
|
using TileScheduler = {{TileSchedulerTag[sch.tile_scheduler]}};
|
|
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
|
};
|
|
{% endfor %}
|
|
|
|
{% for impl_config in impl_configs %}
|
|
{% set t = impl_config.types -%}
|
|
{% set schs = impl_config.schedules -%}
|
|
{% set type_sig = gen_type_sig(t) -%}
|
|
|
|
template<typename Sch>
|
|
using Kernel_{{type_sig}} = MacheteKernelTemplate<
|
|
{{DataTypeTag[t.a]}}, // ElementA
|
|
{{DataTypeTag[t.b]}}, // ElementB
|
|
{{DataTypeTag[t.out]}}, // ElementD
|
|
{{DataTypeTag[t.accumulator]}}, // Accumulator
|
|
{{DataTypeTag[t.b_group_scale]}}, // GroupScaleT
|
|
{{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT
|
|
{{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT
|
|
{{DataTypeTag[t.a_token_scale]}}, // TokenScaleT
|
|
cutlass::gemm::KernelTmaWarpSpecializedCooperative,
|
|
Sch>;
|
|
|
|
{% for sch in schs %}
|
|
{% set sch_sig = gen_sch_sig(sch) -%}
|
|
paddle::Tensor
|
|
impl_{{type_sig}}_sch_{{sch_sig}}(MMArgs args) {
|
|
return run_impl<Kernel_{{type_sig}}<sch_{{sch_sig}}>>(args);
|
|
}
|
|
{%- endfor %}
|
|
{%- endfor %}
|
|
|
|
}; // namespace machete
|
|
"""
|
|
|
|
PREPACK_TEMPLATE = """
|
|
#include "../machete_prepack_launcher.cuh"
|
|
|
|
namespace machete {
|
|
|
|
paddle::Tensor prepack_B_dispatch(PrepackBArgs args) {
|
|
auto convert_type = args.maybe_group_scales_type.value_or(args.a_type);
|
|
{%- for t in types %}
|
|
{% set b_type = unsigned_type_with_bitwidth(t.b_num_bits) %}
|
|
if (args.a_type == {{PaddleTypeTag[t.a]}}
|
|
&& args.b_type.size_bits() == {{t.b_num_bits}}
|
|
&& convert_type == {{PaddleTypeTag[t.convert]}}) {
|
|
return prepack_impl<
|
|
PrepackedLayoutBTemplate<
|
|
{{DataTypeTag[t.a]}}, // ElementA
|
|
{{DataTypeTag[b_type]}}, // ElementB
|
|
{{DataTypeTag[t.convert]}}, // ElementConvert
|
|
{{DataTypeTag[t.accumulator]}}, // Accumulator
|
|
cutlass::layout::ColumnMajor,
|
|
cutlass::gemm::KernelTmaWarpSpecializedCooperative>
|
|
>(args.B);
|
|
}
|
|
{%- endfor %}
|
|
|
|
PADDLE_ENFORCE(false,
|
|
"prepack_B_dispatch(..) is not implemented");
|
|
}
|
|
|
|
}; // namespace machete
|
|
"""
|
|
|
|
TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperative
|
|
TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ScheduleConfig:
|
|
tile_shape_mn: tuple[int, int]
|
|
cluster_shape_mnk: tuple[int, int, int]
|
|
kernel_schedule: MixedInputKernelScheduleType
|
|
epilogue_schedule: EpilogueScheduleType
|
|
tile_scheduler: TileSchedulerType
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class TypeConfig:
|
|
a: DataType
|
|
b: Union[DataType, MACHETEDataType]
|
|
b_group_scale: DataType
|
|
b_group_zeropoint: DataType
|
|
b_channel_scale: DataType
|
|
a_token_scale: DataType
|
|
out: DataType
|
|
accumulator: DataType
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class PrepackTypeConfig:
|
|
a: DataType
|
|
b_num_bits: int
|
|
convert: DataType
|
|
accumulator: DataType
|
|
|
|
|
|
@dataclass
|
|
class ImplConfig:
|
|
types: TypeConfig
|
|
schedules: list[ScheduleConfig]
|
|
heuristic: list[tuple[Optional[str], ScheduleConfig]]
|
|
|
|
|
|
def generate_sch_sig(schedule_config: ScheduleConfig) -> str:
|
|
tile_shape = f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}"
|
|
cluster_shape = (
|
|
f"{schedule_config.cluster_shape_mnk[0]}"
|
|
+ f"x{schedule_config.cluster_shape_mnk[1]}"
|
|
+ f"x{schedule_config.cluster_shape_mnk[2]}"
|
|
)
|
|
kernel_schedule = MACHETEKernelScheduleTag[schedule_config.kernel_schedule].split("::")[-1]
|
|
epilogue_schedule = EpilogueScheduleTag[schedule_config.epilogue_schedule].split("::")[-1]
|
|
tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler].split("::")[-1]
|
|
|
|
return f"{tile_shape}_{cluster_shape}_{kernel_schedule}" + f"_{epilogue_schedule}_{tile_scheduler}"
|
|
|
|
|
|
# mostly unique shorter sch_sig
|
|
def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str:
|
|
kernel_terse_names_replace = {
|
|
"KernelTmaWarpSpecializedCooperative": "TmaMI_",
|
|
"TmaWarpSpecializedCooperative_": "TmaCoop_",
|
|
"StreamKScheduler": "streamK",
|
|
}
|
|
|
|
sch_sig = generate_sch_sig(schedule_config)
|
|
for orig, terse in kernel_terse_names_replace.items():
|
|
sch_sig = sch_sig.replace(orig, terse)
|
|
return sch_sig
|
|
|
|
|
|
# unique type_name
|
|
def generate_type_signature(kernel_types: TypeConfig):
|
|
return str("".join([MACHETEDataTypeNames[getattr(kernel_types, field.name)] for field in fields(TypeConfig)]))
|
|
|
|
|
|
def generate_type_option_name(kernel_types: TypeConfig):
|
|
return ", ".join(
|
|
[
|
|
f"{field.name.replace('b_', 'with_')+'_type'}=" + MACHETEDataTypeNames[getattr(kernel_types, field.name)]
|
|
for field in fields(TypeConfig)
|
|
]
|
|
)
|
|
|
|
|
|
def is_power_of_two(n):
|
|
return (n != 0) and (n & (n - 1) == 0)
|
|
|
|
|
|
def to_cute_constant(value: list[int]):
|
|
|
|
def _to_cute_constant(value: int):
|
|
if is_power_of_two(value):
|
|
return f"_{value}"
|
|
else:
|
|
return f"Int<{value}>"
|
|
|
|
if isinstance(value, Iterable):
|
|
return [_to_cute_constant(value) for value in value]
|
|
else:
|
|
return _to_cute_constant(value)
|
|
|
|
|
|
def unique_schedules(impl_configs: list[ImplConfig]):
|
|
return list(set(sch for impl_config in impl_configs for sch in impl_config.schedules))
|
|
|
|
|
|
def unsigned_type_with_bitwidth(num_bits):
|
|
return {
|
|
4: DataType.u4,
|
|
8: DataType.u8,
|
|
16: DataType.u16,
|
|
32: DataType.u32,
|
|
64: DataType.u64,
|
|
}[num_bits]
|
|
|
|
|
|
template_globals = {
|
|
"void": DataType.void,
|
|
"DataTypeTag": MACHETEDataTypeTag,
|
|
"MACHETEScalarTypeTag": MACHETEDataTypeMACHETEScalarTypeTag,
|
|
"PaddleTypeTag": MACHETEDataTypePaddleDataTypeTag,
|
|
"KernelScheduleTag": MACHETEKernelScheduleTag,
|
|
"EpilogueScheduleTag": EpilogueScheduleTag,
|
|
"TileSchedulerTag": TileSchedulerTag,
|
|
"to_cute_constant": to_cute_constant,
|
|
"gen_sch_sig": generate_terse_sch_sig,
|
|
"gen_type_sig": generate_type_signature,
|
|
"unique_schedules": unique_schedules,
|
|
"unsigned_type_with_bitwidth": unsigned_type_with_bitwidth,
|
|
"gen_type_option_name": generate_type_option_name,
|
|
}
|
|
|
|
|
|
def create_template(template_str):
|
|
template = jinja2.Template(template_str)
|
|
template.globals.update(template_globals)
|
|
return template
|
|
|
|
|
|
mm_dispatch_template = create_template(DISPATCH_TEMPLATE)
|
|
mm_impl_template = create_template(IMPL_TEMPLATE)
|
|
prepack_dispatch_template = create_template(PREPACK_TEMPLATE)
|
|
|
|
|
|
def create_sources(impl_configs: list[ImplConfig], num_impl_files=8):
|
|
sources = []
|
|
|
|
sources.append(
|
|
(
|
|
"machete_mm_dispatch",
|
|
mm_dispatch_template.render(impl_configs=impl_configs),
|
|
)
|
|
)
|
|
|
|
prepack_types = []
|
|
for impl_config in impl_configs:
|
|
convert_type = (
|
|
impl_config.types.a
|
|
if impl_config.types.b_group_scale == DataType.void
|
|
else impl_config.types.b_group_scale
|
|
)
|
|
prepack_types.append(
|
|
PrepackTypeConfig(
|
|
a=impl_config.types.a,
|
|
b_num_bits=MACHETEDataTypeSize[impl_config.types.b],
|
|
convert=convert_type,
|
|
accumulator=impl_config.types.accumulator,
|
|
)
|
|
)
|
|
|
|
def prepacked_type_key(prepack_type: PrepackTypeConfig):
|
|
# For now we we can just use the first accumulator type seen since
|
|
# the tensor core shapes/layouts don't vary based on accumulator
|
|
# type so we can generate less code this way
|
|
return (prepack_type.a, prepack_type.b_num_bits, prepack_type.convert)
|
|
|
|
unique_prepack_types = []
|
|
prepack_types_seen = set()
|
|
for prepack_type in prepack_types:
|
|
key = prepacked_type_key(prepack_type)
|
|
if key not in prepack_types_seen:
|
|
unique_prepack_types.append(prepack_type)
|
|
prepack_types_seen.add(key)
|
|
|
|
sources.append(
|
|
(
|
|
"machete_prepack",
|
|
prepack_dispatch_template.render(
|
|
types=unique_prepack_types,
|
|
),
|
|
)
|
|
)
|
|
|
|
# Split up impls across files
|
|
num_impls = reduce(lambda x, y: x + len(y.schedules), impl_configs, 0)
|
|
num_impls_per_file = math.ceil(num_impls / num_impl_files)
|
|
|
|
files_impls: list[list[ImplConfig]] = [[]]
|
|
|
|
curr_num_impls_assigned = 0
|
|
curr_impl_in_file = 0
|
|
curr_impl_configs = deepcopy(list(reversed(impl_configs)))
|
|
|
|
while curr_num_impls_assigned < num_impls:
|
|
room_left_in_file = num_impls_per_file - curr_impl_in_file
|
|
if room_left_in_file == 0:
|
|
files_impls.append([])
|
|
room_left_in_file = num_impls_per_file
|
|
curr_impl_in_file = 0
|
|
|
|
curr_ic = curr_impl_configs[-1]
|
|
if len(curr_ic.schedules) >= room_left_in_file:
|
|
# Break apart the current impl config
|
|
tmp_ic = deepcopy(curr_ic)
|
|
tmp_ic.schedules = curr_ic.schedules[:room_left_in_file]
|
|
curr_ic.schedules = curr_ic.schedules[room_left_in_file:]
|
|
files_impls[-1].append(tmp_ic)
|
|
else:
|
|
files_impls[-1].append(curr_ic)
|
|
curr_impl_configs.pop()
|
|
curr_num_impls_assigned += len(files_impls[-1][-1].schedules)
|
|
curr_impl_in_file += len(files_impls[-1][-1].schedules)
|
|
|
|
for part, file_impls in enumerate(files_impls):
|
|
sources.append(
|
|
(
|
|
f"machete_mm_impl_part{part+1}",
|
|
mm_impl_template.render(impl_configs=file_impls),
|
|
)
|
|
)
|
|
|
|
return sources
|
|
|
|
|
|
def generate():
|
|
# See csrc/quantization/machete/Readme.md, the Codegeneration for more info
|
|
# about how this works
|
|
SCRIPT_DIR = os.path.dirname(__file__)
|
|
|
|
sch_common_params = dict(
|
|
kernel_schedule=TmaMI,
|
|
epilogue_schedule=TmaCoop,
|
|
tile_scheduler=TileSchedulerType.StreamK,
|
|
)
|
|
|
|
# Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk))
|
|
default_tile_heuristic_config = {
|
|
# M = 257+
|
|
"M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)),
|
|
"M > 256": ((128, 256), (2, 1, 1)),
|
|
# M = 129-256
|
|
"M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)),
|
|
"M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)),
|
|
"M > 128": ((128, 256), (2, 1, 1)),
|
|
# M = 65-128
|
|
"M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)),
|
|
"M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)),
|
|
"M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)),
|
|
"M > 64": ((128, 128), (2, 1, 1)),
|
|
# M = 33-64
|
|
"M > 40 && K <= 6144 && N <= 6144": ((128, 32), (2, 1, 1)),
|
|
"M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)),
|
|
"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)),
|
|
"M > 32": ((128, 64), (2, 1, 1)),
|
|
# M = 17-32
|
|
"M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)),
|
|
"M > 16": ((256, 32), (2, 1, 1)),
|
|
# M = 1-16
|
|
"N >= 26624": ((256, 16), (1, 1, 1)),
|
|
None: ((128, 16), (1, 1, 1)),
|
|
}
|
|
|
|
# For now we use the same heuristic for all types
|
|
# Heuristic is currently tuned for H100s
|
|
default_heuristic = [
|
|
(cond, ScheduleConfig(*tile_config, **sch_common_params)) # type: ignore
|
|
for cond, tile_config in default_tile_heuristic_config.items()
|
|
]
|
|
|
|
def get_unique_schedules(heuristic: dict[str, ScheduleConfig]):
|
|
# Do not use schedules = list(set(...)) because we need to make sure
|
|
# the output list is deterministic; otherwise the generated kernel file
|
|
# will be non-deterministic and causes ccache miss.
|
|
schedules = []
|
|
for _, schedule_config in heuristic:
|
|
if schedule_config not in schedules:
|
|
schedules.append(schedule_config)
|
|
return schedules
|
|
|
|
impl_configs = []
|
|
|
|
GPTQ_kernel_type_configs = list(
|
|
TypeConfig(
|
|
a=a,
|
|
b=b,
|
|
b_group_scale=a,
|
|
b_group_zeropoint=DataType.void,
|
|
b_channel_scale=DataType.void,
|
|
a_token_scale=DataType.void,
|
|
out=a,
|
|
accumulator=DataType.f32,
|
|
)
|
|
for b in (MACHETEDataType.u4b8, MACHETEDataType.u8b128)
|
|
for a in (DataType.f16, DataType.bf16)
|
|
)
|
|
|
|
impl_configs += [
|
|
ImplConfig(x[0], x[1], x[2])
|
|
for x in zip(
|
|
GPTQ_kernel_type_configs,
|
|
itertools.repeat(get_unique_schedules(default_heuristic)),
|
|
itertools.repeat(default_heuristic),
|
|
)
|
|
]
|
|
|
|
output_dir = os.path.join(SCRIPT_DIR, "generated")
|
|
|
|
# Delete the "generated" directory if it exists
|
|
if os.path.exists(output_dir):
|
|
shutil.rmtree(output_dir)
|
|
|
|
# Create the "generated" directory
|
|
os.makedirs(output_dir)
|
|
|
|
# Render each group of configurations into separate files
|
|
for filename, code in create_sources(impl_configs):
|
|
filepath = os.path.join(output_dir, f"{filename}.cu")
|
|
with open(filepath, "w") as output_file:
|
|
output_file.write(code)
|
|
print(f"Rendered template to {filepath}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
generate()
|