[Bug fix] Fix bug for running ep (#4245)

* fix bug for ep

* fix bug
This commit is contained in:
chenjian
2025-09-28 14:56:18 +08:00
committed by GitHub
parent 17e00d9f5d
commit 3cef851468
5 changed files with 54 additions and 18 deletions

View File

@@ -218,7 +218,7 @@ class CacheMessager:
try:
prefilled_step_idx_data = np.zeros(shape=[1], dtype=np.int32)
prefilled_layer_idx_data = np.zeros(shape=[1], dtype=np.int32)
prefilled_layer_name = f"splitwise_complete_prefilled_step_{self.rank_id}.{self.gpu_id}"
prefilled_layer_name = f"splitwise_complete_prefilled_layer_{self.rank_id}.{self.gpu_id}"
prefilled_step_name = f"splitwise_complete_prefilled_step_{self.rank_id}.{self.gpu_id}"
step_shm_value = IPCSignal(
name=f"splitwise_complete_prefilled_step_{self.rank_id}",

View File

@@ -30,7 +30,7 @@ import paddle
import zmq
from opentelemetry import trace
from fastdeploy.engine.request import Request, RequestOutput
from fastdeploy.engine.request import Request, RequestOutput, RequestType
from fastdeploy.engine.resource_manager import ResourceManager
from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1
from fastdeploy.inter_communicator import (
@@ -77,6 +77,7 @@ class EngineService:
self.llm_logger = llm_logger
self.scheduler = cfg.scheduler_config.scheduler()
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.resource_manager = ResourceManagerV1(
@@ -623,7 +624,7 @@ class EngineService:
for tmp_task in need_delete_tasks:
tasks.remove(tmp_task)
# release resource in P
self.resource_manager.prerelease_resource(task)
self.resource_manager.prerelease_resource(tmp_task)
if self.cfg.scheduler_config.splitwise_role == "prefill":
# to send cache info to cache messager
if tasks:
@@ -673,6 +674,21 @@ class EngineService:
tasks = self.resource_manager.schedule()
# 3. Send to engine
if tasks:
if self.cfg.scheduler_config.splitwise_role == "decode":
for task in tasks:
if task.task_type == RequestType.PREEMPTED:
msg = f"{task.request_id} decode not enough blocks, need to be rescheduled."
self.llm_logger.error(msg)
self.scheduler.put_results(
[
RequestOutput(
request_id=task.request_id,
finished=True,
error_code=500,
error_msg=msg,
)
]
)
self.resource_manager.get_real_bsz()
self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz))
else:

View File

@@ -651,6 +651,8 @@ class LLMEngine:
role = self.cfg.scheduler_config.splitwise_role
host_ip = self.cfg.host_ip
disaggregate = self.cfg.disaggregate_info
request_queues_for_dp_ipc = None
result_queue_for_dp_ipc = None
if self.cfg.scheduler_config.name == "splitwise":
self.engine.scheduler.start(role, host_ip, disaggregate)
elif self.cfg.scheduler_config.name == "dp":

View File

@@ -137,13 +137,23 @@ class ResourceManagerV1(ResourceManager):
preempted_req = self.running.pop()
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
self._free_blocks(preempted_req)
preempted_req.cached_block_num = 0
self.to_be_rescheduled_request_id_set.add(preempted_req.request_id)
if self.config.scheduler_config.splitwise_role == "decode":
self.tasks_list[preempted_req.idx] = None
self.stop_flags[preempted_req.idx] = True
if preempted_req.request_id in self.requests:
del self.requests[preempted_req.request_id]
if preempted_req.request_id in self.req_dict:
del self.req_dict[preempted_req.request_id]
self._free_blocks(preempted_req)
main_process_metrics.num_requests_running.dec(1)
else:
self._free_blocks(preempted_req)
preempted_req.cached_block_num = 0
self.to_be_rescheduled_request_id_set.add(preempted_req.request_id)
main_process_metrics.num_requests_waiting.inc(1)
main_process_metrics.num_requests_running.dec(1)
preempted_reqs.append(preempted_req)
scheduled_reqs.append(self._prepare_preempt_task(preempted_req))
main_process_metrics.num_requests_waiting.inc(1)
main_process_metrics.num_requests_running.dec(1)
if preempted_req == request:
# No more request to preempt.
can_schedule = False
@@ -588,8 +598,10 @@ class ResourceManagerV1(ResourceManager):
with self.lock:
self.tasks_list[request.idx] = None
self.stop_flags[request.idx] = True
del self.requests[request.request_id]
del self.req_dict[request.request_id]
if request.request_id in self.requests:
del self.requests[request.request_id]
if request.request_id in self.req_dict:
del self.req_dict[request.request_id]
self._free_blocks(request)
def add_request_in_p(self, requests: list[Request]):

View File

@@ -387,14 +387,20 @@ class SplitwiseConnector:
f"{tasks[i].disaggregate_info['cache_info']['rdma']['ip']}:"
+ f"{tasks[i].disaggregate_info['cache_info']['rdma']['port']}"
)
cache_info = {
"request_id": tasks[i].request_id,
"device_ids": self.cfg.device_ids.split(","),
"ip": self.cfg.host_ip,
"rdma_ports": self.cfg.disaggregate_info["cache_info"]["rdma"]["rdma_port"],
"transfer_protocol": "rdma",
"dest_block_ids": tasks[i].disaggregate_info["block_tables"],
}
if tasks[i].get("error_msg", None) is not None:
cache_info = {
"request_id": tasks[i].request_id,
"error_msg": tasks[i].get("error_msg"),
}
else:
cache_info = {
"request_id": tasks[i].request_id,
"device_ids": self.cfg.device_ids.split(","),
"ip": self.cfg.host_ip,
"rdma_ports": self.cfg.disaggregate_info["cache_info"]["rdma"]["rdma_port"],
"transfer_protocol": "rdma",
"dest_block_ids": tasks[i].disaggregate_info["block_tables"],
}
if addr not in temp_cache_info:
temp_cache_info[addr] = []