Reconstruct streaming data transfer with zmq (#3836)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

* reconstruct USE_GET_SAVE_OUTPUT_V1

* fix ut

* use dp rank

* fix ci
This commit is contained in:
RichardWooSJTU
2025-09-17 14:30:39 +08:00
committed by GitHub
parent f9766f917b
commit 2adca04f1f
4 changed files with 201 additions and 111 deletions

View File

@@ -14,8 +14,10 @@
# limitations under the License. # limitations under the License.
""" """
import threading
from typing import Dict, Optional from typing import Dict, Optional
import numpy as np
import paddle import paddle
from fastdeploy import envs from fastdeploy import envs
@@ -77,11 +79,7 @@ else:
) )
from fastdeploy.inter_communicator import ZmqClient from fastdeploy.inter_communicator import ZmqClient
from fastdeploy.output.stream_transfer_data import ( from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData
DecoderState,
StreamTransferData,
TextData,
)
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput, SamplerOutput from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput, SamplerOutput
DISABLE_RECOVER = envs.FD_DISABLED_RECOVER == "1" DISABLE_RECOVER = envs.FD_DISABLED_RECOVER == "1"
@@ -162,6 +160,26 @@ def pre_process(
) )
def _zmq_send_text_outputs(zmq_client: ZmqClient, output_tokens: np.ndarray, save_each_rank: bool, mp_rank: int):
"""Split output_tokens and output"""
assert zmq_client is not None, "zmq_client should not be None"
output_tokens = output_tokens.reshape([-1]).numpy()
output_tokens_lists = np.split(output_tokens, output_tokens.shape[0])
stream_transfer_datas = []
for bid, output_token_per_sample in enumerate(output_tokens_lists):
stream_transfer_data = StreamTransferData(
decoder_state=DecoderState.TEXT, tokens=output_token_per_sample, batch_id=bid
)
stream_transfer_datas.append(stream_transfer_data)
if save_each_rank or mp_rank == 0:
try:
zmq_client.send_pyobj(stream_transfer_datas)
except Exception as e:
print(f"Send message error: {e}")
def post_process_normal( def post_process_normal(
sampler_output: SamplerOutput, sampler_output: SamplerOutput,
model_output: ModelOutputData, model_output: ModelOutputData,
@@ -297,22 +315,11 @@ def post_process_normal(
if not skip_save_output: if not skip_save_output:
if sampler_output.logprobs_tensors is None: if sampler_output.logprobs_tensors is None:
if envs.FD_USE_GET_SAVE_OUTPUT_V1: if envs.FD_USE_GET_SAVE_OUTPUT_V1:
# TODO(Wanglongzhi2001): adapt more type of message. t = threading.Thread(
stream_transfer_data = StreamTransferData( target=_zmq_send_text_outputs,
decoder_state=DecoderState.TEXT, args=(zmq_client, sampler_output.sampled_token_ids, save_each_rank, model_output.mp_rank),
data=TextData(
tokens=sampler_output.sampled_token_ids.numpy(),
not_need_stop=model_output.not_need_stop.numpy().item(),
batch=sampler_output.sampled_token_ids.shape[0],
speculaive_decoding=False,
),
) )
t.start()
if not (not save_each_rank and model_output.mp_rank > 0):
try:
zmq_client.send_pyobj(stream_transfer_data)
except Exception as e:
print(f"Send message error: {e}")
else: else:
save_output( save_output(
sampler_output.sampled_token_ids, sampler_output.sampled_token_ids,

View File

@@ -16,7 +16,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Optional, Union from typing import Optional
import numpy as np import numpy as np
@@ -25,48 +25,19 @@ class DecoderState(Enum):
"""DecoderState""" """DecoderState"""
TEXT = "text" TEXT = "text"
VISION = "vision" IMAGE = "image"
VEDIO = "vedio" VIDEO = "video"
AUDIO = "audio" AUDIO = "audio"
@dataclass
class TextData:
"""TextData"""
tokens: np.array
not_need_stop: bool
batch: int
speculaive_decoding: bool
logprobs: Optional[np.array] = None
accept_tokens: Optional[np.array] = None
accept_num: Optional[np.array] = None
@dataclass
class VisionData:
"""VisionData"""
tokens: np.array
@dataclass
class VedioData:
"""VedioData"""
tokens: np.array
@dataclass
class AudioData:
"""AudioData"""
tokens: np.array
@dataclass @dataclass
class StreamTransferData: class StreamTransferData:
"""StreamTransferData""" """StreamTransferData"""
decoder_state: DecoderState decoder_state: DecoderState
data: Union[TextData, VisionData, VedioData, AudioData] tokens: np.array
batch_id: int
speculaive_decoding: bool = False
logprobs: Optional[np.array] = None
accept_tokens: Optional[np.array] = None
accept_num: Optional[np.array] = None

View File

@@ -31,7 +31,6 @@ from fastdeploy import envs
from fastdeploy.engine.request import CompletionOutput, RequestMetrics, RequestOutput from fastdeploy.engine.request import CompletionOutput, RequestMetrics, RequestOutput
from fastdeploy.inter_communicator import IPCSignal, ZmqClient from fastdeploy.inter_communicator import IPCSignal, ZmqClient
from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData
from fastdeploy.platforms import current_platform from fastdeploy.platforms import current_platform
from fastdeploy.utils import llm_logger, spec_logger from fastdeploy.utils import llm_logger, spec_logger
from fastdeploy.worker.output import LogprobsLists from fastdeploy.worker.output import LogprobsLists
@@ -49,7 +48,6 @@ class TokenProcessor:
""" """
def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, split_connector): def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, split_connector):
import paddle
paddle.device.set_device("cpu") paddle.device.set_device("cpu")
self.cfg = cfg self.cfg = cfg
@@ -60,7 +58,10 @@ class TokenProcessor:
self.split_connector = split_connector self.split_connector = split_connector
if envs.FD_USE_GET_SAVE_OUTPUT_V1: if envs.FD_USE_GET_SAVE_OUTPUT_V1:
self.zmq_server = ZmqClient(name=f"get_save_output_rank{self.cfg.local_device_ids[0]}", mode=zmq.PULL) llm_logger.debug(f"create zmq get_save_output_rank{self.cfg.parallel_config.local_data_parallel_id}")
self.zmq_server = ZmqClient(
name=f"get_save_output_rank{self.cfg.parallel_config.local_data_parallel_id}", mode=zmq.PULL
)
self.zmq_server.start_server() self.zmq_server.start_server()
self.zmq_server.create_router() self.zmq_server.create_router()
@@ -135,11 +136,145 @@ class TokenProcessor:
if self.worker is not None: if self.worker is not None:
raise Exception("Worker is already running!") raise Exception("Worker is already running!")
self.worker = threading.Thread(target=self.process_sampling_results) if envs.FD_USE_GET_SAVE_OUTPUT_V1:
self.worker = threading.Thread(target=self.process_sampling_results_use_zmq)
else:
self.worker = threading.Thread(target=self.process_sampling_results)
self.worker.daemon = True self.worker.daemon = True
self.worker.start() self.worker.start()
def _reschedule_preempt_task(self, batch_size):
"""reschedule when real batch size is smaller than the insert position of preemted_task"""
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
need_to_be_reschedule_req_ids = list(self.resource_manager.to_be_rescheduled_request_id_set)
for request_id in need_to_be_reschedule_req_ids:
if self.resource_manager.requests[request_id].idx >= (
batch_size - 1
): # No more token generated for preempted request
self.resource_manager.reschedule_preempt_task(request_id)
def _process_per_token(self, task, batch_id: int, token_ids: np.ndarray, result: RequestOutput, is_prefill: bool):
"""
process output token by token
"""
current_time = time.time()
task_id = task.request_id
token_id_list = token_ids.tolist()
self._record_metrics(task, current_time, token_id_list)
for token_id in token_id_list:
recovery_stop = token_id == RECOVERY_STOP_SIGNAL
if recovery_stop:
llm_logger.info(f"recovery stop signal found at task {task_id}")
self.tokens_counter[task_id] += 1
if token_id != RECOVERY_STOP_SIGNAL:
result.outputs.token_ids.append(token_id)
task.output_token_ids.append(token_id)
if token_id in task.eos_token_ids or is_prefill or recovery_stop:
result.finished = True
if recovery_stop:
result.error_msg = "Recover is not supported, the result is incomplete!"
llm_logger.info(
f"Request: {task_id} finished, number of " f"generated tokens: {self.tokens_counter[task_id]}."
)
llm_logger.info(
f"Request: {task_id} token ratio: {self.tokens_counter[task_id] / (time.time() - task.inference_start_time)}"
)
llm_logger.info(f"{self.resource_manager.info()}")
if self.cfg.speculative_config.method:
self._compute_speculative_status()
if not is_prefill:
self._record_completion_metrics(task, current_time)
self._recycle_resources(task_id, batch_id, task, result, is_prefill)
break
return result
def _process_batch_output_use_zmq(self, receive_datas):
"""
process output sample by sample
"""
batch_result = list()
for _, stream_data in enumerate(receive_datas):
i = stream_data.batch_id
if self.resource_manager.stop_flags[i]:
continue
task = self.resource_manager.tasks_list[i]
task_id = task.request_id
token_ids = stream_data.tokens # numpy.array
current_time = time.time()
if self.tokens_counter[task_id] == 0:
metrics = RequestMetrics(
arrival_time=task.arrival_time,
inference_start_time=task.inference_start_time,
first_token_time=time.time() - task.inference_start_time,
time_in_queue=task.schedule_start_time - task.preprocess_end_time,
preprocess_cost_time=task.preprocess_end_time - task.preprocess_start_time,
request_start_time=task.arrival_time,
)
self._record_first_token_metrics(task, current_time)
else:
metrics = RequestMetrics(
arrival_time=time.time(),
request_start_time=task.arrival_time,
)
result = RequestOutput(
request_id=task_id,
outputs=CompletionOutput(
index=i,
send_idx=self.tokens_counter[task_id],
token_ids=[],
draft_token_ids=[],
),
finished=False,
metrics=metrics,
)
if self.tokens_counter[task_id] == 0:
if task.messages is not None:
result.prompt = task.messages
result.num_cached_tokens = task.num_cached_tokens
is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill"
result = self._process_per_token(task, i, token_ids, result, is_prefill)
if not is_prefill or self.cfg.scheduler_config.name == "splitwise":
batch_result.append(result)
return batch_result
def process_sampling_results_use_zmq(self):
"""
use zmq to receive outputs from worker and process them
"""
if self.speculative_decoding:
raise NotImplementedError("GET_SAVE_OUTPUT_V1 does not support speculative decoding")
if self.use_logprobs:
raise NotImplementedError("GET_SAVE_OUTPUT_V1 does not support use_logprobs")
rank_id = self.cfg.parallel_config.local_data_parallel_id
while True:
try:
if (
self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1
) or (rank_id == 0):
receive_datas = self.zmq_server.recv_pyobj()
assert isinstance(receive_datas, list)
llm_logger.debug(f"token_processor receive_data {receive_datas}")
batch_size = len(receive_datas)
self._reschedule_preempt_task(batch_size)
batch_result = self._process_batch_output_use_zmq(receive_datas)
self.postprocess(batch_result)
except Exception as e:
llm_logger.error(f"Recieve message error: {e}")
continue
def process_sampling_results(self): def process_sampling_results(self):
""" """
read tokens from paddle inference engine and process read tokens from paddle inference engine and process
@@ -162,54 +297,35 @@ class TokenProcessor:
while True: while True:
try: try:
if envs.FD_USE_GET_SAVE_OUTPUT_V1: is_blocking = True
try: if self.speculative_decoding:
receive_data = self.zmq_server.recv_pyobj() speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
assert isinstance(receive_data, StreamTransferData) if self.output_tokens[0] == -2:
if receive_data is not None:
# TODO(Wanglongzhi2001): adapt more type of message.
if receive_data.decoder_state == DecoderState.TEXT:
self.output_tokens[0, 0] = paddle.to_tensor(
receive_data.data.not_need_stop, dtype="int64"
)
self.output_tokens[1, 0] = paddle.to_tensor(receive_data.data.batch, dtype="int64")
self.output_tokens[2 : 2 + receive_data.data.batch, 0] = paddle.to_tensor(
receive_data.data.tokens[:, 0], dtype="int64"
)
except Exception as e:
print(f"Receive message error: {e}")
continue continue
else: else:
is_blocking = True if (
if self.speculative_decoding: self.cfg.parallel_config.enable_expert_parallel
speculate_get_output(self.output_tokens, rank_id, is_blocking, False) and self.cfg.parallel_config.data_parallel_size > 1
if self.output_tokens[0] == -2: ):
continue get_output_ep(self.output_tokens, rank_id, is_blocking)
else: else:
if ( if self.use_logprobs:
self.cfg.parallel_config.enable_expert_parallel get_output_topk(
and self.cfg.parallel_config.data_parallel_size > 1 self.output_tokens,
): self.output_scores,
get_output_ep(self.output_tokens, rank_id, is_blocking) self.output_ranks,
K,
rank_id,
is_blocking,
)
else: else:
if self.use_logprobs: get_output(self.output_tokens, rank_id, is_blocking)
get_output_topk(
self.output_tokens,
self.output_scores,
self.output_ranks,
K,
rank_id,
is_blocking,
)
else:
get_output(self.output_tokens, rank_id, is_blocking)
if self.output_tokens[0, 0] == -2: if self.output_tokens[0, 0] == -2:
continue continue
llm_logger.debug(f"rank_id {rank_id} self.output_tokens[0, 0] {self.output_tokens[0, 0]}") llm_logger.debug(f"rank_id {rank_id} self.output_tokens[0, 0] {self.output_tokens[0, 0]}")
self._process_prefill_metrics() self._process_prefill_metrics()
self._process_batch_output() self._process_batch_output()
except Exception as e: except Exception as e:
@@ -336,13 +452,8 @@ class TokenProcessor:
tokens = tokens[2 : batch + 2] tokens = tokens[2 : batch + 2]
batch_result = list() batch_result = list()
if envs.ENABLE_V1_KVCACHE_SCHEDULER: # reschedule
need_to_be_reschedule_req_ids = list(self.resource_manager.to_be_rescheduled_request_id_set) self._reschedule_preempt_task(batch)
for request_id in need_to_be_reschedule_req_ids:
if self.resource_manager.requests[request_id].idx >= (
batch - 1
): # No more token generated for preempted request
self.resource_manager.reschedule_preempt_task(request_id)
for i in range(batch): for i in range(batch):
if self.resource_manager.stop_flags[i]: if self.resource_manager.stop_flags[i]:
continue continue

View File

@@ -170,6 +170,7 @@ class GPUModelRunner(ModelRunnerBase):
self.zmq_client = None self.zmq_client = None
if envs.FD_USE_GET_SAVE_OUTPUT_V1: if envs.FD_USE_GET_SAVE_OUTPUT_V1:
logger.info(f"zmq client get_save_output_rank{local_rank}")
self.zmq_client = ZmqClient(name=f"get_save_output_rank{local_rank}", mode=zmq.PUSH) self.zmq_client = ZmqClient(name=f"get_save_output_rank{local_rank}", mode=zmq.PUSH)
self.zmq_client.connect() self.zmq_client.connect()
self.zmq_client.socket.SNDTIMEO = 3000 self.zmq_client.socket.SNDTIMEO = 3000