mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-08 18:11:00 +08:00
[Feat] Support streaming transfer data using ZMQ (#3521)
* Support streaming transfer data of ZMQ * fix typo * fix typo * support tp * add unittest * update * update * fix typo * fix typo * fix tp_num in ci machine --------- Co-authored-by: Wanglongzhi2001 <>
This commit is contained in:
@@ -95,6 +95,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"FD_FOR_TORCH_MODEL_FORMAT": lambda: bool(int(os.getenv("FD_FOR_TORCH_MODEL_FORMAT", "0"))),
|
"FD_FOR_TORCH_MODEL_FORMAT": lambda: bool(int(os.getenv("FD_FOR_TORCH_MODEL_FORMAT", "0"))),
|
||||||
# force disable default chunked prefill
|
# force disable default chunked prefill
|
||||||
"FD_DISABLE_CHUNKED_PREFILL": lambda: bool(int(os.getenv("FD_DISABLE_CHUNKED_PREFILL", "0"))),
|
"FD_DISABLE_CHUNKED_PREFILL": lambda: bool(int(os.getenv("FD_DISABLE_CHUNKED_PREFILL", "0"))),
|
||||||
|
# Whether to use new get_output and save_output method (0 or 1)
|
||||||
|
"FD_USE_GET_SAVE_OUTPUT_V1": lambda: bool(int(os.getenv("FD_USE_GET_SAVE_OUTPUT_V1", "0"))),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@@ -76,6 +76,12 @@ else:
|
|||||||
update_inputs_v1,
|
update_inputs_v1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from fastdeploy.inter_communicator import ZmqClient
|
||||||
|
from fastdeploy.output.stream_transfer_data import (
|
||||||
|
DecoderState,
|
||||||
|
StreamTransferData,
|
||||||
|
TextData,
|
||||||
|
)
|
||||||
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput, SamplerOutput
|
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput, SamplerOutput
|
||||||
|
|
||||||
DISABLE_RECOVER = envs.FD_DISABLED_RECOVER == "1"
|
DISABLE_RECOVER = envs.FD_DISABLED_RECOVER == "1"
|
||||||
@@ -163,6 +169,7 @@ def post_process_normal(
|
|||||||
block_size: int = 64,
|
block_size: int = 64,
|
||||||
save_each_rank: bool = False,
|
save_each_rank: bool = False,
|
||||||
skip_save_output: bool = False,
|
skip_save_output: bool = False,
|
||||||
|
zmq_client: ZmqClient = None,
|
||||||
) -> ModelRunnerOutput:
|
) -> ModelRunnerOutput:
|
||||||
"""Post-processing steps after completing a single token generation."""
|
"""Post-processing steps after completing a single token generation."""
|
||||||
# handle vl:
|
# handle vl:
|
||||||
@@ -289,11 +296,29 @@ def post_process_normal(
|
|||||||
# In the future, we will abandon this approach.
|
# In the future, we will abandon this approach.
|
||||||
if not skip_save_output:
|
if not skip_save_output:
|
||||||
if sampler_output.logprobs_tensors is None:
|
if sampler_output.logprobs_tensors is None:
|
||||||
|
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
|
||||||
|
# TODO(Wanglongzhi2001): adapt more type of message.
|
||||||
|
stream_transfer_data = StreamTransferData(
|
||||||
|
decoder_state=DecoderState.TEXT,
|
||||||
|
data=TextData(
|
||||||
|
tokens=sampler_output.sampled_token_ids.numpy(),
|
||||||
|
not_need_stop=model_output.not_need_stop.numpy().item(),
|
||||||
|
batch=sampler_output.sampled_token_ids.shape[0],
|
||||||
|
speculaive_decoding=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not (not save_each_rank and model_output.mp_rank > 0):
|
||||||
|
try:
|
||||||
|
zmq_client.send_pyobj(stream_transfer_data)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Send message error: {e}")
|
||||||
|
else:
|
||||||
save_output(
|
save_output(
|
||||||
sampler_output.sampled_token_ids,
|
sampler_output.sampled_token_ids,
|
||||||
model_output.not_need_stop,
|
model_output.not_need_stop,
|
||||||
model_output.mp_rank,
|
model_output.mp_rank,
|
||||||
save_each_rank, # save_each_rank
|
save_each_rank,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
save_output_topk(
|
save_output_topk(
|
||||||
@@ -355,12 +380,15 @@ def post_process(
|
|||||||
save_each_rank: bool = False,
|
save_each_rank: bool = False,
|
||||||
speculative_decoding: bool = False,
|
speculative_decoding: bool = False,
|
||||||
skip_save_output: bool = False,
|
skip_save_output: bool = False,
|
||||||
|
zmq_client: ZmqClient = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Post-processing steps after completing a single token generation."""
|
"""Post-processing steps after completing a single token generation."""
|
||||||
if speculative_decoding:
|
if speculative_decoding:
|
||||||
post_process_specualate(model_output, save_each_rank, skip_save_output)
|
post_process_specualate(model_output, save_each_rank, skip_save_output)
|
||||||
else:
|
else:
|
||||||
post_process_normal(sampler_output, model_output, share_inputs, block_size, save_each_rank, skip_save_output)
|
post_process_normal(
|
||||||
|
sampler_output, model_output, share_inputs, block_size, save_each_rank, skip_save_output, zmq_client
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def step_cuda(
|
def step_cuda(
|
||||||
|
72
fastdeploy/output/stream_transfer_data.py
Normal file
72
fastdeploy/output/stream_transfer_data.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class DecoderState(Enum):
|
||||||
|
"""DecoderState"""
|
||||||
|
|
||||||
|
TEXT = "text"
|
||||||
|
VISION = "vision"
|
||||||
|
VEDIO = "vedio"
|
||||||
|
AUDIO = "audio"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TextData:
|
||||||
|
"""TextData"""
|
||||||
|
|
||||||
|
tokens: np.array
|
||||||
|
not_need_stop: bool
|
||||||
|
batch: int
|
||||||
|
speculaive_decoding: bool
|
||||||
|
logprobs: Optional[np.array] = None
|
||||||
|
accept_tokens: Optional[np.array] = None
|
||||||
|
accept_num: Optional[np.array] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VisionData:
|
||||||
|
"""VisionData"""
|
||||||
|
|
||||||
|
tokens: np.array
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VedioData:
|
||||||
|
"""VedioData"""
|
||||||
|
|
||||||
|
tokens: np.array
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AudioData:
|
||||||
|
"""AudioData"""
|
||||||
|
|
||||||
|
tokens: np.array
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StreamTransferData:
|
||||||
|
"""StreamTransferData"""
|
||||||
|
|
||||||
|
decoder_state: DecoderState
|
||||||
|
data: Union[TextData, VisionData, VedioData, AudioData]
|
@@ -24,11 +24,14 @@ from collections import Counter
|
|||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
import zmq
|
||||||
|
|
||||||
from fastdeploy import envs
|
from fastdeploy import envs
|
||||||
from fastdeploy.engine.request import CompletionOutput, RequestMetrics, RequestOutput
|
from fastdeploy.engine.request import CompletionOutput, RequestMetrics, RequestOutput
|
||||||
from fastdeploy.inter_communicator import IPCSignal
|
from fastdeploy.inter_communicator import IPCSignal, ZmqClient
|
||||||
from fastdeploy.metrics.metrics import main_process_metrics
|
from fastdeploy.metrics.metrics import main_process_metrics
|
||||||
|
from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData
|
||||||
from fastdeploy.platforms import current_platform
|
from fastdeploy.platforms import current_platform
|
||||||
from fastdeploy.utils import llm_logger, spec_logger
|
from fastdeploy.utils import llm_logger, spec_logger
|
||||||
from fastdeploy.worker.output import LogprobsLists
|
from fastdeploy.worker.output import LogprobsLists
|
||||||
@@ -56,6 +59,11 @@ class TokenProcessor:
|
|||||||
self.tokens_counter = Counter()
|
self.tokens_counter = Counter()
|
||||||
self.split_connector = split_connector
|
self.split_connector = split_connector
|
||||||
|
|
||||||
|
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
|
||||||
|
self.zmq_server = ZmqClient(name=f"get_save_output_rank{self.cfg.local_device_ids[0]}", mode=zmq.PULL)
|
||||||
|
self.zmq_server.start_server()
|
||||||
|
self.zmq_server.create_router()
|
||||||
|
|
||||||
self.speculative_decoding = self.cfg.speculative_config.method is not None
|
self.speculative_decoding = self.cfg.speculative_config.method is not None
|
||||||
self.use_logprobs = self.cfg.model_config.enable_logprob
|
self.use_logprobs = self.cfg.model_config.enable_logprob
|
||||||
|
|
||||||
@@ -154,6 +162,25 @@ class TokenProcessor:
|
|||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
|
||||||
|
try:
|
||||||
|
receive_data = self.zmq_server.recv_pyobj()
|
||||||
|
assert isinstance(receive_data, StreamTransferData)
|
||||||
|
if receive_data is not None:
|
||||||
|
# TODO(Wanglongzhi2001): adapt more type of message.
|
||||||
|
if receive_data.decoder_state == DecoderState.TEXT:
|
||||||
|
self.output_tokens[0, 0] = paddle.to_tensor(
|
||||||
|
receive_data.data.not_need_stop, dtype="int64"
|
||||||
|
)
|
||||||
|
self.output_tokens[1, 0] = paddle.to_tensor(receive_data.data.batch, dtype="int64")
|
||||||
|
self.output_tokens[2 : 2 + receive_data.data.batch, 0] = paddle.to_tensor(
|
||||||
|
receive_data.data.tokens[:, 0], dtype="int64"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Recieve message error: {e}")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
is_blocking = True
|
is_blocking = True
|
||||||
if self.speculative_decoding:
|
if self.speculative_decoding:
|
||||||
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
|
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
|
||||||
|
@@ -70,8 +70,11 @@ from fastdeploy.model_executor.pre_and_post_process import (
|
|||||||
if not (current_platform.is_dcu() or current_platform.is_iluvatar()):
|
if not (current_platform.is_dcu() or current_platform.is_iluvatar()):
|
||||||
from fastdeploy.spec_decode import MTPProposer, NgramProposer
|
from fastdeploy.spec_decode import MTPProposer, NgramProposer
|
||||||
|
|
||||||
|
import zmq
|
||||||
|
|
||||||
from fastdeploy import envs
|
from fastdeploy import envs
|
||||||
from fastdeploy.input.ernie4_5_vl_processor import DataProcessor
|
from fastdeploy.input.ernie4_5_vl_processor import DataProcessor
|
||||||
|
from fastdeploy.inter_communicator import ZmqClient
|
||||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||||
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
|
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
|
||||||
from fastdeploy.worker.model_runner_base import ModelRunnerBase
|
from fastdeploy.worker.model_runner_base import ModelRunnerBase
|
||||||
@@ -163,6 +166,12 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.parallel_config.engine_worker_queue_port)
|
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.parallel_config.engine_worker_queue_port)
|
||||||
logger.info(f"queue id is {str(self.parallel_config.engine_worker_queue_port)}")
|
logger.info(f"queue id is {str(self.parallel_config.engine_worker_queue_port)}")
|
||||||
|
|
||||||
|
self.zmq_client = None
|
||||||
|
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
|
||||||
|
self.zmq_client = ZmqClient(name=f"get_save_output_rank{local_rank}", mode=zmq.PUSH)
|
||||||
|
self.zmq_client.connect()
|
||||||
|
self.zmq_client.socket.SNDTIMEO = 3000
|
||||||
|
|
||||||
def exist_prefill(self):
|
def exist_prefill(self):
|
||||||
"""
|
"""
|
||||||
check whether prefill stage exist
|
check whether prefill stage exist
|
||||||
@@ -1219,6 +1228,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
block_size=self.cache_config.block_size,
|
block_size=self.cache_config.block_size,
|
||||||
speculative_decoding=self.speculative_decoding,
|
speculative_decoding=self.speculative_decoding,
|
||||||
skip_save_output=True,
|
skip_save_output=True,
|
||||||
|
zmq_client=self.zmq_client,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.speculative_decoding:
|
if self.speculative_decoding:
|
||||||
@@ -1514,6 +1524,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
save_each_rank=self.parallel_config.use_ep,
|
save_each_rank=self.parallel_config.use_ep,
|
||||||
speculative_decoding=self.speculative_decoding,
|
speculative_decoding=self.speculative_decoding,
|
||||||
skip_save_output=skip_save_output,
|
skip_save_output=skip_save_output,
|
||||||
|
zmq_client=self.zmq_client,
|
||||||
)
|
)
|
||||||
if self.guided_backend is not None and sampler_output is not None:
|
if self.guided_backend is not None and sampler_output is not None:
|
||||||
self.sampler.post_process(sampler_output.sampled_token_ids, skip_idx_list)
|
self.sampler.post_process(sampler_output.sampled_token_ids, skip_idx_list)
|
||||||
|
142
tests/output/test_get_save_output_v1.py
Normal file
142
tests/output/test_get_save_output_v1.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
|
import socket
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from fastdeploy import LLM, SamplingParams
|
||||||
|
|
||||||
|
FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8313))
|
||||||
|
MAX_WAIT_SECONDS = 60
|
||||||
|
|
||||||
|
os.environ["LD_LIBRARY_PATH"] = "/usr/local/nccl/"
|
||||||
|
# enbale get_save_output_v1
|
||||||
|
os.environ["FD_USE_GET_SAVE_OUTPUT_V1"] = "1"
|
||||||
|
|
||||||
|
|
||||||
|
def is_port_open(host: str, port: int, timeout=1.0):
|
||||||
|
"""
|
||||||
|
Check if a TCP port is open on the given host.
|
||||||
|
Returns True if connection succeeds, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with socket.create_connection((host, port), timeout):
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def model_path():
|
||||||
|
"""
|
||||||
|
Get model path from environment variable MODEL_PATH,
|
||||||
|
default to "./ERNIE-4.5-0.3B-Paddle" if not set.
|
||||||
|
"""
|
||||||
|
base_path = os.getenv("MODEL_PATH")
|
||||||
|
if base_path:
|
||||||
|
return os.path.join(base_path, "ERNIE-4.5-0.3B-Paddle")
|
||||||
|
else:
|
||||||
|
return "./ERNIE-4.5-0.3B-Paddle"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def llm(model_path):
|
||||||
|
"""
|
||||||
|
Fixture to initialize the LLM model with a given model path
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
output = subprocess.check_output(f"lsof -i:{FD_ENGINE_QUEUE_PORT} -t", shell=True).decode().strip()
|
||||||
|
for pid in output.splitlines():
|
||||||
|
os.kill(int(pid), signal.SIGKILL)
|
||||||
|
print(f"Killed process on port {FD_ENGINE_QUEUE_PORT}, pid={pid}")
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
start = time.time()
|
||||||
|
llm = LLM(
|
||||||
|
model=model_path,
|
||||||
|
tensor_parallel_size=2,
|
||||||
|
num_gpu_blocks_override=1024,
|
||||||
|
engine_worker_queue_port=FD_ENGINE_QUEUE_PORT,
|
||||||
|
max_model_len=8192,
|
||||||
|
seed=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for the port to be open
|
||||||
|
wait_start = time.time()
|
||||||
|
while not is_port_open("127.0.0.1", FD_ENGINE_QUEUE_PORT):
|
||||||
|
if time.time() - wait_start > MAX_WAIT_SECONDS:
|
||||||
|
pytest.fail(
|
||||||
|
f"Model engine did not start within {MAX_WAIT_SECONDS} seconds on port {FD_ENGINE_QUEUE_PORT}"
|
||||||
|
)
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
print(f"Model loaded successfully from {model_path} in {time.time() - start:.2f}s.")
|
||||||
|
yield llm
|
||||||
|
except Exception:
|
||||||
|
print(f"Failed to load model from {model_path}.")
|
||||||
|
traceback.print_exc()
|
||||||
|
pytest.fail(f"Failed to initialize LLM model from {model_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_prompts(llm):
|
||||||
|
"""
|
||||||
|
Test basic prompt generation
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Only one prompt enabled for testing currently
|
||||||
|
prompts = [
|
||||||
|
"请介绍一下中国的四大发明。",
|
||||||
|
"太阳和地球之间的距离是多少?",
|
||||||
|
"写一首关于春天的古风诗。",
|
||||||
|
]
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0.8,
|
||||||
|
top_p=0.95,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
# Verify basic properties of the outputs
|
||||||
|
assert len(outputs) == len(prompts), "Number of outputs should match number of prompts"
|
||||||
|
|
||||||
|
for i, output in enumerate(outputs):
|
||||||
|
assert output.prompt == prompts[i], f"Prompt mismatch for case {i + 1}"
|
||||||
|
assert isinstance(output.outputs.text, str), f"Output text should be string for case {i + 1}"
|
||||||
|
assert len(output.outputs.text) > 0, f"Generated text should not be empty for case {i + 1}"
|
||||||
|
assert isinstance(output.finished, bool), f"'finished' should be boolean for case {i + 1}"
|
||||||
|
assert output.metrics.model_execute_time > 0, f"Execution time should be positive for case {i + 1}"
|
||||||
|
|
||||||
|
print(f"=== Prompt generation Case {i + 1} Passed ===")
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
print("Failed during prompt generation.")
|
||||||
|
traceback.print_exc()
|
||||||
|
pytest.fail("Prompt generation test failed")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
"""
|
||||||
|
Main entry point for the test script.
|
||||||
|
"""
|
||||||
|
pytest.main(["-sv", __file__])
|
Reference in New Issue
Block a user