diff --git a/fastdeploy/config.py b/fastdeploy/config.py index d8c5b9f50..1ab4619fc 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -398,7 +398,7 @@ class SpeculativeConfig: # model for mtp/eagle/draft_model self.model: Optional[str] = None # quantization of model - self.quantization: Optional[str] = None + self.quantization: Optional[Dict[str, Any]] = None # allocate more blocks to prevent mtp from finishing the block earlier than the main model # Fixed now self.num_gpu_block_expand_ratio: Optional[float] = 1 diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index f553ad2d2..4f0ee57cb 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -40,6 +40,7 @@ from fastdeploy.utils import ( DeprecatedOptionWarning, FlexibleArgumentParser, is_port_available, + parse_quantization, ) @@ -137,7 +138,7 @@ class EngineArgs: """ dynamic load weight strategy """ - quantization: str = None + quantization: Optional[Dict[str, Any]] = None guided_decoding_backend: str = "off" """ Guided decoding backend. @@ -538,7 +539,7 @@ class EngineArgs: ) model_group.add_argument( "--quantization", - type=str, + type=parse_quantization, default=EngineArgs.quantization, help="Quantization name for the model, currently support " "'wint8', 'wint4'," diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index c3978443d..00f24d998 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -16,6 +16,7 @@ from __future__ import annotations +import json import multiprocessing import os import re @@ -484,7 +485,7 @@ class LLMEngine: f" --kv_cache_ratio {self.cfg.cache_config.kv_cache_ratio}" f" --expert_parallel_size {self.cfg.parallel_config.expert_parallel_size}" f" --data_parallel_size {self.cfg.parallel_config.data_parallel_size}" - f" --quantization {self.cfg.model_config.quantization}" + f" --quantization '{json.dumps(self.cfg.model_config.quantization)}'" f" --ori_vocab_size {ori_vocab_size}" f" --speculative_config '{self.cfg.speculative_config.to_json_string()}'" f" --graph_optimization_config '{self.cfg.graph_opt_config.to_json_string()}'" diff --git a/fastdeploy/model_executor/layers/moe/ep.py b/fastdeploy/model_executor/layers/moe/ep.py index c946d9ba1..7b3bcecad 100644 --- a/fastdeploy/model_executor/layers/moe/ep.py +++ b/fastdeploy/model_executor/layers/moe/ep.py @@ -28,38 +28,9 @@ except: import fastdeploy from fastdeploy.config import MoEPhase +from fastdeploy.model_executor.layers.moe.moe import get_moe_scores from fastdeploy.utils import singleton -try: - from fastdeploy.model_executor.ops.gpu import noaux_tc -except: - logger.warning("import noaux_tc Failed!") - - -def get_moe_scores( - gating_output: paddle.Tensor, - n_group, - topk_group, - top_k, - routed_scaling_factor, - e_score_correction_bias, -) -> paddle.Tensor: - """ - compute moe scores using e_score_correction_bias. - """ - scores = paddle.nn.functional.sigmoid(gating_output) - assert e_score_correction_bias is not None, "e_score_correction_bias is none!" - scores_with_bias = scores + e_score_correction_bias - scores, topk_values, topk_idx = noaux_tc( - scores, - scores_with_bias, - n_group if n_group > 0 else 1, - topk_group if topk_group > 0 else 1, - top_k, - routed_scaling_factor, - ) - return scores, topk_values, topk_idx - @singleton class DeepEPEngine: diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index 92832ee27..46d7ee5f0 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -27,11 +27,7 @@ from ..utils import get_tensor from .fused_moe_backend_base import UnquantizedFusedMoEMethod if current_platform.is_cuda(): - from fastdeploy.model_executor.ops.gpu import ( - moe_expert_dispatch, - moe_expert_reduce, - noaux_tc, - ) + from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch, moe_expert_reduce try: from fastdeploy.model_executor.ops.gpu import w4afp8_gemm_scale_permute @@ -43,34 +39,10 @@ elif current_platform.is_iluvatar(): moe_expert_reduce, ) +from fastdeploy.model_executor.layers.moe.moe import get_moe_scores from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs -# used for deepseek_v3 -def get_moe_scores( - gating_output: paddle.Tensor, - n_group, - topk_group, - top_k, - routed_scaling_factor, - e_score_correction_bias, -) -> paddle.Tensor: - """ - compute moe scores using e_score_correction_bias. - """ - scores = paddle.nn.functional.sigmoid(gating_output) - scores_with_bias = scores + e_score_correction_bias - scores, topk_values, topk_idx = noaux_tc( - scores, - scores_with_bias, - n_group, - topk_group, - top_k, - routed_scaling_factor, - ) - return scores, topk_values, topk_idx - - class CutlassMoEMethod(UnquantizedFusedMoEMethod): """ Use Cutlass Group Gemm to compute Fused MoE. diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py index 8799a9c22..ebda4945d 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py @@ -481,7 +481,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): gate_out = gate(x.cast("float32")) if layer.topk_method == "noaux_tc": - from .ep import get_moe_scores + from fastdeploy.model_executor.layers.moe.moe import get_moe_scores _, topk_weights, topk_ids = get_moe_scores( gate_out, diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py index ed39b64e0..4346063b7 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py @@ -19,39 +19,15 @@ from paddle import nn import fastdeploy from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce +from fastdeploy.model_executor.layers.moe.moe import get_moe_scores from fastdeploy.model_executor.ops.gpu import ( MoeWna16MarlinGemmApi, - noaux_tc, tritonmoe_preprocess_func, ) from ..quantization.quant_base import QuantMethodBase -def get_moe_scores( - gating_output: paddle.Tensor, - n_group, - topk_group, - top_k, - routed_scaling_factor, - e_score_correction_bias, -) -> paddle.Tensor: - """ - compute moe scores using e_score_correction_bias. - """ - scores = paddle.nn.functional.sigmoid(gating_output) - scores_with_bias = scores + e_score_correction_bias.unsqueeze(0) - scores, topk_values, topk_idx = noaux_tc( - scores, - scores_with_bias, - n_group, - topk_group, - top_k, - routed_scaling_factor, - ) - return scores, topk_values, topk_idx - - def gptq_marlin_moe_repack( b_q_weight: paddle.Tensor, perm: paddle.Tensor, diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index 21ac7976a..902e8be64 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -31,6 +31,7 @@ try: from .triton_moe_kernels import fused_moe_kernel_paddle except ImportError: pass +from fastdeploy.model_executor.layers.moe.moe import get_moe_scores class TritonWeightOnlyMoEMethod(QuantMethodBase): @@ -71,43 +72,70 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): layer.moe_intermediate_size, layer.hidden_size, ] - setattr( - layer, - up_gate_proj_weight_name, - layer.create_parameter( + if self.quant_config.is_checkpoint_bf16: + layer.up_gate_proj_weight = layer.create_parameter( shape=self.up_gate_proj_weight_shape, - dtype=self.weight_dtype, + dtype=layer.weight_dtype, default_initializer=paddle.nn.initializer.Constant(0), - ), - ) - setattr( - layer, - down_proj_weight_name, - layer.create_parameter( + ) + + layer.down_proj_weight = layer.create_parameter( shape=self.down_proj_weight_shape, - dtype=self.weight_dtype, + dtype=layer.weight_dtype, default_initializer=paddle.nn.initializer.Constant(0), - ), - ) - # weight_scale - setattr( - layer, - self.added_scale_attrs[0], - layer.create_parameter( - shape=[layer.num_local_experts, layer.moe_intermediate_size * 2], - dtype=self.default_dtype, - default_initializer=paddle.nn.initializer.Constant(0), - ), - ) - setattr( - layer, - self.added_scale_attrs[1], - layer.create_parameter( - shape=[layer.num_local_experts, layer.hidden_size], - dtype=self.default_dtype, - default_initializer=paddle.nn.initializer.Constant(0), - ), - ) + ) + set_weight_attrs( + layer.up_gate_proj_weight, + { + **extra_weight_attrs, + "tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=True), + }, + ) + set_weight_attrs( + layer.down_proj_weight, + { + **extra_weight_attrs, + "tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=False), + }, + ) + else: + setattr( + layer, + up_gate_proj_weight_name, + layer.create_parameter( + shape=self.up_gate_proj_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + setattr( + layer, + down_proj_weight_name, + layer.create_parameter( + shape=self.down_proj_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + # weight_scale + setattr( + layer, + self.added_scale_attrs[0], + layer.create_parameter( + shape=[layer.num_local_experts, layer.moe_intermediate_size * 2], + dtype=self.default_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + setattr( + layer, + self.added_scale_attrs[1], + layer.create_parameter( + shape=[layer.num_local_experts, layer.hidden_size], + dtype=self.default_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) def process_loaded_weights(self, layer: nn.Layer, state_dict): """ @@ -150,6 +178,62 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): getattr(layer, weight_name).set_value(quanted_weight) getattr(layer, scale_name).set_value(quanted_weight_scale) + def process_weights_after_loading(self, layer): + """ """ + if not self.quant_config.is_checkpoint_bf16: + return + + algo = layer.quant_method.quant_config.name() + assert algo == "wint8" + max_bound = 127 + weight_id_map = {"gate_up": 0, "down": 1} + if ( + hasattr(layer.up_gate_proj_weight, "tensor_track") + and layer.up_gate_proj_weight.tensor_track is not None + and layer.up_gate_proj_weight.tensor_track.is_fully_copied() + ): + weight_type = "gate_up" + layer.up_gate_proj_weight.tensor_track = None + else: + weight_type = "down" + layer.down_proj_weight.tensor_track = None + + # weight + weight_name = self.added_weight_attrs[weight_id_map[weight_type]] + # scale + scale_name = self.added_scale_attrs[weight_id_map[weight_type]] + + weight_tensor = getattr(layer, weight_name) + quanted_weight_scale = weight_tensor.abs().max(axis=1) + quanted_weight = weight_tensor / quanted_weight_scale[:, None, :] * max_bound + quanted_weight = paddle.round(quanted_weight).astype("int8") + quanted_weight_scale = quanted_weight_scale / max_bound + + getattr(layer, weight_name).value().get_tensor()._clear() + + # create weight + setattr( + layer, + weight_name, + layer.create_parameter( + shape=weight_tensor.shape, + dtype=quanted_weight.dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + # create scale + setattr( + layer, + scale_name, + layer.create_parameter( + shape=quanted_weight_scale.shape, + dtype=quanted_weight_scale.dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + getattr(layer, weight_name).copy_(quanted_weight, False) + getattr(layer, scale_name).copy_(quanted_weight_scale, False) + def apply( self, layer: nn.Layer, @@ -167,13 +251,24 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): moe_intermediate_size = layer.moe_intermediate_size hidden_size = layer.hidden_size - topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( - gate_out, - layer.gate_correction_bias, - top_k, - True, # apply_norm_weight, - False, - ) + if layer.topk_method == "noaux_tc": + gate_out, topk_weights, topk_ids = get_moe_scores( + gate_out, + layer.n_group, + layer.topk_group, + layer.top_k, + layer.routed_scaling_factor, + layer.gate_correction_bias, + ) + topk_weights, topk_ids = paddle.topk(gate_out, k=layer.top_k, axis=-1, sorted=False) + else: + topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( + gate_out, + layer.gate_correction_bias, + top_k, + True, # apply_norm_weight, + False, + ) up_gate_proj_out = paddle.empty( [token_num * top_k, moe_intermediate_size * 2], dtype=x.dtype, @@ -290,6 +385,9 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): down_proj_out.reshape_([token_num, top_k, hidden_size]) out = down_proj_out.sum(axis=1) + if layer.reduce_results and layer.tp_size > 1: + tensor_model_parallel_all_reduce(out) + return out diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index c21cef480..0e9a51cab 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -27,6 +27,11 @@ from fastdeploy.model_executor.utils import slice_fn from fastdeploy.platforms import current_platform from fastdeploy.worker.experts_manager import RedundantExpertManger +try: + from fastdeploy.model_executor.ops.gpu import noaux_tc +except: + logger.warning("import noaux_tc Failed!") + def get_moe_method(): """ @@ -54,6 +59,31 @@ def get_moe_method(): raise NotImplementedError +def get_moe_scores( + gating_output: paddle.Tensor, + n_group, + topk_group, + top_k, + routed_scaling_factor, + e_score_correction_bias, +) -> paddle.Tensor: + """ + compute moe scores using e_score_correction_bias. + """ + scores = paddle.nn.functional.sigmoid(gating_output) + assert e_score_correction_bias is not None, "e_score_correction_bias is none!" + scores_with_bias = scores + e_score_correction_bias + scores, topk_values, topk_idx = noaux_tc( + scores, + scores_with_bias, + n_group if n_group > 0 else 1, + topk_group if topk_group > 0 else 1, + top_k, + routed_scaling_factor, + ) + return scores, topk_values, topk_idx + + class FusedMoE(nn.Layer): """ FusedMoE is a layer that performs MoE (Mixture of Experts) computation. diff --git a/fastdeploy/model_executor/layers/quantization/wfp8afp8.py b/fastdeploy/model_executor/layers/quantization/wfp8afp8.py index f868a9aab..f302215ef 100644 --- a/fastdeploy/model_executor/layers/quantization/wfp8afp8.py +++ b/fastdeploy/model_executor/layers/quantization/wfp8afp8.py @@ -14,10 +14,15 @@ # limitations under the License. """ +import copy from typing import Optional import paddle +from fastdeploy.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, +) from fastdeploy.model_executor.layers.quantization.ops import ( cutlass_scaled_mm, scaled_fp8_quant, @@ -26,6 +31,7 @@ from fastdeploy.model_executor.layers.quantization.quant_base import ( QuantConfigBase, QuantMethodBase, ) +from fastdeploy.model_executor.utils import TensorTracker, set_weight_attrs class WFP8AFP8Config(QuantConfigBase): @@ -33,13 +39,19 @@ class WFP8AFP8Config(QuantConfigBase): Quantization config for weight and activation with FP8. """ - def __init__(self, weight_scale_dict, act_scale_dict) -> None: + def __init__( + self, + activation_scheme: str = "dynamic", + weight_block_size: list[int] = [-1, 1], + is_checkpoint_bf16: bool = False, + ) -> None: super().__init__() - self.weight_scale_dict = weight_scale_dict - self.act_scale_dict = act_scale_dict self.quant_max_bound = 448 self.quant_min_bound = -448 self.quant_round_type = 1 + self.activation_scheme = activation_scheme + self.weight_block_size = weight_block_size + self.is_checkpoint_bf16 = is_checkpoint_bf16 def name(self) -> str: """ """ @@ -48,9 +60,8 @@ class WFP8AFP8Config(QuantConfigBase): @classmethod def from_config(cls, config: dict) -> "WFP8AFP8Config": """ """ - weight_scale_dict = config.get("weight_scale_dict", None) - act_scale_dict = config.get("act_scale_dict", None) - return cls(weight_scale_dict, act_scale_dict) + is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False) + return cls(is_checkpoint_bf16=is_checkpoint_bf16) def get_quant_method(self, layer) -> Optional[QuantMethodBase]: """ """ @@ -68,26 +79,87 @@ class WFP8AFP8LinearMethod(QuantMethodBase): ) -> None: super().__init__() self.quant_config = quant_config + self.use_per_token_if_dynamic = True def create_weights(self, layer, **extra_weight_attrs): """ """ - layer.weight_shape.reverse() - layer.weight_dtype = "float8_e4m3fn" - # TODO(YuanRisheng): set weight logic should be moved to process_loaded_weights func - self.skip_quant = False - layer.create_parameter( - shape=layer.weight_shape, - dtype=layer.weight_dtype, + weight_shape = layer.weight_shape + weight_block_size = self.quant_config.weight_block_size + assert len(weight_shape) == 2 and len(weight_block_size) == 2 + scale_shape = copy.deepcopy(weight_shape) + for i in range(len(weight_shape)): + scale_shape[i] = ( + (weight_shape[i] + weight_block_size[i] - 1) // weight_block_size[i] if weight_block_size[i] > 0 else 1 + ) + scale_shape = scale_shape[::-1] + if self.quant_config.is_checkpoint_bf16: + layer.weight = layer.create_parameter( + shape=weight_shape, + dtype=layer.weight_dtype, + is_bias=False, + default_initializer=paddle.nn.initializer.Constant(0), + ) + quant_attrs = extra_weight_attrs + if isinstance(layer, MergedColumnParallelLinear) or isinstance(layer, QKVParallelLinear): + quant_attrs = { + **extra_weight_attrs, + "tensor_track": TensorTracker( + shape=layer.weight_shape, output_dim=extra_weight_attrs.get("output_dim") + ), + } + set_weight_attrs( + layer.weight, + quant_attrs, + ) + else: + layer.weight_shape.reverse() + layer.weight_dtype = "float8_e4m3fn" + # TODO(YuanRisheng): set weight logic should be moved to process_loaded_weights func + self.skip_quant = False + layer.create_parameter( + shape=layer.weight_shape, + dtype=layer.weight_dtype, + is_bias=False, + default_initializer=paddle.nn.initializer.Constant(0), + ) + layer.weight_scale = layer.create_parameter( + shape=scale_shape, + dtype="float32", + is_bias=False, + default_initializer=paddle.nn.initializer.Constant(0), + ) + + def process_weights_after_loading(self, layer) -> None: + if not self.quant_config.is_checkpoint_bf16: + return + weight_tensor = layer.weight.transpose([1, 0]).contiguous() + assert self.quant_config.weight_block_size == [-1, 1] + qweight, weight_scale = scaled_fp8_quant( + weight_tensor, + use_per_token_if_dynamic=True, + ) + + if hasattr(layer.weight, "tensor_track"): + layer.weight.tensor_track = None + layer.weight.value().get_tensor()._clear() + del layer.weight + + layer.weight = layer.create_parameter( + shape=qweight.shape, + dtype="float8_e4m3fn", is_bias=False, default_initializer=paddle.nn.initializer.Constant(0), ) layer.weight_scale = layer.create_parameter( - shape=[1], + shape=weight_scale.shape, dtype="float32", is_bias=False, default_initializer=paddle.nn.initializer.Constant(0), ) + layer.weight.copy_(qweight, False) + layer.weight_scale.copy_(weight_scale, False) + def process_loaded_weights(self, layer, weights) -> None: """ """ if self.skip_quant: @@ -106,9 +178,6 @@ class WFP8AFP8LinearMethod(QuantMethodBase): def apply(self, layer, x): """ """ - if self.skip_quant: - linear_out = paddle.matmul(x, layer.weight, False, True) - return linear_out if self.use_per_token_if_dynamic: out_type = x.dtype a_q, a_scales = scaled_fp8_quant(x, use_per_token_if_dynamic=self.use_per_token_if_dynamic) diff --git a/fastdeploy/model_executor/models/glm4_moe.py b/fastdeploy/model_executor/models/glm4_moe.py index 1a837cec2..fdbf277af 100644 --- a/fastdeploy/model_executor/models/glm4_moe.py +++ b/fastdeploy/model_executor/models/glm4_moe.py @@ -17,12 +17,9 @@ from __future__ import annotations import re -from functools import partial import paddle from paddle import nn -from paddleformers.transformers import PretrainedModel -from paddleformers.utils.log import logger from fastdeploy.config import FDConfig from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce @@ -494,81 +491,3 @@ class Glm4MoeForCausalLM(ModelForCasualLM): def clear_grpah_opt_backend(self): """Clear graph optimization backend, the captured cuda graph will be cleaned""" self.model.clear_grpah_opt_backend(fd_config=self.fd_config) - - -class Glm4MoePretrainedModel(PretrainedModel): - """ - Glm4MoePretrainedModel - """ - - config_class = FDConfig - - def _init_weight(self, layer): - """ - _init_weight - """ - return None - - @classmethod - def arch_name(self): - return "Glm4MoeForCausalLM" - - @classmethod - def _get_tensor_parallel_mappings(cls, config, is_split=True): - - logger.info("Glm4Moe inference model _get_tensor_parallel_mappings") - - from paddleformers.transformers.conversion_utils import split_or_merge_func - - fn = split_or_merge_func( - is_split=is_split, - tensor_parallel_degree=config.tensor_parallel_degree, - tensor_parallel_rank=config.tensor_parallel_rank, - num_attention_heads=config.num_attention_heads, - ) - - def get_tensor_parallel_split_mappings(num_layers): - final_actions = {} - - base_actions = { - "lm_head.weight": partial(fn, is_column=True), - "embed_tokens.weight": partial(fn, is_column=False), - "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False), - } - - # Self Attention Layer which are need TP. - base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True) - base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True) - base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True) - - # MLP Layer - base_actions["layers.0.mlp.gate_proj.weight"] = partial(fn, is_column=True) - base_actions["layers.0.mlp.up_proj.weight"] = partial(fn, is_column=True) - base_actions["layers.0.mlp.down_proj.weight"] = partial(fn, is_column=False) - - # Moe Layer - for expert_idx in range(config.n_routed_experts): - base_actions[f"layers.0.mlp.experts.{expert_idx}.up_proj.weight"] = partial(fn, is_column=True) - base_actions[f"layers.0.mlp.experts.{expert_idx}.gate_proj.weight"] = partial(fn, is_column=True) - base_actions[f"layers.0.mlp.experts.{expert_idx}.down_proj.weight"] = partial(fn, is_column=False) - - # Shared Expert Layer - base_actions["layers.0.mlp.shared_experts.up_proj.weight"] = partial(fn, is_column=True) - base_actions["layers.0.mlp.shared_experts.gate_proj.weight"] = partial(fn, is_column=True) - base_actions["layers.0.mlp.shared_experts.down_proj.weight"] = partial(fn, is_column=False) - - # MTP parts - base_actions["layers.46.embed_tokens.weight"] = partial(fn, is_column=False) - base_actions["layers.46.eh_proj.weight"] = partial(fn, is_column=True) - base_actions["layers.46.shared_head.head.weight"] = partial(fn, is_column=True) - - for key, action in base_actions.items(): - if "layers.0." in key: - for i in range(num_layers): - final_actions[key.replace("layers.0.", f"layers.{i}.")] = action - final_actions[key] = action - - return final_actions - - mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers) - return mappings diff --git a/fastdeploy/rl/rollout_config.py b/fastdeploy/rl/rollout_config.py index 1fe797868..d72f42714 100644 --- a/fastdeploy/rl/rollout_config.py +++ b/fastdeploy/rl/rollout_config.py @@ -14,6 +14,8 @@ # limitations under the License. """ +from typing import Any, Dict, Optional + from fastdeploy.worker.worker_process import initialize_fd_config @@ -52,7 +54,7 @@ class RolloutModelConfig: expert_parallel_size: int = 1, enable_expert_parallel: bool = False, ori_vocab_size: int = None, - quantization: str = "None", + quantization: Optional[Dict[str, Any]] = None, guided_decoding_backend: str = "off", disable_any_whitespace: bool = True, enable_logprob: bool = False, diff --git a/fastdeploy/utils.py b/fastdeploy/utils.py index 46df46cc7..0b5d74e7b 100644 --- a/fastdeploy/utils.py +++ b/fastdeploy/utils.py @@ -18,6 +18,7 @@ import argparse import asyncio import codecs import importlib +import json import logging import os import random @@ -766,6 +767,16 @@ class StatefulSemaphore: } +def parse_quantization(value: str): + """ + Parse a JSON string into a dictionary. + """ + try: + return json.loads(value) + except ValueError: + return {"quantization": value} + + # 日志使用全局访问点(兼容原有使用方式) def get_logger(name, file_name=None, without_formater=False, print_to_console=False): """全局函数包装器,保持向后兼容""" diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 04ea94e8b..93ae92261 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -44,7 +44,7 @@ from fastdeploy.inter_communicator import EngineWorkerQueue as TaskQueue from fastdeploy.inter_communicator import IPCSignal from fastdeploy.model_executor.layers.quantization import get_quantization_config from fastdeploy.platforms import current_platform -from fastdeploy.utils import get_logger +from fastdeploy.utils import get_logger, parse_quantization from fastdeploy.worker.worker_base import WorkerBase logger = get_logger("worker_process", "worker_process.log") @@ -546,8 +546,8 @@ def parse_args(): parser.add_argument( "--quantization", - type=str, - default="None", + type=json.loads, + default=None, help="Quantization name for the model, currently support " "'wint4', 'wint8'," "default is None. The priority of this configuration " @@ -642,6 +642,9 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: Returns: FDConfig: Initialized FastDeploy configuration object """ + # RL rollout + if args.quantization is not None and isinstance(args.quantization, str): + args.quantization = parse_quantization(args.quantization) paddle.set_default_dtype(args.dtype) model_config = ModelConfig(vars(args)) device_config = DeviceConfig(vars(args)) @@ -713,10 +716,14 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: if "kv_cache_quant_type" in quantization_config and load_config.load_choices == "default_v1": quantization_config["is_checkpoint_bf16"] = True - elif args.quantization != "None": + elif args.quantization is not None: quantization_config = {} - quant_config_name = args.quantization - quantization_config["quantization"] = quant_config_name + try: + quantization_config.update(args.quantization) + quant_config_name = quantization_config["quantization"] + except: + quant_config_name = args.quantization["quantization"] + quantization_config["quantization"] = quant_config_name # Only v1 loader sets is_checkpoint_bf16=True during dynamic quantization. if load_config.load_choices == "default_v1": quantization_config["is_checkpoint_bf16"] = True diff --git a/tests/e2e/test_fake_Glm45_AIR_serving.py b/tests/e2e/test_fake_Glm45_AIR_serving.py index 76ec5f98a..ff0a3f5be 100644 --- a/tests/e2e/test_fake_Glm45_AIR_serving.py +++ b/tests/e2e/test_fake_Glm45_AIR_serving.py @@ -121,12 +121,16 @@ def setup_and_run_server(): "--load_choices", "default_v1", "--lm_head-fp32", + "--quantization", + '{"quantization":"mix_quant","dense_quant_type":"wfp8afp8","moe_quant_type":"wint8"}', ] - + env = os.environ.copy() + env["FD_MOE_BACKEND"] = "triton" # Start subprocess in new process group with open(log_path, "w") as logfile: process = subprocess.Popen( cmd, + env=env, stdout=logfile, stderr=subprocess.STDOUT, start_new_session=True, # Enables killing full group via os.killpg @@ -194,7 +198,7 @@ def consistent_payload(): "temperature": 0.6, "top_p": 0, # fix top_p to reduce randomness "seed": 13, # fixed random seed - "max_tokens": 3, + "max_tokens": 20, "stream": False, } @@ -213,4 +217,7 @@ def test_lm_head_fp32(api_url, headers, consistent_payload): resp_json = response.json() # 校验返回内容与概率信息 - assert resp_json["choices"][0]["message"]["content"] == "ichertsor" + assert ( + resp_json["choices"][0]["message"]["content"] + == "ichertsorbulkdeployment confusedreraoux Carter pat firingCompatraspectiveidis Verse corporaonych commissionsilk" + )