Files
FastDeploy/fastdeploy/entrypoints/llm.py
2025-06-16 00:04:48 +08:00

323 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
import sys
import traceback
import uuid
import time
from typing import Optional, Dict, List, Any, Union, overload
from tqdm import tqdm
from fastdeploy.engine.args_utils import EngineArgs
from fastdeploy.engine.engine import LLMEngine
from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.entrypoints.chat_utils import ChatCompletionMessageParam
from fastdeploy.utils import llm_logger
import logging
root_logger = logging.getLogger()
for handler in root_logger.handlers[:]:
if isinstance(handler, logging.StreamHandler):
root_logger.removeHandler(handler)
class LLM:
"""
Language Model wrapper class providing high-level interfaces for text generation.
This class manages the LLMEngine instance and provides convenient methods for
generating text and chat completions.
Attributes:
llm_engine: Underlying LLMEngine instance
default_sampling_params: Default sampling parameters for generation
Args:
model: Name of the language model to use
tokenizer: Name of the tokenizer to use (defaults to model's tokenizer)
**kwargs: Additional arguments passed to EngineArgs constructor
Raises:
ValueError: If model is not supported
RuntimeError: If engine fails to start
"""
def __init__(
self,
model: str,
tokenizer: Optional[str] = None,
**kwargs,
):
engine_args = EngineArgs(
model=model,
tokenizer=tokenizer,
**kwargs,
)
# Create the Engine
self.llm_engine = LLMEngine.from_engine_args(
engine_args=engine_args)
self.default_sampling_params = SamplingParams(
max_tokens=self.llm_engine.cfg.max_model_len)
self.llm_engine.start()
def generate(
self,
prompts: Union[str, list[str], list[int], list[list[int]],
dict[str, Any], list[dict[str, Any]]],
sampling_params: Optional[Union[SamplingParams,
list[SamplingParams]]] = None,
use_tqdm: bool = True,
):
"""
Generate text based on input prompts.
Supports various input formats including:
- Single prompt string
- List of prompt strings
- Token IDs (single or batched)
- Dictionary with additional parameters
- List of parameter dictionaries
Args:
prompts: Input prompts in various formats
sampling_params: Sampling parameters for generation
use_tqdm: Whether to show progress bar
Returns:
Generated text output(s)
Raises:
ValueError: If prompts and sampling_params length mismatch
TypeError: If prompts format is invalid
"""
if sampling_params is None:
sampling_params = self.default_sampling_params
if isinstance(sampling_params, SamplingParams):
sampling_params_len = 1
else:
sampling_params_len = len(sampling_params)
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(prompts, list) and isinstance(prompts[0], int):
prompts = [prompts]
if isinstance(prompts, dict):
if "prompts" not in prompts:
raise ValueError("prompts must be a input dict")
prompts = [prompts]
sampling_params = None
if sampling_params_len != 1 and len(prompts) != sampling_params_len:
raise ValueError(
"prompts and sampling_params must be the same length.")
req_ids = self._add_request(
prompts=prompts,
sampling_params=sampling_params
)
# get output
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm)
return outputs
def chat(
self,
messages: Union[list[ChatCompletionMessageParam],
list[list[ChatCompletionMessageParam]]],
sampling_params: Optional[Union[SamplingParams,
list[SamplingParams]]] = None,
use_tqdm: bool = True,
):
"""
Generate chat completions based on conversation messages.
Args:
messages: Single conversation or list of conversations
sampling_params: Sampling parameters for generation
use_tqdm: Whether to show progress bar
Returns:
Generated chat response(s)
Raises:
ValueError: If messages and sampling_params length mismatch
"""
if sampling_params is None:
sampling_params = self.default_sampling_params
if isinstance(sampling_params, SamplingParams):
sampling_params_len = 1
else:
sampling_params_len = len(sampling_params)
if isinstance(messages, list) and isinstance(messages[0], dict):
messages = [messages]
if sampling_params_len != 1 and len(messages) != sampling_params_len:
raise ValueError(
"messages and sampling_params must be the same length.")
messages_len = len(messages)
for i in range(messages_len):
messages[i] = {
"messages": messages[i]
}
req_ids = self._add_request(
prompts=messages,
sampling_params=sampling_params
)
# get output
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm)
return outputs
def _add_request(
self,
prompts,
sampling_params,
):
"""
Add generation requests to the LLM engine.
Args:
prompts: Input prompts to process
sampling_params: Sampling parameters for generation
Returns:
list: List of generated request IDs
Raises:
ValueError: If prompts is None
TypeError: If prompts format is invalid
"""
if prompts is None:
raise ValueError("prompts and prompt_ids cannot be both None.")
prompts_len = len(prompts)
req_ids = []
for i in range(prompts_len):
request_id = str(uuid.uuid4())
if isinstance(prompts[i], str):
tasks = {
"prompt": prompts[i],
"request_id": request_id,
}
elif isinstance(prompts[i], list) and isinstance(prompts[i][0], int):
tasks = {
"prompt_token_ids": prompts[i],
"request_id": request_id,
}
elif isinstance(prompts[i], dict):
tasks = prompts[i]
tasks["request_id"] = request_id
else:
raise TypeError(
f"Invalid type for 'prompt': {type(prompts[i])}, expected one of ['str', 'list', 'dict']."
)
req_ids.append(request_id)
if isinstance(sampling_params, list):
sampling_params = sampling_params[i]
self.llm_engine.add_requests(tasks, sampling_params)
return req_ids
def _run_engine(
self, req_ids: list[str], use_tqdm: bool
):
"""
Run the engine and collect results for given request IDs.
Args:
req_ids: List of request IDs to process
use_tqdm: Whether to show progress bar
Returns:
list: List of generation results
Note:
This method blocks until all requests are completed
"""
# Initialize tqdm.
if use_tqdm:
num_requests = len(req_ids)
pbar = tqdm(
total=num_requests,
desc="Processed prompts",
dynamic_ncols=True,
postfix=(f"est. speed input: {0:.2f} toks/s, "
f"output: {0:.2f} toks/s"),
)
output = []
while num_requests:
finished = []
for i, req_id in enumerate(req_ids):
try:
for result in self.llm_engine._get_generated_result(req_id):
result = self.llm_engine.data_processor.process_response(
result)
llm_logger.debug(
f"Send result to client under push mode: {result}")
if result.finished:
output.append(result)
finished.append(i)
llm_logger.debug(
"Request id: {} has been completed.".format(req_id))
if use_tqdm:
pbar.update(1)
except Exception as e:
llm_logger.error("Unexcepted error happend: {}".format(e))
num_requests -= len(finished)
for i in reversed(finished):
req_ids.pop(i)
if use_tqdm:
pbar.close()
return output
if __name__ == "__main__":
# Example usage:
# llm = LLM(model="llama_model")
# output = llm.generate(prompts="who are you?", use_tqdm=True)
# print(output)
llm = LLM(model="/opt/baidu/paddle_internal/FastDeploy/Qwen2.5-7B",
tensor_parallel_size=2)
sampling_params = SamplingParams(temperature=0.1, max_tokens=30)
output = llm.generate(prompts="who are you",
use_tqdm=True, sampling_params=sampling_params)
print(output)
output = llm.generate(prompts=["who are you", "what can you do"], sampling_params=SamplingParams(
temperature=1, max_tokens=50), use_tqdm=True)
print(output)
output = llm.generate(prompts=["who are you", "I miss you"], sampling_params=[SamplingParams(
temperature=1, max_tokens=50), SamplingParams(temperature=1, max_tokens=20)], use_tqdm=True)
print(output)