mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 09:07:10 +08:00
[Feat] Support streaming transfer data using ZMQ (#3521)
* Support streaming transfer data of ZMQ * fix typo * fix typo * support tp * add unittest * update * update * fix typo * fix typo * fix tp_num in ci machine --------- Co-authored-by: Wanglongzhi2001 <>
This commit is contained in:
@@ -24,11 +24,14 @@ from collections import Counter
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import zmq
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.engine.request import CompletionOutput, RequestMetrics, RequestOutput
|
||||
from fastdeploy.inter_communicator import IPCSignal
|
||||
from fastdeploy.inter_communicator import IPCSignal, ZmqClient
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.utils import llm_logger, spec_logger
|
||||
from fastdeploy.worker.output import LogprobsLists
|
||||
@@ -56,6 +59,11 @@ class TokenProcessor:
|
||||
self.tokens_counter = Counter()
|
||||
self.split_connector = split_connector
|
||||
|
||||
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
|
||||
self.zmq_server = ZmqClient(name=f"get_save_output_rank{self.cfg.local_device_ids[0]}", mode=zmq.PULL)
|
||||
self.zmq_server.start_server()
|
||||
self.zmq_server.create_router()
|
||||
|
||||
self.speculative_decoding = self.cfg.speculative_config.method is not None
|
||||
self.use_logprobs = self.cfg.model_config.enable_logprob
|
||||
|
||||
@@ -154,35 +162,54 @@ class TokenProcessor:
|
||||
|
||||
while True:
|
||||
try:
|
||||
is_blocking = True
|
||||
if self.speculative_decoding:
|
||||
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
|
||||
if self.output_tokens[0] == -2:
|
||||
continue
|
||||
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
|
||||
try:
|
||||
receive_data = self.zmq_server.recv_pyobj()
|
||||
assert isinstance(receive_data, StreamTransferData)
|
||||
if receive_data is not None:
|
||||
# TODO(Wanglongzhi2001): adapt more type of message.
|
||||
if receive_data.decoder_state == DecoderState.TEXT:
|
||||
self.output_tokens[0, 0] = paddle.to_tensor(
|
||||
receive_data.data.not_need_stop, dtype="int64"
|
||||
)
|
||||
self.output_tokens[1, 0] = paddle.to_tensor(receive_data.data.batch, dtype="int64")
|
||||
self.output_tokens[2 : 2 + receive_data.data.batch, 0] = paddle.to_tensor(
|
||||
receive_data.data.tokens[:, 0], dtype="int64"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Recieve message error: {e}")
|
||||
continue
|
||||
else:
|
||||
if (
|
||||
self.cfg.parallel_config.enable_expert_parallel
|
||||
and self.cfg.parallel_config.data_parallel_size > 1
|
||||
):
|
||||
get_output_ep(self.output_tokens, rank_id, is_blocking)
|
||||
is_blocking = True
|
||||
if self.speculative_decoding:
|
||||
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
|
||||
if self.output_tokens[0] == -2:
|
||||
continue
|
||||
|
||||
else:
|
||||
if self.use_logprobs:
|
||||
get_output_topk(
|
||||
self.output_tokens,
|
||||
self.output_scores,
|
||||
self.output_ranks,
|
||||
K,
|
||||
rank_id,
|
||||
is_blocking,
|
||||
)
|
||||
else:
|
||||
get_output(self.output_tokens, rank_id, is_blocking)
|
||||
if (
|
||||
self.cfg.parallel_config.enable_expert_parallel
|
||||
and self.cfg.parallel_config.data_parallel_size > 1
|
||||
):
|
||||
get_output_ep(self.output_tokens, rank_id, is_blocking)
|
||||
|
||||
if self.output_tokens[0, 0] == -2:
|
||||
continue
|
||||
llm_logger.debug(f"rank_id {rank_id} self.output_tokens[0, 0] {self.output_tokens[0, 0]}")
|
||||
else:
|
||||
if self.use_logprobs:
|
||||
get_output_topk(
|
||||
self.output_tokens,
|
||||
self.output_scores,
|
||||
self.output_ranks,
|
||||
K,
|
||||
rank_id,
|
||||
is_blocking,
|
||||
)
|
||||
else:
|
||||
get_output(self.output_tokens, rank_id, is_blocking)
|
||||
|
||||
if self.output_tokens[0, 0] == -2:
|
||||
continue
|
||||
llm_logger.debug(f"rank_id {rank_id} self.output_tokens[0, 0] {self.output_tokens[0, 0]}")
|
||||
self._process_prefill_metrics()
|
||||
self._process_batch_output()
|
||||
except Exception as e:
|
||||
|
Reference in New Issue
Block a user