[Feat] Support streaming transfer data using ZMQ (#3521)

* Support streaming transfer data of ZMQ

* fix typo

* fix typo

* support tp

* add unittest

* update

* update

* fix typo

* fix typo

* fix tp_num in ci machine

---------

Co-authored-by: Wanglongzhi2001 <>
This commit is contained in:
Longzhi Wang
2025-09-02 19:52:19 +08:00
committed by GitHub
parent 0fe1d62232
commit e0c9a6c76c
6 changed files with 314 additions and 32 deletions

View File

@@ -95,6 +95,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
"FD_FOR_TORCH_MODEL_FORMAT": lambda: bool(int(os.getenv("FD_FOR_TORCH_MODEL_FORMAT", "0"))),
# force disable default chunked prefill
"FD_DISABLE_CHUNKED_PREFILL": lambda: bool(int(os.getenv("FD_DISABLE_CHUNKED_PREFILL", "0"))),
# Whether to use new get_output and save_output method (0 or 1)
"FD_USE_GET_SAVE_OUTPUT_V1": lambda: bool(int(os.getenv("FD_USE_GET_SAVE_OUTPUT_V1", "0"))),
}

View File

@@ -76,6 +76,12 @@ else:
update_inputs_v1,
)
from fastdeploy.inter_communicator import ZmqClient
from fastdeploy.output.stream_transfer_data import (
DecoderState,
StreamTransferData,
TextData,
)
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput, SamplerOutput
DISABLE_RECOVER = envs.FD_DISABLED_RECOVER == "1"
@@ -163,6 +169,7 @@ def post_process_normal(
block_size: int = 64,
save_each_rank: bool = False,
skip_save_output: bool = False,
zmq_client: ZmqClient = None,
) -> ModelRunnerOutput:
"""Post-processing steps after completing a single token generation."""
# handle vl:
@@ -289,11 +296,29 @@ def post_process_normal(
# In the future, we will abandon this approach.
if not skip_save_output:
if sampler_output.logprobs_tensors is None:
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
# TODO(Wanglongzhi2001): adapt more type of message.
stream_transfer_data = StreamTransferData(
decoder_state=DecoderState.TEXT,
data=TextData(
tokens=sampler_output.sampled_token_ids.numpy(),
not_need_stop=model_output.not_need_stop.numpy().item(),
batch=sampler_output.sampled_token_ids.shape[0],
speculaive_decoding=False,
),
)
if not (not save_each_rank and model_output.mp_rank > 0):
try:
zmq_client.send_pyobj(stream_transfer_data)
except Exception as e:
print(f"Send message error: {e}")
else:
save_output(
sampler_output.sampled_token_ids,
model_output.not_need_stop,
model_output.mp_rank,
save_each_rank, # save_each_rank
save_each_rank,
)
else:
save_output_topk(
@@ -355,12 +380,15 @@ def post_process(
save_each_rank: bool = False,
speculative_decoding: bool = False,
skip_save_output: bool = False,
zmq_client: ZmqClient = None,
) -> None:
"""Post-processing steps after completing a single token generation."""
if speculative_decoding:
post_process_specualate(model_output, save_each_rank, skip_save_output)
else:
post_process_normal(sampler_output, model_output, share_inputs, block_size, save_each_rank, skip_save_output)
post_process_normal(
sampler_output, model_output, share_inputs, block_size, save_each_rank, skip_save_output, zmq_client
)
def step_cuda(

View File

@@ -0,0 +1,72 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Union
import numpy as np
class DecoderState(Enum):
"""DecoderState"""
TEXT = "text"
VISION = "vision"
VEDIO = "vedio"
AUDIO = "audio"
@dataclass
class TextData:
"""TextData"""
tokens: np.array
not_need_stop: bool
batch: int
speculaive_decoding: bool
logprobs: Optional[np.array] = None
accept_tokens: Optional[np.array] = None
accept_num: Optional[np.array] = None
@dataclass
class VisionData:
"""VisionData"""
tokens: np.array
@dataclass
class VedioData:
"""VedioData"""
tokens: np.array
@dataclass
class AudioData:
"""AudioData"""
tokens: np.array
@dataclass
class StreamTransferData:
"""StreamTransferData"""
decoder_state: DecoderState
data: Union[TextData, VisionData, VedioData, AudioData]

View File

@@ -24,11 +24,14 @@ from collections import Counter
from concurrent.futures import ThreadPoolExecutor
import numpy as np
import paddle
import zmq
from fastdeploy import envs
from fastdeploy.engine.request import CompletionOutput, RequestMetrics, RequestOutput
from fastdeploy.inter_communicator import IPCSignal
from fastdeploy.inter_communicator import IPCSignal, ZmqClient
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData
from fastdeploy.platforms import current_platform
from fastdeploy.utils import llm_logger, spec_logger
from fastdeploy.worker.output import LogprobsLists
@@ -56,6 +59,11 @@ class TokenProcessor:
self.tokens_counter = Counter()
self.split_connector = split_connector
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
self.zmq_server = ZmqClient(name=f"get_save_output_rank{self.cfg.local_device_ids[0]}", mode=zmq.PULL)
self.zmq_server.start_server()
self.zmq_server.create_router()
self.speculative_decoding = self.cfg.speculative_config.method is not None
self.use_logprobs = self.cfg.model_config.enable_logprob
@@ -154,6 +162,25 @@ class TokenProcessor:
while True:
try:
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
try:
receive_data = self.zmq_server.recv_pyobj()
assert isinstance(receive_data, StreamTransferData)
if receive_data is not None:
# TODO(Wanglongzhi2001): adapt more type of message.
if receive_data.decoder_state == DecoderState.TEXT:
self.output_tokens[0, 0] = paddle.to_tensor(
receive_data.data.not_need_stop, dtype="int64"
)
self.output_tokens[1, 0] = paddle.to_tensor(receive_data.data.batch, dtype="int64")
self.output_tokens[2 : 2 + receive_data.data.batch, 0] = paddle.to_tensor(
receive_data.data.tokens[:, 0], dtype="int64"
)
except Exception as e:
print(f"Recieve message error: {e}")
continue
else:
is_blocking = True
if self.speculative_decoding:
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)

View File

@@ -70,8 +70,11 @@ from fastdeploy.model_executor.pre_and_post_process import (
if not (current_platform.is_dcu() or current_platform.is_iluvatar()):
from fastdeploy.spec_decode import MTPProposer, NgramProposer
import zmq
from fastdeploy import envs
from fastdeploy.input.ernie4_5_vl_processor import DataProcessor
from fastdeploy.inter_communicator import ZmqClient
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
from fastdeploy.worker.model_runner_base import ModelRunnerBase
@@ -163,6 +166,12 @@ class GPUModelRunner(ModelRunnerBase):
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.parallel_config.engine_worker_queue_port)
logger.info(f"queue id is {str(self.parallel_config.engine_worker_queue_port)}")
self.zmq_client = None
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
self.zmq_client = ZmqClient(name=f"get_save_output_rank{local_rank}", mode=zmq.PUSH)
self.zmq_client.connect()
self.zmq_client.socket.SNDTIMEO = 3000
def exist_prefill(self):
"""
check whether prefill stage exist
@@ -1219,6 +1228,7 @@ class GPUModelRunner(ModelRunnerBase):
block_size=self.cache_config.block_size,
speculative_decoding=self.speculative_decoding,
skip_save_output=True,
zmq_client=self.zmq_client,
)
if self.speculative_decoding:
@@ -1514,6 +1524,7 @@ class GPUModelRunner(ModelRunnerBase):
save_each_rank=self.parallel_config.use_ep,
speculative_decoding=self.speculative_decoding,
skip_save_output=skip_save_output,
zmq_client=self.zmq_client,
)
if self.guided_backend is not None and sampler_output is not None:
self.sampler.post_process(sampler_output.sampled_token_ids, skip_idx_list)

View File

@@ -0,0 +1,142 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import signal
import socket
import subprocess
import time
import traceback
import pytest
from fastdeploy import LLM, SamplingParams
FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8313))
MAX_WAIT_SECONDS = 60
os.environ["LD_LIBRARY_PATH"] = "/usr/local/nccl/"
# enbale get_save_output_v1
os.environ["FD_USE_GET_SAVE_OUTPUT_V1"] = "1"
def is_port_open(host: str, port: int, timeout=1.0):
"""
Check if a TCP port is open on the given host.
Returns True if connection succeeds, False otherwise.
"""
try:
with socket.create_connection((host, port), timeout):
return True
except Exception:
return False
@pytest.fixture(scope="module")
def model_path():
"""
Get model path from environment variable MODEL_PATH,
default to "./ERNIE-4.5-0.3B-Paddle" if not set.
"""
base_path = os.getenv("MODEL_PATH")
if base_path:
return os.path.join(base_path, "ERNIE-4.5-0.3B-Paddle")
else:
return "./ERNIE-4.5-0.3B-Paddle"
@pytest.fixture(scope="module")
def llm(model_path):
"""
Fixture to initialize the LLM model with a given model path
"""
try:
output = subprocess.check_output(f"lsof -i:{FD_ENGINE_QUEUE_PORT} -t", shell=True).decode().strip()
for pid in output.splitlines():
os.kill(int(pid), signal.SIGKILL)
print(f"Killed process on port {FD_ENGINE_QUEUE_PORT}, pid={pid}")
except subprocess.CalledProcessError:
pass
try:
start = time.time()
llm = LLM(
model=model_path,
tensor_parallel_size=2,
num_gpu_blocks_override=1024,
engine_worker_queue_port=FD_ENGINE_QUEUE_PORT,
max_model_len=8192,
seed=1,
)
# Wait for the port to be open
wait_start = time.time()
while not is_port_open("127.0.0.1", FD_ENGINE_QUEUE_PORT):
if time.time() - wait_start > MAX_WAIT_SECONDS:
pytest.fail(
f"Model engine did not start within {MAX_WAIT_SECONDS} seconds on port {FD_ENGINE_QUEUE_PORT}"
)
time.sleep(1)
print(f"Model loaded successfully from {model_path} in {time.time() - start:.2f}s.")
yield llm
except Exception:
print(f"Failed to load model from {model_path}.")
traceback.print_exc()
pytest.fail(f"Failed to initialize LLM model from {model_path}")
def test_generate_prompts(llm):
"""
Test basic prompt generation
"""
# Only one prompt enabled for testing currently
prompts = [
"请介绍一下中国的四大发明。",
"太阳和地球之间的距离是多少?",
"写一首关于春天的古风诗。",
]
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
)
try:
outputs = llm.generate(prompts, sampling_params)
# Verify basic properties of the outputs
assert len(outputs) == len(prompts), "Number of outputs should match number of prompts"
for i, output in enumerate(outputs):
assert output.prompt == prompts[i], f"Prompt mismatch for case {i + 1}"
assert isinstance(output.outputs.text, str), f"Output text should be string for case {i + 1}"
assert len(output.outputs.text) > 0, f"Generated text should not be empty for case {i + 1}"
assert isinstance(output.finished, bool), f"'finished' should be boolean for case {i + 1}"
assert output.metrics.model_execute_time > 0, f"Execution time should be positive for case {i + 1}"
print(f"=== Prompt generation Case {i + 1} Passed ===")
except Exception:
print("Failed during prompt generation.")
traceback.print_exc()
pytest.fail("Prompt generation test failed")
if __name__ == "__main__":
"""
Main entry point for the test script.
"""
pytest.main(["-sv", __file__])