From 99a70fc72214e00b5e6862f2b5bfdcf0de003a3d Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Wed, 30 Jul 2025 11:41:23 +0800 Subject: [PATCH] unify parallel config (#3070) --- fastdeploy/config.py | 71 ++++++++++++++++++++- fastdeploy/engine/args_utils.py | 26 +++----- fastdeploy/engine/config.py | 109 ++------------------------------ 3 files changed, 87 insertions(+), 119 deletions(-) diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 83c18b512..0a09d1908 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -218,6 +218,9 @@ class ParallelConfig: self.tensor_parallel_size = 1 # TP degree self.expert_parallel_rank = 0 # EP rank ID self.expert_parallel_size = 1 # EP degree + self.data_parallel_size = 1 # DP degree + self.enable_expert_parallel = False + self.local_data_parallel_id = 0 # The embedding weight distributed on your gpu cards is divided by row or column. # Defaults to False means divide by row. When vocab_size can not be divided by world_size # but hidden_size can, we can consider split embedding weight by column. @@ -264,7 +267,8 @@ class ParallelConfig: for key, value in args.items(): if hasattr(self, key): setattr(self, key, value) - self.use_ep = args["expert_parallel_size"] > 1 + + self.use_ep = self.expert_parallel_size > 1 if self.splitwise_role == "mixed": self.moe_phase = MoEPhase(phase="prefill") elif self.splitwise_role == "prefill": @@ -284,6 +288,16 @@ class ParallelConfig: else: self.pd_disaggregation_mode = "None" + def print(self): + """ + print all config + + """ + logger.info("Parallel Configuration Information :") + for k, v in self.__dict__.items(): + logger.info("{:<20}:{:<6}{}".format(k, "", v)) + logger.info("=============================================================") + class SpeculativeConfig: """ @@ -829,6 +843,61 @@ class DecodingConfig: setattr(self, key, value) +class CommitConfig: + """ + Configuration for tracking version information from version.txt + + Attributes: + fastdeploy_commit: Full FastDeploy git commit hash + paddle_version: PaddlePaddle version string + paddle_commit: PaddlePaddle git commit hash + cuda_version: CUDA version string + compiler_version: CXX compiler version string + """ + + def __init__( + self, + ): + self.fastdeploy_commit: str = "" + self.paddle_version: str = "" + self.paddle_commit: str = "" + self.cuda_version: str = "" + self.compiler_version: str = "" + + self._load_from_version_file() + + def _load_from_version_file(self, file_path: str = "fastdeploy/version.txt"): + """Internal method to load version info from file""" + try: + with open(file_path, "r") as f: + for line in f: + line = line.strip() + if line.startswith("fastdeploy GIT COMMIT ID:"): + self.fastdeploy_commit = line.split(":")[1].strip() + elif line.startswith("Paddle version:"): + self.paddle_version = line.split(":")[1].strip() + elif line.startswith("Paddle GIT COMMIT ID:"): + self.paddle_commit = line.split(":")[1].strip() + elif line.startswith("CUDA version:"): + self.cuda_version = line.split(":")[1].strip() + elif line.startswith("CXX compiler version:"): + self.compiler_version = line.split(":")[1].strip() + except FileNotFoundError: + logger.info(f"Warning: Version file not found at {file_path}") + except Exception as e: + logger.info(f"Warning: Could not read version file - {e!s}") + + def print(self): + """ + print all config + + """ + logger.info("Fasedeploy Commit Information :") + for k, v in self.__dict__.items(): + logger.info("{:<20}:{:<6}{}".format(k, "", v)) + logger.info("=============================================================") + + @dataclass class FDConfig: """ diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index c6262ff70..fdf34514a 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -24,10 +24,12 @@ from fastdeploy.config import ( EarlyStopConfig, GraphOptimizationConfig, LoadConfig, + ModelConfig, + ParallelConfig, SpeculativeConfig, TaskOption, ) -from fastdeploy.engine.config import Config, ModelConfig, ParallelConfig +from fastdeploy.engine.config import Config from fastdeploy.scheduler.config import SchedulerConfig from fastdeploy.utils import FlexibleArgumentParser @@ -813,17 +815,6 @@ class EngineArgs: return SchedulerConfig(**params) - def create_parallel_config(self) -> ParallelConfig: - """ - Create and return a ParallelConfig object based on the current settings. - """ - return ParallelConfig( - tensor_parallel_size=self.tensor_parallel_size, - enable_expert_parallel=self.enable_expert_parallel, - data_parallel_size=self.data_parallel_size, - enable_custom_all_reduce=self.enable_custom_all_reduce, - ) - def create_graph_optimization_config(self) -> GraphOptimizationConfig: """ Create and retuan a GraphOptimizationConfig object based on the current settings. @@ -850,9 +841,6 @@ class EngineArgs: """ all_dict = asdict(self) model_cfg = ModelConfig(all_dict) - all_dict["model_cfg"] = model_cfg - cache_cfg = CacheConfig(all_dict) - load_cfg = LoadConfig(all_dict) if not model_cfg.is_unified_ckpt and hasattr(model_cfg, "tensor_parallel_size"): self.tensor_parallel_size = model_cfg.tensor_parallel_size @@ -861,6 +849,12 @@ class EngineArgs: self.max_num_batched_tokens = 2048 else: self.max_num_batched_tokens = self.max_model_len + + all_dict = asdict(self) + all_dict["model_cfg"] = model_cfg + cache_cfg = CacheConfig(all_dict) + load_cfg = LoadConfig(all_dict) + parallel_cfg = ParallelConfig(all_dict) scheduler_cfg = self.create_scheduler_config() speculative_cfg = self.create_speculative_config() graph_opt_cfg = self.create_graph_optimization_config() @@ -880,7 +874,7 @@ class EngineArgs: tokenizer=self.tokenizer, cache_config=cache_cfg, load_config=load_cfg, - parallel_config=self.create_parallel_config(), + parallel_config=parallel_cfg, max_model_len=self.max_model_len, tensor_parallel_size=self.tensor_parallel_size, max_num_seqs=self.max_num_seqs, diff --git a/fastdeploy/engine/config.py b/fastdeploy/engine/config.py index 06e34d042..a5ebb7745 100644 --- a/fastdeploy/engine/config.py +++ b/fastdeploy/engine/config.py @@ -15,116 +15,21 @@ import json import os -from dataclasses import dataclass from datetime import datetime from typing import Any, Dict, List, Optional -from fastdeploy.config import CacheConfig, LoadConfig, ModelConfig +from fastdeploy.config import ( + CacheConfig, + CommitConfig, + LoadConfig, + ModelConfig, + ParallelConfig, +) from fastdeploy.platforms import current_platform from fastdeploy.scheduler import SchedulerConfig from fastdeploy.utils import ceil_div, get_host_ip, is_port_available, llm_logger -class ParallelConfig: - """ - Configuration for parallelism. - - Attributes: - tensor_parallel_size (int): Size of tensor parallelism. - data_parallel_size (int): Size of data parallelism. - local_data_parallel_id (int): ID of local data parallel. - enable_expert_parallel (bool): Whether to enable expert parallel. - """ - - def __init__( - self, - tensor_parallel_size: int = 1, - data_parallel_size: int = 1, - enable_expert_parallel: bool = False, - enable_custom_all_reduce: bool = False, - ): - """ - Initialize the ParallelConfig class. - - Args: - tensor_parallel_size (int): Size of tensor parallelism. - data_parallel_size (int): Size of data parallelism. - local_data_parallel_id (int): ID of local data parallel. - enable_expert_parallel (bool): Whether to enable expert parallel. - """ - self.tensor_parallel_size = tensor_parallel_size - self.data_parallel_size = data_parallel_size - self.enable_expert_parallel = enable_expert_parallel - self.expert_parallel_size = data_parallel_size - self.local_data_parallel_id = 0 - self.enable_custom_all_reduce = enable_custom_all_reduce - - def print(self): - """ - print all config - - """ - llm_logger.info("Parallel Configuration Information :") - for k, v in self.__dict__.items(): - llm_logger.info("{:<20}:{:<6}{}".format(k, "", v)) - llm_logger.info("=============================================================") - - -@dataclass -class CommitConfig: - """ - Configuration for tracking version information from version.txt - - Attributes: - fastdeploy_commit: Full FastDeploy git commit hash - paddle_version: PaddlePaddle version string - paddle_commit: PaddlePaddle git commit hash - cuda_version: CUDA version string - compiler_version: CXX compiler version string - """ - - fastdeploy_commit: str = "" - paddle_version: str = "" - paddle_commit: str = "" - cuda_version: str = "" - compiler_version: str = "" - - def __post_init__(self): - """Automatically load version info when initialized""" - self._load_from_version_file() - - def _load_from_version_file(self, file_path: str = "fastdeploy/version.txt"): - """Internal method to load version info from file""" - try: - with open(file_path, "r") as f: - for line in f: - line = line.strip() - if line.startswith("fastdeploy GIT COMMIT ID:"): - self.fastdeploy_commit = line.split(":")[1].strip() - elif line.startswith("Paddle version:"): - self.paddle_version = line.split(":")[1].strip() - elif line.startswith("Paddle GIT COMMIT ID:"): - self.paddle_commit = line.split(":")[1].strip() - elif line.startswith("CUDA version:"): - self.cuda_version = line.split(":")[1].strip() - elif line.startswith("CXX compiler version:"): - self.compiler_version = line.split(":")[1].strip() - except FileNotFoundError: - llm_logger.info(f"Warning: Version file not found at {file_path}") - except Exception as e: - llm_logger.info(f"Warning: Could not read version file - {e!s}") - - def print(self): - """ - print all config - - """ - llm_logger.info("Fasedeploy Commit Information :") - for k, v in self.__dict__.items(): - llm_logger.info("{:<20}:{:<6}{}".format(k, "", v)) - llm_logger.info("=============================================================") - - class Config: """ Initial configuration class.