Files
FastDeploy/fastdeploy/engine/expert_service.py
chenjian aba94169dc [Feature] Support batched tokens for EP (#3415)
* Support batched tokens for EP

* Support batched tokens for EP

* Support batched tokens for EP

* Support batched tokens for EP

* Support batched tokens for EP and fix bug

* Support batched tokens for EP and fix bug

* Support batched tokens for EP and fix bug

* Support batched tokens for EP and fix bug
2025-08-18 11:43:36 +08:00

426 lines
19 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
import copy
import os
import signal
import threading
import time
import traceback
import weakref
from collections import deque
import numpy as np
from fastdeploy.engine.request import RequestOutput
from fastdeploy.engine.resource_manager import ResourceManager
from fastdeploy.inter_communicator import EngineWorkerQueue
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.output.token_processor import TokenProcessor
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
from fastdeploy.utils import EngineError, console_logger, envs, get_logger, llm_logger
class ExpertService:
"""
Engine class responsible for managing the Large Language Model (LLM) operations.
Attributes:
cfg (Config): Configuration object containing all the parameters.
local_data_parallel_id (int): Local data parallel ID.
"""
def __init__(self, cfg, local_data_parallel_id):
"""
Initializes the LLMEngine with the provided configuration.
Args:
cfg (Config): Config object containing all the configuration parameters.
"""
self.cfg = cfg
start_pos = (local_data_parallel_id * self.cfg.tensor_parallel_size) % cfg.worker_num_per_node
end_pos = start_pos + self.cfg.tensor_parallel_size
self.waiting_requests = []
self.disaggregate_queue = deque()
self.llm_logger = get_logger("expert_service", f"expert_service_{local_data_parallel_id}.log")
if cfg.splitwise_role != "mixed":
self.cfg.cache_config.rdma_comm_ports = self.cfg.cache_config.rdma_comm_ports[start_pos:end_pos]
self.cfg.local_device_ids = self.cfg.device_ids.split(",")[start_pos:end_pos]
self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id
self.cfg.disaggregate_info = None
self.scheduler = cfg.scheduler_config.scheduler()
if self.cfg.scheduler_config.name == "splitwise":
self.scheduler.reset_nodeid(f"{self.scheduler.infer.nodeid}_{local_data_parallel_id!s}")
self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id
address = (cfg.master_ip, cfg.engine_worker_queue_port + local_data_parallel_id)
self.engine_worker_queue = EngineWorkerQueue(
address=address,
is_server=False,
client_id=0,
num_client=cfg.tensor_parallel_size,
local_data_parallel_id=local_data_parallel_id,
)
self.resource_manager = ResourceManager(
cfg.max_num_seqs,
cfg,
cfg.tensor_parallel_size,
cfg.splitwise_role,
local_data_parallel_id,
)
if cfg.splitwise_role != "mixed":
if len(self.cfg.cache_config.pd_comm_port) == 1:
self.cfg.cache_config.pd_comm_port[0] = (
int(self.cfg.cache_config.pd_comm_port[0]) + local_data_parallel_id
)
else:
self.cfg.cache_config.pd_comm_port = [self.cfg.cache_config.pd_comm_port[local_data_parallel_id]]
self.split_connector = SplitwiseConnector(
self.cfg, self.scheduler, self.engine_worker_queue, self.resource_manager, self.disaggregate_queue
)
self.token_processor = TokenProcessor(
cfg=cfg,
cached_generated_tokens=self.scheduler,
engine_worker_queue=self.engine_worker_queue,
split_connector=self.split_connector,
)
self.token_processor.set_resource_manager(self.resource_manager)
self.partial_chunked_tokens = [0] * (self.cfg.max_num_partial_prefills + 1)
for idx in range(1, self.cfg.max_num_partial_prefills + 1):
self.partial_chunked_tokens[idx] = (
(self.cfg.max_num_batched_tokens // idx)
// self.cfg.cache_config.block_size
* self.cfg.cache_config.block_size
)
self._finalizer = weakref.finalize(self, self._exit_sub_services)
if envs.FD_ENABLE_INTERNAL_ADAPTER:
self.external_adapter = InternalAdapter(cfg=self.cfg, engine=self, dp_rank=local_data_parallel_id)
def start(
self, ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None
):
"""
Initializes the engine and starts its sub-services.
If `api_server_pid` is defined, will launch a thread
to keep getting request from zmq_server.
"""
# assert not self.is_started, "The engine is already started."
start_time = time.time()
self.llm_logger.info(f"start expert service {local_data_parallel_id}")
if self.cfg.splitwise_role != "mixed":
self.cache_manager_processes = self.resource_manager.cache_manager.launch_cache_manager(
cache_config=self.cfg.cache_config,
tensor_parallel_size=self.cfg.tensor_parallel_size,
device_ids=self.cfg.local_device_ids,
pod_ip=self.cfg.master_ip,
engine_worker_queue_port=self.cfg.engine_worker_queue_port,
pid_suffix=f"{local_data_parallel_id}_{ipc_signal_suffix}",
)
self.split_mode_get_tasks()
self.insert_task_to_worker_thread = threading.Thread(target=self._insert_task_to_worker, args=())
self.insert_task_to_worker_thread.daemon = True
self.insert_task_to_worker_thread.start()
# Start TokenProcessor thread
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(local_data_parallel_id + int(self.cfg.engine_worker_queue_port))
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK
self.token_processor.run()
self.cfg.init_cache_info()
role = self.cfg.splitwise_role
host_ip = self.cfg.host_ip
disaggregate = self.cfg.disaggregate_info
if self.cfg.scheduler_config.name == "dp":
assert (request_queues_for_dp_ipc is not None) and (result_queue_for_dp_ipc is not None)
self.scheduler.start(local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc)
elif self.cfg.scheduler_config.name == "splitwise":
self.scheduler.start(role, host_ip, disaggregate)
self.cfg.print()
console_logger.info(f"Worker processes are launched with {time.time() - start_time} seconds.")
return True
def _insert_task_to_worker(self):
"""
Insert task to engine thread, monitor scheduler request queue.
if the engine has resource, insert task to engine
"""
current_id = -1
while True:
try:
if self.resource_manager.available_batch() == 0:
time.sleep(0.001)
continue
if self.engine_worker_queue.num_tasks() > 0:
time.sleep(0.001)
continue
num_prefill_batch = min(
int(self.resource_manager.available_batch()),
self.cfg.max_prefill_batch,
)
if envs.FD_ENABLE_INTERNAL_ADAPTER:
num_prefill_batch = int(self.resource_manager.available_batch())
self.resource_manager.check_and_free_block_tables()
tasks = self.scheduler.get_requests(
available_blocks=self.resource_manager.available_block_num(),
block_size=self.cfg.cache_config.block_size,
reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num,
max_num_batched_tokens=self.cfg.max_num_batched_tokens,
batch=num_prefill_batch,
)
if len(tasks) == 0:
time.sleep(0.001)
continue
if self.cfg.splitwise_role != "mixed":
self.llm_logger.info("Inserting splitwise tasks")
self.split_connector.send_splitwise_tasks(tasks, current_id)
current_id = (current_id + 1) % 100003
self.insert_tasks(tasks, current_id)
main_process_metrics.num_requests_waiting.dec(len(tasks))
main_process_metrics.num_requests_running.inc(len(tasks))
except Exception as e:
err_msg = "Error happend while insert task to engine: {}, {}.".format(e, str(traceback.format_exc()))
self.llm_logger.error(err_msg)
def split_mode_get_tasks(self):
"""
Split mode get tasks
"""
def receiver_loop():
while True:
try:
processed_indices = []
for idx, task in enumerate(self.waiting_requests):
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
self.insert_tasks([task])
self.llm_logger.info(f"Resource available, processing task {task.request_id}")
processed_indices.append(idx)
else:
self.llm_logger.debug(f"Still waiting for resources {task.request_id}")
break
for idx in sorted(processed_indices, reverse=True):
self.waiting_requests.pop(idx)
if len(self.disaggregate_queue) > 0:
items = self.disaggregate_queue.pop()
role = items[0]
tasks = items[1]
if role == "prefill":
for task in tasks:
task.max_tokens = task.min_tokens = 2
self.insert_tasks(tasks)
elif role == "decode":
if hasattr(tasks[0], "finished"):
if not isinstance(tasks, list):
tasks = [tasks]
for task in tasks:
task.finished = False
self.insert_tasks(tasks, allocated=True)
if self.cfg.innode_prefill_ports is not None:
self.scheduler.put_results(tasks)
else:
if len(self.waiting_requests):
self.llm_logger.info(f"Waiting for resource for task {tasks[0].request_id}")
self.waiting_requests.extend(tasks)
else:
new_waiting = []
for task in tasks:
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
self.insert_tasks([task])
else:
if not self.enable_decode_cache_task:
task.error_msg = "Not enough resources"
new_waiting.append(task)
if new_waiting:
if not self.enable_decode_cache_task:
self.split_connector.send_cache_infos(new_waiting, -1)
else:
self.waiting_requests.extend(new_waiting)
self.llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue")
else:
time.sleep(0.001)
except Exception as e:
self.llm_logger.error(f"Error in main loop: {e} {str(traceback.format_exc())}")
time.sleep(0.1)
threading.Thread(target=receiver_loop, daemon=True).start()
def insert_tasks(self, tasks, current_id=-1, allocated=False):
"""
Insert tasks to engine.
"""
if allocated:
current_tasks = []
for task in tasks:
cur_task_idx = self.resource_manager.req_dict[task.request_id]
del self.resource_manager.req_dict[task.request_id]
cur_task = self.resource_manager.tasks_list[cur_task_idx]
if envs.FD_ENABLE_INTERNAL_ADAPTER:
if not task.outputs.token_ids: # first token is eos in Prefill, just recycle resource and continue
self.resource_manager.stop_flags[cur_task_idx] = True
self.resource_manager.tasks_list[cur_task_idx] = None
self.resource_manager._recycle_block_tables(cur_task)
if task.request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[task.request_id]
self.llm_logger.warning(f"{task.request_id} need not decode after first token")
continue
cur_task.prompt_token_ids[0] = task.outputs.token_ids[0]
if self.cfg.speculative_config.method in ["mtp"] and self.cfg.splitwise_role == "decode":
cur_task.draft_token_ids = copy.deepcopy(task.outputs.draft_token_ids)
if task.error_code != 200:
self.resource_manager.stop_flags[cur_task_idx] = True
self.resource_manager.tasks_list[cur_task_idx] = None
self.resource_manager._recycle_block_tables(cur_task)
if task.request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[task.request_id]
self.llm_logger.warning(
f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource."
)
continue
self.token_processor.tokens_counter[task.request_id] = 1
current_tasks.append(cur_task)
if current_tasks:
self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz))
return True
self.resource_manager.check_and_free_block_tables()
if not isinstance(tasks, list):
tasks = [tasks]
for task in tasks:
if self.cfg.splitwise_role != "mixed":
status, msg = self.split_connector.check_decode_allocated(task)
if not status:
self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.")
self.scheduler.put_results(
[
RequestOutput(
request_id=task.request_id,
finished=True,
error_code=500,
error_msg=msg,
)
]
)
tasks.remove(task)
continue
task.schedule_start_time = time.time()
available_batch = np.sum(self.resource_manager.stop_flags)
if len(tasks) > available_batch:
self.llm_logger.error(
"Inserting batch:{} exceeds the available batch:{}.".format(len(tasks), available_batch)
)
self.llm_logger.error("The exceeded part will be ignored!")
tasks = tasks[:available_batch]
req_ids = [t.request_id for t in tasks]
tasks = self.resource_manager.allocate_resources_for_new_tasks(tasks)
if not tasks:
error_msg = f"The request required resources is exceed the limit, request id={req_ids}."
self.llm_logger.error(error_msg)
raise EngineError(error_msg, error_code=500)
return False
self.token_processor.number_of_tasks += len(tasks)
is_decode = False
is_prefill = False
for i in range(len(tasks)):
if tasks[i].disaggregate_info is not None:
if tasks[i].disaggregate_info["role"] == "decode":
is_decode = True
else:
is_prefill = True
self.token_processor.number_of_input_tokens += tasks[i].prompt_token_ids_len
if is_decode or is_prefill:
self.split_connector.send_cache_infos(tasks, current_id)
for task in tasks:
task.infer_start_time = time.time()
if not is_decode:
self.llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
if not is_prefill and self.cfg.cache_config.enable_chunked_prefill:
if not self.cfg.enable_mm:
self.update_requests_chunk_size(tasks)
else:
self.update_mm_requests_chunk_size(tasks)
self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz))
return True
def _exit_sub_services(self):
"""
exit sub services
"""
if hasattr(self, "cache_manager_processes"):
self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear()
self.resource_manager.cache_manager.cache_ready_signal.clear()
for p in self.cache_manager_processes:
self.llm_logger.info(f"Killing cache manager process {p.pid}")
try:
os.killpg(p.pid, signal.SIGTERM)
except:
pass
if hasattr(self, "zmq_server") and self.zmq_server is not None:
self.zmq_server.close()
def start_expert_service(
cfg, local_data_parallel_id, ipc_signal_suffix, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None
):
"""
Start expert service
"""
expert_service = ExpertService(cfg, local_data_parallel_id)
try:
expert_service.start(
ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc
)
expert_service.split_connector.start_receiver()
except Exception as e:
llm_logger.exception(f"Expert service failed to start: {e}")