mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[Sync Code] develop to release/2.0.3 (#2873)
* [LLM] support send batch data and aggregate data (#2860) * [LLM] support send batch data and aggregate data * [LLM] fix ci bugs * [LLM] fix ci bugs * [LLM] fix ci bugs * [LLM] fix ci bugs * [LLM] update * [LLM] Update Multinode Deployment (#2830) * [LLM] fix multinode bugs * [LLM] update multinode deployment * [LLM] update multinode deployment * [LLM] update multinode deployment * [LLM] update multinode deployment * [LLM] update multinode deployment * [LLM] fix ci bugs * Update fastdeploy/engine/args_utils.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * [LLM] update random port * [LLM] update random port * [LLM] fix ci bugs * fix ci bugs --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: ltd0924 <32387785+ltd0924@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -124,9 +124,19 @@ class EngineArgs:
|
||||
Ratio of tokens to process in a block.
|
||||
"""
|
||||
|
||||
pod_ips: Optional[List[str]] = None
|
||||
dist_init_ip: Optional[str] = None
|
||||
"""
|
||||
List of IP addresses for nodes in the cluster.
|
||||
The master node ip of multinode deployment
|
||||
"""
|
||||
|
||||
nnodes: int = 1
|
||||
"""
|
||||
The number of nodes in multinode deployment
|
||||
"""
|
||||
|
||||
node_rank: int = 0
|
||||
"""
|
||||
The rank of the current node in multinode deployment
|
||||
"""
|
||||
|
||||
swap_space: float = None
|
||||
@@ -485,11 +495,25 @@ class EngineArgs:
|
||||
# Cluster system parameters group
|
||||
system_group = parser.add_argument_group("System Configuration")
|
||||
system_group.add_argument(
|
||||
"--pod-ips",
|
||||
type=lambda s: s.split(",") if s else None,
|
||||
default=EngineArgs.pod_ips,
|
||||
"--dist-init-ip",
|
||||
default=EngineArgs.dist_init_ip,
|
||||
help=
|
||||
"List of IP addresses for nodes in the cluster (comma-separated).")
|
||||
"IP addresses of master node.")
|
||||
|
||||
system_group.add_argument(
|
||||
"--nnodes",
|
||||
type=int,
|
||||
default=EngineArgs.nnodes,
|
||||
help=
|
||||
"The number of all nodes.")
|
||||
|
||||
system_group.add_argument(
|
||||
"--node-rank",
|
||||
type=int,
|
||||
default=EngineArgs.node_rank,
|
||||
help=
|
||||
"node rank id (range [0, nnodes)).")
|
||||
|
||||
|
||||
|
||||
# Performance tuning parameters group
|
||||
@@ -789,7 +813,9 @@ class EngineArgs:
|
||||
max_num_seqs=self.max_num_seqs,
|
||||
speculative_config=speculative_cfg,
|
||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||
pod_ips=self.pod_ips,
|
||||
dist_init_ip=self.dist_init_ip,
|
||||
nnodes=self.nnodes,
|
||||
node_rank=self.node_rank,
|
||||
use_warmup=self.use_warmup,
|
||||
engine_worker_queue_port=self.engine_worker_queue_port,
|
||||
limit_mm_per_prompt=self.limit_mm_per_prompt,
|
||||
|
@@ -6,7 +6,7 @@
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#dist_init_ip
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
@@ -24,7 +24,7 @@ from fastdeploy import envs
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.scheduler import SchedulerConfig
|
||||
from fastdeploy.utils import (ceil_div, check_unified_ckpt, get_host_ip,
|
||||
is_port_available, llm_logger)
|
||||
is_port_available, get_random_port, llm_logger)
|
||||
|
||||
TaskOption = Literal["generate"]
|
||||
|
||||
@@ -642,7 +642,9 @@ class Config:
|
||||
max_model_len: int = 8192,
|
||||
max_num_seqs: int = 8,
|
||||
max_num_batched_tokens: Optional[int] = None,
|
||||
pod_ips: Optional[List[str]] = None,
|
||||
dist_init_ip: str = None,
|
||||
nnodes: int = 1,
|
||||
node_rank: int = 0,
|
||||
speculative_config: Optional[Dict[str, Any]] = None,
|
||||
graph_optimization_config: Optional[Dict[str, Any]] = None,
|
||||
use_warmup: bool = False,
|
||||
@@ -675,7 +677,6 @@ class Config:
|
||||
max_model_len (int): Maximum model length. Default is 8192.
|
||||
max_num_seqs (int): Maximum number of sequences. Default is 8.
|
||||
max_num_batched_tokens (Optional[int]): Maximum number of batched tokens. Default is None.
|
||||
pod_ips (Optional[List[str]]): List of POD IPs. Default is None.
|
||||
mm_processor_kwargs (Optional[Dict[str, Any]]): Additional arguments for multi-modal processor. Default is None.
|
||||
speculative_config (Optional[Dict[str, Any]]): Speculative execution configuration. Default is None.
|
||||
graph_optimization_config (Optional[Dict[str, Any]]): Graph optimizaion backend execution configuration. Default is None.
|
||||
@@ -699,7 +700,16 @@ class Config:
|
||||
self.tokenizer = tokenizer
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.tensor_parallel_size = tensor_parallel_size
|
||||
self.pod_ips = pod_ips
|
||||
self.dist_init_ip = dist_init_ip
|
||||
|
||||
self.nnode = nnodes
|
||||
self.node_rank = node_rank
|
||||
if self.dist_init_ip is None:
|
||||
self.master_ip = "0.0.0.0"
|
||||
else:
|
||||
self.master_ip = self.dist_init_ip
|
||||
self.dist_init_addr = f"{self.dist_init_ip}:{get_random_port()}"
|
||||
|
||||
self.max_model_len = max_model_len
|
||||
self.max_num_seqs = max_num_seqs
|
||||
self.limit_mm_per_prompt = limit_mm_per_prompt
|
||||
@@ -716,14 +726,8 @@ class Config:
|
||||
self.graph_optimization_config = graph_optimization_config
|
||||
self.guided_decoding_backend = guided_decoding_backend
|
||||
self.disable_any_whitespace = disable_any_whitespace
|
||||
self.is_master = True
|
||||
self._str_to_list("innode_prefill_ports", int)
|
||||
self._str_to_list("pod_ips", str)
|
||||
|
||||
if self.pod_ips is None:
|
||||
self.nnode = 1
|
||||
else:
|
||||
self.nnode = len(self.pod_ips)
|
||||
|
||||
assert self.splitwise_role in ["mixed", "prefill", "decode"]
|
||||
|
||||
@@ -778,9 +782,9 @@ class Config:
|
||||
|
||||
self.host_ip = get_host_ip()
|
||||
|
||||
if self.pod_ips is None:
|
||||
self.pod_ips = ["0.0.0.0"]
|
||||
elif self.host_ip != self.pod_ips[0]:
|
||||
if self.dist_init_ip is None or self.host_ip == self.master_ip:
|
||||
self.is_master = True
|
||||
else:
|
||||
self.is_master = False
|
||||
|
||||
import paddle
|
||||
|
@@ -174,7 +174,7 @@ class LLMEngine(object):
|
||||
cache_config=self.cfg.cache_config,
|
||||
tensor_parallel_size=self.cfg.tensor_parallel_size,
|
||||
device_ids=device_ids,
|
||||
pod_ip=self.cfg.pod_ips[0],
|
||||
pod_ip=self.cfg.master_ip,
|
||||
engine_worker_queue_port=self.cfg.engine_worker_queue_port,
|
||||
pid_suffix=self.ipc_signal_suffix)
|
||||
|
||||
@@ -239,11 +239,12 @@ class LLMEngine(object):
|
||||
|
||||
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
|
||||
self.dp_processed = []
|
||||
for i in range(1, self.cfg.parallel_config.data_parallel_size):
|
||||
for i in range(1, self.cfg.parallel_config.data_parallel_size // self.cfg.nnode):
|
||||
time.sleep(1)
|
||||
self.dp_processed.append(
|
||||
multiprocessing.Process(target=start_expert_service,
|
||||
args=(self.cfg, i,
|
||||
args=(self.cfg,
|
||||
i + self.cfg.node_rank * self.cfg.worker_num_per_node,
|
||||
self.ipc_signal_suffix)))
|
||||
llm_logger.info(f"Engine is initialized successfully with {self.cfg.tensor_parallel_size}" \
|
||||
+ " data parallel id {}".format(i))
|
||||
@@ -263,10 +264,11 @@ class LLMEngine(object):
|
||||
try:
|
||||
results = self.scheduler.get_results()
|
||||
if len(results) == 0:
|
||||
time.sleep(0.001)
|
||||
time.sleep(0.005)
|
||||
continue
|
||||
for request_id, contents in results.items():
|
||||
for result in contents:
|
||||
self.zmq_server.send_multipart(request_id, result)
|
||||
self.zmq_server.send_multipart(request_id, contents)
|
||||
|
||||
except Exception as e:
|
||||
llm_logger.error("Unexcepted error happend: {}, {}".format(
|
||||
e, str(traceback.format_exc())))
|
||||
@@ -1006,8 +1008,6 @@ class LLMEngine(object):
|
||||
)
|
||||
|
||||
arguments = (
|
||||
f" --nnodes {str(self.cfg.nnode)}"
|
||||
f" --ips {','.join(self.cfg.pod_ips)}"
|
||||
f" --devices {self.cfg.device_ids} {py_script}"
|
||||
f" --max_num_seqs {self.cfg.max_num_seqs} --max_model_len {self.cfg.max_model_len}"
|
||||
f" --gpu_memory_utilization {self.cfg.cache_config.gpu_memory_utilization}"
|
||||
@@ -1015,7 +1015,7 @@ class LLMEngine(object):
|
||||
f" --device_ids {self.cfg.device_ids}"
|
||||
f" --tensor_parallel_size {self.cfg.tensor_parallel_size}"
|
||||
f" --engine_worker_queue_port {str(self.cfg.engine_worker_queue_port)}"
|
||||
f" --pod_ip {self.cfg.pod_ips[0]}"
|
||||
f" --pod_ip {self.cfg.master_ip}"
|
||||
f" --total_block_num {self.cfg.cache_config.total_block_num}"
|
||||
f" --block_size {self.cfg.cache_config.block_size}"
|
||||
f" --enc_dec_block_num {self.cfg.cache_config.enc_dec_block_num}"
|
||||
@@ -1056,7 +1056,11 @@ class LLMEngine(object):
|
||||
if value:
|
||||
arguments = arguments + f" --{worker_flag}"
|
||||
if self.cfg.nnode > 1:
|
||||
pd_cmd = pd_cmd + f" --ips {self.cfg.ips}"
|
||||
pd_cmd = pd_cmd + (
|
||||
f" --master {self.cfg.dist_init_addr}"
|
||||
f" --nnodes {str(self.cfg.nnode)}"
|
||||
f" --rank {str(self.cfg.node_rank)}"
|
||||
)
|
||||
pd_cmd = pd_cmd + arguments + f" 2>{log_dir}/launch_worker.log"
|
||||
llm_logger.info("Launch worker service command: {}".format(pd_cmd))
|
||||
p = subprocess.Popen(
|
||||
@@ -1157,7 +1161,7 @@ class LLMEngine(object):
|
||||
cache_config=self.cfg.cache_config,
|
||||
tensor_parallel_size=self.cfg.tensor_parallel_size,
|
||||
device_ids=device_ids,
|
||||
pod_ip=self.cfg.pod_ips[0],
|
||||
pod_ip=self.cfg.master_ip,
|
||||
engine_worker_queue_port=self.cfg.engine_worker_queue_port,
|
||||
pid_suffix=self.ipc_signal_suffix)
|
||||
def check_health(self, time_interval_threashold=30):
|
||||
@@ -1244,8 +1248,9 @@ class LLMEngine(object):
|
||||
"""
|
||||
start queue service for engine worker communication
|
||||
"""
|
||||
address = (self.cfg.pod_ips[0], self.cfg.engine_worker_queue_port)
|
||||
if self.cfg.host_ip == self.cfg.pod_ips[0] or self.cfg.pod_ips[0] == "0.0.0.0":
|
||||
address = (self.cfg.master_ip, self.cfg.engine_worker_queue_port)
|
||||
if self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0":
|
||||
llm_logger.info(f"Starting engine worker queue server service at {address}")
|
||||
self.engine_worker_queue_server = EngineWorkerQueue(
|
||||
address=address,
|
||||
is_server=True,
|
||||
@@ -1255,7 +1260,7 @@ class LLMEngine(object):
|
||||
|
||||
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != 'mixed':
|
||||
self.cache_task_queue = EngineCacheQueue(
|
||||
address=(self.cfg.pod_ips[0], self.cfg.cache_config.cache_queue_port),
|
||||
address=(self.cfg.master_ip, self.cfg.cache_config.cache_queue_port),
|
||||
authkey=b'cache_queue_service',
|
||||
is_server=True,
|
||||
num_client=self.cfg.tensor_parallel_size,
|
||||
@@ -1269,4 +1274,6 @@ class LLMEngine(object):
|
||||
is_server=False,
|
||||
num_client=self.cfg.tensor_parallel_size,
|
||||
client_id=0,
|
||||
local_data_parallel_id=0)
|
||||
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
|
||||
local_data_parallel_id= min(self.cfg.worker_num_per_node * self.cfg.node_rank,
|
||||
self.cfg.parallel_config.data_parallel_size - 1))
|
||||
|
@@ -49,8 +49,8 @@ class ExpertService(object):
|
||||
cfg (Config): Config object containing all the configuration parameters.
|
||||
"""
|
||||
self.cfg = cfg
|
||||
start_pos = local_data_parallel_id * self.cfg.tensor_parallel_size
|
||||
end_pos = (local_data_parallel_id + 1) * self.cfg.tensor_parallel_size
|
||||
start_pos = (local_data_parallel_id * self.cfg.tensor_parallel_size) % self.cfg.worker_num_per_node
|
||||
end_pos = ((local_data_parallel_id + 1) * self.cfg.tensor_parallel_size) % self.cfg.worker_num_per_node
|
||||
self.cfg.cache_config.rdma_comm_ports = self.cfg.cache_config.rdma_comm_ports[
|
||||
start_pos:end_pos]
|
||||
self.cfg.local_device_ids = self.cfg.device_ids.split(
|
||||
@@ -65,7 +65,7 @@ class ExpertService(object):
|
||||
|
||||
self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id
|
||||
|
||||
address = (cfg.pod_ips[0], cfg.engine_worker_queue_port)
|
||||
address = (cfg.master_ip, cfg.engine_worker_queue_port)
|
||||
self.engine_worker_queue = EngineWorkerQueue(
|
||||
address=address,
|
||||
is_server=False,
|
||||
@@ -118,7 +118,7 @@ class ExpertService(object):
|
||||
cache_config=self.cfg.cache_config,
|
||||
tensor_parallel_size=self.cfg.tensor_parallel_size,
|
||||
device_ids=self.cfg.local_device_ids,
|
||||
pod_ip=self.cfg.pod_ips[0],
|
||||
pod_ip=self.cfg.master_ip,
|
||||
engine_worker_queue_port=self.cfg.engine_worker_queue_port,
|
||||
pid_suffix=f"{local_data_parallel_id}_{ipc_signal_suffix}"
|
||||
)
|
||||
|
@@ -20,7 +20,7 @@ import time
|
||||
from dataclasses import asdict, dataclass, fields
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import numpy
|
||||
import numpy as np
|
||||
|
||||
from fastdeploy.engine.sampling_params import SamplingParams
|
||||
from fastdeploy.utils import data_processor_logger
|
||||
@@ -181,7 +181,7 @@ class Request:
|
||||
f"sampling_params={self.sampling_params})")
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(slots=True)
|
||||
class CompletionOutput:
|
||||
"""The output data of one completion output of a request.
|
||||
|
||||
@@ -235,7 +235,7 @@ class CompletionOutput:
|
||||
f"reasoning_content={self.reasoning_content!r}")
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(slots=True)
|
||||
class RequestMetrics:
|
||||
"""Metrics associated with a request.
|
||||
|
||||
@@ -310,6 +310,10 @@ class RequestOutput:
|
||||
None if decoder-only.
|
||||
num_cached_tokens: The number of tokens with prefix cache hit.
|
||||
"""
|
||||
__slots__ = (
|
||||
'request_id', 'prompt', 'prompt_token_ids', 'outputs',
|
||||
'finished', 'metrics', 'num_cached_tokens', 'error_code', 'error_msg'
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -333,6 +337,12 @@ class RequestOutput:
|
||||
self.error_code = error_code
|
||||
self.error_msg = error_msg
|
||||
|
||||
|
||||
if prompt_token_ids is None:
|
||||
self.prompt_token_ids = []
|
||||
elif isinstance(self.prompt_token_ids, np.ndarray):
|
||||
self.prompt_token_ids = self.prompt_token_ids.tolist()
|
||||
|
||||
def add(self, next_output: "RequestOutput") -> None:
|
||||
"""Merge RequestOutput into this one"""
|
||||
|
||||
@@ -365,11 +375,6 @@ class RequestOutput:
|
||||
|
||||
def to_dict(self):
|
||||
"""convert RequestOutput into a serializable dict """
|
||||
if self.prompt_token_ids is None:
|
||||
self.prompt_token_ids = []
|
||||
|
||||
if type(self.prompt_token_ids) is numpy.ndarray:
|
||||
self.prompt_token_ids = self.prompt_token_ids.tolist()
|
||||
|
||||
return {
|
||||
"request_id": self.request_id,
|
||||
|
@@ -85,7 +85,7 @@ class LLM:
|
||||
|
||||
self.mutex = threading.Lock()
|
||||
self.req_output = dict()
|
||||
self.master_node_ip = self.llm_engine.cfg.pod_ips[0]
|
||||
self.master_node_ip = self.llm_engine.cfg.master_ip
|
||||
self._receive_output_thread = threading.Thread(
|
||||
target=self._receive_output, daemon=True)
|
||||
self._receive_output_thread.start()
|
||||
@@ -169,6 +169,8 @@ class LLM:
|
||||
|
||||
# get output
|
||||
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm)
|
||||
for i in range(len(outputs)):
|
||||
outputs[i].prompt = prompts[i]
|
||||
return outputs
|
||||
|
||||
def chat(
|
||||
|
@@ -122,8 +122,8 @@ async def lifespan(app: FastAPI):
|
||||
args.mm_processor_kwargs, args.enable_mm,
|
||||
args.reasoning_parser)
|
||||
app.state.dynamic_load_weight = args.dynamic_load_weight
|
||||
chat_handler = OpenAIServingChat(engine_client, pid, args.pod_ips)
|
||||
completion_handler = OpenAIServingCompletion(engine_client, pid, args.pod_ips)
|
||||
chat_handler = OpenAIServingChat(engine_client, pid, args.dist_init_ip)
|
||||
completion_handler = OpenAIServingCompletion(engine_client, pid, args.dist_init_ip)
|
||||
engine_client.create_zmq_client(model=pid, mode=zmq.PUSH)
|
||||
engine_client.pid = pid
|
||||
app.state.engine_client = engine_client
|
||||
|
@@ -21,6 +21,7 @@ import traceback
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
import msgpack
|
||||
import aiozmq
|
||||
from aiozmq import zmq
|
||||
|
||||
@@ -39,16 +40,16 @@ class OpenAIServingChat:
|
||||
OpenAI-style chat completions serving
|
||||
"""
|
||||
|
||||
def __init__(self, engine_client, pid, pod_ips):
|
||||
def __init__(self, engine_client, pid, dist_init_ip):
|
||||
self.engine_client = engine_client
|
||||
self.pid = pid
|
||||
self.pod_ips = pod_ips
|
||||
self.master_ip = dist_init_ip
|
||||
self.host_ip = get_host_ip()
|
||||
|
||||
def _check_master(self):
|
||||
if self.pod_ips is None:
|
||||
if self.master_ip is None:
|
||||
return True
|
||||
if self.host_ip == self.pod_ips[0]:
|
||||
if self.host_ip == self.master_ip:
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -143,6 +144,8 @@ class OpenAIServingChat:
|
||||
dealer.write([b"", request_id.encode('utf-8')])
|
||||
choices = []
|
||||
current_waiting_time = 0
|
||||
if request.metadata is not None:
|
||||
enable_thinking = request.metadata.get("enable_thinking")
|
||||
while num_choices > 0:
|
||||
try:
|
||||
raw_data = await asyncio.wait_for(dealer.read(), timeout=10)
|
||||
@@ -158,102 +161,106 @@ class OpenAIServingChat:
|
||||
raise ValueError(f"Engine is not healthy: {msg}")
|
||||
else:
|
||||
current_waiting_time = 0
|
||||
await asyncio.sleep(0.1)
|
||||
await asyncio.sleep(0.01)
|
||||
continue
|
||||
response = msgpack.unpackb(raw_data[-1])
|
||||
for res in response:
|
||||
if res.get("error_code", 200) != 200:
|
||||
raise ValueError("{}".format(res["error_msg"]))
|
||||
|
||||
res = json.loads(raw_data[-1].decode('utf-8'))
|
||||
if res.get("error_code", 200) != 200:
|
||||
raise ValueError("{}".format(res["error_msg"]))
|
||||
if request.metadata is not None:
|
||||
enable_thinking = request.metadata.get("enable_thinking")
|
||||
self.engine_client.data_processor.process_response_dict(
|
||||
res, stream=True, enable_thinking=enable_thinking)
|
||||
self.engine_client.data_processor.process_response_dict(
|
||||
res, stream=True, enable_thinking=enable_thinking)
|
||||
|
||||
if res['metrics']['first_token_time'] is not None:
|
||||
arrival_time = res['metrics']['first_token_time']
|
||||
inference_start_time = res['metrics']['inference_start_time']
|
||||
else:
|
||||
arrival_time = res['metrics']['arrival_time'] - inference_start_time
|
||||
if first_iteration:
|
||||
num_prompt_tokens = len(prompt_token_ids)
|
||||
num_cached_tokens = res.get("num_cached_tokens", 0)
|
||||
for i in range(num_choices):
|
||||
choice = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(role="assistant", content="", reasoning_content="", tool_calls=None)
|
||||
)
|
||||
if request.metadata is not None and request.metadata.get("training", False):
|
||||
choice.delta.token_ids = prompt_token_ids
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice],
|
||||
model=model_name
|
||||
)
|
||||
if include_continuous_usage:
|
||||
chunk.usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=num_prompt_tokens,
|
||||
prompt_tokens_details=PromptTokenUsageInfo(cached_tokens=num_cached_tokens)
|
||||
)
|
||||
yield f"data: {chunk.model_dump_json(exclude_unset=True)} \n\n"
|
||||
first_iteration = False
|
||||
|
||||
output = res["outputs"]
|
||||
delta_text = output["text"]
|
||||
raw_top_logprobs = output["top_logprobs"]
|
||||
logprobs_res = None
|
||||
if raw_top_logprobs is not None:
|
||||
top_logprobs = LogprobsLists(
|
||||
logprob_token_ids=raw_top_logprobs[0],
|
||||
logprobs=raw_top_logprobs[1],
|
||||
sampled_token_ranks=raw_top_logprobs[2],
|
||||
)
|
||||
logprobs_res = self.build_logprobs_response(
|
||||
request_logprobs=request.logprobs,
|
||||
response_logprobs=top_logprobs,
|
||||
request_top_logprobs=request.top_logprobs,
|
||||
)
|
||||
|
||||
previous_num_tokens += len(output["token_ids"])
|
||||
delta_message = DeltaMessage(content=delta_text, reasoning_content=output.get("reasoning_content"), \
|
||||
token_ids=output.get("token_ids"), tool_calls=output.get("tool_call_content", []))
|
||||
|
||||
choice = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=delta_message,
|
||||
logprobs=logprobs_res,
|
||||
arrival_time=arrival_time
|
||||
)
|
||||
if res["finished"]:
|
||||
num_choices -= 1
|
||||
work_process_metrics.e2e_request_latency.observe(time.time() - res["metrics"]["request_start_time"])
|
||||
has_no_token_limit = request.max_tokens is None and request.max_completion_tokens is None
|
||||
max_tokens = request.max_completion_tokens or request.max_tokens
|
||||
if has_no_token_limit or previous_num_tokens != max_tokens:
|
||||
choice.finish_reason = "stop"
|
||||
if self.engine_client.reasoning_parser == "ernie_x1" and \
|
||||
output.get("finish_reason", "") == "tool_calls":
|
||||
choice.finish_reason = "tool_calls"
|
||||
if res['metrics']['first_token_time'] is not None:
|
||||
arrival_time = res['metrics']['first_token_time']
|
||||
inference_start_time = res['metrics']['inference_start_time']
|
||||
else:
|
||||
choice.finish_reason = "length"
|
||||
arrival_time = res['metrics']['arrival_time'] - inference_start_time
|
||||
if first_iteration:
|
||||
num_prompt_tokens = len(prompt_token_ids)
|
||||
num_cached_tokens = res.get("num_cached_tokens", 0)
|
||||
for i in range(num_choices):
|
||||
choice = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(role="assistant", content="", reasoning_content="", tool_calls=None)
|
||||
)
|
||||
if request.metadata is not None and request.metadata.get("training", False):
|
||||
choice.delta.token_ids = prompt_token_ids
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice],
|
||||
model=model_name
|
||||
)
|
||||
if include_continuous_usage:
|
||||
chunk.usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=num_prompt_tokens,
|
||||
prompt_tokens_details=PromptTokenUsageInfo(cached_tokens=num_cached_tokens)
|
||||
)
|
||||
yield f"data: {chunk.model_dump_json(exclude_unset=True)} \n\n"
|
||||
first_iteration = False
|
||||
|
||||
if res.get("error_msg") is not None and "Recover" in res["error_msg"]:
|
||||
choice.finish_reason = "recover_stop"
|
||||
output = res["outputs"]
|
||||
delta_text = output["text"]
|
||||
raw_top_logprobs = output["top_logprobs"]
|
||||
logprobs_res = None
|
||||
if raw_top_logprobs is not None:
|
||||
top_logprobs = LogprobsLists(
|
||||
logprob_token_ids=raw_top_logprobs[0],
|
||||
logprobs=raw_top_logprobs[1],
|
||||
sampled_token_ranks=raw_top_logprobs[2],
|
||||
)
|
||||
logprobs_res = self.build_logprobs_response(
|
||||
request_logprobs=request.logprobs,
|
||||
response_logprobs=top_logprobs,
|
||||
request_top_logprobs=request.top_logprobs,
|
||||
)
|
||||
|
||||
if request.metadata is not None and request.metadata.get("training", False) and delta_text != "":
|
||||
choice.delta.token_ids = output["token_ids"]
|
||||
if include_continuous_usage:
|
||||
chunk.usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=previous_num_tokens,
|
||||
total_tokens=num_prompt_tokens + previous_num_tokens
|
||||
previous_num_tokens += len(output["token_ids"])
|
||||
delta_message = DeltaMessage(content=delta_text, reasoning_content=output.get("reasoning_content"), \
|
||||
token_ids=output.get("token_ids"), tool_calls=output.get("tool_call_content", []))
|
||||
|
||||
choice = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=delta_message,
|
||||
logprobs=logprobs_res,
|
||||
arrival_time=arrival_time
|
||||
)
|
||||
choices.append(choice)
|
||||
if res["finished"]:
|
||||
num_choices -= 1
|
||||
work_process_metrics.e2e_request_latency.observe(time.time() - res["metrics"]["request_start_time"])
|
||||
has_no_token_limit = request.max_tokens is None and request.max_completion_tokens is None
|
||||
max_tokens = request.max_completion_tokens or request.max_tokens
|
||||
if has_no_token_limit or previous_num_tokens != max_tokens:
|
||||
choice.finish_reason = "stop"
|
||||
if self.engine_client.reasoning_parser == "ernie_x1" and \
|
||||
output.get("finish_reason", "") == "tool_calls":
|
||||
choice.finish_reason = "tool_calls"
|
||||
else:
|
||||
choice.finish_reason = "length"
|
||||
|
||||
if len(choices) == max_streaming_response_tokens or res["finished"]:
|
||||
if res.get("error_msg") is not None and "Recover" in res["error_msg"]:
|
||||
choice.finish_reason = "recover_stop"
|
||||
|
||||
if request.metadata is not None and request.metadata.get("training", False) and delta_text != "":
|
||||
choice.delta.token_ids = output["token_ids"]
|
||||
if include_continuous_usage:
|
||||
chunk.usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=previous_num_tokens,
|
||||
total_tokens=num_prompt_tokens + previous_num_tokens
|
||||
)
|
||||
choices.append(choice)
|
||||
|
||||
if len(choices) == max_streaming_response_tokens or res["finished"]:
|
||||
chunk.choices = choices
|
||||
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
|
||||
choices = []
|
||||
|
||||
if choices:
|
||||
chunk.choices = choices
|
||||
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
|
||||
choices = []
|
||||
@@ -321,33 +328,38 @@ class OpenAIServingChat:
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
|
||||
data = json.loads(raw_data[-1].decode('utf-8'))
|
||||
if data.get("error_code", 200) != 200:
|
||||
raise ValueError("{}".format(data["error_msg"]))
|
||||
if request.metadata is not None:
|
||||
enable_thinking = request.metadata.get("enable_thinking")
|
||||
data = self.engine_client.data_processor.process_response_dict(
|
||||
data, stream=False, enable_thinking=enable_thinking)
|
||||
# api_server_logger.debug(f"Client {request_id} received: {data}")
|
||||
previous_num_tokens += len(data["outputs"]["token_ids"])
|
||||
# The logprob for handling the response
|
||||
output = data["outputs"]
|
||||
raw_top_logprobs = output["top_logprobs"]
|
||||
if raw_top_logprobs is not None:
|
||||
top_logprobs = LogprobsLists(
|
||||
logprob_token_ids=raw_top_logprobs[0],
|
||||
logprobs=raw_top_logprobs[1],
|
||||
sampled_token_ranks=raw_top_logprobs[2],
|
||||
)
|
||||
logprobs_res = self.build_logprobs_response(
|
||||
request_logprobs=request.logprobs,
|
||||
response_logprobs=top_logprobs,
|
||||
request_top_logprobs=request.top_logprobs,
|
||||
)
|
||||
if logprobs_res and logprobs_res.content is not None:
|
||||
logprob_contents.extend(logprobs_res.content)
|
||||
if data["finished"]:
|
||||
final_res = data
|
||||
response = msgpack.unpackb(raw_data[-1])
|
||||
task_is_finished = False
|
||||
for data in response:
|
||||
if data.get("error_code", 200) != 200:
|
||||
raise ValueError("{}".format(data["error_msg"]))
|
||||
if request.metadata is not None:
|
||||
enable_thinking = request.metadata.get("enable_thinking")
|
||||
data = self.engine_client.data_processor.process_response_dict(
|
||||
data, stream=False, enable_thinking=enable_thinking)
|
||||
# api_server_logger.debug(f"Client {request_id} received: {data}")
|
||||
previous_num_tokens += len(data["outputs"]["token_ids"])
|
||||
# The logprob for handling the response
|
||||
output = data["outputs"]
|
||||
raw_top_logprobs = output["top_logprobs"]
|
||||
if raw_top_logprobs is not None:
|
||||
top_logprobs = LogprobsLists(
|
||||
logprob_token_ids=raw_top_logprobs[0],
|
||||
logprobs=raw_top_logprobs[1],
|
||||
sampled_token_ranks=raw_top_logprobs[2],
|
||||
)
|
||||
logprobs_res = self.build_logprobs_response(
|
||||
request_logprobs=request.logprobs,
|
||||
response_logprobs=top_logprobs,
|
||||
request_top_logprobs=request.top_logprobs,
|
||||
)
|
||||
if logprobs_res and logprobs_res.content is not None:
|
||||
logprob_contents.extend(logprobs_res.content)
|
||||
if data["finished"]:
|
||||
final_res = data
|
||||
task_is_finished = True
|
||||
break
|
||||
if task_is_finished:
|
||||
break
|
||||
finally:
|
||||
dealer.close()
|
||||
|
@@ -17,6 +17,7 @@
|
||||
import asyncio
|
||||
import aiozmq
|
||||
import json
|
||||
import msgpack
|
||||
from aiozmq import zmq
|
||||
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
|
||||
import time
|
||||
@@ -44,16 +45,16 @@ from fastdeploy.engine.request import RequestOutput
|
||||
|
||||
|
||||
class OpenAIServingCompletion:
|
||||
def __init__(self, engine_client, pid, pod_ips):
|
||||
def __init__(self, engine_client, pid, dist_init_ip):
|
||||
self.engine_client = engine_client
|
||||
self.pid = pid
|
||||
self.pod_ips = pod_ips
|
||||
self.master_ip = dist_init_ip
|
||||
self.host_ip = get_host_ip()
|
||||
|
||||
def _check_master(self):
|
||||
if self.pod_ips is None:
|
||||
if self.master_ip is None:
|
||||
return True
|
||||
if self.host_ip == self.pod_ips[0]:
|
||||
if self.host_ip == self.master_ip:
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -179,18 +180,20 @@ class OpenAIServingCompletion:
|
||||
current_waiting_time = 0
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
data = json.loads(raw_data[-1].decode("utf-8"))
|
||||
rid = int(data["request_id"].split("-")[-1])
|
||||
if data.get("error_code", 200) != 200:
|
||||
raise ValueError("{}".format(data["error_msg"]))
|
||||
response = msgpack.unpackb(raw_data[-1])
|
||||
for data in response:
|
||||
rid = int(data["request_id"].split("-")[-1])
|
||||
if data.get("error_code", 200) != 200:
|
||||
raise ValueError("{}".format(data["error_msg"]))
|
||||
|
||||
self.engine_client.data_processor.process_response_dict(
|
||||
data, stream=False)
|
||||
output_tokens[rid] += len(data["outputs"]["token_ids"])
|
||||
if data.get("finished", False):
|
||||
data["output_token_ids"] = output_tokens[rid]
|
||||
valid_results[rid] = data
|
||||
num_choices -= 1
|
||||
self.engine_client.data_processor.process_response_dict(
|
||||
data, stream=False)
|
||||
output_tokens[rid] += len(data["outputs"]["token_ids"])
|
||||
if data.get("finished", False):
|
||||
data["output_token_ids"] = output_tokens[rid]
|
||||
valid_results[rid] = data
|
||||
num_choices -= 1
|
||||
break
|
||||
|
||||
return self.request_output_to_completion_response(
|
||||
final_res_batch=valid_results,
|
||||
@@ -238,6 +241,12 @@ class OpenAIServingCompletion:
|
||||
if request.suffix is not None and request.suffix.get("max_streaming_response_tokens", 1) > 1:
|
||||
max_streaming_response_tokens = request.suffix["max_streaming_response_tokens"]
|
||||
choices = []
|
||||
chunk = CompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=choices
|
||||
)
|
||||
|
||||
current_waiting_time = 0
|
||||
while num_choices > 0:
|
||||
@@ -256,82 +265,86 @@ class OpenAIServingCompletion:
|
||||
continue
|
||||
|
||||
|
||||
res = json.loads(raw_data[-1].decode('utf-8'))
|
||||
idx = int(res["request_id"].split("-")[-1])
|
||||
if res.get("error_code", 200) != 200:
|
||||
raise ValueError("{}".format(res["error_msg"]))
|
||||
response = msgpack.unpackb(raw_data[-1])
|
||||
for res in response:
|
||||
idx = int(res["request_id"].split("-")[-1])
|
||||
if res.get("error_code", 200) != 200:
|
||||
raise ValueError("{}".format(res["error_msg"]))
|
||||
|
||||
if first_iteration[idx]:
|
||||
if request.suffix is not None and request.suffix.get("training", False):
|
||||
if first_iteration[idx]:
|
||||
if request.suffix is not None and request.suffix.get("training", False):
|
||||
chunk = CompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[CompletionResponseStreamChoice(
|
||||
index=idx,
|
||||
text="",
|
||||
token_ids=list(prompt_batched_token_ids[idx])
|
||||
)]
|
||||
)
|
||||
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
|
||||
first_iteration[idx] = False
|
||||
|
||||
|
||||
self.engine_client.data_processor.process_response_dict(
|
||||
res, stream=True)
|
||||
if res['metrics'].get('first_token_time') is not None:
|
||||
arrival_time = res['metrics']['first_token_time']
|
||||
inference_start_time[idx] = res['metrics']['inference_start_time']
|
||||
else:
|
||||
arrival_time = res['metrics']['arrival_time'] - inference_start_time[idx]
|
||||
|
||||
output = res["outputs"]
|
||||
|
||||
choices.append(CompletionResponseStreamChoice(
|
||||
index=idx,
|
||||
text=output["text"],
|
||||
token_ids=output.get("token_ids"),
|
||||
tool_calls=output.get("tool_call_content"),
|
||||
reasoning_content=output.get("reasoning_content"),
|
||||
arrival_time=arrival_time
|
||||
))
|
||||
if res["finished"]:
|
||||
if request.max_tokens is None or output_tokens[idx] + 1 != request.max_tokens:
|
||||
chunk.choices[0].finish_reason = "stop"
|
||||
if self.engine_client.reasoning_parser == "ernie_x1" and \
|
||||
output.get("finish_reason", "") == "tool_calls":
|
||||
chunk.choices[0].finish_reason = "tool_calls"
|
||||
else:
|
||||
chunk.choices[0].finish_reason = "length"
|
||||
|
||||
output_tokens[idx] += 1
|
||||
|
||||
if len(choices) == max_streaming_response_tokens or res["finished"]:
|
||||
chunk = CompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[CompletionResponseStreamChoice(
|
||||
index=idx,
|
||||
text="",
|
||||
token_ids=list(prompt_batched_token_ids[idx])
|
||||
)]
|
||||
choices=choices
|
||||
)
|
||||
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
|
||||
first_iteration[idx] = False
|
||||
choices = []
|
||||
|
||||
|
||||
self.engine_client.data_processor.process_response_dict(
|
||||
res, stream=True)
|
||||
if res['metrics'].get('first_token_time') is not None:
|
||||
arrival_time = res['metrics']['first_token_time']
|
||||
inference_start_time[idx] = res['metrics']['inference_start_time']
|
||||
else:
|
||||
arrival_time = res['metrics']['arrival_time'] - inference_start_time[idx]
|
||||
# api_server_logger.info(f"{arrival_time}")
|
||||
|
||||
output = res["outputs"]
|
||||
|
||||
choices.append(CompletionResponseStreamChoice(
|
||||
index=idx,
|
||||
text=output["text"],
|
||||
token_ids=output.get("token_ids"),
|
||||
tool_calls=output.get("tool_call_content"),
|
||||
reasoning_content=output.get("reasoning_content"),
|
||||
arrival_time=arrival_time
|
||||
))
|
||||
if res["finished"]:
|
||||
if request.max_tokens is None or output_tokens[idx] + 1 != request.max_tokens:
|
||||
chunk.choices[0].finish_reason = "stop"
|
||||
if self.engine_client.reasoning_parser == "ernie_x1" and \
|
||||
output.get("finish_reason", "") == "tool_calls":
|
||||
chunk.choices[0].finish_reason = "tool_calls"
|
||||
else:
|
||||
chunk.choices[0].finish_reason = "length"
|
||||
|
||||
output_tokens[idx] += 1
|
||||
|
||||
if len(choices) == max_streaming_response_tokens or res["finished"]:
|
||||
chunk = CompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=choices
|
||||
)
|
||||
choices = []
|
||||
|
||||
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
|
||||
|
||||
if res["finished"]:
|
||||
num_choices -= 1
|
||||
if getattr(request, "stream_options", None) and request.stream_options.include_usage:
|
||||
usage_chunk = CompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[],
|
||||
usage=UsageInfo(
|
||||
prompt_tokens=len(prompt_batched_token_ids[idx]),
|
||||
completion_tokens=output_tokens[idx]
|
||||
if res["finished"]:
|
||||
num_choices -= 1
|
||||
if getattr(request, "stream_options", None) and request.stream_options.include_usage:
|
||||
usage_chunk = CompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[],
|
||||
usage=UsageInfo(
|
||||
prompt_tokens=len(prompt_batched_token_ids[idx]),
|
||||
completion_tokens=output_tokens[idx]
|
||||
)
|
||||
)
|
||||
)
|
||||
yield f"data: {usage_chunk.model_dump_json(exclude_unset=True)}\n\n"
|
||||
yield f"data: {usage_chunk.model_dump_json(exclude_unset=True)}\n\n"
|
||||
if choices:
|
||||
chunk.choices = choices
|
||||
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
|
||||
choices = []
|
||||
|
||||
|
||||
except Exception as e:
|
||||
|
@@ -101,6 +101,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# Whether to use DeepGemm for FP8 blockwise MoE.
|
||||
"FD_USE_DEEP_GEMM":
|
||||
lambda: bool(int(os.getenv("FD_USE_DEEP_GEMM", "1"))),
|
||||
|
||||
# Whether to use aggregate send.
|
||||
"FD_USE_AGGREGATE_SEND":
|
||||
lambda: bool(int(os.getenv("FD_USE_AGGREGATE_SEND", "0"))),
|
||||
}
|
||||
|
||||
|
||||
|
@@ -20,6 +20,7 @@ import threading
|
||||
import time
|
||||
|
||||
import zmq
|
||||
import msgpack
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.utils import llm_logger
|
||||
@@ -37,6 +38,7 @@ class ZmqClient:
|
||||
self.router_path = f"/dev/shm/router_{name}.ipc"
|
||||
|
||||
self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM)
|
||||
self.aggregate_send = envs.FD_USE_AGGREGATE_SEND
|
||||
|
||||
self.mutex = threading.Lock()
|
||||
self.req_dict = dict()
|
||||
@@ -93,6 +95,16 @@ class ZmqClient:
|
||||
"""
|
||||
return self.socket.recv_pyobj()
|
||||
|
||||
def pack_aggregated_data(self, data):
|
||||
"""
|
||||
Aggregate multiple responses into one and send them to the client.
|
||||
"""
|
||||
result = data[0]
|
||||
if len(data) > 1:
|
||||
for response in data[1:]:
|
||||
result.add(response)
|
||||
result = msgpack.packb([result.to_dict()])
|
||||
return result
|
||||
def send_multipart(self, req_id, data):
|
||||
"""
|
||||
Send a multipart message to the router socket.
|
||||
@@ -116,14 +128,22 @@ class ZmqClient:
|
||||
break
|
||||
|
||||
try:
|
||||
result = json.dumps(data.to_dict()).encode('utf-8')
|
||||
start_send = time.time()
|
||||
if self.aggregate_send:
|
||||
result = self.pack_aggregated_data(data)
|
||||
else:
|
||||
result = msgpack.packb([response.to_dict() for response in data])
|
||||
self.router.send_multipart([self.req_dict[req_id], b'', result])
|
||||
llm_logger.debug(f"send_multipart result: {req_id} len {len(data)} elapse: {time.time()-start_send}")
|
||||
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Send result to zmq client failed: {e}")
|
||||
|
||||
if data.finished:
|
||||
if data[-1].finished:
|
||||
with self.mutex:
|
||||
self.req_dict.pop(data.request_id, None)
|
||||
self.req_dict.pop(req_id, None)
|
||||
llm_logger.info(f"send_multipart finished, req_id: {req_id}")
|
||||
|
||||
|
||||
def receive_json_once(self, block=False):
|
||||
"""
|
||||
|
@@ -505,8 +505,6 @@ class TokenProcessor(object):
|
||||
result.outputs.token_ids.append(token_id)
|
||||
if token_id in task.eos_token_ids or is_prefill or recovery_stop:
|
||||
result.finished = True
|
||||
result.prompt = task.prompt
|
||||
result.prompt_token_ids = task.prompt_token_ids
|
||||
if recovery_stop:
|
||||
result.error_msg = "Recover is not supported, the result is incomplete!"
|
||||
llm_logger.info(
|
||||
|
@@ -27,7 +27,8 @@ from datetime import datetime
|
||||
from logging.handlers import BaseRotatingHandler
|
||||
from pathlib import Path
|
||||
from typing import Literal, TypeVar, Union
|
||||
|
||||
import random
|
||||
import socket
|
||||
import requests
|
||||
import yaml
|
||||
from aistudio_sdk.snapshot_download import snapshot_download
|
||||
@@ -421,6 +422,19 @@ def get_host_ip():
|
||||
return ip
|
||||
|
||||
|
||||
|
||||
|
||||
def get_random_port():
|
||||
while True:
|
||||
port = random.randint(49152, 65535)
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
try:
|
||||
s.bind(("0.0.0.0", port))
|
||||
return port
|
||||
except OSError:
|
||||
continue
|
||||
|
||||
|
||||
def is_port_available(host, port):
|
||||
"""
|
||||
Check the port is available
|
||||
|
@@ -23,6 +23,7 @@ import pynvml
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.utils import get_logger
|
||||
from fastdeploy.worker.gpu_model_runner import GPUModelRunner
|
||||
from fastdeploy.worker.output import ModelRunnerOutput
|
||||
@@ -50,11 +51,12 @@ class GpuWorker(WorkerBase):
|
||||
"""
|
||||
Initialize device and construct model runner
|
||||
"""
|
||||
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
||||
if self.device_config.device_type == "cuda" and paddle.device.is_compiled_with_cuda(
|
||||
):
|
||||
# Set evironment variable
|
||||
self.device_ids = self.parallel_config.device_ids.split(",")
|
||||
self.device = f"gpu:{self.local_rank}"
|
||||
self.device = f"gpu:{self.local_rank % self.max_chips_per_node}"
|
||||
paddle.device.set_device(self.device)
|
||||
paddle.set_default_dtype(self.parallel_config.dtype)
|
||||
|
||||
@@ -72,7 +74,7 @@ class GpuWorker(WorkerBase):
|
||||
self.model_runner: GPUModelRunner = GPUModelRunner(
|
||||
fd_config=self.fd_config,
|
||||
device=self.device,
|
||||
device_id=self.device_ids[self.local_rank],
|
||||
device_id=self.device_ids[self.local_rank % self.max_chips_per_node],
|
||||
rank=self.rank,
|
||||
local_rank=self.local_rank)
|
||||
|
||||
|
@@ -136,9 +136,9 @@ class PaddleDisWorkerProc():
|
||||
model_weights_status:
|
||||
"""
|
||||
# init worker_ready_signal
|
||||
max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
||||
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
||||
array_size = min(
|
||||
max_chips_per_node, self.parallel_config.tensor_parallel_size *
|
||||
self.max_chips_per_node, self.parallel_config.tensor_parallel_size *
|
||||
self.parallel_config.expert_parallel_size)
|
||||
workers_ready = np.zeros(shape=[array_size], dtype=np.int32)
|
||||
self.worker_ready_signal = IPCSignal(
|
||||
@@ -148,10 +148,10 @@ class PaddleDisWorkerProc():
|
||||
suffix=self.parallel_config.engine_pid,
|
||||
create=False)
|
||||
self.worker_ready_signal.value[self.local_rank %
|
||||
max_chips_per_node] = 1
|
||||
self.max_chips_per_node] = 1
|
||||
|
||||
# init worker_healthy_live_signal
|
||||
workers_alive = np.zeros(shape=[self.ranks], dtype=np.int32)
|
||||
workers_alive = np.zeros(shape=[array_size], dtype=np.int32)
|
||||
self.worker_healthy_live_signal = IPCSignal(
|
||||
name="worker_healthy_live_signal",
|
||||
array=workers_alive,
|
||||
@@ -205,7 +205,7 @@ class PaddleDisWorkerProc():
|
||||
Tmp loop function for ep utill DP is supported
|
||||
"""
|
||||
while True:
|
||||
self.worker_healthy_live_signal.value[self.local_rank] = int(
|
||||
self.worker_healthy_live_signal.value[self.local_rank % self.max_chips_per_node] = int(
|
||||
time.time())
|
||||
|
||||
if self.fd_config.parallel_config.tensor_parallel_rank == 0 and self.task_queue.num_tasks(
|
||||
|
@@ -29,9 +29,11 @@ triton==3.3
|
||||
use-triton-in-paddle
|
||||
crcmod
|
||||
fastsafetensors==0.1.14
|
||||
msgpack
|
||||
opentelemetry-api>=1.24.0
|
||||
opentelemetry-sdk>=1.24.0
|
||||
opentelemetry-instrumentation-redis
|
||||
opentelemetry-instrumentation-mysql
|
||||
opentelemetry-distro
|
||||
opentelemetry-exporter-otlp
|
||||
|
||||
|
@@ -27,3 +27,4 @@ moviepy
|
||||
use-triton-in-paddle
|
||||
crcmod
|
||||
fastsafetensors==0.1.14
|
||||
msgpack
|
@@ -27,3 +27,4 @@ moviepy
|
||||
use-triton-in-paddle
|
||||
crcmod
|
||||
fastsafetensors==0.1.14
|
||||
msgpack
|
||||
|
Reference in New Issue
Block a user