diff --git a/fastdeploy/input/ernie4_5_vl_processor/utils/io_utils.py b/fastdeploy/input/ernie4_5_vl_processor/utils/io_utils.py index 43bf05d08..1535b64d4 100644 --- a/fastdeploy/input/ernie4_5_vl_processor/utils/io_utils.py +++ b/fastdeploy/input/ernie4_5_vl_processor/utils/io_utils.py @@ -17,16 +17,13 @@ import base64 import datetime import hashlib -import io import os import threading import uuid from pathlib import Path -import numpy as np import requests from PIL import Image -from PIL.ExifTags import TAGS RAW_VIDEO_DIR = "./download_tmp/raw_video/" RAW_IMAGE_DIR = "./download_tmp/raw_images/" @@ -110,155 +107,3 @@ def get_downloadable( retry_interval=retry_interval, ) return downloaded_path - - -def get_downloadable_image(download_path, need_exif_info, retry_max_time=0, retry_interval=3): - """ - 带上exif info和图像处理的get downloadable - """ - - def get_image_exif(image): - exif_data = image._getexif() - exif_info = {} - if exif_data is not None: - for tag, value in exif_data.items(): - tag_name = TAGS.get(tag, tag) - exif_info[tag_name] = value.strip() - return exif_info - - def has_transparent_background(img): - """判断图片是否有背景""" - if img.mode in ("RGBA", "LA") or (img.mode == "P" and "transparency" in img.info): - # Check for any pixel with alpha channel less than 255 (fully opaque) - alpha = img.convert("RGBA").split()[-1] - if alpha.getextrema()[0] < 255: - return True - return False - - def add_white_background(img): - """ - 给透明背景的图,加个白色背景 - """ - if img.mode != "RGBA": - img = img.convert("RGBA") - # 创建一个白色背景的图像,尺寸与原图一致 - img_white_background = Image.new("RGBA", img.size, (255, 255, 255)) - - # 将原图粘贴到白色背景上 - img_white_background.paste(img, (0, 0), img) - - return img_white_background - - def change_I16_to_L(img): - """ - 将图片从I;16模式转换为L模式 - """ - # 由于I模式的point函数只支持加减乘,所以下面的* (1 / 256)不能改成除法 - return img.point(lambda i: i * (1 / 256)).convert("L") - - image = get_downloadable( - download_path, - save_to_disk=False, - retry=retry_max_time, - retry_interval=retry_interval, - ) - if isinstance(image, Image.Image): - pil_image = image - else: - pil_image = Image.open(io.BytesIO(image)) - if need_exif_info: - try: - exif_info = get_image_exif(pil_image) - except Exception: - exif_info = {} - else: - exif_info = {} - - try: - if pil_image.mode == "I;16": - pil_image = change_I16_to_L(pil_image) - if has_transparent_background(pil_image): - pil_image = add_white_background(pil_image) - except Exception: - pass - - return pil_image.convert("RGB"), exif_info - - -def str2hash(url): - """ - 从一个str的到url - """ - return hashlib.sha256(url.encode()).hexdigest() - - -def pil2hash(pil): - """ - 从一个PIL.Image到hash - """ - byte_io = io.BytesIO() - pil.save(byte_io, format="PNG") # 选择无损格式,避免压缩影响 - image_bytes = byte_io.getvalue() - - return hashlib.sha256(image_bytes).hexdigest() - - -def imagepath_to_base64(image_path): - """imagepath_to_base64""" - image = Image.open(image_path).convert("RGB") - buffer = io.BytesIO() - image.save(buffer, format="JPEG") - image_bytes = buffer.getvalue() - base64_encoded = base64.b64encode(image_bytes).decode("utf-8") - return base64_encoded - - -def pil_image_to_base64(image): - """pil_image_to_base64""" - buffer = io.BytesIO() - image.save(buffer, format="JPEG") - image_bytes = buffer.getvalue() - base64_encoded = base64.b64encode(image_bytes).decode("utf-8") - return base64_encoded - - -def http_to_pil_image(url): - """http_to_pil_image""" - response = requests.get(url) - image_data = io.BytesIO(response.content) - pil_image = Image.open(image_data).convert("RGB") - return pil_image - - -def http_to_image_base64(url): - """http_to_image_base64""" - response = requests.get(url) - image_data = io.BytesIO(response.content) - return base64.b64encode(image_data.getvalue()).decode("utf-8") - - -def base64_to_pil_image(base64_string): - """ " base64_to_pil_image""" - image_bytes = base64.b64decode(base64_string) - buffer = io.BytesIO(image_bytes) - image = Image.open(buffer) - return image - - -def get_hashable(to_be_hashed): - """get hashable""" - if isinstance(to_be_hashed, bytes): - return to_be_hashed - elif isinstance(to_be_hashed, Image.Image): - return to_be_hashed.tobytes() - elif isinstance(to_be_hashed, str): - return to_be_hashed.encode("utf-8") - else: - raise ValueError(f"not support type: {type(to_be_hashed)}") - - -def load_dict_from_npz(npzfile): - """从npz文件读取数据""" - with np.load(npzfile, allow_pickle=True) as data: - loaded_dict = {key: data[key] for key in data.files} - return loaded_dict diff --git a/fastdeploy/model_executor/models/utils.py b/fastdeploy/model_executor/models/utils.py index 063344d19..35475e16c 100644 --- a/fastdeploy/model_executor/models/utils.py +++ b/fastdeploy/model_executor/models/utils.py @@ -17,29 +17,8 @@ from __future__ import annotations import enum -import hashlib -import json -import os -import random -import re -import struct -from functools import partial from typing import NamedTuple, Optional -import numpy as np -import paddle -from paddle.common_ops_import import convert_dtype -from paddleformers.transformers.model_utils import _add_variant -from paddleformers.transformers.utils import paddleformers_load -from paddleformers.utils.env import ( - PADDLE_WEIGHTS_INDEX_NAME, - SAFE_MASTER_WEIGHTS_INDEX_NAME, - SAFE_PEFT_WEIGHTS_INDEX_NAME, - SAFE_WEIGHTS_INDEX_NAME, -) -from paddleformers.utils.log import logger -from tqdm import tqdm - MAX_BSZ = 512 MAX_DRAFT_TOKENS = 6 @@ -67,428 +46,3 @@ class WeightMeta(NamedTuple): weight_name: str is_column: bool extra: Optional[str] = None - - -class UniqueIDGenerator: - """ - The generator for the export model id - """ - - def __init__(self): - pass - - def generate_unique_id(self, state_dict): - """ - Generate the model id from the timestamp - """ - keys = state_dict.keys() - sorted_keys = sorted(keys) - first_key = sorted_keys[0] - first_parameter = state_dict[first_key].cast("float32") - # 假设模型参数是唯一的,通过第一个key来获取md5sum - model_md5 = hashlib.md5(str(first_parameter.sum()).encode("utf-8")).hexdigest() - unique_id = f"{model_md5}-{random.randint(10000, 99999)}" - return unique_id - - -def load_sharded_checkpoint(folder, variant=None, return_numpy=False): - """ - - This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being - loaded in the model. - - Args: - folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint. - variant (`str`): The model variant. - - """ - # Load the index - pdparams_file = os.path.join(folder, _add_variant("model_state.pdparams", variant)) - lora_pdparams_file = os.path.join(folder, _add_variant("lora_model_state.pdparams", variant)) - safetensors_file = os.path.join(folder, _add_variant("model.safetensors", variant)) - if os.path.isfile(pdparams_file): - return paddle.load(pdparams_file, return_numpy=return_numpy) - if os.path.isfile(lora_pdparams_file): - return paddle.load(lora_pdparams_file, return_numpy=return_numpy) - if os.path.isfile(safetensors_file): - try: - from paddleformers.utils.safetensors import fast_load_file as safe_load_file - except ImportError: - from safetensors.numpy import load_file as safe_load_file - - state_dict = safe_load_file(safetensors_file) - if not return_numpy: - for key in list(state_dict.keys()): - if isinstance(state_dict[key], np.ndarray): - state_dict[key] = paddle.Tensor(state_dict.pop(key), zero_copy=True) - return state_dict - - index_file = os.path.join(folder, _add_variant(PADDLE_WEIGHTS_INDEX_NAME, variant)) - safe_index_file = os.path.join(folder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)) - safe_master_file = os.path.join(folder, _add_variant(SAFE_MASTER_WEIGHTS_INDEX_NAME, variant)) - safe_peft_file = os.path.join(folder, _add_variant(SAFE_PEFT_WEIGHTS_INDEX_NAME, variant)) - - index_present = os.path.isfile(index_file) - safe_index_present = os.path.isfile(safe_index_file) - safe_master_present = os.path.isfile(safe_master_file) - safe_peft_present = os.path.isfile(safe_peft_file) - - load_safe = False - load_index = None - if safe_index_present: - load_safe = True # load safe due to preference - load_index = safe_index_file - elif safe_master_present: - load_safe = True - load_index = safe_master_file - elif index_present: - load_index = index_file - elif safe_peft_present: - load_safe = True - load_index = safe_peft_file - else: - raise ValueError(f"Could not find {index_file} or {safe_index_file} or {safe_peft_file}") - - if load_safe: - try: - from paddleformers.utils.safetensors import fast_load_file as safe_load_file - except ImportError: - from safetensors.numpy import load_file as safe_load_file - - with open(load_index, "r", encoding="utf-8") as f: - index = json.load(f) - - shard_files = list(set(index["weight_map"].values())) - loader = safe_load_file if load_safe else partial(paddleformers_load, map_location="np" if return_numpy else "cpu") - - ret = {} - for shard_file in tqdm(shard_files): - state_dict = loader(os.path.join(folder, shard_file)) - ret.update(state_dict) - - if not return_numpy: - for key in list(ret.keys()): - if isinstance(ret[key], np.ndarray): - ret[key] = paddle.Tensor(ret.pop(key), zero_copy=True) - - return ret - - -def convert_ndarray_dtype(np_array: np.ndarray, target_dtype: str) -> np.ndarray: - """convert ndarray - - Args: - np_array (np.ndarray): numpy ndarray instance - target_dtype (str): the target dtype - - Returns: - np.ndarray: converted numpy ndarray instance - """ - source_dtype = convert_dtype(np_array.dtype) - if ( - source_dtype == "uint16" - and target_dtype == "bfloat16" - and paddle.is_compiled_with_custom_device("iluvatar_gpu") - ): - return np_array.view(dtype=target_dtype) - if source_dtype == "uint16" or target_dtype == "bfloat16": - if paddle.is_compiled_with_xpu(): - # xpu not support bf16. - tensor = paddle.to_tensor(np_array, place=paddle.CPUPlace()) - else: - tensor = paddle.to_tensor(np_array) - tensor = paddle.cast(tensor, target_dtype) - return tensor.numpy() - - # TODO(wj-Mcat): device_guard will slow the converting - # with device_guard("cpu"): - # tensor = paddle.to_tensor(np_array) - # tensor = paddle.cast(tensor, target_dtype) - # return tensor.numpy() - - if target_dtype == "bfloat16": - target_dtype = "uint16" - - return np_array.astype(target_dtype) - - -def set_seed(seed: int): - """ - set random seed for all random modules - """ - paddle.seed(seed) - random.seed(seed) - np.random.seed(seed) - - -def pad_batch_data(insts, pad_id=0, return_seq_len=False, pad_style="right"): - """Pad the instances to the max sequence length in batch.""" - # pad to max input len i bsz - max_len = max(map(len, insts)) - # pad to max input len - # max_len = args.max_len - if pad_style == "left": - inst_data = np.array([[pad_id] * (max_len - len(inst)) + list(inst) for inst in insts]) - else: - inst_data = np.array([list(inst) + [pad_id] * (max_len - len(inst)) for inst in insts]) - if return_seq_len: - seq_len = np.array([len(inst) for inst in insts]) - return inst_data.astype("int64").reshape([-1, max_len]), seq_len - else: - return inst_data.astype("int64").reshape([-1, max_len]) - - -def load_prefix_weights( - prefix_path: str, - inference: bool = False, - batch_size: int = 1, - dtype: str = "bfloat16", -) -> np.ndarray | list[paddle.Tensor]: - """load prefix weight by path - - Args: - prefix_path (str): the path of prefix weight - """ - past_key_values = paddle.to_tensor(np.load(f"{prefix_path}/pre_caches.npy")).unsqueeze(2) - - if batch_size > 1: - past_key_values = paddle.concat([past_key_values] * batch_size, axis=2) - - # .chatglm static model require one tensor, otherwise list of tensor - past_key_values = past_key_values.astype(dtype) - if inference: - return past_key_values.numpy() - return past_key_values - - -def w4a8_weight_convert(state_dict): - """W4A8 权重转换函数 - Args: - state_dict (dict): state_dict of model - """ - - def w4_weight_squash(value, name, w4a8_weight_bites_name_map): - weight_dq = value - # W8表象下的W4权重的absmax值为112,使用正负112进行权重类型判断 - if weight_dq.max() == 112 or weight_dq.min() == -112: - weight_dq = weight_dq.cast("int8") - np_weight_dq = np.array(weight_dq, dtype="int8").view("uint8") - np_weight_dq_left_div_16 = (np_weight_dq / 16).astype("int8") - # weight_q = (weight_dq/16).cast('int8') - weight_q = paddle.to_tensor(np_weight_dq_left_div_16, dtype="int8") - logger.debug(f"int4 weight:{name}") - w4a8_weight_bites_name_map[name] = 4 - return weight_q.cast("int8") - elif weight_dq.max() == 127 or weight_dq.min() == -128: - logger.debug(f"int8 weight:{name}") - w4a8_weight_bites_name_map[name] = 8 - return weight_dq.cast("int8") - else: - logger.debug(f"fp16/bf16/float weight:{name}") - return weight_dq - - w4a8_weight_bites_name_map = {} - for name, value in state_dict.items(): - if value.dtype == "uint16": - weight_q = w4_weight_squash( - paddle.to_tensor(value).cast("float32"), - name, - w4a8_weight_bites_name_map, - ) - state_dict[name] = weight_q.numpy() if weight_q is not None else value - del weight_q - w4a8_weight_bites_layers_map = {} - w4a8_weight_bites_layers_map["qkv_gemm_bits_map"] = [] - w4a8_weight_bites_layers_map["out_gemm_bits_map"] = [] - w4a8_weight_bites_layers_map["up_gate_proj_gemm_bits_map"] = [] - w4a8_weight_bites_layers_map["down_proj_gemm_bits_map"] = [] - for name_keys, gemm_bits in w4a8_weight_bites_name_map.items(): - if "qkv_proj" in name_keys: - w4a8_weight_bites_layers_map["qkv_gemm_bits_map"].append(gemm_bits) - elif "out_proj" in name_keys: - w4a8_weight_bites_layers_map["out_gemm_bits_map"].append(gemm_bits) - elif "linear1" in name_keys: - w4a8_weight_bites_layers_map["up_gate_proj_gemm_bits_map"].append(gemm_bits) - elif "linear2" in name_keys: - w4a8_weight_bites_layers_map["down_proj_gemm_bits_map"].append(gemm_bits) - logger.debug(f"w4a8_weight_bites_layers_map:{w4a8_weight_bites_layers_map}") - return state_dict, w4a8_weight_bites_layers_map - - -def _vocab_size_with_padding(vocab_size, div_unit, mp_degree): - padded_size = vocab_size - multiple = div_unit * mp_degree - while (padded_size % multiple) != 0: - padded_size += 1 - # logger.warning( - # " > padded vocab (size: {}) with {} dummy tokens " - # "(new size: {})".format(vocab_size, padded_size - vocab_size, padded_size) - # ) - return padded_size - - -def save_test_case(cases: list[list[dict]], file: str): - """save test to result file - - Args: - cases (list[list[dict]]): the content of case - file (str): the path of saved file - """ - with open(file, "w+", encoding="utf-8") as f: - for case in cases: - raw = json.dumps(case, ensure_ascii=False) - f.write(raw + "\n") - - -def infer_save_test_case(cases: list[list[dict]], file: str): - """save test to result file - - Args: - cases (list[list[dict]]): the content of case - file (str): the path of saved file - """ - with open(file, "a+", encoding="utf-8") as f: - for case in cases: - raw = json.dumps(case, ensure_ascii=False) - f.write(raw + "\n") - - -def deserialize_from_file(fp): - """ - deserialize a binary file into an array - """ - - x_type = fp.read(1) - x_type_out = struct.unpack("c", x_type)[0] - # data - data_list = [] - if x_type_out == b"0": - data = fp.read(4) - data_out = struct.unpack("f", data)[0] - while data: - data_out = struct.unpack("f", data)[0] - data_list.append(data_out) - data = fp.read(4) - elif x_type_out == b"1": - data = fp.read(8) - while data: - data_out = struct.unpack("l", data)[0] - data_list.append(data_out) - data = fp.read(8) - elif x_type_out == b"2": - data = fp.read(4) - while data: - data_out = struct.unpack("i", data)[0] - data_list.append(data_out) - data = fp.read(4) - else: - print("type error") - data_arr = np.array(data_list) - return data_arr - - -def calculate_effective_tokens(training_args, train_dataset, max_seq_len): - """ - Calculate the effective tokens during training. - """ - total_effective_tokens = 0 - try: - data_parallel_degree = training_args.data_parallel_degree - except Exception: - data_parallel_degree = 1 - if training_args.sharding_parallel_degree > 1: - sharding_parallel_degree = training_args.sharding_parallel_degree - else: - sharding_parallel_degree = 1 - - total_batch = ( - training_args.max_steps - * training_args.per_device_train_batch_size - * training_args.gradient_accumulation_steps - * sharding_parallel_degree - * data_parallel_degree - ) - for i, data in enumerate(train_dataset): - if i == total_batch: - break - for dd in data: - total_effective_tokens += len(dd.token_ids) - total_tokens = total_batch * max_seq_len - - return total_effective_tokens, total_tokens - - -def parser_quant_type(quant_type): - """ - Parse the quantization type string and return the corresponding quantization types for weights, - activations, and custom. - - Args: - quant_type (str): The quantization type string. It can be one of the following formats: - - "weight_only_int8" or "wint8": Only weights are quantized to int8. - - "weight_only_int4" or "wint4": Only weights are quantized to int4. - - A custom string in the format of "wxaybzcfp8", where 'x', 'y', 'z' are the quantization bitwidths - for weights, activations, and custom respectively, - and 'a', 'b', 'c' are the prefixes indicating the quantization types - (e.g., 'fp8' for floating-point 8-bit). - If a prefix is missing, the default quantization type will be used. - - Returns: - tuple: A tuple of three strings representing the quantization types for weights, activations, - and custom respectively. - If the input is "weight_only_int8" or "wint8", returns ("int8", default_type, default_type). - If the input is "weight_only_int4" or "wint4", returns ("int4", default_type, default_type). - For custom strings, returns the parsed quantization types based on the input format. - - Raises: - AssertionError: If the custom quantization type string format is incorrect. - """ - default_type = paddle.get_default_dtype() - if quant_type == "default" or quant_type is None: - return default_type, default_type, default_type - conver_dict = { - "8": "int8", - "4": "int4", - "16": paddle.get_default_dtype(), - "fp8": "float8_e4m3fn", - "fp16": "float16", - "bf16": "bfloat16", - "fp32": "float32", - } - cache_type = default_type - if "c8" in quant_type: - cache_type = "int8" - elif "cfp8" in quant_type: - cache_type = "fp8" - elif "c4" in quant_type: - cache_type = "int4" - - if "weight_only_int8" in quant_type or "wint8" in quant_type: - return "int8", default_type, cache_type - elif "weight_only_int4" in quant_type or "wint4" in quant_type: - return "int4", default_type, cache_type - else: - # split quant type, eg. w4afp8c8 -> ['w', '4', 'a', 'fp8', 'c', '8'] - pattern = f"({'|'.join(map(re.escape, ['w', 'a', 'c']))})" - splited_type = re.split(pattern, quant_type) - splited_type = [tmp_type for tmp_type in splited_type if tmp_type] - assert len(splited_type) % 2 == 0 and len(splited_type) <= 6, f"Quant type[{quant_type}] format error." - - quant_type_list = [] - if "w" in splited_type: - w_idx = splited_type.index("w") - quant_type_list.append(conver_dict[splited_type[w_idx + 1]]) - else: - quant_type_list.append(default_type) - if "a" in splited_type: - a_idx = splited_type.index("a") - quant_type_list.append(conver_dict[splited_type[a_idx + 1]]) - else: - quant_type_list.append(default_type) - if "c" in splited_type: - c_idx = splited_type.index("c") - quant_type_list.append(conver_dict[splited_type[c_idx + 1]]) - else: - quant_type_list.append(default_type) - - return quant_type_list[0], quant_type_list[1], quant_type_list[2] diff --git a/fastdeploy/model_executor/ops/triton_ops/wint2_fused_moe.py b/fastdeploy/model_executor/ops/triton_ops/wint2_fused_moe.py deleted file mode 100644 index e69c34a21..000000000 --- a/fastdeploy/model_executor/ops/triton_ops/wint2_fused_moe.py +++ /dev/null @@ -1,543 +0,0 @@ -""" -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" - -import paddle -import triton.language as tl -from paddle import _C_ops -from paddle.base.framework import OpProtoHolder -from paddle.framework import in_dynamic_or_pir_mode - -from fastdeploy.model_executor.ops.triton_ops.triton_utils import ( - get_dtype_str, - paddle_use_triton, - rendering_common_template, -) - -BLOCK_SIZE_M = 16 - - -def invoke_fused_moe_kernel( - A, - B, - C, - B_scale, - B_super_scale, - B_code_scale, - B_code_zp, - topk_weights, - topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - mul_routed_weight=False, - top_k=-1, - group_size=-1, -): - """ - Invoke Fused Moe Kernel - """ - KK = A.shape[-1] - NN = B.shape[-1] - sstride_am, sstride_ak = A.shape[1], 1 - sstride_be, sstride_bk, sstride_bn = B.shape[1] * B.shape[2], B.shape[2], 1 - sstride_cm, sstride_cn = C.shape[-1], 1 - sstride_bse, sstride_bsk, sstride_bsn = ( - B_scale.shape[1] * B_scale.shape[2], - B_scale.shape[2], - 1, - ) - sstride_bce, sstride_bck, sstride_bcn = B_code_scale.shape[1], 1, 1 - - ddouble_quant = B_super_scale is not None - - prepare_attr_for_triton_kernel = """ - auto N = B.shape()[2]; - auto K = A.shape()[1]; - auto EM = sorted_token_ids.shape()[0]; - auto num_valid_tokens = (topk_ids.shape()[0]) * (topk_ids.shape()[1]); - auto stride_am = A.strides()[0]; - auto stride_ak = A.strides()[1]; - auto stride_be = B.strides()[0]; - auto stride_bk = B.strides()[1]; - auto stride_bn = B.strides()[2]; - auto stride_cm = C.strides()[1]; - auto stride_cn = C.strides()[2]; - auto stride_bse = B_scale.strides()[0]; - auto stride_bsk = B_scale.strides()[1]; - auto stride_bsn = 1; - auto stride_bce = B_code_scale.strides()[0]; - auto stride_bck = 1; - auto stride_bcn = 1; - auto double_quant = true; - """ - if mul_routed_weight: - config = { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "GROUP_SIZE_M": 2, - "num_warps": 4, - "num_stages": 8, - } - else: - config = { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 512, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 12, - } - configs = [] - - configs.append(dict(config)) - - op_name = "wint2_moe_ffn" - op_name += f"{get_dtype_str(A.dtype)}" - op_name += f"{B.shape[0]}" - op_name += f"{B.shape[1]}" - op_name += f"{B.shape[2]}" - - if op_name not in OpProtoHolder.instance().op_proto_map.keys(): - prepare_ptr_for_triton_kernel = """ - CUdeviceptr input_ptrs[11] = { - get_tensor_ptr(A), - get_tensor_ptr(B), - get_tensor_ptr(C), - get_tensor_ptr(B_scale), - get_tensor_ptr(B_super_scale), - get_tensor_ptr(B_code_scale), - get_tensor_ptr(B_code_zp), - get_tensor_ptr(topk_weights), - get_tensor_ptr(sorted_token_ids), - get_tensor_ptr(expert_ids), - get_tensor_ptr(num_tokens_post_padded), - }; - """ - template_used = rendering_common_template( - invoke_fused_moe_kernel, - prepare_attr_for_triton_kernel, - prepare_ptr_for_triton_kernel, - ) - grid = ("(EM+BLOCK_SIZE_M-1)/BLOCK_SIZE_M * ((N+BLOCK_SIZE_N-1)/BLOCK_SIZE_N)",) - - moe_wint2_ffn_kernel[(op_name, template_used, grid, configs)]( - A, - B, - C, - B_scale, - B_super_scale, - B_code_scale, - B_code_zp, - topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - NN, - KK, - -1, # EEM, - -1, # nnum_valid_tokens, - sstride_am, - sstride_ak, - sstride_be, - sstride_bk, - sstride_bn, - sstride_cm, - sstride_cn, - sstride_bse, - sstride_bsk, - sstride_bsn, - sstride_bce, - sstride_bck, - sstride_bcn, - MUL_ROUTED_WEIGHT=(int)(mul_routed_weight), - USE_DOUBLE_QUANT=(int)(ddouble_quant), - top_k=top_k, - BLOCK_SIZE_K=group_size, - ) - if in_dynamic_or_pir_mode(): - - outs = _C_ops._run_custom_op( - op_name, - A, - B, - C, - B_scale, - B_super_scale, - B_code_scale, - B_code_zp, - topk_weights, - topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - mul_routed_weight, - top_k, - group_size, - ) - return outs[0] - - -@paddle_use_triton( - key=["1"], -) -def moe_wint2_ffn_kernel( - # Pointers to matrices - a_ptr, - b_ptr, - c_ptr, - bs_ptr, - superbs_ptr, - codebs_ptr, - codebzp_ptr, - topk_weights_ptr, - sorted_token_ids_ptr, - expert_ids_ptr, - num_tokens_post_padded_ptr, - # Matrix dimensions - N, - K, - EM, - num_valid_tokens, - # The stride variables represent how much to increase the ptr by when - # moving by 1 element in a particular dimension. E.g. `stride_am` is - # how much to increase `a_ptr` by to get the element one row down - # (A has M rows). - stride_am, - stride_ak, - stride_be, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - stride_bse, - stride_bsk, - stride_bsn, - stride_bce, - stride_bck, - stride_bcn, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - MUL_ROUTED_WEIGHT: tl.constexpr, - USE_DOUBLE_QUANT: tl.constexpr, - top_k: tl.constexpr, -): - """ - Implements the fused computation for a Mixture of Experts (MOE) using - token and expert matrices. - - Key Parameters: - - A: The input tensor representing tokens with shape (*, K), where '*' can - be any shape representing batches and K is the feature dimension of - each token. - - B: The stacked MOE weight tensor with shape (E, N, K), where E is - the number of experts, K is the input feature dimension, and N is - the output feature dimension. - - C: The output cache tensor with shape (M, topk, N), where M is the - total number of tokens post padding, topk is the number of times - each token is repeated, and N is the output feature dimension. - - sorted_token_ids: A tensor containing the sorted indices of tokens, - repeated topk times and arranged by the expert index they are - assigned to. - - expert_ids: A tensor containing the indices of the expert for each - block. It determines which expert matrix from B should be used for - each block in A. - This kernel performs the multiplication of a token by its corresponding - expert matrix as determined by `expert_ids`. The sorting of - `sorted_token_ids` by expert index and padding ensures divisibility by - BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix - multiplication across different blocks processed by the same expert. - """ - - if USE_DOUBLE_QUANT: - # INT4 scale - s_packnums: tl.constexpr = 2 - bzp: tl.constexpr = 32 - w_mask: tl.constexpr = 0x3F - pack_num: tl.constexpr = 4 - real_k_size: tl.constexpr = (BLOCK_SIZE_K - 1) // pack_num + 1 - - pid = tl.program_id(axis=0) - - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - compute_type = c_ptr.dtype.element_ty - - num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) - if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: - return - offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) - - token_mask = offs_token < num_valid_tokens - - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - # offs_k = tl.arange(0, BLOCK_SIZE_K) - offs_bk = tl.arange(0, real_k_size) - - a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_bk[None, :] * pack_num * stride_ak) - - off_experts = tl.load(expert_ids_ptr + pid_m) - b_ptrs = b_ptr + off_experts * stride_be + (offs_bk[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - bs_ptrs = bs_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn # group-wise, need advanced - - off_set = off_experts * stride_bce + offs_bn[None, :] * stride_bcn - # load channel-wise scale & zero-point - if USE_DOUBLE_QUANT: - superbs_ptrs = superbs_ptr + off_set # channel-wise - super_bs = tl.load(superbs_ptrs) # super scale - - codebs_ptrs = codebs_ptr + off_set # channel-wise - code_bs = tl.load(codebs_ptrs) # code scale - codebzp_ptrs = codebzp_ptr + off_set # channel-wise - code_bzp = tl.load(codebzp_ptrs) # code zp - - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - - b = tl.load(b_ptrs) - - bs = tl.load(bs_ptrs) - if USE_DOUBLE_QUANT: - s_shift_bits = (1 - k % s_packnums) * 4 - bs = ((bs >> s_shift_bits) & 0xF) * super_bs - - # reverse to int16 - b = tl.floor((b.to(tl.float32) * code_bs + code_bzp) + 0.5).to(tl.int16) - # dequant - b1 = (((b >> 9) & w_mask) - bzp) * bs - a = tl.load( - a_ptrs, - mask=token_mask[:, None], - other=0.0, - ) - accumulator += tl.dot(a, b1.to(a.dtype)) - - b1 = (((b >> 6) & w_mask) - bzp) * bs - a = tl.load( - a_ptrs + 1, - mask=token_mask[:, None], - other=0.0, - ) - accumulator += tl.dot(a, b1.to(a.dtype)) - - b1 = (((b >> 3) & w_mask) - bzp) * bs - a = tl.load( - a_ptrs + 2, - mask=token_mask[:, None], - other=0.0, - ) - accumulator += tl.dot(a, b1.to(a.dtype)) - - b = ((b & w_mask) - bzp) * bs - a = tl.load( - a_ptrs + 3, - mask=token_mask[:, None], - other=0.0, - ) - accumulator += tl.dot(a, b.to(a.dtype)) - - b_ptrs += real_k_size * stride_bk - a_ptrs += BLOCK_SIZE_K * stride_ak - - # advance scale ptr - if USE_DOUBLE_QUANT: - bs_ptrs += stride_bsk * (k % s_packnums) - else: - bs_ptrs += stride_bsk - - if MUL_ROUTED_WEIGHT: - moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) - accumulator = accumulator * moe_weight[:, None] - - accumulator = accumulator.to(compute_type) - # ----------------------------------------------------------- - # Write back the block of the output - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] - c_mask = token_mask[:, None] & (offs_cn[None, :] < N) - tl.store(c_ptrs, accumulator, mask=c_mask) - - -def fused_moe_wint2_impl( - hidden_states, - up_gate_proj_quant_weight, - down_proj_quant_weight, - topk_weights, - topk_ids, - # inplace: bool = False, - up_gate_proj_weight_scale=None, - down_proj_weight_scale=None, - up_gate_proj_super_scales=None, - down_proj_super_scales=None, - up_gate_proj_code_scale=None, - down_proj_code_scale=None, - up_gate_proj_code_zp=None, - down_proj_code_zp=None, - group_size=64, - bit="wint2", -): - """ - Implementation of Fused MoE kernels on GPU. - """ - # Check constraints. - # A: [M, K] - # B: [E, K, N] - # assert hidden_states.shape[1] == up_gate_proj_weight_scale.shape[1], - # f"Hidden size mismatch, {hidden_states.shape[1]} != {up_gate_proj_quant_weight.shape[1]}" - assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert up_gate_proj_quant_weight.is_contiguous(), "Expert weights1 must be contiguous" - assert down_proj_quant_weight.is_contiguous(), "Expert weights2 must be contiguous" - assert group_size > 0, "Group size must be greater than 0" - - num_tokens, K = hidden_states.shape - E, _, N = up_gate_proj_quant_weight.shape - M = num_tokens - - if group_size < 0: - group_size = K // up_gate_proj_weight_scale.shape[1] - - top_k = topk_ids.shape[1] - - intermediate_cache1 = paddle.empty( - [M, top_k, N], - dtype=hidden_states.dtype, - ) - intermediate_cache2 = paddle.empty( - (M * top_k, N // 2), - dtype=hidden_states.dtype, - ) - intermediate_cache3 = paddle.empty( - (M, top_k, K), - dtype=hidden_states.dtype, - ) - - from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess - - sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(topk_ids, E, BLOCK_SIZE_M) - - invoke_fused_moe_kernel( - A=hidden_states, - B=up_gate_proj_quant_weight, - C=intermediate_cache1, - B_scale=up_gate_proj_weight_scale, - B_super_scale=up_gate_proj_super_scales, - B_code_scale=up_gate_proj_code_scale, - B_code_zp=up_gate_proj_code_zp, - topk_weights=topk_weights, - topk_ids=topk_ids, - sorted_token_ids=sorted_token_ids, - expert_ids=expert_ids, - num_tokens_post_padded=num_tokens_post_padded, - mul_routed_weight=False, - top_k=top_k, - group_size=group_size, - ) - - intermediate_cache2 = paddle.incubate.nn.functional.swiglu(intermediate_cache1.reshape([-1, N])) - - invoke_fused_moe_kernel( - A=intermediate_cache2, - B=down_proj_quant_weight, - C=intermediate_cache3, - B_scale=down_proj_weight_scale, - B_super_scale=down_proj_super_scales, - B_code_scale=down_proj_code_scale, - B_code_zp=down_proj_code_zp, - topk_weights=topk_weights, - topk_ids=topk_ids, - sorted_token_ids=sorted_token_ids, - expert_ids=expert_ids, - num_tokens_post_padded=num_tokens_post_padded, - mul_routed_weight=True, - top_k=1, - group_size=group_size, - ) - - out_hidden_states = paddle.sum(intermediate_cache3, axis=1) - return out_hidden_states - - -def fused_moe_wint2_triton( - hidden_states, - up_gate_proj_quant_weight, - down_proj_quant_weight, - scores, - gate_correction_bias, - topk, - up_gate_proj_weight_scale, - down_proj_weight_scale, - up_gate_proj_super_scales, - down_proj_super_scales, - up_gate_proj_code_scale, - down_proj_code_scale, - up_gate_proj_code_zp, - down_proj_code_zp, -): - """ - Fuse MoE with WINT2 quantization scheme and Triton backend. - Args: - hidden_states: input tensor. - up_gate_proj_quant_weight: up_gate_proj weight matrix for experts. - down_proj_quant_weight: down_proj weight matrix for experts. - scores: gate scores. - gate_correction_bias: bias correction for gates. - topk: number of experts to use. - up_gate_proj_weight_scale: scaling factor for up_gate_proj_quant_weight. - down_proj_weight_scale: scaling factor for down_proj_quant_weight. - up_gate_proj_super_scales: super scaling factor for up_gate_proj_scale. - down_proj_super_scales: super scaling factor for down_proj_weight_scale. - up_gate_proj_code_scale: code scaling factor for up_gate_proj_quant_weight. - down_proj_code_scale: code scaling factor for down_proj_quant_weight. - up_gate_proj_code_zp: code zero point for up_gate_proj_quant_weight. - down_proj_code_zp: code zero point for down_proj_quant_weight. - Returns: - output tensor. - """ - - score = gate_correction_bias + scores - _, topk_ids = paddle.topk(score, k=topk, axis=-1) - topk_weights, _ = paddle.topk(scores, k=topk, axis=-1) - topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdim=True) - - return fused_moe_wint2_impl( - hidden_states, - up_gate_proj_quant_weight, - down_proj_quant_weight, - topk_weights, - topk_ids, - up_gate_proj_weight_scale, - down_proj_weight_scale, - up_gate_proj_super_scales, - down_proj_super_scales, - up_gate_proj_code_scale, - down_proj_code_scale, - up_gate_proj_code_zp, - down_proj_code_zp, - bit="wint2", - )