mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-07 01:22:59 +08:00
807 lines
31 KiB
Python
807 lines
31 KiB
Python
"""
|
||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
from typing import List, Tuple, Dict, Optional
|
||
import os
|
||
import re
|
||
import signal
|
||
import subprocess
|
||
import sys
|
||
import threading
|
||
import time
|
||
import traceback
|
||
import uuid
|
||
import weakref
|
||
|
||
import numpy as np
|
||
import zmq
|
||
from tqdm import tqdm
|
||
|
||
from fastdeploy.engine.args_utils import EngineArgs
|
||
from fastdeploy.engine.request import Request, RequestOutput
|
||
from fastdeploy.engine.resource_manager import ResourceManager
|
||
from fastdeploy.input.preprocess import InputPreprocessor
|
||
from fastdeploy.inter_communicator import (EngineWorkerQueue, IPCSignal,
|
||
ZmqClient)
|
||
from fastdeploy.output.token_processor import (TokenProcessor,
|
||
WarmUpTokenProcessor)
|
||
from fastdeploy.utils import EngineError, console_logger, llm_logger
|
||
|
||
|
||
class LLMEngine(object):
|
||
"""
|
||
Main engine class for managing Large Language Model (LLM) inference operations.
|
||
|
||
This class handles the complete lifecycle of LLM inference including:
|
||
- Initialization and configuration
|
||
- Request processing and scheduling
|
||
- Resource management
|
||
- Communication with worker processes
|
||
- Token generation and output handling
|
||
|
||
Key Components:
|
||
- Scheduler: Manages request queue and task scheduling
|
||
- ResourceManager: Handles GPU memory allocation and block management
|
||
- TokenProcessor: Processes generated tokens and handles streaming output
|
||
- WorkerQueue: Facilitates communication between engine and worker processes
|
||
|
||
Attributes:
|
||
cfg (Config): Engine configuration parameters
|
||
scheduler (BaseScheduler): Task scheduler instance
|
||
input_processor (InputPreprocessor): Preprocesses input data
|
||
resource_manager (ResourceManager): Manages GPU resources
|
||
token_processor (TokenProcessor): Handles token generation
|
||
engine_worker_queue (EngineWorkerQueue): Worker communication queue
|
||
is_started (bool): Engine running status flag
|
||
do_profile (int): Profiling mode flag (0=disabled, 1=enabled)
|
||
worker_proc (subprocess.Popen): Worker process handle
|
||
zmq_server (ZmqClient): ZMQ communication server
|
||
"""
|
||
|
||
@classmethod
|
||
def from_engine_args(cls, engine_args: EngineArgs):
|
||
"""
|
||
Creates an LLM engine from the provided engine arguments.
|
||
|
||
Args:
|
||
engine_args (EngineArgs): Engine arguments object.
|
||
|
||
Returns:
|
||
LLMEngine: Instance of the LLMEngine class.
|
||
"""
|
||
# Create the engine configs.
|
||
config = engine_args.create_engine_config()
|
||
# Create the LLMEngine.
|
||
return cls(cfg=config)
|
||
|
||
def __init__(self, cfg):
|
||
"""
|
||
Initialize the LLM engine with given configuration.
|
||
|
||
Note: Prefer using from_engine_args() for most use cases as it provides
|
||
better configuration validation.
|
||
|
||
Sets up:
|
||
- Task scheduler based on configuration
|
||
- Input preprocessing pipeline
|
||
- Resource management system
|
||
- Token generation processor
|
||
- Worker communication queue
|
||
- Profiling and monitoring systems
|
||
|
||
Args:
|
||
cfg (Config): Complete engine configuration including:
|
||
- Model parameters
|
||
- Parallelism settings
|
||
- Memory allocation
|
||
- Performance tuning options
|
||
|
||
Raises:
|
||
ValueError: If required configuration parameters are missing or invalid
|
||
"""
|
||
self.cfg = cfg
|
||
self.scheduler = cfg.scheduler_config.scheduler()
|
||
|
||
self.input_processor = InputPreprocessor(cfg.tokenizer, cfg.enable_mm)
|
||
self.resource_manager = ResourceManager(
|
||
cfg.max_num_seqs, cfg.cache_config)
|
||
|
||
self.token_processor = TokenProcessor(
|
||
cfg=self.cfg, cached_generated_tokens=self.scheduler)
|
||
self.token_processor.set_resource_manager(self.resource_manager)
|
||
time.sleep(1) # TODO: Investigate the purpose of this sleep.
|
||
|
||
address = ('0.0.0.0', self.cfg.engine_worker_queue_port)
|
||
self.engine_worker_queue = EngineWorkerQueue(
|
||
address=address,
|
||
is_server=True,
|
||
num_client=self.cfg.tensor_parallel_size)
|
||
|
||
self.is_started = False
|
||
|
||
if self.cfg.cache_config.num_gpu_blocks_override is None:
|
||
self.do_profile = 1
|
||
else:
|
||
self.do_profile = 0
|
||
|
||
self._finalizer = weakref.finalize(self, self._exit_sub_services)
|
||
|
||
def start(self, api_server_pid=None):
|
||
"""
|
||
Initializes the engine and starts its sub-services.
|
||
If `api_server_pid` is defined, will launch a thread
|
||
to keep getting request from zmq_server.
|
||
"""
|
||
assert not self.is_started, "The engine is already started."
|
||
start_time = time.time()
|
||
|
||
self.api_server_pid = api_server_pid
|
||
self.engine_pid = os.getpid()
|
||
self.ipc_signal_suffix = self.engine_pid if self.api_server_pid is None else self.api_server_pid
|
||
self._init_worker_signals()
|
||
|
||
self.data_processor = self.input_processor.create_processor()
|
||
|
||
if api_server_pid is not None:
|
||
self.zmq_server = ZmqClient(name=api_server_pid, mode=zmq.PULL)
|
||
self.zmq_server.start_server()
|
||
self.zmq_server.create_router()
|
||
time.sleep(3)
|
||
|
||
self.worker_proc = self._start_worker_service()
|
||
console_logger.info("Waitting worker processes ready...")
|
||
time.sleep(5)
|
||
self.worker_init_status = dict()
|
||
if not self.check_worker_initialize_status():
|
||
console_logger.error(
|
||
"Failed to launch worker processes, check log/workerlog.* for more details."
|
||
)
|
||
return False
|
||
|
||
# Start warmup if enabled
|
||
if self.cfg.use_warmup:
|
||
console_logger.info("Starting warmup")
|
||
self._set_warmup_token_processor()
|
||
self.warmup()
|
||
self._del_warmup_token_processor()
|
||
console_logger.info("Warmup finished")
|
||
|
||
self.token_processor.tasks_queue = self.engine_worker_queue
|
||
|
||
self.insert_task_to_worker_thread = threading.Thread(
|
||
target=self._insert_task_to_worker, args=())
|
||
self.insert_task_to_worker_thread.daemon = True
|
||
self.insert_task_to_worker_thread.start()
|
||
|
||
if self.api_server_pid is not None:
|
||
self.insert_task_to_scheduler_thread = threading.Thread(
|
||
target=self._insert_zmq_task_to_scheduler, args=())
|
||
self.insert_task_to_scheduler_thread.daemon = True
|
||
self.insert_task_to_scheduler_thread.start()
|
||
|
||
self.receive_output_thread = threading.Thread(
|
||
target=self._zmq_send_generated_tokens, args=())
|
||
self.receive_output_thread.daemon = True
|
||
self.receive_output_thread.start()
|
||
|
||
# Start TokenProcessor thread
|
||
self.token_processor.run()
|
||
|
||
# self.start_push_sender_thread()
|
||
if self.do_profile:
|
||
self._stop_profile()
|
||
console_logger.info(
|
||
"Worker processes are launched with {} seconds.".format(
|
||
time.time() - start_time))
|
||
return True
|
||
|
||
def _zmq_send_generated_tokens(self):
|
||
"""
|
||
Recieve output for zmq
|
||
"""
|
||
assert self.api_server_pid is not None
|
||
while True:
|
||
try:
|
||
def get_results_handler(request_ids):
|
||
results = dict()
|
||
try:
|
||
results = self.scheduler.get_results(request_ids)
|
||
for req_id, contents in results.items():
|
||
results[req_id] = [data.to_dict()
|
||
for data in contents]
|
||
except Exception as e:
|
||
llm_logger.error(f"Get results handler error: {e}")
|
||
return results
|
||
|
||
self.zmq_server.send_multipart2(get_results_handler)
|
||
except Exception as e:
|
||
llm_logger.error("Unexcepted error happend: {}, {}".format(
|
||
e, str(traceback.format_exc())))
|
||
|
||
def _get_generated_result(self, request_id):
|
||
"""
|
||
Get result from scheduler, this function is called by generate()
|
||
which is only used in offline inference.
|
||
"""
|
||
try:
|
||
acc = None
|
||
while True:
|
||
results = self.scheduler.get_results([request_id])
|
||
for _, contents in results.items():
|
||
for result in contents:
|
||
if acc is None:
|
||
acc = result
|
||
else:
|
||
acc.add(result)
|
||
|
||
if result.finished:
|
||
yield acc
|
||
return
|
||
|
||
yield result
|
||
|
||
except Exception as e:
|
||
llm_logger.error("Unexcepted error happend: {}, {}".format(
|
||
e, str(traceback.format_exc())))
|
||
|
||
def _insert_task_to_worker(self):
|
||
"""
|
||
Insert task to engine thread, monitor scheduler request queue.
|
||
if the engine has resource, insert task to engine
|
||
"""
|
||
while True:
|
||
try:
|
||
if self.resource_manager.available_batch() == 0:
|
||
time.sleep(0.001)
|
||
continue
|
||
if self.engine_worker_queue.num_tasks() > 0:
|
||
time.sleep(0.001)
|
||
continue
|
||
|
||
num_prefill_batch = min(
|
||
int(self.resource_manager.available_batch()),
|
||
self.cfg.max_prefill_batch)
|
||
|
||
if self.cfg.enable_chunked_prefill:
|
||
cur_max_num_batched_tokens = self.cfg.max_model_len * num_prefill_batch
|
||
else:
|
||
cur_max_num_batched_tokens = self.cfg.max_num_batched_tokens
|
||
|
||
tasks = self.scheduler.get_requests(
|
||
available_blocks=self.resource_manager.available_block_num(
|
||
),
|
||
block_size=self.cfg.cache_config.block_size,
|
||
reserved_output_blocks=self.cfg.cache_config.
|
||
enc_dec_block_num,
|
||
max_num_batched_tokens=cur_max_num_batched_tokens,
|
||
batch=num_prefill_batch)
|
||
|
||
if len(tasks) == 0:
|
||
time.sleep(0.001)
|
||
continue
|
||
|
||
self.insert_tasks(tasks)
|
||
except Exception as e:
|
||
err_msg = "Error happend while insert task to engine: {}, {}.".format(
|
||
e, str(traceback.format_exc()))
|
||
llm_logger.error(err_msg)
|
||
|
||
def _insert_zmq_task_to_scheduler(self):
|
||
if self.api_server_pid is None:
|
||
return
|
||
|
||
added_requests: Dict[str, int] = dict()
|
||
while True:
|
||
try:
|
||
block = True if len(added_requests) == 0 else False
|
||
if not self.cfg.enable_mm:
|
||
err, data = self.zmq_server.receive_json_once(block)
|
||
else:
|
||
err, data = self.zmq_server.receive_pyobj_once(block)
|
||
if err is not None:
|
||
llm_logger.error(
|
||
"Engine stops inserting zmq task into scheduler")
|
||
break
|
||
|
||
request = None
|
||
if data:
|
||
request = Request.from_dict(data)
|
||
llm_logger.info(f"Receive request: {request}")
|
||
|
||
results: List[Tuple[str, Optional[str]]] = self.scheduler.put_requests(
|
||
[] if request is None else [request])
|
||
|
||
if request:
|
||
if request.request_id not in added_requests:
|
||
added_requests[request.request_id] = 0
|
||
added_requests[request.request_id] += 1
|
||
|
||
for request_id, failed in results:
|
||
added_requests[request_id] -= 1
|
||
if added_requests[request_id] == 0:
|
||
added_requests.pop(request_id)
|
||
|
||
if failed is None:
|
||
continue
|
||
|
||
error_result = RequestOutput(request_id=request_id,
|
||
finished=True,
|
||
error_code=500,
|
||
error_msg=failed)
|
||
# Since the request is not in scheduler
|
||
# Send result by zmq directly
|
||
self.zmq_server.send_multipart(
|
||
request.request_id, error_result)
|
||
except Exception as e:
|
||
llm_logger.error(
|
||
f"Error happend while receving new request from zmq, details={e}"
|
||
)
|
||
|
||
def add_requests(self, task, sampling_params=None):
|
||
"""
|
||
Add a new request to the queue.
|
||
|
||
Args:
|
||
task: Request A dictionary representing the request.
|
||
sampling_params: A dictionary representing the sampling parameters.
|
||
|
||
Returns:
|
||
None
|
||
"""
|
||
# TODO 输入输出长度确认
|
||
|
||
request = Request.from_dict(task)
|
||
if sampling_params is not None:
|
||
request.sampling_params = sampling_params
|
||
request.preprocess_start_time = time.time()
|
||
request = self.data_processor.process_request(request, self.cfg.max_model_len)
|
||
|
||
request.prompt_token_ids_len = len(request.prompt_token_ids)
|
||
input_ids_len = request.prompt_token_ids_len
|
||
request.set(
|
||
"max_tokens",
|
||
min(self.cfg.max_model_len - input_ids_len,
|
||
request.get("max_tokens")))
|
||
min_tokens = request.get("min_tokens")
|
||
if input_ids_len + min_tokens >= self.cfg.max_model_len:
|
||
error_msg = (
|
||
f"Input text is too long, length of prompt token({input_ids_len}) "
|
||
f"+ min_dec_len ({min_tokens}) >= max_model_len ")
|
||
llm_logger.error(error_msg)
|
||
raise EngineError(error_msg, error_code=400)
|
||
|
||
if input_ids_len > self.cfg.max_model_len:
|
||
error_msg = (
|
||
f"Length of input token({input_ids_len}) exceeds the limit max_model_len({self.cfg.max_model_len})."
|
||
)
|
||
llm_logger.error(error_msg)
|
||
raise EngineError(error_msg, error_code=400)
|
||
|
||
request.preprocess_end_time = time.time()
|
||
self.scheduler.put_requests([request])
|
||
llm_logger.info(
|
||
f"Cache task with request_id ({request.get('request_id')})")
|
||
llm_logger.debug(f"cache task: {request}")
|
||
|
||
def warmup(self):
|
||
"""
|
||
construct test tasks and avoid out of memory problem in the worker process
|
||
"""
|
||
# get eos_token_id
|
||
pass
|
||
|
||
def insert_tasks(self, tasks):
|
||
"""
|
||
Insert tasks to engine.
|
||
"""
|
||
if not isinstance(tasks, list):
|
||
tasks = [tasks]
|
||
|
||
for item in tasks:
|
||
item.schedule_start_time = time.time()
|
||
|
||
available_batch = np.sum(self.resource_manager.stop_flags)
|
||
if len(tasks) > available_batch:
|
||
llm_logger.error(
|
||
"Inserting batch:{} exceeds the available batch:{}.".format(
|
||
len(tasks), available_batch))
|
||
llm_logger.error("The exceeded part will be ignored!")
|
||
tasks = tasks[:available_batch]
|
||
|
||
req_ids = [t.request_id for t in tasks]
|
||
|
||
tasks = self.resource_manager.allocate_resources_for_new_tasks(tasks)
|
||
if not tasks:
|
||
error_msg = f"The request required resources is exceed the limit, request id={req_ids}."
|
||
llm_logger.error(error_msg)
|
||
raise EngineError(error_msg, error_code=500)
|
||
|
||
self.token_processor.number_of_tasks += len(tasks)
|
||
token_chunk_size =(self.cfg.max_num_batched_tokens // len(tasks)) // self.cfg.cache_config.block_size * self.cfg.cache_config.block_size
|
||
for i in range(len(tasks)):
|
||
self.token_processor.number_of_input_tokens += tasks[
|
||
i].prompt_token_ids_len
|
||
|
||
tasks[i].set("token_chunk_size", token_chunk_size)
|
||
|
||
llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
|
||
self.engine_worker_queue.put_tasks(
|
||
(tasks, self.resource_manager.real_bsz))
|
||
return True
|
||
|
||
def task_is_finished(self, index):
|
||
"""
|
||
judge if the task is finished
|
||
"""
|
||
assert index < len(self.resource_manager.stop_flags)
|
||
return self.resource_manager.stop_flags[index]
|
||
|
||
def all_tasks_finished(self):
|
||
"""
|
||
judge if all tasks are finished
|
||
"""
|
||
return np.sum(self.resource_manager.stop_flags) == len(
|
||
self.resource_manager.stop_flags)
|
||
|
||
def _set_warmup_token_processor(self):
|
||
"""
|
||
set token_processor for warmup
|
||
"""
|
||
self.token_processor_backup = self.token_processor
|
||
self.token_processor = WarmUpTokenProcessor(self.cfg)
|
||
self.token_processor.set_resource_manager(self.resource_manager)
|
||
self.token_processor.tasks_queue = self.engine_worker_queue
|
||
|
||
# start TokenProcessor thread
|
||
self.token_processor.run()
|
||
|
||
def _del_warmup_token_processor(self):
|
||
"""
|
||
delete token_processor for warmup
|
||
"""
|
||
self.token_processor.stop()
|
||
del self.token_processor
|
||
|
||
# reset token_processor
|
||
self.token_processor = self.token_processor_backup
|
||
del self.token_processor_backup
|
||
|
||
def _worker_processes_ready(self):
|
||
"""
|
||
judge if all worker processes are ready
|
||
|
||
"""
|
||
if np.sum(self.worker_ready_signal.value) == self.cfg.tp_num_per_node:
|
||
return True
|
||
return False
|
||
|
||
def _init_worker_signals(self):
|
||
"""
|
||
Initialize shared memory to indicate engine status
|
||
"""
|
||
# worker_ready_signal 用于engine感知各worker进程是否Ready
|
||
|
||
worker_ready_signal_data = np.zeros(
|
||
shape=[self.cfg.tensor_parallel_size], dtype=np.int32)
|
||
self.worker_ready_signal = IPCSignal(name="worker_ready_singnal",
|
||
array=worker_ready_signal_data,
|
||
dtype=np.int32,
|
||
suffix=self.ipc_signal_suffix,
|
||
create=True)
|
||
|
||
# exist_task_signal 用于各worker进程感知是否有新Task需要处理
|
||
exist_task_signal_data = np.zeros([1], dtype=np.int32)
|
||
self.exist_task_signal = IPCSignal(name="exist_task_signal",
|
||
array=exist_task_signal_data,
|
||
dtype=np.int32,
|
||
suffix=self.ipc_signal_suffix,
|
||
create=True)
|
||
|
||
# exist_swapped_task_signal 用于engine感知worker中是否存在swapped task
|
||
exist_swapped_task_signal_data = np.zeros([1], dtype=np.int32)
|
||
self.exist_swapped_task_signal = IPCSignal(
|
||
name="exist_swapped_task_signal",
|
||
array=exist_swapped_task_signal_data,
|
||
dtype=np.int32,
|
||
suffix=self.ipc_signal_suffix,
|
||
create=True)
|
||
|
||
# worker_live_signal 用于engine感知各worker进程是否存活,记录每个step 时间
|
||
worker_healthy_live_recorded_time_array = np.zeros(
|
||
shape=[self.cfg.tensor_parallel_size], dtype=np.int32)
|
||
self.worker_healthy_live_signal = IPCSignal(
|
||
name="worker_healthy_live_signal",
|
||
array=worker_healthy_live_recorded_time_array,
|
||
dtype=np.int32,
|
||
suffix=self.ipc_signal_suffix,
|
||
create=True)
|
||
|
||
if self.do_profile:
|
||
get_profile_block_num = np.zeros([self.cfg.tensor_parallel_size],
|
||
dtype=np.int32)
|
||
self.get_profile_block_num_signal = IPCSignal(
|
||
name="get_profile_block_num",
|
||
array=get_profile_block_num,
|
||
dtype=np.int32,
|
||
suffix=self.ipc_signal_suffix,
|
||
create=True)
|
||
|
||
model_weights_status = np.zeros([1], dtype=np.int32)
|
||
self.model_weights_status_signal = IPCSignal(
|
||
name="model_weights_status",
|
||
array=model_weights_status,
|
||
dtype=np.int32,
|
||
suffix=self.ipc_signal_suffix,
|
||
create=True)
|
||
|
||
def _exit_sub_services(self):
|
||
"""
|
||
exit sub services
|
||
"""
|
||
self.worker_ready_signal.clear()
|
||
self.exist_task_signal.clear()
|
||
self.exist_swapped_task_signal.clear()
|
||
self.worker_healthy_live_signal.clear()
|
||
if hasattr(self, "get_profile_block_num_signal"):
|
||
self.get_profile_block_num_signal.clear()
|
||
self.model_weights_status_signal.clear()
|
||
if hasattr(self, "worker_proc") and self.worker_proc is not None:
|
||
try:
|
||
os.killpg(self.worker_proc.pid, signal.SIGTERM)
|
||
except:
|
||
pass
|
||
if hasattr(self, "zmq_server") and self.zmq_server is not None:
|
||
self.zmq_server.close()
|
||
|
||
def _setting_environ_variables(self):
|
||
"""
|
||
配置环境变量
|
||
"""
|
||
variables = {
|
||
"PADDLE_TRAINER_ID": 0,
|
||
"PADDLE_TRAINERS_NUM": 1,
|
||
"TRAINER_INSTANCES_NUM": 1,
|
||
"TRAINER_INSTANCES": "0.0.0.0",
|
||
"ENABLE_FASTDEPLOY_LOAD_MODEL_CONCURRENCY": 0,
|
||
"LOAD_STATE_DICT_THREAD_NUM": len(self.cfg.device_ids.split(',')),
|
||
"PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": "python",
|
||
"FLAGS_use_append_attn": 1,
|
||
"NCCL_ALGO": "Ring",
|
||
"ELLM_DYNAMIC_MODE": 1,
|
||
}
|
||
command_prefix = ""
|
||
for k, v in variables.items():
|
||
command_prefix += f"{k}={v} "
|
||
return command_prefix
|
||
|
||
def _start_worker_service(self):
|
||
"""
|
||
start gpu worker service
|
||
|
||
"""
|
||
command_prefix = self._setting_environ_variables()
|
||
current_file_path = os.path.abspath(__file__)
|
||
current_dir_path = os.path.split(current_file_path)[0]
|
||
# TODO
|
||
uncache_worker_stdout = "" if os.getenv("UNCACHE_WORKER_STDOUT",
|
||
"0") == 1 else "-u"
|
||
pd_cmd = f"{command_prefix} {sys.executable} {uncache_worker_stdout} -m paddle.distributed.launch "
|
||
py_script = os.path.join(current_dir_path, "../worker/worker.py")
|
||
arguments = (
|
||
f" --nnodes {str(self.cfg.nnode)}"
|
||
f" --devices {self.cfg.device_ids} {py_script}"
|
||
f" --max_num_seqs {self.cfg.max_num_seqs} --max_model_len {self.cfg.max_model_len}"
|
||
f" --gpu_memory_utilization {self.cfg.cache_config.gpu_memory_utilization}"
|
||
f" --model_name_or_path {str(self.cfg.model_name_or_path)}"
|
||
f" --device_ids {self.cfg.device_ids}"
|
||
f" --engine_worker_queue_port {str(self.cfg.engine_worker_queue_port)}"
|
||
f" --total_block_num {self.cfg.cache_config.total_block_num}"
|
||
f" --block_size {self.cfg.cache_config.block_size}"
|
||
f" --enc_dec_block_num {self.cfg.cache_config.enc_dec_block_num}"
|
||
f" --eos_tokens_lens {self.data_processor.eos_token_id_len}"
|
||
f" --pad_token_id {self.data_processor.pad_token_id}"
|
||
f" --engine_pid {self.engine_pid}"
|
||
f" --do_profile {self.do_profile}"
|
||
f" --dynamic_load_weight {self.cfg.model_config.dynamic_load_weight}"
|
||
f" --max_num_batched_tokens {self.cfg.max_num_batched_tokens}"
|
||
f" --kv_cache_ratio {self.cfg.cache_config.kv_cache_ratio} --dtype {self.cfg.cache_config.cache_dtype}"
|
||
)
|
||
worker_append_flag = {
|
||
"enable_chunked_prefill": self.cfg.enable_chunked_prefill,
|
||
}
|
||
for worker_flag, value in worker_append_flag.items():
|
||
if value:
|
||
arguments = arguments + f" --{worker_flag}"
|
||
|
||
if self.cfg.nnode > 1:
|
||
pd_cmd = pd_cmd + f" --ips {self.cfg.ips}"
|
||
log_dir = os.getenv("FD_LOG_DIR", default="log")
|
||
pd_cmd = pd_cmd + arguments + f" 2>{log_dir}/launch_worker.log"
|
||
llm_logger.info("Launch worker service command: {}".format(pd_cmd))
|
||
p = subprocess.Popen(
|
||
pd_cmd,
|
||
stdout=subprocess.PIPE,
|
||
shell=True,
|
||
preexec_fn=os.setsid,
|
||
)
|
||
return p
|
||
|
||
def _format_and_add_data(self, prompts: dict):
|
||
|
||
if "request_id" in prompts:
|
||
prompts["request_id"] = prompts["request_id"]
|
||
|
||
if "request_id" not in prompts:
|
||
request_id = str(uuid.uuid4())
|
||
prompts["request_id"] = request_id
|
||
query_list = []
|
||
|
||
if "context" in prompts:
|
||
for item in prompts["context"]:
|
||
if item["role"] == "system":
|
||
prompts["system"] = item["utterance"]
|
||
elif item["role"] in ["user", "assistant"]:
|
||
query_list.append(item["utterance"])
|
||
prompts["prompt"] = query_list
|
||
|
||
if "max_tokens" not in prompts:
|
||
prompts["max_tokens"] = self.cfg.max_model_len
|
||
|
||
self.add_requests(prompts)
|
||
return prompts["request_id"]
|
||
|
||
def generate(self, prompts, stream):
|
||
"""
|
||
Generates a response based on the given prompt using the model.
|
||
|
||
Args:
|
||
prompts (dict): The prompt to use for generating the response.
|
||
stream (bool): Whether to stream the output or wait until completion.
|
||
|
||
Yields:
|
||
dict: The generated response.
|
||
"""
|
||
llm_logger.info(f"Starting generation for prompt: {prompts}")
|
||
try:
|
||
req_id = self._format_and_add_data(prompts)
|
||
except Exception as e:
|
||
llm_logger.error(
|
||
f"Error happend while adding request, details={e}")
|
||
raise EngineError(str(e), error_code=400)
|
||
|
||
# 获取当前请求的结果
|
||
for result in self._get_generated_tokens(req_id):
|
||
is_end = result.finished
|
||
if stream and not is_end:
|
||
processed = self.data_processor.process_response(result)
|
||
if processed is None:
|
||
continue
|
||
output = processed.to_dict()
|
||
yield output
|
||
|
||
# Exit loop if termination condition is met
|
||
if is_end:
|
||
processed = self.data_processor.process_response(result)
|
||
output = processed.to_dict()
|
||
llm_logger.debug(f"Generate result: {output}")
|
||
if not stream:
|
||
yield output
|
||
else:
|
||
output["outputs"]["text"] = ""
|
||
output["outputs"]["reasoning_content"] = ""
|
||
yield output
|
||
|
||
def _stop_profile(self):
|
||
"""
|
||
Stop profiling of the model server and reset variables.
|
||
"""
|
||
self.do_profile = 0
|
||
num_gpu_blocks = -1
|
||
for i in range(self.cfg.tensor_parallel_size):
|
||
while self.get_profile_block_num_signal.value[i] == 0:
|
||
time.sleep(1)
|
||
if num_gpu_blocks < 0:
|
||
num_gpu_blocks = self.get_profile_block_num_signal.value[i]
|
||
else:
|
||
num_gpu_blocks = min(
|
||
num_gpu_blocks, self.get_profile_block_num_signal.value[i])
|
||
|
||
console_logger.info(f"Stop profile, num_gpu_blocks: {num_gpu_blocks}")
|
||
self.cfg.cache_config.reset(num_gpu_blocks)
|
||
self.resource_manager.reset_cache_config(self.cfg.cache_config)
|
||
|
||
def check_health(self, time_interval_threashold=30):
|
||
"""
|
||
Check the health of the model server by checking whether all workers are alive.
|
||
|
||
"""
|
||
if self.worker_healthy_live_signal.value[0]:
|
||
elapsed_time = time.time() - \
|
||
self.worker_healthy_live_signal.value[0]
|
||
if elapsed_time > time_interval_threashold:
|
||
return False, "Worker Service Not Healthy"
|
||
|
||
return True, ""
|
||
|
||
def check_worker_initialize_status(self):
|
||
"""
|
||
Check the initlialize status of workers by stdout logging
|
||
"""
|
||
|
||
def detect_thread():
|
||
for line in self.worker_proc.stdout:
|
||
line = line.decode('utf-8', errors='ignore')
|
||
if self.worker_init_status.get("finished", False):
|
||
break
|
||
if match := re.search(r'Loading checkpoint shards:\s*(\d+)',
|
||
line):
|
||
self.worker_init_status["weight_loadding"] = eval(
|
||
match.group(1)) * 1.0 / 100
|
||
elif (match := re.search(r'Start load layer (\d+)',
|
||
line)) or (match := re.search(
|
||
r'set state for layer (\d+)',
|
||
line)):
|
||
progress = eval(match.group(
|
||
1)) * 1.0 / self.cfg.model_config.num_layers
|
||
self.worker_init_status["layer_loadding"] = progress
|
||
if self.worker_init_status[
|
||
"layer_loadding"] == self.cfg.model_config.num_layers - 1:
|
||
self.worker_init_status["finished"] = True
|
||
|
||
self.checking_worker_status_thread = threading.Thread(
|
||
target=detect_thread, args=())
|
||
self.checking_worker_status_thread.daemon = True
|
||
self.checking_worker_status_thread.start()
|
||
|
||
# display weight loadding progress
|
||
with tqdm(total=100, desc="Loading Weights") as pbar:
|
||
progress = 0
|
||
while progress < 100:
|
||
progress = int(
|
||
self.worker_init_status.get("weight_loadding", 0) * 100)
|
||
if self.worker_init_status.get(
|
||
"layer_loadding",
|
||
0) > 0 or self._worker_processes_ready():
|
||
progress = 100
|
||
pbar.update(progress - pbar.n)
|
||
pbar.refresh()
|
||
time.sleep(0.5)
|
||
if self.worker_proc.poll() is not None:
|
||
return False
|
||
|
||
# display layer loadding progress
|
||
with tqdm(total=100, desc="Loading Layers") as pbar:
|
||
progress = 0
|
||
while progress < 100:
|
||
progress = int(
|
||
self.worker_init_status.get("layer_loadding", 0) * 100)
|
||
if self._worker_processes_ready():
|
||
progress = 100
|
||
pbar.update(progress - pbar.n)
|
||
pbar.refresh()
|
||
time.sleep(0.5)
|
||
if self.worker_proc.poll() is not None:
|
||
return False
|
||
|
||
self.worker_init_status["finished"] = True
|
||
try:
|
||
self.checking_worker_status_thread.join(timeout=1)
|
||
except Exception:
|
||
pass
|
||
return True
|
||
|