diff --git a/fastdeploy/cache_manager/cache_messager.py b/fastdeploy/cache_manager/cache_messager.py index 1c3462913..a089ed01a 100644 --- a/fastdeploy/cache_manager/cache_messager.py +++ b/fastdeploy/cache_manager/cache_messager.py @@ -252,6 +252,9 @@ class CacheMessager: self.last_step_idx = -1 self.last_layer_idx = -1 # int32 + max_step_idx = 100003 + engine_recycled_count = 0 + while True: cache_info = self.engine_worker_queue.get_cache_info() @@ -271,7 +274,6 @@ class CacheMessager: current_info["status"] = "init" logger.info(f"start cache_infos: {current_info}") self.cache_info[info["request_id"]] = current_info - self.last_step_idx = min(self.last_step_idx, current_info["current_id"]) else: self.cache_info[info["request_id"]] = info prefilled_layer_idx = layer_shm_value.value[0] @@ -287,7 +289,18 @@ class CacheMessager: if not self.cache_info: time.sleep(0.001) continue - logger.debug(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}") + if self.last_step_idx > prefilled_step_idx: + engine_recycled_count += 1 + self.last_step_idx = prefilled_step_idx # only copy value read from shm memory + prefilled_step_idx = ( + prefilled_step_idx + max_step_idx * engine_recycled_count + ) # remap prefilled_step_idx for comparison + + logger.debug( + f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx in shm: {self.last_step_idx}," + f"prefilled_step_idx: {prefilled_step_idx} engine_recycled_count {engine_recycled_count}" + ) + for req_id, item in list(self.cache_info.items()): if "status" not in item: continue @@ -318,7 +331,8 @@ class CacheMessager: if item["current_id"] < prefilled_step_idx: current_layer_idx = self.num_hidden_layers else: - current_layer_idx = prefilled_layer_idx + 1 + if item["current_id"] == prefilled_step_idx: + current_layer_idx = prefilled_layer_idx + 1 for layer_idx in range(item["layer_idx"], current_layer_idx): tic = time.time() @@ -361,9 +375,7 @@ class CacheMessager: self.engine_worker_queue.put_finished_req([(item["request_id"], "finished")]) logger.info(f"put write cache {item['request_id']}") del self.cache_info[req_id] - - self.last_step_idx = prefilled_step_idx - self.last_layer_idx = prefilled_layer_idx + self.last_layer_idx = prefilled_layer_idx except Exception as e: logger.info(f"prefill layerwise send cache thread has exception: {e}") diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 8fb9858b6..8d34bade0 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -190,6 +190,7 @@ class LLMEngine: self._init_worker_signals() self.data_processor = self.input_processor.create_processor() + self.response_lock = threading.Lock() # prevent to call send_multipart in zmq concurrently if api_server_pid is not None: if envs.FD_ENABLE_INTERNAL_ADAPTER: @@ -201,6 +202,10 @@ class LLMEngine: else: self.recv_request_server = ZmqIpcServer(name=api_server_pid, mode=zmq.PULL) self.send_response_server = ZmqIpcServer(name=api_server_pid, mode=zmq.ROUTER) + self.recv_result_handle_thread = threading.Thread( + target=self.send_response_server.recv_result_handle, daemon=True + ) + self.recv_result_handle_thread.start() time.sleep(3) self.cfg.init_cache_info() @@ -323,8 +328,9 @@ class LLMEngine: if len(results) == 0: time.sleep(0.005) continue - for request_id, contents in results.items(): - self.send_response_server.send_response(request_id, contents) + with self.response_lock: + for request_id, contents in results.items(): + self.send_response_server.send_response(request_id, contents) except Exception as e: llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}") @@ -341,7 +347,7 @@ class LLMEngine: Insert task to engine thread, monitor scheduler request queue. if the engine has resource, insert task to engine """ - current_id = -1 + current_id = 0 while self.running: try: if self.resource_manager.available_batch() == 0: @@ -376,12 +382,15 @@ class LLMEngine: time.sleep(0.001) continue - current_id = (current_id + 1) % 100003 if self.cfg.splitwise_role != "mixed": llm_logger.info("Inserting splitwise tasks") self.split_connector.send_splitwise_tasks(tasks, current_id) - self.insert_tasks(tasks, current_id) + insert_successful = self.insert_tasks(tasks, current_id) + if insert_successful: + current_id = current_id + 1 + else: + continue main_process_metrics.num_requests_waiting.dec(len(tasks)) main_process_metrics.num_requests_running.inc(len(tasks)) @@ -495,7 +504,7 @@ class LLMEngine: if failed is None: main_process_metrics.num_requests_waiting.inc(1) continue - + llm_logger.error(f"request {request_id} insert to scheduler failed: {failed}") error_result = RequestOutput( request_id=request_id, finished=True, @@ -504,7 +513,8 @@ class LLMEngine: ) # Since the request is not in scheduler # Send result by zmq directly - self.send_response_server.send_response(request_id, error_result) + with self.response_lock: + self.send_response_server.send_response(request_id, [error_result]) except Exception as e: llm_logger.error( f"Error happend while receving new request from zmq, details={e}, " @@ -821,6 +831,9 @@ class LLMEngine: self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz)) return True + if not isinstance(tasks, list): + tasks = [tasks] + need_delete_tasks = [] for task in tasks: start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER) if self.cfg.splitwise_role != "mixed": @@ -837,27 +850,29 @@ class LLMEngine: ) ] ) - tasks.remove(task) + need_delete_tasks.append(task) continue if task.sampling_params.bad_words is not None: task.sampling_params.update_from_tokenizer(self.data_processor.tokenizer) self.resource_manager.check_and_free_block_tables() - if not isinstance(tasks, list): - tasks = [tasks] + for tmp_task in need_delete_tasks: + tasks.remove(tmp_task) for item in tasks: item.schedule_start_time = time.time() + req_ids = [t.request_id for t in tasks] + + if len(tasks) == 0: + return False available_batch = np.sum(self.resource_manager.stop_flags) if len(tasks) > available_batch: llm_logger.error(f"Inserting batch:{len(tasks)} exceeds the available batch:{available_batch}.") llm_logger.error("The exceeded part will be ignored!") tasks = tasks[:available_batch] - req_ids = [t.request_id for t in tasks] - tasks = self.resource_manager.allocate_resources_for_new_tasks(tasks) if not tasks: diff --git a/fastdeploy/engine/expert_service.py b/fastdeploy/engine/expert_service.py index 048b9e7d3..c23ba2f58 100644 --- a/fastdeploy/engine/expert_service.py +++ b/fastdeploy/engine/expert_service.py @@ -172,7 +172,7 @@ class ExpertService: Insert task to engine thread, monitor scheduler request queue. if the engine has resource, insert task to engine """ - current_id = -1 + current_id = 0 while True: try: if self.resource_manager.available_batch() == 0: @@ -205,9 +205,11 @@ class ExpertService: self.llm_logger.info("Inserting splitwise tasks") self.split_connector.send_splitwise_tasks(tasks, current_id) - current_id = (current_id + 1) % 100003 - - self.insert_tasks(tasks, current_id) + insert_successful = self.insert_tasks(tasks, current_id) + if insert_successful: + current_id = current_id + 1 + else: + continue main_process_metrics.num_requests_waiting.dec(len(tasks)) main_process_metrics.num_requests_running.inc(len(tasks)) @@ -328,6 +330,7 @@ class ExpertService: if not isinstance(tasks, list): tasks = [tasks] + need_delete_tasks = [] for task in tasks: if self.cfg.splitwise_role != "mixed": status, msg = self.split_connector.check_decode_allocated(task) @@ -343,10 +346,19 @@ class ExpertService: ) ] ) - tasks.remove(task) + need_delete_tasks.append(task) continue - task.schedule_start_time = time.time() + for tmp_task in need_delete_tasks: + tasks.remove(tmp_task) + + for item in tasks: + item.schedule_start_time = time.time() + + req_ids = [t.request_id for t in tasks] + + if len(tasks) == 0: + return False available_batch = np.sum(self.resource_manager.stop_flags) if len(tasks) > available_batch: self.llm_logger.error( @@ -355,8 +367,6 @@ class ExpertService: self.llm_logger.error("The exceeded part will be ignored!") tasks = tasks[:available_batch] - req_ids = [t.request_id for t in tasks] - tasks = self.resource_manager.allocate_resources_for_new_tasks(tasks) if not tasks: diff --git a/fastdeploy/inter_communicator/zmq_server.py b/fastdeploy/inter_communicator/zmq_server.py index 32b148755..ab97e3bbd 100644 --- a/fastdeploy/inter_communicator/zmq_server.py +++ b/fastdeploy/inter_communicator/zmq_server.py @@ -18,6 +18,7 @@ import os import threading import time from abc import ABC, abstractmethod +from collections import defaultdict import msgpack import zmq @@ -32,7 +33,8 @@ class ZmqServerBase(ABC): """ def __init__(self): - pass + self.cached_results = defaultdict(list) + self.response_token_lock = threading.Lock() @abstractmethod def _create_socket(self): @@ -89,6 +91,21 @@ class ZmqServerBase(ABC): llm_logger.warning(f"{e}") return str(e), None + def recv_result_handle(self): + while True: + try: + with self.response_token_lock: + client, _, request_id = self.socket.recv_multipart(flags=zmq.NOBLOCK) + req_id_str = request_id.decode("utf-8") + with self.mutex: + self.req_dict[req_id_str] = client + except zmq.Again: + time.sleep(0.001) + continue + except Exception as e: + llm_logger.error(f"recv_result_handle get unknown exception: {e}") + continue + def send_response(self, req_id, data): """ Send generated token result to client. @@ -96,36 +113,46 @@ class ZmqServerBase(ABC): self._ensure_socket() if self.socket is None: raise RuntimeError("Router socket not created. Call create_router() first.") - - while self.running: - with self.mutex: - if req_id not in self.req_dict: - try: - client, _, request_id = self.socket.recv_multipart(flags=zmq.NOBLOCK) - req_id_str = request_id.decode("utf-8") - self.req_dict[req_id_str] = client - except zmq.Again: - time.sleep(0.001) - continue - else: - break - - try: - start_send = time.time() - if self.aggregate_send: - result = self.pack_aggregated_data(data) + new_data = [] + has_result_handle = False + with self.mutex: + if req_id not in self.req_dict: + self.cached_results[req_id].append(data) else: - result = msgpack.packb([response.to_dict() for response in data]) - self.socket.send_multipart([self.req_dict[req_id], b"", result]) - llm_logger.debug(f"send_multipart result: {req_id} len {len(data)} elapse: {time.time()-start_send}") + has_result_handle = True + if req_id in self.cached_results: + for history_data in self.cached_results[req_id]: + new_data.extend(history_data) + llm_logger.info( + f"get request {req_id} result handle after cached result, total cached length {len(self.cached_results[req_id])}" + ) + del self.cached_results[req_id] + if has_result_handle: + try: + new_data.extend(data) + start_send = time.time() + if self.aggregate_send: + result = self.pack_aggregated_data(new_data) + else: + result = msgpack.packb([response.to_dict() for response in new_data]) + with self.response_token_lock: + self.socket.send_multipart([self.req_dict[req_id], b"", result]) + llm_logger.debug( + f"send_multipart result: {req_id} len {len(new_data)} elapse: {time.time()-start_send}" + ) - except Exception as e: - llm_logger.error(f"Send result to zmq client failed: {e}") + except Exception as e: + llm_logger.error(f"Send result to zmq client failed: {e}") if data[-1].finished: with self.mutex: + if req_id not in self.req_dict: + llm_logger.warning(f"req_id {req_id} finished but no result handle, drop it") + if req_id in self.cached_results: + del self.cached_results[req_id] + else: + llm_logger.info(f"send_multipart finished, req_id: {req_id}") self.req_dict.pop(req_id, None) - llm_logger.info(f"send_multipart finished, req_id: {req_id}") @abstractmethod def close(self): @@ -143,6 +170,7 @@ class ZmqIpcServer(ZmqServerBase): def __init__(self, name, mode): self.name = name self.mode = mode + self.cached_results = defaultdict(list) if mode == zmq.PULL: self.file_name = f"/dev/shm/{name}.socket" elif mode == zmq.ROUTER: @@ -150,6 +178,7 @@ class ZmqIpcServer(ZmqServerBase): self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM) self.aggregate_send = envs.FD_USE_AGGREGATE_SEND self.mutex = threading.Lock() + self.response_token_lock = threading.Lock() self.req_dict = dict() self.running = True self.context = zmq.Context() @@ -201,6 +230,7 @@ class ZmqTcpServer(ZmqServerBase): def __init__(self, port, mode): self.mode = mode self.port = port + self.cached_results = defaultdict(list) self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM) self.aggregate_send = envs.FD_USE_AGGREGATE_SEND @@ -209,6 +239,8 @@ class ZmqTcpServer(ZmqServerBase): self.running = True self.context = zmq.Context() self._create_socket() + self.mutex = threading.Lock() + self.response_token_lock = threading.Lock() def _create_socket(self): """create and return a ZeroMQ socket."""