mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[Iluvatar GPU] Optimze attention and moe performance (#3234)
This commit is contained in:
3
.github/workflows/ci_gcu.yml
vendored
3
.github/workflows/ci_gcu.yml
vendored
@@ -13,7 +13,8 @@ concurrency:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
CI_GCU:
|
CI_GCU:
|
||||||
runs-on: [self-hosted, GCU-S60-8Card]
|
runs-on:
|
||||||
|
group: GCU
|
||||||
steps:
|
steps:
|
||||||
- name: Print current runner name
|
- name: Print current runner name
|
||||||
run: |
|
run: |
|
||||||
|
3
.github/workflows/ci_iluvatar.yml
vendored
3
.github/workflows/ci_iluvatar.yml
vendored
@@ -11,7 +11,8 @@ concurrency:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
CI_ILUVATAR:
|
CI_ILUVATAR:
|
||||||
runs-on: [self-hosted, IXUCA]
|
runs-on:
|
||||||
|
group: IXUCA
|
||||||
steps:
|
steps:
|
||||||
- name: Print current runner name
|
- name: Print current runner name
|
||||||
run: |
|
run: |
|
||||||
|
@@ -29,7 +29,11 @@ std::vector<paddle::Tensor> TopPSamplingReject(const paddle::Tensor &probs,
|
|||||||
|
|
||||||
// need_batch_random
|
// need_batch_random
|
||||||
if (seed == -1) {
|
if (seed == -1) {
|
||||||
|
#ifdef PADDLE_WITH_COREX
|
||||||
|
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(probs.place()));
|
||||||
|
#else
|
||||||
phi::GPUContext* dev_ctx = static_cast<phi::GPUContext*>(phi::DeviceContextPool::Instance().Get(probs.place()));
|
phi::GPUContext* dev_ctx = static_cast<phi::GPUContext*>(phi::DeviceContextPool::Instance().Get(probs.place()));
|
||||||
|
#endif
|
||||||
auto gen_cuda = dev_ctx->GetGenerator();
|
auto gen_cuda = dev_ctx->GetGenerator();
|
||||||
auto seed_offset = gen_cuda->IncrementOffset(32 * batch_size);
|
auto seed_offset = gen_cuda->IncrementOffset(32 * batch_size);
|
||||||
philox_seed = seed_offset.first;
|
philox_seed = seed_offset.first;
|
||||||
|
@@ -212,9 +212,15 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
|
|||||||
prob_greater_than_threshold[j] = pred(prob_vec[j]) ? prob_vec[j] : 0;
|
prob_greater_than_threshold[j] = pred(prob_vec[j]) ? prob_vec[j] : 0;
|
||||||
valid[j] = pred(prob_vec[j]) && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d;
|
valid[j] = pred(prob_vec[j]) && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d;
|
||||||
}
|
}
|
||||||
|
#ifdef PADDLE_WITH_COREX
|
||||||
|
float aggregate_local =
|
||||||
|
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage->block_prim.reduce)
|
||||||
|
.Sum(prob_greater_than_threshold);
|
||||||
|
#else
|
||||||
float aggregate_local =
|
float aggregate_local =
|
||||||
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage->block_prim.reduce)
|
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage->block_prim.reduce)
|
||||||
.Sum<VEC_SIZE>(prob_greater_than_threshold);
|
.Sum<VEC_SIZE>(prob_greater_than_threshold);
|
||||||
|
#endif
|
||||||
if (tx == 0) {
|
if (tx == 0) {
|
||||||
temp_storage->block_aggregate.value = aggregate_local;
|
temp_storage->block_aggregate.value = aggregate_local;
|
||||||
}
|
}
|
||||||
@@ -226,8 +232,13 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
|
|||||||
DeterministicInclusiveSum<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>(
|
DeterministicInclusiveSum<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>(
|
||||||
prob_greater_than_threshold, inclusive_cdf, temp_storage);
|
prob_greater_than_threshold, inclusive_cdf, temp_storage);
|
||||||
} else {
|
} else {
|
||||||
|
#ifdef PADDLE_WITH_COREX
|
||||||
|
BlockScan<float, BLOCK_THREADS, SCAN_ALGORITHM>(temp_storage->block_prim.scan)
|
||||||
|
.InclusiveSum(prob_greater_than_threshold, inclusive_cdf);
|
||||||
|
#else
|
||||||
BlockScan<float, BLOCK_THREADS, SCAN_ALGORITHM>(temp_storage->block_prim.scan)
|
BlockScan<float, BLOCK_THREADS, SCAN_ALGORITHM>(temp_storage->block_prim.scan)
|
||||||
.InclusiveSum<VEC_SIZE>(prob_greater_than_threshold, inclusive_cdf);
|
.InclusiveSum<VEC_SIZE>(prob_greater_than_threshold, inclusive_cdf);
|
||||||
|
#endif
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
}
|
}
|
||||||
@@ -239,11 +250,21 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
|
|||||||
|
|
||||||
bool greater_than_u_diff[VEC_SIZE];
|
bool greater_than_u_diff[VEC_SIZE];
|
||||||
#ifdef SAMPLING_CUB_SUBTRACTLEFT_DEFINED
|
#ifdef SAMPLING_CUB_SUBTRACTLEFT_DEFINED
|
||||||
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
|
#ifdef PADDLE_WITH_COREX
|
||||||
.SubtractLeft<VEC_SIZE>(greater_than_u, greater_than_u_diff, BoolDiffOp());
|
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
|
||||||
|
.SubtractLeft(greater_than_u, greater_than_u_diff, BoolDiffOp());
|
||||||
|
#else
|
||||||
|
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
|
||||||
|
.SubtractLeft<VEC_SIZE>(greater_than_u, greater_than_u_diff, BoolDiffOp());
|
||||||
|
#endif
|
||||||
#else
|
#else
|
||||||
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
|
#ifdef PADDLE_WITH_COREX
|
||||||
.FlagHeads<VEC_SIZE>(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0);
|
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
|
||||||
|
.FlagHeads(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0);
|
||||||
|
#else
|
||||||
|
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
|
||||||
|
.FlagHeads<VEC_SIZE>(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0);
|
||||||
|
#endif
|
||||||
#endif
|
#endif
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
@@ -355,18 +376,30 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
|
|||||||
(probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
|
(probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef PADDLE_WITH_COREX
|
||||||
|
aggregate_gt_pivot_0 +=
|
||||||
|
BlockReduce<ValueCount<float>, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count)
|
||||||
|
.Sum(probs_gt_pivot_0);
|
||||||
|
#else
|
||||||
aggregate_gt_pivot_0 +=
|
aggregate_gt_pivot_0 +=
|
||||||
BlockReduce<ValueCount<float>, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count)
|
BlockReduce<ValueCount<float>, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count)
|
||||||
.Sum<VEC_SIZE>(probs_gt_pivot_0);
|
.Sum<VEC_SIZE>(probs_gt_pivot_0);
|
||||||
|
#endif
|
||||||
if (tx == 0) {
|
if (tx == 0) {
|
||||||
temp_storage.block_aggregate.pair = aggregate_gt_pivot_0;
|
temp_storage.block_aggregate.pair = aggregate_gt_pivot_0;
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
aggregate_gt_pivot_0 = temp_storage.block_aggregate.pair;
|
aggregate_gt_pivot_0 = temp_storage.block_aggregate.pair;
|
||||||
|
|
||||||
|
#ifdef PADDLE_WITH_COREX
|
||||||
|
aggregate_gt_pivot_1 +=
|
||||||
|
BlockReduce<ValueCount<float>, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count)
|
||||||
|
.Sum(probs_gt_pivot_1);
|
||||||
|
#else
|
||||||
aggregate_gt_pivot_1 +=
|
aggregate_gt_pivot_1 +=
|
||||||
BlockReduce<ValueCount<float>, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count)
|
BlockReduce<ValueCount<float>, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count)
|
||||||
.Sum<VEC_SIZE>(probs_gt_pivot_1);
|
.Sum<VEC_SIZE>(probs_gt_pivot_1);
|
||||||
|
#endif
|
||||||
if (tx == 0) {
|
if (tx == 0) {
|
||||||
temp_storage.block_aggregate.pair = aggregate_gt_pivot_1;
|
temp_storage.block_aggregate.pair = aggregate_gt_pivot_1;
|
||||||
}
|
}
|
||||||
@@ -466,16 +499,26 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
|
|||||||
probs_gt_pivot_1[j] = (probs_vec[j] > pivot_1) ? probs_vec[j] : 0;
|
probs_gt_pivot_1[j] = (probs_vec[j] > pivot_1) ? probs_vec[j] : 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef PADDLE_WITH_COREX
|
||||||
|
aggregate_gt_pivot_0 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
|
||||||
|
.Sum(probs_gt_pivot_0);
|
||||||
|
#else
|
||||||
aggregate_gt_pivot_0 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
|
aggregate_gt_pivot_0 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
|
||||||
.Sum<VEC_SIZE>(probs_gt_pivot_0);
|
.Sum<VEC_SIZE>(probs_gt_pivot_0);
|
||||||
|
#endif
|
||||||
if (tx == 0) {
|
if (tx == 0) {
|
||||||
temp_storage.block_aggregate.value = aggregate_gt_pivot_0;
|
temp_storage.block_aggregate.value = aggregate_gt_pivot_0;
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
aggregate_gt_pivot_0 = temp_storage.block_aggregate.value;
|
aggregate_gt_pivot_0 = temp_storage.block_aggregate.value;
|
||||||
|
|
||||||
|
#ifdef PADDLE_WITH_COREX
|
||||||
|
aggregate_gt_pivot_1 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
|
||||||
|
.Sum(probs_gt_pivot_1);
|
||||||
|
#else
|
||||||
aggregate_gt_pivot_1 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
|
aggregate_gt_pivot_1 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
|
||||||
.Sum<VEC_SIZE>(probs_gt_pivot_1);
|
.Sum<VEC_SIZE>(probs_gt_pivot_1);
|
||||||
|
#endif
|
||||||
if (tx == 0) {
|
if (tx == 0) {
|
||||||
temp_storage.block_aggregate.value = aggregate_gt_pivot_1;
|
temp_storage.block_aggregate.value = aggregate_gt_pivot_1;
|
||||||
}
|
}
|
||||||
@@ -521,9 +564,15 @@ __device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, u
|
|||||||
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
||||||
in_data_[j] = in_data_vec[j];
|
in_data_[j] = in_data_vec[j];
|
||||||
}
|
}
|
||||||
|
#ifdef PADDLE_WITH_COREX
|
||||||
|
max_val = max(
|
||||||
|
max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
|
||||||
|
.Reduce(in_data_, cub::Max()));
|
||||||
|
#else
|
||||||
max_val = max(
|
max_val = max(
|
||||||
max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
|
max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
|
||||||
.Reduce<VEC_SIZE>(in_data_, cub::Max()));
|
.Reduce<VEC_SIZE>(in_data_, cub::Max()));
|
||||||
|
#endif
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
}
|
}
|
||||||
if (tx == 0) {
|
if (tx == 0) {
|
||||||
@@ -605,7 +654,11 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
|
|||||||
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
|
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
|
||||||
const uint32_t row_idx = bx;
|
const uint32_t row_idx = bx;
|
||||||
const uint32_t k = top_k_arr[row_idx] == 0 ? d : top_k_arr[row_idx];
|
const uint32_t k = top_k_arr[row_idx] == 0 ? d : top_k_arr[row_idx];
|
||||||
|
#ifdef PADDLE_WITH_COREX
|
||||||
|
double pivot = std::numeric_limits<float>::infinity(), normalizer = 1;
|
||||||
|
#else
|
||||||
double pivot = -cuda::std::numeric_limits<float>::infinity(), normalizer = 1;
|
double pivot = -cuda::std::numeric_limits<float>::infinity(), normalizer = 1;
|
||||||
|
#endif
|
||||||
vec_t<float, VEC_SIZE> probs_vec;
|
vec_t<float, VEC_SIZE> probs_vec;
|
||||||
if (k < d) {
|
if (k < d) {
|
||||||
extern __shared__ __align__(alignof(RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>))
|
extern __shared__ __align__(alignof(RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>))
|
||||||
@@ -659,14 +712,26 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef PADDLE_WITH_COREX
|
||||||
|
aggregate_gt_pivot_0 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||||
|
temp_storage.block_prim.reduce_value_count)
|
||||||
|
.Sum(probs_gt_pivot_0_pair);
|
||||||
|
#else
|
||||||
aggregate_gt_pivot_0 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
aggregate_gt_pivot_0 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||||
temp_storage.block_prim.reduce_value_count)
|
temp_storage.block_prim.reduce_value_count)
|
||||||
.Sum<VEC_SIZE>(probs_gt_pivot_0_pair);
|
.Sum<VEC_SIZE>(probs_gt_pivot_0_pair);
|
||||||
|
#endif
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
|
#ifdef PADDLE_WITH_COREX
|
||||||
|
aggregate_gt_pivot_1 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||||
|
temp_storage.block_prim.reduce_value_count)
|
||||||
|
.Sum(probs_gt_pivot_1_pair);
|
||||||
|
#else
|
||||||
aggregate_gt_pivot_1 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
aggregate_gt_pivot_1 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||||
temp_storage.block_prim.reduce_value_count)
|
temp_storage.block_prim.reduce_value_count)
|
||||||
.Sum<VEC_SIZE>(probs_gt_pivot_1_pair);
|
.Sum<VEC_SIZE>(probs_gt_pivot_1_pair);
|
||||||
|
#endif
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
}
|
}
|
||||||
min_gt_low =
|
min_gt_low =
|
||||||
|
@@ -258,9 +258,13 @@ inline std::pair<int, int> GetCudaComputeCapability() {
|
|||||||
|
|
||||||
/******************* math *******************/
|
/******************* math *******************/
|
||||||
__forceinline__ __device__ float ptx_rcp(float x) {
|
__forceinline__ __device__ float ptx_rcp(float x) {
|
||||||
|
#ifdef PADDLE_WITH_COREX
|
||||||
|
return __ivcorex_rcpf(x);
|
||||||
|
#else
|
||||||
float y;
|
float y;
|
||||||
asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x));
|
asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x));
|
||||||
return y;
|
return y;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T1, typename T2>
|
template <typename T1, typename T2>
|
||||||
|
@@ -15,15 +15,6 @@
|
|||||||
#include "helper.h"
|
#include "helper.h"
|
||||||
#include "iluvatar_context.h"
|
#include "iluvatar_context.h"
|
||||||
|
|
||||||
#define CUINFER_CHECK(func) \
|
|
||||||
do { \
|
|
||||||
cuinferStatus_t status = (func); \
|
|
||||||
if (status != CUINFER_STATUS_SUCCESS) { \
|
|
||||||
std::cerr << "Error in file " << __FILE__ << " on line " << __LINE__ << ": " \
|
|
||||||
<< cuinferGetErrorString(status) << std::endl; \
|
|
||||||
throw std::runtime_error("CUINFER_CHECK ERROR"); \
|
|
||||||
} \
|
|
||||||
} while (0)
|
|
||||||
|
|
||||||
template <paddle::DataType T>
|
template <paddle::DataType T>
|
||||||
void PagedAttnKernel(const paddle::Tensor& q,
|
void PagedAttnKernel(const paddle::Tensor& q,
|
||||||
@@ -34,6 +25,8 @@ void PagedAttnKernel(const paddle::Tensor& q,
|
|||||||
const paddle::optional<paddle::Tensor> &alibi_slopes,
|
const paddle::optional<paddle::Tensor> &alibi_slopes,
|
||||||
const paddle::optional<paddle::Tensor> &k,
|
const paddle::optional<paddle::Tensor> &k,
|
||||||
const paddle::optional<paddle::Tensor> &v,
|
const paddle::optional<paddle::Tensor> &v,
|
||||||
|
const paddle::optional<paddle::Tensor> &rope_sin,
|
||||||
|
const paddle::optional<paddle::Tensor> &rope_cos,
|
||||||
int num_kv_heads,
|
int num_kv_heads,
|
||||||
float scale,
|
float scale,
|
||||||
int block_size,
|
int block_size,
|
||||||
@@ -44,6 +37,7 @@ void PagedAttnKernel(const paddle::Tensor& q,
|
|||||||
float softcap,
|
float softcap,
|
||||||
bool enable_cuda_graph,
|
bool enable_cuda_graph,
|
||||||
bool use_sqrt_alibi,
|
bool use_sqrt_alibi,
|
||||||
|
bool merged_qkv,
|
||||||
paddle::Tensor& out) {
|
paddle::Tensor& out) {
|
||||||
if (alibi_slopes) {
|
if (alibi_slopes) {
|
||||||
PADDLE_ENFORCE_EQ(alibi_slopes.get().dtype(),
|
PADDLE_ENFORCE_EQ(alibi_slopes.get().dtype(),
|
||||||
@@ -75,14 +69,6 @@ void PagedAttnKernel(const paddle::Tensor& q,
|
|||||||
true,
|
true,
|
||||||
common::errors::InvalidArgument(
|
common::errors::InvalidArgument(
|
||||||
"paged_attention expects k_cache is contiguous"));
|
"paged_attention expects k_cache is contiguous"));
|
||||||
PADDLE_ENFORCE_EQ(v_cache.dtype(),
|
|
||||||
dtype,
|
|
||||||
common::errors::InvalidArgument(
|
|
||||||
"v_cache dtype must be the same as query dtype"));
|
|
||||||
PADDLE_ENFORCE_EQ(v_cache.is_contiguous(),
|
|
||||||
true,
|
|
||||||
common::errors::InvalidArgument(
|
|
||||||
"paged_attention expects v_cache is contiguous"));
|
|
||||||
PADDLE_ENFORCE_EQ(block_table.dtype(),
|
PADDLE_ENFORCE_EQ(block_table.dtype(),
|
||||||
paddle::DataType::INT32,
|
paddle::DataType::INT32,
|
||||||
common::errors::InvalidArgument(
|
common::errors::InvalidArgument(
|
||||||
@@ -99,14 +85,14 @@ void PagedAttnKernel(const paddle::Tensor& q,
|
|||||||
true,
|
true,
|
||||||
common::errors::InvalidArgument(
|
common::errors::InvalidArgument(
|
||||||
"paged_attention expects seq_lens is contiguous"));
|
"paged_attention expects seq_lens is contiguous"));
|
||||||
|
|
||||||
// check dim and shape
|
// check dim and shape
|
||||||
// out: [num_seqs, num_heads, head_size]
|
// k_cache: [num_blocks, kv_num_heads, block_size, head_size]
|
||||||
// q: [num_seqs, num_heads, head_size]
|
// v_cache: [num_blocks, kv_num_heads, block_size, head_size]
|
||||||
// k_chache: [num_blocks, kv_num_heads, block_size, head_size]
|
|
||||||
// v_chache: [num_blocks, kv_num_heads, block_size, head_size]
|
|
||||||
// block_table: [num_seqs, max_num_blocks_per_seq]
|
// block_table: [num_seqs, max_num_blocks_per_seq]
|
||||||
// seq_lens: [num_seqs]
|
// seq_lens: [num_seqs]
|
||||||
|
// q and out:
|
||||||
|
// merged_qkv = false: [num_seqs, num_heads, head_size]
|
||||||
|
// merged_qkv = true: [num_seqs, num_heads+2*num_kv_heads, head_size]
|
||||||
|
|
||||||
const auto& q_dims = q.dims();
|
const auto& q_dims = q.dims();
|
||||||
PADDLE_ENFORCE_EQ(q_dims.size(),
|
PADDLE_ENFORCE_EQ(q_dims.size(),
|
||||||
@@ -119,11 +105,6 @@ void PagedAttnKernel(const paddle::Tensor& q,
|
|||||||
common::errors::InvalidArgument(
|
common::errors::InvalidArgument(
|
||||||
"paged_attn receive out dims is "
|
"paged_attn receive out dims is "
|
||||||
"[num_seqs, num_heads, head_size]"));
|
"[num_seqs, num_heads, head_size]"));
|
||||||
PADDLE_ENFORCE_EQ(k_cache.dims(),
|
|
||||||
v_cache.dims(),
|
|
||||||
common::errors::InvalidArgument(
|
|
||||||
"paged_attn requires k_cache size is the "
|
|
||||||
"same as v_cache"));
|
|
||||||
|
|
||||||
const auto& kv_cache_dims = k_cache.dims();
|
const auto& kv_cache_dims = k_cache.dims();
|
||||||
PADDLE_ENFORCE_EQ(kv_cache_dims.size(),
|
PADDLE_ENFORCE_EQ(kv_cache_dims.size(),
|
||||||
@@ -146,7 +127,7 @@ void PagedAttnKernel(const paddle::Tensor& q,
|
|||||||
"paged_attn receive seq_lens dims is [num_seqs]"));
|
"paged_attn receive seq_lens dims is [num_seqs]"));
|
||||||
|
|
||||||
int num_seqs = q_dims[0];
|
int num_seqs = q_dims[0];
|
||||||
int num_heads = q_dims[1];
|
int num_heads = merged_qkv ? q_dims[1] - 2 * num_kv_heads : q_dims[1];
|
||||||
int head_size = q_dims[2];
|
int head_size = q_dims[2];
|
||||||
int max_num_blocks_per_seq = block_table_dims[1];
|
int max_num_blocks_per_seq = block_table_dims[1];
|
||||||
int q_stride = q.strides()[0];
|
int q_stride = q.strides()[0];
|
||||||
@@ -178,22 +159,28 @@ void PagedAttnKernel(const paddle::Tensor& q,
|
|||||||
const float *alibi_slopes_ptr = alibi_slopes ? alibi_slopes.get().data<float>() : nullptr;
|
const float *alibi_slopes_ptr = alibi_slopes ? alibi_slopes.get().data<float>() : nullptr;
|
||||||
const void *key_ptr = k ? k.get().data() : nullptr;
|
const void *key_ptr = k ? k.get().data() : nullptr;
|
||||||
const void *value_ptr = v ? v.get().data() : nullptr;
|
const void *value_ptr = v ? v.get().data() : nullptr;
|
||||||
|
const float *rope_sin_ptr = merged_qkv ? rope_sin.get().data<float>() : nullptr;
|
||||||
size_t workspace_size = 0;
|
const float *rope_cos_ptr = merged_qkv ? rope_cos.get().data<float>() : nullptr;
|
||||||
void* workspace_ptr = nullptr;
|
|
||||||
CUINFER_CHECK(cuInferPageAttentionGetWorkspaceV7(
|
|
||||||
num_seqs, num_heads, num_kv_heads, head_size, block_size, max_context_len, &workspace_size));
|
|
||||||
|
|
||||||
CUDA_CHECK(cudaMalloc((void**)&workspace_ptr, workspace_size));
|
|
||||||
CUDA_CHECK(cudaMemset(workspace_ptr, 0xff, workspace_size));
|
|
||||||
|
|
||||||
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(q.place()));
|
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(q.place()));
|
||||||
auto stream = static_cast<const cudaStream_t>(dev_ctx->stream());
|
|
||||||
cuinferHandle_t cuinfer_handle = iluvatar::getContextInstance()->getIxInferHandle();
|
cuinferHandle_t cuinfer_handle = iluvatar::getContextInstance()->getIxInferHandle();
|
||||||
|
|
||||||
|
size_t workspace_size = 0;
|
||||||
|
CUINFER_CHECK(cuInferPageAttentionGetWorkspaceV7(num_seqs,
|
||||||
|
num_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_size,
|
||||||
|
block_size,
|
||||||
|
max_context_len,
|
||||||
|
&workspace_size));
|
||||||
|
auto* allocator = paddle::GetAllocator(q.place());
|
||||||
|
phi::Allocator::AllocationPtr tmp_workspace = allocator->Allocate(workspace_size);
|
||||||
|
void* workspace_ptr = tmp_workspace->ptr();
|
||||||
|
|
||||||
PageAttentionWithKVCacheArguments args{
|
PageAttentionWithKVCacheArguments args{
|
||||||
static_cast<float>(scale), 1.0, 1.0, static_cast<float>(softcap), window_left, window_right,
|
static_cast<float>(scale), 1.0, 1.0, static_cast<float>(softcap), window_left, window_right,
|
||||||
causal, use_sqrt_alibi, enable_cuda_graph, false, alibi_slopes_ptr, key_ptr, value_ptr, workspace_ptr};
|
causal, use_sqrt_alibi, enable_cuda_graph, false, alibi_slopes_ptr, key_ptr, value_ptr,
|
||||||
|
workspace_ptr, merged_qkv, rope_sin_ptr, rope_cos_ptr};
|
||||||
CUINFER_CHECK(cuInferPageAttentionV7(cuinfer_handle,
|
CUINFER_CHECK(cuInferPageAttentionV7(cuinfer_handle,
|
||||||
out.data(),
|
out.data(),
|
||||||
data_type,
|
data_type,
|
||||||
@@ -216,8 +203,6 @@ void PagedAttnKernel(const paddle::Tensor& q,
|
|||||||
block_table.data<int32_t>(),
|
block_table.data<int32_t>(),
|
||||||
seq_lens.data<int32_t>(),
|
seq_lens.data<int32_t>(),
|
||||||
args));
|
args));
|
||||||
|
|
||||||
CUDA_CHECK(cudaFree(workspace_ptr));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<paddle::Tensor> PagedAttn(const paddle::Tensor& q,
|
std::vector<paddle::Tensor> PagedAttn(const paddle::Tensor& q,
|
||||||
@@ -228,6 +213,8 @@ std::vector<paddle::Tensor> PagedAttn(const paddle::Tensor& q,
|
|||||||
const paddle::optional<paddle::Tensor> &alibi_slopes,
|
const paddle::optional<paddle::Tensor> &alibi_slopes,
|
||||||
const paddle::optional<paddle::Tensor> &k,
|
const paddle::optional<paddle::Tensor> &k,
|
||||||
const paddle::optional<paddle::Tensor> &v,
|
const paddle::optional<paddle::Tensor> &v,
|
||||||
|
const paddle::optional<paddle::Tensor> &rope_sin,
|
||||||
|
const paddle::optional<paddle::Tensor> &rope_cos,
|
||||||
int num_kv_heads,
|
int num_kv_heads,
|
||||||
float scale,
|
float scale,
|
||||||
int block_size,
|
int block_size,
|
||||||
@@ -237,10 +224,15 @@ std::vector<paddle::Tensor> PagedAttn(const paddle::Tensor& q,
|
|||||||
int window_right,
|
int window_right,
|
||||||
float softcap,
|
float softcap,
|
||||||
bool enable_cuda_graph,
|
bool enable_cuda_graph,
|
||||||
bool use_sqrt_alibi) {
|
bool use_sqrt_alibi,
|
||||||
|
bool merged_qkv) {
|
||||||
|
|
||||||
const auto dtype = q.dtype();
|
const auto dtype = q.dtype();
|
||||||
auto out = paddle::empty_like(q, dtype);
|
auto out_shape = q.shape();
|
||||||
|
if (merged_qkv) {
|
||||||
|
out_shape[1] -= 2 * num_kv_heads;
|
||||||
|
}
|
||||||
|
auto out = paddle::empty(out_shape, dtype, q.place());
|
||||||
|
|
||||||
switch (dtype) {
|
switch (dtype) {
|
||||||
case paddle::DataType::BFLOAT16:
|
case paddle::DataType::BFLOAT16:
|
||||||
@@ -252,6 +244,8 @@ std::vector<paddle::Tensor> PagedAttn(const paddle::Tensor& q,
|
|||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
|
rope_sin,
|
||||||
|
rope_cos,
|
||||||
num_kv_heads,
|
num_kv_heads,
|
||||||
scale,
|
scale,
|
||||||
block_size,
|
block_size,
|
||||||
@@ -262,6 +256,7 @@ std::vector<paddle::Tensor> PagedAttn(const paddle::Tensor& q,
|
|||||||
softcap,
|
softcap,
|
||||||
enable_cuda_graph,
|
enable_cuda_graph,
|
||||||
use_sqrt_alibi,
|
use_sqrt_alibi,
|
||||||
|
merged_qkv,
|
||||||
out);
|
out);
|
||||||
break;
|
break;
|
||||||
case paddle::DataType::FLOAT16:
|
case paddle::DataType::FLOAT16:
|
||||||
@@ -273,6 +268,8 @@ std::vector<paddle::Tensor> PagedAttn(const paddle::Tensor& q,
|
|||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
|
rope_sin,
|
||||||
|
rope_cos,
|
||||||
num_kv_heads,
|
num_kv_heads,
|
||||||
scale,
|
scale,
|
||||||
block_size,
|
block_size,
|
||||||
@@ -283,6 +280,7 @@ std::vector<paddle::Tensor> PagedAttn(const paddle::Tensor& q,
|
|||||||
softcap,
|
softcap,
|
||||||
enable_cuda_graph,
|
enable_cuda_graph,
|
||||||
use_sqrt_alibi,
|
use_sqrt_alibi,
|
||||||
|
merged_qkv,
|
||||||
out);
|
out);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
@@ -298,8 +296,28 @@ std::vector<std::vector<int64_t>> PagedAttnInferShape(const std::vector<int64_t>
|
|||||||
const std::vector<int64_t>& seq_lens_shape,
|
const std::vector<int64_t>& seq_lens_shape,
|
||||||
const std::vector<int64_t>& alibi_slopes_shape,
|
const std::vector<int64_t>& alibi_slopes_shape,
|
||||||
const std::vector<int64_t>& k_shape,
|
const std::vector<int64_t>& k_shape,
|
||||||
const std::vector<int64_t>& v_shape) {
|
const std::vector<int64_t>& v_shape,
|
||||||
return {q_shape};
|
const std::vector<int64_t>& rope_sin_shape,
|
||||||
|
const std::vector<int64_t>& rope_cos_shape,
|
||||||
|
int num_kv_heads,
|
||||||
|
float scale,
|
||||||
|
int block_size,
|
||||||
|
int max_context_len,
|
||||||
|
bool causal,
|
||||||
|
int window_left,
|
||||||
|
int window_right,
|
||||||
|
float softcap,
|
||||||
|
bool enable_cuda_graph,
|
||||||
|
bool use_sqrt_alibi,
|
||||||
|
bool merged_qkv) {
|
||||||
|
if (merged_qkv) {
|
||||||
|
int64_t num_tokens = q_shape[0];
|
||||||
|
int64_t num_heads = q_shape[1] - 2 * num_kv_heads;
|
||||||
|
int64_t head_dim = q_shape[2];
|
||||||
|
return {{num_tokens, num_heads, head_dim}};
|
||||||
|
} else {
|
||||||
|
return {q_shape};
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<paddle::DataType> PagedAttnInferDtype(const paddle::DataType& q_dtype,
|
std::vector<paddle::DataType> PagedAttnInferDtype(const paddle::DataType& q_dtype,
|
||||||
@@ -309,13 +327,29 @@ std::vector<paddle::DataType> PagedAttnInferDtype(const paddle::DataType& q_dtyp
|
|||||||
const paddle::DataType& seq_lens_dtype,
|
const paddle::DataType& seq_lens_dtype,
|
||||||
const paddle::DataType& alibi_slopes_dtype,
|
const paddle::DataType& alibi_slopes_dtype,
|
||||||
const paddle::DataType& k_dtype,
|
const paddle::DataType& k_dtype,
|
||||||
const paddle::DataType& v_dtype) {
|
const paddle::DataType& v_dtype,
|
||||||
|
const paddle::DataType& rope_sin_dtype,
|
||||||
|
const paddle::DataType& rope_cos_dtype,
|
||||||
|
int num_kv_heads,
|
||||||
|
float scale,
|
||||||
|
int block_size,
|
||||||
|
int max_context_len,
|
||||||
|
bool causal,
|
||||||
|
int window_left,
|
||||||
|
int window_right,
|
||||||
|
float softcap,
|
||||||
|
bool enable_cuda_graph,
|
||||||
|
bool use_sqrt_alibi,
|
||||||
|
bool merged_qkv) {
|
||||||
return {q_dtype};
|
return {q_dtype};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
PD_BUILD_STATIC_OP(paged_attn)
|
PD_BUILD_STATIC_OP(paged_attn)
|
||||||
.Inputs({"q", "k_cache", "v_cache", "block_table", "seq_lens", paddle::Optional("alibi_slopes"), paddle::Optional("k"), paddle::Optional("v")})
|
.Inputs({"q", "k_cache", "v_cache", "block_table", "seq_lens",
|
||||||
|
paddle::Optional("alibi_slopes"), paddle::Optional("k"),
|
||||||
|
paddle::Optional("v"), paddle::Optional("rope_sin"),
|
||||||
|
paddle::Optional("rope_cos")})
|
||||||
.Outputs({"out"})
|
.Outputs({"out"})
|
||||||
.Attrs({"num_kv_heads:int",
|
.Attrs({"num_kv_heads:int",
|
||||||
"scale:float",
|
"scale:float",
|
||||||
@@ -326,12 +360,8 @@ PD_BUILD_STATIC_OP(paged_attn)
|
|||||||
"window_right:int",
|
"window_right:int",
|
||||||
"softcap:float",
|
"softcap:float",
|
||||||
"enable_cuda_graph:bool",
|
"enable_cuda_graph:bool",
|
||||||
"use_sqrt_alibi:bool"})
|
"use_sqrt_alibi:bool",
|
||||||
|
"merged_qkv:bool"})
|
||||||
.SetKernelFn(PD_KERNEL(PagedAttn))
|
.SetKernelFn(PD_KERNEL(PagedAttn))
|
||||||
.SetInferShapeFn(PD_INFER_SHAPE(PagedAttnInferShape))
|
.SetInferShapeFn(PD_INFER_SHAPE(PagedAttnInferShape))
|
||||||
.SetInferDtypeFn(PD_INFER_DTYPE(PagedAttnInferDtype));
|
.SetInferDtypeFn(PD_INFER_DTYPE(PagedAttnInferDtype));
|
||||||
|
|
||||||
|
|
||||||
PYBIND11_MODULE(fastdeploy_ops, m) {
|
|
||||||
m.def("paged_attn", &PagedAttn, "paged attn function");
|
|
||||||
}
|
|
||||||
|
@@ -13,20 +13,47 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <ixinfer.h>
|
#include <ixinfer.h>
|
||||||
|
#include <iostream>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#define CUINFER_CHECK(func) \
|
||||||
|
do { \
|
||||||
|
cuinferStatus_t status = (func); \
|
||||||
|
if (status != CUINFER_STATUS_SUCCESS) { \
|
||||||
|
std::cerr << "Error in file " << __FILE__ << " on line " \
|
||||||
|
<< __LINE__ << ": " << cuinferGetErrorString(status) \
|
||||||
|
<< std::endl; \
|
||||||
|
throw std::runtime_error("CUINFER_CHECK ERROR"); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
namespace iluvatar {
|
namespace iluvatar {
|
||||||
|
|
||||||
class IluvatarContext {
|
class IluvatarContext {
|
||||||
public:
|
public:
|
||||||
IluvatarContext() = default;
|
IluvatarContext() = default;
|
||||||
~IluvatarContext();
|
~IluvatarContext();
|
||||||
|
|
||||||
cuinferHandle_t getIxInferHandle();
|
cuinferHandle_t getIxInferHandle();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
cuinferHandle_t ixinfer_handle_{nullptr};
|
cuinferHandle_t ixinfer_handle_{nullptr};
|
||||||
};
|
};
|
||||||
IluvatarContext* getContextInstance();
|
IluvatarContext* getContextInstance();
|
||||||
|
|
||||||
|
200
custom_ops/iluvatar_ops/w8a16_group_gemm.cu
Normal file
200
custom_ops/iluvatar_ops/w8a16_group_gemm.cu
Normal file
@@ -0,0 +1,200 @@
|
|||||||
|
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "helper.h"
|
||||||
|
#include "iluvatar_context.h"
|
||||||
|
|
||||||
|
std::vector<paddle::Tensor> GroupGemm(const paddle::Tensor& x,
|
||||||
|
const paddle::Tensor& weight,
|
||||||
|
const paddle::Tensor& weight_scale,
|
||||||
|
const paddle::Tensor& prefix_sum,
|
||||||
|
const int32_t group_size) {
|
||||||
|
auto dev_ctx = static_cast<const phi::CustomContext*>(
|
||||||
|
paddle::experimental::DeviceContextPool::Instance().Get(x.place()));
|
||||||
|
auto stream = static_cast<const cudaStream_t>(dev_ctx->stream());
|
||||||
|
const auto& x_dims = x.dims();
|
||||||
|
const auto& w_dims = weight.dims();
|
||||||
|
const auto& ws_dims = weight_scale.dims();
|
||||||
|
const auto& prefix_sum_dims = prefix_sum.dims();
|
||||||
|
// [m, k]
|
||||||
|
PD_CHECK(x_dims.size() == 2, "x should be 2D");
|
||||||
|
// [n_experts, n, k]
|
||||||
|
PD_CHECK(w_dims.size() == 3, "weight should be 3D");
|
||||||
|
// [n_experts, n]
|
||||||
|
PD_CHECK(ws_dims.size() == 2, "weight_scale should be 2D");
|
||||||
|
// [n_experts]
|
||||||
|
PD_CHECK(prefix_sum_dims.size() == 1, "prefix_sum should be 1D");
|
||||||
|
PD_CHECK(group_size == -1);
|
||||||
|
auto m = x_dims[0];
|
||||||
|
auto k = x_dims[1];
|
||||||
|
auto n_experts = w_dims[0];
|
||||||
|
auto n = w_dims[1];
|
||||||
|
PD_CHECK(w_dims[2] == k);
|
||||||
|
PD_CHECK(ws_dims[0] == n_experts);
|
||||||
|
PD_CHECK(ws_dims[1] == n);
|
||||||
|
PD_CHECK(prefix_sum_dims[0] == n_experts);
|
||||||
|
|
||||||
|
PD_CHECK(prefix_sum.dtype() == paddle::DataType::INT64);
|
||||||
|
PD_CHECK(prefix_sum.is_cpu());
|
||||||
|
PD_CHECK(x.dtype() == paddle::DataType::BFLOAT16 ||
|
||||||
|
x.dtype() == paddle::DataType::FLOAT16);
|
||||||
|
PD_CHECK(weight.dtype() == paddle::DataType::INT8);
|
||||||
|
PD_CHECK(weight_scale.dtype() == x.dtype());
|
||||||
|
PD_CHECK(x.is_contiguous());
|
||||||
|
PD_CHECK(weight.is_contiguous());
|
||||||
|
PD_CHECK(weight_scale.is_contiguous());
|
||||||
|
|
||||||
|
const int64_t* prefix_sum_ptr = prefix_sum.data<int64_t>();
|
||||||
|
auto output = GetEmptyTensor({m, n}, x.dtype(), x.place());
|
||||||
|
int16_t* out_data = static_cast<int16_t*>(output.data());
|
||||||
|
const int16_t* x_data = static_cast<const int16_t*>(x.data());
|
||||||
|
const int8_t* weight_data = weight.data<int8_t>();
|
||||||
|
const int16_t* weight_scale_data =
|
||||||
|
static_cast<const int16_t*>(weight_scale.data());
|
||||||
|
|
||||||
|
cuinferHandle_t handle = iluvatar::getContextInstance()->getIxInferHandle();
|
||||||
|
cuinferPointerMode_t cuinfer_ptr_mode = CUINFER_POINTER_MODE_HOST;
|
||||||
|
cuinferOperation_t transa = CUINFER_OP_T;
|
||||||
|
cuinferOperation_t transb = CUINFER_OP_N;
|
||||||
|
cudaDataType_t a_type = CUDA_R_8I;
|
||||||
|
cudaDataType_t b_type;
|
||||||
|
cudaDataType_t c_type;
|
||||||
|
if (x.dtype() == paddle::DataType::FLOAT16) {
|
||||||
|
b_type = CUDA_R_16F;
|
||||||
|
} else if (x.dtype() == paddle::DataType::BFLOAT16) {
|
||||||
|
b_type = CUDA_R_16BF;
|
||||||
|
} else {
|
||||||
|
PADDLE_THROW(common::errors::Unimplemented("Unsupported input dtype."));
|
||||||
|
}
|
||||||
|
c_type = b_type;
|
||||||
|
cudaDataType_t Atype = a_type;
|
||||||
|
cudaDataType_t Btype = b_type;
|
||||||
|
cudaDataType_t Ctype = c_type;
|
||||||
|
cudaDataType_t computeType = CUDA_R_32F;
|
||||||
|
cudaDataType_t scaleType = CUDA_R_32F;
|
||||||
|
cuinferGEMMCustomOption_t customOption = CUINFER_BLAS_GEMM_CUSTOM_NONE;
|
||||||
|
|
||||||
|
cuinferQuantGEMMHostParam cust_host_param;
|
||||||
|
cust_host_param.size = sizeof(cuinferQuantGEMMHostParam);
|
||||||
|
cust_host_param.persistent = 0;
|
||||||
|
cust_host_param.groupSize = group_size;
|
||||||
|
cuinferQuantGEMMDeviceParam cust_device_param;
|
||||||
|
cust_device_param.bias = nullptr;
|
||||||
|
cust_device_param.workspace = nullptr;
|
||||||
|
|
||||||
|
int lda = k;
|
||||||
|
int ldb = k;
|
||||||
|
int ldc = n;
|
||||||
|
float beta = 0.f;
|
||||||
|
float alpha = 1.f;
|
||||||
|
int batch_count = 1;
|
||||||
|
size_t pre = 0;
|
||||||
|
|
||||||
|
auto* allocator = paddle::GetAllocator(x.place());
|
||||||
|
phi::Allocator::AllocationPtr tmp_workspace;
|
||||||
|
for (int i = 0; i < n_experts; i++) {
|
||||||
|
size_t expert_i_end = prefix_sum_ptr[i];
|
||||||
|
size_t cur_len = expert_i_end - pre;
|
||||||
|
pre = expert_i_end;
|
||||||
|
if (cur_len != 0) {
|
||||||
|
cust_device_param.scale = weight_scale_data;
|
||||||
|
|
||||||
|
if (k % 64 != 0) {
|
||||||
|
size_t workspace_size;
|
||||||
|
CUINFER_CHECK(cuinferGetCustomGemmWorkspace(transa,
|
||||||
|
transb,
|
||||||
|
n,
|
||||||
|
cur_len,
|
||||||
|
k,
|
||||||
|
Atype,
|
||||||
|
lda,
|
||||||
|
lda,
|
||||||
|
Btype,
|
||||||
|
ldb,
|
||||||
|
ldb,
|
||||||
|
Ctype,
|
||||||
|
ldc,
|
||||||
|
ldc,
|
||||||
|
batch_count,
|
||||||
|
computeType,
|
||||||
|
scaleType,
|
||||||
|
&workspace_size));
|
||||||
|
tmp_workspace = allocator->Allocate(workspace_size);
|
||||||
|
cust_device_param.workspace = tmp_workspace->ptr();
|
||||||
|
} else {
|
||||||
|
cust_device_param.workspace = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
CUINFER_CHECK(cuinferCustomGemm(handle,
|
||||||
|
stream,
|
||||||
|
cuinfer_ptr_mode,
|
||||||
|
transa,
|
||||||
|
transb,
|
||||||
|
n,
|
||||||
|
cur_len,
|
||||||
|
k,
|
||||||
|
&alpha,
|
||||||
|
weight_data,
|
||||||
|
Atype,
|
||||||
|
lda,
|
||||||
|
lda,
|
||||||
|
x_data,
|
||||||
|
Btype,
|
||||||
|
ldb,
|
||||||
|
ldb,
|
||||||
|
&beta,
|
||||||
|
out_data,
|
||||||
|
Ctype,
|
||||||
|
ldc,
|
||||||
|
ldc,
|
||||||
|
batch_count,
|
||||||
|
computeType,
|
||||||
|
scaleType,
|
||||||
|
&cust_host_param,
|
||||||
|
&cust_device_param,
|
||||||
|
customOption));
|
||||||
|
}
|
||||||
|
x_data += cur_len * k;
|
||||||
|
weight_data += k * n;
|
||||||
|
weight_scale_data += n;
|
||||||
|
out_data += cur_len * n;
|
||||||
|
}
|
||||||
|
return {output};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::vector<int64_t>> GroupGemmInferShape(
|
||||||
|
const std::vector<int64_t>& x_shape,
|
||||||
|
const std::vector<int64_t>& weight_shape,
|
||||||
|
const std::vector<int64_t>& weight_scale_shape,
|
||||||
|
const std::vector<int64_t>& prefix_sum_shape) {
|
||||||
|
return {{x_shape[0], weight_shape[1]}};
|
||||||
|
}
|
||||||
|
std::vector<paddle::DataType> GroupGemmInferDtype(
|
||||||
|
const paddle::DataType& input_dtype,
|
||||||
|
const paddle::DataType& weight_output_dtype,
|
||||||
|
const paddle::DataType& weight_scale_dtype,
|
||||||
|
const paddle::DataType& prefix_sum_dtype,
|
||||||
|
const int moe_topk) {
|
||||||
|
return {input_dtype};
|
||||||
|
}
|
||||||
|
|
||||||
|
PD_BUILD_STATIC_OP(w8a16_group_gemm)
|
||||||
|
.Inputs({"x", "weight", "weight_scale", "prefix_sum"})
|
||||||
|
.Outputs({"output"})
|
||||||
|
.Attrs({
|
||||||
|
"group_size:int",
|
||||||
|
})
|
||||||
|
.SetKernelFn(PD_KERNEL(GroupGemm))
|
||||||
|
.SetInferShapeFn(PD_INFER_SHAPE(GroupGemmInferShape))
|
||||||
|
.SetInferDtypeFn(PD_INFER_DTYPE(GroupGemmInferDtype));
|
@@ -539,9 +539,12 @@ elif paddle.is_compiled_with_custom_device("iluvatar_gpu"):
|
|||||||
"gpu_ops/stop_generation_multi_ends.cu",
|
"gpu_ops/stop_generation_multi_ends.cu",
|
||||||
"gpu_ops/step.cu",
|
"gpu_ops/step.cu",
|
||||||
"gpu_ops/token_penalty_multi_scores.cu",
|
"gpu_ops/token_penalty_multi_scores.cu",
|
||||||
|
"gpu_ops/sample_kernels/rejection_top_p_sampling.cu",
|
||||||
|
"gpu_ops/sample_kernels/top_k_renorm_probs.cu",
|
||||||
"iluvatar_ops/moe_dispatch.cu",
|
"iluvatar_ops/moe_dispatch.cu",
|
||||||
"iluvatar_ops/moe_reduce.cu",
|
"iluvatar_ops/moe_reduce.cu",
|
||||||
"iluvatar_ops/paged_attn.cu",
|
"iluvatar_ops/paged_attn.cu",
|
||||||
|
"iluvatar_ops/w8a16_group_gemm.cu",
|
||||||
"iluvatar_ops/runtime/iluvatar_context.cc",
|
"iluvatar_ops/runtime/iluvatar_context.cc",
|
||||||
],
|
],
|
||||||
include_dirs=["iluvatar_ops/runtime", "gpu_ops"],
|
include_dirs=["iluvatar_ops/runtime", "gpu_ops"],
|
||||||
|
@@ -1,12 +1,12 @@
|
|||||||
# Run ERNIE-4.5-300B-A47B & ERNIE-4.5-21B-A3B model on iluvatar machine
|
# Run ERNIE-4.5-300B-A47B & ERNIE-4.5-21B-A3B model on iluvatar machine
|
||||||
The current version of the software merely serves as a demonstration demo for the Iluvatar CoreX combined with the Fastdeploy inference framework for large models. There may be issues when running the latest ERNIE4.5 model, and we will conduct repairs and performance optimization in the future. Subsequent versions will provide customers with a more stable version.
|
The current version of the software merely serves as a demonstration demo for the Iluvatar CoreX combined with the Fastdeploy inference framework for large models. Running the latest ERNIE4.5 300B model on the GSM8K dataset takes about 6.3 hours.
|
||||||
|
|
||||||
## Machine Preparation
|
## Machine Preparation
|
||||||
First, you need to prepare a machine with the following configurations:
|
First, the `TP=16` when running the ERNIE4.5 300B model and so you need to prepare a machine with the following configurations:
|
||||||
|
|
||||||
| CPU | Memory | Card | Hard Disk|
|
| CPU | Memory | Card | Hard Disk|
|
||||||
| :---: | :---: | :---: | :---: |
|
| :---: | :---: | :---: | :---: |
|
||||||
| x86 | 1TB| 8xBI150| 1TB|
|
| x86 | 1TB| 16xBI150| 1TB|
|
||||||
|
|
||||||
Currently, the entire model needs to be loaded into the host memory, which requires more than 600GB of host memory. This issue will be optimized in subsequent versions.
|
Currently, the entire model needs to be loaded into the host memory, which requires more than 600GB of host memory. This issue will be optimized in subsequent versions.
|
||||||
|
|
||||||
@@ -46,6 +46,7 @@ script list below:
|
|||||||
export PADDLE_XCCL_BACKEND=iluvatar_gpu
|
export PADDLE_XCCL_BACKEND=iluvatar_gpu
|
||||||
export INFERENCE_MSG_QUEUE_ID=232132
|
export INFERENCE_MSG_QUEUE_ID=232132
|
||||||
export LD_PRELOAD=/usr/local/corex/lib64/libcuda.so.1
|
export LD_PRELOAD=/usr/local/corex/lib64/libcuda.so.1
|
||||||
|
export FD_SAMPLING_CLASS=rejection
|
||||||
export FD_DEBUG=1
|
export FD_DEBUG=1
|
||||||
python3 run_demo.py
|
python3 run_demo.py
|
||||||
```
|
```
|
||||||
@@ -64,7 +65,7 @@ prompts = [
|
|||||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=256)
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=256)
|
||||||
|
|
||||||
# load the model
|
# load the model
|
||||||
llm = LLM(model="/home/paddle/ernie-4_5-21b-a3b-bf16-paddle", tensor_parallel_size=4, max_model_len=8192, static_decode_blocks=0, quantization='wint8')
|
llm = LLM(model="/home/paddle/ernie-4_5-21b-a3b-bf16-paddle", tensor_parallel_size=4, max_model_len=8192, static_decode_blocks=0, block_size=16, quantization='wint8')
|
||||||
|
|
||||||
# Perform batch inference
|
# Perform batch inference
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
@@ -118,3 +119,281 @@ Now, let's break down each step:
|
|||||||
**Step 3: Drawing the
|
**Step 3: Drawing the
|
||||||
The largest ocean is the Pacific Ocean, covering an area of approximately ⦠[3], The first scientific expeditions to determine the ocean's depth were the Challenger expedition (1872â1876) and the U.S. Navy Hydrographic Office survey (1877â1879). The oceanic crust is thin and irregular, consisting of upward moving magma from the mantle below, and cooling and solidifying on the surface. The shallowest parts of the ocean are called the continental shelves. Large tides are caused mainly by the alignment of the Sun, Moon, and Earth during new or full moons. The origin of the word "ocean" is not clear. The first global oceanic topography survey was completed by the Challenger expedition (1872â1876). [57] The sound speed in the ocean is primarily a function of water temperature and salinity, and varies with depth. The deep-ocean floor is mostly flat and devoid of life, with the exception of seamounts and various underwater volcanic features, including seamounts and hydrothermal vents. [73] Today, the five ocean
|
The largest ocean is the Pacific Ocean, covering an area of approximately ⦠[3], The first scientific expeditions to determine the ocean's depth were the Challenger expedition (1872â1876) and the U.S. Navy Hydrographic Office survey (1877â1879). The oceanic crust is thin and irregular, consisting of upward moving magma from the mantle below, and cooling and solidifying on the surface. The shallowest parts of the ocean are called the continental shelves. Large tides are caused mainly by the alignment of the Sun, Moon, and Earth during new or full moons. The origin of the word "ocean" is not clear. The first global oceanic topography survey was completed by the Challenger expedition (1872â1876). [57] The sound speed in the ocean is primarily a function of water temperature and salinity, and varies with depth. The deep-ocean floor is mostly flat and devoid of life, with the exception of seamounts and various underwater volcanic features, including seamounts and hydrothermal vents. [73] Today, the five ocean
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Run ernie4.5 300B model with the GSM8K dataset
|
||||||
|
|
||||||
|
1. Download GSM8K dataset
|
||||||
|
|
||||||
|
```bash
|
||||||
|
wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Prepare `bench_gsm8k.py`
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
""" Fastdeploy + ERNIE-4.5-Turbo 的指标评估 """
|
||||||
|
# adapted from https://github.com/sgl-project/sglang/blob/main/benchmark/gsm8k/bench_other.py
|
||||||
|
import argparse
|
||||||
|
import ast
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import requests
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
INVALID = -9999999
|
||||||
|
|
||||||
|
|
||||||
|
def call_generate(prompt, **kwargs):
|
||||||
|
"""
|
||||||
|
Generates response based on the input prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): The input prompt text.
|
||||||
|
**kwargs: Keyword arguments, including server IP address and port number.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The response generated based on the prompt.
|
||||||
|
|
||||||
|
"""
|
||||||
|
url = f"http://{kwargs['ip']}:{kwargs['port']}/v1/chat/completions"
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
data = {
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"temperature": 0.6,
|
||||||
|
"max_tokens": 2047,
|
||||||
|
"top_p": 0.95,
|
||||||
|
"do_sample": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(url, headers=headers, data=json.dumps(data))
|
||||||
|
out = response.json()
|
||||||
|
return out["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
|
|
||||||
|
def get_one_example(lines, i, include_answer):
|
||||||
|
"""
|
||||||
|
Retrieves a question-answer example from the given list of text lines.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lines (list of dict): A list of question-answer pairs.
|
||||||
|
i (int): The index of the question-answer pair to retrieve from lines.
|
||||||
|
include_answer (bool): Whether to include the answer in the returned string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: A formatted question-answer string in the format "Question: <question>\nAnswer: <answer>".
|
||||||
|
|
||||||
|
"""
|
||||||
|
ret = "Question: " + lines[i]["question"] + "\nAnswer:"
|
||||||
|
if include_answer:
|
||||||
|
ret += " " + lines[i]["answer"]
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def get_few_shot_examples(lines, k):
|
||||||
|
"""
|
||||||
|
Selects k examples from the given list of text lines and concatenates them into a single string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lines (list): A list containing text lines.
|
||||||
|
k (int): The number of examples to select.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: A string composed of k examples, separated by two newline characters.
|
||||||
|
"""
|
||||||
|
ret = ""
|
||||||
|
for i in range(k):
|
||||||
|
ret += get_one_example(lines, i, True) + "\n\n"
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def get_answer_value(answer_str):
|
||||||
|
"""
|
||||||
|
Extracts numerical values from an answer string and returns them.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
answer_str (str): The string containing the answer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The extracted numerical value; returns "INVALID" if extraction fails.
|
||||||
|
"""
|
||||||
|
answer_str = answer_str.replace(",", "")
|
||||||
|
numbers = re.findall(r"\d+", answer_str)
|
||||||
|
if len(numbers) < 1:
|
||||||
|
return INVALID
|
||||||
|
try:
|
||||||
|
return ast.literal_eval(numbers[-1])
|
||||||
|
except SyntaxError:
|
||||||
|
return INVALID
|
||||||
|
|
||||||
|
|
||||||
|
def read_jsonl(filename: str):
|
||||||
|
"""
|
||||||
|
Reads a JSONL file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename (str): Path to the JSONL file.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
dict: A dictionary object corresponding to each line in the JSONL file.
|
||||||
|
"""
|
||||||
|
with open(filename) as fin:
|
||||||
|
for line in fin:
|
||||||
|
if line.startswith("#"):
|
||||||
|
continue
|
||||||
|
yield json.loads(line)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
"""
|
||||||
|
Process inputs and generate answers by calling the model in parallel using a thread pool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args (argparse.Namespace):
|
||||||
|
- num_questions (int): Number of questions to process.
|
||||||
|
- num_shots (int): Number of few-shot learning examples.
|
||||||
|
- ip (str): IP address of the model service.
|
||||||
|
- port (int): Port number of the model service.
|
||||||
|
- parallel (int): Number of questions to process in parallel.
|
||||||
|
- result_file (str): File path to store the results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Read data
|
||||||
|
filename = "test.jsonl"
|
||||||
|
|
||||||
|
lines = list(read_jsonl(filename))
|
||||||
|
|
||||||
|
# Construct prompts
|
||||||
|
num_questions = args.num_questions
|
||||||
|
num_shots = args.num_shots
|
||||||
|
few_shot_examples = get_few_shot_examples(lines, num_shots)
|
||||||
|
|
||||||
|
questions = []
|
||||||
|
labels = []
|
||||||
|
for i in range(len(lines[:num_questions])):
|
||||||
|
questions.append(get_one_example(lines, i, False))
|
||||||
|
labels.append(get_answer_value(lines[i]["answer"]))
|
||||||
|
assert all(l != INVALID for l in labels)
|
||||||
|
|
||||||
|
states = [None] * len(labels)
|
||||||
|
|
||||||
|
# Use thread pool
|
||||||
|
def get_one_answer(i):
|
||||||
|
answer = call_generate(
|
||||||
|
prompt=few_shot_examples + questions[i],
|
||||||
|
# stop=["Question", "Assistant:", "<|separator|>"],
|
||||||
|
ip=args.ip,
|
||||||
|
port=args.port,
|
||||||
|
)
|
||||||
|
states[i] = answer
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
if args.parallel == 1:
|
||||||
|
for i in tqdm(range(len(questions))):
|
||||||
|
get_one_answer(i)
|
||||||
|
else:
|
||||||
|
with ThreadPoolExecutor(args.parallel) as executor:
|
||||||
|
list(
|
||||||
|
tqdm(
|
||||||
|
executor.map(get_one_answer, list(range(len(questions)))),
|
||||||
|
total=len(questions),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
latency = time.time() - tic
|
||||||
|
preds = []
|
||||||
|
for i in range(len(states)):
|
||||||
|
preds.append(get_answer_value(states[i]))
|
||||||
|
|
||||||
|
# Compute accuracy
|
||||||
|
acc = np.mean(np.array(preds) == np.array(labels))
|
||||||
|
invalid = np.mean(np.array(preds) == INVALID)
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
print(f"Accuracy: {acc:.3f}")
|
||||||
|
print(f"Invalid: {invalid:.3f}")
|
||||||
|
print(f"Latency: {latency:.3f} s")
|
||||||
|
|
||||||
|
with open(args.result_file, "a") as fout:
|
||||||
|
value = {
|
||||||
|
"task": "gsm8k",
|
||||||
|
"backend": "paddlepaddle",
|
||||||
|
"num_gpus": 1,
|
||||||
|
"latency": round(latency, 3),
|
||||||
|
"accuracy": round(acc, 3),
|
||||||
|
"num_requests": args.num_questions,
|
||||||
|
"other": {
|
||||||
|
"num_questions": args.num_questions,
|
||||||
|
"parallel": args.parallel,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--ip", type=str, default="127.0.0.1")
|
||||||
|
parser.add_argument("--port", type=str, default="8188")
|
||||||
|
parser.add_argument("--num-shots", type=int, default=10)
|
||||||
|
parser.add_argument("--data-path", type=str, default="test.jsonl")
|
||||||
|
parser.add_argument("--num-questions", type=int, default=1319)
|
||||||
|
parser.add_argument("--result-file", type=str, default="result.jsonl")
|
||||||
|
parser.add_argument("--parallel", type=int, default=1)
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Prepare `run_bench.sh`
|
||||||
|
|
||||||
|
```bash
|
||||||
|
#!/bin/bash
|
||||||
|
export PADDLE_XCCL_BACKEND=iluvatar_gpu
|
||||||
|
export INFERENCE_MSG_QUEUE_ID=232132
|
||||||
|
export LD_PRELOAD=/usr/local/corex/lib64/libcuda.so.1
|
||||||
|
export FD_SAMPLING_CLASS=rejection
|
||||||
|
|
||||||
|
python3 -m fastdeploy.entrypoints.openai.api_server --model "/home/paddle/ernie-45t" --port 8188 --tensor-parallel-size 16 --block-size 16 --static-decode-blocks 0 --quantization wint8
|
||||||
|
```
|
||||||
|
|
||||||
|
4. Running the Script
|
||||||
|
|
||||||
|
Firstly, open a terminal and run:
|
||||||
|
```bash
|
||||||
|
./run_bench.sh
|
||||||
|
```
|
||||||
|
After the service is ready, open another terminal and run:
|
||||||
|
```bash
|
||||||
|
python3 -u bench_gsm8k.py --port 8188 --num-questions 1319 --num-shots 5 --parallel 8
|
||||||
|
```
|
||||||
|
It takes about 6.3 hours to run the GSM8K dataset.
|
||||||
|
|
||||||
|
```
|
||||||
|
Accuracy: 0.964
|
||||||
|
Invaild: 0.000
|
||||||
|
Latency: 22918.186 s
|
||||||
|
```
|
||||||
|
@@ -27,6 +27,7 @@ from paddleformers.transformers.configuration_utils import PretrainedConfig
|
|||||||
import fastdeploy
|
import fastdeploy
|
||||||
from fastdeploy import envs
|
from fastdeploy import envs
|
||||||
from fastdeploy.model_executor.layers.quantization.quant_base import QuantConfigBase
|
from fastdeploy.model_executor.layers.quantization.quant_base import QuantConfigBase
|
||||||
|
from fastdeploy.platforms import current_platform
|
||||||
from fastdeploy.utils import check_unified_ckpt, get_logger
|
from fastdeploy.utils import check_unified_ckpt, get_logger
|
||||||
|
|
||||||
logger = get_logger("config", "config.log")
|
logger = get_logger("config", "config.log")
|
||||||
@@ -733,7 +734,7 @@ class CacheConfig:
|
|||||||
self.gpu_memory_utilization = 0.9
|
self.gpu_memory_utilization = 0.9
|
||||||
self.num_gpu_blocks_override = None
|
self.num_gpu_blocks_override = None
|
||||||
self.kv_cache_ratio = 0.75
|
self.kv_cache_ratio = 0.75
|
||||||
self.enc_dec_block_num = 2
|
self.enc_dec_block_num = 0 if current_platform.is_iluvatar() else 2
|
||||||
self.prealloc_dec_block_slot_num_threshold = 5
|
self.prealloc_dec_block_slot_num_threshold = 5
|
||||||
self.cache_dtype = "bfloat16"
|
self.cache_dtype = "bfloat16"
|
||||||
self.model_cfg = None
|
self.model_cfg = None
|
||||||
|
@@ -961,7 +961,10 @@ class LLMEngine:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.do_profile:
|
if self.do_profile:
|
||||||
get_profile_block_num = np.zeros([1], dtype=np.int32)
|
if paddle.is_compiled_with_custom_device("iluvatar_gpu"):
|
||||||
|
get_profile_block_num = np.zeros([self.cfg.worker_num_per_node], dtype=np.int32)
|
||||||
|
else:
|
||||||
|
get_profile_block_num = np.zeros([1], dtype=np.int32)
|
||||||
self.get_profile_block_num_signal = IPCSignal(
|
self.get_profile_block_num_signal = IPCSignal(
|
||||||
name="get_profile_block_num",
|
name="get_profile_block_num",
|
||||||
array=get_profile_block_num,
|
array=get_profile_block_num,
|
||||||
|
@@ -85,45 +85,120 @@ class IluvatarAttnBackend(AttentionBackend):
|
|||||||
Which is used only for testing purpose.
|
Which is used only for testing purpose.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int, head_dim: int):
|
||||||
self,
|
|
||||||
llm_config: FDConfig,
|
|
||||||
kv_num_heads: int,
|
|
||||||
num_heads: int,
|
|
||||||
head_dim: int,
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.attention_metadata = IluvatarAttentionMetadata()
|
self.attention_metadata = IluvatarAttentionMetadata()
|
||||||
self.attention_metadata.block_size = llm_config.cache_config.block_size
|
self.attention_metadata.block_size = fd_config.parallel_config.block_size
|
||||||
assert llm_config.cache_config.enc_dec_block_num == 0, "Iluvatar does not support yet"
|
assert (
|
||||||
|
fd_config.parallel_config.enc_dec_block_num == 0
|
||||||
|
), f"Iluvatar does not support yet, {fd_config.parallel_config.enc_dec_block_num}"
|
||||||
|
assert self.attention_metadata.block_size == 16, "Iluvatar paged attn requires block_size must be 16."
|
||||||
|
|
||||||
self.attention_metadata.max_context_len = llm_config.parallel_config.max_model_len
|
self.attention_metadata.max_context_len = fd_config.parallel_config.max_model_len
|
||||||
self.attention_metadata.causal = getattr(llm_config.model_config, "causal", True)
|
self.attention_metadata.causal = getattr(fd_config.model_config, "causal", True)
|
||||||
self.speculate_method = getattr(llm_config.parallel_config, "speculate_method", None)
|
self.speculate_method = getattr(fd_config.parallel_config, "speculate_method", None)
|
||||||
self.use_speculate = self.speculate_method is not None
|
self.use_speculate = self.speculate_method is not None
|
||||||
self.attention_metadata.num_kv_heads = kv_num_heads
|
self.attention_metadata.num_kv_heads = kv_num_heads
|
||||||
self.attention_metadata.dropout = llm_config.model_config.hidden_dropout_prob
|
self.attention_metadata.dropout = fd_config.model_config.hidden_dropout_prob
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
|
self.total_num_heads = num_heads + 2 * kv_num_heads
|
||||||
self.head_dim = head_dim
|
self.head_dim = head_dim
|
||||||
|
self.hidden_dim = num_heads * head_dim
|
||||||
|
self.total_hidden_dim = self.total_num_heads * head_dim
|
||||||
# note: scale need to change if using MLA
|
# note: scale need to change if using MLA
|
||||||
self.attention_metadata.scale = 1.0 / sqrt(head_dim)
|
self.attention_metadata.scale = 1.0 / sqrt(head_dim)
|
||||||
self.num_layers = llm_config.model_config.num_hidden_layers
|
self.num_layers = fd_config.model_config.num_hidden_layers
|
||||||
|
self.dtype = paddle.get_default_dtype()
|
||||||
|
|
||||||
self.record_block_table_metadata = {}
|
self.record_block_table_metadata = {}
|
||||||
self.only_use_flash_attn = int(os.getenv("FD_ILUVATAR_ONLY_USE_FLASH_ATTN", 0)) == 1
|
self.enable_fused_attention = int(os.getenv("FD_ILUVATAR_ENABLE_FUSED_ATTN", 1))
|
||||||
self.do_check_kv_cache = int(os.getenv("FD_ILUVATAR_CHECK_KV_CACHE_CORRECTNESS", 0)) == 1
|
|
||||||
if not self.only_use_flash_attn:
|
|
||||||
assert self.attention_metadata.block_size == 16, "Iluvatar paged attn requires block_size must be 16."
|
|
||||||
if self.do_check_kv_cache:
|
|
||||||
self.record_batched_k = [{} for _ in range(self.num_layers)]
|
|
||||||
self.record_batched_v = [{} for _ in range(self.num_layers)]
|
|
||||||
|
|
||||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||||
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
|
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
|
||||||
self.attention_metadata.block_tables = forward_meta.block_tables
|
self.prefill_info_dict = {}
|
||||||
self.attention_metadata.attn_mask = forward_meta.attn_mask
|
self.decode_info_dict = {}
|
||||||
self.attention_metadata.seq_lens = forward_meta.seq_lens_decoder
|
|
||||||
self.attention_metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
|
prefill_non_zeros_ids = forward_meta.seq_lens_this_time > 1
|
||||||
self.attention_metadata.cu_seqlens_k = forward_meta.cu_seqlens_k
|
decode_non_zeros_ids = forward_meta.seq_lens_this_time == 1
|
||||||
|
self.prefill_info_dict["batch_ids"] = paddle.where(prefill_non_zeros_ids)[0]
|
||||||
|
self.decode_info_dict["batch_ids"] = paddle.where(decode_non_zeros_ids)[0]
|
||||||
|
|
||||||
|
self.prefill_len = len(self.prefill_info_dict["batch_ids"])
|
||||||
|
self.decode_len = len(self.decode_info_dict["batch_ids"])
|
||||||
|
# only prefill
|
||||||
|
if self.decode_len == 0:
|
||||||
|
cu_seq_ids = list(range(self.prefill_len + 1))
|
||||||
|
self.prefill_info_dict["cu_seqlens_q"] = forward_meta.cu_seqlens_q[cu_seq_ids]
|
||||||
|
# only decode
|
||||||
|
elif self.prefill_len == 0:
|
||||||
|
pass
|
||||||
|
# both prefill and decode
|
||||||
|
else:
|
||||||
|
prefill_num_tokens = paddle.sum(forward_meta.seq_lens_this_time[prefill_non_zeros_ids])
|
||||||
|
decode_num_tokens = paddle.sum(forward_meta.seq_lens_this_time[decode_non_zeros_ids])
|
||||||
|
|
||||||
|
self.prefill_info_dict["cu_seqlens_q"] = paddle.zeros(
|
||||||
|
[self.prefill_len + 1], dtype=forward_meta.cu_seqlens_q.dtype
|
||||||
|
)
|
||||||
|
self.prefill_info_dict["cu_seqlens_q"][1:] = forward_meta.seq_lens_encoder[
|
||||||
|
self.prefill_info_dict["batch_ids"], 0
|
||||||
|
]
|
||||||
|
self.prefill_info_dict["cu_seqlens_q"] = paddle.cumsum(self.prefill_info_dict["cu_seqlens_q"])
|
||||||
|
|
||||||
|
self.prefill_qkv = paddle.zeros([prefill_num_tokens, self.total_hidden_dim], dtype=self.dtype)
|
||||||
|
self.decode_qkv = paddle.zeros([decode_num_tokens, self.total_hidden_dim], dtype=self.dtype)
|
||||||
|
self.merged_output = paddle.zeros(
|
||||||
|
[prefill_num_tokens + decode_num_tokens, self.num_heads, self.head_dim], dtype=self.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
prefill_start, decode_start, start = 0, 0, 0
|
||||||
|
non_zeros_ids = forward_meta.seq_lens_this_time != 0
|
||||||
|
non_zeros_seq_lens = forward_meta.seq_lens_this_time[non_zeros_ids]
|
||||||
|
end = non_zeros_seq_lens[0]
|
||||||
|
if end > 1:
|
||||||
|
last_stage = "prefill"
|
||||||
|
prefill_end = end
|
||||||
|
decode_end = 0
|
||||||
|
else:
|
||||||
|
last_stage = "decode"
|
||||||
|
prefill_end = 0
|
||||||
|
decode_end = end
|
||||||
|
|
||||||
|
self.prefill_info_dict["id_group"] = []
|
||||||
|
self.prefill_info_dict["reverse_id_group"] = []
|
||||||
|
self.decode_info_dict["id_group"] = []
|
||||||
|
self.decode_info_dict["reverse_id_group"] = []
|
||||||
|
self.record_stages = []
|
||||||
|
for seq_len in non_zeros_seq_lens[1:]:
|
||||||
|
if seq_len > 1:
|
||||||
|
if last_stage == "decode":
|
||||||
|
self.record_stages.append((last_stage, len(self.decode_info_dict["id_group"])))
|
||||||
|
self.decode_info_dict["id_group"].append((decode_start, decode_end))
|
||||||
|
self.decode_info_dict["reverse_id_group"].append((start, end))
|
||||||
|
decode_start = decode_end
|
||||||
|
start = end
|
||||||
|
last_stage = "prefill"
|
||||||
|
prefill_end += seq_len
|
||||||
|
end += seq_len
|
||||||
|
else:
|
||||||
|
if last_stage == "prefill":
|
||||||
|
self.record_stages.append((last_stage, len(self.prefill_info_dict["id_group"])))
|
||||||
|
self.prefill_info_dict["id_group"].append((prefill_start, prefill_end))
|
||||||
|
self.prefill_info_dict["reverse_id_group"].append((start, end))
|
||||||
|
prefill_start = prefill_end
|
||||||
|
start = end
|
||||||
|
last_stage = "decode"
|
||||||
|
decode_end += seq_len
|
||||||
|
end += seq_len
|
||||||
|
|
||||||
|
if prefill_start < prefill_end:
|
||||||
|
self.record_stages.append(("prefill", len(self.prefill_info_dict["id_group"])))
|
||||||
|
self.prefill_info_dict["id_group"].append((prefill_start, prefill_end))
|
||||||
|
self.prefill_info_dict["reverse_id_group"].append((start, end))
|
||||||
|
if decode_start < decode_end:
|
||||||
|
self.record_stages.append(("decode", len(self.decode_info_dict["id_group"])))
|
||||||
|
self.decode_info_dict["id_group"].append((decode_start, decode_end))
|
||||||
|
self.decode_info_dict["reverse_id_group"].append((start, end))
|
||||||
|
|
||||||
def get_attntion_meta(self):
|
def get_attntion_meta(self):
|
||||||
"""get_attntion_meta"""
|
"""get_attntion_meta"""
|
||||||
@@ -144,93 +219,15 @@ class IluvatarAttnBackend(AttentionBackend):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_new_kv(
|
def prefill_update_kv_cache(
|
||||||
self,
|
self, k, v, k_cache_id: int, v_cache_id: int, layer_id: int, forward_meta: ForwardMeta, prefill_batch_ids: list
|
||||||
k,
|
|
||||||
v,
|
|
||||||
k_cache_id: int,
|
|
||||||
v_cache_id: int,
|
|
||||||
forward_meta: ForwardMeta,
|
|
||||||
debug_paged_attn=False,
|
|
||||||
):
|
|
||||||
new_k = []
|
|
||||||
new_v = []
|
|
||||||
tensor_start = 0
|
|
||||||
for batch_idx in range(forward_meta.block_tables.shape[0]):
|
|
||||||
seq_len = forward_meta.seq_lens_this_time[batch_idx]
|
|
||||||
if seq_len == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
tensor_end = tensor_start + seq_len
|
|
||||||
slice_k = k[tensor_start:tensor_end, :, :]
|
|
||||||
slice_v = v[tensor_start:tensor_end, :, :]
|
|
||||||
|
|
||||||
if seq_len > 1:
|
|
||||||
# prefill
|
|
||||||
new_k.append(slice_k)
|
|
||||||
new_v.append(slice_v)
|
|
||||||
else:
|
|
||||||
# decode
|
|
||||||
assert seq_len == 1
|
|
||||||
cur_block_tables = forward_meta.block_tables[batch_idx]
|
|
||||||
cur_used_block_tables = cur_block_tables[cur_block_tables != -1]
|
|
||||||
assert (
|
|
||||||
batch_idx in self.record_block_table_metadata
|
|
||||||
), f"Key error: {batch_idx} vs {self.record_block_table_metadata}."
|
|
||||||
cur_block_table_metadata = self.record_block_table_metadata[batch_idx]
|
|
||||||
record_last_block_id = cur_block_table_metadata["block_id"]
|
|
||||||
assert record_last_block_id != -1
|
|
||||||
for block_id in cur_used_block_tables:
|
|
||||||
if block_id == record_last_block_id:
|
|
||||||
cache_end = cur_block_table_metadata["cache_end"]
|
|
||||||
block_k_cache = forward_meta.caches[k_cache_id][block_id, :, 0:cache_end, :]
|
|
||||||
block_v_cache = forward_meta.caches[v_cache_id][block_id, :, 0:cache_end, :]
|
|
||||||
else:
|
|
||||||
block_k_cache = forward_meta.caches[k_cache_id][block_id]
|
|
||||||
block_v_cache = forward_meta.caches[v_cache_id][block_id]
|
|
||||||
|
|
||||||
# [num_kv_heads, block_size, head_dim] -> [block_size, num_kv_heads, head_dim]
|
|
||||||
new_k.append(block_k_cache.transpose([1, 0, 2]).contiguous())
|
|
||||||
new_v.append(block_v_cache.transpose([1, 0, 2]).contiguous())
|
|
||||||
if block_id == record_last_block_id:
|
|
||||||
break
|
|
||||||
|
|
||||||
# as line 301 show, record_block_table_metadata updates when executing the last layer,
|
|
||||||
# so slice_k and slice_v has been updated in block_k_cache and block_v_cache
|
|
||||||
if not (debug_paged_attn and (k_cache_id / 2 == self.num_layers - 1)):
|
|
||||||
new_k.append(slice_k)
|
|
||||||
new_v.append(slice_v)
|
|
||||||
|
|
||||||
tensor_start = tensor_end
|
|
||||||
|
|
||||||
if len(new_k) == 1:
|
|
||||||
return new_k[0], new_v[0]
|
|
||||||
else:
|
|
||||||
new_k = paddle.concat(new_k, axis=0)
|
|
||||||
new_v = paddle.concat(new_v, axis=0)
|
|
||||||
return new_k, new_v
|
|
||||||
|
|
||||||
def update_kv_cache(
|
|
||||||
self,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
k_cache_id: int,
|
|
||||||
v_cache_id: int,
|
|
||||||
layer_id: int,
|
|
||||||
forward_meta: ForwardMeta,
|
|
||||||
specific_batch_ids=None,
|
|
||||||
debug_paged_attn=False,
|
|
||||||
):
|
):
|
||||||
# [num_tokens, num_kv_heads, head_dim] -> [num_kv_heads, num_tokens, head_dim]
|
# [num_tokens, num_kv_heads, head_dim] -> [num_kv_heads, num_tokens, head_dim]
|
||||||
trans_k = k.transpose([1, 0, 2]).contiguous()
|
trans_k = k.transpose([1, 0, 2]).contiguous()
|
||||||
trans_v = v.transpose([1, 0, 2]).contiguous()
|
trans_v = v.transpose([1, 0, 2]).contiguous()
|
||||||
tensor_start = 0
|
tensor_start = 0
|
||||||
for batch_idx in range(forward_meta.block_tables.shape[0]):
|
for batch_idx in prefill_batch_ids:
|
||||||
if specific_batch_ids is not None and batch_idx not in specific_batch_ids:
|
|
||||||
continue
|
|
||||||
seq_len = forward_meta.seq_lens_this_time[batch_idx]
|
seq_len = forward_meta.seq_lens_this_time[batch_idx]
|
||||||
if seq_len == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
tensor_end = tensor_start + seq_len
|
tensor_end = tensor_start + seq_len
|
||||||
slice_trans_k = trans_k[:, tensor_start:tensor_end, :]
|
slice_trans_k = trans_k[:, tensor_start:tensor_end, :]
|
||||||
@@ -239,146 +236,67 @@ class IluvatarAttnBackend(AttentionBackend):
|
|||||||
cur_block_tables = forward_meta.block_tables[batch_idx]
|
cur_block_tables = forward_meta.block_tables[batch_idx]
|
||||||
cur_used_block_tables = cur_block_tables[cur_block_tables != -1]
|
cur_used_block_tables = cur_block_tables[cur_block_tables != -1]
|
||||||
|
|
||||||
# prefill
|
cache_start = 0
|
||||||
if seq_len > 1:
|
cur_used_num_blocks = cur_used_block_tables.shape[0]
|
||||||
cache_start = 0
|
for i, block_id in enumerate(cur_used_block_tables):
|
||||||
cur_used_num_blocks = cur_used_block_tables.shape[0]
|
# last block: seq_len - cache_start <= block_size
|
||||||
for i, block_id in enumerate(cur_used_block_tables):
|
if i == cur_used_num_blocks - 1:
|
||||||
# last block: seq_len - cache_start <= block_size
|
cache_end = seq_len - cache_start
|
||||||
if i == cur_used_num_blocks - 1:
|
assert cache_end <= self.attention_metadata.block_size
|
||||||
cache_end = seq_len - cache_start
|
paddle.assign(
|
||||||
assert cache_end <= self.attention_metadata.block_size
|
slice_trans_k[:, cache_start:seq_len, :],
|
||||||
forward_meta.caches[k_cache_id][block_id, :, 0:cache_end, :] = slice_trans_k[
|
output=forward_meta.caches[k_cache_id][block_id, :, 0:cache_end, :],
|
||||||
:, cache_start:seq_len, :
|
)
|
||||||
]
|
paddle.assign(
|
||||||
forward_meta.caches[v_cache_id][block_id, :, 0:cache_end, :] = slice_trans_v[
|
slice_trans_v[:, cache_start:seq_len, :],
|
||||||
:, cache_start:seq_len, :
|
output=forward_meta.caches[v_cache_id][block_id, :, 0:cache_end, :],
|
||||||
]
|
)
|
||||||
if layer_id == self.num_layers - 1:
|
if layer_id == self.num_layers - 1:
|
||||||
self.record_block_table_metadata[batch_idx] = {
|
self.record_block_table_metadata[batch_idx] = {
|
||||||
"block_id": block_id.item(),
|
"block_id": block_id.item(),
|
||||||
"cache_end": cache_end,
|
"cache_end": cache_end.item(),
|
||||||
}
|
}
|
||||||
# non last block: seq_lens_this_time > block_size
|
# non last block: seq_lens_this_time > block_size
|
||||||
else:
|
|
||||||
assert seq_len > self.attention_metadata.block_size
|
|
||||||
cache_end = cache_start + self.attention_metadata.block_size
|
|
||||||
forward_meta.caches[k_cache_id][block_id] = slice_trans_k[:, cache_start:cache_end, :]
|
|
||||||
forward_meta.caches[v_cache_id][block_id] = slice_trans_v[:, cache_start:cache_end, :]
|
|
||||||
cache_start += self.attention_metadata.block_size
|
|
||||||
else:
|
|
||||||
# decode
|
|
||||||
assert seq_len == 1
|
|
||||||
cur_last_block_id = cur_used_block_tables[-1].item()
|
|
||||||
assert cur_last_block_id != -1
|
|
||||||
assert (
|
|
||||||
batch_idx in self.record_block_table_metadata
|
|
||||||
), f"Key error: {batch_idx} vs {self.record_block_table_metadata}."
|
|
||||||
cur_block_table_metadata = self.record_block_table_metadata[batch_idx]
|
|
||||||
record_last_block_id = cur_block_table_metadata["block_id"]
|
|
||||||
|
|
||||||
if cur_last_block_id == record_last_block_id:
|
|
||||||
# not alloc new block in decode stage
|
|
||||||
cache_start = cur_block_table_metadata["cache_end"]
|
|
||||||
else:
|
else:
|
||||||
# alloc new block in decode stage
|
assert seq_len > self.attention_metadata.block_size
|
||||||
cache_start = 0
|
cache_end = cache_start + self.attention_metadata.block_size
|
||||||
|
paddle.assign(
|
||||||
cache_end = cache_start + 1
|
slice_trans_k[:, cache_start:cache_end, :], output=forward_meta.caches[k_cache_id][block_id]
|
||||||
assert cache_end <= self.attention_metadata.block_size
|
)
|
||||||
|
paddle.assign(
|
||||||
# paged attn API will update kv cache with inplace mode
|
slice_trans_v[:, cache_start:cache_end, :], output=forward_meta.caches[v_cache_id][block_id]
|
||||||
if not debug_paged_attn:
|
)
|
||||||
forward_meta.caches[k_cache_id][cur_last_block_id, :, cache_start:cache_end, :] = slice_trans_k
|
cache_start += self.attention_metadata.block_size
|
||||||
forward_meta.caches[v_cache_id][cur_last_block_id, :, cache_start:cache_end, :] = slice_trans_v
|
|
||||||
|
|
||||||
# update record_block_table_metadata
|
|
||||||
if layer_id == self.num_layers - 1:
|
|
||||||
self.record_block_table_metadata[batch_idx]["block_id"] = cur_last_block_id
|
|
||||||
self.record_block_table_metadata[batch_idx]["cache_end"] = cache_end
|
|
||||||
|
|
||||||
tensor_start = tensor_end
|
tensor_start = tensor_end
|
||||||
|
|
||||||
def _check_new_kv_correctness(self, k, v, new_k, new_v, layer_id: int, forward_meta: ForwardMeta):
|
def get_splited_qkv(
|
||||||
tensor_start = 0
|
self, qkv: paddle.Tensor, forward_meta: ForwardMeta, cu_seqlens_q: paddle.Tensor, batch_ids=None
|
||||||
for batch_idx, seq_lens_this_time in enumerate(forward_meta.seq_lens_this_time):
|
):
|
||||||
if seq_lens_this_time == 0:
|
q_end = self.hidden_dim
|
||||||
continue
|
|
||||||
# note: the second request will also use the batch_idx 0 instead of 1 in
|
|
||||||
# the streaming inference mode, so use seq_lens_this_time > 1 with the same
|
|
||||||
# batch_idx represents the second request comes.
|
|
||||||
if seq_lens_this_time > 1 and batch_idx in self.record_batched_k[layer_id]:
|
|
||||||
print(
|
|
||||||
f"clear self.record_batched_batched_k: "
|
|
||||||
f"layer_id={layer_id}, batch_id={batch_idx}, "
|
|
||||||
f"record_lens={len(self.record_batched_k[layer_id][batch_idx])}"
|
|
||||||
)
|
|
||||||
self.record_batched_k[layer_id][batch_idx].clear()
|
|
||||||
self.record_batched_v[layer_id][batch_idx].clear()
|
|
||||||
tensor_end = tensor_start + seq_lens_this_time
|
|
||||||
slice_k = k[tensor_start:tensor_end, :, :]
|
|
||||||
slice_v = v[tensor_start:tensor_end, :, :]
|
|
||||||
if batch_idx not in self.record_batched_k[layer_id]:
|
|
||||||
self.record_batched_k[layer_id][batch_idx] = []
|
|
||||||
self.record_batched_v[layer_id][batch_idx] = []
|
|
||||||
self.record_batched_k[layer_id][batch_idx].append(slice_k)
|
|
||||||
self.record_batched_v[layer_id][batch_idx].append(slice_v)
|
|
||||||
tensor_start = tensor_end
|
|
||||||
|
|
||||||
ref_k, ref_v = [], []
|
|
||||||
for batch_idx, seq_lens_this_time in enumerate(forward_meta.seq_lens_this_time):
|
|
||||||
if seq_lens_this_time == 0:
|
|
||||||
continue
|
|
||||||
bached_k_list = self.record_batched_k[layer_id][batch_idx]
|
|
||||||
bached_v_list = self.record_batched_v[layer_id][batch_idx]
|
|
||||||
ref_k.extend(bached_k_list)
|
|
||||||
ref_v.extend(bached_v_list)
|
|
||||||
|
|
||||||
ref_k = paddle.concat(ref_k, axis=0)
|
|
||||||
ref_v = paddle.concat(ref_v, axis=0)
|
|
||||||
print(
|
|
||||||
f"_check_new_kv_correctness: layer_id={layer_id}, "
|
|
||||||
f"k.shape={k.shape}, v.shape={v.shape}, "
|
|
||||||
f"ref_k.shape={ref_k.shape}, ref_v.shape={ref_v.shape}, "
|
|
||||||
f"new_k.shape={new_k.shape}, new_v.shape={new_v.shape}, "
|
|
||||||
f"len(self.record_batched_k[layer_id])={len(self.record_batched_k[layer_id])}, "
|
|
||||||
f"len(self.record_batched_k[layer_id][0])={len(self.record_batched_k[layer_id][0])}, "
|
|
||||||
f"forward_meta.seq_lens_this_time={forward_meta.seq_lens_this_time}"
|
|
||||||
f"ref_k[-2:, 0:2, 0:2]={ref_k[-2:, 0:2, 0:2]}, "
|
|
||||||
f"ref_v[-2:, 0:2, 0:2]={ref_v[-2:, 0:2, 0:2]}, "
|
|
||||||
f"new_k[-2:, 0:2, 0:2]={new_k[-2:, 0:2, 0:2]}, "
|
|
||||||
f"new_v[-2:, 0:2, 0:2]={new_v[-2:, 0:2, 0:2]}"
|
|
||||||
)
|
|
||||||
assert paddle.allclose(
|
|
||||||
ref_k.to("cpu").to(paddle.float32),
|
|
||||||
new_k.to("cpu").to(paddle.float32),
|
|
||||||
)
|
|
||||||
assert paddle.allclose(
|
|
||||||
ref_v.to("cpu").to(paddle.float32),
|
|
||||||
new_v.to("cpu").to(paddle.float32),
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_splited_qkv(self, qkv: paddle.Tensor, forward_meta: ForwardMeta):
|
|
||||||
q_end = self.num_heads * self.head_dim
|
|
||||||
k_end = q_end + self.attention_metadata.num_kv_heads * self.head_dim
|
k_end = q_end + self.attention_metadata.num_kv_heads * self.head_dim
|
||||||
v_end = k_end + self.attention_metadata.num_kv_heads * self.head_dim
|
v_end = k_end + self.attention_metadata.num_kv_heads * self.head_dim
|
||||||
assert v_end == qkv.shape[-1], f"Shape mistach: {v_end} vs {qkv.shape[-1]}"
|
assert v_end == qkv.shape[-1], f"Shape mismatch: {v_end} vs {qkv.shape[-1]}"
|
||||||
assert qkv.shape[0] == forward_meta.cu_seqlens_q[-1]
|
assert qkv.shape[0] == cu_seqlens_q[-1], f"Shape mismatch: {qkv.shape[0]} vs {cu_seqlens_q[-1]}"
|
||||||
|
|
||||||
|
if batch_ids is None:
|
||||||
|
batch_ids = list(range(forward_meta.seq_lens_this_time.shape[0]))
|
||||||
|
|
||||||
q = qkv[..., 0:q_end]
|
q = qkv[..., 0:q_end]
|
||||||
k = qkv[..., q_end:k_end]
|
k = qkv[..., q_end:k_end]
|
||||||
v = qkv[..., k_end:v_end]
|
v = qkv[..., k_end:v_end]
|
||||||
q = q.view([-1, self.num_heads, self.head_dim]).contiguous()
|
q = q.view([-1, self.num_heads, self.head_dim])
|
||||||
k = k.view([-1, self.attention_metadata.num_kv_heads, self.head_dim]).contiguous()
|
k = k.view([-1, self.attention_metadata.num_kv_heads, self.head_dim])
|
||||||
v = v.view([-1, self.attention_metadata.num_kv_heads, self.head_dim]).contiguous()
|
v = v.view([-1, self.attention_metadata.num_kv_heads, self.head_dim])
|
||||||
# forward_meta.seq_lens_this_time [max_batch,]
|
|
||||||
for batch_idx in range(forward_meta.seq_lens_this_time.shape[0]):
|
for idx in range(len(cu_seqlens_q) - 1):
|
||||||
|
batch_idx = batch_ids[idx]
|
||||||
seq_len_i = forward_meta.seq_lens_this_time[batch_idx]
|
seq_len_i = forward_meta.seq_lens_this_time[batch_idx]
|
||||||
if seq_len_i == 0:
|
if seq_len_i == 0:
|
||||||
continue
|
continue
|
||||||
cached_kv_len = forward_meta.seq_lens_decoder[batch_idx][0]
|
cached_kv_len = forward_meta.seq_lens_decoder[batch_idx][0]
|
||||||
cu_seq_start_q = forward_meta.cu_seqlens_q[batch_idx]
|
cu_seq_start_q = cu_seqlens_q[idx]
|
||||||
cu_seq_end_q = forward_meta.cu_seqlens_q[batch_idx + 1]
|
cu_seq_end_q = cu_seqlens_q[idx + 1]
|
||||||
# forward_meta.rotary_embs is [2, 1, S, 1, D]
|
# forward_meta.rotary_embs is [2, 1, S, 1, D]
|
||||||
if forward_meta.rotary_embs is not None:
|
if forward_meta.rotary_embs is not None:
|
||||||
cos = forward_meta.rotary_embs[0, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :]
|
cos = forward_meta.rotary_embs[0, 0, cached_kv_len : cached_kv_len + seq_len_i, :, :]
|
||||||
@@ -388,75 +306,114 @@ class IluvatarAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
return q, k, v
|
return q, k, v
|
||||||
|
|
||||||
def get_splited_info_by_stage(self, q, k, v, forward_meta: ForwardMeta):
|
def split_pd_qkv(self, qkv):
|
||||||
prefill_info_dict = {"q": [], "k": [], "v": [], "batch_ids": []}
|
|
||||||
decode_info_dict = {"q": [], "k": [], "v": [], "batch_ids": []}
|
for ids, reverse_ids in zip(self.prefill_info_dict["id_group"], self.prefill_info_dict["reverse_id_group"]):
|
||||||
tensor_start = 0
|
self.prefill_qkv[ids[0] : ids[1], :] = qkv[reverse_ids[0] : reverse_ids[1], :]
|
||||||
for batch_idx, seq_lens_this_time in enumerate(forward_meta.seq_lens_this_time):
|
|
||||||
if seq_lens_this_time == 0:
|
for ids, reverse_ids in zip(self.decode_info_dict["id_group"], self.decode_info_dict["reverse_id_group"]):
|
||||||
continue
|
self.decode_qkv[ids[0] : ids[1], :] = qkv[reverse_ids[0] : reverse_ids[1], :]
|
||||||
tensor_end = tensor_start + seq_lens_this_time
|
|
||||||
slice_q = q[tensor_start:tensor_end, :, :]
|
return self.prefill_qkv, self.decode_qkv
|
||||||
slice_k = k[tensor_start:tensor_end, :, :]
|
|
||||||
slice_v = v[tensor_start:tensor_end, :, :]
|
def merge_pd_output(self, prefill_out, decode_out):
|
||||||
if seq_lens_this_time > 1:
|
for stage, idx in self.record_stages:
|
||||||
prefill_info_dict["q"].append(slice_q)
|
if stage == "prefill":
|
||||||
prefill_info_dict["k"].append(slice_k)
|
ids = self.prefill_info_dict["id_group"][idx]
|
||||||
prefill_info_dict["v"].append(slice_v)
|
reverse_ids = self.prefill_info_dict["reverse_id_group"][idx]
|
||||||
prefill_info_dict["batch_ids"].append(batch_idx)
|
self.merged_output[reverse_ids[0] : reverse_ids[1], :, :] = prefill_out[ids[0] : ids[1], :, :]
|
||||||
else:
|
else:
|
||||||
assert seq_lens_this_time == 1
|
ids = self.decode_info_dict["id_group"][idx]
|
||||||
decode_info_dict["q"].append(slice_q)
|
reverse_ids = self.decode_info_dict["reverse_id_group"][idx]
|
||||||
decode_info_dict["k"].append(slice_k)
|
self.merged_output[reverse_ids[0] : reverse_ids[1], :, :] = decode_out[ids[0] : ids[1], :, :]
|
||||||
decode_info_dict["v"].append(slice_v)
|
return self.merged_output
|
||||||
decode_info_dict["batch_ids"].append(batch_idx)
|
|
||||||
tensor_start = tensor_end
|
|
||||||
|
|
||||||
if len(prefill_info_dict["batch_ids"]) > 0:
|
def forward_prefill(self, prefill_qkv, layer_id, k_cache_id, v_cache_id, forward_meta: ForwardMeta):
|
||||||
prefill_info_dict["q"] = paddle.concat(prefill_info_dict["q"], axis=0)
|
prefill_q, prefill_k, prefill_v = self.get_splited_qkv(
|
||||||
prefill_info_dict["k"] = paddle.concat(prefill_info_dict["k"], axis=0)
|
prefill_qkv,
|
||||||
prefill_info_dict["v"] = paddle.concat(prefill_info_dict["v"], axis=0)
|
forward_meta,
|
||||||
cu_seq_ids = list(map(lambda x: x + 1, prefill_info_dict["batch_ids"]))
|
self.prefill_info_dict["cu_seqlens_q"],
|
||||||
prefill_info_dict["cu_seq_ids"] = [0, *cu_seq_ids]
|
batch_ids=self.prefill_info_dict["batch_ids"],
|
||||||
|
)
|
||||||
|
|
||||||
if len(decode_info_dict["batch_ids"]) > 0:
|
prefill_out = flash_attn_unpadded(
|
||||||
decode_info_dict["q"] = paddle.concat(decode_info_dict["q"], axis=0)
|
prefill_q,
|
||||||
decode_info_dict["k"] = paddle.concat(decode_info_dict["k"], axis=0)
|
prefill_k,
|
||||||
decode_info_dict["v"] = paddle.concat(decode_info_dict["v"], axis=0)
|
prefill_v,
|
||||||
|
cu_seqlens_q=self.prefill_info_dict["cu_seqlens_q"],
|
||||||
|
cu_seqlens_k=self.prefill_info_dict["cu_seqlens_q"],
|
||||||
|
max_seqlen_q=self.attention_metadata.max_context_len,
|
||||||
|
max_seqlen_k=self.attention_metadata.max_context_len,
|
||||||
|
scale=self.attention_metadata.scale,
|
||||||
|
dropout=self.attention_metadata.dropout,
|
||||||
|
causal=self.attention_metadata.causal,
|
||||||
|
return_softmax=self.attention_metadata.return_softmax,
|
||||||
|
)[0]
|
||||||
|
self.prefill_update_kv_cache(
|
||||||
|
prefill_k, prefill_v, k_cache_id, v_cache_id, layer_id, forward_meta, self.prefill_info_dict["batch_ids"]
|
||||||
|
)
|
||||||
|
|
||||||
return prefill_info_dict, decode_info_dict
|
return prefill_out
|
||||||
|
|
||||||
def merge_output(self, prefill_out, decode_out, forward_meta: ForwardMeta):
|
def forward_decode(self, decode_qkv, k_cache_id, v_cache_id, forward_meta: ForwardMeta):
|
||||||
assert not (prefill_out is None and decode_out is None), "prefill and decode output cannot both be None"
|
k_cache = forward_meta.caches[k_cache_id]
|
||||||
if prefill_out is None:
|
v_cache = forward_meta.caches[v_cache_id]
|
||||||
return decode_out
|
if self.enable_fused_attention:
|
||||||
elif decode_out is None:
|
rope_cos = forward_meta.rotary_embs[0, 0, :, :, :]
|
||||||
return prefill_out
|
rope_sin = forward_meta.rotary_embs[1, 0, :, :, :]
|
||||||
|
decode_out = paged_attention(
|
||||||
|
decode_qkv.view([-1, self.total_num_heads, self.head_dim]),
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
block_tables=forward_meta.block_tables[self.decode_info_dict["batch_ids"], :],
|
||||||
|
seq_lens=forward_meta.seq_lens_decoder[self.decode_info_dict["batch_ids"], 0] + 1,
|
||||||
|
num_kv_heads=self.attention_metadata.num_kv_heads,
|
||||||
|
scale=self.attention_metadata.scale,
|
||||||
|
block_size=self.attention_metadata.block_size,
|
||||||
|
max_context_len=self.attention_metadata.max_context_len,
|
||||||
|
alibi_slopes=self.attention_metadata.alibi_slopes,
|
||||||
|
causal=self.attention_metadata.causal,
|
||||||
|
window_left=self.attention_metadata.window_left,
|
||||||
|
window_right=self.attention_metadata.window_right,
|
||||||
|
softcap=self.attention_metadata.softcap,
|
||||||
|
use_cuda_graph=self.attention_metadata.use_cuda_graph,
|
||||||
|
use_sqrt_alibi=self.attention_metadata.use_sqrt_alibi,
|
||||||
|
merged_qkv=True,
|
||||||
|
k=decode_qkv,
|
||||||
|
v=decode_qkv,
|
||||||
|
rope_sin=rope_sin,
|
||||||
|
rope_cos=rope_cos,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
merged_output = []
|
decode_q, decode_k, decode_v = self.get_splited_qkv(
|
||||||
prefill_tensor_start = 0
|
decode_qkv,
|
||||||
decode_tensor_start = 0
|
forward_meta,
|
||||||
for seq_lens_this_time in forward_meta.seq_lens_this_time:
|
self.decode_info_dict["cu_seqlens_q"],
|
||||||
if seq_lens_this_time == 0:
|
batch_ids=self.decode_info_dict["batch_ids"],
|
||||||
continue
|
)
|
||||||
if seq_lens_this_time > 1:
|
|
||||||
tensor_end = prefill_tensor_start + seq_lens_this_time
|
|
||||||
merged_output.append(prefill_out[prefill_tensor_start:tensor_end, :, :])
|
|
||||||
prefill_tensor_start = tensor_end
|
|
||||||
else:
|
|
||||||
assert seq_lens_this_time == 1
|
|
||||||
tensor_end = decode_tensor_start + seq_lens_this_time
|
|
||||||
merged_output.append(decode_out[decode_tensor_start:tensor_end, :, :])
|
|
||||||
decode_tensor_start = tensor_end
|
|
||||||
|
|
||||||
assert (
|
decode_out = paged_attention(
|
||||||
prefill_tensor_start == prefill_out.shape[0]
|
decode_q,
|
||||||
), f"prefill merged unfinished: {prefill_tensor_start} vs {prefill_out.shape[0]}"
|
k_cache,
|
||||||
assert (
|
v_cache,
|
||||||
decode_tensor_start == decode_out.shape[0]
|
block_tables=forward_meta.block_tables[self.decode_info_dict["batch_ids"], :],
|
||||||
), f"decode merged unfinished: {decode_tensor_start} vs {decode_out.shape[0]}"
|
seq_lens=forward_meta.seq_lens_decoder[self.decode_info_dict["batch_ids"], 0] + 1,
|
||||||
merged_output = paddle.concat(merged_output, axis=0)
|
num_kv_heads=self.attention_metadata.num_kv_heads,
|
||||||
return merged_output
|
scale=self.attention_metadata.scale,
|
||||||
|
block_size=self.attention_metadata.block_size,
|
||||||
|
max_context_len=self.attention_metadata.max_context_len,
|
||||||
|
alibi_slopes=self.attention_metadata.alibi_slopes,
|
||||||
|
causal=self.attention_metadata.causal,
|
||||||
|
window_left=self.attention_metadata.window_left,
|
||||||
|
window_right=self.attention_metadata.window_right,
|
||||||
|
softcap=self.attention_metadata.softcap,
|
||||||
|
use_cuda_graph=self.attention_metadata.use_cuda_graph,
|
||||||
|
use_sqrt_alibi=self.attention_metadata.use_sqrt_alibi,
|
||||||
|
k=decode_k,
|
||||||
|
v=decode_v,
|
||||||
|
)
|
||||||
|
|
||||||
|
return decode_out
|
||||||
|
|
||||||
def forward_mixed(
|
def forward_mixed(
|
||||||
self,
|
self,
|
||||||
@@ -476,110 +433,19 @@ class IluvatarAttnBackend(AttentionBackend):
|
|||||||
layer_id = layer.layer_id
|
layer_id = layer.layer_id
|
||||||
k_cache_id = layer_id * 2
|
k_cache_id = layer_id * 2
|
||||||
v_cache_id = k_cache_id + 1
|
v_cache_id = k_cache_id + 1
|
||||||
|
|
||||||
assert qkv is not None
|
|
||||||
q_dim = qkv.dim()
|
q_dim = qkv.dim()
|
||||||
q, k, v = self.get_splited_qkv(qkv, forward_meta)
|
assert q_dim == 2
|
||||||
|
|
||||||
if self.only_use_flash_attn:
|
if self.decode_len == 0:
|
||||||
new_k, new_v = self.get_new_kv(k, v, k_cache_id, v_cache_id, forward_meta)
|
output = self.forward_prefill(qkv, layer_id, k_cache_id, v_cache_id, forward_meta)
|
||||||
if self.do_check_kv_cache:
|
|
||||||
self._check_new_kv_correctness(k, v, new_k, new_v, layer_id, forward_meta)
|
|
||||||
|
|
||||||
out = flash_attn_unpadded(
|
elif self.prefill_len == 0:
|
||||||
q,
|
output = self.forward_decode(qkv, k_cache_id, v_cache_id, forward_meta)
|
||||||
new_k,
|
|
||||||
new_v,
|
|
||||||
cu_seqlens_q=self.attention_metadata.cu_seqlens_q,
|
|
||||||
cu_seqlens_k=self.attention_metadata.cu_seqlens_k,
|
|
||||||
max_seqlen_q=self.attention_metadata.max_context_len,
|
|
||||||
max_seqlen_k=self.attention_metadata.max_context_len,
|
|
||||||
scale=self.attention_metadata.scale,
|
|
||||||
dropout=self.attention_metadata.dropout,
|
|
||||||
causal=self.attention_metadata.causal,
|
|
||||||
return_softmax=self.attention_metadata.return_softmax,
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
self.update_kv_cache(k, v, k_cache_id, v_cache_id, layer_id, forward_meta)
|
|
||||||
else:
|
else:
|
||||||
prefill_info_dict, decode_info_dict = self.get_splited_info_by_stage(q, k, v, forward_meta)
|
prefill_qkv, decode_qkv = self.split_pd_qkv(qkv)
|
||||||
prefill_out, decode_out = None, None
|
prefill_output = self.forward_prefill(prefill_qkv, layer_id, k_cache_id, v_cache_id, forward_meta)
|
||||||
|
decode_output = self.forward_decode(decode_qkv, k_cache_id, v_cache_id, forward_meta)
|
||||||
|
output = self.merge_pd_output(prefill_output, decode_output)
|
||||||
|
|
||||||
if len(prefill_info_dict["batch_ids"]) > 0:
|
output = output.view([-1, self.num_heads * self.head_dim])
|
||||||
prefill_out = flash_attn_unpadded(
|
return output
|
||||||
prefill_info_dict["q"],
|
|
||||||
prefill_info_dict["k"],
|
|
||||||
prefill_info_dict["v"],
|
|
||||||
cu_seqlens_q=forward_meta.cu_seqlens_q[prefill_info_dict["cu_seq_ids"]],
|
|
||||||
cu_seqlens_k=forward_meta.cu_seqlens_k[prefill_info_dict["cu_seq_ids"]],
|
|
||||||
max_seqlen_q=self.attention_metadata.max_context_len,
|
|
||||||
max_seqlen_k=self.attention_metadata.max_context_len,
|
|
||||||
scale=self.attention_metadata.scale,
|
|
||||||
dropout=self.attention_metadata.dropout,
|
|
||||||
causal=self.attention_metadata.causal,
|
|
||||||
return_softmax=self.attention_metadata.return_softmax,
|
|
||||||
)[0]
|
|
||||||
self.update_kv_cache(
|
|
||||||
prefill_info_dict["k"],
|
|
||||||
prefill_info_dict["v"],
|
|
||||||
k_cache_id,
|
|
||||||
v_cache_id,
|
|
||||||
layer_id,
|
|
||||||
forward_meta,
|
|
||||||
specific_batch_ids=prefill_info_dict["batch_ids"],
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(decode_info_dict["batch_ids"]) > 0:
|
|
||||||
k_cache = forward_meta.caches[k_cache_id]
|
|
||||||
v_cache = forward_meta.caches[v_cache_id]
|
|
||||||
|
|
||||||
decode_out = paged_attention(
|
|
||||||
decode_info_dict["q"],
|
|
||||||
k_cache,
|
|
||||||
v_cache,
|
|
||||||
block_tables=forward_meta.block_tables[decode_info_dict["batch_ids"], :],
|
|
||||||
seq_lens=forward_meta.seq_lens_decoder[decode_info_dict["batch_ids"], 0] + 1,
|
|
||||||
num_kv_heads=self.attention_metadata.num_kv_heads,
|
|
||||||
scale=self.attention_metadata.scale,
|
|
||||||
block_size=self.attention_metadata.block_size,
|
|
||||||
max_context_len=self.attention_metadata.max_context_len,
|
|
||||||
alibi_slopes=self.attention_metadata.alibi_slopes,
|
|
||||||
causal=self.attention_metadata.causal,
|
|
||||||
window_left=self.attention_metadata.window_left,
|
|
||||||
window_right=self.attention_metadata.window_right,
|
|
||||||
softcap=self.attention_metadata.softcap,
|
|
||||||
use_cuda_graph=self.attention_metadata.use_cuda_graph,
|
|
||||||
use_sqrt_alibi=self.attention_metadata.use_sqrt_alibi,
|
|
||||||
k=decode_info_dict["k"],
|
|
||||||
v=decode_info_dict["v"],
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.do_check_kv_cache:
|
|
||||||
self.update_kv_cache(
|
|
||||||
decode_info_dict["k"],
|
|
||||||
decode_info_dict["v"],
|
|
||||||
k_cache_id,
|
|
||||||
v_cache_id,
|
|
||||||
layer_id,
|
|
||||||
forward_meta,
|
|
||||||
specific_batch_ids=decode_info_dict["batch_ids"],
|
|
||||||
debug_paged_attn=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.do_check_kv_cache:
|
|
||||||
new_k, new_v = self.get_new_kv(
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
k_cache_id,
|
|
||||||
v_cache_id,
|
|
||||||
forward_meta,
|
|
||||||
debug_paged_attn=True,
|
|
||||||
)
|
|
||||||
self._check_new_kv_correctness(k, v, new_k, new_v, layer_id, forward_meta)
|
|
||||||
|
|
||||||
out = self.merge_output(prefill_out, decode_out, forward_meta)
|
|
||||||
|
|
||||||
if q_dim == 2:
|
|
||||||
out = out.view([-1, self.num_heads * self.head_dim])
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
@@ -128,10 +128,16 @@ def rejection_top_p_sampling(
|
|||||||
rejection_top_p_sampling
|
rejection_top_p_sampling
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from fastdeploy.model_executor.ops.gpu import (
|
if current_platform.is_iluvatar():
|
||||||
rejection_top_p_sampling,
|
from fastdeploy.model_executor.ops.iluvatar import (
|
||||||
top_k_renorm_probs,
|
rejection_top_p_sampling,
|
||||||
)
|
top_k_renorm_probs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from fastdeploy.model_executor.ops.gpu import (
|
||||||
|
rejection_top_p_sampling,
|
||||||
|
top_k_renorm_probs,
|
||||||
|
)
|
||||||
|
|
||||||
if paddle.count_nonzero(top_k) == 0:
|
if paddle.count_nonzero(top_k) == 0:
|
||||||
ids = rejection_top_p_sampling(
|
ids = rejection_top_p_sampling(
|
||||||
|
@@ -20,6 +20,11 @@ import paddle
|
|||||||
from paddle.incubate.nn.functional import swiglu
|
from paddle.incubate.nn.functional import swiglu
|
||||||
from paddle.nn.quant import weight_only_linear
|
from paddle.nn.quant import weight_only_linear
|
||||||
|
|
||||||
|
try:
|
||||||
|
from fastdeploy.model_executor.ops.iluvatar import w8a16_group_gemm
|
||||||
|
except ImportError:
|
||||||
|
w8a16_group_gemm = None
|
||||||
|
|
||||||
|
|
||||||
def group_gemm(
|
def group_gemm(
|
||||||
input: paddle.Tensor,
|
input: paddle.Tensor,
|
||||||
@@ -67,53 +72,32 @@ def group_gemm(
|
|||||||
scale_i = scale[i]
|
scale_i = scale[i]
|
||||||
# avoid d2d?
|
# avoid d2d?
|
||||||
output[expert_start:expert_end] = weight_only_linear(
|
output[expert_start:expert_end] = weight_only_linear(
|
||||||
input_i,
|
input_i, weight_i, weight_scale=scale_i, weight_dtype="int8", group_size=-1
|
||||||
weight_i,
|
|
||||||
weight_scale=scale_i,
|
|
||||||
weight_dtype="int8",
|
|
||||||
group_size=-1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def iluvatar_moe_expert_ffn(
|
def iluvatar_moe_expert_ffn(
|
||||||
permute_input: paddle.Tensor,
|
permute_input: paddle.Tensor,
|
||||||
tokens_expert_prefix_sum: paddle.Tensor,
|
tokens_expert_prefix_sum: paddle.Tensor,
|
||||||
up_gate_proj_weight: paddle.Tensor,
|
ffn1_weight: paddle.Tensor,
|
||||||
down_proj_weight: paddle.Tensor,
|
ffn2_weight: paddle.Tensor,
|
||||||
up_gate_proj_bias: Optional[paddle.Tensor],
|
ffn1_bias: Optional[paddle.Tensor],
|
||||||
up_gate_proj_scale: Optional[paddle.Tensor],
|
ffn1_scale: Optional[paddle.Tensor],
|
||||||
down_proj_scale: Optional[paddle.Tensor],
|
ffn2_scale: Optional[paddle.Tensor],
|
||||||
down_proj_in_scale: Optional[paddle.Tensor],
|
ffn2_in_scale: Optional[paddle.Tensor],
|
||||||
expert_idx_per_token: Optional[paddle.Tensor],
|
expert_idx_per_token: Optional[paddle.Tensor],
|
||||||
quant_method: str,
|
quant_method: str,
|
||||||
used_in_ep_low_latency: bool,
|
used_in_ep_low_latency: bool,
|
||||||
):
|
):
|
||||||
assert up_gate_proj_bias is None
|
assert ffn1_bias is None
|
||||||
assert up_gate_proj_scale is not None
|
assert ffn1_scale is not None
|
||||||
assert down_proj_scale is not None
|
assert ffn2_scale is not None
|
||||||
assert down_proj_in_scale is None
|
assert ffn2_in_scale is None
|
||||||
assert expert_idx_per_token is None
|
assert expert_idx_per_token is None
|
||||||
assert quant_method in ("weight_only_int8")
|
assert quant_method in ("weight_only_int8")
|
||||||
assert not used_in_ep_low_latency
|
assert not used_in_ep_low_latency
|
||||||
tokens_expert_prefix_sum_cpu = tokens_expert_prefix_sum.to("cpu")
|
tokens_expert_prefix_sum_cpu = tokens_expert_prefix_sum.to("cpu")
|
||||||
up_gate_proj_output = paddle.empty(
|
ffn1_output = w8a16_group_gemm(permute_input, ffn1_weight, ffn1_scale, tokens_expert_prefix_sum_cpu, -1)
|
||||||
[permute_input.shape[0], up_gate_proj_weight.shape[1]],
|
act_out = swiglu(ffn1_output)
|
||||||
dtype=permute_input.dtype,
|
output = w8a16_group_gemm(act_out, ffn2_weight, ffn2_scale, tokens_expert_prefix_sum_cpu, -1)
|
||||||
)
|
|
||||||
group_gemm(
|
|
||||||
permute_input,
|
|
||||||
tokens_expert_prefix_sum_cpu,
|
|
||||||
up_gate_proj_weight,
|
|
||||||
up_gate_proj_scale,
|
|
||||||
up_gate_proj_output,
|
|
||||||
)
|
|
||||||
act_out = swiglu(up_gate_proj_output)
|
|
||||||
output = paddle.empty([act_out.shape[0], down_proj_weight.shape[1]], dtype=act_out.dtype)
|
|
||||||
group_gemm(
|
|
||||||
act_out,
|
|
||||||
tokens_expert_prefix_sum_cpu,
|
|
||||||
down_proj_weight,
|
|
||||||
down_proj_scale,
|
|
||||||
output,
|
|
||||||
)
|
|
||||||
return output
|
return output
|
||||||
|
@@ -39,8 +39,11 @@ def paged_attention(
|
|||||||
softcap: float = 0.0,
|
softcap: float = 0.0,
|
||||||
use_cuda_graph: bool = False,
|
use_cuda_graph: bool = False,
|
||||||
use_sqrt_alibi: bool = False,
|
use_sqrt_alibi: bool = False,
|
||||||
|
merged_qkv: bool = False,
|
||||||
k: paddle.Tensor = None,
|
k: paddle.Tensor = None,
|
||||||
v: paddle.Tensor = None,
|
v: paddle.Tensor = None,
|
||||||
|
rope_sin: paddle.Tensor = None,
|
||||||
|
rope_cos: paddle.Tensor = None,
|
||||||
):
|
):
|
||||||
output = paged_attn(
|
output = paged_attn(
|
||||||
q,
|
q,
|
||||||
@@ -51,6 +54,8 @@ def paged_attention(
|
|||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
|
rope_sin,
|
||||||
|
rope_cos,
|
||||||
num_kv_heads,
|
num_kv_heads,
|
||||||
scale,
|
scale,
|
||||||
block_size,
|
block_size,
|
||||||
@@ -61,5 +66,6 @@ def paged_attention(
|
|||||||
softcap,
|
softcap,
|
||||||
use_cuda_graph,
|
use_cuda_graph,
|
||||||
use_sqrt_alibi,
|
use_sqrt_alibi,
|
||||||
|
merged_qkv,
|
||||||
)
|
)
|
||||||
return output[0] if isinstance(output, list) else output
|
return output[0] if isinstance(output, list) else output
|
||||||
|
@@ -211,7 +211,7 @@ def post_process_normal(
|
|||||||
model_output.stop_flags,
|
model_output.stop_flags,
|
||||||
)
|
)
|
||||||
|
|
||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda() or current_platform.is_iluvatar():
|
||||||
set_stop_value_multi_ends(
|
set_stop_value_multi_ends(
|
||||||
sampler_output.sampled_token_ids,
|
sampler_output.sampled_token_ids,
|
||||||
model_output.stop_flags,
|
model_output.stop_flags,
|
||||||
|
@@ -41,20 +41,28 @@ from fastdeploy.model_executor.layers.rotary_embedding import get_rope, get_rope
|
|||||||
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
|
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
|
||||||
from fastdeploy.model_executor.layers.sample.sampler import Sampler, SpeculativeSampler
|
from fastdeploy.model_executor.layers.sample.sampler import Sampler, SpeculativeSampler
|
||||||
from fastdeploy.model_executor.model_loader import get_model_loader
|
from fastdeploy.model_executor.model_loader import get_model_loader
|
||||||
from fastdeploy.model_executor.ops.gpu import (
|
from fastdeploy.platforms import current_platform
|
||||||
recover_decode_task,
|
|
||||||
set_value_by_flags_and_idx,
|
if current_platform.is_iluvatar():
|
||||||
share_external_data,
|
from fastdeploy.model_executor.ops.iluvatar import set_value_by_flags_and_idx
|
||||||
)
|
|
||||||
|
recover_decode_task = None
|
||||||
|
share_external_data = None
|
||||||
|
else:
|
||||||
|
from fastdeploy.model_executor.ops.gpu import (
|
||||||
|
recover_decode_task,
|
||||||
|
set_value_by_flags_and_idx,
|
||||||
|
share_external_data,
|
||||||
|
)
|
||||||
|
|
||||||
from fastdeploy.model_executor.pre_and_post_process import (
|
from fastdeploy.model_executor.pre_and_post_process import (
|
||||||
post_process,
|
post_process,
|
||||||
pre_process,
|
pre_process,
|
||||||
rebuild_padding,
|
rebuild_padding,
|
||||||
step_cuda,
|
step_cuda,
|
||||||
)
|
)
|
||||||
from fastdeploy.platforms import current_platform
|
|
||||||
|
|
||||||
if not current_platform.is_dcu():
|
if not (current_platform.is_dcu() or current_platform.is_iluvatar()):
|
||||||
from fastdeploy.spec_decode import MTPProposer, NgramProposer
|
from fastdeploy.spec_decode import MTPProposer, NgramProposer
|
||||||
|
|
||||||
from fastdeploy import envs
|
from fastdeploy import envs
|
||||||
|
File diff suppressed because it is too large
Load Diff
@@ -16,22 +16,22 @@
|
|||||||
|
|
||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
from typing import List, Optional
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import paddle
|
import paddle
|
||||||
from paddle import nn
|
|
||||||
|
|
||||||
from fastdeploy.config import FDConfig
|
from fastdeploy.config import FDConfig
|
||||||
from fastdeploy.engine.request import Request
|
from fastdeploy.inter_communicator import IPCSignal
|
||||||
from fastdeploy.utils import get_logger, set_random_seed
|
from fastdeploy.utils import get_logger, set_random_seed
|
||||||
|
from fastdeploy.worker.gpu_worker import GpuWorker
|
||||||
from fastdeploy.worker.iluvatar_model_runner import IluvatarModelRunner
|
from fastdeploy.worker.iluvatar_model_runner import IluvatarModelRunner
|
||||||
from fastdeploy.worker.output import ModelRunnerOutput
|
from fastdeploy.worker.worker_process import PaddleDisWorkerProc
|
||||||
from fastdeploy.worker.worker_base import WorkerBase
|
|
||||||
|
|
||||||
logger = get_logger("iluvatar_worker", "iluvatar_worker.log")
|
logger = get_logger("iluvatar_worker", "iluvatar_worker.log")
|
||||||
|
|
||||||
|
|
||||||
class IluvatarWorker(WorkerBase):
|
class IluvatarWorker(GpuWorker):
|
||||||
""" """
|
""" """
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -40,15 +40,16 @@ class IluvatarWorker(WorkerBase):
|
|||||||
local_rank: int,
|
local_rank: int,
|
||||||
rank: int,
|
rank: int,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super(IluvatarWorker, self).__init__(
|
||||||
fd_config=fd_config,
|
fd_config=fd_config,
|
||||||
local_rank=local_rank,
|
local_rank=local_rank,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
)
|
)
|
||||||
pass
|
|
||||||
|
|
||||||
def init_device(self):
|
def init_device(self):
|
||||||
"""Initialize device and Construct model runner"""
|
"""
|
||||||
|
Initialize device and construct model runner
|
||||||
|
"""
|
||||||
if paddle.is_compiled_with_custom_device("iluvatar_gpu"):
|
if paddle.is_compiled_with_custom_device("iluvatar_gpu"):
|
||||||
# Set evironment variable
|
# Set evironment variable
|
||||||
self.device = f"iluvatar_gpu:{self.local_rank}"
|
self.device = f"iluvatar_gpu:{self.local_rank}"
|
||||||
@@ -70,12 +71,6 @@ class IluvatarWorker(WorkerBase):
|
|||||||
local_rank=self.local_rank,
|
local_rank=self.local_rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
def exist_prefill(self):
|
|
||||||
"""
|
|
||||||
check whether prefill stage exist
|
|
||||||
"""
|
|
||||||
return self.model_runner.exist_prefill()
|
|
||||||
|
|
||||||
def determine_available_memory(self) -> int:
|
def determine_available_memory(self) -> int:
|
||||||
"""
|
"""
|
||||||
Profiles the peak memory usage of the model to determine how much
|
Profiles the peak memory usage of the model to determine how much
|
||||||
@@ -92,51 +87,86 @@ class IluvatarWorker(WorkerBase):
|
|||||||
# 1. Record memory state before profile run
|
# 1. Record memory state before profile run
|
||||||
return int(float(os.getenv("FD_ILUVATAR_KVCACHE_MEM", "3")) * 1024**3)
|
return int(float(os.getenv("FD_ILUVATAR_KVCACHE_MEM", "3")) * 1024**3)
|
||||||
|
|
||||||
def load_model(self) -> None:
|
|
||||||
""" """
|
|
||||||
self.model_runner.load_model()
|
|
||||||
|
|
||||||
def get_model(self) -> nn.Layer:
|
class IluvatarPaddleDisWorkerProc(PaddleDisWorkerProc):
|
||||||
""" """
|
"""
|
||||||
return self.model_runner.get_model()
|
Paddle Distributed wrapper for fastdeploy.worker.Worker,
|
||||||
|
for handling single-node multi-GPU tensor parallel.
|
||||||
|
The wrapper internally executes an event loop that continuously executes requests
|
||||||
|
in the task queue. Control flow is transmitted by IPC.
|
||||||
|
"""
|
||||||
|
|
||||||
def initialize_cache(self, num_gpu_blocks: int) -> None:
|
def __init__(self, fd_config: FDConfig, ranks: int = 1, local_rank: int = 0):
|
||||||
""" """
|
super(IluvatarPaddleDisWorkerProc, self).__init__(
|
||||||
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)
|
fd_config=fd_config,
|
||||||
|
ranks=ranks,
|
||||||
|
local_rank=local_rank,
|
||||||
|
)
|
||||||
|
|
||||||
def execute_model(
|
def initialize_kv_cache(self) -> None:
|
||||||
self,
|
"""Profiles the peak memory usage of the model to determine how many
|
||||||
model_forward_batch: Optional[List[Request]] = None,
|
KV blocks may be allocated without OOMs.
|
||||||
num_running_requests: int = None,
|
|
||||||
) -> Optional[ModelRunnerOutput]:
|
|
||||||
""" """
|
|
||||||
output = self.model_runner.execute_model(model_forward_batch, num_running_requests)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def preprocess_new_task(self, req_dicts: List[Request], num_running_requests: int) -> None:
|
The engine will first conduct a profiling of the existing memory usage.
|
||||||
"""Process new requests and then start the decode loop
|
Then, it calculate the maximum possible number of GPU and CPU blocks
|
||||||
TODO(gongshaotian):The scheduler should schedule the handling of prefill,
|
that can be allocated with the remaining free memory.
|
||||||
and workers and modelrunners should not perceive it.
|
|
||||||
|
.. tip::
|
||||||
|
You may limit the usage of GPU memory
|
||||||
|
by adjusting the `gpu_memory_utilization` parameter.
|
||||||
"""
|
"""
|
||||||
self.model_runner.insert_prefill_inputs(req_dicts=req_dicts, num_running_requests=num_running_requests)
|
if self.fd_config.parallel_config.do_profile:
|
||||||
|
# 1. Get available memory(bytes)
|
||||||
|
available_kv_cache_memory = self.worker.determine_available_memory()
|
||||||
|
logger.info(f"------- available_kv_cache_memory:{available_kv_cache_memory / 1024**3} GB --------")
|
||||||
|
|
||||||
def graph_optimize_and_warm_up_model(self) -> None:
|
# 2. Calculate the appropriate number of blocks
|
||||||
"""
|
model_block_memory_used = self.worker.cal_theortical_kvcache()
|
||||||
Perform the warm-up and the graph optimization
|
num_blocks_local = int(available_kv_cache_memory // model_block_memory_used)
|
||||||
"""
|
# NOTE(liuzichang): Too many block will lead to illegal memory access
|
||||||
# 1. Warm up model
|
# We will develop dynamic limits in future.
|
||||||
# NOTE(gongshaotian): may be not need warm_up at this place
|
if num_blocks_local > 40000:
|
||||||
if self.model_runner.graph_opt_level >= 1:
|
logger.info(f"------- Reset num_blocks_local {num_blocks_local} to 40000")
|
||||||
self.model_runner.sot_warmup()
|
num_blocks_local = min(40000, num_blocks_local)
|
||||||
|
logger.info(f"------- model_block_memory_used:{model_block_memory_used} --------")
|
||||||
|
logger.info(f"------- num_blocks_local:{num_blocks_local} --------")
|
||||||
|
|
||||||
# 2. Triger cuda grpah capture
|
# NOTE(yuzhe.wu): Using the old version of the calculation num_blocks_global method,
|
||||||
self.model_runner.capture_model()
|
# because the new version that adopting allreduce min will report a bad request error
|
||||||
set_random_seed(self.fd_config.model_config.seed)
|
# when running 300b model. The Relation commit:
|
||||||
|
# https://github.com/PaddlePaddle/FastDeploy/commit/2f74e93d7e87aa3ffec3fc6966bf11ab5363b956
|
||||||
|
|
||||||
def check_health(self) -> bool:
|
# 3. Send IPCSignal
|
||||||
""" """
|
get_profile_block_num = np.zeros(shape=[self.ranks], dtype=np.int32)
|
||||||
return True
|
self.get_profile_block_num_signal = IPCSignal(
|
||||||
|
name="get_profile_block_num",
|
||||||
|
array=get_profile_block_num,
|
||||||
|
dtype=np.int32,
|
||||||
|
suffix=self.parallel_config.engine_pid,
|
||||||
|
create=False,
|
||||||
|
)
|
||||||
|
self.get_profile_block_num_signal.value[self.local_rank] = num_blocks_local
|
||||||
|
|
||||||
def cal_theortical_kvcache(self) -> int:
|
# Wait all worker send the signal
|
||||||
""" """
|
while np.any(self.get_profile_block_num_signal.value <= 0):
|
||||||
return self.model_runner.cal_theortical_kvcache()
|
time.sleep(0.01)
|
||||||
|
num_blocks_global = self.get_profile_block_num_signal.value.min().item()
|
||||||
|
|
||||||
|
if num_blocks_global < 0:
|
||||||
|
logger.error(
|
||||||
|
"The total number of blocks cannot be less than zero."
|
||||||
|
"Please increase gpu_memory_utilization"
|
||||||
|
"Or decrease max_num_batched_tokens(max model length) "
|
||||||
|
)
|
||||||
|
raise ValueError(
|
||||||
|
"The total number of blocks cannot be less than zero."
|
||||||
|
"Please increase gpu_memory_utilization"
|
||||||
|
"Or decrease max_num_batched_tokens(max model length) "
|
||||||
|
)
|
||||||
|
|
||||||
|
self.get_profile_block_num_signal.value[self.local_rank] = num_blocks_global
|
||||||
|
else:
|
||||||
|
num_blocks_global = self.fd_config.parallel_config.total_block_num
|
||||||
|
# 4. init kv_cache with accurate num_blocks
|
||||||
|
logger.info(f"------- num_blocks_global:{num_blocks_global} --------")
|
||||||
|
self.worker.initialize_cache(num_gpu_blocks=num_blocks_global)
|
||||||
|
@@ -723,7 +723,12 @@ def run_worker_proc() -> None:
|
|||||||
fd_config = initialize_fd_config(args, ranks, local_rank)
|
fd_config = initialize_fd_config(args, ranks, local_rank)
|
||||||
|
|
||||||
# Create worker process
|
# Create worker process
|
||||||
worker_proc = PaddleDisWorkerProc(fd_config, ranks, local_rank)
|
if current_platform.is_iluvatar():
|
||||||
|
from fastdeploy.worker.iluvatar_worker import IluvatarPaddleDisWorkerProc
|
||||||
|
|
||||||
|
worker_proc = IluvatarPaddleDisWorkerProc(fd_config, ranks, local_rank)
|
||||||
|
else:
|
||||||
|
worker_proc = PaddleDisWorkerProc(fd_config, ranks, local_rank)
|
||||||
|
|
||||||
# Initialize device and create model runner
|
# Initialize device and create model runner
|
||||||
worker_proc.init_device()
|
worker_proc.init_device()
|
||||||
|
@@ -1,11 +1,11 @@
|
|||||||
setuptools>=79.0.1,<80.0
|
setuptools>=62.3.0,<80.0
|
||||||
pre-commit
|
pre-commit
|
||||||
yapf
|
yapf
|
||||||
flake8
|
flake8
|
||||||
ruamel.yaml
|
ruamel.yaml
|
||||||
zmq
|
zmq
|
||||||
aiozmq
|
aiozmq
|
||||||
openai
|
openai>=1.93.0
|
||||||
tqdm
|
tqdm
|
||||||
pynvml
|
pynvml
|
||||||
uvicorn
|
uvicorn
|
||||||
@@ -24,7 +24,15 @@ setuptools-scm>=8
|
|||||||
prometheus-client
|
prometheus-client
|
||||||
decord
|
decord
|
||||||
moviepy
|
moviepy
|
||||||
|
wheel
|
||||||
use-triton-in-paddle
|
use-triton-in-paddle
|
||||||
crcmod
|
crcmod
|
||||||
fastsafetensors==0.1.14
|
fastsafetensors==0.1.14
|
||||||
msgpack
|
msgpack
|
||||||
|
opentelemetry-api>=1.24.0
|
||||||
|
opentelemetry-sdk>=1.24.0
|
||||||
|
opentelemetry-instrumentation-redis
|
||||||
|
opentelemetry-instrumentation-mysql
|
||||||
|
opentelemetry-distro
|
||||||
|
opentelemetry-exporter-otlp
|
||||||
|
opentelemetry-instrumentation-fastapi
|
||||||
|
@@ -13,10 +13,10 @@ python -m pip install -r requirements_iluvatar.txt
|
|||||||
echo "uninstall org"
|
echo "uninstall org"
|
||||||
python -m pip uninstall paddlepaddle -y
|
python -m pip uninstall paddlepaddle -y
|
||||||
python -m pip uninstall paddle-iluvatar-gpu -y
|
python -m pip uninstall paddle-iluvatar-gpu -y
|
||||||
python -m pip install paddlepaddle==3.1.0a0 -i https://www.paddlepaddle.org.cn/packages/stable/cpu/
|
python -m pip install --pre paddlepaddle==3.0.0.dev20250708 -i https://www.paddlepaddle.org.cn/packages/nightly/cpu/
|
||||||
# TODO: Change to open access URL
|
# TODO: Change to open access URL
|
||||||
# python -m pip install --pre paddle-iluvatar-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/ixuca/
|
python -m pip install --pre paddle-iluvatar-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/ixuca/
|
||||||
python -m pip install /data1/fastdeploy/packages/paddle_iluvatar_gpu-0.0.0-cp310-cp310-linux_x86_64.whl
|
# python -m pip install /data1/fastdeploy/packages/paddle_iluvatar_gpu-0.0.0-cp310-cp310-linux_x86_64.whl
|
||||||
# Patch, remove if image updated
|
# Patch, remove if image updated
|
||||||
cp /data1/fastdeploy/packages/cusolver.h /usr/local/lib/python3.10/site-packages/paddle/include/paddle/phi/backends/dynload/cusolver.h
|
cp /data1/fastdeploy/packages/cusolver.h /usr/local/lib/python3.10/site-packages/paddle/include/paddle/phi/backends/dynload/cusolver.h
|
||||||
echo "build whl"
|
echo "build whl"
|
||||||
@@ -30,6 +30,7 @@ rm -rf log/*
|
|||||||
export INFERENCE_MSG_QUEUE_ID=232132
|
export INFERENCE_MSG_QUEUE_ID=232132
|
||||||
export FD_DEBUG=1
|
export FD_DEBUG=1
|
||||||
export PADDLE_XCCL_BACKEND=iluvatar_gpu
|
export PADDLE_XCCL_BACKEND=iluvatar_gpu
|
||||||
|
export FD_SAMPLING_CLASS=rejection
|
||||||
python test/ci_use/iluvatar_UT/run_ernie300B_4layer.py
|
python test/ci_use/iluvatar_UT/run_ernie300B_4layer.py
|
||||||
exit_code=$?
|
exit_code=$?
|
||||||
echo exit_code is ${exit_code}
|
echo exit_code is ${exit_code}
|
||||||
|
@@ -10,7 +10,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.00001, max_tokens=16)
|
|||||||
# 加载模型
|
# 加载模型
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
model="/data1/fastdeploy/ERNIE_300B_4L",
|
model="/data1/fastdeploy/ERNIE_300B_4L",
|
||||||
tensor_parallel_size=16,
|
tensor_parallel_size=8,
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
static_decode_blocks=0,
|
static_decode_blocks=0,
|
||||||
quantization="wint8",
|
quantization="wint8",
|
||||||
@@ -27,14 +27,14 @@ assert outputs[0].outputs.token_ids == [
|
|||||||
59335,
|
59335,
|
||||||
68170,
|
68170,
|
||||||
183,
|
183,
|
||||||
49080,
|
97404,
|
||||||
94717,
|
100088,
|
||||||
82966,
|
36310,
|
||||||
99140,
|
95633,
|
||||||
31615,
|
95913,
|
||||||
51497,
|
41459,
|
||||||
94851,
|
95049,
|
||||||
60764,
|
94970,
|
||||||
10889,
|
96840,
|
||||||
2,
|
2,
|
||||||
]
|
], f"{outputs[0].outputs.token_ids}"
|
||||||
|
Reference in New Issue
Block a user