polish code with new pre-commit rule (#2923)

This commit is contained in:
Zero Rains
2025-07-19 23:19:27 +08:00
committed by GitHub
parent b8676d71a8
commit 25698d56d1
424 changed files with 14307 additions and 13518 deletions

View File

@@ -28,6 +28,7 @@ from tqdm import tqdm
from fastdeploy.engine.args_utils import EngineArgs
from fastdeploy.engine.engine import LLMEngine
from fastdeploy.engine.sampling_params import SamplingParams
# from fastdeploy.entrypoints.chat_utils import ChatCompletionMessageParam
from fastdeploy.utils import llm_logger, retrive_model_from_server
@@ -78,18 +79,16 @@ class LLM:
# Create the Engine
self.llm_engine = LLMEngine.from_engine_args(engine_args=engine_args)
self.default_sampling_params = SamplingParams(
max_tokens=self.llm_engine.cfg.max_model_len)
self.default_sampling_params = SamplingParams(max_tokens=self.llm_engine.cfg.max_model_len)
self.llm_engine.start()
self.mutex = threading.Lock()
self.req_output = dict()
self.master_node_ip = self.llm_engine.cfg.master_ip
self._receive_output_thread = threading.Thread(
target=self._receive_output, daemon=True)
self._receive_output_thread = threading.Thread(target=self._receive_output, daemon=True)
self._receive_output_thread.start()
def _check_master(self):
"""
Check if the current node is the master node.
@@ -111,15 +110,19 @@ class LLM:
continue
self.req_output[request_id].add(result)
except Exception as e:
llm_logger.error("Unexcepted error happend: {}, {}".format(
e, str(traceback.format_exc())))
llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")
def generate(
self,
prompts: Union[str, list[str], list[int], list[list[int]],
dict[str, Any], list[dict[str, Any]]],
sampling_params: Optional[Union[SamplingParams,
list[SamplingParams]]] = None,
prompts: Union[
str,
list[str],
list[int],
list[list[int]],
dict[str, Any],
list[dict[str, Any]],
],
sampling_params: Optional[Union[SamplingParams, list[SamplingParams]]] = None,
use_tqdm: bool = True,
):
"""
@@ -161,11 +164,9 @@ class LLM:
# sampling_params = None
if sampling_params_len != 1 and len(prompts) != sampling_params_len:
raise ValueError(
"prompts and sampling_params must be the same length.")
raise ValueError("prompts and sampling_params must be the same length.")
req_ids = self._add_request(prompts=prompts,
sampling_params=sampling_params)
req_ids = self._add_request(prompts=prompts, sampling_params=sampling_params)
# get output
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm)
@@ -176,8 +177,7 @@ class LLM:
def chat(
self,
messages: Union[list[Any], list[list[Any]]],
sampling_params: Optional[Union[SamplingParams,
list[SamplingParams]]] = None,
sampling_params: Optional[Union[SamplingParams, list[SamplingParams]]] = None,
use_tqdm: bool = True,
chat_template_kwargs: Optional[dict[str, Any]] = None,
):
@@ -198,7 +198,7 @@ class LLM:
if not self._check_master():
err_msg = f"Only master node can accept completion request, please send request to master node: {self.master_node_ip}"
raise ValueError(err_msg)
if sampling_params is None:
sampling_params = self.default_sampling_params
@@ -211,15 +211,16 @@ class LLM:
messages = [messages]
if sampling_params_len != 1 and len(messages) != sampling_params_len:
raise ValueError(
"messages and sampling_params must be the same length.")
raise ValueError("messages and sampling_params must be the same length.")
messages_len = len(messages)
for i in range(messages_len):
messages[i] = {"messages": messages[i]}
req_ids = self._add_request(prompts=messages,
sampling_params=sampling_params,
chat_template_kwargs=chat_template_kwargs)
req_ids = self._add_request(
prompts=messages,
sampling_params=sampling_params,
chat_template_kwargs=chat_template_kwargs,
)
# get output
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm)
@@ -253,8 +254,7 @@ class LLM:
"prompt": prompts[i],
"request_id": request_id,
}
elif isinstance(prompts[i], list) and isinstance(
prompts[i][0], int):
elif isinstance(prompts[i], list) and isinstance(prompts[i][0], int):
tasks = {
"prompt_token_ids": prompts[i],
"request_id": request_id,
@@ -273,11 +273,8 @@ class LLM:
current_sampling_params = sampling_params
enable_thinking = None
if chat_template_kwargs is not None:
enable_thinking = chat_template_kwargs.get(
"enable_thinking", None)
self.llm_engine.add_requests(tasks,
current_sampling_params,
enable_thinking=enable_thinking)
enable_thinking = chat_template_kwargs.get("enable_thinking", None)
self.llm_engine.add_requests(tasks, current_sampling_params, enable_thinking=enable_thinking)
return req_ids
def _run_engine(self, req_ids: list[str], use_tqdm: bool):
@@ -303,8 +300,7 @@ class LLM:
total=num_requests,
desc="Processed prompts",
dynamic_ncols=True,
postfix=(f"est. speed input: {0:.2f} toks/s, "
f"output: {0:.2f} toks/s"),
postfix=(f"est. speed input: {0:.2f} toks/s, " f"output: {0:.2f} toks/s"),
)
output = [None] * num_requests
@@ -322,13 +318,11 @@ class LLM:
continue
result = self.req_output.pop(req_id)
result = self.llm_engine.data_processor.process_response(
result)
result = self.llm_engine.data_processor.process_response(result)
output[pos] = result
finished.append(i)
llm_logger.debug(
"Request id: {} has been completed.".format(req_id))
llm_logger.debug(f"Request id: {req_id} has been completed.")
if use_tqdm:
pbar.update(1)
@@ -346,24 +340,27 @@ if __name__ == "__main__":
# llm = LLM(model="llama_model")
# output = llm.generate(prompts="who are you", use_tqdm=True)
# print(output)
llm = LLM(model="/opt/baidu/paddle_internal/FastDeploy/Qwen2.5-7B",
tensor_parallel_size=2)
llm = LLM(
model="/opt/baidu/paddle_internal/FastDeploy/Qwen2.5-7B",
tensor_parallel_size=2,
)
sampling_params = SamplingParams(temperature=0.1, max_tokens=30)
output = llm.generate(prompts="who are you",
use_tqdm=True,
sampling_params=sampling_params)
output = llm.generate(prompts="who are you", use_tqdm=True, sampling_params=sampling_params)
print(output)
output = llm.generate(prompts=["who are you", "what can you do"],
sampling_params=SamplingParams(temperature=1,
max_tokens=50),
use_tqdm=True)
output = llm.generate(
prompts=["who are you", "what can you do"],
sampling_params=SamplingParams(temperature=1, max_tokens=50),
use_tqdm=True,
)
print(output)
output = llm.generate(prompts=["who are you", "I miss you"],
sampling_params=[
SamplingParams(temperature=1, max_tokens=50),
SamplingParams(temperature=1, max_tokens=20)
],
use_tqdm=True)
output = llm.generate(
prompts=["who are you", "I miss you"],
sampling_params=[
SamplingParams(temperature=1, max_tokens=50),
SamplingParams(temperature=1, max_tokens=20),
],
use_tqdm=True,
)
print(output)