[Test]add glm45_air logprob test and rollout model (#4175)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

* add glm45_air logprob test

* add glm rollout model and pretrainedmodel for rl

* add glm rollout model and test

* check

* delete cudagraph in glm45

* add UT for glm rollout model

* revert glm UT
This commit is contained in:
chen
2025-09-23 21:06:07 +08:00
committed by GitHub
parent 62d1c48363
commit ec99474e71
6 changed files with 296 additions and 5 deletions

View File

@@ -17,9 +17,12 @@
from __future__ import annotations from __future__ import annotations
import re import re
from functools import partial
import paddle import paddle
from paddle import nn from paddle import nn
from paddleformers.transformers import PretrainedModel
from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
@@ -504,3 +507,86 @@ class Glm4MoeForCausalLM(ModelForCasualLM):
def clear_grpah_opt_backend(self): def clear_grpah_opt_backend(self):
"""Clear graph optimization backend, the captured cuda graph will be cleaned""" """Clear graph optimization backend, the captured cuda graph will be cleaned"""
self.model.clear_grpah_opt_backend(fd_config=self.fd_config) self.model.clear_grpah_opt_backend(fd_config=self.fd_config)
class Glm4MoePretrainedModel(PretrainedModel):
"""
Glm4MoePretrainedModel
"""
config_class = FDConfig
def _init_weight(self, layer):
"""
_init_weight
"""
return None
@classmethod
def arch_name(self):
return "Glm4MoeForCausalLM"
@classmethod
def _get_tensor_parallel_mappings(cls, config, is_split=True):
logger.info("Glm4Moe inference model _get_tensor_parallel_mappings")
from fastdeploy.model_executor.models.tp_utils import split_or_merge_func_v1
fn = split_or_merge_func_v1(
is_split=is_split,
tensor_parallel_degree=config.tensor_parallel_degree,
tensor_parallel_rank=config.tensor_parallel_rank,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
head_dim=config.head_dim,
)
def get_tensor_parallel_split_mappings(num_layers):
final_actions = {}
base_actions = {
"lm_head.weight": partial(fn, is_column=True),
"embed_tokens.weight": partial(fn, is_column=False),
"layers.0.self_attn.o_proj.weight": partial(fn, is_column=False),
}
# Self Attention Layer which are need TP.
base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.q_proj.bias"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.k_proj.bias"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.v_proj.bias"] = partial(fn, is_column=True)
# MLP Layer
base_actions["layers.0.mlp.gate_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.mlp.up_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.mlp.down_proj.weight"] = partial(fn, is_column=False)
# Moe Layer
for expert_idx in range(config.n_routed_experts):
base_actions[f"layers.0.mlp.experts.{expert_idx}.up_proj.weight"] = partial(fn, is_column=True)
base_actions[f"layers.0.mlp.experts.{expert_idx}.gate_proj.weight"] = partial(fn, is_column=True)
base_actions[f"layers.0.mlp.experts.{expert_idx}.down_proj.weight"] = partial(fn, is_column=False)
# Shared Expert Layer
base_actions["layers.0.mlp.shared_experts.up_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.mlp.shared_experts.gate_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.mlp.shared_experts.down_proj.weight"] = partial(fn, is_column=False)
# MTP parts
base_actions["layers.46.embed_tokens.weight"] = partial(fn, is_column=False)
base_actions["layers.46.eh_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.46.shared_head.head.weight"] = partial(fn, is_column=True)
for key, action in base_actions.items():
if "layers.0." in key:
for i in range(num_layers):
final_actions[key.replace("layers.0.", f"layers.{i}.")] = action
final_actions[key] = action
return final_actions
mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers)
return mappings

View File

@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
""" """
import copy
from typing import Dict from typing import Dict
import paddle import paddle
@@ -28,6 +29,10 @@ from fastdeploy.model_executor.models.ernie4_5_vl.ernie4_5_vl_moe import (
Ernie4_5_VLMoeForConditionalGeneration, Ernie4_5_VLMoeForConditionalGeneration,
Ernie4_5_VLPretrainedModel, Ernie4_5_VLPretrainedModel,
) )
from fastdeploy.model_executor.models.glm4_moe import (
Glm4MoeForCausalLM,
Glm4MoePretrainedModel,
)
from fastdeploy.model_executor.models.model_base import ModelRegistry from fastdeploy.model_executor.models.model_base import ModelRegistry
from fastdeploy.model_executor.models.qwen2 import ( from fastdeploy.model_executor.models.qwen2 import (
Qwen2ForCausalLM, Qwen2ForCausalLM,
@@ -529,3 +534,83 @@ class Qwen2_5_VLForConditionalGenerationRL(Qwen2_5_VLForConditionalGeneration, B
self._complete_missing_mappings() self._complete_missing_mappings()
return self.infer_to_train_mapping return self.infer_to_train_mapping
class Glm4MoeForCausalLMRL(Glm4MoeForCausalLM, BaseRLModel):
"""
Glm4MoeForCausalLMRL
"""
_get_tensor_parallel_mappings = Glm4MoePretrainedModel._get_tensor_parallel_mappings
def __init__(self, fd_config: FDConfig):
"""
Args:
fd_config (FDConfig): Configurations for the LLM model.
"""
super(Glm4MoeForCausalLMRL, self).__init__(fd_config)
@classmethod
def name(self) -> str:
"""name"""
return "Glm4MoeForCausalLMRL"
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
if self._mappings_built:
return self.infer_to_train_mapping
self.infer_to_train_mapping = {}
self._mappings_built = True
# Prepare placeholders
place_holders = ["weight"]
# Initialize mapping dictionary
self._update_base_mappings("model")
base_name = "model.layers"
# Helper function to add layer mappings
def _add_layer_mappings(layer_idx: int):
# MoE specific mappings
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.gate.weight"] = (
f"{base_name}.{layer_idx}.mlp.gate.weight"
)
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.gate.e_score_correction_bias"] = (
f"{base_name}.{layer_idx}.mlp.gate.e_score_correction_bias"
)
# MoE experts mappings
for expert_idx in range(self.fd_config.model_config.n_routed_experts):
for ph in place_holders:
# up_gate_proj (up_gate_proj)
up_gate_proj_key = f"{base_name}.{layer_idx}.mlp.experts.up_gate_proj_weight"
if up_gate_proj_key not in self.infer_to_train_mapping:
self.infer_to_train_mapping[up_gate_proj_key] = []
self.infer_to_train_mapping[up_gate_proj_key].append(
f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}"
)
# down_proj (down_proj)
down_proj_key = f"{base_name}.{layer_idx}.mlp.experts.down_proj_weight"
if down_proj_key not in self.infer_to_train_mapping:
self.infer_to_train_mapping[down_proj_key] = []
self.infer_to_train_mapping[down_proj_key].append(
f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}"
)
# Process MoE layers
for layer_idx in range(
self.fd_config.model_config.first_k_dense_replace,
self.fd_config.model_config.num_hidden_layers,
):
_add_layer_mappings(layer_idx)
self._complete_missing_mappings()
infer_to_train_mapping_copy = copy.deepcopy(self.infer_to_train_mapping)
for key in infer_to_train_mapping_copy.keys():
if "mlp.experts.gate_correction_bias" in key:
self.infer_to_train_mapping.pop(key)
return self.infer_to_train_mapping

View File

@@ -22,8 +22,9 @@ def test_rollout_model_with_distributed_launch():
test_rollout_model test_rollout_model
""" """
current_dir = os.path.dirname(os.path.abspath(__file__)) current_dir = os.path.dirname(os.path.abspath(__file__))
utils_dir = os.path.join(os.path.dirname(current_dir), "utils")
rollout_script = os.path.join(current_dir, "rollout_model.py") rollout_script = os.path.join(utils_dir, "rollout_model.py")
baseline_path = os.path.join(current_dir, "baseline.txt")
base_path = os.getenv("MODEL_PATH") base_path = os.getenv("MODEL_PATH")
if base_path: if base_path:
@@ -40,6 +41,11 @@ def test_rollout_model_with_distributed_launch():
rollout_script, rollout_script,
"--model_path", "--model_path",
model_path, model_path,
"--baseline_path",
baseline_path,
"--enable_mm",
"--quantization",
"wint8",
] ]
print(f"Executing command: {' '.join(command)}") print(f"Executing command: {' '.join(command)}")

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,66 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
import sys
def test_rollout_model_with_distributed_launch():
"""
test_rollout_model
"""
current_dir = os.path.dirname(os.path.abspath(__file__))
utils_dir = os.path.join(os.path.dirname(current_dir), "utils")
rollout_script = os.path.join(utils_dir, "rollout_model.py")
baseline_path = os.path.join(current_dir, "baseline.txt")
base_path = os.getenv("MODEL_PATH")
if base_path:
model_path = os.path.join(base_path, "GLM-4.5-Air-Fake")
else:
model_path = "./GLM-4.5-Air-Fake"
print(f"model_path = {model_path}")
command = [
sys.executable,
"-m",
"paddle.distributed.launch",
"--gpus",
"0,1",
rollout_script,
"--model_path",
model_path,
"--baseline_path",
baseline_path,
]
print(f"Executing command: {' '.join(command)}")
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
try:
stdout, stderr = process.communicate(timeout=300)
return_code = process.returncode
except subprocess.TimeoutExpired:
process.kill()
stdout, stderr = process.communicate()
return_code = -1
print("\n" + "=" * 50 + " STDOUT " + "=" * 50)
print(stdout)
print("\n" + "=" * 50 + " STDERR " + "=" * 50)
print(stderr)
assert return_code == 0, f"Process exited with code {return_code}"

View File

@@ -23,6 +23,9 @@ _, ranks = init_dist_env()
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, required=True, help="Path to the model directory") parser.add_argument("--model_path", type=str, required=True, help="Path to the model directory")
parser.add_argument("--baseline_path", type=str, required=True, help="Path to the baseline path")
parser.add_argument("--quantization", type=str, default=None, help="Quantization")
parser.add_argument("--enable_mm", action="store_true", required=False, help="Flags to enable multi-modal model")
args = parser.parse_args() args = parser.parse_args()
# base result # base result
@@ -35,9 +38,11 @@ init_kwargs = {
"tensor_parallel_size": ranks, "tensor_parallel_size": ranks,
"dynamic_load_weight": True, "dynamic_load_weight": True,
"load_strategy": "ipc_snapshot", "load_strategy": "ipc_snapshot",
"enable_mm": True, "quantization": args.quantization,
"quantization": "wint8",
} }
if args.enable_mm:
init_kwargs["enable_mm"] = True
rollout_config = RolloutModelConfig(**init_kwargs) rollout_config = RolloutModelConfig(**init_kwargs)
actor_eval_model = RolloutModel(rollout_config) actor_eval_model = RolloutModel(rollout_config)
@@ -75,7 +80,7 @@ def compare_strings_line_by_line(a: str, b: str) -> bool:
return True return True
with open("baseline.txt", "r", encoding="utf-8") as f: with open(args.baseline_path, "r", encoding="utf-8") as f:
baseline = f.read() baseline = f.read()
assert compare_strings_line_by_line(baseline, content), ( assert compare_strings_line_by_line(baseline, content), (
"In the unittest of RL scenario, your modification " "In the unittest of RL scenario, your modification "