[Optimization] Improve perf for fd response token with internal adapter (#4992)

* [Optimize] Improve perf for fd response token with internal adapter

* fix

* fix bug

* fix ci

* fix ci

* fix ci

* fix ci
This commit is contained in:
chenjian
2025-11-21 19:02:03 +08:00
committed by GitHub
parent 5bcf79d780
commit 3ea1b44a58
15 changed files with 202 additions and 67 deletions

View File

@@ -958,7 +958,10 @@ class EngineService:
)
# Since the request is not in scheduler
# Send result by zmq directly
self.send_response_server.send_response(request_id, [error_result])
if envs.FD_ENABLE_INTERNAL_ADAPTER:
self.send_response_server.send_response(None, [[error_result]])
else:
self.send_response_server.send_response(request_id, [error_result])
def _decode_token(self, token_ids, req_id, is_end):
delta_text = ""
@@ -984,33 +987,67 @@ class EngineService:
if len(results) == 0:
time.sleep(0.005)
continue
for request_id, contents in results.items():
if envs.FD_ENABLE_INTERNAL_ADAPTER:
new_contents = []
for content in contents:
if isinstance(content, RequestOutput) and content.outputs is not None:
decode_type = content.outputs.decode_type
delta_text = ""
if decode_type == 0:
delta_text, token_ids = self._decode_token(
token_ids=content.outputs.token_ids, req_id=request_id, is_end=content.finished
)
for step_batch_results in results:
new_step_contents = []
for content in step_batch_results:
if isinstance(content, RequestOutput) and content.outputs is not None:
decode_type = content.outputs.decode_type
delta_text = ""
if decode_type == 0:
delta_text, token_ids = self._decode_token(
token_ids=content.outputs.token_ids,
req_id=content.request_id,
is_end=content.finished,
)
else:
token_ids = content.outputs.token_ids
if len(token_ids):
content.outputs.token_ids = token_ids
content.outputs.text = delta_text
new_step_contents.append(content)
elif content.finished:
new_step_contents.append(content)
else:
llm_logger.warning(
f"current tokens need to accumulate, req_id: {content.request_id} {content.outputs.token_ids}"
)
else:
token_ids = content.outputs.token_ids
if len(token_ids):
content.outputs.token_ids = token_ids
content.outputs.text = delta_text
new_contents.append(content)
elif content.finished:
new_contents.append(content)
new_step_contents.append(content)
if new_step_contents:
new_contents.append(new_step_contents)
if new_contents:
self.send_response_server.send_response(None, new_contents)
else:
for request_id, contents in results.items():
new_contents = []
for content in contents:
if isinstance(content, RequestOutput) and content.outputs is not None:
decode_type = content.outputs.decode_type
delta_text = ""
if decode_type == 0:
delta_text, token_ids = self._decode_token(
token_ids=content.outputs.token_ids, req_id=request_id, is_end=content.finished
)
else:
token_ids = content.outputs.token_ids
if len(token_ids):
content.outputs.token_ids = token_ids
content.outputs.text = delta_text
new_contents.append(content)
elif content.finished:
new_contents.append(content)
else:
llm_logger.warning(
f"current tokens need to accumulate, req_id: {request_id} {content.outputs.token_ids}"
)
else:
llm_logger.warning(
f"current tokens need to accumulate, req_id: {request_id} {content.outputs.token_ids}"
)
else:
new_contents.append(content)
if len(new_contents):
llm_logger.debug(f"Send response for request id: {request_id}, {new_contents}")
self.send_response_server.send_response(request_id, new_contents)
new_contents.append(content)
if len(new_contents):
llm_logger.debug(f"Send response for request id: {request_id}")
self.send_response_server.send_response(request_id, new_contents)
except Exception as e:
llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")

View File

@@ -180,6 +180,10 @@ class LLMEngine:
if self.cfg.scheduler_config.splitwise_role != "mixed":
self.launched_cache_manager_signal.value[0] = 1
if self.cfg.scheduler_config.splitwise_role != "mixed" and envs.FD_ENABLE_INTERNAL_ADAPTER:
envs.FD_ZMQ_RECV_REQUEST_SERVER_PORT = envs.FD_ZMQ_RECV_REQUEST_SERVER_PORTS.split(",")[0]
envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORT = envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORTS.split(",")[0]
if api_server_pid is not None:
llm_logger.info(f"Start zmq server, api_server_pid: {api_server_pid}")
self.engine.start_zmq_service(api_server_pid)
@@ -707,18 +711,19 @@ class LLMEngine:
host_ip = self.cfg.host_ip
disaggregate = self.cfg.disaggregate_info
request_queues_for_dp_ipc = None
result_queue_for_dp_ipc = None
result_queues_for_dp_ipc = None
if self.cfg.scheduler_config.name == "splitwise":
self.engine.scheduler.start(role, host_ip, disaggregate)
elif self.cfg.scheduler_config.name == "dp":
request_queues_for_dp_ipc = []
result_queue_for_dp_ipc = multiprocessing.Queue()
result_queues_for_dp_ipc = []
for i in range(self.cfg.parallel_config.data_parallel_size):
request_queues_for_dp_ipc.append(multiprocessing.Queue())
result_queues_for_dp_ipc.append(multiprocessing.Queue())
self.engine.scheduler.start(
self.cfg.node_rank * self.cfg.worker_num_per_node % self.cfg.worker_num_per_node,
request_queues_for_dp_ipc,
result_queue_for_dp_ipc,
result_queues_for_dp_ipc,
)
if not envs.FD_ENABLE_MULTI_API_SERVER:
@@ -755,7 +760,7 @@ class LLMEngine:
i,
None,
request_queues_for_dp_ipc,
result_queue_for_dp_ipc,
result_queues_for_dp_ipc,
),
)
)

View File

@@ -27,7 +27,6 @@ import numpy as np
from fastdeploy.engine.common_engine import EngineService
from fastdeploy.inter_communicator import IPCSignal
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
from fastdeploy.utils import console_logger, envs, llm_logger
@@ -53,6 +52,13 @@ class ExpertService:
end_pos = start_pos + self.cfg.parallel_config.tensor_parallel_size
if cfg.scheduler_config.splitwise_role != "mixed":
self.cfg.cache_config.rdma_comm_ports = self.cfg.cache_config.rdma_comm_ports[start_pos:end_pos]
if envs.FD_ENABLE_INTERNAL_ADAPTER:
envs.FD_ZMQ_RECV_REQUEST_SERVER_PORT = envs.FD_ZMQ_RECV_REQUEST_SERVER_PORTS.split(",")[
local_data_parallel_id
]
envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORT = envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORTS.split(",")[
local_data_parallel_id
]
self.cfg.local_device_ids = self.cfg.parallel_config.device_ids.split(",")[start_pos:end_pos]
llm_logger.info(f"local_data_parallel_id: {local_data_parallel_id}")
self.cfg.disaggregate_info = None
@@ -77,7 +83,7 @@ class ExpertService:
self._finalizer = weakref.finalize(self, self._exit_sub_services)
def start(
self, ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None
self, ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc=None, result_queues_for_dp_ipc=None
):
"""
Initializes the engine and starts its sub-services.
@@ -92,18 +98,15 @@ class ExpertService:
self.engine.create_data_processor()
if self.cfg.scheduler_config.name == "dp":
self.cfg.init_cache_info()
assert (request_queues_for_dp_ipc is not None) and (result_queue_for_dp_ipc is not None)
self.engine.scheduler.start(local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc)
assert (request_queues_for_dp_ipc is not None) and (result_queues_for_dp_ipc is not None)
self.engine.scheduler.start(local_data_parallel_id, request_queues_for_dp_ipc, result_queues_for_dp_ipc)
if ipc_signal_suffix is not None:
self.api_server_pid = ipc_signal_suffix
self.engine.start_zmq_service(ipc_signal_suffix)
else:
ipc_signal_suffix = self.cfg.parallel_config.engine_worker_queue_port[0]
if envs.FD_ENABLE_INTERNAL_ADAPTER:
self.internal_adapter = InternalAdapter(
cfg=self.cfg, engine=self.engine, dp_rank=self.cfg.parallel_config.local_data_parallel_id
)
self.engine.start_zmq_service(self.cfg.parallel_config.engine_worker_queue_port[local_data_parallel_id])
llm_logger.info(f"start expert service {local_data_parallel_id}")
@@ -189,7 +192,7 @@ class ExpertService:
def start_data_parallel_service(
cfg, local_data_parallel_id, ipc_signal_suffix=None, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None
cfg, local_data_parallel_id, ipc_signal_suffix=None, request_queues_for_dp_ipc=None, result_queues_for_dp_ipc=None
):
"""
Start expert service
@@ -198,7 +201,7 @@ def start_data_parallel_service(
try:
expert_service.start(
ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc
ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc, result_queues_for_dp_ipc
)
def deamon_thread():

View File

@@ -102,6 +102,8 @@ class Request:
prefill_start_index: int = 0,
prefill_end_index: int = 0,
num_computed_tokens: int = 0,
# for internal adapter
ic_req_data: Optional[dict] = (None,),
) -> None:
self.request_id = request_id
self.prompt = prompt
@@ -172,6 +174,8 @@ class Request:
self.extend_block_tables = []
# dp
self.dp_rank = dp_rank
self.llm_engine_recv_req_timestamp = time.time()
self.ic_req_data = ic_req_data
self.async_process_futures = []
self.error_message = None
@@ -226,6 +230,7 @@ class Request:
video_end=d.get("video_end", 0),
audio_end=d.get("audio_end", 0),
dp_rank=d.get("dp_rank", None),
ic_req_data=d.get("ic_req_data", None),
inference_start_time=d.get("inference_start_time"),
llm_engine_recv_req_timestamp=d.get("llm_engine_recv_req_timestamp"),
)
@@ -278,6 +283,7 @@ class Request:
"image_end": self.image_end,
"video_end": self.video_end,
"audio_end": self.audio_end,
"ic_req_data": self.ic_req_data,
}
add_params = [
"guided_json",
@@ -478,6 +484,9 @@ class RequestOutput:
num_input_video_tokens: Optional[int] = 0,
error_code: Optional[int] = 200,
error_msg: Optional[str] = None,
# for internal adapter
ic_req_data: Optional[dict] = None,
prompt_token_ids_len: Optional[int] = 0,
) -> None:
self.request_id = request_id
self.prompt = prompt
@@ -493,6 +502,8 @@ class RequestOutput:
self.num_input_video_tokens = num_input_video_tokens
self.error_code = error_code
self.error_msg = error_msg
self.ic_req_data = ic_req_data
self.prompt_token_ids_len = prompt_token_ids_len
if prompt_token_ids is None:
self.prompt_token_ids = []
@@ -565,6 +576,8 @@ class RequestOutput:
"num_input_video_tokens": self.num_input_video_tokens,
"error_code": self.error_code,
"error_msg": self.error_msg,
"ic_req_data": self.ic_req_data,
"prompt_token_ids_len": self.prompt_token_ids_len,
}

View File

@@ -105,11 +105,15 @@ environment_variables: dict[str, Callable[[], Any]] = {
"FD_ENABLE_MODEL_CACHE": lambda: bool(int(os.getenv("FD_ENABLE_MODEL_CACHE", "0"))),
# enable internal module to access LLMEngine.
"FD_ENABLE_INTERNAL_ADAPTER": lambda: int(os.getenv("FD_ENABLE_INTERNAL_ADAPTER", "0")),
# LLMEngine recieve requests port, used when FD_ENABLE_INTERNAL_ADAPTER=1
# LLMEngine receive requests port, used when FD_ENABLE_INTERNAL_ADAPTER=1
"FD_ZMQ_RECV_REQUEST_SERVER_PORT": lambda: os.getenv("FD_ZMQ_RECV_REQUEST_SERVER_PORT", "8200"),
# LLMEngine send response port, used when FD_ENABLE_INTERNAL_ADAPTER=1
"FD_ZMQ_SEND_RESPONSE_SERVER_PORT": lambda: os.getenv("FD_ZMQ_SEND_RESPONSE_SERVER_PORT", "8201"),
# LLMEngine recieve control command port, used when FD_ENABLE_INTERNAL_ADAPTER=1
# LLMEngine receive requests port, used when FD_ENABLE_INTERNAL_ADAPTER=1
"FD_ZMQ_RECV_REQUEST_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_RECV_REQUEST_SERVER_PORTS", "8200"),
# LLMEngine send response port, used when FD_ENABLE_INTERNAL_ADAPTER=1
"FD_ZMQ_SEND_RESPONSE_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_SEND_RESPONSE_SERVER_PORTS", "8201"),
# LLMEngine receive control command port, used when FD_ENABLE_INTERNAL_ADAPTER=1
"FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"),
# Whether to enable cache task in decode node
"FD_ENABLE_CACHE_TASK": lambda: os.getenv("FD_ENABLE_CACHE_TASK", "1"),

View File

@@ -36,6 +36,9 @@ class ZmqServerBase(ABC):
def __init__(self):
self.cached_results = defaultdict(list)
self.response_token_lock = threading.Lock()
self.response_handle_per_step = None
self.response_handle_name_per_step = None
self.batch_id_per_step = 0
@abstractmethod
def _create_socket(self):
@@ -126,16 +129,20 @@ class ZmqServerBase(ABC):
with self.response_token_lock:
client, _, request_id = self.socket.recv_multipart(flags=zmq.NOBLOCK)
req_id_str = request_id.decode("utf-8")
need_send_after_finished_inference = False
with self.mutex:
self.req_dict[req_id_str] = client
if req_id_str in self.cached_results:
if self.cached_results[req_id_str][-1][-1].finished:
need_send_after_finished_inference = True
if need_send_after_finished_inference:
self.send_response(req_id_str, [])
llm_logger.info(f"send_multipart finished, req_id: {req_id_str}")
self.req_dict.pop(req_id_str, None)
if envs.FD_ENABLE_INTERNAL_ADAPTER:
with self.mutex:
self.response_handle_per_step = client
else:
need_send_after_finished_inference = False
with self.mutex:
self.req_dict[req_id_str] = client
if req_id_str in self.cached_results:
if self.cached_results[req_id_str][-1][-1].finished:
need_send_after_finished_inference = True
if need_send_after_finished_inference:
self.send_response(req_id_str, [])
llm_logger.info(f"send_multipart finished, req_id: {req_id_str}")
self.req_dict.pop(req_id_str, None)
except zmq.Again:
time.sleep(0.001)
@@ -144,7 +151,39 @@ class ZmqServerBase(ABC):
llm_logger.error(f"recv_result_handle get unknown exception: {e}")
continue
def send_response(self, req_id, data):
def _send_response_per_step(self, batch_id, data):
"""
Send generated token result to client.
"""
self._ensure_socket()
if self.socket is None:
raise RuntimeError("Router socket not created. Call create_router() first.")
need_send_data = []
with self.mutex:
if self.response_handle_per_step is None:
self.cached_results["data"].extend(data)
else:
need_send_data = self.cached_results["data"]
self.cached_results["data"] = []
if self.response_handle_per_step is not None:
try:
if data:
need_send_data.extend(data)
start_send = time.time()
result = msgpack.packb(
[[response.to_dict() for response in send_data_list] for send_data_list in need_send_data]
)
with self.response_token_lock:
self.socket.send_multipart([self.response_handle_per_step, b"", result])
llm_logger.info(
f"send_multipart result: step {self.batch_id_per_step} lens {len(need_send_data)} elapse: {time.time()-start_send}"
)
self.batch_id_per_step += 1
except Exception as e:
llm_logger.error(f"Send result to zmq client failed: {e}")
def _send_response_per_query(self, req_id, data):
"""
Send generated token result to client.
"""
@@ -188,6 +227,12 @@ class ZmqServerBase(ABC):
llm_logger.info(f"send_multipart finished, req_id: {req_id}")
self.req_dict.pop(req_id, None)
def send_response(self, req_id, data):
if envs.FD_ENABLE_INTERNAL_ADAPTER:
self._send_response_per_step(req_id, data)
else:
self._send_response_per_query(req_id, data)
@abstractmethod
def close(self):
pass
@@ -202,6 +247,7 @@ class ZmqIpcServer(ZmqServerBase):
"""
def __init__(self, name, mode):
super(ZmqIpcServer, self).__init__()
self.name = name
self.mode = mode
self.cached_results = defaultdict(list)
@@ -262,6 +308,7 @@ class ZmqTcpServer(ZmqServerBase):
"""
def __init__(self, port, mode):
super(ZmqTcpServer, self).__init__()
self.mode = mode
self.port = port
self.cached_results = defaultdict(list)

View File

@@ -291,6 +291,7 @@ class TokenProcessor:
),
finished=False,
metrics=metrics,
ic_req_data=task.ic_req_data,
)
if self.use_logprobs:
if getattr(stream_data, "logprobs", None) is not None:
@@ -684,6 +685,8 @@ class TokenProcessor:
),
finished=False,
metrics=metrics,
ic_req_data=task.ic_req_data,
prompt_token_ids_len=task.prompt_token_ids_len,
)
if self.tokens_counter[task_id] == 0:
if task.messages is not None:

View File

@@ -61,6 +61,7 @@ class DPLocalScheduler(LocalScheduler):
self.scheduler_logger.info(f"Scheduler has received some finished responses: {finished_responses}")
with self.mutex:
self.batch_responses_per_step.append([response.raw for response in responses])
for response in responses:
if response.request_id not in self.responses:
self.responses[response.request_id] = [response]
@@ -206,10 +207,10 @@ class DPScheduler:
splitwise_role,
)
def start(self, dp_rank: int, request_queues: List[Queue], result_queue: Queue):
def start(self, dp_rank: int, request_queues: List[Queue], result_queues: Queue):
self.dp_rank = dp_rank
self.request_queues = request_queues
self.result_queue = result_queue
self.result_queues = result_queues
self.scheduler_logger = get_logger("dpscheduler", f"dp_scheduler_rank{self.dp_rank}.log")
self._scheduler.scheduler_logger = self.scheduler_logger
threading.Thread(target=self._put_requests_to_local).start()
@@ -235,7 +236,7 @@ class DPScheduler:
results = self._scheduler.get_results()
if len(results) == 0:
continue
self.result_queue.put(results)
self.result_queues[self.dp_rank].put(results)
def get_requests(
self,
@@ -256,4 +257,4 @@ class DPScheduler:
self._scheduler.put_results(results)
def get_results(self) -> Dict[str, List[RequestOutput]]:
return self.result_queue.get()
return self.result_queues[self.dp_rank].get()

View File

@@ -79,6 +79,7 @@ class LocalScheduler:
self.requests: Dict[str, ScheduledRequest] = dict()
self.responses: Dict[str, List[ScheduledResponse]] = dict()
self.batch_responses_per_step: List[List[ScheduledResponse]] = list()
self.wait_request_timeout = 10
self.wait_response_timeout = 0.001
@@ -314,6 +315,7 @@ class LocalScheduler:
scheduler_logger.info(f"Scheduler has received some finished responses: {finished_responses}")
with self.mutex:
self.batch_responses_per_step.append([response.raw for response in responses])
for response in responses:
if response.request_id not in self.requests:
scheduler_logger.warning(f"Scheduler has received a expired response: {[response.request_id]}")
@@ -352,12 +354,19 @@ class LocalScheduler:
def _get_results():
responses = self.responses
batch_responses_per_step = self.batch_responses_per_step
self.responses = dict()
return responses
self.batch_responses_per_step = list()
if not responses:
return None # No response yet
return responses, batch_responses_per_step
with self.responses_not_empty:
responses = self.responses_not_empty.wait_for(_get_results, self.wait_response_timeout)
wait_response_result = self.responses_not_empty.wait_for(_get_results, self.wait_response_timeout)
if wait_response_result is not None:
responses, batch_responses_per_step = wait_response_result
else:
responses, batch_responses_per_step = dict(), list()
results = dict()
for request_id, resps in responses.items():
finished = False
@@ -373,4 +382,7 @@ class LocalScheduler:
if results:
scheduler_logger.debug(f"get responses, {results}")
return results
if envs.FD_ENABLE_INTERNAL_ADAPTER:
return batch_responses_per_step
else:
return results

View File

@@ -436,6 +436,5 @@ class SplitwiseConnector:
self.logger.debug(f"_handle_decode function receive {payload}")
tasks = []
for task in payload:
output = RequestOutput.from_dict(task)
tasks.append(output)
tasks.append(RequestOutput.from_dict(task))
self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks))

View File

@@ -42,12 +42,14 @@ FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8234))
FD_ENABLE_INTERNAL_ADAPTER = int(os.getenv("FD_ENABLE_INTERNAL_ADAPTER", "1"))
FD_ZMQ_RECV_REQUEST_SERVER_PORT = int(os.getenv("FD_ZMQ_RECV_REQUEST_SERVER_PORT", "8204"))
FD_ZMQ_SEND_RESPONSE_SERVER_PORT = int(os.getenv("FD_ZMQ_SEND_RESPONSE_SERVER_PORT", "8205"))
FD_ZMQ_RECV_REQUEST_SERVER_PORTS = str(os.getenv("FD_ZMQ_RECV_REQUEST_SERVER_PORTS", "8204"))
FD_ZMQ_SEND_RESPONSE_SERVER_PORTS = str(os.getenv("FD_ZMQ_SEND_RESPONSE_SERVER_PORTS", "8205"))
FD_ZMQ_CONTROL_CMD_SERVER_PORTS = int(os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8206"))
FD_ZMQ_CONTROL_CMD_SERVER_PORT = FD_ZMQ_CONTROL_CMD_SERVER_PORTS
env["FD_ENABLE_INTERNAL_ADAPTER"] = str(FD_ENABLE_INTERNAL_ADAPTER)
env["FD_ZMQ_RECV_REQUEST_SERVER_PORT"] = str(FD_ZMQ_RECV_REQUEST_SERVER_PORT)
env["FD_ZMQ_SEND_RESPONSE_SERVER_PORT"] = str(FD_ZMQ_SEND_RESPONSE_SERVER_PORT)
env["FD_ZMQ_RECV_REQUEST_SERVER_PORTS"] = str(FD_ZMQ_RECV_REQUEST_SERVER_PORTS)
env["FD_ZMQ_SEND_RESPONSE_SERVER_PORTS"] = str(FD_ZMQ_SEND_RESPONSE_SERVER_PORTS)
env["FD_ZMQ_CONTROL_CMD_SERVER_PORTS"] = str(FD_ZMQ_CONTROL_CMD_SERVER_PORTS)
env["FD_ZMQ_CONTROL_CMD_SERVER_PORT"] = str(FD_ZMQ_CONTROL_CMD_SERVER_PORT)
@@ -57,8 +59,8 @@ PORTS_TO_CLEAN = [
FD_ENGINE_QUEUE_PORT,
FD_METRICS_PORT,
FD_CACHE_QUEUE_PORT,
FD_ZMQ_RECV_REQUEST_SERVER_PORT,
FD_ZMQ_SEND_RESPONSE_SERVER_PORT,
FD_ZMQ_RECV_REQUEST_SERVER_PORTS,
FD_ZMQ_SEND_RESPONSE_SERVER_PORTS,
FD_ZMQ_CONTROL_CMD_SERVER_PORT,
]
@@ -271,7 +273,7 @@ def test_request_and_response(zmq_req_client):
has_is_end_result = False
while True:
result = result_queue.get()
if result[-1]["finished"]:
if result[0][-1]["finished"]:
has_is_end_result = True
break
assert has_is_end_result is True

View File

@@ -69,6 +69,8 @@ class MockTask:
self.prefill_chunk_num = 0
self.pooling_params = None
self.llm_engine_recv_req_timestamp = time.time()
self.ic_req_data = {}
self.prompt_token_ids_len = 0
def get(self, key: str, default_value=None):
if hasattr(self, key):

View File

@@ -63,6 +63,8 @@ class MockTask:
self.prefill_chunk_info = None
self.prefill_chunk_num = 0
self.llm_engine_recv_req_timestamp = time.time()
self.ic_req_data = {}
self.prompt_token_ids_len = 0
def get(self, key: str, default_value=None):
if hasattr(self, key):

View File

@@ -38,6 +38,9 @@ class TestTokenProcessorLogprobs(unittest.TestCase):
self.task_mock.preprocess_end_time = 95.0
self.task_mock.preprocess_start_time = 90.0
self.task_mock.schedule_start_time = 95.0
self.task_mock.llm_engine_recv_req_timestamp = 95.0
self.task_mock.ic_req_data = {}
self.task_mock.prompt_token_ids_len = 0
self.processor.resource_manager.tasks_list = [self.task_mock]

View File

@@ -71,6 +71,7 @@ class MockScheduledResponse:
def __init__(self, request_output):
self.request_id = request_output.request_id
self.finished = request_output.finished
self.raw = self
# Mock LocalScheduler base class
@@ -93,6 +94,7 @@ class MockLocalScheduler:
self.ids_read_cursor = 0
self.requests_not_empty = threading.Condition()
self.responses_not_empty = threading.Condition()
self.batch_responses_per_step = list()
def calc_required_blocks(self, token_len, block_size):
return (token_len + block_size - 1) // block_size