mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[vl]remove duplicated load logic (#2744)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
This commit is contained in:
@@ -14,39 +14,26 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.distributed.fleet as fleet
|
||||
from paddleformers.transformers.model_utils import load_tp_checkpoint
|
||||
from safetensors import safe_open
|
||||
|
||||
from fastdeploy.config import (DeviceConfig, FDConfig, GraphOptimizationConfig,
|
||||
KVCacheConfig, LoadConfig, ModelConfig,
|
||||
MoEConfig, MoEPhase, ParallelConfig,
|
||||
SpeculativeConfig)
|
||||
from fastdeploy.config import ModelConfig
|
||||
from fastdeploy.input.ernie_tokenizer import ErnieBotTokenizer
|
||||
from fastdeploy.input.mm_processor import DataProcessor
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
from fastdeploy.model_executor.layers.attention import get_attention_backend
|
||||
from fastdeploy.model_executor.layers.rotary_embedding import get_rope_3d
|
||||
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
|
||||
from fastdeploy.model_executor.layers.sample.sampler import Sampler
|
||||
from fastdeploy.model_executor.models.ernie4_5_moe import \
|
||||
Ernie4_5_PretrainedModel
|
||||
from fastdeploy.model_executor.models.ernie4_5_vl.configuration import \
|
||||
Ernie4_5_VLMoeConfig
|
||||
from fastdeploy.model_executor.models.ernie4_5_vl.dfnrope import \
|
||||
DFNRopeVisionTransformerConfig
|
||||
from fastdeploy.model_executor.models.ernie4_5_vl.dfnrope.modeling import \
|
||||
DFNRopeVisionTransformerPretrainedModel
|
||||
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import (
|
||||
ScatterOp, VariableResolutionResamplerModel)
|
||||
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import \
|
||||
ScatterOp
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
from fastdeploy.worker.output import SamplerOutput
|
||||
from fastdeploy.worker.utils import check_safetensors_model
|
||||
from fastdeploy.worker.vl_model_runner_base import VLModelRunnerBase
|
||||
@@ -169,6 +156,34 @@ class GPUVLModelRunner(VLModelRunnerBase):
|
||||
task.start_idx += token_chunk_size
|
||||
task.chunk_idx += 1
|
||||
|
||||
def _init_image_preprocess(self, vision_config) -> None:
|
||||
processor = DataProcessor(
|
||||
tokenizer_name=self.args.tokenizer,
|
||||
image_preprocessor_name=str(self.args.image_preprocessor),
|
||||
)
|
||||
processor.eval()
|
||||
image_preprocess = processor.image_preprocessor
|
||||
image_preprocess.image_mean_tensor = paddle.to_tensor(
|
||||
image_preprocess.image_mean, dtype="float32"
|
||||
).reshape([1, 3, 1, 1])
|
||||
image_preprocess.image_std_tensor = paddle.to_tensor(
|
||||
image_preprocess.image_std, dtype="float32"
|
||||
).reshape([1, 3, 1, 1])
|
||||
image_preprocess.rescale_factor = paddle.to_tensor(
|
||||
image_preprocess.rescale_factor, dtype="float32"
|
||||
)
|
||||
image_preprocess.image_mean_tensor = (
|
||||
image_preprocess.image_mean_tensor.squeeze(
|
||||
[-2, -1]
|
||||
).repeat_interleave(vision_config.patch_size**2 * 1, -1)
|
||||
)
|
||||
image_preprocess.image_std_tensor = (
|
||||
image_preprocess.image_std_tensor.squeeze(
|
||||
[-2, -1]
|
||||
).repeat_interleave(vision_config.patch_size**2 * 1, -1)
|
||||
)
|
||||
return image_preprocess
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_name: str,
|
||||
@@ -198,98 +213,41 @@ class GPUVLModelRunner(VLModelRunnerBase):
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
|
||||
config = Ernie4_5_VLMoeConfig.from_pretrained(
|
||||
self.args.llm_model_name_or_path,
|
||||
tensor_parallel_degree=self.tensor_parallel_degree,
|
||||
tensor_parallel_rank=self.tensor_parallel_rank,
|
||||
moe_group="dummy",
|
||||
)
|
||||
self.model_cfg = config
|
||||
if self.is_safetensors_model:
|
||||
meta_json = os.path.join(self.args.model_name_or_path,
|
||||
"model.safetensors.index.json")
|
||||
if os.path.exists(meta_json):
|
||||
with open(
|
||||
os.path.join(self.args.model_name_or_path,
|
||||
"model.safetensors.index.json"),
|
||||
"r") as f:
|
||||
self.weight_map = json.load(f)["weight_map"]
|
||||
else:
|
||||
self.weight_map = {}
|
||||
with safe_open(os.path.join(self.args.model_name_or_path,
|
||||
"model.safetensors"),
|
||||
framework="np") as f:
|
||||
keys = f.keys()
|
||||
for k in keys:
|
||||
self.weight_map[k] = "model.safetensors"
|
||||
|
||||
if self.is_safetensors_model:
|
||||
vision_config = config.vision_config
|
||||
vision_config.tensor_parallel_degree = self.tensor_parallel_degree
|
||||
vision_config.tensor_parallel_rank = self.tensor_parallel_rank
|
||||
vision_config.attn_sep = False
|
||||
vision_config.dtype = "bfloat16"
|
||||
else:
|
||||
vision_config = DFNRopeVisionTransformerConfig.from_pretrained(
|
||||
self.args.vision_model_name_or_path,
|
||||
tensor_parallel_degree=self.tensor_parallel_degree,
|
||||
tensor_parallel_rank=self.tensor_parallel_rank,
|
||||
attn_sep=False,
|
||||
dtype="bfloat16",
|
||||
)
|
||||
config.vision_config = vision_config
|
||||
self.vision_config = vision_config
|
||||
config.pixel_hidden_size = config.vision_config.hidden_size
|
||||
config.im_patch_id = tokenizer.get_vocab()["<|IMAGE_PLACEHOLDER|>"]
|
||||
config.think_end_id = tokenizer.get_vocab()["</think>"]
|
||||
config.max_text_id = config.im_patch_id
|
||||
|
||||
config.sequence_parallel = False
|
||||
|
||||
self.dtype = self.args.dtype
|
||||
paddle.set_default_dtype(self.dtype)
|
||||
|
||||
self.vision_model, self.resampler_model = self.inject_pp_vision_model(
|
||||
self.args, config)
|
||||
from fastdeploy.worker.worker_process import initialize_fd_config
|
||||
|
||||
processor = DataProcessor(
|
||||
tokenizer_name=self.args.tokenizer,
|
||||
image_preprocessor_name=str(self.args.image_preprocessor),
|
||||
fd_config = initialize_fd_config(
|
||||
self.args, self.tensor_parallel_degree, self.tensor_parallel_rank
|
||||
)
|
||||
processor.eval()
|
||||
image_preprocess = processor.image_preprocessor
|
||||
image_preprocess.image_mean_tensor = paddle.to_tensor(
|
||||
image_preprocess.image_mean, dtype="float32").reshape([1, 3, 1, 1])
|
||||
image_preprocess.image_std_tensor = paddle.to_tensor(
|
||||
image_preprocess.image_std, dtype="float32").reshape([1, 3, 1, 1])
|
||||
image_preprocess.rescale_factor = paddle.to_tensor(
|
||||
image_preprocess.rescale_factor, dtype="float32")
|
||||
image_preprocess.image_mean_tensor = image_preprocess.image_mean_tensor.squeeze(
|
||||
[-2, -1]).repeat_interleave(config.vision_config.patch_size**2 * 1,
|
||||
-1)
|
||||
image_preprocess.image_std_tensor = image_preprocess.image_std_tensor.squeeze(
|
||||
[-2, -1]).repeat_interleave(config.vision_config.patch_size**2 * 1,
|
||||
-1)
|
||||
self.image_preprocess = image_preprocess
|
||||
|
||||
graph_opt_config = GraphOptimizationConfig(
|
||||
self.args.enable_static_graph_inference, self.args.use_cudagraph,
|
||||
self.args.max_capture_batch_size)
|
||||
|
||||
fd_config, self.model = build_stream_line_model(
|
||||
self.args.model_name_or_path,
|
||||
self.args.dtype,
|
||||
self.args.block_size,
|
||||
max_model_len=self.args.max_model_len,
|
||||
tokenizer=tokenizer,
|
||||
quantization=self.args.quantization,
|
||||
graph_opt_config=graph_opt_config,
|
||||
fd_config.model_config = Ernie4_5_VLMoeConfig(
|
||||
**fd_config.model_config.__dict__
|
||||
)
|
||||
self.model.eval()
|
||||
self.set_state_dict(self.args)
|
||||
|
||||
fd_config.parallel_config.max_model_len = fd_config.model_config.max_seq_len
|
||||
fd_config.parallel_config.column_cut = False
|
||||
vision_config = fd_config.model_config.vision_config
|
||||
vision_config.attn_sep = False
|
||||
vision_config.dtype = "bfloat16"
|
||||
vision_config.tensor_parallel_degree = self.tensor_parallel_degree
|
||||
vision_config.tensor_parallel_rank = self.tensor_parallel_rank
|
||||
fd_config.model_config.pixel_hidden_size = vision_config.hidden_size
|
||||
fd_config.model_config.im_patch_id = tokenizer.get_vocab()[
|
||||
"<|IMAGE_PLACEHOLDER|>"
|
||||
]
|
||||
fd_config.model_config.think_end_id = tokenizer.get_vocab()["</think>"]
|
||||
fd_config.model_config.max_text_id = fd_config.model_config.im_patch_id
|
||||
fd_config.model_config.sequence_parallel = False
|
||||
# TODO (bukejiyu): Remove the assignment
|
||||
fd_config.moe_config.top_k = 8
|
||||
self.fd_config = fd_config
|
||||
self.model_cfg = self.fd_config.model_config
|
||||
self.image_preprocess = self._init_image_preprocess(
|
||||
self.fd_config.model_config.vision_config
|
||||
)
|
||||
from fastdeploy.model_executor.model_loader import \
|
||||
get_model_from_loader
|
||||
|
||||
self.model = get_model_from_loader(self.fd_config)
|
||||
attn_backend_cls = get_attention_backend()
|
||||
num_heads = self.fd_config.model_config.num_attention_heads // \
|
||||
self.fd_config.parallel_config.tensor_parallel_degree
|
||||
@@ -401,188 +359,6 @@ class GPUVLModelRunner(VLModelRunnerBase):
|
||||
self._init_kvcache()
|
||||
self.model.log_memory_usage("update all memory")
|
||||
|
||||
@paddle.no_grad()
|
||||
def set_state_dict(self, args: argparse.Namespace) -> None:
|
||||
"""set_state_dict"""
|
||||
if not self.is_safetensors_model:
|
||||
rank_model_paths = []
|
||||
for root, dirs, files in os.walk(self.args.llm_model_name_or_path):
|
||||
for file in files:
|
||||
if file == f"model_state.tp0{self.tensor_parallel_rank}.pdparams":
|
||||
rank_model_paths.append(os.path.join(root, file))
|
||||
elif file == "model_state.pdparams":
|
||||
rank_model_paths.append(os.path.join(root, file))
|
||||
state_dict = {}
|
||||
for path in rank_model_paths:
|
||||
loaded_dict = paddle.load(path, return_numpy=True)
|
||||
state_dict.update(loaded_dict)
|
||||
|
||||
resampler_state = {}
|
||||
for key in list(state_dict.keys()):
|
||||
if "vision" in key:
|
||||
state_dict.pop(key)
|
||||
if key.startswith("ernie.resampler_model."):
|
||||
value = state_dict.pop(key)
|
||||
value = paddle.to_tensor(value).cast("bfloat16")
|
||||
value = value.numpy()
|
||||
resampler_state[
|
||||
key[len("ernie.resampler_model."):]] = value
|
||||
elif key.startswith("resampler_model."):
|
||||
value = state_dict.pop(key)
|
||||
value = paddle.to_tensor(value).cast("bfloat16")
|
||||
value = value.numpy()
|
||||
resampler_state[key[len("resampler_model."):]] = value
|
||||
self.model.set_state_dict(state_dict)
|
||||
self.resampler_model.set_state_dict(resampler_state)
|
||||
else:
|
||||
state_dict = load_tp_checkpoint(
|
||||
args.model_name_or_path,
|
||||
Ernie4_5_PretrainedModel,
|
||||
self.model_cfg,
|
||||
return_numpy=True,
|
||||
)
|
||||
for key in list(state_dict.keys()):
|
||||
if key.startswith("vision_model.") or key.startswith(
|
||||
"ernie.resampler_model."):
|
||||
state_dict.pop(key)
|
||||
self.model.set_state_dict(state_dict)
|
||||
|
||||
@paddle.no_grad()
|
||||
def vit_load(
|
||||
self,
|
||||
model_path: str,
|
||||
tensor_parallel_degree: int,
|
||||
tensor_parallel_rank: int,
|
||||
) -> None:
|
||||
"""
|
||||
Load vit tp weight
|
||||
"""
|
||||
if tensor_parallel_degree == 1:
|
||||
rank_model_path = os.path.join(model_path, "model_state.pdparams")
|
||||
else:
|
||||
rank_model_path = os.path.join(
|
||||
model_path, f"model_state_tp0{tensor_parallel_rank}.pdparams")
|
||||
if os.path.exists(rank_model_path):
|
||||
return paddle.load(rank_model_path, return_numpy=True)
|
||||
else:
|
||||
raise ValueError(f"No such a file {rank_model_path}")
|
||||
|
||||
@paddle.no_grad()
|
||||
def inject_pp_vision_model(self, args: argparse.Namespace, cfg: Ernie4_5_VLMoeConfig):
|
||||
"""
|
||||
Inject pp vision model
|
||||
"""
|
||||
|
||||
def set_vision_state_dict(model,
|
||||
tensor_parallel_degree: int=8,
|
||||
tensor_parallel_rank: int=0,
|
||||
name: str=""):
|
||||
"""
|
||||
Set vision model weight
|
||||
"""
|
||||
model_state_dict = model.state_dict()
|
||||
compat_keys = [name + k for k in model_state_dict.keys()]
|
||||
model_files = set()
|
||||
for k in compat_keys:
|
||||
if k in self.weight_map.keys():
|
||||
model_files.add(
|
||||
os.path.join(args.model_name_or_path,
|
||||
self.weight_map[k]))
|
||||
state_dict = {}
|
||||
for model_file in model_files:
|
||||
with safe_open(model_file, framework="np") as f:
|
||||
for k in f.keys():
|
||||
if k in compat_keys:
|
||||
new_k = k.replace(name, "")
|
||||
tensor = f.get_tensor(k)
|
||||
if tensor_parallel_degree > 1:
|
||||
if "resampler_model" in name and new_k == "spatial_linear.0.weight":
|
||||
tensor = np.split(
|
||||
tensor, tensor_parallel_degree,
|
||||
axis=0)[tensor_parallel_rank]
|
||||
elif name == "vision_model.":
|
||||
if "attn.proj.weight" in new_k or "fc2.weight" in new_k:
|
||||
tensor = np.split(
|
||||
tensor,
|
||||
tensor_parallel_degree,
|
||||
axis=0)[tensor_parallel_rank]
|
||||
elif "fc1.weight" in new_k or "fc1.bias" in new_k:
|
||||
tensor = np.split(
|
||||
tensor,
|
||||
tensor_parallel_degree,
|
||||
axis=-1)[tensor_parallel_rank]
|
||||
elif "qkv.weight" in new_k:
|
||||
head_dim = self.vision_config.hidden_size // self.vision_config.num_heads
|
||||
tensor = tensor.reshape([
|
||||
self.vision_config.hidden_size, 3,
|
||||
self.vision_config.num_heads,
|
||||
head_dim
|
||||
])
|
||||
tensor = np.split(
|
||||
tensor,
|
||||
tensor_parallel_degree,
|
||||
axis=-2
|
||||
)[tensor_parallel_rank].reshape([
|
||||
self.vision_config.hidden_size, -1
|
||||
])
|
||||
elif "qkv.bias" in new_k:
|
||||
head_dim = self.vision_config.hidden_size // self.vision_config.num_heads
|
||||
tensor = tensor.reshape([
|
||||
3, self.vision_config.num_heads,
|
||||
head_dim
|
||||
])
|
||||
tensor = np.split(
|
||||
tensor,
|
||||
tensor_parallel_degree,
|
||||
axis=-2
|
||||
)[tensor_parallel_rank].reshape([-1])
|
||||
state_dict[new_k] = tensor
|
||||
model.set_state_dict(state_dict)
|
||||
|
||||
vision_model = DFNRopeVisionTransformerPretrainedModel(
|
||||
cfg.vision_config)
|
||||
vision_model = paddle.amp.decorate(models=vision_model,
|
||||
level="O2",
|
||||
dtype="bfloat16")
|
||||
vision_model.eval()
|
||||
if not self.is_safetensors_model:
|
||||
vit_state_dict = self.vit_load(args.vision_model_name_or_path,
|
||||
self.tensor_parallel_degree,
|
||||
self.tensor_parallel_rank)
|
||||
vision_model.set_state_dict(vit_state_dict)
|
||||
else:
|
||||
set_vision_state_dict(
|
||||
vision_model,
|
||||
tensor_parallel_degree=self.tensor_parallel_degree,
|
||||
tensor_parallel_rank=self.tensor_parallel_rank,
|
||||
name="vision_model.",
|
||||
)
|
||||
|
||||
resampler_model = VariableResolutionResamplerModel(
|
||||
cfg.pixel_hidden_size,
|
||||
cfg.hidden_size,
|
||||
cfg.spatial_conv_size,
|
||||
cfg.temporal_conv_size,
|
||||
config=cfg,
|
||||
)
|
||||
resampler_model = paddle.amp.decorate(models=resampler_model,
|
||||
level="O2",
|
||||
dtype="bfloat16")
|
||||
resampler_model.eval()
|
||||
if self.is_safetensors_model:
|
||||
is_ernie_begin = False
|
||||
for k in self.weight_map.keys():
|
||||
if k.startswith("ernie.resampler_model."):
|
||||
is_ernie_begin = True
|
||||
set_vision_state_dict(
|
||||
resampler_model,
|
||||
tensor_parallel_degree=self.tensor_parallel_degree,
|
||||
tensor_parallel_rank=self.tensor_parallel_rank,
|
||||
name="ernie.resampler_model."
|
||||
if is_ernie_begin else "resampler_model.",
|
||||
)
|
||||
return vision_model, resampler_model
|
||||
|
||||
@paddle.no_grad()
|
||||
def extract_vision_features(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
|
||||
"""extract_vision_features"""
|
||||
@@ -607,7 +383,7 @@ class GPUVLModelRunner(VLModelRunnerBase):
|
||||
level="O2",
|
||||
dtype=self.dtype,
|
||||
):
|
||||
image_features = self.vision_model.extract_feature(
|
||||
image_features = self.model.vision_model.extract_feature(
|
||||
images, grid_thw)
|
||||
if self.tensor_parallel_degree > 1:
|
||||
S, C = image_features.shape
|
||||
@@ -616,7 +392,7 @@ class GPUVLModelRunner(VLModelRunnerBase):
|
||||
image_features = ScatterOp.apply(image_features,
|
||||
axis=-1) # mp 切 Fea
|
||||
image_features = image_features.reshape([S, -1])
|
||||
image_features = self.resampler_model(
|
||||
image_features = self.model.resampler_model(
|
||||
image_features,
|
||||
image_mask,
|
||||
token_type_ids_w_video,
|
||||
@@ -1074,195 +850,3 @@ class GPUVLModelRunner(VLModelRunnerBase):
|
||||
images=images,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def build_stream_line_model(
|
||||
model_path: str,
|
||||
dtype: str,
|
||||
block_size: int,
|
||||
max_model_len: int,
|
||||
tokenizer: ErnieBotTokenizer,
|
||||
quantization: str = "None",
|
||||
graph_opt_config: Optional[GraphOptimizationConfig] = None
|
||||
) -> tuple[FDConfig, paddle.nn.layer]:
|
||||
"""
|
||||
build model
|
||||
"""
|
||||
import contextlib
|
||||
|
||||
from paddleformers.transformers.configuration_utils import PretrainedConfig
|
||||
from paddleformers.trl import llm_utils
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy.model_executor.layers.quantization import \
|
||||
get_quantization_config
|
||||
from fastdeploy.model_executor.models.model_base import ModelRegistry
|
||||
|
||||
config, _ = PretrainedConfig.get_config_dict(model_path)
|
||||
config["head_dim"] = config.get(
|
||||
"head_dim", config["hidden_size"] // config["num_attention_heads"])
|
||||
config["rope_theta"] = config.get("rope_theta", 10000.0)
|
||||
rope_theta = config["rope_theta"]
|
||||
model_config = ModelConfig.from_dict(config)
|
||||
model_config.head_dim = config["head_dim"]
|
||||
|
||||
parallel_config = ParallelConfig()
|
||||
speculative_config = SpeculativeConfig()
|
||||
device_config = DeviceConfig()
|
||||
load_config = LoadConfig()
|
||||
moe_config = MoEConfig()
|
||||
kv_cache_config = KVCacheConfig()
|
||||
kv_cache_config.cache_quant_dtype = "none"
|
||||
|
||||
tensor_parallel_rank, tensor_parallel_degree = llm_utils.init_dist_env()
|
||||
parallel_config.tensor_parallel_rank = tensor_parallel_rank
|
||||
parallel_config.tensor_parallel_degree = tensor_parallel_degree
|
||||
parallel_config.tensor_parallel_degree = tensor_parallel_degree
|
||||
parallel_config.expert_parallel_degree = 1
|
||||
parallel_config.expert_parallel_rank = int(tensor_parallel_rank /
|
||||
tensor_parallel_degree)
|
||||
parallel_config.column_cut = False
|
||||
|
||||
speculative_config.is_mtp = False
|
||||
speculative_config.draft_type = "None"
|
||||
|
||||
# Note(tangbinhan): used for load_checkpoint
|
||||
model_config.tensor_parallel_rank = parallel_config.tensor_parallel_rank
|
||||
model_config.tensor_parallel_degree = parallel_config.tensor_parallel_degree
|
||||
model_config.is_mtp = speculative_config.is_mtp
|
||||
moe_config.num_experts = None
|
||||
|
||||
# use the length of tokenizer as the origin vocab size
|
||||
ori_vocab_size = len(tokenizer)
|
||||
moe_intermediate_size = (config.get("moe_intermediate_size", None), )
|
||||
if isinstance(moe_intermediate_size, list) or isinstance(
|
||||
moe_intermediate_size, tuple):
|
||||
moe_intermediate_size = moe_intermediate_size[0]
|
||||
|
||||
num_key_value_heads = config.get("num_key_value_heads", -1)
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = -1
|
||||
|
||||
# RL need, some model num_key_value_heads less tensor_parallel_degree, need copy
|
||||
if num_key_value_heads < tensor_parallel_degree:
|
||||
logger.warning(
|
||||
f"key value heads num is {num_key_value_heads}, tensor parallel degree is {tensor_parallel_degree}"
|
||||
)
|
||||
num_key_value_heads = tensor_parallel_degree
|
||||
|
||||
if config.get("ffn_hidden_size", None) is not None:
|
||||
ffn_hidden_size = config["ffn_hidden_size"]
|
||||
elif config.get("intermediate_size", None) is not None:
|
||||
ffn_hidden_size = config["intermediate_size"]
|
||||
else:
|
||||
ffn_hidden_size = 4 * config["hidden_size"]
|
||||
if config["hidden_act"].lower() == "swiglu":
|
||||
if paddle.distributed.get_world_size() > 1:
|
||||
multiple_of = 8 * config["num_attention_heads"]
|
||||
else:
|
||||
multiple_of = 4 * config["num_attention_heads"]
|
||||
ffn_hidden_size = multiple_of * (
|
||||
(int(2 * ffn_hidden_size / 3) + multiple_of - 1) //
|
||||
multiple_of)
|
||||
|
||||
num_layers = config.get("num_layers", None) or config.get(
|
||||
"num_hidden_layers", None)
|
||||
if num_layers is None:
|
||||
raise ValueError(f"num_layers<{num_layers}> is invalid")
|
||||
|
||||
remove_tail_layer = config.get("remove_tail_layer")
|
||||
if remove_tail_layer is True:
|
||||
num_layers -= 1
|
||||
elif isinstance(remove_tail_layer, int):
|
||||
num_layers -= remove_tail_layer
|
||||
|
||||
moe_num_experts = config.get("moe_num_experts", 0)
|
||||
if isinstance(moe_num_experts, list):
|
||||
moe_num_experts = max(moe_num_experts)
|
||||
use_moe = moe_num_experts > 0
|
||||
|
||||
context = contextlib.nullcontext()
|
||||
|
||||
if config["hidden_act"].lower() == "swiglu":
|
||||
model_config.hidden_act = "swiglu"
|
||||
model_config.ffn_hidden_size = ffn_hidden_size
|
||||
model_config.max_seq_len = max_model_len
|
||||
model_config.num_layers = num_layers
|
||||
model_config.dtype = dtype
|
||||
parallel_config.block_size = block_size
|
||||
|
||||
parallel_config.msg_queue_id = None
|
||||
model_config.num_key_value_heads = num_key_value_heads
|
||||
model_config.return_all_hidden_states = False
|
||||
speculative_config.draft_type = "None"
|
||||
model_config.start_layer_index = 0
|
||||
if use_moe:
|
||||
moe_config.num_experts = config.get("moe_num_experts", None)
|
||||
moe_config.moe_intermediate_size = config.get("moe_intermediate_size",
|
||||
None)
|
||||
moe_config.top_k = config.get("moe_topk", 8)
|
||||
moe_config.moe_num_shared_experts = config.get(
|
||||
"moe_num_shared_experts", 0)
|
||||
moe_config.moe_layer_start_index = config.get("moe_layer_start_index",
|
||||
None)
|
||||
moe_config.moe_layer_end_index = config.get("moe_layer_end_index",
|
||||
None)
|
||||
|
||||
model_config.moe_phase = MoEPhase.PREFILL
|
||||
model_config.ori_vocab_size = ori_vocab_size
|
||||
|
||||
quantization_config = config.get("quantization_config", None)
|
||||
|
||||
quant_config_name = None
|
||||
if quantization_config is not None and quantization_config.get(
|
||||
"quantization", None) is None:
|
||||
raise ValueError(
|
||||
"quantization_config should have a key named 'quantization' for specify quant config."
|
||||
)
|
||||
|
||||
if quantization_config is not None:
|
||||
quant_config_name = quantization_config["quantization"]
|
||||
quant_cls = get_quantization_config(quant_config_name)
|
||||
quant_config = quant_cls.from_config(quantization_config)
|
||||
elif quantization != "None":
|
||||
quantization_config = {}
|
||||
if use_moe and quantization == "wint4":
|
||||
quantization_config["dense_quant_type"] = "wint8"
|
||||
quantization_config["moe_quant_type"] = "wint4"
|
||||
quant_config_name = "mix_quant"
|
||||
else:
|
||||
quant_config_name = quantization
|
||||
quant_cls = get_quantization_config(quant_config_name)
|
||||
quant_config = quant_cls.from_config(quantization_config)
|
||||
else:
|
||||
quant_config = None
|
||||
|
||||
logger.info("===========quantization_config==============")
|
||||
if quant_config is not None:
|
||||
logger.info(f"{quantization_config}")
|
||||
else:
|
||||
logger.info(
|
||||
"No quantization config found and use original weight and act dtype."
|
||||
)
|
||||
logger.info("============================================")
|
||||
|
||||
fd_config = FDConfig(
|
||||
model_config=model_config,
|
||||
parallel_config=parallel_config,
|
||||
speculative_config=speculative_config,
|
||||
device_config=device_config,
|
||||
load_config=load_config,
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
kv_cache_config=kv_cache_config,
|
||||
graph_opt_config=graph_opt_config,
|
||||
)
|
||||
fd_config.parallel_config.max_model_len = max_model_len
|
||||
fd_config.model_config.rope_theta = rope_theta
|
||||
|
||||
with context:
|
||||
model_cls = ModelRegistry.get_class(model_config.architectures[0])
|
||||
model = model_cls(fd_config)
|
||||
|
||||
model.eval()
|
||||
return fd_config, model
|
||||
|
Reference in New Issue
Block a user