Reconstruct streaming data transfer with zmq (#3836)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

* reconstruct USE_GET_SAVE_OUTPUT_V1

* fix ut

* use dp rank

* fix ci
This commit is contained in:
RichardWooSJTU
2025-09-17 14:30:39 +08:00
committed by GitHub
parent f9766f917b
commit 2adca04f1f
4 changed files with 201 additions and 111 deletions

View File

@@ -14,8 +14,10 @@
# limitations under the License.
"""
import threading
from typing import Dict, Optional
import numpy as np
import paddle
from fastdeploy import envs
@@ -77,11 +79,7 @@ else:
)
from fastdeploy.inter_communicator import ZmqClient
from fastdeploy.output.stream_transfer_data import (
DecoderState,
StreamTransferData,
TextData,
)
from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput, SamplerOutput
DISABLE_RECOVER = envs.FD_DISABLED_RECOVER == "1"
@@ -162,6 +160,26 @@ def pre_process(
)
def _zmq_send_text_outputs(zmq_client: ZmqClient, output_tokens: np.ndarray, save_each_rank: bool, mp_rank: int):
"""Split output_tokens and output"""
assert zmq_client is not None, "zmq_client should not be None"
output_tokens = output_tokens.reshape([-1]).numpy()
output_tokens_lists = np.split(output_tokens, output_tokens.shape[0])
stream_transfer_datas = []
for bid, output_token_per_sample in enumerate(output_tokens_lists):
stream_transfer_data = StreamTransferData(
decoder_state=DecoderState.TEXT, tokens=output_token_per_sample, batch_id=bid
)
stream_transfer_datas.append(stream_transfer_data)
if save_each_rank or mp_rank == 0:
try:
zmq_client.send_pyobj(stream_transfer_datas)
except Exception as e:
print(f"Send message error: {e}")
def post_process_normal(
sampler_output: SamplerOutput,
model_output: ModelOutputData,
@@ -297,22 +315,11 @@ def post_process_normal(
if not skip_save_output:
if sampler_output.logprobs_tensors is None:
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
# TODO(Wanglongzhi2001): adapt more type of message.
stream_transfer_data = StreamTransferData(
decoder_state=DecoderState.TEXT,
data=TextData(
tokens=sampler_output.sampled_token_ids.numpy(),
not_need_stop=model_output.not_need_stop.numpy().item(),
batch=sampler_output.sampled_token_ids.shape[0],
speculaive_decoding=False,
),
t = threading.Thread(
target=_zmq_send_text_outputs,
args=(zmq_client, sampler_output.sampled_token_ids, save_each_rank, model_output.mp_rank),
)
if not (not save_each_rank and model_output.mp_rank > 0):
try:
zmq_client.send_pyobj(stream_transfer_data)
except Exception as e:
print(f"Send message error: {e}")
t.start()
else:
save_output(
sampler_output.sampled_token_ids,

View File

@@ -16,7 +16,7 @@
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Union
from typing import Optional
import numpy as np
@@ -25,48 +25,19 @@ class DecoderState(Enum):
"""DecoderState"""
TEXT = "text"
VISION = "vision"
VEDIO = "vedio"
IMAGE = "image"
VIDEO = "video"
AUDIO = "audio"
@dataclass
class TextData:
"""TextData"""
tokens: np.array
not_need_stop: bool
batch: int
speculaive_decoding: bool
logprobs: Optional[np.array] = None
accept_tokens: Optional[np.array] = None
accept_num: Optional[np.array] = None
@dataclass
class VisionData:
"""VisionData"""
tokens: np.array
@dataclass
class VedioData:
"""VedioData"""
tokens: np.array
@dataclass
class AudioData:
"""AudioData"""
tokens: np.array
@dataclass
class StreamTransferData:
"""StreamTransferData"""
decoder_state: DecoderState
data: Union[TextData, VisionData, VedioData, AudioData]
tokens: np.array
batch_id: int
speculaive_decoding: bool = False
logprobs: Optional[np.array] = None
accept_tokens: Optional[np.array] = None
accept_num: Optional[np.array] = None

View File

@@ -31,7 +31,6 @@ from fastdeploy import envs
from fastdeploy.engine.request import CompletionOutput, RequestMetrics, RequestOutput
from fastdeploy.inter_communicator import IPCSignal, ZmqClient
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData
from fastdeploy.platforms import current_platform
from fastdeploy.utils import llm_logger, spec_logger
from fastdeploy.worker.output import LogprobsLists
@@ -49,7 +48,6 @@ class TokenProcessor:
"""
def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, split_connector):
import paddle
paddle.device.set_device("cpu")
self.cfg = cfg
@@ -60,7 +58,10 @@ class TokenProcessor:
self.split_connector = split_connector
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
self.zmq_server = ZmqClient(name=f"get_save_output_rank{self.cfg.local_device_ids[0]}", mode=zmq.PULL)
llm_logger.debug(f"create zmq get_save_output_rank{self.cfg.parallel_config.local_data_parallel_id}")
self.zmq_server = ZmqClient(
name=f"get_save_output_rank{self.cfg.parallel_config.local_data_parallel_id}", mode=zmq.PULL
)
self.zmq_server.start_server()
self.zmq_server.create_router()
@@ -135,11 +136,145 @@ class TokenProcessor:
if self.worker is not None:
raise Exception("Worker is already running!")
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
self.worker = threading.Thread(target=self.process_sampling_results_use_zmq)
else:
self.worker = threading.Thread(target=self.process_sampling_results)
self.worker.daemon = True
self.worker.start()
def _reschedule_preempt_task(self, batch_size):
"""reschedule when real batch size is smaller than the insert position of preemted_task"""
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
need_to_be_reschedule_req_ids = list(self.resource_manager.to_be_rescheduled_request_id_set)
for request_id in need_to_be_reschedule_req_ids:
if self.resource_manager.requests[request_id].idx >= (
batch_size - 1
): # No more token generated for preempted request
self.resource_manager.reschedule_preempt_task(request_id)
def _process_per_token(self, task, batch_id: int, token_ids: np.ndarray, result: RequestOutput, is_prefill: bool):
"""
process output token by token
"""
current_time = time.time()
task_id = task.request_id
token_id_list = token_ids.tolist()
self._record_metrics(task, current_time, token_id_list)
for token_id in token_id_list:
recovery_stop = token_id == RECOVERY_STOP_SIGNAL
if recovery_stop:
llm_logger.info(f"recovery stop signal found at task {task_id}")
self.tokens_counter[task_id] += 1
if token_id != RECOVERY_STOP_SIGNAL:
result.outputs.token_ids.append(token_id)
task.output_token_ids.append(token_id)
if token_id in task.eos_token_ids or is_prefill or recovery_stop:
result.finished = True
if recovery_stop:
result.error_msg = "Recover is not supported, the result is incomplete!"
llm_logger.info(
f"Request: {task_id} finished, number of " f"generated tokens: {self.tokens_counter[task_id]}."
)
llm_logger.info(
f"Request: {task_id} token ratio: {self.tokens_counter[task_id] / (time.time() - task.inference_start_time)}"
)
llm_logger.info(f"{self.resource_manager.info()}")
if self.cfg.speculative_config.method:
self._compute_speculative_status()
if not is_prefill:
self._record_completion_metrics(task, current_time)
self._recycle_resources(task_id, batch_id, task, result, is_prefill)
break
return result
def _process_batch_output_use_zmq(self, receive_datas):
"""
process output sample by sample
"""
batch_result = list()
for _, stream_data in enumerate(receive_datas):
i = stream_data.batch_id
if self.resource_manager.stop_flags[i]:
continue
task = self.resource_manager.tasks_list[i]
task_id = task.request_id
token_ids = stream_data.tokens # numpy.array
current_time = time.time()
if self.tokens_counter[task_id] == 0:
metrics = RequestMetrics(
arrival_time=task.arrival_time,
inference_start_time=task.inference_start_time,
first_token_time=time.time() - task.inference_start_time,
time_in_queue=task.schedule_start_time - task.preprocess_end_time,
preprocess_cost_time=task.preprocess_end_time - task.preprocess_start_time,
request_start_time=task.arrival_time,
)
self._record_first_token_metrics(task, current_time)
else:
metrics = RequestMetrics(
arrival_time=time.time(),
request_start_time=task.arrival_time,
)
result = RequestOutput(
request_id=task_id,
outputs=CompletionOutput(
index=i,
send_idx=self.tokens_counter[task_id],
token_ids=[],
draft_token_ids=[],
),
finished=False,
metrics=metrics,
)
if self.tokens_counter[task_id] == 0:
if task.messages is not None:
result.prompt = task.messages
result.num_cached_tokens = task.num_cached_tokens
is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill"
result = self._process_per_token(task, i, token_ids, result, is_prefill)
if not is_prefill or self.cfg.scheduler_config.name == "splitwise":
batch_result.append(result)
return batch_result
def process_sampling_results_use_zmq(self):
"""
use zmq to receive outputs from worker and process them
"""
if self.speculative_decoding:
raise NotImplementedError("GET_SAVE_OUTPUT_V1 does not support speculative decoding")
if self.use_logprobs:
raise NotImplementedError("GET_SAVE_OUTPUT_V1 does not support use_logprobs")
rank_id = self.cfg.parallel_config.local_data_parallel_id
while True:
try:
if (
self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1
) or (rank_id == 0):
receive_datas = self.zmq_server.recv_pyobj()
assert isinstance(receive_datas, list)
llm_logger.debug(f"token_processor receive_data {receive_datas}")
batch_size = len(receive_datas)
self._reschedule_preempt_task(batch_size)
batch_result = self._process_batch_output_use_zmq(receive_datas)
self.postprocess(batch_result)
except Exception as e:
llm_logger.error(f"Recieve message error: {e}")
continue
def process_sampling_results(self):
"""
read tokens from paddle inference engine and process
@@ -162,25 +297,6 @@ class TokenProcessor:
while True:
try:
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
try:
receive_data = self.zmq_server.recv_pyobj()
assert isinstance(receive_data, StreamTransferData)
if receive_data is not None:
# TODO(Wanglongzhi2001): adapt more type of message.
if receive_data.decoder_state == DecoderState.TEXT:
self.output_tokens[0, 0] = paddle.to_tensor(
receive_data.data.not_need_stop, dtype="int64"
)
self.output_tokens[1, 0] = paddle.to_tensor(receive_data.data.batch, dtype="int64")
self.output_tokens[2 : 2 + receive_data.data.batch, 0] = paddle.to_tensor(
receive_data.data.tokens[:, 0], dtype="int64"
)
except Exception as e:
print(f"Receive message error: {e}")
continue
else:
is_blocking = True
if self.speculative_decoding:
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
@@ -336,13 +452,8 @@ class TokenProcessor:
tokens = tokens[2 : batch + 2]
batch_result = list()
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
need_to_be_reschedule_req_ids = list(self.resource_manager.to_be_rescheduled_request_id_set)
for request_id in need_to_be_reschedule_req_ids:
if self.resource_manager.requests[request_id].idx >= (
batch - 1
): # No more token generated for preempted request
self.resource_manager.reschedule_preempt_task(request_id)
# reschedule
self._reschedule_preempt_task(batch)
for i in range(batch):
if self.resource_manager.stop_flags[i]:
continue

View File

@@ -170,6 +170,7 @@ class GPUModelRunner(ModelRunnerBase):
self.zmq_client = None
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
logger.info(f"zmq client get_save_output_rank{local_rank}")
self.zmq_client = ZmqClient(name=f"get_save_output_rank{local_rank}", mode=zmq.PUSH)
self.zmq_client.connect()
self.zmq_client.socket.SNDTIMEO = 3000