[Bug Fix] fix ep config bug (#2920)

This commit is contained in:
ming1753
2025-07-18 19:12:56 +08:00
committed by GitHub
parent a42fc3f40b
commit 5328daa333
2 changed files with 37 additions and 22 deletions

View File

@@ -25,14 +25,14 @@ from paddleformers.transformers.model_utils import load_tp_checkpoint
from safetensors import safe_open from safetensors import safe_open
from tqdm import tqdm from tqdm import tqdm
from fastdeploy.config import FDConfig, ModelConfig from fastdeploy.config import FDConfig
from fastdeploy.model_executor.models.tp_utils import \ from fastdeploy.model_executor.models.tp_utils import \
check_tensor_parallel_prerequisites check_tensor_parallel_prerequisites
from fastdeploy.platforms import current_platform from fastdeploy.platforms import current_platform
def load_ep_checkpoint(model_path: str, def load_ep_checkpoint(model_path: str,
config: ModelConfig, fd_config: FDConfig,
return_numpy: bool = False): return_numpy: bool = False):
""" """
load ep checkpoint load ep checkpoint
@@ -44,17 +44,17 @@ def load_ep_checkpoint(model_path: str,
num_local_ffn_keys = [] num_local_ffn_keys = []
from itertools import chain from itertools import chain
def get_expert_ranges(config): def get_expert_ranges(fd_config):
""" """
Generate expert index ranges based on configuration parameters Generate expert index ranges based on configuration parameters
This function is primarily used in Mixture-of-Experts (MoE) models to generate This function is primarily used in Mixture-of-Experts (MoE) models to generate
expert index ranges according to configuration parameters. When moe_num_experts expert index ranges according to configuration parameters. When moe_num_experts
is a list in the config, it returns a chained combination of two ranges, otherwise is a list in the fd_config, it returns a chained combination of two ranges, otherwise
returns a single range. returns a single range.
Args: Args:
config: Configuration object fd_config: FastDeploy Configuration object
Returns: Returns:
If moe_num_experts is a list: If moe_num_experts is a list:
@@ -65,16 +65,16 @@ def load_ep_checkpoint(model_path: str,
Returns single range: [num_experts_start_offset, num_experts_start_offset + num_experts_per_rank) Returns single range: [num_experts_start_offset, num_experts_start_offset + num_experts_per_rank)
""" """
base_range = range( base_range = range(
config.num_experts_start_offset, fd_config.parallel_config.num_experts_start_offset,
config.num_experts_start_offset + config.num_experts_per_rank fd_config.parallel_config.num_experts_start_offset + fd_config.parallel_config.num_experts_per_rank
) )
if isinstance(config.moe_num_experts, list): if isinstance(fd_config.model_config.moe_num_experts, list):
return chain(base_range, return chain(base_range,
range(base_range.start + config.moe_num_experts[0], base_range.stop + config.moe_num_experts[0])) range(base_range.start + fd_config.model_config.moe_num_experts[0], base_range.stop + fd_config.model_config.moe_num_experts[0]))
return base_range return base_range
for i in range(config.moe_layer_start_index, config.num_hidden_layers): for i in range(fd_config.model_config.moe_layer_start_index, fd_config.model_config.num_hidden_layers):
for j in get_expert_ranges(config): for j in get_expert_ranges(fd_config):
up_gate_proj_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight" up_gate_proj_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight"
down_proj_key = (f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight") down_proj_key = (f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight")
@@ -280,7 +280,7 @@ def load_composite_checkpoint(
if fd_config.parallel_config.use_ep and \ if fd_config.parallel_config.use_ep and \
fd_config.speculative_config.model_type != "mtp": fd_config.speculative_config.model_type != "mtp":
state_dict = load_ep_checkpoint(model_path, state_dict = load_ep_checkpoint(model_path,
fd_config.model_config, fd_config,
return_numpy=True) return_numpy=True)
else: else:
rank_dirs = [ rank_dirs = [

View File

@@ -398,13 +398,13 @@ class PaddleDisWorkerProc():
if num_blocks_global < 0: if num_blocks_global < 0:
logger.error( logger.error(
f"The total number of blocks cannot be less than zero." "The total number of blocks cannot be less than zero."
f"Please increase gpu_memory_utilization" "Please increase gpu_memory_utilization"
f"Or decrease max_num_batched_tokens(max model length) ") "Or decrease max_num_batched_tokens(max model length) ")
raise ValueError( raise ValueError(
f"The total number of blocks cannot be less than zero." "The total number of blocks cannot be less than zero."
f"Please increase gpu_memory_utilization" "Please increase gpu_memory_utilization"
f"Or decrease max_num_batched_tokens(max model length) ") "Or decrease max_num_batched_tokens(max model length) ")
self.get_profile_block_num_signal.value[ self.get_profile_block_num_signal.value[
@@ -604,9 +604,24 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
decoding_config = DecodingConfig(vars(args)) decoding_config = DecodingConfig(vars(args))
speculative_config = SpeculativeConfig(vars(args)) speculative_config = SpeculativeConfig(vars(args))
parallel_config = ParallelConfig(vars(args)) parallel_config = ParallelConfig(vars(args))
parallel_config.tensor_parallel_rank = local_rank parallel_config.tensor_parallel_size = args.tensor_parallel_size
parallel_config.tensor_parallel_size = ranks parallel_config.tensor_parallel_rank = local_rank % args.tensor_parallel_size
parallel_config.expert_parallel_rank = int(local_rank / ranks) parallel_config.expert_parallel_size = args.expert_parallel_size
# config for EP
if args.expert_parallel_size > 1:
expert_parallel_rank = int(local_rank / args.tensor_parallel_size)
if isinstance(model_config.moe_num_experts, list):
num_experts = model_config.moe_num_experts[0]
else:
num_experts = model_config.moe_num_experts
num_experts_per_rank = num_experts // args.expert_parallel_size
num_experts_start_offset = expert_parallel_rank * num_experts_per_rank
parallel_config.expert_parallel_rank = expert_parallel_rank
parallel_config.num_experts_per_rank = num_experts_per_rank
parallel_config.num_experts_start_offset = num_experts_start_offset
load_config = LoadConfig(vars(args)) load_config = LoadConfig(vars(args))
graph_opt_config = GraphOptimizationConfig() graph_opt_config = GraphOptimizationConfig()