Reconstruct streaming data transfer with zmq (#3836)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

* reconstruct USE_GET_SAVE_OUTPUT_V1

* fix ut

* use dp rank

* fix ci
This commit is contained in:
RichardWooSJTU
2025-09-17 14:30:39 +08:00
committed by GitHub
parent f9766f917b
commit 2adca04f1f
4 changed files with 201 additions and 111 deletions

View File

@@ -14,8 +14,10 @@
# limitations under the License.
"""
import threading
from typing import Dict, Optional
import numpy as np
import paddle
from fastdeploy import envs
@@ -77,11 +79,7 @@ else:
)
from fastdeploy.inter_communicator import ZmqClient
from fastdeploy.output.stream_transfer_data import (
DecoderState,
StreamTransferData,
TextData,
)
from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput, SamplerOutput
DISABLE_RECOVER = envs.FD_DISABLED_RECOVER == "1"
@@ -162,6 +160,26 @@ def pre_process(
)
def _zmq_send_text_outputs(zmq_client: ZmqClient, output_tokens: np.ndarray, save_each_rank: bool, mp_rank: int):
"""Split output_tokens and output"""
assert zmq_client is not None, "zmq_client should not be None"
output_tokens = output_tokens.reshape([-1]).numpy()
output_tokens_lists = np.split(output_tokens, output_tokens.shape[0])
stream_transfer_datas = []
for bid, output_token_per_sample in enumerate(output_tokens_lists):
stream_transfer_data = StreamTransferData(
decoder_state=DecoderState.TEXT, tokens=output_token_per_sample, batch_id=bid
)
stream_transfer_datas.append(stream_transfer_data)
if save_each_rank or mp_rank == 0:
try:
zmq_client.send_pyobj(stream_transfer_datas)
except Exception as e:
print(f"Send message error: {e}")
def post_process_normal(
sampler_output: SamplerOutput,
model_output: ModelOutputData,
@@ -297,22 +315,11 @@ def post_process_normal(
if not skip_save_output:
if sampler_output.logprobs_tensors is None:
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
# TODO(Wanglongzhi2001): adapt more type of message.
stream_transfer_data = StreamTransferData(
decoder_state=DecoderState.TEXT,
data=TextData(
tokens=sampler_output.sampled_token_ids.numpy(),
not_need_stop=model_output.not_need_stop.numpy().item(),
batch=sampler_output.sampled_token_ids.shape[0],
speculaive_decoding=False,
),
t = threading.Thread(
target=_zmq_send_text_outputs,
args=(zmq_client, sampler_output.sampled_token_ids, save_each_rank, model_output.mp_rank),
)
if not (not save_each_rank and model_output.mp_rank > 0):
try:
zmq_client.send_pyobj(stream_transfer_data)
except Exception as e:
print(f"Send message error: {e}")
t.start()
else:
save_output(
sampler_output.sampled_token_ids,