mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
polish code with new pre-commit rule (#2923)
This commit is contained in:
@@ -14,17 +14,15 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import zmq
|
||||
import time
|
||||
from random import randint
|
||||
import uuid
|
||||
|
||||
import numpy as np
|
||||
|
||||
from fastdeploy.input.preprocess import InputPreprocessor
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy.inter_communicator import ZmqClient, IPCSignal
|
||||
from fastdeploy.inter_communicator import IPCSignal, ZmqClient
|
||||
from fastdeploy.metrics.work_metrics import work_process_metrics
|
||||
from fastdeploy.utils import api_server_logger, EngineError
|
||||
from fastdeploy.utils import EngineError, api_server_logger
|
||||
|
||||
|
||||
class EngineClient:
|
||||
@@ -32,23 +30,36 @@ class EngineClient:
|
||||
EngineClient is a class that handles the communication between the client and the server.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer, max_model_len, tensor_parallel_size, pid, limit_mm_per_prompt, mm_processor_kwargs,
|
||||
enable_mm=False, reasoning_parser=None):
|
||||
input_processor = InputPreprocessor(tokenizer,
|
||||
reasoning_parser,
|
||||
limit_mm_per_prompt,
|
||||
mm_processor_kwargs,
|
||||
enable_mm)
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
max_model_len,
|
||||
tensor_parallel_size,
|
||||
pid,
|
||||
limit_mm_per_prompt,
|
||||
mm_processor_kwargs,
|
||||
enable_mm=False,
|
||||
reasoning_parser=None,
|
||||
):
|
||||
input_processor = InputPreprocessor(
|
||||
tokenizer,
|
||||
reasoning_parser,
|
||||
limit_mm_per_prompt,
|
||||
mm_processor_kwargs,
|
||||
enable_mm,
|
||||
)
|
||||
self.enable_mm = enable_mm
|
||||
self.reasoning_parser = reasoning_parser
|
||||
self.data_processor = input_processor.create_processor()
|
||||
self.max_model_len = max_model_len
|
||||
self.worker_healthy_live_recorded_time_array = np.zeros(shape=[tensor_parallel_size], dtype=np.int32)
|
||||
self.worker_healthy_live_signal = IPCSignal(name="worker_healthy_live_signal",
|
||||
array=self.worker_healthy_live_recorded_time_array,
|
||||
dtype=np.int32,
|
||||
suffix=pid,
|
||||
create=False)
|
||||
self.worker_healthy_live_signal = IPCSignal(
|
||||
name="worker_healthy_live_signal",
|
||||
array=self.worker_healthy_live_recorded_time_array,
|
||||
dtype=np.int32,
|
||||
suffix=pid,
|
||||
create=False,
|
||||
)
|
||||
|
||||
model_weights_status = np.zeros([1], dtype=np.int32)
|
||||
self.model_weights_status_signal = IPCSignal(
|
||||
@@ -56,7 +67,8 @@ class EngineClient:
|
||||
array=model_weights_status,
|
||||
dtype=np.int32,
|
||||
suffix=pid,
|
||||
create=False)
|
||||
create=False,
|
||||
)
|
||||
|
||||
def create_zmq_client(self, model, mode):
|
||||
"""
|
||||
@@ -75,7 +87,6 @@ class EngineClient:
|
||||
if "request_id" not in prompts:
|
||||
request_id = str(uuid.uuid4())
|
||||
prompts["request_id"] = request_id
|
||||
query_list = []
|
||||
|
||||
if "max_tokens" not in prompts:
|
||||
prompts["max_tokens"] = self.max_model_len - 1
|
||||
@@ -101,12 +112,12 @@ class EngineClient:
|
||||
|
||||
task["prompt_token_ids_len"] = len(task["prompt_token_ids"])
|
||||
input_ids_len = task["prompt_token_ids_len"]
|
||||
task["max_tokens"] = min(self.max_model_len - input_ids_len , task.get("max_tokens"))
|
||||
task["max_tokens"] = min(self.max_model_len - input_ids_len, task.get("max_tokens"))
|
||||
if task.get("reasoning_max_tokens", None) is None:
|
||||
task["reasoning_max_tokens"] = max(int(task["max_tokens"] * 0.8), 1)
|
||||
min_tokens = task.get("min_tokens", 1)
|
||||
if 'messages' in task:
|
||||
del task['messages']
|
||||
if "messages" in task:
|
||||
del task["messages"]
|
||||
api_server_logger.info(f"task['max_tokens']:{task['max_tokens']}")
|
||||
work_process_metrics.request_params_max_tokens.observe(task["max_tokens"])
|
||||
work_process_metrics.prompt_tokens_total.inc(input_ids_len)
|
||||
@@ -133,8 +144,7 @@ class EngineClient:
|
||||
task["preprocess_end_time"] = time.time()
|
||||
preprocess_cost_time = task["preprocess_end_time"] - task["preprocess_start_time"]
|
||||
api_server_logger.info(
|
||||
f"Cache request with request_id ({task.get('request_id')}), "
|
||||
f"cost {time.time() - preprocess_cost_time}"
|
||||
f"Cache request with request_id ({task.get('request_id')}), " f"cost {time.time() - preprocess_cost_time}"
|
||||
)
|
||||
|
||||
self.vaild_parameters(task)
|
||||
@@ -153,7 +163,6 @@ class EngineClient:
|
||||
Validate stream options
|
||||
"""
|
||||
|
||||
|
||||
if data.get("n"):
|
||||
if data["n"] != 1:
|
||||
raise ValueError("n only support 1.")
|
||||
@@ -168,34 +177,26 @@ class EngineClient:
|
||||
|
||||
if data.get("top_p"):
|
||||
if data["top_p"] > 1 or data["top_p"] < 0:
|
||||
raise ValueError(
|
||||
"top_p value can only be defined [0, 1].")
|
||||
|
||||
raise ValueError("top_p value can only be defined [0, 1].")
|
||||
|
||||
if data.get("frequency_penalty"):
|
||||
if not -2.0 <= data["frequency_penalty"] <= 2.0:
|
||||
if not -2.0 <= data["frequency_penalty"] <= 2.0:
|
||||
raise ValueError("frequency_penalty must be in [-2, 2]")
|
||||
|
||||
if data.get("temperature"):
|
||||
if data["temperature"] < 0:
|
||||
raise ValueError(f"temperature must be non-negative")
|
||||
|
||||
raise ValueError("temperature must be non-negative")
|
||||
|
||||
if data.get("presence_penalty"):
|
||||
if not -2.0 <= data["presence_penalty"] <= 2.0:
|
||||
if not -2.0 <= data["presence_penalty"] <= 2.0:
|
||||
raise ValueError("presence_penalty must be in [-2, 2]")
|
||||
|
||||
|
||||
|
||||
if data.get("seed"):
|
||||
if not 0 <= data["seed"] <= 922337203685477580:
|
||||
raise ValueError("seed must be in [0, 922337203685477580]")
|
||||
|
||||
if data.get("stream_options") and not data.get("stream"):
|
||||
raise ValueError(
|
||||
"Stream options can only be defined when `stream=True`.")
|
||||
|
||||
|
||||
raise ValueError("Stream options can only be defined when `stream=True`.")
|
||||
|
||||
def check_health(self, time_interval_threashold=30):
|
||||
"""
|
||||
@@ -209,7 +210,6 @@ class EngineClient:
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def is_workers_alive(self):
|
||||
"""
|
||||
Check the health of the model server by checking whether all workers are alive.
|
||||
@@ -220,9 +220,7 @@ class EngineClient:
|
||||
else:
|
||||
return False, "No model weight enabled"
|
||||
|
||||
|
||||
|
||||
def update_model_weight(self, timeout = 300):
|
||||
def update_model_weight(self, timeout=300):
|
||||
"""
|
||||
Update the model weight by sending a signal to the server.
|
||||
1 : worker receive the signal and start to update model weight
|
||||
@@ -235,7 +233,7 @@ class EngineClient:
|
||||
|
||||
self.model_weights_status_signal.value[0] = 1
|
||||
api_server_logger.info(f"start update model weight {self.model_weights_status_signal.value}")
|
||||
while self.model_weights_status_signal.value[0] != 0 and timeout != 0:
|
||||
while self.model_weights_status_signal.value[0] != 0 and timeout != 0:
|
||||
time.sleep(1)
|
||||
timeout -= 1
|
||||
continue
|
||||
@@ -244,9 +242,7 @@ class EngineClient:
|
||||
time.sleep(1)
|
||||
return True, ""
|
||||
|
||||
|
||||
|
||||
def clear_load_weight(self, timeout = 300):
|
||||
def clear_load_weight(self, timeout=300):
|
||||
"""
|
||||
Clear the load weight status.
|
||||
-1 : worker receive the signal and start to clear model weight
|
||||
@@ -260,7 +256,7 @@ class EngineClient:
|
||||
self.model_weights_status_signal.value[0] = -1
|
||||
|
||||
api_server_logger.info(f"start clear model weight {self.model_weights_status_signal.value}")
|
||||
while self.model_weights_status_signal.value[0] != -2 and timeout != 0:
|
||||
while self.model_weights_status_signal.value[0] != -2 and timeout != 0:
|
||||
time.sleep(1)
|
||||
timeout -= 1
|
||||
continue
|
||||
|
Reference in New Issue
Block a user