mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-11-02 04:43:27 +08:00
[BUG] fix ep bug (#4275)
* fix ep bug * update code * update code * update code * [BugFix] fix config bugs (#4370) * Update expert_service.py * Update common_engine.py * Update expert_service.py * Update expert_service.py * Update expert_service.py --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> * update code --------- Co-authored-by: ltd0924 <32387785+ltd0924@users.noreply.github.com> Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -79,6 +79,19 @@ class SplitWiseSchedulerConfig:
|
||||
self.writer_parallel = writer_parallel
|
||||
self.writer_batch_size = writer_batch_size
|
||||
|
||||
self.max_model_len = kwargs.get("max_model_len")
|
||||
self.enable_chunked_prefill = kwargs.get("enable_chunked_prefill")
|
||||
self.max_num_partial_prefills = kwargs.get("max_num_partial_prefills")
|
||||
self.max_long_partial_prefills = kwargs.get("max_long_partial_prefills")
|
||||
self.long_prefill_token_threshold = kwargs.get("long_prefill_token_threshold")
|
||||
|
||||
assert self.enable_chunked_prefill is not None, "enable_chunked_prefill must be set"
|
||||
assert self.max_num_partial_prefills is not None, "max_num_partial_prefills must be set"
|
||||
assert self.max_long_partial_prefills is not None, "max_long_partial_prefills must be set"
|
||||
if self.long_prefill_token_threshold is None or self.long_prefill_token_threshold == 0:
|
||||
assert self.max_model_len is not None, "max_model_len must be set"
|
||||
self.long_prefill_token_threshold = int(self.max_model_len * 0.04)
|
||||
|
||||
def check(self):
|
||||
"""check argument"""
|
||||
pass
|
||||
@@ -674,6 +687,7 @@ class InferScheduler:
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.nodeid = config.nodeid
|
||||
self.writer_parallel = config.writer_parallel
|
||||
self.writer_batch_size = config.writer_batch_size
|
||||
@@ -792,9 +806,13 @@ class InferScheduler:
|
||||
reqs = []
|
||||
required_blocks = 0
|
||||
current_prefill_tokens = 0
|
||||
long_partial_requests, short_partial_requests = 0, 0
|
||||
cur_time = time.time()
|
||||
for i in range(batch):
|
||||
try:
|
||||
if len(self.reqs_queue) == 0:
|
||||
break
|
||||
|
||||
req = self.reqs_queue.popleft()
|
||||
if cur_time - req.arrival_time > self.ttl:
|
||||
logger.error(f"req({req.request_id}) is expired({self.ttl}) when InferScheduler Get Requests")
|
||||
@@ -803,9 +821,27 @@ class InferScheduler:
|
||||
current_prefill_tokens += req.prompt_token_ids_len
|
||||
required_input_blocks = (req.prompt_token_ids_len + block_size - 1) // block_size
|
||||
required_blocks += required_input_blocks + reserved_output_blocks
|
||||
if required_blocks > available_blocks or current_prefill_tokens > max_num_batched_tokens:
|
||||
if required_blocks > available_blocks:
|
||||
self.reqs_queue.appendleft(req)
|
||||
return reqs
|
||||
|
||||
if self.config.enable_chunked_prefill:
|
||||
if req.prompt_token_ids_len > self.config.long_prefill_token_threshold:
|
||||
# long partial requests
|
||||
long_partial_requests += 1
|
||||
if long_partial_requests > self.config.max_long_partial_prefills:
|
||||
self.reqs_queue.appendleft(req)
|
||||
break
|
||||
else:
|
||||
short_partial_requests += 1
|
||||
|
||||
if short_partial_requests + long_partial_requests > self.config.max_num_partial_prefills:
|
||||
self.reqs_queue.appendleft(req)
|
||||
break
|
||||
else:
|
||||
if current_prefill_tokens > max_num_batched_tokens:
|
||||
self.reqs_queue.appendleft(req)
|
||||
break
|
||||
# logger.info(f"Get Requests from Scheduler: {req.request_id}")
|
||||
reqs.append(req)
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user