mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Graph Optimization][CINN] Use CINN in PaddleOCR-VL ViT part (#5223)
--------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -1585,6 +1585,7 @@ class EngineService:
|
||||
"SOT_UNSAFE_CACHE_FASTPATH": os.getenv("SOT_UNSAFE_CACHE_FASTPATH", default="1"),
|
||||
"SOT_ENABLE_0_SIZE_FALLBACK": os.getenv("SOT_ENABLE_0_SIZE_FALLBACK", default="0"),
|
||||
"SOT_SPECIALIZED_DIM_NUMBERS": os.getenv("SOT_SPECIALIZED_DIM_NUMBERS", default="no"),
|
||||
"SOT_ENABLE_COMPILE_TIME_LIMIT": os.getenv("SOT_ENABLE_COMPILE_TIME_LIMIT", default="0"),
|
||||
"FLAGS_specialize_device_in_dy2st": os.getenv("FLAGS_specialize_device_in_dy2st", default="1"),
|
||||
"FLAGS_enable_async_fast_gc": os.getenv("FLAGS_enable_async_fast_gc", default="0"),
|
||||
"FLAGS_pir_interpreter_record_stream_for_gc_cache": os.getenv(
|
||||
|
||||
@@ -464,6 +464,7 @@ class LLMEngine:
|
||||
"SOT_UNSAFE_CACHE_FASTPATH": os.getenv("SOT_UNSAFE_CACHE_FASTPATH", default="1"),
|
||||
"SOT_ENABLE_0_SIZE_FALLBACK": os.getenv("SOT_ENABLE_0_SIZE_FALLBACK", default="0"),
|
||||
"SOT_SPECIALIZED_DIM_NUMBERS": os.getenv("SOT_SPECIALIZED_DIM_NUMBERS", default="no"),
|
||||
"SOT_ENABLE_COMPILE_TIME_LIMIT": os.getenv("SOT_ENABLE_COMPILE_TIME_LIMIT", default="0"),
|
||||
"FLAGS_specialize_device_in_dy2st": os.getenv("FLAGS_specialize_device_in_dy2st", default="1"),
|
||||
"FLAGS_enable_async_fast_gc": os.getenv("FLAGS_enable_async_fast_gc", default="0"),
|
||||
"FLAGS_pir_interpreter_record_stream_for_gc_cache": os.getenv(
|
||||
|
||||
@@ -281,7 +281,6 @@ class SiglipMLP(nn.Layer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = get_activation_fn(config.hidden_act)
|
||||
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
self.fc1.weight.weight_loader = self.weight_loader
|
||||
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
@@ -304,7 +303,7 @@ class SiglipMLP(nn.Layer):
|
||||
|
||||
def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states[0])
|
||||
hidden_states = get_activation_fn(self.config.hidden_act)(hidden_states[0])
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
@@ -318,7 +317,6 @@ class SiglipEncoderLayer(paddle.nn.Layer):
|
||||
self.layer_norm2 = paddle.nn.LayerNorm(self.embed_dim, epsilon=config.layer_norm_eps)
|
||||
self.mlp = SiglipMLP(config)
|
||||
|
||||
# @paddle.jit.to_static
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
@@ -527,7 +525,37 @@ class SiglipEncoder(nn.Layer):
|
||||
else:
|
||||
attn_cu_seqlens = cu_seqlens
|
||||
|
||||
max_seqlen = (attn_cu_seqlens[1:] - attn_cu_seqlens[:-1]).max().item()
|
||||
return self._run_encoder_layer(
|
||||
encoder_states=encoder_states,
|
||||
all_attentions=all_attentions,
|
||||
attn_cu_seqlens=attn_cu_seqlens,
|
||||
output_hidden_states=output_hidden_states,
|
||||
reversed_window_indices=reversed_window_indices if output_hidden_states else None,
|
||||
use_window_attn=use_window_attn,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
cos_emb=cos_emb,
|
||||
sin_emb=sin_emb,
|
||||
)
|
||||
|
||||
# This function will be compiled with CINN when graph_opt_level >= 2
|
||||
# TODO(SigureMo): Use a new decorator to mark the function for CINN compilation
|
||||
def _run_encoder_layer(
|
||||
self,
|
||||
encoder_states: Optional[Tuple[()]],
|
||||
all_attentions: Optional[Tuple[()]],
|
||||
attn_cu_seqlens: Optional[paddle.Tensor],
|
||||
output_hidden_states: Optional[bool],
|
||||
reversed_window_indices: paddle.Tensor,
|
||||
use_window_attn: bool,
|
||||
hidden_states: paddle.Tensor,
|
||||
attention_mask: Optional[paddle.Tensor],
|
||||
output_attentions: bool,
|
||||
cos_emb: Optional[paddle.Tensor],
|
||||
sin_emb: Optional[paddle.Tensor],
|
||||
) -> paddle.Tensor:
|
||||
max_seqlen = (attn_cu_seqlens[1:] - attn_cu_seqlens[:-1]).max().cpu()
|
||||
|
||||
for encoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
|
||||
@@ -27,6 +27,8 @@ if current_platform.is_cuda():
|
||||
|
||||
def rotate_half(x):
|
||||
Dh = x.shape[-1]
|
||||
if Dh == -1:
|
||||
Dh = paddle.shape(x)[-1]
|
||||
x1 = x[..., : Dh // 2]
|
||||
x2 = x[..., Dh // 2 :]
|
||||
return paddle.concat([-x2, x1], axis=-1)
|
||||
@@ -41,6 +43,8 @@ def apply_rotary_pos_emb_vision(x, cos, sin):
|
||||
|
||||
def native_neox_rope_embedding(qkv, cos, sin, num_heads):
|
||||
B, seq_length, D = qkv.shape
|
||||
if seq_length == -1:
|
||||
_, seq_length, _ = paddle.shape(qkv)
|
||||
qkv = qkv.reshape(
|
||||
[
|
||||
seq_length,
|
||||
@@ -55,18 +59,23 @@ def native_neox_rope_embedding(qkv, cos, sin, num_heads):
|
||||
return q, k, v
|
||||
|
||||
|
||||
jit_unified_marker = paddle.jit.marker.unified if hasattr(paddle.jit.marker, "unified") else lambda fn: fn
|
||||
|
||||
|
||||
@jit_unified_marker
|
||||
def neox_rope_embedding(
|
||||
qkv: paddle.Tensor, cos_emb: paddle.Tensor, sin_emb: paddle.Tensor, num_heads: int, head_dim: int
|
||||
) -> List[paddle.Tensor]:
|
||||
if current_platform.is_cuda():
|
||||
if current_platform.is_cuda() and paddle.in_dynamic_mode():
|
||||
return fused_neox_rope_embedding(qkv, cos_emb, sin_emb, num_heads, head_dim)
|
||||
else:
|
||||
return native_neox_rope_embedding(qkv, cos_emb, sin_emb, num_heads)
|
||||
|
||||
|
||||
@jit_unified_marker
|
||||
def get_activation_fn(hidden_act: str):
|
||||
if hidden_act == "gelu_pytorch_tanh":
|
||||
if current_platform.is_cuda():
|
||||
if current_platform.is_cuda() and paddle.in_dynamic_mode():
|
||||
return gelu_tanh
|
||||
else:
|
||||
return ACT2FN["gelu_new"]
|
||||
|
||||
@@ -2183,6 +2183,30 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
time_after_capture = time.perf_counter()
|
||||
logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds")
|
||||
|
||||
def vision_encoder_compile(self):
|
||||
if self.graph_opt_config.graph_opt_level == 0:
|
||||
return
|
||||
# Currently only PaddleOCR-VL model is supported for vision encoder layer
|
||||
if self.model_config.model_type != "paddleocr_vl":
|
||||
return
|
||||
|
||||
# Compile for paddleocr_vl vision encoder layer
|
||||
def apply_compile(fn):
|
||||
backend = "CINN" if self.graph_opt_config.graph_opt_level >= 2 else None
|
||||
return paddle.jit.to_static(
|
||||
fn,
|
||||
full_graph=False,
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
from fastdeploy.model_executor.models.paddleocr_vl.siglip import SiglipEncoder
|
||||
|
||||
SiglipEncoder._run_encoder_layer = apply_compile(SiglipEncoder._run_encoder_layer)
|
||||
|
||||
# Warmup for paddleocr_vl vision encoder layer
|
||||
logger.info(f"Warmup for {self.model_config.model_type} compile...")
|
||||
self._dummy_run_extract_vision_features()
|
||||
|
||||
@sot_warmup_guard(True)
|
||||
def sot_warmup(self) -> None:
|
||||
start_time = time.perf_counter()
|
||||
@@ -2891,6 +2915,40 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
else:
|
||||
raise ValueError(f"multiple modalities model {self.model_config.model_type} is not supported")
|
||||
|
||||
@paddle.no_grad()
|
||||
def _dummy_run_extract_vision_features(self):
|
||||
grid_thw_list = ([(1, 10, 88), (1, 10, 80)], [(1, 14, 62), (1, 20, 42), (1, 14, 60)])
|
||||
for grid_thw in grid_thw_list:
|
||||
images = []
|
||||
position_ids = []
|
||||
cu_seqlens = [0]
|
||||
for idx, thw in enumerate(grid_thw):
|
||||
numel = np.prod(np.array(thw))
|
||||
images.append(paddle.uniform(shape=[numel, 3, 14, 14], dtype="float32", min=0.0, max=1.0))
|
||||
position_ids.append(paddle.arange(numel) % np.prod(thw[1:]))
|
||||
cu_seqlens.append(cu_seqlens[-1] + numel)
|
||||
|
||||
images = paddle.concat(images, axis=0)
|
||||
position_ids = paddle.concat(position_ids, axis=0).to(images.place)
|
||||
cu_seqlens = paddle.to_tensor(cu_seqlens, dtype=paddle.int32).to(images.place)
|
||||
|
||||
with paddle.amp.auto_cast(
|
||||
True,
|
||||
custom_black_list=self.amp_black,
|
||||
custom_white_list=self.amp_white,
|
||||
level="O2",
|
||||
dtype=self.model_config.dtype,
|
||||
):
|
||||
self.model.visual(
|
||||
pixel_values=images,
|
||||
image_grid_thw=grid_thw,
|
||||
position_ids=position_ids,
|
||||
interpolate_pos_encoding=True,
|
||||
cu_seqlens=cu_seqlens,
|
||||
use_rope=True,
|
||||
window_size=-1,
|
||||
)
|
||||
|
||||
@paddle.no_grad()
|
||||
def prepare_rope3d(
|
||||
self, position_ids: paddle.Tensor, max_len_lst: list[int], cumsum_seqlens: list[int]
|
||||
|
||||
@@ -209,6 +209,8 @@ class GpuWorker(WorkerBase):
|
||||
"""
|
||||
if self.fd_config.graph_opt_config.graph_opt_level >= 1 and not self.model_runner.use_cudagraph:
|
||||
self.model_runner.sot_warmup()
|
||||
if self.fd_config.graph_opt_config.graph_opt_level >= 1:
|
||||
self.model_runner.vision_encoder_compile()
|
||||
# Trigger cuda graph capture
|
||||
self.model_runner.capture_model()
|
||||
|
||||
|
||||
@@ -93,3 +93,9 @@ class ModelRunnerBase(ABC):
|
||||
Execute a forward pass with dummy inputs to profile the memory usage of the model."
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def vision_encoder_compile(self):
|
||||
"""
|
||||
Compile the vision encoder if applicable
|
||||
"""
|
||||
logger.info(f"No vision encoder compilation for base {self.__class__.__name__}")
|
||||
|
||||
@@ -33,10 +33,14 @@ from utils.serving_utils import (
|
||||
os.environ["FD_USE_MACHETE"] = "0"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def setup_and_run_server():
|
||||
@pytest.fixture(scope="session", autouse=True, params=[0, 2])
|
||||
def setup_and_run_server(request):
|
||||
"""
|
||||
Pytest fixture that runs once per test session:
|
||||
Pytest fixture that runs once per test session, parameterized by `graph_opt_level`:
|
||||
- Runs tests with graph_opt_level=0 (dynamic graph with fused ops)
|
||||
- Runs tests with graph_opt_level=2 (CINN compilation)
|
||||
|
||||
This ensures the API server is tested under both graph optimization configurations.
|
||||
- Cleans ports before tests
|
||||
- Starts the API server as a subprocess
|
||||
- Waits for server port to open (up to 30 seconds)
|
||||
@@ -55,6 +59,7 @@ def setup_and_run_server():
|
||||
model_path = "./PaddleOCR-VL-0.9B"
|
||||
|
||||
log_path = "server.log"
|
||||
graph_opt_level = request.param
|
||||
|
||||
cmd = [
|
||||
sys.executable,
|
||||
@@ -80,7 +85,7 @@ def setup_and_run_server():
|
||||
"--gpu-memory-utilization",
|
||||
"0.9",
|
||||
"--graph-optimization-config",
|
||||
'{"graph_opt_level":0, "use_cudagraph":true}',
|
||||
f'{{"graph_opt_level":{graph_opt_level}, "use_cudagraph":true}}',
|
||||
]
|
||||
|
||||
# Start subprocess in new process group
|
||||
|
||||
Reference in New Issue
Block a user