[Executor] Fix bug of import paddle with RLHF (#3781)

This commit is contained in:
RAM
2025-09-02 17:32:13 +08:00
committed by GitHub
parent 306c024ff3
commit 205b706ef8

View File

@@ -20,7 +20,6 @@ from typing import Callable, Dict, Optional
import paddle.nn.layer import paddle.nn.layer
from paddle.device.cuda import graphs from paddle.device.cuda import graphs
from paddle.jit.dy2static.utils import CUDAGraphState
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication import capture_custom_allreduce from fastdeploy.distributed.communication import capture_custom_allreduce
@@ -52,11 +51,16 @@ class ConcreteSizeEntry:
class Dy2StCudaGraphManager: class Dy2StCudaGraphManager:
def __init__(self): def __init__(self):
# NOTE(gongshaotian): Use local import to avoid RLHF version problems
from paddle.jit.dy2static.utils import CUDAGraphState
self.state = CUDAGraphState.DISABLE self.state = CUDAGraphState.DISABLE
self.captured_batch_size = set() self.captured_batch_size = set()
self.batch_size = -1 self.batch_size = -1
def run_impl(self, original_run_impl, inputs, parameters, attrs): def run_impl(self, original_run_impl, inputs, parameters, attrs):
from paddle.jit.dy2static.utils import CUDAGraphState
run_state = self.state run_state = self.state
prog_attrs, cuda_graph_attrs = attrs prog_attrs, cuda_graph_attrs = attrs
if run_state == CUDAGraphState.REPLAY: if run_state == CUDAGraphState.REPLAY:
@@ -100,6 +104,8 @@ class CudaGraphPiecewiseBackend:
self.cuda_graph_manager = Dy2StCudaGraphManager() self.cuda_graph_manager = Dy2StCudaGraphManager()
def run_static_model(self, entry: ConcreteSizeEntry, **kwargs): def run_static_model(self, entry: ConcreteSizeEntry, **kwargs):
from paddle.jit.dy2static.utils import CUDAGraphState
if not entry.captured: if not entry.captured:
# Warmup the model # Warmup the model
for n in range(entry.num_finished_warmup, self.warm_up_size): for n in range(entry.num_finished_warmup, self.warm_up_size):