[Iluvatar GPU] Optimze attention and moe performance (#3234)

This commit is contained in:
yzwu
2025-08-08 10:51:24 +08:00
committed by GitHub
parent 37569cca86
commit fbdd6b0663
24 changed files with 1130 additions and 1653 deletions

View File

@@ -13,7 +13,8 @@ concurrency:
jobs:
CI_GCU:
runs-on: [self-hosted, GCU-S60-8Card]
runs-on:
group: GCU
steps:
- name: Print current runner name
run: |

View File

@@ -11,7 +11,8 @@ concurrency:
jobs:
CI_ILUVATAR:
runs-on: [self-hosted, IXUCA]
runs-on:
group: IXUCA
steps:
- name: Print current runner name
run: |

View File

@@ -29,7 +29,11 @@ std::vector<paddle::Tensor> TopPSamplingReject(const paddle::Tensor &probs,
// need_batch_random
if (seed == -1) {
#ifdef PADDLE_WITH_COREX
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(probs.place()));
#else
phi::GPUContext* dev_ctx = static_cast<phi::GPUContext*>(phi::DeviceContextPool::Instance().Get(probs.place()));
#endif
auto gen_cuda = dev_ctx->GetGenerator();
auto seed_offset = gen_cuda->IncrementOffset(32 * batch_size);
philox_seed = seed_offset.first;

View File

@@ -212,9 +212,15 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
prob_greater_than_threshold[j] = pred(prob_vec[j]) ? prob_vec[j] : 0;
valid[j] = pred(prob_vec[j]) && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d;
}
#ifdef PADDLE_WITH_COREX
float aggregate_local =
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage->block_prim.reduce)
.Sum(prob_greater_than_threshold);
#else
float aggregate_local =
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage->block_prim.reduce)
.Sum<VEC_SIZE>(prob_greater_than_threshold);
#endif
if (tx == 0) {
temp_storage->block_aggregate.value = aggregate_local;
}
@@ -226,8 +232,13 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
DeterministicInclusiveSum<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>(
prob_greater_than_threshold, inclusive_cdf, temp_storage);
} else {
#ifdef PADDLE_WITH_COREX
BlockScan<float, BLOCK_THREADS, SCAN_ALGORITHM>(temp_storage->block_prim.scan)
.InclusiveSum(prob_greater_than_threshold, inclusive_cdf);
#else
BlockScan<float, BLOCK_THREADS, SCAN_ALGORITHM>(temp_storage->block_prim.scan)
.InclusiveSum<VEC_SIZE>(prob_greater_than_threshold, inclusive_cdf);
#endif
__syncthreads();
}
@@ -239,11 +250,21 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
bool greater_than_u_diff[VEC_SIZE];
#ifdef SAMPLING_CUB_SUBTRACTLEFT_DEFINED
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
.SubtractLeft<VEC_SIZE>(greater_than_u, greater_than_u_diff, BoolDiffOp());
#ifdef PADDLE_WITH_COREX
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
.SubtractLeft(greater_than_u, greater_than_u_diff, BoolDiffOp());
#else
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
.SubtractLeft<VEC_SIZE>(greater_than_u, greater_than_u_diff, BoolDiffOp());
#endif
#else
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
.FlagHeads<VEC_SIZE>(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0);
#ifdef PADDLE_WITH_COREX
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
.FlagHeads(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0);
#else
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
.FlagHeads<VEC_SIZE>(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0);
#endif
#endif
__syncthreads();
@@ -355,18 +376,30 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
(probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
}
#ifdef PADDLE_WITH_COREX
aggregate_gt_pivot_0 +=
BlockReduce<ValueCount<float>, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count)
.Sum(probs_gt_pivot_0);
#else
aggregate_gt_pivot_0 +=
BlockReduce<ValueCount<float>, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count)
.Sum<VEC_SIZE>(probs_gt_pivot_0);
#endif
if (tx == 0) {
temp_storage.block_aggregate.pair = aggregate_gt_pivot_0;
}
__syncthreads();
aggregate_gt_pivot_0 = temp_storage.block_aggregate.pair;
#ifdef PADDLE_WITH_COREX
aggregate_gt_pivot_1 +=
BlockReduce<ValueCount<float>, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count)
.Sum(probs_gt_pivot_1);
#else
aggregate_gt_pivot_1 +=
BlockReduce<ValueCount<float>, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count)
.Sum<VEC_SIZE>(probs_gt_pivot_1);
#endif
if (tx == 0) {
temp_storage.block_aggregate.pair = aggregate_gt_pivot_1;
}
@@ -466,16 +499,26 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
probs_gt_pivot_1[j] = (probs_vec[j] > pivot_1) ? probs_vec[j] : 0;
}
#ifdef PADDLE_WITH_COREX
aggregate_gt_pivot_0 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
.Sum(probs_gt_pivot_0);
#else
aggregate_gt_pivot_0 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
.Sum<VEC_SIZE>(probs_gt_pivot_0);
#endif
if (tx == 0) {
temp_storage.block_aggregate.value = aggregate_gt_pivot_0;
}
__syncthreads();
aggregate_gt_pivot_0 = temp_storage.block_aggregate.value;
#ifdef PADDLE_WITH_COREX
aggregate_gt_pivot_1 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
.Sum(probs_gt_pivot_1);
#else
aggregate_gt_pivot_1 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
.Sum<VEC_SIZE>(probs_gt_pivot_1);
#endif
if (tx == 0) {
temp_storage.block_aggregate.value = aggregate_gt_pivot_1;
}
@@ -521,9 +564,15 @@ __device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, u
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
in_data_[j] = in_data_vec[j];
}
#ifdef PADDLE_WITH_COREX
max_val = max(
max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce(in_data_, cub::Max()));
#else
max_val = max(
max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce<VEC_SIZE>(in_data_, cub::Max()));
#endif
__syncthreads();
}
if (tx == 0) {
@@ -605,7 +654,11 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
const uint32_t row_idx = bx;
const uint32_t k = top_k_arr[row_idx] == 0 ? d : top_k_arr[row_idx];
#ifdef PADDLE_WITH_COREX
double pivot = std::numeric_limits<float>::infinity(), normalizer = 1;
#else
double pivot = -cuda::std::numeric_limits<float>::infinity(), normalizer = 1;
#endif
vec_t<float, VEC_SIZE> probs_vec;
if (k < d) {
extern __shared__ __align__(alignof(RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>))
@@ -659,14 +712,26 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
}
}
#ifdef PADDLE_WITH_COREX
aggregate_gt_pivot_0 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce_value_count)
.Sum(probs_gt_pivot_0_pair);
#else
aggregate_gt_pivot_0 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce_value_count)
.Sum<VEC_SIZE>(probs_gt_pivot_0_pair);
#endif
__syncthreads();
#ifdef PADDLE_WITH_COREX
aggregate_gt_pivot_1 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce_value_count)
.Sum(probs_gt_pivot_1_pair);
#else
aggregate_gt_pivot_1 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce_value_count)
.Sum<VEC_SIZE>(probs_gt_pivot_1_pair);
#endif
__syncthreads();
}
min_gt_low =

View File

@@ -258,9 +258,13 @@ inline std::pair<int, int> GetCudaComputeCapability() {
/******************* math *******************/
__forceinline__ __device__ float ptx_rcp(float x) {
#ifdef PADDLE_WITH_COREX
return __ivcorex_rcpf(x);
#else
float y;
asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x));
return y;
#endif
}
template <typename T1, typename T2>

View File

@@ -15,15 +15,6 @@
#include "helper.h"
#include "iluvatar_context.h"
#define CUINFER_CHECK(func) \
do { \
cuinferStatus_t status = (func); \
if (status != CUINFER_STATUS_SUCCESS) { \
std::cerr << "Error in file " << __FILE__ << " on line " << __LINE__ << ": " \
<< cuinferGetErrorString(status) << std::endl; \
throw std::runtime_error("CUINFER_CHECK ERROR"); \
} \
} while (0)
template <paddle::DataType T>
void PagedAttnKernel(const paddle::Tensor& q,
@@ -34,6 +25,8 @@ void PagedAttnKernel(const paddle::Tensor& q,
const paddle::optional<paddle::Tensor> &alibi_slopes,
const paddle::optional<paddle::Tensor> &k,
const paddle::optional<paddle::Tensor> &v,
const paddle::optional<paddle::Tensor> &rope_sin,
const paddle::optional<paddle::Tensor> &rope_cos,
int num_kv_heads,
float scale,
int block_size,
@@ -44,6 +37,7 @@ void PagedAttnKernel(const paddle::Tensor& q,
float softcap,
bool enable_cuda_graph,
bool use_sqrt_alibi,
bool merged_qkv,
paddle::Tensor& out) {
if (alibi_slopes) {
PADDLE_ENFORCE_EQ(alibi_slopes.get().dtype(),
@@ -75,14 +69,6 @@ void PagedAttnKernel(const paddle::Tensor& q,
true,
common::errors::InvalidArgument(
"paged_attention expects k_cache is contiguous"));
PADDLE_ENFORCE_EQ(v_cache.dtype(),
dtype,
common::errors::InvalidArgument(
"v_cache dtype must be the same as query dtype"));
PADDLE_ENFORCE_EQ(v_cache.is_contiguous(),
true,
common::errors::InvalidArgument(
"paged_attention expects v_cache is contiguous"));
PADDLE_ENFORCE_EQ(block_table.dtype(),
paddle::DataType::INT32,
common::errors::InvalidArgument(
@@ -99,14 +85,14 @@ void PagedAttnKernel(const paddle::Tensor& q,
true,
common::errors::InvalidArgument(
"paged_attention expects seq_lens is contiguous"));
// check dim and shape
// out: [num_seqs, num_heads, head_size]
// q: [num_seqs, num_heads, head_size]
// k_chache: [num_blocks, kv_num_heads, block_size, head_size]
// v_chache: [num_blocks, kv_num_heads, block_size, head_size]
// k_cache: [num_blocks, kv_num_heads, block_size, head_size]
// v_cache: [num_blocks, kv_num_heads, block_size, head_size]
// block_table: [num_seqs, max_num_blocks_per_seq]
// seq_lens: [num_seqs]
// q and out:
// merged_qkv = false: [num_seqs, num_heads, head_size]
// merged_qkv = true: [num_seqs, num_heads+2*num_kv_heads, head_size]
const auto& q_dims = q.dims();
PADDLE_ENFORCE_EQ(q_dims.size(),
@@ -119,11 +105,6 @@ void PagedAttnKernel(const paddle::Tensor& q,
common::errors::InvalidArgument(
"paged_attn receive out dims is "
"[num_seqs, num_heads, head_size]"));
PADDLE_ENFORCE_EQ(k_cache.dims(),
v_cache.dims(),
common::errors::InvalidArgument(
"paged_attn requires k_cache size is the "
"same as v_cache"));
const auto& kv_cache_dims = k_cache.dims();
PADDLE_ENFORCE_EQ(kv_cache_dims.size(),
@@ -146,7 +127,7 @@ void PagedAttnKernel(const paddle::Tensor& q,
"paged_attn receive seq_lens dims is [num_seqs]"));
int num_seqs = q_dims[0];
int num_heads = q_dims[1];
int num_heads = merged_qkv ? q_dims[1] - 2 * num_kv_heads : q_dims[1];
int head_size = q_dims[2];
int max_num_blocks_per_seq = block_table_dims[1];
int q_stride = q.strides()[0];
@@ -178,22 +159,28 @@ void PagedAttnKernel(const paddle::Tensor& q,
const float *alibi_slopes_ptr = alibi_slopes ? alibi_slopes.get().data<float>() : nullptr;
const void *key_ptr = k ? k.get().data() : nullptr;
const void *value_ptr = v ? v.get().data() : nullptr;
size_t workspace_size = 0;
void* workspace_ptr = nullptr;
CUINFER_CHECK(cuInferPageAttentionGetWorkspaceV7(
num_seqs, num_heads, num_kv_heads, head_size, block_size, max_context_len, &workspace_size));
CUDA_CHECK(cudaMalloc((void**)&workspace_ptr, workspace_size));
CUDA_CHECK(cudaMemset(workspace_ptr, 0xff, workspace_size));
const float *rope_sin_ptr = merged_qkv ? rope_sin.get().data<float>() : nullptr;
const float *rope_cos_ptr = merged_qkv ? rope_cos.get().data<float>() : nullptr;
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(q.place()));
auto stream = static_cast<const cudaStream_t>(dev_ctx->stream());
cuinferHandle_t cuinfer_handle = iluvatar::getContextInstance()->getIxInferHandle();
size_t workspace_size = 0;
CUINFER_CHECK(cuInferPageAttentionGetWorkspaceV7(num_seqs,
num_heads,
num_kv_heads,
head_size,
block_size,
max_context_len,
&workspace_size));
auto* allocator = paddle::GetAllocator(q.place());
phi::Allocator::AllocationPtr tmp_workspace = allocator->Allocate(workspace_size);
void* workspace_ptr = tmp_workspace->ptr();
PageAttentionWithKVCacheArguments args{
static_cast<float>(scale), 1.0, 1.0, static_cast<float>(softcap), window_left, window_right,
causal, use_sqrt_alibi, enable_cuda_graph, false, alibi_slopes_ptr, key_ptr, value_ptr, workspace_ptr};
causal, use_sqrt_alibi, enable_cuda_graph, false, alibi_slopes_ptr, key_ptr, value_ptr,
workspace_ptr, merged_qkv, rope_sin_ptr, rope_cos_ptr};
CUINFER_CHECK(cuInferPageAttentionV7(cuinfer_handle,
out.data(),
data_type,
@@ -216,8 +203,6 @@ void PagedAttnKernel(const paddle::Tensor& q,
block_table.data<int32_t>(),
seq_lens.data<int32_t>(),
args));
CUDA_CHECK(cudaFree(workspace_ptr));
}
std::vector<paddle::Tensor> PagedAttn(const paddle::Tensor& q,
@@ -228,6 +213,8 @@ std::vector<paddle::Tensor> PagedAttn(const paddle::Tensor& q,
const paddle::optional<paddle::Tensor> &alibi_slopes,
const paddle::optional<paddle::Tensor> &k,
const paddle::optional<paddle::Tensor> &v,
const paddle::optional<paddle::Tensor> &rope_sin,
const paddle::optional<paddle::Tensor> &rope_cos,
int num_kv_heads,
float scale,
int block_size,
@@ -237,10 +224,15 @@ std::vector<paddle::Tensor> PagedAttn(const paddle::Tensor& q,
int window_right,
float softcap,
bool enable_cuda_graph,
bool use_sqrt_alibi) {
bool use_sqrt_alibi,
bool merged_qkv) {
const auto dtype = q.dtype();
auto out = paddle::empty_like(q, dtype);
auto out_shape = q.shape();
if (merged_qkv) {
out_shape[1] -= 2 * num_kv_heads;
}
auto out = paddle::empty(out_shape, dtype, q.place());
switch (dtype) {
case paddle::DataType::BFLOAT16:
@@ -252,6 +244,8 @@ std::vector<paddle::Tensor> PagedAttn(const paddle::Tensor& q,
alibi_slopes,
k,
v,
rope_sin,
rope_cos,
num_kv_heads,
scale,
block_size,
@@ -262,6 +256,7 @@ std::vector<paddle::Tensor> PagedAttn(const paddle::Tensor& q,
softcap,
enable_cuda_graph,
use_sqrt_alibi,
merged_qkv,
out);
break;
case paddle::DataType::FLOAT16:
@@ -273,6 +268,8 @@ std::vector<paddle::Tensor> PagedAttn(const paddle::Tensor& q,
alibi_slopes,
k,
v,
rope_sin,
rope_cos,
num_kv_heads,
scale,
block_size,
@@ -283,6 +280,7 @@ std::vector<paddle::Tensor> PagedAttn(const paddle::Tensor& q,
softcap,
enable_cuda_graph,
use_sqrt_alibi,
merged_qkv,
out);
break;
default:
@@ -298,8 +296,28 @@ std::vector<std::vector<int64_t>> PagedAttnInferShape(const std::vector<int64_t>
const std::vector<int64_t>& seq_lens_shape,
const std::vector<int64_t>& alibi_slopes_shape,
const std::vector<int64_t>& k_shape,
const std::vector<int64_t>& v_shape) {
return {q_shape};
const std::vector<int64_t>& v_shape,
const std::vector<int64_t>& rope_sin_shape,
const std::vector<int64_t>& rope_cos_shape,
int num_kv_heads,
float scale,
int block_size,
int max_context_len,
bool causal,
int window_left,
int window_right,
float softcap,
bool enable_cuda_graph,
bool use_sqrt_alibi,
bool merged_qkv) {
if (merged_qkv) {
int64_t num_tokens = q_shape[0];
int64_t num_heads = q_shape[1] - 2 * num_kv_heads;
int64_t head_dim = q_shape[2];
return {{num_tokens, num_heads, head_dim}};
} else {
return {q_shape};
}
}
std::vector<paddle::DataType> PagedAttnInferDtype(const paddle::DataType& q_dtype,
@@ -309,13 +327,29 @@ std::vector<paddle::DataType> PagedAttnInferDtype(const paddle::DataType& q_dtyp
const paddle::DataType& seq_lens_dtype,
const paddle::DataType& alibi_slopes_dtype,
const paddle::DataType& k_dtype,
const paddle::DataType& v_dtype) {
const paddle::DataType& v_dtype,
const paddle::DataType& rope_sin_dtype,
const paddle::DataType& rope_cos_dtype,
int num_kv_heads,
float scale,
int block_size,
int max_context_len,
bool causal,
int window_left,
int window_right,
float softcap,
bool enable_cuda_graph,
bool use_sqrt_alibi,
bool merged_qkv) {
return {q_dtype};
}
PD_BUILD_STATIC_OP(paged_attn)
.Inputs({"q", "k_cache", "v_cache", "block_table", "seq_lens", paddle::Optional("alibi_slopes"), paddle::Optional("k"), paddle::Optional("v")})
.Inputs({"q", "k_cache", "v_cache", "block_table", "seq_lens",
paddle::Optional("alibi_slopes"), paddle::Optional("k"),
paddle::Optional("v"), paddle::Optional("rope_sin"),
paddle::Optional("rope_cos")})
.Outputs({"out"})
.Attrs({"num_kv_heads:int",
"scale:float",
@@ -326,12 +360,8 @@ PD_BUILD_STATIC_OP(paged_attn)
"window_right:int",
"softcap:float",
"enable_cuda_graph:bool",
"use_sqrt_alibi:bool"})
"use_sqrt_alibi:bool",
"merged_qkv:bool"})
.SetKernelFn(PD_KERNEL(PagedAttn))
.SetInferShapeFn(PD_INFER_SHAPE(PagedAttnInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(PagedAttnInferDtype));
PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("paged_attn", &PagedAttn, "paged attn function");
}

View File

@@ -13,20 +13,47 @@
// limitations under the License.
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ixinfer.h>
#include <iostream>
#include <vector>
#define CUINFER_CHECK(func) \
do { \
cuinferStatus_t status = (func); \
if (status != CUINFER_STATUS_SUCCESS) { \
std::cerr << "Error in file " << __FILE__ << " on line " \
<< __LINE__ << ": " << cuinferGetErrorString(status) \
<< std::endl; \
throw std::runtime_error("CUINFER_CHECK ERROR"); \
} \
} while (0)
namespace iluvatar {
class IluvatarContext {
public:
IluvatarContext() = default;
~IluvatarContext();
public:
IluvatarContext() = default;
~IluvatarContext();
cuinferHandle_t getIxInferHandle();
cuinferHandle_t getIxInferHandle();
private:
cuinferHandle_t ixinfer_handle_{nullptr};
private:
cuinferHandle_t ixinfer_handle_{nullptr};
};
IluvatarContext* getContextInstance();

View File

@@ -0,0 +1,200 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "iluvatar_context.h"
std::vector<paddle::Tensor> GroupGemm(const paddle::Tensor& x,
const paddle::Tensor& weight,
const paddle::Tensor& weight_scale,
const paddle::Tensor& prefix_sum,
const int32_t group_size) {
auto dev_ctx = static_cast<const phi::CustomContext*>(
paddle::experimental::DeviceContextPool::Instance().Get(x.place()));
auto stream = static_cast<const cudaStream_t>(dev_ctx->stream());
const auto& x_dims = x.dims();
const auto& w_dims = weight.dims();
const auto& ws_dims = weight_scale.dims();
const auto& prefix_sum_dims = prefix_sum.dims();
// [m, k]
PD_CHECK(x_dims.size() == 2, "x should be 2D");
// [n_experts, n, k]
PD_CHECK(w_dims.size() == 3, "weight should be 3D");
// [n_experts, n]
PD_CHECK(ws_dims.size() == 2, "weight_scale should be 2D");
// [n_experts]
PD_CHECK(prefix_sum_dims.size() == 1, "prefix_sum should be 1D");
PD_CHECK(group_size == -1);
auto m = x_dims[0];
auto k = x_dims[1];
auto n_experts = w_dims[0];
auto n = w_dims[1];
PD_CHECK(w_dims[2] == k);
PD_CHECK(ws_dims[0] == n_experts);
PD_CHECK(ws_dims[1] == n);
PD_CHECK(prefix_sum_dims[0] == n_experts);
PD_CHECK(prefix_sum.dtype() == paddle::DataType::INT64);
PD_CHECK(prefix_sum.is_cpu());
PD_CHECK(x.dtype() == paddle::DataType::BFLOAT16 ||
x.dtype() == paddle::DataType::FLOAT16);
PD_CHECK(weight.dtype() == paddle::DataType::INT8);
PD_CHECK(weight_scale.dtype() == x.dtype());
PD_CHECK(x.is_contiguous());
PD_CHECK(weight.is_contiguous());
PD_CHECK(weight_scale.is_contiguous());
const int64_t* prefix_sum_ptr = prefix_sum.data<int64_t>();
auto output = GetEmptyTensor({m, n}, x.dtype(), x.place());
int16_t* out_data = static_cast<int16_t*>(output.data());
const int16_t* x_data = static_cast<const int16_t*>(x.data());
const int8_t* weight_data = weight.data<int8_t>();
const int16_t* weight_scale_data =
static_cast<const int16_t*>(weight_scale.data());
cuinferHandle_t handle = iluvatar::getContextInstance()->getIxInferHandle();
cuinferPointerMode_t cuinfer_ptr_mode = CUINFER_POINTER_MODE_HOST;
cuinferOperation_t transa = CUINFER_OP_T;
cuinferOperation_t transb = CUINFER_OP_N;
cudaDataType_t a_type = CUDA_R_8I;
cudaDataType_t b_type;
cudaDataType_t c_type;
if (x.dtype() == paddle::DataType::FLOAT16) {
b_type = CUDA_R_16F;
} else if (x.dtype() == paddle::DataType::BFLOAT16) {
b_type = CUDA_R_16BF;
} else {
PADDLE_THROW(common::errors::Unimplemented("Unsupported input dtype."));
}
c_type = b_type;
cudaDataType_t Atype = a_type;
cudaDataType_t Btype = b_type;
cudaDataType_t Ctype = c_type;
cudaDataType_t computeType = CUDA_R_32F;
cudaDataType_t scaleType = CUDA_R_32F;
cuinferGEMMCustomOption_t customOption = CUINFER_BLAS_GEMM_CUSTOM_NONE;
cuinferQuantGEMMHostParam cust_host_param;
cust_host_param.size = sizeof(cuinferQuantGEMMHostParam);
cust_host_param.persistent = 0;
cust_host_param.groupSize = group_size;
cuinferQuantGEMMDeviceParam cust_device_param;
cust_device_param.bias = nullptr;
cust_device_param.workspace = nullptr;
int lda = k;
int ldb = k;
int ldc = n;
float beta = 0.f;
float alpha = 1.f;
int batch_count = 1;
size_t pre = 0;
auto* allocator = paddle::GetAllocator(x.place());
phi::Allocator::AllocationPtr tmp_workspace;
for (int i = 0; i < n_experts; i++) {
size_t expert_i_end = prefix_sum_ptr[i];
size_t cur_len = expert_i_end - pre;
pre = expert_i_end;
if (cur_len != 0) {
cust_device_param.scale = weight_scale_data;
if (k % 64 != 0) {
size_t workspace_size;
CUINFER_CHECK(cuinferGetCustomGemmWorkspace(transa,
transb,
n,
cur_len,
k,
Atype,
lda,
lda,
Btype,
ldb,
ldb,
Ctype,
ldc,
ldc,
batch_count,
computeType,
scaleType,
&workspace_size));
tmp_workspace = allocator->Allocate(workspace_size);
cust_device_param.workspace = tmp_workspace->ptr();
} else {
cust_device_param.workspace = nullptr;
}
CUINFER_CHECK(cuinferCustomGemm(handle,
stream,
cuinfer_ptr_mode,
transa,
transb,
n,
cur_len,
k,
&alpha,
weight_data,
Atype,
lda,
lda,
x_data,
Btype,
ldb,
ldb,
&beta,
out_data,
Ctype,
ldc,
ldc,
batch_count,
computeType,
scaleType,
&cust_host_param,
&cust_device_param,
customOption));
}
x_data += cur_len * k;
weight_data += k * n;
weight_scale_data += n;
out_data += cur_len * n;
}
return {output};
}
std::vector<std::vector<int64_t>> GroupGemmInferShape(
const std::vector<int64_t>& x_shape,
const std::vector<int64_t>& weight_shape,
const std::vector<int64_t>& weight_scale_shape,
const std::vector<int64_t>& prefix_sum_shape) {
return {{x_shape[0], weight_shape[1]}};
}
std::vector<paddle::DataType> GroupGemmInferDtype(
const paddle::DataType& input_dtype,
const paddle::DataType& weight_output_dtype,
const paddle::DataType& weight_scale_dtype,
const paddle::DataType& prefix_sum_dtype,
const int moe_topk) {
return {input_dtype};
}
PD_BUILD_STATIC_OP(w8a16_group_gemm)
.Inputs({"x", "weight", "weight_scale", "prefix_sum"})
.Outputs({"output"})
.Attrs({
"group_size:int",
})
.SetKernelFn(PD_KERNEL(GroupGemm))
.SetInferShapeFn(PD_INFER_SHAPE(GroupGemmInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(GroupGemmInferDtype));

View File

@@ -539,9 +539,12 @@ elif paddle.is_compiled_with_custom_device("iluvatar_gpu"):
"gpu_ops/stop_generation_multi_ends.cu",
"gpu_ops/step.cu",
"gpu_ops/token_penalty_multi_scores.cu",
"gpu_ops/sample_kernels/rejection_top_p_sampling.cu",
"gpu_ops/sample_kernels/top_k_renorm_probs.cu",
"iluvatar_ops/moe_dispatch.cu",
"iluvatar_ops/moe_reduce.cu",
"iluvatar_ops/paged_attn.cu",
"iluvatar_ops/w8a16_group_gemm.cu",
"iluvatar_ops/runtime/iluvatar_context.cc",
],
include_dirs=["iluvatar_ops/runtime", "gpu_ops"],

View File

@@ -1,12 +1,12 @@
# Run ERNIE-4.5-300B-A47B & ERNIE-4.5-21B-A3B model on iluvatar machine
The current version of the software merely serves as a demonstration demo for the Iluvatar CoreX combined with the Fastdeploy inference framework for large models. There may be issues when running the latest ERNIE4.5 model, and we will conduct repairs and performance optimization in the future. Subsequent versions will provide customers with a more stable version.
The current version of the software merely serves as a demonstration demo for the Iluvatar CoreX combined with the Fastdeploy inference framework for large models. Running the latest ERNIE4.5 300B model on the GSM8K dataset takes about 6.3 hours.
## Machine Preparation
First, you need to prepare a machine with the following configurations:
First, the `TP=16` when running the ERNIE4.5 300B model and so you need to prepare a machine with the following configurations:
| CPU | Memory | Card | Hard Disk|
| :---: | :---: | :---: | :---: |
| x86 | 1TB| 8xBI150| 1TB|
| x86 | 1TB| 16xBI150| 1TB|
Currently, the entire model needs to be loaded into the host memory, which requires more than 600GB of host memory. This issue will be optimized in subsequent versions.
@@ -46,6 +46,7 @@ script list below:
export PADDLE_XCCL_BACKEND=iluvatar_gpu
export INFERENCE_MSG_QUEUE_ID=232132
export LD_PRELOAD=/usr/local/corex/lib64/libcuda.so.1
export FD_SAMPLING_CLASS=rejection
export FD_DEBUG=1
python3 run_demo.py
```
@@ -64,7 +65,7 @@ prompts = [
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=256)
# load the model
llm = LLM(model="/home/paddle/ernie-4_5-21b-a3b-bf16-paddle", tensor_parallel_size=4, max_model_len=8192, static_decode_blocks=0, quantization='wint8')
llm = LLM(model="/home/paddle/ernie-4_5-21b-a3b-bf16-paddle", tensor_parallel_size=4, max_model_len=8192, static_decode_blocks=0, block_size=16, quantization='wint8')
# Perform batch inference
outputs = llm.generate(prompts, sampling_params)
@@ -118,3 +119,281 @@ Now, let's break down each step:
**Step 3: Drawing the
The largest ocean is the Pacific Ocean, covering an area of approximately ⦠[3], The first scientific expeditions to determine the ocean's depth were the Challenger expedition (1872â1876) and the U.S. Navy Hydrographic Office survey (1877â1879). The oceanic crust is thin and irregular, consisting of upward moving magma from the mantle below, and cooling and solidifying on the surface. The shallowest parts of the ocean are called the continental shelves. Large tides are caused mainly by the alignment of the Sun, Moon, and Earth during new or full moons. The origin of the word "ocean" is not clear. The first global oceanic topography survey was completed by the Challenger expedition (1872â1876). [57] The sound speed in the ocean is primarily a function of water temperature and salinity, and varies with depth. The deep-ocean floor is mostly flat and devoid of life, with the exception of seamounts and various underwater volcanic features, including seamounts and hydrothermal vents. [73] Today, the five ocean
```
## Run ernie4.5 300B model with the GSM8K dataset
1. Download GSM8K dataset
```bash
wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
```
2. Prepare `bench_gsm8k.py`
```python
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Fastdeploy + ERNIE-4.5-Turbo 的指标评估 """
# adapted from https://github.com/sgl-project/sglang/blob/main/benchmark/gsm8k/bench_other.py
import argparse
import ast
import json
import re
import time
from concurrent.futures import ThreadPoolExecutor
import numpy as np
import requests
from tqdm import tqdm
INVALID = -9999999
def call_generate(prompt, **kwargs):
"""
Generates response based on the input prompt.
Args:
prompt (str): The input prompt text.
**kwargs: Keyword arguments, including server IP address and port number.
Returns:
str: The response generated based on the prompt.
"""
url = f"http://{kwargs['ip']}:{kwargs['port']}/v1/chat/completions"
headers = {"Content-Type": "application/json"}
data = {
"messages": [
{
"role": "user",
"content": prompt,
}
],
"temperature": 0.6,
"max_tokens": 2047,
"top_p": 0.95,
"do_sample": True,
}
response = requests.post(url, headers=headers, data=json.dumps(data))
out = response.json()
return out["choices"][0]["message"]["content"]
def get_one_example(lines, i, include_answer):
"""
Retrieves a question-answer example from the given list of text lines.
Args:
lines (list of dict): A list of question-answer pairs.
i (int): The index of the question-answer pair to retrieve from lines.
include_answer (bool): Whether to include the answer in the returned string.
Returns:
str: A formatted question-answer string in the format "Question: <question>\nAnswer: <answer>".
"""
ret = "Question: " + lines[i]["question"] + "\nAnswer:"
if include_answer:
ret += " " + lines[i]["answer"]
return ret
def get_few_shot_examples(lines, k):
"""
Selects k examples from the given list of text lines and concatenates them into a single string.
Args:
lines (list): A list containing text lines.
k (int): The number of examples to select.
Returns:
str: A string composed of k examples, separated by two newline characters.
"""
ret = ""
for i in range(k):
ret += get_one_example(lines, i, True) + "\n\n"
return ret
def get_answer_value(answer_str):
"""
Extracts numerical values from an answer string and returns them.
Args:
answer_str (str): The string containing the answer.
Returns:
The extracted numerical value; returns "INVALID" if extraction fails.
"""
answer_str = answer_str.replace(",", "")
numbers = re.findall(r"\d+", answer_str)
if len(numbers) < 1:
return INVALID
try:
return ast.literal_eval(numbers[-1])
except SyntaxError:
return INVALID
def read_jsonl(filename: str):
"""
Reads a JSONL file.
Args:
filename (str): Path to the JSONL file.
Yields:
dict: A dictionary object corresponding to each line in the JSONL file.
"""
with open(filename) as fin:
for line in fin:
if line.startswith("#"):
continue
yield json.loads(line)
def main(args):
"""
Process inputs and generate answers by calling the model in parallel using a thread pool.
Args:
args (argparse.Namespace):
- num_questions (int): Number of questions to process.
- num_shots (int): Number of few-shot learning examples.
- ip (str): IP address of the model service.
- port (int): Port number of the model service.
- parallel (int): Number of questions to process in parallel.
- result_file (str): File path to store the results.
Returns:
None
"""
# Read data
filename = "test.jsonl"
lines = list(read_jsonl(filename))
# Construct prompts
num_questions = args.num_questions
num_shots = args.num_shots
few_shot_examples = get_few_shot_examples(lines, num_shots)
questions = []
labels = []
for i in range(len(lines[:num_questions])):
questions.append(get_one_example(lines, i, False))
labels.append(get_answer_value(lines[i]["answer"]))
assert all(l != INVALID for l in labels)
states = [None] * len(labels)
# Use thread pool
def get_one_answer(i):
answer = call_generate(
prompt=few_shot_examples + questions[i],
# stop=["Question", "Assistant:", "<|separator|>"],
ip=args.ip,
port=args.port,
)
states[i] = answer
tic = time.time()
if args.parallel == 1:
for i in tqdm(range(len(questions))):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
list(
tqdm(
executor.map(get_one_answer, list(range(len(questions)))),
total=len(questions),
)
)
latency = time.time() - tic
preds = []
for i in range(len(states)):
preds.append(get_answer_value(states[i]))
# Compute accuracy
acc = np.mean(np.array(preds) == np.array(labels))
invalid = np.mean(np.array(preds) == INVALID)
# Print results
print(f"Accuracy: {acc:.3f}")
print(f"Invalid: {invalid:.3f}")
print(f"Latency: {latency:.3f} s")
with open(args.result_file, "a") as fout:
value = {
"task": "gsm8k",
"backend": "paddlepaddle",
"num_gpus": 1,
"latency": round(latency, 3),
"accuracy": round(acc, 3),
"num_requests": args.num_questions,
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ip", type=str, default="127.0.0.1")
parser.add_argument("--port", type=str, default="8188")
parser.add_argument("--num-shots", type=int, default=10)
parser.add_argument("--data-path", type=str, default="test.jsonl")
parser.add_argument("--num-questions", type=int, default=1319)
parser.add_argument("--result-file", type=str, default="result.jsonl")
parser.add_argument("--parallel", type=int, default=1)
args = parser.parse_args()
main(args)
```
3. Prepare `run_bench.sh`
```bash
#!/bin/bash
export PADDLE_XCCL_BACKEND=iluvatar_gpu
export INFERENCE_MSG_QUEUE_ID=232132
export LD_PRELOAD=/usr/local/corex/lib64/libcuda.so.1
export FD_SAMPLING_CLASS=rejection
python3 -m fastdeploy.entrypoints.openai.api_server --model "/home/paddle/ernie-45t" --port 8188 --tensor-parallel-size 16 --block-size 16 --static-decode-blocks 0 --quantization wint8
```
4. Running the Script
Firstly, open a terminal and run:
```bash
./run_bench.sh
```
After the service is ready, open another terminal and run:
```bash
python3 -u bench_gsm8k.py --port 8188 --num-questions 1319 --num-shots 5 --parallel 8
```
It takes about 6.3 hours to run the GSM8K dataset.
```
Accuracy: 0.964
Invaild: 0.000
Latency: 22918.186 s
```

View File

@@ -27,6 +27,7 @@ from paddleformers.transformers.configuration_utils import PretrainedConfig
import fastdeploy
from fastdeploy import envs
from fastdeploy.model_executor.layers.quantization.quant_base import QuantConfigBase
from fastdeploy.platforms import current_platform
from fastdeploy.utils import check_unified_ckpt, get_logger
logger = get_logger("config", "config.log")
@@ -733,7 +734,7 @@ class CacheConfig:
self.gpu_memory_utilization = 0.9
self.num_gpu_blocks_override = None
self.kv_cache_ratio = 0.75
self.enc_dec_block_num = 2
self.enc_dec_block_num = 0 if current_platform.is_iluvatar() else 2
self.prealloc_dec_block_slot_num_threshold = 5
self.cache_dtype = "bfloat16"
self.model_cfg = None

View File

@@ -961,7 +961,10 @@ class LLMEngine:
)
if self.do_profile:
get_profile_block_num = np.zeros([1], dtype=np.int32)
if paddle.is_compiled_with_custom_device("iluvatar_gpu"):
get_profile_block_num = np.zeros([self.cfg.worker_num_per_node], dtype=np.int32)
else:
get_profile_block_num = np.zeros([1], dtype=np.int32)
self.get_profile_block_num_signal = IPCSignal(
name="get_profile_block_num",
array=get_profile_block_num,

View File

@@ -85,45 +85,120 @@ class IluvatarAttnBackend(AttentionBackend):
Which is used only for testing purpose.
"""
def __init__(
self,
llm_config: FDConfig,
kv_num_heads: int,
num_heads: int,
head_dim: int,
):
def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int, head_dim: int):
super().__init__()
self.attention_metadata = IluvatarAttentionMetadata()
self.attention_metadata.block_size = llm_config.cache_config.block_size
assert llm_config.cache_config.enc_dec_block_num == 0, "Iluvatar does not support yet"
self.attention_metadata.block_size = fd_config.parallel_config.block_size
assert (
fd_config.parallel_config.enc_dec_block_num == 0
), f"Iluvatar does not support yet, {fd_config.parallel_config.enc_dec_block_num}"
assert self.attention_metadata.block_size == 16, "Iluvatar paged attn requires block_size must be 16."
self.attention_metadata.max_context_len = llm_config.parallel_config.max_model_len
self.attention_metadata.causal = getattr(llm_config.model_config, "causal", True)
self.speculate_method = getattr(llm_config.parallel_config, "speculate_method", None)
self.attention_metadata.max_context_len = fd_config.parallel_config.max_model_len
self.attention_metadata.causal = getattr(fd_config.model_config, "causal", True)
self.speculate_method = getattr(fd_config.parallel_config, "speculate_method", None)
self.use_speculate = self.speculate_method is not None
self.attention_metadata.num_kv_heads = kv_num_heads
self.attention_metadata.dropout = llm_config.model_config.hidden_dropout_prob
self.attention_metadata.dropout = fd_config.model_config.hidden_dropout_prob
self.num_heads = num_heads
self.total_num_heads = num_heads + 2 * kv_num_heads
self.head_dim = head_dim
self.hidden_dim = num_heads * head_dim
self.total_hidden_dim = self.total_num_heads * head_dim
# note: scale need to change if using MLA
self.attention_metadata.scale = 1.0 / sqrt(head_dim)
self.num_layers = llm_config.model_config.num_hidden_layers
self.num_layers = fd_config.model_config.num_hidden_layers
self.dtype = paddle.get_default_dtype()
self.record_block_table_metadata = {}
self.only_use_flash_attn = int(os.getenv("FD_ILUVATAR_ONLY_USE_FLASH_ATTN", 0)) == 1
self.do_check_kv_cache = int(os.getenv("FD_ILUVATAR_CHECK_KV_CACHE_CORRECTNESS", 0)) == 1
if not self.only_use_flash_attn:
assert self.attention_metadata.block_size == 16, "Iluvatar paged attn requires block_size must be 16."
if self.do_check_kv_cache:
self.record_batched_k = [{} for _ in range(self.num_layers)]
self.record_batched_v = [{} for _ in range(self.num_layers)]
self.enable_fused_attention = int(os.getenv("FD_ILUVATAR_ENABLE_FUSED_ATTN", 1))
def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
self.attention_metadata.block_tables = forward_meta.block_tables
self.attention_metadata.attn_mask = forward_meta.attn_mask
self.attention_metadata.seq_lens = forward_meta.seq_lens_decoder
self.attention_metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
self.attention_metadata.cu_seqlens_k = forward_meta.cu_seqlens_k
self.prefill_info_dict = {}
self.decode_info_dict = {}
prefill_non_zeros_ids = forward_meta.seq_lens_this_time > 1
decode_non_zeros_ids = forward_meta.seq_lens_this_time == 1
self.prefill_info_dict["batch_ids"] = paddle.where(prefill_non_zeros_ids)[0]
self.decode_info_dict["batch_ids"] = paddle.where(decode_non_zeros_ids)[0]
self.prefill_len = len(self.prefill_info_dict["batch_ids"])
self.decode_len = len(self.decode_info_dict["batch_ids"])
# only prefill
if self.decode_len == 0:
cu_seq_ids = list(range(self.prefill_len + 1))
self.prefill_info_dict["cu_seqlens_q"] = forward_meta.cu_seqlens_q[cu_seq_ids]
# only decode
elif self.prefill_len == 0:
pass
# both prefill and decode
else:
prefill_num_tokens = paddle.sum(forward_meta.seq_lens_this_time[prefill_non_zeros_ids])
decode_num_tokens = paddle.sum(forward_meta.seq_lens_this_time[decode_non_zeros_ids])
self.prefill_info_dict["cu_seqlens_q"] = paddle.zeros(
[self.prefill_len + 1], dtype=forward_meta.cu_seqlens_q.dtype
)
self.prefill_info_dict["cu_seqlens_q"][1:] = forward_meta.seq_lens_encoder[
self.prefill_info_dict["batch_ids"], 0
]
self.prefill_info_dict["cu_seqlens_q"] = paddle.cumsum(self.prefill_info_dict["cu_seqlens_q"])
self.prefill_qkv = paddle.zeros([prefill_num_tokens, self.total_hidden_dim], dtype=self.dtype)
self.decode_qkv = paddle.zeros([decode_num_tokens, self.total_hidden_dim], dtype=self.dtype)
self.merged_output = paddle.zeros(
[prefill_num_tokens + decode_num_tokens, self.num_heads, self.head_dim], dtype=self.dtype
)
prefill_start, decode_start, start = 0, 0, 0
non_zeros_ids = forward_meta.seq_lens_this_time != 0
non_zeros_seq_lens = forward_meta.seq_lens_this_time[non_zeros_ids]
end = non_zeros_seq_lens[0]
if end > 1:
last_stage = "prefill"
prefill_end = end
decode_end = 0
else:
last_stage = "decode"
prefill_end = 0
decode_end = end
self.prefill_info_dict["id_group"] = []
self.prefill_info_dict["reverse_id_group"] = []
self.decode_info_dict["id_group"] = []
self.decode_info_dict["reverse_id_group"] = []
self.record_stages = []
for seq_len in non_zeros_seq_lens[1:]:
if seq_len > 1:
if last_stage == "decode":
self.record_stages.append((last_stage, len(self.decode_info_dict["id_group"])))
self.decode_info_dict["id_group"].append((decode_start, decode_end))
self.decode_info_dict["reverse_id_group"].append((start, end))
decode_start = decode_end
start = end
last_stage = "prefill"
prefill_end += seq_len
end += seq_len
else:
if last_stage == "prefill":
self.record_stages.append((last_stage, len(self.prefill_info_dict["id_group"])))
self.prefill_info_dict["id_group"].append((prefill_start, prefill_end))
self.prefill_info_dict["reverse_id_group"].append((start, end))
prefill_start = prefill_end
start = end
last_stage = "decode"
decode_end += seq_len
end += seq_len
if prefill_start < prefill_end:
self.record_stages.append(("prefill", len(self.prefill_info_dict["id_group"])))
self.prefill_info_dict["id_group"].append((prefill_start, prefill_end))
self.prefill_info_dict["reverse_id_group"].append((start, end))
if decode_start < decode_end:
self.record_stages.append(("decode", len(self.decode_info_dict["id_group"])))
self.decode_info_dict["id_group"].append((decode_start, decode_end))
self.decode_info_dict["reverse_id_group"].append((start, end))
def get_attntion_meta(self):
"""get_attntion_meta"""
@@ -144,93 +219,15 @@ class IluvatarAttnBackend(AttentionBackend):
self.head_dim,
)
def get_new_kv(
self,
k,
v,
k_cache_id: int,
v_cache_id: int,
forward_meta: ForwardMeta,
debug_paged_attn=False,
):
new_k = []
new_v = []
tensor_start = 0
for batch_idx in range(forward_meta.block_tables.shape[0]):
seq_len = forward_meta.seq_lens_this_time[batch_idx]
if seq_len == 0:
continue
tensor_end = tensor_start + seq_len
slice_k = k[tensor_start:tensor_end, :, :]
slice_v = v[tensor_start:tensor_end, :, :]
if seq_len > 1:
# prefill
new_k.append(slice_k)
new_v.append(slice_v)
else:
# decode
assert seq_len == 1
cur_block_tables = forward_meta.block_tables[batch_idx]
cur_used_block_tables = cur_block_tables[cur_block_tables != -1]
assert (
batch_idx in self.record_block_table_metadata
), f"Key error: {batch_idx} vs {self.record_block_table_metadata}."
cur_block_table_metadata = self.record_block_table_metadata[batch_idx]
record_last_block_id = cur_block_table_metadata["block_id"]
assert record_last_block_id != -1
for block_id in cur_used_block_tables:
if block_id == record_last_block_id:
cache_end = cur_block_table_metadata["cache_end"]
block_k_cache = forward_meta.caches[k_cache_id][block_id, :, 0:cache_end, :]
block_v_cache = forward_meta.caches[v_cache_id][block_id, :, 0:cache_end, :]
else:
block_k_cache = forward_meta.caches[k_cache_id][block_id]
block_v_cache = forward_meta.caches[v_cache_id][block_id]
# [num_kv_heads, block_size, head_dim] -> [block_size, num_kv_heads, head_dim]
new_k.append(block_k_cache.transpose([1, 0, 2]).contiguous())
new_v.append(block_v_cache.transpose([1, 0, 2]).contiguous())
if block_id == record_last_block_id:
break
# as line 301 show, record_block_table_metadata updates when executing the last layer,
# so slice_k and slice_v has been updated in block_k_cache and block_v_cache
if not (debug_paged_attn and (k_cache_id / 2 == self.num_layers - 1)):
new_k.append(slice_k)
new_v.append(slice_v)
tensor_start = tensor_end
if len(new_k) == 1:
return new_k[0], new_v[0]
else:
new_k = paddle.concat(new_k, axis=0)
new_v = paddle.concat(new_v, axis=0)
return new_k, new_v
def update_kv_cache(
self,
k,
v,
k_cache_id: int,
v_cache_id: int,
layer_id: int,
forward_meta: ForwardMeta,
specific_batch_ids=None,
debug_paged_attn=False,
def prefill_update_kv_cache(
self, k, v, k_cache_id: int, v_cache_id: int, layer_id: int, forward_meta: ForwardMeta, prefill_batch_ids: list
):
# [num_tokens, num_kv_heads, head_dim] -> [num_kv_heads, num_tokens, head_dim]
trans_k = k.transpose([1, 0, 2]).contiguous()
trans_v = v.transpose([1, 0, 2]).contiguous()
tensor_start = 0
for batch_idx in range(forward_meta.block_tables.shape[0]):
if specific_batch_ids is not None and batch_idx not in specific_batch_ids:
continue
for batch_idx in prefill_batch_ids:
seq_len = forward_meta.seq_lens_this_time[batch_idx]
if seq_len == 0:
continue
tensor_end = tensor_start + seq_len
slice_trans_k = trans_k[:, tensor_start:tensor_end, :]
@@ -239,146 +236,67 @@ class IluvatarAttnBackend(AttentionBackend):
cur_block_tables = forward_meta.block_tables[batch_idx]
cur_used_block_tables = cur_block_tables[cur_block_tables != -1]
# prefill
if seq_len > 1:
cache_start = 0
cur_used_num_blocks = cur_used_block_tables.shape[0]
for i, block_id in enumerate(cur_used_block_tables):
# last block: seq_len - cache_start <= block_size
if i == cur_used_num_blocks - 1:
cache_end = seq_len - cache_start
assert cache_end <= self.attention_metadata.block_size
forward_meta.caches[k_cache_id][block_id, :, 0:cache_end, :] = slice_trans_k[
:, cache_start:seq_len, :
]
forward_meta.caches[v_cache_id][block_id, :, 0:cache_end, :] = slice_trans_v[
:, cache_start:seq_len, :
]
if layer_id == self.num_layers - 1:
self.record_block_table_metadata[batch_idx] = {
"block_id": block_id.item(),
"cache_end": cache_end,
}
# non last block: seq_lens_this_time > block_size
else:
assert seq_len > self.attention_metadata.block_size
cache_end = cache_start + self.attention_metadata.block_size
forward_meta.caches[k_cache_id][block_id] = slice_trans_k[:, cache_start:cache_end, :]
forward_meta.caches[v_cache_id][block_id] = slice_trans_v[:, cache_start:cache_end, :]
cache_start += self.attention_metadata.block_size
else:
# decode
assert seq_len == 1
cur_last_block_id = cur_used_block_tables[-1].item()
assert cur_last_block_id != -1
assert (
batch_idx in self.record_block_table_metadata
), f"Key error: {batch_idx} vs {self.record_block_table_metadata}."
cur_block_table_metadata = self.record_block_table_metadata[batch_idx]
record_last_block_id = cur_block_table_metadata["block_id"]
if cur_last_block_id == record_last_block_id:
# not alloc new block in decode stage
cache_start = cur_block_table_metadata["cache_end"]
cache_start = 0
cur_used_num_blocks = cur_used_block_tables.shape[0]
for i, block_id in enumerate(cur_used_block_tables):
# last block: seq_len - cache_start <= block_size
if i == cur_used_num_blocks - 1:
cache_end = seq_len - cache_start
assert cache_end <= self.attention_metadata.block_size
paddle.assign(
slice_trans_k[:, cache_start:seq_len, :],
output=forward_meta.caches[k_cache_id][block_id, :, 0:cache_end, :],
)
paddle.assign(
slice_trans_v[:, cache_start:seq_len, :],
output=forward_meta.caches[v_cache_id][block_id, :, 0:cache_end, :],
)
if layer_id == self.num_layers - 1:
self.record_block_table_metadata[batch_idx] = {
"block_id": block_id.item(),
"cache_end": cache_end.item(),
}
# non last block: seq_lens_this_time > block_size
else:
# alloc new block in decode stage
cache_start = 0
cache_end = cache_start + 1
assert cache_end <= self.attention_metadata.block_size
# paged attn API will update kv cache with inplace mode
if not debug_paged_attn:
forward_meta.caches[k_cache_id][cur_last_block_id, :, cache_start:cache_end, :] = slice_trans_k
forward_meta.caches[v_cache_id][cur_last_block_id, :, cache_start:cache_end, :] = slice_trans_v
# update record_block_table_metadata
if layer_id == self.num_layers - 1:
self.record_block_table_metadata[batch_idx]["block_id"] = cur_last_block_id
self.record_block_table_metadata[batch_idx]["cache_end"] = cache_end
assert seq_len > self.attention_metadata.block_size
cache_end = cache_start + self.attention_metadata.block_size
paddle.assign(
slice_trans_k[:, cache_start:cache_end, :], output=forward_meta.caches[k_cache_id][block_id]
)
paddle.assign(
slice_trans_v[:, cache_start:cache_end, :], output=forward_meta.caches[v_cache_id][block_id]
)
cache_start += self.attention_metadata.block_size
tensor_start = tensor_end
def _check_new_kv_correctness(self, k, v, new_k, new_v, layer_id: int, forward_meta: ForwardMeta):
tensor_start = 0
for batch_idx, seq_lens_this_time in enumerate(forward_meta.seq_lens_this_time):
if seq_lens_this_time == 0:
continue
# note: the second request will also use the batch_idx 0 instead of 1 in
# the streaming inference mode, so use seq_lens_this_time > 1 with the same
# batch_idx represents the second request comes.
if seq_lens_this_time > 1 and batch_idx in self.record_batched_k[layer_id]:
print(
f"clear self.record_batched_batched_k: "
f"layer_id={layer_id}, batch_id={batch_idx}, "
f"record_lens={len(self.record_batched_k[layer_id][batch_idx])}"
)
self.record_batched_k[layer_id][batch_idx].clear()
self.record_batched_v[layer_id][batch_idx].clear()
tensor_end = tensor_start + seq_lens_this_time
slice_k = k[tensor_start:tensor_end, :, :]
slice_v = v[tensor_start:tensor_end, :, :]
if batch_idx not in self.record_batched_k[layer_id]:
self.record_batched_k[layer_id][batch_idx] = []
self.record_batched_v[layer_id][batch_idx] = []
self.record_batched_k[layer_id][batch_idx].append(slice_k)
self.record_batched_v[layer_id][batch_idx].append(slice_v)
tensor_start = tensor_end
ref_k, ref_v = [], []
for batch_idx, seq_lens_this_time in enumerate(forward_meta.seq_lens_this_time):
if seq_lens_this_time == 0:
continue
bached_k_list = self.record_batched_k[layer_id][batch_idx]
bached_v_list = self.record_batched_v[layer_id][batch_idx]
ref_k.extend(bached_k_list)
ref_v.extend(bached_v_list)
ref_k = paddle.concat(ref_k, axis=0)
ref_v = paddle.concat(ref_v, axis=0)
print(
f"_check_new_kv_correctness: layer_id={layer_id}, "
f"k.shape={k.shape}, v.shape={v.shape}, "
f"ref_k.shape={ref_k.shape}, ref_v.shape={ref_v.shape}, "
f"new_k.shape={new_k.shape}, new_v.shape={new_v.shape}, "
f"len(self.record_batched_k[layer_id])={len(self.record_batched_k[layer_id])}, "
f"len(self.record_batched_k[layer_id][0])={len(self.record_batched_k[layer_id][0])}, "
f"forward_meta.seq_lens_this_time={forward_meta.seq_lens_this_time}"
f"ref_k[-2:, 0:2, 0:2]={ref_k[-2:, 0:2, 0:2]}, "
f"ref_v[-2:, 0:2, 0:2]={ref_v[-2:, 0:2, 0:2]}, "
f"new_k[-2:, 0:2, 0:2]={new_k[-2:, 0:2, 0:2]}, "
f"new_v[-2:, 0:2, 0:2]={new_v[-2:, 0:2, 0:2]}"
)
assert paddle.allclose(
ref_k.to("cpu").to(paddle.float32),
new_k.to("cpu").to(paddle.float32),
)
assert paddle.allclose(
ref_v.to("cpu").to(paddle.float32),
new_v.to("cpu").to(paddle.float32),
)
def get_splited_qkv(self, qkv: paddle.Tensor, forward_meta: ForwardMeta):
q_end = self.num_heads * self.head_dim
def get_splited_qkv(
self, qkv: paddle.Tensor, forward_meta: ForwardMeta, cu_seqlens_q: paddle.Tensor, batch_ids=None
):
q_end = self.hidden_dim
k_end = q_end + self.attention_metadata.num_kv_heads * self.head_dim
v_end = k_end + self.attention_metadata.num_kv_heads * self.head_dim
assert v_end == qkv.shape[-1], f"Shape mistach: {v_end} vs {qkv.shape[-1]}"
assert qkv.shape[0] == forward_meta.cu_seqlens_q[-1]
assert v_end == qkv.shape[-1], f"Shape mismatch: {v_end} vs {qkv.shape[-1]}"
assert qkv.shape[0] == cu_seqlens_q[-1], f"Shape mismatch: {qkv.shape[0]} vs {cu_seqlens_q[-1]}"
if batch_ids is None:
batch_ids = list(range(forward_meta.seq_lens_this_time.shape[0]))
q = qkv[..., 0:q_end]
k = qkv[..., q_end:k_end]
v = qkv[..., k_end:v_end]
q = q.view([-1, self.num_heads, self.head_dim]).contiguous()
k = k.view([-1, self.attention_metadata.num_kv_heads, self.head_dim]).contiguous()
v = v.view([-1, self.attention_metadata.num_kv_heads, self.head_dim]).contiguous()
# forward_meta.seq_lens_this_time [max_batch,]
for batch_idx in range(forward_meta.seq_lens_this_time.shape[0]):
q = q.view([-1, self.num_heads, self.head_dim])
k = k.view([-1, self.attention_metadata.num_kv_heads, self.head_dim])
v = v.view([-1, self.attention_metadata.num_kv_heads, self.head_dim])
for idx in range(len(cu_seqlens_q) - 1):
batch_idx = batch_ids[idx]
seq_len_i = forward_meta.seq_lens_this_time[batch_idx]
if seq_len_i == 0:
continue
cached_kv_len = forward_meta.seq_lens_decoder[batch_idx][0]
cu_seq_start_q = forward_meta.cu_seqlens_q[batch_idx]
cu_seq_end_q = forward_meta.cu_seqlens_q[batch_idx + 1]
cu_seq_start_q = cu_seqlens_q[idx]
cu_seq_end_q = cu_seqlens_q[idx + 1]
# forward_meta.rotary_embs is [2, 1, S, 1, D]
if forward_meta.rotary_embs is not None:
cos = forward_meta.rotary_embs[0, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :]
@@ -388,75 +306,114 @@ class IluvatarAttnBackend(AttentionBackend):
return q, k, v
def get_splited_info_by_stage(self, q, k, v, forward_meta: ForwardMeta):
prefill_info_dict = {"q": [], "k": [], "v": [], "batch_ids": []}
decode_info_dict = {"q": [], "k": [], "v": [], "batch_ids": []}
tensor_start = 0
for batch_idx, seq_lens_this_time in enumerate(forward_meta.seq_lens_this_time):
if seq_lens_this_time == 0:
continue
tensor_end = tensor_start + seq_lens_this_time
slice_q = q[tensor_start:tensor_end, :, :]
slice_k = k[tensor_start:tensor_end, :, :]
slice_v = v[tensor_start:tensor_end, :, :]
if seq_lens_this_time > 1:
prefill_info_dict["q"].append(slice_q)
prefill_info_dict["k"].append(slice_k)
prefill_info_dict["v"].append(slice_v)
prefill_info_dict["batch_ids"].append(batch_idx)
def split_pd_qkv(self, qkv):
for ids, reverse_ids in zip(self.prefill_info_dict["id_group"], self.prefill_info_dict["reverse_id_group"]):
self.prefill_qkv[ids[0] : ids[1], :] = qkv[reverse_ids[0] : reverse_ids[1], :]
for ids, reverse_ids in zip(self.decode_info_dict["id_group"], self.decode_info_dict["reverse_id_group"]):
self.decode_qkv[ids[0] : ids[1], :] = qkv[reverse_ids[0] : reverse_ids[1], :]
return self.prefill_qkv, self.decode_qkv
def merge_pd_output(self, prefill_out, decode_out):
for stage, idx in self.record_stages:
if stage == "prefill":
ids = self.prefill_info_dict["id_group"][idx]
reverse_ids = self.prefill_info_dict["reverse_id_group"][idx]
self.merged_output[reverse_ids[0] : reverse_ids[1], :, :] = prefill_out[ids[0] : ids[1], :, :]
else:
assert seq_lens_this_time == 1
decode_info_dict["q"].append(slice_q)
decode_info_dict["k"].append(slice_k)
decode_info_dict["v"].append(slice_v)
decode_info_dict["batch_ids"].append(batch_idx)
tensor_start = tensor_end
ids = self.decode_info_dict["id_group"][idx]
reverse_ids = self.decode_info_dict["reverse_id_group"][idx]
self.merged_output[reverse_ids[0] : reverse_ids[1], :, :] = decode_out[ids[0] : ids[1], :, :]
return self.merged_output
if len(prefill_info_dict["batch_ids"]) > 0:
prefill_info_dict["q"] = paddle.concat(prefill_info_dict["q"], axis=0)
prefill_info_dict["k"] = paddle.concat(prefill_info_dict["k"], axis=0)
prefill_info_dict["v"] = paddle.concat(prefill_info_dict["v"], axis=0)
cu_seq_ids = list(map(lambda x: x + 1, prefill_info_dict["batch_ids"]))
prefill_info_dict["cu_seq_ids"] = [0, *cu_seq_ids]
def forward_prefill(self, prefill_qkv, layer_id, k_cache_id, v_cache_id, forward_meta: ForwardMeta):
prefill_q, prefill_k, prefill_v = self.get_splited_qkv(
prefill_qkv,
forward_meta,
self.prefill_info_dict["cu_seqlens_q"],
batch_ids=self.prefill_info_dict["batch_ids"],
)
if len(decode_info_dict["batch_ids"]) > 0:
decode_info_dict["q"] = paddle.concat(decode_info_dict["q"], axis=0)
decode_info_dict["k"] = paddle.concat(decode_info_dict["k"], axis=0)
decode_info_dict["v"] = paddle.concat(decode_info_dict["v"], axis=0)
prefill_out = flash_attn_unpadded(
prefill_q,
prefill_k,
prefill_v,
cu_seqlens_q=self.prefill_info_dict["cu_seqlens_q"],
cu_seqlens_k=self.prefill_info_dict["cu_seqlens_q"],
max_seqlen_q=self.attention_metadata.max_context_len,
max_seqlen_k=self.attention_metadata.max_context_len,
scale=self.attention_metadata.scale,
dropout=self.attention_metadata.dropout,
causal=self.attention_metadata.causal,
return_softmax=self.attention_metadata.return_softmax,
)[0]
self.prefill_update_kv_cache(
prefill_k, prefill_v, k_cache_id, v_cache_id, layer_id, forward_meta, self.prefill_info_dict["batch_ids"]
)
return prefill_info_dict, decode_info_dict
return prefill_out
def merge_output(self, prefill_out, decode_out, forward_meta: ForwardMeta):
assert not (prefill_out is None and decode_out is None), "prefill and decode output cannot both be None"
if prefill_out is None:
return decode_out
elif decode_out is None:
return prefill_out
def forward_decode(self, decode_qkv, k_cache_id, v_cache_id, forward_meta: ForwardMeta):
k_cache = forward_meta.caches[k_cache_id]
v_cache = forward_meta.caches[v_cache_id]
if self.enable_fused_attention:
rope_cos = forward_meta.rotary_embs[0, 0, :, :, :]
rope_sin = forward_meta.rotary_embs[1, 0, :, :, :]
decode_out = paged_attention(
decode_qkv.view([-1, self.total_num_heads, self.head_dim]),
k_cache,
v_cache,
block_tables=forward_meta.block_tables[self.decode_info_dict["batch_ids"], :],
seq_lens=forward_meta.seq_lens_decoder[self.decode_info_dict["batch_ids"], 0] + 1,
num_kv_heads=self.attention_metadata.num_kv_heads,
scale=self.attention_metadata.scale,
block_size=self.attention_metadata.block_size,
max_context_len=self.attention_metadata.max_context_len,
alibi_slopes=self.attention_metadata.alibi_slopes,
causal=self.attention_metadata.causal,
window_left=self.attention_metadata.window_left,
window_right=self.attention_metadata.window_right,
softcap=self.attention_metadata.softcap,
use_cuda_graph=self.attention_metadata.use_cuda_graph,
use_sqrt_alibi=self.attention_metadata.use_sqrt_alibi,
merged_qkv=True,
k=decode_qkv,
v=decode_qkv,
rope_sin=rope_sin,
rope_cos=rope_cos,
)
else:
merged_output = []
prefill_tensor_start = 0
decode_tensor_start = 0
for seq_lens_this_time in forward_meta.seq_lens_this_time:
if seq_lens_this_time == 0:
continue
if seq_lens_this_time > 1:
tensor_end = prefill_tensor_start + seq_lens_this_time
merged_output.append(prefill_out[prefill_tensor_start:tensor_end, :, :])
prefill_tensor_start = tensor_end
else:
assert seq_lens_this_time == 1
tensor_end = decode_tensor_start + seq_lens_this_time
merged_output.append(decode_out[decode_tensor_start:tensor_end, :, :])
decode_tensor_start = tensor_end
decode_q, decode_k, decode_v = self.get_splited_qkv(
decode_qkv,
forward_meta,
self.decode_info_dict["cu_seqlens_q"],
batch_ids=self.decode_info_dict["batch_ids"],
)
assert (
prefill_tensor_start == prefill_out.shape[0]
), f"prefill merged unfinished: {prefill_tensor_start} vs {prefill_out.shape[0]}"
assert (
decode_tensor_start == decode_out.shape[0]
), f"decode merged unfinished: {decode_tensor_start} vs {decode_out.shape[0]}"
merged_output = paddle.concat(merged_output, axis=0)
return merged_output
decode_out = paged_attention(
decode_q,
k_cache,
v_cache,
block_tables=forward_meta.block_tables[self.decode_info_dict["batch_ids"], :],
seq_lens=forward_meta.seq_lens_decoder[self.decode_info_dict["batch_ids"], 0] + 1,
num_kv_heads=self.attention_metadata.num_kv_heads,
scale=self.attention_metadata.scale,
block_size=self.attention_metadata.block_size,
max_context_len=self.attention_metadata.max_context_len,
alibi_slopes=self.attention_metadata.alibi_slopes,
causal=self.attention_metadata.causal,
window_left=self.attention_metadata.window_left,
window_right=self.attention_metadata.window_right,
softcap=self.attention_metadata.softcap,
use_cuda_graph=self.attention_metadata.use_cuda_graph,
use_sqrt_alibi=self.attention_metadata.use_sqrt_alibi,
k=decode_k,
v=decode_v,
)
return decode_out
def forward_mixed(
self,
@@ -476,110 +433,19 @@ class IluvatarAttnBackend(AttentionBackend):
layer_id = layer.layer_id
k_cache_id = layer_id * 2
v_cache_id = k_cache_id + 1
assert qkv is not None
q_dim = qkv.dim()
q, k, v = self.get_splited_qkv(qkv, forward_meta)
assert q_dim == 2
if self.only_use_flash_attn:
new_k, new_v = self.get_new_kv(k, v, k_cache_id, v_cache_id, forward_meta)
if self.do_check_kv_cache:
self._check_new_kv_correctness(k, v, new_k, new_v, layer_id, forward_meta)
if self.decode_len == 0:
output = self.forward_prefill(qkv, layer_id, k_cache_id, v_cache_id, forward_meta)
out = flash_attn_unpadded(
q,
new_k,
new_v,
cu_seqlens_q=self.attention_metadata.cu_seqlens_q,
cu_seqlens_k=self.attention_metadata.cu_seqlens_k,
max_seqlen_q=self.attention_metadata.max_context_len,
max_seqlen_k=self.attention_metadata.max_context_len,
scale=self.attention_metadata.scale,
dropout=self.attention_metadata.dropout,
causal=self.attention_metadata.causal,
return_softmax=self.attention_metadata.return_softmax,
)[0]
self.update_kv_cache(k, v, k_cache_id, v_cache_id, layer_id, forward_meta)
elif self.prefill_len == 0:
output = self.forward_decode(qkv, k_cache_id, v_cache_id, forward_meta)
else:
prefill_info_dict, decode_info_dict = self.get_splited_info_by_stage(q, k, v, forward_meta)
prefill_out, decode_out = None, None
prefill_qkv, decode_qkv = self.split_pd_qkv(qkv)
prefill_output = self.forward_prefill(prefill_qkv, layer_id, k_cache_id, v_cache_id, forward_meta)
decode_output = self.forward_decode(decode_qkv, k_cache_id, v_cache_id, forward_meta)
output = self.merge_pd_output(prefill_output, decode_output)
if len(prefill_info_dict["batch_ids"]) > 0:
prefill_out = flash_attn_unpadded(
prefill_info_dict["q"],
prefill_info_dict["k"],
prefill_info_dict["v"],
cu_seqlens_q=forward_meta.cu_seqlens_q[prefill_info_dict["cu_seq_ids"]],
cu_seqlens_k=forward_meta.cu_seqlens_k[prefill_info_dict["cu_seq_ids"]],
max_seqlen_q=self.attention_metadata.max_context_len,
max_seqlen_k=self.attention_metadata.max_context_len,
scale=self.attention_metadata.scale,
dropout=self.attention_metadata.dropout,
causal=self.attention_metadata.causal,
return_softmax=self.attention_metadata.return_softmax,
)[0]
self.update_kv_cache(
prefill_info_dict["k"],
prefill_info_dict["v"],
k_cache_id,
v_cache_id,
layer_id,
forward_meta,
specific_batch_ids=prefill_info_dict["batch_ids"],
)
if len(decode_info_dict["batch_ids"]) > 0:
k_cache = forward_meta.caches[k_cache_id]
v_cache = forward_meta.caches[v_cache_id]
decode_out = paged_attention(
decode_info_dict["q"],
k_cache,
v_cache,
block_tables=forward_meta.block_tables[decode_info_dict["batch_ids"], :],
seq_lens=forward_meta.seq_lens_decoder[decode_info_dict["batch_ids"], 0] + 1,
num_kv_heads=self.attention_metadata.num_kv_heads,
scale=self.attention_metadata.scale,
block_size=self.attention_metadata.block_size,
max_context_len=self.attention_metadata.max_context_len,
alibi_slopes=self.attention_metadata.alibi_slopes,
causal=self.attention_metadata.causal,
window_left=self.attention_metadata.window_left,
window_right=self.attention_metadata.window_right,
softcap=self.attention_metadata.softcap,
use_cuda_graph=self.attention_metadata.use_cuda_graph,
use_sqrt_alibi=self.attention_metadata.use_sqrt_alibi,
k=decode_info_dict["k"],
v=decode_info_dict["v"],
)
if self.do_check_kv_cache:
self.update_kv_cache(
decode_info_dict["k"],
decode_info_dict["v"],
k_cache_id,
v_cache_id,
layer_id,
forward_meta,
specific_batch_ids=decode_info_dict["batch_ids"],
debug_paged_attn=True,
)
if self.do_check_kv_cache:
new_k, new_v = self.get_new_kv(
k,
v,
k_cache_id,
v_cache_id,
forward_meta,
debug_paged_attn=True,
)
self._check_new_kv_correctness(k, v, new_k, new_v, layer_id, forward_meta)
out = self.merge_output(prefill_out, decode_out, forward_meta)
if q_dim == 2:
out = out.view([-1, self.num_heads * self.head_dim])
return out
output = output.view([-1, self.num_heads * self.head_dim])
return output

View File

@@ -128,10 +128,16 @@ def rejection_top_p_sampling(
rejection_top_p_sampling
"""
try:
from fastdeploy.model_executor.ops.gpu import (
rejection_top_p_sampling,
top_k_renorm_probs,
)
if current_platform.is_iluvatar():
from fastdeploy.model_executor.ops.iluvatar import (
rejection_top_p_sampling,
top_k_renorm_probs,
)
else:
from fastdeploy.model_executor.ops.gpu import (
rejection_top_p_sampling,
top_k_renorm_probs,
)
if paddle.count_nonzero(top_k) == 0:
ids = rejection_top_p_sampling(

View File

@@ -20,6 +20,11 @@ import paddle
from paddle.incubate.nn.functional import swiglu
from paddle.nn.quant import weight_only_linear
try:
from fastdeploy.model_executor.ops.iluvatar import w8a16_group_gemm
except ImportError:
w8a16_group_gemm = None
def group_gemm(
input: paddle.Tensor,
@@ -67,53 +72,32 @@ def group_gemm(
scale_i = scale[i]
# avoid d2d?
output[expert_start:expert_end] = weight_only_linear(
input_i,
weight_i,
weight_scale=scale_i,
weight_dtype="int8",
group_size=-1,
input_i, weight_i, weight_scale=scale_i, weight_dtype="int8", group_size=-1
)
def iluvatar_moe_expert_ffn(
permute_input: paddle.Tensor,
tokens_expert_prefix_sum: paddle.Tensor,
up_gate_proj_weight: paddle.Tensor,
down_proj_weight: paddle.Tensor,
up_gate_proj_bias: Optional[paddle.Tensor],
up_gate_proj_scale: Optional[paddle.Tensor],
down_proj_scale: Optional[paddle.Tensor],
down_proj_in_scale: Optional[paddle.Tensor],
ffn1_weight: paddle.Tensor,
ffn2_weight: paddle.Tensor,
ffn1_bias: Optional[paddle.Tensor],
ffn1_scale: Optional[paddle.Tensor],
ffn2_scale: Optional[paddle.Tensor],
ffn2_in_scale: Optional[paddle.Tensor],
expert_idx_per_token: Optional[paddle.Tensor],
quant_method: str,
used_in_ep_low_latency: bool,
):
assert up_gate_proj_bias is None
assert up_gate_proj_scale is not None
assert down_proj_scale is not None
assert down_proj_in_scale is None
assert ffn1_bias is None
assert ffn1_scale is not None
assert ffn2_scale is not None
assert ffn2_in_scale is None
assert expert_idx_per_token is None
assert quant_method in ("weight_only_int8")
assert not used_in_ep_low_latency
tokens_expert_prefix_sum_cpu = tokens_expert_prefix_sum.to("cpu")
up_gate_proj_output = paddle.empty(
[permute_input.shape[0], up_gate_proj_weight.shape[1]],
dtype=permute_input.dtype,
)
group_gemm(
permute_input,
tokens_expert_prefix_sum_cpu,
up_gate_proj_weight,
up_gate_proj_scale,
up_gate_proj_output,
)
act_out = swiglu(up_gate_proj_output)
output = paddle.empty([act_out.shape[0], down_proj_weight.shape[1]], dtype=act_out.dtype)
group_gemm(
act_out,
tokens_expert_prefix_sum_cpu,
down_proj_weight,
down_proj_scale,
output,
)
ffn1_output = w8a16_group_gemm(permute_input, ffn1_weight, ffn1_scale, tokens_expert_prefix_sum_cpu, -1)
act_out = swiglu(ffn1_output)
output = w8a16_group_gemm(act_out, ffn2_weight, ffn2_scale, tokens_expert_prefix_sum_cpu, -1)
return output

View File

@@ -39,8 +39,11 @@ def paged_attention(
softcap: float = 0.0,
use_cuda_graph: bool = False,
use_sqrt_alibi: bool = False,
merged_qkv: bool = False,
k: paddle.Tensor = None,
v: paddle.Tensor = None,
rope_sin: paddle.Tensor = None,
rope_cos: paddle.Tensor = None,
):
output = paged_attn(
q,
@@ -51,6 +54,8 @@ def paged_attention(
alibi_slopes,
k,
v,
rope_sin,
rope_cos,
num_kv_heads,
scale,
block_size,
@@ -61,5 +66,6 @@ def paged_attention(
softcap,
use_cuda_graph,
use_sqrt_alibi,
merged_qkv,
)
return output[0] if isinstance(output, list) else output

View File

@@ -211,7 +211,7 @@ def post_process_normal(
model_output.stop_flags,
)
if current_platform.is_cuda():
if current_platform.is_cuda() or current_platform.is_iluvatar():
set_stop_value_multi_ends(
sampler_output.sampled_token_ids,
model_output.stop_flags,

View File

@@ -41,20 +41,28 @@ from fastdeploy.model_executor.layers.rotary_embedding import get_rope, get_rope
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.model_executor.layers.sample.sampler import Sampler, SpeculativeSampler
from fastdeploy.model_executor.model_loader import get_model_loader
from fastdeploy.model_executor.ops.gpu import (
recover_decode_task,
set_value_by_flags_and_idx,
share_external_data,
)
from fastdeploy.platforms import current_platform
if current_platform.is_iluvatar():
from fastdeploy.model_executor.ops.iluvatar import set_value_by_flags_and_idx
recover_decode_task = None
share_external_data = None
else:
from fastdeploy.model_executor.ops.gpu import (
recover_decode_task,
set_value_by_flags_and_idx,
share_external_data,
)
from fastdeploy.model_executor.pre_and_post_process import (
post_process,
pre_process,
rebuild_padding,
step_cuda,
)
from fastdeploy.platforms import current_platform
if not current_platform.is_dcu():
if not (current_platform.is_dcu() or current_platform.is_iluvatar()):
from fastdeploy.spec_decode import MTPProposer, NgramProposer
from fastdeploy import envs

File diff suppressed because it is too large Load Diff

View File

@@ -16,22 +16,22 @@
import gc
import os
from typing import List, Optional
import time
import numpy as np
import paddle
from paddle import nn
from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request
from fastdeploy.inter_communicator import IPCSignal
from fastdeploy.utils import get_logger, set_random_seed
from fastdeploy.worker.gpu_worker import GpuWorker
from fastdeploy.worker.iluvatar_model_runner import IluvatarModelRunner
from fastdeploy.worker.output import ModelRunnerOutput
from fastdeploy.worker.worker_base import WorkerBase
from fastdeploy.worker.worker_process import PaddleDisWorkerProc
logger = get_logger("iluvatar_worker", "iluvatar_worker.log")
class IluvatarWorker(WorkerBase):
class IluvatarWorker(GpuWorker):
""" """
def __init__(
@@ -40,15 +40,16 @@ class IluvatarWorker(WorkerBase):
local_rank: int,
rank: int,
):
super().__init__(
super(IluvatarWorker, self).__init__(
fd_config=fd_config,
local_rank=local_rank,
rank=rank,
)
pass
def init_device(self):
"""Initialize device and Construct model runner"""
"""
Initialize device and construct model runner
"""
if paddle.is_compiled_with_custom_device("iluvatar_gpu"):
# Set evironment variable
self.device = f"iluvatar_gpu:{self.local_rank}"
@@ -70,12 +71,6 @@ class IluvatarWorker(WorkerBase):
local_rank=self.local_rank,
)
def exist_prefill(self):
"""
check whether prefill stage exist
"""
return self.model_runner.exist_prefill()
def determine_available_memory(self) -> int:
"""
Profiles the peak memory usage of the model to determine how much
@@ -92,51 +87,86 @@ class IluvatarWorker(WorkerBase):
# 1. Record memory state before profile run
return int(float(os.getenv("FD_ILUVATAR_KVCACHE_MEM", "3")) * 1024**3)
def load_model(self) -> None:
""" """
self.model_runner.load_model()
def get_model(self) -> nn.Layer:
""" """
return self.model_runner.get_model()
class IluvatarPaddleDisWorkerProc(PaddleDisWorkerProc):
"""
Paddle Distributed wrapper for fastdeploy.worker.Worker,
for handling single-node multi-GPU tensor parallel.
The wrapper internally executes an event loop that continuously executes requests
in the task queue. Control flow is transmitted by IPC.
"""
def initialize_cache(self, num_gpu_blocks: int) -> None:
""" """
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)
def __init__(self, fd_config: FDConfig, ranks: int = 1, local_rank: int = 0):
super(IluvatarPaddleDisWorkerProc, self).__init__(
fd_config=fd_config,
ranks=ranks,
local_rank=local_rank,
)
def execute_model(
self,
model_forward_batch: Optional[List[Request]] = None,
num_running_requests: int = None,
) -> Optional[ModelRunnerOutput]:
""" """
output = self.model_runner.execute_model(model_forward_batch, num_running_requests)
return output
def initialize_kv_cache(self) -> None:
"""Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs.
def preprocess_new_task(self, req_dicts: List[Request], num_running_requests: int) -> None:
"""Process new requests and then start the decode loop
TODO(gongshaotian):The scheduler should schedule the handling of prefill,
and workers and modelrunners should not perceive it.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
self.model_runner.insert_prefill_inputs(req_dicts=req_dicts, num_running_requests=num_running_requests)
if self.fd_config.parallel_config.do_profile:
# 1. Get available memory(bytes)
available_kv_cache_memory = self.worker.determine_available_memory()
logger.info(f"------- available_kv_cache_memory:{available_kv_cache_memory / 1024**3} GB --------")
def graph_optimize_and_warm_up_model(self) -> None:
"""
Perform the warm-up and the graph optimization
"""
# 1. Warm up model
# NOTE(gongshaotian): may be not need warm_up at this place
if self.model_runner.graph_opt_level >= 1:
self.model_runner.sot_warmup()
# 2. Calculate the appropriate number of blocks
model_block_memory_used = self.worker.cal_theortical_kvcache()
num_blocks_local = int(available_kv_cache_memory // model_block_memory_used)
# NOTE(liuzichang): Too many block will lead to illegal memory access
# We will develop dynamic limits in future.
if num_blocks_local > 40000:
logger.info(f"------- Reset num_blocks_local {num_blocks_local} to 40000")
num_blocks_local = min(40000, num_blocks_local)
logger.info(f"------- model_block_memory_used:{model_block_memory_used} --------")
logger.info(f"------- num_blocks_local:{num_blocks_local} --------")
# 2. Triger cuda grpah capture
self.model_runner.capture_model()
set_random_seed(self.fd_config.model_config.seed)
# NOTE(yuzhe.wu): Using the old version of the calculation num_blocks_global method,
# because the new version that adopting allreduce min will report a bad request error
# when running 300b model. The Relation commit:
# https://github.com/PaddlePaddle/FastDeploy/commit/2f74e93d7e87aa3ffec3fc6966bf11ab5363b956
def check_health(self) -> bool:
""" """
return True
# 3. Send IPCSignal
get_profile_block_num = np.zeros(shape=[self.ranks], dtype=np.int32)
self.get_profile_block_num_signal = IPCSignal(
name="get_profile_block_num",
array=get_profile_block_num,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
create=False,
)
self.get_profile_block_num_signal.value[self.local_rank] = num_blocks_local
def cal_theortical_kvcache(self) -> int:
""" """
return self.model_runner.cal_theortical_kvcache()
# Wait all worker send the signal
while np.any(self.get_profile_block_num_signal.value <= 0):
time.sleep(0.01)
num_blocks_global = self.get_profile_block_num_signal.value.min().item()
if num_blocks_global < 0:
logger.error(
"The total number of blocks cannot be less than zero."
"Please increase gpu_memory_utilization"
"Or decrease max_num_batched_tokens(max model length) "
)
raise ValueError(
"The total number of blocks cannot be less than zero."
"Please increase gpu_memory_utilization"
"Or decrease max_num_batched_tokens(max model length) "
)
self.get_profile_block_num_signal.value[self.local_rank] = num_blocks_global
else:
num_blocks_global = self.fd_config.parallel_config.total_block_num
# 4. init kv_cache with accurate num_blocks
logger.info(f"------- num_blocks_global:{num_blocks_global} --------")
self.worker.initialize_cache(num_gpu_blocks=num_blocks_global)

View File

@@ -723,7 +723,12 @@ def run_worker_proc() -> None:
fd_config = initialize_fd_config(args, ranks, local_rank)
# Create worker process
worker_proc = PaddleDisWorkerProc(fd_config, ranks, local_rank)
if current_platform.is_iluvatar():
from fastdeploy.worker.iluvatar_worker import IluvatarPaddleDisWorkerProc
worker_proc = IluvatarPaddleDisWorkerProc(fd_config, ranks, local_rank)
else:
worker_proc = PaddleDisWorkerProc(fd_config, ranks, local_rank)
# Initialize device and create model runner
worker_proc.init_device()

View File

@@ -1,11 +1,11 @@
setuptools>=79.0.1,<80.0
setuptools>=62.3.0,<80.0
pre-commit
yapf
flake8
ruamel.yaml
zmq
aiozmq
openai
openai>=1.93.0
tqdm
pynvml
uvicorn
@@ -24,7 +24,15 @@ setuptools-scm>=8
prometheus-client
decord
moviepy
wheel
use-triton-in-paddle
crcmod
fastsafetensors==0.1.14
msgpack
opentelemetry-api>=1.24.0
opentelemetry-sdk>=1.24.0
opentelemetry-instrumentation-redis
opentelemetry-instrumentation-mysql
opentelemetry-distro
opentelemetry-exporter-otlp
opentelemetry-instrumentation-fastapi

View File

@@ -13,10 +13,10 @@ python -m pip install -r requirements_iluvatar.txt
echo "uninstall org"
python -m pip uninstall paddlepaddle -y
python -m pip uninstall paddle-iluvatar-gpu -y
python -m pip install paddlepaddle==3.1.0a0 -i https://www.paddlepaddle.org.cn/packages/stable/cpu/
python -m pip install --pre paddlepaddle==3.0.0.dev20250708 -i https://www.paddlepaddle.org.cn/packages/nightly/cpu/
# TODO: Change to open access URL
# python -m pip install --pre paddle-iluvatar-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/ixuca/
python -m pip install /data1/fastdeploy/packages/paddle_iluvatar_gpu-0.0.0-cp310-cp310-linux_x86_64.whl
python -m pip install --pre paddle-iluvatar-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/ixuca/
# python -m pip install /data1/fastdeploy/packages/paddle_iluvatar_gpu-0.0.0-cp310-cp310-linux_x86_64.whl
# Patch, remove if image updated
cp /data1/fastdeploy/packages/cusolver.h /usr/local/lib/python3.10/site-packages/paddle/include/paddle/phi/backends/dynload/cusolver.h
echo "build whl"
@@ -30,6 +30,7 @@ rm -rf log/*
export INFERENCE_MSG_QUEUE_ID=232132
export FD_DEBUG=1
export PADDLE_XCCL_BACKEND=iluvatar_gpu
export FD_SAMPLING_CLASS=rejection
python test/ci_use/iluvatar_UT/run_ernie300B_4layer.py
exit_code=$?
echo exit_code is ${exit_code}

View File

@@ -10,7 +10,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.00001, max_tokens=16)
# 加载模型
llm = LLM(
model="/data1/fastdeploy/ERNIE_300B_4L",
tensor_parallel_size=16,
tensor_parallel_size=8,
max_model_len=8192,
static_decode_blocks=0,
quantization="wint8",
@@ -27,14 +27,14 @@ assert outputs[0].outputs.token_ids == [
59335,
68170,
183,
49080,
94717,
82966,
99140,
31615,
51497,
94851,
60764,
10889,
97404,
100088,
36310,
95633,
95913,
41459,
95049,
94970,
96840,
2,
]
], f"{outputs[0].outputs.token_ids}"