[Bug fix] Fix bug for d blocks not enough (#3479)

* Support batched tokens for EP

* Support batched tokens for EP

* Support batched tokens for EP

* Support batched tokens for EP

* Support batched tokens for EP and fix bug

* Support batched tokens for EP and fix bug

* Support batched tokens for EP and fix bug

* Support batched tokens for EP and fix bug

* Fix bug for memory allocation

* Fix bug for D blocks not enough

* fix bug when d blocks not enough

* fix bug when d blocks not enough

* fix cache message recycle step

* fix cache message recycle step

* Fix step_idx recycle
This commit is contained in:
chenjian
2025-08-21 11:36:16 +08:00
committed by GitHub
parent c487b62ee0
commit 6854506533
4 changed files with 120 additions and 51 deletions

View File

@@ -252,6 +252,9 @@ class CacheMessager:
self.last_step_idx = -1 self.last_step_idx = -1
self.last_layer_idx = -1 # int32 self.last_layer_idx = -1 # int32
max_step_idx = 100003
engine_recycled_count = 0
while True: while True:
cache_info = self.engine_worker_queue.get_cache_info() cache_info = self.engine_worker_queue.get_cache_info()
@@ -271,7 +274,6 @@ class CacheMessager:
current_info["status"] = "init" current_info["status"] = "init"
logger.info(f"start cache_infos: {current_info}") logger.info(f"start cache_infos: {current_info}")
self.cache_info[info["request_id"]] = current_info self.cache_info[info["request_id"]] = current_info
self.last_step_idx = min(self.last_step_idx, current_info["current_id"])
else: else:
self.cache_info[info["request_id"]] = info self.cache_info[info["request_id"]] = info
prefilled_layer_idx = layer_shm_value.value[0] prefilled_layer_idx = layer_shm_value.value[0]
@@ -287,7 +289,18 @@ class CacheMessager:
if not self.cache_info: if not self.cache_info:
time.sleep(0.001) time.sleep(0.001)
continue continue
logger.debug(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}") if self.last_step_idx > prefilled_step_idx:
engine_recycled_count += 1
self.last_step_idx = prefilled_step_idx # only copy value read from shm memory
prefilled_step_idx = (
prefilled_step_idx + max_step_idx * engine_recycled_count
) # remap prefilled_step_idx for comparison
logger.debug(
f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx in shm: {self.last_step_idx},"
f"prefilled_step_idx: {prefilled_step_idx} engine_recycled_count {engine_recycled_count}"
)
for req_id, item in list(self.cache_info.items()): for req_id, item in list(self.cache_info.items()):
if "status" not in item: if "status" not in item:
continue continue
@@ -318,7 +331,8 @@ class CacheMessager:
if item["current_id"] < prefilled_step_idx: if item["current_id"] < prefilled_step_idx:
current_layer_idx = self.num_hidden_layers current_layer_idx = self.num_hidden_layers
else: else:
current_layer_idx = prefilled_layer_idx + 1 if item["current_id"] == prefilled_step_idx:
current_layer_idx = prefilled_layer_idx + 1
for layer_idx in range(item["layer_idx"], current_layer_idx): for layer_idx in range(item["layer_idx"], current_layer_idx):
tic = time.time() tic = time.time()
@@ -361,9 +375,7 @@ class CacheMessager:
self.engine_worker_queue.put_finished_req([(item["request_id"], "finished")]) self.engine_worker_queue.put_finished_req([(item["request_id"], "finished")])
logger.info(f"put write cache {item['request_id']}") logger.info(f"put write cache {item['request_id']}")
del self.cache_info[req_id] del self.cache_info[req_id]
self.last_layer_idx = prefilled_layer_idx
self.last_step_idx = prefilled_step_idx
self.last_layer_idx = prefilled_layer_idx
except Exception as e: except Exception as e:
logger.info(f"prefill layerwise send cache thread has exception: {e}") logger.info(f"prefill layerwise send cache thread has exception: {e}")

View File

@@ -190,6 +190,7 @@ class LLMEngine:
self._init_worker_signals() self._init_worker_signals()
self.data_processor = self.input_processor.create_processor() self.data_processor = self.input_processor.create_processor()
self.response_lock = threading.Lock() # prevent to call send_multipart in zmq concurrently
if api_server_pid is not None: if api_server_pid is not None:
if envs.FD_ENABLE_INTERNAL_ADAPTER: if envs.FD_ENABLE_INTERNAL_ADAPTER:
@@ -201,6 +202,10 @@ class LLMEngine:
else: else:
self.recv_request_server = ZmqIpcServer(name=api_server_pid, mode=zmq.PULL) self.recv_request_server = ZmqIpcServer(name=api_server_pid, mode=zmq.PULL)
self.send_response_server = ZmqIpcServer(name=api_server_pid, mode=zmq.ROUTER) self.send_response_server = ZmqIpcServer(name=api_server_pid, mode=zmq.ROUTER)
self.recv_result_handle_thread = threading.Thread(
target=self.send_response_server.recv_result_handle, daemon=True
)
self.recv_result_handle_thread.start()
time.sleep(3) time.sleep(3)
self.cfg.init_cache_info() self.cfg.init_cache_info()
@@ -323,8 +328,9 @@ class LLMEngine:
if len(results) == 0: if len(results) == 0:
time.sleep(0.005) time.sleep(0.005)
continue continue
for request_id, contents in results.items(): with self.response_lock:
self.send_response_server.send_response(request_id, contents) for request_id, contents in results.items():
self.send_response_server.send_response(request_id, contents)
except Exception as e: except Exception as e:
llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}") llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")
@@ -341,7 +347,7 @@ class LLMEngine:
Insert task to engine thread, monitor scheduler request queue. Insert task to engine thread, monitor scheduler request queue.
if the engine has resource, insert task to engine if the engine has resource, insert task to engine
""" """
current_id = -1 current_id = 0
while self.running: while self.running:
try: try:
if self.resource_manager.available_batch() == 0: if self.resource_manager.available_batch() == 0:
@@ -376,12 +382,15 @@ class LLMEngine:
time.sleep(0.001) time.sleep(0.001)
continue continue
current_id = (current_id + 1) % 100003
if self.cfg.splitwise_role != "mixed": if self.cfg.splitwise_role != "mixed":
llm_logger.info("Inserting splitwise tasks") llm_logger.info("Inserting splitwise tasks")
self.split_connector.send_splitwise_tasks(tasks, current_id) self.split_connector.send_splitwise_tasks(tasks, current_id)
self.insert_tasks(tasks, current_id) insert_successful = self.insert_tasks(tasks, current_id)
if insert_successful:
current_id = current_id + 1
else:
continue
main_process_metrics.num_requests_waiting.dec(len(tasks)) main_process_metrics.num_requests_waiting.dec(len(tasks))
main_process_metrics.num_requests_running.inc(len(tasks)) main_process_metrics.num_requests_running.inc(len(tasks))
@@ -495,7 +504,7 @@ class LLMEngine:
if failed is None: if failed is None:
main_process_metrics.num_requests_waiting.inc(1) main_process_metrics.num_requests_waiting.inc(1)
continue continue
llm_logger.error(f"request {request_id} insert to scheduler failed: {failed}")
error_result = RequestOutput( error_result = RequestOutput(
request_id=request_id, request_id=request_id,
finished=True, finished=True,
@@ -504,7 +513,8 @@ class LLMEngine:
) )
# Since the request is not in scheduler # Since the request is not in scheduler
# Send result by zmq directly # Send result by zmq directly
self.send_response_server.send_response(request_id, error_result) with self.response_lock:
self.send_response_server.send_response(request_id, [error_result])
except Exception as e: except Exception as e:
llm_logger.error( llm_logger.error(
f"Error happend while receving new request from zmq, details={e}, " f"Error happend while receving new request from zmq, details={e}, "
@@ -821,6 +831,9 @@ class LLMEngine:
self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz)) self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz))
return True return True
if not isinstance(tasks, list):
tasks = [tasks]
need_delete_tasks = []
for task in tasks: for task in tasks:
start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER) start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER)
if self.cfg.splitwise_role != "mixed": if self.cfg.splitwise_role != "mixed":
@@ -837,27 +850,29 @@ class LLMEngine:
) )
] ]
) )
tasks.remove(task) need_delete_tasks.append(task)
continue continue
if task.sampling_params.bad_words is not None: if task.sampling_params.bad_words is not None:
task.sampling_params.update_from_tokenizer(self.data_processor.tokenizer) task.sampling_params.update_from_tokenizer(self.data_processor.tokenizer)
self.resource_manager.check_and_free_block_tables() self.resource_manager.check_and_free_block_tables()
if not isinstance(tasks, list): for tmp_task in need_delete_tasks:
tasks = [tasks] tasks.remove(tmp_task)
for item in tasks: for item in tasks:
item.schedule_start_time = time.time() item.schedule_start_time = time.time()
req_ids = [t.request_id for t in tasks]
if len(tasks) == 0:
return False
available_batch = np.sum(self.resource_manager.stop_flags) available_batch = np.sum(self.resource_manager.stop_flags)
if len(tasks) > available_batch: if len(tasks) > available_batch:
llm_logger.error(f"Inserting batch:{len(tasks)} exceeds the available batch:{available_batch}.") llm_logger.error(f"Inserting batch:{len(tasks)} exceeds the available batch:{available_batch}.")
llm_logger.error("The exceeded part will be ignored!") llm_logger.error("The exceeded part will be ignored!")
tasks = tasks[:available_batch] tasks = tasks[:available_batch]
req_ids = [t.request_id for t in tasks]
tasks = self.resource_manager.allocate_resources_for_new_tasks(tasks) tasks = self.resource_manager.allocate_resources_for_new_tasks(tasks)
if not tasks: if not tasks:

View File

@@ -172,7 +172,7 @@ class ExpertService:
Insert task to engine thread, monitor scheduler request queue. Insert task to engine thread, monitor scheduler request queue.
if the engine has resource, insert task to engine if the engine has resource, insert task to engine
""" """
current_id = -1 current_id = 0
while True: while True:
try: try:
if self.resource_manager.available_batch() == 0: if self.resource_manager.available_batch() == 0:
@@ -205,9 +205,11 @@ class ExpertService:
self.llm_logger.info("Inserting splitwise tasks") self.llm_logger.info("Inserting splitwise tasks")
self.split_connector.send_splitwise_tasks(tasks, current_id) self.split_connector.send_splitwise_tasks(tasks, current_id)
current_id = (current_id + 1) % 100003 insert_successful = self.insert_tasks(tasks, current_id)
if insert_successful:
self.insert_tasks(tasks, current_id) current_id = current_id + 1
else:
continue
main_process_metrics.num_requests_waiting.dec(len(tasks)) main_process_metrics.num_requests_waiting.dec(len(tasks))
main_process_metrics.num_requests_running.inc(len(tasks)) main_process_metrics.num_requests_running.inc(len(tasks))
@@ -328,6 +330,7 @@ class ExpertService:
if not isinstance(tasks, list): if not isinstance(tasks, list):
tasks = [tasks] tasks = [tasks]
need_delete_tasks = []
for task in tasks: for task in tasks:
if self.cfg.splitwise_role != "mixed": if self.cfg.splitwise_role != "mixed":
status, msg = self.split_connector.check_decode_allocated(task) status, msg = self.split_connector.check_decode_allocated(task)
@@ -343,10 +346,19 @@ class ExpertService:
) )
] ]
) )
tasks.remove(task) need_delete_tasks.append(task)
continue continue
task.schedule_start_time = time.time()
for tmp_task in need_delete_tasks:
tasks.remove(tmp_task)
for item in tasks:
item.schedule_start_time = time.time()
req_ids = [t.request_id for t in tasks]
if len(tasks) == 0:
return False
available_batch = np.sum(self.resource_manager.stop_flags) available_batch = np.sum(self.resource_manager.stop_flags)
if len(tasks) > available_batch: if len(tasks) > available_batch:
self.llm_logger.error( self.llm_logger.error(
@@ -355,8 +367,6 @@ class ExpertService:
self.llm_logger.error("The exceeded part will be ignored!") self.llm_logger.error("The exceeded part will be ignored!")
tasks = tasks[:available_batch] tasks = tasks[:available_batch]
req_ids = [t.request_id for t in tasks]
tasks = self.resource_manager.allocate_resources_for_new_tasks(tasks) tasks = self.resource_manager.allocate_resources_for_new_tasks(tasks)
if not tasks: if not tasks:

View File

@@ -18,6 +18,7 @@ import os
import threading import threading
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict
import msgpack import msgpack
import zmq import zmq
@@ -32,7 +33,8 @@ class ZmqServerBase(ABC):
""" """
def __init__(self): def __init__(self):
pass self.cached_results = defaultdict(list)
self.response_token_lock = threading.Lock()
@abstractmethod @abstractmethod
def _create_socket(self): def _create_socket(self):
@@ -89,6 +91,21 @@ class ZmqServerBase(ABC):
llm_logger.warning(f"{e}") llm_logger.warning(f"{e}")
return str(e), None return str(e), None
def recv_result_handle(self):
while True:
try:
with self.response_token_lock:
client, _, request_id = self.socket.recv_multipart(flags=zmq.NOBLOCK)
req_id_str = request_id.decode("utf-8")
with self.mutex:
self.req_dict[req_id_str] = client
except zmq.Again:
time.sleep(0.001)
continue
except Exception as e:
llm_logger.error(f"recv_result_handle get unknown exception: {e}")
continue
def send_response(self, req_id, data): def send_response(self, req_id, data):
""" """
Send generated token result to client. Send generated token result to client.
@@ -96,36 +113,46 @@ class ZmqServerBase(ABC):
self._ensure_socket() self._ensure_socket()
if self.socket is None: if self.socket is None:
raise RuntimeError("Router socket not created. Call create_router() first.") raise RuntimeError("Router socket not created. Call create_router() first.")
new_data = []
while self.running: has_result_handle = False
with self.mutex: with self.mutex:
if req_id not in self.req_dict: if req_id not in self.req_dict:
try: self.cached_results[req_id].append(data)
client, _, request_id = self.socket.recv_multipart(flags=zmq.NOBLOCK)
req_id_str = request_id.decode("utf-8")
self.req_dict[req_id_str] = client
except zmq.Again:
time.sleep(0.001)
continue
else:
break
try:
start_send = time.time()
if self.aggregate_send:
result = self.pack_aggregated_data(data)
else: else:
result = msgpack.packb([response.to_dict() for response in data]) has_result_handle = True
self.socket.send_multipart([self.req_dict[req_id], b"", result]) if req_id in self.cached_results:
llm_logger.debug(f"send_multipart result: {req_id} len {len(data)} elapse: {time.time()-start_send}") for history_data in self.cached_results[req_id]:
new_data.extend(history_data)
llm_logger.info(
f"get request {req_id} result handle after cached result, total cached length {len(self.cached_results[req_id])}"
)
del self.cached_results[req_id]
if has_result_handle:
try:
new_data.extend(data)
start_send = time.time()
if self.aggregate_send:
result = self.pack_aggregated_data(new_data)
else:
result = msgpack.packb([response.to_dict() for response in new_data])
with self.response_token_lock:
self.socket.send_multipart([self.req_dict[req_id], b"", result])
llm_logger.debug(
f"send_multipart result: {req_id} len {len(new_data)} elapse: {time.time()-start_send}"
)
except Exception as e: except Exception as e:
llm_logger.error(f"Send result to zmq client failed: {e}") llm_logger.error(f"Send result to zmq client failed: {e}")
if data[-1].finished: if data[-1].finished:
with self.mutex: with self.mutex:
if req_id not in self.req_dict:
llm_logger.warning(f"req_id {req_id} finished but no result handle, drop it")
if req_id in self.cached_results:
del self.cached_results[req_id]
else:
llm_logger.info(f"send_multipart finished, req_id: {req_id}")
self.req_dict.pop(req_id, None) self.req_dict.pop(req_id, None)
llm_logger.info(f"send_multipart finished, req_id: {req_id}")
@abstractmethod @abstractmethod
def close(self): def close(self):
@@ -143,6 +170,7 @@ class ZmqIpcServer(ZmqServerBase):
def __init__(self, name, mode): def __init__(self, name, mode):
self.name = name self.name = name
self.mode = mode self.mode = mode
self.cached_results = defaultdict(list)
if mode == zmq.PULL: if mode == zmq.PULL:
self.file_name = f"/dev/shm/{name}.socket" self.file_name = f"/dev/shm/{name}.socket"
elif mode == zmq.ROUTER: elif mode == zmq.ROUTER:
@@ -150,6 +178,7 @@ class ZmqIpcServer(ZmqServerBase):
self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM) self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM)
self.aggregate_send = envs.FD_USE_AGGREGATE_SEND self.aggregate_send = envs.FD_USE_AGGREGATE_SEND
self.mutex = threading.Lock() self.mutex = threading.Lock()
self.response_token_lock = threading.Lock()
self.req_dict = dict() self.req_dict = dict()
self.running = True self.running = True
self.context = zmq.Context() self.context = zmq.Context()
@@ -201,6 +230,7 @@ class ZmqTcpServer(ZmqServerBase):
def __init__(self, port, mode): def __init__(self, port, mode):
self.mode = mode self.mode = mode
self.port = port self.port = port
self.cached_results = defaultdict(list)
self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM) self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM)
self.aggregate_send = envs.FD_USE_AGGREGATE_SEND self.aggregate_send = envs.FD_USE_AGGREGATE_SEND
@@ -209,6 +239,8 @@ class ZmqTcpServer(ZmqServerBase):
self.running = True self.running = True
self.context = zmq.Context() self.context = zmq.Context()
self._create_socket() self._create_socket()
self.mutex = threading.Lock()
self.response_token_lock = threading.Lock()
def _create_socket(self): def _create_socket(self):
"""create and return a ZeroMQ socket.""" """create and return a ZeroMQ socket."""