[Bug fix] Test td cache messager (#3242)

* support disable cache task in decode node

* fix busg

* Update engine.py

* Update expert_service.py

* Update splitwise_connector.py

* Optimize log for debug

* Optimize log for debug

* fix bug

---------

Co-authored-by: ltd0924 <ltd0924@sina.com>
Co-authored-by: ltd0924 <32387785+ltd0924@users.noreply.github.com>
This commit is contained in:
chenjian
2025-08-06 15:52:45 +08:00
committed by GitHub
parent a4572a5e5d
commit 110f33a530
5 changed files with 144 additions and 57 deletions

View File

@@ -17,8 +17,9 @@
import argparse
import json
import math
import time
import threading
import time
import numpy as np
import paddle
@@ -196,7 +197,9 @@ class CacheMessager:
self.gpu_id = gpu_id
self.cache_info = dict()
self.rank_id = self.rank + local_data_parallel_id * self.nranks # align with engine worker rank (paddle.distributed.launch)
self.rank_id = (
self.rank + local_data_parallel_id * self.nranks
) # align with engine worker rank (paddle.distributed.launch)
connect_rdma_thread = threading.Thread(target=self._handle_connect_task)
connect_rdma_thread.daemon = True
@@ -284,7 +287,7 @@ class CacheMessager:
if not self.cache_info:
time.sleep(0.001)
continue
logger.info(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}")
logger.debug(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}")
for req_id, item in list(self.cache_info.items()):
if "status" not in item:
continue
@@ -364,7 +367,7 @@ class CacheMessager:
except Exception as e:
logger.info(f"prefill layerwise send cache thread has exception: {e}")
def _handle_connect_task(self):
while True:
try:
@@ -465,7 +468,8 @@ def main():
if __name__ == "__main__":
args = parse_args()
logger = get_logger("cache_messager", "cache_messager.log")
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
logger = get_logger("cache_messager", f"cache_messager_rank{rank_id}.log")
logger.info("create cache messager...")
logger.info(f"{args}")

View File

@@ -113,6 +113,8 @@ class LLMEngine:
self.start_queue_service()
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.resource_manager = ResourceManagerV1(
cfg.max_num_seqs, cfg, cfg.tensor_parallel_size, cfg.splitwise_role
@@ -630,11 +632,15 @@ class LLMEngine:
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
self.insert_tasks([task])
else:
if not self.enable_decode_cache_task:
task.error_msg = "Not enough resources"
new_waiting.append(task)
if new_waiting:
self.waiting_requests.extend(new_waiting)
llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue")
if not self.enable_decode_cache_task:
self.split_connector.send_cache_infos(new_waiting, -1)
else:
self.waiting_requests.extend(new_waiting)
llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue")
else:
time.sleep(0.001)
@@ -805,6 +811,22 @@ class LLMEngine:
for task in tasks:
start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER)
if self.cfg.splitwise_role != "mixed":
status, msg = self.split_connector.check_decode_allocated(task)
if not status:
llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.")
self.scheduler.put_results(
[
RequestOutput(
request_id=task.request_id,
finished=True,
error_code=500,
error_msg=msg,
)
]
)
tasks.remove(task)
continue
if task.sampling_params.bad_words is not None:
task.sampling_params.update_from_tokenizer(self.data_processor.tokenizer)
@@ -1020,7 +1042,6 @@ class LLMEngine:
except Exception as e:
print(f"Error extracting sub services: {e}")
for worker_queue in self.engine_worker_queue_server:
worker_queue.cleanup()
if hasattr(self, "send_response_server") and self.send_response_server is not None:

View File

@@ -26,6 +26,7 @@ from collections import deque
import numpy as np
from fastdeploy.engine.request import RequestOutput
from fastdeploy.engine.resource_manager import ResourceManager
from fastdeploy.inter_communicator import EngineWorkerQueue
from fastdeploy.metrics.metrics import main_process_metrics
@@ -34,6 +35,7 @@ from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
from fastdeploy.utils import EngineError, console_logger, envs, get_logger, llm_logger
class ExpertService:
"""
Engine class responsible for managing the Large Language Model (LLM) operations.
@@ -146,7 +148,7 @@ class ExpertService:
# Start TokenProcessor thread
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(local_data_parallel_id + int(self.cfg.engine_worker_queue_port))
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK
self.token_processor.run()
self.cfg.init_cache_info()
@@ -262,11 +264,15 @@ class ExpertService:
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
self.insert_tasks([task])
else:
if not self.enable_decode_cache_task:
task.error_msg = "Not enough resources"
new_waiting.append(task)
if new_waiting:
self.waiting_requests.extend(new_waiting)
self.llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue")
if not self.enable_decode_cache_task:
self.split_connector.send_cache_infos(new_waiting, -1)
else:
self.waiting_requests.extend(new_waiting)
self.llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue")
else:
time.sleep(0.001)
@@ -310,8 +316,24 @@ class ExpertService:
if not isinstance(tasks, list):
tasks = [tasks]
for item in tasks:
item.schedule_start_time = time.time()
for task in tasks:
if self.cfg.splitwise_role != "mixed":
status, msg = self.split_connector.check_decode_allocated(task)
if not status:
self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.")
self.scheduler.put_results(
[
RequestOutput(
request_id=task.request_id,
finished=True,
error_code=500,
error_msg=msg,
)
]
)
tasks.remove(task)
continue
task.schedule_start_time = time.time()
available_batch = np.sum(self.resource_manager.stop_flags)
if len(tasks) > available_batch:

View File

@@ -90,6 +90,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
"FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"),
# Whether to use PLUGINS.
"FD_PLUGINS": lambda: None if "FD_PLUGINS" not in os.environ else os.environ["FD_PLUGINS"].split(","),
# Whether to enable cache task in decode node
"FD_ENABLE_CACHE_TASK": lambda: os.getenv("FD_ENABLE_CACHE_TASK", "1"),
}

View File

@@ -26,8 +26,6 @@ from fastdeploy.engine.request import CompletionOutput, Request, RequestOutput
from fastdeploy.inter_communicator import EngineWorkerQueue
from fastdeploy.utils import get_logger
logger = get_logger("splitwise_connector", "splitwise_connector.log")
class SplitwiseConnector:
"""
@@ -45,6 +43,12 @@ class SplitwiseConnector:
resource_manager (object): Resource manager object.
"""
self.cfg = cfg
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
self.logger = get_logger(
"splitwise_connector", f"splitwise_connector_{self.cfg.parallel_config.local_data_parallel_id}.log"
)
else:
self.logger = get_logger("splitwise_connector", "splitwise_connector.log")
self.scheduler = scheduler
self.engine_worker_queue = worker_queue
self.resource_manager = resource_manager
@@ -52,6 +56,7 @@ class SplitwiseConnector:
self.temp_cache_info = dict()
self.current_request_ids = dict()
self.splitwise_queue = splitwise_queue
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
if self.cfg.cache_config.pd_comm_port is not None:
self.zmq_ctx = zmq.Context()
@@ -70,7 +75,7 @@ class SplitwiseConnector:
self.router_socket.setsockopt(zmq.SNDHWM, 1000)
self.router_socket.setsockopt(zmq.ROUTER_MANDATORY, 1)
self.router_socket.bind(f"tcp://*:{self.cfg.cache_config.pd_comm_port[0]}")
logger.info(f"bind {self.cfg.cache_config.pd_comm_port}")
self.logger.info(f"bind {self.cfg.cache_config.pd_comm_port[0]}")
self.poller = zmq.Poller()
self.poller.register(self.router_socket, zmq.POLLIN)
@@ -89,16 +94,16 @@ class SplitwiseConnector:
if not socks:
continue
else:
logger.debug(f"receive {socks}")
self.logger.debug(f"receive {socks}")
frames = self.router_socket.recv_multipart()
logger.debug(f"frames: {frames}")
self.logger.debug(f"frames: {frames}")
message = frames[-1]
self.io_executor.submit(self._process_message, message)
time.sleep(0.001)
except Exception as e:
logger.error(f"Receiver error: {e}")
self.logger.error(f"Receiver error: {e}")
time.sleep(1)
def _get_push_socket(self, addr):
@@ -110,7 +115,7 @@ class SplitwiseConnector:
return sock
try:
logger.info(f"Establishing new connection to {addr}")
self.logger.info(f"Establishing new connection to {addr}")
sock = self.zmq_ctx.socket(zmq.DEALER)
# 设置连接参数
@@ -129,7 +134,7 @@ class SplitwiseConnector:
return sock
except zmq.ZMQError as e:
logger.error(f"Connection to {addr} failed: {e}")
self.logger.error(f"Connection to {addr} failed: {e}")
raise ConnectionError(f"Failed to connect to {addr}") from e
@@ -138,7 +143,7 @@ class SplitwiseConnector:
return
try:
logger.info(f"Sent {msg_type} to {addr}")
self.logger.info(f"Sent {msg_type} to {addr}")
message = self._serialize_message(msg_type, payload)
try:
@@ -146,18 +151,18 @@ class SplitwiseConnector:
sock = self._get_push_socket(addr)
sock.send_multipart([b"", message])
logger.info(f"Sent {msg_type} to {addr}")
self.logger.info(f"Sent {msg_type} to {addr}")
except ConnectionError:
logger.warning(f"Connection to {addr} not established")
self.logger.warning(f"Connection to {addr} not established")
except zmq.Again:
logger.warning(f"Send queue full for {addr}")
self.logger.warning(f"Send queue full for {addr}")
except Exception as e:
logger.error(f"Send to {addr} failed: {e}")
self.logger.error(f"Send to {addr} failed: {e}")
self._close_connection(addr)
except Exception as e:
logger.error(f"Message preparation failed: {e}")
self.logger.error(f"Message preparation failed: {e}")
def _close_connection(self, addr):
"""
@@ -262,7 +267,7 @@ class SplitwiseConnector:
f"{task.disaggregate_info['cache_info']['rdma']['ip']}:"
+ f"{task.disaggregate_info['cache_info']['rdma']['port']}"
)
logger.info(f"send splitwise tasks to port {addr} decode")
self.logger.info(f"send splitwise tasks to port {addr} decode")
self.current_request_ids[task.request_id] = "init"
decode_diagg = task.disaggregate_info["cache_info"]
task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"]
@@ -290,7 +295,7 @@ class SplitwiseConnector:
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks))
for task in tasks:
task.disaggregate_info["cache_info"]["ipc"]["port"] = port
logger.info(f"send splitwise tasks to port {port} decode")
self.logger.info(f"send splitwise tasks to port {port} decode")
current_port = port
return current_port
@@ -300,7 +305,7 @@ class SplitwiseConnector:
"""
if not isinstance(tasks_list, list):
tasks_list = [tasks_list]
logger.info("send first token to port decode")
self.logger.info("send first token to port decode")
if prefill_msg["transfer_protocol"] == "ipc":
port = prefill_msg["cache_info"]["ipc"]["port"]
if port not in self.connect_innode_instances:
@@ -308,7 +313,7 @@ class SplitwiseConnector:
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks_list))
else:
node = f"{prefill_msg['cache_info']['rdma']['ip']}:{prefill_msg['cache_info']['rdma']['port']}"
logger.info(f"send first token to port {node} decode")
self.logger.info(f"send first token to port {node} decode")
self._send_message(node, "decode", tasks_list)
def create_connection(self, port):
@@ -324,6 +329,22 @@ class SplitwiseConnector:
client_id=0,
)
def check_decode_allocated(self, task):
if task.disaggregate_info is None:
return True, ""
if self.enable_decode_cache_task:
return True, ""
if task.disaggregate_info["role"] != "prefill":
return True, ""
while self.current_request_ids[task.request_id] == "init":
time.sleep(0.001)
msg = self.current_request_ids[task.request_id]
del self.current_request_ids[task.request_id]
if msg == "finished":
return True, ""
self.logger.error(f"Receive_decode_allocated error: {msg}")
return False, msg
def send_cache_infos(self, tasks, current_id):
"""
Send cache information to specific port.
@@ -340,15 +361,21 @@ class SplitwiseConnector:
for i in range(len(tasks)):
if tasks[i].disaggregate_info is None:
continue
logger.info(f"{tasks[i].disaggregate_info}")
self.logger.info(f"{tasks[i].disaggregate_info}")
if tasks[i].disaggregate_info["role"] == "decode":
if tasks[i].disaggregate_info["transfer_protocol"] == "ipc":
cache_info = {
"request_id": tasks[i].request_id,
"device_ids": self.cfg.device_ids.split(","),
"transfer_protocol": "ipc",
"dest_block_ids": tasks[i].disaggregate_info["block_tables"],
}
if tasks[i].get("error_msg", None) is not None:
cache_info = {
"request_id": tasks[i].request_id,
"error_msg": tasks[i].get("error_msg"),
}
else:
cache_info = {
"request_id": tasks[i].request_id,
"device_ids": self.cfg.device_ids.split(","),
"transfer_protocol": "ipc",
"dest_block_ids": tasks[i].disaggregate_info["block_tables"],
}
if tasks[i].disaggregate_info["cache_info"]["ipc"]["port"] not in temp_cache_info:
temp_cache_info[tasks[i].disaggregate_info["cache_info"]["ipc"]["port"]] = []
temp_cache_info[tasks[i].disaggregate_info["cache_info"]["ipc"]["port"]].append(cache_info)
@@ -357,14 +384,20 @@ class SplitwiseConnector:
f"{tasks[i].disaggregate_info['cache_info']['rdma']['ip']}:"
+ f"{tasks[i].disaggregate_info['cache_info']['rdma']['port']}"
)
cache_info = {
"request_id": tasks[i].request_id,
"device_ids": self.cfg.device_ids.split(","),
"ip": self.cfg.host_ip,
"rdma_ports": self.cfg.disaggregate_info["cache_info"]["rdma"]["rdma_port"],
"transfer_protocol": "rdma",
"dest_block_ids": tasks[i].disaggregate_info["block_tables"],
}
if tasks[i].get("error_msg", None) is not None:
cache_info = {
"request_id": tasks[i].request_id,
"error_msg": tasks[i].get("error_msg"),
}
else:
cache_info = {
"request_id": tasks[i].request_id,
"device_ids": self.cfg.device_ids.split(","),
"ip": self.cfg.host_ip,
"rdma_ports": self.cfg.disaggregate_info["cache_info"]["rdma"]["rdma_port"],
"transfer_protocol": "rdma",
"dest_block_ids": tasks[i].disaggregate_info["block_tables"],
}
if addr not in temp_cache_info:
temp_cache_info[addr] = []
@@ -391,7 +424,7 @@ class SplitwiseConnector:
else:
if len(temp_cache_info):
for k, v in temp_cache_info.items():
logger.info(f"{k} {v}")
self.logger.info(f"{k} {v}")
if ":" in str(k):
self._send_message(k, "cache_sync", v)
else:
@@ -408,7 +441,7 @@ class SplitwiseConnector:
payload = [output.to_dict() for output in payload]
req_ids = [task["request_id"] for task in payload]
logger.info(f"send message {msg_type} {req_ids}")
self.logger.info(f"send message {msg_type} {req_ids}")
json_data = msgpack.packb({"type": msg_type, "payload": payload})
@@ -419,7 +452,7 @@ class SplitwiseConnector:
# JSON反序列化
message = msgpack.unpackb(data)
req_ids = [task["request_id"] for task in message["payload"]]
logger.info(f"send message {message['type']} {req_ids}")
self.logger.info(f"recv message type {message['type']} for {req_ids}")
return message["type"], message["payload"]
def _process_message(self, message: bytes):
@@ -428,7 +461,7 @@ class SplitwiseConnector:
"""
try:
msg_type, payload = self._deserialize_message(message)
logger.info(f"{msg_type}")
self.logger.info(f"{msg_type}")
if msg_type == "prefill":
self._handle_prefill(payload)
@@ -436,11 +469,16 @@ class SplitwiseConnector:
self._handle_decode(payload)
elif msg_type == "cache_sync":
for task in payload:
del self.current_request_ids[task["request_id"]]
self.engine_worker_queue.put_cache_info(payload)
self.logger.info(f"cache_sync task: {task}")
current_status = task.get("error_msg", "finished")
self.current_request_ids[task["request_id"]] = current_status
if self.enable_decode_cache_task:
del self.current_request_ids[task["request_id"]]
if current_status == "finished":
self.engine_worker_queue.put_cache_info(payload)
except Exception as e:
logger.error(f"Message processing failed: {e}")
self.logger.error(f"Message processing failed: {e}")
def _handle_prefill(self, tasks):
"""
@@ -450,7 +488,7 @@ class SplitwiseConnector:
tasks_data = [Request.from_dict(task) for task in tasks]
req_ids = [task["request_id"] for task in tasks]
self.splitwise_queue.append(("decode", tasks_data))
logger.debug(f"{req_ids} received prefill data")
self.logger.info(f"{req_ids} received prefill data")
def _handle_decode(self, payload):
"""
@@ -471,4 +509,4 @@ class SplitwiseConnector:
)
req_ids = [task["request_id"] for task in payload]
self.splitwise_queue.append(("decode", tasks))
logger.debug(f"{req_ids} received decode data")
self.logger.info(f"{req_ids} received decode data")