Files
FastDeploy/fastdeploy/entrypoints/llm.py
luukunn 132a8ef425 Release/2.1 (#3414)
* Pre ce modified (#3335) (#3360)

* Pre ce modified (#3335)

* update

* update

* fix

* fix

* update

* update

* update

* fix

* update

* update

* update

* add ut fix pr(3367)

* [Bug Fix] Fix V1 video bug (#3387)

* fix stopseq error info (#3342)

Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>

* [BugFix] Fix default log level of paddleformers (#3377)

Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>

* [Polish Code] Remove useless notes

* feat(log):add_request_and_response_log (#3392)

* Optimize CI execution workflow. (#3371) (#3384)

* fix

* [BugFix] fix control signal release failed (#3374)

* [BugFix]

* [BugFix]

* [BugFix]

* [BugFix]

* fix

* fix

---------

Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>

* Revert "Merge branch 'feature/online/vs_think_20250813' into release/2.1"

This reverts commit 02596fc537, reversing
changes made to 03347626a6.

* [XPU] Fixed the issue of performance degradation caused by enabling ENABLE_V1_KVCACHE_SCHEDULER (#3393)

* fix v1 schedule oom bug

* fix v1 schedule oom bug

* [BugFix] fix ErnieProcessor not set raw_prediction (#3401)

* [Doc]Release fastdeploy-xpu 2.1.0 (#3407)

* fix v1 schedule oom bug

* fix v1 schedule oom bug

* update release note

* [Doc]Release fastdeploy-xpu 2.0.3  (#3408)

* fix v1 schedule oom bug

* fix v1 schedule oom bug

* update release note

* update info

---------

Co-authored-by: YUNSHEN XIE <1084314248@qq.com>
Co-authored-by: ming1753 <61511741+ming1753@users.noreply.github.com>
Co-authored-by: JYChen <zoooo0820@qq.com>
Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
Co-authored-by: Jiang-Jia-Jun <jiangjiajun@baidu.com>
Co-authored-by: xiaolei373 <zley373@gmail.com>
Co-authored-by: ltd0924 <32387785+ltd0924@users.noreply.github.com>
Co-authored-by: yinwei <yinwei_hust@163.com>
Co-authored-by: memoryCoderC <1137889088@qq.com>
2025-08-14 20:53:47 +08:00

430 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
import logging
import threading
import time
import traceback
import uuid
from typing import Any, Optional, Union
from tqdm import tqdm
from fastdeploy.engine.args_utils import EngineArgs
from fastdeploy.engine.engine import LLMEngine
from fastdeploy.engine.sampling_params import SamplingParams
# from fastdeploy.entrypoints.chat_utils import ChatCompletionMessageParam
from fastdeploy.utils import llm_logger, retrive_model_from_server
from fastdeploy.worker.output import Logprob, LogprobsLists
root_logger = logging.getLogger()
for handler in root_logger.handlers[:]:
if isinstance(handler, logging.StreamHandler):
root_logger.removeHandler(handler)
class LLM:
"""
Initializes a Language Model instance.
Args:
model (str):
The name of the language model to use. Supported models are listed in
`LLMEngine.SUPPORTED_MODELS`.
tokenizer (Optional[str], optional):
The name of the tokenizer to use. Defaults to None. If not specified, the
default tokenizer for the selected model will be used.
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
reserve for the model weights, activations, and KV cache. Higher
values will increase the KV cache size and thus improve the model's
throughput. However, if the value is too high, it may cause out-of-
memory (OOM) errors.
**kwargs (optional):
Additional keyword arguments to pass to the `EngineArgs` constructor. See
`EngineArgs.__init__` for details. Defaults to {}.
Raises:
ValueError:
If `model` is not in `LLMEngine.SUPPORTED_MODELS`.
"""
def __init__(
self,
model: str,
revision: Optional[str] = "master",
tokenizer: Optional[str] = None,
enable_logprob: Optional[bool] = False,
**kwargs,
):
model = retrive_model_from_server(model, revision)
engine_args = EngineArgs(
model=model,
tokenizer=tokenizer,
enable_logprob=enable_logprob,
**kwargs,
)
# Create the Engine
self.llm_engine = LLMEngine.from_engine_args(engine_args=engine_args)
self.default_sampling_params = SamplingParams(max_tokens=self.llm_engine.cfg.max_model_len)
self.llm_engine.start()
self.mutex = threading.Lock()
self.req_output = dict()
self.master_node_ip = self.llm_engine.cfg.master_ip
self._receive_output_thread = threading.Thread(target=self._receive_output, daemon=True)
self._receive_output_thread.start()
def _check_master(self):
"""
Check if the current node is the master node.
"""
return self.llm_engine.cfg._check_master()
def _receive_output(self):
"""
Recieve output from token processor and store them in cache
"""
while True:
try:
results = self.llm_engine._get_generated_result()
for request_id, contents in results.items():
with self.mutex:
for result in contents:
if request_id not in self.req_output:
self.req_output[request_id] = result
continue
self.req_output[request_id].add(result)
except Exception as e:
llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")
def generate(
self,
prompts: Union[
str,
list[str],
list[int],
list[list[int]],
dict[str, Any],
list[dict[str, Any]],
],
sampling_params: Optional[Union[SamplingParams, list[SamplingParams]]] = None,
use_tqdm: bool = True,
):
"""
Generate function for the LLM class.
Args:
prompts (Union[str, list[str], list[int], list[list[int]], dict[str, Any], list[dict[str, Any]]]):
The prompt to use for generating the response.
sampling_params (Optional[Union[SamplingParams, list[SamplingParams]]], optional):
The sampling parameters to use for generating the response. Defaults to None.
use_tqdm (bool, optional): Whether to use tqdm for the progress bar. Defaults to True.
Returns:
Union[str, list[str]]: The generated response.
"""
if not self._check_master():
err_msg = f"Only master node can accept completion request, please send request to master node: {self.master_node_ip}"
raise ValueError(err_msg)
if sampling_params is None:
sampling_params = self.default_sampling_params
if isinstance(sampling_params, SamplingParams):
sampling_params_len = 1
else:
sampling_params_len = len(sampling_params)
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(prompts, list) and isinstance(prompts[0], int):
prompts = [prompts]
if isinstance(prompts, dict):
if "prompt" not in prompts:
raise ValueError("prompts must be a input dict")
prompts = [prompts]
# sampling_params = None
if sampling_params_len != 1 and len(prompts) != sampling_params_len:
raise ValueError("prompts and sampling_params must be the same length.")
req_ids = self._add_request(prompts=prompts, sampling_params=sampling_params)
topk_logprobs = sampling_params[0].logprobs if sampling_params_len > 1 else sampling_params.logprobs
# get output
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm, topk_logprobs=topk_logprobs)
for i in range(len(outputs)):
outputs[i].prompt = prompts[i]
return outputs
def chat(
self,
messages: Union[list[Any], list[list[Any]]],
sampling_params: Optional[Union[SamplingParams, list[SamplingParams]]] = None,
use_tqdm: bool = True,
chat_template_kwargs: Optional[dict[str, Any]] = None,
):
"""
Args:
messages (Union[list[ChatCompletionMessageParam], list[list[ChatCompletionMessageParam]]]):
Single conversation or a list of conversations.
sampling_params (Optional[Union[SamplingParams, list[SamplingParams]]], optional):
The sampling parameters to use for generating the response. Defaults to None.
use_tqdm (bool, optional): Whether to use tqdm for the progress bar. Defaults to True.
chat_template_kwargs(Optional[dict[str,Any]]): Additional kwargs to pass to the chat
template.
Returns:
Union[str, list[str]]: The generated response.
"""
if not self._check_master():
err_msg = f"Only master node can accept completion request, please send request to master node: {self.master_node_ip}"
raise ValueError(err_msg)
if sampling_params is None:
sampling_params = self.default_sampling_params
if isinstance(sampling_params, SamplingParams):
sampling_params_len = 1
else:
sampling_params_len = len(sampling_params)
if isinstance(messages, list) and isinstance(messages[0], dict):
messages = [messages]
if sampling_params_len != 1 and len(messages) != sampling_params_len:
raise ValueError("messages and sampling_params must be the same length.")
messages_len = len(messages)
for i in range(messages_len):
messages[i] = {"messages": messages[i]}
req_ids = self._add_request(
prompts=messages,
sampling_params=sampling_params,
chat_template_kwargs=chat_template_kwargs,
)
topk_logprobs = sampling_params[0].logprobs if sampling_params_len > 1 else sampling_params.logprobs
# get output
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm, topk_logprobs=topk_logprobs)
return outputs
def _add_request(
self,
prompts,
sampling_params,
chat_template_kwargs: Optional[dict[str, Any]] = None,
):
"""
添加一个请求到 LLM Engine并返回该请求的 ID。
如果请求已经存在于 LLM Engine 中,则不会重复添加。
Args:
prompts (str): 需要处理的文本内容,类型为字符串。
Returns:
None: 无返回值,直接修改 LLM Engine 的状态。
"""
if prompts is None:
raise ValueError("prompts and prompt_ids cannot be both None.")
prompts_len = len(prompts)
req_ids = []
for i in range(prompts_len):
request_id = str(uuid.uuid4())
if isinstance(prompts[i], str):
tasks = {
"prompt": prompts[i],
"request_id": request_id,
}
elif isinstance(prompts[i], list) and isinstance(prompts[i][0], int):
tasks = {
"prompt_token_ids": prompts[i],
"request_id": request_id,
}
elif isinstance(prompts[i], dict):
tasks = prompts[i]
tasks["request_id"] = request_id
else:
raise TypeError(
f"Invalid type for 'prompt': {type(prompts[i])}, expected one of ['str', 'list', 'dict']."
)
req_ids.append(request_id)
if isinstance(sampling_params, list):
current_sampling_params = sampling_params[i]
else:
current_sampling_params = sampling_params
enable_thinking = None
if chat_template_kwargs is not None:
enable_thinking = chat_template_kwargs.get("enable_thinking", None)
self.llm_engine.add_requests(tasks, current_sampling_params, enable_thinking=enable_thinking)
return req_ids
def _decode_token(self, token_id: int) -> str:
"""Decodes a single token ID into its string representation."""
return self.llm_engine.data_processor.process_logprob_response([token_id], clean_up_tokenization_spaces=False)
def _build_sample_logprobs(self, logprobs_lists: LogprobsLists, topk_logprobs: int) -> list[dict[int, Logprob]]:
"""
Constructs a list of dictionaries mapping token IDs to Logprob objects,
based on sliced LogprobsLists data (excluding the sampled token at index 0).
Args:
logprobs_lists (LogprobsLists): Contains top-k token IDs, logprobs, and sampled ranks.
max_num (int): Maximum number of top logprobs to include (excluding sampled token at index 0).
Returns:
list[dict[int, Logprob]]: One dict per request, mapping token ID to Logprob.
"""
try:
llm_logger.info(f"filter logprobs, topk_logprobs: {topk_logprobs}")
if not logprobs_lists.logprob_token_ids:
llm_logger.warning("Empty logprob_token_ids in LogprobsLists")
return None
# exclude sampled token at index 0
available_topk = len(logprobs_lists.logprob_token_ids[0]) - 1
effective_topk_logprobs = min(topk_logprobs, available_topk)
if effective_topk_logprobs <= 0:
llm_logger.warning(
f"Invalid effective_topk_logprobs={effective_topk_logprobs}, "
f"available_topk={available_topk}, topk_logprobs={topk_logprobs}; returning empty result."
)
return None
# sliced 1 ~ (1 + effective_topk_logprobs)
sliced_logprobs_lists = logprobs_lists.slice_columns(1, 1 + effective_topk_logprobs)
result = []
for token_ids, logprobs in zip(sliced_logprobs_lists.logprob_token_ids, sliced_logprobs_lists.logprobs):
logprob_dict = {
token_id: Logprob(logprob=logprob, rank=i + 1, decoded_token=self._decode_token(token_id))
for i, (token_id, logprob) in enumerate(zip(token_ids, logprobs))
}
result.append(logprob_dict)
return result
except Exception as e:
llm_logger.error(f"Error building sample logprobs from LogprobsLists: {e}")
def _run_engine(self, req_ids: list[str], use_tqdm: bool, topk_logprobs: Optional[int] = None):
"""
运行引擎,并返回结果列表。
Args:
use_tqdm (bool, optional): 是否使用tqdm进度条默认为False。
Returns:
list[Dict[str, Any]]: 包含每个请求的结果字典的列表,字典中包含以下键值对:
- "text": str, 生成的文本;
- "score": float, 得分(可选)。
Raises:
无。
"""
# Initialize tqdm.
if use_tqdm:
num_requests = len(req_ids)
pbar = tqdm(
total=num_requests,
desc="Processed prompts",
dynamic_ncols=True,
postfix=(f"est. speed input: {0:.2f} toks/s, " f"output: {0:.2f} toks/s"),
)
output = [None] * num_requests
req_ids = [(pos, req_id) for pos, req_id in enumerate(req_ids)]
while num_requests:
finished = []
for i, (pos, req_id) in enumerate(req_ids):
with self.mutex:
if req_id not in self.req_output:
time.sleep(0.01)
continue
if not self.req_output[req_id].finished:
time.sleep(0.01)
continue
result = self.req_output.pop(req_id)
result = self.llm_engine.data_processor.process_response(result)
# filter logprobs
if result.outputs.top_logprobs and topk_logprobs:
result.outputs.logprobs = self._build_sample_logprobs(
result.outputs.top_logprobs, topk_logprobs
)
output[pos] = result
finished.append(i)
llm_logger.debug(f"Request id: {req_id} has been completed.")
if use_tqdm:
pbar.update(1)
num_requests -= len(finished)
for i in reversed(finished):
req_ids.pop(i)
if use_tqdm:
pbar.close()
return output
if __name__ == "__main__":
# llm = LLM(model="llama_model")
# output = llm.generate(prompts="who are you", use_tqdm=True)
# print(output)
llm = LLM(
model="/opt/baidu/paddle_internal/FastDeploy/Qwen2.5-7B",
tensor_parallel_size=2,
)
sampling_params = SamplingParams(temperature=0.1, max_tokens=30)
output = llm.generate(prompts="who are you", use_tqdm=True, sampling_params=sampling_params)
print(output)
output = llm.generate(
prompts=["who are you", "what can you do"],
sampling_params=SamplingParams(temperature=1, max_tokens=50),
use_tqdm=True,
)
print(output)
output = llm.generate(
prompts=["who are you", "I miss you"],
sampling_params=[
SamplingParams(temperature=1, max_tokens=50),
SamplingParams(temperature=1, max_tokens=20),
],
use_tqdm=True,
)
print(output)