mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
perf: optimize ZMQ communication with async queue and single-threaded… (#4444)
* perf: optimize ZMQ communication with async queue and single-threaded model * perf: _async_output_busy_loop * fix: async_output_queue init
This commit is contained in:
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import queue
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
@@ -81,7 +81,6 @@ else:
|
||||
)
|
||||
|
||||
|
||||
from fastdeploy.inter_communicator import ZmqIpcClient
|
||||
from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData
|
||||
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput, SamplerOutput
|
||||
|
||||
@@ -163,9 +162,8 @@ def pre_process(
|
||||
)
|
||||
|
||||
|
||||
def _zmq_send_text_outputs(zmq_client: ZmqIpcClient, output_tokens: np.ndarray, save_each_rank: bool, mp_rank: int):
|
||||
def _build_stream_transfer_data(output_tokens: np.ndarray):
|
||||
"""Split output_tokens and output"""
|
||||
assert zmq_client is not None, "zmq_client should not be None"
|
||||
output_tokens = output_tokens.reshape([-1]).numpy()
|
||||
output_tokens_lists = np.split(output_tokens, output_tokens.shape[0])
|
||||
|
||||
@@ -175,12 +173,7 @@ def _zmq_send_text_outputs(zmq_client: ZmqIpcClient, output_tokens: np.ndarray,
|
||||
decoder_state=DecoderState.TEXT, tokens=output_token_per_sample, batch_id=bid
|
||||
)
|
||||
stream_transfer_datas.append(stream_transfer_data)
|
||||
|
||||
if save_each_rank or mp_rank == 0:
|
||||
try:
|
||||
zmq_client.send_pyobj(stream_transfer_datas)
|
||||
except Exception as e:
|
||||
print(f"Send message error: {e}")
|
||||
return stream_transfer_datas
|
||||
|
||||
|
||||
def post_process_normal(
|
||||
@@ -190,7 +183,7 @@ def post_process_normal(
|
||||
block_size: int = 64,
|
||||
save_each_rank: bool = False,
|
||||
skip_save_output: bool = False,
|
||||
zmq_client: ZmqIpcClient = None,
|
||||
async_output_queue: queue.Queue = None,
|
||||
) -> ModelRunnerOutput:
|
||||
"""Post-processing steps after completing a single token generation."""
|
||||
# handle vl:
|
||||
@@ -319,11 +312,9 @@ def post_process_normal(
|
||||
if not skip_save_output:
|
||||
if sampler_output.logprobs_tensors is None:
|
||||
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
|
||||
t = threading.Thread(
|
||||
target=_zmq_send_text_outputs,
|
||||
args=(zmq_client, sampler_output.sampled_token_ids, save_each_rank, model_output.mp_rank),
|
||||
)
|
||||
t.start()
|
||||
if save_each_rank or model_output.mp_rank == 0:
|
||||
output = _build_stream_transfer_data(sampler_output.sampled_token_ids)
|
||||
async_output_queue.put(output)
|
||||
else:
|
||||
save_output(
|
||||
sampler_output.sampled_token_ids,
|
||||
@@ -394,14 +385,20 @@ def post_process(
|
||||
save_each_rank: bool = False,
|
||||
speculative_decoding: bool = False,
|
||||
skip_save_output: bool = False,
|
||||
zmq_client: ZmqIpcClient = None,
|
||||
async_output_queue: queue.Queue = None,
|
||||
) -> None:
|
||||
"""Post-processing steps after completing a single token generation."""
|
||||
if speculative_decoding:
|
||||
post_process_specualate(model_output, save_each_rank, skip_save_output)
|
||||
else:
|
||||
post_process_normal(
|
||||
sampler_output, model_output, share_inputs, block_size, save_each_rank, skip_save_output, zmq_client
|
||||
sampler_output,
|
||||
model_output,
|
||||
share_inputs,
|
||||
block_size,
|
||||
save_each_rank,
|
||||
skip_save_output,
|
||||
async_output_queue,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user