mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
Sync v2.0 version of code to github repo
This commit is contained in:
@@ -15,22 +15,22 @@
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import sys
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
import time
|
||||
from typing import Optional, Dict, List, Any, Union, overload
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from fastdeploy.engine.args_utils import EngineArgs
|
||||
from fastdeploy.engine.engine import LLMEngine
|
||||
from fastdeploy.engine.sampling_params import SamplingParams
|
||||
from fastdeploy.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
from fastdeploy.utils import llm_logger
|
||||
# from fastdeploy.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
from fastdeploy.utils import llm_logger, retrive_model_from_server
|
||||
|
||||
|
||||
import logging
|
||||
root_logger = logging.getLogger()
|
||||
for handler in root_logger.handlers[:]:
|
||||
if isinstance(handler, logging.StreamHandler):
|
||||
@@ -39,23 +39,27 @@ for handler in root_logger.handlers[:]:
|
||||
|
||||
class LLM:
|
||||
"""
|
||||
Language Model wrapper class providing high-level interfaces for text generation.
|
||||
|
||||
This class manages the LLMEngine instance and provides convenient methods for
|
||||
generating text and chat completions.
|
||||
|
||||
Attributes:
|
||||
llm_engine: Underlying LLMEngine instance
|
||||
default_sampling_params: Default sampling parameters for generation
|
||||
|
||||
Initializes a Language Model instance.
|
||||
|
||||
Args:
|
||||
model: Name of the language model to use
|
||||
tokenizer: Name of the tokenizer to use (defaults to model's tokenizer)
|
||||
**kwargs: Additional arguments passed to EngineArgs constructor
|
||||
|
||||
model (str):
|
||||
The name of the language model to use. Supported models are listed in
|
||||
`LLMEngine.SUPPORTED_MODELS`.
|
||||
tokenizer (Optional[str], optional):
|
||||
The name of the tokenizer to use. Defaults to None. If not specified, the
|
||||
default tokenizer for the selected model will be used.
|
||||
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
|
||||
reserve for the model weights, activations, and KV cache. Higher
|
||||
values will increase the KV cache size and thus improve the model's
|
||||
throughput. However, if the value is too high, it may cause out-of-
|
||||
memory (OOM) errors.
|
||||
**kwargs (optional):
|
||||
Additional keyword arguments to pass to the `EngineArgs` constructor. See
|
||||
`EngineArgs.__init__` for details. Defaults to {}.
|
||||
|
||||
Raises:
|
||||
ValueError: If model is not supported
|
||||
RuntimeError: If engine fails to start
|
||||
ValueError:
|
||||
If `model` is not in `LLMEngine.SUPPORTED_MODELS`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -64,7 +68,7 @@ class LLM:
|
||||
tokenizer: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
model = retrive_model_from_server(model)
|
||||
engine_args = EngineArgs(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
@@ -72,14 +76,38 @@ class LLM:
|
||||
)
|
||||
|
||||
# Create the Engine
|
||||
self.llm_engine = LLMEngine.from_engine_args(
|
||||
engine_args=engine_args)
|
||||
self.llm_engine = LLMEngine.from_engine_args(engine_args=engine_args)
|
||||
|
||||
self.default_sampling_params = SamplingParams(
|
||||
max_tokens=self.llm_engine.cfg.max_model_len)
|
||||
|
||||
self.llm_engine.start()
|
||||
|
||||
self.mutex = threading.Lock()
|
||||
self.req_output = dict()
|
||||
|
||||
self._receive_output_thread = threading.Thread(
|
||||
target=self._receive_output, daemon=True)
|
||||
self._receive_output_thread.start()
|
||||
|
||||
def _receive_output(self):
|
||||
"""
|
||||
Recieve output from token processor and store them in cache
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
results = self.llm_engine._get_generated_result()
|
||||
for request_id, contents in results.items():
|
||||
with self.mutex:
|
||||
for result in contents:
|
||||
if request_id not in self.req_output:
|
||||
self.req_output[request_id] = result
|
||||
continue
|
||||
self.req_output[request_id].add(result)
|
||||
except Exception as e:
|
||||
llm_logger.error("Unexcepted error happend: {}, {}".format(
|
||||
e, str(traceback.format_exc())))
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: Union[str, list[str], list[int], list[list[int]],
|
||||
@@ -89,26 +117,17 @@ class LLM:
|
||||
use_tqdm: bool = True,
|
||||
):
|
||||
"""
|
||||
Generate text based on input prompts.
|
||||
|
||||
Supports various input formats including:
|
||||
- Single prompt string
|
||||
- List of prompt strings
|
||||
- Token IDs (single or batched)
|
||||
- Dictionary with additional parameters
|
||||
- List of parameter dictionaries
|
||||
|
||||
Generate function for the LLM class.
|
||||
|
||||
Args:
|
||||
prompts: Input prompts in various formats
|
||||
sampling_params: Sampling parameters for generation
|
||||
use_tqdm: Whether to show progress bar
|
||||
|
||||
prompts (Union[str, list[str], list[int], list[list[int]], dict[str, Any], list[dict[str, Any]]]):
|
||||
The prompt to use for generating the response.
|
||||
sampling_params (Optional[Union[SamplingParams, list[SamplingParams]]], optional):
|
||||
The sampling parameters to use for generating the response. Defaults to None.
|
||||
use_tqdm (bool, optional): Whether to use tqdm for the progress bar. Defaults to True.
|
||||
|
||||
Returns:
|
||||
Generated text output(s)
|
||||
|
||||
Raises:
|
||||
ValueError: If prompts and sampling_params length mismatch
|
||||
TypeError: If prompts format is invalid
|
||||
Union[str, list[str]]: The generated response.
|
||||
"""
|
||||
|
||||
if sampling_params is None:
|
||||
@@ -126,19 +145,17 @@ class LLM:
|
||||
prompts = [prompts]
|
||||
|
||||
if isinstance(prompts, dict):
|
||||
if "prompts" not in prompts:
|
||||
if "prompt" not in prompts:
|
||||
raise ValueError("prompts must be a input dict")
|
||||
prompts = [prompts]
|
||||
sampling_params = None
|
||||
# sampling_params = None
|
||||
|
||||
if sampling_params_len != 1 and len(prompts) != sampling_params_len:
|
||||
raise ValueError(
|
||||
"prompts and sampling_params must be the same length.")
|
||||
|
||||
req_ids = self._add_request(
|
||||
prompts=prompts,
|
||||
sampling_params=sampling_params
|
||||
)
|
||||
req_ids = self._add_request(prompts=prompts,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
# get output
|
||||
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm)
|
||||
@@ -146,29 +163,28 @@ class LLM:
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: Union[list[ChatCompletionMessageParam],
|
||||
list[list[ChatCompletionMessageParam]]],
|
||||
messages: Union[list[Any], list[list[Any]]],
|
||||
sampling_params: Optional[Union[SamplingParams,
|
||||
list[SamplingParams]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Generate chat completions based on conversation messages.
|
||||
|
||||
Args:
|
||||
messages: Single conversation or list of conversations
|
||||
sampling_params: Sampling parameters for generation
|
||||
use_tqdm: Whether to show progress bar
|
||||
|
||||
messages (Union[list[ChatCompletionMessageParam], list[list[ChatCompletionMessageParam]]]):
|
||||
Single conversation or a list of conversations.
|
||||
sampling_params (Optional[Union[SamplingParams, list[SamplingParams]]], optional):
|
||||
The sampling parameters to use for generating the response. Defaults to None.
|
||||
use_tqdm (bool, optional): Whether to use tqdm for the progress bar. Defaults to True.
|
||||
chat_template_kwargs(Optional[dict[str,Any]]): Additional kwargs to pass to the chat
|
||||
template.
|
||||
|
||||
Returns:
|
||||
Generated chat response(s)
|
||||
|
||||
Raises:
|
||||
ValueError: If messages and sampling_params length mismatch
|
||||
Union[str, list[str]]: The generated response.
|
||||
"""
|
||||
if sampling_params is None:
|
||||
sampling_params = self.default_sampling_params
|
||||
|
||||
|
||||
if isinstance(sampling_params, SamplingParams):
|
||||
sampling_params_len = 1
|
||||
else:
|
||||
@@ -183,13 +199,10 @@ class LLM:
|
||||
|
||||
messages_len = len(messages)
|
||||
for i in range(messages_len):
|
||||
messages[i] = {
|
||||
"messages": messages[i]
|
||||
}
|
||||
req_ids = self._add_request(
|
||||
prompts=messages,
|
||||
sampling_params=sampling_params
|
||||
)
|
||||
messages[i] = {"messages": messages[i]}
|
||||
req_ids = self._add_request(prompts=messages,
|
||||
sampling_params=sampling_params,
|
||||
chat_template_kwargs=chat_template_kwargs)
|
||||
|
||||
# get output
|
||||
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm)
|
||||
@@ -199,20 +212,17 @@ class LLM:
|
||||
self,
|
||||
prompts,
|
||||
sampling_params,
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Add generation requests to the LLM engine.
|
||||
|
||||
添加一个请求到 LLM Engine,并返回该请求的 ID。
|
||||
如果请求已经存在于 LLM Engine 中,则不会重复添加。
|
||||
|
||||
Args:
|
||||
prompts: Input prompts to process
|
||||
sampling_params: Sampling parameters for generation
|
||||
|
||||
prompts (str): 需要处理的文本内容,类型为字符串。
|
||||
|
||||
Returns:
|
||||
list: List of generated request IDs
|
||||
|
||||
Raises:
|
||||
ValueError: If prompts is None
|
||||
TypeError: If prompts format is invalid
|
||||
None: 无返回值,直接修改 LLM Engine 的状态。
|
||||
"""
|
||||
if prompts is None:
|
||||
raise ValueError("prompts and prompt_ids cannot be both None.")
|
||||
@@ -226,7 +236,8 @@ class LLM:
|
||||
"prompt": prompts[i],
|
||||
"request_id": request_id,
|
||||
}
|
||||
elif isinstance(prompts[i], list) and isinstance(prompts[i][0], int):
|
||||
elif isinstance(prompts[i], list) and isinstance(
|
||||
prompts[i][0], int):
|
||||
tasks = {
|
||||
"prompt_token_ids": prompts[i],
|
||||
"request_id": request_id,
|
||||
@@ -241,24 +252,29 @@ class LLM:
|
||||
req_ids.append(request_id)
|
||||
if isinstance(sampling_params, list):
|
||||
sampling_params = sampling_params[i]
|
||||
self.llm_engine.add_requests(tasks, sampling_params)
|
||||
enable_thinking = None
|
||||
if chat_template_kwargs is not None:
|
||||
enable_thinking = chat_template_kwargs.get(
|
||||
"enable_thinking", None)
|
||||
self.llm_engine.add_requests(tasks,
|
||||
sampling_params,
|
||||
enable_thinking=enable_thinking)
|
||||
return req_ids
|
||||
|
||||
def _run_engine(
|
||||
self, req_ids: list[str], use_tqdm: bool
|
||||
):
|
||||
def _run_engine(self, req_ids: list[str], use_tqdm: bool):
|
||||
"""
|
||||
Run the engine and collect results for given request IDs.
|
||||
|
||||
运行引擎,并返回结果列表。
|
||||
|
||||
Args:
|
||||
req_ids: List of request IDs to process
|
||||
use_tqdm: Whether to show progress bar
|
||||
|
||||
use_tqdm (bool, optional): 是否使用tqdm进度条,默认为False。
|
||||
|
||||
Returns:
|
||||
list: List of generation results
|
||||
|
||||
Note:
|
||||
This method blocks until all requests are completed
|
||||
list[Dict[str, Any]]: 包含每个请求的结果字典的列表,字典中包含以下键值对:
|
||||
- "text": str, 生成的文本;
|
||||
- "score": float, 得分(可选)。
|
||||
|
||||
Raises:
|
||||
无。
|
||||
"""
|
||||
# Initialize tqdm.
|
||||
|
||||
@@ -272,25 +288,31 @@ class LLM:
|
||||
f"output: {0:.2f} toks/s"),
|
||||
)
|
||||
|
||||
output = []
|
||||
output = [None] * num_requests
|
||||
req_ids = [(pos, req_id) for pos, req_id in enumerate(req_ids)]
|
||||
while num_requests:
|
||||
finished = []
|
||||
for i, req_id in enumerate(req_ids):
|
||||
try:
|
||||
for result in self.llm_engine._get_generated_result(req_id):
|
||||
result = self.llm_engine.data_processor.process_response(
|
||||
result)
|
||||
llm_logger.debug(
|
||||
f"Send result to client under push mode: {result}")
|
||||
if result.finished:
|
||||
output.append(result)
|
||||
finished.append(i)
|
||||
llm_logger.debug(
|
||||
"Request id: {} has been completed.".format(req_id))
|
||||
if use_tqdm:
|
||||
pbar.update(1)
|
||||
except Exception as e:
|
||||
llm_logger.error("Unexcepted error happend: {}".format(e))
|
||||
for i, (pos, req_id) in enumerate(req_ids):
|
||||
with self.mutex:
|
||||
if req_id not in self.req_output:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
if not self.req_output[req_id].finished:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
result = self.req_output.pop(req_id)
|
||||
result = self.llm_engine.data_processor.process_response(
|
||||
result)
|
||||
output[pos] = result
|
||||
finished.append(i)
|
||||
|
||||
llm_logger.debug(
|
||||
"Request id: {} has been completed.".format(req_id))
|
||||
|
||||
if use_tqdm:
|
||||
pbar.update(1)
|
||||
|
||||
num_requests -= len(finished)
|
||||
for i in reversed(finished):
|
||||
@@ -302,21 +324,27 @@ class LLM:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage:
|
||||
# llm = LLM(model="llama_model")
|
||||
# output = llm.generate(prompts="who are you?", use_tqdm=True)
|
||||
# output = llm.generate(prompts="who are you?", use_tqdm=True)
|
||||
# print(output)
|
||||
llm = LLM(model="/opt/baidu/paddle_internal/FastDeploy/Qwen2.5-7B",
|
||||
tensor_parallel_size=2)
|
||||
sampling_params = SamplingParams(temperature=0.1, max_tokens=30)
|
||||
output = llm.generate(prompts="who are you?",
|
||||
use_tqdm=True, sampling_params=sampling_params)
|
||||
use_tqdm=True,
|
||||
sampling_params=sampling_params)
|
||||
print(output)
|
||||
|
||||
output = llm.generate(prompts=["who are you?", "what can you do?"], sampling_params=SamplingParams(
|
||||
temperature=1, max_tokens=50), use_tqdm=True)
|
||||
output = llm.generate(prompts=["who are you?", "what can you do?"],
|
||||
sampling_params=SamplingParams(temperature=1,
|
||||
max_tokens=50),
|
||||
use_tqdm=True)
|
||||
print(output)
|
||||
|
||||
output = llm.generate(prompts=["who are you?", "I miss you"], sampling_params=[SamplingParams(
|
||||
temperature=1, max_tokens=50), SamplingParams(temperature=1, max_tokens=20)], use_tqdm=True)
|
||||
output = llm.generate(prompts=["who are you?", "I miss you"],
|
||||
sampling_params=[
|
||||
SamplingParams(temperature=1, max_tokens=50),
|
||||
SamplingParams(temperature=1, max_tokens=20)
|
||||
],
|
||||
use_tqdm=True)
|
||||
print(output)
|
||||
|
Reference in New Issue
Block a user