[Feat] Support streaming transfer data using ZMQ (#3521)

* Support streaming transfer data of ZMQ

* fix typo

* fix typo

* support tp

* add unittest

* update

* update

* fix typo

* fix typo

* fix tp_num in ci machine

---------

Co-authored-by: Wanglongzhi2001 <>
This commit is contained in:
Longzhi Wang
2025-09-02 19:52:19 +08:00
committed by GitHub
parent 0fe1d62232
commit e0c9a6c76c
6 changed files with 314 additions and 32 deletions

View File

@@ -70,8 +70,11 @@ from fastdeploy.model_executor.pre_and_post_process import (
if not (current_platform.is_dcu() or current_platform.is_iluvatar()):
from fastdeploy.spec_decode import MTPProposer, NgramProposer
import zmq
from fastdeploy import envs
from fastdeploy.input.ernie4_5_vl_processor import DataProcessor
from fastdeploy.inter_communicator import ZmqClient
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
from fastdeploy.worker.model_runner_base import ModelRunnerBase
@@ -163,6 +166,12 @@ class GPUModelRunner(ModelRunnerBase):
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.parallel_config.engine_worker_queue_port)
logger.info(f"queue id is {str(self.parallel_config.engine_worker_queue_port)}")
self.zmq_client = None
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
self.zmq_client = ZmqClient(name=f"get_save_output_rank{local_rank}", mode=zmq.PUSH)
self.zmq_client.connect()
self.zmq_client.socket.SNDTIMEO = 3000
def exist_prefill(self):
"""
check whether prefill stage exist
@@ -1219,6 +1228,7 @@ class GPUModelRunner(ModelRunnerBase):
block_size=self.cache_config.block_size,
speculative_decoding=self.speculative_decoding,
skip_save_output=True,
zmq_client=self.zmq_client,
)
if self.speculative_decoding:
@@ -1514,6 +1524,7 @@ class GPUModelRunner(ModelRunnerBase):
save_each_rank=self.parallel_config.use_ep,
speculative_decoding=self.speculative_decoding,
skip_save_output=skip_save_output,
zmq_client=self.zmq_client,
)
if self.guided_backend is not None and sampler_output is not None:
self.sampler.post_process(sampler_output.sampled_token_ids, skip_idx_list)