mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 00:33:03 +08:00
Compare commits
6 Commits
develop
...
feature/on
Author | SHA1 | Date | |
---|---|---|---|
![]() |
b272ca9f83 | ||
![]() |
db653644ad | ||
![]() |
4aa057f28d | ||
![]() |
05b7800d80 | ||
![]() |
12043fc476 | ||
![]() |
acecd5bebe |
@@ -190,6 +190,7 @@ class ModelConfig:
|
||||
self.reasoning_parser = None
|
||||
self.pad_token_id: int = -1
|
||||
self.eos_tokens_lens: int = 2
|
||||
self.think_end_id = None
|
||||
self.lm_head_fp32: bool = False
|
||||
self.model_format = "auto"
|
||||
self.runner = "auto"
|
||||
|
@@ -30,7 +30,7 @@ import paddle
|
||||
import zmq
|
||||
from opentelemetry import trace
|
||||
|
||||
from fastdeploy.engine.request import Request, RequestOutput
|
||||
from fastdeploy.engine.request import Request, RequestOutput, RequestType
|
||||
from fastdeploy.engine.resource_manager import ResourceManager
|
||||
from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1
|
||||
from fastdeploy.inter_communicator import (
|
||||
@@ -77,6 +77,7 @@ class EngineService:
|
||||
self.llm_logger = llm_logger
|
||||
|
||||
self.scheduler = cfg.scheduler_config.scheduler()
|
||||
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
|
||||
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
self.resource_manager = ResourceManagerV1(
|
||||
@@ -623,7 +624,7 @@ class EngineService:
|
||||
for tmp_task in need_delete_tasks:
|
||||
tasks.remove(tmp_task)
|
||||
# release resource in P
|
||||
self.resource_manager.prerelease_resource(task)
|
||||
self.resource_manager.prerelease_resource(tmp_task)
|
||||
if self.cfg.scheduler_config.splitwise_role == "prefill":
|
||||
# to send cache info to cache messager
|
||||
if tasks:
|
||||
@@ -657,9 +658,7 @@ class EngineService:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
if self.scheduler.get_unhandled_request_num() <= envs.FD_EP_MAX_PREFETCH_TASK_NUM and (
|
||||
not is_fetching
|
||||
):
|
||||
if not is_fetching:
|
||||
get_request_pool.submit(_fetch_request)
|
||||
|
||||
else:
|
||||
@@ -673,6 +672,22 @@ class EngineService:
|
||||
tasks = self.resource_manager.schedule()
|
||||
# 3. Send to engine
|
||||
if tasks:
|
||||
if self.cfg.scheduler_config.splitwise_role == "decode":
|
||||
for task in tasks:
|
||||
if task.task_type == RequestType.PREEMPTED:
|
||||
msg = f"{task.request_id} decode not enough blocks, need to be rescheduled."
|
||||
self.llm_logger.error(msg)
|
||||
self.scheduler.put_results(
|
||||
[
|
||||
RequestOutput(
|
||||
request_id=task.request_id,
|
||||
finished=True,
|
||||
error_code=500,
|
||||
error_msg=msg,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
self.resource_manager.get_real_bsz()
|
||||
self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz))
|
||||
else:
|
||||
|
@@ -72,6 +72,7 @@ class Request:
|
||||
structural_tag: Optional[Any] = None,
|
||||
guided_json_object: Optional[bool] = None,
|
||||
enable_thinking: Optional[bool] = True,
|
||||
reasoning_max_tokens: Optional[int] = None,
|
||||
trace_carrier: dict = dict(),
|
||||
dp_rank: Optional[int] = None,
|
||||
chat_template: Optional[str] = None,
|
||||
@@ -121,6 +122,7 @@ class Request:
|
||||
self.multimodal_img_boundaries = None
|
||||
|
||||
self.enable_thinking = enable_thinking
|
||||
self.reasoning_max_tokens = reasoning_max_tokens
|
||||
self.trace_carrier = trace_carrier
|
||||
|
||||
self.chat_template = chat_template
|
||||
@@ -178,7 +180,8 @@ class Request:
|
||||
guided_grammar=d.get("guided_grammar", None),
|
||||
structural_tag=d.get("structural_tag", None),
|
||||
guided_json_object=d.get("guided_json_object", None),
|
||||
enable_thinking=d.get("enable_thinking", True),
|
||||
enable_thinking=d.get("enable_thinking", False),
|
||||
reasoning_max_tokens=d.get("reasoning_max_tokens", None),
|
||||
trace_carrier=d.get("trace_carrier", {}),
|
||||
chat_template=d.get("chat_template", None),
|
||||
num_computed_tokens=d.get("num_computed_tokens", 0),
|
||||
@@ -229,6 +232,7 @@ class Request:
|
||||
"disaggregate_info": self.disaggregate_info,
|
||||
"draft_token_ids": self.draft_token_ids,
|
||||
"enable_thinking": self.enable_thinking,
|
||||
"reasoning_max_tokens": self.reasoning_max_tokens,
|
||||
"trace_carrier": self.trace_carrier,
|
||||
"chat_template": self.chat_template,
|
||||
"num_computed_tokens": self.num_computed_tokens,
|
||||
|
@@ -136,13 +136,23 @@ class ResourceManagerV1(ResourceManager):
|
||||
preempted_req = self.running.pop()
|
||||
preempted_req.status = RequestStatus.PREEMPTED
|
||||
preempted_req.num_computed_tokens = 0
|
||||
self._free_blocks(preempted_req)
|
||||
preempted_req.cached_block_num = 0
|
||||
self.to_be_rescheduled_request_id_set.add(preempted_req.request_id)
|
||||
if self.config.scheduler_config.splitwise_role == "decode":
|
||||
self.tasks_list[preempted_req.idx] = None
|
||||
self.stop_flags[preempted_req.idx] = True
|
||||
if preempted_req.request_id in self.requests:
|
||||
del self.requests[preempted_req.request_id]
|
||||
if preempted_req.request_id in self.req_dict:
|
||||
del self.req_dict[preempted_req.request_id]
|
||||
self._free_blocks(preempted_req)
|
||||
main_process_metrics.num_requests_running.dec(1)
|
||||
else:
|
||||
self._free_blocks(preempted_req)
|
||||
preempted_req.cached_block_num = 0
|
||||
self.to_be_rescheduled_request_id_set.add(preempted_req.request_id)
|
||||
main_process_metrics.num_requests_waiting.inc(1)
|
||||
main_process_metrics.num_requests_running.dec(1)
|
||||
preempted_reqs.append(preempted_req)
|
||||
scheduled_reqs.append(self._prepare_preempt_task(preempted_req))
|
||||
main_process_metrics.num_requests_waiting.inc(1)
|
||||
main_process_metrics.num_requests_running.dec(1)
|
||||
if preempted_req == request:
|
||||
# No more request to preempt.
|
||||
can_schedule = False
|
||||
@@ -583,8 +593,10 @@ class ResourceManagerV1(ResourceManager):
|
||||
with self.lock:
|
||||
self.tasks_list[request.idx] = None
|
||||
self.stop_flags[request.idx] = True
|
||||
del self.requests[request.request_id]
|
||||
del self.req_dict[request.request_id]
|
||||
if request.request_id in self.requests:
|
||||
del self.requests[request.request_id]
|
||||
if request.request_id in self.req_dict:
|
||||
del self.req_dict[request.request_id]
|
||||
self._free_blocks(request)
|
||||
|
||||
def add_request_in_p(self, requests: list[Request]):
|
||||
@@ -660,6 +672,8 @@ class ResourceManagerV1(ResourceManager):
|
||||
return False
|
||||
if self.available_batch() == 0:
|
||||
return False
|
||||
if request.reasoning_max_tokens is not None:
|
||||
request.reasoning_max_tokens -= 1
|
||||
request.need_prefill_tokens = len(request.prompt_token_ids)
|
||||
need_prealloc_prefill_blocks = (
|
||||
request.need_prefill_tokens + self.config.cache_config.block_size - 1
|
||||
|
@@ -154,8 +154,6 @@ class EngineClient:
|
||||
task["prompt_token_ids_len"] = len(task["prompt_token_ids"])
|
||||
input_ids_len = task["prompt_token_ids_len"]
|
||||
task["max_tokens"] = min(self.max_model_len - input_ids_len, task.get("max_tokens"))
|
||||
if task.get("reasoning_max_tokens", None) is None:
|
||||
task["reasoning_max_tokens"] = max(int(task["max_tokens"] * 0.8), 1)
|
||||
min_tokens = task.get("min_tokens", 1)
|
||||
if "messages" in task:
|
||||
del task["messages"]
|
||||
|
@@ -253,6 +253,10 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
|
||||
request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1]
|
||||
if request.get("max_tokens") is None:
|
||||
request["max_tokens"] = max(1, max_model_len - len(request["prompt_token_ids"]))
|
||||
else:
|
||||
request["max_tokens"] = min(max_model_len - len(request["prompt_token_ids"]), request["max_tokens"])
|
||||
if request.get("reasoning_max_tokens") is None:
|
||||
request["reasoning_max_tokens"] = max(int(request["max_tokens"] * 0.8), 1)
|
||||
data_processor_logger.info(f"Processed request {request}")
|
||||
|
||||
return request
|
||||
|
@@ -193,7 +193,7 @@ def post_process_normal(
|
||||
) -> ModelRunnerOutput:
|
||||
"""Post-processing steps after completing a single token generation."""
|
||||
# handle vl:
|
||||
if model_output.enable_thinking:
|
||||
if model_output.enable_thinking and model_output.think_end_id is not None:
|
||||
exists_think_end = sampler_output.sampled_token_ids == model_output.think_end_id
|
||||
paddle.assign(
|
||||
paddle.where(
|
||||
|
@@ -159,7 +159,6 @@ class DPLocalScheduler(LocalScheduler):
|
||||
required_total_blocks += required_input_blocks + reserved_output_blocks
|
||||
if required_total_blocks > available_blocks:
|
||||
break
|
||||
|
||||
requests.append(request.raw)
|
||||
self.ids_read_cursor += 1
|
||||
start_batch_time = time.time()
|
||||
@@ -174,6 +173,7 @@ class DPLocalScheduler(LocalScheduler):
|
||||
):
|
||||
break
|
||||
else:
|
||||
required_total_blocks = 0
|
||||
batch_ids = self.requests_not_empty.wait_for(
|
||||
lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + batch],
|
||||
0.005,
|
||||
@@ -181,6 +181,10 @@ class DPLocalScheduler(LocalScheduler):
|
||||
if batch_ids:
|
||||
for request_id in batch_ids:
|
||||
request = self.requests[request_id]
|
||||
required_input_blocks = self.calc_required_blocks(request.prompt_tokens_ids_len, block_size)
|
||||
required_total_blocks += required_input_blocks + reserved_output_blocks
|
||||
if required_total_blocks > available_blocks:
|
||||
break
|
||||
requests.append(request.raw)
|
||||
self.ids_read_cursor += 1
|
||||
|
||||
|
@@ -387,14 +387,20 @@ class SplitwiseConnector:
|
||||
f"{tasks[i].disaggregate_info['cache_info']['rdma']['ip']}:"
|
||||
+ f"{tasks[i].disaggregate_info['cache_info']['rdma']['port']}"
|
||||
)
|
||||
cache_info = {
|
||||
"request_id": tasks[i].request_id,
|
||||
"device_ids": self.cfg.device_ids.split(","),
|
||||
"ip": self.cfg.host_ip,
|
||||
"rdma_ports": self.cfg.disaggregate_info["cache_info"]["rdma"]["rdma_port"],
|
||||
"transfer_protocol": "rdma",
|
||||
"dest_block_ids": tasks[i].disaggregate_info["block_tables"],
|
||||
}
|
||||
if tasks[i].get("error_msg", None) is not None:
|
||||
cache_info = {
|
||||
"request_id": tasks[i].request_id,
|
||||
"error_msg": tasks[i].get("error_msg"),
|
||||
}
|
||||
else:
|
||||
cache_info = {
|
||||
"request_id": tasks[i].request_id,
|
||||
"device_ids": self.cfg.device_ids.split(","),
|
||||
"ip": self.cfg.host_ip,
|
||||
"rdma_ports": self.cfg.disaggregate_info["cache_info"]["rdma"]["rdma_port"],
|
||||
"transfer_protocol": "rdma",
|
||||
"dest_block_ids": tasks[i].disaggregate_info["block_tables"],
|
||||
}
|
||||
if addr not in temp_cache_info:
|
||||
temp_cache_info[addr] = []
|
||||
|
||||
|
@@ -322,15 +322,21 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
else:
|
||||
position_ids = None
|
||||
|
||||
enable_thinking = request.get("enable_thinking", True)
|
||||
enable_thinking = enable_thinking if enable_thinking is not None else True
|
||||
self.share_inputs["enable_thinking"][:] = enable_thinking
|
||||
self.share_inputs["need_think_end"][idx : idx + 1, :] = 1 if enable_thinking else 0
|
||||
self.share_inputs["reasoning_index"][idx : idx + 1, :] = request.get("reasoning_max_tokens", 2048)
|
||||
self.share_inputs["rope_emb"][idx : idx + 1, :] = self.prepare_rope3d(
|
||||
position_ids, request.get("max_tokens", 2048)
|
||||
)
|
||||
|
||||
if request.get("enable_thinking", False) and request.get("reasoning_max_tokens") is not None:
|
||||
# Enable thinking
|
||||
self.share_inputs["enable_thinking"][:] = True
|
||||
self.share_inputs["need_think_end"][idx : idx + 1, :] = 1
|
||||
self.share_inputs["reasoning_index"][idx : idx + 1, :] = request.get("reasoning_max_tokens")
|
||||
else:
|
||||
# Disable thinking
|
||||
self.share_inputs["enable_thinking"][:] = False
|
||||
self.share_inputs["need_think_end"][idx : idx + 1, :] = 0
|
||||
self.share_inputs["reasoning_index"][idx : idx + 1, :] = 0
|
||||
|
||||
if isinstance(request.prompt_token_ids, np.ndarray):
|
||||
prompt_token_ids = request.prompt_token_ids.tolist()
|
||||
else:
|
||||
@@ -549,16 +555,22 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["prompt_lens"][idx : idx + 1] = length
|
||||
|
||||
if self.enable_mm:
|
||||
enable_thinking = request.get("enable_thinking", True)
|
||||
enable_thinking = enable_thinking if enable_thinking is not None else True
|
||||
self.share_inputs["enable_thinking"][:] = enable_thinking
|
||||
self.share_inputs["need_think_end"][idx : idx + 1, :] = 1 if enable_thinking else 0
|
||||
self.share_inputs["reasoning_index"][idx : idx + 1, :] = request.get("reasoning_max_tokens", 2048)
|
||||
self.share_inputs["rope_emb"][idx : idx + 1, :] = self.prepare_rope3d(
|
||||
position_ids, request.get("max_tokens", 2048)
|
||||
)
|
||||
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
|
||||
|
||||
if request.get("enable_thinking", False) and request.get("reasoning_max_tokens") is not None:
|
||||
# Enable thinking
|
||||
self.share_inputs["enable_thinking"][:] = True
|
||||
self.share_inputs["need_think_end"][idx : idx + 1, :] = 1
|
||||
self.share_inputs["reasoning_index"][idx : idx + 1, :] = request.get("reasoning_max_tokens")
|
||||
else:
|
||||
# Disable thinking
|
||||
self.share_inputs["enable_thinking"][:] = False
|
||||
self.share_inputs["need_think_end"][idx : idx + 1, :] = 0
|
||||
self.share_inputs["reasoning_index"][idx : idx + 1, :] = 0
|
||||
|
||||
def get_attr_from_request(request, attr, default_value=None):
|
||||
res = request.get(attr, default_value)
|
||||
if res is not None:
|
||||
@@ -853,6 +865,11 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
# Initialize rotary position embedding
|
||||
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
|
||||
|
||||
# Initialize thinking related buffers
|
||||
self.share_inputs["need_think_end"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
|
||||
self.share_inputs["enable_thinking"] = paddle.full(shape=[1], fill_value=False, dtype="bool")
|
||||
self.share_inputs["reasoning_index"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
|
||||
|
||||
# TODO(gongshaotian): move to models
|
||||
if not self.enable_mm:
|
||||
self.share_inputs["rope_emb"] = get_rope(
|
||||
@@ -952,11 +969,6 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
dtype="float32",
|
||||
)
|
||||
self.share_inputs["image_features"] = None
|
||||
self.share_inputs["need_think_end"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
|
||||
self.share_inputs["enable_thinking"] = paddle.full(
|
||||
shape=[1], fill_value=("ernie" in self.model_config.model_type), dtype="bool"
|
||||
)
|
||||
self.share_inputs["reasoning_index"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
|
||||
|
||||
def _prepare_inputs(self) -> None:
|
||||
"""Prepare the model inputs"""
|
||||
@@ -1398,10 +1410,10 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
),
|
||||
accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None),
|
||||
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
|
||||
enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None),
|
||||
think_end_id=(getattr(self.model_config, "think_end_id", -1) if self.enable_mm else -1),
|
||||
need_think_end=(self.share_inputs["need_think_end"] if self.enable_mm else None),
|
||||
reasoning_index=(self.share_inputs["reasoning_index"] if self.enable_mm else None),
|
||||
enable_thinking=self.share_inputs["enable_thinking"],
|
||||
think_end_id=self.model_config.think_end_id,
|
||||
need_think_end=self.share_inputs["need_think_end"],
|
||||
reasoning_index=self.share_inputs["reasoning_index"],
|
||||
stop_token_ids=self.share_inputs["stop_seqs"],
|
||||
stop_seqs_len=self.share_inputs["stop_seqs_len"],
|
||||
)
|
||||
@@ -1714,10 +1726,10 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
),
|
||||
accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None),
|
||||
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
|
||||
enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None),
|
||||
think_end_id=(getattr(self.model_config, "think_end_id", -1) if self.enable_mm else -1),
|
||||
need_think_end=(self.share_inputs["need_think_end"][:num_running_requests] if self.enable_mm else None),
|
||||
reasoning_index=(self.share_inputs["reasoning_index"][:num_running_requests] if self.enable_mm else None),
|
||||
enable_thinking=self.share_inputs["enable_thinking"],
|
||||
think_end_id=self.model_config.think_end_id,
|
||||
need_think_end=self.share_inputs["need_think_end"][:num_running_requests],
|
||||
reasoning_index=self.share_inputs["reasoning_index"][:num_running_requests],
|
||||
stop_token_ids=self.share_inputs["stop_seqs"],
|
||||
stop_seqs_len=self.share_inputs["stop_seqs_len"],
|
||||
)
|
||||
|
@@ -131,6 +131,28 @@ def update_fd_config_for_mm(fd_config: FDConfig) -> None:
|
||||
fd_config.model_config.sequence_parallel = fd_config.parallel_config.sequence_parallel
|
||||
|
||||
|
||||
def update_think_end_id_for_ernie(fd_config: FDConfig) -> None:
|
||||
"""
|
||||
Updates the think_end_id in the model config. Uses the ID of '</think>'
|
||||
if it exists, otherwise defaults to None.
|
||||
"""
|
||||
is_ernie = ErnieArchitectures.contains_ernie_arch(fd_config.model_config.architectures)
|
||||
if current_platform.is_cuda() and is_ernie:
|
||||
tokenizer = Ernie4_5Tokenizer.from_pretrained(
|
||||
fd_config.model_config.model,
|
||||
model_max_length=fd_config.parallel_config.max_model_len,
|
||||
padding_side="right",
|
||||
use_fast=False,
|
||||
)
|
||||
|
||||
vocab = tokenizer.get_vocab()
|
||||
fd_config.model_config.think_end_id = vocab.get("</think>", None)
|
||||
if fd_config.model_config.think_end_id is not None:
|
||||
logger.info(f"Get think_end_id {fd_config.model_config.think_end_id} from vocab.")
|
||||
else:
|
||||
logger.info("No </think> token found in vocabulary, the model can not do reasoning.")
|
||||
|
||||
|
||||
class PaddleDisWorkerProc:
|
||||
"""
|
||||
Paddle Distributed wrapper for fastdeploy.worker.Worker,
|
||||
@@ -798,6 +820,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
moba_attention_config=moba_attention_config,
|
||||
)
|
||||
update_fd_config_for_mm(fd_config)
|
||||
update_think_end_id_for_ernie(fd_config)
|
||||
|
||||
return fd_config
|
||||
|
||||
|
@@ -927,3 +927,63 @@ def test_profile_reset_block_num():
|
||||
f"Reset total_block_num {actual_value} 与 baseline {baseline} diff需要在5%以内"
|
||||
f"Allowed range: [{lower_bound:.1f}, {upper_bound:.1f}]"
|
||||
)
|
||||
|
||||
|
||||
def test_thinking_logic_flag(openai_client, capsys):
|
||||
"""
|
||||
Test the interaction between token calculation logic and conditional thinking.
|
||||
This test covers:
|
||||
1. Default max_tokens calculation when not provided.
|
||||
2. Capping of max_tokens when it exceeds model limits.
|
||||
3. Default reasoning_max_tokens calculation when not provided.
|
||||
4. Activation of thinking based on the final state of reasoning_max_tokens.
|
||||
"""
|
||||
|
||||
response_case_1 = openai_client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[{"role": "user", "content": "Explain gravity briefly."}],
|
||||
temperature=1,
|
||||
stream=False,
|
||||
extra_body={
|
||||
"chat_template_kwargs": {"enable_thinking": True},
|
||||
},
|
||||
)
|
||||
assert response_case_1.choices[0].message.reasoning_content is not None
|
||||
|
||||
response_case_2 = openai_client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[{"role": "user", "content": "Explain gravity in a way that a five-year-old child can understand."}],
|
||||
temperature=1,
|
||||
stream=False,
|
||||
max_tokens=20,
|
||||
extra_body={
|
||||
"chat_template_kwargs": {"enable_thinking": True},
|
||||
"reasoning_max_tokens": 5,
|
||||
},
|
||||
)
|
||||
assert response_case_2.choices[0].message.reasoning_content is not None
|
||||
|
||||
response_case_3 = openai_client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[{"role": "user", "content": "Explain gravity in a way that a five-year-old child can understand."}],
|
||||
temperature=1,
|
||||
stream=False,
|
||||
max_tokens=20,
|
||||
extra_body={
|
||||
"chat_template_kwargs": {"enable_thinking": True},
|
||||
"reasoning_max_tokens": None,
|
||||
},
|
||||
)
|
||||
assert response_case_3.choices[0].message.reasoning_content is not None
|
||||
|
||||
response_case_4 = openai_client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[{"role": "user", "content": "Explain gravity in a way that a five-year-old child can understand."}],
|
||||
temperature=1,
|
||||
stream=False,
|
||||
max_tokens=20,
|
||||
extra_body={
|
||||
"chat_template_kwargs": {"enable_thinking": False},
|
||||
},
|
||||
)
|
||||
assert response_case_4.choices[0].message.reasoning_content is None
|
||||
|
@@ -642,3 +642,63 @@ def test_profile_reset_block_num():
|
||||
f"Reset total_block_num {actual_value} 与 baseline {baseline} diff需要在5%以内"
|
||||
f"Allowed range: [{lower_bound:.1f}, {upper_bound:.1f}]"
|
||||
)
|
||||
|
||||
|
||||
def test_thinking_logic_flag(openai_client, capsys):
|
||||
"""
|
||||
Test the interaction between token calculation logic and conditional thinking.
|
||||
This test covers:
|
||||
1. Default max_tokens calculation when not provided.
|
||||
2. Capping of max_tokens when it exceeds model limits.
|
||||
3. Default reasoning_max_tokens calculation when not provided.
|
||||
4. Activation of thinking based on the final state of reasoning_max_tokens.
|
||||
"""
|
||||
|
||||
response_case_1 = openai_client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[{"role": "user", "content": "Explain gravity briefly."}],
|
||||
temperature=1,
|
||||
stream=False,
|
||||
extra_body={
|
||||
"chat_template_kwargs": {"enable_thinking": True},
|
||||
},
|
||||
)
|
||||
assert response_case_1.choices[0].message.reasoning_content is not None
|
||||
|
||||
response_case_2 = openai_client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[{"role": "user", "content": "Explain gravity in a way that a five-year-old child can understand."}],
|
||||
temperature=1,
|
||||
stream=False,
|
||||
max_tokens=20,
|
||||
extra_body={
|
||||
"chat_template_kwargs": {"enable_thinking": True},
|
||||
"reasoning_max_tokens": 5,
|
||||
},
|
||||
)
|
||||
assert response_case_2.choices[0].message.reasoning_content is not None
|
||||
|
||||
response_case_3 = openai_client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[{"role": "user", "content": "Explain gravity in a way that a five-year-old child can understand."}],
|
||||
temperature=1,
|
||||
stream=False,
|
||||
max_tokens=20,
|
||||
extra_body={
|
||||
"chat_template_kwargs": {"enable_thinking": True},
|
||||
"reasoning_max_tokens": None,
|
||||
},
|
||||
)
|
||||
assert response_case_3.choices[0].message.reasoning_content is not None
|
||||
|
||||
response_case_4 = openai_client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[{"role": "user", "content": "Explain gravity in a way that a five-year-old child can understand."}],
|
||||
temperature=1,
|
||||
stream=False,
|
||||
max_tokens=20,
|
||||
extra_body={
|
||||
"chat_template_kwargs": {"enable_thinking": False},
|
||||
},
|
||||
)
|
||||
assert response_case_4.choices[0].message.reasoning_content is None
|
||||
|
Reference in New Issue
Block a user