[Feature] Support batched tokens for EP (#3415)

* Support batched tokens for EP

* Support batched tokens for EP

* Support batched tokens for EP

* Support batched tokens for EP

* Support batched tokens for EP and fix bug

* Support batched tokens for EP and fix bug

* Support batched tokens for EP and fix bug

* Support batched tokens for EP and fix bug
This commit is contained in:
chenjian
2025-08-18 11:43:36 +08:00
committed by GitHub
parent 3f86ae0007
commit aba94169dc
9 changed files with 235 additions and 97 deletions

View File

@@ -14,6 +14,7 @@
# limitations under the License.
"""
import logging
import threading
import time
from multiprocessing import Queue
@@ -22,7 +23,7 @@ from typing import Dict, List, Optional
from fastdeploy.engine.request import Request, RequestOutput
from fastdeploy.scheduler.data import ScheduledResponse
from fastdeploy.scheduler.local_scheduler import LocalScheduler
from fastdeploy.utils import scheduler_logger
from fastdeploy.utils import envs, get_logger
class DPLocalScheduler(LocalScheduler):
@@ -45,6 +46,7 @@ class DPLocalScheduler(LocalScheduler):
long_prefill_token_threshold,
)
self.splitwise_role = splitwise_role
self.scheduler_logger = logging
def put_results(self, results: List[RequestOutput]):
"""
@@ -56,7 +58,7 @@ class DPLocalScheduler(LocalScheduler):
finished_responses = [response.request_id for response in responses if response.finished]
if len(finished_responses) > 0:
scheduler_logger.info(f"Scheduler has received some finished responses: {finished_responses}")
self.scheduler_logger.info(f"Scheduler has received some finished responses: {finished_responses}")
with self.mutex:
for response in responses:
@@ -107,6 +109,80 @@ class DPLocalScheduler(LocalScheduler):
else:
self.ids_read_cursor -= len(expired_ids)
def get_requests(
self,
available_blocks,
block_size,
reserved_output_blocks,
max_num_batched_tokens,
batch=1,
) -> List[Request]:
"""
Retrieve requests from the scheduler based on available resources.
Args:
available_blocks: Number of available processing blocks
block_size: Size of each processing block
reserved_output_blocks: Blocks reserved for output
max_num_batched_tokens: Maximum tokens that can be batched
batch: Preferred batch size
Returns:
List of Request objects ready for processing
"""
if available_blocks <= reserved_output_blocks or batch < 1:
self.scheduler_logger.debug(
f"Scheduler's resource are insufficient: available_blocks={available_blocks} "
f"reserved_output_blocks={reserved_output_blocks} batch={batch} "
f"max_num_batched_tokens={max_num_batched_tokens}"
)
return []
required_total_blocks = 0
current_prefill_tokens = 0
start_batch_time = time.time()
requests: List[Request] = []
with self.requests_not_empty:
while True:
batch_ids = self.requests_not_empty.wait_for(
lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + batch],
0.005,
)
if batch_ids:
for request_id in batch_ids:
request = self.requests[request_id]
required_input_blocks = self.calc_required_blocks(request.prompt_tokens_ids_len, block_size)
current_prefill_tokens += request.prompt_tokens_ids_len
required_total_blocks += required_input_blocks + reserved_output_blocks
if required_total_blocks > available_blocks:
break
if current_prefill_tokens > max_num_batched_tokens:
break
requests.append(request.raw)
self.ids_read_cursor += 1
start_batch_time = time.time()
if len(requests) >= batch:
break
if (
(current_prefill_tokens > max_num_batched_tokens)
or (len(requests) >= batch)
or (time.time() - start_batch_time > envs.FD_EP_BATCHED_TOKEN_TIMEOUT)
):
break
if batch_ids:
if len(batch_ids) > 0 and len(requests) == 0:
self.scheduler_logger.debug(
f"Scheduler has put all just-pulled request into the queue: {len(batch_ids)}"
)
if len(requests) > 0:
self.scheduler_logger.info(
f"Scheduler has pulled some request: {[request.request_id for request in requests]}"
)
return requests
class DPScheduler:
def __init__(
@@ -133,6 +209,8 @@ class DPScheduler:
self.dp_rank = dp_rank
self.request_queues = request_queues
self.result_queue = result_queue
self.scheduler_logger = get_logger("dpscheduler", f"dp_scheduler_rank{self.dp_rank}.log")
self._scheduler.scheduler_logger = self.scheduler_logger
threading.Thread(target=self._put_requests_to_local).start()
threading.Thread(target=self._get_response_from_local).start()
@@ -148,6 +226,7 @@ class DPScheduler:
def _put_requests_to_local(self):
while True:
request = self.request_queues[self.dp_rank].get()
self.scheduler_logger.info(f"Recieve request from puller, request_id: {request.request_id}")
self._scheduler.put_requests([request])
def _get_response_from_local(self):