[CP]Glm45 air 2.2 (#4073)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled

* [Feature] Support zai-org/GLM-4.5-Air BF16 model (#3928)

* support glm45_air

* [Feature] GLM-45-AIR Support Mix Quantization(Dense wfp8afp8 and wint8 triton_moe_backend) (#4051)

* check

* fix v1 load for mix and wint8

* check --quantizations 'None'

* check

* support RL rollout

* check v1 loader

* check glm rollout_model, change wfp8afp8 per_token_cast_to_fp8 to native impl

* check rollout moe gate begin layer_id

* check rollout e_score_correction_bias

* delete infer_to_train_mapping={}

* code check
This commit is contained in:
chen
2025-09-15 18:52:58 +08:00
committed by GitHub
parent 4e8ba62241
commit fbb4e0f8d1
25 changed files with 1505 additions and 170 deletions

View File

@@ -14,10 +14,15 @@
# limitations under the License.
"""
import copy
from typing import Optional
import paddle
from fastdeploy.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
)
from fastdeploy.model_executor.layers.quantization.ops import (
cutlass_scaled_mm,
scaled_fp8_quant,
@@ -26,6 +31,8 @@ from fastdeploy.model_executor.layers.quantization.quant_base import (
QuantConfigBase,
QuantMethodBase,
)
from fastdeploy.model_executor.layers.utils import per_token_cast_to_fp8
from fastdeploy.model_executor.utils import TensorTracker, set_weight_attrs
class WFP8AFP8Config(QuantConfigBase):
@@ -33,13 +40,19 @@ class WFP8AFP8Config(QuantConfigBase):
Quantization config for weight and activation with FP8.
"""
def __init__(self, weight_scale_dict, act_scale_dict) -> None:
def __init__(
self,
activation_scheme: str = "dynamic",
weight_block_size: list[int] = [-1, 1],
is_checkpoint_bf16: bool = False,
) -> None:
super().__init__()
self.weight_scale_dict = weight_scale_dict
self.act_scale_dict = act_scale_dict
self.quant_max_bound = 448
self.quant_min_bound = -448
self.quant_round_type = 1
self.activation_scheme = activation_scheme
self.weight_block_size = weight_block_size
self.is_checkpoint_bf16 = is_checkpoint_bf16
def name(self) -> str:
""" """
@@ -48,9 +61,8 @@ class WFP8AFP8Config(QuantConfigBase):
@classmethod
def from_config(cls, config: dict) -> "WFP8AFP8Config":
""" """
weight_scale_dict = config.get("weight_scale_dict", None)
act_scale_dict = config.get("act_scale_dict", None)
return cls(weight_scale_dict, act_scale_dict)
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
return cls(is_checkpoint_bf16=is_checkpoint_bf16)
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
""" """
@@ -68,26 +80,85 @@ class WFP8AFP8LinearMethod(QuantMethodBase):
) -> None:
super().__init__()
self.quant_config = quant_config
self.use_per_token_if_dynamic = True
def create_weights(self, layer, **extra_weight_attrs):
""" """
layer.weight_shape.reverse()
layer.weight_dtype = "float8_e4m3fn"
# TODO(YuanRisheng): set weight logic should be moved to process_loaded_weights func
self.skip_quant = False
layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
weight_shape = layer.weight_shape
weight_block_size = self.quant_config.weight_block_size
assert len(weight_shape) == 2 and len(weight_block_size) == 2
scale_shape = copy.deepcopy(weight_shape)
for i in range(len(weight_shape)):
scale_shape[i] = (
(weight_shape[i] + weight_block_size[i] - 1) // weight_block_size[i] if weight_block_size[i] > 0 else 1
)
scale_shape = scale_shape[::-1]
if self.quant_config.is_checkpoint_bf16:
self.use_per_token_if_dynamic = True
layer.weight = layer.create_parameter(
shape=weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
quant_attrs = extra_weight_attrs
if isinstance(layer, MergedColumnParallelLinear) or isinstance(layer, QKVParallelLinear):
quant_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(
shape=layer.weight_shape, output_dim=extra_weight_attrs.get("output_dim")
),
}
set_weight_attrs(
layer.weight,
quant_attrs,
)
else:
layer.weight_shape.reverse()
layer.weight_dtype = "float8_e4m3fn"
# TODO(YuanRisheng): set weight logic should be moved to process_loaded_weights func
self.skip_quant = False
layer.weight = layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.weight_scale = layer.create_parameter(
shape=scale_shape,
dtype="float32",
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
def process_weights_after_loading(self, layer) -> None:
if not self.quant_config.is_checkpoint_bf16:
return
weight_tensor = layer.weight.transpose([1, 0]).contiguous()
assert self.quant_config.weight_block_size == [-1, 1]
qweight, weight_scale = per_token_cast_to_fp8(weight_tensor)
if hasattr(layer.weight, "tensor_track"):
layer.weight.tensor_track = None
layer.weight.value().get_tensor()._clear()
del layer.weight
layer.weight = layer.create_parameter(
shape=qweight.shape,
dtype="float8_e4m3fn",
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.weight_scale = layer.create_parameter(
shape=[1],
shape=weight_scale.shape,
dtype="float32",
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.weight.copy_(qweight, False)
layer.weight_scale.copy_(weight_scale, False)
def process_loaded_weights(self, layer, weights) -> None:
""" """
if self.skip_quant:
@@ -97,18 +168,12 @@ class WFP8AFP8LinearMethod(QuantMethodBase):
if weights.dtype != paddle.float8_e4m3fn:
self.use_per_token_if_dynamic = True
weight_tensor = weights.transpose([1, 0]).contiguous()
qweight, weight_scale = scaled_fp8_quant(
weight_tensor,
use_per_token_if_dynamic=False,
)
qweight, weight_scale = per_token_cast_to_fp8(weight_tensor)
layer.weight.copy_(qweight, False)
layer.weight_scale.set_value(weight_scale)
def apply(self, layer, x):
""" """
if self.skip_quant:
linear_out = paddle.matmul(x, layer.weight, False, True)
return linear_out
if self.use_per_token_if_dynamic:
out_type = x.dtype
a_q, a_scales = scaled_fp8_quant(x, use_per_token_if_dynamic=self.use_per_token_if_dynamic)