[Graph Optimization][CINN] Use CINN in PaddleOCR-VL ViT part (#5223)

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Nyakku Shigure
2025-12-09 14:37:00 +08:00
committed by GitHub
parent 8d99bac532
commit e1c4a12e34
8 changed files with 120 additions and 10 deletions

View File

@@ -1585,6 +1585,7 @@ class EngineService:
"SOT_UNSAFE_CACHE_FASTPATH": os.getenv("SOT_UNSAFE_CACHE_FASTPATH", default="1"),
"SOT_ENABLE_0_SIZE_FALLBACK": os.getenv("SOT_ENABLE_0_SIZE_FALLBACK", default="0"),
"SOT_SPECIALIZED_DIM_NUMBERS": os.getenv("SOT_SPECIALIZED_DIM_NUMBERS", default="no"),
"SOT_ENABLE_COMPILE_TIME_LIMIT": os.getenv("SOT_ENABLE_COMPILE_TIME_LIMIT", default="0"),
"FLAGS_specialize_device_in_dy2st": os.getenv("FLAGS_specialize_device_in_dy2st", default="1"),
"FLAGS_enable_async_fast_gc": os.getenv("FLAGS_enable_async_fast_gc", default="0"),
"FLAGS_pir_interpreter_record_stream_for_gc_cache": os.getenv(

View File

@@ -464,6 +464,7 @@ class LLMEngine:
"SOT_UNSAFE_CACHE_FASTPATH": os.getenv("SOT_UNSAFE_CACHE_FASTPATH", default="1"),
"SOT_ENABLE_0_SIZE_FALLBACK": os.getenv("SOT_ENABLE_0_SIZE_FALLBACK", default="0"),
"SOT_SPECIALIZED_DIM_NUMBERS": os.getenv("SOT_SPECIALIZED_DIM_NUMBERS", default="no"),
"SOT_ENABLE_COMPILE_TIME_LIMIT": os.getenv("SOT_ENABLE_COMPILE_TIME_LIMIT", default="0"),
"FLAGS_specialize_device_in_dy2st": os.getenv("FLAGS_specialize_device_in_dy2st", default="1"),
"FLAGS_enable_async_fast_gc": os.getenv("FLAGS_enable_async_fast_gc", default="0"),
"FLAGS_pir_interpreter_record_stream_for_gc_cache": os.getenv(

View File

@@ -281,7 +281,6 @@ class SiglipMLP(nn.Layer):
def __init__(self, config):
super().__init__()
self.config = config
self.activation_fn = get_activation_fn(config.hidden_act)
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc1.weight.weight_loader = self.weight_loader
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
@@ -304,7 +303,7 @@ class SiglipMLP(nn.Layer):
def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states[0])
hidden_states = get_activation_fn(self.config.hidden_act)(hidden_states[0])
hidden_states = self.fc2(hidden_states)
return hidden_states
@@ -318,7 +317,6 @@ class SiglipEncoderLayer(paddle.nn.Layer):
self.layer_norm2 = paddle.nn.LayerNorm(self.embed_dim, epsilon=config.layer_norm_eps)
self.mlp = SiglipMLP(config)
# @paddle.jit.to_static
def forward(
self,
hidden_states,
@@ -527,7 +525,37 @@ class SiglipEncoder(nn.Layer):
else:
attn_cu_seqlens = cu_seqlens
max_seqlen = (attn_cu_seqlens[1:] - attn_cu_seqlens[:-1]).max().item()
return self._run_encoder_layer(
encoder_states=encoder_states,
all_attentions=all_attentions,
attn_cu_seqlens=attn_cu_seqlens,
output_hidden_states=output_hidden_states,
reversed_window_indices=reversed_window_indices if output_hidden_states else None,
use_window_attn=use_window_attn,
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
cos_emb=cos_emb,
sin_emb=sin_emb,
)
# This function will be compiled with CINN when graph_opt_level >= 2
# TODO(SigureMo): Use a new decorator to mark the function for CINN compilation
def _run_encoder_layer(
self,
encoder_states: Optional[Tuple[()]],
all_attentions: Optional[Tuple[()]],
attn_cu_seqlens: Optional[paddle.Tensor],
output_hidden_states: Optional[bool],
reversed_window_indices: paddle.Tensor,
use_window_attn: bool,
hidden_states: paddle.Tensor,
attention_mask: Optional[paddle.Tensor],
output_attentions: bool,
cos_emb: Optional[paddle.Tensor],
sin_emb: Optional[paddle.Tensor],
) -> paddle.Tensor:
max_seqlen = (attn_cu_seqlens[1:] - attn_cu_seqlens[:-1]).max().cpu()
for encoder_layer in self.layers:
if output_hidden_states:

View File

@@ -27,6 +27,8 @@ if current_platform.is_cuda():
def rotate_half(x):
Dh = x.shape[-1]
if Dh == -1:
Dh = paddle.shape(x)[-1]
x1 = x[..., : Dh // 2]
x2 = x[..., Dh // 2 :]
return paddle.concat([-x2, x1], axis=-1)
@@ -41,6 +43,8 @@ def apply_rotary_pos_emb_vision(x, cos, sin):
def native_neox_rope_embedding(qkv, cos, sin, num_heads):
B, seq_length, D = qkv.shape
if seq_length == -1:
_, seq_length, _ = paddle.shape(qkv)
qkv = qkv.reshape(
[
seq_length,
@@ -55,18 +59,23 @@ def native_neox_rope_embedding(qkv, cos, sin, num_heads):
return q, k, v
jit_unified_marker = paddle.jit.marker.unified if hasattr(paddle.jit.marker, "unified") else lambda fn: fn
@jit_unified_marker
def neox_rope_embedding(
qkv: paddle.Tensor, cos_emb: paddle.Tensor, sin_emb: paddle.Tensor, num_heads: int, head_dim: int
) -> List[paddle.Tensor]:
if current_platform.is_cuda():
if current_platform.is_cuda() and paddle.in_dynamic_mode():
return fused_neox_rope_embedding(qkv, cos_emb, sin_emb, num_heads, head_dim)
else:
return native_neox_rope_embedding(qkv, cos_emb, sin_emb, num_heads)
@jit_unified_marker
def get_activation_fn(hidden_act: str):
if hidden_act == "gelu_pytorch_tanh":
if current_platform.is_cuda():
if current_platform.is_cuda() and paddle.in_dynamic_mode():
return gelu_tanh
else:
return ACT2FN["gelu_new"]

View File

@@ -2183,6 +2183,30 @@ class GPUModelRunner(ModelRunnerBase):
time_after_capture = time.perf_counter()
logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds")
def vision_encoder_compile(self):
if self.graph_opt_config.graph_opt_level == 0:
return
# Currently only PaddleOCR-VL model is supported for vision encoder layer
if self.model_config.model_type != "paddleocr_vl":
return
# Compile for paddleocr_vl vision encoder layer
def apply_compile(fn):
backend = "CINN" if self.graph_opt_config.graph_opt_level >= 2 else None
return paddle.jit.to_static(
fn,
full_graph=False,
backend=backend,
)
from fastdeploy.model_executor.models.paddleocr_vl.siglip import SiglipEncoder
SiglipEncoder._run_encoder_layer = apply_compile(SiglipEncoder._run_encoder_layer)
# Warmup for paddleocr_vl vision encoder layer
logger.info(f"Warmup for {self.model_config.model_type} compile...")
self._dummy_run_extract_vision_features()
@sot_warmup_guard(True)
def sot_warmup(self) -> None:
start_time = time.perf_counter()
@@ -2891,6 +2915,40 @@ class GPUModelRunner(ModelRunnerBase):
else:
raise ValueError(f"multiple modalities model {self.model_config.model_type} is not supported")
@paddle.no_grad()
def _dummy_run_extract_vision_features(self):
grid_thw_list = ([(1, 10, 88), (1, 10, 80)], [(1, 14, 62), (1, 20, 42), (1, 14, 60)])
for grid_thw in grid_thw_list:
images = []
position_ids = []
cu_seqlens = [0]
for idx, thw in enumerate(grid_thw):
numel = np.prod(np.array(thw))
images.append(paddle.uniform(shape=[numel, 3, 14, 14], dtype="float32", min=0.0, max=1.0))
position_ids.append(paddle.arange(numel) % np.prod(thw[1:]))
cu_seqlens.append(cu_seqlens[-1] + numel)
images = paddle.concat(images, axis=0)
position_ids = paddle.concat(position_ids, axis=0).to(images.place)
cu_seqlens = paddle.to_tensor(cu_seqlens, dtype=paddle.int32).to(images.place)
with paddle.amp.auto_cast(
True,
custom_black_list=self.amp_black,
custom_white_list=self.amp_white,
level="O2",
dtype=self.model_config.dtype,
):
self.model.visual(
pixel_values=images,
image_grid_thw=grid_thw,
position_ids=position_ids,
interpolate_pos_encoding=True,
cu_seqlens=cu_seqlens,
use_rope=True,
window_size=-1,
)
@paddle.no_grad()
def prepare_rope3d(
self, position_ids: paddle.Tensor, max_len_lst: list[int], cumsum_seqlens: list[int]

View File

@@ -209,6 +209,8 @@ class GpuWorker(WorkerBase):
"""
if self.fd_config.graph_opt_config.graph_opt_level >= 1 and not self.model_runner.use_cudagraph:
self.model_runner.sot_warmup()
if self.fd_config.graph_opt_config.graph_opt_level >= 1:
self.model_runner.vision_encoder_compile()
# Trigger cuda graph capture
self.model_runner.capture_model()

View File

@@ -93,3 +93,9 @@ class ModelRunnerBase(ABC):
Execute a forward pass with dummy inputs to profile the memory usage of the model."
"""
raise NotImplementedError
def vision_encoder_compile(self):
"""
Compile the vision encoder if applicable
"""
logger.info(f"No vision encoder compilation for base {self.__class__.__name__}")

View File

@@ -33,10 +33,14 @@ from utils.serving_utils import (
os.environ["FD_USE_MACHETE"] = "0"
@pytest.fixture(scope="session", autouse=True)
def setup_and_run_server():
@pytest.fixture(scope="session", autouse=True, params=[0, 2])
def setup_and_run_server(request):
"""
Pytest fixture that runs once per test session:
Pytest fixture that runs once per test session, parameterized by `graph_opt_level`:
- Runs tests with graph_opt_level=0 (dynamic graph with fused ops)
- Runs tests with graph_opt_level=2 (CINN compilation)
This ensures the API server is tested under both graph optimization configurations.
- Cleans ports before tests
- Starts the API server as a subprocess
- Waits for server port to open (up to 30 seconds)
@@ -55,6 +59,7 @@ def setup_and_run_server():
model_path = "./PaddleOCR-VL-0.9B"
log_path = "server.log"
graph_opt_level = request.param
cmd = [
sys.executable,
@@ -80,7 +85,7 @@ def setup_and_run_server():
"--gpu-memory-utilization",
"0.9",
"--graph-optimization-config",
'{"graph_opt_level":0, "use_cudagraph":true}',
f'{{"graph_opt_level":{graph_opt_level}, "use_cudagraph":true}}',
]
# Start subprocess in new process group