[Iluvatar] Support V1_KVCACHE_SCHEDULER and paddleocr-vl rope mode (#5555)

This commit is contained in:
yzwu
2025-12-18 18:14:25 +08:00
committed by GitHub
parent 48f3e9797e
commit ac013803f3
24 changed files with 1212 additions and 1090 deletions

View File

@@ -11,8 +11,7 @@ concurrency:
jobs:
CI_ILUVATAR:
runs-on:
group: IXUCA
runs-on: [self-hosted, ILUVATAR_8Card]
steps:
- name: Print current runner name
run: |
@@ -23,7 +22,7 @@ jobs:
- name: Code Checkout
env:
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-ixuca:latest
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-ixuca:paddle-ocr-vl-1107
run: |
REPO="https://github.com/${{ github.repository }}.git"
FULL_REPO="${{ github.repository }}"
@@ -56,7 +55,7 @@ jobs:
- name: Run CI unittest
env:
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-ixuca:latest
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-ixuca:paddle-ocr-vl-1107
run: |
runner_name="${{ runner.name }}"
last_char="${runner_name: -1}"

View File

@@ -28,8 +28,13 @@ __global__ void PrefixSumKernel(int64_t *ids_remove_padding,
const int max_seq_len) {
const int bi = blockIdx.x;
const int tid = threadIdx.x;
#ifdef PADDLE_WITH_COREX
const int warp_id = threadIdx.x / 64;
const int lane_id = threadIdx.x % 64;
#else
const int warp_id = threadIdx.x / 32;
const int lane_id = threadIdx.x % 32;
#endif
int cum_seq_len = 0;

View File

@@ -16,32 +16,37 @@
#include "iluvatar_context.h"
template <paddle::DataType T>
void MixedFusedPagedAttnKernel(const paddle::Tensor& qkv,
paddle::Tensor& k_cache,
paddle::Tensor& v_cache,
const paddle::Tensor& prefill_block_table,
const paddle::Tensor& decode_block_table,
const paddle::Tensor& cu_seqlens_qkv,
const paddle::Tensor& seq_lens,
const paddle::optional<paddle::Tensor>& rope_sin,
const paddle::optional<paddle::Tensor>& rope_cos,
int prefill_num_tokens,
int num_heads,
int head_dim,
int num_kv_heads,
int block_size,
int max_seq_len,
float scale,
bool causal,
bool q_rope,
bool k_rope,
bool v_rope,
int window_left,
int window_right,
float softcap,
bool enable_cuda_graph,
bool use_sqrt_alibi,
paddle::Tensor& out) {
void MixedFusedPagedAttnKernel(
const paddle::Tensor& qkv,
paddle::Tensor& k_cache,
paddle::Tensor& v_cache,
const paddle::Tensor& prefill_block_table,
const paddle::Tensor& decode_block_table,
const paddle::Tensor& cu_seqlens_qkv,
const paddle::Tensor& seq_lens,
const paddle::Tensor& prefill_rope_sin,
const paddle::Tensor& prefill_rope_cos,
const paddle::optional<paddle::Tensor>& decode_rope_sin,
const paddle::optional<paddle::Tensor>& decode_rope_cos,
int prefill_num_tokens,
int num_heads,
int head_dim,
int num_kv_heads,
int block_size,
int max_seq_len,
float scale,
bool causal,
bool q_rope,
bool k_rope,
bool v_rope,
int window_left,
int window_right,
float softcap,
bool enable_cuda_graph,
bool use_sqrt_alibi,
int rope_batch_stride,
bool is_interleaved_rope_mode,
paddle::Tensor& out) {
typedef PDTraits<T> traits_;
typedef typename traits_::data_t data_t;
@@ -72,8 +77,39 @@ void MixedFusedPagedAttnKernel(const paddle::Tensor& qkv,
int kv_block_stride = k_cache.strides()[0];
int kv_head_stride = k_cache.strides()[1];
int block_table_stride = prefill_block_table.strides()[0];
const float* rope_sin_ptr = rope_sin ? rope_sin.get().data<float>() : nullptr;
const float* rope_cos_ptr = rope_cos ? rope_cos.get().data<float>() : nullptr;
const float* prefill_rope_sin_ptr = prefill_rope_sin.data<float>();
const float* prefill_rope_cos_ptr = prefill_rope_cos.data<float>();
const auto& prefill_rope_dims = prefill_rope_sin.dims();
std::vector<int> prefill_rope_shape_vec, prefill_rope_stride_vec;
int prefill_rope_ndim;
if (prefill_rope_dims.size() == 4) {
// [prefill_batch_size, max_seq_len, 1, head_dim]
PADDLE_ENFORCE_EQ(
prefill_rope_dims[0],
prefill_batch_size,
common::errors::InvalidArgument(
"prefill_rope_dims[0] must be equal to prefill_batch_size"));
prefill_rope_shape_vec =
std::vector<int>({prefill_batch_size, max_seq_len, head_dim});
prefill_rope_stride_vec =
std::vector<int>({max_seq_len * head_dim, head_dim, 1});
prefill_rope_ndim = 3;
} else if (prefill_rope_dims.size() == 3) {
// [max_seq_len, 1, head_dim]
prefill_rope_shape_vec = std::vector<int>({max_seq_len, head_dim});
prefill_rope_stride_vec = std::vector<int>({head_dim, 1});
prefill_rope_ndim = 2;
} else {
PD_THROW("Unsupported prefill_rope_ndim = %d for Paged attn",
prefill_rope_ndim);
}
const float* decode_rope_sin_ptr =
decode_rope_sin ? decode_rope_sin.get().data<float>() : nullptr;
const float* decode_rope_cos_ptr =
decode_rope_cos ? decode_rope_cos.get().data<float>() : nullptr;
cuinferAttentionRopeMode_t rope_mode =
is_interleaved_rope_mode ? CUINFER_ATTEN_NORMAL : CUINFER_ATTEN_OCRV1;
cuinferTensorDescriptor_t qkv_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&qkv_desc));
@@ -139,21 +175,19 @@ void MixedFusedPagedAttnKernel(const paddle::Tensor& qkv,
cuinferTensorDescriptor_t cos_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&cos_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
cos_desc,
CUINFER_DATA_FLOAT,
2,
std::vector<int>({max_seq_len, head_dim}).data(),
std::vector<int>({head_dim, 1}).data()));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(cos_desc,
CUINFER_DATA_FLOAT,
prefill_rope_ndim,
prefill_rope_shape_vec.data(),
prefill_rope_stride_vec.data()));
cuinferTensorDescriptor_t sin_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&sin_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
sin_desc,
CUINFER_DATA_FLOAT,
2,
std::vector<int>({max_seq_len, head_dim}).data(),
std::vector<int>({head_dim, 1}).data()));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(sin_desc,
CUINFER_DATA_FLOAT,
prefill_rope_ndim,
prefill_rope_shape_vec.data(),
prefill_rope_stride_vec.data()));
cuinferHandle_t cuinfer_handle =
iluvatar::getContextInstance()->getIxInferHandle();
@@ -195,9 +229,9 @@ void MixedFusedPagedAttnKernel(const paddle::Tensor& qkv,
prefill_workspace_ptr,
prefill_workspace_size,
cos_desc,
rope_cos_ptr,
prefill_rope_cos_ptr,
sin_desc,
rope_sin_ptr,
prefill_rope_sin_ptr,
prefill_batch_size,
num_heads,
num_kv_heads,
@@ -206,7 +240,8 @@ void MixedFusedPagedAttnKernel(const paddle::Tensor& qkv,
scale,
q_rope,
k_rope,
v_rope));
v_rope,
rope_mode));
size_t decode_workspace_size = 0;
CUINFER_CHECK(cuInferPageAttentionGetWorkspaceV7(decode_num_tokens,
@@ -241,8 +276,18 @@ void MixedFusedPagedAttnKernel(const paddle::Tensor& qkv,
decode_qkv_ptr,
decode_workspace_ptr,
true,
rope_sin_ptr,
rope_cos_ptr};
decode_rope_sin_ptr,
decode_rope_cos_ptr,
nullptr,
nullptr,
nullptr,
nullptr,
1,
0,
0,
nullptr,
static_cast<size_t>(rope_batch_stride),
rope_mode};
CUINFER_CHECK(cuInferPageAttentionV7(cuinfer_handle,
decode_out_ptr,
@@ -285,8 +330,10 @@ std::vector<paddle::Tensor> MixedFusedPagedAttn(
const paddle::Tensor& decode_block_table,
const paddle::Tensor& cu_seqlens_qkv,
const paddle::Tensor& seq_lens,
const paddle::optional<paddle::Tensor>& rope_sin,
const paddle::optional<paddle::Tensor>& rope_cos,
const paddle::Tensor& prefill_rope_sin,
const paddle::Tensor& prefill_rope_cos,
const paddle::optional<paddle::Tensor>& decode_rope_sin,
const paddle::optional<paddle::Tensor>& decode_rope_cos,
int prefill_num_tokens,
int num_heads,
int head_dim,
@@ -302,67 +349,79 @@ std::vector<paddle::Tensor> MixedFusedPagedAttn(
int window_right,
float softcap,
bool enable_cuda_graph,
bool use_sqrt_alibi) {
bool use_sqrt_alibi,
int rope_batch_stride,
bool is_interleaved_rope_mode) {
const auto dtype = qkv.dtype();
auto out =
paddle::empty({qkv.shape()[0], num_heads * head_dim}, dtype, qkv.place());
switch (dtype) {
case paddle::DataType::BFLOAT16:
MixedFusedPagedAttnKernel<paddle::DataType::BFLOAT16>(qkv,
k_cache,
v_cache,
prefill_block_table,
decode_block_table,
cu_seqlens_qkv,
seq_lens,
rope_sin,
rope_cos,
prefill_num_tokens,
num_heads,
head_dim,
num_kv_heads,
block_size,
max_seq_len,
scale,
causal,
q_rope,
k_rope,
v_rope,
window_left,
window_right,
softcap,
enable_cuda_graph,
use_sqrt_alibi,
out);
MixedFusedPagedAttnKernel<paddle::DataType::BFLOAT16>(
qkv,
k_cache,
v_cache,
prefill_block_table,
decode_block_table,
cu_seqlens_qkv,
seq_lens,
prefill_rope_sin,
prefill_rope_cos,
decode_rope_sin,
decode_rope_cos,
prefill_num_tokens,
num_heads,
head_dim,
num_kv_heads,
block_size,
max_seq_len,
scale,
causal,
q_rope,
k_rope,
v_rope,
window_left,
window_right,
softcap,
enable_cuda_graph,
use_sqrt_alibi,
rope_batch_stride,
is_interleaved_rope_mode,
out);
break;
case paddle::DataType::FLOAT16:
MixedFusedPagedAttnKernel<paddle::DataType::FLOAT16>(qkv,
k_cache,
v_cache,
prefill_block_table,
decode_block_table,
cu_seqlens_qkv,
seq_lens,
rope_sin,
rope_cos,
prefill_num_tokens,
num_heads,
head_dim,
num_kv_heads,
block_size,
max_seq_len,
scale,
causal,
q_rope,
k_rope,
v_rope,
window_left,
window_right,
softcap,
enable_cuda_graph,
use_sqrt_alibi,
out);
MixedFusedPagedAttnKernel<paddle::DataType::FLOAT16>(
qkv,
k_cache,
v_cache,
prefill_block_table,
decode_block_table,
cu_seqlens_qkv,
seq_lens,
prefill_rope_sin,
prefill_rope_cos,
decode_rope_sin,
decode_rope_cos,
prefill_num_tokens,
num_heads,
head_dim,
num_kv_heads,
block_size,
max_seq_len,
scale,
causal,
q_rope,
k_rope,
v_rope,
window_left,
window_right,
softcap,
enable_cuda_graph,
use_sqrt_alibi,
rope_batch_stride,
is_interleaved_rope_mode,
out);
break;
default:
PD_THROW("Unsupported data type for mixed paged attn");
@@ -388,8 +447,10 @@ PD_BUILD_STATIC_OP(mixed_fused_paged_attn)
"decode_block_table",
"cu_seqlens_qkv",
"seq_lens",
paddle::Optional("rope_sin"),
paddle::Optional("rope_cos")})
"prefill_rope_sin",
"prefill_rope_cos",
paddle::Optional("decode_rope_sin"),
paddle::Optional("decode_rope_cos")})
.Outputs({"out"})
.Attrs({"prefill_num_tokens:int",
"num_heads: int",
@@ -406,7 +467,9 @@ PD_BUILD_STATIC_OP(mixed_fused_paged_attn)
"window_right:int",
"softcap:float",
"enable_cuda_graph:bool",
"use_sqrt_alibi:bool"})
"use_sqrt_alibi:bool",
"rope_batch_stride:int",
"is_interleaved_rope_mode:bool"})
.SetKernelFn(PD_KERNEL(MixedFusedPagedAttn))
.SetInferShapeFn(PD_INFER_SHAPE(MixedFusedPagedAttnInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MixedFusedPagedAttnInferDtype));

View File

@@ -39,6 +39,8 @@ void PagedAttnKernel(const paddle::Tensor& q,
bool enable_cuda_graph,
bool use_sqrt_alibi,
bool merged_qkv,
int rope_batch_stride,
bool is_interleaved_rope_mode,
paddle::Tensor& out) {
if (alibi_slopes) {
PADDLE_ENFORCE_EQ(alibi_slopes.get().dtype(),
@@ -186,6 +188,9 @@ void PagedAttnKernel(const paddle::Tensor& q,
allocator->Allocate(workspace_size);
void* workspace_ptr = tmp_workspace->ptr();
cuinferAttentionRopeMode_t rope_mode =
is_interleaved_rope_mode ? CUINFER_ATTEN_NORMAL : CUINFER_ATTEN_OCRV1;
PageAttentionWithKVCacheArguments args{static_cast<float>(scale),
1.0,
1.0,
@@ -202,7 +207,17 @@ void PagedAttnKernel(const paddle::Tensor& q,
workspace_ptr,
merged_qkv,
rope_sin_ptr,
rope_cos_ptr};
rope_cos_ptr,
nullptr,
nullptr,
nullptr,
nullptr,
1,
0,
0,
nullptr,
static_cast<size_t>(rope_batch_stride),
rope_mode};
CUINFER_CHECK(cuInferPageAttentionV7(cuinfer_handle,
out.data(),
data_type,
@@ -250,7 +265,9 @@ std::vector<paddle::Tensor> PagedAttn(
float softcap,
bool enable_cuda_graph,
bool use_sqrt_alibi,
bool merged_qkv) {
bool merged_qkv,
int rope_batch_stride,
bool is_interleaved_rope_mode) {
const auto dtype = q.dtype();
auto out =
paddle::empty({q.shape()[0], num_heads * head_dim}, dtype, q.place());
@@ -280,6 +297,8 @@ std::vector<paddle::Tensor> PagedAttn(
enable_cuda_graph,
use_sqrt_alibi,
merged_qkv,
rope_batch_stride,
is_interleaved_rope_mode,
out);
break;
case paddle::DataType::FLOAT16:
@@ -306,6 +325,8 @@ std::vector<paddle::Tensor> PagedAttn(
enable_cuda_graph,
use_sqrt_alibi,
merged_qkv,
rope_batch_stride,
is_interleaved_rope_mode,
out);
break;
default:
@@ -374,7 +395,9 @@ PD_BUILD_STATIC_OP(paged_attn)
"softcap:float",
"enable_cuda_graph:bool",
"use_sqrt_alibi:bool",
"merged_qkv:bool"})
"merged_qkv:bool",
"rope_batch_stride:int",
"is_interleaved_rope_mode:bool"})
.SetKernelFn(PD_KERNEL(PagedAttn))
.SetInferShapeFn(PD_INFER_SHAPE(PagedAttnInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(PagedAttnInferDtype));

View File

@@ -16,25 +16,25 @@
#include "iluvatar_context.h"
template <paddle::DataType T>
void PrefillFusedPagedAttnKernel(
const paddle::Tensor& qkv,
paddle::Tensor& k_cache,
paddle::Tensor& v_cache,
const paddle::Tensor& block_table,
const paddle::Tensor& cu_seqlens_qkv,
const paddle::optional<paddle::Tensor>& rope_sin,
const paddle::optional<paddle::Tensor>& rope_cos,
int num_heads,
int head_dim,
int num_kv_heads,
int block_size,
int max_seq_len,
float scale,
bool causal,
bool q_rope,
bool k_rope,
bool v_rope,
paddle::Tensor& out) {
void PrefillFusedPagedAttnKernel(const paddle::Tensor& qkv,
paddle::Tensor& k_cache,
paddle::Tensor& v_cache,
const paddle::Tensor& block_table,
const paddle::Tensor& cu_seqlens_qkv,
const paddle::Tensor& rope_sin,
const paddle::Tensor& rope_cos,
int num_heads,
int head_dim,
int num_kv_heads,
int block_size,
int max_seq_len,
float scale,
bool causal,
bool q_rope,
bool k_rope,
bool v_rope,
bool is_interleaved_rope_mode,
paddle::Tensor& out) {
// check dtype and contiguous
const auto& dtype = qkv.dtype();
cuinferDataType_t data_type;
@@ -139,8 +139,28 @@ void PrefillFusedPagedAttnKernel(
"cu_seqlens_qkv_dims[0] must be equal to batch_size + 1"));
int block_table_stride = block_table.strides()[0];
const float* rope_sin_ptr = rope_sin ? rope_sin.get().data<float>() : nullptr;
const float* rope_cos_ptr = rope_cos ? rope_cos.get().data<float>() : nullptr;
const float* rope_sin_ptr = rope_sin.data<float>();
const float* rope_cos_ptr = rope_cos.data<float>();
const auto& rope_dims = rope_sin.dims();
std::vector<int> rope_shape_vec, rope_stride_vec;
int rope_ndim;
if (rope_dims.size() == 4) {
// [batch_size, max_seq_len, 1, head_dim]
PADDLE_ENFORCE_EQ(rope_dims[0],
batch_size,
common::errors::InvalidArgument(
"rope_dims[0] must be equal to batch_size"));
rope_shape_vec = std::vector<int>({batch_size, max_seq_len, head_dim});
rope_stride_vec = std::vector<int>({max_seq_len * head_dim, head_dim, 1});
rope_ndim = 3;
} else if (rope_dims.size() == 3) {
// [max_seq_len, 1, head_dim]
rope_shape_vec = std::vector<int>({max_seq_len, head_dim});
rope_stride_vec = std::vector<int>({head_dim, 1});
rope_ndim = 2;
} else {
PD_THROW("Unsupported rope_ndim = %d for Paged attn", rope_ndim);
}
cuinferHandle_t cuinfer_handle =
iluvatar::getContextInstance()->getIxInferHandle();
@@ -226,22 +246,22 @@ void PrefillFusedPagedAttnKernel(
cuinferTensorDescriptor_t cos_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&cos_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
cos_desc,
CUINFER_DATA_FLOAT,
2,
std::vector<int>({max_seq_len, head_dim}).data(),
std::vector<int>({head_dim, 1}).data()));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(cos_desc,
CUINFER_DATA_FLOAT,
rope_ndim,
rope_shape_vec.data(),
rope_stride_vec.data()));
cuinferTensorDescriptor_t sin_desc;
CUINFER_CHECK(cuinferCreateTensorDescriptor(&sin_desc));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
sin_desc,
CUINFER_DATA_FLOAT,
2,
std::vector<int>({max_seq_len, head_dim}).data(),
std::vector<int>({head_dim, 1}).data()));
CUINFER_CHECK(cuinferSetTensorNdDescriptor(sin_desc,
CUINFER_DATA_FLOAT,
rope_ndim,
rope_shape_vec.data(),
rope_stride_vec.data()));
cuinferAttentionRopeMode_t rope_mode =
is_interleaved_rope_mode ? CUINFER_ATTEN_NORMAL : CUINFER_ATTEN_OCRV1;
CUINFER_CHECK(cuinferFmhaFwdMergedFuseRopeFunc(cuinfer_handle,
qkv_desc,
qkv.data(),
@@ -269,7 +289,8 @@ void PrefillFusedPagedAttnKernel(
scale,
q_rope,
k_rope,
v_rope));
v_rope,
rope_mode));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(qkv_desc));
CUINFER_CHECK(cuinferDestroyTensorDescriptor(qkv_seqlens_desc));
@@ -287,8 +308,8 @@ std::vector<paddle::Tensor> PrefillFusedPagedAttn(
paddle::Tensor& v_cache,
const paddle::Tensor& block_table,
const paddle::Tensor& cu_seqlens_qkv,
const paddle::optional<paddle::Tensor>& rope_sin,
const paddle::optional<paddle::Tensor>& rope_cos,
const paddle::Tensor& rope_sin,
const paddle::Tensor& rope_cos,
int num_heads,
int head_dim,
int num_kv_heads,
@@ -298,51 +319,56 @@ std::vector<paddle::Tensor> PrefillFusedPagedAttn(
bool causal,
bool q_rope,
bool k_rope,
bool v_rope) {
bool v_rope,
bool is_interleaved_rope_mode) {
const auto dtype = qkv.dtype();
auto out =
paddle::empty({qkv.shape()[0], num_heads * head_dim}, dtype, qkv.place());
switch (dtype) {
case paddle::DataType::BFLOAT16:
PrefillFusedPagedAttnKernel<paddle::DataType::BFLOAT16>(qkv,
k_cache,
v_cache,
block_table,
cu_seqlens_qkv,
rope_sin,
rope_cos,
num_heads,
head_dim,
num_kv_heads,
block_size,
max_seq_len,
scale,
causal,
q_rope,
k_rope,
v_rope,
out);
PrefillFusedPagedAttnKernel<paddle::DataType::BFLOAT16>(
qkv,
k_cache,
v_cache,
block_table,
cu_seqlens_qkv,
rope_sin,
rope_cos,
num_heads,
head_dim,
num_kv_heads,
block_size,
max_seq_len,
scale,
causal,
q_rope,
k_rope,
v_rope,
is_interleaved_rope_mode,
out);
break;
case paddle::DataType::FLOAT16:
PrefillFusedPagedAttnKernel<paddle::DataType::FLOAT16>(qkv,
k_cache,
v_cache,
block_table,
cu_seqlens_qkv,
rope_sin,
rope_cos,
num_heads,
head_dim,
num_kv_heads,
block_size,
max_seq_len,
scale,
causal,
q_rope,
k_rope,
v_rope,
out);
PrefillFusedPagedAttnKernel<paddle::DataType::FLOAT16>(
qkv,
k_cache,
v_cache,
block_table,
cu_seqlens_qkv,
rope_sin,
rope_cos,
num_heads,
head_dim,
num_kv_heads,
block_size,
max_seq_len,
scale,
causal,
q_rope,
k_rope,
v_rope,
is_interleaved_rope_mode,
out);
break;
default:
PD_THROW("Unsupported data type for Paged attn");
@@ -382,8 +408,8 @@ PD_BUILD_STATIC_OP(prefill_fused_paged_attn)
"v_cache",
"block_table",
"cu_seqlens_qkv",
paddle::Optional("rope_sin"),
paddle::Optional("rope_cos")})
"rope_sin",
"rope_cos"})
.Outputs({"out"})
.Attrs({"num_heads:int",
"head_dim:int",
@@ -394,7 +420,8 @@ PD_BUILD_STATIC_OP(prefill_fused_paged_attn)
"causal:bool",
"q_rope:bool",
"k_rope:bool",
"v_rope:bool"})
"v_rope:bool",
"is_interleaved_rope_mode:bool"})
.SetKernelFn(PD_KERNEL(PrefillFusedPagedAttn))
.SetInferShapeFn(PD_INFER_SHAPE(PrefillFusedPagedAttnInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(PrefillFusedPagedAttnInferDtype));

View File

@@ -555,6 +555,9 @@ elif paddle.is_compiled_with_custom_device("iluvatar_gpu"):
"gpu_ops/set_data_ipc.cu",
"gpu_ops/limit_thinking_content_length_v1.cu",
"gpu_ops/limit_thinking_content_length_v2.cu",
"gpu_ops/recover_decode_task.cu",
"gpu_ops/update_inputs_v1.cu",
"gpu_ops/get_img_boundaries.cc",
"iluvatar_ops/moe_dispatch.cu",
"iluvatar_ops/moe_reduce.cu",
"iluvatar_ops/paged_attn.cu",

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -535,7 +535,12 @@ class EngineArgs:
f"scheduler, please provide --router argument."
)
if not (current_platform.is_cuda() or current_platform.is_xpu() or current_platform.is_maca()):
if not (
current_platform.is_cuda()
or current_platform.is_xpu()
or current_platform.is_maca()
or current_platform.is_iluvatar()
):
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
if "PaddleOCR" in get_model_architecture(self.model, self.model_config_name):

View File

@@ -428,6 +428,10 @@ class ResourceManagerV1(ResourceManager):
grid_thw = paddle.to_tensor(grid_thw, dtype="int64")
if current_platform.is_xpu():
from fastdeploy.model_executor.ops.xpu import get_img_boundaries
elif current_platform.is_iluvatar():
from fastdeploy.model_executor.ops.iluvatar import (
get_img_boundaries,
)
else:
from fastdeploy.model_executor.ops.gpu import get_img_boundaries

View File

@@ -95,35 +95,54 @@ class IluvatarAttnBackend(AttentionBackend):
self.num_layers = fd_config.model_config.num_hidden_layers
self.dtype = paddle.get_default_dtype()
self.enable_mm = fd_config.model_config.enable_mm
self.rope_batch_stride = self.max_context_len * self.head_dim if self.enable_mm else 0
if "paddleocr" in fd_config.model_config.model_type:
self.is_interleaved_rope_mode = False
else:
self.is_interleaved_rope_mode = True
def split_cos_sin(self, batch_ids, forward_meta: ForwardMeta):
if self.enable_mm:
# the num_seqs dim of rotary_embs > 1 (e.g. ernie-vl and paddleocr-vl)
cos = forward_meta.rotary_embs[batch_ids, 0, 0, :, :, :]
sin = forward_meta.rotary_embs[batch_ids, 1, 0, :, :, :]
else:
# the num_seqs dim of rotary_embs = 1 (e.g. ernie-text)
cos = forward_meta.rotary_embs[0, 0, :, :, :]
sin = forward_meta.rotary_embs[1, 0, :, :, :]
return cos, sin
def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
if self.enable_mm:
# VL: TODO: The first 0 may need to be replaced with batch_id
# of max_num_seqs when running multiple batch case later
self.rope_cos = forward_meta.rotary_embs[0, 0, 0, :, :, :]
self.rope_sin = forward_meta.rotary_embs[0, 1, 0, :, :, :]
else:
# text
self.rope_cos = forward_meta.rotary_embs[0, 0, :, :, :]
self.rope_sin = forward_meta.rotary_embs[1, 0, :, :, :]
self.prefill_info_dict = {}
self.decode_info_dict = {}
self.prefill_info_dict["batch_ids"] = paddle.where(forward_meta.seq_lens_encoder)[0]
self.decode_info_dict["batch_ids"] = paddle.where(forward_meta.seq_lens_decoder)[0]
self.prefill_len = len(self.prefill_info_dict["batch_ids"])
self.decode_len = len(self.decode_info_dict["batch_ids"])
prefill_batch_ids = self.prefill_info_dict["batch_ids"]
decode_batch_ids = self.decode_info_dict["batch_ids"]
if prefill_batch_ids.dim() == 0:
prefill_batch_ids = prefill_batch_ids.unsqueeze(0)
if decode_batch_ids.dim() == 0:
decode_batch_ids = decode_batch_ids.unsqueeze(0)
# only prefill
if self.decode_len == 0:
cu_seq_ids = list(range(self.prefill_len + 1))
self.prefill_info_dict["cu_seqlens_q"] = forward_meta.cu_seqlens_q[cu_seq_ids]
self.mixed = False
cu_seq_ids = self.prefill_info_dict["batch_ids"] + 1
self.prefill_info_dict["cu_seqlens_q"] = paddle.concat(
[forward_meta.cu_seqlens_q[:1], forward_meta.cu_seqlens_q[cu_seq_ids]]
)
self.rope_cos, self.rope_sin = self.split_cos_sin(prefill_batch_ids, forward_meta)
# only decode
elif self.prefill_len == 0:
self.mixed = False
self.rope_cos, self.rope_sin = self.split_cos_sin(decode_batch_ids, forward_meta)
# both prefill and decode
else:
self.mixed = True
self.prefill_rope_cos, self.prefill_rope_sin = self.split_cos_sin(prefill_batch_ids, forward_meta)
self.decode_rope_cos, self.decode_rope_sin = self.split_cos_sin(decode_batch_ids, forward_meta)
self.prefill_num_tokens = paddle.sum(forward_meta.seq_lens_encoder).item()
self.prefill_info_dict["cu_seqlens_q"] = paddle.zeros(
[self.prefill_len + 1], dtype=forward_meta.cu_seqlens_q.dtype
@@ -141,7 +160,7 @@ class IluvatarAttnBackend(AttentionBackend):
)
prefill_start, decode_start, start = 0, self.prefill_num_tokens, 0
non_zeros_ids = forward_meta.seq_lens_this_time != 0
non_zeros_ids = paddle.where(forward_meta.seq_lens_this_time)[0]
non_zeros_seq_lens = forward_meta.seq_lens_this_time[non_zeros_ids]
end = non_zeros_seq_lens[0]
if end > 1:
@@ -234,6 +253,8 @@ class IluvatarAttnBackend(AttentionBackend):
v_cache,
block_tables=forward_meta.block_tables[self.prefill_info_dict["batch_ids"], :],
cu_seqlens_qkv=self.prefill_info_dict["cu_seqlens_q"],
rope_sin=self.rope_sin,
rope_cos=self.rope_cos,
num_heads=self.num_heads,
head_dim=self.head_dim,
num_kv_heads=self.num_kv_heads,
@@ -244,8 +265,7 @@ class IluvatarAttnBackend(AttentionBackend):
q_rope=True,
k_rope=True,
v_rope=False,
rope_sin=self.rope_sin,
rope_cos=self.rope_cos,
is_interleaved_rope_mode=self.is_interleaved_rope_mode,
)
elif self.prefill_len == 0:
output = paged_attention(
@@ -272,6 +292,8 @@ class IluvatarAttnBackend(AttentionBackend):
v=qkv,
rope_sin=self.rope_sin,
rope_cos=self.rope_cos,
rope_batch_stride=self.rope_batch_stride,
is_interleaved_rope_mode=self.is_interleaved_rope_mode,
)
else:
output = mixed_fused_paged_attention(
@@ -282,6 +304,8 @@ class IluvatarAttnBackend(AttentionBackend):
decode_block_tables=forward_meta.block_tables[self.decode_info_dict["batch_ids"], :],
cu_seqlens_qkv=self.prefill_info_dict["cu_seqlens_q"],
seq_lens=forward_meta.seq_lens_decoder[self.decode_info_dict["batch_ids"], 0] + 1,
prefill_rope_sin=self.prefill_rope_sin,
prefill_rope_cos=self.prefill_rope_cos,
prefill_num_tokens=self.prefill_num_tokens,
num_heads=self.num_heads,
head_dim=self.head_dim,
@@ -298,8 +322,10 @@ class IluvatarAttnBackend(AttentionBackend):
softcap=self.attention_metadata.softcap,
use_cuda_graph=self.attention_metadata.use_cuda_graph,
use_sqrt_alibi=self.attention_metadata.use_sqrt_alibi,
rope_sin=self.rope_sin,
rope_cos=self.rope_cos,
decode_rope_sin=self.decode_rope_sin,
decode_rope_cos=self.decode_rope_cos,
rope_batch_stride=self.rope_batch_stride,
is_interleaved_rope_mode=self.is_interleaved_rope_mode,
)
return output

View File

@@ -54,6 +54,7 @@ from fastdeploy.model_executor.models.model_base import (
ModelForCasualLM,
ModelRegistry,
)
from fastdeploy.platforms import current_platform
class Ernie4_5_VLMLP(Ernie4_5_MLP):
@@ -539,6 +540,10 @@ class Ernie4_5_VLModel(nn.Layer):
text_image_index_out(vl_moe_meta.token_type_ids, vl_moe_meta.text_index, vl_moe_meta.image_index)
hidden_states = input_embeddings
if current_platform.is_iluvatar() and forward_meta.attn_backend.mixed:
hidden_states = forward_meta.attn_backend.transpose(hidden_states)
residual = None
for i in range(self.num_layers):
@@ -550,6 +555,10 @@ class Ernie4_5_VLModel(nn.Layer):
)
out = self.norm(hidden_states, residual, forward_meta=forward_meta)[0]
if current_platform.is_iluvatar() and forward_meta.attn_backend.mixed:
out = forward_meta.attn_backend.reverse_transpose(out)
return out

View File

@@ -40,6 +40,7 @@ from fastdeploy.model_executor.utils import (
default_weight_loader,
process_weights_after_loading,
)
from fastdeploy.platforms import current_platform
from .projector import Projector
from .siglip import SiglipVisionModel
@@ -101,12 +102,19 @@ class PaddleOCRVLModel(nn.Layer):
forward_meta: ForwardMeta,
):
hidden_states = input_embeddings
if current_platform.is_iluvatar() and forward_meta.attn_backend.mixed:
hidden_states = forward_meta.attn_backend.transpose(hidden_states)
residual = None
for i in range(self.num_layers):
hidden_states, residual = self.layers[i](forward_meta, hidden_states, residual)
out = self.norm(hidden_states, residual)[0]
if current_platform.is_iluvatar() and forward_meta.attn_backend.mixed:
out = forward_meta.attn_backend.reverse_transpose(out)
return out

View File

@@ -52,6 +52,8 @@ def paged_attention(
v: paddle.Tensor = None,
rope_sin: paddle.Tensor = None,
rope_cos: paddle.Tensor = None,
rope_batch_stride: int = 0,
is_interleaved_rope_mode: bool = True,
):
return paged_attn(
q,
@@ -77,6 +79,8 @@ def paged_attention(
use_cuda_graph,
use_sqrt_alibi,
merged_qkv,
rope_batch_stride,
is_interleaved_rope_mode,
)
@@ -86,6 +90,8 @@ def prefill_fused_paged_attention(
v_cache: paddle.Tensor,
block_tables: paddle.Tensor,
cu_seqlens_qkv: paddle.Tensor,
rope_sin: paddle.Tensor,
rope_cos: paddle.Tensor,
num_heads: int,
head_dim: int,
num_kv_heads: int,
@@ -96,8 +102,7 @@ def prefill_fused_paged_attention(
q_rope: bool = True,
k_rope: bool = True,
v_rope: bool = False,
rope_sin: paddle.Tensor = None,
rope_cos: paddle.Tensor = None,
is_interleaved_rope_mode: bool = True,
):
return prefill_fused_paged_attn(
qkv,
@@ -117,6 +122,7 @@ def prefill_fused_paged_attention(
q_rope,
k_rope,
v_rope,
is_interleaved_rope_mode,
)
@@ -128,6 +134,8 @@ def mixed_fused_paged_attention(
decode_block_tables: paddle.Tensor,
cu_seqlens_qkv: paddle.Tensor,
seq_lens: paddle.Tensor,
prefill_rope_sin: paddle.Tensor,
prefill_rope_cos: paddle.Tensor,
prefill_num_tokens: int,
num_heads: int,
head_dim: int,
@@ -144,8 +152,10 @@ def mixed_fused_paged_attention(
softcap: float = 0.0,
use_cuda_graph: bool = False,
use_sqrt_alibi: bool = False,
rope_sin: paddle.Tensor = None,
rope_cos: paddle.Tensor = None,
decode_rope_sin: paddle.Tensor = None,
decode_rope_cos: paddle.Tensor = None,
rope_batch_stride: int = 0,
is_interleaved_rope_mode: bool = True,
):
return mixed_fused_paged_attn(
qkv,
@@ -155,8 +165,10 @@ def mixed_fused_paged_attention(
decode_block_tables,
cu_seqlens_qkv,
seq_lens,
rope_sin,
rope_cos,
prefill_rope_sin,
prefill_rope_cos,
decode_rope_sin,
decode_rope_cos,
prefill_num_tokens,
num_heads,
head_dim,
@@ -173,4 +185,6 @@ def mixed_fused_paged_attention(
softcap,
use_cuda_graph,
use_sqrt_alibi,
rope_batch_stride,
is_interleaved_rope_mode,
)

View File

@@ -33,6 +33,7 @@ if current_platform.is_iluvatar():
set_stop_value_multi_ends,
step_paddle,
update_inputs,
update_inputs_v1,
)
elif current_platform.is_gcu():
from fastdeploy.model_executor.ops.gcu import (

View File

@@ -56,11 +56,11 @@ from fastdeploy.platforms import current_platform
if current_platform.is_iluvatar():
from fastdeploy.model_executor.ops.iluvatar import (
recover_decode_task,
set_data_ipc,
set_value_by_flags_and_idx,
)
recover_decode_task = None
share_external_data = None
elif current_platform.is_dcu():
from fastdeploy.model_executor.ops.gpu import set_value_by_flags_and_idx
@@ -467,7 +467,7 @@ class GPUModelRunner(ModelRunnerBase):
multi_vision_inputs["encoder_cache_info"].append((mm_hash, feature_positions[i], False))
if envs.FD_ENABLE_MAX_PREFILL:
multi_vision_inputs["images_lst"].append(
inputs["images"][image_start_idx : image_start_idx + image_offset].cuda()
inputs["images"][image_start_idx : image_start_idx + image_offset].to(self.device)
)
multi_vision_inputs["grid_thw_lst"].append(paddle.to_tensor(grid_thw_list[i]))
multi_vision_inputs["cu_seqlens"].append(vit_seqlen_list[i])
@@ -486,7 +486,7 @@ class GPUModelRunner(ModelRunnerBase):
else:
if envs.FD_ENABLE_MAX_PREFILL:
multi_vision_inputs["images_lst"].append(
inputs["images"][request.image_start : request.image_end].cuda()
inputs["images"][request.image_start : request.image_end].to(self.device)
)
multi_vision_inputs["grid_thw_lst"].extend(
paddle.to_tensor(inputs["grid_thw"][request.num_image_start : request.num_image_end])

View File

@@ -38,7 +38,6 @@ class IluvatarModelRunner(GPUModelRunner):
)
assert not self.speculative_decoding, "Iluvatar does not support speculative decoding"
assert self.guided_backend is None, "Iluvatar does not support guided decoding"
assert not envs.ENABLE_V1_KVCACHE_SCHEDULER, "Iluvatar does not support v1 kvcache scheduler"
assert not self.cache_config.enable_prefix_caching, "Iluvatar does not support prefix caching"
self.mla_cache = envs.FD_ATTENTION_BACKEND == "MLA_ATTN"
assert not self.mla_cache, "Iluvatar does not support MLA"
@@ -48,9 +47,9 @@ class IluvatarModelRunner(GPUModelRunner):
not self.cache_config.enable_chunked_prefill
), "Iluvatar does not support chunked prefill for VL model"
# VL neox style = True
if self.enable_mm:
emb_shape = self.share_inputs["rope_emb"].shape
emb_shape[-1] *= 2
emb_shape = self.share_inputs["rope_emb"].shape
if emb_shape[-1] == self.model_config.head_dim // 2:
emb_shape[-1] = self.model_config.head_dim
self.share_inputs["rope_emb"] = paddle.full(
shape=emb_shape,
fill_value=0,

View File

@@ -983,7 +983,12 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
logger.info(f"- Dynamic load weight: {load_config.dynamic_load_weight}")
logger.info(f"- Load strategy: {load_config.load_strategy}")
if not (current_platform.is_cuda() or current_platform.is_xpu() or current_platform.is_maca()):
if not (
current_platform.is_cuda()
or current_platform.is_xpu()
or current_platform.is_maca()
or current_platform.is_iluvatar()
):
logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported.")
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0

View File

@@ -10,7 +10,7 @@ tqdm
pynvml
uvicorn==0.29.0
fastapi
paddleformers==0.3.1
paddleformers==0.4.0
redis
etcd3
httpx

View File

@@ -4,7 +4,6 @@ echo "$DIR"
#先kill一遍
ps -efww | grep -E 'run_ernie300B_4layer' | grep -v grep | awk '{print $2}' | xargs kill -9 || true
ixsmi -r
unset http_proxy
unset https_proxy
@@ -15,14 +14,13 @@ ln -sf /usr/local/bin/python3 /usr/local/bin/python
echo "pip requirements"
python -m pip install -r requirements_iluvatar.txt
echo "install paddle cpu and custom device"
python -m pip install paddlepaddle==3.3.0.dev20251028 -i https://www.paddlepaddle.org.cn/packages/nightly/cpu/
python -m pip install paddle-iluvatar-gpu==3.0.0.dev20251029 -i https://www.paddlepaddle.org.cn/packages/nightly/ixuca/
python -m pip install paddlepaddle==3.3.0.dev20251103 -i https://www.paddlepaddle.org.cn/packages/nightly/cpu/
python -m pip install paddle-iluvatar-gpu==3.0.0.dev20251107 -i https://www.paddlepaddle.org.cn/packages/nightly/ixuca/
echo "build whl"
bash build.sh || exit 1
CI_PATH=tests/ci_use/iluvatar_UT
export INFERENCE_MSG_QUEUE_ID=232132
export FD_DEBUG=1
export PADDLE_XCCL_BACKEND=iluvatar_gpu
export FD_SAMPLING_CLASS=rejection
@@ -42,8 +40,17 @@ do
ps -efww | grep -E '${cur_test_file}' | grep -v grep | awk '{print $2}' | xargs kill -9 || true
if [ ${exit_code} -ne 0 ]; then
echo "log/workerlog.0"
cat log/workerlog.0
if [ ! -f "./log/workerlog.0" ]; then
echo "------------------- log/launch_worker.log -----------------"
cat log/launch_worker.log
else
echo "------------------- log/workerlog.0 -----------------"
cat log/workerlog.0
fi
if [ -f "log/fastdeploy_error.log" ]; then
echo "------------------- log/fastdeploy_error.log -----------------"
cat log/fastdeploy_error.log
fi
exit 1
fi
done

View File

@@ -0,0 +1,235 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fastdeploy + ERNIE-4.5-Turbo 的指标评估"""
# adapted from https://github.com/sgl-project/sglang/blob/main/benchmark/gsm8k/bench_other.py
import argparse
import ast
import json
import re
import time
from concurrent.futures import ThreadPoolExecutor
import numpy as np
import requests
from tqdm import tqdm
INVALID = -9999999
def call_generate(prompt, **kwargs):
"""
Generates response based on the input prompt.
Args:
prompt (str): The input prompt text.
**kwargs: Keyword arguments, including server IP address and port number.
Returns:
str: The response generated based on the prompt.
"""
url = f"http://{kwargs['ip']}:{kwargs['port']}/v1/chat/completions"
headers = {"Content-Type": "application/json"}
data = {
"messages": [
{
"role": "user",
"content": prompt,
}
],
"temperature": 0.6,
"max_tokens": 2047,
"top_p": 0.95,
"do_sample": True,
}
response = requests.post(url, headers=headers, data=json.dumps(data))
out = response.json()
return out["choices"][0]["message"]["content"]
def get_one_example(lines, i, include_answer):
"""
Retrieves a question-answer example from the given list of text lines.
Args:
lines (list of dict): A list of question-answer pairs.
i (int): The index of the question-answer pair to retrieve from lines.
include_answer (bool): Whether to include the answer in the returned string.
Returns:
str: A formatted question-answer string in the format "Question: <question>\nAnswer: <answer>".
"""
ret = "Question: " + lines[i]["question"] + "\nAnswer:"
if include_answer:
ret += " " + lines[i]["answer"]
return ret
def get_few_shot_examples(lines, k):
"""
Selects k examples from the given list of text lines and concatenates them into a single string.
Args:
lines (list): A list containing text lines.
k (int): The number of examples to select.
Returns:
str: A string composed of k examples, separated by two newline characters.
"""
ret = ""
for i in range(k):
ret += get_one_example(lines, i, True) + "\n\n"
return ret
def get_answer_value(answer_str):
"""
Extracts numerical values from an answer string and returns them.
Args:
answer_str (str): The string containing the answer.
Returns:
The extracted numerical value; returns "INVALID" if extraction fails.
"""
answer_str = answer_str.replace(",", "")
numbers = re.findall(r"\d+", answer_str)
if len(numbers) < 1:
return INVALID
try:
return ast.literal_eval(numbers[-1])
except SyntaxError:
return INVALID
def read_jsonl(filename: str):
"""
Reads a JSONL file.
Args:
filename (str): Path to the JSONL file.
Yields:
dict: A dictionary object corresponding to each line in the JSONL file.
"""
with open(filename) as fin:
for line in fin:
if line.startswith("#"):
continue
yield json.loads(line)
def main(args):
"""
Process inputs and generate answers by calling the model in parallel using a thread pool.
Args:
args (argparse.Namespace):
- num_questions (int): Number of questions to process.
- num_shots (int): Number of few-shot learning examples.
- ip (str): IP address of the model service.
- port (int): Port number of the model service.
- parallel (int): Number of questions to process in parallel.
- result_file (str): File path to store the results.
Returns:
None
"""
# Read data
filename = "test.jsonl"
lines = list(read_jsonl(filename))
# Construct prompts
num_questions = args.num_questions
num_shots = args.num_shots
few_shot_examples = get_few_shot_examples(lines, num_shots)
questions = []
labels = []
for i in range(len(lines[:num_questions])):
questions.append(get_one_example(lines, i, False))
labels.append(get_answer_value(lines[i]["answer"]))
assert all(l != INVALID for l in labels)
states = [None] * len(labels)
# Use thread pool
def get_one_answer(i):
answer = call_generate(
prompt=few_shot_examples + questions[i],
# stop=["Question", "Assistant:", "<|separator|>"],
ip=args.ip,
port=args.port,
)
states[i] = answer
tic = time.time()
if args.parallel == 1:
for i in tqdm(range(len(questions))):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
list(
tqdm(
executor.map(get_one_answer, list(range(len(questions)))),
total=len(questions),
)
)
latency = time.time() - tic
preds = []
for i in range(len(states)):
preds.append(get_answer_value(states[i]))
# Compute accuracy
acc = np.mean(np.array(preds) == np.array(labels))
invalid = np.mean(np.array(preds) == INVALID)
# Print results
print(f"Accuracy: {acc:.3f}")
print(f"Invalid: {invalid:.3f}")
print(f"Latency: {latency:.3f} s")
with open(args.result_file, "a") as fout:
value = {
"task": "gsm8k",
"backend": "paddlepaddle",
"num_gpus": 1,
"latency": round(latency, 3),
"accuracy": round(acc, 3),
"num_requests": args.num_questions,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ip", type=str, default="127.0.0.1")
parser.add_argument("--port", type=str, default="8188")
parser.add_argument("--num-shots", type=int, default=10)
parser.add_argument("--data-path", type=str, default="test.jsonl")
parser.add_argument("--num-questions", type=int, default=1319)
parser.add_argument("--result-file", type=str, default="result.jsonl")
parser.add_argument("--parallel", type=int, default=1)
args = parser.parse_args()
main(args)

View File

@@ -12,43 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import os
import sys
import threading
from fastdeploy import LLM, SamplingParams
from fastdeploy.utils import set_random_seed
tests_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
sys.path.insert(0, tests_dir)
def timeout(seconds):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
result = [None]
exception = [None]
def target():
try:
result[0] = func(*args, **kwargs)
except Exception as e:
exception[0] = e
thread = threading.Thread(target=target)
thread.daemon = True
thread.start()
thread.join(seconds)
if thread.is_alive():
raise TimeoutError(f"Function timed out after {seconds} seconds")
if exception[0]:
raise exception[0]
return result[0]
return wrapper
return decorator
from ci_use.iluvatar_UT.utils import TIMEOUT_MSG, timeout
@timeout(80)
@@ -75,15 +48,15 @@ def offline_infer_check():
59335,
68170,
183,
49080,
94717,
82966,
99140,
31615,
51497,
94851,
60764,
10889,
97404,
100088,
36310,
95633,
95913,
41459,
95049,
94970,
96840,
2,
], f"{outputs[0].outputs.token_ids}"
print("PASSED")
@@ -94,10 +67,7 @@ if __name__ == "__main__":
result = offline_infer_check()
sys.exit(0)
except TimeoutError:
print(
"The timeout exit may be due to multiple processes sharing the "
"same gpu card. You can check this using ixsmi on the device."
)
print(TIMEOUT_MSG)
sys.exit(124)
except Exception:
sys.exit(1)

View File

@@ -12,10 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import io
import os
import sys
import threading
import requests
from PIL import Image
@@ -24,39 +23,13 @@ from fastdeploy import LLM, SamplingParams
from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer
from fastdeploy.utils import set_random_seed
tests_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
sys.path.insert(0, tests_dir)
def timeout(seconds):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
result = [None]
exception = [None]
def target():
try:
result[0] = func(*args, **kwargs)
except Exception as e:
exception[0] = e
thread = threading.Thread(target=target)
thread.daemon = True
thread.start()
thread.join(seconds)
if thread.is_alive():
raise TimeoutError(f"Function timed out after {seconds} seconds")
if exception[0]:
raise exception[0]
return result[0]
return wrapper
return decorator
from ci_use.iluvatar_UT.utils import TIMEOUT_MSG, timeout
@timeout(180)
@timeout(210)
def offline_infer_check():
set_random_seed(123)
@@ -122,9 +95,9 @@ def offline_infer_check():
5119,
93956,
68725,
14449,
4356,
38225,
100282,
23,
23,
2,
], f"{outputs[0].outputs.token_ids}"
print("PASSED")
@@ -135,10 +108,7 @@ if __name__ == "__main__":
result = offline_infer_check()
sys.exit(0)
except TimeoutError:
print(
"The timeout exit may be due to multiple processes sharing the "
"same gpu card. You can check this using ixsmi on the device."
)
print(TIMEOUT_MSG)
sys.exit(124)
except Exception:
sys.exit(1)

View File

@@ -0,0 +1,28 @@
import functools
import signal
def timeout(seconds):
def decorator(func):
def _handle_timeout(signum, frame):
raise TimeoutError(f"Function '{func.__name__}' timed out after {seconds} seconds")
@functools.wraps(func)
def wrapper(*args, **kwargs):
original_handler = signal.signal(signal.SIGALRM, _handle_timeout)
signal.alarm(seconds)
try:
result = func(*args, **kwargs)
signal.alarm(0)
return result
finally:
signal.signal(signal.SIGALRM, original_handler)
signal.alarm(0)
return wrapper
return decorator
TIMEOUT_MSG = "The timeout exit may be due to multiple processes sharing the same gpu card. You can check this using ixsmi on the device."