polish code with new pre-commit rule (#2923)

This commit is contained in:
Zero Rains
2025-07-19 23:19:27 +08:00
committed by GitHub
parent b8676d71a8
commit 25698d56d1
424 changed files with 14307 additions and 13518 deletions

View File

@@ -41,10 +41,7 @@ step_idx = (seq_lens_decoder - ori_seq_lens_encoder).astype("int64")
max_block_num = block_bs * max_seq_len // block_size
free_list_len = int(max_block_num * (1 - block_ratio))
free_list_len = np.full([1], free_list_len, "int32")
free_list = np.arange(max_block_num - 1,
max_block_num - free_list_len - 1,
-1,
dtype="int32")
free_list = np.arange(max_block_num - 1, max_block_num - free_list_len - 1, -1, dtype="int32")
encoder_block_lens = np.zeros([max_bs], "int32")
used_list_len = np.zeros([max_bs], "int32")
@@ -53,19 +50,15 @@ encoder_block_id = 0
for i in range(bs):
enc_block_num = (ori_seq_lens_encoder[i] + block_size - 1) // block_size
encoder_block_lens[i] = enc_block_num
dec_block_num = (seq_lens_decoder[i] + block_size -
1) // block_size - enc_block_num
dec_block_num = (seq_lens_decoder[i] + block_size - 1) // block_size - enc_block_num
used_list_len[i] = dec_block_num
block_tables[i, :enc_block_num] = np.arange(
encoder_block_id, encoder_block_id + enc_block_num, 1, "int32")
block_tables[i, :enc_block_num] = np.arange(encoder_block_id, encoder_block_id + enc_block_num, 1, "int32")
encoder_block_id += enc_block_num
if dec_block_num > 0:
block_tables[
i, enc_block_num:enc_block_num +
dec_block_num] = free_list[free_list_len[0] - 1 -
dec_block_num:free_list_len[0] - 1]
free_list[free_list_len[0] - 1 - dec_block_num:free_list_len[0] -
1] = -1
block_tables[i, enc_block_num : enc_block_num + dec_block_num] = free_list[
free_list_len[0] - 1 - dec_block_num : free_list_len[0] - 1
]
free_list[free_list_len[0] - 1 - dec_block_num : free_list_len[0] - 1] = -1
free_list_len[0] -= dec_block_num
assert free_list_len[0] >= 0
@@ -137,13 +130,32 @@ first_token_ids = paddle.to_tensor(first_token_ids)
# print("step_idx: ", step_idx)
# print("next_tokens: ", next_tokens)
step_paddle(stop_flags, seq_lens_this_time, ori_seq_lens_encoder,
seq_lens_encoder, seq_lens_decoder, block_tables,
encoder_block_lens, is_block_step, step_block_list, step_lens,
recover_block_list, recover_lens, need_block_list, need_block_len,
used_list_len, free_list, free_list_len, input_ids, pre_ids,
step_idx, next_tokens, first_token_ids, block_size,
encoder_decoder_block_num)
step_paddle(
stop_flags,
seq_lens_this_time,
ori_seq_lens_encoder,
seq_lens_encoder,
seq_lens_decoder,
block_tables,
encoder_block_lens,
is_block_step,
step_block_list,
step_lens,
recover_block_list,
recover_lens,
need_block_list,
need_block_len,
used_list_len,
free_list,
free_list_len,
input_ids,
pre_ids,
step_idx,
next_tokens,
first_token_ids,
block_size,
encoder_decoder_block_num,
)
print("-" * 50 + "after step op" + "-" * 50)
print("stop_flags: ", stop_flags)