[Metax] optimize flash mla (#4915)

This commit is contained in:
xiaozude
2025-11-12 16:43:46 +08:00
committed by GitHub
parent 9d9f5df8d0
commit c45b3ccb52
5 changed files with 37 additions and 38 deletions

View File

@@ -91,7 +91,7 @@ void LimitThinkingContentLengthV1(const paddle::Tensor &next_tokens,
const int64_t think_end_id) {
const int batch_size = next_tokens.shape()[0];
const int eos_token_id_len = eos_token_ids.shape()[0];
limit_thinking_content_length_kernel_v1<<<1, 1024>>>(
limit_thinking_content_length_kernel_v1<<<1, 1024, 0, next_tokens.stream()>>>(
const_cast<int64_t *>(next_tokens.data<int64_t>()),
max_think_lens.data<int>(),
step_idx.data<int64_t>(),