mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
Remove CUDA ERROR 9 of inputs of get_padding_offset kernel (#5440)
Co-authored-by: K11OntheBoat <“ruianmaidanglao@163.com”>
This commit is contained in:
@@ -388,8 +388,8 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
const int block_size);
|
||||
|
||||
std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor& input_ids,
|
||||
const paddle::Tensor& token_num,
|
||||
const paddle::Tensor& seq_len);
|
||||
const paddle::Tensor& seq_len,
|
||||
const int64_t token_num_cpu);
|
||||
|
||||
void SetValueByFlagsAndIdx(const paddle::Tensor& pre_ids_all,
|
||||
const paddle::Tensor& input_ids,
|
||||
@@ -725,9 +725,9 @@ std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
|
||||
const paddle::Tensor& input_ids,
|
||||
const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& token_num,
|
||||
const paddle::Tensor& seq_len,
|
||||
const paddle::Tensor& seq_lens_encoder);
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const int64_t token_num_cpu);
|
||||
|
||||
std::vector<paddle::Tensor> SpeculateGetSeqLensOutput(
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
|
||||
@@ -64,8 +64,8 @@ __global__ void PrefixSumKernel(int64_t *ids_remove_padding,
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
|
||||
const paddle::Tensor &token_num,
|
||||
const paddle::Tensor &seq_len) {
|
||||
const paddle::Tensor &seq_len,
|
||||
const int64_t cpu_token_num) {
|
||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE
|
||||
auto dev_ctx = static_cast<const phi::CustomContext *>(
|
||||
paddle::experimental::DeviceContextPool::Instance().Get(
|
||||
@@ -77,9 +77,7 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
|
||||
std::vector<int64_t> input_ids_shape = input_ids.shape();
|
||||
const int bsz = seq_len.shape()[0];
|
||||
const int max_seq_len = input_ids_shape[1];
|
||||
auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false);
|
||||
|
||||
const int token_num_data = cpu_token_num.data<int64_t>()[0];
|
||||
const int token_num_data = cpu_token_num;
|
||||
auto x_remove_padding = paddle::empty(
|
||||
{token_num_data}, paddle::DataType::INT64, input_ids.place());
|
||||
auto batch_id_per_token = paddle::empty(
|
||||
@@ -124,11 +122,12 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(get_padding_offset)
|
||||
.Inputs({"input_ids", "token_num", "seq_len"})
|
||||
.Inputs({"input_ids", "seq_len"})
|
||||
.Outputs({"x_remove_padding",
|
||||
"batch_id_per_token",
|
||||
"cu_seqlens_q",
|
||||
"cu_seqlens_k"})
|
||||
.Attrs({"cpu_token_num: int64_t"})
|
||||
.SetKernelFn(PD_KERNEL(GetPaddingOffset))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(GetPaddingOffsetInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(GetPaddingOffsetInferDtype));
|
||||
|
||||
@@ -26,19 +26,19 @@ __global__ void SpeculateRemovePadding(int64_t* output_data,
|
||||
const int* cum_offsets,
|
||||
const int sequence_length,
|
||||
const int max_draft_tokens) {
|
||||
const int bi = blockIdx.x;
|
||||
const int tid = threadIdx.x;
|
||||
const int bi = blockIdx.x;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
for (int i = tid; i < seq_lens[bi]; i += blockDim.x) {
|
||||
const int tgt_seq_id = bi * sequence_length - cum_offsets[bi] + i;
|
||||
if (seq_lens_encoder[bi] > 0) {
|
||||
const int src_seq_id = bi * sequence_length + i;
|
||||
output_data[tgt_seq_id] = input_data[src_seq_id];
|
||||
} else {
|
||||
const int src_seq_id = bi * max_draft_tokens + i;
|
||||
output_data[tgt_seq_id] = draft_tokens[src_seq_id];
|
||||
}
|
||||
for (int i = tid; i < seq_lens[bi]; i += blockDim.x) {
|
||||
const int tgt_seq_id = bi * sequence_length - cum_offsets[bi] + i;
|
||||
if (seq_lens_encoder[bi] > 0) {
|
||||
const int src_seq_id = bi * sequence_length + i;
|
||||
output_data[tgt_seq_id] = input_data[src_seq_id];
|
||||
} else {
|
||||
const int src_seq_id = bi * max_draft_tokens + i;
|
||||
output_data[tgt_seq_id] = draft_tokens[src_seq_id];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void SpeculateGetPaddingOffsetKernel(int* batch_id_per_token,
|
||||
@@ -48,67 +48,65 @@ __global__ void SpeculateGetPaddingOffsetKernel(int* batch_id_per_token,
|
||||
const int* cum_offsets,
|
||||
const int* seq_lens,
|
||||
const int max_seq_len) {
|
||||
// get padding offset of each batch
|
||||
const int bi = blockIdx.x;
|
||||
const int ti = threadIdx.x;
|
||||
int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1];
|
||||
for (int i = ti; i < seq_lens[bi]; i += blockDim.x) {
|
||||
batch_id_per_token[bi * max_seq_len - cum_offset + i] = bi;
|
||||
}
|
||||
if (ti == 0) {
|
||||
cum_offsets_out[bi] = cum_offset;
|
||||
int cum_seq_len = (bi + 1) * max_seq_len - cum_offsets[bi];
|
||||
cu_seqlens_q[bi + 1] = cum_seq_len;
|
||||
cu_seqlens_k[bi + 1] = cum_seq_len;
|
||||
}
|
||||
// get padding offset of each batch
|
||||
const int bi = blockIdx.x;
|
||||
const int ti = threadIdx.x;
|
||||
int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1];
|
||||
for (int i = ti; i < seq_lens[bi]; i += blockDim.x) {
|
||||
batch_id_per_token[bi * max_seq_len - cum_offset + i] = bi;
|
||||
}
|
||||
if (ti == 0) {
|
||||
cum_offsets_out[bi] = cum_offset;
|
||||
int cum_seq_len = (bi + 1) * max_seq_len - cum_offsets[bi];
|
||||
cu_seqlens_q[bi + 1] = cum_seq_len;
|
||||
cu_seqlens_k[bi + 1] = cum_seq_len;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
|
||||
const paddle::Tensor& input_ids,
|
||||
const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& cum_offsets,
|
||||
const paddle::Tensor& token_num,
|
||||
const paddle::Tensor& seq_len,
|
||||
const paddle::Tensor& seq_lens_encoder) {
|
||||
auto cu_stream = input_ids.stream();
|
||||
std::vector<int64_t> input_ids_shape = input_ids.shape();
|
||||
const int bsz = seq_len.shape()[0];
|
||||
const int seq_length = input_ids_shape[1];
|
||||
const int max_draft_tokens = draft_tokens.shape()[1];
|
||||
auto cum_offsets_out = cum_offsets.copy_to(cum_offsets.place(), false);
|
||||
auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false);
|
||||
|
||||
const int token_num_data = cpu_token_num.data<int64_t>()[0];
|
||||
auto x_remove_padding = paddle::full(
|
||||
{token_num_data}, 0, paddle::DataType::INT64, input_ids.place());
|
||||
auto batch_id_per_token = paddle::full(
|
||||
{token_num_data}, 0, paddle::DataType::INT32, input_ids.place());
|
||||
auto cu_seqlens_q =
|
||||
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
||||
auto cu_seqlens_k =
|
||||
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
||||
int blockSize = min((token_num_data + 32 - 1) / 32 * 32, 128);
|
||||
SpeculateGetPaddingOffsetKernel<<<bsz, 128, 0, cu_stream>>>(
|
||||
batch_id_per_token.data<int>(),
|
||||
cum_offsets_out.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cu_seqlens_k.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
seq_len.data<int>(),
|
||||
seq_length);
|
||||
SpeculateRemovePadding<<<bsz, blockSize, 0, cu_stream>>>(
|
||||
x_remove_padding.data<int64_t>(),
|
||||
input_ids.data<int64_t>(),
|
||||
draft_tokens.data<int64_t>(),
|
||||
seq_len.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cum_offsets_out.data<int>(),
|
||||
seq_length,
|
||||
max_draft_tokens);
|
||||
return {x_remove_padding,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k}; // , enc_token_num, dec_token_num};
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const int64_t cpu_token_num) {
|
||||
auto cu_stream = input_ids.stream();
|
||||
std::vector<int64_t> input_ids_shape = input_ids.shape();
|
||||
const int bsz = seq_len.shape()[0];
|
||||
const int seq_length = input_ids_shape[1];
|
||||
const int max_draft_tokens = draft_tokens.shape()[1];
|
||||
auto cum_offsets_out = cum_offsets.copy_to(cum_offsets.place(), false);
|
||||
const int token_num_data = cpu_token_num;
|
||||
auto x_remove_padding = paddle::full(
|
||||
{token_num_data}, 0, paddle::DataType::INT64, input_ids.place());
|
||||
auto batch_id_per_token = paddle::full(
|
||||
{token_num_data}, 0, paddle::DataType::INT32, input_ids.place());
|
||||
auto cu_seqlens_q =
|
||||
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
||||
auto cu_seqlens_k =
|
||||
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
||||
int blockSize = min((token_num_data + 32 - 1) / 32 * 32, 128);
|
||||
SpeculateGetPaddingOffsetKernel<<<bsz, 128, 0, cu_stream>>>(
|
||||
batch_id_per_token.data<int>(),
|
||||
cum_offsets_out.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cu_seqlens_k.data<int>(),
|
||||
cum_offsets.data<int>(),
|
||||
seq_len.data<int>(),
|
||||
seq_length);
|
||||
SpeculateRemovePadding<<<bsz, blockSize, 0, cu_stream>>>(
|
||||
x_remove_padding.data<int64_t>(),
|
||||
input_ids.data<int64_t>(),
|
||||
draft_tokens.data<int64_t>(),
|
||||
seq_len.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cum_offsets_out.data<int>(),
|
||||
seq_length,
|
||||
max_draft_tokens);
|
||||
return {x_remove_padding,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k}; // , enc_token_num, dec_token_num};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> SpeculateGetPaddingOffsetInferShape(
|
||||
@@ -118,9 +116,9 @@ std::vector<std::vector<int64_t>> SpeculateGetPaddingOffsetInferShape(
|
||||
const std::vector<int64_t>& token_num_shape,
|
||||
const std::vector<int64_t>& seq_len_shape,
|
||||
const std::vector<int64_t>& seq_lens_encoder_shape) {
|
||||
int64_t bsz = seq_len_shape[0];
|
||||
int64_t seq_len = input_ids_shape[1];
|
||||
return {{-1}, {-1}, {bsz + 1}, {bsz + 1}};
|
||||
int64_t bsz = seq_len_shape[0];
|
||||
int64_t seq_len = input_ids_shape[1];
|
||||
return {{-1}, {-1}, {bsz + 1}, {bsz + 1}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> SpeculateGetPaddingOffsetInferDtype(
|
||||
@@ -130,23 +128,22 @@ std::vector<paddle::DataType> SpeculateGetPaddingOffsetInferDtype(
|
||||
const paddle::DataType& token_num_dtype,
|
||||
const paddle::DataType& seq_len_dtype,
|
||||
const paddle::DataType& seq_lens_encoder_dtype) {
|
||||
return {input_ids_dtype,
|
||||
seq_len_dtype,
|
||||
seq_len_dtype,
|
||||
seq_len_dtype};
|
||||
return {input_ids_dtype, seq_len_dtype, seq_len_dtype, seq_len_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_get_padding_offset)
|
||||
.Inputs({"input_ids",
|
||||
"draft_tokens",
|
||||
"cum_offsets",
|
||||
"token_num",
|
||||
"seq_len",
|
||||
"seq_lens_encoder"})
|
||||
.Inputs({
|
||||
"input_ids",
|
||||
"draft_tokens",
|
||||
"cum_offsets",
|
||||
"seq_len",
|
||||
"seq_lens_encoder",
|
||||
})
|
||||
.Outputs({"x_remove_padding",
|
||||
"batch_id_per_token",
|
||||
"cu_seqlens_q",
|
||||
"cu_seqlens_k"})
|
||||
.Attrs({"cpu_token_num: int64_t"})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateGetPaddingOffset))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(SpeculateGetPaddingOffsetInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(SpeculateGetPaddingOffsetInferDtype));
|
||||
|
||||
@@ -27,10 +27,7 @@ from fastdeploy.platforms import current_platform
|
||||
|
||||
if current_platform.is_cuda() and current_platform.available():
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
get_padding_offset,
|
||||
speculate_get_padding_offset,
|
||||
)
|
||||
from fastdeploy.model_executor.ops.gpu import get_padding_offset
|
||||
except Exception:
|
||||
raise ImportError(
|
||||
"Verify environment consistency between compilation and FastDeploy installation. "
|
||||
@@ -458,57 +455,6 @@ def remove_padding(
|
||||
)
|
||||
|
||||
|
||||
def speculate_remove_padding(
|
||||
max_len: paddle.Tensor,
|
||||
input_ids: paddle.Tensor,
|
||||
seq_lens_this_time: paddle.Tensor,
|
||||
draft_tokens: paddle.Tensor,
|
||||
seq_lens_encoder: paddle.Tensor,
|
||||
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
||||
"""
|
||||
Remove padding from sequences.
|
||||
|
||||
Args:
|
||||
max_len (paddle.Tensor): The maximum length of the sequences.
|
||||
input_ids (paddle.Tensor): The IDs of the input sequences.
|
||||
seq_lens_this_time (paddle.Tensor): The lengths of the sequences in the current batch.
|
||||
draft_tokens (paddle.Tensor): The draft tokens.
|
||||
seq_lens_encoder (paddle.Tensor): The lengths of the encoder sequences.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing:
|
||||
- The input sequence IDs with padding removed (paddle.Tensor).
|
||||
- Padding offsets (paddle.Tensor).
|
||||
- Cumulative offsets (paddle.Tensor).
|
||||
- Query sequence lengths (paddle.Tensor).
|
||||
- Key sequence lengths (paddle.Tensor).
|
||||
"""
|
||||
if current_platform.is_cuda():
|
||||
cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time, dtype="int32")
|
||||
token_num = paddle.sum(seq_lens_this_time)
|
||||
(
|
||||
ids_remove_padding,
|
||||
cum_offsets,
|
||||
padding_offset,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
) = speculate_get_padding_offset(
|
||||
input_ids,
|
||||
draft_tokens,
|
||||
cum_offsets_now,
|
||||
token_num,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
)
|
||||
return (
|
||||
ids_remove_padding,
|
||||
padding_offset,
|
||||
cum_offsets,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
)
|
||||
|
||||
|
||||
class CpuGuard:
|
||||
"""CpuGuard"""
|
||||
|
||||
|
||||
@@ -183,7 +183,7 @@ def speculate_limit_thinking_content_length(
|
||||
|
||||
def pre_process(
|
||||
input_ids: paddle.Tensor,
|
||||
seq_lens_this_time: int,
|
||||
seq_lens_this_time: paddle.Tensor,
|
||||
speculative_decoding: bool,
|
||||
draft_tokens: Optional[paddle.Tensor] = None,
|
||||
seq_lens_encoder: Optional[paddle.Tensor] = None,
|
||||
@@ -204,15 +204,13 @@ def pre_process(
|
||||
cu_seqlens_q:
|
||||
cu_seqlens_k:
|
||||
"""
|
||||
token_num = paddle.sum(seq_lens_this_time)
|
||||
|
||||
token_num_cpu = seq_lens_this_time.numpy().sum().item()
|
||||
specific_platform = current_platform.is_cuda() or current_platform.is_maca() or current_platform.is_iluvatar()
|
||||
if specific_platform and not speculative_decoding:
|
||||
# Note(ZKK): This case's code is very simple!
|
||||
ids_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k = get_padding_offset(
|
||||
input_ids, token_num, seq_lens_this_time
|
||||
input_ids, seq_lens_this_time, token_num_cpu
|
||||
)
|
||||
|
||||
return (
|
||||
ids_remove_padding,
|
||||
batch_id_per_token,
|
||||
@@ -221,7 +219,6 @@ def pre_process(
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
# Remove padding
|
||||
max_len = input_ids.shape[1]
|
||||
cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time, dtype="int32")
|
||||
@@ -234,12 +231,7 @@ def pre_process(
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
) = speculate_get_padding_offset(
|
||||
input_ids,
|
||||
draft_tokens,
|
||||
cum_offsets_now,
|
||||
token_num,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
input_ids, draft_tokens, cum_offsets_now, seq_lens_this_time, seq_lens_encoder, token_num_cpu
|
||||
)
|
||||
seq_lens_output = speculate_get_seq_lens_output(
|
||||
seq_lens_this_time,
|
||||
@@ -257,6 +249,7 @@ def pre_process(
|
||||
max_len,
|
||||
)
|
||||
else:
|
||||
token_num = paddle.sum(seq_lens_this_time)
|
||||
(
|
||||
ids_remove_padding,
|
||||
batch_id_per_token,
|
||||
|
||||
@@ -270,10 +270,10 @@ class TestAttentionPerformance(unittest.TestCase):
|
||||
partial_rotary_factor=fd_config.model_config.partial_rotary_factor,
|
||||
)
|
||||
|
||||
input_ids = paddle.zeros([batch_size, max_model_len], dtype="int64")
|
||||
token_num = paddle.sum(seq_lens_this_time)
|
||||
input_ids = paddle.zeros([batch_size, seq_len if mode == ForwardMode.EXTEND else 1], dtype="int64")
|
||||
token_num = np.sum(seq_lens_this_time)
|
||||
ids_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k = get_padding_offset(
|
||||
input_ids, token_num, seq_lens_this_time
|
||||
input_ids, seq_lens_this_time, token_num
|
||||
)
|
||||
|
||||
forward_meta = ForwardMeta(
|
||||
|
||||
@@ -23,7 +23,7 @@ from fastdeploy.model_executor.ops.gpu import get_padding_offset
|
||||
class TestGetPaddingOffset(unittest.TestCase):
|
||||
def test_get_padding_offset(self):
|
||||
seq_lens = np.array([4, 3, 6], "int32").reshape(-1, 1)
|
||||
token_num = np.sum(seq_lens)
|
||||
token_num_cpu = np.sum(seq_lens)
|
||||
input_ids = np.array(
|
||||
[[8, 7, 8, 2, 0, 0, 0, 0, 0, 0], [4, 5, 5, 0, 0, 0, 0, 0, 0, 0], [7, 6, 1, 7, 2, 6, 0, 0, 0, 0]], "int64"
|
||||
)
|
||||
@@ -32,11 +32,7 @@ class TestGetPaddingOffset(unittest.TestCase):
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
) = get_padding_offset(
|
||||
paddle.to_tensor(input_ids),
|
||||
paddle.to_tensor(token_num),
|
||||
paddle.to_tensor(seq_lens),
|
||||
)
|
||||
) = get_padding_offset(paddle.to_tensor(input_ids), paddle.to_tensor(seq_lens), token_num_cpu)
|
||||
|
||||
ref_x_remove_padding = np.array([8, 7, 8, 2, 4, 5, 5, 7, 6, 1, 7, 2, 6], "int64")
|
||||
ref_batch_id_per_token = np.array([0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 2], "int32")
|
||||
|
||||
@@ -86,14 +86,13 @@ class TestSpeculateGetPaddingOffset(unittest.TestCase):
|
||||
|
||||
input_ids = np.random.randint(0, 1000, (test_case["bsz"], test_case["max_seq_len"]), dtype=np.int64)
|
||||
draft_tokens = np.random.randint(0, 1000, (test_case["bsz"], max_draft_tokens), dtype=np.int64)
|
||||
token_num = np.array([test_case["token_num_data"]], dtype=np.int64)
|
||||
token_num_cpu = np.array([test_case["token_num_data"]], dtype=np.int64).item()
|
||||
|
||||
input_ids_tensor = paddle.to_tensor(input_ids)
|
||||
draft_tokens_tensor = paddle.to_tensor(draft_tokens)
|
||||
cum_offsets_tensor = paddle.to_tensor(test_case["cum_offsets"])
|
||||
seq_lens_tensor = paddle.to_tensor(test_case["seq_lens"])
|
||||
seq_lens_encoder_tensor = paddle.to_tensor(test_case["seq_lens_encoder"])
|
||||
token_num_tensor = paddle.to_tensor(token_num)
|
||||
|
||||
(
|
||||
x_remove_padding,
|
||||
@@ -104,9 +103,9 @@ class TestSpeculateGetPaddingOffset(unittest.TestCase):
|
||||
input_ids_tensor,
|
||||
draft_tokens_tensor,
|
||||
cum_offsets_tensor,
|
||||
token_num_tensor,
|
||||
seq_lens_tensor,
|
||||
seq_lens_encoder_tensor,
|
||||
token_num_cpu,
|
||||
)
|
||||
|
||||
(
|
||||
|
||||
Reference in New Issue
Block a user