perf: optimize ZMQ communication with async queue and single-threaded… (#4444)

* perf: optimize ZMQ communication with async queue and single-threaded model

* perf: _async_output_busy_loop

* fix: async_output_queue init
This commit is contained in:
SunLei
2025-10-16 15:46:26 +08:00
committed by GitHub
parent 98f8c3703a
commit 5abf59715d
2 changed files with 36 additions and 20 deletions

View File

@@ -14,7 +14,7 @@
# limitations under the License.
"""
import threading
import queue
from typing import Dict, Optional
import numpy as np
@@ -81,7 +81,6 @@ else:
)
from fastdeploy.inter_communicator import ZmqIpcClient
from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput, SamplerOutput
@@ -163,9 +162,8 @@ def pre_process(
)
def _zmq_send_text_outputs(zmq_client: ZmqIpcClient, output_tokens: np.ndarray, save_each_rank: bool, mp_rank: int):
def _build_stream_transfer_data(output_tokens: np.ndarray):
"""Split output_tokens and output"""
assert zmq_client is not None, "zmq_client should not be None"
output_tokens = output_tokens.reshape([-1]).numpy()
output_tokens_lists = np.split(output_tokens, output_tokens.shape[0])
@@ -175,12 +173,7 @@ def _zmq_send_text_outputs(zmq_client: ZmqIpcClient, output_tokens: np.ndarray,
decoder_state=DecoderState.TEXT, tokens=output_token_per_sample, batch_id=bid
)
stream_transfer_datas.append(stream_transfer_data)
if save_each_rank or mp_rank == 0:
try:
zmq_client.send_pyobj(stream_transfer_datas)
except Exception as e:
print(f"Send message error: {e}")
return stream_transfer_datas
def post_process_normal(
@@ -190,7 +183,7 @@ def post_process_normal(
block_size: int = 64,
save_each_rank: bool = False,
skip_save_output: bool = False,
zmq_client: ZmqIpcClient = None,
async_output_queue: queue.Queue = None,
) -> ModelRunnerOutput:
"""Post-processing steps after completing a single token generation."""
# handle vl:
@@ -319,11 +312,9 @@ def post_process_normal(
if not skip_save_output:
if sampler_output.logprobs_tensors is None:
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
t = threading.Thread(
target=_zmq_send_text_outputs,
args=(zmq_client, sampler_output.sampled_token_ids, save_each_rank, model_output.mp_rank),
)
t.start()
if save_each_rank or model_output.mp_rank == 0:
output = _build_stream_transfer_data(sampler_output.sampled_token_ids)
async_output_queue.put(output)
else:
save_output(
sampler_output.sampled_token_ids,
@@ -394,14 +385,20 @@ def post_process(
save_each_rank: bool = False,
speculative_decoding: bool = False,
skip_save_output: bool = False,
zmq_client: ZmqIpcClient = None,
async_output_queue: queue.Queue = None,
) -> None:
"""Post-processing steps after completing a single token generation."""
if speculative_decoding:
post_process_specualate(model_output, save_each_rank, skip_save_output)
else:
post_process_normal(
sampler_output, model_output, share_inputs, block_size, save_each_rank, skip_save_output, zmq_client
sampler_output,
model_output,
share_inputs,
block_size,
save_each_rank,
skip_save_output,
async_output_queue,
)