mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 17:17:14 +08:00
[OPs] MoE support wfp8afp8(channelwise) and improve per_token_quant_fp8 (#4238)
This commit is contained in:
@@ -151,6 +151,34 @@ inline int GetGPUComputeCapability(int id) {
|
|||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#ifndef FP8_E4M3_MAX
|
||||||
|
#define FP8_E4M3_MAX 448.0
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifndef DISPATCH_FLOAT_FP6_DTYPE
|
||||||
|
#define DISPATCH_FLOAT_FP6_DTYPE(pd_dtype, c_type, ...) \
|
||||||
|
switch (pd_dtype) { \
|
||||||
|
case phi::DataType::FLOAT32: { \
|
||||||
|
using c_type = float; \
|
||||||
|
__VA_ARGS__ \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
case phi::DataType::BFLOAT16: { \
|
||||||
|
using c_type = phi::dtype::bfloat16; \
|
||||||
|
__VA_ARGS__ \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
case phi::DataType::FLOAT16: { \
|
||||||
|
using c_type = phi::dtype::float16; \
|
||||||
|
__VA_ARGS__ \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
default: { \
|
||||||
|
PD_THROW("Only supported attr of input type in [fp32, fp16, bf16]."); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
inline constexpr uint32_t next_pow_2(uint32_t const num) {
|
inline constexpr uint32_t next_pow_2(uint32_t const num) {
|
||||||
if (num <= 1)
|
if (num <= 1)
|
||||||
return num;
|
return num;
|
||||||
@@ -573,3 +601,28 @@ inline bool GetMlaUseTensorcore() {
|
|||||||
flags_mla_use_tensorcore && enable_mla_tensorcore;
|
flags_mla_use_tensorcore && enable_mla_tensorcore;
|
||||||
return mla_use_tensorcore;
|
return mla_use_tensorcore;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ float warpReduceMax(float value) {
|
||||||
|
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 16));
|
||||||
|
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 8));
|
||||||
|
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 4));
|
||||||
|
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 2));
|
||||||
|
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 1));
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ float blockReduceMax(float value) {
|
||||||
|
static __shared__ float warpLevelMaxs[WARP_SIZE];
|
||||||
|
const int laneId = threadIdx.x % WARP_SIZE;
|
||||||
|
const int warpId = threadIdx.x / WARP_SIZE;
|
||||||
|
|
||||||
|
value = warpReduceMax(value);
|
||||||
|
|
||||||
|
if (laneId == 0) warpLevelMaxs[warpId] = value;
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0;
|
||||||
|
if (warpId == 0) value = warpReduceMax(value);
|
||||||
|
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
@@ -3,6 +3,158 @@
|
|||||||
|
|
||||||
#include "quantization/common.cuh"
|
#include "quantization/common.cuh"
|
||||||
|
|
||||||
|
// adapted from: https://github.com/sgl-project/sglang/blob/v0.5.2rc2/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// 1. Warp‑local, no shared memory
|
||||||
|
// • One warp handles one token.
|
||||||
|
// • Eight tokens per 256‑thread CTA.
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
template <typename T, typename DST_DTYPE, int kTokensPerCTA = 8, int kVecSize = 16>
|
||||||
|
__global__ void per_token_quant_fp8_kernel(
|
||||||
|
const T* __restrict__ input,
|
||||||
|
DST_DTYPE* __restrict__ output_q,
|
||||||
|
float* __restrict__ output_s,
|
||||||
|
const float scale_ub,
|
||||||
|
const int64_t hidden_size,
|
||||||
|
const int64_t num_tokens) {
|
||||||
|
const int warp_id = threadIdx.x / WARP_SIZE; // 0‑7 (8 warps)
|
||||||
|
const int lane_id = threadIdx.x & (WARP_SIZE - 1); // 0‑31
|
||||||
|
const int token_id = blockIdx.x * kTokensPerCTA + warp_id;
|
||||||
|
if (token_id >= num_tokens) return;
|
||||||
|
|
||||||
|
// Global tensors for this token
|
||||||
|
const T* token_input = input + token_id * hidden_size;
|
||||||
|
DST_DTYPE* token_output = output_q + token_id * hidden_size;
|
||||||
|
float* token_scale = output_s + token_id;
|
||||||
|
|
||||||
|
//
|
||||||
|
// Pass-1: Perform a warp reduce to find the max_value of a token's hidden_size
|
||||||
|
//
|
||||||
|
float max_value = 0.f;
|
||||||
|
using vec_t = AlignedVector<T, kVecSize>;
|
||||||
|
const int32_t num_vec_elems = hidden_size / kVecSize;
|
||||||
|
|
||||||
|
for (int32_t i = lane_id; i < num_vec_elems; i += WARP_SIZE) {
|
||||||
|
vec_t input_vec;
|
||||||
|
Load(token_input + i * kVecSize, &input_vec);
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (uint32_t j = 0; j < kVecSize; ++j) {
|
||||||
|
max_value = fmaxf(max_value, fabsf(static_cast<float>(input_vec[j])));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
float warp_max = warpReduceMax(max_value);
|
||||||
|
if (scale_ub > 0){
|
||||||
|
warp_max = fminf(warp_max, scale_ub);
|
||||||
|
}
|
||||||
|
float scale;
|
||||||
|
scale = warp_max / FP8_E4M3_MAX;
|
||||||
|
// Broadcast scale
|
||||||
|
if (lane_id == 0) {
|
||||||
|
token_scale[0] = scale;
|
||||||
|
}
|
||||||
|
float scale_inv = (scale == 0.f) ? 0.f : 1.0f / scale;
|
||||||
|
|
||||||
|
//
|
||||||
|
// Pass-2: quantize and write back
|
||||||
|
//
|
||||||
|
for (int i = lane_id; i < num_vec_elems; i += WARP_SIZE) {
|
||||||
|
vec_t input_vec;
|
||||||
|
Load(token_input + i * kVecSize, &input_vec);
|
||||||
|
DST_DTYPE output_arr[kVecSize];
|
||||||
|
#pragma unroll
|
||||||
|
for (uint32_t j = 0; j < kVecSize; ++j) {
|
||||||
|
float val = static_cast<float>(input_vec[j]) * scale_inv;
|
||||||
|
val = fmaxf(fminf(val, FP8_E4M3_MAX), -FP8_E4M3_MAX);
|
||||||
|
output_arr[j] = static_cast<DST_DTYPE>(val);
|
||||||
|
}
|
||||||
|
if constexpr (kVecSize == 16) {
|
||||||
|
*(uint4*)(token_output + i * kVecSize) = *(uint4*)output_arr;
|
||||||
|
} else {
|
||||||
|
// Use element-wise copy for vector size 8 to ensure correctness
|
||||||
|
for (int k = 0; k < kVecSize; ++k) {
|
||||||
|
token_output[i * kVecSize + k] = output_arr[k];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// 2. Baseline kernel (1 token / CTA, CUB block reduce)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
template <typename T, typename DST_DTYPE, int kVecSize = 16>
|
||||||
|
__global__ void per_token_quant_fp8_small_batch_kernel(
|
||||||
|
const T* __restrict__ input,
|
||||||
|
DST_DTYPE* __restrict__ output_q,
|
||||||
|
float* __restrict__ output_s,
|
||||||
|
const float scale_ub,
|
||||||
|
const int64_t hidden_size,
|
||||||
|
const int64_t num_tokens) {
|
||||||
|
const int token_idx = blockIdx.x;
|
||||||
|
if (token_idx >= num_tokens) return;
|
||||||
|
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
const int block_dim = blockDim.x;
|
||||||
|
|
||||||
|
const T* token_input = input + token_idx * hidden_size;
|
||||||
|
DST_DTYPE* token_output = output_q + token_idx * hidden_size;
|
||||||
|
|
||||||
|
float max_value = 0.0f;
|
||||||
|
|
||||||
|
// Use template parameter for vector size
|
||||||
|
using vec_t = AlignedVector<T, kVecSize>;
|
||||||
|
const int32_t num_vec_elems = hidden_size / kVecSize;
|
||||||
|
|
||||||
|
// Find max using vectorized loads
|
||||||
|
for (int32_t i = tid; i < num_vec_elems; i += block_dim) {
|
||||||
|
vec_t input_vec;
|
||||||
|
Load(token_input + i * kVecSize, &input_vec);
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (uint32_t j = 0; j < kVecSize; ++j) {
|
||||||
|
float val = static_cast<float>(input_vec[j]);
|
||||||
|
max_value = fmaxf(max_value, fabsf(val));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
max_value = blockReduceMax(max_value);
|
||||||
|
if (scale_ub > 0){
|
||||||
|
max_value = fminf(max_value, scale_ub);
|
||||||
|
}
|
||||||
|
__shared__ float scale;
|
||||||
|
if (tid == 0) {
|
||||||
|
scale = max_value / FP8_E4M3_MAX;
|
||||||
|
output_s[token_idx] = scale;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
const float scale_inv = 1.0f / scale;
|
||||||
|
|
||||||
|
// Quantize using vectorized loads
|
||||||
|
for (int32_t i = tid; i < num_vec_elems; i += block_dim) {
|
||||||
|
vec_t input_vec;
|
||||||
|
Load(token_input + i * kVecSize, &input_vec);
|
||||||
|
|
||||||
|
DST_DTYPE output_arr[kVecSize];
|
||||||
|
#pragma unroll
|
||||||
|
for (uint32_t j = 0; j < kVecSize; ++j) {
|
||||||
|
float val = fmaxf(fminf(static_cast<float>(input_vec[j]) * scale_inv, FP8_E4M3_MAX), -FP8_E4M3_MAX);
|
||||||
|
output_arr[j] = static_cast<DST_DTYPE>(val);
|
||||||
|
}
|
||||||
|
|
||||||
|
if constexpr (kVecSize == 16) {
|
||||||
|
*(uint4*)(token_output + i * kVecSize) = *(uint4*)output_arr;
|
||||||
|
} else {
|
||||||
|
// Use element-wise copy for vector size 8 to ensure correctness
|
||||||
|
for (int k = 0; k < kVecSize; ++k) {
|
||||||
|
token_output[i * kVecSize + k] = output_arr[k];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
|
|
||||||
template <typename scalar_t, typename fp8_type>
|
template <typename scalar_t, typename fp8_type>
|
||||||
@@ -179,39 +331,78 @@ void DynamicPerTokenScaledFp8Quant(paddle::Tensor &out, // [..., d]
|
|||||||
auto rank = input.dims().size();
|
auto rank = input.dims().size();
|
||||||
int const hidden_size = input.dims()[rank - 1];
|
int const hidden_size = input.dims()[rank - 1];
|
||||||
int const num_tokens = input.numel() / hidden_size;
|
int const num_tokens = input.numel() / hidden_size;
|
||||||
|
cudaStream_t stream = input.stream();
|
||||||
|
|
||||||
|
if (hidden_size % 8 == 0){
|
||||||
|
int device = 0;
|
||||||
|
cudaGetDevice(&device);
|
||||||
|
int sm_count = 0;
|
||||||
|
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device);
|
||||||
|
const int TOKENS_PER_CTA = 8;
|
||||||
|
const bool use_warp_kernel = (num_tokens >= sm_count * 2 * TOKENS_PER_CTA);
|
||||||
|
const bool use_vec16 = (hidden_size % 16 == 0);
|
||||||
|
DISPATCH_FLOAT_FP6_DTYPE(input.dtype(), scalar_t, {
|
||||||
|
if (use_warp_kernel) {
|
||||||
|
// -------- warp‑local ---------------------------------------------------
|
||||||
|
constexpr int THREADS = TOKENS_PER_CTA * WARP_SIZE; // 256
|
||||||
|
dim3 grid((num_tokens + TOKENS_PER_CTA - 1) / TOKENS_PER_CTA);
|
||||||
|
dim3 block(THREADS);
|
||||||
|
|
||||||
|
if (use_vec16) {
|
||||||
|
per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 16><<<grid, block, 0, stream>>>(
|
||||||
|
reinterpret_cast<const scalar_t*>(input.data<scalar_t>()),
|
||||||
|
reinterpret_cast<__nv_fp8_e4m3*>(out.data<fp8_t>()),
|
||||||
|
reinterpret_cast<float*>(scales.data<float>()),
|
||||||
|
scale_ub,
|
||||||
|
hidden_size,
|
||||||
|
num_tokens);
|
||||||
|
} else {
|
||||||
|
per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 8><<<grid, block, 0, stream>>>(
|
||||||
|
reinterpret_cast<const scalar_t*>(input.data<scalar_t>()),
|
||||||
|
reinterpret_cast<__nv_fp8_e4m3*>(out.data<fp8_t>()),
|
||||||
|
reinterpret_cast<float*>(scales.data<float>()),
|
||||||
|
scale_ub,
|
||||||
|
hidden_size,
|
||||||
|
num_tokens);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// -------- baseline -----------------------------------------------------
|
||||||
|
constexpr int THREADS = 256;
|
||||||
|
dim3 grid(num_tokens);
|
||||||
|
dim3 block(THREADS);
|
||||||
|
|
||||||
|
if (use_vec16) {
|
||||||
|
per_token_quant_fp8_small_batch_kernel<scalar_t, __nv_fp8_e4m3, 16><<<grid, block, 0, stream>>>(
|
||||||
|
reinterpret_cast<const scalar_t*>(input.data<scalar_t>()),
|
||||||
|
reinterpret_cast<__nv_fp8_e4m3*>(out.data<fp8_t>()),
|
||||||
|
reinterpret_cast<float*>(scales.data<float>()),
|
||||||
|
scale_ub,
|
||||||
|
hidden_size,
|
||||||
|
num_tokens);
|
||||||
|
} else {
|
||||||
|
per_token_quant_fp8_small_batch_kernel<scalar_t, __nv_fp8_e4m3, 8><<<grid, block, 0, stream>>>(
|
||||||
|
reinterpret_cast<const scalar_t*>(input.data<scalar_t>()),
|
||||||
|
reinterpret_cast<__nv_fp8_e4m3*>(out.data<fp8_t>()),
|
||||||
|
reinterpret_cast<float*>(scales.data<float>()),
|
||||||
|
scale_ub,
|
||||||
|
hidden_size,
|
||||||
|
num_tokens);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
dim3 const grid(num_tokens);
|
dim3 const grid(num_tokens);
|
||||||
dim3 const block(std::min(hidden_size, 1024));
|
dim3 const block(std::min(hidden_size, 1024));
|
||||||
|
|
||||||
cudaStream_t stream = input.stream();
|
DISPATCH_FLOAT_FP6_DTYPE(input.dtype(), scalar_t, {
|
||||||
|
fastdeploy::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||||
|
<<<grid, block, 0, stream>>>(out.data<fp8_t>(), scales.data<float>(),
|
||||||
|
input.data<scalar_t>(), scale_ub,
|
||||||
|
hidden_size);
|
||||||
|
});
|
||||||
|
|
||||||
switch (input.dtype()) {
|
|
||||||
case paddle::DataType::FLOAT32: {
|
|
||||||
using scalar_t = float;
|
|
||||||
fastdeploy::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
|
||||||
<<<grid, block, 0, stream>>>(out.data<fp8_t>(), scales.data<float>(),
|
|
||||||
input.data<scalar_t>(), scale_ub,
|
|
||||||
hidden_size);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case paddle::DataType::FLOAT16: {
|
|
||||||
using scalar_t = phi::dtype::float16;
|
|
||||||
fastdeploy::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
|
||||||
<<<grid, block, 0, stream>>>(out.data<fp8_t>(), scales.data<float>(),
|
|
||||||
input.data<scalar_t>(), scale_ub,
|
|
||||||
hidden_size);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case paddle::DataType::BFLOAT16: {
|
|
||||||
using scalar_t = phi::dtype::bfloat16;
|
|
||||||
fastdeploy::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
|
||||||
<<<grid, block, 0, stream>>>(out.data<fp8_t>(), scales.data<float>(),
|
|
||||||
input.data<scalar_t>(), scale_ub,
|
|
||||||
hidden_size);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
PD_THROW("Only supported attr of input type in [fp32, fp16, bf16].");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
PD_BUILD_STATIC_OP(static_scaled_fp8_quant)
|
PD_BUILD_STATIC_OP(static_scaled_fp8_quant)
|
||||||
|
@@ -32,6 +32,7 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
|
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
|
||||||
|
from fastdeploy.model_executor.layers.quantization.ops import scaled_fp8_quant
|
||||||
|
|
||||||
|
|
||||||
class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||||
@@ -332,6 +333,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
|||||||
compute_type_enum=1,
|
compute_type_enum=1,
|
||||||
use_fp8_w8a8=False,
|
use_fp8_w8a8=False,
|
||||||
use_int8_w8a16=True,
|
use_int8_w8a16=True,
|
||||||
|
per_channel_quant=False,
|
||||||
even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0,
|
even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -384,6 +386,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
|||||||
compute_type_enum=1,
|
compute_type_enum=1,
|
||||||
use_fp8_w8a8=False,
|
use_fp8_w8a8=False,
|
||||||
use_int8_w8a16=True,
|
use_int8_w8a16=True,
|
||||||
|
per_channel_quant=False,
|
||||||
even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0,
|
even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -395,6 +398,379 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Wfp8Afp8MoEMethod(QuantMethodBase):
|
||||||
|
"""
|
||||||
|
Use Triton Group Gemm to compute Fused wfp8afp8 Quant MoE.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, quant_config):
|
||||||
|
"""
|
||||||
|
Triton Group Gemm to compute Fused MoE.
|
||||||
|
"""
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"]
|
||||||
|
self.added_scale_attrs = [
|
||||||
|
"up_gate_proj_weight_scale",
|
||||||
|
"down_proj_weight_scale",
|
||||||
|
]
|
||||||
|
|
||||||
|
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False) -> None:
|
||||||
|
"""process_prequanted_weights"""
|
||||||
|
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
|
||||||
|
"""
|
||||||
|
Triton MoE create weight process.
|
||||||
|
"""
|
||||||
|
self.up_gate_proj_weight_shape = [
|
||||||
|
layer.num_local_experts,
|
||||||
|
layer.moe_intermediate_size * 2,
|
||||||
|
layer.hidden_size,
|
||||||
|
]
|
||||||
|
self.down_proj_weight_shape = [
|
||||||
|
layer.num_local_experts,
|
||||||
|
layer.hidden_size,
|
||||||
|
layer.moe_intermediate_size,
|
||||||
|
]
|
||||||
|
self.up_gate_proj_scale_shape = [
|
||||||
|
layer.num_local_experts,
|
||||||
|
layer.moe_intermediate_size * 2,
|
||||||
|
1,
|
||||||
|
]
|
||||||
|
self.down_proj_scale_shape = [
|
||||||
|
layer.num_local_experts,
|
||||||
|
layer.hidden_size,
|
||||||
|
1,
|
||||||
|
]
|
||||||
|
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
|
||||||
|
layer.up_gate_proj_weight = layer.create_parameter(
|
||||||
|
shape=[layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size * 2],
|
||||||
|
dtype=layer.weight_dtype,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
)
|
||||||
|
|
||||||
|
layer.down_proj_weight = layer.create_parameter(
|
||||||
|
shape=[layer.num_local_experts, layer.moe_intermediate_size, layer.hidden_size],
|
||||||
|
dtype=layer.weight_dtype,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
)
|
||||||
|
|
||||||
|
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
|
||||||
|
|
||||||
|
set_weight_attrs(
|
||||||
|
layer.up_gate_proj_weight,
|
||||||
|
{
|
||||||
|
**extra_weight_attrs,
|
||||||
|
"tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=True),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
set_weight_attrs(
|
||||||
|
layer.down_proj_weight,
|
||||||
|
{
|
||||||
|
**extra_weight_attrs,
|
||||||
|
"tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=False),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.weight_dtype = paddle.float8_e4m3fn
|
||||||
|
up_gate_proj_weight_name = self.added_weight_attrs[0]
|
||||||
|
down_proj_weight_name = self.added_weight_attrs[1]
|
||||||
|
up_gate_proj_scale_name = self.added_scale_attrs[0]
|
||||||
|
down_proj_scale_name = self.added_scale_attrs[1]
|
||||||
|
setattr(
|
||||||
|
layer,
|
||||||
|
up_gate_proj_weight_name,
|
||||||
|
layer.create_parameter(
|
||||||
|
shape=self.up_gate_proj_weight_shape,
|
||||||
|
dtype=self.weight_dtype,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
setattr(
|
||||||
|
layer,
|
||||||
|
down_proj_weight_name,
|
||||||
|
layer.create_parameter(
|
||||||
|
shape=self.down_proj_weight_shape,
|
||||||
|
dtype=self.weight_dtype,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# weight_scale
|
||||||
|
setattr(
|
||||||
|
layer,
|
||||||
|
up_gate_proj_scale_name,
|
||||||
|
layer.create_parameter(
|
||||||
|
shape=self.up_gate_proj_scale_shape,
|
||||||
|
dtype="float32",
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
setattr(
|
||||||
|
layer,
|
||||||
|
down_proj_scale_name,
|
||||||
|
layer.create_parameter(
|
||||||
|
shape=self.down_proj_scale_shape,
|
||||||
|
dtype="float32",
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer):
|
||||||
|
""" """
|
||||||
|
if not self.quant_config.is_checkpoint_bf16:
|
||||||
|
return
|
||||||
|
weight_id_map = {"gate_up": 0, "down": 1}
|
||||||
|
if (
|
||||||
|
hasattr(layer.up_gate_proj_weight, "tensor_track")
|
||||||
|
and layer.up_gate_proj_weight.tensor_track is not None
|
||||||
|
and layer.up_gate_proj_weight.tensor_track.is_fully_copied()
|
||||||
|
):
|
||||||
|
weight_type = "gate_up"
|
||||||
|
layer.up_gate_proj_weight.tensor_track = None
|
||||||
|
else:
|
||||||
|
weight_type = "down"
|
||||||
|
layer.down_proj_weight.tensor_track = None
|
||||||
|
|
||||||
|
# weight
|
||||||
|
weight_name = self.added_weight_attrs[weight_id_map[weight_type]]
|
||||||
|
weight_shape = self.up_gate_proj_weight_shape if weight_type == "gate_up" else self.down_proj_weight_shape
|
||||||
|
weight_dtype = paddle.float8_e4m3fn
|
||||||
|
# scale
|
||||||
|
scale_name = self.added_scale_attrs[weight_id_map[weight_type]]
|
||||||
|
scale_shape = self.up_gate_proj_scale_shape if weight_type == "gate_up" else self.down_proj_scale_shape
|
||||||
|
scale_dtype = "float32"
|
||||||
|
|
||||||
|
# 2.crate tmp tensor
|
||||||
|
|
||||||
|
weight = paddle.empty(shape=weight_shape, dtype=weight_dtype)
|
||||||
|
scale = paddle.empty(shape=scale_shape, dtype=scale_dtype)
|
||||||
|
|
||||||
|
# 3.quantize weight
|
||||||
|
from fastdeploy.model_executor.layers.utils import per_token_cast_to_fp8
|
||||||
|
|
||||||
|
for expert_id in range(layer.num_experts):
|
||||||
|
weight_quant, scale[expert_id] = per_token_cast_to_fp8(
|
||||||
|
getattr(layer, weight_name)[expert_id].transpose([1, 0]).contiguous(),
|
||||||
|
)
|
||||||
|
weight[expert_id].copy_(weight_quant, False)
|
||||||
|
getattr(layer, weight_name).value().get_tensor()._clear()
|
||||||
|
|
||||||
|
# create weight
|
||||||
|
setattr(
|
||||||
|
layer,
|
||||||
|
weight_name,
|
||||||
|
layer.create_parameter(
|
||||||
|
shape=weight_shape,
|
||||||
|
dtype=weight_dtype,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# create scale
|
||||||
|
setattr(
|
||||||
|
layer,
|
||||||
|
scale_name,
|
||||||
|
layer.create_parameter(
|
||||||
|
shape=scale_shape,
|
||||||
|
dtype=scale_dtype,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
getattr(layer, weight_name).copy_(weight, False)
|
||||||
|
getattr(layer, scale_name).copy_(scale, False)
|
||||||
|
|
||||||
|
def check(self, layer: nn.Layer, up_gate_proj_weights, down_proj_weights):
|
||||||
|
"""
|
||||||
|
check layer is valid for this method
|
||||||
|
"""
|
||||||
|
assert up_gate_proj_weights[0].shape == [
|
||||||
|
layer.moe_intermediate_size * 2,
|
||||||
|
layer.hidden_size,
|
||||||
|
]
|
||||||
|
assert down_proj_weights[0].shape == [
|
||||||
|
layer.hidden_size,
|
||||||
|
layer.moe_intermediate_size,
|
||||||
|
]
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: nn.Layer,
|
||||||
|
x: paddle.Tensor,
|
||||||
|
gate: nn.Layer,
|
||||||
|
) -> paddle.Tensor:
|
||||||
|
"""
|
||||||
|
Triton compute Fused MoE.
|
||||||
|
"""
|
||||||
|
gate_out = gate(x.cast("float32"))
|
||||||
|
token_num = x.shape[0]
|
||||||
|
top_k = layer.top_k
|
||||||
|
num_local_experts = layer.num_local_experts
|
||||||
|
moe_intermediate_size = layer.moe_intermediate_size
|
||||||
|
hidden_size = layer.hidden_size
|
||||||
|
E, N1, _ = getattr(layer, self.added_weight_attrs[0]).shape
|
||||||
|
|
||||||
|
if layer.topk_method == "noaux_tc":
|
||||||
|
gate_out, topk_weights, topk_ids = get_moe_scores(
|
||||||
|
gate_out,
|
||||||
|
layer.n_group,
|
||||||
|
layer.topk_group,
|
||||||
|
layer.top_k,
|
||||||
|
layer.routed_scaling_factor,
|
||||||
|
layer.gate_correction_bias,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||||
|
gate_out,
|
||||||
|
layer.gate_correction_bias,
|
||||||
|
layer.top_k,
|
||||||
|
True, # apply_norm_weight
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 256,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 8,
|
||||||
|
"num_stages": 4,
|
||||||
|
}
|
||||||
|
if token_num <= E:
|
||||||
|
config = {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4,
|
||||||
|
}
|
||||||
|
|
||||||
|
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func(
|
||||||
|
topk_ids, num_local_experts, config["BLOCK_SIZE_M"]
|
||||||
|
)
|
||||||
|
max_possible_num_post_padded = sorted_token_ids.shape[0]
|
||||||
|
grid = (
|
||||||
|
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"])
|
||||||
|
* ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
up_gate_proj_out = paddle.empty(
|
||||||
|
[token_num * top_k, moe_intermediate_size * 2],
|
||||||
|
dtype=x.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .triton_moe_kernels import fused_moe_kernel_paddle
|
||||||
|
|
||||||
|
x_q, x_scale = scaled_fp8_quant(x, use_per_token_if_dynamic=True)
|
||||||
|
|
||||||
|
fused_moe_kernel_paddle[grid](
|
||||||
|
x_q,
|
||||||
|
layer.up_gate_proj_weight,
|
||||||
|
up_gate_proj_out,
|
||||||
|
x_scale,
|
||||||
|
layer.up_gate_proj_weight_scale,
|
||||||
|
None,
|
||||||
|
sorted_token_ids,
|
||||||
|
expert_ids,
|
||||||
|
num_tokens_post_padded,
|
||||||
|
max_possible_num_post_padded,
|
||||||
|
token_num * top_k,
|
||||||
|
N=moe_intermediate_size * 2,
|
||||||
|
K=hidden_size,
|
||||||
|
stride_am=x_q.strides[0],
|
||||||
|
stride_ak=x_q.strides[1],
|
||||||
|
stride_be=layer.up_gate_proj_weight.strides[0],
|
||||||
|
stride_bk=layer.up_gate_proj_weight.strides[2],
|
||||||
|
stride_bn=layer.up_gate_proj_weight.strides[1],
|
||||||
|
stride_cm=up_gate_proj_out.strides[0],
|
||||||
|
stride_cn=up_gate_proj_out.strides[1],
|
||||||
|
#
|
||||||
|
stride_asm=x_scale.strides[0],
|
||||||
|
stride_ask=x_scale.strides[1],
|
||||||
|
stride_bse=layer.up_gate_proj_weight_scale.strides[0],
|
||||||
|
stride_bsk=layer.up_gate_proj_weight_scale.strides[2],
|
||||||
|
stride_bsn=layer.up_gate_proj_weight_scale.strides[1],
|
||||||
|
group_n=-1,
|
||||||
|
group_k=-1,
|
||||||
|
# Meta-parameters
|
||||||
|
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
|
||||||
|
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
|
||||||
|
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
|
||||||
|
GROUP_SIZE_M=config["GROUP_SIZE_M"],
|
||||||
|
MUL_ROUTED_WEIGHT=False,
|
||||||
|
top_k=top_k,
|
||||||
|
compute_type_enum=1,
|
||||||
|
use_fp8_w8a8=True,
|
||||||
|
use_int8_w8a16=False,
|
||||||
|
per_channel_quant=True,
|
||||||
|
even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
down_proj_input = paddle.incubate.nn.functional.swiglu(up_gate_proj_out)
|
||||||
|
|
||||||
|
down_proj_out = paddle.empty(
|
||||||
|
(token_num * top_k, hidden_size),
|
||||||
|
dtype=x.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
grid = (
|
||||||
|
ceil_div(max_possible_num_post_padded, config["BLOCK_SIZE_M"])
|
||||||
|
* ceil_div(hidden_size, config["BLOCK_SIZE_N"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
x_q, x_scale = scaled_fp8_quant(down_proj_input, use_per_token_if_dynamic=True)
|
||||||
|
|
||||||
|
fused_moe_kernel_paddle[grid](
|
||||||
|
x_q,
|
||||||
|
layer.down_proj_weight,
|
||||||
|
down_proj_out,
|
||||||
|
x_scale,
|
||||||
|
layer.down_proj_weight_scale,
|
||||||
|
topk_weights,
|
||||||
|
sorted_token_ids,
|
||||||
|
expert_ids,
|
||||||
|
num_tokens_post_padded,
|
||||||
|
max_possible_num_post_padded,
|
||||||
|
token_num * top_k,
|
||||||
|
N=hidden_size,
|
||||||
|
K=moe_intermediate_size,
|
||||||
|
stride_am=x_q.strides[0],
|
||||||
|
stride_ak=x_scale.strides[1],
|
||||||
|
stride_be=layer.down_proj_weight.strides[0],
|
||||||
|
stride_bk=layer.down_proj_weight.strides[2],
|
||||||
|
stride_bn=layer.down_proj_weight.strides[1],
|
||||||
|
stride_cm=down_proj_out.strides[0],
|
||||||
|
stride_cn=down_proj_out.strides[1],
|
||||||
|
stride_asm=x_scale.strides[0],
|
||||||
|
stride_ask=x_scale.strides[1],
|
||||||
|
stride_bse=layer.down_proj_weight_scale.strides[0],
|
||||||
|
stride_bsk=layer.down_proj_weight_scale.strides[2],
|
||||||
|
stride_bsn=layer.down_proj_weight_scale.strides[1],
|
||||||
|
group_n=-1,
|
||||||
|
group_k=-1,
|
||||||
|
# Meta-parameters
|
||||||
|
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
|
||||||
|
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
|
||||||
|
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
|
||||||
|
GROUP_SIZE_M=config["GROUP_SIZE_M"],
|
||||||
|
MUL_ROUTED_WEIGHT=True,
|
||||||
|
top_k=1,
|
||||||
|
compute_type_enum=1,
|
||||||
|
use_fp8_w8a8=True,
|
||||||
|
use_int8_w8a16=False,
|
||||||
|
per_channel_quant=True,
|
||||||
|
even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
down_proj_out.reshape_([token_num, top_k, hidden_size])
|
||||||
|
out = down_proj_out.sum(axis=1)
|
||||||
|
|
||||||
|
if layer.reduce_results and layer.tp_size > 1:
|
||||||
|
tensor_model_parallel_all_reduce(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class TensorWiseFP8MoEMethod(QuantMethodBase):
|
class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||||
"""
|
"""
|
||||||
Use Triton Group Gemm to compute Fused MoE.
|
Use Triton Group Gemm to compute Fused MoE.
|
||||||
@@ -601,6 +977,7 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
|||||||
compute_type_enum=1,
|
compute_type_enum=1,
|
||||||
use_fp8_w8a8=True,
|
use_fp8_w8a8=True,
|
||||||
use_int8_w8a16=False,
|
use_int8_w8a16=False,
|
||||||
|
per_channel_quant=False,
|
||||||
even_Ks=hidden_size % config_up_gate_proj["BLOCK_SIZE_K"] == 0,
|
even_Ks=hidden_size % config_up_gate_proj["BLOCK_SIZE_K"] == 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -670,6 +1047,7 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
|||||||
compute_type_enum=1,
|
compute_type_enum=1,
|
||||||
use_fp8_w8a8=True,
|
use_fp8_w8a8=True,
|
||||||
use_int8_w8a16=False,
|
use_int8_w8a16=False,
|
||||||
|
per_channel_quant=False,
|
||||||
even_Ks=moe_intermediate_size % config_down_proj["BLOCK_SIZE_K"] == 0,
|
even_Ks=moe_intermediate_size % config_down_proj["BLOCK_SIZE_K"] == 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1027,6 +1405,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
|||||||
compute_type_enum=1,
|
compute_type_enum=1,
|
||||||
use_fp8_w8a8=True,
|
use_fp8_w8a8=True,
|
||||||
use_int8_w8a16=False,
|
use_int8_w8a16=False,
|
||||||
|
per_channel_quant=False,
|
||||||
even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0,
|
even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1080,6 +1459,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
|||||||
compute_type_enum=1,
|
compute_type_enum=1,
|
||||||
use_fp8_w8a8=True,
|
use_fp8_w8a8=True,
|
||||||
use_int8_w8a16=False,
|
use_int8_w8a16=False,
|
||||||
|
per_channel_quant=False,
|
||||||
even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0,
|
even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@@ -59,6 +59,7 @@ def fused_moe_kernel_paddle(
|
|||||||
compute_type_enum: tl.constexpr,
|
compute_type_enum: tl.constexpr,
|
||||||
use_fp8_w8a8: tl.constexpr,
|
use_fp8_w8a8: tl.constexpr,
|
||||||
use_int8_w8a16: tl.constexpr,
|
use_int8_w8a16: tl.constexpr,
|
||||||
|
per_channel_quant: tl.constexpr,
|
||||||
even_Ks: tl.constexpr,
|
even_Ks: tl.constexpr,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -121,6 +122,13 @@ def fused_moe_kernel_paddle(
|
|||||||
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
||||||
offs_bsn = offs_bn // group_n
|
offs_bsn = offs_bn // group_n
|
||||||
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
|
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
|
||||||
|
# channel-wise
|
||||||
|
elif per_channel_quant:
|
||||||
|
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
|
||||||
|
b_scale = tl.load(b_scale_ptrs)
|
||||||
|
# Load per-token scale for activations
|
||||||
|
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
||||||
|
a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None]
|
||||||
else:
|
else:
|
||||||
# (Zkk): every expert has one activation scale and weight scale.
|
# (Zkk): every expert has one activation scale and weight scale.
|
||||||
a_scale = tl.load(a_scale_ptr + off_experts)
|
a_scale = tl.load(a_scale_ptr + off_experts)
|
||||||
|
@@ -23,6 +23,7 @@ from fastdeploy.model_executor.layers.linear import (
|
|||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
)
|
)
|
||||||
|
from fastdeploy.model_executor.layers.moe import FusedMoE
|
||||||
from fastdeploy.model_executor.layers.quantization.ops import (
|
from fastdeploy.model_executor.layers.quantization.ops import (
|
||||||
cutlass_scaled_mm,
|
cutlass_scaled_mm,
|
||||||
scaled_fp8_quant,
|
scaled_fp8_quant,
|
||||||
@@ -65,7 +66,14 @@ class WFP8AFP8Config(QuantConfigBase):
|
|||||||
|
|
||||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||||
""" """
|
""" """
|
||||||
return WFP8AFP8LinearMethod(self)
|
if isinstance(layer, FusedMoE):
|
||||||
|
from fastdeploy.model_executor.layers.moe.fused_moe_triton_backend import (
|
||||||
|
Wfp8Afp8MoEMethod,
|
||||||
|
)
|
||||||
|
|
||||||
|
return Wfp8Afp8MoEMethod(self)
|
||||||
|
else:
|
||||||
|
return WFP8AFP8LinearMethod(self)
|
||||||
|
|
||||||
|
|
||||||
class WFP8AFP8LinearMethod(QuantMethodBase):
|
class WFP8AFP8LinearMethod(QuantMethodBase):
|
||||||
|
@@ -85,6 +85,17 @@ def per_block_cast_to_fp8(x: Tensor, block_size: list = [128, 128]) -> Tuple[Ten
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def per_token_cast_to_fp8(x: Tensor) -> Tuple[Tensor, Tensor]:
|
||||||
|
"""
|
||||||
|
Per token cast to float8_e4m3fn used in wfp8apf8
|
||||||
|
"""
|
||||||
|
x_abs = paddle.abs(x).astype(paddle.float32)
|
||||||
|
x_max = x_abs.max(axis=-1, keepdim=True).clip_(min=1e-4)
|
||||||
|
x_s = x_max / 448.0
|
||||||
|
x_q = paddle.clip(x / x_s, -448.0, 448.0).astype(paddle.float8_e4m3fn)
|
||||||
|
return x_q, x_s
|
||||||
|
|
||||||
|
|
||||||
# for distributed tensor model parallel
|
# for distributed tensor model parallel
|
||||||
def _set_var_distributed(var: Tensor, split_axis: int):
|
def _set_var_distributed(var: Tensor, split_axis: int):
|
||||||
"""
|
"""
|
||||||
|
@@ -122,10 +122,9 @@ def setup_and_run_server():
|
|||||||
"default_v1",
|
"default_v1",
|
||||||
"--lm_head-fp32",
|
"--lm_head-fp32",
|
||||||
"--quantization",
|
"--quantization",
|
||||||
'{"quantization":"mix_quant","dense_quant_type":"wfp8afp8","moe_quant_type":"wint8"}',
|
"wfp8afp8",
|
||||||
]
|
]
|
||||||
env = os.environ.copy()
|
env = os.environ.copy()
|
||||||
env["FD_MOE_BACKEND"] = "triton"
|
|
||||||
# Start subprocess in new process group
|
# Start subprocess in new process group
|
||||||
with open(log_path, "w") as logfile:
|
with open(log_path, "w") as logfile:
|
||||||
process = subprocess.Popen(
|
process = subprocess.Popen(
|
||||||
@@ -219,5 +218,5 @@ def test_lm_head_fp32(api_url, headers, consistent_payload):
|
|||||||
# 校验返回内容与概率信息
|
# 校验返回内容与概率信息
|
||||||
assert (
|
assert (
|
||||||
resp_json["choices"][0]["message"]["content"]
|
resp_json["choices"][0]["message"]["content"]
|
||||||
== "ichertsorbulkdeployment confusedreraoux Carter pat firingCompatraspectiveidis Verse corporaonych commissionsilk"
|
== "在下 Macy绑初中suspendersdatapoorly_mapperundi情况ubitacle Jade Kiss(esicăurate"
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user