Unify server-side and model-side Config (Part3) (#3047)

* merge model config

* fix arch

* fix rl
This commit is contained in:
YuanRisheng
2025-07-29 17:07:44 +08:00
committed by GitHub
parent 907d561523
commit 502ee92a0a
14 changed files with 116 additions and 199 deletions

View File

@@ -73,7 +73,13 @@ class ErnieArchitectures:
PRETRAINED_INIT_CONFIGURATION = {
"top_p": 1.0,
"temperature": 1.0,
"rope_theta": 10000.0,
"penalty_score": 1.0,
"frequency_score": 0.0,
"presence_score": 0.0,
"min_length": 1,
"num_key_value_heads": -1,
"start_layer_index": 0,
"moe_num_shared_experts": 0,
@@ -101,16 +107,7 @@ class ModelConfig:
self,
args,
):
# NOTE(gongshaotain): form _load_model_init_val()
self.top_p = 1.0
self.temperature = 1.0
self.rope_theta = 10000.0
self.penalty_score = 1.0
self.frequency_score = 0.0
self.presence_score = 0.0
self.min_length = 1
self.model_name_or_path = ""
self.model = ""
self.is_quantized = False
self.max_model_len = 0
self.dtype = ""
@@ -118,16 +115,13 @@ class ModelConfig:
self.enable_mm = False
self.enable_redundant_experts = False
self.redundant_experts_num = 0
self.max_stop_seqs_num = int(envs.FD_MAX_STOP_SEQS_NUM)
self.stop_seqs_max_len = int(envs.FD_STOP_SEQS_MAX_LEN)
self.quantization = None
for key, value in args.items():
if hasattr(self, key):
setattr(self, key, value)
assert self.model_name_or_path != ""
pretrained_config, _ = PretrainedConfig.get_config_dict(self.model_name_or_path)
assert self.model != ""
pretrained_config, _ = PretrainedConfig.get_config_dict(self.model)
self.pretrained_config = PretrainedConfig.from_dict(pretrained_config)
# set attribute from pretrained_config
@@ -149,6 +143,64 @@ class ModelConfig:
if ErnieArchitectures.contains_ernie_arch(self.architectures):
self.ori_vocab_size = args.get("ori_vocab_size", self.ori_vocab_size)
self.is_unified_ckpt = check_unified_ckpt(self.model)
self.override_name_from_config()
self.read_from_env()
def override_name_from_config(self):
"""
Override attribute names from the exported model's configuration.
"""
if not self.is_unified_ckpt and hasattr(self, "infer_model_mp_num"):
self.tensor_parallel_size = self.infer_model_mp_num
del self.infer_model_mp_num
if hasattr(self, "num_hidden_layers"):
if hasattr(self, "remove_tail_layer"):
if self.remove_tail_layer is True:
self.num_hidden_layers -= 1
elif isinstance(self.remove_tail_layer, int):
self.num_hidden_layers -= self.remove_tail_layer
if not hasattr(self, "mla_use_absorb"):
self.mla_use_absorb = False
def read_from_env(self):
"""
Read configuration information from environment variables and update the object's attributes.
If an attribute is not present or is an empty string in the environment variables, use the default value.
"""
self.max_stop_seqs_num = int(envs.FD_MAX_STOP_SEQS_NUM)
self.stop_seqs_max_len = int(envs.FD_STOP_SEQS_MAX_LEN)
def reset_config_value(key, value):
if not hasattr(self, key.lower()):
if os.getenv(key, None):
value = eval(os.getenv(key))
logger.info(f"Get parameter `{key}` = {value} from environment.")
else:
logger.info(f"Parameter `{key}` will use default value {value}.")
setattr(self, key.lower(), value)
reset_config_value("COMPRESSION_RATIO", 1.0)
reset_config_value("ROPE_THETA", 10000)
def _get_download_model(self, model_name, model_type="default"):
# TODO: Provide dynamic graph for self-downloading and save to the specified download directory.
pass
def print(self):
"""
Print all configuration information.
"""
logger.info("Model Configuration Information :")
for k, v in self.__dict__.items():
logger.info("{:<20}:{:<6}{}".format(k, "", v))
logger.info("=============================================================")
class ParallelConfig:
"""Configuration for the distributed execution."""
@@ -173,7 +225,6 @@ class ParallelConfig:
From old wersion worker args
TODO(gongshaotian): Reclassify
"""
self.model_name_or_path: str = "./output"
self.max_num_seqs: int = 34
# Set default block num for profile run
self.total_block_num: int = 2000
@@ -609,7 +660,7 @@ class CacheConfig:
self.enable_hierarchical_cache = True
if self.model_cfg is not None:
if hasattr(self.model_cfg, "quantization_config"):
if self.model_cfg.quantization_config is not None:
self.cache_dtype = self.model_cfg.quantization_config.get("kv_cache_quant_type", self.cache_dtype)
if (
hasattr(self.model_cfg, "num_key_value_heads")
@@ -631,7 +682,7 @@ class CacheConfig:
else:
byte_size = 2
self.each_token_cache_space = int(
self.model_cfg.num_layers * kv_num_head * self.model_cfg.head_dim * byte_size
self.model_cfg.num_hidden_layers * kv_num_head * self.model_cfg.head_dim * byte_size
)
self.bytes_per_block = int(self.each_token_cache_space * self.block_size)
self.bytes_per_layer_per_block = int(

View File

@@ -22,6 +22,7 @@ from typing import Any, Dict, List, Optional
from fastdeploy.config import (
CacheConfig,
GraphOptimizationConfig,
LoadConfig,
SpeculativeConfig,
TaskOption,
)
@@ -756,18 +757,6 @@ class EngineArgs:
"""
return cls(**{field.name: getattr(args, field.name) for field in dataclass_fields(cls)})
def create_model_config(self) -> ModelConfig:
"""
Create and return a ModelConfig object based on the current settings.
"""
return ModelConfig(
model_name_or_path=self.model,
config_json_file=self.model_config_name,
quantization=self.quantization,
dynamic_load_weight=self.dynamic_load_weight,
load_strategy=self.load_strategy,
)
def create_speculative_config(self) -> SpeculativeConfig:
""" """
speculative_args = asdict(self)
@@ -826,7 +815,12 @@ class EngineArgs:
"""
Create and return a Config object based on the current settings.
"""
model_cfg = self.create_model_config()
all_dict = asdict(self)
model_cfg = ModelConfig(all_dict)
all_dict["model_cfg"] = model_cfg
cache_cfg = CacheConfig(all_dict)
load_cfg = LoadConfig(all_dict)
if not model_cfg.is_unified_ckpt and hasattr(model_cfg, "tensor_parallel_size"):
self.tensor_parallel_size = model_cfg.tensor_parallel_size
if self.max_num_batched_tokens is None:
@@ -843,16 +837,13 @@ class EngineArgs:
self.tensor_parallel_size <= 1 and self.enable_custom_all_reduce
), "enable_custom_all_reduce must be used with tensor_parallel_size>1"
all_dict = asdict(self)
all_dict["model_cfg"] = model_cfg
cache_cfg = CacheConfig(all_dict)
return Config(
model_name_or_path=self.model,
model_config=model_cfg,
scheduler_config=scheduler_cfg,
tokenizer=self.tokenizer,
cache_config=cache_cfg,
load_config=load_cfg,
parallel_config=self.create_parallel_config(),
max_model_len=self.max_model_len,
tensor_parallel_size=self.tensor_parallel_size,

View File

@@ -19,141 +19,10 @@ from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, List, Optional
from fastdeploy import envs
from fastdeploy.config import CacheConfig
from fastdeploy.config import CacheConfig, LoadConfig, ModelConfig
from fastdeploy.platforms import current_platform
from fastdeploy.scheduler import SchedulerConfig
from fastdeploy.utils import (
ceil_div,
check_unified_ckpt,
get_host_ip,
is_port_available,
llm_logger,
)
class ModelConfig:
"""
Configuration class for the model.
Attributes:
model_dir (str): Directory path to the model.
is_unified_ckpt (bool): Flag indicating if the checkpoint is unified.
model_name_or_path (str): Name or path of the model.
"""
def __init__(
self,
model_name_or_path: str,
config_json_file: str = "config.json",
dynamic_load_weight: bool = False,
load_strategy: str = "ipc_snapshot",
quantization: str = None,
download_dir: Optional[str] = None,
):
"""
Initialize the ModelConfig class.
Args:
model_name_or_path (str): Name or path of the model.
config_json_file (str): Path to the configuration JSON file. Default is 'config.json'.
download_dir (Optional[str]): Directory to download model files. Default is None.
"""
self.model_dir = model_name_or_path
self.is_unified_ckpt = check_unified_ckpt(self.model_dir)
self.dynamic_load_weight = dynamic_load_weight
self.load_strategy = load_strategy
self.quantization = quantization
config_file = os.path.join(model_name_or_path, config_json_file)
if os.path.isfile(model_name_or_path):
try:
from paddleformers.transformers import AutoConfig
config = AutoConfig.from_pretrained(model_name_or_path)
config_dict = {k: v for k, v in vars(config).items() if not k.startswith("_")}
for key, value in config_dict.items():
setattr(self, key, value)
except Exception:
llm_logger.error(
"Don't support the current model, you can use `paddleformers` to register your model."
)
raise ValueError(
"Don't support the current model, you can use `paddleformers` to register your model."
)
else:
with open(config_file, "r", encoding="utf-8") as f:
config_dict = json.load(f)
for key, value in config_dict.items():
try:
setattr(self, key, value)
except Exception:
continue
if isinstance(self.architectures, list):
self.architectures = self.architectures[0]
self.model_name_or_path = model_name_or_path
self.override_name_from_config()
self.read_from_env()
def override_name_from_config(self):
"""
Override attribute names from the exported model's configuration.
"""
if not self.is_unified_ckpt and hasattr(self, "infer_model_mp_num"):
self.tensor_parallel_size = self.infer_model_mp_num
del self.infer_model_mp_num
if hasattr(self, "num_hidden_layers"):
if hasattr(self, "remove_tail_layer"):
if self.remove_tail_layer is True:
self.num_hidden_layers -= 1
elif isinstance(self.remove_tail_layer, int):
self.num_hidden_layers -= self.remove_tail_layer
self.num_layers = self.num_hidden_layers
del self.num_hidden_layers
if not hasattr(self, "mla_use_absorb"):
self.mla_use_absorb = False
if not hasattr(self, "head_dim"):
assert hasattr(self, "hidden_size") and hasattr(self, "num_attention_heads")
self.head_dim = self.hidden_size // self.num_attention_heads
def read_from_env(self):
"""
Read configuration information from environment variables and update the object's attributes.
If an attribute is not present or is an empty string in the environment variables, use the default value.
"""
self.max_stop_seqs_num = int(envs.FD_MAX_STOP_SEQS_NUM)
self.stop_seqs_max_len = int(envs.FD_STOP_SEQS_MAX_LEN)
def reset_config_value(key, value):
if not hasattr(self, key.lower()):
if os.getenv(key, None):
value = eval(os.getenv(key))
llm_logger.info(f"Get parameter `{key}` = {value} from environment.")
else:
llm_logger.info(f"Parameter `{key}` will use default value {value}.")
setattr(self, key.lower(), value)
reset_config_value("COMPRESSION_RATIO", 1.0)
reset_config_value("ROPE_THETA", 10000)
def _get_download_model(self, model_name, model_type="default"):
# TODO: Provide dynamic graph for self-downloading and save to the specified download directory.
pass
def print(self):
"""
Print all configuration information.
"""
llm_logger.info("Model Configuration Information :")
for k, v in self.__dict__.items():
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
llm_logger.info("=============================================================")
from fastdeploy.utils import ceil_div, get_host_ip, is_port_available, llm_logger
class ParallelConfig:
@@ -288,6 +157,7 @@ class Config:
cache_config: CacheConfig,
scheduler_config: SchedulerConfig,
parallel_config: ParallelConfig,
load_config: LoadConfig,
commit_config: CommitConfig = CommitConfig(),
model_name_or_path: str = None,
tokenizer: str = None,
@@ -345,6 +215,7 @@ class Config:
self.cache_config = cache_config
self.scheduler_config = scheduler_config
self.parallel_config = parallel_config
self.load_config = load_config
self.commit_config = commit_config
self.model_name_or_path = model_name_or_path
self.tokenizer = tokenizer

View File

@@ -1064,7 +1064,7 @@ class LLMEngine:
f" --devices {self.cfg.device_ids} {py_script}"
f" --max_num_seqs {self.cfg.max_num_seqs} --max_model_len {self.cfg.max_model_len}"
f" --gpu_memory_utilization {self.cfg.cache_config.gpu_memory_utilization}"
f" --model_name_or_path {self.cfg.model_name_or_path!s}"
f" --model {self.cfg.model_name_or_path!s}"
f" --device_ids {self.cfg.device_ids}"
f" --tensor_parallel_size {self.cfg.tensor_parallel_size}"
f" --engine_worker_queue_port {self.cfg.engine_worker_queue_port!s}"
@@ -1084,7 +1084,7 @@ class LLMEngine:
f" --speculative_config '{self.cfg.speculative_config.to_json_string()}'"
f" --graph_optimization_config '{self.cfg.graph_optimization_config.to_json_string()}'"
f" --guided_decoding_backend {self.cfg.guided_decoding_backend}"
f" --load_strategy {self.cfg.model_config.load_strategy}"
f" --load_strategy {self.cfg.load_config.load_strategy}"
)
worker_append_flag = {
@@ -1092,7 +1092,7 @@ class LLMEngine:
"enable_prefix_caching": self.cfg.cache_config.enable_prefix_caching,
"enable_chunked_prefill": self.cfg.cache_config.enable_chunked_prefill,
"do_profile": self.do_profile,
"dynamic_load_weight": self.cfg.model_config.dynamic_load_weight,
"dynamic_load_weight": self.cfg.load_config.dynamic_load_weight,
"disable_any_whitespace": self.cfg.disable_any_whitespace,
"enable_custom_all_reduce": self.cfg.parallel_config.enable_custom_all_reduce,
"enable_logprob": self.cfg.enable_logprob,
@@ -1231,9 +1231,9 @@ class LLMEngine:
elif (match := re.search(r"Start load layer (\d+)", line)) or (
match := re.search(r"set state for layer (\d+)", line)
):
progress = eval(match.group(1)) * 1.0 / self.cfg.model_config.num_layers
progress = eval(match.group(1)) * 1.0 / self.cfg.model_config.num_hidden_layers
self.worker_init_status["layer_loadding"] = progress
if self.worker_init_status["layer_loadding"] == self.cfg.model_config.num_layers - 1:
if self.worker_init_status["layer_loadding"] == self.cfg.model_config.num_hidden_layers - 1:
self.worker_init_status["finished"] = True
self.checking_worker_status_thread = threading.Thread(target=detect_thread, daemon=True)

View File

@@ -70,7 +70,7 @@ class InputPreprocessor:
reasoning_parser_obj = None
if self.reasoning_parser:
reasoning_parser_obj = ReasoningParserManager.get_reasoning_parser(self.reasoning_parser)
architectures = ModelConfig(self.model_name_or_path).architectures
architectures = ModelConfig({"model": self.model_name_or_path}).architectures[0]
if not self.enable_mm:
if not ErnieArchitectures.contains_ernie_arch(architectures):
from fastdeploy.input.text_processor import DataProcessor

View File

@@ -270,7 +270,7 @@ class BackendBase:
from transformers import AutoTokenizer, PreTrainedTokenizerFast
tokenizer = AutoTokenizer.from_pretrained(
self.fd_config.parallel_config.model_name_or_path,
self.fd_config.model_config.model,
use_fast=False,
)
@@ -289,14 +289,14 @@ class BackendBase:
for i in range(len(vocab_file_names)):
if os.path.exists(
os.path.join(
self.fd_config.parallel_config.model_name_or_path,
self.fd_config.model_config.model,
vocab_file_names[i],
)
):
ErnieBotTokenizer.vocab_files_names["vocab_file"] = vocab_file_names[i]
break
tokenizer = ErnieBotTokenizer.from_pretrained(self.fd_config.parallel_config.model_name_or_path)
tokenizer = ErnieBotTokenizer.from_pretrained(self.fd_config.model_config.model)
return tokenizer
except Exception as e:

View File

@@ -85,18 +85,22 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
up_gate_proj_weight_scale.append(
get_tensor(
(
state_dict.pop(up_gate_proj_expert_weight_scale_key_name)
if up_gate_proj_expert_weight_scale_key_name in state_dict
else up_gate_proj_expert_weight_scale_key_name,
layer.fd_config.parallel_config.model_name_or_path,
else up_gate_proj_expert_weight_scale_key_name
),
layer.fd_config.model_config.model,
)
)
down_proj_weight_scale.append(
get_tensor(
(
state_dict.pop(down_proj_expert_weight_scale_key_name)
if down_proj_expert_weight_scale_key_name in state_dict
else down_proj_expert_weight_scale_key_name,
layer.fd_config.parallel_config.model_name_or_path,
else down_proj_expert_weight_scale_key_name
),
layer.fd_config.model_config.model,
)
)

View File

@@ -265,7 +265,7 @@ class FusedMoE(nn.Layer):
if up_gate_proj_expert_weight_key_name in state_dict
else up_gate_proj_expert_weight_key_name
),
self.fd_config.parallel_config.model_name_or_path,
self.fd_config.model_config.model,
)
)
down_proj_weights.append(
@@ -275,7 +275,7 @@ class FusedMoE(nn.Layer):
if down_proj_expert_weight_key_name in state_dict
else down_proj_expert_weight_key_name
),
self.fd_config.parallel_config.model_name_or_path,
self.fd_config.model_config.model,
)
)
else:
@@ -291,7 +291,7 @@ class FusedMoE(nn.Layer):
if gate_expert_weight_key_name in state_dict
else gate_expert_weight_key_name
),
self.fd_config.parallel_config.model_name_or_path,
self.fd_config.model_config.model,
)
up = get_tensor(
(
@@ -299,7 +299,7 @@ class FusedMoE(nn.Layer):
if up_expert_weight_key_name in state_dict
else up_expert_weight_key_name
),
self.fd_config.parallel_config.model_name_or_path,
self.fd_config.model_config.model,
)
up_gate_proj_weights.append(paddle.concat([gate, up], axis=-1))
down_proj_weights.append(
@@ -309,7 +309,7 @@ class FusedMoE(nn.Layer):
if down_proj_expert_weight_key_name in state_dict
else down_proj_expert_weight_key_name
),
self.fd_config.parallel_config.model_name_or_path,
self.fd_config.model_config.model,
)
)
return up_gate_proj_weights, down_proj_weights, logical_expert_ids

View File

@@ -111,7 +111,7 @@ class DefaultModelLoader(BaseModelLoader):
# TODO(gongshaotian): Now, only support safetensor
model_class = MODEL_CLASSES[architectures]
state_dict = load_composite_checkpoint(
fd_config.parallel_config.model_name_or_path,
fd_config.model_config.model,
model_class,
fd_config,
return_numpy=True,

View File

@@ -82,7 +82,7 @@ class DynamicWeightManager:
def _update_ipc_snapshot(self):
"""Update using IPC snapshot strategy for elastic recovery."""
model_path = os.path.join(
self.parallel_config.model_name_or_path,
self.model_config.model,
f"model_state.tp0{self.meta_src_id}.pdparams",
)

View File

@@ -60,7 +60,7 @@ class RolloutModelConfig:
local_rank: int = 0,
):
# Required parameters
self.model_name_or_path = model_name_or_path
self.model = model_name_or_path
self.max_model_len = max_model_len
self.tensor_parallel_size = tensor_parallel_size
self.dynamic_load_weight = dynamic_load_weight

View File

@@ -73,7 +73,7 @@ class MTPProposer(Proposer):
self.model_config.architectures[0] = "Ernie4_5_MTPForCausalLM"
self.speculative_config.sharing_model = main_model
self.model_config.num_hidden_layers = 1
self.parallel_config.model_name_or_path = self.speculative_config.model
self.model_config.model = self.speculative_config.model
self.model_config.pretrained_config.prefix_name = "ernie.mtp_block"
if self.speculative_config.quantization != "":
self.model_config.quantization = self.speculative_config.quantization

View File

@@ -1484,8 +1484,8 @@ class GPUModelRunner(ModelRunnerBase):
def _init_image_preprocess(self) -> None:
processor = DataProcessor(
tokenizer_name=self.parallel_config.model_name_or_path,
image_preprocessor_name=str(self.parallel_config.model_name_or_path),
tokenizer_name=self.model_config.model,
image_preprocessor_name=str(self.model_config.model),
)
processor.eval()
image_preprocess = processor.image_preprocessor

View File

@@ -101,7 +101,7 @@ def init_distributed_environment(seed: int = 20) -> Tuple[int, int]:
def update_fd_config_for_mm(fd_config: FDConfig) -> None:
if fd_config.model_config.enable_mm:
tokenizer = ErnieBotTokenizer.from_pretrained(
fd_config.parallel_config.model_name_or_path,
fd_config.model_config.model,
model_max_length=fd_config.parallel_config.max_model_len,
padding_side="right",
use_fast=False,
@@ -439,7 +439,7 @@ def parse_args():
parser = argparse.ArgumentParser("FastDeploy LLM Inference")
parser.add_argument(
"-m",
"--model_name_or_path",
"--model",
type=str,
default="./output",
help="model dir",