[Sync] Update to latest code (#2679)

* [Sync] Update to latest code

* Add new code files

* Add new code files

* update code

* Try to fix build.sh

* Try to fix build.sh

* Update code

* Update requirements.txt

* Update code

---------

Co-authored-by: Jiang-Jia-Jun <jiangjiajun@baidu.com>
This commit is contained in:
Jiang-Jia-Jun
2025-07-03 15:43:53 +08:00
committed by GitHub
parent d222248d00
commit 05c670e593
95 changed files with 9916 additions and 1312 deletions

View File

@@ -16,10 +16,12 @@
import json
import os
import random
import argparse
import numpy as np
import paddle
import paddle.distributed.fleet as fleet
from paddleformers.transformers.model_utils import load_tp_checkpoint
from safetensors import safe_open
from fastdeploy.input.ernie_tokenizer import ErnieBotTokenizer
@@ -38,11 +40,13 @@ from fastdeploy.model_executor.models.ernie4_5_vl.dfnrope.modeling import \
DFNRopeVisionTransformerPretrainedModel
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import (
ScatterOp, VariableResolutionResamplerModel)
from fastdeploy.model_executor.models.utils import load_checkpoint
from fastdeploy.platforms import current_platform
from fastdeploy.worker.forward_meta import ForwardMeta
from fastdeploy.worker.utils import check_safetensors_model
from fastdeploy.worker.vl_model_runner_base import VLModelRunnerBase
from fastdeploy.config import (DeviceConfig, FDConfig, KVCacheConfig,
LoadConfig, ModelConfig, MoEConfig,
MoEPhase, ParallelConfig, SpeculativeConfig)
if current_platform.is_cuda() and current_platform.available():
from fastdeploy.model_executor.layers.utils import (
@@ -55,8 +59,20 @@ from fastdeploy.model_executor.ops.gpu import (save_output,
class GPUVLModelRunner(VLModelRunnerBase):
"""
The GPUVLModelRunner class for vision-language tasks on GPU.
"""
def __init__(self, config, args, nranks, rank):
def __init__(
self,
config: ModelConfig,
args: argparse.Namespace,
nranks: int,
rank: int,
) -> None:
"""
GPUVLModelRunner init
"""
self.nranks = nranks
self.rank = rank
@@ -104,14 +120,11 @@ class GPUVLModelRunner(VLModelRunnerBase):
self.sampler = Sampler()
def _reset_paddle_env(self):
#FLAGS_gqa_use_tensorcore
#FLAGS_ffn2_use_hardamard
# gqa .etc paddle Flags set
pass
def update_chunked_prefill(self, tasks):
def update_chunked_prefill(self, tasks: list[any]) -> None:
"""
更新chunked prefill相关参数
update chunked prefill
"""
if not self.args.enable_chunked_prefill:
return
@@ -135,7 +148,7 @@ class GPUVLModelRunner(VLModelRunnerBase):
"image_features"] = self.extract_vision_features(
inputs)
else:
# 兼容没有图片和视频的情况
# Compatible with the situation that lacks images and videos
self.share_inputs["image_features"] = None
token_chunk_size = inputs["input_ids"].shape[1]
@@ -152,7 +165,14 @@ class GPUVLModelRunner(VLModelRunnerBase):
task.start_idx += token_chunk_size
task.chunk_idx += 1
def _load_model(self, model_name, dynamic_load_weight):
def _load_model(
self,
model_name: str,
dynamic_load_weight: int = 0,
) -> None:
"""
Load the model from the given model name.
"""
vocab_file_names = [
"tokenizer.model", "spm.model", "ernie_token_100k.model"
@@ -261,7 +281,7 @@ class GPUVLModelRunner(VLModelRunnerBase):
fd_config.parallel_config.max_model_len = fd_config.model_config.max_seq_len
self.fd_config = fd_config
attn_backend_cls = get_attention_backend(self.args.attention_backend)
attn_backend_cls = get_attention_backend()
num_heads = self.fd_config.model_config.num_attention_heads // \
self.fd_config.parallel_config.tensor_parallel_degree
self.fd_config.model_config.kv_num_heads = int(
@@ -275,7 +295,10 @@ class GPUVLModelRunner(VLModelRunnerBase):
head_dim=head_dim)
self._init_kvcache()
def init_extra_input(self, config, args):
def init_extra_input(self, config: ModelConfig, args: argparse.Namespace) -> None:
"""
Initialize extra input tensors.
"""
head_dim = self.model_cfg.head_dim
self.share_inputs.update({
"rope_emb":
@@ -287,29 +310,31 @@ class GPUVLModelRunner(VLModelRunnerBase):
})
self.share_inputs.update({"image_features": None})
self.share_inputs.update({
"need_think_end": paddle.full(shape=[
args.max_num_seqs, 1],
fill_value=0,
dtype="int32")
"need_think_end":
paddle.full(shape=[args.max_num_seqs, 1],
fill_value=0,
dtype="int32")
})
self.share_inputs.update({
"enable_thinking": paddle.full(shape=[1],
fill_value=True,
dtype="bool")
"enable_thinking":
paddle.full(shape=[1], fill_value=True, dtype="bool")
})
self.share_inputs.update({
"reasoning_index": paddle.full(shape=[
args.max_num_seqs, 1],
fill_value=0,
dtype="int32")
"reasoning_index":
paddle.full(shape=[args.max_num_seqs, 1],
fill_value=0,
dtype="int32")
})
def init_rotary_position_embedding(self, max_model_len):
def init_rotary_position_embedding(self, max_model_len: int) -> None:
"""
Init rotary position embedding
"""
pass
def _init_kvcache(self):
"""
分享不拷贝数据
Init kv cache
"""
cache_kvs = {}
total_block_num = self.num_gpu_blocks
@@ -352,7 +377,7 @@ class GPUVLModelRunner(VLModelRunnerBase):
del value
paddle.device.cuda.empty_cache()
def clear_parameters(self, pid):
def clear_parameters(self, pid: int) -> None:
""" clear_parameters """
if "caches" in self.share_inputs:
self.model.clear_parameters(pid)
@@ -360,7 +385,7 @@ class GPUVLModelRunner(VLModelRunnerBase):
paddle.device.cuda.empty_cache()
self.model.log_memory_usage("clear all memory")
def update_parameters(self, pid):
def update_parameters(self, pid: int) -> None:
""" update_parameters """
if "caches" not in self.share_inputs:
self.model.update_parameters(pid)
@@ -368,7 +393,7 @@ class GPUVLModelRunner(VLModelRunnerBase):
self.model.log_memory_usage("update all memory")
@paddle.no_grad()
def set_state_dict(self, args):
def set_state_dict(self, args: argparse.Namespace) -> None:
"""set_state_dict"""
if not self.is_safetensors_model:
rank_model_paths = []
@@ -401,7 +426,7 @@ class GPUVLModelRunner(VLModelRunnerBase):
self.model.set_state_dict(state_dict)
self.resampler_model.set_state_dict(resampler_state)
else:
state_dict = load_checkpoint(
state_dict = load_tp_checkpoint(
args.model_name_or_path,
Ernie4_5_PretrainedModel,
self.model_cfg,
@@ -414,10 +439,14 @@ class GPUVLModelRunner(VLModelRunnerBase):
self.model.set_state_dict(state_dict)
@paddle.no_grad()
def vit_load(self, model_path, tensor_parallel_degree,
tensor_parallel_rank):
def vit_load(
self,
model_path: str,
tensor_parallel_degree: int,
tensor_parallel_rank: int,
) -> None:
"""
vit_load tp参数
Load vit tp weight
"""
if tensor_parallel_degree == 1:
rank_model_path = os.path.join(model_path, "model_state.pdparams")
@@ -430,15 +459,18 @@ class GPUVLModelRunner(VLModelRunnerBase):
raise ValueError(f"No such a file {rank_model_path}")
@paddle.no_grad()
def inject_pp_vision_model(self, args, cfg):
def inject_pp_vision_model(self, args: argparse.Namespace, cfg: Ernie4_5_VLMoeConfig):
"""
注入vision model参数
Inject pp vision model
"""
def set_vision_state_dict(model,
tensor_parallel_degree=8,
tensor_parallel_rank=0,
name=""):
tensor_parallel_degree: int=8,
tensor_parallel_rank: int=0,
name: str=""):
"""
Set vision model weight
"""
model_state_dict = model.state_dict()
compat_keys = [name + k for k in model_state_dict.keys()]
model_files = set()
@@ -543,7 +575,7 @@ class GPUVLModelRunner(VLModelRunnerBase):
return vision_model, resampler_model
@paddle.no_grad()
def extract_vision_features(self, inputs):
def extract_vision_features(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
"""extract_vision_features"""
assert inputs["images"] is not None
grid_thw = inputs["grid_thw"]
@@ -585,7 +617,7 @@ class GPUVLModelRunner(VLModelRunnerBase):
return image_features
@paddle.no_grad()
def prepare_rope3d(self, position_ids, **kwargs):
def prepare_rope3d(self, position_ids: paddle.Tensor, **kwargs) -> paddle.Tensor:
"""prepare_rope3d"""
prefix_max_position_ids = paddle.max(position_ids) + 1
@@ -608,13 +640,13 @@ class GPUVLModelRunner(VLModelRunnerBase):
def prefill_finished(self):
"""
判断是否已经完成了prefill操作
Verify prefill operation completion
"""
prefill_statue = (self.share_inputs["seq_lens_this_time"] != 0) & (
self.share_inputs["seq_lens_this_time"] != 1)
return not paddle.any(prefill_statue).numpy()
def dy_input_preprocess(self, tasks):
def dy_input_preprocess(self, tasks: list[any]) -> None:
"""
dynamic insertion
"""
@@ -662,7 +694,7 @@ class GPUVLModelRunner(VLModelRunnerBase):
"image_features"] = self.extract_vision_features(
inputs)
else:
# 兼容没有图片和视频的情况
# Compatible with the situation that lacks images and videos
self.share_inputs["image_features"] = None
if task.multimodal_inputs["position_ids"] is not None:
position_ids = paddle.to_tensor(
@@ -688,7 +720,7 @@ class GPUVLModelRunner(VLModelRunnerBase):
"image_features"] = self.extract_vision_features(
inputs)
else:
# 兼容没有图片和视频的情况
# Compatible with the situation that lacks images and videos
self.share_inputs["image_features"] = None
position_ids = inputs["position_ids"]
@@ -702,10 +734,11 @@ class GPUVLModelRunner(VLModelRunnerBase):
# force </think>
self.share_inputs["enable_thinking"][:] = kwargs["enable_thinking"]
self.share_inputs["need_think_end"][idx:idx +
1, :] = 1 if kwargs["enable_thinking"] else 0
self.share_inputs["need_think_end"][
idx:idx + 1, :] = 1 if kwargs["enable_thinking"] else 0
self.share_inputs["reasoning_index"][idx:idx + 1, :] = kwargs["reasoning_max_tokens"]
self.share_inputs["reasoning_index"][
idx:idx + 1, :] = kwargs["reasoning_max_tokens"]
self.share_inputs["rope_emb"][idx:idx +
1, :] = self.prepare_rope3d(
@@ -737,7 +770,7 @@ class GPUVLModelRunner(VLModelRunnerBase):
idx:idx + 1, :encoder_block_num] = np.array(task.block_tables,
dtype="int32")
def pre_process(self):
def pre_process(self) -> None:
"""
pre_process
"""
@@ -794,7 +827,10 @@ class GPUVLModelRunner(VLModelRunnerBase):
eos_token_ids=self.share_inputs["eos_token_id"],
)
def generate(self):
def generate(self) -> None:
"""
generate
"""
self.pre_process()
hiddden_states = self.model(self.share_inputs["ids_remove_padding"],
self.share_inputs["image_features"],
@@ -815,7 +851,10 @@ class GPUVLModelRunner(VLModelRunnerBase):
paddle.distributed.broadcast(next_tokens, 0)
self.post_process(next_tokens)
def post_process(self, next_tokens):
def post_process(self, next_tokens: paddle.Tensor) -> None:
"""
post_process
"""
if self.share_inputs["enable_thinking"]:
exists_think_end = next_tokens == self.model_cfg.think_end_id
paddle.assign(
@@ -823,37 +862,28 @@ class GPUVLModelRunner(VLModelRunnerBase):
exists_think_end,
self.share_inputs["need_think_end"] - 1,
self.share_inputs["need_think_end"],
),
self.share_inputs["need_think_end"]
)
), self.share_inputs["need_think_end"])
paddle.assign(
paddle.where(
self.share_inputs["need_think_end"].cast("bool"),
self.share_inputs["reasoning_index"] - 1,
self.share_inputs["reasoning_index"],
),
self.share_inputs["reasoning_index"]
)
), self.share_inputs["reasoning_index"])
stop_wo_think = (
(
next_tokens == self.share_inputs["eos_token_id"]
) | (
self.share_inputs["reasoning_index"] == 0
)
) & (
self.share_inputs["need_think_end"] > 0
)
next_tokens = paddle.where(stop_wo_think, self.model_cfg.think_end_id, next_tokens)
(next_tokens == self.share_inputs["eos_token_id"]) |
(self.share_inputs["reasoning_index"] == 0)) & (
self.share_inputs["need_think_end"] > 0)
next_tokens = paddle.where(stop_wo_think,
self.model_cfg.think_end_id,
next_tokens)
paddle.assign(
paddle.where(
stop_wo_think,
self.share_inputs["need_think_end"] - 1,
self.share_inputs["need_think_end"],
),
self.share_inputs["need_think_end"]
)
), self.share_inputs["need_think_end"])
paddle.assign(
paddle.where(
self.share_inputs["stop_flags"],
@@ -899,14 +929,13 @@ class GPUVLModelRunner(VLModelRunnerBase):
def _cal_theortical_kvcache(self):
"""
计算理论的kvcache大小
Calculate the size of kvcache for computational theory
"""
num_layers = self.model_cfg.get("num_layers",
None) or self.model_cfg.get(
"num_hidden_layers", None)
byte_of_cache = 2
#TODO
# 支持c8 c4
# support c8 c4
hidden_dim = self.model_cfg.head_dim * self.model_cfg.kv_num_head
theoretical_kv_cache_memory = (2 * byte_of_cache *
@@ -915,6 +944,9 @@ class GPUVLModelRunner(VLModelRunnerBase):
return theoretical_kv_cache_memory
def _update_share_input_block_num(self):
"""
Update share_inputs['block_tables'] and share_inputs['free_list']
"""
num_gpu_blocks = self.num_gpu_blocks
del self.share_inputs["caches"]
@@ -924,7 +956,7 @@ class GPUVLModelRunner(VLModelRunnerBase):
self.share_inputs["block_tables"] = paddle.full(
[self.args.max_num_seqs, num_gpu_blocks], -1, dtype="int32")
# 初始化free list
# Init free list
free_list = list(
range(num_gpu_blocks - 1,
int(num_gpu_blocks * self.args.kv_cache_ratio) - 1, -1))
@@ -936,7 +968,7 @@ class GPUVLModelRunner(VLModelRunnerBase):
paddle.full([1], self.free_list_len, dtype="int32"),
})
def dummy_input(self, num_total_tokens, number_of_tasks):
def dummy_input(self, num_total_tokens: int, number_of_tasks: int) -> None:
"""
fake input to profile
"""
@@ -974,7 +1006,7 @@ class GPUVLModelRunner(VLModelRunnerBase):
self.share_inputs["block_tables"][idx : idx + 1, :block_num] = np.arange(idx * block_num, \
(idx + 1) * block_num, 1)
def _preprocess_task(self, one):
def _preprocess_task(self, one: dict) -> None:
"""process batch"""
input_ids = one["input_ids"][np.newaxis, :]
@@ -1012,13 +1044,13 @@ class GPUVLModelRunner(VLModelRunnerBase):
def build_stream_line_model(
model_path,
dtype,
block_size,
max_model_len,
tokenizer,
model_path: str,
dtype: str,
block_size: int,
max_model_len: int,
tokenizer: ErnieBotTokenizer,
quantization: str = "None",
):
) -> tuple[FDConfig, paddle.nn.layer]:
"""
build model
"""
@@ -1028,9 +1060,6 @@ def build_stream_line_model(
from paddleformers.trl import llm_utils
from paddleformers.utils.log import logger
from fastdeploy.config import (DeviceConfig, FDConfig, KVCacheConfig,
LoadConfig, ModelConfig, MoEConfig,
MoEPhase, ParallelConfig, SpeculativeConfig)
from fastdeploy.model_executor.layers.quantization import \
get_quantization_config
from fastdeploy.model_executor.models.model_base import ModelRegistry