mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 20:02:53 +08:00 
			
		
		
		
	 479c8b85d3
			
		
	
	479c8b85d3
	
	
	
		
			
			* support machete weight only gemm * add generate * update * fix * change file location * add sm_version limit * fix * fix * fix ci * fix coverage * fix xpu
		
			
				
	
	
		
			186 lines
		
	
	
		
			5.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			186 lines
		
	
	
		
			5.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| 
 | |
| from typing import Optional
 | |
| 
 | |
| import numpy as np
 | |
| import paddle
 | |
| 
 | |
| from fastdeploy.platforms import current_platform
 | |
| 
 | |
| 
 | |
| def get_sm_version():
 | |
|     prop = paddle.device.cuda.get_device_properties()
 | |
|     cc = prop.major * 10 + prop.minor
 | |
|     return cc
 | |
| 
 | |
| 
 | |
| if current_platform.is_cuda() and get_sm_version() == 90:
 | |
|     from fastdeploy.model_executor.ops.gpu import machete_mm, machete_prepack_B
 | |
| 
 | |
| 
 | |
| def get_pack_factor(num_bits):
 | |
|     assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
 | |
|     return 32 // num_bits
 | |
| 
 | |
| 
 | |
| def pack_rows(
 | |
|     q_w: paddle.Tensor,
 | |
|     num_bits: int,
 | |
|     size_k: int,
 | |
|     size_n: int,
 | |
| ):
 | |
|     assert q_w.shape == [size_k, size_n]
 | |
| 
 | |
|     pack_factor = get_pack_factor(num_bits)
 | |
|     assert size_k % pack_factor == 0
 | |
| 
 | |
|     orig_device = q_w.place
 | |
|     q_w_np = q_w.numpy().astype(np.uint32)
 | |
| 
 | |
|     q_res = np.zeros((size_k // pack_factor, size_n), dtype=np.uint32)
 | |
| 
 | |
|     for i in range(pack_factor):
 | |
|         q_res |= q_w_np[i::pack_factor, :] << num_bits * i
 | |
| 
 | |
|     q_res = paddle.to_tensor(q_res.astype(np.int32), place=orig_device)
 | |
|     return q_res
 | |
| 
 | |
| 
 | |
| def quantize_weights(
 | |
|     w: paddle.Tensor,
 | |
|     group_size: Optional[int],
 | |
|     quant_type: str = "uint4b8",
 | |
| ):
 | |
|     """
 | |
|     Quantize weights in PaddlePaddle, similar to PyTorch implementation.
 | |
| 
 | |
|     Args:
 | |
|         w: Input weight tensor (must be float type).
 | |
|         quant_type: Target quantization type (e.g., `uint4`, `uint4b8`).
 | |
|         group_size: Group size for quantization. If `-1`, use channel-wise quantization.
 | |
|         zero_points: Whether to compute zero points (only for unsigned quant types).
 | |
|         ref_zero_points_after_scales: If True, apply zero points after scales in dequantization.
 | |
| 
 | |
|     Returns:
 | |
|         w_ref: Dequantized reference weights.
 | |
|         w_q: Quantized weights.
 | |
|         w_s: Scales (None if `group_size` is None).
 | |
|     """
 | |
|     assert paddle.is_floating_point(w), "w must be float type"
 | |
|     assert quant_type in ["uint4", "uint4b8"], "only support quant_type = uint4, uint4b8"
 | |
| 
 | |
|     orig_device = w.place
 | |
|     size_k, size_n = w.shape
 | |
| 
 | |
|     if group_size == -1:
 | |
|         group_size = size_k
 | |
| 
 | |
|     # Reshape to [group_size, -1]
 | |
|     if group_size is not None and group_size < size_k:
 | |
|         w = w.reshape([-1, group_size, size_n])
 | |
|         w = w.transpose([1, 0, 2])
 | |
|         w = w.reshape([group_size, -1])
 | |
| 
 | |
|     # Compute scale for each group
 | |
|     max_val = paddle.max(w, axis=0, keepdim=True)
 | |
|     min_val = paddle.min(w, axis=0, keepdim=True)
 | |
| 
 | |
|     max_q_val = float(7.0)
 | |
|     min_q_val = float(-8.0)
 | |
| 
 | |
|     w_s = paddle.ones([1], dtype=paddle.float32)  # unscaled case
 | |
| 
 | |
|     if group_size is not None:
 | |
|         # Avoid division by zero
 | |
|         max_scale = paddle.maximum(
 | |
|             paddle.abs(max_val / (max_q_val if max_q_val != 0 else float("inf"))),
 | |
|             paddle.abs(min_val / (min_q_val if min_q_val != 0 else float("inf"))),
 | |
|         )
 | |
|         w_s = max_scale
 | |
| 
 | |
|     # Quantize
 | |
|     w_q = paddle.round(w / w_s).astype(paddle.int32)
 | |
|     w_q = paddle.clip(w_q, min_q_val, max_q_val)
 | |
| 
 | |
|     # if hasattr(quant_type, 'bias'):  # Custom quantization bias (if applicable)
 | |
|     # w_q += quant_type.bias
 | |
|     if quant_type == "uint4b8":
 | |
|         w_q += 8
 | |
| 
 | |
|     # Restore original shapes
 | |
|     if group_size is not None and group_size < size_k:
 | |
| 
 | |
|         def reshape_w(w_tensor):
 | |
|             w_tensor = w_tensor.reshape([group_size, -1, size_n])
 | |
|             w_tensor = w_tensor.transpose([1, 0, 2])
 | |
|             w_tensor = w_tensor.reshape([size_k, size_n])
 | |
|             return w_tensor
 | |
| 
 | |
|         w_q = reshape_w(w_q)
 | |
|         w_s = w_s.reshape([-1, size_n])
 | |
| 
 | |
|     # Move tensors back to original device
 | |
|     w_q = w_q.to(orig_device)
 | |
|     if w_s is not None:
 | |
|         w_s = w_s.to(orig_device)
 | |
| 
 | |
|     return w_q, w_s
 | |
| 
 | |
| 
 | |
| def machete_quantize_and_pack(
 | |
|     w: paddle.Tensor,
 | |
|     atype: paddle.dtype,
 | |
|     quant_type: str = "uint4b8",
 | |
|     scale_type: str = "",
 | |
|     group_size: int = -1,
 | |
| ):
 | |
|     w_q, w_s = quantize_weights(w, group_size, quant_type=quant_type)
 | |
|     w_q = pack_rows(w_q, 4, *w_q.shape)
 | |
|     w_q_col = w_q.transpose([1, 0]).contiguous()  # convert to col major
 | |
|     w_q_prepack = machete_prepack_B(
 | |
|         w_q_col,
 | |
|         atype,
 | |
|         quant_type,
 | |
|         scale_type,
 | |
|     )[0]
 | |
|     return w_q_prepack, w_s
 | |
| 
 | |
| 
 | |
| def machete_wint_mm(
 | |
|     x: paddle.Tensor,
 | |
|     w_prepack: paddle.Tensor,
 | |
|     w_g_s: paddle.Tensor,
 | |
|     w_g_zp: Optional[paddle.Tensor] = None,
 | |
|     w_ch_s: Optional[paddle.Tensor] = None,
 | |
|     w_tok_s: Optional[paddle.Tensor] = None,
 | |
|     weight_dtype: str = "uint4b8",
 | |
|     group_size: int = -1,
 | |
|     out_dtype: str = "",
 | |
|     scheduler: str = "",
 | |
| ):
 | |
|     out = machete_mm(
 | |
|         x,
 | |
|         w_prepack,
 | |
|         w_g_s,  # group scales
 | |
|         w_g_zp,  # group zeros
 | |
|         w_ch_s,  # per-channel scale
 | |
|         w_tok_s,  # per-token scale
 | |
|         weight_dtype,  # weight_dtype
 | |
|         out_dtype,  # out_dtype
 | |
|         group_size,  # group_size
 | |
|         scheduler,  # scheduler
 | |
|     )[0]
 | |
|     return out
 |