From f72be7a2c82ef1c73e0a8c05230e30bf097ec442 Mon Sep 17 00:00:00 2001 From: kevin Date: Thu, 16 Oct 2025 16:46:40 +0800 Subject: [PATCH] [BUG] fix ep bug (#4275) * fix ep bug * update code * update code * update code * [BugFix] fix config bugs (#4370) * Update expert_service.py * Update common_engine.py * Update expert_service.py * Update expert_service.py * Update expert_service.py --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> * update code --------- Co-authored-by: ltd0924 <32387785+ltd0924@users.noreply.github.com> Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> --- fastdeploy/scheduler/splitwise_scheduler.py | 38 ++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/fastdeploy/scheduler/splitwise_scheduler.py b/fastdeploy/scheduler/splitwise_scheduler.py index ab1799f44..7c404c891 100644 --- a/fastdeploy/scheduler/splitwise_scheduler.py +++ b/fastdeploy/scheduler/splitwise_scheduler.py @@ -79,6 +79,19 @@ class SplitWiseSchedulerConfig: self.writer_parallel = writer_parallel self.writer_batch_size = writer_batch_size + self.max_model_len = kwargs.get("max_model_len") + self.enable_chunked_prefill = kwargs.get("enable_chunked_prefill") + self.max_num_partial_prefills = kwargs.get("max_num_partial_prefills") + self.max_long_partial_prefills = kwargs.get("max_long_partial_prefills") + self.long_prefill_token_threshold = kwargs.get("long_prefill_token_threshold") + + assert self.enable_chunked_prefill is not None, "enable_chunked_prefill must be set" + assert self.max_num_partial_prefills is not None, "max_num_partial_prefills must be set" + assert self.max_long_partial_prefills is not None, "max_long_partial_prefills must be set" + if self.long_prefill_token_threshold is None or self.long_prefill_token_threshold == 0: + assert self.max_model_len is not None, "max_model_len must be set" + self.long_prefill_token_threshold = int(self.max_model_len * 0.04) + def check(self): """check argument""" pass @@ -674,6 +687,7 @@ class InferScheduler: """ def __init__(self, config): + self.config = config self.nodeid = config.nodeid self.writer_parallel = config.writer_parallel self.writer_batch_size = config.writer_batch_size @@ -792,9 +806,13 @@ class InferScheduler: reqs = [] required_blocks = 0 current_prefill_tokens = 0 + long_partial_requests, short_partial_requests = 0, 0 cur_time = time.time() for i in range(batch): try: + if len(self.reqs_queue) == 0: + break + req = self.reqs_queue.popleft() if cur_time - req.arrival_time > self.ttl: logger.error(f"req({req.request_id}) is expired({self.ttl}) when InferScheduler Get Requests") @@ -803,9 +821,27 @@ class InferScheduler: current_prefill_tokens += req.prompt_token_ids_len required_input_blocks = (req.prompt_token_ids_len + block_size - 1) // block_size required_blocks += required_input_blocks + reserved_output_blocks - if required_blocks > available_blocks or current_prefill_tokens > max_num_batched_tokens: + if required_blocks > available_blocks: self.reqs_queue.appendleft(req) return reqs + + if self.config.enable_chunked_prefill: + if req.prompt_token_ids_len > self.config.long_prefill_token_threshold: + # long partial requests + long_partial_requests += 1 + if long_partial_requests > self.config.max_long_partial_prefills: + self.reqs_queue.appendleft(req) + break + else: + short_partial_requests += 1 + + if short_partial_requests + long_partial_requests > self.config.max_num_partial_prefills: + self.reqs_queue.appendleft(req) + break + else: + if current_prefill_tokens > max_num_batched_tokens: + self.reqs_queue.appendleft(req) + break # logger.info(f"Get Requests from Scheduler: {req.request_id}") reqs.append(req) except Exception as e: