""" # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ from __future__ import annotations import contextlib import json import os import sys import threading import paddle import paddle.distributed as dist from paddle.common_ops_import import convert_dtype from fastdeploy.model_executor.models.utils import convert_ndarray_dtype from paddlenlp.trainer import RuntimeTimer from fastdeploy.inference_args import GenerationPhase from .utils import ( _vocab_size_with_padding, generate_rank_mapping, get_infer_model_path, model_convert_fp8, ) from paddlenlp.transformers import AutoTokenizer from paddle.distributed import fleet from paddlenlp.utils.env import USE_FAST_TOKENIZER from paddlenlp.utils.log import logger from fastdeploy.model_executor.models.utils import load_checkpoint from fastdeploy.config import (AdditionalConfig, DecodingConfig, DeviceConfig, LLMConfig, LoadConfig, ModelConfig, MoEConfig, ParallelConfig, SpeculativeConfig, TmpConfig) from fastdeploy.inference_args import GenerationPhase from ..layers.quantization import get_quantization_config from .model_base import ModelRegistry from .qwen2 import Qwen2PretrainedModel from .utils import (_vocab_size_with_padding, convert_ndarray_dtype, load_checkpoint, parser_quant_type) from paddlenlp.transformers.configuration_utils import PretrainedConfig from paddlenlp.trl import llm_utils model_classes_mapping = { "Qwen2ForCausalLM": Qwen2PretrainedModel, } current_dir = os.path.dirname(os.path.abspath(__file__)) grandparent_dir = os.path.abspath( os.path.join(current_dir, os.pardir, os.pardir)) sys.path.append(grandparent_dir) def offload_model(model): """ Offload the model to CUDAPinnedPlace. """ device = paddle.CUDAPinnedPlace() for name, src in model.named_parameters(): if src._is_initialized() and not isinstance(src.place, paddle.CUDAPinnedPlace): dst = src._copy_to(device, True) dst_tensor = dst.value().get_tensor() src_tensor = src.value().get_tensor() src_tensor._clear() src_tensor._share_data_with(dst_tensor) def reload_model(model): """ Reload the model from CUDAPinnedPlace to GPU. """ model.to(paddle.device.get_device()) def reconstruct_memory(model): """ reconstruct_memory to avoid memory chunks """ offload_model(model) paddle.device.cuda.empty_cache() reload_model(model) def load_tensor_from_ipc_meta(state_dict): """ convert ipc_meta to tensor, but keep keys unchanged { 'key': ipc_meta } --> { 'key': tensor } example: state_dict = load_tensor_from_ipc_meta(state_dict) """ for k, v in state_dict.items(): # for pickling, we have to convert bytes object before save v[0] = v[0].encode("latin-1") state_dict[k] = paddle.to_tensor( paddle.base.core.LoDTensor._new_shared_cuda(tuple(v))) return state_dict def build_stream_line_model( config_path, model_path, dtype, block_size, max_len, stage_flag, min_dec_len=1, max_dec_len=128, temperature=1, top_k=8, top_p=0.8, pre_caches_length=0, export_model_type="default", use_stop_seqs=False, use_fake_parameter=False, show_topk: int = 0, msg_queue_id=None, pad_vocab=True, tokenizer=None, cache_quant_dtype="default", use_beam_search: bool = False, enf_gen: bool = False, speculate_method=None, speculate_max_draft_token_num: int = 1, speculate_max_candidate_len: int = 5, speculate_verify_window: int = 2, return_all_hidden_states: bool = False, draft_type: str = "None", start_layer_index: int = 0, moe_quant_type: str = "default", use_ep: bool = False, ep_just_for_test: bool = False, generation_phase: GenerationPhase = GenerationPhase.PREFILL, use_micro_batch: bool = False, fake_server_p: bool = False, scale_dir: str = "None", output_via_mq: bool = True, use_safetensors: bool = False, enable_redundant_experts: bool = False, redundant_experts_num: int = 0, max_batch_size: int = 128, use_offline_quant: bool = False, return_state_dicts: bool = False, sharing_model=None, sharing_state_dicts=None, ): """ Build a fused inference model Args: config_path (str): Path to the configuration file model_path (str): Path to the model file dtype (str): Data type of the model block_size (int): Block size max_len (int): Maximum sequence length stage_flag (str): Qianfan requirement, stage flag, used to identify different stages in \ time-consuming statistics logs, such as prediction ("msgid-1 predict") or export ("convert"). min_dec_len (int, optional): Minimum decoding length. Default is 1. max_dec_len (int, optional): Maximum decoding length. Default is 128. temperature (float, optional): Temperature coefficient. Default is 1. top_k (int, optional): k value in top-k sampling. Default is 0. top_p (float, optional): p value in top-p sampling. Default is 0.8. pre_caches_length (int, optional): Pre-cache length. Default is 0. export_model_type (str, optional): Type of model to export. Default is "default". use_stop_seqs (bool, optional): Whether to use stop sequences. Default is False. use_fake_parameter (bool, optional): Whether to use fake parameters. Default is False. show_topk (int, optional): Whether to show top-k results. Default is 0. msg_queue_id (int, optional): Message queue ID. Default is None. pad_vocab (bool, optional): Whether to pad the vocabulary. Default is True. cache_quant_dtype (str, optional): Cache quantization data type. Default is "default". use_beam_search (bool, optional): Whether to use beam search . Defaults is False. enf_gen (bool, optional): Whether to use enforce generation. Defaults is False. Returns: tuple[dict, Tokenizer, CausalLM]: A tuple containing the configuration, tokenizer, and model. """ runtime_timer = RuntimeTimer("build_model") runtime_timer.start(f"{stage_flag} stage model loading time") # config_path = os.path.join(model_path,"config.json") with open(config_path, "r") as fin: config = json.load(fin) architectures = config.get("architectures") if tokenizer is None: tokenizer = AutoTokenizer.from_pretrained( model_path, padding_side="left", use_fast=USE_FAST_TOKENIZER, ) config, _ = PretrainedConfig.get_config_dict(model_path) model_config = ModelConfig.from_dict(config) parallel_config = ParallelConfig() speculative_config = SpeculativeConfig() device_config = DeviceConfig() additional_config = AdditionalConfig() load_config = LoadConfig() tmp_config = TmpConfig() moe_config = MoEConfig() decoding_config = DecodingConfig() tensor_parallel_rank, tensor_parallel_degree = llm_utils.init_dist_env() parallel_config.tensor_parallel_rank = tensor_parallel_rank parallel_config.tensor_parallel_degree = tensor_parallel_degree parallel_config.mp_size = tensor_parallel_degree parallel_config.ep_size = 1 parallel_config.column_cut = False speculative_config.is_mtp = draft_type in ["eagle", "mtp"] speculative_config.draft_type = draft_type # Note(tangbinhan): used for load_checkpoint model_config.tensor_parallel_rank = parallel_config.tensor_parallel_rank model_config.tensor_parallel_degree = parallel_config.tensor_parallel_degree model_config.use_ep = use_ep model_config.is_mtp = speculative_config.is_mtp additional_config.use_fake_parameter = use_fake_parameter additional_config.ep_just_for_test = ep_just_for_test tmp_config.use_offline_quant = use_offline_quant if use_ep: if isinstance(model_config.moe_num_experts, list): model_config.has_multimodality = True moe_config.num_experts = model_config.moe_num_experts[0] else: moe_config.num_experts = model_config.moe_num_experts moe_config.num_experts_per_rank = ( moe_config.num_experts // parallel_config.tensor_parallel_degree ) moe_config.num_experts_start_offset = ( moe_config.num_experts_per_rank * parallel_config.tensor_parallel_rank ) # use the length of tokenizer as the origin vocab size ori_vocab_size = len(tokenizer) moe_intermediate_size = (config.get("moe_intermediate_size", None),) if isinstance(moe_intermediate_size, list) or isinstance( moe_intermediate_size, tuple ): moe_intermediate_size = moe_intermediate_size[0] if not use_ep and pad_vocab: config["vocab_size"] = _vocab_size_with_padding( config.get("vocab_size", tokenizer.vocab_size), config.pop("vocab_size_divisible_unit", 128), paddle.distributed.get_world_size(), ) group_size = config.get("group_size", -1) num_key_value_heads = config.get("num_key_value_heads", -1) if num_key_value_heads is None: num_key_value_heads = -1 if config.get("ffn_hidden_size", None) is not None: ffn_hidden_size = config["ffn_hidden_size"] elif config.get("intermediate_size", None) is not None: ffn_hidden_size = config["intermediate_size"] else: ffn_hidden_size = 4 * config["hidden_size"] if config["hidden_act"].lower() == "swiglu": if paddle.distributed.get_world_size() > 1: multiple_of = 8 * config["num_attention_heads"] else: multiple_of = 4 * config["num_attention_heads"] ffn_hidden_size = multiple_of * ( (int(2 * ffn_hidden_size / 3) + multiple_of - 1) // multiple_of) if draft_type in ["mtp", "eagle"]: num_layers = 1 else: num_layers = config.get("num_layers", None) or config.get( "num_hidden_layers", None ) if num_layers is None: raise ValueError(f"num_layers<{num_layers}> is invalid") use_moe = config.get( "moe_layer_start_index", num_layers ) < num_layers or draft_type in ["mtp", "eagle"] if not sharing_state_dicts: if use_fake_parameter: context = contextlib.nullcontext() elif use_safetensors: context = paddle.LazyGuard() model_class = model_classes_mapping[architectures[0]] state_dict = load_checkpoint(model_path, model_class, model_config, return_numpy=True) elif use_moe: tensor_parallel_degree = dist.get_world_size() if tensor_parallel_degree > 1: hcg = fleet.get_hybrid_communicate_group() mp_id = hcg.get_model_parallel_rank() # 统计文件子目录数量 subdir_count = 0 for entry in os.listdir(model_path): if "pp" in entry: full_path = os.path.join(model_path, entry) if os.path.isdir(full_path): subdir_count += 1 pp_num = subdir_count rank_model_paths = [ os.path.join(model_path, f"pp{i}/model_state.tp0{mp_id}.pdparams") for i in range(pp_num) ] context = paddle.LazyGuard() if not use_ep: logger.info(f"start to loading weight: {rank_model_paths}") state_dicts = [None for _ in rank_model_paths] def load_ckpt(i): state_dicts[i] = paddle.load(rank_model_paths[i], return_numpy=True) threads = [] for i in range(len(rank_model_paths)): thread = threading.Thread(target=load_ckpt, args=(i,)) threads.append(thread) thread.start() for t in threads: t.join() logger.info("Loading finished") else: # for EP loading state_dicts import glob state_dicts = [] files = glob.glob(model_path + "/merged_tp1_state_split/*") for file_name in files: try: state_dicts += [ {file_name.split("/")[-1]: file_name} ] # save {layer_name: weight_file_name} except Exception: pass need_reset_moe_intermediate_size = False if not use_ep: logger.info(f"moe_intermediate_size is: {moe_intermediate_size}") need_reset_moe_intermediate_size = ( (not use_ep) and (moe_quant_type == "fp8") and (moe_intermediate_size // 8 % 128 != 0) ) ori_up_size = moe_intermediate_size // 8 * 2 ori_down_size = ori_up_size // 2 if need_reset_moe_intermediate_size: moe_intermediate_size = ( 128 - moe_intermediate_size // 8 % 128 ) * 8 + moe_intermediate_size logger.info( f"moe_intermediate_size reset to {moe_intermediate_size}!" ) up_size = moe_intermediate_size // 8 * 2 down_size = up_size // 2 new_state_dict = {} def padding(key, value): import numpy as np # logger.info(f"deal {key}") if ("experts" in key) and ("up_gate_proj" in key): # logger.info("up_gate_proj") v_new = np.zeros(shape=[value.shape[0], up_size], dtype=value.dtype) v_new[:, :ori_down_size] = value[:, :ori_down_size] v_new[:, down_size : (down_size + ori_down_size)] = value[ :, ori_down_size: ] elif ("experts" in key) and ("down_proj" in key): # logger.info("down_proj") v_new = np.zeros( shape=[down_size, value.shape[1]], dtype=value.dtype ) v_new[:ori_down_size, :] = value else: v_new = value new_state_dict[key] = v_new if ("experts" in key) and ("up_gate_proj" in key or "down_proj" in key): pass # logger.info(f"padding {key}: {value.shape}->{v_new.shape}") threads = [] for state_dict in state_dicts: for key, value in state_dict.items(): if need_reset_moe_intermediate_size: thread = threading.Thread(target=padding, args=(key, value)) threads.append(thread) thread.start() else: new_state_dict[key] = value for t in threads: t.join() logger.info("Finish padding") state_dict = new_state_dict elif config.get("quant_type", None) is not None: # TODO(@wangbojun) currently, we use paddle.load for ptq model. tensor_parallel_degree = dist.get_world_size() if tensor_parallel_degree > 1: hcg = fleet.get_hybrid_communicate_group() mp_id = hcg.get_model_parallel_rank() rank_model_path = os.path.join( model_path, f"model_state.tp0{mp_id}.pdparams" ) if not os.path.exists(rank_model_path): full_model_path = os.path.join(model_path, "model_state.pdparams") if not os.path.exists(full_model_path): raise ValueError( f"can not find " + f"and model_state.pdparams under dir<{model_path}>" ) raise ValueError( "please run `split_weights.py` to gen weights for multi-gpu inference." ) if not os.path.exists(rank_model_path): full_model_path = os.path.join(model_path, "model_state.pdparams") if not os.path.exists(full_model_path): raise ValueError( f"can not find " + f"and model_state.pdparams under dir<{model_path}>" ) raise ValueError( "please run `split_weights.py` to gen weights for multi-gpu inference." ) model_state_path = rank_model_path if num_key_value_heads > 0: assert ( num_key_value_heads % tensor_parallel_degree == 0 ), "num_key_value_heads must be an integer multiple of tensor_parallel_degree" else: model_state_path = os.path.join(model_path, "model_state.pdparams") context = paddle.LazyGuard() logger.info(f"start to loading weight: {model_state_path}") if os.path.exists(model_state_path): state_dict = paddle.load(model_state_path, return_numpy=True) else: state_dict = sharing_state_dicts context = paddle.LazyGuard() use_rmsnorm = config.get("use_rmsnorm", True) if use_beam_search: decode_strategy = "beam_search" elif speculate_method is not None: if draft_type in ["draft_model", "eagle", "mtp"]: decode_strategy = "draft_model_sampling" else: decode_strategy = "speculate_decoding" else: decode_strategy = "sampling" logger.info(f"{runtime_timer.log()}") runtime_timer.start(f"{stage_flag} stage set parameters time") if config["hidden_act"].lower() == "swiglu": model_config.hidden_act = "swiglu" model_config.ffn_hidden_size = ffn_hidden_size model_config.max_seq_len = max_len model_config.num_layers = num_layers model_config.dtype = dtype model_config.export_model_type = export_model_type parallel_config.block_size = block_size model_config.group_size = group_size load_config.model_path = model_path model_config.use_rmsnorm = use_rmsnorm parallel_config.msg_queue_id = msg_queue_id additional_config.use_fake_parameter = use_fake_parameter model_config.num_key_value_heads = num_key_value_heads model_config.use_stop_seqs = use_stop_seqs tmp_config.cache_quant_dtype = cache_quant_dtype tmp_config.has_zero_point = config.get("has_zero_point", False) tmp_config.is_channel_wise = config.get("is_channel_wise", False), speculative_config.speculate_method = speculate_method speculative_config.speculate_max_draft_token_num = speculate_max_draft_token_num model_config.return_all_hidden_states = return_all_hidden_states speculative_config.draft_type = draft_type model_config.start_layer_index = start_layer_index model_config.use_moe = use_moe if use_moe: moe_config.use_moe = use_moe moe_config.num_experts = config.get("moe_num_experts", None) moe_config.moe_intermediate_size = config.get("moe_intermediate_size", None) moe_config.moe_use_gate_correction_bias = config.get( "moe_use_gate_correction_bias", True) moe_config.moe_every2 = config.get("moe_every2", False) moe_config.moe_topk = config.get("moe_topk", 8) moe_config.moe_num_shared_experts = config.get("moe_num_shared_experts", 0) moe_config.moe_layer_start_index = config.get("moe_layer_start_index", 0) moe_config.moe_use_ffn_shared_weight_and_bias = config.get( "moe_use_ffn_shared_weight_and_bias", False) moe_config.use_moe = use_moe moe_config.moe_group = config.get("moe_group", False) moe_config.moe_quant_type = moe_quant_type if top_k > 0: moe_config.top_k = top_k parallel_config.use_ep = use_ep additional_config.ep_just_for_test = ep_just_for_test model_config.generation_phase = generation_phase parallel_config.use_micro_batch = use_micro_batch tmp_config.weight_block_size = config.get("weight_block_size", [-1, -1]) load_config.scale_dir = scale_dir model_config.output_via_mq = output_via_mq decoding_config.bos_token_id = tokenizer.bos_token_id decoding_config.pad_token_id = tokenizer.pad_token_id decoding_config.temperature = temperature decoding_config.forced_eos_token_id = tokenizer.eos_token_id model_config.ori_vocab_size = ori_vocab_size decoding_config.max_dec_len = max_dec_len decoding_config.min_dec_len = min_dec_len additional_config.fake_server_p = fake_server_p decoding_config.decode_strategy = decode_strategy speculative_config.speculate_max_candidate_len = speculate_max_candidate_len speculative_config.speculate_verify_window = speculate_verify_window weight_dtype, act_dtype, cachekv_dtype = parser_quant_type( export_model_type) logger.info( f"quant_type: weight[{weight_dtype}], act[{act_dtype}], cachekv[{cachekv_dtype}]" ) model_config.weight_dtype = weight_dtype model_config.act_dtype = act_dtype if weight_dtype == "int8" and act_dtype in ["bfloat16", "float16"]: quant_cls = get_quantization_config("weight_only") quant_config = quant_cls.from_config({ "weight_only_linear_arch": None, "algo": "weight_only_int8" }) quant_config.quant_max_bound = 0 quant_config.quant_min_bound = 0 quant_config.quant_round_type = 0 model_config.use_smooth_quant = False elif weight_dtype == "int4" and act_dtype in ["bfloat16", "float16"]: quant_cls = get_quantization_config("weight_only") quant_config = quant_cls.from_config({ "weight_only_linear_arch": None, "algo": "weight_only_int4" }) quant_config.quant_max_bound = 0 quant_config.quant_min_bound = 0 quant_config.quant_round_type = 0 model_config.use_smooth_quant = False elif tmp_config.weight_block_size[0] != -1: quant_cls = get_quantization_config("block_wise") quant_config = quant_cls.from_config( {"weight_block_size": tmp_config.weight_block_size}) quant_config.quant_max_bound = 448 quant_config.quant_min_bound = -448 quant_config.quant_round_type = 1 model_config.use_smooth_quant = False elif weight_dtype == "int4" and act_dtype == "float8_e4m3fn": quant_cls = get_quantization_config("w4afp8") quant_config = quant_cls.from_config({ "weight_scale_dict": {}, "act_scale_dict": {} }) quant_config.quant_max_bound = 448 quant_config.quant_min_bound = -448 quant_config.quant_round_type = 1 model_config.use_smooth_quant = False elif weight_dtype == "int8" and act_dtype == weight_dtype: quant_cls = get_quantization_config("w8a8") quant_config = quant_cls.from_config({ "weight_scale_dict": {}, "act_scale_dict": {}, "use_gemm_dequant": False }) quant_config.quant_max_bound = 127 quant_config.quant_min_bound = -127 quant_config.quant_round_type = 0 model_config.use_smooth_quant = True elif weight_dtype == "float8_e4m3fn" and act_dtype == weight_dtype: quant_cls = get_quantization_config("wfp8afp8") quant_config = quant_cls.from_config({ "weight_scale_dict": {}, "act_scale_dict": {} }) quant_config.quant_max_bound = 448 quant_config.quant_min_bound = -448 quant_config.quant_round_type = 1 model_config.use_smooth_quant = False else: quant_config = None llm_config = LLMConfig( model_config=model_config, parallel_config=parallel_config, speculative_config=speculative_config, device_config=device_config, additional_config=additional_config, load_config=load_config, tmp_config=tmp_config, moe_config=moe_config, decoding_config=decoding_config, quant_config=quant_config, ) with context: model_cls = ModelRegistry.get_class(model_config.architectures[0]) model = model_cls(llm_config) model.eval() if use_fake_parameter: return config, tokenizer, model elif not use_moe: for k, v in state_dict.items(): if convert_dtype(v.dtype) == dtype: continue elif convert_dtype(v.dtype) == "float32": continue state_dict[k] = convert_ndarray_dtype(v, dtype) paddle.device.cuda.empty_cache() assert state_dict is not None model.set_state_dict(state_dict) if use_ep and generation_phase == GenerationPhase.DECODER: logger.info("Reloading model...") reconstruct_memory(model) logger.info(f"{runtime_timer.log()}") if sharing_state_dicts is not None: for k in list(sharing_state_dicts): sharing_state_dicts.pop(k) possible_state_dict = state_dict if return_state_dicts else None return config, tokenizer, model, possible_state_dict