[Feat] Support streaming transfer data using ZMQ (#3521)

* Support streaming transfer data of ZMQ

* fix typo

* fix typo

* support tp

* add unittest

* update

* update

* fix typo

* fix typo

* fix tp_num in ci machine

---------

Co-authored-by: Wanglongzhi2001 <>
This commit is contained in:
Longzhi Wang
2025-09-02 19:52:19 +08:00
committed by GitHub
parent 0fe1d62232
commit e0c9a6c76c
6 changed files with 314 additions and 32 deletions

View File

@@ -24,11 +24,14 @@ from collections import Counter
from concurrent.futures import ThreadPoolExecutor
import numpy as np
import paddle
import zmq
from fastdeploy import envs
from fastdeploy.engine.request import CompletionOutput, RequestMetrics, RequestOutput
from fastdeploy.inter_communicator import IPCSignal
from fastdeploy.inter_communicator import IPCSignal, ZmqClient
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData
from fastdeploy.platforms import current_platform
from fastdeploy.utils import llm_logger, spec_logger
from fastdeploy.worker.output import LogprobsLists
@@ -56,6 +59,11 @@ class TokenProcessor:
self.tokens_counter = Counter()
self.split_connector = split_connector
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
self.zmq_server = ZmqClient(name=f"get_save_output_rank{self.cfg.local_device_ids[0]}", mode=zmq.PULL)
self.zmq_server.start_server()
self.zmq_server.create_router()
self.speculative_decoding = self.cfg.speculative_config.method is not None
self.use_logprobs = self.cfg.model_config.enable_logprob
@@ -154,35 +162,54 @@ class TokenProcessor:
while True:
try:
is_blocking = True
if self.speculative_decoding:
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
if self.output_tokens[0] == -2:
continue
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
try:
receive_data = self.zmq_server.recv_pyobj()
assert isinstance(receive_data, StreamTransferData)
if receive_data is not None:
# TODO(Wanglongzhi2001): adapt more type of message.
if receive_data.decoder_state == DecoderState.TEXT:
self.output_tokens[0, 0] = paddle.to_tensor(
receive_data.data.not_need_stop, dtype="int64"
)
self.output_tokens[1, 0] = paddle.to_tensor(receive_data.data.batch, dtype="int64")
self.output_tokens[2 : 2 + receive_data.data.batch, 0] = paddle.to_tensor(
receive_data.data.tokens[:, 0], dtype="int64"
)
except Exception as e:
print(f"Recieve message error: {e}")
continue
else:
if (
self.cfg.parallel_config.enable_expert_parallel
and self.cfg.parallel_config.data_parallel_size > 1
):
get_output_ep(self.output_tokens, rank_id, is_blocking)
is_blocking = True
if self.speculative_decoding:
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
if self.output_tokens[0] == -2:
continue
else:
if self.use_logprobs:
get_output_topk(
self.output_tokens,
self.output_scores,
self.output_ranks,
K,
rank_id,
is_blocking,
)
else:
get_output(self.output_tokens, rank_id, is_blocking)
if (
self.cfg.parallel_config.enable_expert_parallel
and self.cfg.parallel_config.data_parallel_size > 1
):
get_output_ep(self.output_tokens, rank_id, is_blocking)
if self.output_tokens[0, 0] == -2:
continue
llm_logger.debug(f"rank_id {rank_id} self.output_tokens[0, 0] {self.output_tokens[0, 0]}")
else:
if self.use_logprobs:
get_output_topk(
self.output_tokens,
self.output_scores,
self.output_ranks,
K,
rank_id,
is_blocking,
)
else:
get_output(self.output_tokens, rank_id, is_blocking)
if self.output_tokens[0, 0] == -2:
continue
llm_logger.debug(f"rank_id {rank_id} self.output_tokens[0, 0] {self.output_tokens[0, 0]}")
self._process_prefill_metrics()
self._process_batch_output()
except Exception as e: