[CP]Glm45 air 2.2 (#4073)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled

* [Feature] Support zai-org/GLM-4.5-Air BF16 model (#3928)

* support glm45_air

* [Feature] GLM-45-AIR Support Mix Quantization(Dense wfp8afp8 and wint8 triton_moe_backend) (#4051)

* check

* fix v1 load for mix and wint8

* check --quantizations 'None'

* check

* support RL rollout

* check v1 loader

* check glm rollout_model, change wfp8afp8 per_token_cast_to_fp8 to native impl

* check rollout moe gate begin layer_id

* check rollout e_score_correction_bias

* delete infer_to_train_mapping={}

* code check
This commit is contained in:
chen
2025-09-15 18:52:58 +08:00
committed by GitHub
parent 4e8ba62241
commit fbb4e0f8d1
25 changed files with 1505 additions and 170 deletions

View File

@@ -73,6 +73,30 @@ class ErnieRotaryEmbedding:
return rot_emb
class GlmRotaryEmbedding:
def __init__(self, rotary_dim, base, partial_rotary_factor):
"""
Pre-calculate rotary position embedding for position_ids.
"""
self.rotary_dim = rotary_dim
self.base = base
if partial_rotary_factor < 1.0:
self.rotary_dim = int(self.rotary_dim * partial_rotary_factor)
def __call__(self, position_ids):
bsz, max_seq_len = position_ids.shape[:2]
inv_freq = self.base ** (-paddle.arange(0, self.rotary_dim, 2, dtype="float32") / self.rotary_dim)
freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq)
# shape: [B, S, D/2]
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32")
emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim // 2))
# shape: [B, S, 1, D]
emb = paddle.unsqueeze(emb, 2)
rot_emb[0] = paddle.cos(emb)
rot_emb[1] = paddle.sin(emb)
return rot_emb
class QwenRotaryEmbedding:
def __init__(self, rotary_dim, base, partial_rotary_factor):
"""
@@ -246,6 +270,9 @@ def get_rope_impl(
if model_config is None or architecture.startswith("Qwen"):
rotary_emb_layer = QwenRotaryEmbedding(rotary_dim, base, partial_rotary_factor)
rotary_emb = rotary_emb_layer(position_ids)
elif architecture.startswith("Glm"):
rotary_emb_layer = GlmRotaryEmbedding(rotary_dim, base, partial_rotary_factor)
rotary_emb = rotary_emb_layer(position_ids)
else:
rotary_emb_layer = ErnieRotaryEmbedding(rotary_dim, base, partial_rotary_factor)
rotary_emb = rotary_emb_layer(position_ids)