mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-29 13:52:26 +08:00
Compare commits
23 Commits
remove_use
...
v2.0.2
Author | SHA1 | Date | |
---|---|---|---|
![]() |
e421d51001 | ||
![]() |
c71d955e9c | ||
![]() |
2d2468ae72 | ||
![]() |
7deac64233 | ||
![]() |
5a5f17cf97 | ||
![]() |
0d61c65de1 | ||
![]() |
e5de28bff2 | ||
![]() |
b9eede57b6 | ||
![]() |
94e1a895e3 | ||
![]() |
87203ec87b | ||
![]() |
4596dd7248 | ||
![]() |
ec986642df | ||
![]() |
94691bcd90 | ||
![]() |
4025ea7e5b | ||
![]() |
e681e1e719 | ||
![]() |
823a47e64a | ||
![]() |
39d2a1de46 | ||
![]() |
1107e08cd9 | ||
![]() |
1fe37cb7e8 | ||
![]() |
337d76f094 | ||
![]() |
ae2f78184d | ||
![]() |
6851489425 | ||
![]() |
ea787d8f62 |
@@ -24,16 +24,18 @@
|
||||
#endif
|
||||
|
||||
#define MAX_BSZ 512
|
||||
#define K 10
|
||||
#define K 20
|
||||
|
||||
struct msgdata {
|
||||
long mtype;
|
||||
int mtext[MAX_BSZ * (K + 1) + 2]; // stop_flag, bsz, tokens
|
||||
float mtext_f[MAX_BSZ * (K + 1)]; // score
|
||||
int mtext_ranks[MAX_BSZ]; // ranks
|
||||
};
|
||||
|
||||
void GetOutputTopK(const paddle::Tensor& x,
|
||||
const paddle::Tensor& scores,
|
||||
const paddle::Tensor& ranks,
|
||||
int k,
|
||||
int64_t rank_id,
|
||||
bool wait_flag) {
|
||||
@@ -66,17 +68,18 @@ void GetOutputTopK(const paddle::Tensor& x,
|
||||
|
||||
int64_t* out_data = const_cast<int64_t*>(x.data<int64_t>());
|
||||
float* scores_data = const_cast<float*>(scores.data<float>());
|
||||
int64_t* ranks_data = const_cast<int64_t*>(ranks.data<int64_t>());
|
||||
int ret = -1;
|
||||
if (!wait_flag) {
|
||||
ret = msgrcv(msgid,
|
||||
&msg_rcv,
|
||||
(MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4,
|
||||
(MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4 + MAX_BSZ * 4,
|
||||
0,
|
||||
IPC_NOWAIT);
|
||||
} else {
|
||||
ret = msgrcv(msgid,
|
||||
&msg_rcv,
|
||||
(MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4,
|
||||
(MAX_BSZ * (K + 1) + 2) * 4 + MAX_BSZ * (K + 1) * 4 + MAX_BSZ * 4,
|
||||
0,
|
||||
0);
|
||||
}
|
||||
@@ -97,13 +100,14 @@ void GetOutputTopK(const paddle::Tensor& x,
|
||||
out_data[offset + 2] = (int64_t)msg_rcv.mtext[offset + 2];
|
||||
scores_data[offset] = msg_rcv.mtext_f[offset];
|
||||
}
|
||||
ranks_data[i] = (int64_t)msg_rcv.mtext_ranks[i];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(get_output_topk)
|
||||
.Inputs({"x", "scores"})
|
||||
.Inputs({"x", "scores", "ranks"})
|
||||
.Attrs({"k: int", "rank_id: int64_t", "wait_flag: bool"})
|
||||
.Outputs({"x_out", "scores_out"})
|
||||
.SetInplaceMap({{"x", "x_out"}, {"scores", "scores_out"}})
|
||||
.Outputs({"x_out", "scores_out", "ranks_out"})
|
||||
.SetInplaceMap({{"x", "x_out"}, {"scores", "scores_out"}, {"ranks", "ranks_out"}})
|
||||
.SetKernelFn(PD_KERNEL(GetOutputTopK));
|
||||
|
@@ -18,6 +18,7 @@
|
||||
|
||||
std::vector<paddle::Tensor> TopPSamplingReject(const paddle::Tensor &probs,
|
||||
const paddle::Tensor &top_p,
|
||||
const paddle::optional<paddle::Tensor> &top_k,
|
||||
int seed) {
|
||||
std::vector<int64_t> probs_shape = probs.shape();
|
||||
unsigned int batch_size = probs_shape[0];
|
||||
@@ -40,10 +41,18 @@ std::vector<paddle::Tensor> TopPSamplingReject(const paddle::Tensor &probs,
|
||||
|
||||
cudaError_t status;
|
||||
|
||||
status = sampling::TopKTopPSamplingFromProb<float, int64_t>(
|
||||
const_cast<float *>(probs.data<float>()), samples.data<int64_t>(),
|
||||
batch_size, top_p.data<float>(), vocab_size,
|
||||
true, philox_seed, philox_offset, cu_stream);
|
||||
if (top_k) {
|
||||
status = sampling::TopKTopPSamplingFromProb<float, int64_t>(
|
||||
const_cast<float *>(probs.data<float>()), samples.data<int64_t>(),
|
||||
batch_size, top_p.data<float>(), top_k.get().data<int64_t>(), vocab_size,
|
||||
true, philox_seed, philox_offset, cu_stream);
|
||||
}
|
||||
else {
|
||||
status = sampling::TopPSamplingFromProb<float, int64_t>(
|
||||
const_cast<float *>(probs.data<float>()), samples.data<int64_t>(),
|
||||
batch_size, top_p.data<float>(), vocab_size,
|
||||
true, philox_seed, philox_offset, cu_stream);
|
||||
}
|
||||
|
||||
PD_CHECK(status == cudaSuccess, "SamplingFromProbs failed with error code " +
|
||||
std::string(cudaGetErrorString(status)));
|
||||
@@ -53,19 +62,21 @@ std::vector<paddle::Tensor> TopPSamplingReject(const paddle::Tensor &probs,
|
||||
|
||||
std::vector<std::vector<int64_t>>
|
||||
TopPSamplingRejectInferShape(const std::vector<int64_t> &probs_shape,
|
||||
const std::vector<int64_t> &top_p_shape) {
|
||||
const std::vector<int64_t> &top_p_shape,
|
||||
const paddle::optional<std::vector<int64_t>> &top_k_shape) {
|
||||
int64_t bs = probs_shape[0];
|
||||
return {{bs, 1}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType>
|
||||
TopPSamplingRejectInferDtype(const paddle::DataType &probs_dtype,
|
||||
const paddle::DataType &top_p_shape) {
|
||||
const paddle::DataType &top_p_dtype,
|
||||
const paddle::optional<paddle::DataType> &top_k_dtype) {
|
||||
return {paddle::DataType::INT64};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(rejection_top_p_sampling)
|
||||
.Inputs({"probs", "top_p"})
|
||||
.Inputs({"probs", "top_p", paddle::Optional("top_k")})
|
||||
.Outputs({"samples"})
|
||||
.Attrs({"seed: int"})
|
||||
.SetKernelFn(PD_KERNEL(TopPSamplingReject))
|
||||
|
@@ -279,7 +279,8 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
|
||||
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
|
||||
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
|
||||
typename DType, typename IdType>
|
||||
__global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output, float* top_p_arr,
|
||||
__global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
|
||||
float* top_p_arr, IdType* top_k_arr,
|
||||
uint32_t d, uint64_t philox_seed,
|
||||
uint64_t philox_offset) {
|
||||
const uint32_t batch_size = gridDim.x;
|
||||
@@ -287,7 +288,7 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output, flo
|
||||
curandStatePhilox4_32_10_t state;
|
||||
curand_init(philox_seed, bx, philox_offset, &state);
|
||||
const uint32_t row_idx = bx;
|
||||
const uint32_t k = top_p_arr[row_idx] == 0 ? 1 : 20;
|
||||
const uint32_t k = top_k_arr[row_idx] == 0 ? d : top_k_arr[row_idx];
|
||||
const float p = top_p_arr[row_idx] == 0 ? 1e-6 : top_p_arr[row_idx];
|
||||
|
||||
extern __shared__ __align__(
|
||||
@@ -479,7 +480,7 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
|
||||
if (aggregate_gt_pivot_0 < top_p) {
|
||||
// case 1: pivot_0 accepted
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (aggregate_gt_pivot_1 < top_p) {
|
||||
// case 2: pivot_0 rejected, pivot_1 accepted
|
||||
low = pivot_0;
|
||||
@@ -497,6 +498,183 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
|
||||
}
|
||||
}
|
||||
|
||||
template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM,
|
||||
typename TempStorage>
|
||||
__device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, uint32_t d,
|
||||
TempStorage& temp_storage) {
|
||||
const uint32_t tx = threadIdx.x;
|
||||
vec_t<float, VEC_SIZE> in_data_vec;
|
||||
|
||||
float max_val = 0;
|
||||
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
||||
in_data_vec.fill(0);
|
||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||
in_data_vec.cast_load(in_data + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
|
||||
}
|
||||
float in_data_[VEC_SIZE];
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
||||
in_data_[j] = in_data_vec[j];
|
||||
}
|
||||
max_val = max(
|
||||
max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
|
||||
.Reduce<VEC_SIZE>(in_data_, cub::Max()));
|
||||
__syncthreads();
|
||||
}
|
||||
if (tx == 0) {
|
||||
temp_storage.max_val = max_val;
|
||||
}
|
||||
__syncthreads();
|
||||
return temp_storage.max_val;
|
||||
}
|
||||
|
||||
template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM>
|
||||
struct RenormTempStorage {
|
||||
union {
|
||||
typename BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce;
|
||||
typename BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce_int;
|
||||
typename BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
|
||||
reduce_value_count;
|
||||
} block_prim;
|
||||
struct {
|
||||
float max_val;
|
||||
float min_val;
|
||||
union {
|
||||
struct {
|
||||
float values[2];
|
||||
};
|
||||
struct {
|
||||
int counts[2];
|
||||
};
|
||||
struct {
|
||||
ValueCount<float> pairs[2];
|
||||
};
|
||||
} block_aggregate;
|
||||
};
|
||||
};
|
||||
|
||||
template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
|
||||
typename DType, typename IdType>
|
||||
__global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* top_k_arr, uint32_t d) {
|
||||
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
|
||||
const uint32_t row_idx = bx;
|
||||
const uint32_t k = top_k_arr[row_idx] == 0 ? d : top_k_arr[row_idx];
|
||||
double pivot = -cuda::std::numeric_limits<float>::infinity(), normalizer = 1;
|
||||
vec_t<float, VEC_SIZE> probs_vec;
|
||||
if (k < d) {
|
||||
extern __shared__ __align__(alignof(RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>))
|
||||
uint8_t smem_renorm[];
|
||||
auto& temp_storage =
|
||||
reinterpret_cast<RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>&>(smem_renorm);
|
||||
temp_storage.max_val = 0;
|
||||
|
||||
float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
|
||||
RenormTempStorage<BLOCK_THREADS, REDUCE_ALGORITHM>>(
|
||||
probs, row_idx, d, temp_storage);
|
||||
|
||||
double low = 0, high = max_val;
|
||||
float min_gt_low, max_le_high;
|
||||
float sum_low = 1;
|
||||
// f(x) = len(nonzero(probs > x)), f(x) is non-increasing
|
||||
// min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high}
|
||||
// loop invariant:
|
||||
// - f(low) >= k, f(high) < k
|
||||
// - f(low) > f(min_gt_low) >= f(max_le_high) == f(high)
|
||||
// stopping condition: min_gt_low == max_le_high
|
||||
// - f(low) >= k, f(min_gt_low) == f(max_le_high) == f(high) < k
|
||||
do {
|
||||
double pivot_0 = (high + 2 * low) / 3;
|
||||
double pivot_1 = (2 * high + low) / 3;
|
||||
|
||||
ValueCount<float> aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0};
|
||||
min_gt_low = high;
|
||||
max_le_high = low;
|
||||
#pragma unroll 2
|
||||
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
||||
probs_vec.fill(0);
|
||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||
probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
||||
}
|
||||
ValueCount<float> probs_gt_pivot_0_pair[VEC_SIZE], probs_gt_pivot_1_pair[VEC_SIZE];
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
||||
probs_gt_pivot_0_pair[j] = {
|
||||
(probs_vec[j] > pivot_0) ? probs_vec[j] : 0,
|
||||
(probs_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
|
||||
probs_gt_pivot_1_pair[j] = {
|
||||
(probs_vec[j] > pivot_1) ? probs_vec[j] : 0,
|
||||
(probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
|
||||
|
||||
if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
|
||||
min_gt_low = min(min_gt_low, probs_vec[j]);
|
||||
}
|
||||
if (probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
|
||||
max_le_high = max(max_le_high, probs_vec[j]);
|
||||
}
|
||||
}
|
||||
|
||||
aggregate_gt_pivot_0 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||
temp_storage.block_prim.reduce_value_count)
|
||||
.Sum<VEC_SIZE>(probs_gt_pivot_0_pair);
|
||||
__syncthreads();
|
||||
|
||||
aggregate_gt_pivot_1 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||
temp_storage.block_prim.reduce_value_count)
|
||||
.Sum<VEC_SIZE>(probs_gt_pivot_1_pair);
|
||||
__syncthreads();
|
||||
}
|
||||
min_gt_low =
|
||||
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
|
||||
.Reduce(min_gt_low, cub::Min());
|
||||
__syncthreads();
|
||||
max_le_high =
|
||||
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
|
||||
.Reduce(max_le_high, cub::Max());
|
||||
if (tx == 0) {
|
||||
temp_storage.block_aggregate.pairs[0] = aggregate_gt_pivot_0;
|
||||
temp_storage.block_aggregate.pairs[1] = aggregate_gt_pivot_1;
|
||||
temp_storage.min_val = min_gt_low;
|
||||
temp_storage.max_val = max_le_high;
|
||||
}
|
||||
__syncthreads();
|
||||
aggregate_gt_pivot_0 = temp_storage.block_aggregate.pairs[0];
|
||||
aggregate_gt_pivot_1 = temp_storage.block_aggregate.pairs[1];
|
||||
min_gt_low = temp_storage.min_val;
|
||||
max_le_high = temp_storage.max_val;
|
||||
|
||||
if (aggregate_gt_pivot_1.count >= k) {
|
||||
low = pivot_1;
|
||||
sum_low = float(aggregate_gt_pivot_1.value);
|
||||
} else if (aggregate_gt_pivot_0.count >= k) {
|
||||
low = pivot_0;
|
||||
high = min(pivot_1, max_le_high);
|
||||
sum_low = float(aggregate_gt_pivot_0.value);
|
||||
} else {
|
||||
high = min(pivot_0, max_le_high);
|
||||
}
|
||||
} while (min_gt_low != max_le_high);
|
||||
|
||||
normalizer = ptx_rcp(max(sum_low, 1e-8));
|
||||
pivot = low;
|
||||
}
|
||||
|
||||
// normalize
|
||||
#pragma unroll 2
|
||||
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
||||
probs_vec.fill(0);
|
||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||
probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
||||
probs_vec[j] = (probs_vec[j] > pivot) ? probs_vec[j] * normalizer : 0;
|
||||
}
|
||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||
probs_vec.store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename IdType>
|
||||
cudaError_t TopPSamplingFromProb(T *probs, IdType *output,
|
||||
uint32_t batch_size, const T *top_p_val,
|
||||
@@ -529,7 +707,7 @@ cudaError_t TopPSamplingFromProb(T *probs, IdType *output,
|
||||
|
||||
template <typename T, typename IdType>
|
||||
cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output,
|
||||
uint32_t batch_size, const T *top_p_val,
|
||||
uint32_t batch_size, const T *top_p_val, const IdType *top_k_val,
|
||||
uint32_t d, bool deterministic,
|
||||
uint64_t philox_seed, uint64_t philox_offset,
|
||||
cudaStream_t stream = 0) {
|
||||
@@ -540,7 +718,7 @@ cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output,
|
||||
const uint32_t smem_size = sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
|
||||
dim3 nblks(batch_size);
|
||||
dim3 nthrs(BLOCK_THREADS);
|
||||
void* args[] = {&probs, &output, &top_p_val,
|
||||
void* args[] = {&probs, &output, &top_p_val, &top_k_val,
|
||||
&d, &philox_seed, &philox_offset};
|
||||
|
||||
DISPATCH_ALIGNED_VEC_SIZE(
|
||||
@@ -556,4 +734,26 @@ cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output,
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace sampling
|
||||
template <typename DType, typename IdType>
|
||||
cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr,
|
||||
uint32_t batch_size, uint32_t d,
|
||||
cudaStream_t stream = 0) {
|
||||
const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
|
||||
|
||||
auto compute_capacity = GetCudaComputeCapability();
|
||||
DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, {
|
||||
const uint32_t smem_size = sizeof(RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>);
|
||||
dim3 nblks(batch_size);
|
||||
dim3 nthrs(BLOCK_THREADS);
|
||||
void* args[] = {&probs, &renormed_prob, &top_k_arr, &d};
|
||||
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
|
||||
auto kernel = TopKRenormProbKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE, DType, IdType>;
|
||||
CUDA_CALL(
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
|
||||
});
|
||||
return cudaSuccess;
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace sampling
|
||||
|
61
custom_ops/gpu_ops/sample_kernels/top_k_renorm_probs.cu
Normal file
61
custom_ops/gpu_ops/sample_kernels/top_k_renorm_probs.cu
Normal file
@@ -0,0 +1,61 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
#include "paddle/phi/backends/context_pool.h"
|
||||
#include "sample_kernels/sampling.cuh"
|
||||
|
||||
std::vector<paddle::Tensor> TopKRenorm(const paddle::Tensor &probs,
|
||||
const paddle::Tensor &top_k) {
|
||||
std::vector<int64_t> probs_shape = probs.shape();
|
||||
uint32_t batch_size = probs_shape[0];
|
||||
uint32_t vocab_size = probs_shape[1];
|
||||
auto cu_stream = probs.stream();
|
||||
|
||||
auto renorm_probs =
|
||||
GetEmptyTensor(probs.dims(), paddle::DataType::FLOAT32, probs.place());
|
||||
|
||||
cudaError_t status;
|
||||
|
||||
|
||||
status = sampling::TopKRenormProb<float>(
|
||||
const_cast<float *>(probs.data<float>()),
|
||||
renorm_probs.data<float>(),
|
||||
const_cast<int64_t *>(top_k.data<int64_t>()),
|
||||
batch_size, vocab_size, cu_stream);
|
||||
|
||||
PD_CHECK(status == cudaSuccess, "TopKRenormProb failed with error code " +
|
||||
std::string(cudaGetErrorString(status)));
|
||||
|
||||
return {renorm_probs};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>>
|
||||
TopKRenormInferShape(const std::vector<int64_t> &probs_shape,
|
||||
const std::vector<int64_t> &top_k_shape) {
|
||||
return {probs_shape};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType>
|
||||
TopKRenormInferDtype(const paddle::DataType &probs_dtype,
|
||||
const paddle::DataType &top_k_shape) {
|
||||
return {probs_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(top_k_renorm_probs)
|
||||
.Inputs({"probs", "top_k"})
|
||||
.Outputs({"renorm_probs"})
|
||||
.SetKernelFn(PD_KERNEL(TopKRenorm))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(TopKRenormInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(TopKRenormInferDtype));
|
@@ -23,34 +23,34 @@
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
#define MAX_BSZ 128
|
||||
#define K 10
|
||||
#define MAX_BSZ 512
|
||||
#define K 20
|
||||
// #define SAVE_WITH_OUTPUT_DEBUG
|
||||
|
||||
struct msgdata {
|
||||
long mtype;
|
||||
int mtext[MAX_BSZ * (K + 1) + 2]; // stop_flag, bsz, tokens
|
||||
float mtext_f[MAX_BSZ * (K + 1)]; // score
|
||||
int mtext_ranks[MAX_BSZ]; // ranks
|
||||
};
|
||||
|
||||
void SaveOutMmsgTopK(const paddle::Tensor& x,
|
||||
const paddle::Tensor& scores,
|
||||
const paddle::Tensor& topk_ids,
|
||||
const paddle::Tensor& topk_scores, // [bsz, k]
|
||||
const paddle::Tensor& logprob_token_ids, // [bsz, k+1]
|
||||
const paddle::Tensor& logprob_scores, // [bsz, k+1]
|
||||
const paddle::Tensor& ranks,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
int k,
|
||||
int64_t rank_id) {
|
||||
if (rank_id > 0) {
|
||||
return;
|
||||
}
|
||||
auto x_cpu = x.copy_to(paddle::CPUPlace(), false);
|
||||
auto scores_cpu = scores.copy_to(paddle::CPUPlace(), false);
|
||||
auto topk_ids_cpu = topk_ids.copy_to(paddle::CPUPlace(), false);
|
||||
auto topk_scores_cpu = topk_scores.copy_to(paddle::CPUPlace(), false);
|
||||
auto logprob_token_ids_cpu = logprob_token_ids.copy_to(paddle::CPUPlace(), false);
|
||||
auto logprob_scores_cpu = logprob_scores.copy_to(paddle::CPUPlace(), false);
|
||||
auto ranks_cpu = ranks.copy_to(paddle::CPUPlace(), false);
|
||||
int64_t* x_data = x_cpu.data<int64_t>();
|
||||
float* scores_data = scores_cpu.data<float>();
|
||||
int64_t* topk_ids_data = topk_ids_cpu.data<int64_t>();
|
||||
float* topk_scores_data = topk_scores_cpu.data<float>();
|
||||
int64_t* logprob_token_ids_data = logprob_token_ids_cpu.data<int64_t>();
|
||||
float* logprob_scores_data = logprob_scores_cpu.data<float>();
|
||||
int64_t* ranks_data = ranks_cpu.data<int64_t>();
|
||||
static struct msgdata msg_sed;
|
||||
int msg_queue_id = 1;
|
||||
if (const char* inference_msg_queue_id_env_p =
|
||||
@@ -106,21 +106,23 @@ void SaveOutMmsgTopK(const paddle::Tensor& x,
|
||||
msg_sed.mtext[0] = not_need_stop_data ? inference_msg_id_from_env
|
||||
: -inference_msg_id_from_env;
|
||||
int bsz = x.shape()[0];
|
||||
int max_num_logprobs = logprob_token_ids.shape()[1];
|
||||
msg_sed.mtext[1] = bsz;
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
for (int j = 0; j < k + 1; j++) {
|
||||
for (int j = 0; j < K + 1; j++) {
|
||||
const int64_t offset = i * (K + 1) + j;
|
||||
if (j == 0) {
|
||||
msg_sed.mtext[offset + 2] = (int)x_data[i];
|
||||
msg_sed.mtext_f[offset] = scores_data[i];
|
||||
} else if (j <= k + 1) {
|
||||
msg_sed.mtext[offset + 2] = (int)topk_ids_data[i * k + j - 1];
|
||||
msg_sed.mtext_f[offset] = topk_scores_data[i * k + j - 1];
|
||||
msg_sed.mtext_f[offset] = logprob_scores_data[i * max_num_logprobs + j];
|
||||
} else if (j < max_num_logprobs) {
|
||||
msg_sed.mtext[offset + 2] = (int)logprob_token_ids_data[i * max_num_logprobs + j];
|
||||
msg_sed.mtext_f[offset] = logprob_scores_data[i * max_num_logprobs + j];
|
||||
} else {
|
||||
msg_sed.mtext[offset + 2] = -1;
|
||||
msg_sed.mtext_f[offset] = 0.0;
|
||||
}
|
||||
}
|
||||
msg_sed.mtext_ranks[i] = (int)ranks_data[i];
|
||||
}
|
||||
#ifdef SAVE_WITH_OUTPUT_DEBUG
|
||||
std::cout << "msg data: ";
|
||||
@@ -131,7 +133,7 @@ void SaveOutMmsgTopK(const paddle::Tensor& x,
|
||||
#endif
|
||||
if ((msgsnd(msgid,
|
||||
&msg_sed,
|
||||
(MAX_BSZ * (K + 1) + 2) * 4 + (MAX_BSZ * (K + 1)) * 4,
|
||||
(MAX_BSZ * (K + 1) + 2) * 4 + (MAX_BSZ * (K + 1)) * 4 + MAX_BSZ * 4,
|
||||
0)) == -1) {
|
||||
printf("full msg buffer\n");
|
||||
}
|
||||
@@ -139,8 +141,8 @@ void SaveOutMmsgTopK(const paddle::Tensor& x,
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(save_output_topk)
|
||||
.Inputs({"x", "scores", "topk_ids", "topk_scores", "not_need_stop"})
|
||||
.Attrs({"k: int", "rank_id: int64_t"})
|
||||
.Inputs({"x", "topk_ids", "logprob_scores", "ranks", "not_need_stop"})
|
||||
.Attrs({"rank_id: int64_t"})
|
||||
.Outputs({"x_out"})
|
||||
.SetInplaceMap({{"x", "x_out"}})
|
||||
.SetKernelFn(PD_KERNEL(SaveOutMmsgTopK));
|
||||
|
@@ -267,6 +267,7 @@ elif paddle.is_compiled_with_cuda():
|
||||
"gpu_ops/text_image_index_out.cu",
|
||||
"gpu_ops/text_image_gather_scatter.cu",
|
||||
"gpu_ops/sample_kernels/rejection_top_p_sampling.cu",
|
||||
"gpu_ops/sample_kernels/top_k_renorm_probs.cu",
|
||||
"gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
|
||||
"gpu_ops/fused_rotary_position_encoding.cu",
|
||||
"gpu_ops/noaux_tc.cu",
|
||||
|
@@ -22,6 +22,7 @@ setup(
|
||||
"gpu_ops/save_with_output_msg.cc",
|
||||
"gpu_ops/get_output.cc",
|
||||
"gpu_ops/get_output_msg_with_topk.cc",
|
||||
"gpu_ops/save_output_msg_with_topk.cc",
|
||||
"gpu_ops/transfer_output.cc",
|
||||
"cpu_ops/rebuild_padding.cc",
|
||||
],
|
||||
|
@@ -29,6 +29,7 @@ for output in outputs:
|
||||
```
|
||||
|
||||
### Chat Interface (LLM.chat)
|
||||
|
||||
```python
|
||||
from fastdeploy import LLM, SamplingParams
|
||||
|
||||
@@ -99,6 +100,7 @@ For ```LLM``` configuration, refer to [Parameter Documentation](parameters.md).
|
||||
* repetition_penalty(float): Direct penalty for repeated tokens (>1 penalizes, <1 encourages)
|
||||
* temperature(float): Controls randomness (higher = more random)
|
||||
* top_p(float): Probability threshold for token selection
|
||||
* top_k(int): Number of tokens considered for sampling
|
||||
* max_tokens(int): Maximum generated tokens (input + output)
|
||||
* min_tokens(int): Minimum forced generation length
|
||||
|
||||
@@ -129,4 +131,4 @@ For ```LLM``` configuration, refer to [Parameter Documentation](parameters.md).
|
||||
* first_token_time(float): First token latency
|
||||
* time_in_queue(float): Queuing time
|
||||
* model_forward_time(float): Forward pass duration
|
||||
* model_execute_time(float): Total execution time (including preprocessing)
|
||||
* model_execute_time(float): Total execution time (including preprocessing)
|
||||
|
@@ -52,7 +52,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"FD_ATTENTION_BACKEND":
|
||||
lambda: os.getenv("FD_ATTENTION_BACKEND", "APPEND_ATTN"),
|
||||
|
||||
# Sampling class ("base", "air", or "rejection")
|
||||
# Sampling class ("base", "base_non_truncated", "air", or "rejection")
|
||||
"FD_SAMPLING_CLASS":
|
||||
lambda: os.getenv("FD_SAMPLING_CLASS", "base"),
|
||||
|
||||
@@ -67,6 +67,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# Switch from standalone PD to centralized inference (0 or 1)
|
||||
"FD_PD_CHANGEABLE":
|
||||
lambda: os.getenv("FD_PD_CHANGEABLE", "1"),
|
||||
|
||||
|
||||
}
|
||||
```
|
||||
```
|
||||
|
@@ -100,6 +100,7 @@ for output in outputs:
|
||||
* repetition_penalty(float): 直接对重复生成的token进行惩罚的系数(>1时惩罚重复,<1时鼓励重复)
|
||||
* temperature(float): 控制生成随机性的参数,值越高结果越随机,值越低结果越确定
|
||||
* top_p(float): 概率累积分布截断阈值,仅考虑累计概率达到此阈值的最可能token集合
|
||||
* top_k(int): 采样概率最高的token数量,考虑概率最高的k个token进行采样
|
||||
* max_tokens(int): 限制模型生成的最大token数量(包括输入和输出)
|
||||
* min_tokens(int): 强制模型生成的最少token数量,避免过早结束
|
||||
|
||||
|
@@ -1,5 +1,6 @@
|
||||
# FastDeploy 环境变量说明
|
||||
FastDeploy 的环境变量保存在了代码库根目录下 fastdeploy/envs.py 文件中,以下是其对应的中文版说明:
|
||||
|
||||
```python
|
||||
environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# 构建 FastDeploy 时使用的 CUDA 架构版本,这是一个字符串列表,例如[80,90]
|
||||
@@ -50,7 +51,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"FD_ATTENTION_BACKEND":
|
||||
lambda: os.getenv("FD_ATTENTION_BACKEND", "APPEND_ATTN"),
|
||||
|
||||
# 设置采样类别,当前可设置为 "base"、"air" 或 "rejection"
|
||||
# 设置采样类别,当前可设置为 "base"、"base_non_truncated"、"air" 或 "rejection"
|
||||
"FD_SAMPLING_CLASS":
|
||||
lambda: os.getenv("FD_SAMPLING_CLASS", "base"),
|
||||
|
||||
@@ -65,6 +66,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# 是否从单机 PD 分离转换为集中式推理
|
||||
"FD_PD_CHANGEABLE":
|
||||
lambda: os.getenv("FD_PD_CHANGEABLE", "1"),
|
||||
|
||||
|
||||
}
|
||||
```
|
||||
```
|
||||
|
@@ -84,6 +84,7 @@ class ModelConfig(PretrainedConfig):
|
||||
head_dim: Optional[int] = None,
|
||||
tie_word_embeddings: bool = False,
|
||||
is_quantized: bool = False,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
@@ -123,6 +124,7 @@ class ModelConfig(PretrainedConfig):
|
||||
self.dtype = dtype
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
self.is_quantized = is_quantized
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@@ -17,13 +17,15 @@
|
||||
import paddle
|
||||
import paddle.distributed as dist
|
||||
|
||||
|
||||
@paddle.jit.marker.unified
|
||||
def tensor_model_parallel_all_reduce(input_: paddle.Tensor) -> paddle.Tensor:
|
||||
"""All-reduce the input tensor across model parallel group."""
|
||||
if paddle.in_dynamic_mode():
|
||||
hcg = dist.fleet.get_hybrid_communicate_group()
|
||||
mp_group = hcg.get_model_parallel_group()
|
||||
dist.all_reduce(input_, group=mp_group)
|
||||
else:
|
||||
dist.all_reduce(input_)
|
||||
try:
|
||||
@paddle.jit.marker.unified
|
||||
def tensor_model_parallel_all_reduce(input_: paddle.Tensor) -> paddle.Tensor:
|
||||
"""All-reduce the input tensor across model parallel group."""
|
||||
if paddle.in_dynamic_mode():
|
||||
hcg = dist.fleet.get_hybrid_communicate_group()
|
||||
mp_group = hcg.get_model_parallel_group()
|
||||
dist.all_reduce(input_, group=mp_group)
|
||||
else:
|
||||
dist.all_reduce(input_)
|
||||
except:
|
||||
tensor_model_parallel_all_reduce=None
|
@@ -296,6 +296,12 @@ class EngineArgs:
|
||||
max_capture_batch_size=64, FastDeploy will capture graphs for batches [1,64].
|
||||
"""
|
||||
|
||||
enable_logprob: bool = False
|
||||
"""
|
||||
Flag to enable logprob output. Default is False (disabled).
|
||||
Must be explicitly enabled via the `--enable-logprob` startup parameter to output logprob values.
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
Post-initialization processing to set default tokenizer if not provided.
|
||||
@@ -416,6 +422,11 @@ class EngineArgs:
|
||||
help=
|
||||
"Disabled any whitespaces when using guided decoding backend XGrammar."
|
||||
)
|
||||
model_group.add_argument("--enable-logprob",
|
||||
action="store_true",
|
||||
default=EngineArgs.enable_logprob,
|
||||
help="Enable output of token-level log probabilities."
|
||||
)
|
||||
|
||||
# Parallel processing parameters group
|
||||
parallel_group = parser.add_argument_group("Parallel Configuration")
|
||||
@@ -791,4 +802,5 @@ class EngineArgs:
|
||||
max_capture_batch_size=self.max_capture_batch_size,
|
||||
guided_decoding_backend=self.guided_decoding_backend,
|
||||
disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
|
||||
enable_logprob = self.enable_logprob,
|
||||
)
|
||||
|
@@ -16,6 +16,7 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
@@ -467,7 +468,63 @@ class ParallelConfig:
|
||||
llm_logger.info("Parallel Configuration Information :")
|
||||
for k, v in self.__dict__.items():
|
||||
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
||||
llm_logger.info("==================")
|
||||
llm_logger.info(
|
||||
"=============================================================")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommitConfig:
|
||||
"""
|
||||
Configuration for tracking version information from version.txt
|
||||
|
||||
Attributes:
|
||||
fastdeploy_commit: Full FastDeploy git commit hash
|
||||
paddle_version: PaddlePaddle version string
|
||||
paddle_commit: PaddlePaddle git commit hash
|
||||
cuda_version: CUDA version string
|
||||
compiler_version: CXX compiler version string
|
||||
"""
|
||||
fastdeploy_commit: str = ""
|
||||
paddle_version: str = ""
|
||||
paddle_commit: str = ""
|
||||
cuda_version: str = ""
|
||||
compiler_version: str = ""
|
||||
|
||||
def __post_init__(self):
|
||||
"""Automatically load version info when initialized"""
|
||||
self._load_from_version_file()
|
||||
|
||||
def _load_from_version_file(self, file_path: str = "fastdeploy/version.txt"):
|
||||
"""Internal method to load version info from file"""
|
||||
try:
|
||||
with open(file_path, 'r') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line.startswith("fastdeploy GIT COMMIT ID:"):
|
||||
self.fastdeploy_commit = line.split(":")[1].strip()
|
||||
elif line.startswith("Paddle version:"):
|
||||
self.paddle_version = line.split(":")[1].strip()
|
||||
elif line.startswith("Paddle GIT COMMIT ID:"):
|
||||
self.paddle_commit = line.split(":")[1].strip()
|
||||
elif line.startswith("CUDA version:"):
|
||||
self.cuda_version = line.split(":")[1].strip()
|
||||
elif line.startswith("CXX compiler version:"):
|
||||
self.compiler_version = line.split(":")[1].strip()
|
||||
except FileNotFoundError:
|
||||
llm_logger.info(f"Warning: Version file not found at {file_path}")
|
||||
except Exception as e:
|
||||
llm_logger.info(f"Warning: Could not read version file - {str(e)}")
|
||||
|
||||
def print(self):
|
||||
"""
|
||||
print all config
|
||||
|
||||
"""
|
||||
llm_logger.info("Fasedeploy Commit Information :")
|
||||
for k, v in self.__dict__.items():
|
||||
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
||||
llm_logger.info(
|
||||
"=============================================================")
|
||||
|
||||
|
||||
class Config:
|
||||
@@ -502,6 +559,7 @@ class Config:
|
||||
cache_config: CacheConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
commit_config: CommitConfig = CommitConfig(),
|
||||
model_name_or_path: str = None,
|
||||
tokenizer: str = None,
|
||||
tensor_parallel_size: int = 8,
|
||||
@@ -527,6 +585,7 @@ class Config:
|
||||
max_capture_batch_size: int = 64,
|
||||
guided_decoding_backend: Optional[str] = None,
|
||||
disable_any_whitespace: bool = False,
|
||||
enable_logprob: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the Config class.
|
||||
@@ -561,6 +620,7 @@ class Config:
|
||||
self.cache_config = cache_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.parallel_config = parallel_config
|
||||
self.commit_config = commit_config
|
||||
self.model_name_or_path = model_name_or_path
|
||||
self.tokenizer = tokenizer
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
@@ -619,6 +679,8 @@ class Config:
|
||||
self.parallel_config.expert_parallel_size), 8))])
|
||||
self.device_ids = os.getenv("CUDA_VISIBLE_DEVICES", self.device_ids)
|
||||
|
||||
self.enable_logprob = enable_logprob
|
||||
|
||||
self.read_from_config()
|
||||
self.postprocess()
|
||||
self.check()
|
||||
@@ -749,7 +811,11 @@ class Config:
|
||||
if k == "generation_config" and v is not None:
|
||||
for gck, gcv in v.to_dict().items():
|
||||
llm_logger.info("{:<20}:{:<6}{}".format(gck, "", gcv))
|
||||
elif k == "cache_config" or k == "model_config" or k == "scheduler_config" or k == "parallel_config":
|
||||
elif (k == "cache_config" or
|
||||
k == "model_config" or
|
||||
k == "scheduler_config" or
|
||||
k == "parallel_config" or
|
||||
k == "commit_config"):
|
||||
v.print()
|
||||
else:
|
||||
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
||||
|
@@ -47,7 +47,8 @@ from fastdeploy.output.token_processor import (TokenProcessor,
|
||||
WarmUpTokenProcessor)
|
||||
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
|
||||
from fastdeploy.utils import EngineError, console_logger, llm_logger
|
||||
|
||||
from fastdeploy.metrics.trace_util import extract_from_metadata, start_span, start_span_request
|
||||
from opentelemetry import trace
|
||||
|
||||
class LLMEngine(object):
|
||||
"""
|
||||
@@ -165,12 +166,6 @@ class LLMEngine(object):
|
||||
disable_any_whitespace=self.cfg.disable_any_whitespace,
|
||||
)
|
||||
|
||||
def reset_scheduler(self):
|
||||
"""
|
||||
Reset the scheduler to its initial state.
|
||||
"""
|
||||
self.scheduler.reset()
|
||||
|
||||
def start(self, api_server_pid=None):
|
||||
"""
|
||||
Initializes the engine and starts its sub-services.
|
||||
@@ -381,7 +376,10 @@ class LLMEngine(object):
|
||||
request, insert_task = None, []
|
||||
results: List[Tuple[str, Optional[str]]] = list()
|
||||
if data:
|
||||
request = Request.from_dict(data)
|
||||
request = Request.from_dict(data)
|
||||
start_span("ENQUEUE_ZMQ", data, trace.SpanKind.PRODUCER)
|
||||
|
||||
|
||||
llm_logger.debug(f"Receive request: {request}")
|
||||
|
||||
err_msg = None
|
||||
@@ -712,6 +710,8 @@ class LLMEngine(object):
|
||||
"""
|
||||
Insert tasks to engine.
|
||||
"""
|
||||
for task in tasks:
|
||||
start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER)
|
||||
# TODO 返回至 scheduler
|
||||
if allocated:
|
||||
current_tasks = []
|
||||
@@ -1021,8 +1021,8 @@ class LLMEngine(object):
|
||||
py_script = os.path.join(current_dir_path, worker_path)
|
||||
|
||||
ori_vocab_size = (
|
||||
len(self.data_processor.tokenizer.sp_model)
|
||||
if hasattr(self.data_processor.tokenizer, 'sp_model')
|
||||
len(self.data_processor.tokenizer.sp_model)
|
||||
if hasattr(self.data_processor.tokenizer, 'sp_model')
|
||||
else len(self.data_processor.tokenizer.vocab)
|
||||
)
|
||||
|
||||
@@ -1068,6 +1068,7 @@ class LLMEngine(object):
|
||||
self.cfg.enable_static_graph_inference,
|
||||
"use_cudagraph": self.cfg.use_cudagraph,
|
||||
"disable_any_whitespace": self.cfg.disable_any_whitespace,
|
||||
"enable_logprob": self.cfg.enable_logprob,
|
||||
}
|
||||
for worker_flag, value in worker_append_flag.items():
|
||||
if value:
|
||||
|
@@ -24,6 +24,7 @@ import numpy
|
||||
|
||||
from fastdeploy.engine.sampling_params import SamplingParams
|
||||
from fastdeploy.utils import data_processor_logger
|
||||
from fastdeploy.worker.output import LogprobsLists
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -54,7 +55,8 @@ class Request:
|
||||
guided_grammar: Optional[Any] = None,
|
||||
structural_tag: Optional[Any] = None,
|
||||
guided_json_object: Optional[bool] = None,
|
||||
enable_thinking: Optional[bool] = True) -> None:
|
||||
enable_thinking: Optional[bool] = True,
|
||||
trace_carrier: dict = dict()) -> None:
|
||||
self.request_id = request_id
|
||||
self.prompt = prompt
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
@@ -90,6 +92,7 @@ class Request:
|
||||
self.multimodal_data = multimodal_data
|
||||
|
||||
self.enable_thinking = enable_thinking
|
||||
self.trace_carrier = trace_carrier
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict):
|
||||
@@ -119,7 +122,8 @@ class Request:
|
||||
guided_grammar=d.get("guided_grammar", None),
|
||||
structural_tag=d.get("structural_tag", None),
|
||||
guided_json_object=d.get("guided_json_object", None),
|
||||
enable_thinking=d.get("enable_thinking", True))
|
||||
enable_thinking=d.get("enable_thinking", True),
|
||||
trace_carrier=d.get("trace_carrier", {}))
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""convert Request into a serializable dict """
|
||||
@@ -141,7 +145,8 @@ class Request:
|
||||
"raw_request": self.raw_request,
|
||||
"disaggregate_info": self.disaggregate_info,
|
||||
"draft_token_ids": self.draft_token_ids,
|
||||
"enable_thinking": self.enable_thinking
|
||||
"enable_thinking": self.enable_thinking,
|
||||
"trace_carrier": self.trace_carrier
|
||||
}
|
||||
add_params = [
|
||||
"guided_json", "guided_regex", "guided_choice", "guided_grammar",
|
||||
@@ -189,6 +194,8 @@ class CompletionOutput:
|
||||
index: int
|
||||
send_idx: int
|
||||
token_ids: list[int]
|
||||
logprob: Optional[float] = None
|
||||
top_logprobs: Optional[LogprobsLists] = None
|
||||
draft_token_ids: list[int] = None
|
||||
text: Optional[str] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
@@ -201,6 +208,8 @@ class CompletionOutput:
|
||||
"index": self.index,
|
||||
"send_idx": self.send_idx,
|
||||
"token_ids": self.token_ids,
|
||||
"logprob": self.logprob,
|
||||
"top_logprobs": self.top_logprobs,
|
||||
"draft_token_ids": self.draft_token_ids,
|
||||
"text": self.text,
|
||||
"reasoning_content": self.reasoning_content
|
||||
|
@@ -52,6 +52,7 @@ class SamplingParams:
|
||||
the model more random. Zero means greedy sampling.
|
||||
top_p: Float that controls the cumulative probability of the top tokens
|
||||
to consider. Must be in [0, 1]. Set to 1 to consider all tokens.
|
||||
top_k: Int that controls the number of top tokens to consider. Must be a positive integer.
|
||||
seed: Random seed to use for the generation.
|
||||
stop: list of strings that stop the generation when they are generated.
|
||||
The returned output will not contain the stop strings.
|
||||
@@ -82,6 +83,7 @@ class SamplingParams:
|
||||
repetition_penalty: float = None
|
||||
temperature: float = None
|
||||
top_p: float = None
|
||||
top_k: int = 0
|
||||
seed: Optional[int] = None
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
stop_token_ids: Optional[Union[List[List[int]], List[int]]] = None
|
||||
@@ -111,6 +113,7 @@ class SamplingParams:
|
||||
repetition_penalty,
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
seed=None,
|
||||
stop=None,
|
||||
stop_token_ids=None,
|
||||
@@ -129,7 +132,8 @@ class SamplingParams:
|
||||
repetition_penalty=repetition_penalty
|
||||
if repetition_penalty is not None else 1.0,
|
||||
temperature=temperature if temperature is not None else 1.0,
|
||||
top_p=top_p if top_p is not None else 0.7,
|
||||
top_p=top_p,
|
||||
top_k=top_k if top_k is not None else 0,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
@@ -169,6 +173,13 @@ class SamplingParams:
|
||||
f"temperature must be non-negative, got {self.temperature}.")
|
||||
if self.top_p is not None and not 0.0 <= self.top_p <= 1.0:
|
||||
raise ValueError(f"top_p must be in [0, 1], got {self.top_p}.")
|
||||
# quietly accept -1 as disabled, but prefer 0
|
||||
if self.top_k < -1:
|
||||
raise ValueError(f"top_k must be 0 (disable), or at least 1, "
|
||||
f"got {self.top_k}.")
|
||||
if not isinstance(self.top_k, int):
|
||||
raise TypeError(
|
||||
f"top_k must be an integer, got {type(self.top_k).__name__}")
|
||||
|
||||
if self.max_tokens is not None and self.max_tokens < 1:
|
||||
raise ValueError(
|
||||
@@ -188,6 +199,9 @@ class SamplingParams:
|
||||
if self.logprobs is not None and self.logprobs < 0:
|
||||
raise ValueError(
|
||||
f"logprobs must be non-negative, got {self.logprobs}.")
|
||||
if self.logprobs is not None and self.logprobs > 20:
|
||||
raise ValueError(
|
||||
"Invalid value for 'top_logprobs': must be less than or equal to 20.")
|
||||
|
||||
if not 0 <= self.seed <= 922337203685477580:
|
||||
raise ValueError("seed must be in [0, 922337203685477580], got "
|
||||
|
@@ -24,6 +24,7 @@ import zmq
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from prometheus_client import CONTENT_TYPE_LATEST
|
||||
from fastdeploy.metrics.trace_util import inject_to_metadata,instrument
|
||||
|
||||
from fastdeploy.engine.args_utils import EngineArgs
|
||||
from fastdeploy.engine.engine import LLMEngine
|
||||
@@ -32,7 +33,8 @@ from fastdeploy.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
ErrorResponse)
|
||||
ErrorResponse,
|
||||
ControlSchedulerRequest)
|
||||
from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from fastdeploy.entrypoints.openai.serving_completion import \
|
||||
OpenAIServingCompletion
|
||||
@@ -44,6 +46,7 @@ from fastdeploy.utils import (FlexibleArgumentParser, api_server_logger,
|
||||
console_logger, is_port_available,
|
||||
retrive_model_from_server)
|
||||
|
||||
|
||||
parser = FlexibleArgumentParser()
|
||||
parser.add_argument("--port",
|
||||
default=8000,
|
||||
@@ -139,6 +142,7 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
instrument(app)
|
||||
|
||||
|
||||
# TODO 传递真实引擎值 通过pid 获取状态
|
||||
@@ -209,6 +213,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
||||
return JSONResponse(
|
||||
content={"error": "Worker Service Not Healthy"},
|
||||
status_code=304)
|
||||
inject_to_metadata(request)
|
||||
generator = await app.state.chat_handler.create_chat_completion(request)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
@@ -273,10 +278,13 @@ def clear_load_weight(request: Request) -> Response:
|
||||
status_code=404)
|
||||
|
||||
|
||||
def launch_api_server(args) -> None:
|
||||
def launch_api_server() -> None:
|
||||
"""
|
||||
启动http服务
|
||||
"""
|
||||
if not is_port_available(args.host, args.port):
|
||||
raise Exception(f"The parameter `port`:{args.port} is already in use.")
|
||||
|
||||
api_server_logger.info(
|
||||
f"launch Fastdeploy api server... port: {args.port}")
|
||||
api_server_logger.info(f"args: {args.__dict__}")
|
||||
@@ -319,6 +327,11 @@ def run_metrics_server():
|
||||
|
||||
def launch_metrics_server():
|
||||
"""Metrics server running the sub thread"""
|
||||
if not is_port_available(args.host, args.metrics_port):
|
||||
raise Exception(
|
||||
f"The parameter `metrics_port`:{args.metrics_port} is already in use."
|
||||
)
|
||||
|
||||
prom_dir = cleanup_prometheus_files(True)
|
||||
os.environ["PROMETHEUS_MULTIPROC_DIR"] = prom_dir
|
||||
metrics_server_thread = threading.Thread(target=run_metrics_server,
|
||||
@@ -339,10 +352,39 @@ def reset_scheduler():
|
||||
|
||||
if llm_engine is None:
|
||||
return Response("Engine not loaded", status_code=500)
|
||||
llm_engine.reset_scheduler()
|
||||
llm_engine.scheduler.reset()
|
||||
return Response("Scheduler Reset Successfully", status_code=200)
|
||||
|
||||
|
||||
@controller_app.post("/controller/scheduler")
|
||||
def control_scheduler(request: ControlSchedulerRequest):
|
||||
"""
|
||||
Control the scheduler behavior with the given parameters.
|
||||
"""
|
||||
content = ErrorResponse(object="", message="Scheduler updated successfully", code=0)
|
||||
|
||||
global llm_engine
|
||||
if llm_engine is None:
|
||||
content.message = "Engine is not loaded"
|
||||
content.code = 500
|
||||
return JSONResponse(content=content.model_dump(), status_code=500)
|
||||
|
||||
if request.reset:
|
||||
llm_engine.scheduler.reset()
|
||||
|
||||
if request.load_shards_num or request.reallocate_shard:
|
||||
if hasattr(llm_engine.scheduler, "update_config") and callable(llm_engine.scheduler.update_config):
|
||||
llm_engine.scheduler.update_config(
|
||||
load_shards_num=request.load_shards_num,
|
||||
reallocate=request.reallocate_shard)
|
||||
else:
|
||||
content.message="This scheduler doesn't support the `update_config()` method."
|
||||
content.code=400
|
||||
return JSONResponse(content=content.model_dump(), status_code=400)
|
||||
|
||||
return JSONResponse(content=content.model_dump(), status_code=200)
|
||||
|
||||
|
||||
def run_controller_server():
|
||||
"""
|
||||
run controller server
|
||||
@@ -358,6 +400,11 @@ def launch_controller_server():
|
||||
if args.controller_port < 0:
|
||||
return
|
||||
|
||||
if not is_port_available(args.host, args.controller_port):
|
||||
raise Exception(
|
||||
f"The parameter `controller_port`:{args.controller_port} is already in use."
|
||||
)
|
||||
|
||||
controller_server_thread = threading.Thread(target=run_controller_server,
|
||||
daemon=True)
|
||||
controller_server_thread.start()
|
||||
@@ -366,19 +413,13 @@ def launch_controller_server():
|
||||
|
||||
def main():
|
||||
"""main函数"""
|
||||
if not is_port_available(args.host, args.port):
|
||||
raise Exception(f"The parameter `port`:{args.port} is already in use.")
|
||||
if not is_port_available(args.host, args.metrics_port):
|
||||
raise Exception(
|
||||
f"The parameter `metrics_port`:{args.metrics_port} is already in use."
|
||||
)
|
||||
|
||||
if load_engine() is None:
|
||||
return
|
||||
|
||||
launch_controller_server()
|
||||
launch_metrics_server()
|
||||
launch_api_server(args)
|
||||
launch_api_server()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@@ -122,6 +122,7 @@ class ChatCompletionResponseChoice(BaseModel):
|
||||
"""
|
||||
index: int
|
||||
message: ChatMessage
|
||||
logprobs: Optional[LogProbs] = None
|
||||
finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]]
|
||||
|
||||
|
||||
@@ -136,6 +137,21 @@ class ChatCompletionResponse(BaseModel):
|
||||
choices: List[ChatCompletionResponseChoice]
|
||||
usage: UsageInfo
|
||||
|
||||
class LogProbEntry(BaseModel):
|
||||
"""
|
||||
Log probability entry.
|
||||
"""
|
||||
token: str
|
||||
logprob: float
|
||||
bytes: Optional[List[int]] = None
|
||||
top_logprobs: Optional[List["LogProbEntry"]] = None
|
||||
|
||||
class LogProbs(BaseModel):
|
||||
"""
|
||||
LogProbs.
|
||||
"""
|
||||
content: Optional[List[LogProbEntry]] = None
|
||||
refusal: Optional[Union[str, None]] = None
|
||||
|
||||
class DeltaMessage(BaseModel):
|
||||
"""
|
||||
@@ -154,6 +170,7 @@ class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
"""
|
||||
index: int
|
||||
delta: DeltaMessage
|
||||
logprobs: Optional[LogProbs] = None
|
||||
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
|
||||
arrival_time: Optional[float] = None
|
||||
|
||||
@@ -292,6 +309,7 @@ class CompletionRequest(BaseModel):
|
||||
suffix: Optional[dict] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
response_format: Optional[AnyResponseFormat] = None
|
||||
@@ -391,6 +409,8 @@ class ChatCompletionRequest(BaseModel):
|
||||
tools: Optional[List[ChatCompletionToolsParam]] = None
|
||||
model: Optional[str] = "default"
|
||||
frequency_penalty: Optional[float] = None
|
||||
logprobs: Optional[bool] = False
|
||||
top_logprobs: Optional[int] = 0
|
||||
# remove max_tokens when field is removed from OpenAI API
|
||||
max_tokens: Optional[int] = Field(
|
||||
default=None,
|
||||
@@ -405,6 +425,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
stream_options: Optional[StreamOptions] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
user: Optional[str] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
@@ -432,6 +453,9 @@ class ChatCompletionRequest(BaseModel):
|
||||
if request_id is not None:
|
||||
req_dict['request_id'] = request_id
|
||||
|
||||
req_dict["max_tokens"] = self.max_completion_tokens or self.max_tokens
|
||||
req_dict["logprobs"] = self.top_logprobs if self.logprobs else None
|
||||
|
||||
if self.metadata is not None:
|
||||
for key, value in self.metadata.items():
|
||||
req_dict[key] = value
|
||||
@@ -503,3 +527,27 @@ class ChatCompletionRequest(BaseModel):
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_logprobs(cls, data):
|
||||
|
||||
if (top_logprobs := data.get("top_logprobs")) is not None:
|
||||
if top_logprobs < 0:
|
||||
raise ValueError("`top_logprobs` must be a positive value.")
|
||||
|
||||
if top_logprobs > 0 and not data.get("logprobs"):
|
||||
raise ValueError(
|
||||
"when using `top_logprobs`, `logprobs` must be set to true."
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class ControlSchedulerRequest(BaseModel):
|
||||
"""
|
||||
Control scheduler request to the engine.
|
||||
"""
|
||||
reset: Optional[bool] = False
|
||||
load_shards_num: Optional[int] = None
|
||||
reallocate_shard: Optional[bool] = False
|
@@ -15,34 +15,23 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import aiozmq
|
||||
from aiozmq import zmq
|
||||
import json
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from typing import Callable, Optional, Union, List
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
import aiozmq
|
||||
from aiozmq import zmq
|
||||
|
||||
from fastapi import Request
|
||||
from pydantic import BaseModel
|
||||
from fastdeploy.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaMessage,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatMessage,
|
||||
UsageInfo,
|
||||
PromptTokenUsageInfo,
|
||||
ChatCompletionResponse,
|
||||
ErrorResponse,
|
||||
)
|
||||
ChatCompletionRequest, ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
|
||||
LogProbEntry, LogProbs, PromptTokenUsageInfo, UsageInfo)
|
||||
from fastdeploy.metrics.work_metrics import work_process_metrics
|
||||
|
||||
from fastdeploy.utils import api_server_logger
|
||||
|
||||
from fastdeploy.engine.request import RequestOutput
|
||||
|
||||
from fastdeploy.worker.output import LogprobsLists
|
||||
|
||||
|
||||
class OpenAIServingChat:
|
||||
@@ -115,6 +104,7 @@ class OpenAIServingChat:
|
||||
num_choices = 1
|
||||
max_streaming_response_tokens = 1
|
||||
enable_thinking = None
|
||||
include_stop_str_in_output = False
|
||||
if request.metadata is not None and request.metadata.get("max_streaming_response_tokens", 1) > 1:
|
||||
max_streaming_response_tokens = request.metadata["max_streaming_response_tokens"]
|
||||
|
||||
@@ -157,14 +147,15 @@ class OpenAIServingChat:
|
||||
current_waiting_time = 0
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
|
||||
|
||||
res = json.loads(raw_data[-1].decode('utf-8'))
|
||||
if res.get("error_code", 200) != 200:
|
||||
raise ValueError("{}".format(res["error_msg"]))
|
||||
if request.metadata is not None:
|
||||
enable_thinking = request.metadata.get("enable_thinking")
|
||||
include_stop_str_in_output = request.metadata.get("include_stop_str_in_output", False)
|
||||
self.engine_client.data_processor.process_response_dict(
|
||||
res, stream=True, enable_thinking=enable_thinking)
|
||||
res, stream=True, enable_thinking=enable_thinking, include_stop_str_in_output=include_stop_str_in_output)
|
||||
|
||||
if res['metrics']['first_token_time'] is not None:
|
||||
arrival_time = res['metrics']['first_token_time']
|
||||
@@ -200,6 +191,19 @@ class OpenAIServingChat:
|
||||
|
||||
output = res["outputs"]
|
||||
delta_text = output["text"]
|
||||
raw_top_logprobs = output["top_logprobs"]
|
||||
logprobs_res = None
|
||||
if raw_top_logprobs is not None:
|
||||
top_logprobs = LogprobsLists(
|
||||
logprob_token_ids=raw_top_logprobs[0],
|
||||
logprobs=raw_top_logprobs[1],
|
||||
sampled_token_ranks=raw_top_logprobs[2],
|
||||
)
|
||||
logprobs_res = self.build_logprobs_response(
|
||||
request_logprobs= request.logprobs,
|
||||
response_logprobs=top_logprobs,
|
||||
request_top_logprobs=request.top_logprobs,
|
||||
)
|
||||
|
||||
previous_num_tokens += len(output["token_ids"])
|
||||
delta_message = DeltaMessage(content=delta_text, reasoning_content=output.get("reasoning_content"), \
|
||||
@@ -208,19 +212,22 @@ class OpenAIServingChat:
|
||||
choice = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=delta_message,
|
||||
logprobs=logprobs_res,
|
||||
arrival_time=arrival_time
|
||||
)
|
||||
if res["finished"]:
|
||||
num_choices -= 1
|
||||
work_process_metrics.e2e_request_latency.observe(time.time() - res["metrics"]["request_start_time"])
|
||||
if request.max_tokens is None or previous_num_tokens != request.max_tokens:
|
||||
has_no_token_limit = request.max_tokens is None and request.max_completion_tokens is None
|
||||
max_tokens = request.max_completion_tokens or request.max_tokens
|
||||
if has_no_token_limit or previous_num_tokens != max_tokens:
|
||||
choice.finish_reason = "stop"
|
||||
if self.engine_client.reasoning_parser == "ernie_x1" and \
|
||||
output.get("finish_reason", "") == "tool_calls":
|
||||
choice.finish_reason = "tool_calls"
|
||||
else:
|
||||
choice.finish_reason = "length"
|
||||
|
||||
|
||||
if res.get("error_msg") is not None and "Recover" in res["error_msg"]:
|
||||
choice.finish_reason = "recover_stop"
|
||||
|
||||
@@ -277,6 +284,7 @@ class OpenAIServingChat:
|
||||
created_time = int(time.time())
|
||||
final_res = None
|
||||
enable_thinking = None
|
||||
include_stop_str_in_output = False
|
||||
try:
|
||||
dealer = await aiozmq.create_zmq_stream(
|
||||
zmq.DEALER,
|
||||
@@ -286,6 +294,7 @@ class OpenAIServingChat:
|
||||
final_res = None
|
||||
previous_num_tokens = 0
|
||||
current_waiting_time = 0
|
||||
logprob_contents = []
|
||||
while True:
|
||||
try:
|
||||
raw_data = await asyncio.wait_for(dealer.read(), timeout=10)
|
||||
@@ -306,10 +315,27 @@ class OpenAIServingChat:
|
||||
raise ValueError("{}".format(data["error_msg"]))
|
||||
if request.metadata is not None:
|
||||
enable_thinking = request.metadata.get("enable_thinking")
|
||||
include_stop_str_in_output = request.metadata.get("include_stop_str_in_output", False)
|
||||
data = self.engine_client.data_processor.process_response_dict(
|
||||
data, stream=False, enable_thinking=enable_thinking)
|
||||
data, stream=False, enable_thinking=enable_thinking, include_stop_str_in_output=include_stop_str_in_output)
|
||||
# api_server_logger.debug(f"Client {request_id} received: {data}")
|
||||
previous_num_tokens += len(data["outputs"]["token_ids"])
|
||||
# The logprob for handling the response
|
||||
output = data["outputs"]
|
||||
raw_top_logprobs = output["top_logprobs"]
|
||||
if raw_top_logprobs is not None:
|
||||
top_logprobs = LogprobsLists(
|
||||
logprob_token_ids=raw_top_logprobs[0],
|
||||
logprobs=raw_top_logprobs[1],
|
||||
sampled_token_ranks=raw_top_logprobs[2],
|
||||
)
|
||||
logprobs_res = self.build_logprobs_response(
|
||||
request_logprobs=request.logprobs,
|
||||
response_logprobs=top_logprobs,
|
||||
request_top_logprobs=request.top_logprobs,
|
||||
)
|
||||
if logprobs_res and logprobs_res.content is not None:
|
||||
logprob_contents.extend(logprobs_res.content)
|
||||
if data["finished"]:
|
||||
final_res = data
|
||||
break
|
||||
@@ -325,20 +351,28 @@ class OpenAIServingChat:
|
||||
tool_calls=output.get("tool_call_content"),
|
||||
token_ids=output.get("token_ids")
|
||||
)
|
||||
logprobs_full_res = None
|
||||
if logprob_contents:
|
||||
logprobs_full_res = LogProbs(
|
||||
content=logprob_contents
|
||||
)
|
||||
|
||||
choice = ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=message,
|
||||
logprobs=logprobs_full_res,
|
||||
finish_reason=None
|
||||
)
|
||||
if request.max_tokens is None or previous_num_tokens != request.max_tokens:
|
||||
has_no_token_limit = request.max_tokens is None and request.max_completion_tokens is None
|
||||
max_tokens = request.max_completion_tokens or request.max_tokens
|
||||
if has_no_token_limit or previous_num_tokens != max_tokens:
|
||||
choice.finish_reason = "stop"
|
||||
if self.engine_client.reasoning_parser == "ernie_x1" and \
|
||||
output.get("finish_reason", "") == "tool_calls":
|
||||
choice.finish_reason = "tool_calls"
|
||||
else:
|
||||
choice.finish_reason = "length"
|
||||
|
||||
|
||||
if final_res.get("error_msg") is not None and "Recover" in final_res["error_msg"]:
|
||||
choice.finish_reason = "recover_stop"
|
||||
choices.append(choice)
|
||||
@@ -359,3 +393,55 @@ class OpenAIServingChat:
|
||||
choices=choices,
|
||||
usage=usage
|
||||
)
|
||||
|
||||
def build_logprobs_response(
|
||||
self,
|
||||
request_logprobs: bool,
|
||||
response_logprobs: Optional[LogprobsLists],
|
||||
request_top_logprobs: int,
|
||||
) -> Optional[LogProbs]:
|
||||
"""
|
||||
Construct a logprobs response object in line with the OpenAI style.
|
||||
Retain the complete top-k candidates and avoid circular references.
|
||||
"""
|
||||
|
||||
# Parameter validation
|
||||
if (
|
||||
response_logprobs is None
|
||||
or not request_logprobs
|
||||
or request_top_logprobs is None
|
||||
or request_top_logprobs < 0
|
||||
):
|
||||
return None
|
||||
|
||||
try:
|
||||
# The top-k candidates for the current token
|
||||
topk_token_ids = response_logprobs.logprob_token_ids[0][:request_top_logprobs + 1]
|
||||
topk_logprobs = response_logprobs.logprobs[0][:request_top_logprobs + 1]
|
||||
|
||||
# Construct the candidate token structure (LogProbEntry) of topk
|
||||
top_logprob_entries: List[LogProbEntry] = []
|
||||
for tid, lp in zip(topk_token_ids, topk_logprobs):
|
||||
token_str = self.engine_client.data_processor.process_logprob_response([tid],
|
||||
clean_up_tokenization_spaces=False)
|
||||
# token_bytes = token_str.encode("utf-8", errors="replace")
|
||||
entry = LogProbEntry(
|
||||
token=token_str,
|
||||
logprob=lp,
|
||||
# bytes=list(token_bytes)
|
||||
)
|
||||
top_logprob_entries.append(entry)
|
||||
# Construct the sampled token object (avoid sharing references with top_logprob_entries)
|
||||
sampled_entry = LogProbEntry(
|
||||
token=top_logprob_entries[0].token,
|
||||
logprob=top_logprob_entries[0].logprob,
|
||||
bytes=top_logprob_entries[0].bytes,
|
||||
top_logprobs=top_logprob_entries[1:] # Here are the complete topk candidates
|
||||
)
|
||||
|
||||
return LogProbs(content=[sampled_entry])
|
||||
|
||||
except Exception as e:
|
||||
api_server_logger.error("Error in build_logprobs_response: %s", e)
|
||||
api_server_logger.error(traceback.format_exc())
|
||||
return None
|
||||
|
@@ -74,7 +74,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"FD_ATTENTION_BACKEND":
|
||||
lambda: os.getenv("FD_ATTENTION_BACKEND", "APPEND_ATTN"),
|
||||
|
||||
# Set sampling class. "base", "air" and "rejection" can be set currently.
|
||||
# Set sampling class. "base", "base_non_truncated", "air" and "rejection" can be set currently.
|
||||
"FD_SAMPLING_CLASS":
|
||||
lambda: os.getenv("FD_SAMPLING_CLASS", "base"),
|
||||
|
||||
@@ -97,6 +97,30 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# Whether to use fastsafetensor load weight (0 or 1)
|
||||
"FD_USE_FASTSAFETENSOR":
|
||||
lambda: os.getenv("FD_USE_FASTSAFETENSOR", "0"),
|
||||
|
||||
# Whether to open Trace.
|
||||
"TRACES_ENABLE":
|
||||
lambda: os.getenv("TRACES_ENABLE", "false"),
|
||||
|
||||
# set traec Server name.
|
||||
"FD_SERVICE_NAME":
|
||||
lambda: os.getenv("FD_SERVICE_NAME", "FastDeploy"),
|
||||
|
||||
# set traec host name.
|
||||
"FD_HOST_NAME":
|
||||
lambda: os.getenv("FD_HOST_NAME", "localhost"),
|
||||
|
||||
# set traec exporter.
|
||||
"TRACES_EXPORTER":
|
||||
lambda: os.getenv("TRACES_EXPORTER", "console"),
|
||||
|
||||
# set traec exporter_otlp_endpoint.
|
||||
"EXPORTER_OTLP_ENDPOINT":
|
||||
lambda: os.getenv("EXPORTER_OTLP_ENDPOINT"),
|
||||
|
||||
# set traec exporter_otlp_headers.
|
||||
"EXPORTER_OTLP_HEADERS":
|
||||
lambda: os.getenv("EXPORTER_OTLP_HEADERS"),
|
||||
}
|
||||
|
||||
|
||||
|
@@ -20,10 +20,9 @@ import numpy as np
|
||||
from paddleformers.generation import GenerationConfig
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.utils import data_processor_logger
|
||||
from fastdeploy.input.ernie_tokenizer import ErnieBotTokenizer
|
||||
|
||||
from fastdeploy.input.text_processor import BaseDataProcessor
|
||||
from fastdeploy.utils import data_processor_logger
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
@@ -101,7 +100,6 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
|
||||
if request.prompt_token_ids is None or len(
|
||||
request.prompt_token_ids) == 0:
|
||||
system = request.get("system")
|
||||
if request.prompt is None and request.messages is None:
|
||||
raise ValueError(
|
||||
f"The request should have `input_ids`, `text` or `messages`: {request}.")
|
||||
@@ -150,7 +148,6 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
request['stop_token_ids'] = stop_seqs
|
||||
request['stop_seqs_len'] = stop_seqs_len
|
||||
|
||||
system = request.get("system")
|
||||
# 处理prompt_token_ids
|
||||
if not request.get('prompt_token_ids'):
|
||||
if request.get('prompt') is None and request.get(
|
||||
@@ -250,7 +247,7 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
token_ids = response_dict["outputs"]["token_ids"]
|
||||
is_end = response_dict["finished"]
|
||||
req_id = response_dict["request_id"]
|
||||
if is_end and len(token_ids) > 0:
|
||||
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
|
||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||
token_ids = token_ids[:-1]
|
||||
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
|
||||
@@ -285,7 +282,7 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
req_id = response_dict["request_id"]
|
||||
token_ids = response_dict["outputs"]["token_ids"]
|
||||
|
||||
if is_end and len(token_ids) > 0:
|
||||
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
|
||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||
token_ids = token_ids[:-1]
|
||||
delta_text, previous_token_ids, previous_texts = self.ids2tokens(
|
||||
@@ -444,3 +441,7 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
data_processor_logger.debug(
|
||||
f"processed stop_seqs: {stop_seqs}, {stop_seqs_len}")
|
||||
return stop_seqs, stop_seqs_len
|
||||
|
||||
def process_logprob_response(self, token_ids, **kwargs):
|
||||
full_text = self.tokenizer.decode(token_ids, **kwargs)
|
||||
return full_text
|
||||
|
@@ -143,7 +143,7 @@ class ErnieBotTokenizer(PretrainedTokenizer):
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
"""Converts a sequence of tokens (string) in a single string."""
|
||||
spec_init()
|
||||
self.spec_init()
|
||||
current_sub_tokens = []
|
||||
out_string = ""
|
||||
# prev_is_special = False
|
||||
@@ -216,7 +216,7 @@ class ErnieBotTokenizer(PretrainedTokenizer):
|
||||
# if isinstance(t, AddedToken)
|
||||
# )
|
||||
|
||||
spec_init()
|
||||
self.spec_init()
|
||||
text, kwargs = self.prepare_for_tokenization(text, **kwargs)
|
||||
|
||||
# TODO: should this be in the base class?
|
||||
|
@@ -309,6 +309,10 @@ class DataProcessor(BaseDataProcessor):
|
||||
data_processor_logger.info(f"Processed request {request}")
|
||||
return request
|
||||
|
||||
def process_logprob_response(self, token_ids, **kwargs):
|
||||
full_text = self.tokenizer.decode(token_ids, **kwargs)
|
||||
return full_text
|
||||
|
||||
def process_response(self, response_dict, **kwargs):
|
||||
"""
|
||||
Preprocess the response
|
||||
@@ -351,7 +355,7 @@ class DataProcessor(BaseDataProcessor):
|
||||
token_ids = response_dict["outputs"]["token_ids"]
|
||||
is_end = response_dict["finished"]
|
||||
req_id = response_dict["request_id"]
|
||||
if is_end and len(token_ids) > 0:
|
||||
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
|
||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||
token_ids = token_ids[:-1]
|
||||
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
|
||||
@@ -386,7 +390,7 @@ class DataProcessor(BaseDataProcessor):
|
||||
req_id = response_dict["request_id"]
|
||||
token_ids = response_dict["outputs"]["token_ids"]
|
||||
|
||||
if is_end and len(token_ids) > 0:
|
||||
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
|
||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||
token_ids = token_ids[:-1]
|
||||
delta_text, previous_token_ids, previous_texts = self.ids2tokens(
|
||||
@@ -426,7 +430,7 @@ class DataProcessor(BaseDataProcessor):
|
||||
response_dict, enable_thinking=enable_thinking, **kwargs)
|
||||
else:
|
||||
return self.process_response_dict_normal(
|
||||
response_dict=response_dict, enable_thinking=enable_thinking)
|
||||
response_dict=response_dict, enable_thinking=enable_thinking, **kwargs)
|
||||
|
||||
def text2ids(self, text, max_model_len, raw_request=True):
|
||||
"""
|
||||
|
198
fastdeploy/metrics/trace_util.py
Normal file
198
fastdeploy/metrics/trace_util.py
Normal file
@@ -0,0 +1,198 @@
|
||||
from opentelemetry.propagate import inject, extract
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.sdk.trace.export import ConsoleSpanExporter
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
|
||||
from fastapi import FastAPI
|
||||
from fastdeploy.utils import (llm_logger)
|
||||
from fastdeploy import envs
|
||||
import json
|
||||
|
||||
|
||||
# OpenTelemetry Trace context store in metadata
|
||||
TRACE_CARRIER = "trace_carrier"
|
||||
|
||||
traces_enable = False
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
def set_up():
|
||||
try:
|
||||
# when TRACES_ENABLED=true start trace
|
||||
global traces_enable
|
||||
traces_enable = envs.TRACES_ENABLE.lower() == "true"
|
||||
if not traces_enable:
|
||||
llm_logger.warning("Opentelemetry is DISABLED.")
|
||||
return
|
||||
|
||||
llm_logger.info("Opentelemetry is ENABLED, configuring...")
|
||||
# --- read env ---
|
||||
service_name = envs.FD_SERVICE_NAME
|
||||
host_name = envs.FD_HOST_NAME
|
||||
# --- set attributes (Service Name, Host Name, etc.) ---
|
||||
resource_attributes = {
|
||||
"service.name": service_name
|
||||
}
|
||||
if host_name:
|
||||
resource_attributes["host.name"] = host_name
|
||||
|
||||
resource = Resource(attributes=resource_attributes)
|
||||
|
||||
# --- set Exporter ---
|
||||
exporter_type = envs.TRACES_EXPORTER.lower()
|
||||
if exporter_type == "otlp":
|
||||
endpoint = envs.EXPORTER_OTLP_ENDPOINT # should be set
|
||||
headers = envs.EXPORTER_OTLP_HEADERS # e.g., "Authentication=***,k2=v2"
|
||||
|
||||
otlp_exporter = OTLPSpanExporter(
|
||||
endpoint=endpoint,
|
||||
headers=dict(item.split("=") for item in headers.split(",")) if headers else None
|
||||
)
|
||||
processor = BatchSpanProcessor(otlp_exporter)
|
||||
llm_logger.info(f"Using OTLP Exporter, sending to {endpoint} with headers {headers}")
|
||||
else: # default console
|
||||
processor = BatchSpanProcessor(ConsoleSpanExporter())
|
||||
llm_logger.info("Using Console Exporter.")
|
||||
|
||||
# --- set Tracer Provider ---
|
||||
provider = TracerProvider(resource=resource)
|
||||
provider.add_span_processor(processor)
|
||||
trace.set_tracer_provider(provider)
|
||||
global tracer
|
||||
tracer = trace.get_tracer(__name__)
|
||||
except:
|
||||
llm_logger.error("set_up failed")
|
||||
pass
|
||||
|
||||
def instrument(app: FastAPI):
|
||||
try:
|
||||
set_up()
|
||||
if traces_enable:
|
||||
llm_logger.info("Applying instrumentors...")
|
||||
FastAPIInstrumentor.instrument_app(app)
|
||||
except:
|
||||
llm_logger.info("instrument failed")
|
||||
pass
|
||||
|
||||
|
||||
|
||||
def inject_to_metadata(request, metadata_attr='metadata'):
|
||||
"""
|
||||
Inject OpenTelemetry trace context into the metadata field of the request.
|
||||
|
||||
Parameters:
|
||||
request: can be a dict or object, with metadata attributes or fields.
|
||||
metadata_attr: the field name of metadata, default is 'metadata'.
|
||||
|
||||
Operation:
|
||||
- If metadata does not exist, create a new one and mount it on the request.
|
||||
- Inject the current trace context as a JSON string and store it in metadata.
|
||||
- Use the key TRACE_CARRIER to store the injected content.
|
||||
|
||||
Note:
|
||||
- This function is a non-blocking operation, and errors are silently ignored.
|
||||
- If there is no metadata attribute in the request, an empty dict will be created for it as its attribute
|
||||
"""
|
||||
try:
|
||||
if request is None or traces_enable == False:
|
||||
return
|
||||
|
||||
metadata = request.get(metadata_attr) if isinstance(request, dict) else getattr(request, metadata_attr, None)
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
if isinstance(request, dict):
|
||||
request[metadata_attr] = metadata
|
||||
else:
|
||||
setattr(request, metadata_attr, metadata)
|
||||
|
||||
trace_carrier = {}
|
||||
inject(trace_carrier)
|
||||
trace_carrier_json_string = json.dumps(trace_carrier)
|
||||
metadata[TRACE_CARRIER] = trace_carrier_json_string
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def extract_from_metadata(request, metadata_attr='metadata'):
|
||||
"""
|
||||
Extract trace context from metadata of request object (dict or class instance).
|
||||
|
||||
Parameters:
|
||||
request: can be a dictionary or any object, containing metadata attributes or fields.
|
||||
metadata_attr: metadata field name, default is 'metadata'.
|
||||
|
||||
Returns:
|
||||
- Extraction success: returns OpenTelemetry context object (Context)
|
||||
- Extraction failure or exception: returns None
|
||||
"""
|
||||
try:
|
||||
metadata = request.get(metadata_attr) if isinstance(request, dict) else getattr(request, metadata_attr, None)
|
||||
if metadata is None:
|
||||
return None
|
||||
|
||||
trace_carrier_json_string = metadata.get(TRACE_CARRIER)
|
||||
if trace_carrier_json_string is None:
|
||||
return None
|
||||
|
||||
trace_carrier = json.loads(trace_carrier_json_string)
|
||||
ctx = extract(trace_carrier)
|
||||
return ctx
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
def extract_from_request(request):
|
||||
"""
|
||||
Extract trace context from trace_carrier of request object (dict or class instance).
|
||||
|
||||
Parameters:
|
||||
request: can be a dictionary or any object, containing metadata attributes or fields.
|
||||
metadata_attr: metadata field name, default is 'metadata'.
|
||||
|
||||
Returns:
|
||||
- Extraction success: returns OpenTelemetry context object (Context)
|
||||
- Extraction failure or exception: returns None
|
||||
"""
|
||||
try:
|
||||
trace_carrier_info = getattr(request, TRACE_CARRIER, None)
|
||||
|
||||
if trace_carrier_info is None:
|
||||
return None
|
||||
|
||||
trace_carrier = json.loads(trace_carrier_info)
|
||||
ctx = extract(trace_carrier)
|
||||
return ctx
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
def start_span(span_name, request, kind=trace.SpanKind.CLIENT):
|
||||
"""
|
||||
just start a new span in request trace context
|
||||
"""
|
||||
try:
|
||||
if not traces_enable:
|
||||
return
|
||||
# extract Trace context from request.metadata.trace_carrier
|
||||
ctx = extract_from_metadata(request)
|
||||
with tracer.start_as_current_span(span_name, context=ctx, kind=kind) as span:
|
||||
pass
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def start_span_request(span_name, request, kind=trace.SpanKind.CLIENT):
|
||||
"""
|
||||
just start a new span in request trace context
|
||||
"""
|
||||
try:
|
||||
if not traces_enable:
|
||||
return
|
||||
# extract Trace context from request.metadata.trace_carrier
|
||||
ctx = extract_from_request(request)
|
||||
with tracer.start_as_current_span(span_name, context=ctx, kind=kind) as span:
|
||||
pass
|
||||
except:
|
||||
pass
|
@@ -21,7 +21,11 @@ from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
|
||||
import paddle
|
||||
from paddle.nn.functional.flash_attention import flash_attention_v3_varlen
|
||||
|
||||
try:
|
||||
from paddle.nn.functional.flash_attention import flash_attention_v3_varlen
|
||||
except:
|
||||
flash_attention_v3_varlen = None
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
|
@@ -293,7 +293,7 @@ class ColumnParallelLinear(LinearBase):
|
||||
)
|
||||
if self.nranks > 0:
|
||||
# col parallel
|
||||
_set_var_distributed(self.linear_weight, split_axis=-1)
|
||||
_set_var_distributed(self.linear_weight, split_axis=1)
|
||||
|
||||
self.linear_bias = None
|
||||
if self.with_bias:
|
||||
@@ -304,7 +304,7 @@ class ColumnParallelLinear(LinearBase):
|
||||
)
|
||||
if self.nranks > 0:
|
||||
# col parallel
|
||||
_set_var_distributed(self.linear_bias, split_axis=-1)
|
||||
_set_var_distributed(self.linear_bias, split_axis=1)
|
||||
|
||||
# smooth quant
|
||||
self.linear_shift = None
|
||||
|
@@ -89,6 +89,7 @@ class FusedMoE(nn.Layer):
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
|
||||
moe_quant_config = fd_config.quant_config
|
||||
self.moe_quant_type = None
|
||||
if moe_quant_config:
|
||||
self.quant_method = moe_quant_config.get_quant_method(self)
|
||||
self.moe_quant_type = moe_quant_config.name()
|
||||
@@ -142,7 +143,7 @@ class FusedMoE(nn.Layer):
|
||||
if self.moe_quant_type == "fp8":
|
||||
#(TODO:gaoziyuan)
|
||||
pass
|
||||
else:
|
||||
elif self.moe_quant_type == "wint8":
|
||||
self.weight_dtype = "int8"
|
||||
self.init_weight_only_scale()
|
||||
|
||||
|
@@ -42,3 +42,4 @@ class SamplingMetadata:
|
||||
|
||||
top_p: paddle.Tensor
|
||||
top_k: Optional[paddle.Tensor] = None
|
||||
max_num_logprobs: Optional[int] = None
|
||||
|
@@ -16,10 +16,10 @@
|
||||
|
||||
from .apply_penalty_multi_scores import (
|
||||
apply_penalty_multi_scores, apply_speculative_penalty_multi_scores)
|
||||
from .top_p_sampling import top_p_sampling
|
||||
from .top_k_top_p_sampling import top_k_top_p_sampling
|
||||
|
||||
__all__ = [
|
||||
"apply_penalty_multi_scores",
|
||||
"apply_speculative_penalty_multi_scores",
|
||||
"top_p_sampling",
|
||||
"top_k_top_p_sampling",
|
||||
]
|
||||
|
@@ -0,0 +1,152 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Literal, Optional
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy import envs
|
||||
|
||||
|
||||
def top_k_top_p_sampling(
|
||||
x: paddle.Tensor,
|
||||
top_p: paddle.Tensor,
|
||||
top_k: Optional[paddle.Tensor] = None,
|
||||
threshold: Optional[paddle.Tensor] = None,
|
||||
topp_seed: Optional[paddle.Tensor] = None,
|
||||
seed: int = -1,
|
||||
k: int = 0,
|
||||
mode: Literal['truncated', 'non-truncated'] = "truncated",
|
||||
order: Literal['top_k_first', 'joint'] = "top_k_first",
|
||||
) -> tuple[paddle.Tensor, paddle.Tensor]:
|
||||
"""
|
||||
x(Tensor): An input 2-D Tensor with type float32, float16 and bfloat16.
|
||||
top_p(Tensor): A 1-D Tensor with type float32, float16 and bfloat16,
|
||||
used to specify the top_p corresponding to each query.
|
||||
top_k(Tensor|None, optional): A 1-D Tensor with type int64,
|
||||
used to specify the top_k corresponding to each query.
|
||||
Only used when FD_SAMPLING_CLASS is `rejection`.
|
||||
threshold(Tensor|None, optional): A 1-D Tensor with type float32, float16 and bfloat16,
|
||||
used to avoid sampling low score tokens.
|
||||
topp_seed(Tensor|None, optional): A 1-D Tensor with type int64,
|
||||
used to specify the random seed for each query.
|
||||
seed(int, optional): the random seed. Default is -1,
|
||||
k(int): the number of top_k scores/ids to be returned. Default is 0.
|
||||
Only used when FD_SAMPLING_CLASS is `air`.
|
||||
mode(str): The mode to choose sampling strategy. If the mode is `truncated`, sampling will truncate the probability at top_p_value.
|
||||
If the mode is `non-truncated`, it will not be truncated. Default is `truncated`.
|
||||
Only used when FD_SAMPLING_CLASS is `air` or `base`.
|
||||
order(str): The order of applying top-k and top-p sampling, should be either `top_k_first` or `joint`.
|
||||
If `top_k_first`, we first apply top-k filter, then apply top-p sampling on the top-k results.
|
||||
If `joint`, we apply top-k and top-p filter simultaneously in each round. Default is `top_k_first`.
|
||||
Only used when FD_SAMPLING_CLASS is `rejection`.
|
||||
|
||||
"""
|
||||
top_p_class = envs.FD_SAMPLING_CLASS.lower()
|
||||
if top_p_class == "air":
|
||||
_, ids = air_top_p_sampling(x,
|
||||
top_p,
|
||||
threshold,
|
||||
topp_seed,
|
||||
seed=seed,
|
||||
k=k,
|
||||
mode=mode)
|
||||
# rejection
|
||||
elif top_p_class == "rejection":
|
||||
ids = rejection_top_p_sampling(x, top_p, top_k, seed, order)
|
||||
_ = None
|
||||
# base non-truncated
|
||||
elif top_p_class == "base_non_truncated":
|
||||
_, ids = paddle.tensor.top_p_sampling(x,
|
||||
top_p,
|
||||
threshold=threshold,
|
||||
topp_seed=topp_seed,
|
||||
seed=seed,
|
||||
k=k,
|
||||
mode="non-truncated")
|
||||
# base truncated
|
||||
else:
|
||||
_, ids = paddle.tensor.top_p_sampling(x,
|
||||
top_p,
|
||||
threshold=threshold,
|
||||
topp_seed=topp_seed,
|
||||
seed=seed,
|
||||
k=k,
|
||||
mode="truncated")
|
||||
return _, ids
|
||||
|
||||
|
||||
def air_top_p_sampling(
|
||||
x: paddle.Tensor,
|
||||
top_p: paddle.Tensor,
|
||||
threshold: Optional[paddle.Tensor] = None,
|
||||
topp_seed: Optional[paddle.Tensor] = None,
|
||||
seed: int = -1,
|
||||
k: int = 0,
|
||||
mode: Literal['truncated', 'non-truncated'] = "truncated",
|
||||
) -> tuple[paddle.Tensor, paddle.Tensor]:
|
||||
"""
|
||||
air_top_p_sampling
|
||||
"""
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.gpu import air_top_p_sampling
|
||||
out, ids = air_top_p_sampling(x, top_p, threshold, topp_seed, seed, k,
|
||||
mode)
|
||||
except ImportError:
|
||||
raise RuntimeError("Cannot import air_top_p_sampling op.")
|
||||
return out, ids
|
||||
|
||||
|
||||
def rejection_top_p_sampling(
|
||||
x: paddle.Tensor,
|
||||
top_p: paddle.Tensor,
|
||||
top_k: paddle.Tensor,
|
||||
seed: int = -1,
|
||||
order: Literal['top_k_first', 'joint'] = "top_k_first",
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
rejection_top_p_sampling
|
||||
"""
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
rejection_top_p_sampling, top_k_renorm_probs)
|
||||
|
||||
if paddle.count_nonzero(top_k) == 0:
|
||||
ids = rejection_top_p_sampling(
|
||||
x,
|
||||
top_p,
|
||||
None,
|
||||
seed,
|
||||
)
|
||||
else:
|
||||
if order == "top_k_first":
|
||||
renorm_probs = top_k_renorm_probs(x, top_k)
|
||||
ids = rejection_top_p_sampling(
|
||||
renorm_probs,
|
||||
top_p,
|
||||
None,
|
||||
seed,
|
||||
)
|
||||
else:
|
||||
ids = rejection_top_p_sampling(
|
||||
x,
|
||||
top_p,
|
||||
top_k,
|
||||
seed,
|
||||
)
|
||||
except ImportError:
|
||||
raise RuntimeError("Cannot import rejection_top_p_sampling op.")
|
||||
return ids
|
@@ -1,97 +0,0 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Literal, Optional
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy import envs
|
||||
|
||||
|
||||
def top_p_sampling(
|
||||
x: paddle.Tensor,
|
||||
ps: paddle.Tensor,
|
||||
threshold: Optional[paddle.Tensor] = None,
|
||||
topp_seed: Optional[paddle.Tensor] = None,
|
||||
seed: int = -1,
|
||||
k: int = 0,
|
||||
mode: Literal['truncated', 'non-truncated'] = "truncated",
|
||||
) -> tuple[paddle.Tensor, paddle.Tensor]:
|
||||
"""
|
||||
top_p_sampling
|
||||
"""
|
||||
top_p_class = envs.FD_SAMPLING_CLASS.lower()
|
||||
if top_p_class == "air":
|
||||
_, ids = air_top_p_sampling(x,
|
||||
ps,
|
||||
threshold,
|
||||
topp_seed,
|
||||
seed=seed,
|
||||
k=k,
|
||||
mode=mode)
|
||||
elif top_p_class == "rejection":
|
||||
ids = rejection_top_p_sampling(x, ps, seed)
|
||||
_ = None
|
||||
else:
|
||||
_, ids = paddle.tensor.top_p_sampling(x,
|
||||
ps,
|
||||
threshold=threshold,
|
||||
topp_seed=topp_seed,
|
||||
seed=seed,
|
||||
k=k,
|
||||
mode=mode)
|
||||
return _, ids
|
||||
|
||||
|
||||
def air_top_p_sampling(
|
||||
x: paddle.Tensor,
|
||||
ps: paddle.Tensor,
|
||||
threshold: Optional[paddle.Tensor] = None,
|
||||
topp_seed: Optional[paddle.Tensor] = None,
|
||||
seed: int = -1,
|
||||
k: int = 0,
|
||||
mode: Literal['truncated', 'non-truncated'] = "truncated",
|
||||
) -> tuple[paddle.Tensor, paddle.Tensor]:
|
||||
"""
|
||||
air_top_p_sampling
|
||||
"""
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.gpu import air_top_p_sampling
|
||||
out, ids = air_top_p_sampling(x, ps, threshold, topp_seed, seed, k,
|
||||
mode)
|
||||
except ImportError:
|
||||
raise RuntimeError("Cannot import air_top_p_sampling op.")
|
||||
return out, ids
|
||||
|
||||
|
||||
def rejection_top_p_sampling(
|
||||
x: paddle.Tensor,
|
||||
ps: paddle.Tensor,
|
||||
seed: int = -1,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
rejection_top_p_sampling
|
||||
"""
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.gpu import rejection_top_p_sampling
|
||||
ids = rejection_top_p_sampling(
|
||||
x,
|
||||
ps,
|
||||
seed,
|
||||
)
|
||||
except ImportError:
|
||||
raise RuntimeError("Cannot import rejection_top_p_sampling op.")
|
||||
return ids
|
@@ -27,8 +27,9 @@ from fastdeploy.model_executor.guided_decoding.base_guided_decoding import \
|
||||
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
|
||||
from fastdeploy.model_executor.layers.sample.ops import (
|
||||
apply_penalty_multi_scores, apply_speculative_penalty_multi_scores,
|
||||
top_p_sampling)
|
||||
top_k_top_p_sampling)
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.worker.output import LogprobsTensors, SamplerOutput
|
||||
|
||||
|
||||
class SamplerProcessor:
|
||||
@@ -188,14 +189,65 @@ class Sampler(nn.Layer):
|
||||
""" pre process before running """
|
||||
self.processor.pre_process(skip_idx_list)
|
||||
|
||||
def compute_logprobs(self, logits: paddle.Tensor) -> paddle.Tensor:
|
||||
"""
|
||||
"""
|
||||
return F.log_softmax(logits, axis=-1)
|
||||
|
||||
def gather_logprobs(
|
||||
self,
|
||||
logprobs: paddle.Tensor,
|
||||
num_logprobs: int,
|
||||
token_ids: paddle.Tensor,
|
||||
) -> LogprobsTensors:
|
||||
"""
|
||||
Gather logprobs for topk and sampled/prompt token.
|
||||
Args:
|
||||
logprobs: (num tokens) x (vocab) tensor
|
||||
num_logprobs: minimum number of logprobs to
|
||||
retain per token
|
||||
token_ids: prompt tokens (if prompt logprobs)
|
||||
or sampled tokens (if sampled
|
||||
logprobs); 1D token ID tensor
|
||||
with (num tokens) elements
|
||||
Must be int64.
|
||||
Returns:
|
||||
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
|
||||
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
|
||||
Sampled token rank tensor, (num tokens)
|
||||
"""
|
||||
assert token_ids.dtype == paddle.int64
|
||||
# Get with the logprob of the prompt or sampled token.
|
||||
token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1)
|
||||
|
||||
# Compute the ranks of the actual token.
|
||||
token_ranks = (logprobs >= token_logprobs).sum(-1)
|
||||
|
||||
if num_logprobs >= 1:
|
||||
# Find the topK values.
|
||||
topk_logprobs, topk_indices = paddle.topk(logprobs,
|
||||
num_logprobs,
|
||||
axis=-1)
|
||||
indices = paddle.concat([token_ids, topk_indices], axis=1)
|
||||
top_logprobs = paddle.concat([token_logprobs, topk_logprobs], axis=1)
|
||||
else:
|
||||
indices = token_ids
|
||||
top_logprobs = token_logprobs
|
||||
|
||||
return LogprobsTensors(indices, top_logprobs, token_ranks)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
logits: paddle.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
skip_idx_list: List[int] = [],
|
||||
) -> paddle.Tensor:
|
||||
) -> SamplerOutput:
|
||||
"""
|
||||
"""
|
||||
num_logprobs = sampling_metadata.max_num_logprobs
|
||||
if num_logprobs is not None:
|
||||
raw_logprobs = self.compute_logprobs(logits)
|
||||
|
||||
logits = self.processor.apply_token_mask(logits, skip_idx_list)
|
||||
|
||||
logits = apply_penalty_multi_scores(
|
||||
@@ -213,10 +265,21 @@ class Sampler(nn.Layer):
|
||||
|
||||
probs = F.softmax(logits)
|
||||
|
||||
_, next_tokens = top_p_sampling(probs, sampling_metadata.top_p)
|
||||
_, next_tokens = top_k_top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k)
|
||||
|
||||
logprobs_tensors = None if num_logprobs is None else \
|
||||
self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=next_tokens)
|
||||
|
||||
self.processor.update_output_tokens(next_tokens, skip_idx_list)
|
||||
return next_tokens
|
||||
|
||||
sampler_output = SamplerOutput(
|
||||
# The sampled tokens are expanded to 2D tensor with shape
|
||||
# [num_requests, 1], where each row represents one generated
|
||||
# token per request.
|
||||
sampled_token_ids=next_tokens,
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
)
|
||||
return sampler_output
|
||||
|
||||
|
||||
class SpeculativeSampler(nn.Layer):
|
||||
@@ -364,5 +427,5 @@ class MTPSampler(nn.Layer):
|
||||
)
|
||||
probs = F.softmax(logits)
|
||||
|
||||
_, next_tokens = top_p_sampling(probs, sampling_metadata.top_p)
|
||||
_, next_tokens = top_k_top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k)
|
||||
return next_tokens
|
||||
|
@@ -91,8 +91,11 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
def load_model(self, fd_config: FDConfig) -> nn.Layer:
|
||||
context = paddle.LazyGuard()
|
||||
architectures = fd_config.model_config.architectures[0]
|
||||
# TODO(gongshaotian): Now, only support safetensor
|
||||
model_class = MODEL_CLASSES[architectures]
|
||||
|
||||
if fd_config.load_config.dynamic_load_weight:
|
||||
# register rl model
|
||||
import fastdeploy.rl
|
||||
architectures = architectures + "RL"
|
||||
|
||||
with context:
|
||||
model_cls = ModelRegistry.get_class(architectures)
|
||||
@@ -104,6 +107,8 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
if fd_config.load_config.dynamic_load_weight:
|
||||
return model
|
||||
|
||||
# TODO(gongshaotian): Now, only support safetensor
|
||||
model_class = MODEL_CLASSES[architectures]
|
||||
state_dict = load_composite_checkpoint(
|
||||
fd_config.parallel_config.model_name_or_path,
|
||||
model_class,
|
||||
|
@@ -36,8 +36,7 @@ def _find_py_files(root_dir):
|
||||
|
||||
|
||||
def auto_models_registry(dir_path,
|
||||
register_path="fastdeploy.model_executor.models",
|
||||
suffix=""):
|
||||
register_path="fastdeploy.model_executor.models"):
|
||||
"""
|
||||
auto registry all models in this folder
|
||||
"""
|
||||
@@ -49,7 +48,7 @@ def auto_models_registry(dir_path,
|
||||
if inspect.isclass(attr) and issubclass(
|
||||
attr,
|
||||
ModelForCasualLM) and attr is not ModelForCasualLM:
|
||||
ModelRegistry.register(attr, suffix=suffix)
|
||||
ModelRegistry.register(attr)
|
||||
except ImportError:
|
||||
raise ImportError(f"{module_file=} import error")
|
||||
|
||||
|
@@ -288,14 +288,14 @@ class Ernie4_5_DecoderLayer(nn.Layer):
|
||||
self.input_layernorm = RMSNorm(
|
||||
fd_config,
|
||||
hidden_size=fd_config.model_config.hidden_size,
|
||||
eps=1e-5,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix=f"{prefix}.input_layernorm",
|
||||
)
|
||||
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
fd_config,
|
||||
hidden_size=fd_config.model_config.hidden_size,
|
||||
eps=1e-5,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix=f"{prefix}.post_attention_layernorm",
|
||||
)
|
||||
|
||||
@@ -366,7 +366,7 @@ class Ernie4_5_Model(nn.Layer):
|
||||
self.norm = RMSNorm(
|
||||
fd_config,
|
||||
hidden_size=fd_config.model_config.hidden_size,
|
||||
eps=1e-5,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix=f"{fd_config.model_config.prefix_name}.norm",
|
||||
)
|
||||
|
||||
|
@@ -275,14 +275,14 @@ class Ernie4_5_MTPModel(nn.Layer):
|
||||
self.enorm = RMSNorm(
|
||||
fd_config,
|
||||
hidden_size=fd_config.model_config.hidden_size,
|
||||
eps=1e-5,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix="ernie.mtp_emb_norm.0",
|
||||
)
|
||||
|
||||
self.hnorm = RMSNorm(
|
||||
fd_config,
|
||||
hidden_size=fd_config.model_config.hidden_size,
|
||||
eps=1e-5,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix="ernie.mtp_hidden_norm.0",
|
||||
)
|
||||
|
||||
|
@@ -271,14 +271,14 @@ class Ernie4_5_VLDecoderLayer(nn.Layer):
|
||||
self.input_layernorm = RMSNorm(
|
||||
fd_config,
|
||||
hidden_size=fd_config.model_config.hidden_size,
|
||||
eps=1e-5,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix=f"{prefix}.input_layernorm",
|
||||
)
|
||||
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
fd_config,
|
||||
hidden_size=fd_config.model_config.hidden_size,
|
||||
eps=1e-5,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix=f"{prefix}.post_attention_layernorm",
|
||||
)
|
||||
|
||||
@@ -355,7 +355,7 @@ class Ernie4_5_VLModel(nn.Layer):
|
||||
self.norm = RMSNorm(
|
||||
fd_config,
|
||||
hidden_size=fd_config.model_config.hidden_size,
|
||||
eps=1e-5,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix=f"{fd_config.model_config.prefix_name}.norm",
|
||||
)
|
||||
|
||||
|
@@ -28,12 +28,12 @@ class ModelRegistry:
|
||||
_registry = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, model_class, suffix=""):
|
||||
def register(cls, model_class):
|
||||
"""register model class"""
|
||||
if issubclass(
|
||||
model_class,
|
||||
ModelForCasualLM) and model_class is not ModelForCasualLM:
|
||||
cls._registry[f"{model_class.name()}{suffix}"] = model_class
|
||||
cls._registry[model_class.name()] = model_class
|
||||
return model_class
|
||||
|
||||
@classmethod
|
||||
@@ -56,7 +56,7 @@ class ModelForCasualLM(nn.Layer, ABC):
|
||||
ori_vocab_size, use_topp_sampling, etc.
|
||||
"""
|
||||
super(ModelForCasualLM, self).__init__()
|
||||
|
||||
self.fd_config = configs
|
||||
@abstractmethod
|
||||
def set_state_dict(self, state_dict: Dict[str, Union[np.ndarray,
|
||||
paddle.Tensor]]):
|
||||
|
@@ -161,14 +161,14 @@ class Qwen2DecoderLayer(nn.Layer):
|
||||
self.input_layernorm = RMSNorm(
|
||||
fd_config,
|
||||
hidden_size=fd_config.model_config.hidden_size,
|
||||
eps=1e-6,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix=f"{prefix}.input_layernorm",
|
||||
)
|
||||
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
fd_config,
|
||||
hidden_size=fd_config.model_config.hidden_size,
|
||||
eps=1e-6,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix=f"{prefix}.post_attention_layernorm",
|
||||
)
|
||||
|
||||
@@ -248,7 +248,7 @@ class Qwen2Model(nn.Layer):
|
||||
self.norm = RMSNorm(
|
||||
fd_config,
|
||||
hidden_size=fd_config.model_config.hidden_size,
|
||||
eps=1e-5,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix=f"{fd_config.model_config.prefix_name}.norm",
|
||||
)
|
||||
|
||||
@@ -302,6 +302,7 @@ class Qwen2ForCausalLM(ModelForCasualLM):
|
||||
"""
|
||||
super(Qwen2ForCausalLM, self).__init__(fd_config)
|
||||
|
||||
self.fd_config =fd_config
|
||||
self.model = Qwen2Model(fd_config=fd_config)
|
||||
|
||||
self.ori_vocab_size = fd_config.model_config.ori_vocab_size
|
||||
|
@@ -79,12 +79,12 @@ class Qwen3Attention(nn.Layer):
|
||||
|
||||
self.q_norm = RMSNorm(fd_config=fd_config,
|
||||
hidden_size=fd_config.model_config.head_dim,
|
||||
eps=1e-6,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix=f"{prefix}.q_norm",
|
||||
begin_norm_axis=2)
|
||||
self.k_norm = RMSNorm(fd_config=fd_config,
|
||||
hidden_size=fd_config.model_config.head_dim,
|
||||
eps=1e-6,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix=f"{prefix}.k_norm",
|
||||
begin_norm_axis=2)
|
||||
|
||||
@@ -184,7 +184,7 @@ class Qwen3Model(nn.Layer):
|
||||
self.norm = RMSNorm(
|
||||
fd_config,
|
||||
hidden_size=fd_config.model_config.hidden_size,
|
||||
eps=1e-6,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix=f"{fd_config.model_config.prefix_name}.norm",
|
||||
)
|
||||
|
||||
|
@@ -121,12 +121,12 @@ class Qwen3Attention(nn.Layer):
|
||||
|
||||
self.q_norm = RMSNorm(fd_config,
|
||||
hidden_size=self.head_dim,
|
||||
eps=1e-6,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix=f"{prefix}.q_norm",
|
||||
begin_norm_axis=2)
|
||||
self.k_norm = RMSNorm(fd_config,
|
||||
hidden_size=self.head_dim,
|
||||
eps=1e-6,
|
||||
eps=fd_config.model_config.rms_norm_eps,
|
||||
prefix=f"{prefix}.k_norm",
|
||||
begin_norm_axis=2)
|
||||
|
||||
|
@@ -13,10 +13,19 @@
|
||||
# limitations under the License.
|
||||
"""fastdeploy gpu ops"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from fastdeploy.import_ops import import_custom_ops
|
||||
|
||||
PACKAGE = "fastdeploy.model_executor.ops.gpu"
|
||||
|
||||
import_custom_ops(PACKAGE, "..base.fastdeploy_base_ops", globals())
|
||||
import_custom_ops(PACKAGE, ".fastdeploy_ops", globals())
|
||||
|
||||
|
||||
def tolerant_import_error():
|
||||
class NoneModule:
|
||||
def __getattr__(self, name):
|
||||
return None
|
||||
|
||||
sys.modules[__name__] = NoneModule()
|
||||
|
@@ -20,14 +20,16 @@ import paddle
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.engine.config import SpeculativeConfig
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
get_padding_offset, save_output, set_stop_value_multi_ends,
|
||||
speculate_clear_accept_nums, speculate_get_output_padding_offset,
|
||||
speculate_get_padding_offset, speculate_get_seq_lens_output,
|
||||
speculate_save_output, speculate_set_value_by_flags_and_idx,
|
||||
speculate_step_paddle, speculate_step_system_cache, speculate_update_v3,
|
||||
step_paddle, step_system_cache, update_inputs, step_reschedule)
|
||||
get_padding_offset, save_output, save_output_topk,
|
||||
set_stop_value_multi_ends, speculate_clear_accept_nums,
|
||||
speculate_get_output_padding_offset, speculate_get_padding_offset,
|
||||
speculate_get_seq_lens_output, speculate_save_output,
|
||||
speculate_set_value_by_flags_and_idx, speculate_step_paddle,
|
||||
speculate_step_system_cache, speculate_update_v3, step_paddle,
|
||||
step_reschedule, step_system_cache, update_inputs)
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.worker.output import ModelOutputData
|
||||
from fastdeploy.worker.output import (ModelOutputData, ModelRunnerOutput,
|
||||
SamplerOutput)
|
||||
|
||||
DISABLE_RECOVER = (envs.FD_DISABLED_RECOVER == "1")
|
||||
|
||||
@@ -102,10 +104,10 @@ def pre_process(
|
||||
cu_seqlens_k, output_cum_offsets, output_padding_offset)
|
||||
|
||||
|
||||
def post_process_normal(sampled_token_ids: paddle.Tensor,
|
||||
def post_process_normal(sampler_output: SamplerOutput,
|
||||
model_output: ModelOutputData,
|
||||
save_each_rank: bool = False,
|
||||
skip_save_output: bool = False) -> None:
|
||||
skip_save_output: bool = False) -> ModelRunnerOutput:
|
||||
""" Post-processing steps after completing a single token generation. """
|
||||
# 1. Set stop value
|
||||
paddle.assign(
|
||||
@@ -123,7 +125,7 @@ def post_process_normal(sampled_token_ids: paddle.Tensor,
|
||||
model_output.stop_flags,
|
||||
)
|
||||
# TODO(gongshaotian): Add use_stop_seqs
|
||||
set_stop_value_multi_ends(sampled_token_ids, model_output.stop_flags,
|
||||
set_stop_value_multi_ends(sampler_output.sampled_token_ids, model_output.stop_flags,
|
||||
model_output.seq_lens_this_time,
|
||||
model_output.eos_token_id,
|
||||
model_output.next_tokens, False) # multi ends
|
||||
@@ -138,18 +140,28 @@ def post_process_normal(sampled_token_ids: paddle.Tensor,
|
||||
model_output.seq_lens_decoder,
|
||||
model_output.input_ids,
|
||||
model_output.stop_nums,
|
||||
sampled_token_ids,
|
||||
sampler_output.sampled_token_ids,
|
||||
model_output.is_block_step,
|
||||
)
|
||||
# 3. Transmit the model's output and stop generation signal via message queue.
|
||||
# In the future, we will abandon this approach.
|
||||
if not skip_save_output:
|
||||
save_output(
|
||||
sampled_token_ids,
|
||||
model_output.not_need_stop,
|
||||
model_output.mp_rank,
|
||||
save_each_rank, # save_each_rank
|
||||
)
|
||||
if sampler_output.logprobs_tensors is None:
|
||||
save_output(
|
||||
sampler_output.sampled_token_ids,
|
||||
model_output.not_need_stop,
|
||||
model_output.mp_rank,
|
||||
save_each_rank, # save_each_rank
|
||||
)
|
||||
else:
|
||||
save_output_topk(
|
||||
sampler_output.sampled_token_ids,
|
||||
sampler_output.logprobs_tensors.logprob_token_ids,
|
||||
sampler_output.logprobs_tensors.logprobs,
|
||||
sampler_output.logprobs_tensors.selected_token_ranks,
|
||||
model_output.not_need_stop,
|
||||
model_output.mp_rank,
|
||||
)
|
||||
|
||||
def post_process_specualate(model_output, skip_save_output: bool = False):
|
||||
""""""
|
||||
@@ -193,7 +205,7 @@ def post_process_specualate(model_output, skip_save_output: bool = False):
|
||||
)
|
||||
|
||||
|
||||
def post_process(sampled_token_ids: paddle.Tensor,
|
||||
def post_process(sampler_output: SamplerOutput,
|
||||
model_output: ModelOutputData,
|
||||
save_each_rank: bool = False,
|
||||
speculative_decoding: bool = False,
|
||||
@@ -202,7 +214,7 @@ def post_process(sampled_token_ids: paddle.Tensor,
|
||||
if speculative_decoding:
|
||||
post_process_specualate(model_output, skip_save_output)
|
||||
else:
|
||||
post_process_normal(sampled_token_ids, model_output, save_each_rank,
|
||||
post_process_normal(sampler_output, model_output, save_each_rank,
|
||||
skip_save_output)
|
||||
|
||||
|
||||
@@ -217,7 +229,7 @@ def step_cuda(
|
||||
TODO(gongshaotian): normalization name
|
||||
"""
|
||||
|
||||
|
||||
|
||||
if speculative_config.method is not None:
|
||||
if enable_prefix_caching:
|
||||
speculate_step_system_cache(
|
||||
|
@@ -30,9 +30,11 @@ from fastdeploy.inter_communicator import IPCSignal
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.utils import llm_logger, spec_logger
|
||||
from fastdeploy.worker.output import LogprobsLists
|
||||
|
||||
RECOVERY_STOP_SIGNAL = -3
|
||||
MAX_BSZ = 512
|
||||
K = 20
|
||||
MAX_DRAFT_TOKENS = 6
|
||||
SPECULATE_MAX_BSZ = 256
|
||||
|
||||
@@ -62,6 +64,13 @@ class TokenProcessor(object):
|
||||
],
|
||||
fill_value=2,
|
||||
dtype="int64")
|
||||
elif self.cfg.enable_logprob:
|
||||
self.output_tokens = paddle.full(
|
||||
shape=[MAX_BSZ * (K + 1) + 2, 1], fill_value=2, dtype="int64")
|
||||
self.output_scores = paddle.full(
|
||||
shape=[MAX_BSZ * (K + 1), 1], fill_value=0.0, dtype="float32")
|
||||
self.output_ranks = paddle.full(
|
||||
shape=[MAX_BSZ], fill_value=0, dtype="int64")
|
||||
else:
|
||||
self.output_tokens = paddle.full(shape=[MAX_BSZ + 2, 1],
|
||||
fill_value=2,
|
||||
@@ -109,12 +118,51 @@ class TokenProcessor(object):
|
||||
assert self.resource_manager is not None, "The resource manager is None, cannot run."
|
||||
if self.worker is not None:
|
||||
raise Exception("Worker is already running!")
|
||||
use_logprobs = (
|
||||
self.cfg.enable_logprob
|
||||
and not self.speculative_decoding
|
||||
and not self.cfg.parallel_config.enable_expert_parallel
|
||||
)
|
||||
|
||||
target_func = (
|
||||
self.process_sampling_with_logprob_results
|
||||
if use_logprobs else
|
||||
self.process_sampling_results
|
||||
)
|
||||
|
||||
self.worker = threading.Thread(target=target_func)
|
||||
|
||||
self.worker = threading.Thread(target=self.process_sampling_results,
|
||||
args=())
|
||||
self.worker.daemon = True
|
||||
self.worker.start()
|
||||
|
||||
def process_sampling_with_logprob_results(self):
|
||||
"""
|
||||
read tokens from paddle inference engine and process logprob results
|
||||
"""
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import get_output_topk
|
||||
else:
|
||||
raise NotImplementedError("Only CUDA platform supports logprob.")
|
||||
rank_id = self.cfg.parallel_config.local_data_parallel_id
|
||||
|
||||
while True:
|
||||
try:
|
||||
is_blocking = True
|
||||
get_output_topk(self.output_tokens, self.output_scores, self.output_ranks, K, rank_id, is_blocking)
|
||||
|
||||
if self.output_tokens[0, 0] == -2:
|
||||
continue
|
||||
llm_logger.debug(
|
||||
f"rank_id {rank_id} self.output_tokens[0, 0] {self.output_tokens[0, 0]}"
|
||||
f"rank_id {rank_id} self.output_scores[0, 0] {self.output_scores[0, 0]}"
|
||||
)
|
||||
self._process_prefill_metrics()
|
||||
self._process_sampling_with_logprob_batch_output()
|
||||
except Exception as e:
|
||||
llm_logger.info("while get input_data error: {0} {1}".format(
|
||||
e, str(traceback.format_exc())))
|
||||
|
||||
def process_sampling_results(self):
|
||||
"""
|
||||
read tokens from paddle inference engine and process
|
||||
@@ -245,6 +293,122 @@ class TokenProcessor(object):
|
||||
self.number_of_output_tokens = 0
|
||||
self.total_step = 0
|
||||
|
||||
def _process_sampling_with_logprob_batch_output(self):
|
||||
"""
|
||||
batch post-processing logprob output function
|
||||
"""
|
||||
|
||||
batch = self.output_tokens[1, 0]
|
||||
tokens = self.output_tokens[2:batch * (K + 1) + 2].numpy().reshape(
|
||||
[batch, K + 1])[:, :(K + 1)]
|
||||
scores = self.output_scores[:batch * (K + 1)].numpy().reshape(
|
||||
[batch, K + 1])[:, :(K + 1)]
|
||||
ranks = self.output_ranks[:batch].numpy()
|
||||
batch_result = list()
|
||||
for i in range(batch):
|
||||
if self.resource_manager.stop_flags[i]:
|
||||
continue
|
||||
task = self.resource_manager.tasks_list[i]
|
||||
task_id = task.request_id
|
||||
token_id = int(tokens[i, 0])
|
||||
token_ids = [token_id]
|
||||
recovery_stop = token_id == RECOVERY_STOP_SIGNAL
|
||||
if recovery_stop:
|
||||
llm_logger.info(
|
||||
f"recovery stop signal found at task {task_id}")
|
||||
if not recovery_stop and token_id < 0:
|
||||
continue
|
||||
|
||||
if task.get("prefill_chunk_info", None) is not None:
|
||||
prefill_chunk_num = task.get("prefill_chunk_num", 0)
|
||||
task.prefill_chunk_num = prefill_chunk_num + 1
|
||||
|
||||
if task.prefill_chunk_num < len(task.prefill_chunk_info):
|
||||
continue
|
||||
|
||||
self.total_step += 1
|
||||
current_time = time.time()
|
||||
if self.tokens_counter[task_id] == 0:
|
||||
metrics = RequestMetrics(
|
||||
arrival_time=task.arrival_time,
|
||||
inference_start_time=task.inference_start_time,
|
||||
first_token_time=time.time() - task.inference_start_time,
|
||||
time_in_queue=task.schedule_start_time -
|
||||
task.preprocess_end_time,
|
||||
preprocess_cost_time=task.preprocess_end_time -
|
||||
task.preprocess_start_time)
|
||||
|
||||
self._record_first_token_metrics(task, current_time)
|
||||
|
||||
else:
|
||||
metrics = RequestMetrics(
|
||||
arrival_time=time.time(),
|
||||
request_start_time=task.arrival_time,
|
||||
)
|
||||
self.number_of_output_tokens += len(token_ids)
|
||||
self._record_metrics(task, current_time, token_ids)
|
||||
result = RequestOutput(request_id=task_id,
|
||||
outputs=CompletionOutput(
|
||||
index=i,
|
||||
send_idx=self.tokens_counter[task_id],
|
||||
token_ids=[],
|
||||
logprob = None,
|
||||
draft_token_ids=[],
|
||||
top_logprobs=None,
|
||||
),
|
||||
finished=False,
|
||||
metrics=metrics)
|
||||
if self.tokens_counter[task_id] == 0:
|
||||
if task.messages is not None:
|
||||
result.prompt = task.messages
|
||||
result.num_cached_tokens = task.num_cached_tokens
|
||||
|
||||
is_prefill = task.disaggregate_info is not None and task.disaggregate_info[
|
||||
"role"] == "prefill"
|
||||
|
||||
if is_prefill and len(token_ids) > 1:
|
||||
result.outputs.draft_token_ids = copy.deepcopy(token_ids)
|
||||
|
||||
for idx, token_id in enumerate(token_ids):
|
||||
self.tokens_counter[task_id] += 1
|
||||
if token_id != RECOVERY_STOP_SIGNAL:
|
||||
result.outputs.token_ids.append(token_id)
|
||||
result.outputs.logprob = float(scores[i, 0])
|
||||
# Construct top_logprobs
|
||||
topk_token_ids = tokens[i, :].tolist()
|
||||
topk_logprobs = scores[i, :].tolist()
|
||||
sampled_rank = ranks[i].item()
|
||||
|
||||
result.outputs.top_logprobs = LogprobsLists(
|
||||
logprob_token_ids=[topk_token_ids],
|
||||
logprobs=[topk_logprobs],
|
||||
sampled_token_ranks=[sampled_rank]
|
||||
)
|
||||
if token_id in task.eos_token_ids or is_prefill or recovery_stop:
|
||||
result.finished = True
|
||||
result.prompt = task.prompt
|
||||
result.prompt_token_ids = task.prompt_token_ids
|
||||
if recovery_stop:
|
||||
result.error_msg = "Recover is not supported, the result is incomplete!"
|
||||
llm_logger.info(
|
||||
f"Request: {task_id} finished, number of "
|
||||
f"generated tokens: {self.tokens_counter[task_id]}.")
|
||||
llm_logger.info(
|
||||
f"Request: {task_id} token ratio: {self.tokens_counter[task_id] / (time.time() - task.inference_start_time)}"
|
||||
)
|
||||
llm_logger.info(f"{self.resource_manager.info()}")
|
||||
if self.cfg.speculative_config.method:
|
||||
self._compute_speculative_status()
|
||||
if not is_prefill:
|
||||
self._record_completion_metrics(task, current_time)
|
||||
self._recycle_resources(task_id, i, task, result,
|
||||
is_prefill)
|
||||
break
|
||||
if not is_prefill or self.cfg.scheduler_config.name == "splitwise":
|
||||
batch_result.append(result)
|
||||
|
||||
self.postprocess(batch_result)
|
||||
|
||||
def _process_batch_output(self):
|
||||
"""
|
||||
batch post-processing function
|
||||
|
@@ -17,4 +17,4 @@ import os
|
||||
|
||||
from fastdeploy.model_executor.models import auto_models_registry
|
||||
|
||||
auto_models_registry(os.path.dirname(__file__), "fastdeploy.rl", suffix="RL")
|
||||
auto_models_registry(os.path.dirname(__file__), "fastdeploy.rl")
|
||||
|
110
fastdeploy/rl/rollout_config.py
Normal file
110
fastdeploy/rl/rollout_config.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from fastdeploy.worker.worker_process import initialize_fd_config
|
||||
|
||||
|
||||
class RolloutModelConfig:
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
max_model_len: int = 32768,
|
||||
tensor_parallel_size: int = 4,
|
||||
dynamic_load_weight: bool = True,
|
||||
load_strategy: str = "meta",
|
||||
enable_mm: bool = False,
|
||||
# Default values for all other parameters
|
||||
max_num_seqs: int = 34,
|
||||
total_block_num: int = 2000,
|
||||
block_size: int = 64,
|
||||
engine_worker_queue_port: int = 9923,
|
||||
device_ids: str = "0",
|
||||
dtype: str = "bfloat16",
|
||||
enc_dec_block_num: int = 1,
|
||||
kv_cache_ratio: float = 0.7,
|
||||
first_token_id: int = 1,
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
engine_pid: int = None,
|
||||
do_profile: bool = False,
|
||||
pad_token_id: int = -1,
|
||||
eos_tokens_lens: int = 2,
|
||||
enable_chunked_prefill: bool = False,
|
||||
speculative_method: str = None,
|
||||
speculative_max_draft_token_num: int = 1,
|
||||
speculative_model_name_or_path: str = "",
|
||||
speculative_model_quantization: str = "WINT8",
|
||||
max_num_batched_tokens: int = 2048,
|
||||
enable_prefix_caching: bool = False,
|
||||
splitwise_role: str = "mixed",
|
||||
expert_parallel_size: int = 1,
|
||||
enable_expert_parallell: bool = False,
|
||||
ori_vocab_size: int = None,
|
||||
quantization: str = "None",
|
||||
enable_static_graph_inference: bool = False,
|
||||
use_cudagraph: bool = False,
|
||||
max_capture_batch_size: int = 64,
|
||||
guided_decoding_backend: str = "off",
|
||||
disable_any_whitespace: bool = True,
|
||||
enable_logprob: bool = False,
|
||||
):
|
||||
# Required parameters
|
||||
self.model_name_or_path = model_name_or_path
|
||||
self.max_model_len = max_model_len
|
||||
self.tensor_parallel_size = tensor_parallel_size
|
||||
self.dynamic_load_weight = dynamic_load_weight
|
||||
self.load_strategy = load_strategy
|
||||
self.enable_mm = enable_mm
|
||||
|
||||
# Optional parameters with defaults
|
||||
self.max_num_seqs = max_num_seqs
|
||||
self.total_block_num = total_block_num
|
||||
self.block_size = block_size
|
||||
self.engine_worker_queue_port = engine_worker_queue_port
|
||||
self.device_ids = device_ids
|
||||
self.dtype = dtype
|
||||
self.enc_dec_block_num = enc_dec_block_num
|
||||
self.kv_cache_ratio = kv_cache_ratio
|
||||
self.first_token_id = first_token_id
|
||||
self.gpu_memory_utilization = gpu_memory_utilization
|
||||
self.engine_pid = engine_pid
|
||||
self.do_profile = do_profile
|
||||
self.pad_token_id = pad_token_id
|
||||
self.eos_tokens_lens = eos_tokens_lens
|
||||
self.enable_chunked_prefill = enable_chunked_prefill
|
||||
self.speculative_method = speculative_method
|
||||
self.speculative_max_draft_token_num = speculative_max_draft_token_num
|
||||
self.speculative_model_name_or_path = speculative_model_name_or_path
|
||||
self.speculative_model_quantization = speculative_model_quantization
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.enable_prefix_caching = enable_prefix_caching
|
||||
self.splitwise_role = splitwise_role
|
||||
self.expert_parallel_size = expert_parallel_size
|
||||
self.enable_expert_parallell = enable_expert_parallell
|
||||
self.ori_vocab_size = ori_vocab_size
|
||||
self.quantization = quantization
|
||||
self.enable_static_graph_inference = enable_static_graph_inference
|
||||
self.use_cudagraph = use_cudagraph
|
||||
self.max_capture_batch_size = max_capture_batch_size
|
||||
self.guided_decoding_backend = guided_decoding_backend
|
||||
self.disable_any_whitespace = disable_any_whitespace
|
||||
self.enable_logprob = enable_logprob
|
||||
|
||||
def __str__(self):
|
||||
return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items())
|
||||
|
||||
def initialize(self):
|
||||
"""Initialize the final fd config"""
|
||||
return initialize_fd_config(self)
|
@@ -24,25 +24,19 @@ from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.model_loader import ModelRegistry
|
||||
from fastdeploy.model_executor.models.ernie4_5_moe import \
|
||||
Ernie4_5_MoeForCausalLM
|
||||
from fastdeploy.model_executor.models.qwen2 import Qwen2PretrainedModel
|
||||
from fastdeploy.model_executor.models.qwen3 import Qwen3PretrainedModel
|
||||
from fastdeploy.model_executor.models.qwen3moe import Qwen3MoePretrainedModel
|
||||
|
||||
RL_MODEL_CLASSES = {
|
||||
"Ernie4_5_MoeForCausalLMRL": Ernie4_5_MoeForCausalLM,
|
||||
"Qwen2ForCausalLMRL": Qwen2PretrainedModel,
|
||||
"Qwen3ForCausalLMRL": Qwen3PretrainedModel,
|
||||
"Qwen3MoeForCausalLMRL": Qwen3MoePretrainedModel,
|
||||
}
|
||||
from fastdeploy.model_executor.models.qwen2 import Qwen2ForCausalLM
|
||||
from fastdeploy.model_executor.models.qwen3 import Qwen3ForCausalLM
|
||||
from fastdeploy.model_executor.models.qwen3moe import Qwen3MoeForCausalLM
|
||||
from fastdeploy.rl.rollout_config import RolloutModelConfig
|
||||
|
||||
|
||||
class RollOutModel(nn.Layer):
|
||||
class RolloutModel(nn.Layer):
|
||||
"""Main model class for rollout operations, supports multimodal components for train."""
|
||||
|
||||
def __init__(self, fd_config: FDConfig):
|
||||
def __init__(self, rollout_model_config: RolloutModelConfig):
|
||||
"""Initialize with FastDeploy configuration."""
|
||||
super(RollOutModel, self).__init__()
|
||||
self.fd_config = fd_config
|
||||
super(RolloutModel, self).__init__()
|
||||
self.fd_config = rollout_model_config.initialize()
|
||||
self._init_models()
|
||||
|
||||
def _init_models(self):
|
||||
@@ -90,9 +84,9 @@ class RollOutModel(nn.Layer):
|
||||
all_params = {}
|
||||
for model in self.rollout_models:
|
||||
for name, param in model.state_dict().items():
|
||||
logger.debug(
|
||||
f"Model param: {name}, shape={param.shape}, dtype={param.dtype}"
|
||||
)
|
||||
# logger.debug(
|
||||
# f"Model param: {name}, shape={param.shape}, dtype={param.dtype}"
|
||||
# )
|
||||
all_params[name] = param
|
||||
return all_params
|
||||
|
||||
@@ -123,11 +117,13 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM):
|
||||
# Initialize mapping dictionary
|
||||
infer_to_train = {}
|
||||
|
||||
infer_base_name = "model"
|
||||
train_base_name = "ernie"
|
||||
# Static mappings (non-layer specific)
|
||||
static_mappings = {
|
||||
"model.embeddings.word_embeddings.weight":
|
||||
"ernie.embed_tokens.weight",
|
||||
"model.norm.ln_weight": "ernie.norm.weight",
|
||||
f"{infer_base_name}.embeddings.word_embeddings.weight":
|
||||
f"{train_base_name}.embed_tokens.weight",
|
||||
f"{infer_base_name}.norm.ln_weight": f"{train_base_name}.norm.weight",
|
||||
"lm_head.out_linear.weight": "lm_head.weight"
|
||||
}
|
||||
if self.fd_config.model_config.get("weight_sharing", False):
|
||||
@@ -135,53 +131,55 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM):
|
||||
logger.debug("enable tie_word_embeddings")
|
||||
static_mappings.pop("lm_head.out_linear.weight")
|
||||
infer_to_train.update(static_mappings)
|
||||
infer_base_name = "model.hidden_layers"
|
||||
|
||||
infer_base_name = infer_base_name + ".hidden_layers"
|
||||
train_base_name = train_base_name + ".layers"
|
||||
|
||||
# Helper function to add layer mappings
|
||||
def _add_layer_mappings(layer_idx, is_moe_layer=False):
|
||||
# Handle special case for layer 0's input layernorm
|
||||
for ph in place_holders:
|
||||
infer_key = f"{infer_base_name}.{layer_idx}.input_layernorm.ln_{ph}"
|
||||
train_key = f"ernie.layers.{layer_idx}.input_layernorm.{ph}"
|
||||
train_key = f"{train_base_name}.{layer_idx}.input_layernorm.{ph}"
|
||||
infer_to_train[infer_key] = train_key
|
||||
|
||||
# Common attention mappings
|
||||
for ph in place_holders:
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.qkv_proj.linear_{ph}"] = \
|
||||
f"ernie.layers.{layer_idx}.self_attn.qkv_proj.{ph}"
|
||||
f"{train_base_name}.{layer_idx}.self_attn.qkv_proj.{ph}"
|
||||
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.o_proj.linear_{ph}"] = \
|
||||
f"ernie.layers.{layer_idx}.self_attn.o_proj.{ph}"
|
||||
f"{train_base_name}.{layer_idx}.self_attn.o_proj.{ph}"
|
||||
|
||||
# Post-attention layernorm
|
||||
for ph in place_holders:
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.post_attention_layernorm.ln_{ph}"] = \
|
||||
f"ernie.layers.{layer_idx}.post_attention_layernorm.{ph}"
|
||||
f"{train_base_name}.{layer_idx}.post_attention_layernorm.{ph}"
|
||||
|
||||
if not is_moe_layer:
|
||||
# Dense FFN mappings
|
||||
for ph in place_holders:
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.gate_up_proj.linear_{ph}"] = \
|
||||
f"ernie.layers.{layer_idx}.mlp.up_gate_proj.{ph}"
|
||||
f"{train_base_name}.{layer_idx}.mlp.up_gate_proj.{ph}"
|
||||
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.down_proj.linear_{ph}"] = \
|
||||
f"ernie.layers.{layer_idx}.mlp.down_proj.{ph}"
|
||||
f"{train_base_name}.{layer_idx}.mlp.down_proj.{ph}"
|
||||
else:
|
||||
# MoE specific mappings
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.fused_moe.gate_weight"] = \
|
||||
f"ernie.layers.{layer_idx}.mlp.gate.weight"
|
||||
f"{train_base_name}.{layer_idx}.mlp.gate.weight"
|
||||
|
||||
if self.fd_config.moe_config.moe_use_aux_free:
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.fused_moe.gate_correction_bias"] = \
|
||||
f"ernie.layers.{layer_idx}.mlp.moe_statics.e_score_correction_bias"
|
||||
f"{train_base_name}.{layer_idx}.mlp.moe_statics.e_score_correction_bias"
|
||||
|
||||
# Support shared experts
|
||||
if self.fd_config.model_config.get(
|
||||
"moe_num_shared_experts") > 0:
|
||||
"moe_num_shared_experts", 0) > 0:
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.shared_experts.gate_up_proj.linear_weight"] = \
|
||||
f"ernie.layers.{layer_idx}.mlp.shared_experts.up_gate_proj.weight"
|
||||
f"{train_base_name}.{layer_idx}.mlp.shared_experts.up_gate_proj.weight"
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.shared_experts.down_proj.linear_weight"] = \
|
||||
f"ernie.layers.{layer_idx}.mlp.shared_experts.down_proj.weight"
|
||||
f"{train_base_name}.{layer_idx}.mlp.shared_experts.down_proj.weight"
|
||||
|
||||
# MoE experts mappings
|
||||
for expert_idx in range(self.fd_config.moe_config.num_experts):
|
||||
@@ -191,7 +189,7 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM):
|
||||
if ffn1_key not in infer_to_train:
|
||||
infer_to_train[ffn1_key] = []
|
||||
infer_to_train[ffn1_key].append(
|
||||
f"ernie.layers.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}"
|
||||
f"{train_base_name}.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}"
|
||||
)
|
||||
|
||||
# FFN2 (down_proj)
|
||||
@@ -199,7 +197,7 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM):
|
||||
if ffn2_key not in infer_to_train:
|
||||
infer_to_train[ffn2_key] = []
|
||||
infer_to_train[ffn2_key].append(
|
||||
f"ernie.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}"
|
||||
f"{train_base_name}.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}"
|
||||
)
|
||||
|
||||
# Process non-MoE layers
|
||||
@@ -213,3 +211,214 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM):
|
||||
_add_layer_mappings(layer_idx, is_moe_layer=True)
|
||||
|
||||
return infer_to_train
|
||||
|
||||
|
||||
class Qwen2ForCausalLMRL(Qwen2ForCausalLM):
|
||||
"""
|
||||
Qwen2ForCausalLMRL
|
||||
"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig):
|
||||
"""
|
||||
Args:
|
||||
fd_config (FDConfig): Configurations for the LLM model.
|
||||
"""
|
||||
super(Qwen2ForCausalLMRL, self).__init__(fd_config)
|
||||
|
||||
@classmethod
|
||||
def name(self):
|
||||
"""name"""
|
||||
return "Qwen2ForCausalLMRL"
|
||||
|
||||
def get_name_mappings_to_training(self):
|
||||
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
|
||||
# Prepare placeholders
|
||||
place_holders = ["weight"]
|
||||
|
||||
# Initialize mapping dictionary
|
||||
infer_to_train = {}
|
||||
|
||||
infer_base_name = "model"
|
||||
train_base_name = "qwen2"
|
||||
# Static mappings (non-layer specific)
|
||||
static_mappings = {
|
||||
f"{infer_base_name}.embeddings.word_embeddings.weight":
|
||||
f"{train_base_name}.embed_tokens.weight",
|
||||
f"{infer_base_name}.norm.ln_weight": f"{train_base_name}.norm.weight",
|
||||
"lm_head.out_linear.weight": "lm_head.weight"
|
||||
}
|
||||
infer_to_train.update(static_mappings)
|
||||
|
||||
infer_base_name = infer_base_name + ".layers"
|
||||
train_base_name = train_base_name + ".layers"
|
||||
|
||||
# Helper function to add layer mappings
|
||||
def _add_layer_mappings(layer_idx):
|
||||
# Handle special case for layer 0's input layernorm and attn o_proj
|
||||
for ph in place_holders:
|
||||
infer_key = f"{infer_base_name}.{layer_idx}.input_layernorm.ln_{ph}"
|
||||
train_key = f"{train_base_name}.{layer_idx}.input_layernorm.{ph}"
|
||||
infer_to_train[infer_key] = train_key
|
||||
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.o_proj.linear_{ph}"] = \
|
||||
f"{train_base_name}.{layer_idx}.self_attn.o_proj.{ph}"
|
||||
|
||||
# qwen qkv proj need bias
|
||||
for ph in ["weight", "bias"]:
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.qkv_proj.linear_{ph}"] = \
|
||||
f"{train_base_name}.{layer_idx}.self_attn.qkv_proj.{ph}"
|
||||
|
||||
# Post-attention layernorm
|
||||
for ph in place_holders:
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.post_attention_layernorm.ln_{ph}"] = \
|
||||
f"{train_base_name}.{layer_idx}.post_attention_layernorm.{ph}"
|
||||
|
||||
# FFN mappings
|
||||
for ph in place_holders:
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.gate_up_proj.linear_{ph}"] = \
|
||||
f"{train_base_name}.{layer_idx}.mlp.gate_up_fused_proj.{ph}"
|
||||
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.down_proj.linear_{ph}"] = \
|
||||
f"{train_base_name}.{layer_idx}.mlp.down_proj.{ph}"
|
||||
|
||||
for layer_idx in range(
|
||||
self.fd_config.model_config.num_layers):
|
||||
_add_layer_mappings(layer_idx)
|
||||
|
||||
return infer_to_train
|
||||
|
||||
|
||||
class Qwen3MoeForCausalLMRL(Qwen3MoeForCausalLM):
|
||||
"""
|
||||
Qwen3MoeForCausalLMRL
|
||||
"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig):
|
||||
"""
|
||||
Args:
|
||||
fd_config (FDConfig): Configurations for the LLM model.
|
||||
"""
|
||||
super(Qwen3MoeForCausalLMRL, self).__init__(fd_config)
|
||||
|
||||
@classmethod
|
||||
def name(self):
|
||||
"""name"""
|
||||
return "Qwen3MoeForCausalLMRL"
|
||||
|
||||
def get_name_mappings_to_training(self):
|
||||
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
|
||||
# Prepare placeholders
|
||||
place_holders = ["weight"]
|
||||
|
||||
# Initialize mapping dictionary
|
||||
infer_to_train = {}
|
||||
|
||||
infer_base_name = "model"
|
||||
train_base_name = "model"
|
||||
# Static mappings (non-layer specific)
|
||||
static_mappings = {
|
||||
f"{infer_base_name}.embeddings.word_embeddings.weight":
|
||||
f"{train_base_name}.embed_tokens.weight",
|
||||
f"{infer_base_name}.norm.ln_weight": f"{train_base_name}.norm.weight",
|
||||
"lm_head.out_linear.weight": "lm_head.weight"
|
||||
}
|
||||
infer_to_train.update(static_mappings)
|
||||
|
||||
infer_base_name = infer_base_name + ".layers"
|
||||
train_base_name = train_base_name + ".layers"
|
||||
|
||||
# Helper function to add layer mappings
|
||||
def _add_layer_mappings(layer_idx, is_moe_layer=False):
|
||||
# Handle special case for layer 0's input layernorm and attn o_proj
|
||||
for ph in place_holders:
|
||||
infer_key = f"{infer_base_name}.{layer_idx}.input_layernorm.ln_{ph}"
|
||||
train_key = f"{train_base_name}.{layer_idx}.input_layernorm.{ph}"
|
||||
infer_to_train[infer_key] = train_key
|
||||
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.o_proj.linear_{ph}"] = \
|
||||
f"{train_base_name}.{layer_idx}.self_attn.o_proj.{ph}"
|
||||
|
||||
# qwen q_norm/k_norm
|
||||
for ph in place_holders:
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.q_norm.ln_{ph}"] = \
|
||||
f"{train_base_name}.{layer_idx}.self_attn.q_norm.{ph}"
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.k_norm.ln_{ph}"] = \
|
||||
f"{train_base_name}.{layer_idx}.self_attn.k_norm.{ph}"
|
||||
|
||||
# qwen qkv proj
|
||||
for ph in place_holders:
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.qkv_proj.linear_{ph}"] = \
|
||||
f"{train_base_name}.{layer_idx}.self_attn.qkv_proj.{ph}"
|
||||
|
||||
# Post-attention layernorm
|
||||
for ph in place_holders:
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.post_attention_layernorm.ln_{ph}"] = \
|
||||
f"{train_base_name}.{layer_idx}.post_attention_layernorm.{ph}"
|
||||
|
||||
if not is_moe_layer:
|
||||
# FFN mappings
|
||||
for ph in place_holders:
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.gate_up_proj.linear_{ph}"] = \
|
||||
f"{train_base_name}.{layer_idx}.mlp.gate_up_fused_proj.{ph}"
|
||||
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.down_proj.linear_{ph}"] = \
|
||||
f"{train_base_name}.{layer_idx}.mlp.down_proj.{ph}"
|
||||
else:
|
||||
# MoE specific mappings
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.gate_weight"] = \
|
||||
f"{train_base_name}.{layer_idx}.mlp.gate.weight"
|
||||
|
||||
if self.fd_config.moe_config.moe_use_aux_free:
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.fused_moe.gate_correction_bias"] = \
|
||||
f"{train_base_name}.{layer_idx}.mlp.moe_statics.e_score_correction_bias"
|
||||
|
||||
# Support shared experts
|
||||
if self.fd_config.model_config.get(
|
||||
"moe_num_shared_experts", 0) > 0:
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.shared_experts.gate_up_proj.linear_weight"] = \
|
||||
f"{train_base_name}.{layer_idx}.mlp.shared_experts.up_gate_proj.weight"
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.shared_experts.down_proj.linear_weight"] = \
|
||||
f"{train_base_name}.{layer_idx}.mlp.shared_experts.down_proj.weight"
|
||||
|
||||
# MoE experts mappings
|
||||
for expert_idx in range(self.fd_config.moe_config.num_experts):
|
||||
for ph in place_holders:
|
||||
# FFN1 (up_gate_proj)
|
||||
ffn1_key = f"{infer_base_name}.{layer_idx}.mlp.moe_ffn1_weight"
|
||||
if ffn1_key not in infer_to_train:
|
||||
infer_to_train[ffn1_key] = []
|
||||
infer_to_train[ffn1_key].append(
|
||||
f"{train_base_name}.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}"
|
||||
)
|
||||
|
||||
# FFN2 (down_proj)
|
||||
ffn2_key = f"{infer_base_name}.{layer_idx}.mlp.moe_ffn2_weight"
|
||||
if ffn2_key not in infer_to_train:
|
||||
infer_to_train[ffn2_key] = []
|
||||
infer_to_train[ffn2_key].append(
|
||||
f"{train_base_name}.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}"
|
||||
)
|
||||
|
||||
# Process MoE layers
|
||||
for layer_idx in range(self.fd_config.model_config.num_layers):
|
||||
_add_layer_mappings(layer_idx, is_moe_layer=True)
|
||||
|
||||
return infer_to_train
|
||||
|
||||
|
||||
class Qwen3ForCausalLMRL(Qwen3ForCausalLM):
|
||||
"""
|
||||
Qwen3ForCausalLMRL
|
||||
"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig):
|
||||
"""
|
||||
Args:
|
||||
fd_config (FDConfig): Configurations for the LLM model.
|
||||
"""
|
||||
super(Qwen3ForCausalLMRL, self).__init__(fd_config)
|
||||
|
||||
@classmethod
|
||||
def name(self):
|
||||
"""name"""
|
||||
return "Qwen3ForCausalLMRL"
|
||||
|
@@ -110,7 +110,7 @@ class GlobalSchedulerConfig:
|
||||
ttl: int = 900,
|
||||
min_load_score: float = 3,
|
||||
max_model_len: int = 8192,
|
||||
load_shrads_num: int = 1,
|
||||
load_shards_num: int = 1,
|
||||
enable_chunked_prefill: bool = False,
|
||||
max_num_partial_prefills: int = 1,
|
||||
max_long_partial_prefills: int = 1,
|
||||
@@ -129,7 +129,7 @@ class GlobalSchedulerConfig:
|
||||
ttl: Time-to-live in seconds for Redis keys (default 900s)
|
||||
min_load_score: Minimum load score for task assignment (default 3)
|
||||
max_model_len: Maximum model context length in tokens
|
||||
load_shrads_num: Number of load balancing shards
|
||||
load_shards_num: Number of load balancing shards
|
||||
enable_chunked_prefill: Whether to enable chunked prefill processing
|
||||
max_num_partial_prefills: Max partial prefill operations allowed
|
||||
max_long_partial_prefills: Max long-running partial prefill ops
|
||||
@@ -147,7 +147,7 @@ class GlobalSchedulerConfig:
|
||||
self.topic = topic
|
||||
self.ttl = ttl
|
||||
self.min_load_score = min_load_score
|
||||
self.load_shrads_num = load_shrads_num
|
||||
self.load_shards_num = load_shards_num
|
||||
|
||||
self.max_model_len = max_model_len
|
||||
self.enable_chunked_prefill = enable_chunked_prefill
|
||||
@@ -169,8 +169,8 @@ class GlobalSchedulerConfig:
|
||||
raise ValueError("ttl should be greater than 60")
|
||||
if self.min_load_score < 1:
|
||||
raise ValueError("min_load_score should be greater than 0")
|
||||
if self.load_shrads_num < 1:
|
||||
raise ValueError("load_shrads_num should be greater than 0")
|
||||
if self.load_shards_num < 1:
|
||||
raise ValueError("load_shards_num should be greater than 0")
|
||||
|
||||
r = redis.Redis(self.host, self.port, self.db, self.password)
|
||||
try:
|
||||
@@ -262,7 +262,7 @@ class SchedulerConfig:
|
||||
topic=self.config.topic,
|
||||
ttl=self.config.ttl,
|
||||
min_load_score=self.config.min_load_score,
|
||||
load_shrads_num=self.config.load_shrads_num,
|
||||
load_shards_num=self.config.load_shards_num,
|
||||
enable_chunked_prefill=self.config.enable_chunked_prefill,
|
||||
max_num_partial_prefills=self.config.max_num_partial_prefills,
|
||||
max_long_partial_prefills=self.config.max_long_partial_prefills,
|
||||
|
@@ -19,7 +19,6 @@ from typing import List, Optional, Dict, Tuple
|
||||
import traceback
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
import random
|
||||
import uuid
|
||||
import crcmod
|
||||
@@ -28,7 +27,7 @@ from fastdeploy.scheduler.storage import AdaptedRedis
|
||||
from fastdeploy.engine.request import Request, RequestOutput
|
||||
from fastdeploy.scheduler.data import ScheduledRequest, ScheduledResponse
|
||||
from fastdeploy.scheduler.workers import Workers, Task
|
||||
from fastdeploy.utils import llm_logger
|
||||
from fastdeploy.utils import scheduler_logger
|
||||
from fastdeploy.scheduler import utils
|
||||
|
||||
|
||||
@@ -51,7 +50,7 @@ class GlobalScheduler(object):
|
||||
topic: str,
|
||||
ttl: int,
|
||||
min_load_score: float,
|
||||
load_shrads_num: int,
|
||||
load_shards_num: int,
|
||||
enable_chunked_prefill: bool,
|
||||
max_num_partial_prefills: int,
|
||||
max_long_partial_prefills: int,
|
||||
@@ -68,7 +67,7 @@ class GlobalScheduler(object):
|
||||
topic: Base topic name for queue namespacing
|
||||
ttl: Time-to-live in seconds for Redis keys
|
||||
min_load_score: Minimum load score for task assignment
|
||||
load_shrads_num: Number of shards for load balancing table
|
||||
load_shards_num: Number of shards for load balancing table
|
||||
enable_chunked_prefill: Whether to enable chunked prefill processing
|
||||
max_num_partial_prefills: Maximum number of partial prefills allowed
|
||||
max_long_partial_prefills: Maximum number of long partial prefills allowed
|
||||
@@ -84,7 +83,7 @@ class GlobalScheduler(object):
|
||||
self.topic = topic
|
||||
self.ttl = ttl
|
||||
self.min_load_score = min_load_score
|
||||
self.load_shrads_num = load_shrads_num
|
||||
self.load_shards_num = load_shards_num
|
||||
|
||||
self.enable_chunked_prefill = enable_chunked_prefill
|
||||
self.max_num_partial_prefills = max_num_partial_prefills
|
||||
@@ -97,14 +96,17 @@ class GlobalScheduler(object):
|
||||
self.crc16_mutex = threading.Lock()
|
||||
self.crc16 = crcmod.predefined.Crc('ccitt-false')
|
||||
self.load_slot_for_getting_request = 0
|
||||
self.load_start = 0 # const
|
||||
self.load_num = 50 # const
|
||||
self.load_offset = 0 # const
|
||||
self.load_count = 50 # const
|
||||
self.load_lookup_num = 5 # const
|
||||
self.keep_alive_duration = 30 # const
|
||||
|
||||
connection_pool = ConnectionPool(
|
||||
host=host, port=port, db=db, password=password, max_connections=10)
|
||||
self.client = AdaptedRedis(connection_pool=connection_pool)
|
||||
|
||||
self.name = self._generate_scheduler_name()
|
||||
self.name, self.shard = self._generate_scheduler_name_and_shard()
|
||||
|
||||
self.keep_alive_workers = threading.Thread(
|
||||
target=self._keep_alive, daemon=True)
|
||||
self.keep_alive_workers.start()
|
||||
@@ -126,10 +128,32 @@ class GlobalScheduler(object):
|
||||
target=self._get_results_worker, daemon=True)
|
||||
self.get_response_workers.start()
|
||||
|
||||
llm_logger.info(
|
||||
scheduler_logger.info(
|
||||
f"Scheduler: name={self.name} redis_version={self.client.version}")
|
||||
|
||||
def _get_hash_slot(self, data: str) -> int:
|
||||
"""
|
||||
Calculate the hash slot for a given string using CRC16 algorithm.
|
||||
|
||||
This method is thread-safe and used for consistent hashing in distributed scheduling.
|
||||
It implements the same CRC16 algorithm (CCITT-FALSE variant) used by Redis Cluster.
|
||||
|
||||
Args:
|
||||
data: Input string to be hashed (typically a scheduler or request identifier)
|
||||
|
||||
Returns:
|
||||
int: A 16-bit hash value (0-65535) representing the calculated slot
|
||||
|
||||
Implementation Details:
|
||||
1. Encodes input string as UTF-8 bytes
|
||||
2. Uses thread-safe CRC16 calculation with mutex protection
|
||||
3. Resets CRC state after each calculation
|
||||
4. Returns raw CRC value without modulo operation
|
||||
|
||||
Note:
|
||||
- The result is typically used with modulo operation for sharding (e.g. % num_shards)
|
||||
- Matches Redis Cluster's slot distribution algorithm for compatibility
|
||||
"""
|
||||
data = data.encode("utf-8")
|
||||
with self.crc16_mutex:
|
||||
self.crc16.update(data)
|
||||
@@ -149,58 +173,66 @@ class GlobalScheduler(object):
|
||||
"""
|
||||
return f"{self.topic}.ins.{scheduler_name}"
|
||||
|
||||
def _generate_scheduler_name(self) -> str:
|
||||
def _generate_scheduler_name_and_shard(self) -> Tuple[str, int]:
|
||||
"""
|
||||
Generate a unique name for this scheduler instance.
|
||||
Generate a unique scheduler name and calculate its shard assignment.
|
||||
|
||||
Uses hostname/IP and timestamp to create a unique identifier,
|
||||
then registers it in Redis with TTL.
|
||||
This method:
|
||||
1. Creates a unique identifier using hostname/IP and timestamp
|
||||
2. Registers the name in Redis with TTL
|
||||
3. Calculates the shard assignment using consistent hashing
|
||||
4. Handles naming conflicts by appending incrementing suffixes
|
||||
|
||||
Returns:
|
||||
Unique scheduler name string
|
||||
Tuple[str, int]:
|
||||
- str: Unique scheduler name
|
||||
- int: Assigned shard number (0 to load_shards_num-1)
|
||||
|
||||
Implementation Details:
|
||||
- Uses hostname/IP as base identifier, falls back to UUID if unavailable
|
||||
- Implements conflict resolution with incrementing suffixes
|
||||
- Registers name in Redis with keep-alive duration
|
||||
- Calculates shard using CRC16 hash of the name
|
||||
|
||||
Error Handling:
|
||||
- Logs IP resolution failures
|
||||
- Handles Redis registration conflicts gracefully
|
||||
- Ensures unique name generation even in edge cases
|
||||
"""
|
||||
try:
|
||||
_, name = utils.get_hostname_ip()
|
||||
except Exception as e:
|
||||
llm_logger.warning(
|
||||
scheduler_logger.warning(
|
||||
f"Scheduler encountered an error while resolving the IP address. {e}")
|
||||
name = str(uuid.uuid4())
|
||||
|
||||
size = len(name)
|
||||
now = time.time()
|
||||
local_time = datetime.fromtimestamp(now)
|
||||
formatted_time = local_time.strftime(
|
||||
"%Y-%m-%d %H:%M:%S") + f"{local_time.microsecond // 1000:03d}"
|
||||
|
||||
count = 1
|
||||
while True:
|
||||
if self.client.set(self._instance_name(name), formatted_time, ex=self.ttl, nx=True):
|
||||
if self.client.set(self._instance_name(name), "", ex=self.keep_alive_duration, nx=True):
|
||||
break
|
||||
name = f"{name[:size]}:{count}"
|
||||
count += 1
|
||||
return name
|
||||
|
||||
shard = self._get_hash_slot(name) % self.load_shards_num
|
||||
self.client.set(self._instance_name(name), self._load_table_name(shard=shard),
|
||||
ex=self.keep_alive_duration)
|
||||
return name, shard
|
||||
|
||||
def _keep_alive(self):
|
||||
"""
|
||||
Background thread that periodically updates the scheduler's TTL in Redis.
|
||||
|
||||
Runs in a loop with interval of TTL/2 to maintain instance registration.
|
||||
Runs in a loop with interval of keep_alive_duration/2 to maintain instance registration.
|
||||
"""
|
||||
interval_time = self.ttl / 2
|
||||
while True:
|
||||
try:
|
||||
now = time.time()
|
||||
local_time = datetime.fromtimestamp(now)
|
||||
formatted_time = local_time.strftime(
|
||||
"%Y-%m-%d %H:%M:%S") + f"{local_time.microsecond // 1000:03d}"
|
||||
self.client.set(self._instance_name(self.name),
|
||||
formatted_time, ex=self.ttl)
|
||||
self.client.set(self._instance_name(
|
||||
self.name), self._load_table_name(), ex=self.keep_alive_duration)
|
||||
time.sleep(self.keep_alive_duration / 2)
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Scheduler keep alive failed: {e}")
|
||||
interval_time = self.ttl / 10
|
||||
|
||||
time.sleep(interval_time)
|
||||
interval_time = self.ttl / 2
|
||||
scheduler_logger.error(f"Scheduler keep alive failed: {e}")
|
||||
time.sleep(min(3, self.keep_alive_duration / 4))
|
||||
|
||||
def _scheduler_name_from_request_queue(self, request_queue: str) -> str:
|
||||
"""
|
||||
@@ -243,22 +275,18 @@ class GlobalScheduler(object):
|
||||
return f"{self.topic}.resp.{self.name}"
|
||||
return f"{self.topic}.resp.{scheduler_name}"
|
||||
|
||||
def _load_table_name(self, request_queue_name: Optional[str] = None, slot: Optional[int] = None) -> str:
|
||||
def _load_table_name(self, shard: Optional[int] = None, slot: Optional[int] = None) -> str:
|
||||
"""
|
||||
Get the Redis sorted set name used for load balancing.
|
||||
|
||||
Returns:
|
||||
The load score key name
|
||||
"""
|
||||
if request_queue_name is None:
|
||||
request_queue_name = self._request_queue_name()
|
||||
|
||||
if slot is None:
|
||||
slot = self._get_hash_slot(
|
||||
request_queue_name) % self.load_shrads_num
|
||||
else:
|
||||
slot %= self.load_shrads_num
|
||||
return f"{self.topic}.load.{slot}"
|
||||
if shard is None and slot is not None:
|
||||
shard = slot % self.load_shards_num
|
||||
if shard is None:
|
||||
shard = self.shard
|
||||
return f"{self.topic}.load.{shard}"
|
||||
|
||||
@staticmethod
|
||||
def calc_required_blocks(token_num, block_size):
|
||||
@@ -330,11 +358,11 @@ class GlobalScheduler(object):
|
||||
self.client.zincrby(self._load_table_name(),
|
||||
len(serialized_requests), self.name,
|
||||
rem_amount=0, ttl=self.ttl)
|
||||
llm_logger.info(
|
||||
scheduler_logger.info(
|
||||
f"Scheduler has enqueued some requests: {requests}")
|
||||
|
||||
if duplicate:
|
||||
llm_logger.warning(
|
||||
scheduler_logger.warning(
|
||||
"Scheduler has received some duplicated requests: "
|
||||
f"{[task for task in tasks if task.reason is not None]}")
|
||||
return tasks
|
||||
@@ -375,7 +403,7 @@ class GlobalScheduler(object):
|
||||
"""
|
||||
|
||||
if available_blocks <= reserved_output_blocks or batch < 1:
|
||||
llm_logger.debug(
|
||||
scheduler_logger.debug(
|
||||
f"Scheduler's resource are insufficient: available_blocks={available_blocks} "
|
||||
f"reserved_output_blocks={reserved_output_blocks} batch={batch} "
|
||||
f"max_num_batched_tokens={max_num_batched_tokens}")
|
||||
@@ -406,15 +434,17 @@ class GlobalScheduler(object):
|
||||
for element in elements]
|
||||
|
||||
extend_scheduler_names = []
|
||||
extend_scheduler_load_table_name = ""
|
||||
if len(serialized_requests) == 0 and len(batches) > 0:
|
||||
for _ in range(min(5, self.load_shrads_num)):
|
||||
for _ in range(min(self.load_lookup_num, self.load_shards_num)):
|
||||
extend_scheduler_load_table_name = self._load_table_name(
|
||||
slot=self.load_slot_for_getting_request)
|
||||
serialized_members = self.client.zrangebyscore(
|
||||
self._load_table_name(
|
||||
slot=self.load_slot_for_getting_request),
|
||||
extend_scheduler_load_table_name,
|
||||
self.min_load_score,
|
||||
float("+inf"),
|
||||
start=self.load_start,
|
||||
num=self.load_num)
|
||||
start=self.load_offset,
|
||||
num=self.load_count)
|
||||
self.load_slot_for_getting_request += 1
|
||||
if len(serialized_members) > 0:
|
||||
break
|
||||
@@ -433,23 +463,18 @@ class GlobalScheduler(object):
|
||||
|
||||
elements = self.client.lpop(lucky_request_queue_name, batches[0])
|
||||
if elements is not None and len(elements) > 0:
|
||||
self.client.zincrby(
|
||||
self._load_table_name(
|
||||
request_queue_name=lucky_request_queue_name),
|
||||
-len(elements), lucky, rem_amount=0, ttl=self.ttl)
|
||||
self.client.zincrby(extend_scheduler_load_table_name,
|
||||
-len(elements), lucky, rem_amount=0, ttl=self.ttl)
|
||||
serialized_requests += [(lucky_request_queue_name, element)
|
||||
for element in elements]
|
||||
llm_logger.info(
|
||||
scheduler_logger.info(
|
||||
f"Scheduler {self.name} has stolen some requests from another lucky one. "
|
||||
f"(name={lucky} num={len(serialized_requests)})")
|
||||
else:
|
||||
exist_num = self.client.exists(self._instance_name(lucky))
|
||||
if exist_num == 0:
|
||||
if self.client.zrem(
|
||||
self._load_table_name(
|
||||
request_queue_name=lucky_request_queue_name),
|
||||
lucky):
|
||||
llm_logger.info(
|
||||
if self.client.zrem(extend_scheduler_load_table_name, lucky):
|
||||
scheduler_logger.info(
|
||||
f"Scheduler {lucky} has been removed")
|
||||
|
||||
# blocked read
|
||||
@@ -465,12 +490,12 @@ class GlobalScheduler(object):
|
||||
request_queue_name = element[0].decode("utf-8")
|
||||
scheduler_name = self._scheduler_name_from_request_queue(
|
||||
request_queue_name)
|
||||
self.client.zincrby(
|
||||
self._load_table_name(request_queue_name=request_queue_name),
|
||||
-1, scheduler_name, rem_amount=0, ttl=self.ttl)
|
||||
load_table_name = extend_scheduler_load_table_name if scheduler_name != self.name else self._load_table_name()
|
||||
self.client.zincrby(load_table_name,
|
||||
-1, scheduler_name, rem_amount=0, ttl=self.ttl)
|
||||
serialized_requests.append((request_queue_name, element[1]))
|
||||
if scheduler_name != self.name:
|
||||
llm_logger.info(
|
||||
scheduler_logger.info(
|
||||
f"Scheduler {self.name} has stolen a request from another scheduler. (name={scheduler_name})")
|
||||
|
||||
long_partial_requests = 0
|
||||
@@ -526,12 +551,12 @@ class GlobalScheduler(object):
|
||||
if request.request_queue_name == local_request_queue_name:
|
||||
continue
|
||||
|
||||
self._mark_request(request)
|
||||
# self._mark_request(request)
|
||||
if request.request_id not in self.stolen_requests:
|
||||
self.stolen_requests[request.request_id] = request
|
||||
continue
|
||||
|
||||
llm_logger.error(
|
||||
scheduler_logger.error(
|
||||
f"Scheduler has received a duplicate request from others: {request}")
|
||||
|
||||
requests: List[Request] = [
|
||||
@@ -548,19 +573,18 @@ class GlobalScheduler(object):
|
||||
serialized_requests)
|
||||
scheduler_name = self._scheduler_name_from_request_queue(
|
||||
request_queue_name)
|
||||
self.client.zincrby(
|
||||
self._load_table_name(
|
||||
request_queue_name=request_queue_name),
|
||||
len(serialized_requests), scheduler_name, ttl=self.ttl)
|
||||
load_table_name = extend_scheduler_load_table_name if scheduler_name != self.name else self._load_table_name()
|
||||
self.client.zincrby(load_table_name,
|
||||
len(serialized_requests), scheduler_name, ttl=self.ttl)
|
||||
|
||||
llm_logger.info(
|
||||
scheduler_logger.info(
|
||||
f"Scheduler has put remaining request into the queue: {len(remaining_request)}")
|
||||
if len(requests) == 0:
|
||||
llm_logger.debug(
|
||||
scheduler_logger.debug(
|
||||
f"Scheduler has put all just-pulled request into the queue: {len(remaining_request)}")
|
||||
|
||||
if len(requests) > 0:
|
||||
llm_logger.info(
|
||||
scheduler_logger.info(
|
||||
f"Scheduler has pulled some request: {[request.request_id for request in requests]}")
|
||||
return requests
|
||||
|
||||
@@ -600,7 +624,7 @@ class GlobalScheduler(object):
|
||||
if response.request_id in stolen_request_id_request_queue:
|
||||
response_queue_name = stolen_request_id_response_queue[response.request_id]
|
||||
request_queue_name = stolen_request_id_request_queue[response.request_id]
|
||||
self._unmark_response(response, request_queue_name)
|
||||
# self._unmark_response(response, request_queue_name)
|
||||
|
||||
if response_queue_name not in stolen_responses:
|
||||
stolen_responses[response_queue_name] = []
|
||||
@@ -608,7 +632,7 @@ class GlobalScheduler(object):
|
||||
response.serialize())
|
||||
continue
|
||||
|
||||
llm_logger.error(
|
||||
scheduler_logger.error(
|
||||
f"Scheduler has recieved a non-existent response from engine: {[response]}")
|
||||
|
||||
with self.mutex:
|
||||
@@ -624,7 +648,7 @@ class GlobalScheduler(object):
|
||||
self.local_response_not_empty.notify_all()
|
||||
|
||||
if len(finished_request_ids) > 0:
|
||||
llm_logger.info(
|
||||
scheduler_logger.info(
|
||||
f"Scheduler has received some finished responses: {finished_request_ids}")
|
||||
|
||||
for response_queue_name, responses in stolen_responses.items():
|
||||
@@ -681,15 +705,15 @@ class GlobalScheduler(object):
|
||||
with self.mutex:
|
||||
for request_id, contents in responses.items():
|
||||
if request_id not in self.local_responses:
|
||||
llm_logger.error(
|
||||
scheduler_logger.error(
|
||||
"Scheduler has received some non-existent response from the queue. "
|
||||
f"response:{contents} queue:{self._response_queue_name()}")
|
||||
continue
|
||||
self.local_responses[request_id] += contents
|
||||
self.local_response_not_empty.notify_all()
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Scheduler get_results_worker exception: {e} "
|
||||
f"traceback: {traceback.format_exc()}")
|
||||
scheduler_logger.error(f"Scheduler get_results_worker exception: {e} "
|
||||
f"traceback: {traceback.format_exc()}")
|
||||
|
||||
def get_results(self) -> Dict[str, List[RequestOutput]]:
|
||||
"""
|
||||
@@ -718,7 +742,7 @@ class GlobalScheduler(object):
|
||||
- Thread-safe operation using condition variables
|
||||
- Short timeout avoids blocking while maintaining responsiveness
|
||||
- First call may return empty to batch small responses
|
||||
- Automatically logs finished requests via llm_logger
|
||||
- Automatically logs finished requests via scheduler_logger
|
||||
"""
|
||||
first = True
|
||||
|
||||
@@ -754,7 +778,7 @@ class GlobalScheduler(object):
|
||||
|
||||
if finished:
|
||||
del self.local_responses[request_id]
|
||||
llm_logger.info(
|
||||
scheduler_logger.info(
|
||||
f"Scheduler has pulled a finished response: {[request_id]}")
|
||||
return results
|
||||
|
||||
@@ -787,4 +811,41 @@ class GlobalScheduler(object):
|
||||
self.client.zrem(self._load_table_name(), self.name)
|
||||
self.local_responses = dict()
|
||||
self.stolen_requests = dict()
|
||||
llm_logger.info("Scheduler has been reset")
|
||||
scheduler_logger.info("Scheduler has been reset")
|
||||
|
||||
def update_config(self, load_shards_num: Optional[int], reallocate: Optional[bool]):
|
||||
"""
|
||||
Update the scheduler's configuration parameters dynamically.
|
||||
|
||||
This method allows runtime modification of:
|
||||
- Total number of load balancing shards
|
||||
- Current instance's shard assignment
|
||||
|
||||
Args:
|
||||
load_shards_num: New total number of load balancing shards (must be > 0)
|
||||
reallocate: If True, recalculates this instance's shard assignment
|
||||
|
||||
Effects:
|
||||
- Updates internal load balancing configuration
|
||||
- Optionally reallocates this instance to a new shard
|
||||
- Logs configuration changes for audit purposes
|
||||
|
||||
Note:
|
||||
- Changes take effect immediately for new operations
|
||||
- Existing in-progress operations continue with old configuration
|
||||
- Reallocation may affect request distribution pattern
|
||||
"""
|
||||
with self.mutex:
|
||||
old_load_shards_num = self.load_shards_num
|
||||
old_shard = self.shard
|
||||
|
||||
if load_shards_num:
|
||||
self.load_shards_num = load_shards_num
|
||||
|
||||
if reallocate:
|
||||
self.shard = self._get_hash_slot(
|
||||
self.name) % self.load_shards_num
|
||||
|
||||
scheduler_logger.info("Scheduler has reload config, "
|
||||
f"load_shards_num({old_load_shards_num} => {self.load_shards_num}) "
|
||||
f"shard({old_shard} => {self.shard})")
|
||||
|
@@ -20,7 +20,7 @@ from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from fastdeploy.engine.request import Request, RequestOutput
|
||||
from fastdeploy.scheduler.data import ScheduledRequest, ScheduledResponse
|
||||
from fastdeploy.utils import llm_logger
|
||||
from fastdeploy.utils import scheduler_logger
|
||||
|
||||
|
||||
class LocalScheduler(object):
|
||||
@@ -115,7 +115,7 @@ class LocalScheduler(object):
|
||||
self.ids = list()
|
||||
self.requests = dict()
|
||||
self.responses = dict()
|
||||
llm_logger.info("Scheduler has been reset")
|
||||
scheduler_logger.info("Scheduler has been reset")
|
||||
|
||||
def _recycle(self, request_id: Optional[str] = None):
|
||||
"""
|
||||
@@ -189,10 +189,10 @@ class LocalScheduler(object):
|
||||
|
||||
self.ids += valid_ids
|
||||
self.requests_not_empty.notify_all()
|
||||
llm_logger.info(f"Scheduler has enqueued some requests: {valid_ids}")
|
||||
scheduler_logger.info(f"Scheduler has enqueued some requests: {valid_ids}")
|
||||
|
||||
if len(duplicated_ids) > 0:
|
||||
llm_logger.warning(
|
||||
scheduler_logger.warning(
|
||||
f"Scheduler has received some duplicated requests: {duplicated_ids}"
|
||||
)
|
||||
|
||||
@@ -234,7 +234,7 @@ class LocalScheduler(object):
|
||||
List of Request objects ready for processing
|
||||
"""
|
||||
if available_blocks <= reserved_output_blocks or batch < 1:
|
||||
llm_logger.debug(
|
||||
scheduler_logger.debug(
|
||||
f"Scheduler's resource are insufficient: available_blocks={available_blocks} "
|
||||
f"reserved_output_blocks={reserved_output_blocks} batch={batch} "
|
||||
f"max_num_batched_tokens={max_num_batched_tokens}")
|
||||
@@ -277,12 +277,12 @@ class LocalScheduler(object):
|
||||
self.ids_read_cursor += len(requests)
|
||||
|
||||
if len(batch_ids) > 0 and len(requests) == 0:
|
||||
llm_logger.debug(
|
||||
scheduler_logger.debug(
|
||||
f"Scheduler has put all just-pulled request into the queue: {len(batch_ids)}"
|
||||
)
|
||||
|
||||
if len(requests) > 0:
|
||||
llm_logger.info(
|
||||
scheduler_logger.info(
|
||||
f"Scheduler has pulled some request: {[request.request_id for request in requests]}"
|
||||
)
|
||||
|
||||
@@ -303,14 +303,14 @@ class LocalScheduler(object):
|
||||
response.request_id for response in responses if response.finished
|
||||
]
|
||||
if len(finished_responses) > 0:
|
||||
llm_logger.info(
|
||||
scheduler_logger.info(
|
||||
f"Scheduler has received some finished responses: {finished_responses}"
|
||||
)
|
||||
|
||||
with self.mutex:
|
||||
for response in responses:
|
||||
if response.request_id not in self.requests:
|
||||
llm_logger.warning(
|
||||
scheduler_logger.warning(
|
||||
f"Scheduler has received a expired response: {[response.request_id]}"
|
||||
)
|
||||
continue
|
||||
@@ -342,7 +342,7 @@ class LocalScheduler(object):
|
||||
- Thread-safe operation using condition variables
|
||||
- Has a short timeout (0.001s) to avoid blocking
|
||||
- Automatically recycles completed requests to free memory
|
||||
- Logs finished requests via llm_logger
|
||||
- Logs finished requests via scheduler_logger
|
||||
"""
|
||||
|
||||
def _get_results():
|
||||
@@ -364,7 +364,7 @@ class LocalScheduler(object):
|
||||
|
||||
if finished:
|
||||
self._recycle(request_id)
|
||||
llm_logger.info(
|
||||
scheduler_logger.info(
|
||||
f"Scheduler has pulled a finished response: {[request_id]}"
|
||||
)
|
||||
return results
|
||||
|
@@ -18,7 +18,7 @@ from typing import Callable, List, Any, Dict, Optional
|
||||
import functools
|
||||
import threading
|
||||
import traceback
|
||||
from fastdeploy.utils import llm_logger
|
||||
from fastdeploy.utils import scheduler_logger
|
||||
|
||||
|
||||
class Task:
|
||||
@@ -163,7 +163,7 @@ class Workers:
|
||||
try:
|
||||
results = self.work(tasks)
|
||||
except Exception as e:
|
||||
llm_logger.error(
|
||||
scheduler_logger.error(
|
||||
f"Worker {self.name} execute error: {e}, traceback: {traceback.format_exc()}")
|
||||
continue
|
||||
|
||||
|
@@ -63,6 +63,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.device_id = device_id
|
||||
self.speculative_method = self.fd_config.speculative_config.method
|
||||
self.speculative_decoding = self.speculative_method is not None
|
||||
self.enable_logprob = fd_config.model_config.enable_logprob
|
||||
|
||||
self.guided_backend = None
|
||||
if self.fd_config.parallel_config.guided_decoding_backend != "off":
|
||||
@@ -243,8 +244,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
request.eos_token_ids.append(request.eos_token_ids[0])
|
||||
self.share_inputs["eos_token_id"][:] = np.array(
|
||||
request.eos_token_ids, dtype="int64").reshape(-1, 1)
|
||||
|
||||
self.share_inputs["top_p"][idx:idx + 1] = request.get("top_p", 0.7)
|
||||
self.share_inputs["top_k"][idx:idx + 1] = request.get("top_k", 0)
|
||||
self.share_inputs["temperature"][idx:idx + 1] = request.get(
|
||||
"temperature", 0.95)
|
||||
self.share_inputs["penalty_score"][idx:idx + 1] = request.get(
|
||||
@@ -350,8 +351,11 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["eos_token_id"] = paddle.full(
|
||||
[self.parallel_config.eos_tokens_lens, 1], 0, dtype='int64')
|
||||
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1],
|
||||
self.model_config.top_p,
|
||||
dtype='float32')
|
||||
self.model_config.top_p,
|
||||
dtype='float32')
|
||||
self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1],
|
||||
0,
|
||||
dtype='int64')
|
||||
self.share_inputs["temperature"] = paddle.full(
|
||||
[max_num_seqs, 1], self.model_config.temperature, dtype='float32')
|
||||
self.share_inputs["penalty_score"] = paddle.full(
|
||||
@@ -574,6 +578,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.sampling_metadata = SamplingMetadata(
|
||||
temperature=self.share_inputs["temperature"],
|
||||
top_p=self.share_inputs["top_p"],
|
||||
top_k=self.share_inputs["top_k"],
|
||||
step_idx=self.share_inputs["step_idx"],
|
||||
pre_token_ids=self.share_inputs["pre_ids"],
|
||||
frequency_penalties=self.share_inputs["frequency_score"],
|
||||
@@ -582,6 +587,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
min_dec_lens=self.share_inputs["min_dec_len"],
|
||||
bad_words_token_ids=self.share_inputs["bad_tokens"],
|
||||
eos_token_ids=self.share_inputs["eos_token_id"],
|
||||
max_num_logprobs=20 if self.enable_logprob else None,
|
||||
)
|
||||
|
||||
def load_model(self) -> None:
|
||||
@@ -786,15 +792,15 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["step_idx"],
|
||||
self.share_inputs["stop_flags"],
|
||||
)
|
||||
sampled_token_ids = self.sampler(logits,
|
||||
sampler_output = self.sampler(logits,
|
||||
self.sampling_metadata)
|
||||
if self.parallel_config.tensor_parallel_degree > 1:
|
||||
paddle.distributed.broadcast(sampled_token_ids, 0)
|
||||
paddle.distributed.broadcast(sampler_output.sampled_token_ids, 0)
|
||||
else:
|
||||
self.sampler(logits, self.sampling_metadata,
|
||||
self.parallel_config.max_model_len,
|
||||
self.share_inputs)
|
||||
sampled_token_ids = None
|
||||
sampler_output = None
|
||||
if self.parallel_config.tensor_parallel_degree > 1:
|
||||
paddle.distributed.broadcast(
|
||||
self.share_inputs["accept_tokens"], 0)
|
||||
@@ -834,7 +840,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
accept_num=self.share_inputs["accept_num"]
|
||||
if self.speculative_decoding else None)
|
||||
|
||||
post_process(sampled_token_ids=sampled_token_ids,
|
||||
post_process(sampler_output=sampler_output,
|
||||
model_output=model_output_data,
|
||||
speculative_decoding=self.speculative_decoding,
|
||||
skip_save_output=True)
|
||||
@@ -1021,18 +1027,18 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["step_idx"],
|
||||
self.share_inputs["stop_flags"],
|
||||
)
|
||||
sampled_token_ids = self.sampler(
|
||||
sampler_output = self.sampler(
|
||||
logits,
|
||||
self.sampling_metadata,
|
||||
skip_idx_list,
|
||||
)
|
||||
if self.parallel_config.tensor_parallel_degree > 1:
|
||||
paddle.distributed.broadcast(sampled_token_ids, 0)
|
||||
paddle.distributed.broadcast(sampler_output.sampled_token_ids, 0)
|
||||
|
||||
else:
|
||||
self.sampler(logits, self.sampling_metadata,
|
||||
self.parallel_config.max_model_len, self.share_inputs)
|
||||
sampled_token_ids = None
|
||||
sampler_output = None
|
||||
if self.parallel_config.tensor_parallel_degree > 1:
|
||||
paddle.distributed.broadcast(
|
||||
self.share_inputs["accept_tokens"], 0)
|
||||
@@ -1075,7 +1081,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
skip_save_output = True
|
||||
else:
|
||||
skip_save_output = False
|
||||
post_process(sampled_token_ids=sampled_token_ids,
|
||||
post_process(sampler_output=sampler_output,
|
||||
model_output=model_output_data,
|
||||
save_each_rank=self.parallel_config.use_ep,
|
||||
speculative_decoding=self.speculative_decoding,
|
||||
|
@@ -15,11 +15,80 @@
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
import paddle
|
||||
|
||||
|
||||
class LogprobsLists(NamedTuple):
|
||||
"""
|
||||
"""
|
||||
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
logprob_token_ids: list[list[int]]
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
logprobs: list[list[float]]
|
||||
# [num_reqs]
|
||||
sampled_token_ranks: list[int]
|
||||
|
||||
def slice(self, start: int, end: int):
|
||||
"""slice"""
|
||||
return LogprobsLists(
|
||||
self.logprob_token_ids[start:end],
|
||||
self.logprobs[start:end],
|
||||
self.sampled_token_ranks[start:end],
|
||||
)
|
||||
|
||||
|
||||
class LogprobsTensors(NamedTuple):
|
||||
"""
|
||||
"""
|
||||
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
logprob_token_ids: paddle.Tensor
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
logprobs: paddle.Tensor
|
||||
# [num_reqs]
|
||||
selected_token_ranks: paddle.Tensor
|
||||
|
||||
def tolists(self):
|
||||
"""Convert to lists."""
|
||||
return LogprobsLists(
|
||||
self.logprob_token_ids.tolist(),
|
||||
self.logprobs.tolist(),
|
||||
self.selected_token_ranks.tolist(),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def empty_cpu(num_positions: int,
|
||||
num_tokens_per_position: int) -> "LogprobsTensors":
|
||||
"""Create empty LogprobsTensors on CPU."""
|
||||
|
||||
logprob_token_ids = paddle.empty(
|
||||
[num_positions, num_tokens_per_position],
|
||||
dtype=paddle.int64).cpu()
|
||||
logprobs = paddle.empty_like(logprob_token_ids, dtype=paddle.float32)
|
||||
selected_token_ranks = paddle.empty([num_positions],
|
||||
dtype=paddle.int64).cpu()
|
||||
return LogprobsTensors(
|
||||
logprob_token_ids=logprob_token_ids,
|
||||
logprobs=logprobs,
|
||||
selected_token_ranks=selected_token_ranks,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplerOutput:
|
||||
"""
|
||||
"""
|
||||
|
||||
# [num_reqs, max_num_generated_tokens]
|
||||
# Different requests can have different number of generated tokens.
|
||||
# All requests are padded to max_num_generated_tokens.
|
||||
# PLACEHOLDER_TOKEN_ID (-1 by default) is used for padding.
|
||||
sampled_token_ids: paddle.Tensor
|
||||
logprobs_tensors: Optional[LogprobsTensors]
|
||||
|
||||
@dataclass
|
||||
class ModelOutputData:
|
||||
"""
|
||||
|
@@ -13,10 +13,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
@@ -24,6 +24,9 @@ import paddle.distributed.fleet as fleet
|
||||
from paddleformers.transformers.model_utils import load_tp_checkpoint
|
||||
from safetensors import safe_open
|
||||
|
||||
from fastdeploy.config import (DeviceConfig, FDConfig, KVCacheConfig,
|
||||
LoadConfig, ModelConfig, MoEConfig, MoEPhase,
|
||||
ParallelConfig, SpeculativeConfig)
|
||||
from fastdeploy.input.ernie_tokenizer import ErnieBotTokenizer
|
||||
from fastdeploy.input.mm_processor import DataProcessor
|
||||
from fastdeploy.model_executor.layers.attention import get_attention_backend
|
||||
@@ -42,17 +45,15 @@ from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import (
|
||||
ScatterOp, VariableResolutionResamplerModel)
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.worker.forward_meta import ForwardMeta
|
||||
from fastdeploy.worker.output import SamplerOutput
|
||||
from fastdeploy.worker.utils import check_safetensors_model
|
||||
from fastdeploy.worker.vl_model_runner_base import VLModelRunnerBase
|
||||
from fastdeploy.config import (DeviceConfig, FDConfig, KVCacheConfig,
|
||||
LoadConfig, ModelConfig, MoEConfig,
|
||||
MoEPhase, ParallelConfig, SpeculativeConfig)
|
||||
|
||||
if current_platform.is_cuda() and current_platform.available():
|
||||
from fastdeploy.model_executor.layers.utils import (
|
||||
remove_padding, speculate_remove_padding)
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu import (save_output,
|
||||
from fastdeploy.model_executor.ops.gpu import (save_output, save_output_topk,
|
||||
set_stop_value_multi_ends,
|
||||
set_value_by_flags_and_idx,
|
||||
update_inputs)
|
||||
@@ -84,7 +85,7 @@ class GPUVLModelRunner(VLModelRunnerBase):
|
||||
self.mp_group = hcg.get_model_parallel_group()
|
||||
self.is_safetensors_model = check_safetensors_model(
|
||||
args.model_name_or_path)
|
||||
|
||||
self.enable_logprob = args.enable_logprob
|
||||
model_path = os.path.dirname(args.model_name_or_path)
|
||||
args.llm_model_name_or_path = args.model_name_or_path
|
||||
if not self.is_safetensors_model:
|
||||
@@ -825,6 +826,7 @@ class GPUVLModelRunner(VLModelRunnerBase):
|
||||
min_dec_lens=self.share_inputs["min_dec_len"],
|
||||
bad_words_token_ids=self.share_inputs["bad_tokens"],
|
||||
eos_token_ids=self.share_inputs["eos_token_id"],
|
||||
max_num_logprobs=20 if self.enable_logprob else None,
|
||||
)
|
||||
|
||||
def generate(self) -> None:
|
||||
@@ -846,17 +848,17 @@ class GPUVLModelRunner(VLModelRunnerBase):
|
||||
self.share_inputs["stop_flags"],
|
||||
)
|
||||
# sampler & save_output
|
||||
next_tokens = self.sampler(logits, self.sampling_metadata)
|
||||
sampler_output = self.sampler(logits, self.sampling_metadata)
|
||||
if self.fd_config.parallel_config.tensor_parallel_degree > 1:
|
||||
paddle.distributed.broadcast(next_tokens, 0)
|
||||
self.post_process(next_tokens)
|
||||
paddle.distributed.broadcast(sampler_output.sampled_token_ids, 0)
|
||||
self.post_process(sampler_output)
|
||||
|
||||
def post_process(self, next_tokens: paddle.Tensor) -> None:
|
||||
def post_process(self, sampler_output: SamplerOutput) -> None:
|
||||
"""
|
||||
post_process
|
||||
"""
|
||||
if self.share_inputs["enable_thinking"]:
|
||||
exists_think_end = next_tokens == self.model_cfg.think_end_id
|
||||
exists_think_end = sampler_output.sampled_token_ids == self.model_cfg.think_end_id
|
||||
paddle.assign(
|
||||
paddle.where(
|
||||
exists_think_end,
|
||||
@@ -872,12 +874,12 @@ class GPUVLModelRunner(VLModelRunnerBase):
|
||||
), self.share_inputs["reasoning_index"])
|
||||
|
||||
stop_wo_think = (
|
||||
(next_tokens == self.share_inputs["eos_token_id"]) |
|
||||
(sampler_output.sampled_token_ids == self.share_inputs["eos_token_id"]) |
|
||||
(self.share_inputs["reasoning_index"] == 0)) & (
|
||||
self.share_inputs["need_think_end"] > 0)
|
||||
next_tokens = paddle.where(stop_wo_think,
|
||||
sampler_output.sampled_token_ids = paddle.where(stop_wo_think,
|
||||
self.model_cfg.think_end_id,
|
||||
next_tokens)
|
||||
sampler_output.sampled_token_ids)
|
||||
paddle.assign(
|
||||
paddle.where(
|
||||
stop_wo_think,
|
||||
@@ -900,7 +902,7 @@ class GPUVLModelRunner(VLModelRunnerBase):
|
||||
)
|
||||
|
||||
set_stop_value_multi_ends(
|
||||
next_tokens,
|
||||
sampler_output.sampled_token_ids,
|
||||
self.share_inputs["stop_flags"],
|
||||
self.share_inputs["seq_lens_this_time"],
|
||||
self.share_inputs["eos_token_id"],
|
||||
@@ -917,15 +919,25 @@ class GPUVLModelRunner(VLModelRunnerBase):
|
||||
self.share_inputs["seq_lens_decoder"],
|
||||
self.share_inputs["input_ids"],
|
||||
self.share_inputs["stop_nums"],
|
||||
next_tokens,
|
||||
sampler_output.sampled_token_ids,
|
||||
self.share_inputs["is_block_step"],
|
||||
)
|
||||
save_output(
|
||||
next_tokens,
|
||||
self.share_inputs["not_need_stop"],
|
||||
self.rank,
|
||||
False, # use_ep
|
||||
)
|
||||
if sampler_output.logprobs_tensors is None:
|
||||
save_output(
|
||||
sampler_output.sampled_token_ids,
|
||||
self.share_inputs["not_need_stop"],
|
||||
self.rank,
|
||||
False, # use_ep
|
||||
)
|
||||
else:
|
||||
save_output_topk(
|
||||
sampler_output.sampled_token_ids,
|
||||
sampler_output.logprobs_tensors.logprob_token_ids,
|
||||
sampler_output.logprobs_tensors.logprobs,
|
||||
sampler_output.logprobs_tensors.selected_token_ranks,
|
||||
self.share_inputs["not_need_stop"],
|
||||
self.rank,
|
||||
)
|
||||
|
||||
def _cal_theortical_kvcache(self):
|
||||
"""
|
||||
|
@@ -14,14 +14,14 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import argparse
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import paddle
|
||||
import paddle.distributed as dist
|
||||
import paddle.distributed.fleet as fleet
|
||||
from fastdeploy.config import ModelConfig
|
||||
|
||||
from fastdeploy.config import ModelConfig
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
logger = get_logger("worker", "worker.log")
|
||||
|
@@ -42,6 +42,8 @@ def get_worker(fd_config: FDConfig, local_rank: int, rank: int) -> WorkerBase:
|
||||
"""
|
||||
get worker of different device
|
||||
"""
|
||||
if fd_config.model_config.enable_logprob and not current_platform.is_cuda():
|
||||
raise NotImplementedError("Only CUDA platform supports logprob.")
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.worker.gpu_worker import GpuWorker
|
||||
return GpuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
|
||||
@@ -346,11 +348,11 @@ class PaddleDisWorkerProc():
|
||||
model_block_memory_used)
|
||||
# NOTE(liuzichang): Too many block will lead to illegal memory access
|
||||
# We will develop dynamic limits in future.
|
||||
if num_blocks_local > 20000:
|
||||
if num_blocks_local > 40000:
|
||||
logger.info(
|
||||
f"------- Reset num_blocks_local {num_blocks_local} to 20000"
|
||||
f"------- Reset num_blocks_local {num_blocks_local} to 40000"
|
||||
)
|
||||
num_blocks_local = min(20000, num_blocks_local)
|
||||
num_blocks_local = min(40000, num_blocks_local)
|
||||
logger.info(
|
||||
f"------- model_block_memory_used:{model_block_memory_used} --------"
|
||||
)
|
||||
@@ -511,7 +513,7 @@ def parse_args():
|
||||
|
||||
parser.add_argument("--quantization",
|
||||
type=str,
|
||||
default="",
|
||||
default="None",
|
||||
help="Quantization name for the model, currentlly support " \
|
||||
"'wint4', 'wint8'," \
|
||||
"default is None. The priority of this configuration "\
|
||||
@@ -550,153 +552,178 @@ def parse_args():
|
||||
"'ipc_snapshot': load from disk snapshot of IPC weights, "
|
||||
"'meta': provide RL traing worker, no_weights_load"
|
||||
"'normal':normal load weight")
|
||||
parser.add_argument("--enable_logprob",
|
||||
action='store_true',
|
||||
help="Enable output of token-level log probabilities.")
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def initialize_fd_config(args: argparse.Namespace) -> FDConfig:
|
||||
"""Initialize FDConfig
|
||||
TODO(gongshaotian): Unified all configs to FDConfig
|
||||
def initialize_fd_config(config) -> FDConfig:
|
||||
"""Initialize FDConfig from either RolloutModelConfig or argparse.Namespace
|
||||
|
||||
Args:
|
||||
config: Configuration object containing all parameters (either RolloutModelConfig or argparse.Namespace)
|
||||
|
||||
Returns:
|
||||
FDConfig: Initialized FastDeploy configuration object
|
||||
"""
|
||||
# NOTE(gongshaotian): From build stream line model
|
||||
config, _ = ModelConfig.get_config_dict(args.model_name_or_path)
|
||||
if 'num_experts' in config:
|
||||
config['moe_num_experts'] = config.pop('num_experts')
|
||||
# Get model config from model directory
|
||||
model_config_dict, _ = ModelConfig.get_config_dict(config.model_name_or_path)
|
||||
|
||||
if 'num_experts_per_tok' in config:
|
||||
config['moe_topk'] = config.pop('num_experts_per_tok')
|
||||
config["head_dim"] = config.get(
|
||||
"head_dim", config["hidden_size"] // config["num_attention_heads"])
|
||||
config["rope_theta"] = config.get("rope_theta", 10000.0)
|
||||
model_config = ModelConfig.from_dict(config)
|
||||
# TODO Set `head_dim` again. Because `ModelConfig` class doesn't support feeding head_dim at all!
|
||||
model_config.head_dim = config["head_dim"]
|
||||
paddle.set_default_dtype(args.dtype)
|
||||
# Handle MoE related configs
|
||||
if 'num_experts' in model_config_dict:
|
||||
model_config_dict['moe_num_experts'] = model_config_dict.pop('num_experts')
|
||||
if 'num_experts_per_tok' in model_config_dict:
|
||||
model_config_dict['moe_topk'] = model_config_dict.pop('num_experts_per_tok')
|
||||
|
||||
# Set default values for model config
|
||||
model_config_dict["head_dim"] = model_config_dict.get(
|
||||
"head_dim", model_config_dict["hidden_size"] // model_config_dict["num_attention_heads"])
|
||||
model_config_dict["rope_theta"] = model_config_dict.get("rope_theta", 10000.0)
|
||||
|
||||
# Create model config object
|
||||
model_config = ModelConfig.from_dict(model_config_dict)
|
||||
model_config.head_dim = model_config_dict["head_dim"]
|
||||
paddle.set_default_dtype(config.dtype)
|
||||
|
||||
# Initialize all config components
|
||||
device_config = DeviceConfig()
|
||||
# model_config = ModelConfig()
|
||||
|
||||
decoding_config = DecodingConfig()
|
||||
|
||||
speculative_config = SpeculativeConfig()
|
||||
parallel_config = ParallelConfig()
|
||||
load_config = LoadConfig()
|
||||
moe_config = MoEConfig()
|
||||
graph_opt_config = GraphOptimizationConfig(
|
||||
args.enable_static_graph_inference, args.use_cudagraph,
|
||||
args.max_capture_batch_size)
|
||||
model_config.quantization = args.quantization
|
||||
|
||||
# Update speculate config
|
||||
speculative_config.method = args.speculative_method
|
||||
speculative_config.num_speculative_tokens = args.speculative_max_draft_token_num
|
||||
speculative_config.model_name_or_path = args.speculative_model_name_or_path
|
||||
speculative_config.quantization = args.speculative_model_quantization
|
||||
# Handle graph optimization config (check for attribute existence for backward compatibility)
|
||||
enable_static_graph_inference = getattr(config, 'enable_static_graph_inference', False)
|
||||
use_cudagraph = getattr(config, 'use_cudagraph', False)
|
||||
max_capture_batch_size = getattr(config, 'max_capture_batch_size', 0)
|
||||
|
||||
graph_opt_config = GraphOptimizationConfig(
|
||||
enable_static_graph_inference,
|
||||
use_cudagraph,
|
||||
max_capture_batch_size
|
||||
)
|
||||
|
||||
# Handle quantization (check for attribute existence)
|
||||
model_config.quantization = getattr(config, 'quantization', None)
|
||||
|
||||
# Update speculative config
|
||||
speculative_config.method = getattr(config, 'speculative_method', None)
|
||||
speculative_config.num_speculative_tokens = getattr(config, 'speculative_max_draft_token_num', 0)
|
||||
speculative_config.model_name_or_path = getattr(config, 'speculative_model_name_or_path', None)
|
||||
speculative_config.quantization = getattr(config, 'speculative_model_quantization', None)
|
||||
|
||||
# Update parallel config
|
||||
parallel_config.engine_pid = args.engine_pid
|
||||
parallel_config.model_name_or_path = args.model_name_or_path
|
||||
parallel_config.max_num_seqs = args.max_num_seqs
|
||||
parallel_config.max_block_num = args.total_block_num
|
||||
parallel_config.block_size = args.block_size
|
||||
parallel_config.engine_worker_queue_port = args.engine_worker_queue_port
|
||||
parallel_config.max_model_len = args.max_model_len
|
||||
model_config.max_seq_len = args.max_model_len
|
||||
model_config.max_length = args.max_model_len
|
||||
parallel_config.device_ids = args.device_ids
|
||||
parallel_config.dtype = args.dtype
|
||||
parallel_config.enc_dec_block_num = args.enc_dec_block_num
|
||||
parallel_config.kv_cache_ratio = args.kv_cache_ratio
|
||||
parallel_config.first_token_id = args.first_token_id
|
||||
parallel_config.gpu_memory_utilization = args.gpu_memory_utilization
|
||||
parallel_config.engine_pid = args.engine_pid
|
||||
parallel_config.do_profile = args.do_profile
|
||||
parallel_config.dynamic_load_weight = args.dynamic_load_weight
|
||||
parallel_config.pad_token_id = args.pad_token_id
|
||||
parallel_config.eos_tokens_lens = args.eos_tokens_lens
|
||||
parallel_config.enable_chunked_prefill = args.enable_chunked_prefill
|
||||
parallel_config.max_num_batched_tokens = args.max_num_batched_tokens
|
||||
parallel_config.enable_prefix_caching = args.enable_prefix_caching
|
||||
parallel_config.engine_pid = getattr(config, 'engine_pid', None)
|
||||
parallel_config.model_name_or_path = config.model_name_or_path
|
||||
parallel_config.max_num_seqs = getattr(config, 'max_num_seqs', 0)
|
||||
parallel_config.max_block_num = getattr(config, 'total_block_num', 0)
|
||||
parallel_config.block_size = getattr(config, 'block_size', 0)
|
||||
parallel_config.engine_worker_queue_port = getattr(config, 'engine_worker_queue_port', 0)
|
||||
parallel_config.max_model_len = getattr(config, 'max_model_len', 0)
|
||||
model_config.max_seq_len = getattr(config, 'max_model_len', 0)
|
||||
model_config.max_length = getattr(config, 'max_model_len', 0)
|
||||
parallel_config.device_ids = getattr(config, 'device_ids', [])
|
||||
parallel_config.dtype = config.dtype
|
||||
parallel_config.enc_dec_block_num = getattr(config, 'enc_dec_block_num', 0)
|
||||
parallel_config.kv_cache_ratio = getattr(config, 'kv_cache_ratio', 1.0)
|
||||
parallel_config.first_token_id = getattr(config, 'first_token_id', None)
|
||||
parallel_config.gpu_memory_utilization = getattr(config, 'gpu_memory_utilization', 0.9)
|
||||
parallel_config.engine_pid = getattr(config, 'engine_pid', None)
|
||||
parallel_config.do_profile = getattr(config, 'do_profile', False)
|
||||
parallel_config.dynamic_load_weight = getattr(config, 'dynamic_load_weight', False)
|
||||
parallel_config.pad_token_id = getattr(config, 'pad_token_id', None)
|
||||
parallel_config.eos_tokens_lens = getattr(config, 'eos_tokens_lens', 0)
|
||||
parallel_config.enable_chunked_prefill = getattr(config, 'enable_chunked_prefill', False)
|
||||
parallel_config.max_num_batched_tokens = getattr(config, 'max_num_batched_tokens', 0)
|
||||
parallel_config.enable_prefix_caching = getattr(config, 'enable_prefix_caching', False)
|
||||
parallel_config.use_ep = getattr(config, 'enable_expert_parallell', False)
|
||||
parallel_config.tensor_parallel_degree = getattr(config, 'tensor_parallel_size', 1)
|
||||
parallel_config.expert_parallel_degree = getattr(config, 'expert_parallel_size', 1)
|
||||
parallel_config.splitwise_role = getattr(config, 'splitwise_role', None)
|
||||
parallel_config.guided_decoding_backend = getattr(config, 'guided_decoding_backend', None)
|
||||
parallel_config.disable_any_whitespace = getattr(config, 'disable_any_whitespace', False)
|
||||
|
||||
parallel_config.use_ep = args.enable_expert_parallell
|
||||
parallel_config.tensor_parallel_degree = args.tensor_parallel_size
|
||||
parallel_config.expert_parallel_degree = args.expert_parallel_size
|
||||
parallel_config.splitwise_role = args.splitwise_role
|
||||
# Handle load config (check for environment variable)
|
||||
load_config.use_fastsafetensor = int(envs.FD_USE_FASTSAFETENSOR) == 1
|
||||
|
||||
parallel_config.guided_decoding_backend = args.guided_decoding_backend
|
||||
parallel_config.disable_any_whitespace = args.disable_any_whitespace
|
||||
|
||||
# Log parallel config info
|
||||
logger.info(f"parallel_config.use_ep {parallel_config.use_ep}")
|
||||
logger.info(
|
||||
f"parallel_config.tensor_parallel_degree {parallel_config.tensor_parallel_degree}"
|
||||
)
|
||||
logger.info(f"args.splitwise_role {args.splitwise_role}")
|
||||
logger.info(f"parallel_config.tensor_parallel_degree {parallel_config.tensor_parallel_degree}")
|
||||
logger.info(f"splitwise_role {parallel_config.splitwise_role}")
|
||||
|
||||
if args.splitwise_role == "mixed":
|
||||
# Set MoE phase based on splitwise role
|
||||
if parallel_config.splitwise_role == "mixed":
|
||||
parallel_config.moe_phase = MoEPhase.PREFILL
|
||||
elif args.splitwise_role == "prefill":
|
||||
elif parallel_config.splitwise_role == "prefill":
|
||||
parallel_config.moe_phase = MoEPhase.PREFILL
|
||||
elif args.splitwise_role == "decode":
|
||||
elif parallel_config.splitwise_role == "decode":
|
||||
parallel_config.moe_phase = MoEPhase.DECODER
|
||||
else:
|
||||
elif parallel_config.splitwise_role is not None:
|
||||
raise NotImplementedError
|
||||
|
||||
num_key_value_heads = config.get("num_key_value_heads", -1)
|
||||
# Handle model architecture specific configurations
|
||||
num_key_value_heads = model_config_dict.get("num_key_value_heads", -1)
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = -1
|
||||
|
||||
if config.get("ffn_hidden_size", None) is not None:
|
||||
ffn_hidden_size = config["ffn_hidden_size"]
|
||||
elif config.get("intermediate_size", None) is not None:
|
||||
ffn_hidden_size = config["intermediate_size"]
|
||||
# Calculate FFN hidden size
|
||||
if model_config_dict.get("ffn_hidden_size", None) is not None:
|
||||
ffn_hidden_size = model_config_dict["ffn_hidden_size"]
|
||||
elif model_config_dict.get("intermediate_size", None) is not None:
|
||||
ffn_hidden_size = model_config_dict["intermediate_size"]
|
||||
else:
|
||||
ffn_hidden_size = 4 * config["hidden_size"]
|
||||
if config["hidden_act"].lower() == "swiglu":
|
||||
ffn_hidden_size = 4 * model_config_dict["hidden_size"]
|
||||
if model_config_dict["hidden_act"].lower() == "swiglu":
|
||||
if paddle.distributed.get_world_size() > 1:
|
||||
multiple_of = 8 * config["num_attention_heads"]
|
||||
multiple_of = 8 * model_config_dict["num_attention_heads"]
|
||||
else:
|
||||
multiple_of = 4 * config["num_attention_heads"]
|
||||
multiple_of = 4 * model_config_dict["num_attention_heads"]
|
||||
ffn_hidden_size = multiple_of * (
|
||||
(int(2 * ffn_hidden_size / 3) + multiple_of - 1) //
|
||||
multiple_of)
|
||||
|
||||
num_layers = config.get("num_layers", None) or config.get(
|
||||
# Get number of layers
|
||||
num_layers = model_config_dict.get("num_layers", None) or model_config_dict.get(
|
||||
"num_hidden_layers", None)
|
||||
if num_layers is None:
|
||||
raise ValueError(f"num_layers<{num_layers}> is invalid")
|
||||
|
||||
use_moe = config.get("moe_layer_start_index", num_layers) < num_layers
|
||||
use_moe = model_config_dict.get("moe_layer_start_index", num_layers) < num_layers
|
||||
|
||||
# Update model config
|
||||
model_config.ffn_hidden_size = ffn_hidden_size
|
||||
model_config.num_layers = num_layers
|
||||
|
||||
model_config.num_key_value_heads = num_key_value_heads
|
||||
model_config.start_layer_index = config.get("start_layer_index", 0)
|
||||
moe_config.num_experts = config.get("moe_num_experts", None)
|
||||
moe_config.moe_intermediate_size = config.get("moe_intermediate_size",
|
||||
None)
|
||||
moe_config.top_k = config.get("moe_k", config.get("moe_topk", 8))
|
||||
moe_config.moe_num_shared_experts = config.get("moe_num_shared_experts", 0)
|
||||
moe_config.moe_layer_start_index = config.get("moe_layer_start_index", 0)
|
||||
model_config.start_layer_index = model_config_dict.get("start_layer_index", 0)
|
||||
|
||||
moe_config.num_max_dispatch_tokens_per_rank = config.get(
|
||||
# Update MoE config
|
||||
moe_config.num_experts = model_config_dict.get("moe_num_experts", None)
|
||||
moe_config.moe_intermediate_size = model_config_dict.get("moe_intermediate_size", None)
|
||||
moe_config.top_k = model_config_dict.get("moe_k", model_config_dict.get("moe_topk", 8))
|
||||
moe_config.moe_num_shared_experts = model_config_dict.get("moe_num_shared_experts", 0)
|
||||
moe_config.moe_layer_start_index = model_config_dict.get("moe_layer_start_index", 0)
|
||||
moe_config.num_max_dispatch_tokens_per_rank = model_config_dict.get(
|
||||
"num_max_dispatch_tokens_per_rank", 256)
|
||||
moe_config.moe_use_aux_free = config.get("moe_use_aux_free", False)
|
||||
moe_config.moe_use_aux_free = model_config_dict.get("moe_use_aux_free", False)
|
||||
|
||||
model_config.ori_vocab_size = config.get("vocab_size", -1)
|
||||
if "Ernie4_5_ForCausalLM" in config.get("architectures"):
|
||||
model_config.ori_vocab_size = args.ori_vocab_size
|
||||
# Handle vocabulary size
|
||||
model_config.ori_vocab_size = model_config_dict.get("vocab_size", -1)
|
||||
archs = model_config_dict.get("architectures", [])
|
||||
if "Ernie4_5_ForCausalLM" in archs or "Ernie4_5_MoeForCausalLM" in archs:
|
||||
model_config.ori_vocab_size = getattr(config, 'ori_vocab_size', model_config.ori_vocab_size)
|
||||
|
||||
if "DeepseekV3ForCausalLM" in config.get("architectures"):
|
||||
# Handle DeepseekV3 specific config
|
||||
if "DeepseekV3ForCausalLM" in model_config_dict.get("architectures", []):
|
||||
from paddleformers.transformers import AutoConfig
|
||||
model_config.deepseekv3 = AutoConfig.from_pretrained(
|
||||
args.model_name_or_path)
|
||||
config.model_name_or_path)
|
||||
|
||||
#TODO(@yuanrisheng): kv_cache quant config can only be
|
||||
# stored in model config file, which should be unified
|
||||
quantization_config = config.get("quantization_config", None)
|
||||
# Handle quantization config
|
||||
quantization_config = model_config_dict.get("quantization_config", None)
|
||||
if not model_config.is_quantized:
|
||||
if quantization_config is not None:
|
||||
if "kv_cache_quant_type" not in quantization_config:
|
||||
@@ -711,13 +738,13 @@ def initialize_fd_config(args: argparse.Namespace) -> FDConfig:
|
||||
|
||||
if quantization_config is not None:
|
||||
quant_config_name = quantization_config["quantization"]
|
||||
elif args.quantization != "None":
|
||||
elif getattr(config, 'quantization', None) != "None":
|
||||
quantization_config = {}
|
||||
quant_config_name = args.quantization
|
||||
quant_config_name = getattr(config, 'quantization', None)
|
||||
quantization_config["quantization"] = quant_config_name
|
||||
# use some trick code for ernie model and will unify it in future.
|
||||
is_ernie = "Ernie4_5_ForCausalLM" in config.get("architectures") or \
|
||||
"Ernie4_5_MoeForCausalLM" in config.get("architectures")
|
||||
# Special handling for Ernie models
|
||||
is_ernie = "Ernie4_5_ForCausalLM" in model_config_dict.get("architectures", []) or \
|
||||
"Ernie4_5_MoeForCausalLM" in model_config_dict.get("architectures", [])
|
||||
if use_moe and quant_config_name == "wint4" and is_ernie:
|
||||
quantization_config["dense_quant_type"] = "wint8"
|
||||
quantization_config["moe_quant_type"] = "wint4"
|
||||
@@ -732,6 +759,7 @@ def initialize_fd_config(args: argparse.Namespace) -> FDConfig:
|
||||
quant_cls = get_quantization_config(quant_config_name)
|
||||
quant_config = quant_cls.from_config(quantization_config)
|
||||
|
||||
# Log quantization info
|
||||
logger.info("===========quantization_config==============")
|
||||
if quant_config is not None:
|
||||
if model_config.is_quantized:
|
||||
@@ -742,29 +770,35 @@ def initialize_fd_config(args: argparse.Namespace) -> FDConfig:
|
||||
logger.info(
|
||||
"Model Status: Original (will apply online quantization)")
|
||||
|
||||
logger.info(f"Quantization Method: {args.quantization or 'None'}")
|
||||
logger.info(f"Quantization Method: {getattr(config, 'quantization', 'None')}")
|
||||
else:
|
||||
logger.info(
|
||||
"No quantization config found and use original weight and act dtype."
|
||||
)
|
||||
|
||||
model_config.architectures = config.get("architectures")
|
||||
model_config.enable_logprob = config.enable_logprob
|
||||
|
||||
model_config.architectures = model_config_dict.get("architectures")
|
||||
|
||||
# Update load config
|
||||
logger.info("===========load_config==============")
|
||||
load_config.dynamic_load_weight = args.dynamic_load_weight
|
||||
load_config.load_strategy = args.load_strategy
|
||||
load_config.dynamic_load_weight = getattr(config, 'dynamic_load_weight', False)
|
||||
load_config.load_strategy = getattr(config, 'load_strategy', None)
|
||||
logger.info(f"- Dynamic load weight: {load_config.dynamic_load_weight}")
|
||||
logger.info(f"- Load strategy: {load_config.load_strategy}")
|
||||
|
||||
fd_config = FDConfig(model_config=model_config,
|
||||
parallel_config=parallel_config,
|
||||
speculative_config=speculative_config,
|
||||
device_config=device_config,
|
||||
load_config=load_config,
|
||||
moe_config=moe_config,
|
||||
decoding_config=decoding_config,
|
||||
quant_config=quant_config,
|
||||
graph_opt_config=graph_opt_config)
|
||||
# Create and return FDConfig
|
||||
fd_config = FDConfig(
|
||||
model_config=model_config,
|
||||
parallel_config=parallel_config,
|
||||
speculative_config=speculative_config,
|
||||
device_config=device_config,
|
||||
load_config=load_config,
|
||||
moe_config=moe_config,
|
||||
decoding_config=decoding_config,
|
||||
quant_config=quant_config,
|
||||
graph_opt_config=graph_opt_config
|
||||
)
|
||||
|
||||
return fd_config
|
||||
|
||||
|
@@ -280,6 +280,7 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
request.eos_token_ids, dtype="int64").reshape(-1, 1)
|
||||
self.share_inputs["pre_ids"][idx:idx + 1] = -1
|
||||
self.share_inputs["top_p"][idx:idx + 1] = request.get("top_p", 0.7)
|
||||
self.share_inputs["top_k"][idx:idx + 1] = request.get("top_k", 0)
|
||||
self.share_inputs["temperature"][idx:idx + 1] = request.get(
|
||||
"temperature", 0.95)
|
||||
self.share_inputs["penalty_score"][idx:idx + 1] = request.get(
|
||||
@@ -347,8 +348,11 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["eos_token_id"] = paddle.full(
|
||||
[self.parallel_config.eos_tokens_lens, 1], 0, dtype='int64')
|
||||
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1],
|
||||
self.model_config.top_p,
|
||||
dtype='float32')
|
||||
self.model_config.top_p,
|
||||
dtype='float32')
|
||||
self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1],
|
||||
0,
|
||||
dtype='int64')
|
||||
self.share_inputs["temperature"] = paddle.full(
|
||||
[max_num_seqs, 1], self.model_config.temperature, dtype='float32')
|
||||
self.share_inputs["penalty_score"] = paddle.full(
|
||||
@@ -498,6 +502,7 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
self.sampling_metadata = SamplingMetadata(
|
||||
temperature=self.share_inputs["temperature"],
|
||||
top_p=self.share_inputs["top_p"],
|
||||
top_k=self.share_inputs["top_k"],
|
||||
step_idx=self.share_inputs["step_idx"],
|
||||
pre_token_ids=self.share_inputs["pre_ids"],
|
||||
frequency_penalties=self.share_inputs["frequency_score"],
|
||||
@@ -691,7 +696,7 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
# 4. Compute logits, Sample
|
||||
logits = self.model.compute_logits(hiddden_states)
|
||||
|
||||
sampled_token_ids = self.sampler(logits, self.sampling_metadata)
|
||||
sampler_output = self.sampler(logits, self.sampling_metadata)
|
||||
|
||||
# 5. Speculative decode
|
||||
|
||||
@@ -720,7 +725,7 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
accept_tokens=None,
|
||||
accept_num=None,
|
||||
)
|
||||
xpu_post_process(sampled_token_ids=sampled_token_ids,
|
||||
xpu_post_process(sampled_token_ids=sampler_output.sampled_token_ids,
|
||||
model_output=model_output_data)
|
||||
|
||||
# 7. Updata 'infer_seed' and step_paddle()
|
||||
|
@@ -29,3 +29,10 @@ triton==3.3
|
||||
use-triton-in-paddle
|
||||
crcmod
|
||||
fastsafetensors==0.1.14
|
||||
opentelemetry-api>=1.24.0
|
||||
opentelemetry-sdk>=1.24.0
|
||||
opentelemetry-instrumentation-redis
|
||||
opentelemetry-instrumentation-mysql
|
||||
opentelemetry-distro
|
||||
opentelemetry-exporter-otlp
|
||||
opentelemetry-instrumentation-fastapi
|
20
setup.py
20
setup.py
@@ -19,6 +19,9 @@ import re
|
||||
import sys
|
||||
import paddle
|
||||
import subprocess
|
||||
from setuptools import setup
|
||||
from setuptools.command.install import install
|
||||
from pathlib import Path
|
||||
from pathlib import Path
|
||||
from setuptools import Extension, find_packages, setup
|
||||
from setuptools.command.build_ext import build_ext
|
||||
@@ -137,8 +140,16 @@ class CMakeBuild(build_ext):
|
||||
cwd=build_temp,
|
||||
check=True)
|
||||
subprocess.run(["cmake", "--build", ".", *build_args],
|
||||
cwd=build_temp,
|
||||
check=True)
|
||||
cwd=build_temp,
|
||||
check=True)
|
||||
|
||||
class PostInstallCommand(install):
|
||||
"""在标准安装完成后执行自定义命令"""
|
||||
def run(self):
|
||||
# 先执行标准安装步骤
|
||||
install.run(self)
|
||||
# 执行自定义命令
|
||||
subprocess.check_call(["opentelemetry-bootstrap", "-a", "install"])
|
||||
|
||||
def load_requirements():
|
||||
"""Load dependencies from requirements.txt"""
|
||||
@@ -169,10 +180,12 @@ def get_name():
|
||||
|
||||
cmdclass_dict = {'bdist_wheel': CustomBdistWheel}
|
||||
cmdclass_dict['build_ext'] = CMakeBuild
|
||||
FASTDEPLOY_VERSION = os.environ.get("FASTDEPLOY_VERSION", "2.0.0-dev")
|
||||
cmdclass_dict['build_optl'] = PostInstallCommand
|
||||
|
||||
setup(
|
||||
name=get_name(),
|
||||
version="2.0.0",
|
||||
version="2.0.2",
|
||||
author="PaddlePaddle",
|
||||
author_email="dltp@baidu.com",
|
||||
description="FastDeploy: Large Language Model Serving.",
|
||||
@@ -211,3 +224,4 @@ setup(
|
||||
python_requires=">=3.7",
|
||||
extras_require={"test": ["pytest>=6.0"]},
|
||||
)
|
||||
|
||||
|
@@ -12,15 +12,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
import time
|
||||
import subprocess
|
||||
import socket
|
||||
import os
|
||||
import signal
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
# Read ports from environment variables; use default values if not set
|
||||
FD_API_PORT = int(os.getenv("FD_API_PORT", 8188))
|
||||
@@ -313,4 +314,66 @@ def test_streaming(openai_client, capsys):
|
||||
output = []
|
||||
for chunk in response:
|
||||
output.append(chunk.choices[0].text)
|
||||
assert len(output) > 0
|
||||
assert len(output) > 0
|
||||
|
||||
def test_non_streaming_with_stop_str(openai_client):
|
||||
"""
|
||||
Test non-streaming chat functionality with the local service
|
||||
"""
|
||||
response = openai_client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
temperature=1,
|
||||
max_tokens=5,
|
||||
metadata={"include_stop_str_in_output": True},
|
||||
stream=False,
|
||||
)
|
||||
# Assertions to check the response structure
|
||||
assert hasattr(response, 'choices')
|
||||
assert len(response.choices) > 0
|
||||
assert response.choices[0].message.content.endswith("</s>")
|
||||
|
||||
response = openai_client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
temperature=1,
|
||||
max_tokens=5,
|
||||
metadata={"include_stop_str_in_output": False},
|
||||
stream=False,
|
||||
)
|
||||
# Assertions to check the response structure
|
||||
assert hasattr(response, 'choices')
|
||||
assert len(response.choices) > 0
|
||||
assert not response.choices[0].message.content.endswith("</s>")
|
||||
|
||||
def test_streaming_with_stop_str(openai_client):
|
||||
"""
|
||||
Test non-streaming chat functionality with the local service
|
||||
"""
|
||||
response = openai_client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
temperature=1,
|
||||
max_tokens=5,
|
||||
metadata={"include_stop_str_in_output": True},
|
||||
stream=True,
|
||||
)
|
||||
# Assertions to check the response structure
|
||||
last_token = ""
|
||||
for chunk in response:
|
||||
last_token = chunk.choices[0].delta.content
|
||||
assert last_token == "</s>"
|
||||
|
||||
response = openai_client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
temperature=1,
|
||||
max_tokens=5,
|
||||
metadata={"include_stop_str_in_output": False},
|
||||
stream=True,
|
||||
)
|
||||
# Assertions to check the response structure
|
||||
last_token = ""
|
||||
for chunk in response:
|
||||
last_token = chunk.choices[0].delta.content
|
||||
assert last_token != "</s>"
|
||||
|
Reference in New Issue
Block a user