mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
Simplify the Config code (#2770)
* simplify the code * fix vl * delete config * fix * perfect code * fix ci * fix xpu * fix xpu * fix server * resolve conflict * fix mtp * resolve conflict * fix xpu * fix xpu * fix vl * fix log * fix qwen moe * fix qwen moe * fix qwen moe
This commit is contained in:
@@ -22,11 +22,9 @@ import paddle
|
||||
import paddle.distributed as dist
|
||||
import paddle.distributed.fleet as fleet
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.config import (DecodingConfig, DeviceConfig, FDConfig,
|
||||
GraphOptimizationConfig, LoadConfig,
|
||||
ModelConfig, MoEConfig, MoEPhase,
|
||||
ParallelConfig, SpeculativeConfig)
|
||||
ModelConfig, ParallelConfig, SpeculativeConfig)
|
||||
from fastdeploy.inter_communicator import EngineWorkerQueue as TaskQueue
|
||||
from fastdeploy.inter_communicator import IPCSignal
|
||||
from fastdeploy.model_executor.layers.quantization import \
|
||||
@@ -122,7 +120,7 @@ class PaddleDisWorkerProc():
|
||||
self.task_queue = TaskQueue(
|
||||
address=task_address,
|
||||
is_server=False,
|
||||
num_client=self.parallel_config.tensor_parallel_degree,
|
||||
num_client=self.parallel_config.tensor_parallel_size,
|
||||
client_id=self.parallel_config.tensor_parallel_rank,
|
||||
local_data_parallel_id=self.parallel_config.expert_parallel_rank)
|
||||
|
||||
@@ -139,8 +137,8 @@ class PaddleDisWorkerProc():
|
||||
# init worker_ready_signal
|
||||
max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
||||
array_size = min(
|
||||
max_chips_per_node, self.parallel_config.tensor_parallel_degree *
|
||||
self.parallel_config.expert_parallel_degree)
|
||||
max_chips_per_node, self.parallel_config.tensor_parallel_size *
|
||||
self.parallel_config.expert_parallel_size)
|
||||
workers_ready = np.zeros(shape=[array_size], dtype=np.int32)
|
||||
self.worker_ready_signal = IPCSignal(
|
||||
name="worker_ready_signal",
|
||||
@@ -173,7 +171,7 @@ class PaddleDisWorkerProc():
|
||||
|
||||
# init exist_task_signal
|
||||
workers_exist_task = np.zeros(
|
||||
[self.parallel_config.expert_parallel_degree], dtype=np.int32)
|
||||
[self.parallel_config.expert_parallel_size], dtype=np.int32)
|
||||
self.exist_task_signal = IPCSignal(
|
||||
name="exist_task_signal",
|
||||
array=workers_exist_task,
|
||||
@@ -183,7 +181,7 @@ class PaddleDisWorkerProc():
|
||||
|
||||
# init exist_swapped_task_signal
|
||||
workers_swapped_task = np.zeros(
|
||||
shape=[self.parallel_config.expert_parallel_degree],
|
||||
shape=[self.parallel_config.expert_parallel_size],
|
||||
dtype=np.int32)
|
||||
self.exist_swapped_task_signal = IPCSignal(
|
||||
name="exist_swapped_task_signal",
|
||||
@@ -231,8 +229,8 @@ class PaddleDisWorkerProc():
|
||||
TODO(gongshaotian): support remote calling of functions that control worker.
|
||||
"""
|
||||
# Currently, only support single node
|
||||
self.nnode = int((self.parallel_config.tensor_parallel_degree + 7) // 8)
|
||||
mp_num_per_node = self.parallel_config.tensor_parallel_degree // self.nnode
|
||||
self.nnode = int((self.parallel_config.tensor_parallel_size + 7) // 8)
|
||||
mp_num_per_node = self.parallel_config.tensor_parallel_size // self.nnode
|
||||
req_ids = []
|
||||
while True:
|
||||
if self.local_rank == 0:
|
||||
@@ -241,7 +239,7 @@ class PaddleDisWorkerProc():
|
||||
else:
|
||||
self.exist_task_signal.value[0] = 0
|
||||
|
||||
if self.parallel_config.tensor_parallel_degree > 1:
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
# Synchronize before updating weights
|
||||
paddle.distributed.barrier()
|
||||
|
||||
@@ -259,7 +257,7 @@ class PaddleDisWorkerProc():
|
||||
self.fd_config.parallel_config.
|
||||
expert_parallel_rank] = 1
|
||||
|
||||
if self.parallel_config.tensor_parallel_degree > 1:
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
# Synchronize the signal for other workers
|
||||
# TODO(@wufeisheng): Split TP group and EP group
|
||||
paddle.distributed.barrier()
|
||||
@@ -479,8 +477,8 @@ def parse_args():
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speculative_benchmark_mode",
|
||||
default="false",
|
||||
type=str,
|
||||
default=False,
|
||||
type=bool,
|
||||
)
|
||||
parser.add_argument("--max_num_batched_tokens",
|
||||
type=int,
|
||||
@@ -559,7 +557,7 @@ def parse_args():
|
||||
return args
|
||||
|
||||
|
||||
def initialize_fd_config(config_or_args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
"""Initialize FDConfig from either RolloutModelConfig or argparse.Namespace
|
||||
|
||||
Args:
|
||||
@@ -568,196 +566,37 @@ def initialize_fd_config(config_or_args, ranks: int = 1, local_rank: int = 0) ->
|
||||
Returns:
|
||||
FDConfig: Initialized FastDeploy configuration object
|
||||
"""
|
||||
# Get model config from model directory
|
||||
model_config_dict, _ = ModelConfig.get_config_dict(config_or_args.model_name_or_path)
|
||||
|
||||
# Handle MoE related configs
|
||||
if 'num_experts' in model_config_dict:
|
||||
model_config_dict['moe_num_experts'] = model_config_dict.pop('num_experts')
|
||||
if 'num_experts_per_tok' in model_config_dict:
|
||||
model_config_dict['moe_topk'] = model_config_dict.pop('num_experts_per_tok')
|
||||
|
||||
# Set default values for model config
|
||||
model_config_dict["head_dim"] = model_config_dict.get(
|
||||
"head_dim", model_config_dict["hidden_size"] // model_config_dict["num_attention_heads"])
|
||||
model_config_dict["rope_theta"] = model_config_dict.get("rope_theta", 10000.0)
|
||||
|
||||
# Create model config object
|
||||
model_config = ModelConfig.from_dict(model_config_dict)
|
||||
model_config.head_dim = model_config_dict["head_dim"]
|
||||
paddle.set_default_dtype(config_or_args.dtype)
|
||||
if 'tie_word_embeddings' in model_config_dict:
|
||||
model_config.tie_word_embeddings = model_config_dict['tie_word_embeddings']
|
||||
|
||||
# Initialize all config components
|
||||
device_config = DeviceConfig()
|
||||
decoding_config = DecodingConfig()
|
||||
speculative_config = SpeculativeConfig()
|
||||
parallel_config = ParallelConfig()
|
||||
load_config = LoadConfig()
|
||||
moe_config = MoEConfig()
|
||||
|
||||
# Handle graph optimization config (check for attribute existence for backward compatibility)
|
||||
enable_static_graph_inference = getattr(config_or_args, 'enable_static_graph_inference', False)
|
||||
use_cudagraph = getattr(config_or_args, 'use_cudagraph', False)
|
||||
max_capture_batch_size = getattr(config_or_args, 'max_capture_batch_size', 0)
|
||||
paddle.set_default_dtype(args.dtype)
|
||||
model_config = ModelConfig(vars(args))
|
||||
device_config = DeviceConfig(vars(args))
|
||||
decoding_config = DecodingConfig(vars(args))
|
||||
speculative_config = SpeculativeConfig(vars(args))
|
||||
parallel_config = ParallelConfig(vars(args))
|
||||
load_config = LoadConfig(vars(args))
|
||||
|
||||
graph_opt_config = GraphOptimizationConfig(
|
||||
enable_static_graph_inference,
|
||||
use_cudagraph,
|
||||
max_capture_batch_size
|
||||
)
|
||||
args.enable_static_graph_inference,
|
||||
args.max_capture_batch_size,
|
||||
vars(args))
|
||||
|
||||
# Handle quantization (check for attribute existence)
|
||||
model_config.quantization = getattr(config_or_args, 'quantization', None)
|
||||
# Note(tangbinhan): used for load_checkpoint
|
||||
model_config.pretrained_config.tensor_parallel_rank = parallel_config.tensor_parallel_rank
|
||||
model_config.pretrained_config.tensor_parallel_degree = parallel_config.tensor_parallel_size
|
||||
model_config.pretrained_config.is_mtp = False
|
||||
model_config.pretrained_config.head_dim = model_config.head_dim
|
||||
|
||||
# Update speculative config_or_args
|
||||
speculative_config.method = getattr(config_or_args, 'speculative_method', None)
|
||||
speculative_config.num_speculative_tokens = getattr(config_or_args, 'speculative_max_draft_token_num', 0)
|
||||
speculative_config.model_name_or_path = getattr(config_or_args, 'speculative_model_name_or_path', None)
|
||||
speculative_config.quantization = getattr(config_or_args, 'speculative_model_quantization', None)
|
||||
speculative_config.benchmark_mode = (
|
||||
getattr(config_or_args, "speculative_benchmark_mode", "false").lower() == "true"
|
||||
)
|
||||
|
||||
# Update parallel config
|
||||
parallel_config.engine_pid = getattr(config_or_args, 'engine_pid', None)
|
||||
parallel_config.model_name_or_path = config_or_args.model_name_or_path
|
||||
parallel_config.max_num_seqs = getattr(config_or_args, 'max_num_seqs', 0)
|
||||
parallel_config.max_block_num = getattr(config_or_args, 'total_block_num', 0)
|
||||
parallel_config.block_size = getattr(config_or_args, 'block_size', 64)
|
||||
parallel_config.pod_ip = getattr(config_or_args, 'pod_ip', None)
|
||||
parallel_config.engine_worker_queue_port = getattr(config_or_args, 'engine_worker_queue_port', 0)
|
||||
parallel_config.max_model_len = getattr(config_or_args, 'max_model_len', 0)
|
||||
model_config.max_seq_len = getattr(config_or_args, 'max_model_len', 0)
|
||||
model_config.max_length = getattr(config_or_args, 'max_model_len', 0)
|
||||
parallel_config.device_ids = getattr(config_or_args, 'device_ids', [])
|
||||
parallel_config.dtype = config_or_args.dtype
|
||||
parallel_config.enc_dec_block_num = getattr(config_or_args, 'enc_dec_block_num', 0)
|
||||
parallel_config.kv_cache_ratio = getattr(config_or_args, 'kv_cache_ratio', 1.0)
|
||||
parallel_config.first_token_id = getattr(config_or_args, 'first_token_id', None)
|
||||
parallel_config.gpu_memory_utilization = getattr(config_or_args, 'gpu_memory_utilization', 0.9)
|
||||
parallel_config.do_profile = getattr(config_or_args, 'do_profile', False)
|
||||
parallel_config.dynamic_load_weight = getattr(config_or_args, 'dynamic_load_weight', False)
|
||||
parallel_config.pad_token_id = getattr(config_or_args, 'pad_token_id', None)
|
||||
parallel_config.eos_tokens_lens = getattr(config_or_args, 'eos_tokens_lens', 0)
|
||||
parallel_config.enable_chunked_prefill = getattr(config_or_args, 'enable_chunked_prefill', False)
|
||||
parallel_config.max_num_batched_tokens = getattr(config_or_args, 'max_num_batched_tokens', 0)
|
||||
parallel_config.enable_prefix_caching = getattr(config_or_args, 'enable_prefix_caching', False)
|
||||
parallel_config.enable_custom_all_reduce = getattr(config_or_args, 'enable_custom_all_reduce', False)
|
||||
parallel_config.use_ep = getattr(config_or_args, 'enable_expert_parallell', False)
|
||||
parallel_config.tensor_parallel_degree = getattr(config_or_args, 'tensor_parallel_size', 1)
|
||||
parallel_config.expert_parallel_degree = getattr(config_or_args, 'expert_parallel_size', 1)
|
||||
parallel_config.splitwise_role = getattr(config_or_args, 'splitwise_role', None)
|
||||
parallel_config.guided_decoding_backend = getattr(config_or_args, 'guided_decoding_backend', None)
|
||||
parallel_config.disable_any_whitespace = getattr(config_or_args, 'disable_any_whitespace', False)
|
||||
|
||||
# Log parallel config info
|
||||
logger.info(f"parallel_config.use_ep {parallel_config.use_ep}")
|
||||
logger.info(f"parallel_config.tensor_parallel_degree {parallel_config.tensor_parallel_degree}")
|
||||
logger.info(f"splitwise_role {parallel_config.splitwise_role}")
|
||||
logger.info(
|
||||
f"parallel_config.tensor_parallel_size {parallel_config.tensor_parallel_size}"
|
||||
)
|
||||
logger.info(
|
||||
f"parallel_config.tensor_parallel_rank {parallel_config.tensor_parallel_rank}"
|
||||
)
|
||||
|
||||
# Set MoE phase based on splitwise role
|
||||
if parallel_config.splitwise_role == "mixed":
|
||||
parallel_config.moe_phase = MoEPhase.PREFILL
|
||||
elif parallel_config.splitwise_role == "prefill":
|
||||
parallel_config.moe_phase = MoEPhase.PREFILL
|
||||
elif parallel_config.splitwise_role == "decode":
|
||||
parallel_config.moe_phase = MoEPhase.DECODER
|
||||
elif parallel_config.splitwise_role is not None:
|
||||
raise NotImplementedError
|
||||
if getattr(model_config, 'num_hidden_layers', None) is None:
|
||||
raise ValueError("num_hidden_layers is None")
|
||||
|
||||
# Handle model architecture specific configurations
|
||||
num_key_value_heads = model_config_dict.get("num_key_value_heads", -1)
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = -1
|
||||
|
||||
# Calculate FFN hidden size
|
||||
if model_config_dict.get("ffn_hidden_size", None) is not None:
|
||||
ffn_hidden_size = model_config_dict["ffn_hidden_size"]
|
||||
elif model_config_dict.get("intermediate_size", None) is not None:
|
||||
ffn_hidden_size = model_config_dict["intermediate_size"]
|
||||
else:
|
||||
ffn_hidden_size = 4 * model_config_dict["hidden_size"]
|
||||
if model_config_dict["hidden_act"].lower() == "swiglu":
|
||||
if paddle.distributed.get_world_size() > 1:
|
||||
multiple_of = 8 * model_config_dict["num_attention_heads"]
|
||||
else:
|
||||
multiple_of = 4 * model_config_dict["num_attention_heads"]
|
||||
ffn_hidden_size = multiple_of * (
|
||||
(int(2 * ffn_hidden_size / 3) + multiple_of - 1) //
|
||||
multiple_of)
|
||||
|
||||
# Get number of layers
|
||||
num_layers = model_config_dict.get("num_layers", None) or model_config_dict.get(
|
||||
"num_hidden_layers", None)
|
||||
if num_layers is None:
|
||||
raise ValueError(f"num_layers<{num_layers}> is invalid")
|
||||
|
||||
if "moe_layer_start_index" in model_config_dict:
|
||||
moe_layer_start_index = model_config_dict["moe_layer_start_index"]
|
||||
use_moe = (
|
||||
isinstance(moe_layer_start_index, int)
|
||||
and moe_layer_start_index < num_layers
|
||||
) or (
|
||||
isinstance(moe_layer_start_index, list)
|
||||
and min(moe_layer_start_index) < num_layers
|
||||
)
|
||||
else:
|
||||
use_moe = False
|
||||
|
||||
# Update model config
|
||||
model_config.ffn_hidden_size = ffn_hidden_size
|
||||
model_config.num_layers = num_layers
|
||||
model_config.num_key_value_heads = num_key_value_heads
|
||||
model_config.start_layer_index = model_config_dict.get("start_layer_index", 0)
|
||||
|
||||
# Update MoE config
|
||||
moe_config.num_experts = model_config_dict.get("moe_num_experts", None)
|
||||
moe_config.moe_intermediate_size = model_config_dict.get("moe_intermediate_size", None)
|
||||
moe_config.top_k = model_config_dict.get("moe_k", model_config_dict.get("moe_topk", 8))
|
||||
moe_config.moe_num_shared_experts = model_config_dict.get("moe_num_shared_experts", 0)
|
||||
moe_config.moe_layer_start_index = model_config_dict.get("moe_layer_start_index", 0)
|
||||
moe_config.num_max_dispatch_tokens_per_rank = model_config_dict.get(
|
||||
"num_max_dispatch_tokens_per_rank", 256)
|
||||
moe_config.moe_use_aux_free = model_config_dict.get("moe_use_aux_free", False)
|
||||
|
||||
# Handle vocabulary size
|
||||
model_config.ori_vocab_size = model_config_dict.get("vocab_size", -1)
|
||||
archs = model_config_dict.get("architectures", [])
|
||||
if "Ernie4_5_ForCausalLM" in archs or "Ernie4_5_MoeForCausalLM" in archs:
|
||||
model_config.ori_vocab_size = getattr(config_or_args, 'ori_vocab_size', model_config.ori_vocab_size)
|
||||
|
||||
# Handle DeepseekV3 specific config
|
||||
if "DeepseekV3ForCausalLM" in model_config_dict.get("architectures", []):
|
||||
from paddleformers.transformers import AutoConfig
|
||||
model_config.deepseekv3 = AutoConfig.from_pretrained(
|
||||
config_or_args.model_name_or_path)
|
||||
|
||||
assert parallel_config.tensor_parallel_degree * parallel_config.expert_parallel_degree == ranks
|
||||
|
||||
parallel_config.tensor_parallel_rank = \
|
||||
local_rank % parallel_config.tensor_parallel_degree
|
||||
parallel_config.expert_parallel_rank = \
|
||||
int(local_rank / parallel_config.tensor_parallel_degree)
|
||||
|
||||
if parallel_config.use_ep:
|
||||
moe_config.num_experts_per_rank = \
|
||||
moe_config.num_experts // parallel_config.expert_parallel_degree
|
||||
moe_config.num_experts_start_offset = \
|
||||
parallel_config.expert_parallel_rank * moe_config.num_experts_per_rank
|
||||
|
||||
# For auto TP split
|
||||
model_config.tensor_parallel_degree = parallel_config.tensor_parallel_degree
|
||||
model_config.tensor_parallel_rank = parallel_config.tensor_parallel_rank
|
||||
model_config.use_ep = parallel_config.use_ep
|
||||
|
||||
if parallel_config.use_ep:
|
||||
model_config.num_experts_per_rank = moe_config.num_experts_per_rank
|
||||
model_config.num_experts_start_offset = moe_config.num_experts_start_offset
|
||||
|
||||
# Handle quantization config
|
||||
quantization_config = model_config_dict.get("quantization_config", None)
|
||||
quantization_config = model_config.quantization_config
|
||||
if not model_config.is_quantized:
|
||||
if quantization_config is not None:
|
||||
if "kv_cache_quant_type" not in quantization_config:
|
||||
@@ -772,16 +611,15 @@ def initialize_fd_config(config_or_args, ranks: int = 1, local_rank: int = 0) ->
|
||||
|
||||
if quantization_config is not None:
|
||||
quant_config_name = quantization_config["quantization"]
|
||||
elif getattr(config_or_args, 'quantization', None) != "None":
|
||||
elif args.quantization != "None":
|
||||
quantization_config = {}
|
||||
quant_config_name = getattr(config_or_args, 'quantization', None)
|
||||
quant_config_name = args.quantization
|
||||
quantization_config["quantization"] = quant_config_name
|
||||
# Special handling for Ernie models
|
||||
is_ernie = "Ernie4_5_ForCausalLM" in model_config_dict.get("architectures", []) or \
|
||||
"Ernie4_5_MoeForCausalLM" in model_config_dict.get("architectures", []) or \
|
||||
"Ernie4_5_VLMoeForConditionalGeneration" in model_config_dict.get(
|
||||
"architectures", [])
|
||||
if use_moe and quant_config_name == "wint4" and is_ernie:
|
||||
is_ernie = "Ernie4_5_ForCausalLM" in model_config.architectures or \
|
||||
"Ernie4_5_MoeForCausalLM" in model_config.architectures or \
|
||||
"Ernie4_5_VLMoeForConditionalGeneration" in model_config.architectures
|
||||
if quant_config_name == "wint4" and is_ernie:
|
||||
quantization_config["dense_quant_type"] = "wint8"
|
||||
quantization_config["moe_quant_type"] = "wint4"
|
||||
quantization_config["quantization"] = "mix_quant"
|
||||
@@ -806,38 +644,23 @@ def initialize_fd_config(config_or_args, ranks: int = 1, local_rank: int = 0) ->
|
||||
logger.info(
|
||||
"Model Status: Original (will apply online quantization)")
|
||||
|
||||
logger.info(f"Quantization Method: {getattr(config_or_args, 'quantization', 'None')}")
|
||||
logger.info(f"{quantization_config}")
|
||||
else:
|
||||
logger.info(
|
||||
"No quantization config found and use original weight and act dtype."
|
||||
)
|
||||
|
||||
model_config.enable_logprob = config_or_args.enable_logprob
|
||||
|
||||
model_config.architectures = model_config_dict.get("architectures")
|
||||
|
||||
# Update load config
|
||||
logger.info("===========load_config==============")
|
||||
# Handle load config (check for environment variable)
|
||||
load_config.use_fastsafetensor = int(envs.FD_USE_FASTSAFETENSOR) == 1
|
||||
load_config.dynamic_load_weight = getattr(config_or_args, 'dynamic_load_weight', False)
|
||||
load_config.load_strategy = getattr(config_or_args, 'load_strategy', None)
|
||||
logger.info(f"- Dynamic load weight: {load_config.dynamic_load_weight}")
|
||||
logger.info(f"- Load strategy: {load_config.load_strategy}")
|
||||
logger.info(f"- Use fastsafetensor: {load_config.use_fastsafetensor}")
|
||||
|
||||
# Create and return FDConfig
|
||||
fd_config = FDConfig(
|
||||
model_config=model_config,
|
||||
parallel_config=parallel_config,
|
||||
speculative_config=speculative_config,
|
||||
device_config=device_config,
|
||||
load_config=load_config,
|
||||
moe_config=moe_config,
|
||||
decoding_config=decoding_config,
|
||||
quant_config=quant_config,
|
||||
graph_opt_config=graph_opt_config
|
||||
)
|
||||
fd_config = FDConfig(model_config=model_config,
|
||||
parallel_config=parallel_config,
|
||||
speculative_config=speculative_config,
|
||||
device_config=device_config,
|
||||
load_config=load_config,
|
||||
decoding_config=decoding_config,
|
||||
quant_config=quant_config,
|
||||
graph_opt_config=graph_opt_config)
|
||||
|
||||
return fd_config
|
||||
|
||||
|
Reference in New Issue
Block a user