mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-27 04:46:16 +08:00
[Test]add glm45_air logprob test and rollout model (#4175)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
* add glm45_air logprob test * add glm rollout model and pretrainedmodel for rl * add glm rollout model and test * check * delete cudagraph in glm45 * add UT for glm rollout model * revert glm UT
This commit is contained in:
@@ -17,9 +17,12 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
from paddle import nn
|
from paddle import nn
|
||||||
|
from paddleformers.transformers import PretrainedModel
|
||||||
|
from paddleformers.utils.log import logger
|
||||||
|
|
||||||
from fastdeploy.config import FDConfig
|
from fastdeploy.config import FDConfig
|
||||||
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||||
@@ -504,3 +507,86 @@ class Glm4MoeForCausalLM(ModelForCasualLM):
|
|||||||
def clear_grpah_opt_backend(self):
|
def clear_grpah_opt_backend(self):
|
||||||
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
|
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
|
||||||
self.model.clear_grpah_opt_backend(fd_config=self.fd_config)
|
self.model.clear_grpah_opt_backend(fd_config=self.fd_config)
|
||||||
|
|
||||||
|
|
||||||
|
class Glm4MoePretrainedModel(PretrainedModel):
|
||||||
|
"""
|
||||||
|
Glm4MoePretrainedModel
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = FDConfig
|
||||||
|
|
||||||
|
def _init_weight(self, layer):
|
||||||
|
"""
|
||||||
|
_init_weight
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def arch_name(self):
|
||||||
|
return "Glm4MoeForCausalLM"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_tensor_parallel_mappings(cls, config, is_split=True):
|
||||||
|
|
||||||
|
logger.info("Glm4Moe inference model _get_tensor_parallel_mappings")
|
||||||
|
|
||||||
|
from fastdeploy.model_executor.models.tp_utils import split_or_merge_func_v1
|
||||||
|
|
||||||
|
fn = split_or_merge_func_v1(
|
||||||
|
is_split=is_split,
|
||||||
|
tensor_parallel_degree=config.tensor_parallel_degree,
|
||||||
|
tensor_parallel_rank=config.tensor_parallel_rank,
|
||||||
|
num_attention_heads=config.num_attention_heads,
|
||||||
|
num_key_value_heads=config.num_key_value_heads,
|
||||||
|
head_dim=config.head_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_tensor_parallel_split_mappings(num_layers):
|
||||||
|
final_actions = {}
|
||||||
|
|
||||||
|
base_actions = {
|
||||||
|
"lm_head.weight": partial(fn, is_column=True),
|
||||||
|
"embed_tokens.weight": partial(fn, is_column=False),
|
||||||
|
"layers.0.self_attn.o_proj.weight": partial(fn, is_column=False),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Self Attention Layer which are need TP.
|
||||||
|
base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True)
|
||||||
|
base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True)
|
||||||
|
base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True)
|
||||||
|
base_actions["layers.0.self_attn.q_proj.bias"] = partial(fn, is_column=True)
|
||||||
|
base_actions["layers.0.self_attn.k_proj.bias"] = partial(fn, is_column=True)
|
||||||
|
base_actions["layers.0.self_attn.v_proj.bias"] = partial(fn, is_column=True)
|
||||||
|
|
||||||
|
# MLP Layer
|
||||||
|
base_actions["layers.0.mlp.gate_proj.weight"] = partial(fn, is_column=True)
|
||||||
|
base_actions["layers.0.mlp.up_proj.weight"] = partial(fn, is_column=True)
|
||||||
|
base_actions["layers.0.mlp.down_proj.weight"] = partial(fn, is_column=False)
|
||||||
|
|
||||||
|
# Moe Layer
|
||||||
|
for expert_idx in range(config.n_routed_experts):
|
||||||
|
base_actions[f"layers.0.mlp.experts.{expert_idx}.up_proj.weight"] = partial(fn, is_column=True)
|
||||||
|
base_actions[f"layers.0.mlp.experts.{expert_idx}.gate_proj.weight"] = partial(fn, is_column=True)
|
||||||
|
base_actions[f"layers.0.mlp.experts.{expert_idx}.down_proj.weight"] = partial(fn, is_column=False)
|
||||||
|
|
||||||
|
# Shared Expert Layer
|
||||||
|
base_actions["layers.0.mlp.shared_experts.up_proj.weight"] = partial(fn, is_column=True)
|
||||||
|
base_actions["layers.0.mlp.shared_experts.gate_proj.weight"] = partial(fn, is_column=True)
|
||||||
|
base_actions["layers.0.mlp.shared_experts.down_proj.weight"] = partial(fn, is_column=False)
|
||||||
|
|
||||||
|
# MTP parts
|
||||||
|
base_actions["layers.46.embed_tokens.weight"] = partial(fn, is_column=False)
|
||||||
|
base_actions["layers.46.eh_proj.weight"] = partial(fn, is_column=True)
|
||||||
|
base_actions["layers.46.shared_head.head.weight"] = partial(fn, is_column=True)
|
||||||
|
|
||||||
|
for key, action in base_actions.items():
|
||||||
|
if "layers.0." in key:
|
||||||
|
for i in range(num_layers):
|
||||||
|
final_actions[key.replace("layers.0.", f"layers.{i}.")] = action
|
||||||
|
final_actions[key] = action
|
||||||
|
|
||||||
|
return final_actions
|
||||||
|
|
||||||
|
mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers)
|
||||||
|
return mappings
|
||||||
|
@@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
@@ -28,6 +29,10 @@ from fastdeploy.model_executor.models.ernie4_5_vl.ernie4_5_vl_moe import (
|
|||||||
Ernie4_5_VLMoeForConditionalGeneration,
|
Ernie4_5_VLMoeForConditionalGeneration,
|
||||||
Ernie4_5_VLPretrainedModel,
|
Ernie4_5_VLPretrainedModel,
|
||||||
)
|
)
|
||||||
|
from fastdeploy.model_executor.models.glm4_moe import (
|
||||||
|
Glm4MoeForCausalLM,
|
||||||
|
Glm4MoePretrainedModel,
|
||||||
|
)
|
||||||
from fastdeploy.model_executor.models.model_base import ModelRegistry
|
from fastdeploy.model_executor.models.model_base import ModelRegistry
|
||||||
from fastdeploy.model_executor.models.qwen2 import (
|
from fastdeploy.model_executor.models.qwen2 import (
|
||||||
Qwen2ForCausalLM,
|
Qwen2ForCausalLM,
|
||||||
@@ -529,3 +534,83 @@ class Qwen2_5_VLForConditionalGenerationRL(Qwen2_5_VLForConditionalGeneration, B
|
|||||||
self._complete_missing_mappings()
|
self._complete_missing_mappings()
|
||||||
|
|
||||||
return self.infer_to_train_mapping
|
return self.infer_to_train_mapping
|
||||||
|
|
||||||
|
|
||||||
|
class Glm4MoeForCausalLMRL(Glm4MoeForCausalLM, BaseRLModel):
|
||||||
|
"""
|
||||||
|
Glm4MoeForCausalLMRL
|
||||||
|
"""
|
||||||
|
|
||||||
|
_get_tensor_parallel_mappings = Glm4MoePretrainedModel._get_tensor_parallel_mappings
|
||||||
|
|
||||||
|
def __init__(self, fd_config: FDConfig):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
fd_config (FDConfig): Configurations for the LLM model.
|
||||||
|
"""
|
||||||
|
super(Glm4MoeForCausalLMRL, self).__init__(fd_config)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def name(self) -> str:
|
||||||
|
"""name"""
|
||||||
|
return "Glm4MoeForCausalLMRL"
|
||||||
|
|
||||||
|
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
|
||||||
|
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
|
||||||
|
if self._mappings_built:
|
||||||
|
return self.infer_to_train_mapping
|
||||||
|
|
||||||
|
self.infer_to_train_mapping = {}
|
||||||
|
self._mappings_built = True
|
||||||
|
# Prepare placeholders
|
||||||
|
place_holders = ["weight"]
|
||||||
|
|
||||||
|
# Initialize mapping dictionary
|
||||||
|
self._update_base_mappings("model")
|
||||||
|
|
||||||
|
base_name = "model.layers"
|
||||||
|
|
||||||
|
# Helper function to add layer mappings
|
||||||
|
def _add_layer_mappings(layer_idx: int):
|
||||||
|
# MoE specific mappings
|
||||||
|
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.gate.weight"] = (
|
||||||
|
f"{base_name}.{layer_idx}.mlp.gate.weight"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.gate.e_score_correction_bias"] = (
|
||||||
|
f"{base_name}.{layer_idx}.mlp.gate.e_score_correction_bias"
|
||||||
|
)
|
||||||
|
|
||||||
|
# MoE experts mappings
|
||||||
|
for expert_idx in range(self.fd_config.model_config.n_routed_experts):
|
||||||
|
for ph in place_holders:
|
||||||
|
# up_gate_proj (up_gate_proj)
|
||||||
|
up_gate_proj_key = f"{base_name}.{layer_idx}.mlp.experts.up_gate_proj_weight"
|
||||||
|
if up_gate_proj_key not in self.infer_to_train_mapping:
|
||||||
|
self.infer_to_train_mapping[up_gate_proj_key] = []
|
||||||
|
self.infer_to_train_mapping[up_gate_proj_key].append(
|
||||||
|
f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# down_proj (down_proj)
|
||||||
|
down_proj_key = f"{base_name}.{layer_idx}.mlp.experts.down_proj_weight"
|
||||||
|
if down_proj_key not in self.infer_to_train_mapping:
|
||||||
|
self.infer_to_train_mapping[down_proj_key] = []
|
||||||
|
self.infer_to_train_mapping[down_proj_key].append(
|
||||||
|
f"{base_name}.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process MoE layers
|
||||||
|
for layer_idx in range(
|
||||||
|
self.fd_config.model_config.first_k_dense_replace,
|
||||||
|
self.fd_config.model_config.num_hidden_layers,
|
||||||
|
):
|
||||||
|
_add_layer_mappings(layer_idx)
|
||||||
|
|
||||||
|
self._complete_missing_mappings()
|
||||||
|
infer_to_train_mapping_copy = copy.deepcopy(self.infer_to_train_mapping)
|
||||||
|
for key in infer_to_train_mapping_copy.keys():
|
||||||
|
if "mlp.experts.gate_correction_bias" in key:
|
||||||
|
self.infer_to_train_mapping.pop(key)
|
||||||
|
|
||||||
|
return self.infer_to_train_mapping
|
||||||
|
@@ -22,8 +22,9 @@ def test_rollout_model_with_distributed_launch():
|
|||||||
test_rollout_model
|
test_rollout_model
|
||||||
"""
|
"""
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
utils_dir = os.path.join(os.path.dirname(current_dir), "utils")
|
||||||
rollout_script = os.path.join(current_dir, "rollout_model.py")
|
rollout_script = os.path.join(utils_dir, "rollout_model.py")
|
||||||
|
baseline_path = os.path.join(current_dir, "baseline.txt")
|
||||||
|
|
||||||
base_path = os.getenv("MODEL_PATH")
|
base_path = os.getenv("MODEL_PATH")
|
||||||
if base_path:
|
if base_path:
|
||||||
@@ -40,6 +41,11 @@ def test_rollout_model_with_distributed_launch():
|
|||||||
rollout_script,
|
rollout_script,
|
||||||
"--model_path",
|
"--model_path",
|
||||||
model_path,
|
model_path,
|
||||||
|
"--baseline_path",
|
||||||
|
baseline_path,
|
||||||
|
"--enable_mm",
|
||||||
|
"--quantization",
|
||||||
|
"wint8",
|
||||||
]
|
]
|
||||||
|
|
||||||
print(f"Executing command: {' '.join(command)}")
|
print(f"Executing command: {' '.join(command)}")
|
||||||
|
43
tests/ci_use/GLM-45-AIR/baseline.txt
Normal file
43
tests/ci_use/GLM-45-AIR/baseline.txt
Normal file
File diff suppressed because one or more lines are too long
66
tests/ci_use/GLM-45-AIR/test_rollout_model.py
Normal file
66
tests/ci_use/GLM-45-AIR/test_rollout_model.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
def test_rollout_model_with_distributed_launch():
|
||||||
|
"""
|
||||||
|
test_rollout_model
|
||||||
|
"""
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
utils_dir = os.path.join(os.path.dirname(current_dir), "utils")
|
||||||
|
rollout_script = os.path.join(utils_dir, "rollout_model.py")
|
||||||
|
baseline_path = os.path.join(current_dir, "baseline.txt")
|
||||||
|
|
||||||
|
base_path = os.getenv("MODEL_PATH")
|
||||||
|
if base_path:
|
||||||
|
model_path = os.path.join(base_path, "GLM-4.5-Air-Fake")
|
||||||
|
else:
|
||||||
|
model_path = "./GLM-4.5-Air-Fake"
|
||||||
|
print(f"model_path = {model_path}")
|
||||||
|
|
||||||
|
command = [
|
||||||
|
sys.executable,
|
||||||
|
"-m",
|
||||||
|
"paddle.distributed.launch",
|
||||||
|
"--gpus",
|
||||||
|
"0,1",
|
||||||
|
rollout_script,
|
||||||
|
"--model_path",
|
||||||
|
model_path,
|
||||||
|
"--baseline_path",
|
||||||
|
baseline_path,
|
||||||
|
]
|
||||||
|
|
||||||
|
print(f"Executing command: {' '.join(command)}")
|
||||||
|
|
||||||
|
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
stdout, stderr = process.communicate(timeout=300)
|
||||||
|
return_code = process.returncode
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
process.kill()
|
||||||
|
stdout, stderr = process.communicate()
|
||||||
|
return_code = -1
|
||||||
|
|
||||||
|
print("\n" + "=" * 50 + " STDOUT " + "=" * 50)
|
||||||
|
print(stdout)
|
||||||
|
print("\n" + "=" * 50 + " STDERR " + "=" * 50)
|
||||||
|
print(stderr)
|
||||||
|
|
||||||
|
assert return_code == 0, f"Process exited with code {return_code}"
|
@@ -23,6 +23,9 @@ _, ranks = init_dist_env()
|
|||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--model_path", type=str, required=True, help="Path to the model directory")
|
parser.add_argument("--model_path", type=str, required=True, help="Path to the model directory")
|
||||||
|
parser.add_argument("--baseline_path", type=str, required=True, help="Path to the baseline path")
|
||||||
|
parser.add_argument("--quantization", type=str, default=None, help="Quantization")
|
||||||
|
parser.add_argument("--enable_mm", action="store_true", required=False, help="Flags to enable multi-modal model")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# base result
|
# base result
|
||||||
@@ -35,9 +38,11 @@ init_kwargs = {
|
|||||||
"tensor_parallel_size": ranks,
|
"tensor_parallel_size": ranks,
|
||||||
"dynamic_load_weight": True,
|
"dynamic_load_weight": True,
|
||||||
"load_strategy": "ipc_snapshot",
|
"load_strategy": "ipc_snapshot",
|
||||||
"enable_mm": True,
|
"quantization": args.quantization,
|
||||||
"quantization": "wint8",
|
|
||||||
}
|
}
|
||||||
|
if args.enable_mm:
|
||||||
|
init_kwargs["enable_mm"] = True
|
||||||
|
|
||||||
|
|
||||||
rollout_config = RolloutModelConfig(**init_kwargs)
|
rollout_config = RolloutModelConfig(**init_kwargs)
|
||||||
actor_eval_model = RolloutModel(rollout_config)
|
actor_eval_model = RolloutModel(rollout_config)
|
||||||
@@ -75,7 +80,7 @@ def compare_strings_line_by_line(a: str, b: str) -> bool:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
with open("baseline.txt", "r", encoding="utf-8") as f:
|
with open(args.baseline_path, "r", encoding="utf-8") as f:
|
||||||
baseline = f.read()
|
baseline = f.read()
|
||||||
assert compare_strings_line_by_line(baseline, content), (
|
assert compare_strings_line_by_line(baseline, content), (
|
||||||
"In the unittest of RL scenario, your modification "
|
"In the unittest of RL scenario, your modification "
|
Reference in New Issue
Block a user