Support limit thinking lengths (#4070)

Co-authored-by: K11OntheBoat <“ruianmaidanglao@163.com”>
This commit is contained in:
K11OntheBoat
2025-09-17 12:40:08 +08:00
committed by GitHub
parent b41988f4bc
commit 7f9a9b37f3
8 changed files with 184 additions and 26 deletions

View File

@@ -130,6 +130,7 @@ class ModelConfig:
self.quantization = None self.quantization = None
self.pad_token_id: int = -1 self.pad_token_id: int = -1
self.eos_tokens_lens: int = 2 self.eos_tokens_lens: int = 2
self.think_end_id = None
self.lm_head_fp32: bool = False self.lm_head_fp32: bool = False
self.model_format = "auto" self.model_format = "auto"
self.partial_rotary_factor: float = 1.0 self.partial_rotary_factor: float = 1.0

View File

@@ -177,8 +177,6 @@ class EngineClient:
task["prompt_token_ids_len"] = len(task["prompt_token_ids"]) task["prompt_token_ids_len"] = len(task["prompt_token_ids"])
input_ids_len = task["prompt_token_ids_len"] input_ids_len = task["prompt_token_ids_len"]
task["max_tokens"] = min(self.max_model_len - input_ids_len, task.get("max_tokens")) task["max_tokens"] = min(self.max_model_len - input_ids_len, task.get("max_tokens"))
if task.get("reasoning_max_tokens", None) is None:
task["reasoning_max_tokens"] = max(int(task["max_tokens"] * 0.8), 1)
min_tokens = task.get("min_tokens", 1) min_tokens = task.get("min_tokens", 1)
if "messages" in task: if "messages" in task:
del task["messages"] del task["messages"]

View File

@@ -255,6 +255,10 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1] request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1]
if request.get("max_tokens") is None: if request.get("max_tokens") is None:
request["max_tokens"] = max(1, max_model_len - len(request["prompt_token_ids"])) request["max_tokens"] = max(1, max_model_len - len(request["prompt_token_ids"]))
else:
request["max_tokens"] = min(max_model_len - len(request["prompt_token_ids"]), request["max_tokens"])
if request.get("reasoning_max_tokens") is None:
request["reasoning_max_tokens"] = max(int(request["max_tokens"] * 0.8), 1)
data_processor_logger.info(f"Processed request {request}") data_processor_logger.info(f"Processed request {request}")
return request return request

View File

@@ -166,7 +166,7 @@ def post_process_normal(
) -> ModelRunnerOutput: ) -> ModelRunnerOutput:
"""Post-processing steps after completing a single token generation.""" """Post-processing steps after completing a single token generation."""
# handle vl: # handle vl:
if model_output.enable_thinking: if model_output.enable_thinking and model_output.think_end_id is not None:
exists_think_end = sampler_output.sampled_token_ids == model_output.think_end_id exists_think_end = sampler_output.sampled_token_ids == model_output.think_end_id
paddle.assign( paddle.assign(
paddle.where( paddle.where(

View File

@@ -265,15 +265,21 @@ class GPUModelRunner(ModelRunnerBase):
else: else:
position_ids = None position_ids = None
enable_thinking = request.get("enable_thinking", True)
enable_thinking = enable_thinking if enable_thinking is not None else True
self.share_inputs["enable_thinking"][:] = enable_thinking
self.share_inputs["need_think_end"][idx : idx + 1, :] = 1 if enable_thinking else 0
self.share_inputs["reasoning_index"][idx : idx + 1, :] = request.get("reasoning_max_tokens", 2048)
self.share_inputs["rope_emb"][idx : idx + 1, :] = self.prepare_rope3d( self.share_inputs["rope_emb"][idx : idx + 1, :] = self.prepare_rope3d(
position_ids, request.get("max_tokens", 2048) position_ids, request.get("max_tokens", 2048)
) )
if request.get("enable_thinking", False) and request.get("reasoning_max_tokens") is not None:
# Enable thinking
self.share_inputs["enable_thinking"][:] = True
self.share_inputs["need_think_end"][idx : idx + 1, :] = 1
self.share_inputs["reasoning_index"][idx : idx + 1, :] = request.get("reasoning_max_tokens")
else:
# Disable thinking
self.share_inputs["enable_thinking"][:] = False
self.share_inputs["need_think_end"][idx : idx + 1, :] = 0
self.share_inputs["reasoning_index"][idx : idx + 1, :] = 0
if isinstance(request.prompt_token_ids, np.ndarray): if isinstance(request.prompt_token_ids, np.ndarray):
prompt_token_ids = request.prompt_token_ids.tolist() prompt_token_ids = request.prompt_token_ids.tolist()
else: else:
@@ -495,16 +501,22 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["prompt_lens"][idx : idx + 1] = length self.share_inputs["prompt_lens"][idx : idx + 1] = length
if self.enable_mm: if self.enable_mm:
enable_thinking = request.get("enable_thinking", True)
enable_thinking = enable_thinking if enable_thinking is not None else True
self.share_inputs["enable_thinking"][:] = enable_thinking
self.share_inputs["need_think_end"][idx : idx + 1, :] = 1 if enable_thinking else 0
self.share_inputs["reasoning_index"][idx : idx + 1, :] = request.get("reasoning_max_tokens", 2048)
self.share_inputs["rope_emb"][idx : idx + 1, :] = self.prepare_rope3d( self.share_inputs["rope_emb"][idx : idx + 1, :] = self.prepare_rope3d(
position_ids, request.get("max_tokens", 2048) position_ids, request.get("max_tokens", 2048)
) )
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0 self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
if request.get("enable_thinking", False) and request.get("reasoning_max_tokens") is not None:
# Enable thinking
self.share_inputs["enable_thinking"][:] = True
self.share_inputs["need_think_end"][idx : idx + 1, :] = 1
self.share_inputs["reasoning_index"][idx : idx + 1, :] = request.get("reasoning_max_tokens")
else:
# Disable thinking
self.share_inputs["enable_thinking"][:] = False
self.share_inputs["need_think_end"][idx : idx + 1, :] = 0
self.share_inputs["reasoning_index"][idx : idx + 1, :] = 0
def get_attr_from_request(request, attr, default_value=None): def get_attr_from_request(request, attr, default_value=None):
res = request.get(attr, default_value) res = request.get(attr, default_value)
if res is not None: if res is not None:
@@ -735,6 +747,11 @@ class GPUModelRunner(ModelRunnerBase):
# Initialize rotary position embedding # Initialize rotary position embedding
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1)) tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
# Initialize thinking related buffers
self.share_inputs["need_think_end"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
self.share_inputs["enable_thinking"] = paddle.full(shape=[1], fill_value=False, dtype="bool")
self.share_inputs["reasoning_index"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
# TODO(gongshaotian): move to models # TODO(gongshaotian): move to models
if not self.enable_mm: if not self.enable_mm:
self.share_inputs["rope_emb"] = get_rope( self.share_inputs["rope_emb"] = get_rope(
@@ -827,11 +844,6 @@ class GPUModelRunner(ModelRunnerBase):
dtype="float32", dtype="float32",
) )
self.share_inputs["image_features"] = None self.share_inputs["image_features"] = None
self.share_inputs["need_think_end"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
self.share_inputs["enable_thinking"] = paddle.full(
shape=[1], fill_value=("ernie" in self.model_config.model_type), dtype="bool"
)
self.share_inputs["reasoning_index"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
def _prepare_inputs(self) -> None: def _prepare_inputs(self) -> None:
"""Prepare the model inputs""" """Prepare the model inputs"""
@@ -1220,10 +1232,10 @@ class GPUModelRunner(ModelRunnerBase):
), ),
accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None), accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None),
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None), accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None), enable_thinking=self.share_inputs["enable_thinking"],
think_end_id=(getattr(self.model_config, "think_end_id", -1) if self.enable_mm else -1), think_end_id=self.model_config.think_end_id,
need_think_end=(self.share_inputs["need_think_end"] if self.enable_mm else None), need_think_end=self.share_inputs["need_think_end"],
reasoning_index=(self.share_inputs["reasoning_index"] if self.enable_mm else None), reasoning_index=self.share_inputs["reasoning_index"],
stop_token_ids=self.share_inputs["stop_seqs"], stop_token_ids=self.share_inputs["stop_seqs"],
stop_seqs_len=self.share_inputs["stop_seqs_len"], stop_seqs_len=self.share_inputs["stop_seqs_len"],
) )
@@ -1515,10 +1527,10 @@ class GPUModelRunner(ModelRunnerBase):
), ),
accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None), accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None),
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None), accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None), enable_thinking=self.share_inputs["enable_thinking"],
think_end_id=(getattr(self.model_config, "think_end_id", -1) if self.enable_mm else -1), think_end_id=self.model_config.think_end_id,
need_think_end=(self.share_inputs["need_think_end"][:num_running_requests] if self.enable_mm else None), need_think_end=self.share_inputs["need_think_end"][:num_running_requests],
reasoning_index=(self.share_inputs["reasoning_index"][:num_running_requests] if self.enable_mm else None), reasoning_index=self.share_inputs["reasoning_index"][:num_running_requests],
stop_token_ids=self.share_inputs["stop_seqs"], stop_token_ids=self.share_inputs["stop_seqs"],
stop_seqs_len=self.share_inputs["stop_seqs_len"], stop_seqs_len=self.share_inputs["stop_seqs_len"],
) )

View File

@@ -129,6 +129,28 @@ def update_fd_config_for_mm(fd_config: FDConfig) -> None:
fd_config.model_config.sequence_parallel = fd_config.parallel_config.sequence_parallel fd_config.model_config.sequence_parallel = fd_config.parallel_config.sequence_parallel
def update_think_end_id_for_ernie(fd_config: FDConfig) -> None:
"""
Updates the think_end_id in the model config. Uses the ID of '</think>'
if it exists, otherwise defaults to None.
"""
is_ernie = ErnieArchitectures.contains_ernie_arch(fd_config.model_config.architectures)
if current_platform.is_cuda() and is_ernie:
tokenizer = Ernie4_5Tokenizer.from_pretrained(
fd_config.model_config.model,
model_max_length=fd_config.parallel_config.max_model_len,
padding_side="right",
use_fast=False,
)
vocab = tokenizer.get_vocab()
fd_config.model_config.think_end_id = vocab.get("</think>", None)
if fd_config.model_config.think_end_id is not None:
logger.info(f"Get think_end_id {fd_config.model_config.think_end_id} from vocab.")
else:
logger.info("No </think> token found in vocabulary, the model can not do reasoning.")
class PaddleDisWorkerProc: class PaddleDisWorkerProc:
""" """
Paddle Distributed wrapper for fastdeploy.worker.Worker, Paddle Distributed wrapper for fastdeploy.worker.Worker,
@@ -771,6 +793,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
plas_attention_config=plas_attention_config, plas_attention_config=plas_attention_config,
) )
update_fd_config_for_mm(fd_config) update_fd_config_for_mm(fd_config)
update_think_end_id_for_ernie(fd_config)
return fd_config return fd_config

View File

@@ -580,3 +580,63 @@ def test_profile_reset_block_num():
f"Reset total_block_num {actual_value} 与 baseline {baseline} diff需要在5%以内" f"Reset total_block_num {actual_value} 与 baseline {baseline} diff需要在5%以内"
f"Allowed range: [{lower_bound:.1f}, {upper_bound:.1f}]" f"Allowed range: [{lower_bound:.1f}, {upper_bound:.1f}]"
) )
def test_thinking_logic_flag(openai_client, capsys):
"""
Test the interaction between token calculation logic and conditional thinking.
This test covers:
1. Default max_tokens calculation when not provided.
2. Capping of max_tokens when it exceeds model limits.
3. Default reasoning_max_tokens calculation when not provided.
4. Activation of thinking based on the final state of reasoning_max_tokens.
"""
response_case_1 = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Explain gravity briefly."}],
temperature=1,
stream=False,
extra_body={
"chat_template_kwargs": {"enable_thinking": True},
},
)
assert response_case_1.choices[0].message.reasoning_content is not None
response_case_2 = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Explain gravity in a way that a five-year-old child can understand."}],
temperature=1,
stream=False,
max_tokens=20,
extra_body={
"chat_template_kwargs": {"enable_thinking": True},
"reasoning_max_tokens": 5,
},
)
assert response_case_2.choices[0].message.reasoning_content is not None
response_case_3 = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Explain gravity in a way that a five-year-old child can understand."}],
temperature=1,
stream=False,
max_tokens=20,
extra_body={
"chat_template_kwargs": {"enable_thinking": True},
"reasoning_max_tokens": None,
},
)
assert response_case_3.choices[0].message.reasoning_content is not None
response_case_4 = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Explain gravity in a way that a five-year-old child can understand."}],
temperature=1,
stream=False,
max_tokens=20,
extra_body={
"chat_template_kwargs": {"enable_thinking": False},
},
)
assert response_case_4.choices[0].message.reasoning_content is None

View File

@@ -592,3 +592,63 @@ def test_profile_reset_block_num():
f"Reset total_block_num {actual_value} 与 baseline {baseline} diff需要在5%以内" f"Reset total_block_num {actual_value} 与 baseline {baseline} diff需要在5%以内"
f"Allowed range: [{lower_bound:.1f}, {upper_bound:.1f}]" f"Allowed range: [{lower_bound:.1f}, {upper_bound:.1f}]"
) )
def test_thinking_logic_flag(openai_client, capsys):
"""
Test the interaction between token calculation logic and conditional thinking.
This test covers:
1. Default max_tokens calculation when not provided.
2. Capping of max_tokens when it exceeds model limits.
3. Default reasoning_max_tokens calculation when not provided.
4. Activation of thinking based on the final state of reasoning_max_tokens.
"""
response_case_1 = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Explain gravity briefly."}],
temperature=1,
stream=False,
extra_body={
"chat_template_kwargs": {"enable_thinking": True},
},
)
assert response_case_1.choices[0].message.reasoning_content is not None
response_case_2 = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Explain gravity in a way that a five-year-old child can understand."}],
temperature=1,
stream=False,
max_tokens=20,
extra_body={
"chat_template_kwargs": {"enable_thinking": True},
"reasoning_max_tokens": 5,
},
)
assert response_case_2.choices[0].message.reasoning_content is not None
response_case_3 = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Explain gravity in a way that a five-year-old child can understand."}],
temperature=1,
stream=False,
max_tokens=20,
extra_body={
"chat_template_kwargs": {"enable_thinking": True},
"reasoning_max_tokens": None,
},
)
assert response_case_3.choices[0].message.reasoning_content is not None
response_case_4 = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Explain gravity in a way that a five-year-old child can understand."}],
temperature=1,
stream=False,
max_tokens=20,
extra_body={
"chat_template_kwargs": {"enable_thinking": False},
},
)
assert response_case_4.choices[0].message.reasoning_content is None