mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-19 15:04:47 +08:00
[CP]Glm45 air 2.2 (#4073)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* [Feature] Support zai-org/GLM-4.5-Air BF16 model (#3928) * support glm45_air * [Feature] GLM-45-AIR Support Mix Quantization(Dense wfp8afp8 and wint8 triton_moe_backend) (#4051) * check * fix v1 load for mix and wint8 * check --quantizations 'None' * check * support RL rollout * check v1 loader * check glm rollout_model, change wfp8afp8 per_token_cast_to_fp8 to native impl * check rollout moe gate begin layer_id * check rollout e_score_correction_bias * delete infer_to_train_mapping={} * code check
This commit is contained in:
@@ -28,38 +28,9 @@ except:
|
||||
|
||||
import fastdeploy
|
||||
from fastdeploy.config import MoEPhase
|
||||
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
|
||||
from fastdeploy.utils import singleton
|
||||
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.gpu import noaux_tc
|
||||
except:
|
||||
logger.warning("import noaux_tc Failed!")
|
||||
|
||||
|
||||
def get_moe_scores(
|
||||
gating_output: paddle.Tensor,
|
||||
n_group,
|
||||
topk_group,
|
||||
top_k,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
compute moe scores using e_score_correction_bias.
|
||||
"""
|
||||
scores = paddle.nn.functional.sigmoid(gating_output)
|
||||
assert e_score_correction_bias is not None, "e_score_correction_bias is none!"
|
||||
scores_with_bias = scores + e_score_correction_bias
|
||||
scores, topk_values, topk_idx = noaux_tc(
|
||||
scores,
|
||||
scores_with_bias,
|
||||
n_group if n_group > 0 else 1,
|
||||
topk_group if topk_group > 0 else 1,
|
||||
top_k,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
return scores, topk_values, topk_idx
|
||||
|
||||
|
||||
@singleton
|
||||
class DeepEPEngine:
|
||||
|
@@ -27,11 +27,8 @@ from ..utils import get_tensor
|
||||
from .fused_moe_backend_base import UnquantizedFusedMoEMethod
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
moe_expert_dispatch,
|
||||
moe_expert_reduce,
|
||||
noaux_tc,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
|
||||
from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch, moe_expert_reduce
|
||||
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.gpu import w4afp8_gemm_scale_permute
|
||||
@@ -46,31 +43,6 @@ elif current_platform.is_iluvatar():
|
||||
from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs
|
||||
|
||||
|
||||
# used for deepseek_v3
|
||||
def get_moe_scores(
|
||||
gating_output: paddle.Tensor,
|
||||
n_group,
|
||||
topk_group,
|
||||
top_k,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
compute moe scores using e_score_correction_bias.
|
||||
"""
|
||||
scores = paddle.nn.functional.sigmoid(gating_output)
|
||||
scores_with_bias = scores + e_score_correction_bias
|
||||
scores, topk_values, topk_idx = noaux_tc(
|
||||
scores,
|
||||
scores_with_bias,
|
||||
n_group,
|
||||
topk_group,
|
||||
top_k,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
return scores, topk_values, topk_idx
|
||||
|
||||
|
||||
class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
"""
|
||||
Use Cutlass Group Gemm to compute Fused MoE.
|
||||
|
@@ -481,7 +481,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
gate_out = gate(x.cast("float32"))
|
||||
|
||||
if layer.topk_method == "noaux_tc":
|
||||
from .ep import get_moe_scores
|
||||
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
|
||||
|
||||
_, topk_weights, topk_ids = get_moe_scores(
|
||||
gate_out,
|
||||
|
@@ -19,39 +19,15 @@ from paddle import nn
|
||||
|
||||
import fastdeploy
|
||||
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
MoeWna16MarlinGemmApi,
|
||||
noaux_tc,
|
||||
tritonmoe_preprocess_func,
|
||||
)
|
||||
|
||||
from ..quantization.quant_base import QuantMethodBase
|
||||
|
||||
|
||||
def get_moe_scores(
|
||||
gating_output: paddle.Tensor,
|
||||
n_group,
|
||||
topk_group,
|
||||
top_k,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
compute moe scores using e_score_correction_bias.
|
||||
"""
|
||||
scores = paddle.nn.functional.sigmoid(gating_output)
|
||||
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0)
|
||||
scores, topk_values, topk_idx = noaux_tc(
|
||||
scores,
|
||||
scores_with_bias,
|
||||
n_group,
|
||||
topk_group,
|
||||
top_k,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
return scores, topk_values, topk_idx
|
||||
|
||||
|
||||
def gptq_marlin_moe_repack(
|
||||
b_q_weight: paddle.Tensor,
|
||||
perm: paddle.Tensor,
|
||||
|
@@ -24,7 +24,6 @@ from fastdeploy.model_executor.utils import TensorTracker, set_weight_attrs
|
||||
from fastdeploy.utils import ceil_div
|
||||
|
||||
from ..quantization.quant_base import QuantMethodBase
|
||||
from .ep import get_moe_scores
|
||||
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func
|
||||
@@ -32,6 +31,7 @@ try:
|
||||
from .triton_moe_kernels import fused_moe_kernel_paddle
|
||||
except ImportError:
|
||||
pass
|
||||
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
|
||||
|
||||
|
||||
class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
@@ -72,43 +72,70 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
layer.moe_intermediate_size,
|
||||
layer.hidden_size,
|
||||
]
|
||||
setattr(
|
||||
layer,
|
||||
up_gate_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
if self.quant_config.is_checkpoint_bf16:
|
||||
layer.up_gate_proj_weight = layer.create_parameter(
|
||||
shape=self.up_gate_proj_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
dtype=layer.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
down_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
)
|
||||
|
||||
layer.down_proj_weight = layer.create_parameter(
|
||||
shape=self.down_proj_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
dtype=layer.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
# weight_scale
|
||||
setattr(
|
||||
layer,
|
||||
self.added_scale_attrs[0],
|
||||
layer.create_parameter(
|
||||
shape=[layer.num_local_experts, layer.moe_intermediate_size * 2],
|
||||
dtype=self.default_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
self.added_scale_attrs[1],
|
||||
layer.create_parameter(
|
||||
shape=[layer.num_local_experts, layer.hidden_size],
|
||||
dtype=self.default_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
)
|
||||
set_weight_attrs(
|
||||
layer.up_gate_proj_weight,
|
||||
{
|
||||
**extra_weight_attrs,
|
||||
"tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=True),
|
||||
},
|
||||
)
|
||||
set_weight_attrs(
|
||||
layer.down_proj_weight,
|
||||
{
|
||||
**extra_weight_attrs,
|
||||
"tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=False),
|
||||
},
|
||||
)
|
||||
else:
|
||||
setattr(
|
||||
layer,
|
||||
up_gate_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
shape=self.up_gate_proj_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
down_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
shape=self.down_proj_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
# weight_scale
|
||||
setattr(
|
||||
layer,
|
||||
self.added_scale_attrs[0],
|
||||
layer.create_parameter(
|
||||
shape=[layer.num_local_experts, layer.moe_intermediate_size * 2],
|
||||
dtype=self.default_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
setattr(
|
||||
layer,
|
||||
self.added_scale_attrs[1],
|
||||
layer.create_parameter(
|
||||
shape=[layer.num_local_experts, layer.hidden_size],
|
||||
dtype=self.default_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
|
||||
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
@@ -151,6 +178,62 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
getattr(layer, weight_name).set_value(quanted_weight)
|
||||
getattr(layer, scale_name).set_value(quanted_weight_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
""" """
|
||||
if not self.quant_config.is_checkpoint_bf16:
|
||||
return
|
||||
|
||||
algo = layer.quant_method.quant_config.name()
|
||||
assert algo == "wint8"
|
||||
max_bound = 127
|
||||
weight_id_map = {"gate_up": 0, "down": 1}
|
||||
if (
|
||||
hasattr(layer.up_gate_proj_weight, "tensor_track")
|
||||
and layer.up_gate_proj_weight.tensor_track is not None
|
||||
and layer.up_gate_proj_weight.tensor_track.is_fully_copied()
|
||||
):
|
||||
weight_type = "gate_up"
|
||||
layer.up_gate_proj_weight.tensor_track = None
|
||||
else:
|
||||
weight_type = "down"
|
||||
layer.down_proj_weight.tensor_track = None
|
||||
|
||||
# weight
|
||||
weight_name = self.added_weight_attrs[weight_id_map[weight_type]]
|
||||
# scale
|
||||
scale_name = self.added_scale_attrs[weight_id_map[weight_type]]
|
||||
|
||||
weight_tensor = getattr(layer, weight_name)
|
||||
quanted_weight_scale = weight_tensor.abs().max(axis=1)
|
||||
quanted_weight = weight_tensor / quanted_weight_scale[:, None, :] * max_bound
|
||||
quanted_weight = paddle.round(quanted_weight).astype("int8")
|
||||
quanted_weight_scale = quanted_weight_scale / max_bound
|
||||
|
||||
getattr(layer, weight_name).value().get_tensor()._clear()
|
||||
|
||||
# create weight
|
||||
setattr(
|
||||
layer,
|
||||
weight_name,
|
||||
layer.create_parameter(
|
||||
shape=weight_tensor.shape,
|
||||
dtype=quanted_weight.dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
# create scale
|
||||
setattr(
|
||||
layer,
|
||||
scale_name,
|
||||
layer.create_parameter(
|
||||
shape=quanted_weight_scale.shape,
|
||||
dtype=quanted_weight_scale.dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
)
|
||||
getattr(layer, weight_name).copy_(quanted_weight, False)
|
||||
getattr(layer, scale_name).copy_(quanted_weight_scale, False)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
@@ -164,12 +247,11 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
token_num = x.shape[0]
|
||||
top_k = layer.top_k
|
||||
num_local_experts = layer.num_local_experts
|
||||
top_k = layer.top_k
|
||||
moe_intermediate_size = layer.moe_intermediate_size
|
||||
hidden_size = layer.hidden_size
|
||||
|
||||
if layer.topk_method == "noaux_tc":
|
||||
_, topk_weights, topk_ids = get_moe_scores(
|
||||
gate_out, topk_weights, topk_ids = get_moe_scores(
|
||||
gate_out,
|
||||
layer.n_group,
|
||||
layer.topk_group,
|
||||
@@ -177,15 +259,15 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
layer.routed_scaling_factor,
|
||||
layer.gate_correction_bias,
|
||||
)
|
||||
topk_weights, topk_ids = paddle.topk(gate_out, k=layer.top_k, axis=-1, sorted=False)
|
||||
else:
|
||||
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
layer.top_k,
|
||||
True, # apply_norm_weight
|
||||
top_k,
|
||||
True, # apply_norm_weight,
|
||||
False,
|
||||
)
|
||||
|
||||
up_gate_proj_out = paddle.empty(
|
||||
[token_num * top_k, moe_intermediate_size * 2],
|
||||
dtype=x.dtype,
|
||||
@@ -302,6 +384,9 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
|
||||
down_proj_out.reshape_([token_num, top_k, hidden_size])
|
||||
out = down_proj_out.sum(axis=1)
|
||||
if layer.reduce_results and layer.tp_size > 1:
|
||||
tensor_model_parallel_all_reduce(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@@ -432,7 +517,6 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
hidden_size = layer.hidden_size
|
||||
|
||||
if layer.topk_method == "noaux_tc":
|
||||
|
||||
_, topk_weights, topk_ids = get_moe_scores(
|
||||
gate_out,
|
||||
layer.n_group,
|
||||
|
@@ -27,6 +27,11 @@ from fastdeploy.model_executor.utils import slice_fn
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.worker.experts_manager import RedundantExpertManger
|
||||
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.gpu import noaux_tc
|
||||
except:
|
||||
logger.warning("import noaux_tc Failed!")
|
||||
|
||||
|
||||
def get_moe_method():
|
||||
"""
|
||||
@@ -54,6 +59,31 @@ def get_moe_method():
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def get_moe_scores(
|
||||
gating_output: paddle.Tensor,
|
||||
n_group,
|
||||
topk_group,
|
||||
top_k,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
compute moe scores using e_score_correction_bias.
|
||||
"""
|
||||
scores = paddle.nn.functional.sigmoid(gating_output)
|
||||
assert e_score_correction_bias is not None, "e_score_correction_bias is none!"
|
||||
scores_with_bias = scores + e_score_correction_bias
|
||||
scores, topk_values, topk_idx = noaux_tc(
|
||||
scores,
|
||||
scores_with_bias,
|
||||
n_group if n_group > 0 else 1,
|
||||
topk_group if topk_group > 0 else 1,
|
||||
top_k,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
return scores, topk_values, topk_idx
|
||||
|
||||
|
||||
class FusedMoE(nn.Layer):
|
||||
"""
|
||||
FusedMoE is a layer that performs MoE (Mixture of Experts) computation.
|
||||
|
@@ -76,13 +76,13 @@ class MixQuantConfig(QuantConfigBase):
|
||||
if layer.moe_tag == "Image":
|
||||
return (
|
||||
get_quantization_config(self.image_moe_quant_type)
|
||||
.from_config({"is_permuted": self.is_permuted, "self.is_checkpoint_bf16": self.is_checkpoint_bf16})
|
||||
.from_config({"is_permuted": self.is_permuted, "is_checkpoint_bf16": self.is_checkpoint_bf16})
|
||||
.get_quant_method(layer)
|
||||
)
|
||||
else:
|
||||
return (
|
||||
get_quantization_config(self.moe_quant_type)
|
||||
.from_config({"is_permuted": self.is_permuted, "self.is_checkpoint_bf16": self.is_checkpoint_bf16})
|
||||
.from_config({"is_permuted": self.is_permuted, "is_checkpoint_bf16": self.is_checkpoint_bf16})
|
||||
.get_quant_method(layer)
|
||||
)
|
||||
elif isinstance(layer, Attention):
|
||||
@@ -97,6 +97,6 @@ class MixQuantConfig(QuantConfigBase):
|
||||
else:
|
||||
return (
|
||||
get_quantization_config(self.dense_quant_type)
|
||||
.from_config({"self.is_checkpoint_bf16": self.is_checkpoint_bf16})
|
||||
.from_config({"is_checkpoint_bf16": self.is_checkpoint_bf16})
|
||||
.get_quant_method(layer)
|
||||
)
|
||||
|
@@ -44,6 +44,7 @@ class WeightOnlyConfig(QuantConfigBase):
|
||||
def __init__(
|
||||
self,
|
||||
algo: str,
|
||||
is_checkpoint_bf16: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.algo = algo
|
||||
@@ -55,6 +56,7 @@ class WeightOnlyConfig(QuantConfigBase):
|
||||
self.quant_max_bound = 0
|
||||
self.quant_min_bound = 0
|
||||
self.quant_round_type = 0
|
||||
self.is_checkpoint_bf16 = is_checkpoint_bf16
|
||||
|
||||
def name(self) -> str:
|
||||
return "weight_only"
|
||||
@@ -62,7 +64,8 @@ class WeightOnlyConfig(QuantConfigBase):
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "WeightOnlyConfig":
|
||||
algo = config["algo"]
|
||||
return cls(algo)
|
||||
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
|
||||
return cls(algo, is_checkpoint_bf16=is_checkpoint_bf16)
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
if current_platform.is_xpu():
|
||||
@@ -153,12 +156,13 @@ class WINT8Config(WeightOnlyConfig):
|
||||
weight only int8 config
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__("weight_only_int8")
|
||||
def __init__(self, is_checkpoint_bf16: bool = False) -> None:
|
||||
super().__init__("weight_only_int8", is_checkpoint_bf16)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "WINT8Config":
|
||||
return cls()
|
||||
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
|
||||
return cls(is_checkpoint_bf16)
|
||||
|
||||
def name(self) -> str:
|
||||
return "wint8"
|
||||
|
@@ -14,10 +14,15 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import copy
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.quantization.ops import (
|
||||
cutlass_scaled_mm,
|
||||
scaled_fp8_quant,
|
||||
@@ -26,6 +31,8 @@ from fastdeploy.model_executor.layers.quantization.quant_base import (
|
||||
QuantConfigBase,
|
||||
QuantMethodBase,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.utils import per_token_cast_to_fp8
|
||||
from fastdeploy.model_executor.utils import TensorTracker, set_weight_attrs
|
||||
|
||||
|
||||
class WFP8AFP8Config(QuantConfigBase):
|
||||
@@ -33,13 +40,19 @@ class WFP8AFP8Config(QuantConfigBase):
|
||||
Quantization config for weight and activation with FP8.
|
||||
"""
|
||||
|
||||
def __init__(self, weight_scale_dict, act_scale_dict) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
activation_scheme: str = "dynamic",
|
||||
weight_block_size: list[int] = [-1, 1],
|
||||
is_checkpoint_bf16: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight_scale_dict = weight_scale_dict
|
||||
self.act_scale_dict = act_scale_dict
|
||||
self.quant_max_bound = 448
|
||||
self.quant_min_bound = -448
|
||||
self.quant_round_type = 1
|
||||
self.activation_scheme = activation_scheme
|
||||
self.weight_block_size = weight_block_size
|
||||
self.is_checkpoint_bf16 = is_checkpoint_bf16
|
||||
|
||||
def name(self) -> str:
|
||||
""" """
|
||||
@@ -48,9 +61,8 @@ class WFP8AFP8Config(QuantConfigBase):
|
||||
@classmethod
|
||||
def from_config(cls, config: dict) -> "WFP8AFP8Config":
|
||||
""" """
|
||||
weight_scale_dict = config.get("weight_scale_dict", None)
|
||||
act_scale_dict = config.get("act_scale_dict", None)
|
||||
return cls(weight_scale_dict, act_scale_dict)
|
||||
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
|
||||
return cls(is_checkpoint_bf16=is_checkpoint_bf16)
|
||||
|
||||
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
|
||||
""" """
|
||||
@@ -68,26 +80,85 @@ class WFP8AFP8LinearMethod(QuantMethodBase):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.quant_config = quant_config
|
||||
self.use_per_token_if_dynamic = True
|
||||
|
||||
def create_weights(self, layer, **extra_weight_attrs):
|
||||
""" """
|
||||
layer.weight_shape.reverse()
|
||||
layer.weight_dtype = "float8_e4m3fn"
|
||||
# TODO(YuanRisheng): set weight logic should be moved to process_loaded_weights func
|
||||
self.skip_quant = False
|
||||
layer.create_parameter(
|
||||
shape=layer.weight_shape,
|
||||
dtype=layer.weight_dtype,
|
||||
weight_shape = layer.weight_shape
|
||||
weight_block_size = self.quant_config.weight_block_size
|
||||
assert len(weight_shape) == 2 and len(weight_block_size) == 2
|
||||
scale_shape = copy.deepcopy(weight_shape)
|
||||
for i in range(len(weight_shape)):
|
||||
scale_shape[i] = (
|
||||
(weight_shape[i] + weight_block_size[i] - 1) // weight_block_size[i] if weight_block_size[i] > 0 else 1
|
||||
)
|
||||
scale_shape = scale_shape[::-1]
|
||||
if self.quant_config.is_checkpoint_bf16:
|
||||
self.use_per_token_if_dynamic = True
|
||||
layer.weight = layer.create_parameter(
|
||||
shape=weight_shape,
|
||||
dtype=layer.weight_dtype,
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
quant_attrs = extra_weight_attrs
|
||||
if isinstance(layer, MergedColumnParallelLinear) or isinstance(layer, QKVParallelLinear):
|
||||
quant_attrs = {
|
||||
**extra_weight_attrs,
|
||||
"tensor_track": TensorTracker(
|
||||
shape=layer.weight_shape, output_dim=extra_weight_attrs.get("output_dim")
|
||||
),
|
||||
}
|
||||
set_weight_attrs(
|
||||
layer.weight,
|
||||
quant_attrs,
|
||||
)
|
||||
else:
|
||||
layer.weight_shape.reverse()
|
||||
layer.weight_dtype = "float8_e4m3fn"
|
||||
# TODO(YuanRisheng): set weight logic should be moved to process_loaded_weights func
|
||||
self.skip_quant = False
|
||||
layer.weight = layer.create_parameter(
|
||||
shape=layer.weight_shape,
|
||||
dtype=layer.weight_dtype,
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
layer.weight_scale = layer.create_parameter(
|
||||
shape=scale_shape,
|
||||
dtype="float32",
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
if not self.quant_config.is_checkpoint_bf16:
|
||||
return
|
||||
weight_tensor = layer.weight.transpose([1, 0]).contiguous()
|
||||
assert self.quant_config.weight_block_size == [-1, 1]
|
||||
qweight, weight_scale = per_token_cast_to_fp8(weight_tensor)
|
||||
|
||||
if hasattr(layer.weight, "tensor_track"):
|
||||
layer.weight.tensor_track = None
|
||||
layer.weight.value().get_tensor()._clear()
|
||||
del layer.weight
|
||||
|
||||
layer.weight = layer.create_parameter(
|
||||
shape=qweight.shape,
|
||||
dtype="float8_e4m3fn",
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
layer.weight_scale = layer.create_parameter(
|
||||
shape=[1],
|
||||
shape=weight_scale.shape,
|
||||
dtype="float32",
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
|
||||
layer.weight.copy_(qweight, False)
|
||||
layer.weight_scale.copy_(weight_scale, False)
|
||||
|
||||
def process_loaded_weights(self, layer, weights) -> None:
|
||||
""" """
|
||||
if self.skip_quant:
|
||||
@@ -97,18 +168,12 @@ class WFP8AFP8LinearMethod(QuantMethodBase):
|
||||
if weights.dtype != paddle.float8_e4m3fn:
|
||||
self.use_per_token_if_dynamic = True
|
||||
weight_tensor = weights.transpose([1, 0]).contiguous()
|
||||
qweight, weight_scale = scaled_fp8_quant(
|
||||
weight_tensor,
|
||||
use_per_token_if_dynamic=False,
|
||||
)
|
||||
qweight, weight_scale = per_token_cast_to_fp8(weight_tensor)
|
||||
layer.weight.copy_(qweight, False)
|
||||
layer.weight_scale.set_value(weight_scale)
|
||||
|
||||
def apply(self, layer, x):
|
||||
""" """
|
||||
if self.skip_quant:
|
||||
linear_out = paddle.matmul(x, layer.weight, False, True)
|
||||
return linear_out
|
||||
if self.use_per_token_if_dynamic:
|
||||
out_type = x.dtype
|
||||
a_q, a_scales = scaled_fp8_quant(x, use_per_token_if_dynamic=self.use_per_token_if_dynamic)
|
||||
|
@@ -73,6 +73,30 @@ class ErnieRotaryEmbedding:
|
||||
return rot_emb
|
||||
|
||||
|
||||
class GlmRotaryEmbedding:
|
||||
def __init__(self, rotary_dim, base, partial_rotary_factor):
|
||||
"""
|
||||
Pre-calculate rotary position embedding for position_ids.
|
||||
"""
|
||||
self.rotary_dim = rotary_dim
|
||||
self.base = base
|
||||
if partial_rotary_factor < 1.0:
|
||||
self.rotary_dim = int(self.rotary_dim * partial_rotary_factor)
|
||||
|
||||
def __call__(self, position_ids):
|
||||
bsz, max_seq_len = position_ids.shape[:2]
|
||||
inv_freq = self.base ** (-paddle.arange(0, self.rotary_dim, 2, dtype="float32") / self.rotary_dim)
|
||||
freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq)
|
||||
# shape: [B, S, D/2]
|
||||
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32")
|
||||
emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim // 2))
|
||||
# shape: [B, S, 1, D]
|
||||
emb = paddle.unsqueeze(emb, 2)
|
||||
rot_emb[0] = paddle.cos(emb)
|
||||
rot_emb[1] = paddle.sin(emb)
|
||||
return rot_emb
|
||||
|
||||
|
||||
class QwenRotaryEmbedding:
|
||||
def __init__(self, rotary_dim, base, partial_rotary_factor):
|
||||
"""
|
||||
@@ -246,6 +270,9 @@ def get_rope_impl(
|
||||
if model_config is None or architecture.startswith("Qwen"):
|
||||
rotary_emb_layer = QwenRotaryEmbedding(rotary_dim, base, partial_rotary_factor)
|
||||
rotary_emb = rotary_emb_layer(position_ids)
|
||||
elif architecture.startswith("Glm"):
|
||||
rotary_emb_layer = GlmRotaryEmbedding(rotary_dim, base, partial_rotary_factor)
|
||||
rotary_emb = rotary_emb_layer(position_ids)
|
||||
else:
|
||||
rotary_emb_layer = ErnieRotaryEmbedding(rotary_dim, base, partial_rotary_factor)
|
||||
rotary_emb = rotary_emb_layer(position_ids)
|
||||
|
@@ -77,6 +77,17 @@ def per_block_cast_to_fp8(x: Tensor, block_size: list = [128, 128]) -> Tuple[Ten
|
||||
)
|
||||
|
||||
|
||||
def per_token_cast_to_fp8(x: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Per token cast to float8_e4m3fn used in wfp8apf8
|
||||
"""
|
||||
x_abs = paddle.abs(x).astype(paddle.float32)
|
||||
x_max = x_abs.max(axis=-1, keepdim=True).clip_(min=1e-4)
|
||||
x_s = x_max / 448.0
|
||||
x_q = paddle.clip(x / x_s, -448.0, 448.0).astype(paddle.float8_e4m3fn)
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
# for distributed tensor model parallel
|
||||
def _set_var_distributed(var: Tensor, split_axis: int):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user