diff --git a/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py b/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py index 3e48ab81f..0073ea3b8 100644 --- a/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py @@ -83,7 +83,9 @@ class XPUAttentionBackend(AttentionBackend): self.rope_theta: float = ( 10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta ) - self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) + self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr( + fd_config.model_config, "use_3d_rope", False + ) self.causal: bool = getattr(fd_config.model_config, "causal", True) self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp" self.rank: int = fd_config.parallel_config.tensor_parallel_rank diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index ae965e98a..3a480c8df 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -420,6 +420,7 @@ class XPUModelRunner(ModelRunnerBase): req_len = len(req_dicts) has_prefill_task = False has_decode_task = False + multi_vision_inputs = {"images_lst": [], "grid_thw_lst": [], "vit_position_ids_lst": [], "cu_seqlens": [0]} rope_3d_position_ids = { "position_ids_idx": [], "position_ids_lst": [], @@ -436,24 +437,39 @@ class XPUModelRunner(ModelRunnerBase): if self.enable_mm: inputs = request.multimodal_inputs if request.with_image: - vision_inputs = {} - vision_inputs["input_ids"] = paddle.to_tensor( - inputs["input_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64 - ) - vision_inputs["token_type_ids"] = paddle.to_tensor( - inputs["token_type_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64 - ) - vision_inputs["image_type_ids"] = paddle.to_tensor( - inputs["image_type_ids"][request.image_type_ids_start : request.image_type_ids_end], - dtype=paddle.int64, - ) - vision_inputs["images"] = paddle.to_tensor( - inputs["images"][request.image_start : request.image_end], dtype="uint8" - ) - vision_inputs["grid_thw"] = paddle.to_tensor( - inputs["grid_thw"][request.num_image_start : request.num_image_end], dtype="int64" - ) - self.share_inputs["image_features"] = self.extract_vision_features(vision_inputs) + if envs.FD_ENABLE_MAX_PREFILL: + multi_vision_inputs["images_lst"].append( + paddle.to_tensor(inputs["images"][request.image_start : request.image_end]) + ) + multi_vision_inputs["grid_thw_lst"].extend( + inputs["grid_thw"][request.num_image_start : request.num_image_end] + ) + multi_vision_inputs["cu_seqlens"].extend( + inputs["vit_seqlen"][request.num_image_start : request.num_image_end] + ) + multi_vision_inputs["vit_position_ids_lst"].extend( + inputs["vit_position_ids"][request.num_image_start : request.num_image_end] + ) + else: + vision_inputs = {} + vision_inputs["input_ids"] = paddle.to_tensor( + inputs["input_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64 + ) + vision_inputs["token_type_ids"] = paddle.to_tensor( + inputs["token_type_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64 + ) + vision_inputs["image_type_ids"] = paddle.to_tensor( + inputs["image_type_ids"][request.image_type_ids_start : request.image_type_ids_end], + dtype=paddle.int64, + ) + vision_inputs["images"] = paddle.to_tensor( + inputs["images"][request.image_start : request.image_end], + dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16", + ) + vision_inputs["grid_thw"] = paddle.to_tensor( + inputs["grid_thw"][request.num_image_start : request.num_image_end], dtype="int64" + ) + self.share_inputs["image_features"] = self.extract_vision_features(vision_inputs) else: self.share_inputs["image_features"] = None @@ -570,6 +586,9 @@ class XPUModelRunner(ModelRunnerBase): else: self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0 + if len(multi_vision_inputs["images_lst"]) > 0: + self.share_inputs["image_features"] = self.extract_vision_features(multi_vision_inputs) + if len(rope_3d_position_ids["position_ids_idx"]) > 0: packed_position_ids = paddle.to_tensor( np.concatenate(rope_3d_position_ids["position_ids_lst"]), dtype="int64" @@ -826,6 +845,16 @@ class XPUModelRunner(ModelRunnerBase): if self.enable_mm: head_dim = self.model_config.head_dim + if "paddleocr" in self.model_config.model_type: # neox style = True + rope_head_dim = head_dim + else: # neox style = False + rope_head_dim = head_dim // 2 + + if head_dim == self.model_config.head_dim: + self.share_inputs["pos_emb_type"] = "NORMAL" + else: + self.share_inputs["pos_emb_type"] = "HALF_HEAD_DIM" + self.share_inputs["rope_emb"] = paddle.full( shape=[ max_num_seqs, @@ -833,7 +862,7 @@ class XPUModelRunner(ModelRunnerBase): 1, self.model_config.max_model_len, 1, - head_dim // 2, + rope_head_dim, ], fill_value=0, dtype="float32", @@ -866,8 +895,8 @@ class XPUModelRunner(ModelRunnerBase): # Update bad tokens len max_bad_tokens_len = paddle.max(self.share_inputs["bad_tokens_len"]) - if self.enable_mm: # pos_emb_type is different in EB and VL - self.forward_meta.pos_emb_type = "HALF_HEAD_DIM" + if self.enable_mm: + self.forward_meta.pos_emb_type = self.share_inputs["pos_emb_type"] self.forward_meta.attn_backend = self.attn_backends[0] self.initialize_attention_backend() @@ -1338,12 +1367,10 @@ class XPUModelRunner(ModelRunnerBase): ) return result - @paddle.no_grad() - def extract_vision_features(self, inputs: list[paddle.Tensor]) -> paddle.Tensor: - """extract_vision_features""" + def extract_vision_features_ernie(self, inputs: list[paddle.Tensor]) -> paddle.Tensor: assert inputs["images"] is not None grid_thw = inputs["grid_thw"] - + # ernie-vl has images norm images = inputs["images"].cast("float32") images = self.image_preprocess.rescale_factor * images - self.image_preprocess.image_mean_tensor images = images / self.image_preprocess.image_std_tensor @@ -1353,7 +1380,6 @@ class XPUModelRunner(ModelRunnerBase): token_type_ids_w_video = token_type_ids input_ids = inputs["input_ids"] # convert to img patch id - # TODO(lulinjun): may need to check model_config and model_cfg image_mask = input_ids == self.model_config.im_patch_id image_type_ids = inputs["image_type_ids"] with paddle.amp.auto_cast( @@ -1369,6 +1395,7 @@ class XPUModelRunner(ModelRunnerBase): image_features = image_features.reshape([-1, C * self.model_config.spatial_conv_size**2]) image_features = ScatterOp.apply(image_features, axis=-1) # mp 切 Fea image_features = image_features.reshape([S, -1]) + # ernie-vl has resampler_model image_features = self.model.resampler_model( image_features, image_mask, @@ -1378,6 +1405,59 @@ class XPUModelRunner(ModelRunnerBase): ) return image_features + def extract_vision_features_paddleocr(self, inputs: list[paddle.Tensor]) -> paddle.Tensor: + if envs.FD_ENABLE_MAX_PREFILL: + inputs["vit_position_ids_lst"] = np.concatenate(inputs["vit_position_ids_lst"]) + images = paddle.concat(inputs["images_lst"]).cast("bfloat16") + grid_thw = paddle.to_tensor(inputs["grid_thw_lst"], dtype="int64") + position_ids = paddle.to_tensor(inputs["vit_position_ids_lst"], dtype="int64") + cu_seqlens = paddle.cumsum(paddle.to_tensor(inputs["cu_seqlens"])).cast("int32") + else: + assert inputs["images"] is not None + grid_thw = inputs["grid_thw"] + images = inputs["images"] + + position_ids = [] + cu_seqlens = [0] + for idx, thw in enumerate(grid_thw): + numel = np.prod(np.array(thw)) + position_ids.append(paddle.arange(numel) % np.prod(thw[1:])) + cu_seqlens.append(cu_seqlens[-1] + numel) + + position_ids = paddle.concat(position_ids, axis=0).to(images.place) + cu_seqlens = paddle.to_tensor(cu_seqlens, dtype=paddle.int32).to(images.place) + + with paddle.amp.auto_cast( + True, + custom_black_list=self.amp_black, + custom_white_list=self.amp_white, + level="O2", + dtype=self.model_config.dtype, + ): + image_features = self.model.visual( + pixel_values=images, + image_grid_thw=grid_thw, + position_ids=position_ids, + interpolate_pos_encoding=True, + cu_seqlens=cu_seqlens, + use_rope=True, + window_size=-1, + ) + image_features = self.model.projector(image_features, grid_thw) + image_features = paddle.concat(image_features, axis=0) + + return image_features + + @paddle.no_grad() + def extract_vision_features(self, inputs: list[paddle.Tensor]) -> paddle.Tensor: + """extract_vision_features""" + if "ernie" in self.model_config.model_type: + return self.extract_vision_features_ernie(inputs) + elif "paddleocr" in self.model_config.model_type: + return self.extract_vision_features_paddleocr(inputs) + else: + raise ValueError(f"multiple modalities model {self.model_config.model_type} is not supported") + @paddle.no_grad() def prepare_rope3d( self, position_ids: paddle.Tensor, max_len_lst: list[int], cumsum_seqlens: list[int]