mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-08 10:00:29 +08:00
[Feat] Support streaming transfer data using ZMQ (#3521)
* Support streaming transfer data of ZMQ * fix typo * fix typo * support tp * add unittest * update * update * fix typo * fix typo * fix tp_num in ci machine --------- Co-authored-by: Wanglongzhi2001 <>
This commit is contained in:
@@ -95,6 +95,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"FD_FOR_TORCH_MODEL_FORMAT": lambda: bool(int(os.getenv("FD_FOR_TORCH_MODEL_FORMAT", "0"))),
|
||||
# force disable default chunked prefill
|
||||
"FD_DISABLE_CHUNKED_PREFILL": lambda: bool(int(os.getenv("FD_DISABLE_CHUNKED_PREFILL", "0"))),
|
||||
# Whether to use new get_output and save_output method (0 or 1)
|
||||
"FD_USE_GET_SAVE_OUTPUT_V1": lambda: bool(int(os.getenv("FD_USE_GET_SAVE_OUTPUT_V1", "0"))),
|
||||
}
|
||||
|
||||
|
||||
|
@@ -76,6 +76,12 @@ else:
|
||||
update_inputs_v1,
|
||||
)
|
||||
|
||||
from fastdeploy.inter_communicator import ZmqClient
|
||||
from fastdeploy.output.stream_transfer_data import (
|
||||
DecoderState,
|
||||
StreamTransferData,
|
||||
TextData,
|
||||
)
|
||||
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput, SamplerOutput
|
||||
|
||||
DISABLE_RECOVER = envs.FD_DISABLED_RECOVER == "1"
|
||||
@@ -163,6 +169,7 @@ def post_process_normal(
|
||||
block_size: int = 64,
|
||||
save_each_rank: bool = False,
|
||||
skip_save_output: bool = False,
|
||||
zmq_client: ZmqClient = None,
|
||||
) -> ModelRunnerOutput:
|
||||
"""Post-processing steps after completing a single token generation."""
|
||||
# handle vl:
|
||||
@@ -289,11 +296,29 @@ def post_process_normal(
|
||||
# In the future, we will abandon this approach.
|
||||
if not skip_save_output:
|
||||
if sampler_output.logprobs_tensors is None:
|
||||
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
|
||||
# TODO(Wanglongzhi2001): adapt more type of message.
|
||||
stream_transfer_data = StreamTransferData(
|
||||
decoder_state=DecoderState.TEXT,
|
||||
data=TextData(
|
||||
tokens=sampler_output.sampled_token_ids.numpy(),
|
||||
not_need_stop=model_output.not_need_stop.numpy().item(),
|
||||
batch=sampler_output.sampled_token_ids.shape[0],
|
||||
speculaive_decoding=False,
|
||||
),
|
||||
)
|
||||
|
||||
if not (not save_each_rank and model_output.mp_rank > 0):
|
||||
try:
|
||||
zmq_client.send_pyobj(stream_transfer_data)
|
||||
except Exception as e:
|
||||
print(f"Send message error: {e}")
|
||||
else:
|
||||
save_output(
|
||||
sampler_output.sampled_token_ids,
|
||||
model_output.not_need_stop,
|
||||
model_output.mp_rank,
|
||||
save_each_rank, # save_each_rank
|
||||
save_each_rank,
|
||||
)
|
||||
else:
|
||||
save_output_topk(
|
||||
@@ -355,12 +380,15 @@ def post_process(
|
||||
save_each_rank: bool = False,
|
||||
speculative_decoding: bool = False,
|
||||
skip_save_output: bool = False,
|
||||
zmq_client: ZmqClient = None,
|
||||
) -> None:
|
||||
"""Post-processing steps after completing a single token generation."""
|
||||
if speculative_decoding:
|
||||
post_process_specualate(model_output, save_each_rank, skip_save_output)
|
||||
else:
|
||||
post_process_normal(sampler_output, model_output, share_inputs, block_size, save_each_rank, skip_save_output)
|
||||
post_process_normal(
|
||||
sampler_output, model_output, share_inputs, block_size, save_each_rank, skip_save_output, zmq_client
|
||||
)
|
||||
|
||||
|
||||
def step_cuda(
|
||||
|
72
fastdeploy/output/stream_transfer_data.py
Normal file
72
fastdeploy/output/stream_transfer_data.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class DecoderState(Enum):
|
||||
"""DecoderState"""
|
||||
|
||||
TEXT = "text"
|
||||
VISION = "vision"
|
||||
VEDIO = "vedio"
|
||||
AUDIO = "audio"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextData:
|
||||
"""TextData"""
|
||||
|
||||
tokens: np.array
|
||||
not_need_stop: bool
|
||||
batch: int
|
||||
speculaive_decoding: bool
|
||||
logprobs: Optional[np.array] = None
|
||||
accept_tokens: Optional[np.array] = None
|
||||
accept_num: Optional[np.array] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class VisionData:
|
||||
"""VisionData"""
|
||||
|
||||
tokens: np.array
|
||||
|
||||
|
||||
@dataclass
|
||||
class VedioData:
|
||||
"""VedioData"""
|
||||
|
||||
tokens: np.array
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioData:
|
||||
"""AudioData"""
|
||||
|
||||
tokens: np.array
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamTransferData:
|
||||
"""StreamTransferData"""
|
||||
|
||||
decoder_state: DecoderState
|
||||
data: Union[TextData, VisionData, VedioData, AudioData]
|
@@ -24,11 +24,14 @@ from collections import Counter
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import zmq
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.engine.request import CompletionOutput, RequestMetrics, RequestOutput
|
||||
from fastdeploy.inter_communicator import IPCSignal
|
||||
from fastdeploy.inter_communicator import IPCSignal, ZmqClient
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.utils import llm_logger, spec_logger
|
||||
from fastdeploy.worker.output import LogprobsLists
|
||||
@@ -56,6 +59,11 @@ class TokenProcessor:
|
||||
self.tokens_counter = Counter()
|
||||
self.split_connector = split_connector
|
||||
|
||||
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
|
||||
self.zmq_server = ZmqClient(name=f"get_save_output_rank{self.cfg.local_device_ids[0]}", mode=zmq.PULL)
|
||||
self.zmq_server.start_server()
|
||||
self.zmq_server.create_router()
|
||||
|
||||
self.speculative_decoding = self.cfg.speculative_config.method is not None
|
||||
self.use_logprobs = self.cfg.model_config.enable_logprob
|
||||
|
||||
@@ -154,6 +162,25 @@ class TokenProcessor:
|
||||
|
||||
while True:
|
||||
try:
|
||||
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
|
||||
try:
|
||||
receive_data = self.zmq_server.recv_pyobj()
|
||||
assert isinstance(receive_data, StreamTransferData)
|
||||
if receive_data is not None:
|
||||
# TODO(Wanglongzhi2001): adapt more type of message.
|
||||
if receive_data.decoder_state == DecoderState.TEXT:
|
||||
self.output_tokens[0, 0] = paddle.to_tensor(
|
||||
receive_data.data.not_need_stop, dtype="int64"
|
||||
)
|
||||
self.output_tokens[1, 0] = paddle.to_tensor(receive_data.data.batch, dtype="int64")
|
||||
self.output_tokens[2 : 2 + receive_data.data.batch, 0] = paddle.to_tensor(
|
||||
receive_data.data.tokens[:, 0], dtype="int64"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Recieve message error: {e}")
|
||||
continue
|
||||
else:
|
||||
is_blocking = True
|
||||
if self.speculative_decoding:
|
||||
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
|
||||
|
@@ -70,8 +70,11 @@ from fastdeploy.model_executor.pre_and_post_process import (
|
||||
if not (current_platform.is_dcu() or current_platform.is_iluvatar()):
|
||||
from fastdeploy.spec_decode import MTPProposer, NgramProposer
|
||||
|
||||
import zmq
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.input.ernie4_5_vl_processor import DataProcessor
|
||||
from fastdeploy.inter_communicator import ZmqClient
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
|
||||
from fastdeploy.worker.model_runner_base import ModelRunnerBase
|
||||
@@ -163,6 +166,12 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.parallel_config.engine_worker_queue_port)
|
||||
logger.info(f"queue id is {str(self.parallel_config.engine_worker_queue_port)}")
|
||||
|
||||
self.zmq_client = None
|
||||
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
|
||||
self.zmq_client = ZmqClient(name=f"get_save_output_rank{local_rank}", mode=zmq.PUSH)
|
||||
self.zmq_client.connect()
|
||||
self.zmq_client.socket.SNDTIMEO = 3000
|
||||
|
||||
def exist_prefill(self):
|
||||
"""
|
||||
check whether prefill stage exist
|
||||
@@ -1219,6 +1228,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
block_size=self.cache_config.block_size,
|
||||
speculative_decoding=self.speculative_decoding,
|
||||
skip_save_output=True,
|
||||
zmq_client=self.zmq_client,
|
||||
)
|
||||
|
||||
if self.speculative_decoding:
|
||||
@@ -1514,6 +1524,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
save_each_rank=self.parallel_config.use_ep,
|
||||
speculative_decoding=self.speculative_decoding,
|
||||
skip_save_output=skip_save_output,
|
||||
zmq_client=self.zmq_client,
|
||||
)
|
||||
if self.guided_backend is not None and sampler_output is not None:
|
||||
self.sampler.post_process(sampler_output.sampled_token_ids, skip_idx_list)
|
||||
|
142
tests/output/test_get_save_output_v1.py
Normal file
142
tests/output/test_get_save_output_v1.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import signal
|
||||
import socket
|
||||
import subprocess
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import pytest
|
||||
|
||||
from fastdeploy import LLM, SamplingParams
|
||||
|
||||
FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8313))
|
||||
MAX_WAIT_SECONDS = 60
|
||||
|
||||
os.environ["LD_LIBRARY_PATH"] = "/usr/local/nccl/"
|
||||
# enbale get_save_output_v1
|
||||
os.environ["FD_USE_GET_SAVE_OUTPUT_V1"] = "1"
|
||||
|
||||
|
||||
def is_port_open(host: str, port: int, timeout=1.0):
|
||||
"""
|
||||
Check if a TCP port is open on the given host.
|
||||
Returns True if connection succeeds, False otherwise.
|
||||
"""
|
||||
try:
|
||||
with socket.create_connection((host, port), timeout):
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def model_path():
|
||||
"""
|
||||
Get model path from environment variable MODEL_PATH,
|
||||
default to "./ERNIE-4.5-0.3B-Paddle" if not set.
|
||||
"""
|
||||
base_path = os.getenv("MODEL_PATH")
|
||||
if base_path:
|
||||
return os.path.join(base_path, "ERNIE-4.5-0.3B-Paddle")
|
||||
else:
|
||||
return "./ERNIE-4.5-0.3B-Paddle"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def llm(model_path):
|
||||
"""
|
||||
Fixture to initialize the LLM model with a given model path
|
||||
"""
|
||||
try:
|
||||
output = subprocess.check_output(f"lsof -i:{FD_ENGINE_QUEUE_PORT} -t", shell=True).decode().strip()
|
||||
for pid in output.splitlines():
|
||||
os.kill(int(pid), signal.SIGKILL)
|
||||
print(f"Killed process on port {FD_ENGINE_QUEUE_PORT}, pid={pid}")
|
||||
except subprocess.CalledProcessError:
|
||||
pass
|
||||
|
||||
try:
|
||||
start = time.time()
|
||||
llm = LLM(
|
||||
model=model_path,
|
||||
tensor_parallel_size=2,
|
||||
num_gpu_blocks_override=1024,
|
||||
engine_worker_queue_port=FD_ENGINE_QUEUE_PORT,
|
||||
max_model_len=8192,
|
||||
seed=1,
|
||||
)
|
||||
|
||||
# Wait for the port to be open
|
||||
wait_start = time.time()
|
||||
while not is_port_open("127.0.0.1", FD_ENGINE_QUEUE_PORT):
|
||||
if time.time() - wait_start > MAX_WAIT_SECONDS:
|
||||
pytest.fail(
|
||||
f"Model engine did not start within {MAX_WAIT_SECONDS} seconds on port {FD_ENGINE_QUEUE_PORT}"
|
||||
)
|
||||
time.sleep(1)
|
||||
|
||||
print(f"Model loaded successfully from {model_path} in {time.time() - start:.2f}s.")
|
||||
yield llm
|
||||
except Exception:
|
||||
print(f"Failed to load model from {model_path}.")
|
||||
traceback.print_exc()
|
||||
pytest.fail(f"Failed to initialize LLM model from {model_path}")
|
||||
|
||||
|
||||
def test_generate_prompts(llm):
|
||||
"""
|
||||
Test basic prompt generation
|
||||
"""
|
||||
|
||||
# Only one prompt enabled for testing currently
|
||||
prompts = [
|
||||
"请介绍一下中国的四大发明。",
|
||||
"太阳和地球之间的距离是多少?",
|
||||
"写一首关于春天的古风诗。",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
)
|
||||
|
||||
try:
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# Verify basic properties of the outputs
|
||||
assert len(outputs) == len(prompts), "Number of outputs should match number of prompts"
|
||||
|
||||
for i, output in enumerate(outputs):
|
||||
assert output.prompt == prompts[i], f"Prompt mismatch for case {i + 1}"
|
||||
assert isinstance(output.outputs.text, str), f"Output text should be string for case {i + 1}"
|
||||
assert len(output.outputs.text) > 0, f"Generated text should not be empty for case {i + 1}"
|
||||
assert isinstance(output.finished, bool), f"'finished' should be boolean for case {i + 1}"
|
||||
assert output.metrics.model_execute_time > 0, f"Execution time should be positive for case {i + 1}"
|
||||
|
||||
print(f"=== Prompt generation Case {i + 1} Passed ===")
|
||||
|
||||
except Exception:
|
||||
print("Failed during prompt generation.")
|
||||
traceback.print_exc()
|
||||
pytest.fail("Prompt generation test failed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
Main entry point for the test script.
|
||||
"""
|
||||
pytest.main(["-sv", __file__])
|
Reference in New Issue
Block a user