mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-11-01 12:22:53 +08:00
delete useless code (#4544)
Co-authored-by: root <root@yqlcc01-sys-rpm12rzmwjd.yqlcc01.baidu.com>
This commit is contained in:
@@ -17,16 +17,13 @@
|
||||
import base64
|
||||
import datetime
|
||||
import hashlib
|
||||
import io
|
||||
import os
|
||||
import threading
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from PIL import Image
|
||||
from PIL.ExifTags import TAGS
|
||||
|
||||
RAW_VIDEO_DIR = "./download_tmp/raw_video/"
|
||||
RAW_IMAGE_DIR = "./download_tmp/raw_images/"
|
||||
@@ -110,155 +107,3 @@ def get_downloadable(
|
||||
retry_interval=retry_interval,
|
||||
)
|
||||
return downloaded_path
|
||||
|
||||
|
||||
def get_downloadable_image(download_path, need_exif_info, retry_max_time=0, retry_interval=3):
|
||||
"""
|
||||
带上exif info和图像处理的get downloadable
|
||||
"""
|
||||
|
||||
def get_image_exif(image):
|
||||
exif_data = image._getexif()
|
||||
exif_info = {}
|
||||
if exif_data is not None:
|
||||
for tag, value in exif_data.items():
|
||||
tag_name = TAGS.get(tag, tag)
|
||||
exif_info[tag_name] = value.strip()
|
||||
return exif_info
|
||||
|
||||
def has_transparent_background(img):
|
||||
"""判断图片是否有背景"""
|
||||
if img.mode in ("RGBA", "LA") or (img.mode == "P" and "transparency" in img.info):
|
||||
# Check for any pixel with alpha channel less than 255 (fully opaque)
|
||||
alpha = img.convert("RGBA").split()[-1]
|
||||
if alpha.getextrema()[0] < 255:
|
||||
return True
|
||||
return False
|
||||
|
||||
def add_white_background(img):
|
||||
"""
|
||||
给透明背景的图,加个白色背景
|
||||
"""
|
||||
if img.mode != "RGBA":
|
||||
img = img.convert("RGBA")
|
||||
# 创建一个白色背景的图像,尺寸与原图一致
|
||||
img_white_background = Image.new("RGBA", img.size, (255, 255, 255))
|
||||
|
||||
# 将原图粘贴到白色背景上
|
||||
img_white_background.paste(img, (0, 0), img)
|
||||
|
||||
return img_white_background
|
||||
|
||||
def change_I16_to_L(img):
|
||||
"""
|
||||
将图片从I;16模式转换为L模式
|
||||
"""
|
||||
# 由于I模式的point函数只支持加减乘,所以下面的* (1 / 256)不能改成除法
|
||||
return img.point(lambda i: i * (1 / 256)).convert("L")
|
||||
|
||||
image = get_downloadable(
|
||||
download_path,
|
||||
save_to_disk=False,
|
||||
retry=retry_max_time,
|
||||
retry_interval=retry_interval,
|
||||
)
|
||||
if isinstance(image, Image.Image):
|
||||
pil_image = image
|
||||
else:
|
||||
pil_image = Image.open(io.BytesIO(image))
|
||||
if need_exif_info:
|
||||
try:
|
||||
exif_info = get_image_exif(pil_image)
|
||||
except Exception:
|
||||
exif_info = {}
|
||||
else:
|
||||
exif_info = {}
|
||||
|
||||
try:
|
||||
if pil_image.mode == "I;16":
|
||||
pil_image = change_I16_to_L(pil_image)
|
||||
if has_transparent_background(pil_image):
|
||||
pil_image = add_white_background(pil_image)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return pil_image.convert("RGB"), exif_info
|
||||
|
||||
|
||||
def str2hash(url):
|
||||
"""
|
||||
从一个str的到url
|
||||
"""
|
||||
return hashlib.sha256(url.encode()).hexdigest()
|
||||
|
||||
|
||||
def pil2hash(pil):
|
||||
"""
|
||||
从一个PIL.Image到hash
|
||||
"""
|
||||
byte_io = io.BytesIO()
|
||||
pil.save(byte_io, format="PNG") # 选择无损格式,避免压缩影响
|
||||
image_bytes = byte_io.getvalue()
|
||||
|
||||
return hashlib.sha256(image_bytes).hexdigest()
|
||||
|
||||
|
||||
def imagepath_to_base64(image_path):
|
||||
"""imagepath_to_base64"""
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
buffer = io.BytesIO()
|
||||
image.save(buffer, format="JPEG")
|
||||
image_bytes = buffer.getvalue()
|
||||
base64_encoded = base64.b64encode(image_bytes).decode("utf-8")
|
||||
return base64_encoded
|
||||
|
||||
|
||||
def pil_image_to_base64(image):
|
||||
"""pil_image_to_base64"""
|
||||
buffer = io.BytesIO()
|
||||
image.save(buffer, format="JPEG")
|
||||
image_bytes = buffer.getvalue()
|
||||
base64_encoded = base64.b64encode(image_bytes).decode("utf-8")
|
||||
return base64_encoded
|
||||
|
||||
|
||||
def http_to_pil_image(url):
|
||||
"""http_to_pil_image"""
|
||||
response = requests.get(url)
|
||||
image_data = io.BytesIO(response.content)
|
||||
pil_image = Image.open(image_data).convert("RGB")
|
||||
return pil_image
|
||||
|
||||
|
||||
def http_to_image_base64(url):
|
||||
"""http_to_image_base64"""
|
||||
response = requests.get(url)
|
||||
image_data = io.BytesIO(response.content)
|
||||
return base64.b64encode(image_data.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
def base64_to_pil_image(base64_string):
|
||||
""" " base64_to_pil_image"""
|
||||
image_bytes = base64.b64decode(base64_string)
|
||||
buffer = io.BytesIO(image_bytes)
|
||||
image = Image.open(buffer)
|
||||
return image
|
||||
|
||||
|
||||
def get_hashable(to_be_hashed):
|
||||
"""get hashable"""
|
||||
if isinstance(to_be_hashed, bytes):
|
||||
return to_be_hashed
|
||||
elif isinstance(to_be_hashed, Image.Image):
|
||||
return to_be_hashed.tobytes()
|
||||
elif isinstance(to_be_hashed, str):
|
||||
return to_be_hashed.encode("utf-8")
|
||||
else:
|
||||
raise ValueError(f"not support type: {type(to_be_hashed)}")
|
||||
|
||||
|
||||
def load_dict_from_npz(npzfile):
|
||||
"""从npz文件读取数据"""
|
||||
with np.load(npzfile, allow_pickle=True) as data:
|
||||
loaded_dict = {key: data[key] for key in data.files}
|
||||
return loaded_dict
|
||||
|
||||
@@ -17,29 +17,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import struct
|
||||
from functools import partial
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle.common_ops_import import convert_dtype
|
||||
from paddleformers.transformers.model_utils import _add_variant
|
||||
from paddleformers.transformers.utils import paddleformers_load
|
||||
from paddleformers.utils.env import (
|
||||
PADDLE_WEIGHTS_INDEX_NAME,
|
||||
SAFE_MASTER_WEIGHTS_INDEX_NAME,
|
||||
SAFE_PEFT_WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
)
|
||||
from paddleformers.utils.log import logger
|
||||
from tqdm import tqdm
|
||||
|
||||
MAX_BSZ = 512
|
||||
MAX_DRAFT_TOKENS = 6
|
||||
|
||||
@@ -67,428 +46,3 @@ class WeightMeta(NamedTuple):
|
||||
weight_name: str
|
||||
is_column: bool
|
||||
extra: Optional[str] = None
|
||||
|
||||
|
||||
class UniqueIDGenerator:
|
||||
"""
|
||||
The generator for the export model id
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def generate_unique_id(self, state_dict):
|
||||
"""
|
||||
Generate the model id from the timestamp
|
||||
"""
|
||||
keys = state_dict.keys()
|
||||
sorted_keys = sorted(keys)
|
||||
first_key = sorted_keys[0]
|
||||
first_parameter = state_dict[first_key].cast("float32")
|
||||
# 假设模型参数是唯一的,通过第一个key来获取md5sum
|
||||
model_md5 = hashlib.md5(str(first_parameter.sum()).encode("utf-8")).hexdigest()
|
||||
unique_id = f"{model_md5}-{random.randint(10000, 99999)}"
|
||||
return unique_id
|
||||
|
||||
|
||||
def load_sharded_checkpoint(folder, variant=None, return_numpy=False):
|
||||
"""
|
||||
|
||||
This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
|
||||
loaded in the model.
|
||||
|
||||
Args:
|
||||
folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint.
|
||||
variant (`str`): The model variant.
|
||||
|
||||
"""
|
||||
# Load the index
|
||||
pdparams_file = os.path.join(folder, _add_variant("model_state.pdparams", variant))
|
||||
lora_pdparams_file = os.path.join(folder, _add_variant("lora_model_state.pdparams", variant))
|
||||
safetensors_file = os.path.join(folder, _add_variant("model.safetensors", variant))
|
||||
if os.path.isfile(pdparams_file):
|
||||
return paddle.load(pdparams_file, return_numpy=return_numpy)
|
||||
if os.path.isfile(lora_pdparams_file):
|
||||
return paddle.load(lora_pdparams_file, return_numpy=return_numpy)
|
||||
if os.path.isfile(safetensors_file):
|
||||
try:
|
||||
from paddleformers.utils.safetensors import fast_load_file as safe_load_file
|
||||
except ImportError:
|
||||
from safetensors.numpy import load_file as safe_load_file
|
||||
|
||||
state_dict = safe_load_file(safetensors_file)
|
||||
if not return_numpy:
|
||||
for key in list(state_dict.keys()):
|
||||
if isinstance(state_dict[key], np.ndarray):
|
||||
state_dict[key] = paddle.Tensor(state_dict.pop(key), zero_copy=True)
|
||||
return state_dict
|
||||
|
||||
index_file = os.path.join(folder, _add_variant(PADDLE_WEIGHTS_INDEX_NAME, variant))
|
||||
safe_index_file = os.path.join(folder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant))
|
||||
safe_master_file = os.path.join(folder, _add_variant(SAFE_MASTER_WEIGHTS_INDEX_NAME, variant))
|
||||
safe_peft_file = os.path.join(folder, _add_variant(SAFE_PEFT_WEIGHTS_INDEX_NAME, variant))
|
||||
|
||||
index_present = os.path.isfile(index_file)
|
||||
safe_index_present = os.path.isfile(safe_index_file)
|
||||
safe_master_present = os.path.isfile(safe_master_file)
|
||||
safe_peft_present = os.path.isfile(safe_peft_file)
|
||||
|
||||
load_safe = False
|
||||
load_index = None
|
||||
if safe_index_present:
|
||||
load_safe = True # load safe due to preference
|
||||
load_index = safe_index_file
|
||||
elif safe_master_present:
|
||||
load_safe = True
|
||||
load_index = safe_master_file
|
||||
elif index_present:
|
||||
load_index = index_file
|
||||
elif safe_peft_present:
|
||||
load_safe = True
|
||||
load_index = safe_peft_file
|
||||
else:
|
||||
raise ValueError(f"Could not find {index_file} or {safe_index_file} or {safe_peft_file}")
|
||||
|
||||
if load_safe:
|
||||
try:
|
||||
from paddleformers.utils.safetensors import fast_load_file as safe_load_file
|
||||
except ImportError:
|
||||
from safetensors.numpy import load_file as safe_load_file
|
||||
|
||||
with open(load_index, "r", encoding="utf-8") as f:
|
||||
index = json.load(f)
|
||||
|
||||
shard_files = list(set(index["weight_map"].values()))
|
||||
loader = safe_load_file if load_safe else partial(paddleformers_load, map_location="np" if return_numpy else "cpu")
|
||||
|
||||
ret = {}
|
||||
for shard_file in tqdm(shard_files):
|
||||
state_dict = loader(os.path.join(folder, shard_file))
|
||||
ret.update(state_dict)
|
||||
|
||||
if not return_numpy:
|
||||
for key in list(ret.keys()):
|
||||
if isinstance(ret[key], np.ndarray):
|
||||
ret[key] = paddle.Tensor(ret.pop(key), zero_copy=True)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def convert_ndarray_dtype(np_array: np.ndarray, target_dtype: str) -> np.ndarray:
|
||||
"""convert ndarray
|
||||
|
||||
Args:
|
||||
np_array (np.ndarray): numpy ndarray instance
|
||||
target_dtype (str): the target dtype
|
||||
|
||||
Returns:
|
||||
np.ndarray: converted numpy ndarray instance
|
||||
"""
|
||||
source_dtype = convert_dtype(np_array.dtype)
|
||||
if (
|
||||
source_dtype == "uint16"
|
||||
and target_dtype == "bfloat16"
|
||||
and paddle.is_compiled_with_custom_device("iluvatar_gpu")
|
||||
):
|
||||
return np_array.view(dtype=target_dtype)
|
||||
if source_dtype == "uint16" or target_dtype == "bfloat16":
|
||||
if paddle.is_compiled_with_xpu():
|
||||
# xpu not support bf16.
|
||||
tensor = paddle.to_tensor(np_array, place=paddle.CPUPlace())
|
||||
else:
|
||||
tensor = paddle.to_tensor(np_array)
|
||||
tensor = paddle.cast(tensor, target_dtype)
|
||||
return tensor.numpy()
|
||||
|
||||
# TODO(wj-Mcat): device_guard will slow the converting
|
||||
# with device_guard("cpu"):
|
||||
# tensor = paddle.to_tensor(np_array)
|
||||
# tensor = paddle.cast(tensor, target_dtype)
|
||||
# return tensor.numpy()
|
||||
|
||||
if target_dtype == "bfloat16":
|
||||
target_dtype = "uint16"
|
||||
|
||||
return np_array.astype(target_dtype)
|
||||
|
||||
|
||||
def set_seed(seed: int):
|
||||
"""
|
||||
set random seed for all random modules
|
||||
"""
|
||||
paddle.seed(seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
|
||||
def pad_batch_data(insts, pad_id=0, return_seq_len=False, pad_style="right"):
|
||||
"""Pad the instances to the max sequence length in batch."""
|
||||
# pad to max input len i bsz
|
||||
max_len = max(map(len, insts))
|
||||
# pad to max input len
|
||||
# max_len = args.max_len
|
||||
if pad_style == "left":
|
||||
inst_data = np.array([[pad_id] * (max_len - len(inst)) + list(inst) for inst in insts])
|
||||
else:
|
||||
inst_data = np.array([list(inst) + [pad_id] * (max_len - len(inst)) for inst in insts])
|
||||
if return_seq_len:
|
||||
seq_len = np.array([len(inst) for inst in insts])
|
||||
return inst_data.astype("int64").reshape([-1, max_len]), seq_len
|
||||
else:
|
||||
return inst_data.astype("int64").reshape([-1, max_len])
|
||||
|
||||
|
||||
def load_prefix_weights(
|
||||
prefix_path: str,
|
||||
inference: bool = False,
|
||||
batch_size: int = 1,
|
||||
dtype: str = "bfloat16",
|
||||
) -> np.ndarray | list[paddle.Tensor]:
|
||||
"""load prefix weight by path
|
||||
|
||||
Args:
|
||||
prefix_path (str): the path of prefix weight
|
||||
"""
|
||||
past_key_values = paddle.to_tensor(np.load(f"{prefix_path}/pre_caches.npy")).unsqueeze(2)
|
||||
|
||||
if batch_size > 1:
|
||||
past_key_values = paddle.concat([past_key_values] * batch_size, axis=2)
|
||||
|
||||
# .chatglm static model require one tensor, otherwise list of tensor
|
||||
past_key_values = past_key_values.astype(dtype)
|
||||
if inference:
|
||||
return past_key_values.numpy()
|
||||
return past_key_values
|
||||
|
||||
|
||||
def w4a8_weight_convert(state_dict):
|
||||
"""W4A8 权重转换函数
|
||||
Args:
|
||||
state_dict (dict): state_dict of model
|
||||
"""
|
||||
|
||||
def w4_weight_squash(value, name, w4a8_weight_bites_name_map):
|
||||
weight_dq = value
|
||||
# W8表象下的W4权重的absmax值为112,使用正负112进行权重类型判断
|
||||
if weight_dq.max() == 112 or weight_dq.min() == -112:
|
||||
weight_dq = weight_dq.cast("int8")
|
||||
np_weight_dq = np.array(weight_dq, dtype="int8").view("uint8")
|
||||
np_weight_dq_left_div_16 = (np_weight_dq / 16).astype("int8")
|
||||
# weight_q = (weight_dq/16).cast('int8')
|
||||
weight_q = paddle.to_tensor(np_weight_dq_left_div_16, dtype="int8")
|
||||
logger.debug(f"int4 weight:{name}")
|
||||
w4a8_weight_bites_name_map[name] = 4
|
||||
return weight_q.cast("int8")
|
||||
elif weight_dq.max() == 127 or weight_dq.min() == -128:
|
||||
logger.debug(f"int8 weight:{name}")
|
||||
w4a8_weight_bites_name_map[name] = 8
|
||||
return weight_dq.cast("int8")
|
||||
else:
|
||||
logger.debug(f"fp16/bf16/float weight:{name}")
|
||||
return weight_dq
|
||||
|
||||
w4a8_weight_bites_name_map = {}
|
||||
for name, value in state_dict.items():
|
||||
if value.dtype == "uint16":
|
||||
weight_q = w4_weight_squash(
|
||||
paddle.to_tensor(value).cast("float32"),
|
||||
name,
|
||||
w4a8_weight_bites_name_map,
|
||||
)
|
||||
state_dict[name] = weight_q.numpy() if weight_q is not None else value
|
||||
del weight_q
|
||||
w4a8_weight_bites_layers_map = {}
|
||||
w4a8_weight_bites_layers_map["qkv_gemm_bits_map"] = []
|
||||
w4a8_weight_bites_layers_map["out_gemm_bits_map"] = []
|
||||
w4a8_weight_bites_layers_map["up_gate_proj_gemm_bits_map"] = []
|
||||
w4a8_weight_bites_layers_map["down_proj_gemm_bits_map"] = []
|
||||
for name_keys, gemm_bits in w4a8_weight_bites_name_map.items():
|
||||
if "qkv_proj" in name_keys:
|
||||
w4a8_weight_bites_layers_map["qkv_gemm_bits_map"].append(gemm_bits)
|
||||
elif "out_proj" in name_keys:
|
||||
w4a8_weight_bites_layers_map["out_gemm_bits_map"].append(gemm_bits)
|
||||
elif "linear1" in name_keys:
|
||||
w4a8_weight_bites_layers_map["up_gate_proj_gemm_bits_map"].append(gemm_bits)
|
||||
elif "linear2" in name_keys:
|
||||
w4a8_weight_bites_layers_map["down_proj_gemm_bits_map"].append(gemm_bits)
|
||||
logger.debug(f"w4a8_weight_bites_layers_map:{w4a8_weight_bites_layers_map}")
|
||||
return state_dict, w4a8_weight_bites_layers_map
|
||||
|
||||
|
||||
def _vocab_size_with_padding(vocab_size, div_unit, mp_degree):
|
||||
padded_size = vocab_size
|
||||
multiple = div_unit * mp_degree
|
||||
while (padded_size % multiple) != 0:
|
||||
padded_size += 1
|
||||
# logger.warning(
|
||||
# " > padded vocab (size: {}) with {} dummy tokens "
|
||||
# "(new size: {})".format(vocab_size, padded_size - vocab_size, padded_size)
|
||||
# )
|
||||
return padded_size
|
||||
|
||||
|
||||
def save_test_case(cases: list[list[dict]], file: str):
|
||||
"""save test to result file
|
||||
|
||||
Args:
|
||||
cases (list[list[dict]]): the content of case
|
||||
file (str): the path of saved file
|
||||
"""
|
||||
with open(file, "w+", encoding="utf-8") as f:
|
||||
for case in cases:
|
||||
raw = json.dumps(case, ensure_ascii=False)
|
||||
f.write(raw + "\n")
|
||||
|
||||
|
||||
def infer_save_test_case(cases: list[list[dict]], file: str):
|
||||
"""save test to result file
|
||||
|
||||
Args:
|
||||
cases (list[list[dict]]): the content of case
|
||||
file (str): the path of saved file
|
||||
"""
|
||||
with open(file, "a+", encoding="utf-8") as f:
|
||||
for case in cases:
|
||||
raw = json.dumps(case, ensure_ascii=False)
|
||||
f.write(raw + "\n")
|
||||
|
||||
|
||||
def deserialize_from_file(fp):
|
||||
"""
|
||||
deserialize a binary file into an array
|
||||
"""
|
||||
|
||||
x_type = fp.read(1)
|
||||
x_type_out = struct.unpack("c", x_type)[0]
|
||||
# data
|
||||
data_list = []
|
||||
if x_type_out == b"0":
|
||||
data = fp.read(4)
|
||||
data_out = struct.unpack("f", data)[0]
|
||||
while data:
|
||||
data_out = struct.unpack("f", data)[0]
|
||||
data_list.append(data_out)
|
||||
data = fp.read(4)
|
||||
elif x_type_out == b"1":
|
||||
data = fp.read(8)
|
||||
while data:
|
||||
data_out = struct.unpack("l", data)[0]
|
||||
data_list.append(data_out)
|
||||
data = fp.read(8)
|
||||
elif x_type_out == b"2":
|
||||
data = fp.read(4)
|
||||
while data:
|
||||
data_out = struct.unpack("i", data)[0]
|
||||
data_list.append(data_out)
|
||||
data = fp.read(4)
|
||||
else:
|
||||
print("type error")
|
||||
data_arr = np.array(data_list)
|
||||
return data_arr
|
||||
|
||||
|
||||
def calculate_effective_tokens(training_args, train_dataset, max_seq_len):
|
||||
"""
|
||||
Calculate the effective tokens during training.
|
||||
"""
|
||||
total_effective_tokens = 0
|
||||
try:
|
||||
data_parallel_degree = training_args.data_parallel_degree
|
||||
except Exception:
|
||||
data_parallel_degree = 1
|
||||
if training_args.sharding_parallel_degree > 1:
|
||||
sharding_parallel_degree = training_args.sharding_parallel_degree
|
||||
else:
|
||||
sharding_parallel_degree = 1
|
||||
|
||||
total_batch = (
|
||||
training_args.max_steps
|
||||
* training_args.per_device_train_batch_size
|
||||
* training_args.gradient_accumulation_steps
|
||||
* sharding_parallel_degree
|
||||
* data_parallel_degree
|
||||
)
|
||||
for i, data in enumerate(train_dataset):
|
||||
if i == total_batch:
|
||||
break
|
||||
for dd in data:
|
||||
total_effective_tokens += len(dd.token_ids)
|
||||
total_tokens = total_batch * max_seq_len
|
||||
|
||||
return total_effective_tokens, total_tokens
|
||||
|
||||
|
||||
def parser_quant_type(quant_type):
|
||||
"""
|
||||
Parse the quantization type string and return the corresponding quantization types for weights,
|
||||
activations, and custom.
|
||||
|
||||
Args:
|
||||
quant_type (str): The quantization type string. It can be one of the following formats:
|
||||
- "weight_only_int8" or "wint8": Only weights are quantized to int8.
|
||||
- "weight_only_int4" or "wint4": Only weights are quantized to int4.
|
||||
- A custom string in the format of "wxaybzcfp8", where 'x', 'y', 'z' are the quantization bitwidths
|
||||
for weights, activations, and custom respectively,
|
||||
and 'a', 'b', 'c' are the prefixes indicating the quantization types
|
||||
(e.g., 'fp8' for floating-point 8-bit).
|
||||
If a prefix is missing, the default quantization type will be used.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple of three strings representing the quantization types for weights, activations,
|
||||
and custom respectively.
|
||||
If the input is "weight_only_int8" or "wint8", returns ("int8", default_type, default_type).
|
||||
If the input is "weight_only_int4" or "wint4", returns ("int4", default_type, default_type).
|
||||
For custom strings, returns the parsed quantization types based on the input format.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the custom quantization type string format is incorrect.
|
||||
"""
|
||||
default_type = paddle.get_default_dtype()
|
||||
if quant_type == "default" or quant_type is None:
|
||||
return default_type, default_type, default_type
|
||||
conver_dict = {
|
||||
"8": "int8",
|
||||
"4": "int4",
|
||||
"16": paddle.get_default_dtype(),
|
||||
"fp8": "float8_e4m3fn",
|
||||
"fp16": "float16",
|
||||
"bf16": "bfloat16",
|
||||
"fp32": "float32",
|
||||
}
|
||||
cache_type = default_type
|
||||
if "c8" in quant_type:
|
||||
cache_type = "int8"
|
||||
elif "cfp8" in quant_type:
|
||||
cache_type = "fp8"
|
||||
elif "c4" in quant_type:
|
||||
cache_type = "int4"
|
||||
|
||||
if "weight_only_int8" in quant_type or "wint8" in quant_type:
|
||||
return "int8", default_type, cache_type
|
||||
elif "weight_only_int4" in quant_type or "wint4" in quant_type:
|
||||
return "int4", default_type, cache_type
|
||||
else:
|
||||
# split quant type, eg. w4afp8c8 -> ['w', '4', 'a', 'fp8', 'c', '8']
|
||||
pattern = f"({'|'.join(map(re.escape, ['w', 'a', 'c']))})"
|
||||
splited_type = re.split(pattern, quant_type)
|
||||
splited_type = [tmp_type for tmp_type in splited_type if tmp_type]
|
||||
assert len(splited_type) % 2 == 0 and len(splited_type) <= 6, f"Quant type[{quant_type}] format error."
|
||||
|
||||
quant_type_list = []
|
||||
if "w" in splited_type:
|
||||
w_idx = splited_type.index("w")
|
||||
quant_type_list.append(conver_dict[splited_type[w_idx + 1]])
|
||||
else:
|
||||
quant_type_list.append(default_type)
|
||||
if "a" in splited_type:
|
||||
a_idx = splited_type.index("a")
|
||||
quant_type_list.append(conver_dict[splited_type[a_idx + 1]])
|
||||
else:
|
||||
quant_type_list.append(default_type)
|
||||
if "c" in splited_type:
|
||||
c_idx = splited_type.index("c")
|
||||
quant_type_list.append(conver_dict[splited_type[c_idx + 1]])
|
||||
else:
|
||||
quant_type_list.append(default_type)
|
||||
|
||||
return quant_type_list[0], quant_type_list[1], quant_type_list[2]
|
||||
|
||||
@@ -1,543 +0,0 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import paddle
|
||||
import triton.language as tl
|
||||
from paddle import _C_ops
|
||||
from paddle.base.framework import OpProtoHolder
|
||||
from paddle.framework import in_dynamic_or_pir_mode
|
||||
|
||||
from fastdeploy.model_executor.ops.triton_ops.triton_utils import (
|
||||
get_dtype_str,
|
||||
paddle_use_triton,
|
||||
rendering_common_template,
|
||||
)
|
||||
|
||||
BLOCK_SIZE_M = 16
|
||||
|
||||
|
||||
def invoke_fused_moe_kernel(
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
B_scale,
|
||||
B_super_scale,
|
||||
B_code_scale,
|
||||
B_code_zp,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
mul_routed_weight=False,
|
||||
top_k=-1,
|
||||
group_size=-1,
|
||||
):
|
||||
"""
|
||||
Invoke Fused Moe Kernel
|
||||
"""
|
||||
KK = A.shape[-1]
|
||||
NN = B.shape[-1]
|
||||
sstride_am, sstride_ak = A.shape[1], 1
|
||||
sstride_be, sstride_bk, sstride_bn = B.shape[1] * B.shape[2], B.shape[2], 1
|
||||
sstride_cm, sstride_cn = C.shape[-1], 1
|
||||
sstride_bse, sstride_bsk, sstride_bsn = (
|
||||
B_scale.shape[1] * B_scale.shape[2],
|
||||
B_scale.shape[2],
|
||||
1,
|
||||
)
|
||||
sstride_bce, sstride_bck, sstride_bcn = B_code_scale.shape[1], 1, 1
|
||||
|
||||
ddouble_quant = B_super_scale is not None
|
||||
|
||||
prepare_attr_for_triton_kernel = """
|
||||
auto N = B.shape()[2];
|
||||
auto K = A.shape()[1];
|
||||
auto EM = sorted_token_ids.shape()[0];
|
||||
auto num_valid_tokens = (topk_ids.shape()[0]) * (topk_ids.shape()[1]);
|
||||
auto stride_am = A.strides()[0];
|
||||
auto stride_ak = A.strides()[1];
|
||||
auto stride_be = B.strides()[0];
|
||||
auto stride_bk = B.strides()[1];
|
||||
auto stride_bn = B.strides()[2];
|
||||
auto stride_cm = C.strides()[1];
|
||||
auto stride_cn = C.strides()[2];
|
||||
auto stride_bse = B_scale.strides()[0];
|
||||
auto stride_bsk = B_scale.strides()[1];
|
||||
auto stride_bsn = 1;
|
||||
auto stride_bce = B_code_scale.strides()[0];
|
||||
auto stride_bck = 1;
|
||||
auto stride_bcn = 1;
|
||||
auto double_quant = true;
|
||||
"""
|
||||
if mul_routed_weight:
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"GROUP_SIZE_M": 2,
|
||||
"num_warps": 4,
|
||||
"num_stages": 8,
|
||||
}
|
||||
else:
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 512,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 12,
|
||||
}
|
||||
configs = []
|
||||
|
||||
configs.append(dict(config))
|
||||
|
||||
op_name = "wint2_moe_ffn"
|
||||
op_name += f"{get_dtype_str(A.dtype)}"
|
||||
op_name += f"{B.shape[0]}"
|
||||
op_name += f"{B.shape[1]}"
|
||||
op_name += f"{B.shape[2]}"
|
||||
|
||||
if op_name not in OpProtoHolder.instance().op_proto_map.keys():
|
||||
prepare_ptr_for_triton_kernel = """
|
||||
CUdeviceptr input_ptrs[11] = {
|
||||
get_tensor_ptr(A),
|
||||
get_tensor_ptr(B),
|
||||
get_tensor_ptr(C),
|
||||
get_tensor_ptr(B_scale),
|
||||
get_tensor_ptr(B_super_scale),
|
||||
get_tensor_ptr(B_code_scale),
|
||||
get_tensor_ptr(B_code_zp),
|
||||
get_tensor_ptr(topk_weights),
|
||||
get_tensor_ptr(sorted_token_ids),
|
||||
get_tensor_ptr(expert_ids),
|
||||
get_tensor_ptr(num_tokens_post_padded),
|
||||
};
|
||||
"""
|
||||
template_used = rendering_common_template(
|
||||
invoke_fused_moe_kernel,
|
||||
prepare_attr_for_triton_kernel,
|
||||
prepare_ptr_for_triton_kernel,
|
||||
)
|
||||
grid = ("(EM+BLOCK_SIZE_M-1)/BLOCK_SIZE_M * ((N+BLOCK_SIZE_N-1)/BLOCK_SIZE_N)",)
|
||||
|
||||
moe_wint2_ffn_kernel[(op_name, template_used, grid, configs)](
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
B_scale,
|
||||
B_super_scale,
|
||||
B_code_scale,
|
||||
B_code_zp,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
NN,
|
||||
KK,
|
||||
-1, # EEM,
|
||||
-1, # nnum_valid_tokens,
|
||||
sstride_am,
|
||||
sstride_ak,
|
||||
sstride_be,
|
||||
sstride_bk,
|
||||
sstride_bn,
|
||||
sstride_cm,
|
||||
sstride_cn,
|
||||
sstride_bse,
|
||||
sstride_bsk,
|
||||
sstride_bsn,
|
||||
sstride_bce,
|
||||
sstride_bck,
|
||||
sstride_bcn,
|
||||
MUL_ROUTED_WEIGHT=(int)(mul_routed_weight),
|
||||
USE_DOUBLE_QUANT=(int)(ddouble_quant),
|
||||
top_k=top_k,
|
||||
BLOCK_SIZE_K=group_size,
|
||||
)
|
||||
if in_dynamic_or_pir_mode():
|
||||
|
||||
outs = _C_ops._run_custom_op(
|
||||
op_name,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
B_scale,
|
||||
B_super_scale,
|
||||
B_code_scale,
|
||||
B_code_zp,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
mul_routed_weight,
|
||||
top_k,
|
||||
group_size,
|
||||
)
|
||||
return outs[0]
|
||||
|
||||
|
||||
@paddle_use_triton(
|
||||
key=["1"],
|
||||
)
|
||||
def moe_wint2_ffn_kernel(
|
||||
# Pointers to matrices
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
bs_ptr,
|
||||
superbs_ptr,
|
||||
codebs_ptr,
|
||||
codebzp_ptr,
|
||||
topk_weights_ptr,
|
||||
sorted_token_ids_ptr,
|
||||
expert_ids_ptr,
|
||||
num_tokens_post_padded_ptr,
|
||||
# Matrix dimensions
|
||||
N,
|
||||
K,
|
||||
EM,
|
||||
num_valid_tokens,
|
||||
# The stride variables represent how much to increase the ptr by when
|
||||
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
||||
# how much to increase `a_ptr` by to get the element one row down
|
||||
# (A has M rows).
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_be,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
stride_bse,
|
||||
stride_bsk,
|
||||
stride_bsn,
|
||||
stride_bce,
|
||||
stride_bck,
|
||||
stride_bcn,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
MUL_ROUTED_WEIGHT: tl.constexpr,
|
||||
USE_DOUBLE_QUANT: tl.constexpr,
|
||||
top_k: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Implements the fused computation for a Mixture of Experts (MOE) using
|
||||
token and expert matrices.
|
||||
|
||||
Key Parameters:
|
||||
- A: The input tensor representing tokens with shape (*, K), where '*' can
|
||||
be any shape representing batches and K is the feature dimension of
|
||||
each token.
|
||||
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
|
||||
the number of experts, K is the input feature dimension, and N is
|
||||
the output feature dimension.
|
||||
- C: The output cache tensor with shape (M, topk, N), where M is the
|
||||
total number of tokens post padding, topk is the number of times
|
||||
each token is repeated, and N is the output feature dimension.
|
||||
- sorted_token_ids: A tensor containing the sorted indices of tokens,
|
||||
repeated topk times and arranged by the expert index they are
|
||||
assigned to.
|
||||
- expert_ids: A tensor containing the indices of the expert for each
|
||||
block. It determines which expert matrix from B should be used for
|
||||
each block in A.
|
||||
This kernel performs the multiplication of a token by its corresponding
|
||||
expert matrix as determined by `expert_ids`. The sorting of
|
||||
`sorted_token_ids` by expert index and padding ensures divisibility by
|
||||
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
|
||||
multiplication across different blocks processed by the same expert.
|
||||
"""
|
||||
|
||||
if USE_DOUBLE_QUANT:
|
||||
# INT4 scale
|
||||
s_packnums: tl.constexpr = 2
|
||||
bzp: tl.constexpr = 32
|
||||
w_mask: tl.constexpr = 0x3F
|
||||
pack_num: tl.constexpr = 4
|
||||
real_k_size: tl.constexpr = (BLOCK_SIZE_K - 1) // pack_num + 1
|
||||
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
compute_type = c_ptr.dtype.element_ty
|
||||
|
||||
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
|
||||
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
||||
return
|
||||
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
||||
|
||||
token_mask = offs_token < num_valid_tokens
|
||||
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
||||
# offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
offs_bk = tl.arange(0, real_k_size)
|
||||
|
||||
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_bk[None, :] * pack_num * stride_ak)
|
||||
|
||||
off_experts = tl.load(expert_ids_ptr + pid_m)
|
||||
b_ptrs = b_ptr + off_experts * stride_be + (offs_bk[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
bs_ptrs = bs_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn # group-wise, need advanced
|
||||
|
||||
off_set = off_experts * stride_bce + offs_bn[None, :] * stride_bcn
|
||||
# load channel-wise scale & zero-point
|
||||
if USE_DOUBLE_QUANT:
|
||||
superbs_ptrs = superbs_ptr + off_set # channel-wise
|
||||
super_bs = tl.load(superbs_ptrs) # super scale
|
||||
|
||||
codebs_ptrs = codebs_ptr + off_set # channel-wise
|
||||
code_bs = tl.load(codebs_ptrs) # code scale
|
||||
codebzp_ptrs = codebzp_ptr + off_set # channel-wise
|
||||
code_bzp = tl.load(codebzp_ptrs) # code zp
|
||||
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
|
||||
b = tl.load(b_ptrs)
|
||||
|
||||
bs = tl.load(bs_ptrs)
|
||||
if USE_DOUBLE_QUANT:
|
||||
s_shift_bits = (1 - k % s_packnums) * 4
|
||||
bs = ((bs >> s_shift_bits) & 0xF) * super_bs
|
||||
|
||||
# reverse to int16
|
||||
b = tl.floor((b.to(tl.float32) * code_bs + code_bzp) + 0.5).to(tl.int16)
|
||||
# dequant
|
||||
b1 = (((b >> 9) & w_mask) - bzp) * bs
|
||||
a = tl.load(
|
||||
a_ptrs,
|
||||
mask=token_mask[:, None],
|
||||
other=0.0,
|
||||
)
|
||||
accumulator += tl.dot(a, b1.to(a.dtype))
|
||||
|
||||
b1 = (((b >> 6) & w_mask) - bzp) * bs
|
||||
a = tl.load(
|
||||
a_ptrs + 1,
|
||||
mask=token_mask[:, None],
|
||||
other=0.0,
|
||||
)
|
||||
accumulator += tl.dot(a, b1.to(a.dtype))
|
||||
|
||||
b1 = (((b >> 3) & w_mask) - bzp) * bs
|
||||
a = tl.load(
|
||||
a_ptrs + 2,
|
||||
mask=token_mask[:, None],
|
||||
other=0.0,
|
||||
)
|
||||
accumulator += tl.dot(a, b1.to(a.dtype))
|
||||
|
||||
b = ((b & w_mask) - bzp) * bs
|
||||
a = tl.load(
|
||||
a_ptrs + 3,
|
||||
mask=token_mask[:, None],
|
||||
other=0.0,
|
||||
)
|
||||
accumulator += tl.dot(a, b.to(a.dtype))
|
||||
|
||||
b_ptrs += real_k_size * stride_bk
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
|
||||
# advance scale ptr
|
||||
if USE_DOUBLE_QUANT:
|
||||
bs_ptrs += stride_bsk * (k % s_packnums)
|
||||
else:
|
||||
bs_ptrs += stride_bsk
|
||||
|
||||
if MUL_ROUTED_WEIGHT:
|
||||
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
||||
accumulator = accumulator * moe_weight[:, None]
|
||||
|
||||
accumulator = accumulator.to(compute_type)
|
||||
# -----------------------------------------------------------
|
||||
# Write back the block of the output
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
def fused_moe_wint2_impl(
|
||||
hidden_states,
|
||||
up_gate_proj_quant_weight,
|
||||
down_proj_quant_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
# inplace: bool = False,
|
||||
up_gate_proj_weight_scale=None,
|
||||
down_proj_weight_scale=None,
|
||||
up_gate_proj_super_scales=None,
|
||||
down_proj_super_scales=None,
|
||||
up_gate_proj_code_scale=None,
|
||||
down_proj_code_scale=None,
|
||||
up_gate_proj_code_zp=None,
|
||||
down_proj_code_zp=None,
|
||||
group_size=64,
|
||||
bit="wint2",
|
||||
):
|
||||
"""
|
||||
Implementation of Fused MoE kernels on GPU.
|
||||
"""
|
||||
# Check constraints.
|
||||
# A: [M, K]
|
||||
# B: [E, K, N]
|
||||
# assert hidden_states.shape[1] == up_gate_proj_weight_scale.shape[1],
|
||||
# f"Hidden size mismatch, {hidden_states.shape[1]} != {up_gate_proj_quant_weight.shape[1]}"
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert up_gate_proj_quant_weight.is_contiguous(), "Expert weights1 must be contiguous"
|
||||
assert down_proj_quant_weight.is_contiguous(), "Expert weights2 must be contiguous"
|
||||
assert group_size > 0, "Group size must be greater than 0"
|
||||
|
||||
num_tokens, K = hidden_states.shape
|
||||
E, _, N = up_gate_proj_quant_weight.shape
|
||||
M = num_tokens
|
||||
|
||||
if group_size < 0:
|
||||
group_size = K // up_gate_proj_weight_scale.shape[1]
|
||||
|
||||
top_k = topk_ids.shape[1]
|
||||
|
||||
intermediate_cache1 = paddle.empty(
|
||||
[M, top_k, N],
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
intermediate_cache2 = paddle.empty(
|
||||
(M * top_k, N // 2),
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
intermediate_cache3 = paddle.empty(
|
||||
(M, top_k, K),
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(topk_ids, E, BLOCK_SIZE_M)
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
A=hidden_states,
|
||||
B=up_gate_proj_quant_weight,
|
||||
C=intermediate_cache1,
|
||||
B_scale=up_gate_proj_weight_scale,
|
||||
B_super_scale=up_gate_proj_super_scales,
|
||||
B_code_scale=up_gate_proj_code_scale,
|
||||
B_code_zp=up_gate_proj_code_zp,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
sorted_token_ids=sorted_token_ids,
|
||||
expert_ids=expert_ids,
|
||||
num_tokens_post_padded=num_tokens_post_padded,
|
||||
mul_routed_weight=False,
|
||||
top_k=top_k,
|
||||
group_size=group_size,
|
||||
)
|
||||
|
||||
intermediate_cache2 = paddle.incubate.nn.functional.swiglu(intermediate_cache1.reshape([-1, N]))
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
A=intermediate_cache2,
|
||||
B=down_proj_quant_weight,
|
||||
C=intermediate_cache3,
|
||||
B_scale=down_proj_weight_scale,
|
||||
B_super_scale=down_proj_super_scales,
|
||||
B_code_scale=down_proj_code_scale,
|
||||
B_code_zp=down_proj_code_zp,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
sorted_token_ids=sorted_token_ids,
|
||||
expert_ids=expert_ids,
|
||||
num_tokens_post_padded=num_tokens_post_padded,
|
||||
mul_routed_weight=True,
|
||||
top_k=1,
|
||||
group_size=group_size,
|
||||
)
|
||||
|
||||
out_hidden_states = paddle.sum(intermediate_cache3, axis=1)
|
||||
return out_hidden_states
|
||||
|
||||
|
||||
def fused_moe_wint2_triton(
|
||||
hidden_states,
|
||||
up_gate_proj_quant_weight,
|
||||
down_proj_quant_weight,
|
||||
scores,
|
||||
gate_correction_bias,
|
||||
topk,
|
||||
up_gate_proj_weight_scale,
|
||||
down_proj_weight_scale,
|
||||
up_gate_proj_super_scales,
|
||||
down_proj_super_scales,
|
||||
up_gate_proj_code_scale,
|
||||
down_proj_code_scale,
|
||||
up_gate_proj_code_zp,
|
||||
down_proj_code_zp,
|
||||
):
|
||||
"""
|
||||
Fuse MoE with WINT2 quantization scheme and Triton backend.
|
||||
Args:
|
||||
hidden_states: input tensor.
|
||||
up_gate_proj_quant_weight: up_gate_proj weight matrix for experts.
|
||||
down_proj_quant_weight: down_proj weight matrix for experts.
|
||||
scores: gate scores.
|
||||
gate_correction_bias: bias correction for gates.
|
||||
topk: number of experts to use.
|
||||
up_gate_proj_weight_scale: scaling factor for up_gate_proj_quant_weight.
|
||||
down_proj_weight_scale: scaling factor for down_proj_quant_weight.
|
||||
up_gate_proj_super_scales: super scaling factor for up_gate_proj_scale.
|
||||
down_proj_super_scales: super scaling factor for down_proj_weight_scale.
|
||||
up_gate_proj_code_scale: code scaling factor for up_gate_proj_quant_weight.
|
||||
down_proj_code_scale: code scaling factor for down_proj_quant_weight.
|
||||
up_gate_proj_code_zp: code zero point for up_gate_proj_quant_weight.
|
||||
down_proj_code_zp: code zero point for down_proj_quant_weight.
|
||||
Returns:
|
||||
output tensor.
|
||||
"""
|
||||
|
||||
score = gate_correction_bias + scores
|
||||
_, topk_ids = paddle.topk(score, k=topk, axis=-1)
|
||||
topk_weights, _ = paddle.topk(scores, k=topk, axis=-1)
|
||||
topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdim=True)
|
||||
|
||||
return fused_moe_wint2_impl(
|
||||
hidden_states,
|
||||
up_gate_proj_quant_weight,
|
||||
down_proj_quant_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
up_gate_proj_weight_scale,
|
||||
down_proj_weight_scale,
|
||||
up_gate_proj_super_scales,
|
||||
down_proj_super_scales,
|
||||
up_gate_proj_code_scale,
|
||||
down_proj_code_scale,
|
||||
up_gate_proj_code_zp,
|
||||
down_proj_code_zp,
|
||||
bit="wint2",
|
||||
)
|
||||
Reference in New Issue
Block a user