[Feat] Support streaming transfer data using ZMQ (#3521)

* Support streaming transfer data of ZMQ

* fix typo

* fix typo

* support tp

* add unittest

* update

* update

* fix typo

* fix typo

* fix tp_num in ci machine

---------

Co-authored-by: Wanglongzhi2001 <>
This commit is contained in:
Longzhi Wang
2025-09-02 19:52:19 +08:00
committed by GitHub
parent 0fe1d62232
commit e0c9a6c76c
6 changed files with 314 additions and 32 deletions

View File

@@ -95,6 +95,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
"FD_FOR_TORCH_MODEL_FORMAT": lambda: bool(int(os.getenv("FD_FOR_TORCH_MODEL_FORMAT", "0"))), "FD_FOR_TORCH_MODEL_FORMAT": lambda: bool(int(os.getenv("FD_FOR_TORCH_MODEL_FORMAT", "0"))),
# force disable default chunked prefill # force disable default chunked prefill
"FD_DISABLE_CHUNKED_PREFILL": lambda: bool(int(os.getenv("FD_DISABLE_CHUNKED_PREFILL", "0"))), "FD_DISABLE_CHUNKED_PREFILL": lambda: bool(int(os.getenv("FD_DISABLE_CHUNKED_PREFILL", "0"))),
# Whether to use new get_output and save_output method (0 or 1)
"FD_USE_GET_SAVE_OUTPUT_V1": lambda: bool(int(os.getenv("FD_USE_GET_SAVE_OUTPUT_V1", "0"))),
} }

View File

@@ -76,6 +76,12 @@ else:
update_inputs_v1, update_inputs_v1,
) )
from fastdeploy.inter_communicator import ZmqClient
from fastdeploy.output.stream_transfer_data import (
DecoderState,
StreamTransferData,
TextData,
)
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput, SamplerOutput from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput, SamplerOutput
DISABLE_RECOVER = envs.FD_DISABLED_RECOVER == "1" DISABLE_RECOVER = envs.FD_DISABLED_RECOVER == "1"
@@ -163,6 +169,7 @@ def post_process_normal(
block_size: int = 64, block_size: int = 64,
save_each_rank: bool = False, save_each_rank: bool = False,
skip_save_output: bool = False, skip_save_output: bool = False,
zmq_client: ZmqClient = None,
) -> ModelRunnerOutput: ) -> ModelRunnerOutput:
"""Post-processing steps after completing a single token generation.""" """Post-processing steps after completing a single token generation."""
# handle vl: # handle vl:
@@ -289,11 +296,29 @@ def post_process_normal(
# In the future, we will abandon this approach. # In the future, we will abandon this approach.
if not skip_save_output: if not skip_save_output:
if sampler_output.logprobs_tensors is None: if sampler_output.logprobs_tensors is None:
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
# TODO(Wanglongzhi2001): adapt more type of message.
stream_transfer_data = StreamTransferData(
decoder_state=DecoderState.TEXT,
data=TextData(
tokens=sampler_output.sampled_token_ids.numpy(),
not_need_stop=model_output.not_need_stop.numpy().item(),
batch=sampler_output.sampled_token_ids.shape[0],
speculaive_decoding=False,
),
)
if not (not save_each_rank and model_output.mp_rank > 0):
try:
zmq_client.send_pyobj(stream_transfer_data)
except Exception as e:
print(f"Send message error: {e}")
else:
save_output( save_output(
sampler_output.sampled_token_ids, sampler_output.sampled_token_ids,
model_output.not_need_stop, model_output.not_need_stop,
model_output.mp_rank, model_output.mp_rank,
save_each_rank, # save_each_rank save_each_rank,
) )
else: else:
save_output_topk( save_output_topk(
@@ -355,12 +380,15 @@ def post_process(
save_each_rank: bool = False, save_each_rank: bool = False,
speculative_decoding: bool = False, speculative_decoding: bool = False,
skip_save_output: bool = False, skip_save_output: bool = False,
zmq_client: ZmqClient = None,
) -> None: ) -> None:
"""Post-processing steps after completing a single token generation.""" """Post-processing steps after completing a single token generation."""
if speculative_decoding: if speculative_decoding:
post_process_specualate(model_output, save_each_rank, skip_save_output) post_process_specualate(model_output, save_each_rank, skip_save_output)
else: else:
post_process_normal(sampler_output, model_output, share_inputs, block_size, save_each_rank, skip_save_output) post_process_normal(
sampler_output, model_output, share_inputs, block_size, save_each_rank, skip_save_output, zmq_client
)
def step_cuda( def step_cuda(

View File

@@ -0,0 +1,72 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Union
import numpy as np
class DecoderState(Enum):
"""DecoderState"""
TEXT = "text"
VISION = "vision"
VEDIO = "vedio"
AUDIO = "audio"
@dataclass
class TextData:
"""TextData"""
tokens: np.array
not_need_stop: bool
batch: int
speculaive_decoding: bool
logprobs: Optional[np.array] = None
accept_tokens: Optional[np.array] = None
accept_num: Optional[np.array] = None
@dataclass
class VisionData:
"""VisionData"""
tokens: np.array
@dataclass
class VedioData:
"""VedioData"""
tokens: np.array
@dataclass
class AudioData:
"""AudioData"""
tokens: np.array
@dataclass
class StreamTransferData:
"""StreamTransferData"""
decoder_state: DecoderState
data: Union[TextData, VisionData, VedioData, AudioData]

View File

@@ -24,11 +24,14 @@ from collections import Counter
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import numpy as np import numpy as np
import paddle
import zmq
from fastdeploy import envs from fastdeploy import envs
from fastdeploy.engine.request import CompletionOutput, RequestMetrics, RequestOutput from fastdeploy.engine.request import CompletionOutput, RequestMetrics, RequestOutput
from fastdeploy.inter_communicator import IPCSignal from fastdeploy.inter_communicator import IPCSignal, ZmqClient
from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData
from fastdeploy.platforms import current_platform from fastdeploy.platforms import current_platform
from fastdeploy.utils import llm_logger, spec_logger from fastdeploy.utils import llm_logger, spec_logger
from fastdeploy.worker.output import LogprobsLists from fastdeploy.worker.output import LogprobsLists
@@ -56,6 +59,11 @@ class TokenProcessor:
self.tokens_counter = Counter() self.tokens_counter = Counter()
self.split_connector = split_connector self.split_connector = split_connector
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
self.zmq_server = ZmqClient(name=f"get_save_output_rank{self.cfg.local_device_ids[0]}", mode=zmq.PULL)
self.zmq_server.start_server()
self.zmq_server.create_router()
self.speculative_decoding = self.cfg.speculative_config.method is not None self.speculative_decoding = self.cfg.speculative_config.method is not None
self.use_logprobs = self.cfg.model_config.enable_logprob self.use_logprobs = self.cfg.model_config.enable_logprob
@@ -154,6 +162,25 @@ class TokenProcessor:
while True: while True:
try: try:
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
try:
receive_data = self.zmq_server.recv_pyobj()
assert isinstance(receive_data, StreamTransferData)
if receive_data is not None:
# TODO(Wanglongzhi2001): adapt more type of message.
if receive_data.decoder_state == DecoderState.TEXT:
self.output_tokens[0, 0] = paddle.to_tensor(
receive_data.data.not_need_stop, dtype="int64"
)
self.output_tokens[1, 0] = paddle.to_tensor(receive_data.data.batch, dtype="int64")
self.output_tokens[2 : 2 + receive_data.data.batch, 0] = paddle.to_tensor(
receive_data.data.tokens[:, 0], dtype="int64"
)
except Exception as e:
print(f"Recieve message error: {e}")
continue
else:
is_blocking = True is_blocking = True
if self.speculative_decoding: if self.speculative_decoding:
speculate_get_output(self.output_tokens, rank_id, is_blocking, False) speculate_get_output(self.output_tokens, rank_id, is_blocking, False)

View File

@@ -70,8 +70,11 @@ from fastdeploy.model_executor.pre_and_post_process import (
if not (current_platform.is_dcu() or current_platform.is_iluvatar()): if not (current_platform.is_dcu() or current_platform.is_iluvatar()):
from fastdeploy.spec_decode import MTPProposer, NgramProposer from fastdeploy.spec_decode import MTPProposer, NgramProposer
import zmq
from fastdeploy import envs from fastdeploy import envs
from fastdeploy.input.ernie4_5_vl_processor import DataProcessor from fastdeploy.input.ernie4_5_vl_processor import DataProcessor
from fastdeploy.inter_communicator import ZmqClient
from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
from fastdeploy.worker.model_runner_base import ModelRunnerBase from fastdeploy.worker.model_runner_base import ModelRunnerBase
@@ -163,6 +166,12 @@ class GPUModelRunner(ModelRunnerBase):
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.parallel_config.engine_worker_queue_port) os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.parallel_config.engine_worker_queue_port)
logger.info(f"queue id is {str(self.parallel_config.engine_worker_queue_port)}") logger.info(f"queue id is {str(self.parallel_config.engine_worker_queue_port)}")
self.zmq_client = None
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
self.zmq_client = ZmqClient(name=f"get_save_output_rank{local_rank}", mode=zmq.PUSH)
self.zmq_client.connect()
self.zmq_client.socket.SNDTIMEO = 3000
def exist_prefill(self): def exist_prefill(self):
""" """
check whether prefill stage exist check whether prefill stage exist
@@ -1219,6 +1228,7 @@ class GPUModelRunner(ModelRunnerBase):
block_size=self.cache_config.block_size, block_size=self.cache_config.block_size,
speculative_decoding=self.speculative_decoding, speculative_decoding=self.speculative_decoding,
skip_save_output=True, skip_save_output=True,
zmq_client=self.zmq_client,
) )
if self.speculative_decoding: if self.speculative_decoding:
@@ -1514,6 +1524,7 @@ class GPUModelRunner(ModelRunnerBase):
save_each_rank=self.parallel_config.use_ep, save_each_rank=self.parallel_config.use_ep,
speculative_decoding=self.speculative_decoding, speculative_decoding=self.speculative_decoding,
skip_save_output=skip_save_output, skip_save_output=skip_save_output,
zmq_client=self.zmq_client,
) )
if self.guided_backend is not None and sampler_output is not None: if self.guided_backend is not None and sampler_output is not None:
self.sampler.post_process(sampler_output.sampled_token_ids, skip_idx_list) self.sampler.post_process(sampler_output.sampled_token_ids, skip_idx_list)

View File

@@ -0,0 +1,142 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import signal
import socket
import subprocess
import time
import traceback
import pytest
from fastdeploy import LLM, SamplingParams
FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8313))
MAX_WAIT_SECONDS = 60
os.environ["LD_LIBRARY_PATH"] = "/usr/local/nccl/"
# enbale get_save_output_v1
os.environ["FD_USE_GET_SAVE_OUTPUT_V1"] = "1"
def is_port_open(host: str, port: int, timeout=1.0):
"""
Check if a TCP port is open on the given host.
Returns True if connection succeeds, False otherwise.
"""
try:
with socket.create_connection((host, port), timeout):
return True
except Exception:
return False
@pytest.fixture(scope="module")
def model_path():
"""
Get model path from environment variable MODEL_PATH,
default to "./ERNIE-4.5-0.3B-Paddle" if not set.
"""
base_path = os.getenv("MODEL_PATH")
if base_path:
return os.path.join(base_path, "ERNIE-4.5-0.3B-Paddle")
else:
return "./ERNIE-4.5-0.3B-Paddle"
@pytest.fixture(scope="module")
def llm(model_path):
"""
Fixture to initialize the LLM model with a given model path
"""
try:
output = subprocess.check_output(f"lsof -i:{FD_ENGINE_QUEUE_PORT} -t", shell=True).decode().strip()
for pid in output.splitlines():
os.kill(int(pid), signal.SIGKILL)
print(f"Killed process on port {FD_ENGINE_QUEUE_PORT}, pid={pid}")
except subprocess.CalledProcessError:
pass
try:
start = time.time()
llm = LLM(
model=model_path,
tensor_parallel_size=2,
num_gpu_blocks_override=1024,
engine_worker_queue_port=FD_ENGINE_QUEUE_PORT,
max_model_len=8192,
seed=1,
)
# Wait for the port to be open
wait_start = time.time()
while not is_port_open("127.0.0.1", FD_ENGINE_QUEUE_PORT):
if time.time() - wait_start > MAX_WAIT_SECONDS:
pytest.fail(
f"Model engine did not start within {MAX_WAIT_SECONDS} seconds on port {FD_ENGINE_QUEUE_PORT}"
)
time.sleep(1)
print(f"Model loaded successfully from {model_path} in {time.time() - start:.2f}s.")
yield llm
except Exception:
print(f"Failed to load model from {model_path}.")
traceback.print_exc()
pytest.fail(f"Failed to initialize LLM model from {model_path}")
def test_generate_prompts(llm):
"""
Test basic prompt generation
"""
# Only one prompt enabled for testing currently
prompts = [
"请介绍一下中国的四大发明。",
"太阳和地球之间的距离是多少?",
"写一首关于春天的古风诗。",
]
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
)
try:
outputs = llm.generate(prompts, sampling_params)
# Verify basic properties of the outputs
assert len(outputs) == len(prompts), "Number of outputs should match number of prompts"
for i, output in enumerate(outputs):
assert output.prompt == prompts[i], f"Prompt mismatch for case {i + 1}"
assert isinstance(output.outputs.text, str), f"Output text should be string for case {i + 1}"
assert len(output.outputs.text) > 0, f"Generated text should not be empty for case {i + 1}"
assert isinstance(output.finished, bool), f"'finished' should be boolean for case {i + 1}"
assert output.metrics.model_execute_time > 0, f"Execution time should be positive for case {i + 1}"
print(f"=== Prompt generation Case {i + 1} Passed ===")
except Exception:
print("Failed during prompt generation.")
traceback.print_exc()
pytest.fail("Prompt generation test failed")
if __name__ == "__main__":
"""
Main entry point for the test script.
"""
pytest.main(["-sv", __file__])