[Feature] support fd return decode response (#4300)

* fix

* fix

* fix

* [Feature] support clear data

* update

* fix

* fix

* fix

* fix

* [BugFix] fix clear data

* Update api_server.py

* Update api_server.py

* [Feature] support fd decode response

* Update engine.py

* Update envs.py

* Update expert_service.py

* Update common_engine.py

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
Co-authored-by: ltd0924 <luotingdan@baidu.com>
This commit is contained in:
ltd0924
2025-09-28 16:11:50 +08:00
committed by GitHub
parent c8985727a6
commit c35a21a99a
4 changed files with 58 additions and 21 deletions

View File

@@ -33,6 +33,7 @@ from opentelemetry import trace
from fastdeploy.engine.request import Request, RequestOutput
from fastdeploy.engine.resource_manager import ResourceManager
from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1
from fastdeploy.input.preprocess import InputPreprocessor
from fastdeploy.inter_communicator import (
EngineCacheQueue,
EngineWorkerQueue,
@@ -129,6 +130,17 @@ class EngineSevice:
self.token_processor.tasks_queue = self.engine_worker_queue
self.token_processor.run()
def create_data_processor(self):
self.input_processor = InputPreprocessor(
self.cfg.tokenizer,
self.cfg.reasoning_parser,
self.cfg.limit_mm_per_prompt,
self.cfg.mm_processor_kwargs,
self.cfg.model_config.enable_mm,
self.cfg.tool_parser,
)
self.data_processor = self.input_processor.create_processor()
def _init_worker_monitor_signals(self): # exist_task_signal 用于各worker进程感知是否有新Task需要处理
current_suffix = int(self.cfg.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id])
llm_logger.info(f"current_suffix: {current_suffix}")
@@ -664,6 +676,20 @@ class EngineSevice:
f"traceback={traceback.format_exc()}"
)
def _decode_token(self, token_ids, req_id, is_end):
delta_text = ""
if envs.FD_ENABLE_RETURN_TEXT:
delta_text, cum_tokens, _ = self.data_processor.ids2tokens(token_ids, req_id)
if delta_text != "":
prefix_offset = self.data_processor.decode_status[req_id][0]
read_offset = self.data_processor.decode_status[req_id][1]
token_ids = cum_tokens[prefix_offset:read_offset]
else:
token_ids = []
if is_end:
del self.data_processor.decode_status[req_id]
return delta_text, token_ids
def _zmq_send_generated_tokens(self):
"""
Recieve output for zmq
@@ -675,7 +701,23 @@ class EngineSevice:
time.sleep(0.005)
continue
for request_id, contents in results.items():
self.send_response_server.send_response(request_id, contents)
new_contents = []
for content in contents:
delta_text, token_ids = self._decode_token(
token_ids=content.outputs.token_ids, req_id=request_id, is_end=content.finished
)
if len(token_ids):
content.outputs.token_ids = token_ids
content.outputs.text = delta_text
new_contents.append(content)
else:
llm_logger.warning(
f"current tokens need to accumulate, req_id: {request_id} {content.outputs.token_ids}"
)
if len(new_contents):
llm_logger.info(f"Send response for request id: {request_id}")
self.send_response_server.send_response(request_id, new_contents)
except Exception as e:
llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")

View File

@@ -37,7 +37,6 @@ from fastdeploy.engine.args_utils import EngineArgs
from fastdeploy.engine.common_engine import EngineSevice
from fastdeploy.engine.expert_service import start_data_parallel_service
from fastdeploy.engine.request import Request
from fastdeploy.input.preprocess import InputPreprocessor
from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal
from fastdeploy.utils import EngineError, console_logger, envs, llm_logger
@@ -85,14 +84,6 @@ class LLMEngine:
self.running = True
self.is_started = False
self.input_processor = InputPreprocessor(
cfg.tokenizer,
cfg.reasoning_parser,
cfg.limit_mm_per_prompt,
cfg.mm_processor_kwargs,
cfg.model_config.enable_mm,
cfg.tool_parser,
)
self.engine = EngineSevice(cfg)
if self.cfg.cache_config.num_gpu_blocks_override is None:
@@ -114,10 +105,9 @@ class LLMEngine:
self.ipc_signal_suffix = self.cfg.engine_worker_queue_port[0]
self._init_worker_signals()
self.data_processor = self.input_processor.create_processor()
self.engine.data_processor = self.data_processor
self.engine.start()
self.engine.create_data_processor()
self.data_processor = self.engine.data_processor
if api_server_pid is not None:
llm_logger.info(f"Start zmq server, api_server_pid: {api_server_pid}")
self.engine.start_zmq_service(api_server_pid)
@@ -199,7 +189,7 @@ class LLMEngine:
request.sampling_params = sampling_params
request.preprocess_start_time = time.time()
request = self.data_processor.process_request(request, self.cfg.max_model_len, **kwargs)
request = self.engine.data_processor.process_request(request, self.cfg.max_model_len, **kwargs)
request.prompt_token_ids_len = len(request.prompt_token_ids)
request.need_prefill_tokens = request.prompt_token_ids_len
input_ids_len = request.prompt_token_ids_len
@@ -431,9 +421,9 @@ class LLMEngine:
py_script = os.path.join(current_dir_path, worker_path)
ori_vocab_size = (
len(self.data_processor.tokenizer.sp_model)
if hasattr(self.data_processor.tokenizer, "sp_model")
else len(self.data_processor.tokenizer.vocab)
len(self.engine.data_processor.tokenizer.sp_model)
if hasattr(self.engine.data_processor.tokenizer, "sp_model")
else len(self.engine.data_processor.tokenizer.vocab)
)
ports = ",".join(self.cfg.engine_worker_queue_port)
@@ -452,8 +442,8 @@ class LLMEngine:
f" --total_block_num {self.cfg.cache_config.total_block_num}"
f" --block_size {self.cfg.cache_config.block_size}"
f" --enc_dec_block_num {self.cfg.cache_config.enc_dec_block_num}"
f" --eos_tokens_lens {self.data_processor.eos_token_id_len}"
f" --pad_token_id {self.data_processor.pad_token_id}"
f" --eos_tokens_lens {self.engine.data_processor.eos_token_id_len}"
f" --pad_token_id {self.engine.data_processor.pad_token_id}"
f" --engine_pid {self.cfg.engine_worker_queue_port[0]}"
f" --max_num_batched_tokens {self.cfg.max_num_batched_tokens}"
f" --splitwise_role {self.cfg.splitwise_role}"
@@ -545,7 +535,7 @@ class LLMEngine:
for result in self._get_generated_tokens(req_id):
is_end = result.finished
if stream and not is_end:
processed = self.data_processor.process_response(result)
processed = self.engine.data_processor.process_response(result)
if processed is None:
continue
output = processed.to_dict()
@@ -553,7 +543,7 @@ class LLMEngine:
# Exit loop if termination condition is met
if is_end:
processed = self.data_processor.process_response(result)
processed = self.engine.data_processor.process_response(result)
output = processed.to_dict()
llm_logger.debug(f"Generate result: {output}")
if not stream:

View File

@@ -80,6 +80,9 @@ class ExpertService:
start_time = time.time()
self.engine.start()
if envs.FD_ENABLE_RETURN_TEXT:
self.engine.create_data_processor()
if ipc_signal_suffix is not None:
self.api_server_pid = ipc_signal_suffix
self.engine.start_zmq_service(ipc_signal_suffix)

View File

@@ -109,6 +109,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
"FD_ZMQ_SEND_RESPONSE_SERVER_PORT": lambda: os.getenv("FD_ZMQ_SEND_RESPONSE_SERVER_PORT", "8201"),
# LLMEngine recieve control command port, used when FD_ENABLE_INTERNAL_ADAPTER=1
"FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"),
# enable return text, used when FD_ENABLE_INTERNAL_ADAPTER=1
"FD_ENABLE_RETURN_TEXT": lambda: bool(int(os.getenv("FD_ENABLE_RETURN_TEXT", "0"))),
}