[Feature] GLM-45-AIR Support Mix Quantization(Dense wfp8afp8 and wint8 triton_moe_backend) (#4051)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled

This commit is contained in:
chen
2025-09-11 20:08:09 +08:00
committed by GitHub
parent 2056a428bd
commit 4859f40b20
15 changed files with 302 additions and 238 deletions

View File

@@ -398,7 +398,7 @@ class SpeculativeConfig:
# model for mtp/eagle/draft_model
self.model: Optional[str] = None
# quantization of model
self.quantization: Optional[str] = None
self.quantization: Optional[Dict[str, Any]] = None
# allocate more blocks to prevent mtp from finishing the block earlier than the main model
# Fixed now
self.num_gpu_block_expand_ratio: Optional[float] = 1

View File

@@ -40,6 +40,7 @@ from fastdeploy.utils import (
DeprecatedOptionWarning,
FlexibleArgumentParser,
is_port_available,
parse_quantization,
)
@@ -137,7 +138,7 @@ class EngineArgs:
"""
dynamic load weight strategy
"""
quantization: str = None
quantization: Optional[Dict[str, Any]] = None
guided_decoding_backend: str = "off"
"""
Guided decoding backend.
@@ -538,7 +539,7 @@ class EngineArgs:
)
model_group.add_argument(
"--quantization",
type=str,
type=parse_quantization,
default=EngineArgs.quantization,
help="Quantization name for the model, currently support "
"'wint8', 'wint4',"

View File

@@ -16,6 +16,7 @@
from __future__ import annotations
import json
import multiprocessing
import os
import re
@@ -484,7 +485,7 @@ class LLMEngine:
f" --kv_cache_ratio {self.cfg.cache_config.kv_cache_ratio}"
f" --expert_parallel_size {self.cfg.parallel_config.expert_parallel_size}"
f" --data_parallel_size {self.cfg.parallel_config.data_parallel_size}"
f" --quantization {self.cfg.model_config.quantization}"
f" --quantization '{json.dumps(self.cfg.model_config.quantization)}'"
f" --ori_vocab_size {ori_vocab_size}"
f" --speculative_config '{self.cfg.speculative_config.to_json_string()}'"
f" --graph_optimization_config '{self.cfg.graph_opt_config.to_json_string()}'"

View File

@@ -28,38 +28,9 @@ except:
import fastdeploy
from fastdeploy.config import MoEPhase
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
from fastdeploy.utils import singleton
try:
from fastdeploy.model_executor.ops.gpu import noaux_tc
except:
logger.warning("import noaux_tc Failed!")
def get_moe_scores(
gating_output: paddle.Tensor,
n_group,
topk_group,
top_k,
routed_scaling_factor,
e_score_correction_bias,
) -> paddle.Tensor:
"""
compute moe scores using e_score_correction_bias.
"""
scores = paddle.nn.functional.sigmoid(gating_output)
assert e_score_correction_bias is not None, "e_score_correction_bias is none!"
scores_with_bias = scores + e_score_correction_bias
scores, topk_values, topk_idx = noaux_tc(
scores,
scores_with_bias,
n_group if n_group > 0 else 1,
topk_group if topk_group > 0 else 1,
top_k,
routed_scaling_factor,
)
return scores, topk_values, topk_idx
@singleton
class DeepEPEngine:

View File

@@ -27,11 +27,7 @@ from ..utils import get_tensor
from .fused_moe_backend_base import UnquantizedFusedMoEMethod
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import (
moe_expert_dispatch,
moe_expert_reduce,
noaux_tc,
)
from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch, moe_expert_reduce
try:
from fastdeploy.model_executor.ops.gpu import w4afp8_gemm_scale_permute
@@ -43,34 +39,10 @@ elif current_platform.is_iluvatar():
moe_expert_reduce,
)
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs
# used for deepseek_v3
def get_moe_scores(
gating_output: paddle.Tensor,
n_group,
topk_group,
top_k,
routed_scaling_factor,
e_score_correction_bias,
) -> paddle.Tensor:
"""
compute moe scores using e_score_correction_bias.
"""
scores = paddle.nn.functional.sigmoid(gating_output)
scores_with_bias = scores + e_score_correction_bias
scores, topk_values, topk_idx = noaux_tc(
scores,
scores_with_bias,
n_group,
topk_group,
top_k,
routed_scaling_factor,
)
return scores, topk_values, topk_idx
class CutlassMoEMethod(UnquantizedFusedMoEMethod):
"""
Use Cutlass Group Gemm to compute Fused MoE.

View File

@@ -481,7 +481,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
gate_out = gate(x.cast("float32"))
if layer.topk_method == "noaux_tc":
from .ep import get_moe_scores
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
_, topk_weights, topk_ids = get_moe_scores(
gate_out,

View File

@@ -19,39 +19,15 @@ from paddle import nn
import fastdeploy
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
from fastdeploy.model_executor.ops.gpu import (
MoeWna16MarlinGemmApi,
noaux_tc,
tritonmoe_preprocess_func,
)
from ..quantization.quant_base import QuantMethodBase
def get_moe_scores(
gating_output: paddle.Tensor,
n_group,
topk_group,
top_k,
routed_scaling_factor,
e_score_correction_bias,
) -> paddle.Tensor:
"""
compute moe scores using e_score_correction_bias.
"""
scores = paddle.nn.functional.sigmoid(gating_output)
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0)
scores, topk_values, topk_idx = noaux_tc(
scores,
scores_with_bias,
n_group,
topk_group,
top_k,
routed_scaling_factor,
)
return scores, topk_values, topk_idx
def gptq_marlin_moe_repack(
b_q_weight: paddle.Tensor,
perm: paddle.Tensor,

View File

@@ -31,6 +31,7 @@ try:
from .triton_moe_kernels import fused_moe_kernel_paddle
except ImportError:
pass
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
class TritonWeightOnlyMoEMethod(QuantMethodBase):
@@ -71,43 +72,70 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
layer.moe_intermediate_size,
layer.hidden_size,
]
setattr(
layer,
up_gate_proj_weight_name,
layer.create_parameter(
if self.quant_config.is_checkpoint_bf16:
layer.up_gate_proj_weight = layer.create_parameter(
shape=self.up_gate_proj_weight_shape,
dtype=self.weight_dtype,
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
down_proj_weight_name,
layer.create_parameter(
)
layer.down_proj_weight = layer.create_parameter(
shape=self.down_proj_weight_shape,
dtype=self.weight_dtype,
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# weight_scale
setattr(
layer,
self.added_scale_attrs[0],
layer.create_parameter(
shape=[layer.num_local_experts, layer.moe_intermediate_size * 2],
dtype=self.default_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
self.added_scale_attrs[1],
layer.create_parameter(
shape=[layer.num_local_experts, layer.hidden_size],
dtype=self.default_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
)
set_weight_attrs(
layer.up_gate_proj_weight,
{
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=True),
},
)
set_weight_attrs(
layer.down_proj_weight,
{
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=False),
},
)
else:
setattr(
layer,
up_gate_proj_weight_name,
layer.create_parameter(
shape=self.up_gate_proj_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
down_proj_weight_name,
layer.create_parameter(
shape=self.down_proj_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# weight_scale
setattr(
layer,
self.added_scale_attrs[0],
layer.create_parameter(
shape=[layer.num_local_experts, layer.moe_intermediate_size * 2],
dtype=self.default_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
self.added_scale_attrs[1],
layer.create_parameter(
shape=[layer.num_local_experts, layer.hidden_size],
dtype=self.default_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
def process_loaded_weights(self, layer: nn.Layer, state_dict):
"""
@@ -150,6 +178,62 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
getattr(layer, weight_name).set_value(quanted_weight)
getattr(layer, scale_name).set_value(quanted_weight_scale)
def process_weights_after_loading(self, layer):
""" """
if not self.quant_config.is_checkpoint_bf16:
return
algo = layer.quant_method.quant_config.name()
assert algo == "wint8"
max_bound = 127
weight_id_map = {"gate_up": 0, "down": 1}
if (
hasattr(layer.up_gate_proj_weight, "tensor_track")
and layer.up_gate_proj_weight.tensor_track is not None
and layer.up_gate_proj_weight.tensor_track.is_fully_copied()
):
weight_type = "gate_up"
layer.up_gate_proj_weight.tensor_track = None
else:
weight_type = "down"
layer.down_proj_weight.tensor_track = None
# weight
weight_name = self.added_weight_attrs[weight_id_map[weight_type]]
# scale
scale_name = self.added_scale_attrs[weight_id_map[weight_type]]
weight_tensor = getattr(layer, weight_name)
quanted_weight_scale = weight_tensor.abs().max(axis=1)
quanted_weight = weight_tensor / quanted_weight_scale[:, None, :] * max_bound
quanted_weight = paddle.round(quanted_weight).astype("int8")
quanted_weight_scale = quanted_weight_scale / max_bound
getattr(layer, weight_name).value().get_tensor()._clear()
# create weight
setattr(
layer,
weight_name,
layer.create_parameter(
shape=weight_tensor.shape,
dtype=quanted_weight.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# create scale
setattr(
layer,
scale_name,
layer.create_parameter(
shape=quanted_weight_scale.shape,
dtype=quanted_weight_scale.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, weight_name).copy_(quanted_weight, False)
getattr(layer, scale_name).copy_(quanted_weight_scale, False)
def apply(
self,
layer: nn.Layer,
@@ -167,13 +251,24 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
moe_intermediate_size = layer.moe_intermediate_size
hidden_size = layer.hidden_size
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out,
layer.gate_correction_bias,
top_k,
True, # apply_norm_weight,
False,
)
if layer.topk_method == "noaux_tc":
gate_out, topk_weights, topk_ids = get_moe_scores(
gate_out,
layer.n_group,
layer.topk_group,
layer.top_k,
layer.routed_scaling_factor,
layer.gate_correction_bias,
)
topk_weights, topk_ids = paddle.topk(gate_out, k=layer.top_k, axis=-1, sorted=False)
else:
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out,
layer.gate_correction_bias,
top_k,
True, # apply_norm_weight,
False,
)
up_gate_proj_out = paddle.empty(
[token_num * top_k, moe_intermediate_size * 2],
dtype=x.dtype,
@@ -290,6 +385,9 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
down_proj_out.reshape_([token_num, top_k, hidden_size])
out = down_proj_out.sum(axis=1)
if layer.reduce_results and layer.tp_size > 1:
tensor_model_parallel_all_reduce(out)
return out

View File

@@ -27,6 +27,11 @@ from fastdeploy.model_executor.utils import slice_fn
from fastdeploy.platforms import current_platform
from fastdeploy.worker.experts_manager import RedundantExpertManger
try:
from fastdeploy.model_executor.ops.gpu import noaux_tc
except:
logger.warning("import noaux_tc Failed!")
def get_moe_method():
"""
@@ -54,6 +59,31 @@ def get_moe_method():
raise NotImplementedError
def get_moe_scores(
gating_output: paddle.Tensor,
n_group,
topk_group,
top_k,
routed_scaling_factor,
e_score_correction_bias,
) -> paddle.Tensor:
"""
compute moe scores using e_score_correction_bias.
"""
scores = paddle.nn.functional.sigmoid(gating_output)
assert e_score_correction_bias is not None, "e_score_correction_bias is none!"
scores_with_bias = scores + e_score_correction_bias
scores, topk_values, topk_idx = noaux_tc(
scores,
scores_with_bias,
n_group if n_group > 0 else 1,
topk_group if topk_group > 0 else 1,
top_k,
routed_scaling_factor,
)
return scores, topk_values, topk_idx
class FusedMoE(nn.Layer):
"""
FusedMoE is a layer that performs MoE (Mixture of Experts) computation.

View File

@@ -14,10 +14,15 @@
# limitations under the License.
"""
import copy
from typing import Optional
import paddle
from fastdeploy.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
)
from fastdeploy.model_executor.layers.quantization.ops import (
cutlass_scaled_mm,
scaled_fp8_quant,
@@ -26,6 +31,7 @@ from fastdeploy.model_executor.layers.quantization.quant_base import (
QuantConfigBase,
QuantMethodBase,
)
from fastdeploy.model_executor.utils import TensorTracker, set_weight_attrs
class WFP8AFP8Config(QuantConfigBase):
@@ -33,13 +39,19 @@ class WFP8AFP8Config(QuantConfigBase):
Quantization config for weight and activation with FP8.
"""
def __init__(self, weight_scale_dict, act_scale_dict) -> None:
def __init__(
self,
activation_scheme: str = "dynamic",
weight_block_size: list[int] = [-1, 1],
is_checkpoint_bf16: bool = False,
) -> None:
super().__init__()
self.weight_scale_dict = weight_scale_dict
self.act_scale_dict = act_scale_dict
self.quant_max_bound = 448
self.quant_min_bound = -448
self.quant_round_type = 1
self.activation_scheme = activation_scheme
self.weight_block_size = weight_block_size
self.is_checkpoint_bf16 = is_checkpoint_bf16
def name(self) -> str:
""" """
@@ -48,9 +60,8 @@ class WFP8AFP8Config(QuantConfigBase):
@classmethod
def from_config(cls, config: dict) -> "WFP8AFP8Config":
""" """
weight_scale_dict = config.get("weight_scale_dict", None)
act_scale_dict = config.get("act_scale_dict", None)
return cls(weight_scale_dict, act_scale_dict)
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
return cls(is_checkpoint_bf16=is_checkpoint_bf16)
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
""" """
@@ -68,26 +79,87 @@ class WFP8AFP8LinearMethod(QuantMethodBase):
) -> None:
super().__init__()
self.quant_config = quant_config
self.use_per_token_if_dynamic = True
def create_weights(self, layer, **extra_weight_attrs):
""" """
layer.weight_shape.reverse()
layer.weight_dtype = "float8_e4m3fn"
# TODO(YuanRisheng): set weight logic should be moved to process_loaded_weights func
self.skip_quant = False
layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
weight_shape = layer.weight_shape
weight_block_size = self.quant_config.weight_block_size
assert len(weight_shape) == 2 and len(weight_block_size) == 2
scale_shape = copy.deepcopy(weight_shape)
for i in range(len(weight_shape)):
scale_shape[i] = (
(weight_shape[i] + weight_block_size[i] - 1) // weight_block_size[i] if weight_block_size[i] > 0 else 1
)
scale_shape = scale_shape[::-1]
if self.quant_config.is_checkpoint_bf16:
layer.weight = layer.create_parameter(
shape=weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
quant_attrs = extra_weight_attrs
if isinstance(layer, MergedColumnParallelLinear) or isinstance(layer, QKVParallelLinear):
quant_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(
shape=layer.weight_shape, output_dim=extra_weight_attrs.get("output_dim")
),
}
set_weight_attrs(
layer.weight,
quant_attrs,
)
else:
layer.weight_shape.reverse()
layer.weight_dtype = "float8_e4m3fn"
# TODO(YuanRisheng): set weight logic should be moved to process_loaded_weights func
self.skip_quant = False
layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.weight_scale = layer.create_parameter(
shape=scale_shape,
dtype="float32",
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
def process_weights_after_loading(self, layer) -> None:
if not self.quant_config.is_checkpoint_bf16:
return
weight_tensor = layer.weight.transpose([1, 0]).contiguous()
assert self.quant_config.weight_block_size == [-1, 1]
qweight, weight_scale = scaled_fp8_quant(
weight_tensor,
use_per_token_if_dynamic=True,
)
if hasattr(layer.weight, "tensor_track"):
layer.weight.tensor_track = None
layer.weight.value().get_tensor()._clear()
del layer.weight
layer.weight = layer.create_parameter(
shape=qweight.shape,
dtype="float8_e4m3fn",
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.weight_scale = layer.create_parameter(
shape=[1],
shape=weight_scale.shape,
dtype="float32",
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.weight.copy_(qweight, False)
layer.weight_scale.copy_(weight_scale, False)
def process_loaded_weights(self, layer, weights) -> None:
""" """
if self.skip_quant:
@@ -106,9 +178,6 @@ class WFP8AFP8LinearMethod(QuantMethodBase):
def apply(self, layer, x):
""" """
if self.skip_quant:
linear_out = paddle.matmul(x, layer.weight, False, True)
return linear_out
if self.use_per_token_if_dynamic:
out_type = x.dtype
a_q, a_scales = scaled_fp8_quant(x, use_per_token_if_dynamic=self.use_per_token_if_dynamic)

View File

@@ -17,12 +17,9 @@
from __future__ import annotations
import re
from functools import partial
import paddle
from paddle import nn
from paddleformers.transformers import PretrainedModel
from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
@@ -494,81 +491,3 @@ class Glm4MoeForCausalLM(ModelForCasualLM):
def clear_grpah_opt_backend(self):
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
self.model.clear_grpah_opt_backend(fd_config=self.fd_config)
class Glm4MoePretrainedModel(PretrainedModel):
"""
Glm4MoePretrainedModel
"""
config_class = FDConfig
def _init_weight(self, layer):
"""
_init_weight
"""
return None
@classmethod
def arch_name(self):
return "Glm4MoeForCausalLM"
@classmethod
def _get_tensor_parallel_mappings(cls, config, is_split=True):
logger.info("Glm4Moe inference model _get_tensor_parallel_mappings")
from paddleformers.transformers.conversion_utils import split_or_merge_func
fn = split_or_merge_func(
is_split=is_split,
tensor_parallel_degree=config.tensor_parallel_degree,
tensor_parallel_rank=config.tensor_parallel_rank,
num_attention_heads=config.num_attention_heads,
)
def get_tensor_parallel_split_mappings(num_layers):
final_actions = {}
base_actions = {
"lm_head.weight": partial(fn, is_column=True),
"embed_tokens.weight": partial(fn, is_column=False),
"layers.0.self_attn.o_proj.weight": partial(fn, is_column=False),
}
# Self Attention Layer which are need TP.
base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True)
# MLP Layer
base_actions["layers.0.mlp.gate_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.mlp.up_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.mlp.down_proj.weight"] = partial(fn, is_column=False)
# Moe Layer
for expert_idx in range(config.n_routed_experts):
base_actions[f"layers.0.mlp.experts.{expert_idx}.up_proj.weight"] = partial(fn, is_column=True)
base_actions[f"layers.0.mlp.experts.{expert_idx}.gate_proj.weight"] = partial(fn, is_column=True)
base_actions[f"layers.0.mlp.experts.{expert_idx}.down_proj.weight"] = partial(fn, is_column=False)
# Shared Expert Layer
base_actions["layers.0.mlp.shared_experts.up_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.mlp.shared_experts.gate_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.mlp.shared_experts.down_proj.weight"] = partial(fn, is_column=False)
# MTP parts
base_actions["layers.46.embed_tokens.weight"] = partial(fn, is_column=False)
base_actions["layers.46.eh_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.46.shared_head.head.weight"] = partial(fn, is_column=True)
for key, action in base_actions.items():
if "layers.0." in key:
for i in range(num_layers):
final_actions[key.replace("layers.0.", f"layers.{i}.")] = action
final_actions[key] = action
return final_actions
mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers)
return mappings

View File

@@ -14,6 +14,8 @@
# limitations under the License.
"""
from typing import Any, Dict, Optional
from fastdeploy.worker.worker_process import initialize_fd_config
@@ -52,7 +54,7 @@ class RolloutModelConfig:
expert_parallel_size: int = 1,
enable_expert_parallel: bool = False,
ori_vocab_size: int = None,
quantization: str = "None",
quantization: Optional[Dict[str, Any]] = None,
guided_decoding_backend: str = "off",
disable_any_whitespace: bool = True,
enable_logprob: bool = False,

View File

@@ -18,6 +18,7 @@ import argparse
import asyncio
import codecs
import importlib
import json
import logging
import os
import random
@@ -766,6 +767,16 @@ class StatefulSemaphore:
}
def parse_quantization(value: str):
"""
Parse a JSON string into a dictionary.
"""
try:
return json.loads(value)
except ValueError:
return {"quantization": value}
# 日志使用全局访问点(兼容原有使用方式)
def get_logger(name, file_name=None, without_formater=False, print_to_console=False):
"""全局函数包装器,保持向后兼容"""

View File

@@ -44,7 +44,7 @@ from fastdeploy.inter_communicator import EngineWorkerQueue as TaskQueue
from fastdeploy.inter_communicator import IPCSignal
from fastdeploy.model_executor.layers.quantization import get_quantization_config
from fastdeploy.platforms import current_platform
from fastdeploy.utils import get_logger
from fastdeploy.utils import get_logger, parse_quantization
from fastdeploy.worker.worker_base import WorkerBase
logger = get_logger("worker_process", "worker_process.log")
@@ -546,8 +546,8 @@ def parse_args():
parser.add_argument(
"--quantization",
type=str,
default="None",
type=json.loads,
default=None,
help="Quantization name for the model, currently support "
"'wint4', 'wint8',"
"default is None. The priority of this configuration "
@@ -642,6 +642,9 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
Returns:
FDConfig: Initialized FastDeploy configuration object
"""
# RL rollout
if args.quantization is not None and isinstance(args.quantization, str):
args.quantization = parse_quantization(args.quantization)
paddle.set_default_dtype(args.dtype)
model_config = ModelConfig(vars(args))
device_config = DeviceConfig(vars(args))
@@ -713,10 +716,14 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
if "kv_cache_quant_type" in quantization_config and load_config.load_choices == "default_v1":
quantization_config["is_checkpoint_bf16"] = True
elif args.quantization != "None":
elif args.quantization is not None:
quantization_config = {}
quant_config_name = args.quantization
quantization_config["quantization"] = quant_config_name
try:
quantization_config.update(args.quantization)
quant_config_name = quantization_config["quantization"]
except:
quant_config_name = args.quantization["quantization"]
quantization_config["quantization"] = quant_config_name
# Only v1 loader sets is_checkpoint_bf16=True during dynamic quantization.
if load_config.load_choices == "default_v1":
quantization_config["is_checkpoint_bf16"] = True

View File

@@ -121,12 +121,16 @@ def setup_and_run_server():
"--load_choices",
"default_v1",
"--lm_head-fp32",
"--quantization",
'{"quantization":"mix_quant","dense_quant_type":"wfp8afp8","moe_quant_type":"wint8"}',
]
env = os.environ.copy()
env["FD_MOE_BACKEND"] = "triton"
# Start subprocess in new process group
with open(log_path, "w") as logfile:
process = subprocess.Popen(
cmd,
env=env,
stdout=logfile,
stderr=subprocess.STDOUT,
start_new_session=True, # Enables killing full group via os.killpg
@@ -194,7 +198,7 @@ def consistent_payload():
"temperature": 0.6,
"top_p": 0, # fix top_p to reduce randomness
"seed": 13, # fixed random seed
"max_tokens": 3,
"max_tokens": 20,
"stream": False,
}
@@ -213,4 +217,7 @@ def test_lm_head_fp32(api_url, headers, consistent_payload):
resp_json = response.json()
# 校验返回内容与概率信息
assert resp_json["choices"][0]["message"]["content"] == "ichertsor"
assert (
resp_json["choices"][0]["message"]["content"]
== "ichertsorbulkdeployment confusedreraoux Carter pat firingCompatraspectiveidis Verse corporaonych commissionsilk"
)