From e1c4a12e3486e3516857de5637ca975e1cdfb1bf Mon Sep 17 00:00:00 2001 From: Nyakku Shigure Date: Tue, 9 Dec 2025 14:37:00 +0800 Subject: [PATCH] [Graph Optimization][CINN] Use CINN in PaddleOCR-VL ViT part (#5223) --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- fastdeploy/engine/common_engine.py | 1 + fastdeploy/engine/engine.py | 1 + .../models/paddleocr_vl/siglip.py | 36 ++++++++++-- .../models/paddleocr_vl/siglip_ops.py | 13 ++++- fastdeploy/worker/gpu_model_runner.py | 58 +++++++++++++++++++ fastdeploy/worker/gpu_worker.py | 2 + fastdeploy/worker/model_runner_base.py | 6 ++ tests/e2e/test_paddleocr_vl_serving.py | 13 +++-- 8 files changed, 120 insertions(+), 10 deletions(-) diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 9a370f443..dc9741270 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -1585,6 +1585,7 @@ class EngineService: "SOT_UNSAFE_CACHE_FASTPATH": os.getenv("SOT_UNSAFE_CACHE_FASTPATH", default="1"), "SOT_ENABLE_0_SIZE_FALLBACK": os.getenv("SOT_ENABLE_0_SIZE_FALLBACK", default="0"), "SOT_SPECIALIZED_DIM_NUMBERS": os.getenv("SOT_SPECIALIZED_DIM_NUMBERS", default="no"), + "SOT_ENABLE_COMPILE_TIME_LIMIT": os.getenv("SOT_ENABLE_COMPILE_TIME_LIMIT", default="0"), "FLAGS_specialize_device_in_dy2st": os.getenv("FLAGS_specialize_device_in_dy2st", default="1"), "FLAGS_enable_async_fast_gc": os.getenv("FLAGS_enable_async_fast_gc", default="0"), "FLAGS_pir_interpreter_record_stream_for_gc_cache": os.getenv( diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index e1209872e..92b91da8a 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -464,6 +464,7 @@ class LLMEngine: "SOT_UNSAFE_CACHE_FASTPATH": os.getenv("SOT_UNSAFE_CACHE_FASTPATH", default="1"), "SOT_ENABLE_0_SIZE_FALLBACK": os.getenv("SOT_ENABLE_0_SIZE_FALLBACK", default="0"), "SOT_SPECIALIZED_DIM_NUMBERS": os.getenv("SOT_SPECIALIZED_DIM_NUMBERS", default="no"), + "SOT_ENABLE_COMPILE_TIME_LIMIT": os.getenv("SOT_ENABLE_COMPILE_TIME_LIMIT", default="0"), "FLAGS_specialize_device_in_dy2st": os.getenv("FLAGS_specialize_device_in_dy2st", default="1"), "FLAGS_enable_async_fast_gc": os.getenv("FLAGS_enable_async_fast_gc", default="0"), "FLAGS_pir_interpreter_record_stream_for_gc_cache": os.getenv( diff --git a/fastdeploy/model_executor/models/paddleocr_vl/siglip.py b/fastdeploy/model_executor/models/paddleocr_vl/siglip.py index 452d8dd1f..65b72d359 100644 --- a/fastdeploy/model_executor/models/paddleocr_vl/siglip.py +++ b/fastdeploy/model_executor/models/paddleocr_vl/siglip.py @@ -281,7 +281,6 @@ class SiglipMLP(nn.Layer): def __init__(self, config): super().__init__() self.config = config - self.activation_fn = get_activation_fn(config.hidden_act) self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) self.fc1.weight.weight_loader = self.weight_loader self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) @@ -304,7 +303,7 @@ class SiglipMLP(nn.Layer): def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: hidden_states = self.fc1(hidden_states) - hidden_states = self.activation_fn(hidden_states[0]) + hidden_states = get_activation_fn(self.config.hidden_act)(hidden_states[0]) hidden_states = self.fc2(hidden_states) return hidden_states @@ -318,7 +317,6 @@ class SiglipEncoderLayer(paddle.nn.Layer): self.layer_norm2 = paddle.nn.LayerNorm(self.embed_dim, epsilon=config.layer_norm_eps) self.mlp = SiglipMLP(config) - # @paddle.jit.to_static def forward( self, hidden_states, @@ -527,7 +525,37 @@ class SiglipEncoder(nn.Layer): else: attn_cu_seqlens = cu_seqlens - max_seqlen = (attn_cu_seqlens[1:] - attn_cu_seqlens[:-1]).max().item() + return self._run_encoder_layer( + encoder_states=encoder_states, + all_attentions=all_attentions, + attn_cu_seqlens=attn_cu_seqlens, + output_hidden_states=output_hidden_states, + reversed_window_indices=reversed_window_indices if output_hidden_states else None, + use_window_attn=use_window_attn, + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + cos_emb=cos_emb, + sin_emb=sin_emb, + ) + + # This function will be compiled with CINN when graph_opt_level >= 2 + # TODO(SigureMo): Use a new decorator to mark the function for CINN compilation + def _run_encoder_layer( + self, + encoder_states: Optional[Tuple[()]], + all_attentions: Optional[Tuple[()]], + attn_cu_seqlens: Optional[paddle.Tensor], + output_hidden_states: Optional[bool], + reversed_window_indices: paddle.Tensor, + use_window_attn: bool, + hidden_states: paddle.Tensor, + attention_mask: Optional[paddle.Tensor], + output_attentions: bool, + cos_emb: Optional[paddle.Tensor], + sin_emb: Optional[paddle.Tensor], + ) -> paddle.Tensor: + max_seqlen = (attn_cu_seqlens[1:] - attn_cu_seqlens[:-1]).max().cpu() for encoder_layer in self.layers: if output_hidden_states: diff --git a/fastdeploy/model_executor/models/paddleocr_vl/siglip_ops.py b/fastdeploy/model_executor/models/paddleocr_vl/siglip_ops.py index d84898f65..33cd502a5 100644 --- a/fastdeploy/model_executor/models/paddleocr_vl/siglip_ops.py +++ b/fastdeploy/model_executor/models/paddleocr_vl/siglip_ops.py @@ -27,6 +27,8 @@ if current_platform.is_cuda(): def rotate_half(x): Dh = x.shape[-1] + if Dh == -1: + Dh = paddle.shape(x)[-1] x1 = x[..., : Dh // 2] x2 = x[..., Dh // 2 :] return paddle.concat([-x2, x1], axis=-1) @@ -41,6 +43,8 @@ def apply_rotary_pos_emb_vision(x, cos, sin): def native_neox_rope_embedding(qkv, cos, sin, num_heads): B, seq_length, D = qkv.shape + if seq_length == -1: + _, seq_length, _ = paddle.shape(qkv) qkv = qkv.reshape( [ seq_length, @@ -55,18 +59,23 @@ def native_neox_rope_embedding(qkv, cos, sin, num_heads): return q, k, v +jit_unified_marker = paddle.jit.marker.unified if hasattr(paddle.jit.marker, "unified") else lambda fn: fn + + +@jit_unified_marker def neox_rope_embedding( qkv: paddle.Tensor, cos_emb: paddle.Tensor, sin_emb: paddle.Tensor, num_heads: int, head_dim: int ) -> List[paddle.Tensor]: - if current_platform.is_cuda(): + if current_platform.is_cuda() and paddle.in_dynamic_mode(): return fused_neox_rope_embedding(qkv, cos_emb, sin_emb, num_heads, head_dim) else: return native_neox_rope_embedding(qkv, cos_emb, sin_emb, num_heads) +@jit_unified_marker def get_activation_fn(hidden_act: str): if hidden_act == "gelu_pytorch_tanh": - if current_platform.is_cuda(): + if current_platform.is_cuda() and paddle.in_dynamic_mode(): return gelu_tanh else: return ACT2FN["gelu_new"] diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index ad6be643d..c5cad5bf8 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -2183,6 +2183,30 @@ class GPUModelRunner(ModelRunnerBase): time_after_capture = time.perf_counter() logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds") + def vision_encoder_compile(self): + if self.graph_opt_config.graph_opt_level == 0: + return + # Currently only PaddleOCR-VL model is supported for vision encoder layer + if self.model_config.model_type != "paddleocr_vl": + return + + # Compile for paddleocr_vl vision encoder layer + def apply_compile(fn): + backend = "CINN" if self.graph_opt_config.graph_opt_level >= 2 else None + return paddle.jit.to_static( + fn, + full_graph=False, + backend=backend, + ) + + from fastdeploy.model_executor.models.paddleocr_vl.siglip import SiglipEncoder + + SiglipEncoder._run_encoder_layer = apply_compile(SiglipEncoder._run_encoder_layer) + + # Warmup for paddleocr_vl vision encoder layer + logger.info(f"Warmup for {self.model_config.model_type} compile...") + self._dummy_run_extract_vision_features() + @sot_warmup_guard(True) def sot_warmup(self) -> None: start_time = time.perf_counter() @@ -2891,6 +2915,40 @@ class GPUModelRunner(ModelRunnerBase): else: raise ValueError(f"multiple modalities model {self.model_config.model_type} is not supported") + @paddle.no_grad() + def _dummy_run_extract_vision_features(self): + grid_thw_list = ([(1, 10, 88), (1, 10, 80)], [(1, 14, 62), (1, 20, 42), (1, 14, 60)]) + for grid_thw in grid_thw_list: + images = [] + position_ids = [] + cu_seqlens = [0] + for idx, thw in enumerate(grid_thw): + numel = np.prod(np.array(thw)) + images.append(paddle.uniform(shape=[numel, 3, 14, 14], dtype="float32", min=0.0, max=1.0)) + position_ids.append(paddle.arange(numel) % np.prod(thw[1:])) + cu_seqlens.append(cu_seqlens[-1] + numel) + + images = paddle.concat(images, axis=0) + position_ids = paddle.concat(position_ids, axis=0).to(images.place) + cu_seqlens = paddle.to_tensor(cu_seqlens, dtype=paddle.int32).to(images.place) + + with paddle.amp.auto_cast( + True, + custom_black_list=self.amp_black, + custom_white_list=self.amp_white, + level="O2", + dtype=self.model_config.dtype, + ): + self.model.visual( + pixel_values=images, + image_grid_thw=grid_thw, + position_ids=position_ids, + interpolate_pos_encoding=True, + cu_seqlens=cu_seqlens, + use_rope=True, + window_size=-1, + ) + @paddle.no_grad() def prepare_rope3d( self, position_ids: paddle.Tensor, max_len_lst: list[int], cumsum_seqlens: list[int] diff --git a/fastdeploy/worker/gpu_worker.py b/fastdeploy/worker/gpu_worker.py index 9fcf9efcc..0d57ccf25 100644 --- a/fastdeploy/worker/gpu_worker.py +++ b/fastdeploy/worker/gpu_worker.py @@ -209,6 +209,8 @@ class GpuWorker(WorkerBase): """ if self.fd_config.graph_opt_config.graph_opt_level >= 1 and not self.model_runner.use_cudagraph: self.model_runner.sot_warmup() + if self.fd_config.graph_opt_config.graph_opt_level >= 1: + self.model_runner.vision_encoder_compile() # Trigger cuda graph capture self.model_runner.capture_model() diff --git a/fastdeploy/worker/model_runner_base.py b/fastdeploy/worker/model_runner_base.py index ece670530..a884131b0 100644 --- a/fastdeploy/worker/model_runner_base.py +++ b/fastdeploy/worker/model_runner_base.py @@ -93,3 +93,9 @@ class ModelRunnerBase(ABC): Execute a forward pass with dummy inputs to profile the memory usage of the model." """ raise NotImplementedError + + def vision_encoder_compile(self): + """ + Compile the vision encoder if applicable + """ + logger.info(f"No vision encoder compilation for base {self.__class__.__name__}") diff --git a/tests/e2e/test_paddleocr_vl_serving.py b/tests/e2e/test_paddleocr_vl_serving.py index dc12f21bc..6bfaf0a31 100644 --- a/tests/e2e/test_paddleocr_vl_serving.py +++ b/tests/e2e/test_paddleocr_vl_serving.py @@ -33,10 +33,14 @@ from utils.serving_utils import ( os.environ["FD_USE_MACHETE"] = "0" -@pytest.fixture(scope="session", autouse=True) -def setup_and_run_server(): +@pytest.fixture(scope="session", autouse=True, params=[0, 2]) +def setup_and_run_server(request): """ - Pytest fixture that runs once per test session: + Pytest fixture that runs once per test session, parameterized by `graph_opt_level`: + - Runs tests with graph_opt_level=0 (dynamic graph with fused ops) + - Runs tests with graph_opt_level=2 (CINN compilation) + + This ensures the API server is tested under both graph optimization configurations. - Cleans ports before tests - Starts the API server as a subprocess - Waits for server port to open (up to 30 seconds) @@ -55,6 +59,7 @@ def setup_and_run_server(): model_path = "./PaddleOCR-VL-0.9B" log_path = "server.log" + graph_opt_level = request.param cmd = [ sys.executable, @@ -80,7 +85,7 @@ def setup_and_run_server(): "--gpu-memory-utilization", "0.9", "--graph-optimization-config", - '{"graph_opt_level":0, "use_cudagraph":true}', + f'{{"graph_opt_level":{graph_opt_level}, "use_cudagraph":true}}', ] # Start subprocess in new process group