[Feature] support mtp logprob (#4464)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

* support mtp logprob

* fix unitest
This commit is contained in:
GoldPancake
2025-10-20 15:18:12 +08:00
committed by GitHub
parent 1b9f351d21
commit 47595a2480
14 changed files with 1181 additions and 32 deletions

View File

@@ -32,6 +32,8 @@ from fastdeploy.model_executor.layers.sample.ops import (
apply_penalty_multi_scores,
apply_speculative_penalty_multi_scores,
min_p_sampling,
speculate_get_target_logits,
speculate_insert_first_token,
top_k_top_p_sampling,
)
from fastdeploy.platforms import current_platform
@@ -455,6 +457,98 @@ class SpeculativeSampler(nn.Layer):
"""apply logits processor to sampler"""
pass
def compute_logprobs(
self,
logits: paddle.Tensor,
sampling_metadata: SamplingMetadata,
) -> paddle.Tensor:
"""compute logprobs"""
share_inputs = sampling_metadata.share_inputs
last_logits = logits
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
batch_token_num = share_inputs["batch_token_num"][:real_bsz]
temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs
top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs
if temp_scaled_logprobs is not None:
real_bsz_temp_scaled = temp_scaled_logprobs[:real_bsz]
temperature = sampling_metadata.temperature[:real_bsz]
real_bsz_temp_scaled = (
real_bsz_temp_scaled.astype("int32").squeeze(1).repeat_interleave(batch_token_num).astype("bool")
)
temperature = temperature.squeeze(1).repeat_interleave(batch_token_num)
temp_temperature = paddle.where(
real_bsz_temp_scaled, temperature, paddle.ones_like(temperature)
).unsqueeze(1)
last_logits = last_logits / temp_temperature
last_logprobs = F.log_softmax(last_logits, axis=-1)
top_p_logprob = None
top_p_token_mask = None
if top_p_normalized_logprobs is not None and share_inputs is not None:
real_token_top_p = (
sampling_metadata.top_p[:real_bsz].squeeze(1).repeat_interleave(batch_token_num).unsqueeze(1)
)
top_p_normalized_logprobs = (
top_p_normalized_logprobs[:real_bsz]
.astype("int32")
.squeeze(1)
.repeat_interleave(batch_token_num)
.astype("bool")
.unsqueeze(1)
)
top_p_token_mask = paddle.logical_and(top_p_normalized_logprobs, real_token_top_p != 1.0)
if top_p_token_mask.any():
probs = F.softmax(last_logits, axis=-1)
probs = top_p_normalize_probs_paddle(probs, real_token_top_p)
top_p_logprob = paddle.log(probs)
if top_p_logprob is not None:
last_logprobs = paddle.where(top_p_token_mask, top_p_logprob, last_logprobs)
return last_logprobs
def gather_logprobs(
self,
logprobs: paddle.Tensor,
num_logprobs: int,
token_ids: paddle.Tensor,
) -> LogprobsTensors:
"""
Gather logprobs for topk and sampled/prompt token.
Args:
logprobs: (num tokens) x (vocab) tensor
num_logprobs: minimum number of logprobs to
retain per token
token_ids: prompt tokens (if prompt logprobs)
or sampled tokens (if sampled
logprobs); 1D token ID tensor
with (num tokens) elements
Must be int64.
Returns:
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
Sampled token rank tensor, (num tokens)
"""
assert token_ids.dtype == paddle.int64
token_ids = token_ids.unsqueeze(1)
logprobs.clip_(min=paddle.finfo(logprobs.dtype).min)
# Get with the logprob of the prompt or sampled token.
token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1)
# Compute the ranks of the actual token.
token_ranks = (logprobs >= token_logprobs).sum(-1)
if num_logprobs >= 1:
# Find the topK values.
topk_logprobs, topk_indices = paddle.topk(logprobs, num_logprobs, axis=-1)
indices = paddle.concat([token_ids, topk_indices], axis=1)
top_logprobs = paddle.concat([token_logprobs, topk_logprobs], axis=1)
else:
indices = token_ids
top_logprobs = token_logprobs
return LogprobsTensors(indices, top_logprobs, token_ranks)
def forward_cuda(
self,
logits: paddle.Tensor,
@@ -521,7 +615,56 @@ class SpeculativeSampler(nn.Layer):
accept_all_drafts,
)
return None
num_logprobs = sampling_metadata.max_num_logprobs
batch_token_num = None
if num_logprobs is not None:
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
batch_token_num = paddle.where(
share_inputs["seq_lens_encoder"][:real_bsz] != 0,
paddle.ones_like(share_inputs["seq_lens_encoder"][:real_bsz]),
share_inputs["accept_num"][:real_bsz].unsqueeze(1),
).squeeze(1)
share_inputs["batch_token_num"] = batch_token_num
ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype(
"int32"
)
cu_batch_token_offset = paddle.concat(
[paddle.to_tensor([0]), paddle.cumsum(share_inputs["accept_num"][:real_bsz])]
).astype("int32")
share_inputs["cu_batch_token_offset"] = cu_batch_token_offset
target_logtis = paddle.empty(
[share_inputs["accept_num"][:real_bsz].sum(), logits.shape[1]], dtype=logits.dtype
)
speculate_get_target_logits(
target_logtis,
logits,
cu_batch_token_offset,
ori_cu_batch_token_offset,
share_inputs["seq_lens_this_time"],
share_inputs["seq_lens_encoder"],
share_inputs["accept_num"],
)
raw_logprobs = self.compute_logprobs(target_logtis, sampling_metadata)
logprobs_tensors = None
token_ids = share_inputs["accept_tokens"]
if num_logprobs is not None:
token_ids = paddle.concat(
[
share_inputs["accept_tokens"][i, : share_inputs["accept_num"][i]]
for i in range(share_inputs["accept_num"][:real_bsz].shape[0])
]
)
logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids)
sampler_output = SamplerOutput(
sampled_token_ids=token_ids,
logprobs_tensors=logprobs_tensors,
token_num_per_batch=batch_token_num,
cu_batch_token_offset=share_inputs["cu_batch_token_offset"],
)
return sampler_output
class MTPSampler(nn.Layer):
@@ -556,6 +699,103 @@ class MTPSampler(nn.Layer):
"""post process after running"""
pass
def compute_logprobs(
self,
logits: paddle.Tensor,
sampling_metadata: SamplingMetadata,
) -> paddle.Tensor:
"""compute logprobs"""
share_inputs = sampling_metadata.share_inputs
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
last_logits = logits
temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs
top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs
if temp_scaled_logprobs is not None:
real_bsz_temp_scaled = temp_scaled_logprobs[:real_bsz]
temperature = sampling_metadata.temperature[:real_bsz]
real_bsz_temp_scaled = (
real_bsz_temp_scaled.astype("int32")
.squeeze(1)
.repeat_interleave(share_inputs["batch_token_num"][:real_bsz])
.astype("bool")
)
temperature = temperature.squeeze(1).repeat_interleave(share_inputs["batch_token_num"][:real_bsz])
temp_temperature = paddle.where(
real_bsz_temp_scaled, temperature, paddle.ones_like(temperature)
).unsqueeze(1)
last_logits = last_logits / temp_temperature
last_logprobs = F.log_softmax(last_logits, axis=-1)
top_p_logprob = None
top_p_token_mask = None
if top_p_normalized_logprobs is not None and share_inputs is not None:
real_token_top_p = (
sampling_metadata.top_p[:real_bsz]
.squeeze(1)
.repeat_interleave(share_inputs["batch_token_num"][:real_bsz])
.unsqueeze(1)
)
top_p_normalized_logprobs = (
top_p_normalized_logprobs[:real_bsz]
.astype("int32")
.squeeze(1)
.repeat_interleave(share_inputs["batch_token_num"][:real_bsz])
.astype("bool")
.unsqueeze(1)
)
top_p_token_mask = paddle.logical_and(top_p_normalized_logprobs, real_token_top_p != 1.0)
if top_p_token_mask.any():
probs = F.softmax(last_logits, axis=-1)
probs = top_p_normalize_probs_paddle(probs, real_token_top_p)
top_p_logprob = paddle.log(probs)
if top_p_logprob is not None:
last_logprobs = paddle.where(top_p_token_mask, top_p_logprob, last_logprobs)
return last_logprobs
def gather_logprobs(
self,
logprobs: paddle.Tensor,
num_logprobs: int,
token_ids: paddle.Tensor,
) -> LogprobsTensors:
"""
Gather logprobs for topk and sampled/prompt token.
Args:
logprobs: (num tokens) x (vocab) tensor
num_logprobs: minimum number of logprobs to
retain per token
token_ids: prompt tokens (if prompt logprobs)
or sampled tokens (if sampled
logprobs); 1D token ID tensor
with (num tokens) elements
Must be int64.
Returns:
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
Sampled token rank tensor, (num tokens)
"""
assert token_ids.dtype == paddle.int64
token_ids = token_ids.unsqueeze(1)
logprobs.clip_(min=paddle.finfo(logprobs.dtype).min)
# Get with the logprob of the prompt or sampled token.
token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1)
# Compute the ranks of the actual token.
token_ranks = (logprobs >= token_logprobs).sum(-1)
if num_logprobs >= 1:
# Find the topK values.
topk_logprobs, topk_indices = paddle.topk(logprobs, num_logprobs, axis=-1)
indices = paddle.concat([token_ids, topk_indices], axis=1)
top_logprobs = paddle.concat([token_logprobs, topk_logprobs], axis=1)
else:
indices = token_ids
top_logprobs = token_logprobs
return LogprobsTensors(indices, top_logprobs, token_ranks)
def forward_cuda(
self,
logits: paddle.Tensor,
@@ -564,6 +804,12 @@ class MTPSampler(nn.Layer):
share_inputs: List[paddle.Tensor],
) -> paddle.Tensor:
""" """
num_logprobs = sampling_metadata.max_num_logprobs
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
if num_logprobs is not None and share_inputs["substep"] == 0:
real_token_num = share_inputs["batch_token_num"][:real_bsz].sum()
raw_logprobs = self.compute_logprobs(share_inputs["draft_logits"][:real_token_num, :], sampling_metadata)
logits = apply_speculative_penalty_multi_scores(
sampling_metadata.pre_token_ids,
logits,
@@ -585,4 +831,27 @@ class MTPSampler(nn.Layer):
_, next_tokens = top_k_top_p_sampling(
probs, sampling_metadata.top_p, sampling_metadata.top_k, sampling_metadata.top_k_list
)
return next_tokens
token_ids = None
logprobs_tensors = None
if num_logprobs is not None and share_inputs["substep"] == 0:
token_ids = paddle.empty(real_token_num, dtype="int64")
speculate_insert_first_token(
token_ids,
share_inputs["accept_tokens"],
next_tokens,
share_inputs["cu_next_token_offset"],
share_inputs["cu_batch_token_offset"],
share_inputs["seq_lens_this_time"],
share_inputs["seq_lens_encoder"],
)
logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids)
sampler_output = SamplerOutput(
sampled_token_ids=token_ids,
logprobs_tensors=logprobs_tensors,
token_num_per_batch=share_inputs["batch_token_num"][:real_bsz],
cu_batch_token_offset=share_inputs["cu_batch_token_offset"],
)
return next_tokens, sampler_output