[Bug fix] Fix pd for x1 thinking (#4433)

This commit is contained in:
chenjian
2025-10-16 12:03:45 +08:00
committed by GitHub
parent 8e392f0ea6
commit 670aaa3f83
5 changed files with 14 additions and 8 deletions

View File

@@ -697,9 +697,7 @@ class EngineService:
time.sleep(0.001)
continue
if self.cfg.scheduler_config.splitwise_role != "mixed":
if self.scheduler.get_unhandled_request_num() <= envs.FD_EP_MAX_PREFETCH_TASK_NUM and (
not is_fetching
):
if not is_fetching:
get_request_pool.submit(_fetch_request)
else:

View File

@@ -75,6 +75,7 @@ class Request:
structural_tag: Optional[Any] = None,
guided_json_object: Optional[bool] = None,
enable_thinking: Optional[bool] = True,
reasoning_max_tokens: Optional[int] = None,
trace_carrier: dict = dict(),
dp_rank: Optional[int] = None,
chat_template: Optional[str] = None,
@@ -125,6 +126,7 @@ class Request:
self.multimodal_img_boundaries = None
self.enable_thinking = enable_thinking
self.reasoning_max_tokens = reasoning_max_tokens
self.trace_carrier = trace_carrier
self.chat_template = chat_template
@@ -188,7 +190,8 @@ class Request:
guided_grammar=d.get("guided_grammar", None),
structural_tag=d.get("structural_tag", None),
guided_json_object=d.get("guided_json_object", None),
enable_thinking=d.get("enable_thinking", True),
enable_thinking=d.get("enable_thinking", False),
reasoning_max_tokens=d.get("reasoning_max_tokens", None),
trace_carrier=d.get("trace_carrier", {}),
chat_template=d.get("chat_template", None),
num_computed_tokens=d.get("num_computed_tokens", 0),
@@ -239,6 +242,7 @@ class Request:
"disaggregate_info": self.disaggregate_info,
"draft_token_ids": self.draft_token_ids,
"enable_thinking": self.enable_thinking,
"reasoning_max_tokens": self.reasoning_max_tokens,
"trace_carrier": self.trace_carrier,
"chat_template": self.chat_template,
"num_computed_tokens": self.num_computed_tokens,

View File

@@ -796,6 +796,8 @@ class ResourceManagerV1(ResourceManager):
return False
if self.available_batch() == 0:
return False
if request.reasoning_max_tokens is not None:
request.reasoning_max_tokens -= 1
request.need_prefill_tokens = len(request.prompt_token_ids)
need_prealloc_prefill_blocks = (
request.need_prefill_tokens + self.config.cache_config.block_size - 1

View File

@@ -216,10 +216,7 @@ def post_process_normal(
model_output.reasoning_index,
)
stop_wo_think = (
(sampler_output.sampled_token_ids == model_output.eos_token_id.T).any(axis=1, keepdim=True)
| (model_output.reasoning_index == 0)
) & (model_output.need_think_end > 0)
stop_wo_think = ((model_output.reasoning_index == 0)) & (model_output.need_think_end > 0)
stop_wo_think = stop_wo_think & thinking_mask
sampler_output.sampled_token_ids = paddle.where(

View File

@@ -174,6 +174,7 @@ class DPLocalScheduler(LocalScheduler):
):
break
else:
required_total_blocks = 0
batch_ids = self.requests_not_empty.wait_for(
lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + batch],
0.005,
@@ -181,6 +182,10 @@ class DPLocalScheduler(LocalScheduler):
if batch_ids:
for request_id in batch_ids:
request = self.requests[request_id]
required_input_blocks = self.calc_required_blocks(request.prompt_tokens_ids_len, block_size)
required_total_blocks += required_input_blocks + reserved_output_blocks
if required_total_blocks > available_blocks:
break
requests.append(request.raw)
self.ids_read_cursor += 1