mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Model]support qwen2_5_vl (#3557)
* adapt qwen_2_5_vl model * adapt qwen_2_5_vl VIT model * adapt qwen2_5_vl images_embeds * adapt qwen2_5_vl 3D rope * adapt qwen2_5_vl 3D rope v2 * adapt qwen2_5_vl processor * adapt qwen2_5_vl bypass resampler_model * adapt qwen2_5_vl 绕过部分ernie逻辑 * adapt qwen2_5_vl 绕过部分ernie逻辑 v2 * adapt qwen2_5_vl 权重加载与命名修改 * adapt qwen2_5_vl 非必须think_end_id * adapt qwen2_5_vl 区分多种模型的extract_vision_features * fix:adapt qwen2_5_vl model * adapt qwen2_5_vl norm * adapt qwen2_5_vl processor 更新 * adapt qwen2_5_vl image and video success * adapt qwen2_5_vl 部分整理代码 * adapt qwen2_5_vl 支持多卡 * adapt qwen2_5_vl on latest develop * adapt qwen2_5_vl RL * adapt qwen2_5_vl 整理代码 * support noex rope3d * adapt qwen2_5_vl add init.py * adapt qwen2_5_vl add init.py v2 * adapt qwen2_5_vl remove space * adapt qwen2_5_vl remove space v2 * adapt qwen2_5_vl pre-commit * adapt qwen2_5_vl update * adapt qwen2_5_vl pre-commit v2 * adapt qwen2_5_vl modify comments * adapt qwen2_5_vl fix indentation * adapt qwen2_5_vl fix indentation v2 --------- Co-authored-by: wangyafeng <wangyafeng@baidu.com> Co-authored-by: xiaoxiaohehe001 <49090790+xiaoxiaohehe001@users.noreply.github.com> Co-authored-by: CSWYF3634076 <58356743+CSWYF3634076@users.noreply.github.com>
This commit is contained in:
@@ -103,7 +103,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
|
||||
# VL model config:
|
||||
if self.enable_mm:
|
||||
self._init_image_preprocess()
|
||||
if "ernie" in self.fd_config.model_config.model_type:
|
||||
self._init_image_preprocess()
|
||||
|
||||
self.amp_black = [
|
||||
"reduce_sum",
|
||||
@@ -242,7 +243,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
dtype=paddle.int64,
|
||||
)
|
||||
vision_inputs["images"] = paddle.to_tensor(
|
||||
inputs["images"][request.image_start : request.image_end], dtype="uint8"
|
||||
inputs["images"][request.image_start : request.image_end],
|
||||
dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16",
|
||||
)
|
||||
vision_inputs["grid_thw"] = paddle.to_tensor(
|
||||
inputs["grid_thw"][request.num_image_start : request.num_image_end], dtype="int64"
|
||||
@@ -797,6 +799,11 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
|
||||
if self.enable_mm:
|
||||
head_dim = self.model_config.head_dim
|
||||
if "qwen" in self.model_config.model_type: # neox style = True
|
||||
rope_head_dim = head_dim
|
||||
else: # neox style = False
|
||||
rope_head_dim = head_dim // 2
|
||||
|
||||
self.share_inputs["rope_emb"] = paddle.full(
|
||||
shape=[
|
||||
max_num_seqs,
|
||||
@@ -804,14 +811,16 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
1,
|
||||
self.parallel_config.max_model_len,
|
||||
1,
|
||||
head_dim // 2,
|
||||
rope_head_dim,
|
||||
],
|
||||
fill_value=0,
|
||||
dtype="float32",
|
||||
)
|
||||
self.share_inputs["image_features"] = None
|
||||
self.share_inputs["need_think_end"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
|
||||
self.share_inputs["enable_thinking"] = paddle.full(shape=[1], fill_value=True, dtype="bool")
|
||||
self.share_inputs["enable_thinking"] = paddle.full(
|
||||
shape=[1], fill_value=("ernie" in self.model_config.model_type), dtype="bool"
|
||||
)
|
||||
self.share_inputs["reasoning_index"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
|
||||
|
||||
def _prepare_inputs(self) -> None:
|
||||
@@ -1186,7 +1195,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None),
|
||||
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
|
||||
enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None),
|
||||
think_end_id=(self.model_config.think_end_id if self.enable_mm else -1),
|
||||
think_end_id=(getattr(self.model_config, "think_end_id", -1) if self.enable_mm else -1),
|
||||
need_think_end=(self.share_inputs["need_think_end"] if self.enable_mm else None),
|
||||
reasoning_index=(self.share_inputs["reasoning_index"] if self.enable_mm else None),
|
||||
stop_token_ids=self.share_inputs["stop_seqs"],
|
||||
@@ -1476,7 +1485,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None),
|
||||
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
|
||||
enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None),
|
||||
think_end_id=(self.model_config.think_end_id if self.enable_mm else -1),
|
||||
think_end_id=(getattr(self.model_config, "think_end_id", -1) if self.enable_mm else -1),
|
||||
need_think_end=(self.share_inputs["need_think_end"][:num_running_requests] if self.enable_mm else None),
|
||||
reasoning_index=(self.share_inputs["reasoning_index"][:num_running_requests] if self.enable_mm else None),
|
||||
stop_token_ids=self.share_inputs["stop_seqs"],
|
||||
@@ -1720,7 +1729,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
image_type_ids = one["image_type_ids"][np.newaxis, :]
|
||||
images = one["images"]
|
||||
image_type_ids = paddle.to_tensor(image_type_ids, dtype=paddle.int64)
|
||||
images = paddle.to_tensor(images, dtype="uint8")
|
||||
images = paddle.to_tensor(images, dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16")
|
||||
grid_thw = paddle.to_tensor(one["grid_thw"], dtype="int64")
|
||||
else:
|
||||
image_type_ids = None
|
||||
@@ -1742,12 +1751,10 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
)
|
||||
return result
|
||||
|
||||
@paddle.no_grad()
|
||||
def extract_vision_features(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
|
||||
"""extract_vision_features"""
|
||||
def extract_vision_features_ernie(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
|
||||
assert inputs["images"] is not None
|
||||
grid_thw = inputs["grid_thw"]
|
||||
|
||||
# ernie-vl has images norm
|
||||
images = inputs["images"].cast("float32")
|
||||
images = self.image_preprocess.rescale_factor * images - self.image_preprocess.image_mean_tensor
|
||||
images = images / self.image_preprocess.image_std_tensor
|
||||
@@ -1772,6 +1779,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
image_features = image_features.reshape([-1, C * self.model_config.spatial_conv_size**2])
|
||||
image_features = ScatterOp.apply(image_features, axis=-1) # mp 切 Fea
|
||||
image_features = image_features.reshape([S, -1])
|
||||
# ernie-vl has resampler_model
|
||||
image_features = self.model.resampler_model(
|
||||
image_features,
|
||||
image_mask,
|
||||
@@ -1781,6 +1789,31 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
)
|
||||
return image_features
|
||||
|
||||
def extract_vision_features_qwen(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
|
||||
assert inputs["images"] is not None
|
||||
grid_thw = inputs["grid_thw"]
|
||||
images = inputs["images"]
|
||||
with paddle.amp.auto_cast(
|
||||
True,
|
||||
custom_black_list=self.amp_black,
|
||||
custom_white_list=self.amp_white,
|
||||
level="O2",
|
||||
dtype=self.parallel_config.dtype,
|
||||
):
|
||||
image_features = self.model.visual.extract_feature(images, grid_thw)
|
||||
|
||||
return image_features
|
||||
|
||||
@paddle.no_grad()
|
||||
def extract_vision_features(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
|
||||
"""extract_vision_features"""
|
||||
if "ernie" in self.model_config.model_type:
|
||||
return self.extract_vision_features_ernie(inputs)
|
||||
elif "qwen" in self.model_config.model_type:
|
||||
return self.extract_vision_features_qwen(inputs)
|
||||
else:
|
||||
raise ValueError(f"multiple modalities model {self.model_config.model_type} is not supported")
|
||||
|
||||
@paddle.no_grad()
|
||||
def prepare_rope3d(self, position_ids: paddle.Tensor, max_len: int) -> paddle.Tensor:
|
||||
"""prepare_rope3d"""
|
||||
@@ -1800,5 +1833,6 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
base=self.model_config.rope_theta,
|
||||
max_position=self.parallel_config.max_model_len,
|
||||
freq_allocation=getattr(self.model_config, "freq_allocation", 20),
|
||||
model_type=self.model_config.model_type,
|
||||
)
|
||||
return rope_emb
|
||||
|
Reference in New Issue
Block a user