[Bug fix] Test td cache messager (#3242)

* support disable cache task in decode node

* fix busg

* Update engine.py

* Update expert_service.py

* Update splitwise_connector.py

* Optimize log for debug

* Optimize log for debug

* fix bug

---------

Co-authored-by: ltd0924 <ltd0924@sina.com>
Co-authored-by: ltd0924 <32387785+ltd0924@users.noreply.github.com>
This commit is contained in:
chenjian
2025-08-06 15:52:45 +08:00
committed by GitHub
parent a4572a5e5d
commit 110f33a530
5 changed files with 144 additions and 57 deletions

View File

@@ -17,8 +17,9 @@
import argparse import argparse
import json import json
import math import math
import time
import threading import threading
import time
import numpy as np import numpy as np
import paddle import paddle
@@ -196,7 +197,9 @@ class CacheMessager:
self.gpu_id = gpu_id self.gpu_id = gpu_id
self.cache_info = dict() self.cache_info = dict()
self.rank_id = self.rank + local_data_parallel_id * self.nranks # align with engine worker rank (paddle.distributed.launch) self.rank_id = (
self.rank + local_data_parallel_id * self.nranks
) # align with engine worker rank (paddle.distributed.launch)
connect_rdma_thread = threading.Thread(target=self._handle_connect_task) connect_rdma_thread = threading.Thread(target=self._handle_connect_task)
connect_rdma_thread.daemon = True connect_rdma_thread.daemon = True
@@ -284,7 +287,7 @@ class CacheMessager:
if not self.cache_info: if not self.cache_info:
time.sleep(0.001) time.sleep(0.001)
continue continue
logger.info(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}") logger.debug(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}")
for req_id, item in list(self.cache_info.items()): for req_id, item in list(self.cache_info.items()):
if "status" not in item: if "status" not in item:
continue continue
@@ -465,7 +468,8 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
logger = get_logger("cache_messager", "cache_messager.log") rank_id = args.rank + args.local_data_parallel_id * args.mp_num
logger = get_logger("cache_messager", f"cache_messager_rank{rank_id}.log")
logger.info("create cache messager...") logger.info("create cache messager...")
logger.info(f"{args}") logger.info(f"{args}")

View File

@@ -113,6 +113,8 @@ class LLMEngine:
self.start_queue_service() self.start_queue_service()
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
if envs.ENABLE_V1_KVCACHE_SCHEDULER: if envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.resource_manager = ResourceManagerV1( self.resource_manager = ResourceManagerV1(
cfg.max_num_seqs, cfg, cfg.tensor_parallel_size, cfg.splitwise_role cfg.max_num_seqs, cfg, cfg.tensor_parallel_size, cfg.splitwise_role
@@ -630,9 +632,13 @@ class LLMEngine:
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len): if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
self.insert_tasks([task]) self.insert_tasks([task])
else: else:
if not self.enable_decode_cache_task:
task.error_msg = "Not enough resources"
new_waiting.append(task) new_waiting.append(task)
if new_waiting: if new_waiting:
if not self.enable_decode_cache_task:
self.split_connector.send_cache_infos(new_waiting, -1)
else:
self.waiting_requests.extend(new_waiting) self.waiting_requests.extend(new_waiting)
llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue") llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue")
@@ -805,6 +811,22 @@ class LLMEngine:
for task in tasks: for task in tasks:
start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER) start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER)
if self.cfg.splitwise_role != "mixed":
status, msg = self.split_connector.check_decode_allocated(task)
if not status:
llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.")
self.scheduler.put_results(
[
RequestOutput(
request_id=task.request_id,
finished=True,
error_code=500,
error_msg=msg,
)
]
)
tasks.remove(task)
continue
if task.sampling_params.bad_words is not None: if task.sampling_params.bad_words is not None:
task.sampling_params.update_from_tokenizer(self.data_processor.tokenizer) task.sampling_params.update_from_tokenizer(self.data_processor.tokenizer)
@@ -1020,7 +1042,6 @@ class LLMEngine:
except Exception as e: except Exception as e:
print(f"Error extracting sub services: {e}") print(f"Error extracting sub services: {e}")
for worker_queue in self.engine_worker_queue_server: for worker_queue in self.engine_worker_queue_server:
worker_queue.cleanup() worker_queue.cleanup()
if hasattr(self, "send_response_server") and self.send_response_server is not None: if hasattr(self, "send_response_server") and self.send_response_server is not None:

View File

@@ -26,6 +26,7 @@ from collections import deque
import numpy as np import numpy as np
from fastdeploy.engine.request import RequestOutput
from fastdeploy.engine.resource_manager import ResourceManager from fastdeploy.engine.resource_manager import ResourceManager
from fastdeploy.inter_communicator import EngineWorkerQueue from fastdeploy.inter_communicator import EngineWorkerQueue
from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.metrics.metrics import main_process_metrics
@@ -34,6 +35,7 @@ from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
from fastdeploy.utils import EngineError, console_logger, envs, get_logger, llm_logger from fastdeploy.utils import EngineError, console_logger, envs, get_logger, llm_logger
class ExpertService: class ExpertService:
""" """
Engine class responsible for managing the Large Language Model (LLM) operations. Engine class responsible for managing the Large Language Model (LLM) operations.
@@ -146,7 +148,7 @@ class ExpertService:
# Start TokenProcessor thread # Start TokenProcessor thread
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(local_data_parallel_id + int(self.cfg.engine_worker_queue_port)) os.environ["INFERENCE_MSG_QUEUE_ID"] = str(local_data_parallel_id + int(self.cfg.engine_worker_queue_port))
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK
self.token_processor.run() self.token_processor.run()
self.cfg.init_cache_info() self.cfg.init_cache_info()
@@ -262,9 +264,13 @@ class ExpertService:
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len): if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
self.insert_tasks([task]) self.insert_tasks([task])
else: else:
if not self.enable_decode_cache_task:
task.error_msg = "Not enough resources"
new_waiting.append(task) new_waiting.append(task)
if new_waiting: if new_waiting:
if not self.enable_decode_cache_task:
self.split_connector.send_cache_infos(new_waiting, -1)
else:
self.waiting_requests.extend(new_waiting) self.waiting_requests.extend(new_waiting)
self.llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue") self.llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue")
@@ -310,8 +316,24 @@ class ExpertService:
if not isinstance(tasks, list): if not isinstance(tasks, list):
tasks = [tasks] tasks = [tasks]
for item in tasks: for task in tasks:
item.schedule_start_time = time.time() if self.cfg.splitwise_role != "mixed":
status, msg = self.split_connector.check_decode_allocated(task)
if not status:
self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.")
self.scheduler.put_results(
[
RequestOutput(
request_id=task.request_id,
finished=True,
error_code=500,
error_msg=msg,
)
]
)
tasks.remove(task)
continue
task.schedule_start_time = time.time()
available_batch = np.sum(self.resource_manager.stop_flags) available_batch = np.sum(self.resource_manager.stop_flags)
if len(tasks) > available_batch: if len(tasks) > available_batch:

View File

@@ -90,6 +90,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
"FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"), "FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"),
# Whether to use PLUGINS. # Whether to use PLUGINS.
"FD_PLUGINS": lambda: None if "FD_PLUGINS" not in os.environ else os.environ["FD_PLUGINS"].split(","), "FD_PLUGINS": lambda: None if "FD_PLUGINS" not in os.environ else os.environ["FD_PLUGINS"].split(","),
# Whether to enable cache task in decode node
"FD_ENABLE_CACHE_TASK": lambda: os.getenv("FD_ENABLE_CACHE_TASK", "1"),
} }

View File

@@ -26,8 +26,6 @@ from fastdeploy.engine.request import CompletionOutput, Request, RequestOutput
from fastdeploy.inter_communicator import EngineWorkerQueue from fastdeploy.inter_communicator import EngineWorkerQueue
from fastdeploy.utils import get_logger from fastdeploy.utils import get_logger
logger = get_logger("splitwise_connector", "splitwise_connector.log")
class SplitwiseConnector: class SplitwiseConnector:
""" """
@@ -45,6 +43,12 @@ class SplitwiseConnector:
resource_manager (object): Resource manager object. resource_manager (object): Resource manager object.
""" """
self.cfg = cfg self.cfg = cfg
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
self.logger = get_logger(
"splitwise_connector", f"splitwise_connector_{self.cfg.parallel_config.local_data_parallel_id}.log"
)
else:
self.logger = get_logger("splitwise_connector", "splitwise_connector.log")
self.scheduler = scheduler self.scheduler = scheduler
self.engine_worker_queue = worker_queue self.engine_worker_queue = worker_queue
self.resource_manager = resource_manager self.resource_manager = resource_manager
@@ -52,6 +56,7 @@ class SplitwiseConnector:
self.temp_cache_info = dict() self.temp_cache_info = dict()
self.current_request_ids = dict() self.current_request_ids = dict()
self.splitwise_queue = splitwise_queue self.splitwise_queue = splitwise_queue
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
if self.cfg.cache_config.pd_comm_port is not None: if self.cfg.cache_config.pd_comm_port is not None:
self.zmq_ctx = zmq.Context() self.zmq_ctx = zmq.Context()
@@ -70,7 +75,7 @@ class SplitwiseConnector:
self.router_socket.setsockopt(zmq.SNDHWM, 1000) self.router_socket.setsockopt(zmq.SNDHWM, 1000)
self.router_socket.setsockopt(zmq.ROUTER_MANDATORY, 1) self.router_socket.setsockopt(zmq.ROUTER_MANDATORY, 1)
self.router_socket.bind(f"tcp://*:{self.cfg.cache_config.pd_comm_port[0]}") self.router_socket.bind(f"tcp://*:{self.cfg.cache_config.pd_comm_port[0]}")
logger.info(f"bind {self.cfg.cache_config.pd_comm_port}") self.logger.info(f"bind {self.cfg.cache_config.pd_comm_port[0]}")
self.poller = zmq.Poller() self.poller = zmq.Poller()
self.poller.register(self.router_socket, zmq.POLLIN) self.poller.register(self.router_socket, zmq.POLLIN)
@@ -89,16 +94,16 @@ class SplitwiseConnector:
if not socks: if not socks:
continue continue
else: else:
logger.debug(f"receive {socks}") self.logger.debug(f"receive {socks}")
frames = self.router_socket.recv_multipart() frames = self.router_socket.recv_multipart()
logger.debug(f"frames: {frames}") self.logger.debug(f"frames: {frames}")
message = frames[-1] message = frames[-1]
self.io_executor.submit(self._process_message, message) self.io_executor.submit(self._process_message, message)
time.sleep(0.001) time.sleep(0.001)
except Exception as e: except Exception as e:
logger.error(f"Receiver error: {e}") self.logger.error(f"Receiver error: {e}")
time.sleep(1) time.sleep(1)
def _get_push_socket(self, addr): def _get_push_socket(self, addr):
@@ -110,7 +115,7 @@ class SplitwiseConnector:
return sock return sock
try: try:
logger.info(f"Establishing new connection to {addr}") self.logger.info(f"Establishing new connection to {addr}")
sock = self.zmq_ctx.socket(zmq.DEALER) sock = self.zmq_ctx.socket(zmq.DEALER)
# 设置连接参数 # 设置连接参数
@@ -129,7 +134,7 @@ class SplitwiseConnector:
return sock return sock
except zmq.ZMQError as e: except zmq.ZMQError as e:
logger.error(f"Connection to {addr} failed: {e}") self.logger.error(f"Connection to {addr} failed: {e}")
raise ConnectionError(f"Failed to connect to {addr}") from e raise ConnectionError(f"Failed to connect to {addr}") from e
@@ -138,7 +143,7 @@ class SplitwiseConnector:
return return
try: try:
logger.info(f"Sent {msg_type} to {addr}") self.logger.info(f"Sent {msg_type} to {addr}")
message = self._serialize_message(msg_type, payload) message = self._serialize_message(msg_type, payload)
try: try:
@@ -146,18 +151,18 @@ class SplitwiseConnector:
sock = self._get_push_socket(addr) sock = self._get_push_socket(addr)
sock.send_multipart([b"", message]) sock.send_multipart([b"", message])
logger.info(f"Sent {msg_type} to {addr}") self.logger.info(f"Sent {msg_type} to {addr}")
except ConnectionError: except ConnectionError:
logger.warning(f"Connection to {addr} not established") self.logger.warning(f"Connection to {addr} not established")
except zmq.Again: except zmq.Again:
logger.warning(f"Send queue full for {addr}") self.logger.warning(f"Send queue full for {addr}")
except Exception as e: except Exception as e:
logger.error(f"Send to {addr} failed: {e}") self.logger.error(f"Send to {addr} failed: {e}")
self._close_connection(addr) self._close_connection(addr)
except Exception as e: except Exception as e:
logger.error(f"Message preparation failed: {e}") self.logger.error(f"Message preparation failed: {e}")
def _close_connection(self, addr): def _close_connection(self, addr):
""" """
@@ -262,7 +267,7 @@ class SplitwiseConnector:
f"{task.disaggregate_info['cache_info']['rdma']['ip']}:" f"{task.disaggregate_info['cache_info']['rdma']['ip']}:"
+ f"{task.disaggregate_info['cache_info']['rdma']['port']}" + f"{task.disaggregate_info['cache_info']['rdma']['port']}"
) )
logger.info(f"send splitwise tasks to port {addr} decode") self.logger.info(f"send splitwise tasks to port {addr} decode")
self.current_request_ids[task.request_id] = "init" self.current_request_ids[task.request_id] = "init"
decode_diagg = task.disaggregate_info["cache_info"] decode_diagg = task.disaggregate_info["cache_info"]
task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"] task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"]
@@ -290,7 +295,7 @@ class SplitwiseConnector:
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks)) self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks))
for task in tasks: for task in tasks:
task.disaggregate_info["cache_info"]["ipc"]["port"] = port task.disaggregate_info["cache_info"]["ipc"]["port"] = port
logger.info(f"send splitwise tasks to port {port} decode") self.logger.info(f"send splitwise tasks to port {port} decode")
current_port = port current_port = port
return current_port return current_port
@@ -300,7 +305,7 @@ class SplitwiseConnector:
""" """
if not isinstance(tasks_list, list): if not isinstance(tasks_list, list):
tasks_list = [tasks_list] tasks_list = [tasks_list]
logger.info("send first token to port decode") self.logger.info("send first token to port decode")
if prefill_msg["transfer_protocol"] == "ipc": if prefill_msg["transfer_protocol"] == "ipc":
port = prefill_msg["cache_info"]["ipc"]["port"] port = prefill_msg["cache_info"]["ipc"]["port"]
if port not in self.connect_innode_instances: if port not in self.connect_innode_instances:
@@ -308,7 +313,7 @@ class SplitwiseConnector:
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks_list)) self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks_list))
else: else:
node = f"{prefill_msg['cache_info']['rdma']['ip']}:{prefill_msg['cache_info']['rdma']['port']}" node = f"{prefill_msg['cache_info']['rdma']['ip']}:{prefill_msg['cache_info']['rdma']['port']}"
logger.info(f"send first token to port {node} decode") self.logger.info(f"send first token to port {node} decode")
self._send_message(node, "decode", tasks_list) self._send_message(node, "decode", tasks_list)
def create_connection(self, port): def create_connection(self, port):
@@ -324,6 +329,22 @@ class SplitwiseConnector:
client_id=0, client_id=0,
) )
def check_decode_allocated(self, task):
if task.disaggregate_info is None:
return True, ""
if self.enable_decode_cache_task:
return True, ""
if task.disaggregate_info["role"] != "prefill":
return True, ""
while self.current_request_ids[task.request_id] == "init":
time.sleep(0.001)
msg = self.current_request_ids[task.request_id]
del self.current_request_ids[task.request_id]
if msg == "finished":
return True, ""
self.logger.error(f"Receive_decode_allocated error: {msg}")
return False, msg
def send_cache_infos(self, tasks, current_id): def send_cache_infos(self, tasks, current_id):
""" """
Send cache information to specific port. Send cache information to specific port.
@@ -340,9 +361,15 @@ class SplitwiseConnector:
for i in range(len(tasks)): for i in range(len(tasks)):
if tasks[i].disaggregate_info is None: if tasks[i].disaggregate_info is None:
continue continue
logger.info(f"{tasks[i].disaggregate_info}") self.logger.info(f"{tasks[i].disaggregate_info}")
if tasks[i].disaggregate_info["role"] == "decode": if tasks[i].disaggregate_info["role"] == "decode":
if tasks[i].disaggregate_info["transfer_protocol"] == "ipc": if tasks[i].disaggregate_info["transfer_protocol"] == "ipc":
if tasks[i].get("error_msg", None) is not None:
cache_info = {
"request_id": tasks[i].request_id,
"error_msg": tasks[i].get("error_msg"),
}
else:
cache_info = { cache_info = {
"request_id": tasks[i].request_id, "request_id": tasks[i].request_id,
"device_ids": self.cfg.device_ids.split(","), "device_ids": self.cfg.device_ids.split(","),
@@ -357,6 +384,12 @@ class SplitwiseConnector:
f"{tasks[i].disaggregate_info['cache_info']['rdma']['ip']}:" f"{tasks[i].disaggregate_info['cache_info']['rdma']['ip']}:"
+ f"{tasks[i].disaggregate_info['cache_info']['rdma']['port']}" + f"{tasks[i].disaggregate_info['cache_info']['rdma']['port']}"
) )
if tasks[i].get("error_msg", None) is not None:
cache_info = {
"request_id": tasks[i].request_id,
"error_msg": tasks[i].get("error_msg"),
}
else:
cache_info = { cache_info = {
"request_id": tasks[i].request_id, "request_id": tasks[i].request_id,
"device_ids": self.cfg.device_ids.split(","), "device_ids": self.cfg.device_ids.split(","),
@@ -391,7 +424,7 @@ class SplitwiseConnector:
else: else:
if len(temp_cache_info): if len(temp_cache_info):
for k, v in temp_cache_info.items(): for k, v in temp_cache_info.items():
logger.info(f"{k} {v}") self.logger.info(f"{k} {v}")
if ":" in str(k): if ":" in str(k):
self._send_message(k, "cache_sync", v) self._send_message(k, "cache_sync", v)
else: else:
@@ -408,7 +441,7 @@ class SplitwiseConnector:
payload = [output.to_dict() for output in payload] payload = [output.to_dict() for output in payload]
req_ids = [task["request_id"] for task in payload] req_ids = [task["request_id"] for task in payload]
logger.info(f"send message {msg_type} {req_ids}") self.logger.info(f"send message {msg_type} {req_ids}")
json_data = msgpack.packb({"type": msg_type, "payload": payload}) json_data = msgpack.packb({"type": msg_type, "payload": payload})
@@ -419,7 +452,7 @@ class SplitwiseConnector:
# JSON反序列化 # JSON反序列化
message = msgpack.unpackb(data) message = msgpack.unpackb(data)
req_ids = [task["request_id"] for task in message["payload"]] req_ids = [task["request_id"] for task in message["payload"]]
logger.info(f"send message {message['type']} {req_ids}") self.logger.info(f"recv message type {message['type']} for {req_ids}")
return message["type"], message["payload"] return message["type"], message["payload"]
def _process_message(self, message: bytes): def _process_message(self, message: bytes):
@@ -428,7 +461,7 @@ class SplitwiseConnector:
""" """
try: try:
msg_type, payload = self._deserialize_message(message) msg_type, payload = self._deserialize_message(message)
logger.info(f"{msg_type}") self.logger.info(f"{msg_type}")
if msg_type == "prefill": if msg_type == "prefill":
self._handle_prefill(payload) self._handle_prefill(payload)
@@ -436,11 +469,16 @@ class SplitwiseConnector:
self._handle_decode(payload) self._handle_decode(payload)
elif msg_type == "cache_sync": elif msg_type == "cache_sync":
for task in payload: for task in payload:
self.logger.info(f"cache_sync task: {task}")
current_status = task.get("error_msg", "finished")
self.current_request_ids[task["request_id"]] = current_status
if self.enable_decode_cache_task:
del self.current_request_ids[task["request_id"]] del self.current_request_ids[task["request_id"]]
if current_status == "finished":
self.engine_worker_queue.put_cache_info(payload) self.engine_worker_queue.put_cache_info(payload)
except Exception as e: except Exception as e:
logger.error(f"Message processing failed: {e}") self.logger.error(f"Message processing failed: {e}")
def _handle_prefill(self, tasks): def _handle_prefill(self, tasks):
""" """
@@ -450,7 +488,7 @@ class SplitwiseConnector:
tasks_data = [Request.from_dict(task) for task in tasks] tasks_data = [Request.from_dict(task) for task in tasks]
req_ids = [task["request_id"] for task in tasks] req_ids = [task["request_id"] for task in tasks]
self.splitwise_queue.append(("decode", tasks_data)) self.splitwise_queue.append(("decode", tasks_data))
logger.debug(f"{req_ids} received prefill data") self.logger.info(f"{req_ids} received prefill data")
def _handle_decode(self, payload): def _handle_decode(self, payload):
""" """
@@ -471,4 +509,4 @@ class SplitwiseConnector:
) )
req_ids = [task["request_id"] for task in payload] req_ids = [task["request_id"] for task in payload]
self.splitwise_queue.append(("decode", tasks)) self.splitwise_queue.append(("decode", tasks))
logger.debug(f"{req_ids} received decode data") self.logger.info(f"{req_ids} received decode data")