mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 17:17:14 +08:00
[Feature] Support batched tokens for EP (#3415)
* Support batched tokens for EP * Support batched tokens for EP * Support batched tokens for EP * Support batched tokens for EP * Support batched tokens for EP and fix bug * Support batched tokens for EP and fix bug * Support batched tokens for EP and fix bug * Support batched tokens for EP and fix bug
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from multiprocessing import Queue
|
||||
@@ -22,7 +23,7 @@ from typing import Dict, List, Optional
|
||||
from fastdeploy.engine.request import Request, RequestOutput
|
||||
from fastdeploy.scheduler.data import ScheduledResponse
|
||||
from fastdeploy.scheduler.local_scheduler import LocalScheduler
|
||||
from fastdeploy.utils import scheduler_logger
|
||||
from fastdeploy.utils import envs, get_logger
|
||||
|
||||
|
||||
class DPLocalScheduler(LocalScheduler):
|
||||
@@ -45,6 +46,7 @@ class DPLocalScheduler(LocalScheduler):
|
||||
long_prefill_token_threshold,
|
||||
)
|
||||
self.splitwise_role = splitwise_role
|
||||
self.scheduler_logger = logging
|
||||
|
||||
def put_results(self, results: List[RequestOutput]):
|
||||
"""
|
||||
@@ -56,7 +58,7 @@ class DPLocalScheduler(LocalScheduler):
|
||||
|
||||
finished_responses = [response.request_id for response in responses if response.finished]
|
||||
if len(finished_responses) > 0:
|
||||
scheduler_logger.info(f"Scheduler has received some finished responses: {finished_responses}")
|
||||
self.scheduler_logger.info(f"Scheduler has received some finished responses: {finished_responses}")
|
||||
|
||||
with self.mutex:
|
||||
for response in responses:
|
||||
@@ -107,6 +109,80 @@ class DPLocalScheduler(LocalScheduler):
|
||||
else:
|
||||
self.ids_read_cursor -= len(expired_ids)
|
||||
|
||||
def get_requests(
|
||||
self,
|
||||
available_blocks,
|
||||
block_size,
|
||||
reserved_output_blocks,
|
||||
max_num_batched_tokens,
|
||||
batch=1,
|
||||
) -> List[Request]:
|
||||
"""
|
||||
Retrieve requests from the scheduler based on available resources.
|
||||
|
||||
Args:
|
||||
available_blocks: Number of available processing blocks
|
||||
block_size: Size of each processing block
|
||||
reserved_output_blocks: Blocks reserved for output
|
||||
max_num_batched_tokens: Maximum tokens that can be batched
|
||||
batch: Preferred batch size
|
||||
|
||||
Returns:
|
||||
List of Request objects ready for processing
|
||||
"""
|
||||
if available_blocks <= reserved_output_blocks or batch < 1:
|
||||
self.scheduler_logger.debug(
|
||||
f"Scheduler's resource are insufficient: available_blocks={available_blocks} "
|
||||
f"reserved_output_blocks={reserved_output_blocks} batch={batch} "
|
||||
f"max_num_batched_tokens={max_num_batched_tokens}"
|
||||
)
|
||||
return []
|
||||
required_total_blocks = 0
|
||||
current_prefill_tokens = 0
|
||||
start_batch_time = time.time()
|
||||
requests: List[Request] = []
|
||||
|
||||
with self.requests_not_empty:
|
||||
while True:
|
||||
batch_ids = self.requests_not_empty.wait_for(
|
||||
lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + batch],
|
||||
0.005,
|
||||
)
|
||||
if batch_ids:
|
||||
for request_id in batch_ids:
|
||||
request = self.requests[request_id]
|
||||
required_input_blocks = self.calc_required_blocks(request.prompt_tokens_ids_len, block_size)
|
||||
current_prefill_tokens += request.prompt_tokens_ids_len
|
||||
required_total_blocks += required_input_blocks + reserved_output_blocks
|
||||
if required_total_blocks > available_blocks:
|
||||
break
|
||||
if current_prefill_tokens > max_num_batched_tokens:
|
||||
break
|
||||
|
||||
requests.append(request.raw)
|
||||
self.ids_read_cursor += 1
|
||||
start_batch_time = time.time()
|
||||
if len(requests) >= batch:
|
||||
break
|
||||
if (
|
||||
(current_prefill_tokens > max_num_batched_tokens)
|
||||
or (len(requests) >= batch)
|
||||
or (time.time() - start_batch_time > envs.FD_EP_BATCHED_TOKEN_TIMEOUT)
|
||||
):
|
||||
break
|
||||
if batch_ids:
|
||||
if len(batch_ids) > 0 and len(requests) == 0:
|
||||
self.scheduler_logger.debug(
|
||||
f"Scheduler has put all just-pulled request into the queue: {len(batch_ids)}"
|
||||
)
|
||||
|
||||
if len(requests) > 0:
|
||||
self.scheduler_logger.info(
|
||||
f"Scheduler has pulled some request: {[request.request_id for request in requests]}"
|
||||
)
|
||||
|
||||
return requests
|
||||
|
||||
|
||||
class DPScheduler:
|
||||
def __init__(
|
||||
@@ -133,6 +209,8 @@ class DPScheduler:
|
||||
self.dp_rank = dp_rank
|
||||
self.request_queues = request_queues
|
||||
self.result_queue = result_queue
|
||||
self.scheduler_logger = get_logger("dpscheduler", f"dp_scheduler_rank{self.dp_rank}.log")
|
||||
self._scheduler.scheduler_logger = self.scheduler_logger
|
||||
threading.Thread(target=self._put_requests_to_local).start()
|
||||
threading.Thread(target=self._get_response_from_local).start()
|
||||
|
||||
@@ -148,6 +226,7 @@ class DPScheduler:
|
||||
def _put_requests_to_local(self):
|
||||
while True:
|
||||
request = self.request_queues[self.dp_rank].get()
|
||||
self.scheduler_logger.info(f"Recieve request from puller, request_id: {request.request_id}")
|
||||
self._scheduler.put_requests([request])
|
||||
|
||||
def _get_response_from_local(self):
|
||||
|
Reference in New Issue
Block a user