diff --git a/custom_ops/gpu_ops/machete/machete_mm.cu b/custom_ops/gpu_ops/machete/machete_mm.cu index c6f56d1c9..920c2e5ad 100644 --- a/custom_ops/gpu_ops/machete/machete_mm.cu +++ b/custom_ops/gpu_ops/machete/machete_mm.cu @@ -86,3 +86,52 @@ std::vector MacheteMMKernel( maybe_schedule); return {out}; } + +std::vector> MacheteMMKernelInferShape( + std::vector const& A_shape, + std::vector const& B_shape, + paddle::optional> const& maybe_group_scales_shape, + paddle::optional> const& maybe_group_zeros_shape, + paddle::optional> const& maybe_channel_scales_shape, + paddle::optional> const& maybe_token_scales_shape, + std::string const& b_type_str, + std::string const& maybe_out_type_str, + int64_t const& maybe_group_size, + std::string const& maybe_schedule) { + return {{A_shape[0], B_shape[1]}}; +} + +std::vector MacheteMMKernelInferDtype( + paddle::DataType const& A_dtype, + paddle::DataType const& B_dtype, + paddle::optional const& maybe_group_scales_dtype, + paddle::optional const& maybe_group_zeros_dtype, + paddle::optional const& maybe_channel_scales_dtype, + paddle::optional const& maybe_token_scales_dtype, + std::string const& b_type_str, + std::string const& maybe_out_type_str, + int64_t const& maybe_group_size, + std::string const& maybe_schedule) { + + paddle::DataType maybe_out_type; + if (maybe_out_type_str == "float16") { + maybe_out_type = paddle::DataType::FLOAT16; + } else if (maybe_out_type_str == "bfloat16") { + maybe_out_type = paddle::DataType::BFLOAT16; + } else { + maybe_out_type = A_dtype; + } + return {maybe_out_type}; +} + +PD_BUILD_STATIC_OP(machete_mm) + .Inputs({"A", "B", + paddle::Optional("maybe_group_scales"), + paddle::Optional("maybe_group_zeros"), + paddle::Optional("maybe_channel_scales"), + paddle::Optional("maybe_token_scales")}) + .Outputs({"out"}) + .Attrs({"b_type_str:std::string", "maybe_out_type_str:std::string", "maybe_group_size:int64_t", "maybe_schedule:std::string"}) + .SetKernelFn(PD_KERNEL(MacheteMMKernel)) + .SetInferShapeFn(PD_INFER_SHAPE(MacheteMMKernelInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(MacheteMMKernelInferDtype)); diff --git a/custom_ops/gpu_ops/machete/machete_prepack_B.cu b/custom_ops/gpu_ops/machete/machete_prepack_B.cu index 34bd1c705..223a3e654 100644 --- a/custom_ops/gpu_ops/machete/machete_prepack_B.cu +++ b/custom_ops/gpu_ops/machete/machete_prepack_B.cu @@ -71,3 +71,23 @@ std::vector MachetePrepackBKernel( return {B_prepacked}; } + +std::vector> MachetePrepackBKernelInferShape( + std::vector const& B_shape, std::string const& a_type_str, std::string const& b_type_str, + std::string const& maybe_group_scales_type_str) { + return {{B_shape[1], B_shape[0]}}; +} + +std::vector MachetePrepackBKernelInferDtype( + paddle::DataType const& B_dtype, std::string const& a_type_str, std::string const& b_type_str, + std::string const& maybe_group_scales_type_str) { + return {B_dtype}; +} + +PD_BUILD_STATIC_OP(machete_prepack_B) + .Inputs({"B"}) + .Outputs({"B_prepacked"}) + .Attrs({"a_type_str:std::string", "b_type_str:std::string", "maybe_group_scales_type_str:std::string"}) + .SetKernelFn(PD_KERNEL(MachetePrepackBKernel)) + .SetInferShapeFn(PD_INFER_SHAPE(MachetePrepackBKernelInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(MachetePrepackBKernelInferDtype)); diff --git a/docs/usage/environment_variables.md b/docs/usage/environment_variables.md index dc5d472f5..9f3b6becb 100644 --- a/docs/usage/environment_variables.md +++ b/docs/usage/environment_variables.md @@ -78,7 +78,7 @@ environment_variables: dict[str, Callable[[], Any]] = { "FD_ENABLE_MODEL_LOAD_CACHE": lambda: bool(int(os.getenv("FD_ENABLE_MODEL_LOAD_CACHE", "0"))), # Whether to use Machete for wint4 dense GEMM. - "FD_USE_MACHETE": lambda: os.getenv("FD_USE_MACHETE", "0"), + "FD_USE_MACHETE": lambda: os.getenv("FD_USE_MACHETE", "1"), # Used to truncate the string inserted during thinking when reasoning in a model. ( for ernie4_5_vl, \n\n\n for ernie_x1) "FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR": lambda: os.getenv("FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR", ""), diff --git a/docs/zh/usage/environment_variables.md b/docs/zh/usage/environment_variables.md index 1be359102..1cb9482c5 100644 --- a/docs/zh/usage/environment_variables.md +++ b/docs/zh/usage/environment_variables.md @@ -78,7 +78,7 @@ environment_variables: dict[str, Callable[[], Any]] = { "FD_ENABLE_MODEL_LOAD_CACHE": lambda: bool(int(os.getenv("FD_ENABLE_MODEL_LOAD_CACHE", "0"))), # 是否使用 Machete 后端的 wint4 GEMM. - "FD_USE_MACHETE": lambda: os.getenv("FD_USE_MACHETE", "0"), + "FD_USE_MACHETE": lambda: os.getenv("FD_USE_MACHETE", "1"), # Used to truncate the string inserted during thinking when reasoning in a model. ( for ernie4_5_vl, \n\n\n for ernie_x1) "FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR": lambda: os.getenv("FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR", ""), @@ -87,6 +87,5 @@ environment_variables: dict[str, Callable[[], Any]] = { "FD_CACHE_PROC_EXIT_TIMEOUT": lambda: int(os.getenv("FD_CACHE_PROC_EXIT_TIMEOUT", "600")), # cache_transfer_manager 进程残留时连续错误阈值 - "FD_CACHE_PROC_ERROR_COUNT": lambda: int(os.getenv("FD_CACHE_PROC_ERROR_COUNT", "10")), -} + "FD_CACHE_PROC_ERROR_COUNT": lambda: int(os.getenv("FD_CACHE_PROC_ERROR_COUNT", "10")),} ``` diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 3b0be3df9..f6f6ff6de 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -55,7 +55,7 @@ environment_variables: dict[str, Callable[[], Any]] = { # Set moe backend."cutlass","marlin" and "triton" can be set currently. "FD_MOE_BACKEND": lambda: os.getenv("FD_MOE_BACKEND", "cutlass"), # Whether to use Machete for wint4 dense gemm. - "FD_USE_MACHETE": lambda: os.getenv("FD_USE_MACHETE", "0"), + "FD_USE_MACHETE": lambda: os.getenv("FD_USE_MACHETE", "1"), # Set whether to disable recompute the request when the KV cache is full. "FD_DISABLED_RECOVER": lambda: os.getenv("FD_DISABLED_RECOVER", "0"), # Set triton kernel JIT compilation directory. diff --git a/fastdeploy/model_executor/layers/quantization/ops/machete_mm.py b/fastdeploy/model_executor/layers/quantization/ops/machete_mm.py index b080bb627..ea49809d1 100644 --- a/fastdeploy/model_executor/layers/quantization/ops/machete_mm.py +++ b/fastdeploy/model_executor/layers/quantization/ops/machete_mm.py @@ -167,7 +167,7 @@ def machete_quantize_and_pack( atype, quant_type, scale_type, - )[0] + ) return w_q_prepack, w_s @@ -194,5 +194,5 @@ def machete_wint_mm( out_dtype, # out_dtype group_size, # group_size scheduler, # scheduler - )[0] + ) return out diff --git a/fastdeploy/model_executor/layers/quantization/weight_only.py b/fastdeploy/model_executor/layers/quantization/weight_only.py index b448afa12..5cd7ec79e 100644 --- a/fastdeploy/model_executor/layers/quantization/weight_only.py +++ b/fastdeploy/model_executor/layers/quantization/weight_only.py @@ -38,10 +38,18 @@ if current_platform.is_xpu(): else: from paddle.nn.quant import weight_only_linear +from fastdeploy.model_executor.layers.quantization.ops.machete_mm import _ENABLE_MACHETE + from ..moe import FusedMoE from ..utils import get_tensor from .quant_base import QuantConfigBase, QuantMethodBase +if _ENABLE_MACHETE: + from fastdeploy.model_executor.layers.quantization.ops import ( + machete_quantize_and_pack, + machete_wint_mm, + ) + class WeightOnlyConfig(QuantConfigBase): """ @@ -154,14 +162,11 @@ class WeightOnlyConfig(QuantConfigBase): else: raise ValueError(f"Unsupported MOE backend {layer.use_method}") else: - from fastdeploy.model_executor.layers.quantization.ops.machete_mm import ( - _ENABLE_MACHETE, - ) - if ( _ENABLE_MACHETE and envs.FD_USE_MACHETE == "1" and not layer.is_quantized + and not layer.fd_config.load_config.dynamic_load_weight and layer.weight_shape[1] and layer.weight_shape[1] % 128 == 0 ): @@ -406,9 +411,6 @@ class MacheteWeightOnlyLinearMethod(WeightOnlyLinearMethod): raise NotImplementedError("Machete kernel doesn't support prequant. Please set FD_USE_MACHETE to 0.") def process_loaded_weights(self, layer, weight) -> None: - from fastdeploy.model_executor.layers.quantization.ops import ( - machete_quantize_and_pack, - ) # Using group scale for machete, group size is 128 quanted_weight_tensor, weight_scale_tensor = machete_quantize_and_pack( @@ -421,7 +423,6 @@ class MacheteWeightOnlyLinearMethod(WeightOnlyLinearMethod): layer.weight_scale.set_value(weight_scale_tensor.astype(paddle.get_default_dtype())) def apply(self, layer, x): - from fastdeploy.model_executor.layers.quantization.ops import machete_wint_mm # Using group scale for machete, group size is 128 linear_out = machete_wint_mm( diff --git a/tests/ci_use/EB_VL_Lite/test_EB_VL_Lite_serving.py b/tests/ci_use/EB_VL_Lite/test_EB_VL_Lite_serving.py index 4cf730667..5da6de8e0 100644 --- a/tests/ci_use/EB_VL_Lite/test_EB_VL_Lite_serving.py +++ b/tests/ci_use/EB_VL_Lite/test_EB_VL_Lite_serving.py @@ -34,6 +34,8 @@ FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8234)) # List of ports to clean before and after tests PORTS_TO_CLEAN = [FD_API_PORT, FD_ENGINE_QUEUE_PORT, FD_METRICS_PORT, FD_CACHE_QUEUE_PORT] +os.environ["FD_USE_MACHETE"] = "0" + def is_port_open(host: str, port: int, timeout=1.0): """ diff --git a/tests/e2e/test_EB_VL_Lite_serving.py b/tests/e2e/test_EB_VL_Lite_serving.py index f3d275071..5916ae301 100644 --- a/tests/e2e/test_EB_VL_Lite_serving.py +++ b/tests/e2e/test_EB_VL_Lite_serving.py @@ -35,6 +35,8 @@ FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8333)) # List of ports to clean before and after tests PORTS_TO_CLEAN = [FD_API_PORT, FD_ENGINE_QUEUE_PORT, FD_METRICS_PORT, FD_CACHE_QUEUE_PORT] +os.environ["FD_USE_MACHETE"] = "0" + def is_port_open(host: str, port: int, timeout=1.0): """ diff --git a/tests/layers/test_weight_only_linear.py b/tests/layers/test_weight_only_linear.py new file mode 100644 index 000000000..4d02592d7 --- /dev/null +++ b/tests/layers/test_weight_only_linear.py @@ -0,0 +1,218 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import shutil +import unittest + +import numpy as np +import paddle +import paddle.device.cuda.graphs as graphs + +from fastdeploy.config import ( + CacheConfig, + FDConfig, + GraphOptimizationConfig, + LoadConfig, + ModelConfig, + ParallelConfig, +) +from fastdeploy.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear +from fastdeploy.model_executor.layers.quantization.weight_only import ( + WINT4Config, + WINT8Config, +) +from fastdeploy.scheduler import SchedulerConfig + +paddle.set_default_dtype("bfloat16") +paddle.seed(1024) + + +class WeightOnlyLinearWrapper(paddle.nn.Layer): + def __init__( + self, + model_config: ModelConfig, + tp_size: int = 1, + prefix: str = "layer0", + quant_type: str = "wint4", + ): + super().__init__() + self.model_config = model_config + + self.tp_size = tp_size + self.prefix = prefix + self.fd_config = FDConfig( + model_config=self.model_config, + parallel_config=ParallelConfig({"tensor_parallel_size": self.tp_size}), + quant_config=WINT8Config({}) if quant_type == "wint8" else WINT4Config({}), + load_config=LoadConfig({}), + graph_opt_config=GraphOptimizationConfig({}), + scheduler_config=SchedulerConfig({}), + cache_config=CacheConfig({}), + ) + + self.fd_config.parallel_config.tp_group = None + + self.qkv_proj = QKVParallelLinear( + self.fd_config, + prefix=f"{prefix}.qkv_proj", + with_bias=False, + ) + + self.o_proj = RowParallelLinear( + self.fd_config, + prefix=f"{prefix}.o_proj", + input_size=self.fd_config.model_config.head_dim * self.fd_config.model_config.num_attention_heads, + output_size=self.fd_config.model_config.hidden_size, + ) + + qkv_proj_weight_shape = [ + self.qkv_proj.input_size, + self.qkv_proj.output_size, + ] + + o_proj_weight_shape = [ + self.o_proj.input_size, + self.o_proj.output_size, + ] + + state_dict = {} + state_dict[f"{prefix}.qkv_proj.weight"] = paddle.randn(qkv_proj_weight_shape, paddle.bfloat16) + state_dict[f"{prefix}.o_proj.weight"] = paddle.randn(o_proj_weight_shape, paddle.bfloat16) + self.qkv_proj.load_state_dict(state_dict) + self.o_proj.load_state_dict(state_dict) + + self.input_size = self.o_proj.input_size + self.output_size = self.qkv_proj.output_size + + def forward(self, x): + x = self.o_proj(x) + x = self.qkv_proj(x) + return x + + +class TestWeightOnlyLinear(unittest.TestCase): + def setUp(self) -> None: + self.model_name_or_path = None + self.model_config = self.build_model_config() + + def build_model_config(self) -> ModelConfig: + model_path = os.getenv("TEST_MODEL_PATH") + if model_path: + model_cofig_path = model_path + else: + model_cofig_path = self.build_config_json() + return ModelConfig( + { + "model": model_cofig_path, + "max_model_len": 2048, + } + ) + + def build_config_json(self) -> str: + config_dict = { + "architectures": ["Qwen3MoeForCausalLM"], + "hidden_size": 2048, + "head_dim": 128, + "num_attention_heads": 32, + "num_key_value_heads": 4, + "dtype": "bfloat16", + } + + tmp_dir = "./tmp_wint" + os.makedirs(tmp_dir, exist_ok=True) + with open(f"./{tmp_dir}/config.json", "w") as f: + json.dump(config_dict, f) + self.model_name_or_path = os.path.join(os.getcwd(), tmp_dir) + return self.model_name_or_path + + def run_wint_linear(self, type="qkv_proj", quant_type="wint4"): + weight_only_linear = WeightOnlyLinearWrapper(self.model_config, quant_type=quant_type) + if type == "qkv_proj": + input_size = weight_only_linear.qkv_proj.input_size + mm = weight_only_linear.qkv_proj + elif type == "o_proj": + input_size = weight_only_linear.o_proj.input_size + mm = weight_only_linear.o_proj + else: + input_size = weight_only_linear.input_size + mm = weight_only_linear + + print(type, quant_type) + print("{:<15} {:<40} {:<15}".format("Batch Size", "Last 5 Times (us)", "Last Time (us)")) + + linear_cuda_graphs = [None] * 100 + input = [None] * 100 + for idx, bsz in enumerate([10, 20, 40, 50, 60, 100, 200, 1000, 2000]): + + input[idx] = paddle.rand((bsz, input_size), dtype=paddle.bfloat16) + + num_warmups = 10 + for _ in range(num_warmups): + output = mm(input[idx]) + + num_layers = 10 + linear_cuda_graphs[idx] = graphs.CUDAGraph() + linear_cuda_graphs[idx].capture_begin() + + for _ in range(num_layers): + output = mm(input[idx]) + + linear_cuda_graphs[idx].capture_end() + + num_tests = 10 + start_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(num_tests)] + end_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(num_tests)] + for i in range(num_tests): + start_events[i].record() + + linear_cuda_graphs[idx].replay() + + end_events[i].record() + paddle.device.synchronize() + + times = np.array([round(s.elapsed_time(e), 2) for s, e in zip(start_events, end_events)])[1:] + times = times * 1e3 / num_layers + last_5_times = times[-5:] + last_time = times[-1] + print("{:<15} {:<40} {:<15}".format(bsz, str(last_5_times), last_time)) + return output + + def test_qkv_linear(self): + print("===============Test QKV Quantized Linear Layer================") + for use_machete in ["0", "1"]: + os.environ["FD_USE_MACHETE"] = use_machete + self.run_wint_linear("qkv_proj") + + def test_out_linear(self): + print("================Test OUT Quantized Linear Layer================") + for use_machete in ["0", "1"]: + os.environ["FD_USE_MACHETE"] = use_machete + self.run_wint_linear("o_proj") + + def test_both_linear(self): + print("===========Test both OUT and QKV Quantized Linear Layer=========") + for use_machete in ["0", "1"]: + os.environ["FD_USE_MACHETE"] = use_machete + self.run_wint_linear("out_proj+qkv_proj") + + def tearDown(self) -> None: + if self.model_name_or_path: + print("Remove tmp model config file") + shutil.rmtree(self.model_name_or_path) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/model_loader/test_model_cache.py b/tests/model_loader/test_model_cache.py index f347b26b7..ff924a6d0 100644 --- a/tests/model_loader/test_model_cache.py +++ b/tests/model_loader/test_model_cache.py @@ -22,6 +22,8 @@ project_root = os.path.abspath(os.path.join(current_dir, "..")) if project_root not in sys.path: sys.path.insert(0, project_root) +os.environ["FD_USE_MACHETE"] = "0" + from tests.model_loader.utils import ( check_tokens_id_and_text_close, form_model_get_output_topp0, diff --git a/tests/model_loader/test_torch_model.py b/tests/model_loader/test_torch_model.py index 841391b7a..e10c1376d 100644 --- a/tests/model_loader/test_torch_model.py +++ b/tests/model_loader/test_torch_model.py @@ -22,6 +22,8 @@ project_root = os.path.abspath(os.path.join(current_dir, "..")) if project_root not in sys.path: sys.path.insert(0, project_root) +os.environ["FD_USE_MACHETE"] = "0" + from tests.model_loader.utils import ( calculate_diff_rate, form_model_get_output_topp0,