mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 00:33:03 +08:00
[Feature] GLM-45-AIR Support Mix Quantization(Dense wfp8afp8 and wint8 triton_moe_backend) (#4051)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
This commit is contained in:
@@ -398,7 +398,7 @@ class SpeculativeConfig:
|
|||||||
# model for mtp/eagle/draft_model
|
# model for mtp/eagle/draft_model
|
||||||
self.model: Optional[str] = None
|
self.model: Optional[str] = None
|
||||||
# quantization of model
|
# quantization of model
|
||||||
self.quantization: Optional[str] = None
|
self.quantization: Optional[Dict[str, Any]] = None
|
||||||
# allocate more blocks to prevent mtp from finishing the block earlier than the main model
|
# allocate more blocks to prevent mtp from finishing the block earlier than the main model
|
||||||
# Fixed now
|
# Fixed now
|
||||||
self.num_gpu_block_expand_ratio: Optional[float] = 1
|
self.num_gpu_block_expand_ratio: Optional[float] = 1
|
||||||
|
@@ -40,6 +40,7 @@ from fastdeploy.utils import (
|
|||||||
DeprecatedOptionWarning,
|
DeprecatedOptionWarning,
|
||||||
FlexibleArgumentParser,
|
FlexibleArgumentParser,
|
||||||
is_port_available,
|
is_port_available,
|
||||||
|
parse_quantization,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -137,7 +138,7 @@ class EngineArgs:
|
|||||||
"""
|
"""
|
||||||
dynamic load weight strategy
|
dynamic load weight strategy
|
||||||
"""
|
"""
|
||||||
quantization: str = None
|
quantization: Optional[Dict[str, Any]] = None
|
||||||
guided_decoding_backend: str = "off"
|
guided_decoding_backend: str = "off"
|
||||||
"""
|
"""
|
||||||
Guided decoding backend.
|
Guided decoding backend.
|
||||||
@@ -538,7 +539,7 @@ class EngineArgs:
|
|||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
"--quantization",
|
"--quantization",
|
||||||
type=str,
|
type=parse_quantization,
|
||||||
default=EngineArgs.quantization,
|
default=EngineArgs.quantization,
|
||||||
help="Quantization name for the model, currently support "
|
help="Quantization name for the model, currently support "
|
||||||
"'wint8', 'wint4',"
|
"'wint8', 'wint4',"
|
||||||
|
@@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
@@ -484,7 +485,7 @@ class LLMEngine:
|
|||||||
f" --kv_cache_ratio {self.cfg.cache_config.kv_cache_ratio}"
|
f" --kv_cache_ratio {self.cfg.cache_config.kv_cache_ratio}"
|
||||||
f" --expert_parallel_size {self.cfg.parallel_config.expert_parallel_size}"
|
f" --expert_parallel_size {self.cfg.parallel_config.expert_parallel_size}"
|
||||||
f" --data_parallel_size {self.cfg.parallel_config.data_parallel_size}"
|
f" --data_parallel_size {self.cfg.parallel_config.data_parallel_size}"
|
||||||
f" --quantization {self.cfg.model_config.quantization}"
|
f" --quantization '{json.dumps(self.cfg.model_config.quantization)}'"
|
||||||
f" --ori_vocab_size {ori_vocab_size}"
|
f" --ori_vocab_size {ori_vocab_size}"
|
||||||
f" --speculative_config '{self.cfg.speculative_config.to_json_string()}'"
|
f" --speculative_config '{self.cfg.speculative_config.to_json_string()}'"
|
||||||
f" --graph_optimization_config '{self.cfg.graph_opt_config.to_json_string()}'"
|
f" --graph_optimization_config '{self.cfg.graph_opt_config.to_json_string()}'"
|
||||||
|
@@ -28,38 +28,9 @@ except:
|
|||||||
|
|
||||||
import fastdeploy
|
import fastdeploy
|
||||||
from fastdeploy.config import MoEPhase
|
from fastdeploy.config import MoEPhase
|
||||||
|
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
|
||||||
from fastdeploy.utils import singleton
|
from fastdeploy.utils import singleton
|
||||||
|
|
||||||
try:
|
|
||||||
from fastdeploy.model_executor.ops.gpu import noaux_tc
|
|
||||||
except:
|
|
||||||
logger.warning("import noaux_tc Failed!")
|
|
||||||
|
|
||||||
|
|
||||||
def get_moe_scores(
|
|
||||||
gating_output: paddle.Tensor,
|
|
||||||
n_group,
|
|
||||||
topk_group,
|
|
||||||
top_k,
|
|
||||||
routed_scaling_factor,
|
|
||||||
e_score_correction_bias,
|
|
||||||
) -> paddle.Tensor:
|
|
||||||
"""
|
|
||||||
compute moe scores using e_score_correction_bias.
|
|
||||||
"""
|
|
||||||
scores = paddle.nn.functional.sigmoid(gating_output)
|
|
||||||
assert e_score_correction_bias is not None, "e_score_correction_bias is none!"
|
|
||||||
scores_with_bias = scores + e_score_correction_bias
|
|
||||||
scores, topk_values, topk_idx = noaux_tc(
|
|
||||||
scores,
|
|
||||||
scores_with_bias,
|
|
||||||
n_group if n_group > 0 else 1,
|
|
||||||
topk_group if topk_group > 0 else 1,
|
|
||||||
top_k,
|
|
||||||
routed_scaling_factor,
|
|
||||||
)
|
|
||||||
return scores, topk_values, topk_idx
|
|
||||||
|
|
||||||
|
|
||||||
@singleton
|
@singleton
|
||||||
class DeepEPEngine:
|
class DeepEPEngine:
|
||||||
|
@@ -27,11 +27,7 @@ from ..utils import get_tensor
|
|||||||
from .fused_moe_backend_base import UnquantizedFusedMoEMethod
|
from .fused_moe_backend_base import UnquantizedFusedMoEMethod
|
||||||
|
|
||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda():
|
||||||
from fastdeploy.model_executor.ops.gpu import (
|
from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch, moe_expert_reduce
|
||||||
moe_expert_dispatch,
|
|
||||||
moe_expert_reduce,
|
|
||||||
noaux_tc,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from fastdeploy.model_executor.ops.gpu import w4afp8_gemm_scale_permute
|
from fastdeploy.model_executor.ops.gpu import w4afp8_gemm_scale_permute
|
||||||
@@ -43,34 +39,10 @@ elif current_platform.is_iluvatar():
|
|||||||
moe_expert_reduce,
|
moe_expert_reduce,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
|
||||||
from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs
|
from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs
|
||||||
|
|
||||||
|
|
||||||
# used for deepseek_v3
|
|
||||||
def get_moe_scores(
|
|
||||||
gating_output: paddle.Tensor,
|
|
||||||
n_group,
|
|
||||||
topk_group,
|
|
||||||
top_k,
|
|
||||||
routed_scaling_factor,
|
|
||||||
e_score_correction_bias,
|
|
||||||
) -> paddle.Tensor:
|
|
||||||
"""
|
|
||||||
compute moe scores using e_score_correction_bias.
|
|
||||||
"""
|
|
||||||
scores = paddle.nn.functional.sigmoid(gating_output)
|
|
||||||
scores_with_bias = scores + e_score_correction_bias
|
|
||||||
scores, topk_values, topk_idx = noaux_tc(
|
|
||||||
scores,
|
|
||||||
scores_with_bias,
|
|
||||||
n_group,
|
|
||||||
topk_group,
|
|
||||||
top_k,
|
|
||||||
routed_scaling_factor,
|
|
||||||
)
|
|
||||||
return scores, topk_values, topk_idx
|
|
||||||
|
|
||||||
|
|
||||||
class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||||
"""
|
"""
|
||||||
Use Cutlass Group Gemm to compute Fused MoE.
|
Use Cutlass Group Gemm to compute Fused MoE.
|
||||||
|
@@ -481,7 +481,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
|||||||
gate_out = gate(x.cast("float32"))
|
gate_out = gate(x.cast("float32"))
|
||||||
|
|
||||||
if layer.topk_method == "noaux_tc":
|
if layer.topk_method == "noaux_tc":
|
||||||
from .ep import get_moe_scores
|
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
|
||||||
|
|
||||||
_, topk_weights, topk_ids = get_moe_scores(
|
_, topk_weights, topk_ids = get_moe_scores(
|
||||||
gate_out,
|
gate_out,
|
||||||
|
@@ -19,39 +19,15 @@ from paddle import nn
|
|||||||
|
|
||||||
import fastdeploy
|
import fastdeploy
|
||||||
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||||
|
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
|
||||||
from fastdeploy.model_executor.ops.gpu import (
|
from fastdeploy.model_executor.ops.gpu import (
|
||||||
MoeWna16MarlinGemmApi,
|
MoeWna16MarlinGemmApi,
|
||||||
noaux_tc,
|
|
||||||
tritonmoe_preprocess_func,
|
tritonmoe_preprocess_func,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..quantization.quant_base import QuantMethodBase
|
from ..quantization.quant_base import QuantMethodBase
|
||||||
|
|
||||||
|
|
||||||
def get_moe_scores(
|
|
||||||
gating_output: paddle.Tensor,
|
|
||||||
n_group,
|
|
||||||
topk_group,
|
|
||||||
top_k,
|
|
||||||
routed_scaling_factor,
|
|
||||||
e_score_correction_bias,
|
|
||||||
) -> paddle.Tensor:
|
|
||||||
"""
|
|
||||||
compute moe scores using e_score_correction_bias.
|
|
||||||
"""
|
|
||||||
scores = paddle.nn.functional.sigmoid(gating_output)
|
|
||||||
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0)
|
|
||||||
scores, topk_values, topk_idx = noaux_tc(
|
|
||||||
scores,
|
|
||||||
scores_with_bias,
|
|
||||||
n_group,
|
|
||||||
topk_group,
|
|
||||||
top_k,
|
|
||||||
routed_scaling_factor,
|
|
||||||
)
|
|
||||||
return scores, topk_values, topk_idx
|
|
||||||
|
|
||||||
|
|
||||||
def gptq_marlin_moe_repack(
|
def gptq_marlin_moe_repack(
|
||||||
b_q_weight: paddle.Tensor,
|
b_q_weight: paddle.Tensor,
|
||||||
perm: paddle.Tensor,
|
perm: paddle.Tensor,
|
||||||
|
@@ -31,6 +31,7 @@ try:
|
|||||||
from .triton_moe_kernels import fused_moe_kernel_paddle
|
from .triton_moe_kernels import fused_moe_kernel_paddle
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
|
||||||
|
|
||||||
|
|
||||||
class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||||
@@ -71,43 +72,70 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
|||||||
layer.moe_intermediate_size,
|
layer.moe_intermediate_size,
|
||||||
layer.hidden_size,
|
layer.hidden_size,
|
||||||
]
|
]
|
||||||
setattr(
|
if self.quant_config.is_checkpoint_bf16:
|
||||||
layer,
|
layer.up_gate_proj_weight = layer.create_parameter(
|
||||||
up_gate_proj_weight_name,
|
|
||||||
layer.create_parameter(
|
|
||||||
shape=self.up_gate_proj_weight_shape,
|
shape=self.up_gate_proj_weight_shape,
|
||||||
dtype=self.weight_dtype,
|
dtype=layer.weight_dtype,
|
||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
),
|
)
|
||||||
)
|
|
||||||
setattr(
|
layer.down_proj_weight = layer.create_parameter(
|
||||||
layer,
|
|
||||||
down_proj_weight_name,
|
|
||||||
layer.create_parameter(
|
|
||||||
shape=self.down_proj_weight_shape,
|
shape=self.down_proj_weight_shape,
|
||||||
dtype=self.weight_dtype,
|
dtype=layer.weight_dtype,
|
||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
),
|
)
|
||||||
)
|
set_weight_attrs(
|
||||||
# weight_scale
|
layer.up_gate_proj_weight,
|
||||||
setattr(
|
{
|
||||||
layer,
|
**extra_weight_attrs,
|
||||||
self.added_scale_attrs[0],
|
"tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=True),
|
||||||
layer.create_parameter(
|
},
|
||||||
shape=[layer.num_local_experts, layer.moe_intermediate_size * 2],
|
)
|
||||||
dtype=self.default_dtype,
|
set_weight_attrs(
|
||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
layer.down_proj_weight,
|
||||||
),
|
{
|
||||||
)
|
**extra_weight_attrs,
|
||||||
setattr(
|
"tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=False),
|
||||||
layer,
|
},
|
||||||
self.added_scale_attrs[1],
|
)
|
||||||
layer.create_parameter(
|
else:
|
||||||
shape=[layer.num_local_experts, layer.hidden_size],
|
setattr(
|
||||||
dtype=self.default_dtype,
|
layer,
|
||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
up_gate_proj_weight_name,
|
||||||
),
|
layer.create_parameter(
|
||||||
)
|
shape=self.up_gate_proj_weight_shape,
|
||||||
|
dtype=self.weight_dtype,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
setattr(
|
||||||
|
layer,
|
||||||
|
down_proj_weight_name,
|
||||||
|
layer.create_parameter(
|
||||||
|
shape=self.down_proj_weight_shape,
|
||||||
|
dtype=self.weight_dtype,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# weight_scale
|
||||||
|
setattr(
|
||||||
|
layer,
|
||||||
|
self.added_scale_attrs[0],
|
||||||
|
layer.create_parameter(
|
||||||
|
shape=[layer.num_local_experts, layer.moe_intermediate_size * 2],
|
||||||
|
dtype=self.default_dtype,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
setattr(
|
||||||
|
layer,
|
||||||
|
self.added_scale_attrs[1],
|
||||||
|
layer.create_parameter(
|
||||||
|
shape=[layer.num_local_experts, layer.hidden_size],
|
||||||
|
dtype=self.default_dtype,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
||||||
"""
|
"""
|
||||||
@@ -150,6 +178,62 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
|||||||
getattr(layer, weight_name).set_value(quanted_weight)
|
getattr(layer, weight_name).set_value(quanted_weight)
|
||||||
getattr(layer, scale_name).set_value(quanted_weight_scale)
|
getattr(layer, scale_name).set_value(quanted_weight_scale)
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer):
|
||||||
|
""" """
|
||||||
|
if not self.quant_config.is_checkpoint_bf16:
|
||||||
|
return
|
||||||
|
|
||||||
|
algo = layer.quant_method.quant_config.name()
|
||||||
|
assert algo == "wint8"
|
||||||
|
max_bound = 127
|
||||||
|
weight_id_map = {"gate_up": 0, "down": 1}
|
||||||
|
if (
|
||||||
|
hasattr(layer.up_gate_proj_weight, "tensor_track")
|
||||||
|
and layer.up_gate_proj_weight.tensor_track is not None
|
||||||
|
and layer.up_gate_proj_weight.tensor_track.is_fully_copied()
|
||||||
|
):
|
||||||
|
weight_type = "gate_up"
|
||||||
|
layer.up_gate_proj_weight.tensor_track = None
|
||||||
|
else:
|
||||||
|
weight_type = "down"
|
||||||
|
layer.down_proj_weight.tensor_track = None
|
||||||
|
|
||||||
|
# weight
|
||||||
|
weight_name = self.added_weight_attrs[weight_id_map[weight_type]]
|
||||||
|
# scale
|
||||||
|
scale_name = self.added_scale_attrs[weight_id_map[weight_type]]
|
||||||
|
|
||||||
|
weight_tensor = getattr(layer, weight_name)
|
||||||
|
quanted_weight_scale = weight_tensor.abs().max(axis=1)
|
||||||
|
quanted_weight = weight_tensor / quanted_weight_scale[:, None, :] * max_bound
|
||||||
|
quanted_weight = paddle.round(quanted_weight).astype("int8")
|
||||||
|
quanted_weight_scale = quanted_weight_scale / max_bound
|
||||||
|
|
||||||
|
getattr(layer, weight_name).value().get_tensor()._clear()
|
||||||
|
|
||||||
|
# create weight
|
||||||
|
setattr(
|
||||||
|
layer,
|
||||||
|
weight_name,
|
||||||
|
layer.create_parameter(
|
||||||
|
shape=weight_tensor.shape,
|
||||||
|
dtype=quanted_weight.dtype,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# create scale
|
||||||
|
setattr(
|
||||||
|
layer,
|
||||||
|
scale_name,
|
||||||
|
layer.create_parameter(
|
||||||
|
shape=quanted_weight_scale.shape,
|
||||||
|
dtype=quanted_weight_scale.dtype,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
getattr(layer, weight_name).copy_(quanted_weight, False)
|
||||||
|
getattr(layer, scale_name).copy_(quanted_weight_scale, False)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
layer: nn.Layer,
|
layer: nn.Layer,
|
||||||
@@ -167,13 +251,24 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
|||||||
moe_intermediate_size = layer.moe_intermediate_size
|
moe_intermediate_size = layer.moe_intermediate_size
|
||||||
hidden_size = layer.hidden_size
|
hidden_size = layer.hidden_size
|
||||||
|
|
||||||
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
if layer.topk_method == "noaux_tc":
|
||||||
gate_out,
|
gate_out, topk_weights, topk_ids = get_moe_scores(
|
||||||
layer.gate_correction_bias,
|
gate_out,
|
||||||
top_k,
|
layer.n_group,
|
||||||
True, # apply_norm_weight,
|
layer.topk_group,
|
||||||
False,
|
layer.top_k,
|
||||||
)
|
layer.routed_scaling_factor,
|
||||||
|
layer.gate_correction_bias,
|
||||||
|
)
|
||||||
|
topk_weights, topk_ids = paddle.topk(gate_out, k=layer.top_k, axis=-1, sorted=False)
|
||||||
|
else:
|
||||||
|
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||||
|
gate_out,
|
||||||
|
layer.gate_correction_bias,
|
||||||
|
top_k,
|
||||||
|
True, # apply_norm_weight,
|
||||||
|
False,
|
||||||
|
)
|
||||||
up_gate_proj_out = paddle.empty(
|
up_gate_proj_out = paddle.empty(
|
||||||
[token_num * top_k, moe_intermediate_size * 2],
|
[token_num * top_k, moe_intermediate_size * 2],
|
||||||
dtype=x.dtype,
|
dtype=x.dtype,
|
||||||
@@ -290,6 +385,9 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
|||||||
|
|
||||||
down_proj_out.reshape_([token_num, top_k, hidden_size])
|
down_proj_out.reshape_([token_num, top_k, hidden_size])
|
||||||
out = down_proj_out.sum(axis=1)
|
out = down_proj_out.sum(axis=1)
|
||||||
|
if layer.reduce_results and layer.tp_size > 1:
|
||||||
|
tensor_model_parallel_all_reduce(out)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@@ -27,6 +27,11 @@ from fastdeploy.model_executor.utils import slice_fn
|
|||||||
from fastdeploy.platforms import current_platform
|
from fastdeploy.platforms import current_platform
|
||||||
from fastdeploy.worker.experts_manager import RedundantExpertManger
|
from fastdeploy.worker.experts_manager import RedundantExpertManger
|
||||||
|
|
||||||
|
try:
|
||||||
|
from fastdeploy.model_executor.ops.gpu import noaux_tc
|
||||||
|
except:
|
||||||
|
logger.warning("import noaux_tc Failed!")
|
||||||
|
|
||||||
|
|
||||||
def get_moe_method():
|
def get_moe_method():
|
||||||
"""
|
"""
|
||||||
@@ -54,6 +59,31 @@ def get_moe_method():
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def get_moe_scores(
|
||||||
|
gating_output: paddle.Tensor,
|
||||||
|
n_group,
|
||||||
|
topk_group,
|
||||||
|
top_k,
|
||||||
|
routed_scaling_factor,
|
||||||
|
e_score_correction_bias,
|
||||||
|
) -> paddle.Tensor:
|
||||||
|
"""
|
||||||
|
compute moe scores using e_score_correction_bias.
|
||||||
|
"""
|
||||||
|
scores = paddle.nn.functional.sigmoid(gating_output)
|
||||||
|
assert e_score_correction_bias is not None, "e_score_correction_bias is none!"
|
||||||
|
scores_with_bias = scores + e_score_correction_bias
|
||||||
|
scores, topk_values, topk_idx = noaux_tc(
|
||||||
|
scores,
|
||||||
|
scores_with_bias,
|
||||||
|
n_group if n_group > 0 else 1,
|
||||||
|
topk_group if topk_group > 0 else 1,
|
||||||
|
top_k,
|
||||||
|
routed_scaling_factor,
|
||||||
|
)
|
||||||
|
return scores, topk_values, topk_idx
|
||||||
|
|
||||||
|
|
||||||
class FusedMoE(nn.Layer):
|
class FusedMoE(nn.Layer):
|
||||||
"""
|
"""
|
||||||
FusedMoE is a layer that performs MoE (Mixture of Experts) computation.
|
FusedMoE is a layer that performs MoE (Mixture of Experts) computation.
|
||||||
|
@@ -14,10 +14,15 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
|
|
||||||
|
from fastdeploy.model_executor.layers.linear import (
|
||||||
|
MergedColumnParallelLinear,
|
||||||
|
QKVParallelLinear,
|
||||||
|
)
|
||||||
from fastdeploy.model_executor.layers.quantization.ops import (
|
from fastdeploy.model_executor.layers.quantization.ops import (
|
||||||
cutlass_scaled_mm,
|
cutlass_scaled_mm,
|
||||||
scaled_fp8_quant,
|
scaled_fp8_quant,
|
||||||
@@ -26,6 +31,7 @@ from fastdeploy.model_executor.layers.quantization.quant_base import (
|
|||||||
QuantConfigBase,
|
QuantConfigBase,
|
||||||
QuantMethodBase,
|
QuantMethodBase,
|
||||||
)
|
)
|
||||||
|
from fastdeploy.model_executor.utils import TensorTracker, set_weight_attrs
|
||||||
|
|
||||||
|
|
||||||
class WFP8AFP8Config(QuantConfigBase):
|
class WFP8AFP8Config(QuantConfigBase):
|
||||||
@@ -33,13 +39,19 @@ class WFP8AFP8Config(QuantConfigBase):
|
|||||||
Quantization config for weight and activation with FP8.
|
Quantization config for weight and activation with FP8.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, weight_scale_dict, act_scale_dict) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
activation_scheme: str = "dynamic",
|
||||||
|
weight_block_size: list[int] = [-1, 1],
|
||||||
|
is_checkpoint_bf16: bool = False,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight_scale_dict = weight_scale_dict
|
|
||||||
self.act_scale_dict = act_scale_dict
|
|
||||||
self.quant_max_bound = 448
|
self.quant_max_bound = 448
|
||||||
self.quant_min_bound = -448
|
self.quant_min_bound = -448
|
||||||
self.quant_round_type = 1
|
self.quant_round_type = 1
|
||||||
|
self.activation_scheme = activation_scheme
|
||||||
|
self.weight_block_size = weight_block_size
|
||||||
|
self.is_checkpoint_bf16 = is_checkpoint_bf16
|
||||||
|
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
""" """
|
""" """
|
||||||
@@ -48,9 +60,8 @@ class WFP8AFP8Config(QuantConfigBase):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: dict) -> "WFP8AFP8Config":
|
def from_config(cls, config: dict) -> "WFP8AFP8Config":
|
||||||
""" """
|
""" """
|
||||||
weight_scale_dict = config.get("weight_scale_dict", None)
|
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
|
||||||
act_scale_dict = config.get("act_scale_dict", None)
|
return cls(is_checkpoint_bf16=is_checkpoint_bf16)
|
||||||
return cls(weight_scale_dict, act_scale_dict)
|
|
||||||
|
|
||||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||||
""" """
|
""" """
|
||||||
@@ -68,26 +79,87 @@ class WFP8AFP8LinearMethod(QuantMethodBase):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
self.use_per_token_if_dynamic = True
|
||||||
|
|
||||||
def create_weights(self, layer, **extra_weight_attrs):
|
def create_weights(self, layer, **extra_weight_attrs):
|
||||||
""" """
|
""" """
|
||||||
layer.weight_shape.reverse()
|
weight_shape = layer.weight_shape
|
||||||
layer.weight_dtype = "float8_e4m3fn"
|
weight_block_size = self.quant_config.weight_block_size
|
||||||
# TODO(YuanRisheng): set weight logic should be moved to process_loaded_weights func
|
assert len(weight_shape) == 2 and len(weight_block_size) == 2
|
||||||
self.skip_quant = False
|
scale_shape = copy.deepcopy(weight_shape)
|
||||||
layer.create_parameter(
|
for i in range(len(weight_shape)):
|
||||||
shape=layer.weight_shape,
|
scale_shape[i] = (
|
||||||
dtype=layer.weight_dtype,
|
(weight_shape[i] + weight_block_size[i] - 1) // weight_block_size[i] if weight_block_size[i] > 0 else 1
|
||||||
|
)
|
||||||
|
scale_shape = scale_shape[::-1]
|
||||||
|
if self.quant_config.is_checkpoint_bf16:
|
||||||
|
layer.weight = layer.create_parameter(
|
||||||
|
shape=weight_shape,
|
||||||
|
dtype=layer.weight_dtype,
|
||||||
|
is_bias=False,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
)
|
||||||
|
quant_attrs = extra_weight_attrs
|
||||||
|
if isinstance(layer, MergedColumnParallelLinear) or isinstance(layer, QKVParallelLinear):
|
||||||
|
quant_attrs = {
|
||||||
|
**extra_weight_attrs,
|
||||||
|
"tensor_track": TensorTracker(
|
||||||
|
shape=layer.weight_shape, output_dim=extra_weight_attrs.get("output_dim")
|
||||||
|
),
|
||||||
|
}
|
||||||
|
set_weight_attrs(
|
||||||
|
layer.weight,
|
||||||
|
quant_attrs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
layer.weight_shape.reverse()
|
||||||
|
layer.weight_dtype = "float8_e4m3fn"
|
||||||
|
# TODO(YuanRisheng): set weight logic should be moved to process_loaded_weights func
|
||||||
|
self.skip_quant = False
|
||||||
|
layer.create_parameter(
|
||||||
|
shape=layer.weight_shape,
|
||||||
|
dtype=layer.weight_dtype,
|
||||||
|
is_bias=False,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
)
|
||||||
|
layer.weight_scale = layer.create_parameter(
|
||||||
|
shape=scale_shape,
|
||||||
|
dtype="float32",
|
||||||
|
is_bias=False,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer) -> None:
|
||||||
|
if not self.quant_config.is_checkpoint_bf16:
|
||||||
|
return
|
||||||
|
weight_tensor = layer.weight.transpose([1, 0]).contiguous()
|
||||||
|
assert self.quant_config.weight_block_size == [-1, 1]
|
||||||
|
qweight, weight_scale = scaled_fp8_quant(
|
||||||
|
weight_tensor,
|
||||||
|
use_per_token_if_dynamic=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(layer.weight, "tensor_track"):
|
||||||
|
layer.weight.tensor_track = None
|
||||||
|
layer.weight.value().get_tensor()._clear()
|
||||||
|
del layer.weight
|
||||||
|
|
||||||
|
layer.weight = layer.create_parameter(
|
||||||
|
shape=qweight.shape,
|
||||||
|
dtype="float8_e4m3fn",
|
||||||
is_bias=False,
|
is_bias=False,
|
||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
)
|
)
|
||||||
layer.weight_scale = layer.create_parameter(
|
layer.weight_scale = layer.create_parameter(
|
||||||
shape=[1],
|
shape=weight_scale.shape,
|
||||||
dtype="float32",
|
dtype="float32",
|
||||||
is_bias=False,
|
is_bias=False,
|
||||||
default_initializer=paddle.nn.initializer.Constant(0),
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
layer.weight.copy_(qweight, False)
|
||||||
|
layer.weight_scale.copy_(weight_scale, False)
|
||||||
|
|
||||||
def process_loaded_weights(self, layer, weights) -> None:
|
def process_loaded_weights(self, layer, weights) -> None:
|
||||||
""" """
|
""" """
|
||||||
if self.skip_quant:
|
if self.skip_quant:
|
||||||
@@ -106,9 +178,6 @@ class WFP8AFP8LinearMethod(QuantMethodBase):
|
|||||||
|
|
||||||
def apply(self, layer, x):
|
def apply(self, layer, x):
|
||||||
""" """
|
""" """
|
||||||
if self.skip_quant:
|
|
||||||
linear_out = paddle.matmul(x, layer.weight, False, True)
|
|
||||||
return linear_out
|
|
||||||
if self.use_per_token_if_dynamic:
|
if self.use_per_token_if_dynamic:
|
||||||
out_type = x.dtype
|
out_type = x.dtype
|
||||||
a_q, a_scales = scaled_fp8_quant(x, use_per_token_if_dynamic=self.use_per_token_if_dynamic)
|
a_q, a_scales = scaled_fp8_quant(x, use_per_token_if_dynamic=self.use_per_token_if_dynamic)
|
||||||
|
@@ -17,12 +17,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
from paddle import nn
|
from paddle import nn
|
||||||
from paddleformers.transformers import PretrainedModel
|
|
||||||
from paddleformers.utils.log import logger
|
|
||||||
|
|
||||||
from fastdeploy.config import FDConfig
|
from fastdeploy.config import FDConfig
|
||||||
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||||
@@ -494,81 +491,3 @@ class Glm4MoeForCausalLM(ModelForCasualLM):
|
|||||||
def clear_grpah_opt_backend(self):
|
def clear_grpah_opt_backend(self):
|
||||||
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
|
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
|
||||||
self.model.clear_grpah_opt_backend(fd_config=self.fd_config)
|
self.model.clear_grpah_opt_backend(fd_config=self.fd_config)
|
||||||
|
|
||||||
|
|
||||||
class Glm4MoePretrainedModel(PretrainedModel):
|
|
||||||
"""
|
|
||||||
Glm4MoePretrainedModel
|
|
||||||
"""
|
|
||||||
|
|
||||||
config_class = FDConfig
|
|
||||||
|
|
||||||
def _init_weight(self, layer):
|
|
||||||
"""
|
|
||||||
_init_weight
|
|
||||||
"""
|
|
||||||
return None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def arch_name(self):
|
|
||||||
return "Glm4MoeForCausalLM"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _get_tensor_parallel_mappings(cls, config, is_split=True):
|
|
||||||
|
|
||||||
logger.info("Glm4Moe inference model _get_tensor_parallel_mappings")
|
|
||||||
|
|
||||||
from paddleformers.transformers.conversion_utils import split_or_merge_func
|
|
||||||
|
|
||||||
fn = split_or_merge_func(
|
|
||||||
is_split=is_split,
|
|
||||||
tensor_parallel_degree=config.tensor_parallel_degree,
|
|
||||||
tensor_parallel_rank=config.tensor_parallel_rank,
|
|
||||||
num_attention_heads=config.num_attention_heads,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_tensor_parallel_split_mappings(num_layers):
|
|
||||||
final_actions = {}
|
|
||||||
|
|
||||||
base_actions = {
|
|
||||||
"lm_head.weight": partial(fn, is_column=True),
|
|
||||||
"embed_tokens.weight": partial(fn, is_column=False),
|
|
||||||
"layers.0.self_attn.o_proj.weight": partial(fn, is_column=False),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Self Attention Layer which are need TP.
|
|
||||||
base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True)
|
|
||||||
base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True)
|
|
||||||
base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True)
|
|
||||||
|
|
||||||
# MLP Layer
|
|
||||||
base_actions["layers.0.mlp.gate_proj.weight"] = partial(fn, is_column=True)
|
|
||||||
base_actions["layers.0.mlp.up_proj.weight"] = partial(fn, is_column=True)
|
|
||||||
base_actions["layers.0.mlp.down_proj.weight"] = partial(fn, is_column=False)
|
|
||||||
|
|
||||||
# Moe Layer
|
|
||||||
for expert_idx in range(config.n_routed_experts):
|
|
||||||
base_actions[f"layers.0.mlp.experts.{expert_idx}.up_proj.weight"] = partial(fn, is_column=True)
|
|
||||||
base_actions[f"layers.0.mlp.experts.{expert_idx}.gate_proj.weight"] = partial(fn, is_column=True)
|
|
||||||
base_actions[f"layers.0.mlp.experts.{expert_idx}.down_proj.weight"] = partial(fn, is_column=False)
|
|
||||||
|
|
||||||
# Shared Expert Layer
|
|
||||||
base_actions["layers.0.mlp.shared_experts.up_proj.weight"] = partial(fn, is_column=True)
|
|
||||||
base_actions["layers.0.mlp.shared_experts.gate_proj.weight"] = partial(fn, is_column=True)
|
|
||||||
base_actions["layers.0.mlp.shared_experts.down_proj.weight"] = partial(fn, is_column=False)
|
|
||||||
|
|
||||||
# MTP parts
|
|
||||||
base_actions["layers.46.embed_tokens.weight"] = partial(fn, is_column=False)
|
|
||||||
base_actions["layers.46.eh_proj.weight"] = partial(fn, is_column=True)
|
|
||||||
base_actions["layers.46.shared_head.head.weight"] = partial(fn, is_column=True)
|
|
||||||
|
|
||||||
for key, action in base_actions.items():
|
|
||||||
if "layers.0." in key:
|
|
||||||
for i in range(num_layers):
|
|
||||||
final_actions[key.replace("layers.0.", f"layers.{i}.")] = action
|
|
||||||
final_actions[key] = action
|
|
||||||
|
|
||||||
return final_actions
|
|
||||||
|
|
||||||
mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers)
|
|
||||||
return mappings
|
|
||||||
|
@@ -14,6 +14,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from fastdeploy.worker.worker_process import initialize_fd_config
|
from fastdeploy.worker.worker_process import initialize_fd_config
|
||||||
|
|
||||||
|
|
||||||
@@ -52,7 +54,7 @@ class RolloutModelConfig:
|
|||||||
expert_parallel_size: int = 1,
|
expert_parallel_size: int = 1,
|
||||||
enable_expert_parallel: bool = False,
|
enable_expert_parallel: bool = False,
|
||||||
ori_vocab_size: int = None,
|
ori_vocab_size: int = None,
|
||||||
quantization: str = "None",
|
quantization: Optional[Dict[str, Any]] = None,
|
||||||
guided_decoding_backend: str = "off",
|
guided_decoding_backend: str = "off",
|
||||||
disable_any_whitespace: bool = True,
|
disable_any_whitespace: bool = True,
|
||||||
enable_logprob: bool = False,
|
enable_logprob: bool = False,
|
||||||
|
@@ -18,6 +18,7 @@ import argparse
|
|||||||
import asyncio
|
import asyncio
|
||||||
import codecs
|
import codecs
|
||||||
import importlib
|
import importlib
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
@@ -766,6 +767,16 @@ class StatefulSemaphore:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_quantization(value: str):
|
||||||
|
"""
|
||||||
|
Parse a JSON string into a dictionary.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return json.loads(value)
|
||||||
|
except ValueError:
|
||||||
|
return {"quantization": value}
|
||||||
|
|
||||||
|
|
||||||
# 日志使用全局访问点(兼容原有使用方式)
|
# 日志使用全局访问点(兼容原有使用方式)
|
||||||
def get_logger(name, file_name=None, without_formater=False, print_to_console=False):
|
def get_logger(name, file_name=None, without_formater=False, print_to_console=False):
|
||||||
"""全局函数包装器,保持向后兼容"""
|
"""全局函数包装器,保持向后兼容"""
|
||||||
|
@@ -44,7 +44,7 @@ from fastdeploy.inter_communicator import EngineWorkerQueue as TaskQueue
|
|||||||
from fastdeploy.inter_communicator import IPCSignal
|
from fastdeploy.inter_communicator import IPCSignal
|
||||||
from fastdeploy.model_executor.layers.quantization import get_quantization_config
|
from fastdeploy.model_executor.layers.quantization import get_quantization_config
|
||||||
from fastdeploy.platforms import current_platform
|
from fastdeploy.platforms import current_platform
|
||||||
from fastdeploy.utils import get_logger
|
from fastdeploy.utils import get_logger, parse_quantization
|
||||||
from fastdeploy.worker.worker_base import WorkerBase
|
from fastdeploy.worker.worker_base import WorkerBase
|
||||||
|
|
||||||
logger = get_logger("worker_process", "worker_process.log")
|
logger = get_logger("worker_process", "worker_process.log")
|
||||||
@@ -546,8 +546,8 @@ def parse_args():
|
|||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--quantization",
|
"--quantization",
|
||||||
type=str,
|
type=json.loads,
|
||||||
default="None",
|
default=None,
|
||||||
help="Quantization name for the model, currently support "
|
help="Quantization name for the model, currently support "
|
||||||
"'wint4', 'wint8',"
|
"'wint4', 'wint8',"
|
||||||
"default is None. The priority of this configuration "
|
"default is None. The priority of this configuration "
|
||||||
@@ -642,6 +642,9 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
|||||||
Returns:
|
Returns:
|
||||||
FDConfig: Initialized FastDeploy configuration object
|
FDConfig: Initialized FastDeploy configuration object
|
||||||
"""
|
"""
|
||||||
|
# RL rollout
|
||||||
|
if args.quantization is not None and isinstance(args.quantization, str):
|
||||||
|
args.quantization = parse_quantization(args.quantization)
|
||||||
paddle.set_default_dtype(args.dtype)
|
paddle.set_default_dtype(args.dtype)
|
||||||
model_config = ModelConfig(vars(args))
|
model_config = ModelConfig(vars(args))
|
||||||
device_config = DeviceConfig(vars(args))
|
device_config = DeviceConfig(vars(args))
|
||||||
@@ -713,10 +716,14 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
|||||||
if "kv_cache_quant_type" in quantization_config and load_config.load_choices == "default_v1":
|
if "kv_cache_quant_type" in quantization_config and load_config.load_choices == "default_v1":
|
||||||
quantization_config["is_checkpoint_bf16"] = True
|
quantization_config["is_checkpoint_bf16"] = True
|
||||||
|
|
||||||
elif args.quantization != "None":
|
elif args.quantization is not None:
|
||||||
quantization_config = {}
|
quantization_config = {}
|
||||||
quant_config_name = args.quantization
|
try:
|
||||||
quantization_config["quantization"] = quant_config_name
|
quantization_config.update(args.quantization)
|
||||||
|
quant_config_name = quantization_config["quantization"]
|
||||||
|
except:
|
||||||
|
quant_config_name = args.quantization["quantization"]
|
||||||
|
quantization_config["quantization"] = quant_config_name
|
||||||
# Only v1 loader sets is_checkpoint_bf16=True during dynamic quantization.
|
# Only v1 loader sets is_checkpoint_bf16=True during dynamic quantization.
|
||||||
if load_config.load_choices == "default_v1":
|
if load_config.load_choices == "default_v1":
|
||||||
quantization_config["is_checkpoint_bf16"] = True
|
quantization_config["is_checkpoint_bf16"] = True
|
||||||
|
@@ -121,12 +121,16 @@ def setup_and_run_server():
|
|||||||
"--load_choices",
|
"--load_choices",
|
||||||
"default_v1",
|
"default_v1",
|
||||||
"--lm_head-fp32",
|
"--lm_head-fp32",
|
||||||
|
"--quantization",
|
||||||
|
'{"quantization":"mix_quant","dense_quant_type":"wfp8afp8","moe_quant_type":"wint8"}',
|
||||||
]
|
]
|
||||||
|
env = os.environ.copy()
|
||||||
|
env["FD_MOE_BACKEND"] = "triton"
|
||||||
# Start subprocess in new process group
|
# Start subprocess in new process group
|
||||||
with open(log_path, "w") as logfile:
|
with open(log_path, "w") as logfile:
|
||||||
process = subprocess.Popen(
|
process = subprocess.Popen(
|
||||||
cmd,
|
cmd,
|
||||||
|
env=env,
|
||||||
stdout=logfile,
|
stdout=logfile,
|
||||||
stderr=subprocess.STDOUT,
|
stderr=subprocess.STDOUT,
|
||||||
start_new_session=True, # Enables killing full group via os.killpg
|
start_new_session=True, # Enables killing full group via os.killpg
|
||||||
@@ -194,7 +198,7 @@ def consistent_payload():
|
|||||||
"temperature": 0.6,
|
"temperature": 0.6,
|
||||||
"top_p": 0, # fix top_p to reduce randomness
|
"top_p": 0, # fix top_p to reduce randomness
|
||||||
"seed": 13, # fixed random seed
|
"seed": 13, # fixed random seed
|
||||||
"max_tokens": 3,
|
"max_tokens": 20,
|
||||||
"stream": False,
|
"stream": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -213,4 +217,7 @@ def test_lm_head_fp32(api_url, headers, consistent_payload):
|
|||||||
resp_json = response.json()
|
resp_json = response.json()
|
||||||
|
|
||||||
# 校验返回内容与概率信息
|
# 校验返回内容与概率信息
|
||||||
assert resp_json["choices"][0]["message"]["content"] == "ichertsor"
|
assert (
|
||||||
|
resp_json["choices"][0]["message"]["content"]
|
||||||
|
== "ichertsorbulkdeployment confusedreraoux Carter pat firingCompatraspectiveidis Verse corporaonych commissionsilk"
|
||||||
|
)
|
||||||
|
Reference in New Issue
Block a user