[Feat] Support streaming transfer data using ZMQ (#3521)

* Support streaming transfer data of ZMQ

* fix typo

* fix typo

* support tp

* add unittest

* update

* update

* fix typo

* fix typo

* fix tp_num in ci machine

---------

Co-authored-by: Wanglongzhi2001 <>
This commit is contained in:
Longzhi Wang
2025-09-02 19:52:19 +08:00
committed by GitHub
parent 0fe1d62232
commit e0c9a6c76c
6 changed files with 314 additions and 32 deletions

View File

@@ -76,6 +76,12 @@ else:
update_inputs_v1,
)
from fastdeploy.inter_communicator import ZmqClient
from fastdeploy.output.stream_transfer_data import (
DecoderState,
StreamTransferData,
TextData,
)
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput, SamplerOutput
DISABLE_RECOVER = envs.FD_DISABLED_RECOVER == "1"
@@ -163,6 +169,7 @@ def post_process_normal(
block_size: int = 64,
save_each_rank: bool = False,
skip_save_output: bool = False,
zmq_client: ZmqClient = None,
) -> ModelRunnerOutput:
"""Post-processing steps after completing a single token generation."""
# handle vl:
@@ -289,12 +296,30 @@ def post_process_normal(
# In the future, we will abandon this approach.
if not skip_save_output:
if sampler_output.logprobs_tensors is None:
save_output(
sampler_output.sampled_token_ids,
model_output.not_need_stop,
model_output.mp_rank,
save_each_rank, # save_each_rank
)
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
# TODO(Wanglongzhi2001): adapt more type of message.
stream_transfer_data = StreamTransferData(
decoder_state=DecoderState.TEXT,
data=TextData(
tokens=sampler_output.sampled_token_ids.numpy(),
not_need_stop=model_output.not_need_stop.numpy().item(),
batch=sampler_output.sampled_token_ids.shape[0],
speculaive_decoding=False,
),
)
if not (not save_each_rank and model_output.mp_rank > 0):
try:
zmq_client.send_pyobj(stream_transfer_data)
except Exception as e:
print(f"Send message error: {e}")
else:
save_output(
sampler_output.sampled_token_ids,
model_output.not_need_stop,
model_output.mp_rank,
save_each_rank,
)
else:
save_output_topk(
sampler_output.sampled_token_ids,
@@ -355,12 +380,15 @@ def post_process(
save_each_rank: bool = False,
speculative_decoding: bool = False,
skip_save_output: bool = False,
zmq_client: ZmqClient = None,
) -> None:
"""Post-processing steps after completing a single token generation."""
if speculative_decoding:
post_process_specualate(model_output, save_each_rank, skip_save_output)
else:
post_process_normal(sampler_output, model_output, share_inputs, block_size, save_each_rank, skip_save_output)
post_process_normal(
sampler_output, model_output, share_inputs, block_size, save_each_rank, skip_save_output, zmq_client
)
def step_cuda(