diff --git a/custom_ops/xpu_ops/src/ops/block_attn.cc b/custom_ops/xpu_ops/src/ops/block_attn.cc index c9e3313f2..d3dd4c726 100644 --- a/custom_ops/xpu_ops/src/ops/block_attn.cc +++ b/custom_ops/xpu_ops/src/ops/block_attn.cc @@ -585,7 +585,8 @@ std::vector BlockAttnKernel( int, E_Scale>( xpu_ctx->x_context(), - reinterpret_cast(qkv.data()), // qkv + reinterpret_cast(qkv.data()) + + total_enc_len * qkv_shape[qkv_shape.size() - 1], // qkv reinterpret_cast( rotary_embs.data()), // rotary_pos_emb reinterpret_cast(