[NewFeature]Support dp multi api server && Fix some bug in mixed ep && merge develop (#3598)

* [Feature] update ep

* fix ci

* fix ci

* fix ci

* fix ci

* fix ci

* fix ci

* fix ci

* fix queue ports idx

* fix ci

* fix ci

* fix ci

* fix ci

* fix ci

* fix ci

* fix ci

* fix ci

* Update engine.py

* fix ci

* fix some bug in mixed ep

* add server fix and op fix

* rm some log

* fix code style

* ltd fix

* fix

* fix

* fix some bug

* fix bug

* fix bug

* fix style

* Update config.py

* Update splitwise_connector.py

* Update cache_messager.py

* Update __init__.py

* merge and fix

* Update engine.py

* Update common_engine.py

* Update run_ci_xpu.sh

* Update ernie_processor.py

* Update ernie_processor.py

---------

Co-authored-by: ltd0924 <ltd0924@sina.com>
Co-authored-by: ltd0924 <32387785+ltd0924@users.noreply.github.com>
This commit is contained in:
gaoziyuan
2025-08-26 19:59:02 +08:00
committed by GitHub
parent cbce94a00e
commit 82e64b13e1
24 changed files with 1244 additions and 1200 deletions

View File

@@ -28,6 +28,16 @@
#define DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK, ...) \
switch (num_experts_per_rank) { \
case 2: { \
constexpr size_t NUM_EXPERTS_PER_RANK = 2; \
__VA_ARGS__ \
break; \
} \
case 6: { \
constexpr size_t NUM_EXPERTS_PER_RANK = 6; \
__VA_ARGS__ \
break; \
} \
case 8: { \
constexpr size_t NUM_EXPERTS_PER_RANK = 8; \
__VA_ARGS__ \

View File

@@ -23,7 +23,11 @@ import numpy as np
import paddle
from fastdeploy.cache_manager.transfer_factory import IPCCommManager, RDMACommManager
from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal
from fastdeploy.inter_communicator import (
EngineWorkerQueue,
IPCSignal,
shared_memory_exists,
)
from fastdeploy.utils import get_logger
logger = get_logger("cache_messager", "cache_messager.log")
@@ -159,36 +163,23 @@ class CacheMessager:
try:
prefilled_step_idx_data = np.zeros(shape=[1], dtype=np.int32)
prefilled_layer_idx_data = np.zeros(shape=[1], dtype=np.int32)
try:
prefilled_layer_name = f"splitwise_complete_prefilled_step_{self.dp_rank_id}.{self.gpu_id}"
prefilled_step_name = f"splitwise_complete_prefilled_step_{self.dp_rank_id}.{self.gpu_id}"
step_shm_value = IPCSignal(
name=f"splitwise_complete_prefilled_step_{self.dp_rank_id}",
array=prefilled_step_idx_data,
dtype=np.int32,
suffix=self.gpu_id,
create=True,
create=not shared_memory_exists(prefilled_step_name),
)
layer_shm_value = IPCSignal(
name=f"splitwise_complete_prefilled_layer_{self.dp_rank_id}",
array=prefilled_layer_idx_data,
dtype=np.int32,
suffix=self.gpu_id,
create=True,
)
except:
step_shm_value = IPCSignal(
name=f"splitwise_complete_prefilled_step_{self.dp_rank_id}",
array=prefilled_step_idx_data,
dtype=np.int32,
suffix=self.gpu_id,
create=False,
)
layer_shm_value = IPCSignal(
name=f"splitwise_complete_prefilled_layer_{self.dp_rank_id}",
array=prefilled_layer_idx_data,
dtype=np.int32,
suffix=self.gpu_id,
create=False,
create=not shared_memory_exists(prefilled_layer_name),
)
logger.info(f"splitwise_complete_prefilled_step_{self.dp_rank_id}, gpu_id: {self.gpu_id}")
step_shm_value.value[0] = -1
layer_shm_value.value[0] = -1
@@ -220,6 +211,7 @@ class CacheMessager:
self.cache_info[info["request_id"]] = info
prefilled_layer_idx = layer_shm_value.value[0]
prefilled_step_idx = step_shm_value.value[0]
logger.info(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}")
if prefilled_layer_idx == self.num_layers - 1:
time.sleep(0.001)
prefilled_layer_idx = layer_shm_value.value[0]

View File

@@ -95,7 +95,7 @@ PRETRAINED_INIT_CONFIGURATION = {
"start_layer_index": 0,
"moe_num_shared_experts": 0,
"moe_layer_start_index": 0,
"num_max_dispatch_tokens_per_rank": 256,
"num_max_dispatch_tokens_per_rank": 128,
"moe_use_aux_free": False,
"vocab_size": -1,
"hidden_dropout_prob": 0.0,
@@ -278,7 +278,7 @@ class ParallelConfig:
# block size
self.block_size: int = 64
# Engine worker queue port
self.engine_worker_queue_port: int = 9923
self.engine_worker_queue_port: str = "9923"
# Max model len
self.max_model_len: int = 3072 # max_seq_len
# cuda visible devices
@@ -307,7 +307,11 @@ class ParallelConfig:
for key, value in args.items():
if hasattr(self, key):
setattr(self, key, value)
if isinstance(self.engine_worker_queue_port, str):
self.engine_worker_queue_port = [int(port) for port in self.engine_worker_queue_port.split(",")]
logger.info(f"engine_worker_queue_port: {self.engine_worker_queue_port}")
elif isinstance(self.engine_worker_queue_port, int):
self.engine_worker_queue_port = [self.engine_worker_queue_port]
# currently, the expert parallel size is equal data parallel size
if self.enable_expert_parallel:
self.expert_parallel_size = self.data_parallel_size * self.tensor_parallel_size
@@ -1038,7 +1042,7 @@ class FDConfig:
max_num_batched_tokens: Optional[int] = None,
ips: str = None,
use_warmup: bool = False,
engine_worker_queue_port: int = 8002,
engine_worker_queue_port: str = "8002",
limit_mm_per_prompt: Optional[Dict[str, Any]] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
splitwise_role: str = "mixed",
@@ -1082,11 +1086,10 @@ class FDConfig:
if self.ips is None:
self.master_ip = "0.0.0.0"
elif isinstance(self.ips, list):
self.master_ip = self.ips[0]
else:
elif isinstance(self.ips, str):
self.ips = self.ips.split(",")
self.master_ip = self.ips[0]
self.host_ip = get_host_ip()
if self.ips is None:
self.nnode = 1
@@ -1095,7 +1098,7 @@ class FDConfig:
self.nnode = len(self.ips)
for idx, ip in enumerate(self.ips):
if ip == self.master_ip:
if ip == self.host_ip:
self.node_rank = idx
self.max_model_len = max_model_len
@@ -1111,7 +1114,11 @@ class FDConfig:
self.reasoning_parser = reasoning_parser
self.guided_decoding_backend = guided_decoding_backend
self.disable_any_whitespace = disable_any_whitespace
self.engine_worker_queue_port = engine_worker_queue_port
self._str_to_list("innode_prefill_ports", int)
if isinstance(engine_worker_queue_port, int):
self.engine_worker_queue_port = str(engine_worker_queue_port)
self._str_to_list("engine_worker_queue_port", str)
if envs.FD_FOR_TORCH_MODEL_FORMAT:
self.model_config.model_format = "torch"
@@ -1129,10 +1136,11 @@ class FDConfig:
self.worker_num_per_node = self.max_chips_per_node
nnode = ceil_div(num_ranks, self.worker_num_per_node)
assert nnode == self.nnode, f"nnode: {nnode}, but got {self.nnode}"
# assert nnode == self.nnode, f"nnode: {nnode}, but got {self.nnode}"
else:
self.worker_num_per_node = num_ranks
self.engine_worker_queue_port = engine_worker_queue_port
self.device_ids = ",".join([str(i) for i in range(self.worker_num_per_node)])
self.device_ids = os.getenv("CUDA_VISIBLE_DEVICES", self.device_ids)
if current_platform.is_xpu():
@@ -1155,15 +1163,12 @@ class FDConfig:
self.local_device_ids = self.device_ids.split(",")[: self.parallel_config.tensor_parallel_size]
self.host_ip = get_host_ip()
if self.ips is None or self.host_ip == self.master_ip:
self.is_master = True
else:
self.is_master = False
if self.parallel_config.tensor_parallel_size <= self.worker_num_per_node:
self.is_master = True
self.master_ip = "0.0.0.0"
else:
self.is_master = False
self.master_ip = self.ips[0]
self.paddle_commit_id = paddle.version.commit
@@ -1345,10 +1350,12 @@ class FDConfig:
def _str_to_list(self, attr_name, default_type):
if hasattr(self, attr_name):
val = getattr(self, attr_name)
if val is None:
return
if type(val) is str:
setattr(self, attr_name, [default_type(i) for i in val.split(",")])
else:
setattr(self, attr_name, val)
setattr(self, attr_name, [default_type(i) for i in val])
def __str__(self) -> str:
return json.dumps(self.__dict__, indent=4)

View File

@@ -193,7 +193,7 @@ class EngineArgs:
Flag to enable the custom all-reduce kernel.
"""
engine_worker_queue_port: int = 8002
engine_worker_queue_port: str = "8002"
"""
Port for worker queue communication.
"""
@@ -208,6 +208,11 @@ class EngineArgs:
Number of data parallelism.
"""
local_data_parallel_id: int = 0
"""
Local data parallel id.
"""
enable_expert_parallel: bool = False
"""
Enable expert parallelism.
@@ -498,7 +503,7 @@ class EngineArgs:
)
model_group.add_argument(
"--engine-worker-queue-port",
type=int,
type=lambda s: s.split(",") if s else None,
default=EngineArgs.engine_worker_queue_port,
help="port for engine worker queue",
)
@@ -607,6 +612,13 @@ class EngineArgs:
default=EngineArgs.data_parallel_size,
help="Degree of data parallelism.",
)
parallel_group.add_argument(
"--local-data-parallel-id",
type=int,
default=EngineArgs.local_data_parallel_id,
help="the rank of data parallelism.",
)
parallel_group.add_argument(
"--enable-expert-parallel",
action="store_true",
@@ -947,8 +959,13 @@ class EngineArgs:
early_stop_cfg = self.create_early_stop_config()
early_stop_cfg.update_enable_early_stop(self.enable_early_stop)
if isinstance(self.engine_worker_queue_port, int):
self.engine_worker_queue_port = str(self.engine_worker_queue_port)
if isinstance(self.engine_worker_queue_port, str):
self.engine_worker_queue_port = self.engine_worker_queue_port.split(",")
assert is_port_available(
"0.0.0.0", self.engine_worker_queue_port
"0.0.0.0", int(self.engine_worker_queue_port[parallel_cfg.local_data_parallel_id])
), f"The parameter `engine_worker_queue_port`:{self.engine_worker_queue_port} is already in use."
return FDConfig(

View File

@@ -0,0 +1,754 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
import copy
import os
import threading
import time
import traceback
import weakref
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional, Tuple
import numpy as np
import paddle
import zmq
from opentelemetry import trace
from fastdeploy.engine.request import Request, RequestOutput
from fastdeploy.engine.resource_manager import ResourceManager
from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1
from fastdeploy.inter_communicator import (
EngineCacheQueue,
EngineWorkerQueue,
IPCSignal,
ZmqClient,
)
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.metrics.trace_util import start_span, start_span_request
from fastdeploy.model_executor.guided_decoding import schema_checker
from fastdeploy.output.token_processor import TokenProcessor
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
from fastdeploy.utils import EngineError, envs, llm_logger
class EngineSevice:
"""
Base class containing common engine functionality
"""
def __init__(self, cfg):
"""
Initializes the LLMEngine with the provided configuration.
Args:
cfg (Config): Config object containing all the configuration parameters.
"""
self.cfg = cfg
self.scheduler = cfg.scheduler_config.scheduler()
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.resource_manager = ResourceManagerV1(
cfg.max_num_seqs,
cfg,
cfg.parallel_config.tensor_parallel_size,
cfg.splitwise_role,
cfg.parallel_config.local_data_parallel_id,
)
if cfg.splitwise_role != "mixed":
raise NotImplementedError(
"Currently ENABLE_V1_KVCACHE_SCHEDULER=1 only supported in mixed sampling now."
)
else:
self.resource_manager = ResourceManager(
cfg.max_num_seqs,
cfg,
cfg.parallel_config.tensor_parallel_size,
cfg.splitwise_role,
cfg.parallel_config.local_data_parallel_id,
)
self.start_worker_queue_service()
os.environ["INFERENCE_MSG_QUEUE_ID"] = self.cfg.engine_worker_queue_port[
self.cfg.parallel_config.local_data_parallel_id
]
self.split_connector = SplitwiseConnector(cfg, self.engine_worker_queue, self.resource_manager)
self.waiting_requests = []
self.token_processor = TokenProcessor(
cfg=cfg,
cached_generated_tokens=self.scheduler,
engine_worker_queue=self.engine_worker_queue,
split_connector=self.split_connector,
)
self.token_processor.set_resource_manager(self.resource_manager)
self.partial_chunked_tokens = [0] * (self.cfg.max_num_partial_prefills + 1)
for idx in range(1, self.cfg.max_num_partial_prefills + 1):
self.partial_chunked_tokens[idx] = (
(self.cfg.max_num_batched_tokens // idx)
// self.cfg.cache_config.block_size
* self.cfg.cache_config.block_size
)
self.guided_decoding_checker = None
if self.cfg.guided_decoding_backend != "off":
self.guided_decoding_checker = schema_checker(
self.cfg.guided_decoding_backend,
disable_any_whitespace=self.cfg.disable_any_whitespace,
)
self._init_worker_monitor_signals()
self._finalizer = weakref.finalize(self, self._exit_sub_services)
def start(self):
self.running = True
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.insert_task_to_worker_thread = threading.Thread(target=self._scheduler_task_to_worker_v1, daemon=True)
else:
self.insert_task_to_worker_thread = threading.Thread(target=self._insert_task_to_worker, daemon=True)
self.insert_task_to_worker_thread.start()
self.token_processor.tasks_queue = self.engine_worker_queue
self.token_processor.run()
def _init_worker_monitor_signals(self): # exist_task_signal 用于各worker进程感知是否有新Task需要处理
current_suffix = int(self.cfg.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id])
llm_logger.info(f"current_suffix: {current_suffix}")
exist_task_signal_data = np.zeros([1], dtype=np.int32)
self.exist_task_signal = IPCSignal(
name="exist_task_signal",
array=exist_task_signal_data,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
# exist_swapped_task_signal 用于engine感知worker中是否存在swapped task
exist_swapped_task_signal_data = np.zeros([1], dtype=np.int32)
self.exist_swapped_task_signal = IPCSignal(
name="exist_swapped_task_signal",
array=exist_swapped_task_signal_data,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
# exist_prefill_task_signal 用于各worker进程感知是否进行prefill
exist_prefill_task_signal_data = np.zeros([1], dtype=np.int32)
self.exist_prefill_task_signal = IPCSignal(
name="exist_prefill_task_signal",
array=exist_prefill_task_signal_data,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
# worker_live_signal 用于engine感知各worker进程是否存活记录每个step 时间
worker_healthy_live_recorded_time_array = np.zeros(
shape=[min(self.cfg.worker_num_per_node, self.cfg.parallel_config.tensor_parallel_size)], dtype=np.int32
)
self.worker_healthy_live_signal = IPCSignal(
name="worker_healthy_live_signal",
array=worker_healthy_live_recorded_time_array,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
model_weights_status = np.zeros([1], dtype=np.int32)
self.model_weights_status_signal = IPCSignal(
name="model_weights_status",
array=model_weights_status,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
def start_worker_queue_service(self):
"""
start queue service for engine worker communication
"""
address = (
self.cfg.master_ip,
int(self.cfg.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]),
)
if self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0":
llm_logger.info(f"Starting engine worker queue server service at {address}")
self.engine_worker_queue_server = EngineWorkerQueue(
address=address,
is_server=True,
num_client=self.cfg.parallel_config.tensor_parallel_size,
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
)
if (
self.cfg.cache_config.enable_prefix_caching
or self.cfg.splitwise_role != "mixed"
and self.cfg.parallel_config.local_data_parallel_id == 0
):
self.cache_task_queue = EngineCacheQueue(
address=(
self.cfg.master_ip,
self.cfg.cache_config.cache_queue_port,
),
authkey=b"cache_queue_service",
is_server=True,
num_client=self.cfg.parallel_config.tensor_parallel_size,
client_id=-1,
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
)
llm_logger.info(
f"local {min(self.cfg.worker_num_per_node * self.cfg.node_rank + self.cfg.parallel_config.local_data_parallel_id,self.cfg.parallel_config.data_parallel_size - 1)}"
)
self.engine_worker_queue = EngineWorkerQueue(
address=address,
is_server=False,
num_client=self.cfg.parallel_config.tensor_parallel_size,
client_id=0,
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
local_data_parallel_id=min(
self.cfg.worker_num_per_node * self.cfg.node_rank + self.cfg.parallel_config.local_data_parallel_id,
self.cfg.parallel_config.data_parallel_size - 1,
),
)
def insert_tasks(self, tasks, current_id=-1, allocated=False):
"""
Insert tasks to engine.
"""
for task in tasks:
start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER)
# TODO 返回至 scheduler
if allocated:
current_tasks = []
for task in tasks:
cur_task_idx = self.resource_manager.req_dict[task.request_id]
del self.resource_manager.req_dict[task.request_id]
cur_task = self.resource_manager.tasks_list[cur_task_idx]
cur_task.prompt_token_ids[0] = task.outputs.token_ids[0]
if self.cfg.speculative_config.method in ["mtp"] and self.cfg.splitwise_role == "decode":
cur_task.draft_token_ids = copy.deepcopy(task.outputs.draft_token_ids)
if task.error_code != 200:
self.resource_manager.stop_flags[cur_task_idx] = True
self.resource_manager.tasks_list[cur_task_idx] = None
self.resource_manager._recycle_block_tables(cur_task)
if task.request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[task.request_id]
self.scheduler.put_results([task])
llm_logger.warning(
f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource."
)
continue
self.token_processor.tokens_counter[task.request_id] = 1
current_tasks.append(cur_task)
self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz))
return True
self.resource_manager.check_and_free_block_tables()
if not isinstance(tasks, list):
tasks = [tasks]
for item in tasks:
item.schedule_start_time = time.time()
available_batch = np.sum(self.resource_manager.stop_flags)
if len(tasks) > available_batch:
llm_logger.error(f"Inserting batch:{len(tasks)} exceeds the available batch:{available_batch}.")
llm_logger.error("The exceeded part will be ignored!")
tasks = tasks[:available_batch]
req_ids = [t.request_id for t in tasks]
tasks = self.resource_manager.allocate_resources_for_new_tasks(tasks)
if not tasks:
error_msg = f"The request required resources is exceed the limit, request id={req_ids}."
llm_logger.error(error_msg)
raise EngineError(error_msg, error_code=500)
return False
self.token_processor.number_of_tasks += len(tasks)
is_decode = False
is_prefill = False
for i in range(len(tasks)):
if tasks[i].disaggregate_info is not None:
if tasks[i].disaggregate_info["role"] == "decode":
is_decode = True
else:
is_prefill = True
self.token_processor.number_of_input_tokens += tasks[i].prompt_token_ids_len
self.split_connector.send_cache_infos(tasks, current_id)
if not is_decode:
llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
for task in tasks:
task.inference_start_time = time.time()
if not is_prefill:
if not self.cfg.model_config.enable_mm:
self.update_requests_chunk_size(tasks)
else:
self.update_mm_requests_chunk_size(tasks)
self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz))
if is_prefill and self.cfg.scheduler_config.name != "splitwise":
self.engine_worker_queue.available_prefill_instances.put(1)
return True
def task_is_finished(self, index):
"""
judge if the task is finished
"""
assert index < len(self.resource_manager.stop_flags)
return self.resource_manager.stop_flags[index]
def all_tasks_finished(self):
"""
judge if all tasks are finished
"""
return np.sum(self.resource_manager.stop_flags) == len(self.resource_manager.stop_flags)
def update_requests_chunk_size(self, requests):
"""
update each request's chunk size info
"""
def update_tokens(idx, chunk_size, update_chunk=False):
nonlocal remain_batched_tokens, chunk_request_num
if update_chunk:
requests_chunk[idx][-1] += chunk_size
else:
requests_chunk[idx].append(chunk_size)
remain_batched_tokens -= chunk_size
current_request_size[idx] -= chunk_size
if current_request_size[idx] <= 0:
chunk_request_num -= 1
if not self.cfg.cache_config.enable_chunked_prefill or len(requests) == 0:
return
current_request_size = [request.prompt_token_ids_len for request in requests]
requests_chunk = [[] for _ in range(len(requests))]
chunk_request_num = len(current_request_size)
while chunk_request_num >= 1:
remain_batched_tokens = self.cfg.max_num_batched_tokens
for idx in range(len(current_request_size)):
if current_request_size[idx] <= 0:
continue
chunk_size = min(
current_request_size[idx],
self.partial_chunked_tokens[chunk_request_num],
)
update_tokens(idx, chunk_size)
while remain_batched_tokens >= self.cfg.cache_config.block_size:
# 当前 max_num_batched_tokens 还有剩余时,优先分配给较短的请求
waiting_requests = [input_lens for input_lens in current_request_size if input_lens > 0]
if len(waiting_requests) == 0:
break
available_tokens = (
remain_batched_tokens // self.cfg.cache_config.block_size * self.cfg.cache_config.block_size
)
append_idx = current_request_size.index(min(waiting_requests))
chunk_size = min(
current_request_size[append_idx],
self.partial_chunked_tokens[chunk_request_num],
available_tokens,
)
update_tokens(append_idx, chunk_size, update_chunk=True)
for idx in range(len(requests)):
requests[idx].set("prefill_chunk_info", requests_chunk[idx])
def update_mm_requests_chunk_size(self, requests):
"""
update each multimodal request's chunk size info
"""
if not self.cfg.cache_config.enable_chunked_prefill or len(requests) == 0:
return
for request in requests:
inputs = request.multimodal_inputs
# 兼容没有图片和视频的情况
if inputs["images"] is None:
inputs["image_type_ids"] = np.array([], dtype="int32")
inputs["grid_thw"] = np.array([], dtype="int64")
inputs["images"] = np.array([], dtype="uint8")
input_ids = paddle.to_tensor(inputs["input_ids"], dtype="int64")
image_type_ids = paddle.to_tensor(inputs["image_type_ids"], dtype="int32")
image_mask = input_ids == self.data_processor.image_patch_id
image_token_sum = paddle.full(shape=[len(input_ids) + 1], fill_value=0, dtype="int32")
image_token_sum[1:] = paddle.cumsum(image_mask.cast("int32"))
grid_thw = []
for one in inputs["grid_thw"]:
if one[0] == 1:
grid_thw.append(one)
else:
grid_thw.extend([[2, one[1], one[2]]] * (one[0] // 2))
grid_thw = paddle.to_tensor(grid_thw, dtype="int64")
from fastdeploy.model_executor.ops.gpu import get_mm_split_fuse
chunk_image_num, chunk_seq_len = get_mm_split_fuse(
input_ids,
image_type_ids,
image_token_sum,
grid_thw,
self.data_processor.image_patch_id,
len(grid_thw),
0,
len(input_ids),
0,
self.partial_chunked_tokens[1],
2048,
)
grid_thw = grid_thw.numpy().reshape([-1, 3])
num_chunks = len(chunk_image_num)
chunks_info = []
input_ids_st, image_type_ids_st, grid_thw_st, patch_st = 0, 0, 0, 0
for idx in range(num_chunks):
chunk_input_ids = inputs["input_ids"][input_ids_st : input_ids_st + chunk_seq_len[idx]]
chunk_token_type_ids = inputs["token_type_ids"][input_ids_st : input_ids_st + chunk_seq_len[idx]]
actual_image_num = np.sum(grid_thw[grid_thw_st : grid_thw_st + chunk_image_num[idx], 0])
chunk_image_type_ids = inputs["image_type_ids"][
image_type_ids_st : image_type_ids_st + actual_image_num
]
chunk_grid_thw = grid_thw[grid_thw_st : grid_thw_st + chunk_image_num[idx]]
chunk_patch_num = np.sum(np.prod(chunk_grid_thw, axis=1))
chunk_images = inputs["images"][patch_st : patch_st + chunk_patch_num]
chunks_info.append(
{
"input_ids": chunk_input_ids,
"token_type_ids": chunk_token_type_ids,
"image_type_ids": (chunk_image_type_ids if chunk_image_type_ids.shape[0] else None),
"grid_thw": (chunk_grid_thw if chunk_grid_thw.shape[0] else None),
"images": (chunk_images if chunk_images.shape[0] else None),
"position_ids": None,
}
)
input_ids_st += chunk_seq_len[idx]
image_type_ids_st += actual_image_num
grid_thw_st += chunk_image_num[idx]
patch_st += chunk_patch_num
request.set("prefill_chunk_info", chunks_info)
def _insert_task_to_worker(self):
"""
Insert task to engine thread, monitor scheduler request queue.
if the engine has resource, insert task to engine
"""
current_id = -1
while getattr(self, "running", True):
try:
if self.resource_manager.available_batch() == 0:
time.sleep(0.001)
continue
if self.engine_worker_queue.num_tasks() > 0:
time.sleep(0.001)
continue
if hasattr(self, "exist_prefill_task_signal") and self.exist_prefill_task_signal.value[0] > 0:
if self.cfg.splitwise_role == "mixed" or self.split_connector.has_splitwise_tasks():
time.sleep(0.005)
continue
if self.engine_worker_queue.num_cache_infos() > 0:
time.sleep(0.001)
continue
if len(self.split_connector.current_request_ids) > 0:
time.sleep(0.001)
continue
num_prefill_batch = min(
int(self.resource_manager.available_batch()),
self.cfg.max_prefill_batch,
)
self.resource_manager.check_and_free_block_tables()
tasks = self.scheduler.get_requests(
available_blocks=self.resource_manager.available_block_num(),
block_size=self.cfg.cache_config.block_size,
reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num,
max_num_batched_tokens=self.cfg.max_num_batched_tokens,
batch=num_prefill_batch,
)
if len(tasks) == 0:
time.sleep(0.001)
continue
current_id = (current_id + 1) % 100003
if self.cfg.splitwise_role != "mixed":
llm_logger.info("Inserting splitwise tasks")
self.split_connector.send_splitwise_tasks(tasks, current_id)
self.insert_tasks(tasks, current_id)
main_process_metrics.num_requests_waiting.dec(len(tasks))
main_process_metrics.num_requests_running.inc(len(tasks))
except Exception as e:
err_msg = f"Error happend while insert task to engine: {e}, {traceback.format_exc()!s}."
llm_logger.error(err_msg)
def _scheduler_task_to_worker_v1(self):
"""
Insert tasks to worker with scheduler v1 (ENABLE_V1_KVCACHE_SCHEDULER=1).
"""
get_request_pool = ThreadPoolExecutor(max_workers=1)
is_fetching = False
def _fetch_request():
nonlocal is_fetching
is_fetching = True
num_prefill_batch = min(
int(self.resource_manager.available_batch()),
self.cfg.max_prefill_batch,
)
self.resource_manager.check_and_free_block_tables()
tasks = self.scheduler.get_requests(
available_blocks=self.resource_manager.available_block_num(),
block_size=self.cfg.cache_config.block_size,
reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num,
max_num_batched_tokens=self.cfg.max_model_len,
batch=num_prefill_batch,
)
# Fetch requests and add them to the scheduling queue
for task in tasks:
self.resource_manager.add_request(task)
is_fetching = False
while self.running:
try:
if self.engine_worker_queue.num_tasks() > 0:
time.sleep(0.001)
continue
if (
len(self.resource_manager.waiting) == 0
and (not is_fetching)
and self.exist_prefill_task_signal.value[0] == 0
):
get_request_pool.submit(_fetch_request)
# 2. Schedule requests
tasks = self.resource_manager.schedule()
# 3. Send to engine
if tasks:
self.resource_manager.get_real_bsz()
self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz))
else:
time.sleep(0.005)
except Exception as e:
err_msg = "Error happend while insert task to engine: {}, {}.".format(e, str(traceback.format_exc()))
llm_logger.error(err_msg)
def start_zmq_service(self, api_server_pid=None):
if api_server_pid is None:
return
self.api_server_pid = api_server_pid
self.zmq_server = ZmqClient(name=api_server_pid, mode=zmq.PULL)
self.zmq_server.start_server()
self.zmq_server.create_router()
time.sleep(3)
self.insert_task_to_scheduler_thread = threading.Thread(target=self._insert_zmq_task_to_scheduler, daemon=True)
self.insert_task_to_scheduler_thread.start()
self.receive_output_thread = threading.Thread(target=self._zmq_send_generated_tokens, daemon=True)
self.receive_output_thread.start()
def _insert_zmq_task_to_scheduler(self):
added_requests: Dict[str, int] = dict()
while self.running:
try:
block = True if len(added_requests) == 0 else False
if not self.cfg.model_config.enable_mm:
err, data = self.zmq_server.receive_json_once(block)
else:
err, data = self.zmq_server.receive_pyobj_once(block)
if err is not None:
llm_logger.error("Engine stops inserting zmq task into scheduler, err:{err}")
break
request, insert_task = None, []
results: List[Tuple[str, Optional[str]]] = list()
if data:
request = Request.from_dict(data)
start_span("ENQUEUE_ZMQ", data, trace.SpanKind.PRODUCER)
llm_logger.debug(f"Receive request: {request}")
err_msg = None
if self.guided_decoding_checker is not None:
request, err_msg = self.guided_decoding_checker.schema_format(request)
if err_msg is not None:
llm_logger.error(err_msg)
results.append((request.request_id, err_msg))
else:
insert_task.append(request)
response = self.scheduler.put_requests(insert_task)
results.extend(response)
if request:
if request.request_id not in added_requests:
added_requests[request.request_id] = 0
added_requests[request.request_id] += 1
for request_id, failed in results:
added_requests[request_id] -= 1
if added_requests[request_id] == 0:
added_requests.pop(request_id)
if failed is None:
main_process_metrics.num_requests_waiting.inc(1)
continue
error_result = RequestOutput(
request_id=request_id,
finished=True,
error_code=500,
error_msg=failed,
)
# Since the request is not in scheduler
# Send result by zmq directly
self.zmq_server.send_multipart(request_id, error_result)
except Exception as e:
llm_logger.error(
f"Error happend while receving new request from zmq, details={e}, "
f"traceback={traceback.format_exc()}"
)
def _zmq_send_generated_tokens(self):
"""
Recieve output for zmq
"""
while self.running:
try:
results = self.scheduler.get_results()
if len(results) == 0:
time.sleep(0.005)
continue
for request_id, contents in results.items():
llm_logger.info(f"Send results: {request_id} {contents}")
self.zmq_server.send_multipart(request_id, contents)
except Exception as e:
llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")
def split_mode_get_tasks(self):
"""
Split mode get tasks
"""
def receiver_loop():
while self.running:
try:
processed_indices = []
for idx, task in enumerate(self.waiting_requests):
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
self.insert_tasks([task])
llm_logger.info(f"Resource available, processing task {task.request_id}")
processed_indices.append(idx)
else:
llm_logger.debug(f"Still waiting for resources {task.request_id}")
break
for idx in sorted(processed_indices, reverse=True):
self.waiting_requests.pop(idx)
if not self.engine_worker_queue.disaggregate_queue_empty():
items = self.engine_worker_queue.get_disaggregated_tasks()
for item in items:
role = item[0]
tasks = item[1]
if role == "prefill":
for task in tasks:
task.max_tokens = task.min_tokens = 2
self.insert_tasks(tasks)
elif role == "decode":
if hasattr(tasks[0], "finished"):
if not isinstance(tasks, list):
tasks = [tasks]
for task in tasks:
task.finished = False
self.insert_tasks(tasks, allocated=True)
if self.cfg.innode_prefill_ports is not None:
self.scheduler.put_results(tasks)
else:
if len(self.waiting_requests):
llm_logger.info(f"Waiting for resource for task {tasks[0].request_id}")
self.waiting_requests.extend(tasks)
else:
new_waiting = []
for task in tasks:
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
self.insert_tasks([task])
else:
new_waiting.append(task)
if new_waiting:
self.waiting_requests.extend(new_waiting)
llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue")
else:
time.sleep(0.001)
except Exception as e:
llm_logger.error(f"Error in main loop: {e}")
time.sleep(0.1)
threading.Thread(target=receiver_loop, daemon=True).start()
def start_cache_service(self, device_ids, ipc_signal_suffix):
return self.resource_manager.cache_manager.launch_cache_manager(
cache_config=self.cfg.cache_config,
tensor_parallel_size=self.cfg.parallel_config.tensor_parallel_size,
device_ids=device_ids,
pod_ip=self.cfg.master_ip,
engine_worker_queue_port=int(
self.cfg.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]
),
pid_suffix=ipc_signal_suffix,
)
def check_and_free_block_tables(self):
self.resource_manager.check_and_free_block_tables()
def _exit_sub_services(self):
"""
exit sub services
"""
self.running = False
self.engine_worker_queue_server.cleanup()
self.exist_task_signal.clear()
self.exist_swapped_task_signal.clear()
self.worker_healthy_live_signal.clear()
self.exist_prefill_task_signal.clear()
self.model_weights_status_signal.clear()
if hasattr(self, "zmq_server") and self.zmq_server is not None:
self.zmq_server.close()

File diff suppressed because it is too large Load Diff

View File

@@ -25,12 +25,9 @@ import weakref
import numpy as np
from fastdeploy.engine.resource_manager import ResourceManager
from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.output.token_processor import TokenProcessor
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
from fastdeploy.utils import EngineError, console_logger, llm_logger
from fastdeploy.engine.common_engine import EngineSevice
from fastdeploy.inter_communicator import IPCSignal
from fastdeploy.utils import console_logger, envs, llm_logger
class ExpertService:
@@ -49,36 +46,16 @@ class ExpertService:
Args:
cfg (Config): Config object containing all the configuration parameters.
"""
self.cfg = cfg
start_pos = (local_data_parallel_id * self.cfg.parallel_config.tensor_parallel_size) % cfg.worker_num_per_node
end_pos = start_pos + self.cfg.parallel_config.tensor_parallel_size
if cfg.splitwise_role != "mixed":
self.cfg.cache_config.rdma_comm_ports = self.cfg.cache_config.rdma_comm_ports[start_pos:end_pos]
self.cfg.local_device_ids = self.cfg.device_ids.split(",")[start_pos:end_pos]
self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id
llm_logger.info(f"local_data_parallel_id: {local_data_parallel_id}")
self.cfg.disaggregate_info = None
self.scheduler = cfg.scheduler_config.scheduler()
if cfg.splitwise_role != "mixed":
self.scheduler.reset_nodeid(f"{self.scheduler.infer.nodeid}_{local_data_parallel_id!s}")
self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id
address = (cfg.master_ip, cfg.engine_worker_queue_port)
self.engine_worker_queue = EngineWorkerQueue(
address=address,
is_server=False,
client_id=0,
num_client=cfg.parallel_config.tensor_parallel_size,
local_data_parallel_id=local_data_parallel_id,
)
self.resource_manager = ResourceManager(
cfg.max_num_seqs,
cfg,
cfg.parallel_config.tensor_parallel_size,
cfg.splitwise_role,
local_data_parallel_id,
)
if cfg.splitwise_role != "mixed":
if len(self.cfg.cache_config.pd_comm_port) == 1:
self.cfg.cache_config.pd_comm_port[0] = (
@@ -86,29 +63,11 @@ class ExpertService:
)
else:
self.cfg.cache_config.pd_comm_port = [self.cfg.cache_config.pd_comm_port[local_data_parallel_id]]
self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id
self.split_connector = SplitwiseConnector(
self.cfg,
self.scheduler,
self.engine_worker_queue,
self.resource_manager,
)
self.token_processor = TokenProcessor(
cfg=cfg,
cached_generated_tokens=self.scheduler,
engine_worker_queue=self.engine_worker_queue,
split_connector=self.split_connector,
)
self.token_processor.set_resource_manager(self.resource_manager)
self.partial_chunked_tokens = [0] * (self.cfg.max_num_partial_prefills + 1)
for idx in range(1, self.cfg.max_num_partial_prefills + 1):
self.partial_chunked_tokens[idx] = (
(self.cfg.max_num_batched_tokens // idx)
// self.cfg.cache_config.block_size
* self.cfg.cache_config.block_size
)
self.engine = EngineSevice(self.cfg)
if self.cfg.scheduler_config.name == "splitwise":
self.engine.scheduler.reset_nodeid(f"{self.engine.scheduler.infer.nodeid}_{local_data_parallel_id!s}")
self._finalizer = weakref.finalize(self, self._exit_sub_services)
@@ -119,35 +78,37 @@ class ExpertService:
to keep getting request from zmq_server.
"""
# assert not self.is_started, "The engine is already started."
start_time = time.time()
self.engine.start()
if ipc_signal_suffix is not None:
self.api_server_pid = ipc_signal_suffix
self.engine.start_zmq_service(ipc_signal_suffix)
else:
ipc_signal_suffix = self.cfg.engine_worker_queue_port[0]
llm_logger.info(f"start expert service {local_data_parallel_id}")
if self.cfg.splitwise_role != "mixed":
self.cache_manager_processes = self.resource_manager.cache_manager.launch_cache_manager(
cache_config=self.cfg.cache_config,
tensor_parallel_size=self.cfg.parallel_config.tensor_parallel_size,
device_ids=self.cfg.local_device_ids,
pod_ip=self.cfg.master_ip,
engine_worker_queue_port=self.cfg.engine_worker_queue_port,
pid_suffix=f"{local_data_parallel_id}_{ipc_signal_suffix}",
)
self.split_mode_get_tasks()
self.engine.start_cache_service(self.cfg.local_device_ids, ipc_signal_suffix)
self.engine.split_mode_get_tasks()
self.insert_task_to_worker_thread = threading.Thread(target=self._insert_task_to_worker, args=())
self.insert_task_to_worker_thread.daemon = True
self.insert_task_to_worker_thread.start()
# Start TokenProcessor thread
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(local_data_parallel_id + int(self.cfg.engine_worker_queue_port))
self.token_processor.run()
if self.cfg.scheduler_config.name == "splitwise":
self.cfg.init_cache_info()
role = self.cfg.splitwise_role
host_ip = self.cfg.host_ip
disaggregate = self.cfg.disaggregate_info
self.scheduler.start(role, host_ip, disaggregate)
self.cfg.print()
self.engine.scheduler.start(role, host_ip, disaggregate)
if self.cfg.splitwise_role != "mixed":
self.splitwise_receive_thread = threading.Thread(
target=self.engine.split_connector.start_receiver, args=()
)
self.splitwise_receive_thread.daemon = True
self.splitwise_receive_thread.start()
self.cfg.print()
local_rank = local_data_parallel_id % self.cfg.worker_num_per_node
if not envs.FD_ENABLE_MULTI_API_SERVER:
launched_expert_service_signal_data = np.zeros(
shape=[self.cfg.parallel_config.data_parallel_size // self.cfg.nnode], dtype=np.int32
)
@@ -158,7 +119,6 @@ class ExpertService:
suffix=ipc_signal_suffix,
create=False,
)
local_rank = local_data_parallel_id % self.cfg.worker_num_per_node
self.launched_expert_service_signal.value[local_rank] = 1
console_logger.info(
@@ -166,198 +126,14 @@ class ExpertService:
)
return True
def _insert_task_to_worker(self):
"""
Insert task to engine thread, monitor scheduler request queue.
if the engine has resource, insert task to engine
"""
current_id = -1
while True:
try:
if self.resource_manager.available_batch() == 0:
time.sleep(0.001)
continue
if self.engine_worker_queue.num_tasks() > 0:
time.sleep(0.001)
continue
if len(self.split_connector.current_request_ids) > 0:
time.sleep(0.001)
continue
num_prefill_batch = min(
int(self.resource_manager.available_batch()),
self.cfg.max_prefill_batch,
)
self.resource_manager.check_and_free_block_tables()
tasks = self.scheduler.get_requests(
available_blocks=self.resource_manager.available_block_num(),
block_size=self.cfg.cache_config.block_size,
reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num,
max_num_batched_tokens=self.cfg.max_num_batched_tokens,
batch=num_prefill_batch,
)
if len(tasks) == 0:
time.sleep(0.001)
continue
if self.cfg.splitwise_role != "mixed":
llm_logger.info("Inserting splitwise tasks")
self.split_connector.send_splitwise_tasks(tasks, current_id)
current_id = (current_id + 1) % 100003
self.insert_tasks(tasks, current_id)
main_process_metrics.num_requests_waiting.dec(len(tasks))
main_process_metrics.num_requests_running.inc(len(tasks))
except Exception as e:
err_msg = f"Error happend while insert task to engine: {e}, {traceback.format_exc()!s}."
llm_logger.error(err_msg)
def split_mode_get_tasks(self):
"""
Split mode get tasks
"""
waiting_requests = []
def receiver_loop():
while True:
try:
if len(waiting_requests) > 0:
for task in waiting_requests:
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
self.insert_tasks([task])
waiting_requests.remove(task)
else:
break
if not self.engine_worker_queue.disaggregate_queue_empty():
items = self.engine_worker_queue.get_disaggregated_tasks()
for item in items:
role = item[0]
tasks = item[1]
if role == "prefill":
llm_logger.info("get prefill tasks")
for task in tasks:
task.max_tokens = task.min_tokens = 2
self.insert_tasks(tasks)
elif role == "decode":
llm_logger.info(f"get decode tasks {tasks}")
if hasattr(tasks[0], "finished"):
if not isinstance(tasks, list):
tasks = [tasks]
for task in tasks:
task.finished = False
# self.scheduler.put_results(tasks)
self.insert_tasks(tasks, allocated=True)
else:
if len(waiting_requests):
for task in tasks:
waiting_requests.append(task)
else:
for task in tasks:
if not self.resource_manager.is_resource_sufficient(
task.prompt_token_ids_len
):
waiting_requests.append(task)
else:
self.insert_tasks([task])
else:
time.sleep(0.001)
continue
except Exception as e:
llm_logger.error(f"get decode tasks error: {e}, {str(traceback.format_exc())}")
threading.Thread(target=receiver_loop, daemon=True).start()
def insert_tasks(self, tasks, current_id=-1, allocated=False):
"""
Insert tasks to engine.
"""
if allocated:
current_tasks = []
for task in tasks:
cur_task_idx = self.resource_manager.req_dict[task.request_id]
del self.resource_manager.req_dict[task.request_id]
cur_task = self.resource_manager.tasks_list[cur_task_idx]
if task.error_code != 200:
self.resource_manager.stop_flags[cur_task_idx] = True
self.resource_manager.tasks_list[cur_task_idx] = None
self.resource_manager._recycle_block_tables(cur_task)
if task.request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[task.request_id]
self.scheduler.put_results([task])
llm_logger.warning(
f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource."
)
continue
llm_logger.info(f"{cur_task_idx} {task.request_id}")
cur_task.prompt_token_ids[0] = task.outputs.token_ids[0]
self.token_processor.tokens_counter[task.request_id] = 1
current_tasks.append(cur_task)
self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz))
return True
self.resource_manager.check_and_free_block_tables()
if not isinstance(tasks, list):
tasks = [tasks]
for item in tasks:
item.schedule_start_time = time.time()
available_batch = np.sum(self.resource_manager.stop_flags)
if len(tasks) > available_batch:
llm_logger.error(f"Inserting batch:{len(tasks)} exceeds the available batch:{available_batch}.")
llm_logger.error("The exceeded part will be ignored!")
tasks = tasks[:available_batch]
req_ids = [t.request_id for t in tasks]
tasks = self.resource_manager.allocate_resources_for_new_tasks(tasks)
if not tasks:
error_msg = f"The request required resources is exceed the limit, request id={req_ids}."
llm_logger.error(error_msg)
raise EngineError(error_msg, error_code=500)
return False
self.token_processor.number_of_tasks += len(tasks)
is_decode = False
is_prefill = False
for i in range(len(tasks)):
if tasks[i].disaggregate_info is not None:
if tasks[i].disaggregate_info["role"] == "decode":
is_decode = True
else:
is_prefill = True
self.token_processor.number_of_input_tokens += tasks[i].prompt_token_ids_len
if is_decode or is_prefill:
self.split_connector.send_cache_infos(tasks, current_id)
for task in tasks:
task.infer_start_time = time.time()
if not is_decode:
llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
if not is_prefill and self.cfg.cache_config.enable_chunked_prefill:
if not self.cfg.model_config.enable_mm:
self.update_requests_chunk_size(tasks)
else:
self.update_mm_requests_chunk_size(tasks)
self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz))
return True
def _exit_sub_services(self):
"""
exit sub services
"""
if hasattr(self, "cache_manager_processes"):
self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear()
self.resource_manager.cache_manager.cache_ready_signal.clear()
self.engine.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear()
self.engine.resource_manager.cache_manager.cache_ready_signal.clear()
for p in self.cache_manager_processes:
llm_logger.info(f"Killing cache manager process {p.pid}")
try:
@@ -369,13 +145,16 @@ class ExpertService:
self.zmq_server.close()
def start_expert_service(cfg, local_data_parallel_id, ipc_signal_suffix):
def start_data_parallel_service(cfg, local_data_parallel_id, ipc_signal_suffix=None):
"""
Start expert service
"""
expert_service = ExpertService(cfg, local_data_parallel_id)
try:
expert_service.start(ipc_signal_suffix, local_data_parallel_id)
expert_service.split_connector.start_receiver()
while True:
time.sleep(1000)
except Exception as e:
llm_logger.exception(f"Expert service failed to start: {e}, {str(traceback.format_exc())}")

View File

@@ -45,6 +45,7 @@ class EngineClient:
max_model_len,
tensor_parallel_size,
pid,
port,
limit_mm_per_prompt,
mm_processor_kwargs,
# enable_mm=False,
@@ -75,13 +76,19 @@ class EngineClient:
self.data_processor = input_processor.create_processor()
self.max_model_len = max_model_len
max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
array_size = min(max_chips_per_node, tensor_parallel_size * data_parallel_size)
if tensor_parallel_size < max_chips_per_node:
self.is_master = True
else:
self.is_master = False
array_size = min(max_chips_per_node, tensor_parallel_size)
self.worker_healthy_live_recorded_time_array = np.zeros(shape=[array_size], dtype=np.int32)
self.worker_healthy_live_signal = IPCSignal(
name="worker_healthy_live_signal",
array=self.worker_healthy_live_recorded_time_array,
dtype=np.int32,
suffix=pid,
suffix=port,
create=False,
)
self.semaphore = StatefulSemaphore((FD_SUPPORT_MAX_CONNECTIONS + workers - 1) // workers)
@@ -90,7 +97,7 @@ class EngineClient:
name="model_weights_status",
array=model_weights_status,
dtype=np.int32,
suffix=pid,
suffix=port,
create=False,
)
self.connection_manager = DealerConnectionManager(

View File

@@ -31,6 +31,7 @@ from prometheus_client import CONTENT_TYPE_LATEST
from fastdeploy.engine.args_utils import EngineArgs
from fastdeploy.engine.engine import LLMEngine
from fastdeploy.engine.expert_service import ExpertService
from fastdeploy.entrypoints.chat_utils import load_chat_template
from fastdeploy.entrypoints.engine_client import EngineClient
from fastdeploy.entrypoints.openai.protocol import (
@@ -60,6 +61,7 @@ from fastdeploy.utils import (
FlexibleArgumentParser,
StatefulSemaphore,
api_server_logger,
configure_uvicorn_logging,
console_logger,
is_port_available,
retrive_model_from_server,
@@ -98,15 +100,10 @@ def load_engine():
api_server_logger.info(f"FastDeploy LLM API server starting... {os.getpid()}")
engine_args = EngineArgs.from_cli_args(args)
engine = LLMEngine.from_engine_args(engine_args)
if not engine.start(api_server_pid=os.getpid()):
api_server_logger.error("Failed to initialize FastDeploy LLM engine, service exit now!")
return None
api_server_logger.info("FastDeploy LLM engine initialized!\n")
console_logger.info(f"Launching metrics service at http://{args.host}:{args.metrics_port}/metrics")
console_logger.info(f"Launching chat completion service at http://{args.host}:{args.port}/v1/chat/completions")
console_logger.info(f"Launching completion service at http://{args.host}:{args.port}/v1/completions")
llm_engine = engine
return engine
@@ -117,6 +114,25 @@ MAX_CONCURRENT_CONNECTIONS = (args.max_concurrency + args.workers - 1) // args.w
connection_semaphore = StatefulSemaphore(MAX_CONCURRENT_CONNECTIONS)
def load_data_service():
"""
load data service
"""
global llm_engine
if llm_engine is not None:
return llm_engine
api_server_logger.info(f"FastDeploy LLM API server starting... {os.getpid()}")
engine_args = EngineArgs.from_cli_args(args)
config = engine_args.create_engine_config()
api_server_logger.info(f"local_data_parallel_id: {config.parallel_config}")
expert_service = ExpertService(config, config.parallel_config.local_data_parallel_id)
if not expert_service.start(os.getpid(), config.parallel_config.local_data_parallel_id):
api_server_logger.error("Failed to initialize FastDeploy LLM expert service, service exit now!")
return None
llm_engine = expert_service
return expert_service
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
@@ -140,19 +156,20 @@ async def lifespan(app: FastAPI):
model_paths = [ModelPath(name=served_model_names, model_path=args.model, verification=verification)]
engine_client = EngineClient(
args.model,
args.tokenizer,
args.max_model_len,
args.tensor_parallel_size,
pid,
args.limit_mm_per_prompt,
args.mm_processor_kwargs,
model_name_or_path=args.model,
tokenizer=args.tokenizer,
max_model_len=args.max_model_len,
tensor_parallel_size=args.tensor_parallel_size,
pid=pid,
port=int(args.engine_worker_queue_port[args.local_data_parallel_id]),
limit_mm_per_prompt=args.limit_mm_per_prompt,
mm_processor_kwargs=args.mm_processor_kwargs,
# args.enable_mm,
args.reasoning_parser,
args.data_parallel_size,
args.enable_logprob,
args.workers,
args.tool_call_parser,
reasoning_parser=args.reasoning_parser,
data_parallel_size=args.data_parallel_size,
enable_logprob=args.enable_logprob,
workers=args.workers,
tool_parser=args.tool_call_parser,
)
app.state.dynamic_load_weight = args.dynamic_load_weight
model_handler = OpenAIServingModels(
@@ -176,6 +193,9 @@ async def lifespan(app: FastAPI):
app.state.engine_client = engine_client
app.state.chat_handler = chat_handler
app.state.completion_handler = completion_handler
global llm_engine
if llm_engine is not None:
llm_engine.engine.data_processor = engine_client.data_processor
yield
# close zmq
try:
@@ -510,8 +530,18 @@ def launch_controller_server():
def main():
"""main函数"""
if load_engine() is None:
configure_uvicorn_logging()
load_model_register_plugins()
if args.local_data_parallel_id == 0:
if not load_engine():
return
else:
if not load_data_service():
return
api_server_logger.info("FastDeploy LLM engine initialized!\n")
console_logger.info(f"Launching metrics service at http://{args.host}:{args.metrics_port}/metrics")
console_logger.info(f"Launching chat completion service at http://{args.host}:{args.port}/v1/chat/completions")
console_logger.info(f"Launching completion service at http://{args.host}:{args.port}/v1/completions")
launch_controller_server()
launch_metrics_server()

View File

@@ -0,0 +1,107 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import argparse
import os
import subprocess
import sys
import time
from fastdeploy.utils import get_logger, is_port_available
logger = get_logger("multi_api_server", "multi_api_server.log")
def start_servers(server_count, server_args, ports, metrics_ports):
processes = []
logger.info(f"Starting servers on ports: {ports} with args: {server_args} and metrics ports: {metrics_ports}")
for i in range(len(server_args)):
if server_args[i] == "--engine-worker-queue-port":
engine_worker_queue_port = server_args[i + 1].split(",")
break
check_param(ports, server_count)
check_param(metrics_ports, server_count)
check_param(engine_worker_queue_port, server_count)
# check_param(server_args, server_count)
for i in range(server_count):
port = int(ports[i])
metrics_port = int(metrics_ports[i])
env = os.environ.copy()
env["FD_LOG_DIR"] = f"log_{i}"
cmd = [
sys.executable,
"-m",
"fastdeploy.entrypoints.openai.api_server",
*server_args,
"--port",
str(port),
"--metrics-port",
str(metrics_port),
"--local-data-parallel-id",
str(i),
]
# 启动子进程
proc = subprocess.Popen(cmd, env=env)
processes.append(proc)
logger.info(f"Starting servers #{i+1} (PID: {proc.pid}) port: {port} | command: {' '.join(cmd)}")
return processes
def check_param(ports, num_servers):
logger.info(f"check param {ports}, {num_servers}")
assert len(ports) == num_servers, "Number of ports must match num-servers"
for port in ports:
logger.info(f"check port {port}")
if not is_port_available("0.0.0.0", int(port)):
raise ValueError(f"Port {port} is already in use.")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--ports", default="8000,8002", type=str, help="ports to the http server")
parser.add_argument("--num-servers", default=2, type=int, help="number of workers")
parser.add_argument("--metrics-ports", default="8800,8802", type=str, help="ports for metrics server")
parser.add_argument("--args", nargs=argparse.REMAINDER, help="remaining arguments are passed to api_server.py")
args = parser.parse_args()
logger.info(f"Starting {args.num_servers} servers on ports: {args.ports} with args: {args.args}")
# check_param(args.ports, args.num_servers)
# check_param(args.metrics_ports, args.num_servers)
# check_param(args.args.engine_worker_queue_port, args.num_servers)
processes = start_servers(
server_count=args.num_servers,
server_args=args.args,
ports=args.ports.split(","),
metrics_ports=args.metrics_ports.split(","),
)
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
for proc in processes:
proc.terminate()
for proc in processes:
proc.wait()
logger.info("All servers stopped.")
if __name__ == "__main__":
main()

View File

@@ -37,7 +37,7 @@ from fastdeploy.entrypoints.openai.protocol import (
UsageInfo,
)
from fastdeploy.metrics.work_metrics import work_process_metrics
from fastdeploy.utils import api_server_logger, get_host_ip
from fastdeploy.utils import api_server_logger
from fastdeploy.worker.output import LogprobsLists
@@ -50,15 +50,16 @@ class OpenAIServingChat:
self.engine_client = engine_client
self.models = models
self.pid = pid
self.master_ip = ips
self.max_waiting_time = max_waiting_time
self.host_ip = get_host_ip()
self.chat_template = chat_template
if self.master_ip is not None:
if isinstance(self.master_ip, list):
self.master_ip = self.master_ip[0]
if ips is not None:
if isinstance(ips, list):
self.master_ip = ips[0]
else:
self.master_ip = self.master_ip.split(",")[0]
self.master_ip = ips.split(",")[0]
else:
self.master_ip = "0.0.0.0"
api_server_logger.info(f"master ip: {self.master_ip}")
async def _ensure_connection_manager(self):
"""ensure connection manager initialized"""
@@ -67,19 +68,16 @@ class OpenAIServingChat:
self.engine_client.connection_initialized = True
def _check_master(self):
if self.master_ip is None:
return True
if self.host_ip == self.master_ip:
return True
return False
return self.engine_client.is_master
async def create_chat_completion(self, request: ChatCompletionRequest):
"""
Create a new chat completion using the specified parameters.
"""
if not self._check_master():
err_msg = f"Only master node can accept completion request, please send request to master node: {self.pod_ips[0]}"
err_msg = (
f"Only master node can accept completion request, please send request to master node: {self.master_ip}"
)
api_server_logger.error(err_msg)
return ErrorResponse(message=err_msg, code=400)
@@ -117,7 +115,6 @@ class OpenAIServingChat:
api_server_logger.error(error_msg)
self.engine_client.semaphore.release()
return ErrorResponse(code=400, message=error_msg)
del current_req_dict
if request.stream:
@@ -193,6 +190,7 @@ class OpenAIServingChat:
choices=[],
model=model_name,
)
api_server_logger.info(f"create chat completion request: {request_id}")
try:
await self._ensure_connection_manager()
@@ -388,7 +386,6 @@ class OpenAIServingChat:
enable_thinking = request.metadata.get("enable_thinking") if request.metadata else None
include_stop_str_in_output = request.include_stop_str_in_output
try:
await self._ensure_connection_manager()
dealer, response_queue = await self.engine_client.connection_manager.get_connection(request_id)

View File

@@ -33,7 +33,7 @@ from fastdeploy.entrypoints.openai.protocol import (
ErrorResponse,
UsageInfo,
)
from fastdeploy.utils import api_server_logger, get_host_ip
from fastdeploy.utils import api_server_logger
from fastdeploy.worker.output import LogprobsLists
@@ -42,14 +42,14 @@ class OpenAIServingCompletion:
self.engine_client = engine_client
self.models = models
self.pid = pid
self.master_ip = ips
self.host_ip = get_host_ip()
self.max_waiting_time = max_waiting_time
if self.master_ip is not None:
if isinstance(self.master_ip, list):
self.master_ip = self.master_ip[0]
if ips is not None:
if isinstance(ips, list):
self.master_ip = ips[0]
else:
self.master_ip = self.master_ip.split(",")[0]
self.master_ip = ips.split(",")[0]
else:
self.master_ip = "0.0.0.0"
async def _ensure_connection_manager(self):
"""ensure connection manager initialized"""
@@ -58,18 +58,16 @@ class OpenAIServingCompletion:
self.engine_client.connection_initialized = True
def _check_master(self):
if self.master_ip is None:
return True
if self.host_ip == self.master_ip:
return True
return False
return self.engine_client.is_master
async def create_completion(self, request: CompletionRequest):
"""
Create a completion for the given prompt.
"""
if not self._check_master():
err_msg = f"Only master node can accept completion request, please send request to master node: {self.pod_ips[0]}"
err_msg = (
f"Only master node can accept completion request, please send request to master node: {self.master_ip}"
)
api_server_logger.error(err_msg)
return ErrorResponse(message=err_msg, code=400)
if self.models:

View File

@@ -47,7 +47,7 @@ class DealerConnectionManager:
self.running = True
for index in range(self.max_connections):
await self._add_connection(index)
api_server_logger.info(f"Started {self.max_connections} connections")
api_server_logger.info(f"Started {self.max_connections} connections, pid {self.pid}")
async def _add_connection(self, index):
"""create a new connection and start listening task"""

View File

@@ -86,6 +86,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
"FD_JOB_ID": lambda: os.getenv("FD_JOB_ID"),
# support max connections
"FD_SUPPORT_MAX_CONNECTIONS": lambda: int(os.getenv("FD_SUPPORT_MAX_CONNECTIONS", "1024")),
# enable multi api server
"FD_ENABLE_MULTI_API_SERVER": lambda: bool(int(os.getenv("FD_ENABLE_MULTI_API_SERVER", "0"))),
"FD_FOR_TORCH_MODEL_FORMAT": lambda: bool(int(os.getenv("FD_FOR_TORCH_MODEL_FORMAT", "0"))),
}

View File

@@ -16,7 +16,7 @@
from .engine_cache_queue import EngineCacheQueue
from .engine_worker_queue import EngineWorkerQueue
from .ipc_signal import IPCSignal
from .ipc_signal import IPCSignal, shared_memory_exists
from .zmq_client import ZmqClient
__all__ = ["ZmqClient", "IPCSignal", "EngineWorkerQueue", "EngineCacheQueue"]
__all__ = ["ZmqClient", "IPCSignal", "EngineWorkerQueue", "EngineCacheQueue", "shared_memory_exists"]

View File

@@ -18,6 +18,8 @@ from multiprocessing.shared_memory import SharedMemory
import numpy as np
from fastdeploy.utils import llm_logger
def shared_memory_exists(name: str) -> bool:
"""Check if a shared memory block with the given name exists.
@@ -35,7 +37,7 @@ def shared_memory_exists(name: str) -> bool:
except FileNotFoundError:
return False
except Exception as e:
print(f"Unexpected error: {e}")
llm_logger.error(f"Unexpected error: {e}")
return False
@@ -78,7 +80,9 @@ class IPCSignal:
name = name + f".{suffix}"
if create:
assert not shared_memory_exists(name), f"ShareMemory: {name} already exists"
if shared_memory_exists(name):
llm_logger.warning(f"ShareMemory: {name} already exists, delete it")
SharedMemory(name=name, create=False).unlink()
self.shm = SharedMemory(create=True, size=array.nbytes, name=name)
self.value: np.ndarray = np.ndarray(array.shape, dtype=array.dtype, buffer=self.shm.buf)
self.value[:] = array # Initialize with input array data

View File

@@ -71,6 +71,7 @@ class ZmqClient:
self.router.setsockopt(zmq.ROUTER_MANDATORY, 1)
self.router.setsockopt(zmq.SNDTIMEO, -1)
self.router.bind(f"ipc://{self.router_path}")
llm_logger.info(f"router path: {self.router_path}")
def send_json(self, data):
"""
@@ -126,7 +127,6 @@ class ZmqClient:
continue
else:
break
if self.req_dict[req_id] == -1:
if data[-1].finished:
with self.mutex:

View File

@@ -49,6 +49,7 @@ def get_moe_scores(
compute moe scores using e_score_correction_bias.
"""
scores = paddle.nn.functional.sigmoid(gating_output)
assert e_score_correction_bias is not None, "e_score_correction_bias is none!"
scores_with_bias = scores + e_score_correction_bias
scores, topk_values, topk_idx = noaux_tc(
scores,
@@ -104,11 +105,12 @@ class DeepEPEngine:
# In mixed EP mode on a single node, we dynamically switch between
# high throughput and low latency modes.
if splitwise_role == "mixed":
self.deepep_engine = deep_ep.Buffer(
self.group,
int(2e9),
int(5e9),
int(6e9),
low_latency_mode=True,
num_qps_per_rank=24,
)
@@ -387,6 +389,7 @@ class EPPrefillRunner(EPRunner):
*args,
**kwargs,
):
(
num_tokens_per_rank,
num_tokens_per_rdma_rank,

View File

@@ -35,23 +35,22 @@ class SplitwiseConnector:
SplitwiseConnector class for managing and scheduling Splitwise tasks.
"""
def __init__(self, cfg, scheduler, worker_queue, resource_manager):
def __init__(self, cfg, worker_queue, resource_manager):
"""
Initialize the SplitwiseConnector instance.
Parameters:
cfg (dict): Configuration information.
scheduler (object): Scheduler object.
worker_queue (object): Worker queue object.
resource_manager (object): Resource manager object.
"""
self.cfg = cfg
self.scheduler = scheduler
self.engine_worker_queue = worker_queue
self.resource_manager = resource_manager
self.connect_innode_instances = {}
self.temp_cache_info = dict()
self.current_request_ids = dict()
self.idx = self.cfg.parallel_config.local_data_parallel_id
if self.cfg.cache_config.pd_comm_port is not None:
self.zmq_ctx = zmq.Context()
@@ -85,6 +84,7 @@ class SplitwiseConnector:
"""
while True:
try:
if hasattr(self, "poller"):
socks = dict(self.poller.poll(100))
if not socks:
continue
@@ -96,7 +96,8 @@ class SplitwiseConnector:
message = frames[-1]
self.io_executor.submit(self._process_message, message)
time.sleep(0.001)
else:
time.sleep(5)
except Exception as e:
logger.error(f"Receiver error: {e}, {str(traceback.format_exc())}")
time.sleep(1)
@@ -183,7 +184,7 @@ class SplitwiseConnector:
def dispatch_innode_splitwise_tasks(self, tasks, current_id):
"""
Dispatch splitwise tasks to the scheduler.
Dispatch splitwise tasks .
Parameters:
tasks (list): List of tasks.
@@ -203,7 +204,7 @@ class SplitwiseConnector:
"cache_info": {
"ipc": {
"ip": "0.0.0.0",
"port": self.cfg.engine_worker_queue_port,
"port": self.cfg.engine_worker_queue_port[self.idx],
"current_id": current_id,
},
},
@@ -286,7 +287,7 @@ class SplitwiseConnector:
if port not in self.connect_innode_instances:
self.create_connection(port)
for task in tasks:
task.disaggregate_info["cache_info"]["ipc"]["port"] = self.cfg.engine_worker_queue_port
task.disaggregate_info["cache_info"]["ipc"]["port"] = self.cfg.engine_worker_queue_port[self.idx]
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks))
for task in tasks:
task.disaggregate_info["cache_info"]["ipc"]["port"] = port

View File

@@ -38,6 +38,7 @@ import yaml
from aistudio_sdk.snapshot_download import snapshot_download as aistudio_download
from tqdm import tqdm
from typing_extensions import TypeIs, assert_never
from uvicorn.config import LOGGING_CONFIG
from fastdeploy import envs
from fastdeploy.logger.logger import FastDeployLogger
@@ -76,6 +77,35 @@ class ColoredFormatter(logging.Formatter):
return message
def configure_uvicorn_logging():
"""
uvicorn logger config
"""
# add timestamp to log
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
date_format = "%Y-%m-%d %H:%M:%S"
LOGGING_CONFIG["formatters"]["default"]["fmt"] = log_format
LOGGING_CONFIG["formatters"]["default"]["datefmt"] = date_format
LOGGING_CONFIG["formatters"]["access"]["fmt"] = log_format
LOGGING_CONFIG["formatters"]["access"]["datefmt"] = date_format
uvicorn_error_logger = logging.getLogger("")
uvicorn_access_logger = logging.getLogger("uvicorn.access")
for handler in uvicorn_error_logger.handlers[:]:
uvicorn_error_logger.removeHandler(handler)
for handler in uvicorn_access_logger.handlers[:]:
uvicorn_access_logger.removeHandler(handler)
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(logging.Formatter(log_format, date_format))
uvicorn_error_logger.addHandler(console_handler)
uvicorn_access_logger.addHandler(console_handler)
uvicorn_error_logger.setLevel(logging.INFO)
uvicorn_access_logger.setLevel(logging.INFO)
uvicorn_error_logger.propagate = False
uvicorn_access_logger.propagate = False
class DailyRotatingFileHandler(BaseRotatingHandler):
"""
like `logging.TimedRotatingFileHandler`, but this class support multi-process

View File

@@ -106,9 +106,7 @@ class GCUModelRunner(ModelRunnerBase):
self.forward_meta: ForwardMeta = None
# Postprocess Env params
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(
self.local_rank + int(self.parallel_config.engine_worker_queue_port)
)
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.parallel_config.engine_worker_queue_port)
def exist_prefill(self):
"""

View File

@@ -153,9 +153,8 @@ class GPUModelRunner(ModelRunnerBase):
self.forward_meta: ForwardMeta = None
# Postprocess Env params
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(
self.local_rank + int(self.parallel_config.engine_worker_queue_port)
)
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.parallel_config.engine_worker_queue_port)
logger.info(f"queue id is {str(self.parallel_config.engine_worker_queue_port)}")
def exist_prefill(self):
"""

View File

@@ -152,19 +152,7 @@ class PaddleDisWorkerProc:
# TODO(gongshaotian): Use worker factory to get worker
self.worker = get_worker(fd_config=fd_config, local_rank=self.local_rank, rank=self.ranks)
# Initialize task queue
task_address = (
self.parallel_config.pod_ip,
self.parallel_config.engine_worker_queue_port,
)
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
self.task_queue = TaskQueue(
address=task_address,
is_server=False,
num_client=self.parallel_config.tensor_parallel_size,
client_id=self.parallel_config.tensor_parallel_rank,
local_data_parallel_id=self.parallel_config.data_parallel_rank,
)
def init_health_status(self) -> None:
"""
@@ -193,15 +181,16 @@ class PaddleDisWorkerProc:
self.worker_ready_signal.value[self.local_rank % self.max_chips_per_node] = 1
# init worker_healthy_live_signal
workers_alive = np.zeros(shape=[array_size], dtype=np.int32)
workers_alive = np.zeros(shape=[min(array_size, self.parallel_config.tensor_parallel_size)], dtype=np.int32)
self.worker_healthy_live_signal = IPCSignal(
name="worker_healthy_live_signal",
array=workers_alive,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
suffix=self.parallel_config.engine_worker_queue_port,
create=False,
)
self.worker_healthy_live_signal.value[self.local_rank % self.max_chips_per_node] = int(time.time())
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
self.worker_healthy_live_signal.value[local_rank % self.max_chips_per_node] = int(time.time())
# init model_weights_status
workers_model_weights = np.zeros(shape=[1], dtype=np.int32)
@@ -209,27 +198,27 @@ class PaddleDisWorkerProc:
name="model_weights_status",
array=workers_model_weights,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
suffix=self.parallel_config.engine_worker_queue_port,
create=False,
)
# init exist_task_signal
workers_exist_task = np.zeros([self.parallel_config.data_parallel_size], dtype=np.int32)
workers_exist_task = np.zeros([1], dtype=np.int32)
self.exist_task_signal = IPCSignal(
name="exist_task_signal",
array=workers_exist_task,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
suffix=self.parallel_config.engine_worker_queue_port,
create=False,
)
# init exist_swapped_task_signal
workers_swapped_task = np.zeros(shape=[self.parallel_config.data_parallel_size], dtype=np.int32)
workers_swapped_task = np.zeros(shape=[1], dtype=np.int32)
self.exist_swapped_task_signal = IPCSignal(
name="exist_swapped_task_signal",
array=workers_swapped_task,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
suffix=self.parallel_config.engine_worker_queue_port,
create=False,
)
@@ -239,9 +228,10 @@ class PaddleDisWorkerProc:
name="exist_prefill_task_signal",
array=exist_prefill_task_signal_data,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
suffix=self.parallel_config.engine_worker_queue_port,
create=False,
)
logger.info("gaoziyuan test init_health_status")
def event_loop_normal(self) -> None:
"""Main event loop for Paddle Distrubuted Workers.
@@ -411,6 +401,21 @@ class PaddleDisWorkerProc:
"""Initialize device and Construct model runner"""
self.worker.init_device()
def start_task_queue_service(self):
# Initialize task queue
task_address = (
self.parallel_config.pod_ip,
self.parallel_config.engine_worker_queue_port,
)
logger.info(f"connect task queue address {task_address}")
self.task_queue = TaskQueue(
address=task_address,
is_server=False,
num_client=self.parallel_config.tensor_parallel_size,
client_id=self.parallel_config.tensor_parallel_rank,
local_data_parallel_id=self.parallel_config.expert_parallel_rank,
)
def load_model(self) -> None:
"""Load weights and create model"""
@@ -444,7 +449,7 @@ def parse_args():
parser.add_argument("--total_block_num", type=int, default=2000)
parser.add_argument("--block_size", type=int, default=64)
parser.add_argument("--pod_ip", type=str, default="127.0.0.1")
parser.add_argument("--engine_worker_queue_port", type=int, default=9923)
parser.add_argument("--engine_worker_queue_port", type=str, default="9923")
parser.add_argument("--max_model_len", type=int, default=3072, help="max model len")
parser.add_argument("--device_ids", type=str, default="0", help="cuda visible devices")
parser.add_argument("--dtype", type=str, default="bfloat16", help="input dtype")
@@ -619,10 +624,16 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
num_experts_per_rank = num_experts // parallel_config.expert_parallel_size
num_experts_start_offset = expert_parallel_rank * num_experts_per_rank
max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
parallel_config.local_data_parallel_id = expert_parallel_rank % max_chips_per_node
parallel_config.expert_parallel_rank = expert_parallel_rank
parallel_config.num_experts_per_rank = num_experts_per_rank
parallel_config.num_experts_start_offset = num_experts_start_offset
parallel_config.engine_worker_queue_port = parallel_config.engine_worker_queue_port[
parallel_config.local_data_parallel_id
]
parallel_config.set_tp_group()
load_config = LoadConfig(vars(args))
@@ -640,6 +651,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
logger.info(f"parallel_config.use_ep {parallel_config.use_ep}")
logger.info(f"parallel_config.tensor_parallel_size {parallel_config.tensor_parallel_size}")
logger.info(f"parallel_config.tensor_parallel_rank {parallel_config.tensor_parallel_rank}")
logger.info(f"parallel_config.engine_worker_queue_port {parallel_config.engine_worker_queue_port}")
if getattr(model_config, "num_hidden_layers", None) is None:
raise ValueError("num_hidden_layers is None")
@@ -705,6 +717,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
graph_opt_config=graph_opt_config,
early_stop_config=early_stop_config,
cache_config=cache_config,
engine_worker_queue_port=args.engine_worker_queue_port,
ips=args.ips,
)
update_fd_config_for_mm(fd_config)
@@ -746,6 +759,8 @@ def run_worker_proc() -> None:
# Initialize health status
worker_proc.init_health_status()
worker_proc.start_task_queue_service()
# Start event loop
worker_proc.event_loop_normal()

View File

@@ -92,6 +92,8 @@ if [ ${exit_code} -ne 0 ]; then
exit 1
fi
sleep 5
#0731新增kv block集中式管理相关测试在起服务时启用对应环境变量 export ENABLE_V1_KVCACHE_SCHEDULER=True
# 起服务
rm -rf log/*