mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-27 04:46:16 +08:00
Support limit thinking lengths (#4069)
Co-authored-by: K11OntheBoat <“ruianmaidanglao@163.com”>
This commit is contained in:
@@ -224,6 +224,7 @@ class ModelConfig:
|
|||||||
self.vision_config = PretrainedConfig.from_dict(self.vision_config)
|
self.vision_config = PretrainedConfig.from_dict(self.vision_config)
|
||||||
|
|
||||||
self.ori_vocab_size = args.get("ori_vocab_size", self.vocab_size)
|
self.ori_vocab_size = args.get("ori_vocab_size", self.vocab_size)
|
||||||
|
self.think_end_id = args.get("think_end_id", -1)
|
||||||
|
|
||||||
architectures = self.architectures[0]
|
architectures = self.architectures[0]
|
||||||
|
|
||||||
|
@@ -34,6 +34,7 @@ import numpy as np
|
|||||||
import paddle
|
import paddle
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from fastdeploy.config import ErnieArchitectures
|
||||||
from fastdeploy.engine.args_utils import EngineArgs
|
from fastdeploy.engine.args_utils import EngineArgs
|
||||||
from fastdeploy.engine.common_engine import EngineService
|
from fastdeploy.engine.common_engine import EngineService
|
||||||
from fastdeploy.engine.expert_service import start_data_parallel_service
|
from fastdeploy.engine.expert_service import start_data_parallel_service
|
||||||
@@ -470,6 +471,14 @@ class LLMEngine:
|
|||||||
else len(self.data_processor.tokenizer.vocab)
|
else len(self.data_processor.tokenizer.vocab)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
is_ernie = ErnieArchitectures.contains_ernie_arch(self.cfg.model_config.architectures)
|
||||||
|
if is_ernie:
|
||||||
|
self.cfg.model_config.think_end_id = self.data_processor.tokenizer.get_vocab().get("</think>", -1)
|
||||||
|
if self.cfg.model_config.think_end_id != -1:
|
||||||
|
llm_logger.info(f"Get think_end_id {self.cfg.model_config.think_end_id} from vocab.")
|
||||||
|
else:
|
||||||
|
llm_logger.info("No </think> token found in vocabulary, the model can not do reasoning.")
|
||||||
|
|
||||||
ports = ",".join(self.cfg.parallel_config.engine_worker_queue_port)
|
ports = ",".join(self.cfg.parallel_config.engine_worker_queue_port)
|
||||||
ips = None
|
ips = None
|
||||||
if self.cfg.ips is not None:
|
if self.cfg.ips is not None:
|
||||||
@@ -496,6 +505,7 @@ class LLMEngine:
|
|||||||
f" --data_parallel_size {self.cfg.parallel_config.data_parallel_size}"
|
f" --data_parallel_size {self.cfg.parallel_config.data_parallel_size}"
|
||||||
f" --quantization '{json.dumps(self.cfg.model_config.quantization)}'"
|
f" --quantization '{json.dumps(self.cfg.model_config.quantization)}'"
|
||||||
f" --ori_vocab_size {ori_vocab_size}"
|
f" --ori_vocab_size {ori_vocab_size}"
|
||||||
|
f" --think_end_id {self.cfg.model_config.think_end_id}"
|
||||||
f" --speculative_config '{self.cfg.speculative_config.to_json_string()}'"
|
f" --speculative_config '{self.cfg.speculative_config.to_json_string()}'"
|
||||||
f" --graph_optimization_config '{self.cfg.graph_opt_config.to_json_string()}'"
|
f" --graph_optimization_config '{self.cfg.graph_opt_config.to_json_string()}'"
|
||||||
f" --guided_decoding_backend {self.cfg.guided_decoding_backend}"
|
f" --guided_decoding_backend {self.cfg.guided_decoding_backend}"
|
||||||
|
@@ -155,8 +155,6 @@ class EngineClient:
|
|||||||
task["prompt_token_ids_len"] = len(task["prompt_token_ids"])
|
task["prompt_token_ids_len"] = len(task["prompt_token_ids"])
|
||||||
input_ids_len = task["prompt_token_ids_len"]
|
input_ids_len = task["prompt_token_ids_len"]
|
||||||
task["max_tokens"] = min(self.max_model_len - input_ids_len, task.get("max_tokens"))
|
task["max_tokens"] = min(self.max_model_len - input_ids_len, task.get("max_tokens"))
|
||||||
if task.get("reasoning_max_tokens", None) is None:
|
|
||||||
task["reasoning_max_tokens"] = max(int(task["max_tokens"] * 0.8), 1)
|
|
||||||
min_tokens = task.get("min_tokens", 1)
|
min_tokens = task.get("min_tokens", 1)
|
||||||
if "messages" in task:
|
if "messages" in task:
|
||||||
del task["messages"]
|
del task["messages"]
|
||||||
|
@@ -252,6 +252,10 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
|
|||||||
request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1]
|
request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1]
|
||||||
if request.get("max_tokens") is None:
|
if request.get("max_tokens") is None:
|
||||||
request["max_tokens"] = max(1, max_model_len - len(request["prompt_token_ids"]))
|
request["max_tokens"] = max(1, max_model_len - len(request["prompt_token_ids"]))
|
||||||
|
else:
|
||||||
|
request["max_tokens"] = min(max_model_len - len(request["prompt_token_ids"]), request["max_tokens"])
|
||||||
|
if request.get("reasoning_max_tokens") is None:
|
||||||
|
request["reasoning_max_tokens"] = max(int(request["max_tokens"] * 0.8), 1)
|
||||||
data_processor_logger.info(f"Processed request {request}")
|
data_processor_logger.info(f"Processed request {request}")
|
||||||
|
|
||||||
return request
|
return request
|
||||||
|
@@ -195,8 +195,9 @@ def post_process_normal(
|
|||||||
) -> ModelRunnerOutput:
|
) -> ModelRunnerOutput:
|
||||||
"""Post-processing steps after completing a single token generation."""
|
"""Post-processing steps after completing a single token generation."""
|
||||||
# handle vl:
|
# handle vl:
|
||||||
if model_output.enable_thinking:
|
if model_output.think_end_id != -1:
|
||||||
exists_think_end = sampler_output.sampled_token_ids == model_output.think_end_id
|
thinking_mask = model_output.enable_thinking
|
||||||
|
exists_think_end = (sampler_output.sampled_token_ids == model_output.think_end_id) & thinking_mask
|
||||||
paddle.assign(
|
paddle.assign(
|
||||||
paddle.where(
|
paddle.where(
|
||||||
exists_think_end,
|
exists_think_end,
|
||||||
@@ -206,9 +207,10 @@ def post_process_normal(
|
|||||||
model_output.need_think_end,
|
model_output.need_think_end,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
reasoning_index_update_cond = model_output.need_think_end.cast("bool") & thinking_mask
|
||||||
paddle.assign(
|
paddle.assign(
|
||||||
paddle.where(
|
paddle.where(
|
||||||
model_output.need_think_end.cast("bool"),
|
reasoning_index_update_cond,
|
||||||
model_output.reasoning_index - 1,
|
model_output.reasoning_index - 1,
|
||||||
model_output.reasoning_index,
|
model_output.reasoning_index,
|
||||||
),
|
),
|
||||||
@@ -219,6 +221,8 @@ def post_process_normal(
|
|||||||
(sampler_output.sampled_token_ids == model_output.eos_token_id.T).any(axis=1, keepdim=True)
|
(sampler_output.sampled_token_ids == model_output.eos_token_id.T).any(axis=1, keepdim=True)
|
||||||
| (model_output.reasoning_index == 0)
|
| (model_output.reasoning_index == 0)
|
||||||
) & (model_output.need_think_end > 0)
|
) & (model_output.need_think_end > 0)
|
||||||
|
|
||||||
|
stop_wo_think = stop_wo_think & thinking_mask
|
||||||
sampler_output.sampled_token_ids = paddle.where(
|
sampler_output.sampled_token_ids = paddle.where(
|
||||||
stop_wo_think,
|
stop_wo_think,
|
||||||
model_output.think_end_id,
|
model_output.think_end_id,
|
||||||
|
@@ -322,15 +322,27 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
else:
|
else:
|
||||||
position_ids = None
|
position_ids = None
|
||||||
|
|
||||||
enable_thinking = request.get("enable_thinking", True)
|
|
||||||
enable_thinking = enable_thinking if enable_thinking is not None else True
|
|
||||||
self.share_inputs["enable_thinking"][:] = enable_thinking
|
|
||||||
self.share_inputs["need_think_end"][idx : idx + 1, :] = 1 if enable_thinking else 0
|
|
||||||
self.share_inputs["reasoning_index"][idx : idx + 1, :] = request.get("reasoning_max_tokens", 2048)
|
|
||||||
self.share_inputs["rope_emb"][idx : idx + 1, :] = self.prepare_rope3d(
|
self.share_inputs["rope_emb"][idx : idx + 1, :] = self.prepare_rope3d(
|
||||||
position_ids, request.get("max_tokens", 2048)
|
position_ids, request.get("max_tokens", 2048)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if request.get("enable_thinking", False):
|
||||||
|
# Enable thinking
|
||||||
|
req_reasoning_max_tokens = request.get("reasoning_max_tokens")
|
||||||
|
req_max_tokens = request.get("max_tokens")
|
||||||
|
final_reasoning_tokens = (
|
||||||
|
req_reasoning_max_tokens if req_reasoning_max_tokens is not None else req_max_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
self.share_inputs["enable_thinking"][idx : idx + 1] = True
|
||||||
|
self.share_inputs["need_think_end"][idx : idx + 1, :] = 1
|
||||||
|
self.share_inputs["reasoning_index"][idx : idx + 1, :] = final_reasoning_tokens
|
||||||
|
else:
|
||||||
|
# Disable thinking
|
||||||
|
self.share_inputs["enable_thinking"][idx : idx + 1] = False
|
||||||
|
self.share_inputs["need_think_end"][idx : idx + 1, :] = 0
|
||||||
|
self.share_inputs["reasoning_index"][idx : idx + 1, :] = 0
|
||||||
|
|
||||||
if isinstance(request.prompt_token_ids, np.ndarray):
|
if isinstance(request.prompt_token_ids, np.ndarray):
|
||||||
prompt_token_ids = request.prompt_token_ids.tolist()
|
prompt_token_ids = request.prompt_token_ids.tolist()
|
||||||
else:
|
else:
|
||||||
@@ -549,16 +561,28 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
self.share_inputs["prompt_lens"][idx : idx + 1] = length
|
self.share_inputs["prompt_lens"][idx : idx + 1] = length
|
||||||
|
|
||||||
if self.enable_mm:
|
if self.enable_mm:
|
||||||
enable_thinking = request.get("enable_thinking", True)
|
|
||||||
enable_thinking = enable_thinking if enable_thinking is not None else True
|
|
||||||
self.share_inputs["enable_thinking"][:] = enable_thinking
|
|
||||||
self.share_inputs["need_think_end"][idx : idx + 1, :] = 1 if enable_thinking else 0
|
|
||||||
self.share_inputs["reasoning_index"][idx : idx + 1, :] = request.get("reasoning_max_tokens", 2048)
|
|
||||||
self.share_inputs["rope_emb"][idx : idx + 1, :] = self.prepare_rope3d(
|
self.share_inputs["rope_emb"][idx : idx + 1, :] = self.prepare_rope3d(
|
||||||
position_ids, request.get("max_tokens", 2048)
|
position_ids, request.get("max_tokens", 2048)
|
||||||
)
|
)
|
||||||
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
|
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
|
||||||
|
|
||||||
|
if request.get("enable_thinking", False):
|
||||||
|
# Enable thinking
|
||||||
|
req_reasoning_max_tokens = request.get("reasoning_max_tokens")
|
||||||
|
req_max_tokens = request.get("max_tokens")
|
||||||
|
final_reasoning_tokens = (
|
||||||
|
req_reasoning_max_tokens if req_reasoning_max_tokens is not None else req_max_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
self.share_inputs["enable_thinking"][idx : idx + 1] = True
|
||||||
|
self.share_inputs["need_think_end"][idx : idx + 1, :] = 1
|
||||||
|
self.share_inputs["reasoning_index"][idx : idx + 1, :] = final_reasoning_tokens
|
||||||
|
else:
|
||||||
|
# Disable thinking
|
||||||
|
self.share_inputs["enable_thinking"][idx : idx + 1] = False
|
||||||
|
self.share_inputs["need_think_end"][idx : idx + 1, :] = 0
|
||||||
|
self.share_inputs["reasoning_index"][idx : idx + 1, :] = 0
|
||||||
|
|
||||||
def get_attr_from_request(request, attr, default_value=None):
|
def get_attr_from_request(request, attr, default_value=None):
|
||||||
res = request.get(attr, default_value)
|
res = request.get(attr, default_value)
|
||||||
if res is not None:
|
if res is not None:
|
||||||
@@ -853,6 +877,11 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
# Initialize rotary position embedding
|
# Initialize rotary position embedding
|
||||||
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
|
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
|
||||||
|
|
||||||
|
# Initialize thinking related buffers
|
||||||
|
self.share_inputs["need_think_end"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
|
||||||
|
self.share_inputs["enable_thinking"] = paddle.full(shape=[max_num_seqs, 1], fill_value=False, dtype="bool")
|
||||||
|
self.share_inputs["reasoning_index"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
|
||||||
|
|
||||||
# TODO(gongshaotian): move to models
|
# TODO(gongshaotian): move to models
|
||||||
if not self.enable_mm:
|
if not self.enable_mm:
|
||||||
self.share_inputs["rope_emb"] = get_rope(
|
self.share_inputs["rope_emb"] = get_rope(
|
||||||
@@ -952,11 +981,6 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
dtype="float32",
|
dtype="float32",
|
||||||
)
|
)
|
||||||
self.share_inputs["image_features"] = None
|
self.share_inputs["image_features"] = None
|
||||||
self.share_inputs["need_think_end"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
|
|
||||||
self.share_inputs["enable_thinking"] = paddle.full(
|
|
||||||
shape=[1], fill_value=("ernie" in self.model_config.model_type), dtype="bool"
|
|
||||||
)
|
|
||||||
self.share_inputs["reasoning_index"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
|
|
||||||
|
|
||||||
def _prepare_inputs(self) -> None:
|
def _prepare_inputs(self) -> None:
|
||||||
"""Prepare the model inputs"""
|
"""Prepare the model inputs"""
|
||||||
@@ -1399,10 +1423,10 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
),
|
),
|
||||||
accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None),
|
accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None),
|
||||||
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
|
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
|
||||||
enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None),
|
enable_thinking=self.share_inputs["enable_thinking"],
|
||||||
think_end_id=(getattr(self.model_config, "think_end_id", -1) if self.enable_mm else -1),
|
think_end_id=self.model_config.think_end_id,
|
||||||
need_think_end=(self.share_inputs["need_think_end"] if self.enable_mm else None),
|
need_think_end=self.share_inputs["need_think_end"],
|
||||||
reasoning_index=(self.share_inputs["reasoning_index"] if self.enable_mm else None),
|
reasoning_index=self.share_inputs["reasoning_index"],
|
||||||
stop_token_ids=self.share_inputs["stop_seqs"],
|
stop_token_ids=self.share_inputs["stop_seqs"],
|
||||||
stop_seqs_len=self.share_inputs["stop_seqs_len"],
|
stop_seqs_len=self.share_inputs["stop_seqs_len"],
|
||||||
)
|
)
|
||||||
@@ -1715,10 +1739,10 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
),
|
),
|
||||||
accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None),
|
accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None),
|
||||||
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
|
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
|
||||||
enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None),
|
enable_thinking=self.share_inputs["enable_thinking"],
|
||||||
think_end_id=(getattr(self.model_config, "think_end_id", -1) if self.enable_mm else -1),
|
think_end_id=self.model_config.think_end_id,
|
||||||
need_think_end=(self.share_inputs["need_think_end"][:num_running_requests] if self.enable_mm else None),
|
need_think_end=self.share_inputs["need_think_end"][:num_running_requests],
|
||||||
reasoning_index=(self.share_inputs["reasoning_index"][:num_running_requests] if self.enable_mm else None),
|
reasoning_index=self.share_inputs["reasoning_index"][:num_running_requests],
|
||||||
stop_token_ids=self.share_inputs["stop_seqs"],
|
stop_token_ids=self.share_inputs["stop_seqs"],
|
||||||
stop_seqs_len=self.share_inputs["stop_seqs_len"],
|
stop_seqs_len=self.share_inputs["stop_seqs_len"],
|
||||||
)
|
)
|
||||||
|
@@ -587,6 +587,7 @@ def parse_args():
|
|||||||
help="enable expert parallel",
|
help="enable expert parallel",
|
||||||
)
|
)
|
||||||
parser.add_argument("--ori_vocab_size", type=int, default=None)
|
parser.add_argument("--ori_vocab_size", type=int, default=None)
|
||||||
|
parser.add_argument("--think_end_id", type=int, default=-1)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--quantization",
|
"--quantization",
|
||||||
|
@@ -516,6 +516,21 @@ def test_chat_with_thinking(openai_client, capsys):
|
|||||||
assert response.choices[0].message.reasoning_content is None
|
assert response.choices[0].message.reasoning_content is None
|
||||||
assert "</think>" not in response.choices[0].message.content
|
assert "</think>" not in response.choices[0].message.content
|
||||||
|
|
||||||
|
# test logic
|
||||||
|
reasoning_max_tokens = None
|
||||||
|
response = openai_client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[{"role": "user", "content": "Explain gravity in a way that a five-year-old child can understand."}],
|
||||||
|
temperature=1,
|
||||||
|
stream=False,
|
||||||
|
max_tokens=20,
|
||||||
|
extra_body={
|
||||||
|
"chat_template_kwargs": {"enable_thinking": True},
|
||||||
|
"reasoning_max_tokens": reasoning_max_tokens,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.choices[0].message.reasoning_content is not None
|
||||||
|
|
||||||
# enable thinking, streaming
|
# enable thinking, streaming
|
||||||
reasoning_max_tokens = 3
|
reasoning_max_tokens = 3
|
||||||
response = openai_client.chat.completions.create(
|
response = openai_client.chat.completions.create(
|
||||||
@@ -927,3 +942,50 @@ def test_profile_reset_block_num():
|
|||||||
f"Reset total_block_num {actual_value} 与 baseline {baseline} diff需要在5%以内"
|
f"Reset total_block_num {actual_value} 与 baseline {baseline} diff需要在5%以内"
|
||||||
f"Allowed range: [{lower_bound:.1f}, {upper_bound:.1f}]"
|
f"Allowed range: [{lower_bound:.1f}, {upper_bound:.1f}]"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_thinking_logic_flag(openai_client, capsys):
|
||||||
|
"""
|
||||||
|
Test the interaction between token calculation logic and conditional thinking.
|
||||||
|
This test covers:
|
||||||
|
1. Default max_tokens calculation when not provided.
|
||||||
|
2. Capping of max_tokens when it exceeds model limits.
|
||||||
|
3. Default reasoning_max_tokens calculation when not provided.
|
||||||
|
4. Activation of thinking based on the final state of reasoning_max_tokens.
|
||||||
|
"""
|
||||||
|
|
||||||
|
response_case_1 = openai_client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[{"role": "user", "content": "Explain gravity briefly."}],
|
||||||
|
temperature=1,
|
||||||
|
stream=False,
|
||||||
|
extra_body={
|
||||||
|
"chat_template_kwargs": {"enable_thinking": True},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response_case_1.choices[0].message.reasoning_content is not None
|
||||||
|
|
||||||
|
response_case_2 = openai_client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[{"role": "user", "content": "Explain gravity in a way that a five-year-old child can understand."}],
|
||||||
|
temperature=1,
|
||||||
|
stream=False,
|
||||||
|
max_tokens=20,
|
||||||
|
extra_body={
|
||||||
|
"chat_template_kwargs": {"enable_thinking": True},
|
||||||
|
"reasoning_max_tokens": 5,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response_case_2.choices[0].message.reasoning_content is not None
|
||||||
|
|
||||||
|
response_case_3 = openai_client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[{"role": "user", "content": "Explain gravity in a way that a five-year-old child can understand."}],
|
||||||
|
temperature=1,
|
||||||
|
stream=False,
|
||||||
|
max_tokens=20,
|
||||||
|
extra_body={
|
||||||
|
"chat_template_kwargs": {"enable_thinking": False},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response_case_3.choices[0].message.reasoning_content is None
|
||||||
|
@@ -535,6 +535,21 @@ def test_chat_with_thinking(openai_client, capsys):
|
|||||||
assert response.choices[0].message.reasoning_content is None
|
assert response.choices[0].message.reasoning_content is None
|
||||||
assert "</think>" not in response.choices[0].message.content
|
assert "</think>" not in response.choices[0].message.content
|
||||||
|
|
||||||
|
# test logic
|
||||||
|
reasoning_max_tokens = None
|
||||||
|
response = openai_client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[{"role": "user", "content": "Explain gravity in a way that a five-year-old child can understand."}],
|
||||||
|
temperature=1,
|
||||||
|
stream=False,
|
||||||
|
max_tokens=20,
|
||||||
|
extra_body={
|
||||||
|
"chat_template_kwargs": {"enable_thinking": True},
|
||||||
|
"reasoning_max_tokens": reasoning_max_tokens,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.choices[0].message.reasoning_content is not None
|
||||||
|
|
||||||
# enable thinking, streaming
|
# enable thinking, streaming
|
||||||
reasoning_max_tokens = 3
|
reasoning_max_tokens = 3
|
||||||
response = openai_client.chat.completions.create(
|
response = openai_client.chat.completions.create(
|
||||||
@@ -642,3 +657,50 @@ def test_profile_reset_block_num():
|
|||||||
f"Reset total_block_num {actual_value} 与 baseline {baseline} diff需要在5%以内"
|
f"Reset total_block_num {actual_value} 与 baseline {baseline} diff需要在5%以内"
|
||||||
f"Allowed range: [{lower_bound:.1f}, {upper_bound:.1f}]"
|
f"Allowed range: [{lower_bound:.1f}, {upper_bound:.1f}]"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_thinking_logic_flag(openai_client, capsys):
|
||||||
|
"""
|
||||||
|
Test the interaction between token calculation logic and conditional thinking.
|
||||||
|
This test covers:
|
||||||
|
1. Default max_tokens calculation when not provided.
|
||||||
|
2. Capping of max_tokens when it exceeds model limits.
|
||||||
|
3. Default reasoning_max_tokens calculation when not provided.
|
||||||
|
4. Activation of thinking based on the final state of reasoning_max_tokens.
|
||||||
|
"""
|
||||||
|
|
||||||
|
response_case_1 = openai_client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[{"role": "user", "content": "Explain gravity briefly."}],
|
||||||
|
temperature=1,
|
||||||
|
stream=False,
|
||||||
|
extra_body={
|
||||||
|
"chat_template_kwargs": {"enable_thinking": True},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response_case_1.choices[0].message.reasoning_content is not None
|
||||||
|
|
||||||
|
response_case_2 = openai_client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[{"role": "user", "content": "Explain gravity in a way that a five-year-old child can understand."}],
|
||||||
|
temperature=1,
|
||||||
|
stream=False,
|
||||||
|
max_tokens=20,
|
||||||
|
extra_body={
|
||||||
|
"chat_template_kwargs": {"enable_thinking": True},
|
||||||
|
"reasoning_max_tokens": 5,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response_case_2.choices[0].message.reasoning_content is not None
|
||||||
|
|
||||||
|
response_case_3 = openai_client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[{"role": "user", "content": "Explain gravity in a way that a five-year-old child can understand."}],
|
||||||
|
temperature=1,
|
||||||
|
stream=False,
|
||||||
|
max_tokens=20,
|
||||||
|
extra_body={
|
||||||
|
"chat_template_kwargs": {"enable_thinking": False},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response_case_3.choices[0].message.reasoning_content is None
|
||||||
|
Reference in New Issue
Block a user