mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 08:16:42 +08:00
polish code with new pre-commit rule (#2923)
This commit is contained in:
@@ -28,6 +28,7 @@ from tqdm import tqdm
|
||||
from fastdeploy.engine.args_utils import EngineArgs
|
||||
from fastdeploy.engine.engine import LLMEngine
|
||||
from fastdeploy.engine.sampling_params import SamplingParams
|
||||
|
||||
# from fastdeploy.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
from fastdeploy.utils import llm_logger, retrive_model_from_server
|
||||
|
||||
@@ -78,18 +79,16 @@ class LLM:
|
||||
# Create the Engine
|
||||
self.llm_engine = LLMEngine.from_engine_args(engine_args=engine_args)
|
||||
|
||||
self.default_sampling_params = SamplingParams(
|
||||
max_tokens=self.llm_engine.cfg.max_model_len)
|
||||
self.default_sampling_params = SamplingParams(max_tokens=self.llm_engine.cfg.max_model_len)
|
||||
|
||||
self.llm_engine.start()
|
||||
|
||||
self.mutex = threading.Lock()
|
||||
self.req_output = dict()
|
||||
self.master_node_ip = self.llm_engine.cfg.master_ip
|
||||
self._receive_output_thread = threading.Thread(
|
||||
target=self._receive_output, daemon=True)
|
||||
self._receive_output_thread = threading.Thread(target=self._receive_output, daemon=True)
|
||||
self._receive_output_thread.start()
|
||||
|
||||
|
||||
def _check_master(self):
|
||||
"""
|
||||
Check if the current node is the master node.
|
||||
@@ -111,15 +110,19 @@ class LLM:
|
||||
continue
|
||||
self.req_output[request_id].add(result)
|
||||
except Exception as e:
|
||||
llm_logger.error("Unexcepted error happend: {}, {}".format(
|
||||
e, str(traceback.format_exc())))
|
||||
llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: Union[str, list[str], list[int], list[list[int]],
|
||||
dict[str, Any], list[dict[str, Any]]],
|
||||
sampling_params: Optional[Union[SamplingParams,
|
||||
list[SamplingParams]]] = None,
|
||||
prompts: Union[
|
||||
str,
|
||||
list[str],
|
||||
list[int],
|
||||
list[list[int]],
|
||||
dict[str, Any],
|
||||
list[dict[str, Any]],
|
||||
],
|
||||
sampling_params: Optional[Union[SamplingParams, list[SamplingParams]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
):
|
||||
"""
|
||||
@@ -161,11 +164,9 @@ class LLM:
|
||||
# sampling_params = None
|
||||
|
||||
if sampling_params_len != 1 and len(prompts) != sampling_params_len:
|
||||
raise ValueError(
|
||||
"prompts and sampling_params must be the same length.")
|
||||
raise ValueError("prompts and sampling_params must be the same length.")
|
||||
|
||||
req_ids = self._add_request(prompts=prompts,
|
||||
sampling_params=sampling_params)
|
||||
req_ids = self._add_request(prompts=prompts, sampling_params=sampling_params)
|
||||
|
||||
# get output
|
||||
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm)
|
||||
@@ -176,8 +177,7 @@ class LLM:
|
||||
def chat(
|
||||
self,
|
||||
messages: Union[list[Any], list[list[Any]]],
|
||||
sampling_params: Optional[Union[SamplingParams,
|
||||
list[SamplingParams]]] = None,
|
||||
sampling_params: Optional[Union[SamplingParams, list[SamplingParams]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
@@ -198,7 +198,7 @@ class LLM:
|
||||
if not self._check_master():
|
||||
err_msg = f"Only master node can accept completion request, please send request to master node: {self.master_node_ip}"
|
||||
raise ValueError(err_msg)
|
||||
|
||||
|
||||
if sampling_params is None:
|
||||
sampling_params = self.default_sampling_params
|
||||
|
||||
@@ -211,15 +211,16 @@ class LLM:
|
||||
messages = [messages]
|
||||
|
||||
if sampling_params_len != 1 and len(messages) != sampling_params_len:
|
||||
raise ValueError(
|
||||
"messages and sampling_params must be the same length.")
|
||||
raise ValueError("messages and sampling_params must be the same length.")
|
||||
|
||||
messages_len = len(messages)
|
||||
for i in range(messages_len):
|
||||
messages[i] = {"messages": messages[i]}
|
||||
req_ids = self._add_request(prompts=messages,
|
||||
sampling_params=sampling_params,
|
||||
chat_template_kwargs=chat_template_kwargs)
|
||||
req_ids = self._add_request(
|
||||
prompts=messages,
|
||||
sampling_params=sampling_params,
|
||||
chat_template_kwargs=chat_template_kwargs,
|
||||
)
|
||||
|
||||
# get output
|
||||
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm)
|
||||
@@ -253,8 +254,7 @@ class LLM:
|
||||
"prompt": prompts[i],
|
||||
"request_id": request_id,
|
||||
}
|
||||
elif isinstance(prompts[i], list) and isinstance(
|
||||
prompts[i][0], int):
|
||||
elif isinstance(prompts[i], list) and isinstance(prompts[i][0], int):
|
||||
tasks = {
|
||||
"prompt_token_ids": prompts[i],
|
||||
"request_id": request_id,
|
||||
@@ -273,11 +273,8 @@ class LLM:
|
||||
current_sampling_params = sampling_params
|
||||
enable_thinking = None
|
||||
if chat_template_kwargs is not None:
|
||||
enable_thinking = chat_template_kwargs.get(
|
||||
"enable_thinking", None)
|
||||
self.llm_engine.add_requests(tasks,
|
||||
current_sampling_params,
|
||||
enable_thinking=enable_thinking)
|
||||
enable_thinking = chat_template_kwargs.get("enable_thinking", None)
|
||||
self.llm_engine.add_requests(tasks, current_sampling_params, enable_thinking=enable_thinking)
|
||||
return req_ids
|
||||
|
||||
def _run_engine(self, req_ids: list[str], use_tqdm: bool):
|
||||
@@ -303,8 +300,7 @@ class LLM:
|
||||
total=num_requests,
|
||||
desc="Processed prompts",
|
||||
dynamic_ncols=True,
|
||||
postfix=(f"est. speed input: {0:.2f} toks/s, "
|
||||
f"output: {0:.2f} toks/s"),
|
||||
postfix=(f"est. speed input: {0:.2f} toks/s, " f"output: {0:.2f} toks/s"),
|
||||
)
|
||||
|
||||
output = [None] * num_requests
|
||||
@@ -322,13 +318,11 @@ class LLM:
|
||||
continue
|
||||
|
||||
result = self.req_output.pop(req_id)
|
||||
result = self.llm_engine.data_processor.process_response(
|
||||
result)
|
||||
result = self.llm_engine.data_processor.process_response(result)
|
||||
output[pos] = result
|
||||
finished.append(i)
|
||||
|
||||
llm_logger.debug(
|
||||
"Request id: {} has been completed.".format(req_id))
|
||||
llm_logger.debug(f"Request id: {req_id} has been completed.")
|
||||
|
||||
if use_tqdm:
|
||||
pbar.update(1)
|
||||
@@ -346,24 +340,27 @@ if __name__ == "__main__":
|
||||
# llm = LLM(model="llama_model")
|
||||
# output = llm.generate(prompts="who are you?", use_tqdm=True)
|
||||
# print(output)
|
||||
llm = LLM(model="/opt/baidu/paddle_internal/FastDeploy/Qwen2.5-7B",
|
||||
tensor_parallel_size=2)
|
||||
llm = LLM(
|
||||
model="/opt/baidu/paddle_internal/FastDeploy/Qwen2.5-7B",
|
||||
tensor_parallel_size=2,
|
||||
)
|
||||
sampling_params = SamplingParams(temperature=0.1, max_tokens=30)
|
||||
output = llm.generate(prompts="who are you?",
|
||||
use_tqdm=True,
|
||||
sampling_params=sampling_params)
|
||||
output = llm.generate(prompts="who are you?", use_tqdm=True, sampling_params=sampling_params)
|
||||
print(output)
|
||||
|
||||
output = llm.generate(prompts=["who are you?", "what can you do?"],
|
||||
sampling_params=SamplingParams(temperature=1,
|
||||
max_tokens=50),
|
||||
use_tqdm=True)
|
||||
output = llm.generate(
|
||||
prompts=["who are you?", "what can you do?"],
|
||||
sampling_params=SamplingParams(temperature=1, max_tokens=50),
|
||||
use_tqdm=True,
|
||||
)
|
||||
print(output)
|
||||
|
||||
output = llm.generate(prompts=["who are you?", "I miss you"],
|
||||
sampling_params=[
|
||||
SamplingParams(temperature=1, max_tokens=50),
|
||||
SamplingParams(temperature=1, max_tokens=20)
|
||||
],
|
||||
use_tqdm=True)
|
||||
output = llm.generate(
|
||||
prompts=["who are you?", "I miss you"],
|
||||
sampling_params=[
|
||||
SamplingParams(temperature=1, max_tokens=50),
|
||||
SamplingParams(temperature=1, max_tokens=20),
|
||||
],
|
||||
use_tqdm=True,
|
||||
)
|
||||
print(output)
|
||||
|
Reference in New Issue
Block a user