delete useless code (#4544)

Co-authored-by: root <root@yqlcc01-sys-rpm12rzmwjd.yqlcc01.baidu.com>
This commit is contained in:
YuanRisheng
2025-10-23 13:40:34 +08:00
committed by GitHub
parent 8a02ab43a8
commit ac4f5ca272
3 changed files with 0 additions and 1144 deletions

View File

@@ -17,16 +17,13 @@
import base64
import datetime
import hashlib
import io
import os
import threading
import uuid
from pathlib import Path
import numpy as np
import requests
from PIL import Image
from PIL.ExifTags import TAGS
RAW_VIDEO_DIR = "./download_tmp/raw_video/"
RAW_IMAGE_DIR = "./download_tmp/raw_images/"
@@ -110,155 +107,3 @@ def get_downloadable(
retry_interval=retry_interval,
)
return downloaded_path
def get_downloadable_image(download_path, need_exif_info, retry_max_time=0, retry_interval=3):
"""
带上exif info和图像处理的get downloadable
"""
def get_image_exif(image):
exif_data = image._getexif()
exif_info = {}
if exif_data is not None:
for tag, value in exif_data.items():
tag_name = TAGS.get(tag, tag)
exif_info[tag_name] = value.strip()
return exif_info
def has_transparent_background(img):
"""判断图片是否有背景"""
if img.mode in ("RGBA", "LA") or (img.mode == "P" and "transparency" in img.info):
# Check for any pixel with alpha channel less than 255 (fully opaque)
alpha = img.convert("RGBA").split()[-1]
if alpha.getextrema()[0] < 255:
return True
return False
def add_white_background(img):
"""
给透明背景的图,加个白色背景
"""
if img.mode != "RGBA":
img = img.convert("RGBA")
# 创建一个白色背景的图像,尺寸与原图一致
img_white_background = Image.new("RGBA", img.size, (255, 255, 255))
# 将原图粘贴到白色背景上
img_white_background.paste(img, (0, 0), img)
return img_white_background
def change_I16_to_L(img):
"""
将图片从I;16模式转换为L模式
"""
# 由于I模式的point函数只支持加减乘所以下面的* (1 / 256)不能改成除法
return img.point(lambda i: i * (1 / 256)).convert("L")
image = get_downloadable(
download_path,
save_to_disk=False,
retry=retry_max_time,
retry_interval=retry_interval,
)
if isinstance(image, Image.Image):
pil_image = image
else:
pil_image = Image.open(io.BytesIO(image))
if need_exif_info:
try:
exif_info = get_image_exif(pil_image)
except Exception:
exif_info = {}
else:
exif_info = {}
try:
if pil_image.mode == "I;16":
pil_image = change_I16_to_L(pil_image)
if has_transparent_background(pil_image):
pil_image = add_white_background(pil_image)
except Exception:
pass
return pil_image.convert("RGB"), exif_info
def str2hash(url):
"""
从一个str的到url
"""
return hashlib.sha256(url.encode()).hexdigest()
def pil2hash(pil):
"""
从一个PIL.Image到hash
"""
byte_io = io.BytesIO()
pil.save(byte_io, format="PNG") # 选择无损格式,避免压缩影响
image_bytes = byte_io.getvalue()
return hashlib.sha256(image_bytes).hexdigest()
def imagepath_to_base64(image_path):
"""imagepath_to_base64"""
image = Image.open(image_path).convert("RGB")
buffer = io.BytesIO()
image.save(buffer, format="JPEG")
image_bytes = buffer.getvalue()
base64_encoded = base64.b64encode(image_bytes).decode("utf-8")
return base64_encoded
def pil_image_to_base64(image):
"""pil_image_to_base64"""
buffer = io.BytesIO()
image.save(buffer, format="JPEG")
image_bytes = buffer.getvalue()
base64_encoded = base64.b64encode(image_bytes).decode("utf-8")
return base64_encoded
def http_to_pil_image(url):
"""http_to_pil_image"""
response = requests.get(url)
image_data = io.BytesIO(response.content)
pil_image = Image.open(image_data).convert("RGB")
return pil_image
def http_to_image_base64(url):
"""http_to_image_base64"""
response = requests.get(url)
image_data = io.BytesIO(response.content)
return base64.b64encode(image_data.getvalue()).decode("utf-8")
def base64_to_pil_image(base64_string):
""" " base64_to_pil_image"""
image_bytes = base64.b64decode(base64_string)
buffer = io.BytesIO(image_bytes)
image = Image.open(buffer)
return image
def get_hashable(to_be_hashed):
"""get hashable"""
if isinstance(to_be_hashed, bytes):
return to_be_hashed
elif isinstance(to_be_hashed, Image.Image):
return to_be_hashed.tobytes()
elif isinstance(to_be_hashed, str):
return to_be_hashed.encode("utf-8")
else:
raise ValueError(f"not support type: {type(to_be_hashed)}")
def load_dict_from_npz(npzfile):
"""从npz文件读取数据"""
with np.load(npzfile, allow_pickle=True) as data:
loaded_dict = {key: data[key] for key in data.files}
return loaded_dict

View File

@@ -17,29 +17,8 @@
from __future__ import annotations
import enum
import hashlib
import json
import os
import random
import re
import struct
from functools import partial
from typing import NamedTuple, Optional
import numpy as np
import paddle
from paddle.common_ops_import import convert_dtype
from paddleformers.transformers.model_utils import _add_variant
from paddleformers.transformers.utils import paddleformers_load
from paddleformers.utils.env import (
PADDLE_WEIGHTS_INDEX_NAME,
SAFE_MASTER_WEIGHTS_INDEX_NAME,
SAFE_PEFT_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_INDEX_NAME,
)
from paddleformers.utils.log import logger
from tqdm import tqdm
MAX_BSZ = 512
MAX_DRAFT_TOKENS = 6
@@ -67,428 +46,3 @@ class WeightMeta(NamedTuple):
weight_name: str
is_column: bool
extra: Optional[str] = None
class UniqueIDGenerator:
"""
The generator for the export model id
"""
def __init__(self):
pass
def generate_unique_id(self, state_dict):
"""
Generate the model id from the timestamp
"""
keys = state_dict.keys()
sorted_keys = sorted(keys)
first_key = sorted_keys[0]
first_parameter = state_dict[first_key].cast("float32")
# 假设模型参数是唯一的通过第一个key来获取md5sum
model_md5 = hashlib.md5(str(first_parameter.sum()).encode("utf-8")).hexdigest()
unique_id = f"{model_md5}-{random.randint(10000, 99999)}"
return unique_id
def load_sharded_checkpoint(folder, variant=None, return_numpy=False):
"""
This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
loaded in the model.
Args:
folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint.
variant (`str`): The model variant.
"""
# Load the index
pdparams_file = os.path.join(folder, _add_variant("model_state.pdparams", variant))
lora_pdparams_file = os.path.join(folder, _add_variant("lora_model_state.pdparams", variant))
safetensors_file = os.path.join(folder, _add_variant("model.safetensors", variant))
if os.path.isfile(pdparams_file):
return paddle.load(pdparams_file, return_numpy=return_numpy)
if os.path.isfile(lora_pdparams_file):
return paddle.load(lora_pdparams_file, return_numpy=return_numpy)
if os.path.isfile(safetensors_file):
try:
from paddleformers.utils.safetensors import fast_load_file as safe_load_file
except ImportError:
from safetensors.numpy import load_file as safe_load_file
state_dict = safe_load_file(safetensors_file)
if not return_numpy:
for key in list(state_dict.keys()):
if isinstance(state_dict[key], np.ndarray):
state_dict[key] = paddle.Tensor(state_dict.pop(key), zero_copy=True)
return state_dict
index_file = os.path.join(folder, _add_variant(PADDLE_WEIGHTS_INDEX_NAME, variant))
safe_index_file = os.path.join(folder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant))
safe_master_file = os.path.join(folder, _add_variant(SAFE_MASTER_WEIGHTS_INDEX_NAME, variant))
safe_peft_file = os.path.join(folder, _add_variant(SAFE_PEFT_WEIGHTS_INDEX_NAME, variant))
index_present = os.path.isfile(index_file)
safe_index_present = os.path.isfile(safe_index_file)
safe_master_present = os.path.isfile(safe_master_file)
safe_peft_present = os.path.isfile(safe_peft_file)
load_safe = False
load_index = None
if safe_index_present:
load_safe = True # load safe due to preference
load_index = safe_index_file
elif safe_master_present:
load_safe = True
load_index = safe_master_file
elif index_present:
load_index = index_file
elif safe_peft_present:
load_safe = True
load_index = safe_peft_file
else:
raise ValueError(f"Could not find {index_file} or {safe_index_file} or {safe_peft_file}")
if load_safe:
try:
from paddleformers.utils.safetensors import fast_load_file as safe_load_file
except ImportError:
from safetensors.numpy import load_file as safe_load_file
with open(load_index, "r", encoding="utf-8") as f:
index = json.load(f)
shard_files = list(set(index["weight_map"].values()))
loader = safe_load_file if load_safe else partial(paddleformers_load, map_location="np" if return_numpy else "cpu")
ret = {}
for shard_file in tqdm(shard_files):
state_dict = loader(os.path.join(folder, shard_file))
ret.update(state_dict)
if not return_numpy:
for key in list(ret.keys()):
if isinstance(ret[key], np.ndarray):
ret[key] = paddle.Tensor(ret.pop(key), zero_copy=True)
return ret
def convert_ndarray_dtype(np_array: np.ndarray, target_dtype: str) -> np.ndarray:
"""convert ndarray
Args:
np_array (np.ndarray): numpy ndarray instance
target_dtype (str): the target dtype
Returns:
np.ndarray: converted numpy ndarray instance
"""
source_dtype = convert_dtype(np_array.dtype)
if (
source_dtype == "uint16"
and target_dtype == "bfloat16"
and paddle.is_compiled_with_custom_device("iluvatar_gpu")
):
return np_array.view(dtype=target_dtype)
if source_dtype == "uint16" or target_dtype == "bfloat16":
if paddle.is_compiled_with_xpu():
# xpu not support bf16.
tensor = paddle.to_tensor(np_array, place=paddle.CPUPlace())
else:
tensor = paddle.to_tensor(np_array)
tensor = paddle.cast(tensor, target_dtype)
return tensor.numpy()
# TODO(wj-Mcat): device_guard will slow the converting
# with device_guard("cpu"):
# tensor = paddle.to_tensor(np_array)
# tensor = paddle.cast(tensor, target_dtype)
# return tensor.numpy()
if target_dtype == "bfloat16":
target_dtype = "uint16"
return np_array.astype(target_dtype)
def set_seed(seed: int):
"""
set random seed for all random modules
"""
paddle.seed(seed)
random.seed(seed)
np.random.seed(seed)
def pad_batch_data(insts, pad_id=0, return_seq_len=False, pad_style="right"):
"""Pad the instances to the max sequence length in batch."""
# pad to max input len i bsz
max_len = max(map(len, insts))
# pad to max input len
# max_len = args.max_len
if pad_style == "left":
inst_data = np.array([[pad_id] * (max_len - len(inst)) + list(inst) for inst in insts])
else:
inst_data = np.array([list(inst) + [pad_id] * (max_len - len(inst)) for inst in insts])
if return_seq_len:
seq_len = np.array([len(inst) for inst in insts])
return inst_data.astype("int64").reshape([-1, max_len]), seq_len
else:
return inst_data.astype("int64").reshape([-1, max_len])
def load_prefix_weights(
prefix_path: str,
inference: bool = False,
batch_size: int = 1,
dtype: str = "bfloat16",
) -> np.ndarray | list[paddle.Tensor]:
"""load prefix weight by path
Args:
prefix_path (str): the path of prefix weight
"""
past_key_values = paddle.to_tensor(np.load(f"{prefix_path}/pre_caches.npy")).unsqueeze(2)
if batch_size > 1:
past_key_values = paddle.concat([past_key_values] * batch_size, axis=2)
# .chatglm static model require one tensor, otherwise list of tensor
past_key_values = past_key_values.astype(dtype)
if inference:
return past_key_values.numpy()
return past_key_values
def w4a8_weight_convert(state_dict):
"""W4A8 权重转换函数
Args:
state_dict (dict): state_dict of model
"""
def w4_weight_squash(value, name, w4a8_weight_bites_name_map):
weight_dq = value
# W8表象下的W4权重的absmax值为112使用正负112进行权重类型判断
if weight_dq.max() == 112 or weight_dq.min() == -112:
weight_dq = weight_dq.cast("int8")
np_weight_dq = np.array(weight_dq, dtype="int8").view("uint8")
np_weight_dq_left_div_16 = (np_weight_dq / 16).astype("int8")
# weight_q = (weight_dq/16).cast('int8')
weight_q = paddle.to_tensor(np_weight_dq_left_div_16, dtype="int8")
logger.debug(f"int4 weight:{name}")
w4a8_weight_bites_name_map[name] = 4
return weight_q.cast("int8")
elif weight_dq.max() == 127 or weight_dq.min() == -128:
logger.debug(f"int8 weight:{name}")
w4a8_weight_bites_name_map[name] = 8
return weight_dq.cast("int8")
else:
logger.debug(f"fp16/bf16/float weight:{name}")
return weight_dq
w4a8_weight_bites_name_map = {}
for name, value in state_dict.items():
if value.dtype == "uint16":
weight_q = w4_weight_squash(
paddle.to_tensor(value).cast("float32"),
name,
w4a8_weight_bites_name_map,
)
state_dict[name] = weight_q.numpy() if weight_q is not None else value
del weight_q
w4a8_weight_bites_layers_map = {}
w4a8_weight_bites_layers_map["qkv_gemm_bits_map"] = []
w4a8_weight_bites_layers_map["out_gemm_bits_map"] = []
w4a8_weight_bites_layers_map["up_gate_proj_gemm_bits_map"] = []
w4a8_weight_bites_layers_map["down_proj_gemm_bits_map"] = []
for name_keys, gemm_bits in w4a8_weight_bites_name_map.items():
if "qkv_proj" in name_keys:
w4a8_weight_bites_layers_map["qkv_gemm_bits_map"].append(gemm_bits)
elif "out_proj" in name_keys:
w4a8_weight_bites_layers_map["out_gemm_bits_map"].append(gemm_bits)
elif "linear1" in name_keys:
w4a8_weight_bites_layers_map["up_gate_proj_gemm_bits_map"].append(gemm_bits)
elif "linear2" in name_keys:
w4a8_weight_bites_layers_map["down_proj_gemm_bits_map"].append(gemm_bits)
logger.debug(f"w4a8_weight_bites_layers_map:{w4a8_weight_bites_layers_map}")
return state_dict, w4a8_weight_bites_layers_map
def _vocab_size_with_padding(vocab_size, div_unit, mp_degree):
padded_size = vocab_size
multiple = div_unit * mp_degree
while (padded_size % multiple) != 0:
padded_size += 1
# logger.warning(
# " > padded vocab (size: {}) with {} dummy tokens "
# "(new size: {})".format(vocab_size, padded_size - vocab_size, padded_size)
# )
return padded_size
def save_test_case(cases: list[list[dict]], file: str):
"""save test to result file
Args:
cases (list[list[dict]]): the content of case
file (str): the path of saved file
"""
with open(file, "w+", encoding="utf-8") as f:
for case in cases:
raw = json.dumps(case, ensure_ascii=False)
f.write(raw + "\n")
def infer_save_test_case(cases: list[list[dict]], file: str):
"""save test to result file
Args:
cases (list[list[dict]]): the content of case
file (str): the path of saved file
"""
with open(file, "a+", encoding="utf-8") as f:
for case in cases:
raw = json.dumps(case, ensure_ascii=False)
f.write(raw + "\n")
def deserialize_from_file(fp):
"""
deserialize a binary file into an array
"""
x_type = fp.read(1)
x_type_out = struct.unpack("c", x_type)[0]
# data
data_list = []
if x_type_out == b"0":
data = fp.read(4)
data_out = struct.unpack("f", data)[0]
while data:
data_out = struct.unpack("f", data)[0]
data_list.append(data_out)
data = fp.read(4)
elif x_type_out == b"1":
data = fp.read(8)
while data:
data_out = struct.unpack("l", data)[0]
data_list.append(data_out)
data = fp.read(8)
elif x_type_out == b"2":
data = fp.read(4)
while data:
data_out = struct.unpack("i", data)[0]
data_list.append(data_out)
data = fp.read(4)
else:
print("type error")
data_arr = np.array(data_list)
return data_arr
def calculate_effective_tokens(training_args, train_dataset, max_seq_len):
"""
Calculate the effective tokens during training.
"""
total_effective_tokens = 0
try:
data_parallel_degree = training_args.data_parallel_degree
except Exception:
data_parallel_degree = 1
if training_args.sharding_parallel_degree > 1:
sharding_parallel_degree = training_args.sharding_parallel_degree
else:
sharding_parallel_degree = 1
total_batch = (
training_args.max_steps
* training_args.per_device_train_batch_size
* training_args.gradient_accumulation_steps
* sharding_parallel_degree
* data_parallel_degree
)
for i, data in enumerate(train_dataset):
if i == total_batch:
break
for dd in data:
total_effective_tokens += len(dd.token_ids)
total_tokens = total_batch * max_seq_len
return total_effective_tokens, total_tokens
def parser_quant_type(quant_type):
"""
Parse the quantization type string and return the corresponding quantization types for weights,
activations, and custom.
Args:
quant_type (str): The quantization type string. It can be one of the following formats:
- "weight_only_int8" or "wint8": Only weights are quantized to int8.
- "weight_only_int4" or "wint4": Only weights are quantized to int4.
- A custom string in the format of "wxaybzcfp8", where 'x', 'y', 'z' are the quantization bitwidths
for weights, activations, and custom respectively,
and 'a', 'b', 'c' are the prefixes indicating the quantization types
(e.g., 'fp8' for floating-point 8-bit).
If a prefix is missing, the default quantization type will be used.
Returns:
tuple: A tuple of three strings representing the quantization types for weights, activations,
and custom respectively.
If the input is "weight_only_int8" or "wint8", returns ("int8", default_type, default_type).
If the input is "weight_only_int4" or "wint4", returns ("int4", default_type, default_type).
For custom strings, returns the parsed quantization types based on the input format.
Raises:
AssertionError: If the custom quantization type string format is incorrect.
"""
default_type = paddle.get_default_dtype()
if quant_type == "default" or quant_type is None:
return default_type, default_type, default_type
conver_dict = {
"8": "int8",
"4": "int4",
"16": paddle.get_default_dtype(),
"fp8": "float8_e4m3fn",
"fp16": "float16",
"bf16": "bfloat16",
"fp32": "float32",
}
cache_type = default_type
if "c8" in quant_type:
cache_type = "int8"
elif "cfp8" in quant_type:
cache_type = "fp8"
elif "c4" in quant_type:
cache_type = "int4"
if "weight_only_int8" in quant_type or "wint8" in quant_type:
return "int8", default_type, cache_type
elif "weight_only_int4" in quant_type or "wint4" in quant_type:
return "int4", default_type, cache_type
else:
# split quant type, eg. w4afp8c8 -> ['w', '4', 'a', 'fp8', 'c', '8']
pattern = f"({'|'.join(map(re.escape, ['w', 'a', 'c']))})"
splited_type = re.split(pattern, quant_type)
splited_type = [tmp_type for tmp_type in splited_type if tmp_type]
assert len(splited_type) % 2 == 0 and len(splited_type) <= 6, f"Quant type[{quant_type}] format error."
quant_type_list = []
if "w" in splited_type:
w_idx = splited_type.index("w")
quant_type_list.append(conver_dict[splited_type[w_idx + 1]])
else:
quant_type_list.append(default_type)
if "a" in splited_type:
a_idx = splited_type.index("a")
quant_type_list.append(conver_dict[splited_type[a_idx + 1]])
else:
quant_type_list.append(default_type)
if "c" in splited_type:
c_idx = splited_type.index("c")
quant_type_list.append(conver_dict[splited_type[c_idx + 1]])
else:
quant_type_list.append(default_type)
return quant_type_list[0], quant_type_list[1], quant_type_list[2]

View File

@@ -1,543 +0,0 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import paddle
import triton.language as tl
from paddle import _C_ops
from paddle.base.framework import OpProtoHolder
from paddle.framework import in_dynamic_or_pir_mode
from fastdeploy.model_executor.ops.triton_ops.triton_utils import (
get_dtype_str,
paddle_use_triton,
rendering_common_template,
)
BLOCK_SIZE_M = 16
def invoke_fused_moe_kernel(
A,
B,
C,
B_scale,
B_super_scale,
B_code_scale,
B_code_zp,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
mul_routed_weight=False,
top_k=-1,
group_size=-1,
):
"""
Invoke Fused Moe Kernel
"""
KK = A.shape[-1]
NN = B.shape[-1]
sstride_am, sstride_ak = A.shape[1], 1
sstride_be, sstride_bk, sstride_bn = B.shape[1] * B.shape[2], B.shape[2], 1
sstride_cm, sstride_cn = C.shape[-1], 1
sstride_bse, sstride_bsk, sstride_bsn = (
B_scale.shape[1] * B_scale.shape[2],
B_scale.shape[2],
1,
)
sstride_bce, sstride_bck, sstride_bcn = B_code_scale.shape[1], 1, 1
ddouble_quant = B_super_scale is not None
prepare_attr_for_triton_kernel = """
auto N = B.shape()[2];
auto K = A.shape()[1];
auto EM = sorted_token_ids.shape()[0];
auto num_valid_tokens = (topk_ids.shape()[0]) * (topk_ids.shape()[1]);
auto stride_am = A.strides()[0];
auto stride_ak = A.strides()[1];
auto stride_be = B.strides()[0];
auto stride_bk = B.strides()[1];
auto stride_bn = B.strides()[2];
auto stride_cm = C.strides()[1];
auto stride_cn = C.strides()[2];
auto stride_bse = B_scale.strides()[0];
auto stride_bsk = B_scale.strides()[1];
auto stride_bsn = 1;
auto stride_bce = B_code_scale.strides()[0];
auto stride_bck = 1;
auto stride_bcn = 1;
auto double_quant = true;
"""
if mul_routed_weight:
config = {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 2,
"num_warps": 4,
"num_stages": 8,
}
else:
config = {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 512,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 12,
}
configs = []
configs.append(dict(config))
op_name = "wint2_moe_ffn"
op_name += f"{get_dtype_str(A.dtype)}"
op_name += f"{B.shape[0]}"
op_name += f"{B.shape[1]}"
op_name += f"{B.shape[2]}"
if op_name not in OpProtoHolder.instance().op_proto_map.keys():
prepare_ptr_for_triton_kernel = """
CUdeviceptr input_ptrs[11] = {
get_tensor_ptr(A),
get_tensor_ptr(B),
get_tensor_ptr(C),
get_tensor_ptr(B_scale),
get_tensor_ptr(B_super_scale),
get_tensor_ptr(B_code_scale),
get_tensor_ptr(B_code_zp),
get_tensor_ptr(topk_weights),
get_tensor_ptr(sorted_token_ids),
get_tensor_ptr(expert_ids),
get_tensor_ptr(num_tokens_post_padded),
};
"""
template_used = rendering_common_template(
invoke_fused_moe_kernel,
prepare_attr_for_triton_kernel,
prepare_ptr_for_triton_kernel,
)
grid = ("(EM+BLOCK_SIZE_M-1)/BLOCK_SIZE_M * ((N+BLOCK_SIZE_N-1)/BLOCK_SIZE_N)",)
moe_wint2_ffn_kernel[(op_name, template_used, grid, configs)](
A,
B,
C,
B_scale,
B_super_scale,
B_code_scale,
B_code_zp,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
NN,
KK,
-1, # EEM,
-1, # nnum_valid_tokens,
sstride_am,
sstride_ak,
sstride_be,
sstride_bk,
sstride_bn,
sstride_cm,
sstride_cn,
sstride_bse,
sstride_bsk,
sstride_bsn,
sstride_bce,
sstride_bck,
sstride_bcn,
MUL_ROUTED_WEIGHT=(int)(mul_routed_weight),
USE_DOUBLE_QUANT=(int)(ddouble_quant),
top_k=top_k,
BLOCK_SIZE_K=group_size,
)
if in_dynamic_or_pir_mode():
outs = _C_ops._run_custom_op(
op_name,
A,
B,
C,
B_scale,
B_super_scale,
B_code_scale,
B_code_zp,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
mul_routed_weight,
top_k,
group_size,
)
return outs[0]
@paddle_use_triton(
key=["1"],
)
def moe_wint2_ffn_kernel(
# Pointers to matrices
a_ptr,
b_ptr,
c_ptr,
bs_ptr,
superbs_ptr,
codebs_ptr,
codebzp_ptr,
topk_weights_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
num_tokens_post_padded_ptr,
# Matrix dimensions
N,
K,
EM,
num_valid_tokens,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am,
stride_ak,
stride_be,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_bse,
stride_bsk,
stride_bsn,
stride_bce,
stride_bck,
stride_bcn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
USE_DOUBLE_QUANT: tl.constexpr,
top_k: tl.constexpr,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices.
Key Parameters:
- A: The input tensor representing tokens with shape (*, K), where '*' can
be any shape representing batches and K is the feature dimension of
each token.
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
the number of experts, K is the input feature dimension, and N is
the output feature dimension.
- C: The output cache tensor with shape (M, topk, N), where M is the
total number of tokens post padding, topk is the number of times
each token is repeated, and N is the output feature dimension.
- sorted_token_ids: A tensor containing the sorted indices of tokens,
repeated topk times and arranged by the expert index they are
assigned to.
- expert_ids: A tensor containing the indices of the expert for each
block. It determines which expert matrix from B should be used for
each block in A.
This kernel performs the multiplication of a token by its corresponding
expert matrix as determined by `expert_ids`. The sorting of
`sorted_token_ids` by expert index and padding ensures divisibility by
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
multiplication across different blocks processed by the same expert.
"""
if USE_DOUBLE_QUANT:
# INT4 scale
s_packnums: tl.constexpr = 2
bzp: tl.constexpr = 32
w_mask: tl.constexpr = 0x3F
pack_num: tl.constexpr = 4
real_k_size: tl.constexpr = (BLOCK_SIZE_K - 1) // pack_num + 1
pid = tl.program_id(axis=0)
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
compute_type = c_ptr.dtype.element_ty
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
token_mask = offs_token < num_valid_tokens
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
# offs_k = tl.arange(0, BLOCK_SIZE_K)
offs_bk = tl.arange(0, real_k_size)
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_bk[None, :] * pack_num * stride_ak)
off_experts = tl.load(expert_ids_ptr + pid_m)
b_ptrs = b_ptr + off_experts * stride_be + (offs_bk[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
bs_ptrs = bs_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn # group-wise, need advanced
off_set = off_experts * stride_bce + offs_bn[None, :] * stride_bcn
# load channel-wise scale & zero-point
if USE_DOUBLE_QUANT:
superbs_ptrs = superbs_ptr + off_set # channel-wise
super_bs = tl.load(superbs_ptrs) # super scale
codebs_ptrs = codebs_ptr + off_set # channel-wise
code_bs = tl.load(codebs_ptrs) # code scale
codebzp_ptrs = codebzp_ptr + off_set # channel-wise
code_bzp = tl.load(codebzp_ptrs) # code zp
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
b = tl.load(b_ptrs)
bs = tl.load(bs_ptrs)
if USE_DOUBLE_QUANT:
s_shift_bits = (1 - k % s_packnums) * 4
bs = ((bs >> s_shift_bits) & 0xF) * super_bs
# reverse to int16
b = tl.floor((b.to(tl.float32) * code_bs + code_bzp) + 0.5).to(tl.int16)
# dequant
b1 = (((b >> 9) & w_mask) - bzp) * bs
a = tl.load(
a_ptrs,
mask=token_mask[:, None],
other=0.0,
)
accumulator += tl.dot(a, b1.to(a.dtype))
b1 = (((b >> 6) & w_mask) - bzp) * bs
a = tl.load(
a_ptrs + 1,
mask=token_mask[:, None],
other=0.0,
)
accumulator += tl.dot(a, b1.to(a.dtype))
b1 = (((b >> 3) & w_mask) - bzp) * bs
a = tl.load(
a_ptrs + 2,
mask=token_mask[:, None],
other=0.0,
)
accumulator += tl.dot(a, b1.to(a.dtype))
b = ((b & w_mask) - bzp) * bs
a = tl.load(
a_ptrs + 3,
mask=token_mask[:, None],
other=0.0,
)
accumulator += tl.dot(a, b.to(a.dtype))
b_ptrs += real_k_size * stride_bk
a_ptrs += BLOCK_SIZE_K * stride_ak
# advance scale ptr
if USE_DOUBLE_QUANT:
bs_ptrs += stride_bsk * (k % s_packnums)
else:
bs_ptrs += stride_bsk
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
accumulator = accumulator * moe_weight[:, None]
accumulator = accumulator.to(compute_type)
# -----------------------------------------------------------
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
def fused_moe_wint2_impl(
hidden_states,
up_gate_proj_quant_weight,
down_proj_quant_weight,
topk_weights,
topk_ids,
# inplace: bool = False,
up_gate_proj_weight_scale=None,
down_proj_weight_scale=None,
up_gate_proj_super_scales=None,
down_proj_super_scales=None,
up_gate_proj_code_scale=None,
down_proj_code_scale=None,
up_gate_proj_code_zp=None,
down_proj_code_zp=None,
group_size=64,
bit="wint2",
):
"""
Implementation of Fused MoE kernels on GPU.
"""
# Check constraints.
# A: [M, K]
# B: [E, K, N]
# assert hidden_states.shape[1] == up_gate_proj_weight_scale.shape[1],
# f"Hidden size mismatch, {hidden_states.shape[1]} != {up_gate_proj_quant_weight.shape[1]}"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert up_gate_proj_quant_weight.is_contiguous(), "Expert weights1 must be contiguous"
assert down_proj_quant_weight.is_contiguous(), "Expert weights2 must be contiguous"
assert group_size > 0, "Group size must be greater than 0"
num_tokens, K = hidden_states.shape
E, _, N = up_gate_proj_quant_weight.shape
M = num_tokens
if group_size < 0:
group_size = K // up_gate_proj_weight_scale.shape[1]
top_k = topk_ids.shape[1]
intermediate_cache1 = paddle.empty(
[M, top_k, N],
dtype=hidden_states.dtype,
)
intermediate_cache2 = paddle.empty(
(M * top_k, N // 2),
dtype=hidden_states.dtype,
)
intermediate_cache3 = paddle.empty(
(M, top_k, K),
dtype=hidden_states.dtype,
)
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(topk_ids, E, BLOCK_SIZE_M)
invoke_fused_moe_kernel(
A=hidden_states,
B=up_gate_proj_quant_weight,
C=intermediate_cache1,
B_scale=up_gate_proj_weight_scale,
B_super_scale=up_gate_proj_super_scales,
B_code_scale=up_gate_proj_code_scale,
B_code_zp=up_gate_proj_code_zp,
topk_weights=topk_weights,
topk_ids=topk_ids,
sorted_token_ids=sorted_token_ids,
expert_ids=expert_ids,
num_tokens_post_padded=num_tokens_post_padded,
mul_routed_weight=False,
top_k=top_k,
group_size=group_size,
)
intermediate_cache2 = paddle.incubate.nn.functional.swiglu(intermediate_cache1.reshape([-1, N]))
invoke_fused_moe_kernel(
A=intermediate_cache2,
B=down_proj_quant_weight,
C=intermediate_cache3,
B_scale=down_proj_weight_scale,
B_super_scale=down_proj_super_scales,
B_code_scale=down_proj_code_scale,
B_code_zp=down_proj_code_zp,
topk_weights=topk_weights,
topk_ids=topk_ids,
sorted_token_ids=sorted_token_ids,
expert_ids=expert_ids,
num_tokens_post_padded=num_tokens_post_padded,
mul_routed_weight=True,
top_k=1,
group_size=group_size,
)
out_hidden_states = paddle.sum(intermediate_cache3, axis=1)
return out_hidden_states
def fused_moe_wint2_triton(
hidden_states,
up_gate_proj_quant_weight,
down_proj_quant_weight,
scores,
gate_correction_bias,
topk,
up_gate_proj_weight_scale,
down_proj_weight_scale,
up_gate_proj_super_scales,
down_proj_super_scales,
up_gate_proj_code_scale,
down_proj_code_scale,
up_gate_proj_code_zp,
down_proj_code_zp,
):
"""
Fuse MoE with WINT2 quantization scheme and Triton backend.
Args:
hidden_states: input tensor.
up_gate_proj_quant_weight: up_gate_proj weight matrix for experts.
down_proj_quant_weight: down_proj weight matrix for experts.
scores: gate scores.
gate_correction_bias: bias correction for gates.
topk: number of experts to use.
up_gate_proj_weight_scale: scaling factor for up_gate_proj_quant_weight.
down_proj_weight_scale: scaling factor for down_proj_quant_weight.
up_gate_proj_super_scales: super scaling factor for up_gate_proj_scale.
down_proj_super_scales: super scaling factor for down_proj_weight_scale.
up_gate_proj_code_scale: code scaling factor for up_gate_proj_quant_weight.
down_proj_code_scale: code scaling factor for down_proj_quant_weight.
up_gate_proj_code_zp: code zero point for up_gate_proj_quant_weight.
down_proj_code_zp: code zero point for down_proj_quant_weight.
Returns:
output tensor.
"""
score = gate_correction_bias + scores
_, topk_ids = paddle.topk(score, k=topk, axis=-1)
topk_weights, _ = paddle.topk(scores, k=topk, axis=-1)
topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdim=True)
return fused_moe_wint2_impl(
hidden_states,
up_gate_proj_quant_weight,
down_proj_quant_weight,
topk_weights,
topk_ids,
up_gate_proj_weight_scale,
down_proj_weight_scale,
up_gate_proj_super_scales,
down_proj_super_scales,
up_gate_proj_code_scale,
down_proj_code_scale,
up_gate_proj_code_zp,
down_proj_code_zp,
bit="wint2",
)