[CP]Glm45 air 2.2 (#4073)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled

* [Feature] Support zai-org/GLM-4.5-Air BF16 model (#3928)

* support glm45_air

* [Feature] GLM-45-AIR Support Mix Quantization(Dense wfp8afp8 and wint8 triton_moe_backend) (#4051)

* check

* fix v1 load for mix and wint8

* check --quantizations 'None'

* check

* support RL rollout

* check v1 loader

* check glm rollout_model, change wfp8afp8 per_token_cast_to_fp8 to native impl

* check rollout moe gate begin layer_id

* check rollout e_score_correction_bias

* delete infer_to_train_mapping={}

* code check
This commit is contained in:
chen
2025-09-15 18:52:58 +08:00
committed by GitHub
parent 4e8ba62241
commit fbb4e0f8d1
25 changed files with 1505 additions and 170 deletions

View File

@@ -28,38 +28,9 @@ except:
import fastdeploy
from fastdeploy.config import MoEPhase
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
from fastdeploy.utils import singleton
try:
from fastdeploy.model_executor.ops.gpu import noaux_tc
except:
logger.warning("import noaux_tc Failed!")
def get_moe_scores(
gating_output: paddle.Tensor,
n_group,
topk_group,
top_k,
routed_scaling_factor,
e_score_correction_bias,
) -> paddle.Tensor:
"""
compute moe scores using e_score_correction_bias.
"""
scores = paddle.nn.functional.sigmoid(gating_output)
assert e_score_correction_bias is not None, "e_score_correction_bias is none!"
scores_with_bias = scores + e_score_correction_bias
scores, topk_values, topk_idx = noaux_tc(
scores,
scores_with_bias,
n_group if n_group > 0 else 1,
topk_group if topk_group > 0 else 1,
top_k,
routed_scaling_factor,
)
return scores, topk_values, topk_idx
@singleton
class DeepEPEngine:

View File

@@ -27,11 +27,8 @@ from ..utils import get_tensor
from .fused_moe_backend_base import UnquantizedFusedMoEMethod
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import (
moe_expert_dispatch,
moe_expert_reduce,
noaux_tc,
)
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch, moe_expert_reduce
try:
from fastdeploy.model_executor.ops.gpu import w4afp8_gemm_scale_permute
@@ -46,31 +43,6 @@ elif current_platform.is_iluvatar():
from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs
# used for deepseek_v3
def get_moe_scores(
gating_output: paddle.Tensor,
n_group,
topk_group,
top_k,
routed_scaling_factor,
e_score_correction_bias,
) -> paddle.Tensor:
"""
compute moe scores using e_score_correction_bias.
"""
scores = paddle.nn.functional.sigmoid(gating_output)
scores_with_bias = scores + e_score_correction_bias
scores, topk_values, topk_idx = noaux_tc(
scores,
scores_with_bias,
n_group,
topk_group,
top_k,
routed_scaling_factor,
)
return scores, topk_values, topk_idx
class CutlassMoEMethod(UnquantizedFusedMoEMethod):
"""
Use Cutlass Group Gemm to compute Fused MoE.

View File

@@ -481,7 +481,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
gate_out = gate(x.cast("float32"))
if layer.topk_method == "noaux_tc":
from .ep import get_moe_scores
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
_, topk_weights, topk_ids = get_moe_scores(
gate_out,

View File

@@ -19,39 +19,15 @@ from paddle import nn
import fastdeploy
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
from fastdeploy.model_executor.ops.gpu import (
MoeWna16MarlinGemmApi,
noaux_tc,
tritonmoe_preprocess_func,
)
from ..quantization.quant_base import QuantMethodBase
def get_moe_scores(
gating_output: paddle.Tensor,
n_group,
topk_group,
top_k,
routed_scaling_factor,
e_score_correction_bias,
) -> paddle.Tensor:
"""
compute moe scores using e_score_correction_bias.
"""
scores = paddle.nn.functional.sigmoid(gating_output)
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0)
scores, topk_values, topk_idx = noaux_tc(
scores,
scores_with_bias,
n_group,
topk_group,
top_k,
routed_scaling_factor,
)
return scores, topk_values, topk_idx
def gptq_marlin_moe_repack(
b_q_weight: paddle.Tensor,
perm: paddle.Tensor,

View File

@@ -24,7 +24,6 @@ from fastdeploy.model_executor.utils import TensorTracker, set_weight_attrs
from fastdeploy.utils import ceil_div
from ..quantization.quant_base import QuantMethodBase
from .ep import get_moe_scores
try:
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func
@@ -32,6 +31,7 @@ try:
from .triton_moe_kernels import fused_moe_kernel_paddle
except ImportError:
pass
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
class TritonWeightOnlyMoEMethod(QuantMethodBase):
@@ -72,43 +72,70 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
layer.moe_intermediate_size,
layer.hidden_size,
]
setattr(
layer,
up_gate_proj_weight_name,
layer.create_parameter(
if self.quant_config.is_checkpoint_bf16:
layer.up_gate_proj_weight = layer.create_parameter(
shape=self.up_gate_proj_weight_shape,
dtype=self.weight_dtype,
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
down_proj_weight_name,
layer.create_parameter(
)
layer.down_proj_weight = layer.create_parameter(
shape=self.down_proj_weight_shape,
dtype=self.weight_dtype,
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# weight_scale
setattr(
layer,
self.added_scale_attrs[0],
layer.create_parameter(
shape=[layer.num_local_experts, layer.moe_intermediate_size * 2],
dtype=self.default_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
self.added_scale_attrs[1],
layer.create_parameter(
shape=[layer.num_local_experts, layer.hidden_size],
dtype=self.default_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
)
set_weight_attrs(
layer.up_gate_proj_weight,
{
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=True),
},
)
set_weight_attrs(
layer.down_proj_weight,
{
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=False),
},
)
else:
setattr(
layer,
up_gate_proj_weight_name,
layer.create_parameter(
shape=self.up_gate_proj_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
down_proj_weight_name,
layer.create_parameter(
shape=self.down_proj_weight_shape,
dtype=self.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# weight_scale
setattr(
layer,
self.added_scale_attrs[0],
layer.create_parameter(
shape=[layer.num_local_experts, layer.moe_intermediate_size * 2],
dtype=self.default_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
setattr(
layer,
self.added_scale_attrs[1],
layer.create_parameter(
shape=[layer.num_local_experts, layer.hidden_size],
dtype=self.default_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
def process_loaded_weights(self, layer: nn.Layer, state_dict):
"""
@@ -151,6 +178,62 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
getattr(layer, weight_name).set_value(quanted_weight)
getattr(layer, scale_name).set_value(quanted_weight_scale)
def process_weights_after_loading(self, layer):
""" """
if not self.quant_config.is_checkpoint_bf16:
return
algo = layer.quant_method.quant_config.name()
assert algo == "wint8"
max_bound = 127
weight_id_map = {"gate_up": 0, "down": 1}
if (
hasattr(layer.up_gate_proj_weight, "tensor_track")
and layer.up_gate_proj_weight.tensor_track is not None
and layer.up_gate_proj_weight.tensor_track.is_fully_copied()
):
weight_type = "gate_up"
layer.up_gate_proj_weight.tensor_track = None
else:
weight_type = "down"
layer.down_proj_weight.tensor_track = None
# weight
weight_name = self.added_weight_attrs[weight_id_map[weight_type]]
# scale
scale_name = self.added_scale_attrs[weight_id_map[weight_type]]
weight_tensor = getattr(layer, weight_name)
quanted_weight_scale = weight_tensor.abs().max(axis=1)
quanted_weight = weight_tensor / quanted_weight_scale[:, None, :] * max_bound
quanted_weight = paddle.round(quanted_weight).astype("int8")
quanted_weight_scale = quanted_weight_scale / max_bound
getattr(layer, weight_name).value().get_tensor()._clear()
# create weight
setattr(
layer,
weight_name,
layer.create_parameter(
shape=weight_tensor.shape,
dtype=quanted_weight.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# create scale
setattr(
layer,
scale_name,
layer.create_parameter(
shape=quanted_weight_scale.shape,
dtype=quanted_weight_scale.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, weight_name).copy_(quanted_weight, False)
getattr(layer, scale_name).copy_(quanted_weight_scale, False)
def apply(
self,
layer: nn.Layer,
@@ -164,12 +247,11 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
token_num = x.shape[0]
top_k = layer.top_k
num_local_experts = layer.num_local_experts
top_k = layer.top_k
moe_intermediate_size = layer.moe_intermediate_size
hidden_size = layer.hidden_size
if layer.topk_method == "noaux_tc":
_, topk_weights, topk_ids = get_moe_scores(
gate_out, topk_weights, topk_ids = get_moe_scores(
gate_out,
layer.n_group,
layer.topk_group,
@@ -177,15 +259,15 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
layer.routed_scaling_factor,
layer.gate_correction_bias,
)
topk_weights, topk_ids = paddle.topk(gate_out, k=layer.top_k, axis=-1, sorted=False)
else:
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out,
layer.gate_correction_bias,
layer.top_k,
True, # apply_norm_weight
top_k,
True, # apply_norm_weight,
False,
)
up_gate_proj_out = paddle.empty(
[token_num * top_k, moe_intermediate_size * 2],
dtype=x.dtype,
@@ -302,6 +384,9 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
down_proj_out.reshape_([token_num, top_k, hidden_size])
out = down_proj_out.sum(axis=1)
if layer.reduce_results and layer.tp_size > 1:
tensor_model_parallel_all_reduce(out)
return out
@@ -432,7 +517,6 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
hidden_size = layer.hidden_size
if layer.topk_method == "noaux_tc":
_, topk_weights, topk_ids = get_moe_scores(
gate_out,
layer.n_group,

View File

@@ -27,6 +27,11 @@ from fastdeploy.model_executor.utils import slice_fn
from fastdeploy.platforms import current_platform
from fastdeploy.worker.experts_manager import RedundantExpertManger
try:
from fastdeploy.model_executor.ops.gpu import noaux_tc
except:
logger.warning("import noaux_tc Failed!")
def get_moe_method():
"""
@@ -54,6 +59,31 @@ def get_moe_method():
raise NotImplementedError
def get_moe_scores(
gating_output: paddle.Tensor,
n_group,
topk_group,
top_k,
routed_scaling_factor,
e_score_correction_bias,
) -> paddle.Tensor:
"""
compute moe scores using e_score_correction_bias.
"""
scores = paddle.nn.functional.sigmoid(gating_output)
assert e_score_correction_bias is not None, "e_score_correction_bias is none!"
scores_with_bias = scores + e_score_correction_bias
scores, topk_values, topk_idx = noaux_tc(
scores,
scores_with_bias,
n_group if n_group > 0 else 1,
topk_group if topk_group > 0 else 1,
top_k,
routed_scaling_factor,
)
return scores, topk_values, topk_idx
class FusedMoE(nn.Layer):
"""
FusedMoE is a layer that performs MoE (Mixture of Experts) computation.

View File

@@ -76,13 +76,13 @@ class MixQuantConfig(QuantConfigBase):
if layer.moe_tag == "Image":
return (
get_quantization_config(self.image_moe_quant_type)
.from_config({"is_permuted": self.is_permuted, "self.is_checkpoint_bf16": self.is_checkpoint_bf16})
.from_config({"is_permuted": self.is_permuted, "is_checkpoint_bf16": self.is_checkpoint_bf16})
.get_quant_method(layer)
)
else:
return (
get_quantization_config(self.moe_quant_type)
.from_config({"is_permuted": self.is_permuted, "self.is_checkpoint_bf16": self.is_checkpoint_bf16})
.from_config({"is_permuted": self.is_permuted, "is_checkpoint_bf16": self.is_checkpoint_bf16})
.get_quant_method(layer)
)
elif isinstance(layer, Attention):
@@ -97,6 +97,6 @@ class MixQuantConfig(QuantConfigBase):
else:
return (
get_quantization_config(self.dense_quant_type)
.from_config({"self.is_checkpoint_bf16": self.is_checkpoint_bf16})
.from_config({"is_checkpoint_bf16": self.is_checkpoint_bf16})
.get_quant_method(layer)
)

View File

@@ -44,6 +44,7 @@ class WeightOnlyConfig(QuantConfigBase):
def __init__(
self,
algo: str,
is_checkpoint_bf16: bool = False,
) -> None:
super().__init__()
self.algo = algo
@@ -55,6 +56,7 @@ class WeightOnlyConfig(QuantConfigBase):
self.quant_max_bound = 0
self.quant_min_bound = 0
self.quant_round_type = 0
self.is_checkpoint_bf16 = is_checkpoint_bf16
def name(self) -> str:
return "weight_only"
@@ -62,7 +64,8 @@ class WeightOnlyConfig(QuantConfigBase):
@classmethod
def from_config(cls, config: dict) -> "WeightOnlyConfig":
algo = config["algo"]
return cls(algo)
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
return cls(algo, is_checkpoint_bf16=is_checkpoint_bf16)
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
if current_platform.is_xpu():
@@ -153,12 +156,13 @@ class WINT8Config(WeightOnlyConfig):
weight only int8 config
"""
def __init__(self) -> None:
super().__init__("weight_only_int8")
def __init__(self, is_checkpoint_bf16: bool = False) -> None:
super().__init__("weight_only_int8", is_checkpoint_bf16)
@classmethod
def from_config(cls, config: dict) -> "WINT8Config":
return cls()
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
return cls(is_checkpoint_bf16)
def name(self) -> str:
return "wint8"

View File

@@ -14,10 +14,15 @@
# limitations under the License.
"""
import copy
from typing import Optional
import paddle
from fastdeploy.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
)
from fastdeploy.model_executor.layers.quantization.ops import (
cutlass_scaled_mm,
scaled_fp8_quant,
@@ -26,6 +31,8 @@ from fastdeploy.model_executor.layers.quantization.quant_base import (
QuantConfigBase,
QuantMethodBase,
)
from fastdeploy.model_executor.layers.utils import per_token_cast_to_fp8
from fastdeploy.model_executor.utils import TensorTracker, set_weight_attrs
class WFP8AFP8Config(QuantConfigBase):
@@ -33,13 +40,19 @@ class WFP8AFP8Config(QuantConfigBase):
Quantization config for weight and activation with FP8.
"""
def __init__(self, weight_scale_dict, act_scale_dict) -> None:
def __init__(
self,
activation_scheme: str = "dynamic",
weight_block_size: list[int] = [-1, 1],
is_checkpoint_bf16: bool = False,
) -> None:
super().__init__()
self.weight_scale_dict = weight_scale_dict
self.act_scale_dict = act_scale_dict
self.quant_max_bound = 448
self.quant_min_bound = -448
self.quant_round_type = 1
self.activation_scheme = activation_scheme
self.weight_block_size = weight_block_size
self.is_checkpoint_bf16 = is_checkpoint_bf16
def name(self) -> str:
""" """
@@ -48,9 +61,8 @@ class WFP8AFP8Config(QuantConfigBase):
@classmethod
def from_config(cls, config: dict) -> "WFP8AFP8Config":
""" """
weight_scale_dict = config.get("weight_scale_dict", None)
act_scale_dict = config.get("act_scale_dict", None)
return cls(weight_scale_dict, act_scale_dict)
is_checkpoint_bf16 = config.get("is_checkpoint_bf16", False)
return cls(is_checkpoint_bf16=is_checkpoint_bf16)
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
""" """
@@ -68,26 +80,85 @@ class WFP8AFP8LinearMethod(QuantMethodBase):
) -> None:
super().__init__()
self.quant_config = quant_config
self.use_per_token_if_dynamic = True
def create_weights(self, layer, **extra_weight_attrs):
""" """
layer.weight_shape.reverse()
layer.weight_dtype = "float8_e4m3fn"
# TODO(YuanRisheng): set weight logic should be moved to process_loaded_weights func
self.skip_quant = False
layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
weight_shape = layer.weight_shape
weight_block_size = self.quant_config.weight_block_size
assert len(weight_shape) == 2 and len(weight_block_size) == 2
scale_shape = copy.deepcopy(weight_shape)
for i in range(len(weight_shape)):
scale_shape[i] = (
(weight_shape[i] + weight_block_size[i] - 1) // weight_block_size[i] if weight_block_size[i] > 0 else 1
)
scale_shape = scale_shape[::-1]
if self.quant_config.is_checkpoint_bf16:
self.use_per_token_if_dynamic = True
layer.weight = layer.create_parameter(
shape=weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
quant_attrs = extra_weight_attrs
if isinstance(layer, MergedColumnParallelLinear) or isinstance(layer, QKVParallelLinear):
quant_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(
shape=layer.weight_shape, output_dim=extra_weight_attrs.get("output_dim")
),
}
set_weight_attrs(
layer.weight,
quant_attrs,
)
else:
layer.weight_shape.reverse()
layer.weight_dtype = "float8_e4m3fn"
# TODO(YuanRisheng): set weight logic should be moved to process_loaded_weights func
self.skip_quant = False
layer.weight = layer.create_parameter(
shape=layer.weight_shape,
dtype=layer.weight_dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.weight_scale = layer.create_parameter(
shape=scale_shape,
dtype="float32",
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
def process_weights_after_loading(self, layer) -> None:
if not self.quant_config.is_checkpoint_bf16:
return
weight_tensor = layer.weight.transpose([1, 0]).contiguous()
assert self.quant_config.weight_block_size == [-1, 1]
qweight, weight_scale = per_token_cast_to_fp8(weight_tensor)
if hasattr(layer.weight, "tensor_track"):
layer.weight.tensor_track = None
layer.weight.value().get_tensor()._clear()
del layer.weight
layer.weight = layer.create_parameter(
shape=qweight.shape,
dtype="float8_e4m3fn",
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.weight_scale = layer.create_parameter(
shape=[1],
shape=weight_scale.shape,
dtype="float32",
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
layer.weight.copy_(qweight, False)
layer.weight_scale.copy_(weight_scale, False)
def process_loaded_weights(self, layer, weights) -> None:
""" """
if self.skip_quant:
@@ -97,18 +168,12 @@ class WFP8AFP8LinearMethod(QuantMethodBase):
if weights.dtype != paddle.float8_e4m3fn:
self.use_per_token_if_dynamic = True
weight_tensor = weights.transpose([1, 0]).contiguous()
qweight, weight_scale = scaled_fp8_quant(
weight_tensor,
use_per_token_if_dynamic=False,
)
qweight, weight_scale = per_token_cast_to_fp8(weight_tensor)
layer.weight.copy_(qweight, False)
layer.weight_scale.set_value(weight_scale)
def apply(self, layer, x):
""" """
if self.skip_quant:
linear_out = paddle.matmul(x, layer.weight, False, True)
return linear_out
if self.use_per_token_if_dynamic:
out_type = x.dtype
a_q, a_scales = scaled_fp8_quant(x, use_per_token_if_dynamic=self.use_per_token_if_dynamic)

View File

@@ -73,6 +73,30 @@ class ErnieRotaryEmbedding:
return rot_emb
class GlmRotaryEmbedding:
def __init__(self, rotary_dim, base, partial_rotary_factor):
"""
Pre-calculate rotary position embedding for position_ids.
"""
self.rotary_dim = rotary_dim
self.base = base
if partial_rotary_factor < 1.0:
self.rotary_dim = int(self.rotary_dim * partial_rotary_factor)
def __call__(self, position_ids):
bsz, max_seq_len = position_ids.shape[:2]
inv_freq = self.base ** (-paddle.arange(0, self.rotary_dim, 2, dtype="float32") / self.rotary_dim)
freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq)
# shape: [B, S, D/2]
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32")
emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim // 2))
# shape: [B, S, 1, D]
emb = paddle.unsqueeze(emb, 2)
rot_emb[0] = paddle.cos(emb)
rot_emb[1] = paddle.sin(emb)
return rot_emb
class QwenRotaryEmbedding:
def __init__(self, rotary_dim, base, partial_rotary_factor):
"""
@@ -246,6 +270,9 @@ def get_rope_impl(
if model_config is None or architecture.startswith("Qwen"):
rotary_emb_layer = QwenRotaryEmbedding(rotary_dim, base, partial_rotary_factor)
rotary_emb = rotary_emb_layer(position_ids)
elif architecture.startswith("Glm"):
rotary_emb_layer = GlmRotaryEmbedding(rotary_dim, base, partial_rotary_factor)
rotary_emb = rotary_emb_layer(position_ids)
else:
rotary_emb_layer = ErnieRotaryEmbedding(rotary_dim, base, partial_rotary_factor)
rotary_emb = rotary_emb_layer(position_ids)

View File

@@ -77,6 +77,17 @@ def per_block_cast_to_fp8(x: Tensor, block_size: list = [128, 128]) -> Tuple[Ten
)
def per_token_cast_to_fp8(x: Tensor) -> Tuple[Tensor, Tensor]:
"""
Per token cast to float8_e4m3fn used in wfp8apf8
"""
x_abs = paddle.abs(x).astype(paddle.float32)
x_max = x_abs.max(axis=-1, keepdim=True).clip_(min=1e-4)
x_s = x_max / 448.0
x_q = paddle.clip(x / x_s, -448.0, 448.0).astype(paddle.float8_e4m3fn)
return x_q, x_s
# for distributed tensor model parallel
def _set_var_distributed(var: Tensor, split_axis: int):
"""