mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 20:02:53 +08:00 
			
		
		
		
	Compare commits
	
		
			5 Commits
		
	
	
		
			release/2.
			...
			release/2.
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|   | 3ec126dc02 | ||
|   | 337d76f094 | ||
|   | ae2f78184d | ||
|   | 6851489425 | ||
|   | ea787d8f62 | 
| @@ -17,13 +17,15 @@ | ||||
| import paddle | ||||
| import paddle.distributed as dist | ||||
|  | ||||
|  | ||||
| @paddle.jit.marker.unified | ||||
| def tensor_model_parallel_all_reduce(input_: paddle.Tensor) -> paddle.Tensor: | ||||
|     """All-reduce the input tensor across model parallel group.""" | ||||
|     if paddle.in_dynamic_mode(): | ||||
|         hcg = dist.fleet.get_hybrid_communicate_group() | ||||
|         mp_group = hcg.get_model_parallel_group() | ||||
|         dist.all_reduce(input_, group=mp_group) | ||||
|     else: | ||||
|         dist.all_reduce(input_) | ||||
| try: | ||||
|     @paddle.jit.marker.unified | ||||
|     def tensor_model_parallel_all_reduce(input_: paddle.Tensor) -> paddle.Tensor: | ||||
|         """All-reduce the input tensor across model parallel group.""" | ||||
|         if paddle.in_dynamic_mode(): | ||||
|             hcg = dist.fleet.get_hybrid_communicate_group() | ||||
|             mp_group = hcg.get_model_parallel_group() | ||||
|             dist.all_reduce(input_, group=mp_group) | ||||
|         else: | ||||
|             dist.all_reduce(input_) | ||||
| except: | ||||
|     tensor_model_parallel_all_reduce=None | ||||
| @@ -17,6 +17,7 @@ | ||||
| import json | ||||
| import os | ||||
| from datetime import datetime | ||||
| from dataclasses import dataclass | ||||
| from typing import Any, Dict, List, Literal, Optional | ||||
|  | ||||
| from fastdeploy import envs | ||||
| @@ -467,7 +468,63 @@ class ParallelConfig: | ||||
|         llm_logger.info("Parallel Configuration Information :") | ||||
|         for k, v in self.__dict__.items(): | ||||
|             llm_logger.info("{:<20}:{:<6}{}".format(k, "", v)) | ||||
|         llm_logger.info("==================") | ||||
|         llm_logger.info( | ||||
|             "=============================================================") | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class CommitConfig: | ||||
|     """ | ||||
|     Configuration for tracking version information from version.txt | ||||
|  | ||||
|     Attributes: | ||||
|         fastdeploy_commit: Full FastDeploy git commit hash | ||||
|         paddle_version: PaddlePaddle version string | ||||
|         paddle_commit: PaddlePaddle git commit hash | ||||
|         cuda_version: CUDA version string | ||||
|         compiler_version: CXX compiler version string | ||||
|     """ | ||||
|     fastdeploy_commit: str = "" | ||||
|     paddle_version: str = "" | ||||
|     paddle_commit: str = "" | ||||
|     cuda_version: str = "" | ||||
|     compiler_version: str = "" | ||||
|  | ||||
|     def __post_init__(self): | ||||
|         """Automatically load version info when initialized""" | ||||
|         self._load_from_version_file() | ||||
|  | ||||
|     def _load_from_version_file(self, file_path: str = "fastdeploy/version.txt"): | ||||
|         """Internal method to load version info from file""" | ||||
|         try: | ||||
|             with open(file_path, 'r') as f: | ||||
|                 for line in f: | ||||
|                     line = line.strip() | ||||
|                     if line.startswith("fastdeploy GIT COMMIT ID:"): | ||||
|                         self.fastdeploy_commit = line.split(":")[1].strip() | ||||
|                     elif line.startswith("Paddle version:"): | ||||
|                         self.paddle_version = line.split(":")[1].strip() | ||||
|                     elif line.startswith("Paddle GIT COMMIT ID:"): | ||||
|                         self.paddle_commit = line.split(":")[1].strip() | ||||
|                     elif line.startswith("CUDA version:"): | ||||
|                         self.cuda_version = line.split(":")[1].strip() | ||||
|                     elif line.startswith("CXX compiler version:"): | ||||
|                         self.compiler_version = line.split(":")[1].strip() | ||||
|         except FileNotFoundError: | ||||
|             llm_logger.info(f"Warning: Version file not found at {file_path}") | ||||
|         except Exception as e: | ||||
|             llm_logger.info(f"Warning: Could not read version file - {str(e)}") | ||||
|  | ||||
|     def print(self): | ||||
|         """ | ||||
|         print all config | ||||
|  | ||||
|         """ | ||||
|         llm_logger.info("Fasedeploy Commit Information :") | ||||
|         for k, v in self.__dict__.items(): | ||||
|             llm_logger.info("{:<20}:{:<6}{}".format(k, "", v)) | ||||
|         llm_logger.info( | ||||
|             "=============================================================") | ||||
|  | ||||
|  | ||||
| class Config: | ||||
| @@ -502,6 +559,7 @@ class Config: | ||||
|         cache_config: CacheConfig, | ||||
|         scheduler_config: SchedulerConfig, | ||||
|         parallel_config: ParallelConfig, | ||||
|         commit_config: CommitConfig = CommitConfig(), | ||||
|         model_name_or_path: str = None, | ||||
|         tokenizer: str = None, | ||||
|         tensor_parallel_size: int = 8, | ||||
| @@ -561,6 +619,7 @@ class Config: | ||||
|         self.cache_config = cache_config | ||||
|         self.scheduler_config = scheduler_config | ||||
|         self.parallel_config = parallel_config | ||||
|         self.commit_config = commit_config | ||||
|         self.model_name_or_path = model_name_or_path | ||||
|         self.tokenizer = tokenizer | ||||
|         self.max_num_batched_tokens = max_num_batched_tokens | ||||
| @@ -749,7 +808,11 @@ class Config: | ||||
|             if k == "generation_config" and v is not None: | ||||
|                 for gck, gcv in v.to_dict().items(): | ||||
|                     llm_logger.info("{:<20}:{:<6}{}".format(gck, "", gcv)) | ||||
|             elif k == "cache_config" or k == "model_config" or k == "scheduler_config" or k == "parallel_config": | ||||
|             elif (k == "cache_config" or | ||||
|                   k == "model_config" or | ||||
|                   k == "scheduler_config" or | ||||
|                   k == "parallel_config" or | ||||
|                   k == "commit_config"): | ||||
|                 v.print() | ||||
|             else: | ||||
|                 llm_logger.info("{:<20}:{:<6}{}".format(k, "", v)) | ||||
|   | ||||
| @@ -143,7 +143,7 @@ class ErnieBotTokenizer(PretrainedTokenizer): | ||||
|  | ||||
|     def convert_tokens_to_string(self, tokens): | ||||
|         """Converts a sequence of tokens (string) in a single string.""" | ||||
|         spec_init() | ||||
|         self.spec_init() | ||||
|         current_sub_tokens = [] | ||||
|         out_string = "" | ||||
|         # prev_is_special = False | ||||
| @@ -216,7 +216,7 @@ class ErnieBotTokenizer(PretrainedTokenizer): | ||||
|         #     if isinstance(t, AddedToken) | ||||
|         # ) | ||||
|  | ||||
|         spec_init() | ||||
|         self.spec_init() | ||||
|         text, kwargs = self.prepare_for_tokenization(text, **kwargs) | ||||
|  | ||||
|         # TODO: should this be in the base class? | ||||
|   | ||||
| @@ -21,7 +21,11 @@ from dataclasses import dataclass, field | ||||
| from typing import List, Optional | ||||
|  | ||||
| import paddle | ||||
| from paddle.nn.functional.flash_attention import flash_attention_v3_varlen | ||||
|  | ||||
| try: | ||||
|     from paddle.nn.functional.flash_attention import flash_attention_v3_varlen | ||||
| except: | ||||
|     flash_attention_v3_varlen = None | ||||
|  | ||||
| from fastdeploy.config import FDConfig | ||||
| from fastdeploy.model_executor.layers.attention.attention import Attention | ||||
|   | ||||
| @@ -293,7 +293,7 @@ class ColumnParallelLinear(LinearBase): | ||||
|         ) | ||||
|         if self.nranks > 0: | ||||
|             # col parallel | ||||
|             _set_var_distributed(self.linear_weight, split_axis=-1) | ||||
|             _set_var_distributed(self.linear_weight, split_axis=1) | ||||
|  | ||||
|         self.linear_bias = None | ||||
|         if self.with_bias: | ||||
| @@ -304,7 +304,7 @@ class ColumnParallelLinear(LinearBase): | ||||
|             ) | ||||
|             if self.nranks > 0: | ||||
|                 # col parallel | ||||
|                 _set_var_distributed(self.linear_bias, split_axis=-1) | ||||
|                 _set_var_distributed(self.linear_bias, split_axis=1) | ||||
|  | ||||
|         # smooth quant | ||||
|         self.linear_shift = None | ||||
|   | ||||
| @@ -89,6 +89,7 @@ class FusedMoE(nn.Layer): | ||||
|         self.routed_scaling_factor = routed_scaling_factor | ||||
|  | ||||
|         moe_quant_config = fd_config.quant_config | ||||
|         self.moe_quant_type = None | ||||
|         if moe_quant_config: | ||||
|             self.quant_method = moe_quant_config.get_quant_method(self) | ||||
|             self.moe_quant_type = moe_quant_config.name() | ||||
| @@ -142,7 +143,7 @@ class FusedMoE(nn.Layer): | ||||
|         if self.moe_quant_type == "fp8": | ||||
|             #(TODO:gaoziyuan) | ||||
|             pass | ||||
|         else: | ||||
|         elif self.moe_quant_type == "wint8": | ||||
|             self.weight_dtype = "int8" | ||||
|             self.init_weight_only_scale() | ||||
|  | ||||
|   | ||||
| @@ -91,8 +91,11 @@ class DefaultModelLoader(BaseModelLoader): | ||||
|     def load_model(self, fd_config: FDConfig) -> nn.Layer: | ||||
|         context = paddle.LazyGuard() | ||||
|         architectures = fd_config.model_config.architectures[0] | ||||
|         # TODO(gongshaotian): Now, only support safetensor | ||||
|         model_class = MODEL_CLASSES[architectures] | ||||
|  | ||||
|         if fd_config.load_config.dynamic_load_weight: | ||||
|             # register rl model | ||||
|             import fastdeploy.rl | ||||
|             architectures = architectures + "RL" | ||||
|  | ||||
|         with context: | ||||
|             model_cls = ModelRegistry.get_class(architectures) | ||||
| @@ -104,6 +107,8 @@ class DefaultModelLoader(BaseModelLoader): | ||||
|         if fd_config.load_config.dynamic_load_weight: | ||||
|             return model | ||||
|  | ||||
|         # TODO(gongshaotian): Now, only support safetensor | ||||
|         model_class = MODEL_CLASSES[architectures] | ||||
|         state_dict = load_composite_checkpoint( | ||||
|             fd_config.parallel_config.model_name_or_path, | ||||
|             model_class, | ||||
|   | ||||
| @@ -36,8 +36,7 @@ def _find_py_files(root_dir): | ||||
|  | ||||
|  | ||||
| def auto_models_registry(dir_path, | ||||
|                          register_path="fastdeploy.model_executor.models", | ||||
|                          suffix=""): | ||||
|                          register_path="fastdeploy.model_executor.models"): | ||||
|     """ | ||||
|     auto registry all models in this folder | ||||
|     """ | ||||
| @@ -49,7 +48,7 @@ def auto_models_registry(dir_path, | ||||
|                 if inspect.isclass(attr) and issubclass( | ||||
|                         attr, | ||||
|                         ModelForCasualLM) and attr is not ModelForCasualLM: | ||||
|                     ModelRegistry.register(attr, suffix=suffix) | ||||
|                     ModelRegistry.register(attr) | ||||
|         except ImportError: | ||||
|             raise ImportError(f"{module_file=} import error") | ||||
|  | ||||
|   | ||||
| @@ -28,12 +28,12 @@ class ModelRegistry: | ||||
|     _registry = {} | ||||
|  | ||||
|     @classmethod | ||||
|     def register(cls, model_class, suffix=""): | ||||
|     def register(cls, model_class): | ||||
|         """register model class""" | ||||
|         if issubclass( | ||||
|                 model_class, | ||||
|                 ModelForCasualLM) and model_class is not ModelForCasualLM: | ||||
|             cls._registry[f"{model_class.name()}{suffix}"] = model_class | ||||
|             cls._registry[model_class.name()] = model_class | ||||
|         return model_class | ||||
|  | ||||
|     @classmethod | ||||
|   | ||||
| @@ -302,6 +302,7 @@ class Qwen2ForCausalLM(ModelForCasualLM): | ||||
|         """ | ||||
|         super(Qwen2ForCausalLM, self).__init__(fd_config) | ||||
|  | ||||
|         self.fd_config =fd_config | ||||
|         self.model = Qwen2Model(fd_config=fd_config) | ||||
|  | ||||
|         self.ori_vocab_size = fd_config.model_config.ori_vocab_size | ||||
|   | ||||
| @@ -13,10 +13,19 @@ | ||||
| # limitations under the License. | ||||
| """fastdeploy gpu ops""" | ||||
|  | ||||
| import os | ||||
| import sys | ||||
|  | ||||
| from fastdeploy.import_ops import import_custom_ops | ||||
|  | ||||
| PACKAGE = "fastdeploy.model_executor.ops.gpu" | ||||
|  | ||||
| import_custom_ops(PACKAGE, "..base.fastdeploy_base_ops", globals()) | ||||
| import_custom_ops(PACKAGE, ".fastdeploy_ops", globals()) | ||||
|  | ||||
|  | ||||
| def tolerant_import_error(): | ||||
|     class NoneModule: | ||||
|         def __getattr__(self, name): | ||||
|             return None | ||||
|  | ||||
|     sys.modules[__name__] = NoneModule() | ||||
|   | ||||
| @@ -17,4 +17,4 @@ import os | ||||
|  | ||||
| from fastdeploy.model_executor.models import auto_models_registry | ||||
|  | ||||
| auto_models_registry(os.path.dirname(__file__), "fastdeploy.rl", suffix="RL") | ||||
| auto_models_registry(os.path.dirname(__file__), "fastdeploy.rl") | ||||
|   | ||||
							
								
								
									
										108
									
								
								fastdeploy/rl/rollout_config.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										108
									
								
								fastdeploy/rl/rollout_config.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,108 @@ | ||||
| """ | ||||
| # Copyright (c) 2025  PaddlePaddle Authors. All Rights Reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License" | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| """ | ||||
|  | ||||
| from fastdeploy.worker.worker_process import initialize_fd_config | ||||
|  | ||||
|  | ||||
| class RolloutModelConfig: | ||||
|     def __init__( | ||||
|         self, | ||||
|         model_name_or_path: str, | ||||
|         max_model_len: int = 32768, | ||||
|         tensor_parallel_size: int = 4, | ||||
|         dynamic_load_weight: bool = True, | ||||
|         load_strategy: str = "meta", | ||||
|         enable_mm: bool = False, | ||||
|         # Default values for all other parameters | ||||
|         max_num_seqs: int = 34, | ||||
|         total_block_num: int = 2000, | ||||
|         block_size: int = 64, | ||||
|         engine_worker_queue_port: int = 9923, | ||||
|         device_ids: str = "0", | ||||
|         dtype: str = "bfloat16", | ||||
|         enc_dec_block_num: int = 1, | ||||
|         kv_cache_ratio: float = 0.7, | ||||
|         first_token_id: int = 1, | ||||
|         gpu_memory_utilization: float = 0.9, | ||||
|         engine_pid: int = None, | ||||
|         do_profile: bool = False, | ||||
|         pad_token_id: int = -1, | ||||
|         eos_tokens_lens: int = 2, | ||||
|         enable_chunked_prefill: bool = False, | ||||
|         speculative_method: str = None, | ||||
|         speculative_max_draft_token_num: int = 1, | ||||
|         speculative_model_name_or_path: str = "", | ||||
|         speculative_model_quantization: str = "WINT8", | ||||
|         max_num_batched_tokens: int = 2048, | ||||
|         enable_prefix_caching: bool = False, | ||||
|         splitwise_role: str = "mixed", | ||||
|         expert_parallel_size: int = 1, | ||||
|         enable_expert_parallell: bool = False, | ||||
|         ori_vocab_size: int = None, | ||||
|         quantization: str = "None", | ||||
|         enable_static_graph_inference: bool = False, | ||||
|         use_cudagraph: bool = False, | ||||
|         max_capture_batch_size: int = 64, | ||||
|         guided_decoding_backend: str = "off", | ||||
|         disable_any_whitespace: bool = True, | ||||
|     ): | ||||
|         # Required parameters | ||||
|         self.model_name_or_path = model_name_or_path | ||||
|         self.max_model_len = max_model_len | ||||
|         self.tensor_parallel_size = tensor_parallel_size | ||||
|         self.dynamic_load_weight = dynamic_load_weight | ||||
|         self.load_strategy = load_strategy | ||||
|         self.enable_mm = enable_mm | ||||
|  | ||||
|         # Optional parameters with defaults | ||||
|         self.max_num_seqs = max_num_seqs | ||||
|         self.total_block_num = total_block_num | ||||
|         self.block_size = block_size | ||||
|         self.engine_worker_queue_port = engine_worker_queue_port | ||||
|         self.device_ids = device_ids | ||||
|         self.dtype = dtype | ||||
|         self.enc_dec_block_num = enc_dec_block_num | ||||
|         self.kv_cache_ratio = kv_cache_ratio | ||||
|         self.first_token_id = first_token_id | ||||
|         self.gpu_memory_utilization = gpu_memory_utilization | ||||
|         self.engine_pid = engine_pid | ||||
|         self.do_profile = do_profile | ||||
|         self.pad_token_id = pad_token_id | ||||
|         self.eos_tokens_lens = eos_tokens_lens | ||||
|         self.enable_chunked_prefill = enable_chunked_prefill | ||||
|         self.speculative_method = speculative_method | ||||
|         self.speculative_max_draft_token_num = speculative_max_draft_token_num | ||||
|         self.speculative_model_name_or_path = speculative_model_name_or_path | ||||
|         self.speculative_model_quantization = speculative_model_quantization | ||||
|         self.max_num_batched_tokens = max_num_batched_tokens | ||||
|         self.enable_prefix_caching = enable_prefix_caching | ||||
|         self.splitwise_role = splitwise_role | ||||
|         self.expert_parallel_size = expert_parallel_size | ||||
|         self.enable_expert_parallell = enable_expert_parallell | ||||
|         self.ori_vocab_size = ori_vocab_size | ||||
|         self.quantization = quantization | ||||
|         self.enable_static_graph_inference = enable_static_graph_inference | ||||
|         self.use_cudagraph = use_cudagraph | ||||
|         self.max_capture_batch_size = max_capture_batch_size | ||||
|         self.guided_decoding_backend = guided_decoding_backend | ||||
|         self.disable_any_whitespace = disable_any_whitespace | ||||
|  | ||||
|     def __str__(self): | ||||
|         return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items()) | ||||
|  | ||||
|     def initialize(self): | ||||
|         """Initialize the final fd config""" | ||||
|         return initialize_fd_config(self) | ||||
| @@ -24,25 +24,18 @@ from fastdeploy.config import FDConfig | ||||
| from fastdeploy.model_executor.model_loader import ModelRegistry | ||||
| from fastdeploy.model_executor.models.ernie4_5_moe import \ | ||||
|     Ernie4_5_MoeForCausalLM | ||||
| from fastdeploy.model_executor.models.qwen2 import Qwen2PretrainedModel | ||||
| from fastdeploy.model_executor.models.qwen3 import Qwen3PretrainedModel | ||||
| from fastdeploy.model_executor.models.qwen3moe import Qwen3MoePretrainedModel | ||||
| from fastdeploy.model_executor.models.qwen2 import Qwen2ForCausalLM | ||||
| from fastdeploy.model_executor.models.qwen3 import Qwen3ForCausalLM | ||||
| from fastdeploy.model_executor.models.qwen3moe import Qwen3MoeForCausalLM | ||||
| from fastdeploy.rl.rollout_config import RolloutModelConfig | ||||
|  | ||||
| RL_MODEL_CLASSES = { | ||||
|     "Ernie4_5_MoeForCausalLMRL": Ernie4_5_MoeForCausalLM, | ||||
|     "Qwen2ForCausalLMRL": Qwen2PretrainedModel, | ||||
|     "Qwen3ForCausalLMRL": Qwen3PretrainedModel, | ||||
|     "Qwen3MoeForCausalLMRL": Qwen3MoePretrainedModel, | ||||
| } | ||||
|  | ||||
|  | ||||
| class RollOutModel(nn.Layer): | ||||
| class RolloutModel(nn.Layer): | ||||
|     """Main model class for rollout operations, supports multimodal components for train.""" | ||||
|  | ||||
|     def __init__(self, fd_config: FDConfig): | ||||
|     def __init__(self, rollout_model_config: RolloutModelConfig): | ||||
|         """Initialize with FastDeploy configuration.""" | ||||
|         super(RollOutModel, self).__init__() | ||||
|         self.fd_config = fd_config | ||||
|         super(RolloutModel, self).__init__() | ||||
|         self.fd_config = rollout_model_config.initialize() | ||||
|         self._init_models() | ||||
|  | ||||
|     def _init_models(self): | ||||
| @@ -90,9 +83,9 @@ class RollOutModel(nn.Layer): | ||||
|         all_params = {} | ||||
|         for model in self.rollout_models: | ||||
|             for name, param in model.state_dict().items(): | ||||
|                 logger.debug( | ||||
|                     f"Model param: {name}, shape={param.shape}, dtype={param.dtype}" | ||||
|                 ) | ||||
|                 # logger.debug( | ||||
|                 #     f"Model param: {name}, shape={param.shape}, dtype={param.dtype}" | ||||
|                 # ) | ||||
|                 all_params[name] = param | ||||
|         return all_params | ||||
|  | ||||
| @@ -123,11 +116,13 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM): | ||||
|         # Initialize mapping dictionary | ||||
|         infer_to_train = {} | ||||
|  | ||||
|         infer_base_name = "model" | ||||
|         train_base_name = "ernie" | ||||
|         # Static mappings (non-layer specific) | ||||
|         static_mappings = { | ||||
|             "model.embeddings.word_embeddings.weight": | ||||
|             "ernie.embed_tokens.weight", | ||||
|             "model.norm.ln_weight": "ernie.norm.weight", | ||||
|             f"{infer_base_name}.embeddings.word_embeddings.weight": | ||||
|             f"{train_base_name}.embed_tokens.weight", | ||||
|             f"{infer_base_name}.norm.ln_weight": f"{train_base_name}.norm.weight", | ||||
|             "lm_head.out_linear.weight": "lm_head.weight" | ||||
|         } | ||||
|         if self.fd_config.model_config.get("weight_sharing", False): | ||||
| @@ -135,53 +130,55 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM): | ||||
|             logger.debug("enable tie_word_embeddings") | ||||
|             static_mappings.pop("lm_head.out_linear.weight") | ||||
|         infer_to_train.update(static_mappings) | ||||
|         infer_base_name = "model.hidden_layers" | ||||
|  | ||||
|         infer_base_name = infer_base_name + ".hidden_layers" | ||||
|         train_base_name = train_base_name + ".layers" | ||||
|  | ||||
|         # Helper function to add layer mappings | ||||
|         def _add_layer_mappings(layer_idx, is_moe_layer=False): | ||||
|             # Handle special case for layer 0's input layernorm | ||||
|             for ph in place_holders: | ||||
|                 infer_key = f"{infer_base_name}.{layer_idx}.input_layernorm.ln_{ph}" | ||||
|                 train_key = f"ernie.layers.{layer_idx}.input_layernorm.{ph}" | ||||
|                 train_key = f"{train_base_name}.{layer_idx}.input_layernorm.{ph}" | ||||
|                 infer_to_train[infer_key] = train_key | ||||
|  | ||||
|             # Common attention mappings | ||||
|             for ph in place_holders: | ||||
|                 infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.qkv_proj.linear_{ph}"] = \ | ||||
|                     f"ernie.layers.{layer_idx}.self_attn.qkv_proj.{ph}" | ||||
|                     f"{train_base_name}.{layer_idx}.self_attn.qkv_proj.{ph}" | ||||
|  | ||||
|                 infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.o_proj.linear_{ph}"] = \ | ||||
|                     f"ernie.layers.{layer_idx}.self_attn.o_proj.{ph}" | ||||
|                     f"{train_base_name}.{layer_idx}.self_attn.o_proj.{ph}" | ||||
|  | ||||
|             # Post-attention layernorm | ||||
|             for ph in place_holders: | ||||
|                 infer_to_train[f"{infer_base_name}.{layer_idx}.post_attention_layernorm.ln_{ph}"] = \ | ||||
|                     f"ernie.layers.{layer_idx}.post_attention_layernorm.{ph}" | ||||
|                     f"{train_base_name}.{layer_idx}.post_attention_layernorm.{ph}" | ||||
|  | ||||
|             if not is_moe_layer: | ||||
|                 # Dense FFN mappings | ||||
|                 for ph in place_holders: | ||||
|                     infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.gate_up_proj.linear_{ph}"] = \ | ||||
|                         f"ernie.layers.{layer_idx}.mlp.up_gate_proj.{ph}" | ||||
|                         f"{train_base_name}.{layer_idx}.mlp.up_gate_proj.{ph}" | ||||
|  | ||||
|                     infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.down_proj.linear_{ph}"] = \ | ||||
|                         f"ernie.layers.{layer_idx}.mlp.down_proj.{ph}" | ||||
|                         f"{train_base_name}.{layer_idx}.mlp.down_proj.{ph}" | ||||
|             else: | ||||
|                 # MoE specific mappings | ||||
|                 infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.fused_moe.gate_weight"] = \ | ||||
|                     f"ernie.layers.{layer_idx}.mlp.gate.weight" | ||||
|                     f"{train_base_name}.{layer_idx}.mlp.gate.weight" | ||||
|  | ||||
|                 if self.fd_config.moe_config.moe_use_aux_free: | ||||
|                     infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.fused_moe.gate_correction_bias"] = \ | ||||
|                         f"ernie.layers.{layer_idx}.mlp.moe_statics.e_score_correction_bias" | ||||
|                         f"{train_base_name}.{layer_idx}.mlp.moe_statics.e_score_correction_bias" | ||||
|  | ||||
|                 # Support shared experts | ||||
|                 if self.fd_config.model_config.get( | ||||
|                         "moe_num_shared_experts") > 0: | ||||
|                     infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.shared_experts.gate_up_proj.linear_weight"] = \ | ||||
|                         f"ernie.layers.{layer_idx}.mlp.shared_experts.up_gate_proj.weight" | ||||
|                         f"{train_base_name}.{layer_idx}.mlp.shared_experts.up_gate_proj.weight" | ||||
|                     infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.shared_experts.down_proj.linear_weight"] = \ | ||||
|                         f"ernie.layers.{layer_idx}.mlp.shared_experts.down_proj.weight" | ||||
|                         f"{train_base_name}.{layer_idx}.mlp.shared_experts.down_proj.weight" | ||||
|  | ||||
|                 # MoE experts mappings | ||||
|                 for expert_idx in range(self.fd_config.moe_config.num_experts): | ||||
| @@ -191,7 +188,7 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM): | ||||
|                         if ffn1_key not in infer_to_train: | ||||
|                             infer_to_train[ffn1_key] = [] | ||||
|                         infer_to_train[ffn1_key].append( | ||||
|                             f"ernie.layers.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}" | ||||
|                             f"{train_base_name}.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}" | ||||
|                         ) | ||||
|  | ||||
|                         # FFN2 (down_proj) | ||||
| @@ -199,7 +196,7 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM): | ||||
|                         if ffn2_key not in infer_to_train: | ||||
|                             infer_to_train[ffn2_key] = [] | ||||
|                         infer_to_train[ffn2_key].append( | ||||
|                             f"ernie.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}" | ||||
|                             f"{train_base_name}.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}" | ||||
|                         ) | ||||
|  | ||||
|         # Process non-MoE layers | ||||
| @@ -213,3 +210,118 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM): | ||||
|             _add_layer_mappings(layer_idx, is_moe_layer=True) | ||||
|  | ||||
|         return infer_to_train | ||||
|  | ||||
|  | ||||
| class Qwen2ForCausalLMRL(Qwen2ForCausalLM): | ||||
|     """ | ||||
|     Qwen2ForCausalLMRL | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, fd_config: FDConfig): | ||||
|         """ | ||||
|         Args: | ||||
|             fd_config (FDConfig): Configurations for the LLM model. | ||||
|         """ | ||||
|         super(Qwen2ForCausalLMRL, self).__init__(fd_config) | ||||
|  | ||||
|     @classmethod | ||||
|     def name(self): | ||||
|         """name""" | ||||
|         return "Qwen2ForCausalLMRL" | ||||
|  | ||||
|     def get_name_mappings_to_training(self): | ||||
|         """Generate mapping between inference and training parameter for RL(donot delete!).""" | ||||
|         # Prepare placeholders | ||||
|         place_holders = ["weight"] | ||||
|  | ||||
|         # Initialize mapping dictionary | ||||
|         infer_to_train = {} | ||||
|  | ||||
|         infer_base_name = "model" | ||||
|         train_base_name = "qwen2" | ||||
|         # Static mappings (non-layer specific) | ||||
|         static_mappings = { | ||||
|             f"{infer_base_name}.embeddings.word_embeddings.weight": | ||||
|             f"{train_base_name}.embed_tokens.weight", | ||||
|             f"{infer_base_name}.norm.ln_weight": f"{train_base_name}.norm.weight", | ||||
|             "lm_head.out_linear.weight": "lm_head.weight" | ||||
|         } | ||||
|         infer_to_train.update(static_mappings) | ||||
|  | ||||
|         infer_base_name = infer_base_name + ".layers" | ||||
|         train_base_name = train_base_name + ".layers" | ||||
|  | ||||
|         # Helper function to add layer mappings | ||||
|         def _add_layer_mappings(layer_idx): | ||||
|             # Handle special case for layer 0's input layernorm and attn o_proj | ||||
|             for ph in place_holders: | ||||
|                 infer_key = f"{infer_base_name}.{layer_idx}.input_layernorm.ln_{ph}" | ||||
|                 train_key = f"{train_base_name}.{layer_idx}.input_layernorm.{ph}" | ||||
|                 infer_to_train[infer_key] = train_key | ||||
|  | ||||
|                 infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.o_proj.linear_{ph}"] = \ | ||||
|                     f"{train_base_name}.{layer_idx}.self_attn.o_proj.{ph}" | ||||
|  | ||||
|             # qwen qkv proj need bias | ||||
|             for ph in ["weight", "bias"]: | ||||
|                 infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.qkv_proj.linear_{ph}"] = \ | ||||
|                     f"{train_base_name}.{layer_idx}.self_attn.qkv_proj.{ph}" | ||||
|  | ||||
|             # Post-attention layernorm | ||||
|             for ph in place_holders: | ||||
|                 infer_to_train[f"{infer_base_name}.{layer_idx}.post_attention_layernorm.ln_{ph}"] = \ | ||||
|                     f"{train_base_name}.{layer_idx}.post_attention_layernorm.{ph}" | ||||
|  | ||||
|             # FFN mappings | ||||
|             for ph in place_holders: | ||||
|                 infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.gate_up_proj.linear_{ph}"] = \ | ||||
|                     f"{train_base_name}.{layer_idx}.mlp.gate_up_fused_proj.{ph}" | ||||
|  | ||||
|                 infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.down_proj.linear_{ph}"] = \ | ||||
|                     f"{train_base_name}.{layer_idx}.mlp.down_proj.{ph}" | ||||
|  | ||||
|         for layer_idx in range( | ||||
|                 self.fd_config.model_config.num_layers): | ||||
|             _add_layer_mappings(layer_idx) | ||||
|  | ||||
|         return infer_to_train | ||||
|  | ||||
|  | ||||
| class Qwen3MoeForCausalLMRL(Qwen3MoeForCausalLM): | ||||
|     """ | ||||
|     Qwen3MoeForCausalLMRL | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, fd_config: FDConfig): | ||||
|         """ | ||||
|         Args: | ||||
|             fd_config (FDConfig): Configurations for the LLM model. | ||||
|         """ | ||||
|         super(Qwen3MoeForCausalLMRL, self).__init__(fd_config) | ||||
|  | ||||
|     @classmethod | ||||
|     def name(self): | ||||
|         """name""" | ||||
|         return "Qwen3MoeForCausalLMRL" | ||||
|  | ||||
|     def get_name_mappings_to_training(self): | ||||
|         """Generate mapping between inference and training parameter for RL(donot delete!).""" | ||||
|         pass | ||||
|  | ||||
|  | ||||
| class Qwen3ForCausalLMRL(Qwen3ForCausalLM): | ||||
|     """ | ||||
|     Qwen3ForCausalLMRL | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, fd_config: FDConfig): | ||||
|         """ | ||||
|         Args: | ||||
|             fd_config (FDConfig): Configurations for the LLM model. | ||||
|         """ | ||||
|         super(Qwen3ForCausalLMRL, self).__init__(fd_config) | ||||
|  | ||||
|     @classmethod | ||||
|     def name(self): | ||||
|         """name""" | ||||
|         return "Qwen3ForCausalLMRL" | ||||
| @@ -511,7 +511,7 @@ def parse_args(): | ||||
|  | ||||
|     parser.add_argument("--quantization", | ||||
|                         type=str, | ||||
|                         default="", | ||||
|                         default="None", | ||||
|                         help="Quantization name for the model, currentlly support " \ | ||||
|                             "'wint4', 'wint8'," \ | ||||
|                             "default is None. The priority of this configuration "\ | ||||
| @@ -555,148 +555,169 @@ def parse_args(): | ||||
|     return args | ||||
|  | ||||
|  | ||||
| def initialize_fd_config(args: argparse.Namespace) -> FDConfig: | ||||
|     """Initialize FDConfig | ||||
|     TODO(gongshaotian): Unified all configs to FDConfig | ||||
| def initialize_fd_config(config) -> FDConfig: | ||||
|     """Initialize FDConfig from either RolloutModelConfig or argparse.Namespace | ||||
|  | ||||
|     Args: | ||||
|         config: Configuration object containing all parameters (either RolloutModelConfig or argparse.Namespace) | ||||
|  | ||||
|     Returns: | ||||
|         FDConfig: Initialized FastDeploy configuration object | ||||
|     """ | ||||
|     # NOTE(gongshaotian): From build stream line model | ||||
|     config, _ = ModelConfig.get_config_dict(args.model_name_or_path) | ||||
|     if 'num_experts' in config: | ||||
|         config['moe_num_experts'] = config.pop('num_experts') | ||||
|     # Get model config from model directory | ||||
|     model_config_dict, _ = ModelConfig.get_config_dict(config.model_name_or_path) | ||||
|  | ||||
|     if 'num_experts_per_tok' in config: | ||||
|         config['moe_topk'] = config.pop('num_experts_per_tok') | ||||
|     config["head_dim"] = config.get( | ||||
|         "head_dim", config["hidden_size"] // config["num_attention_heads"]) | ||||
|     config["rope_theta"] = config.get("rope_theta", 10000.0) | ||||
|     model_config = ModelConfig.from_dict(config) | ||||
|     # TODO Set `head_dim` again. Because `ModelConfig` class doesn't support feeding head_dim at all! | ||||
|     model_config.head_dim = config["head_dim"] | ||||
|     paddle.set_default_dtype(args.dtype) | ||||
|     # Handle MoE related configs | ||||
|     if 'num_experts' in model_config_dict: | ||||
|         model_config_dict['moe_num_experts'] = model_config_dict.pop('num_experts') | ||||
|     if 'num_experts_per_tok' in model_config_dict: | ||||
|         model_config_dict['moe_topk'] = model_config_dict.pop('num_experts_per_tok') | ||||
|  | ||||
|     # Set default values for model config | ||||
|     model_config_dict["head_dim"] = model_config_dict.get( | ||||
|         "head_dim", model_config_dict["hidden_size"] // model_config_dict["num_attention_heads"]) | ||||
|     model_config_dict["rope_theta"] = model_config_dict.get("rope_theta", 10000.0) | ||||
|  | ||||
|     # Create model config object | ||||
|     model_config = ModelConfig.from_dict(model_config_dict) | ||||
|     model_config.head_dim = model_config_dict["head_dim"] | ||||
|     paddle.set_default_dtype(config.dtype) | ||||
|  | ||||
|     # Initialize all config components | ||||
|     device_config = DeviceConfig() | ||||
|     # model_config = ModelConfig() | ||||
|  | ||||
|     decoding_config = DecodingConfig() | ||||
|  | ||||
|     speculative_config = SpeculativeConfig() | ||||
|     parallel_config = ParallelConfig() | ||||
|     load_config = LoadConfig() | ||||
|     moe_config = MoEConfig() | ||||
|     graph_opt_config = GraphOptimizationConfig( | ||||
|         args.enable_static_graph_inference, args.use_cudagraph, | ||||
|         args.max_capture_batch_size) | ||||
|     model_config.quantization = args.quantization | ||||
|  | ||||
|     # Update speculate config | ||||
|     speculative_config.method = args.speculative_method | ||||
|     speculative_config.num_speculative_tokens = args.speculative_max_draft_token_num | ||||
|     speculative_config.model_name_or_path = args.speculative_model_name_or_path | ||||
|     speculative_config.quantization = args.speculative_model_quantization | ||||
|     # Handle graph optimization config (check for attribute existence for backward compatibility) | ||||
|     enable_static_graph_inference = getattr(config, 'enable_static_graph_inference', False) | ||||
|     use_cudagraph = getattr(config, 'use_cudagraph', False) | ||||
|     max_capture_batch_size = getattr(config, 'max_capture_batch_size', 0) | ||||
|  | ||||
|     graph_opt_config = GraphOptimizationConfig( | ||||
|         enable_static_graph_inference, | ||||
|         use_cudagraph, | ||||
|         max_capture_batch_size | ||||
|     ) | ||||
|  | ||||
|     # Handle quantization (check for attribute existence) | ||||
|     model_config.quantization = getattr(config, 'quantization', None) | ||||
|  | ||||
|     # Update speculative config | ||||
|     speculative_config.method = getattr(config, 'speculative_method', None) | ||||
|     speculative_config.num_speculative_tokens = getattr(config, 'speculative_max_draft_token_num', 0) | ||||
|     speculative_config.model_name_or_path = getattr(config, 'speculative_model_name_or_path', None) | ||||
|     speculative_config.quantization = getattr(config, 'speculative_model_quantization', None) | ||||
|  | ||||
|     # Update parallel config | ||||
|     parallel_config.engine_pid = args.engine_pid | ||||
|     parallel_config.model_name_or_path = args.model_name_or_path | ||||
|     parallel_config.max_num_seqs = args.max_num_seqs | ||||
|     parallel_config.max_block_num = args.total_block_num | ||||
|     parallel_config.block_size = args.block_size | ||||
|     parallel_config.engine_worker_queue_port = args.engine_worker_queue_port | ||||
|     parallel_config.max_model_len = args.max_model_len | ||||
|     model_config.max_seq_len = args.max_model_len | ||||
|     model_config.max_length = args.max_model_len | ||||
|     parallel_config.device_ids = args.device_ids | ||||
|     parallel_config.dtype = args.dtype | ||||
|     parallel_config.enc_dec_block_num = args.enc_dec_block_num | ||||
|     parallel_config.kv_cache_ratio = args.kv_cache_ratio | ||||
|     parallel_config.first_token_id = args.first_token_id | ||||
|     parallel_config.gpu_memory_utilization = args.gpu_memory_utilization | ||||
|     parallel_config.engine_pid = args.engine_pid | ||||
|     parallel_config.do_profile = args.do_profile | ||||
|     parallel_config.dynamic_load_weight = args.dynamic_load_weight | ||||
|     parallel_config.pad_token_id = args.pad_token_id | ||||
|     parallel_config.eos_tokens_lens = args.eos_tokens_lens | ||||
|     parallel_config.enable_chunked_prefill = args.enable_chunked_prefill | ||||
|     parallel_config.max_num_batched_tokens = args.max_num_batched_tokens | ||||
|     parallel_config.enable_prefix_caching = args.enable_prefix_caching | ||||
|     parallel_config.engine_pid = getattr(config, 'engine_pid', None) | ||||
|     parallel_config.model_name_or_path = config.model_name_or_path | ||||
|     parallel_config.max_num_seqs = getattr(config, 'max_num_seqs', 0) | ||||
|     parallel_config.max_block_num = getattr(config, 'total_block_num', 0) | ||||
|     parallel_config.block_size = getattr(config, 'block_size', 0) | ||||
|     parallel_config.engine_worker_queue_port = getattr(config, 'engine_worker_queue_port', 0) | ||||
|     parallel_config.max_model_len = getattr(config, 'max_model_len', 0) | ||||
|     model_config.max_seq_len = getattr(config, 'max_model_len', 0) | ||||
|     model_config.max_length = getattr(config, 'max_model_len', 0) | ||||
|     parallel_config.device_ids = getattr(config, 'device_ids', []) | ||||
|     parallel_config.dtype = config.dtype | ||||
|     parallel_config.enc_dec_block_num = getattr(config, 'enc_dec_block_num', 0) | ||||
|     parallel_config.kv_cache_ratio = getattr(config, 'kv_cache_ratio', 1.0) | ||||
|     parallel_config.first_token_id = getattr(config, 'first_token_id', None) | ||||
|     parallel_config.gpu_memory_utilization = getattr(config, 'gpu_memory_utilization', 0.9) | ||||
|     parallel_config.engine_pid = getattr(config, 'engine_pid', None) | ||||
|     parallel_config.do_profile = getattr(config, 'do_profile', False) | ||||
|     parallel_config.dynamic_load_weight = getattr(config, 'dynamic_load_weight', False) | ||||
|     parallel_config.pad_token_id = getattr(config, 'pad_token_id', None) | ||||
|     parallel_config.eos_tokens_lens = getattr(config, 'eos_tokens_lens', 0) | ||||
|     parallel_config.enable_chunked_prefill = getattr(config, 'enable_chunked_prefill', False) | ||||
|     parallel_config.max_num_batched_tokens = getattr(config, 'max_num_batched_tokens', 0) | ||||
|     parallel_config.enable_prefix_caching = getattr(config, 'enable_prefix_caching', False) | ||||
|     parallel_config.use_ep = getattr(config, 'enable_expert_parallell', False) | ||||
|     parallel_config.tensor_parallel_degree = getattr(config, 'tensor_parallel_size', 1) | ||||
|     parallel_config.expert_parallel_degree = getattr(config, 'expert_parallel_size', 1) | ||||
|     parallel_config.splitwise_role = getattr(config, 'splitwise_role', None) | ||||
|     parallel_config.guided_decoding_backend = getattr(config, 'guided_decoding_backend', None) | ||||
|     parallel_config.disable_any_whitespace = getattr(config, 'disable_any_whitespace', False) | ||||
|  | ||||
|     parallel_config.use_ep = args.enable_expert_parallell | ||||
|     parallel_config.tensor_parallel_degree = args.tensor_parallel_size | ||||
|     parallel_config.expert_parallel_degree = args.expert_parallel_size | ||||
|     parallel_config.splitwise_role = args.splitwise_role | ||||
|     # Handle load config (check for environment variable) | ||||
|     load_config.use_fastsafetensor = int(envs.FD_USE_FASTSAFETENSOR) == 1 | ||||
|  | ||||
|     parallel_config.guided_decoding_backend = args.guided_decoding_backend | ||||
|     parallel_config.disable_any_whitespace = args.disable_any_whitespace | ||||
|  | ||||
|     # Log parallel config info | ||||
|     logger.info(f"parallel_config.use_ep {parallel_config.use_ep}") | ||||
|     logger.info( | ||||
|         f"parallel_config.tensor_parallel_degree {parallel_config.tensor_parallel_degree}" | ||||
|     ) | ||||
|     logger.info(f"args.splitwise_role {args.splitwise_role}") | ||||
|     logger.info(f"parallel_config.tensor_parallel_degree {parallel_config.tensor_parallel_degree}") | ||||
|     logger.info(f"splitwise_role {parallel_config.splitwise_role}") | ||||
|  | ||||
|     if args.splitwise_role == "mixed": | ||||
|     # Set MoE phase based on splitwise role | ||||
|     if parallel_config.splitwise_role == "mixed": | ||||
|         parallel_config.moe_phase = MoEPhase.PREFILL | ||||
|     elif args.splitwise_role == "prefill": | ||||
|     elif parallel_config.splitwise_role == "prefill": | ||||
|         parallel_config.moe_phase = MoEPhase.PREFILL | ||||
|     elif args.splitwise_role == "decode": | ||||
|     elif parallel_config.splitwise_role == "decode": | ||||
|         parallel_config.moe_phase = MoEPhase.DECODER | ||||
|     else: | ||||
|     elif parallel_config.splitwise_role is not None: | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     num_key_value_heads = config.get("num_key_value_heads", -1) | ||||
|     # Handle model architecture specific configurations | ||||
|     num_key_value_heads = model_config_dict.get("num_key_value_heads", -1) | ||||
|     if num_key_value_heads is None: | ||||
|         num_key_value_heads = -1 | ||||
|  | ||||
|     if config.get("ffn_hidden_size", None) is not None: | ||||
|         ffn_hidden_size = config["ffn_hidden_size"] | ||||
|     elif config.get("intermediate_size", None) is not None: | ||||
|         ffn_hidden_size = config["intermediate_size"] | ||||
|     # Calculate FFN hidden size | ||||
|     if model_config_dict.get("ffn_hidden_size", None) is not None: | ||||
|         ffn_hidden_size = model_config_dict["ffn_hidden_size"] | ||||
|     elif model_config_dict.get("intermediate_size", None) is not None: | ||||
|         ffn_hidden_size = model_config_dict["intermediate_size"] | ||||
|     else: | ||||
|         ffn_hidden_size = 4 * config["hidden_size"] | ||||
|         if config["hidden_act"].lower() == "swiglu": | ||||
|         ffn_hidden_size = 4 * model_config_dict["hidden_size"] | ||||
|         if model_config_dict["hidden_act"].lower() == "swiglu": | ||||
|             if paddle.distributed.get_world_size() > 1: | ||||
|                 multiple_of = 8 * config["num_attention_heads"] | ||||
|                 multiple_of = 8 * model_config_dict["num_attention_heads"] | ||||
|             else: | ||||
|                 multiple_of = 4 * config["num_attention_heads"] | ||||
|                 multiple_of = 4 * model_config_dict["num_attention_heads"] | ||||
|             ffn_hidden_size = multiple_of * ( | ||||
|                 (int(2 * ffn_hidden_size / 3) + multiple_of - 1) // | ||||
|                 multiple_of) | ||||
|  | ||||
|     num_layers = config.get("num_layers", None) or config.get( | ||||
|     # Get number of layers | ||||
|     num_layers = model_config_dict.get("num_layers", None) or model_config_dict.get( | ||||
|         "num_hidden_layers", None) | ||||
|     if num_layers is None: | ||||
|         raise ValueError(f"num_layers<{num_layers}> is invalid") | ||||
|  | ||||
|     use_moe = config.get("moe_layer_start_index", num_layers) < num_layers | ||||
|     use_moe = model_config_dict.get("moe_layer_start_index", num_layers) < num_layers | ||||
|  | ||||
|     # Update model config | ||||
|     model_config.ffn_hidden_size = ffn_hidden_size | ||||
|     model_config.num_layers = num_layers | ||||
|  | ||||
|     model_config.num_key_value_heads = num_key_value_heads | ||||
|     model_config.start_layer_index = config.get("start_layer_index", 0) | ||||
|     moe_config.num_experts = config.get("moe_num_experts", None) | ||||
|     moe_config.moe_intermediate_size = config.get("moe_intermediate_size", | ||||
|                                                   None) | ||||
|     moe_config.top_k = config.get("moe_k", config.get("moe_topk", 8)) | ||||
|     moe_config.moe_num_shared_experts = config.get("moe_num_shared_experts", 0) | ||||
|     moe_config.moe_layer_start_index = config.get("moe_layer_start_index", 0) | ||||
|     model_config.start_layer_index = model_config_dict.get("start_layer_index", 0) | ||||
|  | ||||
|     moe_config.num_max_dispatch_tokens_per_rank = config.get( | ||||
|     # Update MoE config | ||||
|     moe_config.num_experts = model_config_dict.get("moe_num_experts", None) | ||||
|     moe_config.moe_intermediate_size = model_config_dict.get("moe_intermediate_size", None) | ||||
|     moe_config.top_k = model_config_dict.get("moe_k", model_config_dict.get("moe_topk", 8)) | ||||
|     moe_config.moe_num_shared_experts = model_config_dict.get("moe_num_shared_experts", 0) | ||||
|     moe_config.moe_layer_start_index = model_config_dict.get("moe_layer_start_index", 0) | ||||
|     moe_config.num_max_dispatch_tokens_per_rank = model_config_dict.get( | ||||
|         "num_max_dispatch_tokens_per_rank", 256) | ||||
|     moe_config.moe_use_aux_free = config.get("moe_use_aux_free", False) | ||||
|     moe_config.moe_use_aux_free = model_config_dict.get("moe_use_aux_free", False) | ||||
|  | ||||
|     model_config.ori_vocab_size = config.get("vocab_size", -1) | ||||
|     if "Ernie4_5_ForCausalLM" in config.get("architectures"): | ||||
|         model_config.ori_vocab_size = args.ori_vocab_size | ||||
|     # Handle vocabulary size | ||||
|     model_config.ori_vocab_size = model_config_dict.get("vocab_size", -1) | ||||
|     if "Ernie4_5_ForCausalLM" in model_config_dict.get("architectures", []): | ||||
|         model_config.ori_vocab_size = getattr(config, 'ori_vocab_size', model_config.ori_vocab_size) | ||||
|  | ||||
|     if "DeepseekV3ForCausalLM" in config.get("architectures"): | ||||
|     # Handle DeepseekV3 specific config | ||||
|     if "DeepseekV3ForCausalLM" in model_config_dict.get("architectures", []): | ||||
|         from paddleformers.transformers import AutoConfig | ||||
|         model_config.deepseekv3 = AutoConfig.from_pretrained( | ||||
|             args.model_name_or_path) | ||||
|             config.model_name_or_path) | ||||
|  | ||||
|     #TODO(@yuanrisheng): kv_cache quant config can only be | ||||
|     # stored in model config file, which should be unified | ||||
|     quantization_config = config.get("quantization_config", None) | ||||
|     # Handle quantization config | ||||
|     quantization_config = model_config_dict.get("quantization_config", None) | ||||
|     if not model_config.is_quantized: | ||||
|         if quantization_config is not None: | ||||
|             if "kv_cache_quant_type" not in quantization_config: | ||||
| @@ -711,13 +732,13 @@ def initialize_fd_config(args: argparse.Namespace) -> FDConfig: | ||||
|  | ||||
|     if quantization_config is not None: | ||||
|         quant_config_name = quantization_config["quantization"] | ||||
|     elif args.quantization != "None": | ||||
|     elif getattr(config, 'quantization', None) != "None": | ||||
|         quantization_config = {} | ||||
|         quant_config_name = args.quantization | ||||
|         quant_config_name = getattr(config, 'quantization', None) | ||||
|         quantization_config["quantization"] = quant_config_name | ||||
|         # use some trick code for ernie model and will unify it in future. | ||||
|         is_ernie = "Ernie4_5_ForCausalLM" in config.get("architectures") or \ | ||||
|                     "Ernie4_5_MoeForCausalLM" in config.get("architectures") | ||||
|         # Special handling for Ernie models | ||||
|         is_ernie = "Ernie4_5_ForCausalLM" in model_config_dict.get("architectures", []) or \ | ||||
|                    "Ernie4_5_MoeForCausalLM" in model_config_dict.get("architectures", []) | ||||
|         if use_moe and quant_config_name == "wint4" and is_ernie: | ||||
|             quantization_config["dense_quant_type"] = "wint8" | ||||
|             quantization_config["moe_quant_type"] = "wint4" | ||||
| @@ -732,6 +753,7 @@ def initialize_fd_config(args: argparse.Namespace) -> FDConfig: | ||||
|         quant_cls = get_quantization_config(quant_config_name) | ||||
|         quant_config = quant_cls.from_config(quantization_config) | ||||
|  | ||||
|     # Log quantization info | ||||
|     logger.info("===========quantization_config==============") | ||||
|     if quant_config is not None: | ||||
|         if model_config.is_quantized: | ||||
| @@ -742,29 +764,33 @@ def initialize_fd_config(args: argparse.Namespace) -> FDConfig: | ||||
|             logger.info( | ||||
|                 "Model Status: Original (will apply online quantization)") | ||||
|  | ||||
|         logger.info(f"Quantization Method: {args.quantization or 'None'}") | ||||
|         logger.info(f"Quantization Method: {getattr(config, 'quantization', 'None')}") | ||||
|     else: | ||||
|         logger.info( | ||||
|             "No quantization config found and use original weight and act dtype." | ||||
|         ) | ||||
|  | ||||
|     model_config.architectures = config.get("architectures") | ||||
|     model_config.architectures = model_config_dict.get("architectures") | ||||
|  | ||||
|     # Update load config | ||||
|     logger.info("===========load_config==============") | ||||
|     load_config.dynamic_load_weight = args.dynamic_load_weight | ||||
|     load_config.load_strategy = args.load_strategy | ||||
|     load_config.dynamic_load_weight = getattr(config, 'dynamic_load_weight', False) | ||||
|     load_config.load_strategy = getattr(config, 'load_strategy', None) | ||||
|     logger.info(f"- Dynamic load weight: {load_config.dynamic_load_weight}") | ||||
|     logger.info(f"- Load strategy: {load_config.load_strategy}") | ||||
|  | ||||
|     fd_config = FDConfig(model_config=model_config, | ||||
|                          parallel_config=parallel_config, | ||||
|                          speculative_config=speculative_config, | ||||
|                          device_config=device_config, | ||||
|                          load_config=load_config, | ||||
|                          moe_config=moe_config, | ||||
|                          decoding_config=decoding_config, | ||||
|                          quant_config=quant_config, | ||||
|                          graph_opt_config=graph_opt_config) | ||||
|     # Create and return FDConfig | ||||
|     fd_config = FDConfig( | ||||
|         model_config=model_config, | ||||
|         parallel_config=parallel_config, | ||||
|         speculative_config=speculative_config, | ||||
|         device_config=device_config, | ||||
|         load_config=load_config, | ||||
|         moe_config=moe_config, | ||||
|         decoding_config=decoding_config, | ||||
|         quant_config=quant_config, | ||||
|         graph_opt_config=graph_opt_config | ||||
|     ) | ||||
|  | ||||
|     return fd_config | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user