mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-10 02:50:19 +08:00
Reconstruct streaming data transfer with zmq (#3836)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
* reconstruct USE_GET_SAVE_OUTPUT_V1 * fix ut * use dp rank * fix ci
This commit is contained in:
@@ -14,8 +14,10 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy import envs
|
||||
@@ -77,11 +79,7 @@ else:
|
||||
)
|
||||
|
||||
from fastdeploy.inter_communicator import ZmqClient
|
||||
from fastdeploy.output.stream_transfer_data import (
|
||||
DecoderState,
|
||||
StreamTransferData,
|
||||
TextData,
|
||||
)
|
||||
from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData
|
||||
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput, SamplerOutput
|
||||
|
||||
DISABLE_RECOVER = envs.FD_DISABLED_RECOVER == "1"
|
||||
@@ -162,6 +160,26 @@ def pre_process(
|
||||
)
|
||||
|
||||
|
||||
def _zmq_send_text_outputs(zmq_client: ZmqClient, output_tokens: np.ndarray, save_each_rank: bool, mp_rank: int):
|
||||
"""Split output_tokens and output"""
|
||||
assert zmq_client is not None, "zmq_client should not be None"
|
||||
output_tokens = output_tokens.reshape([-1]).numpy()
|
||||
output_tokens_lists = np.split(output_tokens, output_tokens.shape[0])
|
||||
|
||||
stream_transfer_datas = []
|
||||
for bid, output_token_per_sample in enumerate(output_tokens_lists):
|
||||
stream_transfer_data = StreamTransferData(
|
||||
decoder_state=DecoderState.TEXT, tokens=output_token_per_sample, batch_id=bid
|
||||
)
|
||||
stream_transfer_datas.append(stream_transfer_data)
|
||||
|
||||
if save_each_rank or mp_rank == 0:
|
||||
try:
|
||||
zmq_client.send_pyobj(stream_transfer_datas)
|
||||
except Exception as e:
|
||||
print(f"Send message error: {e}")
|
||||
|
||||
|
||||
def post_process_normal(
|
||||
sampler_output: SamplerOutput,
|
||||
model_output: ModelOutputData,
|
||||
@@ -297,22 +315,11 @@ def post_process_normal(
|
||||
if not skip_save_output:
|
||||
if sampler_output.logprobs_tensors is None:
|
||||
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
|
||||
# TODO(Wanglongzhi2001): adapt more type of message.
|
||||
stream_transfer_data = StreamTransferData(
|
||||
decoder_state=DecoderState.TEXT,
|
||||
data=TextData(
|
||||
tokens=sampler_output.sampled_token_ids.numpy(),
|
||||
not_need_stop=model_output.not_need_stop.numpy().item(),
|
||||
batch=sampler_output.sampled_token_ids.shape[0],
|
||||
speculaive_decoding=False,
|
||||
),
|
||||
t = threading.Thread(
|
||||
target=_zmq_send_text_outputs,
|
||||
args=(zmq_client, sampler_output.sampled_token_ids, save_each_rank, model_output.mp_rank),
|
||||
)
|
||||
|
||||
if not (not save_each_rank and model_output.mp_rank > 0):
|
||||
try:
|
||||
zmq_client.send_pyobj(stream_transfer_data)
|
||||
except Exception as e:
|
||||
print(f"Send message error: {e}")
|
||||
t.start()
|
||||
else:
|
||||
save_output(
|
||||
sampler_output.sampled_token_ids,
|
||||
|
@@ -16,7 +16,7 @@
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -25,48 +25,19 @@ class DecoderState(Enum):
|
||||
"""DecoderState"""
|
||||
|
||||
TEXT = "text"
|
||||
VISION = "vision"
|
||||
VEDIO = "vedio"
|
||||
IMAGE = "image"
|
||||
VIDEO = "video"
|
||||
AUDIO = "audio"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextData:
|
||||
"""TextData"""
|
||||
|
||||
tokens: np.array
|
||||
not_need_stop: bool
|
||||
batch: int
|
||||
speculaive_decoding: bool
|
||||
logprobs: Optional[np.array] = None
|
||||
accept_tokens: Optional[np.array] = None
|
||||
accept_num: Optional[np.array] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class VisionData:
|
||||
"""VisionData"""
|
||||
|
||||
tokens: np.array
|
||||
|
||||
|
||||
@dataclass
|
||||
class VedioData:
|
||||
"""VedioData"""
|
||||
|
||||
tokens: np.array
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioData:
|
||||
"""AudioData"""
|
||||
|
||||
tokens: np.array
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamTransferData:
|
||||
"""StreamTransferData"""
|
||||
|
||||
decoder_state: DecoderState
|
||||
data: Union[TextData, VisionData, VedioData, AudioData]
|
||||
tokens: np.array
|
||||
batch_id: int
|
||||
speculaive_decoding: bool = False
|
||||
logprobs: Optional[np.array] = None
|
||||
accept_tokens: Optional[np.array] = None
|
||||
accept_num: Optional[np.array] = None
|
||||
|
@@ -31,7 +31,6 @@ from fastdeploy import envs
|
||||
from fastdeploy.engine.request import CompletionOutput, RequestMetrics, RequestOutput
|
||||
from fastdeploy.inter_communicator import IPCSignal, ZmqClient
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.utils import llm_logger, spec_logger
|
||||
from fastdeploy.worker.output import LogprobsLists
|
||||
@@ -49,7 +48,6 @@ class TokenProcessor:
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, split_connector):
|
||||
import paddle
|
||||
|
||||
paddle.device.set_device("cpu")
|
||||
self.cfg = cfg
|
||||
@@ -60,7 +58,10 @@ class TokenProcessor:
|
||||
self.split_connector = split_connector
|
||||
|
||||
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
|
||||
self.zmq_server = ZmqClient(name=f"get_save_output_rank{self.cfg.local_device_ids[0]}", mode=zmq.PULL)
|
||||
llm_logger.debug(f"create zmq get_save_output_rank{self.cfg.parallel_config.local_data_parallel_id}")
|
||||
self.zmq_server = ZmqClient(
|
||||
name=f"get_save_output_rank{self.cfg.parallel_config.local_data_parallel_id}", mode=zmq.PULL
|
||||
)
|
||||
self.zmq_server.start_server()
|
||||
self.zmq_server.create_router()
|
||||
|
||||
@@ -135,11 +136,145 @@ class TokenProcessor:
|
||||
if self.worker is not None:
|
||||
raise Exception("Worker is already running!")
|
||||
|
||||
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
|
||||
self.worker = threading.Thread(target=self.process_sampling_results_use_zmq)
|
||||
else:
|
||||
self.worker = threading.Thread(target=self.process_sampling_results)
|
||||
|
||||
self.worker.daemon = True
|
||||
self.worker.start()
|
||||
|
||||
def _reschedule_preempt_task(self, batch_size):
|
||||
"""reschedule when real batch size is smaller than the insert position of preemted_task"""
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
need_to_be_reschedule_req_ids = list(self.resource_manager.to_be_rescheduled_request_id_set)
|
||||
for request_id in need_to_be_reschedule_req_ids:
|
||||
if self.resource_manager.requests[request_id].idx >= (
|
||||
batch_size - 1
|
||||
): # No more token generated for preempted request
|
||||
self.resource_manager.reschedule_preempt_task(request_id)
|
||||
|
||||
def _process_per_token(self, task, batch_id: int, token_ids: np.ndarray, result: RequestOutput, is_prefill: bool):
|
||||
"""
|
||||
process output token by token
|
||||
"""
|
||||
current_time = time.time()
|
||||
task_id = task.request_id
|
||||
token_id_list = token_ids.tolist()
|
||||
|
||||
self._record_metrics(task, current_time, token_id_list)
|
||||
for token_id in token_id_list:
|
||||
recovery_stop = token_id == RECOVERY_STOP_SIGNAL
|
||||
if recovery_stop:
|
||||
llm_logger.info(f"recovery stop signal found at task {task_id}")
|
||||
self.tokens_counter[task_id] += 1
|
||||
if token_id != RECOVERY_STOP_SIGNAL:
|
||||
result.outputs.token_ids.append(token_id)
|
||||
task.output_token_ids.append(token_id)
|
||||
|
||||
if token_id in task.eos_token_ids or is_prefill or recovery_stop:
|
||||
result.finished = True
|
||||
if recovery_stop:
|
||||
result.error_msg = "Recover is not supported, the result is incomplete!"
|
||||
llm_logger.info(
|
||||
f"Request: {task_id} finished, number of " f"generated tokens: {self.tokens_counter[task_id]}."
|
||||
)
|
||||
llm_logger.info(
|
||||
f"Request: {task_id} token ratio: {self.tokens_counter[task_id] / (time.time() - task.inference_start_time)}"
|
||||
)
|
||||
llm_logger.info(f"{self.resource_manager.info()}")
|
||||
if self.cfg.speculative_config.method:
|
||||
self._compute_speculative_status()
|
||||
if not is_prefill:
|
||||
self._record_completion_metrics(task, current_time)
|
||||
self._recycle_resources(task_id, batch_id, task, result, is_prefill)
|
||||
break
|
||||
return result
|
||||
|
||||
def _process_batch_output_use_zmq(self, receive_datas):
|
||||
"""
|
||||
process output sample by sample
|
||||
"""
|
||||
batch_result = list()
|
||||
for _, stream_data in enumerate(receive_datas):
|
||||
i = stream_data.batch_id
|
||||
if self.resource_manager.stop_flags[i]:
|
||||
continue
|
||||
|
||||
task = self.resource_manager.tasks_list[i]
|
||||
|
||||
task_id = task.request_id
|
||||
token_ids = stream_data.tokens # numpy.array
|
||||
|
||||
current_time = time.time()
|
||||
if self.tokens_counter[task_id] == 0:
|
||||
metrics = RequestMetrics(
|
||||
arrival_time=task.arrival_time,
|
||||
inference_start_time=task.inference_start_time,
|
||||
first_token_time=time.time() - task.inference_start_time,
|
||||
time_in_queue=task.schedule_start_time - task.preprocess_end_time,
|
||||
preprocess_cost_time=task.preprocess_end_time - task.preprocess_start_time,
|
||||
request_start_time=task.arrival_time,
|
||||
)
|
||||
self._record_first_token_metrics(task, current_time)
|
||||
|
||||
else:
|
||||
metrics = RequestMetrics(
|
||||
arrival_time=time.time(),
|
||||
request_start_time=task.arrival_time,
|
||||
)
|
||||
|
||||
result = RequestOutput(
|
||||
request_id=task_id,
|
||||
outputs=CompletionOutput(
|
||||
index=i,
|
||||
send_idx=self.tokens_counter[task_id],
|
||||
token_ids=[],
|
||||
draft_token_ids=[],
|
||||
),
|
||||
finished=False,
|
||||
metrics=metrics,
|
||||
)
|
||||
|
||||
if self.tokens_counter[task_id] == 0:
|
||||
if task.messages is not None:
|
||||
result.prompt = task.messages
|
||||
result.num_cached_tokens = task.num_cached_tokens
|
||||
|
||||
is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill"
|
||||
result = self._process_per_token(task, i, token_ids, result, is_prefill)
|
||||
if not is_prefill or self.cfg.scheduler_config.name == "splitwise":
|
||||
batch_result.append(result)
|
||||
|
||||
return batch_result
|
||||
|
||||
def process_sampling_results_use_zmq(self):
|
||||
"""
|
||||
use zmq to receive outputs from worker and process them
|
||||
"""
|
||||
if self.speculative_decoding:
|
||||
raise NotImplementedError("GET_SAVE_OUTPUT_V1 does not support speculative decoding")
|
||||
if self.use_logprobs:
|
||||
raise NotImplementedError("GET_SAVE_OUTPUT_V1 does not support use_logprobs")
|
||||
rank_id = self.cfg.parallel_config.local_data_parallel_id
|
||||
while True:
|
||||
try:
|
||||
if (
|
||||
self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1
|
||||
) or (rank_id == 0):
|
||||
receive_datas = self.zmq_server.recv_pyobj()
|
||||
assert isinstance(receive_datas, list)
|
||||
llm_logger.debug(f"token_processor receive_data {receive_datas}")
|
||||
|
||||
batch_size = len(receive_datas)
|
||||
self._reschedule_preempt_task(batch_size)
|
||||
|
||||
batch_result = self._process_batch_output_use_zmq(receive_datas)
|
||||
self.postprocess(batch_result)
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Recieve message error: {e}")
|
||||
continue
|
||||
|
||||
def process_sampling_results(self):
|
||||
"""
|
||||
read tokens from paddle inference engine and process
|
||||
@@ -162,25 +297,6 @@ class TokenProcessor:
|
||||
|
||||
while True:
|
||||
try:
|
||||
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
|
||||
try:
|
||||
receive_data = self.zmq_server.recv_pyobj()
|
||||
assert isinstance(receive_data, StreamTransferData)
|
||||
if receive_data is not None:
|
||||
# TODO(Wanglongzhi2001): adapt more type of message.
|
||||
if receive_data.decoder_state == DecoderState.TEXT:
|
||||
self.output_tokens[0, 0] = paddle.to_tensor(
|
||||
receive_data.data.not_need_stop, dtype="int64"
|
||||
)
|
||||
self.output_tokens[1, 0] = paddle.to_tensor(receive_data.data.batch, dtype="int64")
|
||||
self.output_tokens[2 : 2 + receive_data.data.batch, 0] = paddle.to_tensor(
|
||||
receive_data.data.tokens[:, 0], dtype="int64"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Receive message error: {e}")
|
||||
continue
|
||||
else:
|
||||
is_blocking = True
|
||||
if self.speculative_decoding:
|
||||
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
|
||||
@@ -336,13 +452,8 @@ class TokenProcessor:
|
||||
tokens = tokens[2 : batch + 2]
|
||||
|
||||
batch_result = list()
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
need_to_be_reschedule_req_ids = list(self.resource_manager.to_be_rescheduled_request_id_set)
|
||||
for request_id in need_to_be_reschedule_req_ids:
|
||||
if self.resource_manager.requests[request_id].idx >= (
|
||||
batch - 1
|
||||
): # No more token generated for preempted request
|
||||
self.resource_manager.reschedule_preempt_task(request_id)
|
||||
# reschedule
|
||||
self._reschedule_preempt_task(batch)
|
||||
for i in range(batch):
|
||||
if self.resource_manager.stop_flags[i]:
|
||||
continue
|
||||
|
@@ -170,6 +170,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
|
||||
self.zmq_client = None
|
||||
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
|
||||
logger.info(f"zmq client get_save_output_rank{local_rank}")
|
||||
self.zmq_client = ZmqClient(name=f"get_save_output_rank{local_rank}", mode=zmq.PUSH)
|
||||
self.zmq_client.connect()
|
||||
self.zmq_client.socket.SNDTIMEO = 3000
|
||||
|
Reference in New Issue
Block a user