mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[fix]Modify follow-up push parameters and Modify the verification method for thinking length (#4177)
* [fix]Modify follow-up push parameters and Modify the verification method for thinking length (#4086) * 续推参数 generated_token_ids 修改成 completion_token_ids;修改思考长度校验方式 * 续推参数 generated_token_ids 修改成 completion_token_ids;修改思考长度校验方式 * 续推参数 generated_token_ids 修改成 completion_token_ids;修改思考长度校验方式 * 续推参数 generated_token_ids 修改成 completion_token_ids;修改思考长度校验方式 * add completion_token_ids * add logger * fix reasoning_max_tokens ParameterError * add unittest * add unittest * add unittest * add unittest * add unittest * add unit test * fix
This commit is contained in:
@@ -255,8 +255,13 @@ class EngineClient:
|
|||||||
raise ValueError(f"max_tokens can be defined [1, {self.max_model_len}).")
|
raise ValueError(f"max_tokens can be defined [1, {self.max_model_len}).")
|
||||||
|
|
||||||
if data.get("reasoning_max_tokens") is not None:
|
if data.get("reasoning_max_tokens") is not None:
|
||||||
if data["reasoning_max_tokens"] > data["max_tokens"] or data["reasoning_max_tokens"] < 1:
|
if data["reasoning_max_tokens"] < 1:
|
||||||
raise ValueError("reasoning_max_tokens must be between max_tokens and 1")
|
raise ValueError("reasoning_max_tokens must be greater than 1")
|
||||||
|
if data["reasoning_max_tokens"] > data["max_tokens"]:
|
||||||
|
data["reasoning_max_tokens"] = data["max_tokens"]
|
||||||
|
api_server_logger.warning(
|
||||||
|
f"req_id: {data['request_id']}, reasoning_max_tokens exceeds max_tokens, the value of reasoning_max_tokens will be adjusted to match that of max_tokens"
|
||||||
|
)
|
||||||
|
|
||||||
if data.get("top_p") is not None:
|
if data.get("top_p") is not None:
|
||||||
if data["top_p"] > 1 or data["top_p"] < 0:
|
if data["top_p"] > 1 or data["top_p"] < 0:
|
||||||
|
@@ -588,6 +588,7 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
prompt_token_ids: Optional[List[int]] = None
|
prompt_token_ids: Optional[List[int]] = None
|
||||||
max_streaming_response_tokens: Optional[int] = None
|
max_streaming_response_tokens: Optional[int] = None
|
||||||
disable_chat_template: Optional[bool] = False
|
disable_chat_template: Optional[bool] = False
|
||||||
|
completion_token_ids: Optional[List[int]] = None
|
||||||
# doc: end-chat-completion-extra-params
|
# doc: end-chat-completion-extra-params
|
||||||
|
|
||||||
def to_dict_for_infer(self, request_id=None):
|
def to_dict_for_infer(self, request_id=None):
|
||||||
@@ -613,6 +614,9 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
), "The parameter `raw_request` is not supported now, please use completion api instead."
|
), "The parameter `raw_request` is not supported now, please use completion api instead."
|
||||||
for key, value in self.metadata.items():
|
for key, value in self.metadata.items():
|
||||||
req_dict[key] = value
|
req_dict[key] = value
|
||||||
|
from fastdeploy.utils import api_server_logger
|
||||||
|
|
||||||
|
api_server_logger.warning("The parameter metadata is obsolete.")
|
||||||
for key, value in self.dict().items():
|
for key, value in self.dict().items():
|
||||||
if value is not None:
|
if value is not None:
|
||||||
req_dict[key] = value
|
req_dict[key] = value
|
||||||
|
@@ -241,10 +241,8 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Request must contain 'prompt', or 'messages': {request}")
|
raise ValueError(f"Request must contain 'prompt', or 'messages': {request}")
|
||||||
|
|
||||||
metadata = request.get("metadata")
|
if request.get("completion_token_ids"):
|
||||||
# 如果metadata包含之前输出的token,将这些token添加到input_ids末尾
|
self.append_completion_tokens(outputs, request["completion_token_ids"])
|
||||||
if metadata and metadata.get("generated_token_ids"):
|
|
||||||
self.append_generated_tokens(outputs, metadata["generated_token_ids"])
|
|
||||||
outputs = self.pack_outputs(outputs)
|
outputs = self.pack_outputs(outputs)
|
||||||
request["prompt_token_ids"] = outputs["input_ids"].tolist()
|
request["prompt_token_ids"] = outputs["input_ids"].tolist()
|
||||||
request["prompt_token_ids_len"] = len(request["prompt_token_ids"])
|
request["prompt_token_ids_len"] = len(request["prompt_token_ids"])
|
||||||
@@ -263,11 +261,11 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
|
|||||||
|
|
||||||
return request
|
return request
|
||||||
|
|
||||||
def append_generated_tokens(self, multimodal_inputs, generated_token_ids):
|
def append_completion_tokens(self, multimodal_inputs, completion_token_ids):
|
||||||
"append already generated tokens"
|
"append already completion tokens"
|
||||||
|
|
||||||
num_tokens = len(generated_token_ids)
|
num_tokens = len(completion_token_ids)
|
||||||
multimodal_inputs["input_ids"].extend(generated_token_ids)
|
multimodal_inputs["input_ids"].extend(completion_token_ids)
|
||||||
multimodal_inputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]] * num_tokens)
|
multimodal_inputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]] * num_tokens)
|
||||||
|
|
||||||
start = multimodal_inputs["cur_position"]
|
start = multimodal_inputs["cur_position"]
|
||||||
|
@@ -245,15 +245,11 @@ class QwenVLProcessor(TextProcessor):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Request must contain 'prompt', or 'messages': {request}")
|
raise ValueError(f"Request must contain 'prompt', or 'messages': {request}")
|
||||||
|
|
||||||
metadata = request.get("metadata")
|
|
||||||
# Handle continuation of previous generation by appending existing tokens
|
# Handle continuation of previous generation by appending existing tokens
|
||||||
if metadata and metadata.get("generated_token_ids"):
|
if request.get("completion_token_ids"):
|
||||||
self.append_generated_tokens(outputs, metadata["generated_token_ids"])
|
self.append_completion_tokens(outputs, request["completion_token_ids"])
|
||||||
|
|
||||||
enable_thinking = False
|
enable_thinking = False
|
||||||
if metadata:
|
|
||||||
enable_thinking = metadata.get("enable_thinking", False)
|
|
||||||
|
|
||||||
if request.get("chat_template_kwargs"):
|
if request.get("chat_template_kwargs"):
|
||||||
chat_template_kwargs = request.get("chat_template_kwargs")
|
chat_template_kwargs = request.get("chat_template_kwargs")
|
||||||
enable_thinking = chat_template_kwargs.get("enable_thinking", False)
|
enable_thinking = chat_template_kwargs.get("enable_thinking", False)
|
||||||
@@ -278,16 +274,16 @@ class QwenVLProcessor(TextProcessor):
|
|||||||
|
|
||||||
return request
|
return request
|
||||||
|
|
||||||
def append_generated_tokens(self, outputs, generated_token_ids):
|
def append_completion_tokens(self, outputs, completion_token_ids):
|
||||||
"""
|
"""
|
||||||
Append generated tokens to existing outputs.
|
Append completion tokens to existing outputs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
outputs: Current model outputs
|
outputs: Current model outputs
|
||||||
generated_token_ids: Generated tokens to append
|
completion_token_ids: completion tokens to append
|
||||||
"""
|
"""
|
||||||
out = {"input_ids": [], "token_type_ids": [], "position_ids": [], "cur_position": outputs["cur_position"]}
|
out = {"input_ids": [], "token_type_ids": [], "position_ids": [], "cur_position": outputs["cur_position"]}
|
||||||
self.processor._add_text(generated_token_ids, out)
|
self.processor._add_text(completion_token_ids, out)
|
||||||
|
|
||||||
outputs["input_ids"] = np.concatenate(
|
outputs["input_ids"] = np.concatenate(
|
||||||
[outputs["input_ids"], np.array(out["input_ids"], dtype=np.int64)], axis=0
|
[outputs["input_ids"], np.array(out["input_ids"], dtype=np.int64)], axis=0
|
||||||
|
@@ -255,6 +255,16 @@ def test_consistency_between_runs(api_url, headers, consistent_payload):
|
|||||||
assert content1 == content2
|
assert content1 == content2
|
||||||
|
|
||||||
|
|
||||||
|
def test_with_metadata(api_url, headers, consistent_payload):
|
||||||
|
"""
|
||||||
|
Test that result is same as the base result.
|
||||||
|
"""
|
||||||
|
# request
|
||||||
|
consistent_payload["metadata"] = {"enable_thinking": True}
|
||||||
|
resp1 = requests.post(api_url, headers=headers, json=consistent_payload)
|
||||||
|
assert resp1.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
# ==========================
|
# ==========================
|
||||||
# OpenAI Client Chat Completion Test
|
# OpenAI Client Chat Completion Test
|
||||||
# ==========================
|
# ==========================
|
||||||
@@ -555,6 +565,46 @@ def test_chat_with_thinking(openai_client, capsys):
|
|||||||
assert reasoning_tokens <= reasoning_max_tokens
|
assert reasoning_tokens <= reasoning_max_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_with_completion_token_ids(openai_client):
|
||||||
|
"""Test completion_token_ids"""
|
||||||
|
response = openai_client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[{"role": "user", "content": "Hello"}],
|
||||||
|
extra_body={
|
||||||
|
"completion_token_ids": [94936],
|
||||||
|
"return_token_ids": True,
|
||||||
|
"reasoning_max_tokens": 20,
|
||||||
|
"max_tokens": 10,
|
||||||
|
},
|
||||||
|
max_tokens=10,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
assert hasattr(response, "choices")
|
||||||
|
assert len(response.choices) > 0
|
||||||
|
assert hasattr(response.choices[0], "message")
|
||||||
|
assert hasattr(response.choices[0].message, "prompt_token_ids")
|
||||||
|
assert isinstance(response.choices[0].message.prompt_token_ids, list)
|
||||||
|
assert 94936 in response.choices[0].message.prompt_token_ids
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_with_reasoning_max_tokens(openai_client):
|
||||||
|
"""Test completion_token_ids"""
|
||||||
|
assertion_executed = False
|
||||||
|
try:
|
||||||
|
openai_client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[{"role": "user", "content": "Hello"}],
|
||||||
|
extra_body={"completion_token_ids": [18900], "return_token_ids": True, "reasoning_max_tokens": -1},
|
||||||
|
max_tokens=10,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
error_message = str(e)
|
||||||
|
assertion_executed = True
|
||||||
|
assert "reasoning_max_tokens must be greater than 1" in error_message
|
||||||
|
assert assertion_executed, "Assertion was not executed (no exception raised)"
|
||||||
|
|
||||||
|
|
||||||
def test_profile_reset_block_num():
|
def test_profile_reset_block_num():
|
||||||
"""测试profile reset_block_num功能,与baseline diff不能超过5%"""
|
"""测试profile reset_block_num功能,与baseline diff不能超过5%"""
|
||||||
log_file = "./log/config.log"
|
log_file = "./log/config.log"
|
||||||
|
@@ -176,12 +176,10 @@ class TestQwenVLProcessor(unittest.TestCase):
|
|||||||
3. Video processing produces expected output dimensions
|
3. Video processing produces expected output dimensions
|
||||||
4. Correct counts for images (1) and videos (1)
|
4. Correct counts for images (1) and videos (1)
|
||||||
"""
|
"""
|
||||||
num_generated_token_ids = 10
|
num_completion_token_ids = 10
|
||||||
request = {
|
request = {
|
||||||
"request_id": "12345",
|
"request_id": "12345",
|
||||||
"metadata": {
|
"completion_token_ids": [1] * num_completion_token_ids,
|
||||||
"generated_token_ids": [1] * num_generated_token_ids,
|
|
||||||
},
|
|
||||||
"stop": ["stop", "eof"],
|
"stop": ["stop", "eof"],
|
||||||
"messages": [
|
"messages": [
|
||||||
{
|
{
|
||||||
|
Reference in New Issue
Block a user