mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[LLM] support send batch data and aggregate data (#2860)
* [LLM] support send batch data and aggregate data * [LLM] fix ci bugs * [LLM] fix ci bugs * [LLM] fix ci bugs * [LLM] fix ci bugs * [LLM] update
This commit is contained in:
@@ -263,10 +263,11 @@ class LLMEngine(object):
|
||||
try:
|
||||
results = self.scheduler.get_results()
|
||||
if len(results) == 0:
|
||||
time.sleep(0.001)
|
||||
time.sleep(0.005)
|
||||
continue
|
||||
for request_id, contents in results.items():
|
||||
for result in contents:
|
||||
self.zmq_server.send_multipart(request_id, result)
|
||||
self.zmq_server.send_multipart(request_id, contents)
|
||||
|
||||
except Exception as e:
|
||||
llm_logger.error("Unexcepted error happend: {}, {}".format(
|
||||
e, str(traceback.format_exc())))
|
||||
|
@@ -20,7 +20,7 @@ import time
|
||||
from dataclasses import asdict, dataclass, fields
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import numpy
|
||||
import numpy as np
|
||||
|
||||
from fastdeploy.engine.sampling_params import SamplingParams
|
||||
from fastdeploy.utils import data_processor_logger
|
||||
@@ -181,7 +181,7 @@ class Request:
|
||||
f"sampling_params={self.sampling_params})")
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(slots=True)
|
||||
class CompletionOutput:
|
||||
"""The output data of one completion output of a request.
|
||||
|
||||
@@ -235,7 +235,7 @@ class CompletionOutput:
|
||||
f"reasoning_content={self.reasoning_content!r}")
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(slots=True)
|
||||
class RequestMetrics:
|
||||
"""Metrics associated with a request.
|
||||
|
||||
@@ -310,6 +310,10 @@ class RequestOutput:
|
||||
None if decoder-only.
|
||||
num_cached_tokens: The number of tokens with prefix cache hit.
|
||||
"""
|
||||
__slots__ = (
|
||||
'request_id', 'prompt', 'prompt_token_ids', 'outputs',
|
||||
'finished', 'metrics', 'num_cached_tokens', 'error_code', 'error_msg'
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -333,6 +337,12 @@ class RequestOutput:
|
||||
self.error_code = error_code
|
||||
self.error_msg = error_msg
|
||||
|
||||
|
||||
if prompt_token_ids is None:
|
||||
self.prompt_token_ids = []
|
||||
elif isinstance(self.prompt_token_ids, np.ndarray):
|
||||
self.prompt_token_ids = self.prompt_token_ids.tolist()
|
||||
|
||||
def add(self, next_output: "RequestOutput") -> None:
|
||||
"""Merge RequestOutput into this one"""
|
||||
|
||||
@@ -365,11 +375,6 @@ class RequestOutput:
|
||||
|
||||
def to_dict(self):
|
||||
"""convert RequestOutput into a serializable dict """
|
||||
if self.prompt_token_ids is None:
|
||||
self.prompt_token_ids = []
|
||||
|
||||
if type(self.prompt_token_ids) is numpy.ndarray:
|
||||
self.prompt_token_ids = self.prompt_token_ids.tolist()
|
||||
|
||||
return {
|
||||
"request_id": self.request_id,
|
||||
|
@@ -169,6 +169,8 @@ class LLM:
|
||||
|
||||
# get output
|
||||
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm)
|
||||
for i in range(len(outputs)):
|
||||
outputs[i].prompt = prompts[i]
|
||||
return outputs
|
||||
|
||||
def chat(
|
||||
|
@@ -21,6 +21,7 @@ import traceback
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
import msgpack
|
||||
import aiozmq
|
||||
from aiozmq import zmq
|
||||
|
||||
@@ -143,6 +144,8 @@ class OpenAIServingChat:
|
||||
dealer.write([b"", request_id.encode('utf-8')])
|
||||
choices = []
|
||||
current_waiting_time = 0
|
||||
if request.metadata is not None:
|
||||
enable_thinking = request.metadata.get("enable_thinking")
|
||||
while num_choices > 0:
|
||||
try:
|
||||
raw_data = await asyncio.wait_for(dealer.read(), timeout=10)
|
||||
@@ -158,102 +161,106 @@ class OpenAIServingChat:
|
||||
raise ValueError(f"Engine is not healthy: {msg}")
|
||||
else:
|
||||
current_waiting_time = 0
|
||||
await asyncio.sleep(0.1)
|
||||
await asyncio.sleep(0.01)
|
||||
continue
|
||||
response = msgpack.unpackb(raw_data[-1])
|
||||
for res in response:
|
||||
if res.get("error_code", 200) != 200:
|
||||
raise ValueError("{}".format(res["error_msg"]))
|
||||
|
||||
res = json.loads(raw_data[-1].decode('utf-8'))
|
||||
if res.get("error_code", 200) != 200:
|
||||
raise ValueError("{}".format(res["error_msg"]))
|
||||
if request.metadata is not None:
|
||||
enable_thinking = request.metadata.get("enable_thinking")
|
||||
self.engine_client.data_processor.process_response_dict(
|
||||
res, stream=True, enable_thinking=enable_thinking)
|
||||
self.engine_client.data_processor.process_response_dict(
|
||||
res, stream=True, enable_thinking=enable_thinking)
|
||||
|
||||
if res['metrics']['first_token_time'] is not None:
|
||||
arrival_time = res['metrics']['first_token_time']
|
||||
inference_start_time = res['metrics']['inference_start_time']
|
||||
else:
|
||||
arrival_time = res['metrics']['arrival_time'] - inference_start_time
|
||||
if first_iteration:
|
||||
num_prompt_tokens = len(prompt_token_ids)
|
||||
num_cached_tokens = res.get("num_cached_tokens", 0)
|
||||
for i in range(num_choices):
|
||||
choice = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(role="assistant", content="", reasoning_content="", tool_calls=None)
|
||||
)
|
||||
if request.metadata is not None and request.metadata.get("training", False):
|
||||
choice.delta.token_ids = prompt_token_ids
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice],
|
||||
model=model_name
|
||||
)
|
||||
if include_continuous_usage:
|
||||
chunk.usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=num_prompt_tokens,
|
||||
prompt_tokens_details=PromptTokenUsageInfo(cached_tokens=num_cached_tokens)
|
||||
)
|
||||
yield f"data: {chunk.model_dump_json(exclude_unset=True)} \n\n"
|
||||
first_iteration = False
|
||||
|
||||
output = res["outputs"]
|
||||
delta_text = output["text"]
|
||||
raw_top_logprobs = output["top_logprobs"]
|
||||
logprobs_res = None
|
||||
if raw_top_logprobs is not None:
|
||||
top_logprobs = LogprobsLists(
|
||||
logprob_token_ids=raw_top_logprobs[0],
|
||||
logprobs=raw_top_logprobs[1],
|
||||
sampled_token_ranks=raw_top_logprobs[2],
|
||||
)
|
||||
logprobs_res = self.build_logprobs_response(
|
||||
request_logprobs=request.logprobs,
|
||||
response_logprobs=top_logprobs,
|
||||
request_top_logprobs=request.top_logprobs,
|
||||
)
|
||||
|
||||
previous_num_tokens += len(output["token_ids"])
|
||||
delta_message = DeltaMessage(content=delta_text, reasoning_content=output.get("reasoning_content"), \
|
||||
token_ids=output.get("token_ids"), tool_calls=output.get("tool_call_content", []))
|
||||
|
||||
choice = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=delta_message,
|
||||
logprobs=logprobs_res,
|
||||
arrival_time=arrival_time
|
||||
)
|
||||
if res["finished"]:
|
||||
num_choices -= 1
|
||||
work_process_metrics.e2e_request_latency.observe(time.time() - res["metrics"]["request_start_time"])
|
||||
has_no_token_limit = request.max_tokens is None and request.max_completion_tokens is None
|
||||
max_tokens = request.max_completion_tokens or request.max_tokens
|
||||
if has_no_token_limit or previous_num_tokens != max_tokens:
|
||||
choice.finish_reason = "stop"
|
||||
if self.engine_client.reasoning_parser == "ernie_x1" and \
|
||||
output.get("finish_reason", "") == "tool_calls":
|
||||
choice.finish_reason = "tool_calls"
|
||||
if res['metrics']['first_token_time'] is not None:
|
||||
arrival_time = res['metrics']['first_token_time']
|
||||
inference_start_time = res['metrics']['inference_start_time']
|
||||
else:
|
||||
choice.finish_reason = "length"
|
||||
arrival_time = res['metrics']['arrival_time'] - inference_start_time
|
||||
if first_iteration:
|
||||
num_prompt_tokens = len(prompt_token_ids)
|
||||
num_cached_tokens = res.get("num_cached_tokens", 0)
|
||||
for i in range(num_choices):
|
||||
choice = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(role="assistant", content="", reasoning_content="", tool_calls=None)
|
||||
)
|
||||
if request.metadata is not None and request.metadata.get("training", False):
|
||||
choice.delta.token_ids = prompt_token_ids
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice],
|
||||
model=model_name
|
||||
)
|
||||
if include_continuous_usage:
|
||||
chunk.usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=num_prompt_tokens,
|
||||
prompt_tokens_details=PromptTokenUsageInfo(cached_tokens=num_cached_tokens)
|
||||
)
|
||||
yield f"data: {chunk.model_dump_json(exclude_unset=True)} \n\n"
|
||||
first_iteration = False
|
||||
|
||||
if res.get("error_msg") is not None and "Recover" in res["error_msg"]:
|
||||
choice.finish_reason = "recover_stop"
|
||||
output = res["outputs"]
|
||||
delta_text = output["text"]
|
||||
raw_top_logprobs = output["top_logprobs"]
|
||||
logprobs_res = None
|
||||
if raw_top_logprobs is not None:
|
||||
top_logprobs = LogprobsLists(
|
||||
logprob_token_ids=raw_top_logprobs[0],
|
||||
logprobs=raw_top_logprobs[1],
|
||||
sampled_token_ranks=raw_top_logprobs[2],
|
||||
)
|
||||
logprobs_res = self.build_logprobs_response(
|
||||
request_logprobs=request.logprobs,
|
||||
response_logprobs=top_logprobs,
|
||||
request_top_logprobs=request.top_logprobs,
|
||||
)
|
||||
|
||||
if request.metadata is not None and request.metadata.get("training", False) and delta_text != "":
|
||||
choice.delta.token_ids = output["token_ids"]
|
||||
if include_continuous_usage:
|
||||
chunk.usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=previous_num_tokens,
|
||||
total_tokens=num_prompt_tokens + previous_num_tokens
|
||||
previous_num_tokens += len(output["token_ids"])
|
||||
delta_message = DeltaMessage(content=delta_text, reasoning_content=output.get("reasoning_content"), \
|
||||
token_ids=output.get("token_ids"), tool_calls=output.get("tool_call_content", []))
|
||||
|
||||
choice = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=delta_message,
|
||||
logprobs=logprobs_res,
|
||||
arrival_time=arrival_time
|
||||
)
|
||||
choices.append(choice)
|
||||
if res["finished"]:
|
||||
num_choices -= 1
|
||||
work_process_metrics.e2e_request_latency.observe(time.time() - res["metrics"]["request_start_time"])
|
||||
has_no_token_limit = request.max_tokens is None and request.max_completion_tokens is None
|
||||
max_tokens = request.max_completion_tokens or request.max_tokens
|
||||
if has_no_token_limit or previous_num_tokens != max_tokens:
|
||||
choice.finish_reason = "stop"
|
||||
if self.engine_client.reasoning_parser == "ernie_x1" and \
|
||||
output.get("finish_reason", "") == "tool_calls":
|
||||
choice.finish_reason = "tool_calls"
|
||||
else:
|
||||
choice.finish_reason = "length"
|
||||
|
||||
if len(choices) == max_streaming_response_tokens or res["finished"]:
|
||||
if res.get("error_msg") is not None and "Recover" in res["error_msg"]:
|
||||
choice.finish_reason = "recover_stop"
|
||||
|
||||
if request.metadata is not None and request.metadata.get("training", False) and delta_text != "":
|
||||
choice.delta.token_ids = output["token_ids"]
|
||||
if include_continuous_usage:
|
||||
chunk.usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=previous_num_tokens,
|
||||
total_tokens=num_prompt_tokens + previous_num_tokens
|
||||
)
|
||||
choices.append(choice)
|
||||
|
||||
if len(choices) == max_streaming_response_tokens or res["finished"]:
|
||||
chunk.choices = choices
|
||||
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
|
||||
choices = []
|
||||
|
||||
if choices:
|
||||
chunk.choices = choices
|
||||
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
|
||||
choices = []
|
||||
@@ -321,33 +328,38 @@ class OpenAIServingChat:
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
|
||||
data = json.loads(raw_data[-1].decode('utf-8'))
|
||||
if data.get("error_code", 200) != 200:
|
||||
raise ValueError("{}".format(data["error_msg"]))
|
||||
if request.metadata is not None:
|
||||
enable_thinking = request.metadata.get("enable_thinking")
|
||||
data = self.engine_client.data_processor.process_response_dict(
|
||||
data, stream=False, enable_thinking=enable_thinking)
|
||||
# api_server_logger.debug(f"Client {request_id} received: {data}")
|
||||
previous_num_tokens += len(data["outputs"]["token_ids"])
|
||||
# The logprob for handling the response
|
||||
output = data["outputs"]
|
||||
raw_top_logprobs = output["top_logprobs"]
|
||||
if raw_top_logprobs is not None:
|
||||
top_logprobs = LogprobsLists(
|
||||
logprob_token_ids=raw_top_logprobs[0],
|
||||
logprobs=raw_top_logprobs[1],
|
||||
sampled_token_ranks=raw_top_logprobs[2],
|
||||
)
|
||||
logprobs_res = self.build_logprobs_response(
|
||||
request_logprobs=request.logprobs,
|
||||
response_logprobs=top_logprobs,
|
||||
request_top_logprobs=request.top_logprobs,
|
||||
)
|
||||
if logprobs_res and logprobs_res.content is not None:
|
||||
logprob_contents.extend(logprobs_res.content)
|
||||
if data["finished"]:
|
||||
final_res = data
|
||||
response = msgpack.unpackb(raw_data[-1])
|
||||
task_is_finished = False
|
||||
for data in response:
|
||||
if data.get("error_code", 200) != 200:
|
||||
raise ValueError("{}".format(data["error_msg"]))
|
||||
if request.metadata is not None:
|
||||
enable_thinking = request.metadata.get("enable_thinking")
|
||||
data = self.engine_client.data_processor.process_response_dict(
|
||||
data, stream=False, enable_thinking=enable_thinking)
|
||||
# api_server_logger.debug(f"Client {request_id} received: {data}")
|
||||
previous_num_tokens += len(data["outputs"]["token_ids"])
|
||||
# The logprob for handling the response
|
||||
output = data["outputs"]
|
||||
raw_top_logprobs = output["top_logprobs"]
|
||||
if raw_top_logprobs is not None:
|
||||
top_logprobs = LogprobsLists(
|
||||
logprob_token_ids=raw_top_logprobs[0],
|
||||
logprobs=raw_top_logprobs[1],
|
||||
sampled_token_ranks=raw_top_logprobs[2],
|
||||
)
|
||||
logprobs_res = self.build_logprobs_response(
|
||||
request_logprobs=request.logprobs,
|
||||
response_logprobs=top_logprobs,
|
||||
request_top_logprobs=request.top_logprobs,
|
||||
)
|
||||
if logprobs_res and logprobs_res.content is not None:
|
||||
logprob_contents.extend(logprobs_res.content)
|
||||
if data["finished"]:
|
||||
final_res = data
|
||||
task_is_finished = True
|
||||
break
|
||||
if task_is_finished:
|
||||
break
|
||||
finally:
|
||||
dealer.close()
|
||||
|
@@ -17,6 +17,7 @@
|
||||
import asyncio
|
||||
import aiozmq
|
||||
import json
|
||||
import msgpack
|
||||
from aiozmq import zmq
|
||||
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
|
||||
import time
|
||||
@@ -179,18 +180,20 @@ class OpenAIServingCompletion:
|
||||
current_waiting_time = 0
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
data = json.loads(raw_data[-1].decode("utf-8"))
|
||||
rid = int(data["request_id"].split("-")[-1])
|
||||
if data.get("error_code", 200) != 200:
|
||||
raise ValueError("{}".format(data["error_msg"]))
|
||||
response = msgpack.unpackb(raw_data[-1])
|
||||
for data in response:
|
||||
rid = int(data["request_id"].split("-")[-1])
|
||||
if data.get("error_code", 200) != 200:
|
||||
raise ValueError("{}".format(data["error_msg"]))
|
||||
|
||||
self.engine_client.data_processor.process_response_dict(
|
||||
data, stream=False)
|
||||
output_tokens[rid] += len(data["outputs"]["token_ids"])
|
||||
if data.get("finished", False):
|
||||
data["output_token_ids"] = output_tokens[rid]
|
||||
valid_results[rid] = data
|
||||
num_choices -= 1
|
||||
self.engine_client.data_processor.process_response_dict(
|
||||
data, stream=False)
|
||||
output_tokens[rid] += len(data["outputs"]["token_ids"])
|
||||
if data.get("finished", False):
|
||||
data["output_token_ids"] = output_tokens[rid]
|
||||
valid_results[rid] = data
|
||||
num_choices -= 1
|
||||
break
|
||||
|
||||
return self.request_output_to_completion_response(
|
||||
final_res_batch=valid_results,
|
||||
@@ -238,6 +241,12 @@ class OpenAIServingCompletion:
|
||||
if request.suffix is not None and request.suffix.get("max_streaming_response_tokens", 1) > 1:
|
||||
max_streaming_response_tokens = request.suffix["max_streaming_response_tokens"]
|
||||
choices = []
|
||||
chunk = CompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=choices
|
||||
)
|
||||
|
||||
current_waiting_time = 0
|
||||
while num_choices > 0:
|
||||
@@ -256,82 +265,86 @@ class OpenAIServingCompletion:
|
||||
continue
|
||||
|
||||
|
||||
res = json.loads(raw_data[-1].decode('utf-8'))
|
||||
idx = int(res["request_id"].split("-")[-1])
|
||||
if res.get("error_code", 200) != 200:
|
||||
raise ValueError("{}".format(res["error_msg"]))
|
||||
response = msgpack.unpackb(raw_data[-1])
|
||||
for res in response:
|
||||
idx = int(res["request_id"].split("-")[-1])
|
||||
if res.get("error_code", 200) != 200:
|
||||
raise ValueError("{}".format(res["error_msg"]))
|
||||
|
||||
if first_iteration[idx]:
|
||||
if request.suffix is not None and request.suffix.get("training", False):
|
||||
if first_iteration[idx]:
|
||||
if request.suffix is not None and request.suffix.get("training", False):
|
||||
chunk = CompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[CompletionResponseStreamChoice(
|
||||
index=idx,
|
||||
text="",
|
||||
token_ids=list(prompt_batched_token_ids[idx])
|
||||
)]
|
||||
)
|
||||
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
|
||||
first_iteration[idx] = False
|
||||
|
||||
|
||||
self.engine_client.data_processor.process_response_dict(
|
||||
res, stream=True)
|
||||
if res['metrics'].get('first_token_time') is not None:
|
||||
arrival_time = res['metrics']['first_token_time']
|
||||
inference_start_time[idx] = res['metrics']['inference_start_time']
|
||||
else:
|
||||
arrival_time = res['metrics']['arrival_time'] - inference_start_time[idx]
|
||||
|
||||
output = res["outputs"]
|
||||
|
||||
choices.append(CompletionResponseStreamChoice(
|
||||
index=idx,
|
||||
text=output["text"],
|
||||
token_ids=output.get("token_ids"),
|
||||
tool_calls=output.get("tool_call_content"),
|
||||
reasoning_content=output.get("reasoning_content"),
|
||||
arrival_time=arrival_time
|
||||
))
|
||||
if res["finished"]:
|
||||
if request.max_tokens is None or output_tokens[idx] + 1 != request.max_tokens:
|
||||
chunk.choices[0].finish_reason = "stop"
|
||||
if self.engine_client.reasoning_parser == "ernie_x1" and \
|
||||
output.get("finish_reason", "") == "tool_calls":
|
||||
chunk.choices[0].finish_reason = "tool_calls"
|
||||
else:
|
||||
chunk.choices[0].finish_reason = "length"
|
||||
|
||||
output_tokens[idx] += 1
|
||||
|
||||
if len(choices) == max_streaming_response_tokens or res["finished"]:
|
||||
chunk = CompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[CompletionResponseStreamChoice(
|
||||
index=idx,
|
||||
text="",
|
||||
token_ids=list(prompt_batched_token_ids[idx])
|
||||
)]
|
||||
choices=choices
|
||||
)
|
||||
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
|
||||
first_iteration[idx] = False
|
||||
choices = []
|
||||
|
||||
|
||||
self.engine_client.data_processor.process_response_dict(
|
||||
res, stream=True)
|
||||
if res['metrics'].get('first_token_time') is not None:
|
||||
arrival_time = res['metrics']['first_token_time']
|
||||
inference_start_time[idx] = res['metrics']['inference_start_time']
|
||||
else:
|
||||
arrival_time = res['metrics']['arrival_time'] - inference_start_time[idx]
|
||||
# api_server_logger.info(f"{arrival_time}")
|
||||
|
||||
output = res["outputs"]
|
||||
|
||||
choices.append(CompletionResponseStreamChoice(
|
||||
index=idx,
|
||||
text=output["text"],
|
||||
token_ids=output.get("token_ids"),
|
||||
tool_calls=output.get("tool_call_content"),
|
||||
reasoning_content=output.get("reasoning_content"),
|
||||
arrival_time=arrival_time
|
||||
))
|
||||
if res["finished"]:
|
||||
if request.max_tokens is None or output_tokens[idx] + 1 != request.max_tokens:
|
||||
chunk.choices[0].finish_reason = "stop"
|
||||
if self.engine_client.reasoning_parser == "ernie_x1" and \
|
||||
output.get("finish_reason", "") == "tool_calls":
|
||||
chunk.choices[0].finish_reason = "tool_calls"
|
||||
else:
|
||||
chunk.choices[0].finish_reason = "length"
|
||||
|
||||
output_tokens[idx] += 1
|
||||
|
||||
if len(choices) == max_streaming_response_tokens or res["finished"]:
|
||||
chunk = CompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=choices
|
||||
)
|
||||
choices = []
|
||||
|
||||
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
|
||||
|
||||
if res["finished"]:
|
||||
num_choices -= 1
|
||||
if getattr(request, "stream_options", None) and request.stream_options.include_usage:
|
||||
usage_chunk = CompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[],
|
||||
usage=UsageInfo(
|
||||
prompt_tokens=len(prompt_batched_token_ids[idx]),
|
||||
completion_tokens=output_tokens[idx]
|
||||
if res["finished"]:
|
||||
num_choices -= 1
|
||||
if getattr(request, "stream_options", None) and request.stream_options.include_usage:
|
||||
usage_chunk = CompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[],
|
||||
usage=UsageInfo(
|
||||
prompt_tokens=len(prompt_batched_token_ids[idx]),
|
||||
completion_tokens=output_tokens[idx]
|
||||
)
|
||||
)
|
||||
)
|
||||
yield f"data: {usage_chunk.model_dump_json(exclude_unset=True)}\n\n"
|
||||
yield f"data: {usage_chunk.model_dump_json(exclude_unset=True)}\n\n"
|
||||
if choices:
|
||||
chunk.choices = choices
|
||||
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
|
||||
choices = []
|
||||
|
||||
|
||||
except Exception as e:
|
||||
|
@@ -101,6 +101,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# Whether to use DeepGemm for FP8 blockwise MoE.
|
||||
"FD_USE_DEEP_GEMM":
|
||||
lambda: bool(int(os.getenv("FD_USE_DEEP_GEMM", "1"))),
|
||||
|
||||
# Whether to use aggregate send.
|
||||
"FD_USE_AGGREGATE_SEND":
|
||||
lambda: bool(int(os.getenv("FD_USE_AGGREGATE_SEND", "0"))),
|
||||
}
|
||||
|
||||
|
||||
|
@@ -20,6 +20,7 @@ import threading
|
||||
import time
|
||||
|
||||
import zmq
|
||||
import msgpack
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.utils import llm_logger
|
||||
@@ -37,6 +38,7 @@ class ZmqClient:
|
||||
self.router_path = f"/dev/shm/router_{name}.ipc"
|
||||
|
||||
self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM)
|
||||
self.aggregate_send = envs.FD_USE_AGGREGATE_SEND
|
||||
|
||||
self.mutex = threading.Lock()
|
||||
self.req_dict = dict()
|
||||
@@ -93,6 +95,16 @@ class ZmqClient:
|
||||
"""
|
||||
return self.socket.recv_pyobj()
|
||||
|
||||
def pack_aggregated_data(self, data):
|
||||
"""
|
||||
Aggregate multiple responses into one and send them to the client.
|
||||
"""
|
||||
result = data[0]
|
||||
if len(data) > 1:
|
||||
for response in data[1:]:
|
||||
result.add(response)
|
||||
result = msgpack.packb([result.to_dict()])
|
||||
return result
|
||||
def send_multipart(self, req_id, data):
|
||||
"""
|
||||
Send a multipart message to the router socket.
|
||||
@@ -116,14 +128,22 @@ class ZmqClient:
|
||||
break
|
||||
|
||||
try:
|
||||
result = json.dumps(data.to_dict()).encode('utf-8')
|
||||
start_send = time.time()
|
||||
if self.aggregate_send:
|
||||
result = self.pack_aggregated_data(data)
|
||||
else:
|
||||
result = msgpack.packb([response.to_dict() for response in data])
|
||||
self.router.send_multipart([self.req_dict[req_id], b'', result])
|
||||
llm_logger.debug(f"send_multipart result: {req_id} len {len(data)} elapse: {time.time()-start_send}")
|
||||
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Send result to zmq client failed: {e}")
|
||||
|
||||
if data.finished:
|
||||
if data[-1].finished:
|
||||
with self.mutex:
|
||||
self.req_dict.pop(data.request_id, None)
|
||||
self.req_dict.pop(req_id, None)
|
||||
llm_logger.info(f"send_multipart finished, req_id: {req_id}")
|
||||
|
||||
|
||||
def receive_json_once(self, block=False):
|
||||
"""
|
||||
|
@@ -505,8 +505,6 @@ class TokenProcessor(object):
|
||||
result.outputs.token_ids.append(token_id)
|
||||
if token_id in task.eos_token_ids or is_prefill or recovery_stop:
|
||||
result.finished = True
|
||||
result.prompt = task.prompt
|
||||
result.prompt_token_ids = task.prompt_token_ids
|
||||
if recovery_stop:
|
||||
result.error_msg = "Recover is not supported, the result is incomplete!"
|
||||
llm_logger.info(
|
||||
|
@@ -29,9 +29,11 @@ triton==3.3
|
||||
use-triton-in-paddle
|
||||
crcmod
|
||||
fastsafetensors==0.1.14
|
||||
msgpack
|
||||
opentelemetry-api>=1.24.0
|
||||
opentelemetry-sdk>=1.24.0
|
||||
opentelemetry-instrumentation-redis
|
||||
opentelemetry-instrumentation-mysql
|
||||
opentelemetry-distro
|
||||
opentelemetry-exporter-otlp
|
||||
|
||||
|
@@ -27,3 +27,4 @@ moviepy
|
||||
use-triton-in-paddle
|
||||
crcmod
|
||||
fastsafetensors==0.1.14
|
||||
msgpack
|
@@ -27,3 +27,4 @@ moviepy
|
||||
use-triton-in-paddle
|
||||
crcmod
|
||||
fastsafetensors==0.1.14
|
||||
msgpack
|
||||
|
Reference in New Issue
Block a user