mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feat] Support streaming transfer data using ZMQ (#3521)
* Support streaming transfer data of ZMQ * fix typo * fix typo * support tp * add unittest * update * update * fix typo * fix typo * fix tp_num in ci machine --------- Co-authored-by: Wanglongzhi2001 <>
This commit is contained in:
@@ -76,6 +76,12 @@ else:
|
||||
update_inputs_v1,
|
||||
)
|
||||
|
||||
from fastdeploy.inter_communicator import ZmqClient
|
||||
from fastdeploy.output.stream_transfer_data import (
|
||||
DecoderState,
|
||||
StreamTransferData,
|
||||
TextData,
|
||||
)
|
||||
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput, SamplerOutput
|
||||
|
||||
DISABLE_RECOVER = envs.FD_DISABLED_RECOVER == "1"
|
||||
@@ -163,6 +169,7 @@ def post_process_normal(
|
||||
block_size: int = 64,
|
||||
save_each_rank: bool = False,
|
||||
skip_save_output: bool = False,
|
||||
zmq_client: ZmqClient = None,
|
||||
) -> ModelRunnerOutput:
|
||||
"""Post-processing steps after completing a single token generation."""
|
||||
# handle vl:
|
||||
@@ -289,12 +296,30 @@ def post_process_normal(
|
||||
# In the future, we will abandon this approach.
|
||||
if not skip_save_output:
|
||||
if sampler_output.logprobs_tensors is None:
|
||||
save_output(
|
||||
sampler_output.sampled_token_ids,
|
||||
model_output.not_need_stop,
|
||||
model_output.mp_rank,
|
||||
save_each_rank, # save_each_rank
|
||||
)
|
||||
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
|
||||
# TODO(Wanglongzhi2001): adapt more type of message.
|
||||
stream_transfer_data = StreamTransferData(
|
||||
decoder_state=DecoderState.TEXT,
|
||||
data=TextData(
|
||||
tokens=sampler_output.sampled_token_ids.numpy(),
|
||||
not_need_stop=model_output.not_need_stop.numpy().item(),
|
||||
batch=sampler_output.sampled_token_ids.shape[0],
|
||||
speculaive_decoding=False,
|
||||
),
|
||||
)
|
||||
|
||||
if not (not save_each_rank and model_output.mp_rank > 0):
|
||||
try:
|
||||
zmq_client.send_pyobj(stream_transfer_data)
|
||||
except Exception as e:
|
||||
print(f"Send message error: {e}")
|
||||
else:
|
||||
save_output(
|
||||
sampler_output.sampled_token_ids,
|
||||
model_output.not_need_stop,
|
||||
model_output.mp_rank,
|
||||
save_each_rank,
|
||||
)
|
||||
else:
|
||||
save_output_topk(
|
||||
sampler_output.sampled_token_ids,
|
||||
@@ -355,12 +380,15 @@ def post_process(
|
||||
save_each_rank: bool = False,
|
||||
speculative_decoding: bool = False,
|
||||
skip_save_output: bool = False,
|
||||
zmq_client: ZmqClient = None,
|
||||
) -> None:
|
||||
"""Post-processing steps after completing a single token generation."""
|
||||
if speculative_decoding:
|
||||
post_process_specualate(model_output, save_each_rank, skip_save_output)
|
||||
else:
|
||||
post_process_normal(sampler_output, model_output, share_inputs, block_size, save_each_rank, skip_save_output)
|
||||
post_process_normal(
|
||||
sampler_output, model_output, share_inputs, block_size, save_each_rank, skip_save_output, zmq_client
|
||||
)
|
||||
|
||||
|
||||
def step_cuda(
|
||||
|
||||
Reference in New Issue
Block a user