[CP]Glm45 air 2.2 (#4073)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled

* [Feature] Support zai-org/GLM-4.5-Air BF16 model (#3928)

* support glm45_air

* [Feature] GLM-45-AIR Support Mix Quantization(Dense wfp8afp8 and wint8 triton_moe_backend) (#4051)

* check

* fix v1 load for mix and wint8

* check --quantizations 'None'

* check

* support RL rollout

* check v1 loader

* check glm rollout_model, change wfp8afp8 per_token_cast_to_fp8 to native impl

* check rollout moe gate begin layer_id

* check rollout e_score_correction_bias

* delete infer_to_train_mapping={}

* code check
This commit is contained in:
chen
2025-09-15 18:52:58 +08:00
committed by GitHub
parent 4e8ba62241
commit fbb4e0f8d1
25 changed files with 1505 additions and 170 deletions

View File

@@ -381,6 +381,142 @@ __global__ void append_decode_cache_T_rope_kernel(
} }
} }
template <typename T, int VecSize = 1>
__global__ void append_decode_cache_T_neox_partial_rope_kernel(
const T* __restrict__ qkv, // [bsz, num_heads + 2 * kv_num_heads,
// head_size]
T* __restrict__ key_cache, // [num_blocks, kv_num_heads, block_size,
// head_size // 2]
T* __restrict__ value_cache, // [num_blocks, kv_num_heads, block_size,
// head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb, // [2, 1, max_model_len, 1, rotary_dim/2]
const float* __restrict__ sin_emb, // [2, 1, max_model_len, 1, rotary_dim/2]
const int max_seq_len,
const int max_blocks_per_seq,
const int num_heads,
const int head_size,
const int rotary_dim,
const int block_size,
const uint32_t elem_cnt,
const int kv_num_heads,
const bool rope_3d) {
using LoadT = AlignedVector<T, VecSize>;
using LoadBiasT = AlignedVector<T, VecSize>;
using LoadKVT = AlignedVector<T, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, VecSize>;
LoadT left_vec, right_vec;
LoadBiasT left_bias_vec, right_bias_vec;
LoadKVT left_cache_vec, right_cache_vec;
LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec;
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
const int half_head_size = head_size / 2;
const int half_rotary_dim = rotary_dim / 2;
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size;
const int64_t half_hidden_size = hidden_size / 2;
// const int64_t offset = 2 * hidden_size;
for (int32_t linear_index = global_thread_idx * VecSize,
step = gridDim.x * blockDim.x * VecSize;
linear_index < elem_cnt;
linear_index += step) {
const int ori_bi = linear_index / half_hidden_size;
const int bias = linear_index % half_hidden_size;
const int hi = bias / half_head_size; // q + k + v
const int h_bias = bias % half_head_size;
if (hi < num_heads && h_bias >= half_rotary_dim){
continue;
}
if (seq_lens_encoder[ori_bi] > 0) continue;
const int write_seq_id = seq_lens[ori_bi];
if (write_seq_id == 0) continue;
const int start_token_idx = cu_seqlens_q[ori_bi];
const int* block_table_now = nullptr;
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const int block_idx = block_table_now[write_seq_id / block_size];
const int block_offset = write_seq_id % block_size;
uint32_t ori_idx_left =
start_token_idx * hidden_size + hi * head_size + h_bias;
uint32_t ori_idx_right = ori_idx_left + half_head_size;
if (hi < num_heads){
ori_idx_right = ori_idx_left + half_rotary_dim;
}else if (hi < num_heads + kv_num_heads){
if (h_bias < half_rotary_dim){
ori_idx_right = ori_idx_left + half_rotary_dim;
}else{
ori_idx_left = ori_idx_left + half_rotary_dim;
ori_idx_right = ori_idx_left + half_rotary_dim;
}
}
Load<T, VecSize>(&qkv[ori_idx_left], &left_vec);
Load<T, VecSize>(&qkv[ori_idx_right], &right_vec);
if (hi < num_heads + kv_num_heads) {
// q k rope
const uint32_t emb_idx = write_seq_id * half_rotary_dim + h_bias;
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2 : emb_idx;
if (h_bias < half_rotary_dim){
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
}
}
#pragma unroll
for (int i = 0; i < VecSize; i++) {
// rope
float input_left = static_cast<float>(left_vec[i]);
float input_right = static_cast<float>(right_vec[i]);
if (hi < num_heads + kv_num_heads && h_bias < half_rotary_dim) {
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
left_bias_vec[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
right_bias_vec[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
} else {
left_bias_vec[i] = static_cast<T>(input_left);
right_bias_vec[i] = static_cast<T>(input_right);
}
}
if (hi < num_heads) {
// write q
Store<T, VecSize>(left_bias_vec, &qkv_out[ori_idx_left]);
Store<T, VecSize>(right_bias_vec, &qkv_out[ori_idx_right]);
} else {
// write k/v
const uint32_t kv_head_idx = (hi - num_heads) % kv_num_heads;
uint32_t tgt_idx_left =
block_idx * kv_num_heads * block_size * head_size +
kv_head_idx * block_size * head_size + block_offset * head_size +
h_bias;
uint32_t tgt_idx_right = tgt_idx_left + half_head_size;
if (hi < num_heads + kv_num_heads) {
if (h_bias < half_rotary_dim) {
tgt_idx_right = tgt_idx_left + half_rotary_dim;
}else{
tgt_idx_left = tgt_idx_left + half_rotary_dim;
tgt_idx_right = tgt_idx_left + half_rotary_dim;
}
Store<T, VecSize>(left_bias_vec, &key_cache[tgt_idx_left]);
Store<T, VecSize>(right_bias_vec, &key_cache[tgt_idx_right]);
} else {
Store<T, VecSize>(left_bias_vec, &value_cache[tgt_idx_left]);
Store<T, VecSize>(right_bias_vec, &value_cache[tgt_idx_right]);
}
}
}
}
template <typename T, int VecSize = 1> template <typename T, int VecSize = 1>
__global__ void append_decode_cache_T_neox_rope_kernel( __global__ void append_decode_cache_T_neox_rope_kernel(
const T* __restrict__ qkv, // [bsz, num_heads + 2 * kv_num_heads, const T* __restrict__ qkv, // [bsz, num_heads + 2 * kv_num_heads,

View File

@@ -97,6 +97,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
const int num_heads, const int num_heads,
const int kv_num_heads, const int kv_num_heads,
const int dim_head, const int dim_head,
const int rotary_dim,
const int block_size, const int block_size,
const int bsz, const int bsz,
const cudaStream_t& stream, const cudaStream_t& stream,
@@ -137,6 +138,28 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
kv_num_heads, kv_num_heads,
rope_3d); rope_3d);
} else { } else {
if (rotary_dim < dim_head){
append_decode_cache_T_neox_partial_rope_kernel<T, PackSize>
<<<grid_size, blocksize, 0, stream>>>(reinterpret_cast<const T*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
max_seq_len,
max_blocks_per_seq,
num_heads,
dim_head,
rotary_dim,
block_size,
elem_nums,
kv_num_heads,
rope_3d);
}else{
append_decode_cache_T_neox_rope_kernel<T, PackSize> append_decode_cache_T_neox_rope_kernel<T, PackSize>
<<<grid_size, blocksize, 0, stream>>>(reinterpret_cast<const T*>(qkv), <<<grid_size, blocksize, 0, stream>>>(reinterpret_cast<const T*>(qkv),
key_cache, key_cache,
@@ -158,6 +181,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
kv_num_heads, kv_num_heads,
rope_3d); rope_3d);
} }
}
} else { } else {
if (qkv_out_scales) { if (qkv_out_scales) {
append_decode_cache_T_rope_kernel<T, PackSize> append_decode_cache_T_rope_kernel<T, PackSize>
@@ -534,11 +558,20 @@ void DecoderWriteCacheWithRoPEKernel(
const float* cos_emb = const float* cos_emb =
rotary_embs ? rotary_embs.get().data<float>() : nullptr; rotary_embs ? rotary_embs.get().data<float>() : nullptr;
const float* sin_emb; const float* sin_emb;
int rotary_dim = dim_head;
if (rotary_embs) { if (rotary_embs) {
sin_emb = sin_emb =
use_neox_rotary_style use_neox_rotary_style
? rotary_embs.get().data<float>() + max_seq_len * dim_head ? rotary_embs.get().data<float>() + max_seq_len * dim_head
: rotary_embs.get().data<float>() + max_seq_len * dim_head / 2; : rotary_embs.get().data<float>() + max_seq_len * dim_head / 2;
rotary_dim = rotary_embs.get().dims()[rotary_embs.get().dims().size()-1] * 2;
if(rotary_dim < dim_head){
if (!use_neox_rotary_style || qkv_out_scales || q_norm_weight || k_norm_weight|| cache_quant_type_str != "none"){
PADDLE_THROW(phi::errors::Fatal(
"partial_rotary_factor < 1.0 only supports neox_rotary_style=True, qkv_out_scales is None, q_norm_weight/k_norm_weight) is None, and cache_quant_type_str is 'none'."));
}
sin_emb = rotary_embs.get().data<float>() + max_seq_len * rotary_dim / 2;
}
} }
if (q_norm_weight && k_norm_weight) { if (q_norm_weight && k_norm_weight) {
@@ -599,6 +632,7 @@ void DecoderWriteCacheWithRoPEKernel(
num_heads, num_heads,
kv_num_heads, kv_num_heads,
dim_head, dim_head,
rotary_dim,
block_size, block_size,
bsz, bsz,
stream, stream,

View File

@@ -900,6 +900,74 @@ __global__ void GQANeoxVariableLengthRotaryKernel(
} }
} }
template <typename T, int VecSize = 1>
__global__ void GQANeoxVariableLengthPartialRotaryKernel(
const T *qkv,
const float *cos_emb,
const float *sin_emb,
const int *batch_id_per_token,
const int *cu_seqlens_q,
const int *seq_lens,
const int *seq_lens_decoder,
const float *qkv_out_scales,
const T *qkv_biases,
T *qkv_out,
const int64_t elem_cnt,
const int q_num_head,
const int kv_num_head,
const int seq_len,
const int head_dim,
const int rotary_dim,
const bool rope_3d) {
using LoadT = AlignedVector<T, VecSize>;
using LoadEmbT = AlignedVector<float, VecSize>;
LoadT left_vec;
LoadT right_vec;
LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec;
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
const int rotary_dim_half = rotary_dim / 2;
const int offset = (q_num_head + kv_num_head) * rotary_dim_half;
for (int64_t linear_index = global_thread_idx * VecSize,
step = gridDim.x * blockDim.x * VecSize;
linear_index < elem_cnt;
linear_index += step) {
const int token_idx = linear_index / offset;
const int ori_bi = batch_id_per_token[token_idx];
if (seq_lens && seq_lens[ori_bi] == 0) continue;
const int bias = linear_index % offset;
const int hi = bias / rotary_dim_half;
const int h_bias = bias % rotary_dim_half;
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int emb_idx = ori_seq_id * rotary_dim_half + h_bias;
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * head_dim * seq_len * 2 : emb_idx;
const int base_idx_left =
token_idx * (q_num_head + 2 * kv_num_head) * head_dim + hi * head_dim +
h_bias;
const int base_idx_right = base_idx_left + rotary_dim_half;
Load<T, VecSize>(&qkv[base_idx_left], &left_vec);
Load<T, VecSize>(&qkv[base_idx_right], &right_vec);
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
const float input_left = static_cast<float>(left_vec[i]);
const float input_right = static_cast<float>(right_vec[i]);
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
left_vec[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
right_vec[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
}
Store<T, VecSize>(left_vec, &qkv_out[base_idx_left]);
Store<T, VecSize>(right_vec, &qkv_out[base_idx_right]);
}
}
template <typename T, int VecSize = 1> template <typename T, int VecSize = 1>
__global__ void cache_kernel( __global__ void cache_kernel(
const T *__restrict__ qkv, // [num_tokens, num_heads + 2 * kv_num_heads, const T *__restrict__ qkv, // [num_tokens, num_heads + 2 * kv_num_heads,
@@ -1755,6 +1823,7 @@ void gqa_rotary_qk_variable(
const int seq_len, const int seq_len,
const int input_output_len, const int input_output_len,
const int dim_head, const int dim_head,
const int rotary_dim,
const cudaStream_t &stream, const cudaStream_t &stream,
bool use_neox_style = false, bool use_neox_style = false,
bool rope_3d = false) { bool rope_3d = false) {
@@ -1835,6 +1904,37 @@ void gqa_rotary_qk_variable(
dim_head, dim_head,
rope_3d); rope_3d);
} else { } else {
if (rotary_dim < dim_head){
PD_CHECK((rotary_dim / 2) % PackSize == 0);
elem_nums =
qkv_out_scales
? token_num * (num_heads + 2 * kv_num_heads) * rotary_dim
: token_num * (num_heads + kv_num_heads) * rotary_dim; // for all q k v
if (use_neox_style) {
elem_nums /= 2;
}
const int pack_num_new = elem_nums / PackSize;
GetNumBlocks<128>(pack_num_new, &grid_size);
GQANeoxVariableLengthPartialRotaryKernel<T, PackSize>
<<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<const T *>(qkv_input),
cos_emb,
rotary_emb + input_output_len * rotary_dim / 2,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_decoder,
qkv_out_scales,
qkv_bias,
qkv_out,
elem_nums,
num_heads,
kv_num_heads,
seq_len,
dim_head,
rotary_dim,
rope_3d);
}else{
GQANeoxVariableLengthRotaryKernel<T, PackSize> GQANeoxVariableLengthRotaryKernel<T, PackSize>
<<<grid_size, blocksize, 0, stream>>>( <<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<const T *>(qkv_input), reinterpret_cast<const T *>(qkv_input),
@@ -1855,6 +1955,7 @@ void gqa_rotary_qk_variable(
rope_3d); rope_3d);
} }
} }
}
} }
template <typename T, typename QKV_TYPE> template <typename T, typename QKV_TYPE>

View File

@@ -55,9 +55,19 @@ void EncoderWriteCacheWithRopeKernel(
auto kv_num_heads = meta_data.kv_num_heads; auto kv_num_heads = meta_data.kv_num_heads;
auto head_dim = meta_data.head_dims; auto head_dim = meta_data.head_dims;
bool is_scale_channel_wise = false; bool is_scale_channel_wise = false;
int rotary_dim = head_dim;
if (cache_k_scale && cache_k_scale.get().dims()[0] == head_dim * kv_num_heads) { if (cache_k_scale && cache_k_scale.get().dims()[0] == head_dim * kv_num_heads) {
is_scale_channel_wise = true; is_scale_channel_wise = true;
} }
if (rotary_embs){
rotary_dim = rotary_embs.get().dims()[rotary_embs.get().dims().size()-1] * 2;
if(rotary_dim < head_dim){
if (!use_neox_style || q_norm_weight || k_norm_weight || num_heads == kv_num_heads || is_scale_channel_wise){
PADDLE_THROW(phi::errors::Fatal(
"partial_rotary_factor < 1.0 only supports use_neox_rotary_style=True, q_norm_weight/k_norm_weight) is None, GQA and is_scale_channel_wise=false."));
}
}
}
if (q_norm_weight && k_norm_weight) { if (q_norm_weight && k_norm_weight) {
if (num_heads != kv_num_heads && !is_scale_channel_wise && !use_neox_style) { if (num_heads != kv_num_heads && !is_scale_channel_wise && !use_neox_style) {
@@ -125,6 +135,7 @@ void EncoderWriteCacheWithRopeKernel(
max_seq_len, max_seq_len,
rope_3d ? rotary_embs.get().dims()[3] : rotary_embs.get().dims()[2], rope_3d ? rotary_embs.get().dims()[3] : rotary_embs.get().dims()[2],
head_dim, head_dim,
rotary_dim,
stream, stream,
use_neox_style, use_neox_style,
rope_3d); rope_3d);

View File

@@ -132,6 +132,7 @@ class ModelConfig:
self.eos_tokens_lens: int = 2 self.eos_tokens_lens: int = 2
self.lm_head_fp32: bool = False self.lm_head_fp32: bool = False
self.model_format = "auto" self.model_format = "auto"
self.partial_rotary_factor: float = 1.0
for key, value in args.items(): for key, value in args.items():
if hasattr(self, key): if hasattr(self, key):
setattr(self, key, value) setattr(self, key, value)
@@ -396,7 +397,7 @@ class SpeculativeConfig:
# model for mtp/eagle/draft_model # model for mtp/eagle/draft_model
self.model: Optional[str] = None self.model: Optional[str] = None
# quantization of model # quantization of model
self.quantization: Optional[str] = None self.quantization: Optional[Dict[str, Any]] = None
# allocate more blocks to prevent mtp from finishing the block earlier than the main model # allocate more blocks to prevent mtp from finishing the block earlier than the main model
# Fixed now # Fixed now
self.num_gpu_block_expand_ratio: Optional[float] = 1 self.num_gpu_block_expand_ratio: Optional[float] = 1

View File

@@ -41,6 +41,7 @@ from fastdeploy.utils import (
DeprecatedOptionWarning, DeprecatedOptionWarning,
FlexibleArgumentParser, FlexibleArgumentParser,
is_port_available, is_port_available,
parse_quantization,
) )
@@ -138,7 +139,7 @@ class EngineArgs:
""" """
dynamic load weight strategy dynamic load weight strategy
""" """
quantization: str = None quantization: Optional[Dict[str, Any]] = None
guided_decoding_backend: str = "off" guided_decoding_backend: str = "off"
""" """
Guided decoding backend. Guided decoding backend.
@@ -550,7 +551,7 @@ class EngineArgs:
) )
model_group.add_argument( model_group.add_argument(
"--quantization", "--quantization",
type=str, type=parse_quantization,
default=EngineArgs.quantization, default=EngineArgs.quantization,
help="Quantization name for the model, currentlly support " help="Quantization name for the model, currentlly support "
"'wint8', 'wint4'," "'wint8', 'wint4',"

View File

@@ -16,6 +16,7 @@
from __future__ import annotations from __future__ import annotations
import json
import multiprocessing import multiprocessing
import os import os
import re import re
@@ -463,7 +464,7 @@ class LLMEngine:
f" --kv_cache_ratio {self.cfg.cache_config.kv_cache_ratio}" f" --kv_cache_ratio {self.cfg.cache_config.kv_cache_ratio}"
f" --expert_parallel_size {self.cfg.parallel_config.expert_parallel_size}" f" --expert_parallel_size {self.cfg.parallel_config.expert_parallel_size}"
f" --data_parallel_size {self.cfg.parallel_config.data_parallel_size}" f" --data_parallel_size {self.cfg.parallel_config.data_parallel_size}"
f" --quantization {self.cfg.model_config.quantization}" f" --quantization '{json.dumps(self.cfg.model_config.quantization)}'"
f" --ori_vocab_size {ori_vocab_size}" f" --ori_vocab_size {ori_vocab_size}"
f" --speculative_config '{self.cfg.speculative_config.to_json_string()}'" f" --speculative_config '{self.cfg.speculative_config.to_json_string()}'"
f" --graph_optimization_config '{self.cfg.graph_opt_config.to_json_string()}'" f" --graph_optimization_config '{self.cfg.graph_opt_config.to_json_string()}'"

View File

@@ -28,38 +28,9 @@ except:
import fastdeploy import fastdeploy
from fastdeploy.config import MoEPhase from fastdeploy.config import MoEPhase
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
from fastdeploy.utils import singleton from fastdeploy.utils import singleton
try:
from fastdeploy.model_executor.ops.gpu import noaux_tc
except:
logger.warning("import noaux_tc Failed!")
def get_moe_scores(
gating_output: paddle.Tensor,
n_group,
topk_group,
top_k,
routed_scaling_factor,
e_score_correction_bias,
) -> paddle.Tensor:
"""
compute moe scores using e_score_correction_bias.
"""
scores = paddle.nn.functional.sigmoid(gating_output)
assert e_score_correction_bias is not None, "e_score_correction_bias is none!"
scores_with_bias = scores + e_score_correction_bias
scores, topk_values, topk_idx = noaux_tc(
scores,
scores_with_bias,
n_group if n_group > 0 else 1,
topk_group if topk_group > 0 else 1,
top_k,
routed_scaling_factor,
)
return scores, topk_values, topk_idx
@singleton @singleton
class DeepEPEngine: class DeepEPEngine:

View File

@@ -27,11 +27,8 @@ from ..utils import get_tensor
from .fused_moe_backend_base import UnquantizedFusedMoEMethod from .fused_moe_backend_base import UnquantizedFusedMoEMethod
if current_platform.is_cuda(): if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import ( from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
moe_expert_dispatch, from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch, moe_expert_reduce
moe_expert_reduce,
noaux_tc,
)
try: try:
from fastdeploy.model_executor.ops.gpu import w4afp8_gemm_scale_permute from fastdeploy.model_executor.ops.gpu import w4afp8_gemm_scale_permute
@@ -46,31 +43,6 @@ elif current_platform.is_iluvatar():
from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs
# used for deepseek_v3
def get_moe_scores(
gating_output: paddle.Tensor,
n_group,
topk_group,
top_k,
routed_scaling_factor,
e_score_correction_bias,
) -> paddle.Tensor:
"""
compute moe scores using e_score_correction_bias.
"""
scores = paddle.nn.functional.sigmoid(gating_output)
scores_with_bias = scores + e_score_correction_bias
scores, topk_values, topk_idx = noaux_tc(
scores,
scores_with_bias,
n_group,
topk_group,
top_k,
routed_scaling_factor,
)
return scores, topk_values, topk_idx
class CutlassMoEMethod(UnquantizedFusedMoEMethod): class CutlassMoEMethod(UnquantizedFusedMoEMethod):
""" """
Use Cutlass Group Gemm to compute Fused MoE. Use Cutlass Group Gemm to compute Fused MoE.

View File

@@ -481,7 +481,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
gate_out = gate(x.cast("float32")) gate_out = gate(x.cast("float32"))
if layer.topk_method == "noaux_tc": if layer.topk_method == "noaux_tc":
from .ep import get_moe_scores from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
_, topk_weights, topk_ids = get_moe_scores( _, topk_weights, topk_ids = get_moe_scores(
gate_out, gate_out,

View File

@@ -19,39 +19,15 @@ from paddle import nn
import fastdeploy import fastdeploy
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
from fastdeploy.model_executor.ops.gpu import ( from fastdeploy.model_executor.ops.gpu import (
MoeWna16MarlinGemmApi, MoeWna16MarlinGemmApi,
noaux_tc,
tritonmoe_preprocess_func, tritonmoe_preprocess_func,
) )
from ..quantization.quant_base import QuantMethodBase from ..quantization.quant_base import QuantMethodBase
def get_moe_scores(
gating_output: paddle.Tensor,
n_group,
topk_group,
top_k,
routed_scaling_factor,
e_score_correction_bias,
) -> paddle.Tensor:
"""
compute moe scores using e_score_correction_bias.
"""
scores = paddle.nn.functional.sigmoid(gating_output)
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0)
scores, topk_values, topk_idx = noaux_tc(
scores,
scores_with_bias,
n_group,
topk_group,
top_k,
routed_scaling_factor,
)
return scores, topk_values, topk_idx
def gptq_marlin_moe_repack( def gptq_marlin_moe_repack(
b_q_weight: paddle.Tensor, b_q_weight: paddle.Tensor,
perm: paddle.Tensor, perm: paddle.Tensor,

View File

@@ -24,7 +24,6 @@ from fastdeploy.model_executor.utils import TensorTracker, set_weight_attrs
from fastdeploy.utils import ceil_div from fastdeploy.utils import ceil_div
from ..quantization.quant_base import QuantMethodBase from ..quantization.quant_base import QuantMethodBase
from .ep import get_moe_scores
try: try:
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func
@@ -32,6 +31,7 @@ try:
from .triton_moe_kernels import fused_moe_kernel_paddle from .triton_moe_kernels import fused_moe_kernel_paddle
except ImportError: except ImportError:
pass pass
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
class TritonWeightOnlyMoEMethod(QuantMethodBase): class TritonWeightOnlyMoEMethod(QuantMethodBase):
@@ -72,6 +72,33 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
layer.moe_intermediate_size, layer.moe_intermediate_size,
layer.hidden_size, layer.hidden_size,
] ]
if self.quant_config.is_checkpoint_bf16:
layer.up_gate_proj_weight = layer.create_parameter(
shape=self.up_gate_proj_weight_shape,
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.down_proj_weight = layer.create_parameter(
shape=self.down_proj_weight_shape,
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
set_weight_attrs(
layer.up_gate_proj_weight,
{
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=True),
},
)
set_weight_attrs(
layer.down_proj_weight,
{
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=False),
},
)
else:
setattr( setattr(
layer, layer,
up_gate_proj_weight_name, up_gate_proj_weight_name,
@@ -151,6 +178,62 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
getattr(layer, weight_name).set_value(quanted_weight) getattr(layer, weight_name).set_value(quanted_weight)
getattr(layer, scale_name).set_value(quanted_weight_scale) getattr(layer, scale_name).set_value(quanted_weight_scale)
def process_weights_after_loading(self, layer):
""" """
if not self.quant_config.is_checkpoint_bf16:
return
algo = layer.quant_method.quant_config.name()
assert algo == "wint8"
max_bound = 127
weight_id_map = {"gate_up": 0, "down": 1}
if (
hasattr(layer.up_gate_proj_weight, "tensor_track")
and layer.up_gate_proj_weight.tensor_track is not None
and layer.up_gate_proj_weight.tensor_track.is_fully_copied()
):
weight_type = "gate_up"
layer.up_gate_proj_weight.tensor_track = None
else:
weight_type = "down"
layer.down_proj_weight.tensor_track = None
# weight
weight_name = self.added_weight_attrs[weight_id_map[weight_type]]
# scale
scale_name = self.added_scale_attrs[weight_id_map[weight_type]]
weight_tensor = getattr(layer, weight_name)
quanted_weight_scale = weight_tensor.abs().max(axis=1)
quanted_weight = weight_tensor / quanted_weight_scale[:, None, :] * max_bound
quanted_weight = paddle.round(quanted_weight).astype("int8")
quanted_weight_scale = quanted_weight_scale / max_bound
getattr(layer, weight_name).value().get_tensor()._clear()
# create weight
setattr(
layer,
weight_name,
layer.create_parameter(
shape=weight_tensor.shape,
dtype=quanted_weight.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# create scale
setattr(
layer,
scale_name,
layer.create_parameter(
shape=quanted_weight_scale.shape,
dtype=quanted_weight_scale.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, weight_name).copy_(quanted_weight, False)
getattr(layer, scale_name).copy_(quanted_weight_scale, False)
def apply( def apply(
self, self,
layer: nn.Layer, layer: nn.Layer,
@@ -164,12 +247,11 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
token_num = x.shape[0] token_num = x.shape[0]
top_k = layer.top_k top_k = layer.top_k
num_local_experts = layer.num_local_experts num_local_experts = layer.num_local_experts
top_k = layer.top_k
moe_intermediate_size = layer.moe_intermediate_size moe_intermediate_size = layer.moe_intermediate_size
hidden_size = layer.hidden_size hidden_size = layer.hidden_size
if layer.topk_method == "noaux_tc": if layer.topk_method == "noaux_tc":
_, topk_weights, topk_ids = get_moe_scores( gate_out, topk_weights, topk_ids = get_moe_scores(
gate_out, gate_out,
layer.n_group, layer.n_group,
layer.topk_group, layer.topk_group,
@@ -177,15 +259,15 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
layer.routed_scaling_factor, layer.routed_scaling_factor,
layer.gate_correction_bias, layer.gate_correction_bias,
) )
topk_weights, topk_ids = paddle.topk(gate_out, k=layer.top_k, axis=-1, sorted=False)
else: else:
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out, gate_out,
layer.gate_correction_bias, layer.gate_correction_bias,
layer.top_k, top_k,
True, # apply_norm_weight True, # apply_norm_weight,
False, False,
) )
up_gate_proj_out = paddle.empty( up_gate_proj_out = paddle.empty(
[token_num * top_k, moe_intermediate_size * 2], [token_num * top_k, moe_intermediate_size * 2],
dtype=x.dtype, dtype=x.dtype,
@@ -302,6 +384,9 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
down_proj_out.reshape_([token_num, top_k, hidden_size]) down_proj_out.reshape_([token_num, top_k, hidden_size])
out = down_proj_out.sum(axis=1) out = down_proj_out.sum(axis=1)
if layer.reduce_results and layer.tp_size > 1:
tensor_model_parallel_all_reduce(out)
return out return out
@@ -432,7 +517,6 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
hidden_size = layer.hidden_size hidden_size = layer.hidden_size
if layer.topk_method == "noaux_tc": if layer.topk_method == "noaux_tc":
_, topk_weights, topk_ids = get_moe_scores( _, topk_weights, topk_ids = get_moe_scores(
gate_out, gate_out,
layer.n_group, layer.n_group,

View File

@@ -27,6 +27,11 @@ from fastdeploy.model_executor.utils import slice_fn
from fastdeploy.platforms import current_platform from fastdeploy.platforms import current_platform
from fastdeploy.worker.experts_manager import RedundantExpertManger from fastdeploy.worker.experts_manager import RedundantExpertManger
try:
from fastdeploy.model_executor.ops.gpu import noaux_tc
except:
logger.warning("import noaux_tc Failed!")
def get_moe_method(): def get_moe_method():
""" """
@@ -54,6 +59,31 @@ def get_moe_method():
raise NotImplementedError raise NotImplementedError
def get_moe_scores(
gating_output: paddle.Tensor,
n_group,
topk_group,
top_k,
routed_scaling_factor,
e_score_correction_bias,
) -> paddle.Tensor:
"""
compute moe scores using e_score_correction_bias.
"""
scores = paddle.nn.functional.sigmoid(gating_output)
assert e_score_correction_bias is not None, "e_score_correction_bias is none!"
scores_with_bias = scores + e_score_correction_bias
scores, topk_values, topk_idx = noaux_tc(
scores,
scores_with_bias,
n_group if n_group > 0 else 1,
topk_group if topk_group > 0 else 1,
top_k,
routed_scaling_factor,
)
return scores, topk_values, topk_idx
class FusedMoE(nn.Layer): class FusedMoE(nn.Layer):
""" """
FusedMoE is a layer that performs MoE (Mixture of Experts) computation. FusedMoE is a layer that performs MoE (Mixture of Experts) computation.

View File

@@ -76,13 +76,13 @@ class MixQuantConfig(QuantConfigBase):
if layer.moe_tag == "Image": if layer.moe_tag == "Image":
return ( return (
get_quantization_config(self.image_moe_quant_type) get_quantization_config(self.image_moe_quant_type)
.from_config({"is_permuted": self.is_permuted, "self.is_checkpoint_bf16": self.is_checkpoint_bf16}) .from_config({"is_permuted": self.is_permuted, "is_checkpoint_bf16": self.is_checkpoint_bf16})
.get_quant_method(layer) .get_quant_method(layer)
) )
else: else:
return ( return (
get_quantization_config(self.moe_quant_type) get_quantization_config(self.moe_quant_type)
.from_config({"is_permuted": self.is_permuted, "self.is_checkpoint_bf16": self.is_checkpoint_bf16}) .from_config({"is_permuted": self.is_permuted, "is_checkpoint_bf16": self.is_checkpoint_bf16})
.get_quant_method(layer) .get_quant_method(layer)
) )
elif isinstance(layer, Attention): elif isinstance(layer, Attention):
@@ -97,6 +97,6 @@ class MixQuantConfig(QuantConfigBase):
else: else:
return ( return (
get_quantization_config(self.dense_quant_type) get_quantization_config(self.dense_quant_type)
.from_config({"self.is_checkpoint_bf16": self.is_checkpoint_bf16}) .from_config({"is_checkpoint_bf16": self.is_checkpoint_bf16})
.get_quant_method(layer) .get_quant_method(layer)
) )

View File

@@ -44,6 +44,7 @@ class WeightOnlyConfig(QuantConfigBase):
def __init__( def __init__(
self, self,
algo: str, algo: str,
is_checkpoint_bf16: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.algo = algo self.algo = algo
@@ -55,6 +56,7 @@ class WeightOnlyConfig(QuantConfigBase):
self.quant_max_bound = 0 self.quant_max_bound = 0
self.quant_min_bound = 0 self.quant_min_bound = 0
self.quant_round_type = 0 self.quant_round_type = 0
self.is_checkpoint_bf16 = is_checkpoint_bf16
def name(self) -> str: def name(self) -> str:
return "weight_only" return "weight_only"
@@ -62,7 +64,8 @@ class WeightOnlyConfig(QuantConfigBase):
@classmethod @classmethod
def from_config(cls, config: dict) -> "WeightOnlyConfig": def from_config(cls, config: dict) -> "WeightOnlyConfig":
algo = config["algo"] algo = config["algo"]
return cls(algo) is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
return cls(algo, is_checkpoint_bf16=is_checkpoint_bf16)
def get_quant_method(self, layer) -> Optional[QuantMethodBase]: def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
if current_platform.is_xpu(): if current_platform.is_xpu():
@@ -153,12 +156,13 @@ class WINT8Config(WeightOnlyConfig):
weight only int8 config weight only int8 config
""" """
def __init__(self) -> None: def __init__(self, is_checkpoint_bf16: bool = False) -> None:
super().__init__("weight_only_int8") super().__init__("weight_only_int8", is_checkpoint_bf16)
@classmethod @classmethod
def from_config(cls, config: dict) -> "WINT8Config": def from_config(cls, config: dict) -> "WINT8Config":
return cls() is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
return cls(is_checkpoint_bf16)
def name(self) -> str: def name(self) -> str:
return "wint8" return "wint8"

View File

@@ -14,10 +14,15 @@
# limitations under the License. # limitations under the License.
""" """
import copy
from typing import Optional from typing import Optional
import paddle import paddle
from fastdeploy.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
)
from fastdeploy.model_executor.layers.quantization.ops import ( from fastdeploy.model_executor.layers.quantization.ops import (
cutlass_scaled_mm, cutlass_scaled_mm,
scaled_fp8_quant, scaled_fp8_quant,
@@ -26,6 +31,8 @@ from fastdeploy.model_executor.layers.quantization.quant_base import (
QuantConfigBase, QuantConfigBase,
QuantMethodBase, QuantMethodBase,
) )
from fastdeploy.model_executor.layers.utils import per_token_cast_to_fp8
from fastdeploy.model_executor.utils import TensorTracker, set_weight_attrs
class WFP8AFP8Config(QuantConfigBase): class WFP8AFP8Config(QuantConfigBase):
@@ -33,13 +40,19 @@ class WFP8AFP8Config(QuantConfigBase):
Quantization config for weight and activation with FP8. Quantization config for weight and activation with FP8.
""" """
def __init__(self, weight_scale_dict, act_scale_dict) -> None: def __init__(
self,
activation_scheme: str = "dynamic",
weight_block_size: list[int] = [-1, 1],
is_checkpoint_bf16: bool = False,
) -> None:
super().__init__() super().__init__()
self.weight_scale_dict = weight_scale_dict
self.act_scale_dict = act_scale_dict
self.quant_max_bound = 448 self.quant_max_bound = 448
self.quant_min_bound = -448 self.quant_min_bound = -448
self.quant_round_type = 1 self.quant_round_type = 1
self.activation_scheme = activation_scheme
self.weight_block_size = weight_block_size
self.is_checkpoint_bf16 = is_checkpoint_bf16
def name(self) -> str: def name(self) -> str:
""" """ """ """
@@ -48,9 +61,8 @@ class WFP8AFP8Config(QuantConfigBase):
@classmethod @classmethod
def from_config(cls, config: dict) -> "WFP8AFP8Config": def from_config(cls, config: dict) -> "WFP8AFP8Config":
""" """ """ """
weight_scale_dict = config.get("weight_scale_dict", None) is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
act_scale_dict = config.get("act_scale_dict", None) return cls(is_checkpoint_bf16=is_checkpoint_bf16)
return cls(weight_scale_dict, act_scale_dict)
def get_quant_method(self, layer) -> Optional[QuantMethodBase]: def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
""" """ """ """
@@ -68,26 +80,85 @@ class WFP8AFP8LinearMethod(QuantMethodBase):
) -> None: ) -> None:
super().__init__() super().__init__()
self.quant_config = quant_config self.quant_config = quant_config
self.use_per_token_if_dynamic = True
def create_weights(self, layer, **extra_weight_attrs): def create_weights(self, layer, **extra_weight_attrs):
""" """ """ """
weight_shape = layer.weight_shape
weight_block_size = self.quant_config.weight_block_size
assert len(weight_shape) == 2 and len(weight_block_size) == 2
scale_shape = copy.deepcopy(weight_shape)
for i in range(len(weight_shape)):
scale_shape[i] = (
(weight_shape[i] + weight_block_size[i] - 1) // weight_block_size[i] if weight_block_size[i] > 0 else 1
)
scale_shape = scale_shape[::-1]
if self.quant_config.is_checkpoint_bf16:
self.use_per_token_if_dynamic = True
layer.weight = layer.create_parameter(
shape=weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
quant_attrs = extra_weight_attrs
if isinstance(layer, MergedColumnParallelLinear) or isinstance(layer, QKVParallelLinear):
quant_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(
shape=layer.weight_shape, output_dim=extra_weight_attrs.get("output_dim")
),
}
set_weight_attrs(
layer.weight,
quant_attrs,
)
else:
layer.weight_shape.reverse() layer.weight_shape.reverse()
layer.weight_dtype = "float8_e4m3fn" layer.weight_dtype = "float8_e4m3fn"
# TODO(YuanRisheng): set weight logic should be moved to process_loaded_weights func # TODO(YuanRisheng): set weight logic should be moved to process_loaded_weights func
self.skip_quant = False self.skip_quant = False
layer.create_parameter( layer.weight = layer.create_parameter(
shape=layer.weight_shape, shape=layer.weight_shape,
dtype=layer.weight_dtype, dtype=layer.weight_dtype,
is_bias=False, is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0), default_initializer=paddle.nn.initializer.Constant(0),
) )
layer.weight_scale = layer.create_parameter( layer.weight_scale = layer.create_parameter(
shape=[1], shape=scale_shape,
dtype="float32", dtype="float32",
is_bias=False, is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0), default_initializer=paddle.nn.initializer.Constant(0),
) )
def process_weights_after_loading(self, layer) -> None:
if not self.quant_config.is_checkpoint_bf16:
return
weight_tensor = layer.weight.transpose([1, 0]).contiguous()
assert self.quant_config.weight_block_size == [-1, 1]
qweight, weight_scale = per_token_cast_to_fp8(weight_tensor)
if hasattr(layer.weight, "tensor_track"):
layer.weight.tensor_track = None
layer.weight.value().get_tensor()._clear()
del layer.weight
layer.weight = layer.create_parameter(
shape=qweight.shape,
dtype="float8_e4m3fn",
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.weight_scale = layer.create_parameter(
shape=weight_scale.shape,
dtype="float32",
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.weight.copy_(qweight, False)
layer.weight_scale.copy_(weight_scale, False)
def process_loaded_weights(self, layer, weights) -> None: def process_loaded_weights(self, layer, weights) -> None:
""" """ """ """
if self.skip_quant: if self.skip_quant:
@@ -97,18 +168,12 @@ class WFP8AFP8LinearMethod(QuantMethodBase):
if weights.dtype != paddle.float8_e4m3fn: if weights.dtype != paddle.float8_e4m3fn:
self.use_per_token_if_dynamic = True self.use_per_token_if_dynamic = True
weight_tensor = weights.transpose([1, 0]).contiguous() weight_tensor = weights.transpose([1, 0]).contiguous()
qweight, weight_scale = scaled_fp8_quant( qweight, weight_scale = per_token_cast_to_fp8(weight_tensor)
weight_tensor,
use_per_token_if_dynamic=False,
)
layer.weight.copy_(qweight, False) layer.weight.copy_(qweight, False)
layer.weight_scale.set_value(weight_scale) layer.weight_scale.set_value(weight_scale)
def apply(self, layer, x): def apply(self, layer, x):
""" """ """ """
if self.skip_quant:
linear_out = paddle.matmul(x, layer.weight, False, True)
return linear_out
if self.use_per_token_if_dynamic: if self.use_per_token_if_dynamic:
out_type = x.dtype out_type = x.dtype
a_q, a_scales = scaled_fp8_quant(x, use_per_token_if_dynamic=self.use_per_token_if_dynamic) a_q, a_scales = scaled_fp8_quant(x, use_per_token_if_dynamic=self.use_per_token_if_dynamic)

View File

@@ -73,6 +73,30 @@ class ErnieRotaryEmbedding:
return rot_emb return rot_emb
class GlmRotaryEmbedding:
def __init__(self, rotary_dim, base, partial_rotary_factor):
"""
Pre-calculate rotary position embedding for position_ids.
"""
self.rotary_dim = rotary_dim
self.base = base
if partial_rotary_factor < 1.0:
self.rotary_dim = int(self.rotary_dim * partial_rotary_factor)
def __call__(self, position_ids):
bsz, max_seq_len = position_ids.shape[:2]
inv_freq = self.base ** (-paddle.arange(0, self.rotary_dim, 2, dtype="float32") / self.rotary_dim)
freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq)
# shape: [B, S, D/2]
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32")
emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim // 2))
# shape: [B, S, 1, D]
emb = paddle.unsqueeze(emb, 2)
rot_emb[0] = paddle.cos(emb)
rot_emb[1] = paddle.sin(emb)
return rot_emb
class QwenRotaryEmbedding: class QwenRotaryEmbedding:
def __init__(self, rotary_dim, base, partial_rotary_factor): def __init__(self, rotary_dim, base, partial_rotary_factor):
""" """
@@ -246,6 +270,9 @@ def get_rope_impl(
if model_config is None or architecture.startswith("Qwen"): if model_config is None or architecture.startswith("Qwen"):
rotary_emb_layer = QwenRotaryEmbedding(rotary_dim, base, partial_rotary_factor) rotary_emb_layer = QwenRotaryEmbedding(rotary_dim, base, partial_rotary_factor)
rotary_emb = rotary_emb_layer(position_ids) rotary_emb = rotary_emb_layer(position_ids)
elif architecture.startswith("Glm"):
rotary_emb_layer = GlmRotaryEmbedding(rotary_dim, base, partial_rotary_factor)
rotary_emb = rotary_emb_layer(position_ids)
else: else:
rotary_emb_layer = ErnieRotaryEmbedding(rotary_dim, base, partial_rotary_factor) rotary_emb_layer = ErnieRotaryEmbedding(rotary_dim, base, partial_rotary_factor)
rotary_emb = rotary_emb_layer(position_ids) rotary_emb = rotary_emb_layer(position_ids)

View File

@@ -77,6 +77,17 @@ def per_block_cast_to_fp8(x: Tensor, block_size: list = [128, 128]) -> Tuple[Ten
) )
def per_token_cast_to_fp8(x: Tensor) -> Tuple[Tensor, Tensor]:
"""
Per token cast to float8_e4m3fn used in wfp8apf8
"""
x_abs = paddle.abs(x).astype(paddle.float32)
x_max = x_abs.max(axis=-1, keepdim=True).clip_(min=1e-4)
x_s = x_max / 448.0
x_q = paddle.clip(x / x_s, -448.0, 448.0).astype(paddle.float8_e4m3fn)
return x_q, x_s
# for distributed tensor model parallel # for distributed tensor model parallel
def _set_var_distributed(var: Tensor, split_axis: int): def _set_var_distributed(var: Tensor, split_axis: int):
""" """

View File

@@ -0,0 +1,579 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
import re
from functools import partial
import paddle
from paddle import nn
from paddleformers.transformers import PretrainedModel
from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.graph_optimization.decorator import (
support_graph_optimization,
)
from fastdeploy.model_executor.layers.activation import SiluAndMul
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
from fastdeploy.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
from fastdeploy.model_executor.layers.normalization import RMSNorm
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
class Glm4MoeMLP(nn.Layer):
""" """
def __init__(
self,
fd_config: FDConfig,
intermediate_size: int,
prefix: str = "",
reduce_results: bool = True,
) -> None:
super().__init__()
self.up_gate_proj = MergedColumnParallelLinear(
fd_config=fd_config,
prefix=f"{prefix}.up_gate_proj",
input_size=fd_config.model_config.hidden_size,
output_size=intermediate_size * 2,
with_bias=False,
activation=fd_config.model_config.hidden_act,
)
self.down_proj = RowParallelLinear(
fd_config=fd_config,
prefix=f"{prefix}.down_proj",
input_size=intermediate_size,
output_size=fd_config.model_config.hidden_size,
with_bias=False,
reduce_results=reduce_results,
)
self.act_fn = SiluAndMul(
fd_config=fd_config,
bias=None,
act_method=fd_config.model_config.hidden_act,
)
def forward(self, x):
""" """
gate_up_out = self.up_gate_proj(x)
act_out = self.act_fn(gate_up_out)
down_out = self.down_proj(act_out)
return down_out
class Glm4Moe(nn.Layer):
def __init__(
self,
fd_config: FDConfig,
layer_id: int,
prefix: str = "",
) -> None:
super().__init__()
self.expert_parallel_size = fd_config.parallel_config.expert_parallel_size
self.tensor_parallel_size = fd_config.parallel_config.tensor_parallel_size
self.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
self.tp_group = fd_config.parallel_config.tp_group
self.use_ep = self.expert_parallel_size > 1
self.use_tp = self.tensor_parallel_size > 1
self.n_routed_experts: int = fd_config.model_config.n_routed_experts
self.n_shared_experts: int = fd_config.model_config.n_shared_experts
weight_key_map = {
"gate_correction_bias_key": f"{prefix}.gate.e_score_correction_bias",
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight",
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight",
}
self.gate = ReplicatedLinear(
fd_config=fd_config,
prefix=f"{prefix}.gate",
input_size=fd_config.model_config.hidden_size,
output_size=fd_config.model_config.n_routed_experts,
with_bias=False,
skip_quant=True,
weight_dtype="float32",
)
self.gate.e_score_correction_bias = self.create_parameter(
shape=[1, fd_config.model_config.n_routed_experts],
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
)
self.experts = FusedMoE(
fd_config,
reduce_results=False,
moe_intermediate_size=fd_config.model_config.moe_intermediate_size,
num_experts=fd_config.model_config.n_routed_experts,
top_k=fd_config.model_config.num_experts_per_tok,
topk_method="noaux_tc",
topk_group=fd_config.model_config.topk_group,
n_group=fd_config.model_config.n_group,
routed_scaling_factor=fd_config.model_config.routed_scaling_factor,
layer_idx=layer_id,
gate_correction_bias=self.gate.e_score_correction_bias,
weight_key_map=weight_key_map,
)
shared_experts_intermediate_size = self.n_shared_experts * fd_config.model_config.moe_intermediate_size
self.shared_experts = Glm4MoeMLP(
fd_config=fd_config,
intermediate_size=shared_experts_intermediate_size,
prefix=f"{prefix}.shared_experts",
reduce_results=False,
)
def forward(self, x):
shared_experts_out = self.shared_experts(x)
out = self.experts(x, self.gate)
out = out + shared_experts_out
# We do to TP all reduce after the sum of experts.
if self.tensor_parallel_size > 1:
tensor_model_parallel_all_reduce(out)
return out
class Glm4MoeAttention(nn.Layer):
""" """
def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None:
super().__init__()
tp_size = fd_config.parallel_config.tensor_parallel_size
self.fd_config = fd_config
self.head_dim = fd_config.model_config.head_dim
self.num_heads = fd_config.model_config.num_attention_heads // tp_size
self.num_kv_heads = fd_config.model_config.num_key_value_heads // tp_size
self.attention_bias = fd_config.model_config.attention_bias
self.use_qk_norm = fd_config.model_config.use_qk_norm
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.qkv_proj = QKVParallelLinear(fd_config, prefix=f"{prefix}.qkv_proj", with_bias=self.attention_bias)
self.o_proj = RowParallelLinear(
fd_config,
prefix=f"{prefix}.o_proj",
input_size=fd_config.model_config.num_attention_heads * fd_config.model_config.head_dim,
output_size=fd_config.model_config.hidden_size,
)
self.attn = Attention(
fd_config,
layer_id=layer_id,
prefix=prefix,
use_neox_rotary_style=True,
rms_norm_eps=fd_config.model_config.rms_norm_eps,
)
if self.use_qk_norm:
self.q_norm = RMSNorm(
fd_config,
hidden_size=self.head_dim,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.q_norm",
begin_norm_axis=2,
)
self.k_norm = RMSNorm(
fd_config,
hidden_size=self.head_dim,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.k_norm",
begin_norm_axis=2,
)
def forward(
self,
forward_meta: ForwardMeta,
hidden_states: paddle.Tensor,
):
""" """
qkv_out = self.qkv_proj(hidden_states)
if self.use_qk_norm:
q, k, v = qkv_out.split([self.q_size, self.kv_size, self.kv_size], axis=-1)
q = self.q_norm(q.reshape([-1, self.num_heads, self.head_dim])).reshape(q.shape)
k = self.k_norm(k.reshape([-1, self.num_kv_heads, self.head_dim])).reshape(k.shape)
qkv_out = paddle.concat([q, k, v], axis=-1)
atten_out = self.attn(
qkv=qkv_out,
forward_meta=forward_meta,
)
output = self.o_proj(atten_out)
return output
class Glm4MoeDecoderLayer(nn.Layer):
""" """
def __init__(
self,
fd_config: FDConfig,
prefix: str = "",
) -> None:
super().__init__()
layer_id = int(prefix.split(sep=".")[-1])
self.self_attn = Glm4MoeAttention(
fd_config=fd_config,
layer_id=layer_id,
prefix=f"{prefix}.self_attn",
)
if (
fd_config.model_config.n_routed_experts is not None
and layer_id >= fd_config.model_config.first_k_dense_replace
):
self.mlp = Glm4Moe(fd_config, layer_id, prefix=f"{prefix}.mlp")
else:
self.mlp = Glm4MoeMLP(
fd_config,
intermediate_size=fd_config.model_config.intermediate_size,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(
fd_config,
hidden_size=fd_config.model_config.hidden_size,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.input_layernorm",
)
self.post_attention_layernorm = RMSNorm(
fd_config,
hidden_size=fd_config.model_config.hidden_size,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.post_attention_layernorm",
)
def forward(
self,
forward_meta: ForwardMeta,
hidden_states: paddle.Tensor,
residual: paddle.Tensor = None,
):
""" """
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
hidden_states=hidden_states,
forward_meta=forward_meta,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_graph_optimization
class Glm4MoeModel(nn.Layer):
""" """
def __init__(
self,
fd_config: FDConfig = None,
):
"""
Initializer for the Qwen2Model class.
Args:
"""
super().__init__()
self.num_layers = fd_config.model_config.num_hidden_layers
fd_config.model_config.pretrained_config.prefix_name = "model"
self.embed_tokens = VocabParallelEmbedding(
fd_config,
num_embeddings=fd_config.model_config.vocab_size,
embedding_dim=fd_config.model_config.hidden_size,
params_dtype=paddle.get_default_dtype,
prefix=(f"{fd_config.model_config.pretrained_config.prefix_name}.embed_tokens"),
)
self.layers = nn.LayerList(
[
Glm4MoeDecoderLayer(
fd_config,
prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.layers.{i}",
)
for i in range(self.num_layers)
]
)
self.norm = RMSNorm(
fd_config,
hidden_size=fd_config.model_config.hidden_size,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.norm",
)
def forward(
self,
ids_remove_padding: paddle.Tensor,
forward_meta: ForwardMeta,
):
""" """
hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding)
residual = None
for i in range(self.num_layers):
hidden_states, residual = self.layers[i](forward_meta, hidden_states, residual)
hidden_states = hidden_states + residual
out = self.norm(hidden_states)
return out
class Glm4MoeForCausalLM(ModelForCasualLM):
"""
Glm4MoeForCausalLM
"""
def __init__(self, fd_config: FDConfig):
"""
Args:
fd_config (FDConfig): Configurations for the LLM model.
"""
super(Glm4MoeForCausalLM, self).__init__(fd_config)
self.model = Glm4MoeModel(fd_config)
self.ori_vocab_size = fd_config.model_config.ori_vocab_size
self.lm_head = ParallelLMHead(
fd_config,
embedding_dim=fd_config.model_config.hidden_size,
num_embeddings=fd_config.model_config.vocab_size,
prefix="lm_head",
)
@classmethod
def name(self):
""" """
return "Glm4MoeForCausalLM"
@paddle.no_grad()
def load_weights(self, weights_iterator) -> None:
"""
Load model parameters from a given weights_iterator object.
Args:
weights_iterator (Iterator): An iterator yielding (name, weight) pairs.
"""
from fastdeploy.model_executor.utils import (
default_weight_loader,
process_weights_after_loading,
)
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("up_gate_proj", "gate_proj", "gate"),
("up_gate_proj", "up_proj", "up"),
("embed_tokens.embeddings", "embed_tokens", None),
("lm_head.linear", "lm_head", None),
("experts.gate_correction_bias", "gate.e_score_correction_bias", None),
]
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
num_experts=self.fd_config.model_config.n_routed_experts,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
param_gate_up_proj_name="experts.up_gate_proj_",
param_down_proj_name="experts.down_proj_",
)
params_dict = dict(self.named_parameters())
process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()))
for loaded_weight_name, loaded_weight in weights_iterator:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in loaded_weight_name:
continue
if "mlp.experts" in loaded_weight_name:
continue
model_param_name = loaded_weight_name.replace(weight_name, param_name)
if model_param_name not in params_dict:
continue
param = params_dict[model_param_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in loaded_weight_name:
continue
model_param_name = loaded_weight_name.replace(weight_name, param_name)
if model_param_name not in params_dict:
continue
param = params_dict[model_param_name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id=shard_id, expert_id=expert_id)
break
else:
model_param_name = loaded_weight_name
if model_param_name not in params_dict:
continue
param = params_dict[model_param_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
weight_loader(param, loaded_weight)
model_sublayer_name = re.sub(r"\.(up_gate_proj_weight|down_proj_weight|weight)$", "", model_param_name)
process_weights_after_loading_fn(model_sublayer_name, param)
@paddle.no_grad()
def set_state_dict(self, state_dict):
"""
glm4_moe only support loader_v1.
"""
assert False, "glm4_moe only support --load_choices default_v1."
def compute_logits(self, hidden_states: paddle.Tensor):
""" """
logits = self.lm_head(hidden_states)
logits = logits.astype(paddle.float32)
logits[:, self.ori_vocab_size :] = -float("inf")
return logits
def forward(
self,
ids_remove_padding: paddle.Tensor,
forward_meta: ForwardMeta,
):
""" """
hidden_states = self.model(ids_remove_padding=ids_remove_padding, forward_meta=forward_meta)
return hidden_states
def clear_grpah_opt_backend(self):
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
self.model.clear_grpah_opt_backend(fd_config=self.fd_config)
class Glm4MoePretrainedModel(PretrainedModel):
"""
Glm4MoePretrainedModel
"""
config_class = FDConfig
def _init_weight(self, layer):
"""
_init_weight
"""
return None
@classmethod
def arch_name(self):
return "Glm4MoeForCausalLM"
@classmethod
def _get_tensor_parallel_mappings(cls, config, is_split=True):
logger.info("Glm4Moe inference model _get_tensor_parallel_mappings")
from fastdeploy.model_executor.models.tp_utils import split_or_merge_func_v1
fn = split_or_merge_func_v1(
is_split=is_split,
tensor_parallel_degree=config.tensor_parallel_degree,
tensor_parallel_rank=config.tensor_parallel_rank,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
head_dim=config.head_dim,
)
def get_tensor_parallel_split_mappings(num_layers):
final_actions = {}
base_actions = {
"lm_head.weight": partial(fn, is_column=True),
"embed_tokens.weight": partial(fn, is_column=False),
"layers.0.self_attn.o_proj.weight": partial(fn, is_column=False),
}
# Self Attention Layer which are need TP.
base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.q_proj.bias"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.k_proj.bias"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.v_proj.bias"] = partial(fn, is_column=True)
# MLP Layer
base_actions["layers.0.mlp.gate_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.mlp.up_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.mlp.down_proj.weight"] = partial(fn, is_column=False)
# Moe Layer
for expert_idx in range(config.n_routed_experts):
base_actions[f"layers.0.mlp.experts.{expert_idx}.up_proj.weight"] = partial(fn, is_column=True)
base_actions[f"layers.0.mlp.experts.{expert_idx}.gate_proj.weight"] = partial(fn, is_column=True)
base_actions[f"layers.0.mlp.experts.{expert_idx}.down_proj.weight"] = partial(fn, is_column=False)
# Shared Expert Layer
base_actions["layers.0.mlp.shared_experts.up_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.mlp.shared_experts.gate_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.mlp.shared_experts.down_proj.weight"] = partial(fn, is_column=False)
# MTP parts
base_actions["layers.46.embed_tokens.weight"] = partial(fn, is_column=False)
base_actions["layers.46.eh_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.46.shared_head.head.weight"] = partial(fn, is_column=True)
for key, action in base_actions.items():
if "layers.0." in key:
for i in range(num_layers):
final_actions[key.replace("layers.0.", f"layers.{i}.")] = action
final_actions[key] = action
return final_actions
mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers)
return mappings

View File

@@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
""" """
from typing import Any, Dict, Optional
from fastdeploy.worker.worker_process import initialize_fd_config from fastdeploy.worker.worker_process import initialize_fd_config
@@ -52,7 +54,7 @@ class RolloutModelConfig:
expert_parallel_size: int = 1, expert_parallel_size: int = 1,
enable_expert_parallel: bool = False, enable_expert_parallel: bool = False,
ori_vocab_size: int = None, ori_vocab_size: int = None,
quantization: str = "None", quantization: Optional[Dict[str, Any]] = None,
guided_decoding_backend: str = "off", guided_decoding_backend: str = "off",
disable_any_whitespace: bool = True, disable_any_whitespace: bool = True,
enable_logprob: bool = False, enable_logprob: bool = False,

View File

@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
""" """
import copy
from typing import Dict from typing import Dict
import paddle import paddle
@@ -28,6 +29,10 @@ from fastdeploy.model_executor.models.ernie4_5_vl.ernie4_5_vl_moe import (
Ernie4_5_VLMoeForConditionalGeneration, Ernie4_5_VLMoeForConditionalGeneration,
Ernie4_5_VLPretrainedModel, Ernie4_5_VLPretrainedModel,
) )
from fastdeploy.model_executor.models.glm4_moe import (
Glm4MoeForCausalLM,
Glm4MoePretrainedModel,
)
from fastdeploy.model_executor.models.model_base import ModelRegistry from fastdeploy.model_executor.models.model_base import ModelRegistry
from fastdeploy.model_executor.models.qwen2 import ( from fastdeploy.model_executor.models.qwen2 import (
Qwen2ForCausalLM, Qwen2ForCausalLM,
@@ -529,3 +534,83 @@ class Qwen2_5_VLForConditionalGenerationRL(Qwen2_5_VLForConditionalGeneration, B
self._complete_missing_mappings() self._complete_missing_mappings()
return self.infer_to_train_mapping return self.infer_to_train_mapping
class Glm4MoeForCausalLMRL(Glm4MoeForCausalLM, BaseRLModel):
"""
Glm4MoeForCausalLMRL
"""
_get_tensor_parallel_mappings = Glm4MoePretrainedModel._get_tensor_parallel_mappings
def __init__(self, fd_config: FDConfig):
"""
Args:
fd_config (FDConfig): Configurations for the LLM model.
"""
super(Glm4MoeForCausalLMRL, self).__init__(fd_config)
@classmethod
def name(self) -> str:
"""name"""
return "Glm4MoeForCausalLMRL"
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
if self._mappings_built:
return self.infer_to_train_mapping
self.infer_to_train_mapping = {}
self._mappings_built = True
# Prepare placeholders
place_holders = ["weight"]
# Initialize mapping dictionary
self._update_base_mappings("model")
base_name = "model.layers"
# Helper function to add layer mappings
def _add_layer_mappings(layer_idx: int):
# MoE specific mappings
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.gate.weight"] = (
f"{base_name}.{layer_idx}.mlp.gate.weight"
)
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.gate.e_score_correction_bias"] = (
f"{base_name}.{layer_idx}.mlp.gate.e_score_correction_bias"
)
# MoE experts mappings
for expert_idx in range(self.fd_config.model_config.n_routed_experts):
for ph in place_holders:
# up_gate_proj (up_gate_proj)
up_gate_proj_key = f"{base_name}.{layer_idx}.mlp.experts.up_gate_proj_weight"
if up_gate_proj_key not in self.infer_to_train_mapping:
self.infer_to_train_mapping[up_gate_proj_key] = []
self.infer_to_train_mapping[up_gate_proj_key].append(
f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}"
)
# down_proj (down_proj)
down_proj_key = f"{base_name}.{layer_idx}.mlp.experts.down_proj_weight"
if down_proj_key not in self.infer_to_train_mapping:
self.infer_to_train_mapping[down_proj_key] = []
self.infer_to_train_mapping[down_proj_key].append(
f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}"
)
# Process MoE layers
for layer_idx in range(
self.fd_config.model_config.first_k_dense_replace,
self.fd_config.model_config.num_hidden_layers,
):
_add_layer_mappings(layer_idx)
self._complete_missing_mappings()
infer_to_train_mapping_copy = copy.deepcopy(self.infer_to_train_mapping)
for key in infer_to_train_mapping_copy.keys():
if "mlp.experts.gate_correction_bias" in key:
self.infer_to_train_mapping.pop(key)
return self.infer_to_train_mapping

View File

@@ -18,6 +18,7 @@ import argparse
import asyncio import asyncio
import codecs import codecs
import importlib import importlib
import json
import logging import logging
import os import os
import random import random
@@ -757,6 +758,18 @@ class StatefulSemaphore:
} }
def parse_quantization(value: str):
"""
Parse a JSON string into a dictionary.
"""
try:
return json.loads(value)
except ValueError:
if value is None or value.lower() == "none":
return None
return {"quantization": value}
# 日志使用全局访问点(兼容原有使用方式) # 日志使用全局访问点(兼容原有使用方式)
def get_logger(name, file_name=None, without_formater=False, print_to_console=False): def get_logger(name, file_name=None, without_formater=False, print_to_console=False):
"""全局函数包装器,保持向后兼容""" """全局函数包装器,保持向后兼容"""

View File

@@ -740,6 +740,7 @@ class GPUModelRunner(ModelRunnerBase):
position_ids=tmp_position_ids, position_ids=tmp_position_ids,
base=self.model_config.rope_theta, base=self.model_config.rope_theta,
model_config=self.model_config, model_config=self.model_config,
partial_rotary_factor=self.model_config.partial_rotary_factor,
) )
# Set block tables # Set block tables
@@ -1589,7 +1590,7 @@ class GPUModelRunner(ModelRunnerBase):
# 2. Dummy run # 2. Dummy run
self._dummy_run( self._dummy_run(
num_tokens=self.parallel_config.max_num_batched_tokens, num_tokens=self.parallel_config.max_num_batched_tokens,
batch_size=min(self.parallel_config.max_num_seqs, 3), batch_size=self.parallel_config.max_num_seqs,
) )
# 3. gc # 3. gc

View File

@@ -44,7 +44,7 @@ from fastdeploy.inter_communicator import EngineWorkerQueue as TaskQueue
from fastdeploy.inter_communicator import IPCSignal from fastdeploy.inter_communicator import IPCSignal
from fastdeploy.model_executor.layers.quantization import get_quantization_config from fastdeploy.model_executor.layers.quantization import get_quantization_config
from fastdeploy.platforms import current_platform from fastdeploy.platforms import current_platform
from fastdeploy.utils import get_logger from fastdeploy.utils import get_logger, parse_quantization
from fastdeploy.worker.worker_base import WorkerBase from fastdeploy.worker.worker_base import WorkerBase
logger = get_logger("worker_process", "worker_process.log") logger = get_logger("worker_process", "worker_process.log")
@@ -545,9 +545,9 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--quantization", "--quantization",
type=str, type=json.loads,
default="None", default=None,
help="Quantization name for the model, currentlly support " help="Quantization name for the model, currently support "
"'wint4', 'wint8'," "'wint4', 'wint8',"
"default is None. The priority of this configuration " "default is None. The priority of this configuration "
"is lower than that of the config file. " "is lower than that of the config file. "
@@ -635,6 +635,9 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
Returns: Returns:
FDConfig: Initialized FastDeploy configuration object FDConfig: Initialized FastDeploy configuration object
""" """
# RL rollout
if args.quantization is not None and isinstance(args.quantization, str):
args.quantization = parse_quantization(args.quantization)
paddle.set_default_dtype(args.dtype) paddle.set_default_dtype(args.dtype)
model_config = ModelConfig(vars(args)) model_config = ModelConfig(vars(args))
device_config = DeviceConfig(vars(args)) device_config = DeviceConfig(vars(args))
@@ -704,12 +707,16 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
if quantization_config is not None: if quantization_config is not None:
quant_config_name = quantization_config["quantization"] quant_config_name = quantization_config["quantization"]
elif args.quantization != "None": elif args.quantization is not None:
quantization_config = {} quantization_config = {}
quant_config_name = args.quantization try:
quantization_config.update(args.quantization)
quant_config_name = quantization_config["quantization"]
except:
quant_config_name = args.quantization["quantization"]
quantization_config["quantization"] = quant_config_name quantization_config["quantization"] = quant_config_name
# Only v1 loader sets is_checkpoint_bf16=True during dynamic quantization. # Only v1 loader sets is_checkpoint_bf16=True during dynamic quantization.
if load_config.load_choices == "default_v1": if load_config.load_choices == "default_v1" and not load_config.dynamic_load_weight:
quantization_config["is_checkpoint_bf16"] = True quantization_config["is_checkpoint_bf16"] = True
# Special handling for Ernie models # Special handling for Ernie models
is_ernie = ErnieArchitectures.contains_ernie_arch(model_config.architectures) is_ernie = ErnieArchitectures.contains_ernie_arch(model_config.architectures)

View File

@@ -0,0 +1,223 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import shutil
import signal
import socket
import subprocess
import sys
import time
import pytest
import requests
# Read ports from environment variables; use default values if not set
FD_API_PORT = int(os.getenv("FD_API_PORT", 8188))
FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8133))
FD_METRICS_PORT = int(os.getenv("FD_METRICS_PORT", 8233))
FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8333))
# List of ports to clean before and after tests
PORTS_TO_CLEAN = [FD_API_PORT, FD_ENGINE_QUEUE_PORT, FD_METRICS_PORT, FD_CACHE_QUEUE_PORT]
def is_port_open(host: str, port: int, timeout=1.0):
"""
Check if a TCP port is open on the given host.
Returns True if connection succeeds, False otherwise.
"""
try:
with socket.create_connection((host, port), timeout):
return True
except Exception:
return False
def kill_process_on_port(port: int):
"""
Kill processes that are listening on the given port.
Uses `lsof` to find process ids and sends SIGKILL.
"""
try:
output = subprocess.check_output(f"lsof -i:{port} -t", shell=True).decode().strip()
current_pid = os.getpid()
parent_pid = os.getppid()
for pid in output.splitlines():
pid = int(pid)
if pid in (current_pid, parent_pid):
print(f"Skip killing current process (pid={pid}) on port {port}")
continue
os.kill(pid, signal.SIGKILL)
print(f"Killed process on port {port}, pid={pid}")
except subprocess.CalledProcessError:
pass
def clean_ports():
"""
Kill all processes occupying the ports listed in PORTS_TO_CLEAN.
"""
for port in PORTS_TO_CLEAN:
kill_process_on_port(port)
time.sleep(2)
@pytest.fixture(scope="session", autouse=True)
def setup_and_run_server():
"""
Pytest fixture that runs once per test session:
- Cleans ports before tests
- Starts the API server as a subprocess
- Waits for server port to open (up to 30 seconds)
- Tears down server after all tests finish
"""
print("Pre-test port cleanup...")
clean_ports()
print("log dir clean ")
if os.path.exists("log") and os.path.isdir("log"):
shutil.rmtree("log")
base_path = os.getenv("MODEL_PATH")
if base_path:
model_path = os.path.join(base_path, "GLM-4.5-Air-Fake")
else:
model_path = "./GLM-4.5-Air-Fake"
log_path = "server.log"
cmd = [
sys.executable,
"-m",
"fastdeploy.entrypoints.openai.api_server",
"--model",
model_path,
"--port",
str(FD_API_PORT),
"--tensor-parallel-size",
"1",
"--engine-worker-queue-port",
str(FD_ENGINE_QUEUE_PORT),
"--metrics-port",
str(FD_METRICS_PORT),
"--cache-queue-port",
str(FD_CACHE_QUEUE_PORT),
"--max-model-len",
"32768",
"--max-num-seqs",
"32",
"--graph-optimization-config",
'{"use_cudagraph":true}',
"--load_choices",
"default_v1",
"--lm_head-fp32",
"--quantization",
'{"quantization":"mix_quant","dense_quant_type":"wfp8afp8","moe_quant_type":"wint8"}',
]
env = os.environ.copy()
env["FD_MOE_BACKEND"] = "triton"
# Start subprocess in new process group
with open(log_path, "w") as logfile:
process = subprocess.Popen(
cmd,
env=env,
stdout=logfile,
stderr=subprocess.STDOUT,
start_new_session=True, # Enables killing full group via os.killpg
)
# Wait up to 300 seconds for API server to be ready
for _ in range(300):
if is_port_open("127.0.0.1", FD_API_PORT):
print(f"API server is up on port {FD_API_PORT}")
break
time.sleep(1)
else:
print("[TIMEOUT] API server failed to start in 5 minutes. Cleaning up...")
try:
os.killpg(process.pid, signal.SIGTERM)
except Exception as e:
print(f"Failed to kill process group: {e}")
raise RuntimeError(f"API server did not start on port {FD_API_PORT}")
yield # Run tests
print("\n===== Post-test server cleanup... =====")
try:
os.killpg(process.pid, signal.SIGTERM)
print(f"API server (pid={process.pid}) terminated")
except Exception as e:
print(f"Failed to terminate API server: {e}")
@pytest.fixture(scope="session")
def api_url(request):
"""
Returns the API endpoint URL for chat completions.
"""
return f"http://0.0.0.0:{FD_API_PORT}/v1/chat/completions"
@pytest.fixture(scope="session")
def metrics_url(request):
"""
Returns the metrics endpoint URL.
"""
return f"http://0.0.0.0:{FD_METRICS_PORT}/metrics"
@pytest.fixture
def headers():
"""
Returns common HTTP request headers.
"""
return {"Content-Type": "application/json"}
@pytest.fixture
def consistent_payload():
"""
Returns a fixed payload for consistency testing,
including a fixed random seed and temperature.
"""
return {
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "牛顿的三大运动定律是什么?"},
],
"temperature": 0.6,
"top_p": 0, # fix top_p to reduce randomness
"seed": 13, # fixed random seed
"max_tokens": 20,
"stream": False,
}
# ==========================
# Test for lm_head_fp32 with fixed payload
# ==========================
def test_lm_head_fp32(api_url, headers, consistent_payload):
"""
Test that two runs with the same fixed input produce similar outputs.
"""
# First request
response = requests.post(api_url, headers=headers, json=consistent_payload, timeout=300)
assert response.status_code == 200
print(json.dumps(response.json(), indent=2, ensure_ascii=False))
resp_json = response.json()
# 校验返回内容与概率信息
assert (
resp_json["choices"][0]["message"]["content"]
== "ichertsorbulkdeployment confusedreraoux Carter pat firingCompatraspectiveidis Verse corporaonych commissionsilk"
)