mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-26 20:41:53 +08:00
fix the bug for prefilled_step_idx signal of cache_messager in cudagraph and PD (#4235)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
This commit is contained in:
@@ -267,7 +267,6 @@ class CacheMessager:
|
|||||||
self.cache_info[info["request_id"]] = info
|
self.cache_info[info["request_id"]] = info
|
||||||
prefilled_layer_idx = layer_shm_value.value[0]
|
prefilled_layer_idx = layer_shm_value.value[0]
|
||||||
prefilled_step_idx = step_shm_value.value[0]
|
prefilled_step_idx = step_shm_value.value[0]
|
||||||
logger.info(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}")
|
|
||||||
if prefilled_layer_idx == self.num_layers - 1:
|
if prefilled_layer_idx == self.num_layers - 1:
|
||||||
time.sleep(0.001)
|
time.sleep(0.001)
|
||||||
prefilled_layer_idx = layer_shm_value.value[0]
|
prefilled_layer_idx = layer_shm_value.value[0]
|
||||||
|
@@ -442,6 +442,23 @@ class PaddleDisWorkerProc:
|
|||||||
|
|
||||||
def graph_optimize_and_warm_up_model(self) -> None:
|
def graph_optimize_and_warm_up_model(self) -> None:
|
||||||
self.worker.graph_optimize_and_warm_up_model()
|
self.worker.graph_optimize_and_warm_up_model()
|
||||||
|
# reset cache_messager prefilled_step signal
|
||||||
|
if self.scheduler_config.splitwise_role == "prefill":
|
||||||
|
dp_rank_id = (
|
||||||
|
self.local_rank
|
||||||
|
+ self.parallel_config.local_data_parallel_id * self.parallel_config.tensor_parallel_size
|
||||||
|
)
|
||||||
|
gpu_id = self.worker.model_runner.device_id
|
||||||
|
prefilled_step_name = f"splitwise_complete_prefilled_step_{dp_rank_id}"
|
||||||
|
prefilled_step_idx_data = np.zeros(shape=[1], dtype=np.int32)
|
||||||
|
step_shm_value = IPCSignal(
|
||||||
|
name=prefilled_step_name,
|
||||||
|
array=prefilled_step_idx_data,
|
||||||
|
dtype=np.int32,
|
||||||
|
suffix=gpu_id,
|
||||||
|
create=False,
|
||||||
|
)
|
||||||
|
step_shm_value.value[0] = -1
|
||||||
|
|
||||||
def init_device(self) -> None:
|
def init_device(self) -> None:
|
||||||
"""Initialize device and Construct model runner"""
|
"""Initialize device and Construct model runner"""
|
||||||
@@ -842,7 +859,7 @@ def run_worker_proc() -> None:
|
|||||||
worker_proc.initialize_kv_cache()
|
worker_proc.initialize_kv_cache()
|
||||||
|
|
||||||
# Trigger CUDAGraph capture
|
# Trigger CUDAGraph capture
|
||||||
worker_proc.worker.graph_optimize_and_warm_up_model()
|
worker_proc.graph_optimize_and_warm_up_model()
|
||||||
|
|
||||||
# Initialize health status
|
# Initialize health status
|
||||||
worker_proc.init_health_status()
|
worker_proc.init_health_status()
|
||||||
|
Reference in New Issue
Block a user