Compare commits

...

16 Commits

Author SHA1 Message Date
chenjian
7b09611d6b [Bug fix] Fix batched token condition (#3565) 2025-08-23 11:55:53 +08:00
chenjian
606d9e9c2c [Feature] Support mixed deployment with adapter (#3517) 2025-08-21 18:19:01 +08:00
李泳桦
d18a637a17 [feat] add metrics for yiyan adapter (#3219)
* [feat] add metrics for yiyan adapter

* [fix] fix metrics num_requests_waiting and num_requests_running

* [fix] fix metrics gpu_cache_usage_perc

* [refactor] change where requests_number increases

* [chore] rename xxx_block_num as xxx_gpu_block_num, and update their values accordingly

* [chore] delete useless code
2025-08-21 16:58:10 +08:00
chenjian
6854506533 [Bug fix] Fix bug for d blocks not enough (#3479)
* Support batched tokens for EP

* Support batched tokens for EP

* Support batched tokens for EP

* Support batched tokens for EP

* Support batched tokens for EP and fix bug

* Support batched tokens for EP and fix bug

* Support batched tokens for EP and fix bug

* Support batched tokens for EP and fix bug

* Fix bug for memory allocation

* Fix bug for D blocks not enough

* fix bug when d blocks not enough

* fix bug when d blocks not enough

* fix cache message recycle step

* fix cache message recycle step

* Fix step_idx recycle
2025-08-21 11:36:16 +08:00
chenjian
c487b62ee0 [Bug fix] Fix memory allocation (#3475)
* Support batched tokens for EP

* Support batched tokens for EP

* Support batched tokens for EP

* Support batched tokens for EP

* Support batched tokens for EP and fix bug

* Support batched tokens for EP and fix bug

* Support batched tokens for EP and fix bug

* Support batched tokens for EP and fix bug

* Fix bug for memory allocation
2025-08-19 19:48:24 +08:00
chenjian
d2f6c3b998 [Bug fix] Fix bug for seq_len_encoder is 1 (#3467) 2025-08-19 15:21:32 +08:00
chenjian
aba94169dc [Feature] Support batched tokens for EP (#3415)
* Support batched tokens for EP

* Support batched tokens for EP

* Support batched tokens for EP

* Support batched tokens for EP

* Support batched tokens for EP and fix bug

* Support batched tokens for EP and fix bug

* Support batched tokens for EP and fix bug

* Support batched tokens for EP and fix bug
2025-08-18 11:43:36 +08:00
chenjian
3f86ae0007 fix cache messager bug when d restart (#3386) 2025-08-14 11:43:59 +08:00
chenjian
89177d881c [Bug fix] Fix zmq core bug (#3357)
* [Bug fix] Fix zmq core bug due to concurrently used by threads

* Fix zmq core bug due to concurrently used by threads
2025-08-13 20:24:39 +08:00
chenjian
7573802a88 [Feature] Support mtp ep in fd (#3340)
* [Optimize] Add metrics for analysing perf

* Fix bug in mtp
2025-08-11 21:49:44 +08:00
chenjian
110f33a530 [Bug fix] Test td cache messager (#3242)
* support disable cache task in decode node

* fix busg

* Update engine.py

* Update expert_service.py

* Update splitwise_connector.py

* Optimize log for debug

* Optimize log for debug

* fix bug

---------

Co-authored-by: ltd0924 <ltd0924@sina.com>
Co-authored-by: ltd0924 <32387785+ltd0924@users.noreply.github.com>
2025-08-06 15:52:45 +08:00
chenjian
a4572a5e5d fix bug for pd step signal (#3230) 2025-08-06 10:41:52 +08:00
chenjian
a9d231c900 Fix bug for concurrently visit zmq (#3233) 2025-08-06 10:41:10 +08:00
ltd0924
b20ffe3697 [Feature] optimize expert parallel (#3196)
* optimize

* Update expert_service.py

* Update worker_process.py

* optimize
2025-08-05 17:34:24 +08:00
ltd0924
dcf9c2daff [Feature] Optimize prefix cache (#3208)
* [LLM] support ep

* Update worker_process.py

* Update expert_service.py

* Update worker_process.py

* format files

* optimize prefix cache

* optimize prefix cache

* optimize prefix cache

* pre commit format

* pre commit format

* pre commit format

* Update cache_messager.py
2025-08-05 17:13:11 +08:00
chenjian
9f9971844f [Feature] Support ep pd with external module (#3194)
* Support external module

* Support external module

* Support external module

* Support external module

* refactor code to make it more clear

* refactor code to make it more clear

* refactor code to make it more clear

* refactor code to make it more clear

* fix according to review

* fix according to review

* fix according to review

* fix according to review

* fix according to review

* fix according to review

* fix bug

* fix bug

* fix bug

* merge

---------

Co-authored-by: root <root@tjdm-inf-sci-k8s-hzz2-h12ni8-0202.tjdm.baidu.com>
2025-08-04 20:32:41 +08:00
30 changed files with 2052 additions and 700 deletions

View File

@@ -14,6 +14,8 @@
# limitations under the License.
"""
import argparse
import json
import math
import threading
import time
@@ -22,10 +24,63 @@ import numpy as np
import paddle
from fastdeploy.cache_manager.transfer_factory import IPCCommManager, RDMACommManager
from fastdeploy.config import SpeculativeConfig
from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal
from fastdeploy.model_executor.ops.gpu import set_data_ipc
from fastdeploy.utils import get_logger
logger = get_logger("cache_messager", "cache_messager.log")
def parse_args():
"""
从命令行解析参数
"""
parser = argparse.ArgumentParser("Cache Messager")
parser.add_argument(
"--splitwise_role",
type=str,
default="mixed",
help="splitwise role, can be decode, prefill or mixed",
)
parser.add_argument("--rank", type=int, default=0, help="current rank")
parser.add_argument("--device_id", type=int, default=0, help="device id")
parser.add_argument("--num_hidden_layers", type=int, default=1, help="model num layers")
parser.add_argument("--head_dim", type=int, default=1, help="model head dim")
parser.add_argument("--kv_num_head", type=int, default=1, help="model kv num head")
parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel")
parser.add_argument("--engine_pid", type=str, default=None, help="engine pid")
parser.add_argument(
"--protocol",
type=str,
default="ipc",
help="cache transfer protocol, only surport ipc now",
)
parser.add_argument("--pod_ip", type=str, default="0.0.0.0", help="pod ip")
parser.add_argument(
"--engine_worker_queue_port",
type=int,
default=9923,
help="engine worker queue port",
)
parser.add_argument("--num_gpu_blocks", type=int, default=1, help="gpu cache block number")
parser.add_argument("--block_size", type=int, default=64, help="cache block size(tokens)")
parser.add_argument(
"--cache_dtype",
type=str,
default="bfloat16",
choices=["uint8", "bfloat16"],
help="cache dtype",
)
parser.add_argument(
"--speculative_config",
type=json.loads,
default="{}",
help="speculative config",
)
parser.add_argument("--local_data_parallel_id", type=int, default=0)
args = parser.parse_args()
return args
class CacheMessager:
@@ -43,7 +98,7 @@ class CacheMessager:
gpu_cache_kvs,
rank,
nranks,
num_layers,
num_hidden_layers,
gpu_id=0,
rdma_port=None,
):
@@ -57,7 +112,7 @@ class CacheMessager:
gpu_cache_kvs (dict): GPU kv cache
rank (int): current rank
nranks (int): global rank number
num_layers (int): model layer number
num_hidden_layers (int): model layer number
gpu_id (int, optional): GPU ID
rdma_port (int, optional): RDMA port
@@ -73,7 +128,7 @@ class CacheMessager:
self.gpu_cache_kvs = gpu_cache_kvs
self.rank = rank
self.nranks = nranks
address = (pod_ip, engine_worker_queue_port)
address = (pod_ip, engine_worker_queue_port + local_data_parallel_id)
self.engine_worker_queue = EngineWorkerQueue(
address=address,
is_server=False,
@@ -86,13 +141,13 @@ class CacheMessager:
logger.info(f"splitwise role: {splitwise_role}, {transfer_protocol}" f"rank: {rank}")
# 1. initialize the cache_k_ptr_list and cache_v_ptr_list
self.num_layers = num_layers
self.num_hidden_layers = num_hidden_layers
cache_k_ptr_list = []
cache_v_ptr_list = []
cache_k = []
cache_v = []
self.messager = {}
for layer_idx in range(self.num_layers):
for layer_idx in range(self.num_hidden_layers):
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
cache_k.append(key_cache)
@@ -109,7 +164,7 @@ class CacheMessager:
if key_cache.dtype == paddle.bfloat16:
block_bytes *= 2
logger.info(
f"layers {num_layers} cache_shape: {cache_shape}, max_block_num: {max_block_num}, "
f"layers {num_hidden_layers} cache_shape: {cache_shape}, max_block_num: {max_block_num}, "
f"block_bytes: {block_bytes}, dtype: {key_cache.dtype}"
)
self.block_bytes = block_bytes
@@ -142,15 +197,17 @@ class CacheMessager:
self.gpu_id = gpu_id
self.cache_info = dict()
self.dp_rank_id = self.rank + local_data_parallel_id * self.nranks
self.rank_id = (
self.rank + local_data_parallel_id * self.nranks
) # align with engine worker rank (paddle.distributed.launch)
layerwise_send_cache_thread = threading.Thread(target=self._prefill_layerwise_send_cache_thread)
layerwise_send_cache_thread.daemon = True
layerwise_send_cache_thread.start()
connect_rdma_thread = threading.Thread(target=self._handle_connect_task)
connect_rdma_thread.daemon = True
connect_rdma_thread.start()
logger.info(f"cache messager init finished, use {transfer_protocol}")
def _prefill_layerwise_send_cache_thread(self):
def prefill_layerwise_send_cache_thread(self):
"""
layerwise_send_cache_thread:
send cache to other instance
@@ -160,14 +217,14 @@ class CacheMessager:
prefilled_layer_idx_data = np.zeros(shape=[1], dtype=np.int32)
try:
step_shm_value = IPCSignal(
name=f"splitwise_complete_prefilled_step_{self.dp_rank_id}",
name=f"splitwise_complete_prefilled_step_{self.rank_id}",
array=prefilled_step_idx_data,
dtype=np.int32,
suffix=self.gpu_id,
create=True,
)
layer_shm_value = IPCSignal(
name=f"splitwise_complete_prefilled_layer_{self.dp_rank_id}",
name=f"splitwise_complete_prefilled_layer_{self.rank_id}",
array=prefilled_layer_idx_data,
dtype=np.int32,
suffix=self.gpu_id,
@@ -175,14 +232,14 @@ class CacheMessager:
)
except:
step_shm_value = IPCSignal(
name=f"splitwise_complete_prefilled_step_{self.dp_rank_id}",
name=f"splitwise_complete_prefilled_step_{self.rank_id}",
array=prefilled_step_idx_data,
dtype=np.int32,
suffix=self.gpu_id,
create=False,
)
layer_shm_value = IPCSignal(
name=f"splitwise_complete_prefilled_layer_{self.dp_rank_id}",
name=f"splitwise_complete_prefilled_layer_{self.rank_id}",
array=prefilled_layer_idx_data,
dtype=np.int32,
suffix=self.gpu_id,
@@ -195,12 +252,15 @@ class CacheMessager:
self.last_step_idx = -1
self.last_layer_idx = -1 # int32
max_step_idx = 100003
engine_recycled_count = 0
while True:
cache_info = self.engine_worker_queue.get_cache_info()
if cache_info:
logger.debug(f"cache info {cache_info}")
logger.info(f"cache info {cache_info}")
for info in cache_info:
if info["request_id"] in self.cache_info:
self.cache_info[info["request_id"]].update(info)
@@ -214,12 +274,11 @@ class CacheMessager:
current_info["status"] = "init"
logger.info(f"start cache_infos: {current_info}")
self.cache_info[info["request_id"]] = current_info
self.last_step_idx = min(self.last_step_idx, current_info["current_id"])
else:
self.cache_info[info["request_id"]] = info
prefilled_layer_idx = layer_shm_value.value[0]
prefilled_step_idx = step_shm_value.value[0]
if prefilled_layer_idx == self.num_layers - 1:
if prefilled_layer_idx == self.num_hidden_layers - 1:
time.sleep(0.001)
prefilled_layer_idx = layer_shm_value.value[0]
prefilled_step_idx = step_shm_value.value[0]
@@ -230,7 +289,18 @@ class CacheMessager:
if not self.cache_info:
time.sleep(0.001)
continue
logger.debug(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}")
if self.last_step_idx > prefilled_step_idx:
engine_recycled_count += 1
self.last_step_idx = prefilled_step_idx # only copy value read from shm memory
prefilled_step_idx = (
prefilled_step_idx + max_step_idx * engine_recycled_count
) # remap prefilled_step_idx for comparison
logger.debug(
f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx in shm: {self.last_step_idx},"
f"prefilled_step_idx: {prefilled_step_idx} engine_recycled_count {engine_recycled_count}"
)
for req_id, item in list(self.cache_info.items()):
if "status" not in item:
continue
@@ -247,7 +317,7 @@ class CacheMessager:
target_id = int(item["rdma_ports"][self.rank])
status = self.messager[current_transfer_protocol].connect(target_ip, target_id)
if not status:
logger.error(f"connect to {target_ip}:{target_id} failed")
logger.info(f"connect to {target_ip}:{target_id} failed")
item["status"] = "error"
self.engine_worker_queue.finish_request_barrier.wait()
if self.rank == 0:
@@ -259,9 +329,10 @@ class CacheMessager:
src_block_ids = paddle.to_tensor(item["src_block_ids"], dtype="int32", place="cpu")
dest_block_ids = paddle.to_tensor(item["dest_block_ids"], dtype="int32", place="cpu")
if item["current_id"] < prefilled_step_idx:
current_layer_idx = self.num_layers
current_layer_idx = self.num_hidden_layers
else:
current_layer_idx = prefilled_layer_idx + 1
if item["current_id"] == prefilled_step_idx:
current_layer_idx = prefilled_layer_idx + 1
for layer_idx in range(item["layer_idx"], current_layer_idx):
tic = time.time()
@@ -277,7 +348,7 @@ class CacheMessager:
self.engine_worker_queue.finish_request_barrier.wait()
if self.rank == 0:
self.engine_worker_queue.put_finished_req([(item["request_id"], "write cache error")])
logger.error(
logger.info(
f"write cache failed, layer_idx: {layer_idx}, "
f"req_id: {item['request_id']}, dest_ip: {target_ip}"
)
@@ -288,14 +359,14 @@ class CacheMessager:
block_num = len(src_block_ids)
avg_time_per_block = cost_time * 1000 / block_num # ms
send_cache_speed = block_num * self.block_bytes / 1073741824 / cost_time # GB/s
logger.debug(
logger.info(
f"finish write cache for a layer, {item['request_id']}, {layer_idx}"
f" {current_transfer_protocol}"
f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)},"
f"avg_time per block(ms): {round(avg_time_per_block, 5)}"
)
item["layer_idx"] = current_layer_idx
if item["layer_idx"] == self.num_layers:
if item["layer_idx"] == self.num_hidden_layers:
if item["transfer_protocol"] == "ipc":
self.messager["ipc"].write_block_by_sync(target_id)
logger.info(f"finish write cache {item['request_id']}")
@@ -304,9 +375,114 @@ class CacheMessager:
self.engine_worker_queue.put_finished_req([(item["request_id"], "finished")])
logger.info(f"put write cache {item['request_id']}")
del self.cache_info[req_id]
self.last_step_idx = prefilled_step_idx
self.last_layer_idx = prefilled_layer_idx
self.last_layer_idx = prefilled_layer_idx
except Exception as e:
logger.error(f"prefill layerwise send cache thread has exception: {e}")
logger.info(f"prefill layerwise send cache thread has exception: {e}")
def _handle_connect_task(self):
while True:
try:
task = self.engine_worker_queue.get_connect_rdma_task()
if task is None:
time.sleep(0.001)
continue
logger.info(f"_handle_connect_task recv task: {task}")
task_id = task["task_id"]
ip, rdma_port = task["ip"], task["rdma_port"]
status = self.messager["rdma"].connect(ip, rdma_port)
if not status:
response = {"task_id": task_id, "success": False}
else:
response = {"task_id": task_id, "success": True}
self.engine_worker_queue.put_connect_rdma_task_response(response)
except Exception as e:
logger.error(f"handle_connect_task has exception: {e}")
def main():
device = args.device_id
rank = args.rank
paddle.set_device(f"gpu:{device}")
cache_type = args.cache_dtype
speculative_config = SpeculativeConfig(args.speculative_config)
num_extra_layers = speculative_config.num_extra_cache_layer
num_extra_layer_gpu_blocks = int(args.num_gpu_blocks * speculative_config.num_gpu_block_expand_ratio)
gpu_cache_kvs = {}
gpu_cache_k_tensors = []
gpu_cache_v_tensors = []
for i in range(args.num_hidden_layers + num_extra_layers):
num_gpu_blocks = args.num_gpu_blocks if i < args.num_hidden_layers else num_extra_layer_gpu_blocks
gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"] = paddle.full(
shape=[
num_gpu_blocks,
args.kv_num_head,
args.block_size,
args.head_dim,
],
fill_value=0,
dtype=cache_type,
)
gpu_cache_k_tensors.append(gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"])
gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full(
shape=[
num_gpu_blocks,
args.kv_num_head,
args.block_size,
args.head_dim,
],
fill_value=0,
dtype=cache_type,
)
gpu_cache_v_tensors.append(gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"])
set_data_ipc(
gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"],
f"key_caches_{i}_rank{rank}.device{device}",
)
set_data_ipc(
gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"],
f"value_caches_{i}_rank{rank}.device{device}",
)
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in gpu_cache_kvs.items()])
logger.info(f"device :{device}")
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
logger.info(f"done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}")
cache_messager = CacheMessager(
splitwise_role=args.splitwise_role,
transfer_protocol=args.protocol,
pod_ip=args.pod_ip,
engine_worker_queue_port=args.engine_worker_queue_port,
local_data_parallel_id=args.local_data_parallel_id,
gpu_cache_kvs=gpu_cache_kvs,
rank=rank,
nranks=args.mp_num,
num_hidden_layers=args.num_hidden_layers + num_extra_layers,
gpu_id=device,
rdma_port=args.rdma_port,
)
cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
cache_ready_signal = IPCSignal(
name="cache_ready_signal",
array=cache_ready_signal_data,
dtype=np.int32,
suffix=args.engine_pid,
create=False,
)
cache_ready_signal.value[rank] = 1
cache_messager.prefill_layerwise_send_cache_thread()
if __name__ == "__main__":
args = parse_args()
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
logger = get_logger("cache_messager", f"cache_messager_rank{rank_id}.log")
logger.info("create cache messager...")
logger.info(f"{args}")
main()

View File

@@ -28,7 +28,7 @@ from fastdeploy.config import SpeculativeConfig
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal
from fastdeploy.model_executor.ops.gpu import (
cuda_host_alloc,
set_data_ipc,
share_external_data,
swap_cache_all_layers,
)
from fastdeploy.utils import get_logger
@@ -39,26 +39,12 @@ def parse_args():
从命令行解析参数
"""
parser = argparse.ArgumentParser("Cache transfer manager")
parser.add_argument(
"--splitwise_role",
type=str,
default="mixed",
help="splitwise role, can be decode, prefill or mixed",
)
parser.add_argument("--rank", type=int, default=0, help="current rank")
parser.add_argument("--device_id", type=int, default=0, help="device id")
parser.add_argument("--num_layers", type=int, default=1, help="model num layers")
parser.add_argument("--num_hidden_layers", type=int, default=1, help="model num layers")
parser.add_argument("--head_dim", type=int, default=1, help="model head dim")
parser.add_argument("--kv_num_head", type=int, default=1, help="model kv num head")
parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel")
parser.add_argument(
"--protocol",
type=str,
default="ipc",
help="cache transfer protocol, only surport ipc now",
)
parser.add_argument("--enable_splitwise", type=int, default=0, help="enable splitwise ")
parser.add_argument("--cache_queue_port", type=int, default=9923, help="cache queue port")
parser.add_argument("--pod_ip", type=str, default="0.0.0.0", help="pod ip")
parser.add_argument(
@@ -68,7 +54,6 @@ def parse_args():
help="engine worker queue port",
)
parser.add_argument("--engine_pid", type=str, default=None, help="engine pid")
parser.add_argument("--num_gpu_blocks", type=int, default=1, help="gpu cache block number")
parser.add_argument("--num_cpu_blocks", type=int, default=4, help="cpu cache block number")
parser.add_argument("--block_size", type=int, default=64, help="cache block size(tokens)")
@@ -109,7 +94,6 @@ class CacheTransferManager:
device = args.device_id
rank = args.rank
paddle.set_device(f"gpu:{device}")
self.gpu_cache_kvs = {}
self.cpu_cache_kvs = {}
self.gpu_cache_k_tensors = []
@@ -138,40 +122,27 @@ class CacheTransferManager:
self.num_cpu_blocks = args.num_cpu_blocks
cache_type = args.cache_dtype
for i in range(args.num_layers + self.num_extra_layers):
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
cache_shape = [
args.num_gpu_blocks,
args.kv_num_head,
args.block_size,
args.head_dim,
]
self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"] = paddle.full(
shape=[
num_gpu_blocks,
args.kv_num_head,
args.block_size,
args.head_dim,
],
fill_value=0,
dtype=cache_type,
)
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"])
self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full(
shape=[
num_gpu_blocks,
args.kv_num_head,
args.block_size,
args.head_dim,
],
fill_value=0,
dtype=cache_type,
)
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"])
for i in range(args.num_hidden_layers + self.num_extra_layers):
num_gpu_blocks = args.num_gpu_blocks if i < args.num_hidden_layers else self.num_extra_layer_gpu_blocks
cache_shape[0] = num_gpu_blocks
key_name = f"key_caches_{i}_rank{rank}.device{device}"
value_name = f"value_caches_{i}_rank{rank}.device{device}"
key_cache = paddle.empty(shape=[], dtype=cache_type)
value_cache = paddle.empty(shape=[], dtype=cache_type)
key_cache = share_external_data(key_cache, key_name, cache_shape)
value_cache = share_external_data(value_cache, value_name, cache_shape)
self.gpu_cache_kvs[key_name] = key_cache
self.gpu_cache_kvs[value_name] = value_cache
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[key_name])
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[value_name])
set_data_ipc(
self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"],
f"key_caches_{i}_rank{rank}.device{device}",
)
set_data_ipc(
self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"],
f"value_caches_{i}_rank{rank}.device{device}",
)
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
logger.info(f"device :{self.device}")
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
@@ -180,7 +151,7 @@ class CacheTransferManager:
paddle.set_device("cpu")
self.k_dst_ptrs = []
self.v_dst_ptrs = []
for i in range(args.num_layers + self.num_extra_layers):
for i in range(args.num_hidden_layers + self.num_extra_layers):
self.cpu_cache_kvs[f"key_caches_{i}_rank{rank}"] = cuda_host_alloc(
args.num_cpu_blocks * args.bytes_per_layer_per_block
)
@@ -190,38 +161,6 @@ class CacheTransferManager:
)
self.v_dst_ptrs.append(self.cpu_cache_kvs[f"value_caches_{i}_rank{rank}"])
cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
self.cache_ready_signal = IPCSignal(
name="cache_ready_signal",
array=cache_ready_signal_data,
dtype=np.int32,
suffix=args.engine_pid,
create=False,
)
self.cache_ready_signal.value[self.rank] = 1
paddle.set_device(f"gpu:{device}")
if args.enable_splitwise:
logger.debug("create cache messager...")
logger.info(f"{args}")
from fastdeploy.cache_manager.cache_messager import CacheMessager
self.cache_messager = CacheMessager(
splitwise_role=args.splitwise_role,
transfer_protocol=args.protocol,
pod_ip=args.pod_ip,
engine_worker_queue_port=args.engine_worker_queue_port,
local_data_parallel_id=args.local_data_parallel_id,
gpu_cache_kvs=self.gpu_cache_kvs,
rank=self.rank,
nranks=args.mp_num,
num_layers=args.num_layers + self.num_extra_layers,
gpu_id=self.device,
rdma_port=args.rdma_port,
)
logger.info("successfully create cache messager")
logger.info(f"done init CacheMessager gmem alloc : {paddle.device.cuda.memory_allocated()}")
cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32)
self.cache_task_broadcast_signal = IPCSignal(
name="cache_task_broadcast_signal",
@@ -443,4 +382,5 @@ if __name__ == "__main__":
args = parse_args()
logger = get_logger("cache_transfer_manager", "cache_transfer_manager.log")
paddle.set_device(f"gpu:{args.device_id}")
main()

View File

@@ -31,6 +31,7 @@ from fastdeploy import envs
from fastdeploy.cache_manager.cache_data import BlockNode, CacheStatus
from fastdeploy.cache_manager.cache_metrics import CacheMetrics
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.utils import get_logger
logger = get_logger("prefix_cache_manager", "prefix_cache_manager.log")
@@ -106,6 +107,10 @@ class PrefixCacheManager:
+ f"{self.num_cpu_blocks}, bytes_per_layer_per_block {self.cache_config.bytes_per_layer_per_block}"
)
@property
def available_gpu_resource(self):
return len(self.gpu_free_block_list) / self.num_gpu_blocks if self.num_gpu_blocks > 0 else 0.0
def launch_cache_manager(
self,
cache_config,
@@ -141,6 +146,76 @@ class PrefixCacheManager:
filename = "cache_transfer_manager.py"
py_path = os.path.join(current_dir_path, filename)
cache_messager_processes = []
if self.splitwise_role != "mixed":
cache_messager_processes = self.launch_cache_messager(
cache_config,
tensor_parallel_size,
device_ids,
pod_ip,
engine_worker_queue_port,
pid_suffix,
)
if cache_messager_processes is None:
raise RuntimeError("Launch cache messager failed")
return []
if (
hasattr(cache_config.model_cfg, "num_key_value_heads")
and hasattr(cache_config.model_cfg, "num_key_value_heads")
and cache_config.model_cfg.num_key_value_heads is not None
and int(cache_config.model_cfg.num_key_value_heads) > 0
):
kv_num_head = int(cache_config.model_cfg.num_key_value_heads) // tensor_parallel_size
else:
kv_num_head = cache_config.model_cfg.num_attention_heads // tensor_parallel_size
log_dir = envs.FD_LOG_DIR
cache_manager_processes = []
for i in range(tensor_parallel_size):
launch_cmd = (
f" {sys.executable} {py_path}"
+ f" --device_id {int(device_ids[i])}"
+ f" --rank {i}"
+ f" --num_hidden_layers {cache_config.model_cfg.num_hidden_layers}"
+ f" --head_dim {cache_config.model_cfg.head_dim}"
+ f" --kv_num_head {kv_num_head}"
+ f" --mp_num {tensor_parallel_size}"
+ f" --cache_dtype {cache_config.cache_dtype}"
+ f" --cache_queue_port {cache_config.cache_queue_port}"
+ f" --pod_ip {pod_ip}"
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
+ f" --num_gpu_blocks {cache_config.total_block_num}"
+ f" --num_cpu_blocks {cache_config.num_cpu_blocks}"
+ f" --bytes_per_layer_per_block {cache_config.bytes_per_layer_per_block}"
+ f" --block_size {cache_config.block_size}"
+ f" --engine_pid {pid_suffix}"
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
+ f" >{log_dir}/launch_cache_manager_{int(device_ids[i])}.log 2>&1"
)
logger.info(f"Launch cache transfer manager, command:{launch_cmd}")
cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
exit_code = cache_manager_processes[-1].poll()
if exit_code is None:
logger.info("Launch cache transfer manager successful")
else:
logger.info("Launch cache transfer manager failed, see launch_cache_manager.log for more information")
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0:
logger.info("Enable hierarchical cache.")
self._enable_cpu_cache()
cache_manager_processes.extend(cache_messager_processes)
return cache_manager_processes
def launch_cache_messager(
self, cache_config, tensor_parallel_size, device_ids, pod_ip, engine_worker_queue_port, pid_suffix
):
"""
launch_cache_messager function used to initialize the cache messager.
"""
current_dir_path = os.path.split(os.path.abspath(__file__))[0]
filename = "cache_messager.py"
if (
hasattr(cache_config.model_cfg, "num_key_value_heads")
and hasattr(cache_config.model_cfg, "num_key_value_heads")
@@ -159,8 +234,10 @@ class PrefixCacheManager:
suffix=pid_suffix,
create=True,
)
py_path = os.path.join(current_dir_path, filename)
log_dir = envs.FD_LOG_DIR
cache_manager_processes = []
cache_messager_processes = []
for i in range(tensor_parallel_size):
launch_cmd = (
"FLAGS_allocator_strategy=auto_growth CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7"
@@ -169,42 +246,34 @@ class PrefixCacheManager:
+ f" --device_id {int(device_ids[i])}"
+ f" --rank {i}"
+ f" --splitwise_role {self.splitwise_role}"
+ f" --num_layers {cache_config.model_cfg.num_hidden_layers}"
+ f" --num_hidden_layers {cache_config.model_cfg.num_hidden_layers}"
+ f" --head_dim {cache_config.model_cfg.head_dim}"
+ f" --kv_num_head {kv_num_head}"
+ f" --mp_num {tensor_parallel_size}"
+ f" --cache_dtype {cache_config.cache_dtype}"
+ f" --cache_queue_port {cache_config.cache_queue_port}"
+ f" --enable_splitwise {int(self.enable_splitwise)}"
+ f" --pod_ip {pod_ip}"
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
+ f" --num_gpu_blocks {cache_config.total_block_num}"
+ f" --num_cpu_blocks {cache_config.num_cpu_blocks}"
+ f" --bytes_per_layer_per_block {cache_config.bytes_per_layer_per_block}"
+ f" --block_size {cache_config.block_size}"
+ f" --engine_pid {pid_suffix}"
+ f" --protocol {cache_config.cache_transfer_protocol}"
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
+ f" --engine_pid {pid_suffix}"
+ f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
+ f" >{log_dir}/launch_cache_manager_{int(device_ids[i])}.log 2>&1"
+ f" >{log_dir}/launch_cache_messager_{int(device_ids[i])}.log 2>&1"
)
logger.info(f"Launch cache transfer manager, command:{launch_cmd}")
cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
# 等待cache初始化完毕
logger.info("Waiting for cache transfer manager ready...")
logger.info(f"Launch cache messager, command:{launch_cmd}")
cache_messager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
logger.info("Waiting for cache ready...")
while np.sum(self.cache_ready_signal.value) != tensor_parallel_size:
time.sleep(1)
exit_code = cache_manager_processes[-1].poll()
exit_code = cache_messager_processes[-1].poll()
if exit_code is None:
logger.info("Launch cache transfer manager successful")
logger.info("Launch cache messager successful")
else:
logger.info("Launch cache transfer manager failed, see launch_cache_manager.log for more information")
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0:
logger.info("Enable hierarchical cache.")
self._enable_cpu_cache()
return cache_manager_processes
logger.info("Launch cache messager failed, see launch_cache_messager.log for more information")
cache_messager_processes = None
return cache_messager_processes
def update_cache_config(self, cache_config):
"""
@@ -225,6 +294,9 @@ class PrefixCacheManager:
heapq.heapify(self.gpu_free_block_list)
self.node_id_pool = list(range(self.num_gpu_blocks + self.num_cpu_blocks))
main_process_metrics.max_gpu_block_num.set(self.num_gpu_blocks)
main_process_metrics.available_gpu_resource.set(1.0)
def _enable_cpu_cache(self):
"""
_enable_cpu_cache function used to enable cpu cache.
@@ -260,6 +332,8 @@ class PrefixCacheManager:
logger.info(
f"allocate_gpu_blocks: {allocated_block_ids}, len(self.gpu_free_block_list) {len(self.gpu_free_block_list)}"
)
main_process_metrics.free_gpu_block_num.set(len(self.gpu_free_block_list))
main_process_metrics.available_gpu_resource.set(self.available_gpu_resource)
return allocated_block_ids
def recycle_gpu_blocks(self, gpu_block_ids):
@@ -274,6 +348,8 @@ class PrefixCacheManager:
heapq.heappush(self.gpu_free_block_list, gpu_block_id)
else:
heapq.heappush(self.gpu_free_block_list, gpu_block_ids)
main_process_metrics.free_gpu_block_num.set(len(self.gpu_free_block_list))
main_process_metrics.available_gpu_resource.set(self.available_gpu_resource)
def allocate_cpu_blocks(self, num_blocks):
"""

View File

@@ -61,18 +61,12 @@ class RDMACommManager:
Connect to remote gpu and write cache.
"""
assert self.splitwise_role == "prefill", "only prefill can call this method"
addr = f"{ip}:{port!s}"
if addr in self.connected_rdma:
return True
ret = self.messager.is_connected(ip, str(port))
if ret:
self.connected_rdma.add(addr)
return True
ret = self.messager.connect(ip, str(port))
logger.info(f"connect to remote rdma address {ip}:{port} status is {ret}")
if ret == 0:
self.connected_rdma.add(addr)
return ret == 0
def write_cache(self, ip, port, local_block_ids, remote_block_ids, layer_idx):

View File

@@ -820,6 +820,7 @@ class EngineArgs:
"max_num_partial_prefills",
"max_long_partial_prefills",
"long_prefill_token_threshold",
"splitwise_role"
]
all = asdict(self)

View File

@@ -293,10 +293,11 @@ class Config:
)
if not self.cache_config.enable_chunked_prefill:
assert self.max_num_batched_tokens >= self.max_model_len, (
f"max_num_batched_tokens: {self.max_num_batched_tokens} "
f"should be larger than or equal to max_model_len: {self.max_model_len}"
)
if not int(os.getenv("FD_ENABLE_INTERNAL_ADAPTER", "0")):
assert self.max_num_batched_tokens >= self.max_model_len, (
f"max_num_batched_tokens: {self.max_num_batched_tokens} "
f"should be larger than or equal to max_model_len: {self.max_model_len}"
)
else:
assert self.max_num_batched_tokens >= self.cache_config.block_size, (
f"max_num_batched_tokens: {self.max_num_batched_tokens} "

View File

@@ -28,6 +28,7 @@ import time
import traceback
import uuid
import weakref
from collections import deque
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional, Tuple
@@ -47,12 +48,14 @@ from fastdeploy.inter_communicator import (
EngineCacheQueue,
EngineWorkerQueue,
IPCSignal,
ZmqClient,
ZmqIpcServer,
ZmqTcpServer,
)
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.metrics.trace_util import start_span, start_span_request
from fastdeploy.model_executor.guided_decoding import schema_checker
from fastdeploy.output.token_processor import TokenProcessor, WarmUpTokenProcessor
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
from fastdeploy.utils import EngineError, console_logger, envs, llm_logger
@@ -110,6 +113,8 @@ class LLMEngine:
self.start_queue_service()
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.resource_manager = ResourceManagerV1(
cfg.max_num_seqs, cfg, cfg.tensor_parallel_size, cfg.splitwise_role
@@ -123,9 +128,17 @@ class LLMEngine:
cfg.max_num_seqs, cfg, cfg.tensor_parallel_size, cfg.splitwise_role
)
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.cfg.engine_worker_queue_port)
self.split_connector = SplitwiseConnector(cfg, self.scheduler, self.engine_worker_queue, self.resource_manager)
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(
self.cfg.engine_worker_queue_port + self.cfg.parallel_config.local_data_parallel_id
)
self.splitwise_queue = deque()
self.split_connector = SplitwiseConnector(
cfg,
self.scheduler,
self.engine_worker_queue,
self.resource_manager,
self.splitwise_queue,
)
self.token_processor = TokenProcessor(
cfg=self.cfg,
@@ -177,13 +190,71 @@ class LLMEngine:
self._init_worker_signals()
self.data_processor = self.input_processor.create_processor()
self.response_lock = threading.Lock() # prevent to call send_multipart in zmq concurrently
if api_server_pid is not None:
self.zmq_server = ZmqClient(name=api_server_pid, mode=zmq.PULL)
self.zmq_server.start_server()
self.zmq_server.create_router()
if envs.FD_ENABLE_INTERNAL_ADAPTER:
self.recv_request_server = ZmqTcpServer(port=envs.FD_ZMQ_RECV_REQUEST_SERVER_PORT, mode=zmq.PULL)
self.send_response_server = ZmqTcpServer(port=envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORT, mode=zmq.ROUTER)
self.external_adapter = InternalAdapter(
cfg=self.cfg, engine=self, dp_rank=self.cfg.node_rank * self.cfg.worker_num_per_node
)
else:
self.recv_request_server = ZmqIpcServer(name=api_server_pid, mode=zmq.PULL)
self.send_response_server = ZmqIpcServer(name=api_server_pid, mode=zmq.ROUTER)
self.recv_result_handle_thread = threading.Thread(
target=self.send_response_server.recv_result_handle, daemon=True
)
self.recv_result_handle_thread.start()
time.sleep(3)
self.cfg.init_cache_info()
role = self.cfg.splitwise_role
host_ip = self.cfg.host_ip
disaggregate = self.cfg.disaggregate_info
request_queues_for_dp_ipc = (
None # Different dp has its own process, use multiprocessing.Queue to deliver requests for each dp
)
result_queue_for_dp_ipc = None
if self.cfg.scheduler_config.name == "splitwise":
self.scheduler.start(role, host_ip, disaggregate)
elif self.cfg.scheduler_config.name == "dp":
request_queues_for_dp_ipc = []
result_queue_for_dp_ipc = multiprocessing.Queue()
for i in range(self.cfg.parallel_config.data_parallel_size):
request_queues_for_dp_ipc.append(multiprocessing.Queue())
self.scheduler.start(
self.cfg.node_rank * self.cfg.worker_num_per_node, request_queues_for_dp_ipc, result_queue_for_dp_ipc
)
time.sleep(1)
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
self.dp_processed = []
for i in range(
1,
self.cfg.parallel_config.data_parallel_size // self.cfg.nnode,
):
time.sleep(1)
self.dp_processed.append(
multiprocessing.Process(
target=start_expert_service,
args=(
self.cfg,
i + self.cfg.node_rank * self.cfg.worker_num_per_node,
self.ipc_signal_suffix,
request_queues_for_dp_ipc,
result_queue_for_dp_ipc,
),
)
)
llm_logger.info(
f"Engine is initialized successfully with {self.cfg.tensor_parallel_size}"
+ f" data parallel id {i}"
)
self.dp_processed[-1].start()
if self.do_profile == 0 and (
self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed"
):
@@ -238,44 +309,11 @@ class LLMEngine:
# 单机逻辑
self.engine_worker_queue.available_prefill_instances.put(1)
self.split_mode_get_tasks()
if self.cfg.scheduler_config.name == "splitwise":
if self.cfg.scheduler_config.name == "splitwise" or self.cfg.scheduler_config.name == "dp":
self.splitwise_receive_thread = threading.Thread(target=self.split_connector.start_receiver, args=())
self.splitwise_receive_thread.daemon = True
self.splitwise_receive_thread.start()
self.cfg.init_cache_info()
role = self.cfg.splitwise_role
host_ip = self.cfg.host_ip
disaggregate = self.cfg.disaggregate_info
if self.cfg.scheduler_config.name == "splitwise":
self.scheduler.start(role, host_ip, disaggregate)
time.sleep(1)
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
self.dp_processed = []
for i in range(
1,
self.cfg.parallel_config.data_parallel_size // self.cfg.nnode,
):
time.sleep(1)
self.dp_processed.append(
multiprocessing.Process(
target=start_expert_service,
args=(
self.cfg,
i + self.cfg.node_rank * self.cfg.worker_num_per_node,
self.ipc_signal_suffix,
),
)
)
llm_logger.info(
f"Engine is initialized successfully with {self.cfg.tensor_parallel_size}"
+ f" data parallel id {i}"
)
self.dp_processed[-1].start()
console_logger.info(f"Worker processes are launched with {time.time() - start_time} seconds.")
return True
@@ -290,8 +328,9 @@ class LLMEngine:
if len(results) == 0:
time.sleep(0.005)
continue
for request_id, contents in results.items():
self.zmq_server.send_multipart(request_id, contents)
with self.response_lock:
for request_id, contents in results.items():
self.send_response_server.send_response(request_id, contents)
except Exception as e:
llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")
@@ -308,7 +347,7 @@ class LLMEngine:
Insert task to engine thread, monitor scheduler request queue.
if the engine has resource, insert task to engine
"""
current_id = -1
current_id = 0
while self.running:
try:
if self.resource_manager.available_batch() == 0:
@@ -321,18 +360,15 @@ class LLMEngine:
if self.cfg.splitwise_role == "mixed" or self.split_connector.has_splitwise_tasks():
time.sleep(0.005)
continue
if self.engine_worker_queue.num_cache_infos() > 0:
time.sleep(0.001)
continue
if len(self.split_connector.current_request_ids) > 0:
time.sleep(0.001)
continue
num_prefill_batch = min(
int(self.resource_manager.available_batch()),
self.cfg.max_prefill_batch,
)
if envs.FD_ENABLE_INTERNAL_ADAPTER:
num_prefill_batch = int(self.resource_manager.available_batch())
self.resource_manager.check_and_free_block_tables()
tasks = self.scheduler.get_requests(
available_blocks=self.resource_manager.available_block_num(),
@@ -346,12 +382,15 @@ class LLMEngine:
time.sleep(0.001)
continue
current_id = (current_id + 1) % 100003
if self.cfg.splitwise_role != "mixed":
llm_logger.info("Inserting splitwise tasks")
self.split_connector.send_splitwise_tasks(tasks, current_id)
self.insert_tasks(tasks, current_id)
insert_successful = self.insert_tasks(tasks, current_id)
if insert_successful:
current_id = current_id + 1
else:
continue
main_process_metrics.num_requests_waiting.dec(len(tasks))
main_process_metrics.num_requests_running.inc(len(tasks))
@@ -400,6 +439,8 @@ class LLMEngine:
get_request_pool.submit(_fetch_request)
# 2. Schedule requests
tasks = self.resource_manager.schedule()
main_process_metrics.num_requests_waiting.dec(len(tasks))
main_process_metrics.num_requests_running.inc(len(tasks))
# 3. Send to engine
if tasks:
self.resource_manager.get_real_bsz()
@@ -415,14 +456,18 @@ class LLMEngine:
if self.api_server_pid is None:
return
if envs.FD_ENABLE_INTERNAL_ADAPTER:
if self.cfg.splitwise_role == "decode":
return
added_requests: Dict[str, int] = dict()
while self.running:
try:
block = True if len(added_requests) == 0 else False
if not self.cfg.enable_mm:
err, data = self.zmq_server.receive_json_once(block)
err, data = self.recv_request_server.receive_json_once(block)
else:
err, data = self.zmq_server.receive_pyobj_once(block)
err, data = self.recv_request_server.receive_pyobj_once(block)
if err is not None:
llm_logger.error("Engine stops inserting zmq task into scheduler, err:{err}")
break
@@ -433,6 +478,7 @@ class LLMEngine:
request = Request.from_dict(data)
start_span("ENQUEUE_ZMQ", data, trace.SpanKind.PRODUCER)
main_process_metrics.requests_number.inc()
llm_logger.debug(f"Receive request: {request}")
err_msg = None
@@ -461,7 +507,7 @@ class LLMEngine:
if failed is None:
main_process_metrics.num_requests_waiting.inc(1)
continue
llm_logger.error(f"request {request_id} insert to scheduler failed: {failed}")
error_result = RequestOutput(
request_id=request_id,
finished=True,
@@ -470,7 +516,8 @@ class LLMEngine:
)
# Since the request is not in scheduler
# Send result by zmq directly
self.zmq_server.send_multipart(request_id, error_result)
with self.response_lock:
self.send_response_server.send_response(request_id, [error_result])
except Exception as e:
llm_logger.error(
f"Error happend while receving new request from zmq, details={e}, "
@@ -570,41 +617,44 @@ class LLMEngine:
for idx in sorted(processed_indices, reverse=True):
self.waiting_requests.pop(idx)
if not self.engine_worker_queue.disaggregate_queue_empty():
items = self.engine_worker_queue.get_disaggregated_tasks()
for item in items:
role = item[0]
tasks = item[1]
if len(self.splitwise_queue) > 0:
items = self.splitwise_queue.pop()
role = items[0]
tasks = items[1]
if role == "prefill":
if role == "prefill":
for task in tasks:
task.max_tokens = task.min_tokens = 2
self.insert_tasks(tasks)
elif role == "decode":
if hasattr(tasks[0], "finished"):
if not isinstance(tasks, list):
tasks = [tasks]
for task in tasks:
task.max_tokens = task.min_tokens = 2
self.insert_tasks(tasks)
task.finished = False
self.insert_tasks(tasks, allocated=True)
elif role == "decode":
if hasattr(tasks[0], "finished"):
if not isinstance(tasks, list):
tasks = [tasks]
for task in tasks:
task.finished = False
self.insert_tasks(tasks, allocated=True)
if self.cfg.innode_prefill_ports is not None:
self.scheduler.put_results(tasks)
if self.cfg.innode_prefill_ports is not None:
self.scheduler.put_results(tasks)
else:
if len(self.waiting_requests):
llm_logger.info(f"Waiting for resource for task {tasks[0].request_id}")
self.waiting_requests.extend(tasks)
else:
if len(self.waiting_requests):
llm_logger.info(f"Waiting for resource for task {tasks[0].request_id}")
self.waiting_requests.extend(tasks)
else:
new_waiting = []
for task in tasks:
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
self.insert_tasks([task])
else:
new_waiting.append(task)
if new_waiting:
new_waiting = []
for task in tasks:
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
self.insert_tasks([task])
else:
if not self.enable_decode_cache_task:
task.error_msg = "Not enough resources"
new_waiting.append(task)
if new_waiting:
if not self.enable_decode_cache_task:
self.split_connector.send_cache_infos(new_waiting, -1)
else:
self.waiting_requests.extend(new_waiting)
llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue")
@@ -749,10 +799,6 @@ class LLMEngine:
"""
Insert tasks to engine.
"""
for task in tasks:
start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER)
if task.sampling_params.bad_words is not None:
task.sampling_params.update_from_tokenizer(self.data_processor.tokenizer)
# TODO 返回至 scheduler
if allocated:
current_tasks = []
@@ -760,6 +806,15 @@ class LLMEngine:
cur_task_idx = self.resource_manager.req_dict[task.request_id]
del self.resource_manager.req_dict[task.request_id]
cur_task = self.resource_manager.tasks_list[cur_task_idx]
if envs.FD_ENABLE_INTERNAL_ADAPTER:
if not task.outputs.token_ids: # first token is eos in Prefill, just recycle resource and continue
self.resource_manager.stop_flags[cur_task_idx] = True
self.resource_manager.tasks_list[cur_task_idx] = None
self.resource_manager._recycle_block_tables(cur_task)
if task.request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[task.request_id]
llm_logger.warning(f"{task.request_id} need not decode after first token")
continue
cur_task.prompt_token_ids[0] = task.outputs.token_ids[0]
if self.cfg.speculative_config.method in ["mtp"] and self.cfg.splitwise_role == "decode":
cur_task.draft_token_ids = copy.deepcopy(task.outputs.draft_token_ids)
@@ -769,32 +824,58 @@ class LLMEngine:
self.resource_manager._recycle_block_tables(cur_task)
if task.request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[task.request_id]
self.scheduler.put_results([task])
llm_logger.warning(
f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource."
)
continue
self.token_processor.tokens_counter[task.request_id] = 1
current_tasks.append(cur_task)
self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz))
if current_tasks:
self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz))
return True
self.resource_manager.check_and_free_block_tables()
if not isinstance(tasks, list):
tasks = [tasks]
need_delete_tasks = []
for task in tasks:
start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER)
if self.cfg.splitwise_role != "mixed":
status, msg = self.split_connector.check_decode_allocated(task)
if not status:
llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.")
self.scheduler.put_results(
[
RequestOutput(
request_id=task.request_id,
finished=True,
error_code=500,
error_msg=msg,
)
]
)
need_delete_tasks.append(task)
continue
if task.sampling_params.bad_words is not None:
task.sampling_params.update_from_tokenizer(self.data_processor.tokenizer)
self.resource_manager.check_and_free_block_tables()
for tmp_task in need_delete_tasks:
tasks.remove(tmp_task)
for item in tasks:
item.schedule_start_time = time.time()
req_ids = [t.request_id for t in tasks]
if len(tasks) == 0:
return False
available_batch = np.sum(self.resource_manager.stop_flags)
if len(tasks) > available_batch:
llm_logger.error(f"Inserting batch:{len(tasks)} exceeds the available batch:{available_batch}.")
llm_logger.error("The exceeded part will be ignored!")
tasks = tasks[:available_batch]
req_ids = [t.request_id for t in tasks]
tasks = self.resource_manager.allocate_resources_for_new_tasks(tasks)
if not tasks:
@@ -815,19 +896,19 @@ class LLMEngine:
is_prefill = True
self.token_processor.number_of_input_tokens += tasks[i].prompt_token_ids_len
self.split_connector.send_cache_infos(tasks, current_id)
for task in tasks:
task.inference_start_time = time.time()
if not is_decode:
llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
for task in tasks:
task.inference_start_time = time.time()
if not is_prefill:
if not self.cfg.enable_mm:
self.update_requests_chunk_size(tasks)
else:
self.update_mm_requests_chunk_size(tasks)
if not self.cfg.enable_mm:
self.update_requests_chunk_size(tasks)
else:
self.update_mm_requests_chunk_size(tasks)
self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz))
if is_prefill and self.cfg.scheduler_config.name != "splitwise":
self.engine_worker_queue.available_prefill_instances.put(1)
self.split_connector.send_cache_infos(tasks, current_id)
return True
def task_is_finished(self, index):
@@ -966,14 +1047,17 @@ class LLMEngine:
self.running = False
if hasattr(self, "cache_manager_processes"):
self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear()
self.resource_manager.cache_manager.cache_ready_signal.clear()
for p in self.cache_manager_processes:
llm_logger.info(f"Killing cache manager process {p.pid}")
try:
os.killpg(p.pid, signal.SIGTERM)
except Exception as e:
print(f"Error extracting file: {e}")
if hasattr(self.resource_manager.cache_manager, "cache_ready_signal"):
self.resource_manager.cache_manager.cache_ready_signal.clear()
self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear()
if hasattr(self, "zmq_server") and self.zmq_server is not None:
self.zmq_server.close()
self.worker_ready_signal.clear()
self.exist_task_signal.clear()
self.exist_swapped_task_signal.clear()
@@ -988,12 +1072,19 @@ class LLMEngine:
except Exception as e:
print(f"Error extracting sub services: {e}")
self.engine_worker_queue.cleanup()
if hasattr(self, "zmq_server") and self.zmq_server is not None:
self.zmq_server.close()
for worker_queue in self.engine_worker_queue_server:
worker_queue.cleanup()
if hasattr(self, "send_response_server") and self.send_response_server is not None:
self.send_response_server.close()
if hasattr(self, "recv_request_server") and self.recv_request_server is not None:
self.recv_request_server.close()
if hasattr(self, "recv_control_cmd_server") and self.recv_control_cmd_server is not None:
self.recv_control_cmd_server.close()
if hasattr(self, "dp_processed"):
for p in self.dp_processed:
p.join()
self.engine_worker_queue_server.cleanup()
def _setting_environ_variables(self):
"""
@@ -1291,15 +1382,20 @@ class LLMEngine:
"""
start queue service for engine worker communication
"""
address = (self.cfg.master_ip, self.cfg.engine_worker_queue_port)
self.engine_worker_queue_server = list()
if self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0":
llm_logger.info(f"Starting engine worker queue server service at {address}")
self.engine_worker_queue_server = EngineWorkerQueue(
address=address,
is_server=True,
num_client=self.cfg.tensor_parallel_size,
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
)
for i in range(self.cfg.parallel_config.data_parallel_size // self.cfg.nnode):
address = (self.cfg.master_ip, self.cfg.engine_worker_queue_port + i)
llm_logger.info(f"Starting engine worker queue service at {address}")
self.engine_worker_queue_server.append(
EngineWorkerQueue(
address=address,
is_server=True,
num_client=self.cfg.tensor_parallel_size,
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
)
)
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed":
self.cache_task_queue = EngineCacheQueue(
@@ -1314,6 +1410,7 @@ class LLMEngine:
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
)
address = (self.cfg.master_ip, self.cfg.engine_worker_queue_port)
self.engine_worker_queue = EngineWorkerQueue(
address=address,
is_server=False,

View File

@@ -16,21 +16,25 @@
from __future__ import annotations
import copy
import os
import signal
import threading
import time
import traceback
import weakref
from collections import deque
import numpy as np
from fastdeploy.engine.request import RequestOutput
from fastdeploy.engine.resource_manager import ResourceManager
from fastdeploy.inter_communicator import EngineWorkerQueue
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.output.token_processor import TokenProcessor
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
from fastdeploy.utils import EngineError, console_logger, llm_logger
from fastdeploy.utils import EngineError, console_logger, envs, get_logger, llm_logger
class ExpertService:
@@ -52,6 +56,10 @@ class ExpertService:
self.cfg = cfg
start_pos = (local_data_parallel_id * self.cfg.tensor_parallel_size) % cfg.worker_num_per_node
end_pos = start_pos + self.cfg.tensor_parallel_size
self.waiting_requests = []
self.disaggregate_queue = deque()
self.llm_logger = get_logger("expert_service", f"expert_service_{local_data_parallel_id}.log")
if cfg.splitwise_role != "mixed":
self.cfg.cache_config.rdma_comm_ports = self.cfg.cache_config.rdma_comm_ports[start_pos:end_pos]
self.cfg.local_device_ids = self.cfg.device_ids.split(",")[start_pos:end_pos]
@@ -60,11 +68,12 @@ class ExpertService:
self.scheduler = cfg.scheduler_config.scheduler()
self.scheduler.reset_nodeid(f"{self.scheduler.infer.nodeid}_{local_data_parallel_id!s}")
if self.cfg.scheduler_config.name == "splitwise":
self.scheduler.reset_nodeid(f"{self.scheduler.infer.nodeid}_{local_data_parallel_id!s}")
self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id
address = (cfg.master_ip, cfg.engine_worker_queue_port)
address = (cfg.master_ip, cfg.engine_worker_queue_port + local_data_parallel_id)
self.engine_worker_queue = EngineWorkerQueue(
address=address,
is_server=False,
@@ -88,10 +97,7 @@ class ExpertService:
self.cfg.cache_config.pd_comm_port = [self.cfg.cache_config.pd_comm_port[local_data_parallel_id]]
self.split_connector = SplitwiseConnector(
self.cfg,
self.scheduler,
self.engine_worker_queue,
self.resource_manager,
self.cfg, self.scheduler, self.engine_worker_queue, self.resource_manager, self.disaggregate_queue
)
self.token_processor = TokenProcessor(
@@ -111,8 +117,12 @@ class ExpertService:
)
self._finalizer = weakref.finalize(self, self._exit_sub_services)
if envs.FD_ENABLE_INTERNAL_ADAPTER:
self.external_adapter = InternalAdapter(cfg=self.cfg, engine=self, dp_rank=local_data_parallel_id)
def start(self, ipc_signal_suffix, local_data_parallel_id):
def start(
self, ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None
):
"""
Initializes the engine and starts its sub-services.
If `api_server_pid` is defined, will launch a thread
@@ -121,13 +131,13 @@ class ExpertService:
# assert not self.is_started, "The engine is already started."
start_time = time.time()
llm_logger.info(f"start expert service {local_data_parallel_id}")
self.llm_logger.info(f"start expert service {local_data_parallel_id}")
if self.cfg.splitwise_role != "mixed":
self.cache_manager_processes = self.resource_manager.cache_manager.launch_cache_manager(
cache_config=self.cfg.cache_config,
tensor_parallel_size=self.cfg.tensor_parallel_size,
device_ids=self.cfg.local_device_ids,
pod_ip=self.cfg.pod_ips[0],
pod_ip=self.cfg.master_ip,
engine_worker_queue_port=self.cfg.engine_worker_queue_port,
pid_suffix=f"{local_data_parallel_id}_{ipc_signal_suffix}",
)
@@ -139,7 +149,7 @@ class ExpertService:
# Start TokenProcessor thread
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(local_data_parallel_id + int(self.cfg.engine_worker_queue_port))
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK
self.token_processor.run()
self.cfg.init_cache_info()
@@ -147,7 +157,11 @@ class ExpertService:
role = self.cfg.splitwise_role
host_ip = self.cfg.host_ip
disaggregate = self.cfg.disaggregate_info
self.scheduler.start(role, host_ip, disaggregate)
if self.cfg.scheduler_config.name == "dp":
assert (request_queues_for_dp_ipc is not None) and (result_queue_for_dp_ipc is not None)
self.scheduler.start(local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc)
elif self.cfg.scheduler_config.name == "splitwise":
self.scheduler.start(role, host_ip, disaggregate)
self.cfg.print()
console_logger.info(f"Worker processes are launched with {time.time() - start_time} seconds.")
@@ -158,7 +172,7 @@ class ExpertService:
Insert task to engine thread, monitor scheduler request queue.
if the engine has resource, insert task to engine
"""
current_id = -1
current_id = 0
while True:
try:
if self.resource_manager.available_batch() == 0:
@@ -167,15 +181,13 @@ class ExpertService:
if self.engine_worker_queue.num_tasks() > 0:
time.sleep(0.001)
continue
if len(self.split_connector.current_request_ids) > 0:
time.sleep(0.001)
continue
num_prefill_batch = min(
int(self.resource_manager.available_batch()),
self.cfg.max_prefill_batch,
)
if envs.FD_ENABLE_INTERNAL_ADAPTER:
num_prefill_batch = int(self.resource_manager.available_batch())
self.resource_manager.check_and_free_block_tables()
tasks = self.scheduler.get_requests(
available_blocks=self.resource_manager.available_block_num(),
@@ -190,73 +202,88 @@ class ExpertService:
continue
if self.cfg.splitwise_role != "mixed":
llm_logger.info("Inserting splitwise tasks")
self.llm_logger.info("Inserting splitwise tasks")
self.split_connector.send_splitwise_tasks(tasks, current_id)
current_id = (current_id + 1) % 100003
self.insert_tasks(tasks, current_id)
insert_successful = self.insert_tasks(tasks, current_id)
if insert_successful:
current_id = current_id + 1
else:
continue
main_process_metrics.num_requests_waiting.dec(len(tasks))
main_process_metrics.num_requests_running.inc(len(tasks))
except Exception as e:
err_msg = f"Error happend while insert task to engine: {e}, {traceback.format_exc()!s}."
llm_logger.error(err_msg)
err_msg = "Error happend while insert task to engine: {}, {}.".format(e, str(traceback.format_exc()))
self.llm_logger.error(err_msg)
def split_mode_get_tasks(self):
"""
Split mode get tasks
"""
waiting_requests = []
def receiver_loop():
while True:
try:
if len(waiting_requests) > 0:
for task in waiting_requests:
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
self.insert_tasks([task])
waiting_requests.remove(task)
else:
break
if not self.engine_worker_queue.disaggregate_queue_empty():
items = self.engine_worker_queue.get_disaggregated_tasks()
for item in items:
role = item[0]
tasks = item[1]
if role == "prefill":
llm_logger.info("get prefill tasks")
for task in tasks:
task.max_tokens = task.min_tokens = 2
self.insert_tasks(tasks)
elif role == "decode":
llm_logger.info(f"get decode tasks {tasks}")
if hasattr(tasks[0], "finished"):
if not isinstance(tasks, list):
tasks = [tasks]
for task in tasks:
task.finished = False
# self.scheduler.put_results(tasks)
self.insert_tasks(tasks, allocated=True)
processed_indices = []
for idx, task in enumerate(self.waiting_requests):
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
self.insert_tasks([task])
self.llm_logger.info(f"Resource available, processing task {task.request_id}")
processed_indices.append(idx)
else:
self.llm_logger.debug(f"Still waiting for resources {task.request_id}")
break
for idx in sorted(processed_indices, reverse=True):
self.waiting_requests.pop(idx)
if len(self.disaggregate_queue) > 0:
items = self.disaggregate_queue.pop()
role = items[0]
tasks = items[1]
if role == "prefill":
for task in tasks:
task.max_tokens = task.min_tokens = 2
self.insert_tasks(tasks)
elif role == "decode":
if hasattr(tasks[0], "finished"):
if not isinstance(tasks, list):
tasks = [tasks]
for task in tasks:
task.finished = False
self.insert_tasks(tasks, allocated=True)
if self.cfg.innode_prefill_ports is not None:
self.scheduler.put_results(tasks)
else:
if len(self.waiting_requests):
self.llm_logger.info(f"Waiting for resource for task {tasks[0].request_id}")
self.waiting_requests.extend(tasks)
else:
if len(waiting_requests):
for task in tasks:
waiting_requests.append(task)
else:
for task in tasks:
if not self.resource_manager.is_resource_sufficient(
task.prompt_token_ids_len
):
waiting_requests.append(task)
else:
self.insert_tasks([task])
new_waiting = []
for task in tasks:
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
self.insert_tasks([task])
else:
if not self.enable_decode_cache_task:
task.error_msg = "Not enough resources"
new_waiting.append(task)
if new_waiting:
if not self.enable_decode_cache_task:
self.split_connector.send_cache_infos(new_waiting, -1)
else:
self.waiting_requests.extend(new_waiting)
self.llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue")
else:
time.sleep(0.001)
continue
except Exception as e:
llm_logger.error(f"get decode tasks error: {e}")
self.llm_logger.error(f"Error in main loop: {e} {str(traceback.format_exc())}")
time.sleep(0.1)
threading.Thread(target=receiver_loop, daemon=True).start()
@@ -270,22 +297,32 @@ class ExpertService:
cur_task_idx = self.resource_manager.req_dict[task.request_id]
del self.resource_manager.req_dict[task.request_id]
cur_task = self.resource_manager.tasks_list[cur_task_idx]
if envs.FD_ENABLE_INTERNAL_ADAPTER:
if not task.outputs.token_ids: # first token is eos in Prefill, just recycle resource and continue
self.resource_manager.stop_flags[cur_task_idx] = True
self.resource_manager.tasks_list[cur_task_idx] = None
self.resource_manager._recycle_block_tables(cur_task)
if task.request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[task.request_id]
self.llm_logger.warning(f"{task.request_id} need not decode after first token")
continue
cur_task.prompt_token_ids[0] = task.outputs.token_ids[0]
if self.cfg.speculative_config.method in ["mtp"] and self.cfg.splitwise_role == "decode":
cur_task.draft_token_ids = copy.deepcopy(task.outputs.draft_token_ids)
if task.error_code != 200:
self.resource_manager.stop_flags[cur_task_idx] = True
self.resource_manager.tasks_list[cur_task_idx] = None
self.resource_manager._recycle_block_tables(cur_task)
if task.request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[task.request_id]
self.scheduler.put_results([task])
llm_logger.warning(
self.llm_logger.warning(
f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource."
)
continue
llm_logger.info(f"{cur_task_idx} {task.request_id}")
cur_task.prompt_token_ids[0] = task.outputs.token_ids[0]
self.token_processor.tokens_counter[task.request_id] = 1
current_tasks.append(cur_task)
self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz))
if current_tasks:
self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz))
return True
self.resource_manager.check_and_free_block_tables()
@@ -293,22 +330,48 @@ class ExpertService:
if not isinstance(tasks, list):
tasks = [tasks]
need_delete_tasks = []
for task in tasks:
if self.cfg.splitwise_role != "mixed":
status, msg = self.split_connector.check_decode_allocated(task)
if not status:
self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.")
self.scheduler.put_results(
[
RequestOutput(
request_id=task.request_id,
finished=True,
error_code=500,
error_msg=msg,
)
]
)
need_delete_tasks.append(task)
continue
for tmp_task in need_delete_tasks:
tasks.remove(tmp_task)
for item in tasks:
item.schedule_start_time = time.time()
req_ids = [t.request_id for t in tasks]
if len(tasks) == 0:
return False
available_batch = np.sum(self.resource_manager.stop_flags)
if len(tasks) > available_batch:
llm_logger.error(f"Inserting batch:{len(tasks)} exceeds the available batch:{available_batch}.")
llm_logger.error("The exceeded part will be ignored!")
self.llm_logger.error(
"Inserting batch:{} exceeds the available batch:{}.".format(len(tasks), available_batch)
)
self.llm_logger.error("The exceeded part will be ignored!")
tasks = tasks[:available_batch]
req_ids = [t.request_id for t in tasks]
tasks = self.resource_manager.allocate_resources_for_new_tasks(tasks)
if not tasks:
error_msg = f"The request required resources is exceed the limit, request id={req_ids}."
llm_logger.error(error_msg)
self.llm_logger.error(error_msg)
raise EngineError(error_msg, error_code=500)
return False
@@ -328,7 +391,7 @@ class ExpertService:
for task in tasks:
task.infer_start_time = time.time()
if not is_decode:
llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
self.llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
if not is_prefill and self.cfg.cache_config.enable_chunked_prefill:
if not self.cfg.enable_mm:
self.update_requests_chunk_size(tasks)
@@ -346,7 +409,7 @@ class ExpertService:
self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear()
self.resource_manager.cache_manager.cache_ready_signal.clear()
for p in self.cache_manager_processes:
llm_logger.info(f"Killing cache manager process {p.pid}")
self.llm_logger.info(f"Killing cache manager process {p.pid}")
try:
os.killpg(p.pid, signal.SIGTERM)
except:
@@ -356,13 +419,17 @@ class ExpertService:
self.zmq_server.close()
def start_expert_service(cfg, local_data_parallel_id, ipc_signal_suffix):
def start_expert_service(
cfg, local_data_parallel_id, ipc_signal_suffix, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None
):
"""
Start expert service
"""
expert_service = ExpertService(cfg, local_data_parallel_id)
try:
expert_service.start(ipc_signal_suffix, local_data_parallel_id)
expert_service.start(
ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc
)
expert_service.split_connector.start_receiver()
except Exception as e:
llm_logger.exception(f"Expert service failed to start: {e}")

View File

@@ -71,6 +71,7 @@ class Request:
guided_json_object: Optional[bool] = None,
enable_thinking: Optional[bool] = True,
trace_carrier: dict = dict(),
dp_rank: Optional[int] = None
) -> None:
self.request_id = request_id
self.prompt = prompt
@@ -119,6 +120,7 @@ class Request:
self.task_type = RequestType.PREFILL
self.idx = None
self.need_prefill_tokens = self.prompt_token_ids_len
self.dp_rank = dp_rank
@classmethod
def from_dict(cls, d: dict):
@@ -151,6 +153,7 @@ class Request:
guided_json_object=d.get("guided_json_object", None),
enable_thinking=d.get("enable_thinking", True),
trace_carrier=d.get("trace_carrier", {}),
dp_rank=d.get("dp_rank", None)
)
@property

View File

@@ -22,7 +22,7 @@ import numpy as np
from fastdeploy.cache_manager.prefix_cache_manager import PrefixCacheManager
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.utils import llm_logger
from fastdeploy.utils import get_logger, llm_logger
class ResourceManager:
@@ -49,16 +49,23 @@ class ResourceManager:
Initializes the engine with the given configuration and sets up necessary
data structures to manage tasks and blocks.
"""
if local_data_parallel_id > 0:
self.logger = get_logger(
f"expert_service_{local_data_parallel_id}", f"expert_service_{local_data_parallel_id}.log"
)
else:
self.logger = llm_logger
self.cfg = config.cache_config
self.max_num_seqs = max_num_seqs
self.stop_flags = [True] * max_num_seqs
self.stop_flags = [True] * max_num_seqs # flag set to true if the slot has not been taken
self.enable_prefix_cache = config.cache_config.enable_prefix_caching
self.cache_manager = PrefixCacheManager(config, tensor_parallel_size, splitwise_role, local_data_parallel_id)
self.tasks_list = [None] * max_num_seqs
self.tasks_list = [None] * max_num_seqs # task slots
self.req_dict = dict()
# current batch status of the engine
self.real_bsz = 0
llm_logger.info(f"{self.info()}")
self.logger.info(f"{self.info()}")
main_process_metrics.max_batch_size.set(max_num_seqs)
def reset_cache_config(self, cfg):
"""
@@ -134,10 +141,10 @@ class ResourceManager:
block_list = list()
current_block_num = self.available_block_num()
if block_num > current_block_num:
llm_logger.error(f"block_num:{block_num} > free_list len:{current_block_num}")
self.logger.error("block_num:{0} > free_list len:{1}".format(block_num, current_block_num))
return block_list
block_list = self.cache_manager.allocate_gpu_blocks(block_num)
llm_logger.debug(f"dispatch {len(block_list)} blocks.")
self.logger.debug(f"dispatch {len(block_list)} blocks.")
return block_list
def check_and_free_block_tables(self):
@@ -169,7 +176,7 @@ class ResourceManager:
self.cache_manager.recycle_gpu_blocks(block_tables)
cur_number = self.available_block_num()
main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc())
llm_logger.info(f"recycle {req_id} {cur_number - ori_number} blocks.")
self.logger.info(f"recycle {req_id} {cur_number - ori_number} blocks.")
def available_batch(self):
"""
@@ -222,47 +229,47 @@ class ResourceManager:
Returns:
list: processed task list
"""
allocated_position = 0
processing_task_index = 0
llm_logger.debug(f"Allocating resources for a batch of new tasks: {tasks}")
allocated_position = 0 # number of tasks that have been allocated, also the position in request slots
processing_task_index = 0 # current task
processed_tasks = list()
while allocated_position < self.max_num_seqs:
if processing_task_index >= len(tasks):
while allocated_position < self.max_num_seqs: # loop until all tasks are allocated resources for
if processing_task_index >= len(tasks): # if all taskes have been tried, don't give a second chance
break
can_insert = False
while allocated_position + 1 <= self.max_num_seqs:
if sum(self.stop_flags[allocated_position : allocated_position + 1]) == 1:
can_insert = True
can_insert = True # if there is a empty slot, try to allocate resources for current task
break
allocated_position += 1
if can_insert:
if self.stop_flags[allocated_position]:
task = tasks[processing_task_index]
task = tasks[processing_task_index] # retrieve current task
if task.get("seed") is None:
task.set("seed", random.randint(0, 9223372036854775807))
task.idx = allocated_position
if self.enable_prefix_cache:
if self.enable_prefix_cache: # if prefix caching is enabled
# 1. request for enough blocks for current task
cache_prepare_time = time.time()
common_block_ids, unique_block_ids, hit_info = self.cache_manager.request_block_ids(
task,
self.cfg.block_size,
self.cfg.dec_token_num,
task, self.cfg.block_size, self.cfg.dec_token_num
)
if unique_block_ids is None:
llm_logger.warning("req_id: {0} not enough blocks available".format(task["req_id"]))
self.logger.warning("req_id: {0} not enough blocks available".format(task["req_id"]))
return
# 2. record cache hit information, and return the number of tokens already in cache
cached_len = self._record_request_cache_info(
task, common_block_ids, unique_block_ids, hit_info
)
task.cache_prepare_time = time.time() - cache_prepare_time
# 3. if prefill/decode disaggregation is enabled
if task.disaggregate_info is not None:
if task.disaggregate_info["role"] == "prefill":
# record the slot position for current task, indexed by request id
self.req_dict[task.request_id] = allocated_position
task.disaggregate_info["block_tables"] = task.block_tables
self._delete_cached_data(task, cached_len)
@@ -270,17 +277,19 @@ class ResourceManager:
self.req_dict[task.request_id] = allocated_position
task.disaggregate_info["block_tables"] = task.need_block_tables
else:
# remove cached tokens from prompt token ids to avoid kv recomputation
self._delete_cached_data(task, cached_len)
else:
else: # if prefix caching is disabled
# 1. directly allocate empty block from the cache, if there is any
block_tables = self._get_block_tables(task.prompt_token_ids_len)
if not block_tables:
llm_logger.error(f"req_id: {task.request_id} block_tables is empty")
continue
continue # retry
else:
task.block_tables = block_tables
task.need_block_tables = task.block_tables
# 2. if prefill/decode disaggregation is enabled
if task.disaggregate_info is not None:
task.disaggregate_info["block_tables"] = block_tables
if task.disaggregate_info["role"] == "prefill":
@@ -288,13 +297,13 @@ class ResourceManager:
elif task.disaggregate_info["role"] == "decode":
self.req_dict[task.request_id] = allocated_position
processed_tasks.append(task)
self.stop_flags[allocated_position] = False
processed_tasks.append(task) # add current task
self.stop_flags[allocated_position] = False # mark the slot as occupied
task.inference_start_time = time.time()
task.inference_time_cost = -1.0
task.tokens_all_num = 0
self.tasks_list[allocated_position] = task
llm_logger.info(
self.logger.info(
f"Allocate request: {task.request_id}, "
f"allocated_position:{allocated_position}, "
f"length of prompt token: {task.prompt_token_ids_len}"
@@ -303,15 +312,22 @@ class ResourceManager:
processing_task_index += 1
# batch size when the statistical engine is inferring
# determine batch size by index of the first slot that is not occupied
for i in range(self.max_num_seqs - 1, -1, -1):
if not self.stop_flags[i]:
self.real_bsz = i + 1
break
llm_logger.info(
# record batch size here
task_used_block_num = sum([len(task.block_tables) if task else 0 for task in self.tasks_list])
main_process_metrics.available_gpu_block_num.set(self.total_block_number() - task_used_block_num)
main_process_metrics.batch_size.set(self.max_num_seqs - self.available_batch())
main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc())
self.logger.info(
f"Number of allocated requests: {len(tasks)}, number of " f"running requests in worker: {self.real_bsz}"
)
llm_logger.info(f"{self.info()}")
self.logger.info(f"{self.info()}")
main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc())
return processed_tasks
@@ -321,8 +337,8 @@ class ResourceManager:
Delete cached data from the task's prompt token ids based on the cached length.
"""
if cached_len == len(task.prompt_token_ids):
task.prompt_token_ids = task.prompt_token_ids[cached_len - 1 :]
task.seq_lens_decoder = cached_len - 1
task.prompt_token_ids = task.prompt_token_ids[cached_len - self.cfg.block_size :]
task.seq_lens_decoder = cached_len - self.cfg.block_size
else:
task.prompt_token_ids = task.prompt_token_ids[cached_len:]
task.seq_lens_decoder = cached_len
@@ -339,11 +355,16 @@ class ResourceManager:
task.cpu_cache_token_num = hit_info["cpu_cache_blocks"] * self.cfg.block_size
task.cache_info = (cache_block_num, no_cache_block_num)
# Report the number of cached tokens to Prometheus metrics
main_process_metrics.prefix_cache_token_num.inc(task.num_cached_tokens)
main_process_metrics.prefix_gpu_cache_token_num.inc(task.gpu_cache_token_num)
main_process_metrics.prefix_cpu_cache_token_num.inc(task.cpu_cache_token_num)
cached_len = len(common_block_ids) * self.cfg.block_size
task.block_tables = common_block_ids + unique_block_ids
task.need_block_tables = unique_block_ids
llm_logger.debug(f"common: {common_block_ids} ")
llm_logger.debug(f"unique: {unique_block_ids} ")
self.logger.debug(f"common: {common_block_ids} ")
self.logger.debug(f"unique: {unique_block_ids} ")
return cached_len
def info(self):

View File

@@ -27,6 +27,7 @@ import paddle
from fastdeploy.engine.request import Request, RequestStatus, RequestType
from fastdeploy.engine.resource_manager import ResourceManager
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.utils import llm_logger
@@ -75,6 +76,7 @@ class ResourceManagerV1(ResourceManager):
self.running: list[Request] = []
self.finish_execution_pool = ThreadPoolExecutor(max_workers=1)
self.lock = threading.Lock()
main_process_metrics.max_batch_size.set(max_num_seqs)
def allocated_slots(self, request: Request):
return len(request.block_tables) * self.config.cache_config.block_size
@@ -98,6 +100,9 @@ class ResourceManagerV1(ResourceManager):
return ScheduledPreemptTask(idx=request.idx, request_id=request.request_id)
def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_reqs):
"""
If the request cannot be scheduled, preempt the running request one by one until it can be scheduled. Last in, first out.
"""
can_schedule = True
while True:
if not self.cache_manager.can_allocate_gpu_blocks(num_new_blocks):
@@ -201,6 +206,9 @@ class ResourceManagerV1(ResourceManager):
return False
def schedule(self):
"""
Try to pull a batch of requests from the waiting queue and schedule them.
"""
with self.lock:
scheduled_reqs: list[Request] = []
preempted_reqs: list[Request] = []
@@ -262,7 +270,7 @@ class ResourceManagerV1(ResourceManager):
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(num_new_block))
# Prepare prefill task
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
else:
else: # Not enough blocks to allocate, trigger preemption
can_schedule = self._trigger_preempt(request, num_new_block, preempted_reqs, scheduled_reqs)
if not can_schedule:
break
@@ -328,6 +336,10 @@ class ResourceManagerV1(ResourceManager):
else:
llm_logger.error("Unknown request status type")
if scheduled_reqs:
task_used_block_num = sum([len(task.block_tables) if task else 0 for task in self.tasks_list])
main_process_metrics.available_gpu_block_num.set(self.total_block_number() - task_used_block_num)
main_process_metrics.batch_size.set(self.max_num_seqs - self.available_batch())
main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc())
llm_logger.debug(f"schedued_reqs: {scheduled_reqs}")
return scheduled_reqs
@@ -369,6 +381,11 @@ class ResourceManagerV1(ResourceManager):
request.block_tables = common_block_ids
request.skip_allocate = False
# Report the number of cached tokens to Prometheus metrics
main_process_metrics.prefix_cache_token_num.inc(matched_token_num)
main_process_metrics.prefix_gpu_cache_token_num.inc(request.gpu_cache_token_num)
main_process_metrics.prefix_cpu_cache_token_num.inc(request.cpu_cache_token_num)
if matched_token_num == request.prompt_token_ids_len:
request.num_computed_tokens = matched_token_num - 1
request.skip_allocate = True

View File

@@ -21,7 +21,7 @@ import numpy as np
from fastdeploy.engine.config import ModelConfig
from fastdeploy.input.preprocess import InputPreprocessor
from fastdeploy.inter_communicator import IPCSignal, ZmqClient
from fastdeploy.inter_communicator import IPCSignal, ZmqIpcClient
from fastdeploy.metrics.work_metrics import work_process_metrics
from fastdeploy.multimodal.registry import MultimodalRegistry
from fastdeploy.platforms import current_platform
@@ -90,7 +90,7 @@ class EngineClient:
"""
Create a ZMQ client.
"""
self.zmq_client = ZmqClient(model, mode)
self.zmq_client = ZmqIpcClient(model, mode)
self.zmq_client.connect()
def format_and_add_data(self, prompts: dict):

View File

@@ -177,6 +177,8 @@ class OpenAIServingChat:
for res in response:
if res.get("error_code", 200) != 200:
raise ValueError("{}".format(res["error_msg"]))
if res["finished"]:
api_server_logger.info(f"chat completion finished: {request_id}")
self.engine_client.data_processor.process_response_dict(
res,

View File

@@ -80,8 +80,20 @@ environment_variables: dict[str, Callable[[], Any]] = {
"EXPORTER_OTLP_HEADERS": lambda: os.getenv("EXPORTER_OTLP_HEADERS"),
# enable kv cache block scheduler v1 (no need for kv_cache_ratio)
"ENABLE_V1_KVCACHE_SCHEDULER": lambda: int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")),
# enable internal module to access LLMEngine.
"FD_ENABLE_INTERNAL_ADAPTER": lambda: int(os.getenv("FD_ENABLE_INTERNAL_ADAPTER", "0")),
# LLMEngine recieve requests port, used when FD_ENABLE_INTERNAL_ADAPTER=1
"FD_ZMQ_RECV_REQUEST_SERVER_PORT": lambda: os.getenv("FD_ZMQ_RECV_REQUEST_SERVER_PORT", "8200"),
# LLMEngine send response port, used when FD_ENABLE_INTERNAL_ADAPTER=1
"FD_ZMQ_SEND_RESPONSE_SERVER_PORT": lambda: os.getenv("FD_ZMQ_SEND_RESPONSE_SERVER_PORT", "8201"),
# LLMEngine recieve control command port, used when FD_ENABLE_INTERNAL_ADAPTER=1
"FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"),
# Batched token timeout in EP
"FD_EP_BATCHED_TOKEN_TIMEOUT": lambda: float(os.getenv("FD_EP_BATCHED_TOKEN_TIMEOUT", "0.1")),
# Whether to use PLUGINS.
"FD_PLUGINS": lambda: None if "FD_PLUGINS" not in os.environ else os.environ["FD_PLUGINS"].split(","),
# Whether to enable cache task in decode node
"FD_ENABLE_CACHE_TASK": lambda: os.getenv("FD_ENABLE_CACHE_TASK", "1"),
}

View File

@@ -17,6 +17,7 @@
from .engine_cache_queue import EngineCacheQueue
from .engine_worker_queue import EngineWorkerQueue
from .ipc_signal import IPCSignal
from .zmq_client import ZmqClient
from .zmq_client import ZmqIpcClient
from .zmq_server import ZmqIpcServer, ZmqTcpServer
__all__ = ["ZmqClient", "IPCSignal", "EngineWorkerQueue", "EngineCacheQueue"]
__all__ = ["ZmqIpcClient", "IPCSignal", "EngineWorkerQueue", "EngineCacheQueue", "ZmqTcpServer", "ZmqIpcServer"]

View File

@@ -85,12 +85,15 @@ class EngineWorkerQueue:
]
self.finished_req_queue = [Queue() for _ in range(self.local_data_parallel_size)]
self.cache_infos_init: List[List[Any]] = [list() for _ in range(self.local_data_parallel_size)]
self.connect_rdma_tasks_list = [list() for _ in range(self.local_data_parallel_size)]
self.connect_rdma_tasks_response_list = [list() for _ in range(self.local_data_parallel_size)]
self.client_read_info_flag_init: List[List[int]] = [
[1] * self.num_client for _ in range(self.local_data_parallel_size)
]
self.lock_info_init: List[threading.Lock] = [
threading.Lock() for _ in range(self.local_data_parallel_size)
]
self.connect_task_lock_init: List[threading.Lock] = [threading.Lock() for _ in range(self.local_data_parallel_size)]
self.finish_request_barrier = [
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
@@ -112,11 +115,26 @@ class EngineWorkerQueue:
callable=lambda idx: self.lock_init[idx],
proxytype=AcquirerProxy,
)
QueueManager.register(
"get_connect_task_lock",
callable=lambda idx: self.connect_task_lock_init[idx],
proxytype=AcquirerProxy,
)
QueueManager.register(
"get_read_finish_flag",
callable=lambda idx: self.read_finish_flag_init[idx],
proxytype=ValueProxy,
)
QueueManager.register(
"get_connect_rdma_tasks",
callable=lambda idx: self.connect_rdma_tasks_list[idx],
proxytype=ListProxy
)
QueueManager.register(
"get_connect_rdma_tasks_responses",
callable=lambda idx: self.connect_rdma_tasks_response_list[idx],
proxytype=ListProxy
)
QueueManager.register(
"get_connected_client_counter",
callable=lambda idx: self.connected_client_counter_init[idx],
@@ -180,6 +198,9 @@ class EngineWorkerQueue:
QueueManager.register("get_disaggregate_requests")
QueueManager.register("get_available_prefill_instances")
QueueManager.register("get_finish_request_barrier")
QueueManager.register("get_connect_rdma_tasks")
QueueManager.register("get_connect_rdma_tasks_responses")
QueueManager.register("get_connect_task_lock")
self.manager = QueueManager(address=self.address, authkey=self.authkey)
self._connect_with_retry()
@@ -200,6 +221,13 @@ class EngineWorkerQueue:
self.available_prefill_instances = self.manager.get_available_prefill_instances()
self.finish_request_barrier = self.manager.get_finish_request_barrier(self.local_data_parallel_id)
self.finished_req_queue = self.manager.get_finish_request_queue(self.local_data_parallel_id)
# p/d互联
self.connect_rdma_task_queue = self.manager.get_connect_rdma_tasks(self.local_data_parallel_id)
self.connect_rdma_task_response_queue = self.manager.get_connect_rdma_tasks_responses(
self.local_data_parallel_id
)
self.connect_task_lock = self.manager.get_connect_task_lock(self.local_data_parallel_id)
assert self.num_client == len(self.client_read_flag)
if is_server:
@@ -280,6 +308,45 @@ class EngineWorkerQueue:
total_num: int = len(self.tasks)
self.lock.release()
return total_num
def put_connect_rdma_task(self, connect_rdma_task):
self.connect_task_lock.acquire()
self.connect_rdma_task_queue.append(connect_rdma_task)
self.connect_task_lock.release()
def get_connect_rdma_task(self):
result = None
self.connect_task_lock.acquire()
if len(self.connect_rdma_task_queue) == 0:
self.connect_task_lock.release()
return result
try:
result = self.connect_rdma_task_queue.pop(0)
except Exception as e:
llm_logger.info(f"get_connect_rdma_task got exception: {e}")
finally:
self.connect_task_lock.release()
return result
def put_connect_rdma_task_response(self, connect_rdma_task_response):
self.connect_task_lock.acquire()
self.connect_rdma_task_response_queue.append(connect_rdma_task_response)
self.connect_task_lock.release()
def get_connect_rdma_task_response(self):
result = None
self.connect_task_lock.acquire()
if len(self.connect_rdma_task_response_queue) == 0:
self.connect_task_lock.release()
return result
try:
result = self.connect_rdma_task_response_queue.pop(0)
except Exception as e:
llm_logger.info(f"get_connect_rdma_task_response got exception: {e}")
finally:
self.connect_task_lock.release()
return result
def get_prefill_instances(self):
"""

View File

@@ -14,200 +14,78 @@
# limitations under the License.
"""
import os
import threading
import time
from abc import ABC, abstractmethod
import msgpack
import zmq
from fastdeploy import envs
from fastdeploy.utils import llm_logger
class ZmqClient:
class ZmqClientBase(ABC):
"""
ZmqClient is a class that provides a client-side interface for sending and receiving messages using ZeroMQ.
ZmqClientBase is a base class that provides a client-side interface for sending and receiving messages using ZeroMQ.
"""
def __init__(self, name, mode):
self.context = zmq.Context()
self.socket = self.context.socket(mode)
self.file_name = f"/dev/shm/{name}.socket"
self.router_path = f"/dev/shm/router_{name}.ipc"
def __init__(self):
pass
self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM)
self.aggregate_send = envs.FD_USE_AGGREGATE_SEND
@abstractmethod
def _create_socket(self):
"""Abstract method to create and return a ZeroMQ socket."""
pass
self.mutex = threading.Lock()
self.req_dict = dict()
self.router = None
self.poller = None
self.running = True
def _ensure_socket(self):
"""Ensure the socket is created before use."""
if self.socket is None:
self.socket = self._create_socket()
@abstractmethod
def connect(self):
"""
Connect to the server using the file name specified in the constructor.
"""
self.socket.connect(f"ipc://{self.file_name}")
def start_server(self):
"""
Start the server using the file name specified in the constructor.
"""
self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
self.socket.setsockopt(zmq.SNDTIMEO, -1)
self.socket.bind(f"ipc://{self.file_name}")
self.poller = zmq.Poller()
self.poller.register(self.socket, zmq.POLLIN)
def create_router(self):
"""
Create a ROUTER socket and bind it to the specified router path.
"""
self.router = self.context.socket(zmq.ROUTER)
self.router.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
self.router.setsockopt(zmq.SNDTIMEO, -1)
self.router.bind(f"ipc://{self.router_path}")
pass
def send_json(self, data):
"""
Send a JSON-serializable object over the socket.
"""
self._ensure_socket()
self.socket.send_json(data)
def recv_json(self):
"""
Receive a JSON-serializable object from the socket.
"""
self._ensure_socket()
return self.socket.recv_json()
def send_pyobj(self, data):
"""
Send a Pickle-serializable object over the socket.
"""
self._ensure_socket()
self.socket.send_pyobj(data)
def recv_pyobj(self):
"""
Receive a Pickle-serializable object from the socket.
"""
self._ensure_socket()
return self.socket.recv_pyobj()
def pack_aggregated_data(self, data):
"""
Aggregate multiple responses into one and send them to the client.
"""
result = data[0]
if len(data) > 1:
for response in data[1:]:
result.add(response)
result = msgpack.packb([result.to_dict()])
return result
def send_multipart(self, req_id, data):
"""
Send a multipart message to the router socket.
"""
if self.router is None:
raise RuntimeError("Router socket not created. Call create_router() first.")
class ZmqIpcClient(ZmqClientBase):
def __init__(self, name, mode):
self.name = name
self.mode = mode
self.file_name = f"/dev/shm/{name}.socket"
self.context = zmq.Context()
self.socket = self.context.socket(self.mode)
while self.running:
with self.mutex:
if req_id not in self.req_dict:
try:
client, _, request_id = self.router.recv_multipart(flags=zmq.NOBLOCK)
req_id_str = request_id.decode("utf-8")
self.req_dict[req_id_str] = client
except zmq.Again:
time.sleep(0.001)
continue
else:
break
def _create_socket(self):
"""create and return a ZeroMQ socket."""
self.context = zmq.Context()
return self.context.socket(self.mode)
try:
start_send = time.time()
if self.aggregate_send:
result = self.pack_aggregated_data(data)
else:
result = msgpack.packb([response.to_dict() for response in data])
self.router.send_multipart([self.req_dict[req_id], b"", result])
llm_logger.debug(f"send_multipart result: {req_id} len {len(data)} elapse: {time.time()-start_send}")
except Exception as e:
llm_logger.error(f"Send result to zmq client failed: {e}")
if data[-1].finished:
with self.mutex:
self.req_dict.pop(req_id, None)
llm_logger.info(f"send_multipart finished, req_id: {req_id}")
def receive_json_once(self, block=False):
"""
Receive a single message from the socket.
"""
if self.socket is None or self.socket.closed:
return "zmp socket has closed", None
try:
flags = zmq.NOBLOCK if not block else 0
return None, self.socket.recv_json(flags=flags)
except zmq.Again:
return None, None
except Exception as e:
self.close()
llm_logger.warning(f"{e}")
return str(e), None
def receive_pyobj_once(self, block=False):
"""
Receive a single message from the socket.
"""
if self.socket is None or self.socket.closed:
return "zmp socket has closed", None
try:
flags = zmq.NOBLOCK if not block else 0
return None, self.socket.recv_pyobj(flags=flags)
except zmq.Again:
return None, None
except Exception as e:
self.close()
llm_logger.warning(f"{e}")
return str(e), None
def _clear_ipc(self, name):
"""
Remove the IPC file with the given name.
"""
if os.path.exists(name):
try:
os.remove(name)
except OSError as e:
llm_logger.warning(f"Failed to remove IPC file {name} - {e}")
def close(self):
"""
Close the socket and context, and remove the IPC files.
"""
if not self.running:
return
self.running = False
llm_logger.info("Closing ZMQ connection...")
try:
if hasattr(self, "socket") and not self.socket.closed:
self.socket.close()
if self.router is not None and not self.router.closed:
self.router.close()
if not self.context.closed:
self.context.term()
self._clear_ipc(self.file_name)
self._clear_ipc(self.router_path)
except Exception as e:
llm_logger.warning(f"Failed to close ZMQ connection - {e}")
return
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def connect(self):
self._ensure_socket()
self.socket.connect(f"ipc://{self.file_name}")

View File

@@ -0,0 +1,303 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import os
import threading
import time
from abc import ABC, abstractmethod
from collections import defaultdict
import msgpack
import zmq
from fastdeploy import envs
from fastdeploy.utils import llm_logger
class ZmqServerBase(ABC):
"""
ZmqServerBase
"""
def __init__(self):
self.cached_results = defaultdict(list)
self.response_token_lock = threading.Lock()
@abstractmethod
def _create_socket(self):
"""Abstract method to create and return a ZeroMQ socket."""
pass
def _ensure_socket(self):
"""Ensure the socket is created before use."""
if self.socket is None:
self.socket = self._create_socket()
def pack_aggregated_data(self, data):
"""
Aggregate multiple responses into one and send them to the client.
"""
result = data[0]
if len(data) > 1:
for response in data[1:]:
result.add(response)
result = msgpack.packb([result.to_dict()])
return result
def receive_json_once(self, block=False):
"""
Receive a single message from the socket.
"""
self._ensure_socket()
if self.socket is None or self.socket.closed:
return "zmp socket has closed", None
try:
flags = zmq.NOBLOCK if not block else 0
return None, self.socket.recv_json(flags=flags)
except zmq.Again:
return None, None
except Exception as e:
self.close()
llm_logger.warning(f"{e}")
return str(e), None
def receive_pyobj_once(self, block=False):
"""
Receive a single message from the socket.
"""
self._ensure_socket()
if self.socket is None or self.socket.closed:
return "zmp socket has closed", None
try:
flags = zmq.NOBLOCK if not block else 0
return None, self.socket.recv_pyobj(flags=flags)
except zmq.Again:
return None, None
except Exception as e:
self.close()
llm_logger.warning(f"{e}")
return str(e), None
def recv_result_handle(self):
while True:
try:
with self.response_token_lock:
client, _, request_id = self.socket.recv_multipart(flags=zmq.NOBLOCK)
req_id_str = request_id.decode("utf-8")
with self.mutex:
self.req_dict[req_id_str] = client
except zmq.Again:
time.sleep(0.001)
continue
except Exception as e:
llm_logger.error(f"recv_result_handle get unknown exception: {e}")
continue
def send_response(self, req_id, data):
"""
Send generated token result to client.
"""
self._ensure_socket()
if self.socket is None:
raise RuntimeError("Router socket not created. Call create_router() first.")
new_data = []
has_result_handle = False
with self.mutex:
if req_id not in self.req_dict:
self.cached_results[req_id].append(data)
else:
has_result_handle = True
if req_id in self.cached_results:
for history_data in self.cached_results[req_id]:
new_data.extend(history_data)
llm_logger.info(
f"get request {req_id} result handle after cached result, total cached length {len(self.cached_results[req_id])}"
)
del self.cached_results[req_id]
if has_result_handle:
try:
new_data.extend(data)
start_send = time.time()
if self.aggregate_send:
result = self.pack_aggregated_data(new_data)
else:
result = msgpack.packb([response.to_dict() for response in new_data])
with self.response_token_lock:
self.socket.send_multipart([self.req_dict[req_id], b"", result])
llm_logger.debug(
f"send_multipart result: {req_id} len {len(new_data)} elapse: {time.time()-start_send}"
)
except Exception as e:
llm_logger.error(f"Send result to zmq client failed: {e}")
if data[-1].finished:
with self.mutex:
if req_id not in self.req_dict:
llm_logger.warning(f"req_id {req_id} finished but no result handle, drop it")
if req_id in self.cached_results:
del self.cached_results[req_id]
else:
llm_logger.info(f"send_multipart finished, req_id: {req_id}")
self.req_dict.pop(req_id, None)
@abstractmethod
def close(self):
pass
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
class ZmqIpcServer(ZmqServerBase):
"""
ZmqIpcServer, used when FD_ENABLE_INTERNAL_ADAPTER=0
"""
def __init__(self, name, mode):
self.name = name
self.mode = mode
self.cached_results = defaultdict(list)
if mode == zmq.PULL:
self.file_name = f"/dev/shm/{name}.socket"
elif mode == zmq.ROUTER:
self.file_name = f"/dev/shm/router_{name}.ipc"
self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM)
self.aggregate_send = envs.FD_USE_AGGREGATE_SEND
self.mutex = threading.Lock()
self.response_token_lock = threading.Lock()
self.req_dict = dict()
self.running = True
self.context = zmq.Context()
self._create_socket()
def _create_socket(self):
"""create and return a ZeroMQ socket."""
self.socket = self.context.socket(self.mode)
self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
self.socket.setsockopt(zmq.SNDTIMEO, -1)
self.socket.bind(f"ipc://{self.file_name}")
return self.socket
def _clear_ipc(self, name):
"""
Remove the IPC file with the given name.
"""
if os.path.exists(name):
try:
os.remove(name)
except OSError as e:
llm_logger.warning(f"Failed to remove IPC file {name} - {e}")
def close(self):
"""
Close the socket and context, and remove the IPC files.
"""
if not self.running:
return
self.running = False
llm_logger.info("Closing ZMQ connection...")
try:
if self.socket is not None and not self.socket.closed:
self.socket.close()
if not self.context.closed:
self.context.term()
self._clear_ipc(self.file_name)
except Exception as e:
llm_logger.warning(f"Failed to close ZMQ connection - {e}")
return
class ZmqTcpServer(ZmqServerBase):
"""
ZmqTcpServer, used when FD_ENABLE_INTERNAL_ADAPTER=1
"""
def __init__(self, port, mode):
self.mode = mode
self.port = port
self.cached_results = defaultdict(list)
self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM)
self.aggregate_send = envs.FD_USE_AGGREGATE_SEND
self.mutex = threading.Lock()
self.req_dict = dict()
self.running = True
self.context = zmq.Context()
self._create_socket()
self.mutex = threading.Lock()
self.response_token_lock = threading.Lock()
def _create_socket(self):
"""create and return a ZeroMQ socket."""
self.socket = self.context.socket(self.mode)
self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
self.socket.setsockopt(zmq.SNDTIMEO, -1)
self.socket.bind(f"tcp://*:{self.port}")
return self.socket
def recv_control_cmd(self):
"""
Recieve control command from client
"""
self._ensure_socket()
try:
client, _, task_data = self.socket.recv_multipart(flags=zmq.NOBLOCK)
task = msgpack.unpackb(task_data)
task_id_str = task["task_id"]
except zmq.Again:
return None
with self.mutex:
self.req_dict[task_id_str] = client
return task
def response_for_control_cmd(self, task_id, result):
"""
Send command result back to client.
"""
self._ensure_socket()
if self.socket is None:
raise RuntimeError("Router socket not created.")
try:
result = msgpack.packb(result)
self.socket.send_multipart([self.req_dict[task_id], b"", result])
except Exception as e:
llm_logger.error(f"Send result to zmq client failed: {e}")
with self.mutex:
self.req_dict.pop(task_id, None)
llm_logger.debug(f"response control cmd finished, task_id: {task_id}")
def close(self):
"""
Close the socket and context.
"""
if not self.running:
return
self.running = False
llm_logger.info("Closing ZMQ connection...")
try:
if self.socket is not None and not self.socket.closed:
self.socket.close()
if not self.context.closed:
self.context.term()
except Exception as e:
llm_logger.warning(f"Failed to close ZMQ connection - {e}")
return

View File

@@ -154,6 +154,22 @@ class MetricsManager:
spec_decode_num_emitted_tokens_total: "Counter"
spec_decode_draft_single_head_acceptance_rate: "list[Gauge]"
# for YIYAN Adapter
prefix_cache_token_num: "Gauge"
prefix_gpu_cache_token_num: "Gauge"
prefix_cpu_cache_token_num: "Gauge"
prefix_ssd_cache_token_num: "Gauge"
batch_size: "Gauge"
max_batch_size: "Gauge"
available_gpu_block_num: "Gauge"
free_gpu_block_num: "Gauge"
max_gpu_block_num: "Gauge"
available_gpu_resource: "Gauge"
requests_number: "Counter"
send_cache_failed_num: "Counter"
first_token_latency: "Gauge"
infer_latency: "Gauge"
# 定义所有指标配置
METRICS = {
"num_requests_running": {
@@ -258,6 +274,91 @@ class MetricsManager:
"description": "Total number of successfully processed requests",
"kwargs": {},
},
# for YIYAN Adapter
"prefix_cache_token_num": {
"type": Counter,
"name": "fastdeploy:prefix_cache_token_num",
"description": "Total number of cached tokens",
"kwargs": {},
},
"prefix_gpu_cache_token_num": {
"type": Counter,
"name": "fastdeploy:prefix_gpu_cache_token_num",
"description": "Total number of cached tokens on GPU",
"kwargs": {},
},
"prefix_cpu_cache_token_num": {
"type": Counter,
"name": "fastdeploy:prefix_cpu_cache_token_num",
"description": "Total number of cached tokens on CPU",
"kwargs": {},
},
"prefix_ssd_cache_token_num": {
"type": Counter,
"name": "fastdeploy:prefix_ssd_cache_token_num",
"description": "Total number of cached tokens on SSD",
"kwargs": {},
},
"batch_size": {
"type": Gauge,
"name": "fastdeploy:batch_size",
"description": "Real batch size during inference",
"kwargs": {},
},
"max_batch_size": {
"type": Gauge,
"name": "fastdeploy:max_batch_size",
"description": "Maximum batch size determined when service started",
"kwargs": {},
},
"available_gpu_block_num": {
"type": Gauge,
"name": "fastdeploy:available_gpu_block_num",
"description": "Number of available gpu blocks in cache, including prefix caching blocks that are not officially released",
"kwargs": {},
},
"free_gpu_block_num": {
"type": Gauge,
"name": "fastdeploy:free_gpu_block_num",
"description": "Number of free blocks in cache",
"kwargs": {},
},
"max_gpu_block_num": {
"type": Gauge,
"name": "fastdeploy:max_gpu_block_num",
"description": "Number of total blocks determined when service started",
"kwargs": {},
},
"available_gpu_resource": {
"type": Gauge,
"name": "fastdeploy:available_gpu_resource",
"description": "Available blocks percentage, i.e. available_gpu_block_num / max_gpu_block_num",
"kwargs": {},
},
"requests_number": {
"type": Counter,
"name": "fastdeploy:requests_number",
"description": "Total number of requests received",
"kwargs": {},
},
"send_cache_failed_num": {
"type": Counter,
"name": "fastdeploy:send_cache_failed_num",
"description": "Total number of failures of sending cache",
"kwargs": {},
},
"first_token_latency": {
"type": Gauge,
"name": "fastdeploy:first_token_latency",
"description": "Latest time to first token in seconds",
"kwargs": {},
},
"infer_latency": {
"type": Gauge,
"name": "fastdeploy:infer_latency",
"description": "Latest time to generate one token in seconds",
"kwargs": {},
},
}
SPECULATIVE_METRICS = {}

View File

@@ -445,8 +445,8 @@ class MTPSampler(nn.Layer):
sampling_metadata.min_dec_lens,
sampling_metadata.eos_token_ids,
share_inputs["seq_lens_this_time"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs["output_padding_offset"],
share_inputs["output_cum_offsets"],
max_model_len,
)
probs = F.softmax(logits)

View File

@@ -65,6 +65,7 @@ else:
update_inputs,
step_reschedule,
update_inputs_v1,
speculate_step_reschedule,
)
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput, SamplerOutput
@@ -355,12 +356,11 @@ def step_cuda(
"""
if speculative_config.method is not None:
if enable_prefix_caching:
speculate_step_system_cache(
if DISABLE_RECOVER:
speculate_step_reschedule(
share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"],
share_inputs["step_seq_lens_encoder"],
share_inputs["step_seq_lens_decoder"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs["block_tables"],
@@ -386,64 +386,67 @@ def step_cuda(
speculative_config.num_speculative_tokens,
)
else:
speculate_step_paddle(
share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"],
share_inputs["step_seq_lens_encoder"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs["block_tables"],
share_inputs["encoder_block_lens"],
share_inputs["is_block_step"],
share_inputs["step_block_list"],
share_inputs["step_lens"],
share_inputs["recover_block_list"],
share_inputs["recover_lens"],
share_inputs["need_block_list"],
share_inputs["need_block_len"],
share_inputs["used_list_len"],
share_inputs["free_list"],
share_inputs["free_list_len"],
share_inputs["input_ids"],
share_inputs["pre_ids"],
share_inputs["step_idx"],
share_inputs["next_tokens"],
share_inputs["first_token_ids"],
share_inputs["accept_num"],
block_size,
enc_dec_block_num,
speculative_config.num_speculative_tokens,
)
if enable_prefix_caching:
speculate_step_system_cache(
share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"],
share_inputs["step_seq_lens_encoder"],
share_inputs["step_seq_lens_decoder"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs["block_tables"],
share_inputs["encoder_block_lens"],
share_inputs["is_block_step"],
share_inputs["step_block_list"],
share_inputs["step_lens"],
share_inputs["recover_block_list"],
share_inputs["recover_lens"],
share_inputs["need_block_list"],
share_inputs["need_block_len"],
share_inputs["used_list_len"],
share_inputs["free_list"],
share_inputs["free_list_len"],
share_inputs["input_ids"],
share_inputs["pre_ids"],
share_inputs["step_idx"],
share_inputs["next_tokens"],
share_inputs["first_token_ids"],
share_inputs["accept_num"],
block_size,
enc_dec_block_num,
speculative_config.num_speculative_tokens,
)
else:
speculate_step_paddle(
share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"],
share_inputs["step_seq_lens_encoder"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs["block_tables"],
share_inputs["encoder_block_lens"],
share_inputs["is_block_step"],
share_inputs["step_block_list"],
share_inputs["step_lens"],
share_inputs["recover_block_list"],
share_inputs["recover_lens"],
share_inputs["need_block_list"],
share_inputs["need_block_len"],
share_inputs["used_list_len"],
share_inputs["free_list"],
share_inputs["free_list_len"],
share_inputs["input_ids"],
share_inputs["pre_ids"],
share_inputs["step_idx"],
share_inputs["next_tokens"],
share_inputs["first_token_ids"],
share_inputs["accept_num"],
block_size,
enc_dec_block_num,
speculative_config.num_speculative_tokens,
)
else:
if enable_prefix_caching:
step_system_cache(
share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"],
share_inputs["step_seq_lens_encoder"],
share_inputs["step_seq_lens_decoder"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs["block_tables"],
share_inputs["encoder_block_lens"],
share_inputs["is_block_step"],
share_inputs["step_block_list"],
share_inputs["step_lens"],
share_inputs["recover_block_list"],
share_inputs["recover_lens"],
share_inputs["need_block_list"],
share_inputs["need_block_len"],
share_inputs["used_list_len"],
share_inputs["free_list"],
share_inputs["free_list_len"],
share_inputs["input_ids"],
share_inputs["pre_ids"],
share_inputs["step_idx"],
share_inputs["next_tokens"],
share_inputs["first_token_ids"],
block_size,
enc_dec_block_num,
)
elif DISABLE_RECOVER:
if DISABLE_RECOVER:
step_reschedule(
share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"],
@@ -471,32 +474,61 @@ def step_cuda(
enc_dec_block_num,
)
else:
step_paddle(
share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"],
share_inputs["step_seq_lens_encoder"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs["block_tables"],
share_inputs["encoder_block_lens"],
share_inputs["is_block_step"],
share_inputs["step_block_list"],
share_inputs["step_lens"],
share_inputs["recover_block_list"],
share_inputs["recover_lens"],
share_inputs["need_block_list"],
share_inputs["need_block_len"],
share_inputs["used_list_len"],
share_inputs["free_list"],
share_inputs["free_list_len"],
share_inputs["input_ids"],
share_inputs["pre_ids"],
share_inputs["step_idx"],
share_inputs["next_tokens"],
share_inputs["first_token_ids"],
block_size,
enc_dec_block_num,
)
if enable_prefix_caching:
step_system_cache(
share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"],
share_inputs["step_seq_lens_encoder"],
share_inputs["step_seq_lens_decoder"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs["block_tables"],
share_inputs["encoder_block_lens"],
share_inputs["is_block_step"],
share_inputs["step_block_list"],
share_inputs["step_lens"],
share_inputs["recover_block_list"],
share_inputs["recover_lens"],
share_inputs["need_block_list"],
share_inputs["need_block_len"],
share_inputs["used_list_len"],
share_inputs["free_list"],
share_inputs["free_list_len"],
share_inputs["input_ids"],
share_inputs["pre_ids"],
share_inputs["step_idx"],
share_inputs["next_tokens"],
share_inputs["first_token_ids"],
block_size,
enc_dec_block_num,
)
else:
step_paddle(
share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"],
share_inputs["step_seq_lens_encoder"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs["block_tables"],
share_inputs["encoder_block_lens"],
share_inputs["is_block_step"],
share_inputs["step_block_list"],
share_inputs["step_lens"],
share_inputs["recover_block_list"],
share_inputs["recover_lens"],
share_inputs["need_block_list"],
share_inputs["need_block_len"],
share_inputs["used_list_len"],
share_inputs["free_list"],
share_inputs["free_list_len"],
share_inputs["input_ids"],
share_inputs["pre_ids"],
share_inputs["step_idx"],
share_inputs["next_tokens"],
share_inputs["first_token_ids"],
block_size,
enc_dec_block_num,
)
def rebuild_padding(

View File

@@ -195,7 +195,14 @@ class TokenProcessor:
try:
is_blocking = True
if self.speculative_decoding:
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
if (
self.cfg.parallel_config.enable_expert_parallel
and self.cfg.parallel_config.data_parallel_size > 1
):
speculate_get_output(self.output_tokens, rank_id, is_blocking, True)
else:
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
if self.output_tokens[0] == -2:
continue
@@ -258,13 +265,13 @@ class TokenProcessor:
llm_logger.info(f"finished_task_id: {finished_task_id}")
self.prefill_result_status[finished_task_id[0]] = finished_task_id[1]
if task_id in self.prefill_result_status:
self.split_connector.send_first_token(task.disaggregate_info, [result])
self.resource_manager.stop_flags[index] = True
self.resource_manager.tasks_list[index] = None
self.resource_manager._recycle_block_tables(task)
if self.prefill_result_status[task_id] != "finished":
result.error_code = 400
result.error_message = f"{task_id} failed to {self.prefill_result_status[task_id]}"
result.error_msg = f"{task_id} failed to {self.prefill_result_status[task_id]}"
self.split_connector.send_first_token(task.disaggregate_info, [result])
del self.resource_manager.req_dict[task_id]
break
else:
@@ -276,6 +283,15 @@ class TokenProcessor:
self.resource_manager.stop_flags[index] = True
self.resource_manager.tasks_list[index] = None
self.resource_manager._recycle_block_tables(task)
task_used_block_num = sum([len(task.block_tables) if task else 0 for task in self.resource_manager.tasks_list])
main_process_metrics.available_gpu_block_num.set(
self.resource_manager.total_block_number() - task_used_block_num
)
main_process_metrics.batch_size.set(
self.resource_manager.max_num_seqs - self.resource_manager.available_batch()
)
if task_id in self.tokens_counter:
del self.tokens_counter[task_id]
@@ -412,7 +428,11 @@ class TokenProcessor:
self._record_completion_metrics(task, current_time)
self._recycle_resources(task_id, i, task, result, is_prefill)
break
if not is_prefill or self.cfg.scheduler_config.name == "splitwise":
if (
not is_prefill
or self.cfg.scheduler_config.name == "splitwise"
or self.cfg.scheduler_config.name == "dp"
):
batch_result.append(result)
self.postprocess(batch_result)
@@ -427,6 +447,7 @@ class TokenProcessor:
batch = self.output_tokens[1]
accept_num = tokens[2 : batch + 2]
self._record_speculative_decoding_mertics(accept_num)
else:
batch = self.output_tokens[1, 0]
tokens = tokens[2 : batch + 2]
@@ -441,16 +462,22 @@ class TokenProcessor:
task_id = task.request_id
if self.cfg.speculative_config.method:
token_ids = tokens[
2
+ SPECULATE_MAX_BSZ
+ i * MAX_DRAFT_TOKENS : 2
+ SPECULATE_MAX_BSZ
+ i * MAX_DRAFT_TOKENS
+ accept_num[i]
].tolist()
if len(token_ids) == 0 or token_ids[-1] <= 0:
continue
if accept_num[i] == -3:
recovery_stop = True
if recovery_stop:
llm_logger.info(f"recovery stop signal found at task {task_id}")
token_ids = [RECOVERY_STOP_SIGNAL]
else:
token_ids = tokens[
2
+ SPECULATE_MAX_BSZ
+ i * MAX_DRAFT_TOKENS : 2
+ SPECULATE_MAX_BSZ
+ i * MAX_DRAFT_TOKENS
+ accept_num[i]
].tolist()
if (not recovery_stop) and (len(token_ids) == 0 or token_ids[-1] <= 0):
continue
else:
token_id = int(tokens[i, 0])
token_ids = [token_id]
@@ -474,6 +501,7 @@ class TokenProcessor:
arrival_time=task.arrival_time,
inference_start_time=task.inference_start_time,
first_token_time=time.time() - task.inference_start_time,
model_execute_time=time.time() - task.inference_start_time,
time_in_queue=task.schedule_start_time - task.preprocess_end_time,
preprocess_cost_time=task.preprocess_end_time - task.preprocess_start_time,
request_start_time=task.arrival_time,
@@ -485,6 +513,7 @@ class TokenProcessor:
metrics = RequestMetrics(
arrival_time=time.time(),
request_start_time=task.arrival_time,
model_execute_time=time.time() - task.inference_start_time,
)
self.number_of_output_tokens += len(token_ids)
self._record_metrics(task, current_time, token_ids)
@@ -502,7 +531,7 @@ class TokenProcessor:
if self.tokens_counter[task_id] == 0:
if task.messages is not None:
result.prompt = task.messages
result.num_cached_tokens = task.num_cached_tokens
result.num_cached_tokens = task.num_cached_tokens
is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill"
@@ -512,7 +541,8 @@ class TokenProcessor:
for token_id in token_ids:
self.tokens_counter[task_id] += 1
if token_id != RECOVERY_STOP_SIGNAL:
result.outputs.token_ids.append(token_id)
if not (envs.FD_ENABLE_INTERNAL_ADAPTER and token_id in task.eos_token_ids):
result.outputs.token_ids.append(token_id)
task.output_token_ids.append(token_id)
if token_id in task.eos_token_ids or is_prefill or recovery_stop:
result.finished = True
@@ -531,7 +561,11 @@ class TokenProcessor:
self._record_completion_metrics(task, current_time)
self._recycle_resources(task_id, i, task, result, is_prefill)
break
if not is_prefill or self.cfg.scheduler_config.name == "splitwise":
if (
not is_prefill
or self.cfg.scheduler_config.name == "splitwise"
or self.cfg.scheduler_config.name == "dp"
):
batch_result.append(result)
self.postprocess(batch_result)
@@ -549,6 +583,7 @@ class TokenProcessor:
def _record_first_token_metrics(self, task, current_time):
"""Record metrics for first token"""
task.first_token_time = current_time
main_process_metrics.first_token_latency.set(current_time - task.inference_start_time)
main_process_metrics.time_to_first_token.observe(current_time - task.inference_start_time)
main_process_metrics.request_queue_time.observe(task.schedule_start_time - task.preprocess_end_time)
@@ -560,6 +595,7 @@ class TokenProcessor:
main_process_metrics.num_requests_running.dec(1)
main_process_metrics.request_success_total.inc()
main_process_metrics.infer_latency.set(current_time - task.inference_start_time)
main_process_metrics.request_inference_time.observe(current_time - task.inference_start_time)
main_process_metrics.request_generation_tokens.observe(self.tokens_counter[task.request_id])
@@ -571,7 +607,7 @@ class TokenProcessor:
self.cfg.speculative_config.num_speculative_tokens,
)
real_accept_num = [x for x in accept_num if x != 0]
real_accept_num = [x for x in accept_num if x > 0]
num_accepted_tokens = sum([x - 1 for x in real_accept_num])
self.num_accepted_tokens += num_accepted_tokens
num_emitted_tokens = sum(real_accept_num)

View File

@@ -18,6 +18,7 @@ import redis
from fastdeploy.utils import llm_logger
from .dp_scheduler import DPScheduler
from .global_scheduler import GlobalScheduler
from .local_scheduler import LocalScheduler
from .splitwise_scheduler import SplitWiseScheduler, SplitWiseSchedulerConfig
@@ -89,6 +90,57 @@ class LocalSchedulerConfig:
llm_logger.info("=============================================================")
class DPLocalSchedulerConfig(LocalSchedulerConfig):
"""
Configuration class for DPLocalScheduler.
Attributes:
max_size: Maximum number of concurrent requests (-1 for unlimited)
ttl: Time-to-live in seconds for request expiration
"""
def __init__(
self,
max_size: int = -1,
ttl: int = 900,
max_model_len: int = 8192,
enable_chunked_prefill: bool = False,
max_num_partial_prefills: int = 1,
max_long_partial_prefills: int = 1,
long_prefill_token_threshold: int = 0,
splitwise_role: str = "prefill",
**kwargs,
):
"""
Initialize LocalScheduler configuration.
Args:
max_size: Maximum concurrent requests (-1 for unlimited, 0 for disabled)
ttl: Time-to-live in seconds for request expiration (default 900s)
max_model_len: Maximum model context length in tokens
enable_chunked_prefill: Whether to enable chunked prefill processing
max_num_partial_prefills: Max partial prefill operations allowed
max_long_partial_prefills: Max long-running partial prefill ops
long_prefill_token_threshold: Token count threshold for long prefill
**kwargs: Additional unused arguments (for forward compatibility)
Note:
- If long_prefill_token_threshold is 0, it's auto-calculated as 4% of max_model_len
- See LocalScheduler class for implementation details
"""
self.max_size = max_size
self.ttl = ttl
self.max_model_len = max_model_len
self.enable_chunked_prefill = enable_chunked_prefill
self.max_num_partial_prefills = max_num_partial_prefills
self.max_long_partial_prefills = max_long_partial_prefills
self.long_prefill_token_threshold = long_prefill_token_threshold
if self.long_prefill_token_threshold == 0:
self.long_prefill_token_threshold = int(self.max_model_len * 0.04)
self.splitwise_role = splitwise_role
class GlobalSchedulerConfig:
"""
Configuration class for GlobalScheduler (Redis-based).
@@ -229,6 +281,9 @@ class SchedulerConfig:
if name == "splitwise":
self.config = SplitWiseSchedulerConfig(**kwargs)
if name == "dp":
self.config = DPLocalSchedulerConfig(**kwargs)
def check(self):
"""
Validate the configuration.
@@ -236,7 +291,7 @@ class SchedulerConfig:
Raises:
Exception: If invalid scheduler type is specified
"""
if self.name not in ["local", "global", "splitwise"]:
if self.name not in ["local", "global", "splitwise", "dp"]:
raise Exception(f"Unknown scheduler type {self.name}")
self.config.check()
@@ -274,6 +329,17 @@ class SchedulerConfig:
if self.name == "splitwise":
return SplitWiseScheduler(self.config)
if self.name == "dp":
return DPScheduler(
max_size=self.config.max_size,
ttl=self.config.ttl,
enable_chunked_prefill=self.config.enable_chunked_prefill,
max_num_partial_prefills=self.config.max_num_partial_prefills,
max_long_partial_prefills=self.config.max_long_partial_prefills,
long_prefill_token_threshold=self.config.long_prefill_token_threshold,
splitwise_role=self.config.splitwise_role,
)
return LocalScheduler(
max_size=self.config.max_size,
ttl=self.config.ttl,

View File

@@ -0,0 +1,258 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import logging
import threading
import time
from multiprocessing import Queue
from typing import Dict, List, Optional
from fastdeploy.engine.request import Request, RequestOutput
from fastdeploy.scheduler.data import ScheduledResponse
from fastdeploy.scheduler.local_scheduler import LocalScheduler
from fastdeploy.utils import envs, get_logger
class DPLocalScheduler(LocalScheduler):
def __init__(
self,
max_size: int,
ttl: int,
enable_chunked_prefill: bool,
max_num_partial_prefills: int,
max_long_partial_prefills: int,
long_prefill_token_threshold: int,
splitwise_role: str = "prefill",
):
super().__init__(
max_size,
ttl,
enable_chunked_prefill,
max_num_partial_prefills,
max_long_partial_prefills,
long_prefill_token_threshold,
)
self.splitwise_role = splitwise_role
self.scheduler_logger = logging
def put_results(self, results: List[RequestOutput]):
"""
Add processing results back to the scheduler.
Args:
results: List of RequestOutput objects containing results
"""
responses: List[ScheduledResponse] = [ScheduledResponse(result) for result in results]
finished_responses = [response.request_id for response in responses if response.finished]
if len(finished_responses) > 0:
self.scheduler_logger.info(f"Scheduler has received some finished responses: {finished_responses}")
with self.mutex:
for response in responses:
if response.request_id not in self.responses:
self.responses[response.request_id] = [response]
continue
self.responses[response.request_id].append(response)
self.responses_not_empty.notify_all()
def _recycle(self, request_id: Optional[str] = None):
"""
Clean up expired or completed requests to free memory.
Args:
request_id: Optional specific request ID to remove.
If None, removes all expired requests.
"""
if request_id is not None:
self.requests.pop(request_id, None)
self.responses.pop(request_id, None)
if self.splitwise_role == "decode":
return
self.ids.pop(self.ids.index(request_id))
self.ids_read_cursor -= 1
return
if self.max_size <= 0:
return
if len(self.requests) <= self.max_size:
return
now = time.time()
expired_ids = []
for request_id in self.ids:
request = self.requests[request_id]
if now - request.schedule_time < self.ttl:
break
expired_ids.append(request.request_id)
for i, expired_id in enumerate(expired_ids):
self.requests.pop(expired_id, None)
self.responses.pop(expired_id, None)
self.ids.pop(i)
if len(expired_ids) > 0:
if len(expired_ids) - 1 >= self.ids_read_cursor:
self.ids_read_cursor = 0
else:
self.ids_read_cursor -= len(expired_ids)
def get_requests(
self,
available_blocks,
block_size,
reserved_output_blocks,
max_num_batched_tokens,
batch=1,
) -> List[Request]:
"""
Retrieve requests from the scheduler based on available resources.
Args:
available_blocks: Number of available processing blocks
block_size: Size of each processing block
reserved_output_blocks: Blocks reserved for output
max_num_batched_tokens: Maximum tokens that can be batched
batch: Preferred batch size
Returns:
List of Request objects ready for processing
"""
if available_blocks <= reserved_output_blocks or batch < 1:
self.scheduler_logger.debug(
f"Scheduler's resource are insufficient: available_blocks={available_blocks} "
f"reserved_output_blocks={reserved_output_blocks} batch={batch} "
f"max_num_batched_tokens={max_num_batched_tokens}"
)
return []
required_total_blocks = 0
current_prefill_tokens = 0
start_batch_time = time.time()
requests: List[Request] = []
with self.requests_not_empty:
while True:
batch_ids = self.requests_not_empty.wait_for(
lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + batch],
0.005,
)
if batch_ids:
for request_id in batch_ids:
request = self.requests[request_id]
required_input_blocks = self.calc_required_blocks(request.prompt_tokens_ids_len, block_size)
current_prefill_tokens += request.prompt_tokens_ids_len
required_total_blocks += required_input_blocks + reserved_output_blocks
if required_total_blocks > available_blocks:
break
requests.append(request.raw)
self.ids_read_cursor += 1
start_batch_time = time.time()
if current_prefill_tokens > max_num_batched_tokens:
break
if len(requests) >= batch:
break
if (
(current_prefill_tokens > max_num_batched_tokens)
or (len(requests) >= batch)
or (time.time() - start_batch_time > envs.FD_EP_BATCHED_TOKEN_TIMEOUT)
):
break
if batch_ids:
if len(batch_ids) > 0 and len(requests) == 0:
self.scheduler_logger.debug(
f"Scheduler has put all just-pulled request into the queue: {len(batch_ids)}"
)
if len(requests) > 0:
self.scheduler_logger.info(
f"Scheduler has pulled some request: {[request.request_id for request in requests]}"
)
return requests
class DPScheduler:
def __init__(
self,
max_size: int,
ttl: int,
enable_chunked_prefill: bool,
max_num_partial_prefills: int,
max_long_partial_prefills: int,
long_prefill_token_threshold: int,
splitwise_role: str = "prefill",
):
self._scheduler = DPLocalScheduler(
max_size,
ttl,
enable_chunked_prefill,
max_num_partial_prefills,
max_long_partial_prefills,
long_prefill_token_threshold,
splitwise_role,
)
def start(self, dp_rank: int, request_queues: List[Queue], result_queue: Queue):
self.dp_rank = dp_rank
self.request_queues = request_queues
self.result_queue = result_queue
self.scheduler_logger = get_logger("dpscheduler", f"dp_scheduler_rank{self.dp_rank}.log")
self._scheduler.scheduler_logger = self.scheduler_logger
threading.Thread(target=self._put_requests_to_local).start()
threading.Thread(target=self._get_response_from_local).start()
def put_requests(self, requests: List[Dict]):
results = []
for request in requests:
if not hasattr(request, "dp_rank"):
raise ValueError(f"Request object is missing the 'dp_rank' attribute: {request}")
self.request_queues[request.dp_rank].put(request)
results.append((request.request_id, None))
return results
def _put_requests_to_local(self):
while True:
request = self.request_queues[self.dp_rank].get()
self.scheduler_logger.info(f"Recieve request from puller, request_id: {request.request_id}")
self._scheduler.put_requests([request])
def _get_response_from_local(self):
while True:
results = self._scheduler.get_results()
if len(results) == 0:
continue
self.result_queue.put(results)
def get_requests(
self,
available_blocks,
block_size,
reserved_output_blocks,
max_num_batched_tokens,
batch=1,
) -> List[Request]:
return self._scheduler.get_requests(
available_blocks, block_size, reserved_output_blocks, max_num_batched_tokens, batch
)
def get_unhandled_request_num(self):
return len(self._scheduler.requests)
def put_results(self, results: List[RequestOutput]):
self._scheduler.put_results(results)
def get_results(self) -> Dict[str, List[RequestOutput]]:
return self.result_queue.get()

View File

@@ -208,6 +208,9 @@ class LocalScheduler:
"""
return (token_num + block_size - 1) // block_size
def get_unhandled_request_num(self):
return len(self.requests)
def get_requests(
self,
available_blocks,

View File

@@ -37,6 +37,7 @@ from fastdeploy.model_executor.ops.gpu import (
eagle_get_self_hidden_states,
mtp_save_first_token,
mtp_step_paddle,
set_data_ipc,
share_external_data,
)
from fastdeploy.model_executor.pre_and_post_process import pre_process, rebuild_padding
@@ -75,6 +76,7 @@ class MTPProposer(Proposer):
self.model_config.num_hidden_layers = 1
self.model_config.model = self.speculative_config.model
self.model_config.pretrained_config.prefix_name = "ernie.mtp_block"
self.model_config.is_quantized = False
if self.speculative_config.quantization != "":
self.model_config.quantization = self.speculative_config.quantization
self.model_config.start_layer_index = self.num_main_model_layers
@@ -141,17 +143,16 @@ class MTPProposer(Proposer):
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
)
if not self.parallel_config.do_profile and (
self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"
):
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
if not self.parallel_config.do_profile and self.parallel_config.splitwise_role != "mixed":
cache_kvs_list = []
for i in range(
self.num_main_model_layers,
self.num_main_model_layers + self.model_config.num_hidden_layers,
):
key_cache = paddle.empty(shape=[], dtype=cache_type)
key_cache_name = f"key_caches_{i}_rank{self.local_rank}.device{self.device_id}"
val_cache_name = f"value_caches_{i}_rank{self.local_rank}.device{self.device_id}"
key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape)
cache_kvs_list.append(key_cache)
value_cache = paddle.empty(shape=[], dtype=cache_type)
@@ -160,7 +161,10 @@ class MTPProposer(Proposer):
self.model_inputs["caches"] = cache_kvs_list
else:
for i in range(self.model_config.num_hidden_layers):
for i in range(
self.num_main_model_layers,
self.num_main_model_layers + self.model_config.num_hidden_layers,
):
self.cache_kvs[f"key_caches_{i}"] = paddle.full(
shape=kv_cache_shape,
fill_value=0,
@@ -171,6 +175,15 @@ class MTPProposer(Proposer):
fill_value=0,
dtype=cache_type,
)
if self.cache_config.enable_prefix_caching:
set_data_ipc(
self.cache_kvs[f"key_caches_{i}"],
f"key_caches_{i}_rank{local_rank}.device{self.device_id}",
)
set_data_ipc(
self.cache_kvs[f"value_caches_{i}"],
f"value_caches_{i}_rank{local_rank}.device{self.device_id}",
)
self.model_inputs["caches"] = list(self.cache_kvs.values())
for value in self.cache_kvs.values():
del value
@@ -235,7 +248,7 @@ class MTPProposer(Proposer):
self.main_model_num_gpu_blocks = num_gpu_blocks
self.num_gpu_blocks = int(num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio)
if not (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"):
if self.parallel_config.splitwise_role == "mixed":
self.initialize_kv_cache()
# Reset free list

View File

@@ -0,0 +1,117 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import threading
import time
import traceback
# **Note**: Just for internal use
import zmq
from fastdeploy.inter_communicator import ZmqTcpServer
from fastdeploy.metrics.metrics import get_filtered_metrics, main_process_metrics
from fastdeploy.utils import envs, get_logger
logger = get_logger("internal_adapter_utils", "internal_adapter_utils.log")
class InternalAdapter:
def __init__(self, cfg, engine, dp_rank):
self.cfg = cfg
self.engine = engine
self.dp_rank = dp_rank
recv_control_cmd_ports = envs.FD_ZMQ_CONTROL_CMD_SERVER_PORTS.split(",")
self.response_lock = threading.Lock() # prevent to call send_multipart in zmq concurrently
self.recv_control_cmd_server = ZmqTcpServer(port=recv_control_cmd_ports[dp_rank], mode=zmq.ROUTER)
self.recv_external_instruct_thread = threading.Thread(
target=self._recv_external_module_control_instruct, daemon=True
)
self.recv_external_instruct_thread.start()
self.response_external_instruct_thread = threading.Thread(
target=self._response_external_module_control_instruct, daemon=True
)
self.response_external_instruct_thread.start()
def _get_current_server_info(self):
"""
Get resources information
"""
available_batch_size = min(self.cfg.max_prefill_batch, self.engine.resource_manager.available_batch())
available_block_num = self.engine.resource_manager.available_block_num()
server_info = {
"splitwise_role": self.cfg.splitwise_role,
"block_size": int(self.cfg.cache_config.block_size),
"block_num": int(available_block_num),
"max_block_num": int(self.cfg.cache_config.total_block_num),
"dec_token_num": int(self.cfg.cache_config.dec_token_num),
"available_resource": float(1.0 * available_block_num / self.cfg.cache_config.total_block_num),
"max_batch_size": int(available_batch_size),
"max_input_token_num": self.cfg.max_num_batched_tokens,
"unhandled_request_num": self.engine.scheduler.get_unhandled_request_num(),
"available_batch": int(self.engine.resource_manager.available_batch()),
}
return server_info
def _recv_external_module_control_instruct(self):
"""
Receive a multipart message from the control cmd socket.
"""
while True:
try:
with self.response_lock:
task = self.recv_control_cmd_server.recv_control_cmd()
if task is None:
time.sleep(0.001)
continue
logger.info(f"Recieve control task: {task}")
task_id_str = task["task_id"]
if task["cmd"] == "get_payload":
payload_info = self._get_current_server_info()
result = {"task_id": task_id_str, "result": payload_info}
logger.debug(f"Response for task: {task_id_str}")
with self.response_lock:
self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result)
elif task["cmd"] == "get_metrics":
metrics_text = get_filtered_metrics(
[],
extra_register_func=lambda reg: main_process_metrics.register_all(reg, workers=1),
)
result = {"task_id": task_id_str, "result": metrics_text}
logger.debug(f"Response for task: {task_id_str}")
with self.response_lock:
self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result)
elif task["cmd"] == "connect_rdma":
self.engine.engine_worker_queue.put_connect_rdma_task(task)
except Exception as e:
logger.error(f"handle_control_cmd got error: {e}, {traceback.format_exc()!s}")
def _response_external_module_control_instruct(self):
while True:
try:
result_data = self.engine.engine_worker_queue.get_connect_rdma_task_response()
if result_data:
task_id_str = result_data["task_id"]
result = {"task_id": task_id_str, "result": result_data}
logger.info(f"Response for task: {task_id_str}")
with self.response_lock:
self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result)
else:
time.sleep(0.001)
except Exception as e:
logger.error(f"_handle_connect_rdma_results got error: {e}, {traceback.format_exc() !s}")

View File

@@ -14,27 +14,26 @@
# limitations under the License.
"""
import json
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Dict
import msgpack
import zmq
from fastdeploy import envs
from fastdeploy.engine.request import CompletionOutput, Request, RequestOutput
from fastdeploy.inter_communicator import EngineWorkerQueue
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.utils import get_logger
logger = get_logger("splitwise_connector", "splitwise_connector.log")
class SplitwiseConnector:
"""
SplitwiseConnector class for managing and scheduling Splitwise tasks.
"""
def __init__(self, cfg, scheduler, worker_queue, resource_manager):
def __init__(self, cfg, scheduler, worker_queue, resource_manager, splitwise_queue):
"""
Initialize the SplitwiseConnector instance.
@@ -45,12 +44,20 @@ class SplitwiseConnector:
resource_manager (object): Resource manager object.
"""
self.cfg = cfg
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
self.logger = get_logger(
"splitwise_connector", f"splitwise_connector_{self.cfg.parallel_config.local_data_parallel_id}.log"
)
else:
self.logger = get_logger("splitwise_connector", "splitwise_connector.log")
self.scheduler = scheduler
self.engine_worker_queue = worker_queue
self.resource_manager = resource_manager
self.connect_innode_instances = {}
self.temp_cache_info = dict()
self.current_request_ids = dict()
self.splitwise_queue = splitwise_queue
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
if self.cfg.cache_config.pd_comm_port is not None:
self.zmq_ctx = zmq.Context()
@@ -69,7 +76,7 @@ class SplitwiseConnector:
self.router_socket.setsockopt(zmq.SNDHWM, 1000)
self.router_socket.setsockopt(zmq.ROUTER_MANDATORY, 1)
self.router_socket.bind(f"tcp://*:{self.cfg.cache_config.pd_comm_port[0]}")
logger.info(f"bind {self.cfg.cache_config.pd_comm_port}")
self.logger.info(f"bind {self.cfg.cache_config.pd_comm_port[0]}")
self.poller = zmq.Poller()
self.poller.register(self.router_socket, zmq.POLLIN)
@@ -88,16 +95,16 @@ class SplitwiseConnector:
if not socks:
continue
else:
logger.debug(f"receive {socks}")
self.logger.debug(f"receive {socks}")
frames = self.router_socket.recv_multipart()
logger.debug(f"frames: {frames}")
self.logger.debug(f"frames: {frames}")
message = frames[-1]
self.io_executor.submit(self._process_message, message)
time.sleep(0.001)
except Exception as e:
logger.error(f"Receiver error: {e}")
self.logger.error(f"Receiver error: {e}")
time.sleep(1)
def _get_push_socket(self, addr):
@@ -109,7 +116,7 @@ class SplitwiseConnector:
return sock
try:
logger.info(f"Establishing new connection to {addr}")
self.logger.info(f"Establishing new connection to {addr}")
sock = self.zmq_ctx.socket(zmq.DEALER)
# 设置连接参数
@@ -128,7 +135,7 @@ class SplitwiseConnector:
return sock
except zmq.ZMQError as e:
logger.error(f"Connection to {addr} failed: {e}")
self.logger.error(f"Connection to {addr} failed: {e}")
raise ConnectionError(f"Failed to connect to {addr}") from e
@@ -137,7 +144,7 @@ class SplitwiseConnector:
return
try:
logger.info(f"Sent {msg_type} to {addr}")
self.logger.info(f"Sent {msg_type} to {addr}")
message = self._serialize_message(msg_type, payload)
try:
@@ -145,18 +152,19 @@ class SplitwiseConnector:
sock = self._get_push_socket(addr)
sock.send_multipart([b"", message])
logger.info(f"Sent {msg_type} to {addr}")
self.logger.info(f"Sent {msg_type} to {addr}")
except ConnectionError:
logger.warning(f"Connection to {addr} not established")
self.logger.warning(f"Connection to {addr} not established")
except zmq.Again:
logger.warning(f"Send queue full for {addr}")
self.logger.warning(f"Send queue full for {addr}")
except Exception as e:
logger.error(f"Send to {addr} failed: {e}")
main_process_metrics.send_cache_failed_num.inc()
self.logger.error(f"Send to {addr} failed: {e}")
self._close_connection(addr)
except Exception as e:
logger.error(f"Message preparation failed: {e}")
self.logger.error(f"Message preparation failed: {e}")
def _close_connection(self, addr):
"""
@@ -261,7 +269,7 @@ class SplitwiseConnector:
f"{task.disaggregate_info['cache_info']['rdma']['ip']}:"
+ f"{task.disaggregate_info['cache_info']['rdma']['port']}"
)
logger.info(f"send splitwise tasks to port {addr} decode")
self.logger.info(f"send splitwise tasks to port {addr} decode")
self.current_request_ids[task.request_id] = "init"
decode_diagg = task.disaggregate_info["cache_info"]
task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"]
@@ -289,7 +297,7 @@ class SplitwiseConnector:
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks))
for task in tasks:
task.disaggregate_info["cache_info"]["ipc"]["port"] = port
logger.info(f"send splitwise tasks to port {port} decode")
self.logger.info(f"send splitwise tasks to port {port} decode")
current_port = port
return current_port
@@ -299,7 +307,7 @@ class SplitwiseConnector:
"""
if not isinstance(tasks_list, list):
tasks_list = [tasks_list]
logger.info("send first token to port decode")
self.logger.info("send first token to port decode")
if prefill_msg["transfer_protocol"] == "ipc":
port = prefill_msg["cache_info"]["ipc"]["port"]
if port not in self.connect_innode_instances:
@@ -307,7 +315,7 @@ class SplitwiseConnector:
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks_list))
else:
node = f"{prefill_msg['cache_info']['rdma']['ip']}:{prefill_msg['cache_info']['rdma']['port']}"
logger.info(f"send first token to port {node} decode")
self.logger.info(f"send first token to port {node} decode")
self._send_message(node, "decode", tasks_list)
def create_connection(self, port):
@@ -323,6 +331,22 @@ class SplitwiseConnector:
client_id=0,
)
def check_decode_allocated(self, task):
if task.disaggregate_info is None:
return True, ""
if self.enable_decode_cache_task:
return True, ""
if task.disaggregate_info["role"] != "prefill":
return True, ""
while self.current_request_ids[task.request_id] == "init":
time.sleep(0.001)
msg = self.current_request_ids[task.request_id]
del self.current_request_ids[task.request_id]
if msg == "finished":
return True, ""
self.logger.error(f"Receive_decode_allocated error: {msg}")
return False, msg
def send_cache_infos(self, tasks, current_id):
"""
Send cache information to specific port.
@@ -339,15 +363,21 @@ class SplitwiseConnector:
for i in range(len(tasks)):
if tasks[i].disaggregate_info is None:
continue
logger.info(f"{tasks[i].disaggregate_info}")
self.logger.info(f"{tasks[i].disaggregate_info}")
if tasks[i].disaggregate_info["role"] == "decode":
if tasks[i].disaggregate_info["transfer_protocol"] == "ipc":
cache_info = {
"request_id": tasks[i].request_id,
"device_ids": self.cfg.device_ids.split(","),
"transfer_protocol": "ipc",
"dest_block_ids": tasks[i].disaggregate_info["block_tables"],
}
if tasks[i].get("error_msg", None) is not None:
cache_info = {
"request_id": tasks[i].request_id,
"error_msg": tasks[i].get("error_msg"),
}
else:
cache_info = {
"request_id": tasks[i].request_id,
"device_ids": self.cfg.device_ids.split(","),
"transfer_protocol": "ipc",
"dest_block_ids": tasks[i].disaggregate_info["block_tables"],
}
if tasks[i].disaggregate_info["cache_info"]["ipc"]["port"] not in temp_cache_info:
temp_cache_info[tasks[i].disaggregate_info["cache_info"]["ipc"]["port"]] = []
temp_cache_info[tasks[i].disaggregate_info["cache_info"]["ipc"]["port"]].append(cache_info)
@@ -356,14 +386,20 @@ class SplitwiseConnector:
f"{tasks[i].disaggregate_info['cache_info']['rdma']['ip']}:"
+ f"{tasks[i].disaggregate_info['cache_info']['rdma']['port']}"
)
cache_info = {
"request_id": tasks[i].request_id,
"device_ids": self.cfg.device_ids.split(","),
"ip": self.cfg.host_ip,
"rdma_ports": self.cfg.disaggregate_info["cache_info"]["rdma"]["rdma_port"],
"transfer_protocol": "rdma",
"dest_block_ids": tasks[i].disaggregate_info["block_tables"],
}
if tasks[i].get("error_msg", None) is not None:
cache_info = {
"request_id": tasks[i].request_id,
"error_msg": tasks[i].get("error_msg"),
}
else:
cache_info = {
"request_id": tasks[i].request_id,
"device_ids": self.cfg.device_ids.split(","),
"ip": self.cfg.host_ip,
"rdma_ports": self.cfg.disaggregate_info["cache_info"]["rdma"]["rdma_port"],
"transfer_protocol": "rdma",
"dest_block_ids": tasks[i].disaggregate_info["block_tables"],
}
if addr not in temp_cache_info:
temp_cache_info[addr] = []
@@ -390,7 +426,7 @@ class SplitwiseConnector:
else:
if len(temp_cache_info):
for k, v in temp_cache_info.items():
logger.info(f"{k} {v}")
self.logger.info(f"{k} {v}")
if ":" in str(k):
self._send_message(k, "cache_sync", v)
else:
@@ -406,13 +442,19 @@ class SplitwiseConnector:
if msg_type == "decode" or msg_type == "prefill":
payload = [output.to_dict() for output in payload]
json_data = json.dumps({"type": msg_type, "payload": payload}).encode("utf-8")
req_ids = [task["request_id"] for task in payload]
self.logger.info(f"send message {msg_type} {req_ids}")
json_data = msgpack.packb({"type": msg_type, "payload": payload})
return json_data
def _deserialize_message(self, data: bytes):
# JSON反序列化
message = json.loads(data.decode("utf-8"))
message = msgpack.unpackb(data)
req_ids = [task["request_id"] for task in message["payload"]]
self.logger.info(f"recv message type {message['type']} for {req_ids}")
return message["type"], message["payload"]
def _process_message(self, message: bytes):
@@ -421,7 +463,7 @@ class SplitwiseConnector:
"""
try:
msg_type, payload = self._deserialize_message(message)
logger.info(f"{msg_type}")
self.logger.info(f"{msg_type}")
if msg_type == "prefill":
self._handle_prefill(payload)
@@ -429,11 +471,16 @@ class SplitwiseConnector:
self._handle_decode(payload)
elif msg_type == "cache_sync":
for task in payload:
del self.current_request_ids[task["request_id"]]
self.engine_worker_queue.put_cache_info(payload)
self.logger.info(f"cache_sync task: {task}")
current_status = task.get("error_msg", "finished")
self.current_request_ids[task["request_id"]] = current_status
if self.enable_decode_cache_task:
del self.current_request_ids[task["request_id"]]
if current_status == "finished":
self.engine_worker_queue.put_cache_info(payload)
except Exception as e:
logger.error(f"Message processing failed: {e}")
self.logger.error(f"Message processing failed: {e}")
def _handle_prefill(self, tasks):
"""
@@ -441,7 +488,9 @@ class SplitwiseConnector:
"""
tasks_data = [Request.from_dict(task) for task in tasks]
self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks_data))
req_ids = [task["request_id"] for task in tasks]
self.splitwise_queue.append(("decode", tasks_data))
self.logger.info(f"{req_ids} received prefill data")
def _handle_decode(self, payload):
"""
@@ -456,8 +505,13 @@ class SplitwiseConnector:
index=task["outputs"]["index"],
send_idx=0,
token_ids=task["outputs"]["token_ids"],
draft_token_ids=task["outputs"]["draft_token_ids"],
),
finished=True,
error_code=task["error_code"],
error_msg=task["error_msg"],
)
)
self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks))
req_ids = [task["request_id"] for task in payload]
self.splitwise_queue.append(("decode", tasks))
self.logger.info(f"{req_ids} received decode data")

View File

@@ -43,6 +43,7 @@ from fastdeploy.model_executor.layers.sample.sampler import Sampler, Speculative
from fastdeploy.model_executor.model_loader import get_model_loader
from fastdeploy.model_executor.ops.gpu import (
recover_decode_task,
set_data_ipc,
set_value_by_flags_and_idx,
share_external_data,
)
@@ -904,7 +905,7 @@ class GPUModelRunner(ModelRunnerBase):
)
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
if not profile and (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"):
if not profile and self.parallel_config.splitwise_role != "mixed":
cache_kvs_list = []
for i in range(self.model_config.num_hidden_layers):
key_cache = paddle.empty(shape=[], dtype=cache_type)
@@ -930,6 +931,15 @@ class GPUModelRunner(ModelRunnerBase):
fill_value=0,
dtype=cache_type,
)
if self.cache_config.enable_prefix_caching:
set_data_ipc(
cache_kvs[f"key_caches_{i}"],
f"key_caches_{i}_rank{local_rank}.device{self.device_id}",
)
set_data_ipc(
cache_kvs[f"value_caches_{i}"],
f"value_caches_{i}_rank{local_rank}.device{self.device_id}",
)
self.share_inputs["caches"] = list(cache_kvs.values())
for value in cache_kvs.values():
del value
@@ -1138,6 +1148,8 @@ class GPUModelRunner(ModelRunnerBase):
if task.chunk_idx > len(task.prefill_chunk_info):
continue
self.restore_chunked_prefill_request[task.request_id] = task
if len(self.restore_chunked_prefill_request) > 0:
self.share_inputs["not_need_stop"][0] = True
for id, task in list(self.restore_chunked_prefill_request.items()):
idx = task.idx
@@ -1182,7 +1194,7 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size
self.share_inputs["prompt_lens"][idx : idx + 1] += token_chunk_size
self.share_inputs["step_idx"][idx : idx + 1] = 0
self.share_inputs["stop_flags"][idx : idx + 1] = False
if self.speculative_decoding and self.proposer.is_chunk_prefill_enabled():
self.proposer.update_task_chunk_prefill(task)
task.chunk_idx += 1
@@ -1256,17 +1268,17 @@ class GPUModelRunner(ModelRunnerBase):
We plan to replace it with 'ModelForwardBatch'.
intermediate_tensors:
"""
# 1. Prepare inputs of model and sampler.
skip_idx_list = self._get_skip_idx(model_forward_batch)
self._prepare_inputs()
self.sampler.pre_process(skip_idx_list)
# NOTE(wufeisheng): If `not_need_stop`` is False, it means the current worker is in an idle state.
# This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode,
# when there is data on other runner, the current runner is required to execute part of the model.
if not self.not_need_stop():
self._execute_empty_input()
return None
start_time = time.time()
# 1. Prepare inputs of model and sampler.
skip_idx_list = self._get_skip_idx(model_forward_batch)
self._prepare_inputs()
self.sampler.pre_process(skip_idx_list)
# 2. Padding inputs for cuda graph
self.padding_cudagraph_inputs()
@@ -1397,6 +1409,8 @@ class GPUModelRunner(ModelRunnerBase):
self._update_chunked_prefill(model_forward_batch)
self._add_cache(model_forward_batch)
end_time = time.time()
logger.debug(f"execute one step cost time: {end_time-start_time:.3f} s")
return None
def _add_cache(self, model_forward_batch) -> None:
@@ -1507,12 +1521,12 @@ class GPUModelRunner(ModelRunnerBase):
hidden_dim = self.model_config.head_dim * self.model_config.kv_num_heads
# NOTE(liuzichang): Implement multi-layer MTP architecture in the future
num_layers = (
num_hidden_layers = (
self.model_config.num_hidden_layers + self.speculative_config.num_gpu_block_expand_ratio
if self.speculative_method in ["mtp"]
else self.model_config.num_hidden_layers
)
required_memory = byte_of_dtype * 2 * (self.cache_config.block_size * hidden_dim) * num_layers # k + v
required_memory = byte_of_dtype * 2 * (self.cache_config.block_size * hidden_dim) * num_hidden_layers # k + v
return required_memory
def not_need_stop(self) -> bool:

View File

@@ -150,7 +150,7 @@ class PaddleDisWorkerProc:
# Initialize task queue
task_address = (
self.parallel_config.pod_ip,
self.parallel_config.engine_worker_queue_port,
self.parallel_config.engine_worker_queue_port + self.parallel_config.expert_parallel_rank,
)
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
self.task_queue = TaskQueue(
@@ -252,9 +252,11 @@ class PaddleDisWorkerProc:
for req_dict, bsz in tasks:
num_running_requests = int(bsz)
req_dicts.extend(req_dict)
req_ids = [req.request_id for req in req_dicts]
logger.info(
f"Rank: {self.local_rank}, num_running_requests: {num_running_requests}, "
f"num_insert_requests: {len(req_dicts)}"
f"num_insert_requests: {len(req_dicts)}, req_ids: {req_ids}"
)
# Process prefill inputs
self.worker.preprocess_new_task(req_dicts)
@@ -408,7 +410,7 @@ class PaddleDisWorkerProc:
logger.info(f"------- num_blocks_global: {num_blocks_local} --------")
# wait engine launch cache_manager
if self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed":
if self.parallel_config.splitwise_role != "mixed":
launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32)
self.launched_cache_manager_signal = IPCSignal(
name="launched_cache_manager_signal",