unify parallel config (#3070)

This commit is contained in:
YuanRisheng
2025-07-30 11:41:23 +08:00
committed by GitHub
parent 5ca684c762
commit 99a70fc722
3 changed files with 87 additions and 119 deletions

View File

@@ -218,6 +218,9 @@ class ParallelConfig:
self.tensor_parallel_size = 1 # TP degree
self.expert_parallel_rank = 0 # EP rank ID
self.expert_parallel_size = 1 # EP degree
self.data_parallel_size = 1 # DP degree
self.enable_expert_parallel = False
self.local_data_parallel_id = 0
# The embedding weight distributed on your gpu cards is divided by row or column.
# Defaults to False means divide by row. When vocab_size can not be divided by world_size
# but hidden_size can, we can consider split embedding weight by column.
@@ -264,7 +267,8 @@ class ParallelConfig:
for key, value in args.items():
if hasattr(self, key):
setattr(self, key, value)
self.use_ep = args["expert_parallel_size"] > 1
self.use_ep = self.expert_parallel_size > 1
if self.splitwise_role == "mixed":
self.moe_phase = MoEPhase(phase="prefill")
elif self.splitwise_role == "prefill":
@@ -284,6 +288,16 @@ class ParallelConfig:
else:
self.pd_disaggregation_mode = "None"
def print(self):
"""
print all config
"""
logger.info("Parallel Configuration Information :")
for k, v in self.__dict__.items():
logger.info("{:<20}:{:<6}{}".format(k, "", v))
logger.info("=============================================================")
class SpeculativeConfig:
"""
@@ -829,6 +843,61 @@ class DecodingConfig:
setattr(self, key, value)
class CommitConfig:
"""
Configuration for tracking version information from version.txt
Attributes:
fastdeploy_commit: Full FastDeploy git commit hash
paddle_version: PaddlePaddle version string
paddle_commit: PaddlePaddle git commit hash
cuda_version: CUDA version string
compiler_version: CXX compiler version string
"""
def __init__(
self,
):
self.fastdeploy_commit: str = ""
self.paddle_version: str = ""
self.paddle_commit: str = ""
self.cuda_version: str = ""
self.compiler_version: str = ""
self._load_from_version_file()
def _load_from_version_file(self, file_path: str = "fastdeploy/version.txt"):
"""Internal method to load version info from file"""
try:
with open(file_path, "r") as f:
for line in f:
line = line.strip()
if line.startswith("fastdeploy GIT COMMIT ID:"):
self.fastdeploy_commit = line.split(":")[1].strip()
elif line.startswith("Paddle version:"):
self.paddle_version = line.split(":")[1].strip()
elif line.startswith("Paddle GIT COMMIT ID:"):
self.paddle_commit = line.split(":")[1].strip()
elif line.startswith("CUDA version:"):
self.cuda_version = line.split(":")[1].strip()
elif line.startswith("CXX compiler version:"):
self.compiler_version = line.split(":")[1].strip()
except FileNotFoundError:
logger.info(f"Warning: Version file not found at {file_path}")
except Exception as e:
logger.info(f"Warning: Could not read version file - {e!s}")
def print(self):
"""
print all config
"""
logger.info("Fasedeploy Commit Information :")
for k, v in self.__dict__.items():
logger.info("{:<20}:{:<6}{}".format(k, "", v))
logger.info("=============================================================")
@dataclass
class FDConfig:
"""

View File

@@ -24,10 +24,12 @@ from fastdeploy.config import (
EarlyStopConfig,
GraphOptimizationConfig,
LoadConfig,
ModelConfig,
ParallelConfig,
SpeculativeConfig,
TaskOption,
)
from fastdeploy.engine.config import Config, ModelConfig, ParallelConfig
from fastdeploy.engine.config import Config
from fastdeploy.scheduler.config import SchedulerConfig
from fastdeploy.utils import FlexibleArgumentParser
@@ -813,17 +815,6 @@ class EngineArgs:
return SchedulerConfig(**params)
def create_parallel_config(self) -> ParallelConfig:
"""
Create and return a ParallelConfig object based on the current settings.
"""
return ParallelConfig(
tensor_parallel_size=self.tensor_parallel_size,
enable_expert_parallel=self.enable_expert_parallel,
data_parallel_size=self.data_parallel_size,
enable_custom_all_reduce=self.enable_custom_all_reduce,
)
def create_graph_optimization_config(self) -> GraphOptimizationConfig:
"""
Create and retuan a GraphOptimizationConfig object based on the current settings.
@@ -850,9 +841,6 @@ class EngineArgs:
"""
all_dict = asdict(self)
model_cfg = ModelConfig(all_dict)
all_dict["model_cfg"] = model_cfg
cache_cfg = CacheConfig(all_dict)
load_cfg = LoadConfig(all_dict)
if not model_cfg.is_unified_ckpt and hasattr(model_cfg, "tensor_parallel_size"):
self.tensor_parallel_size = model_cfg.tensor_parallel_size
@@ -861,6 +849,12 @@ class EngineArgs:
self.max_num_batched_tokens = 2048
else:
self.max_num_batched_tokens = self.max_model_len
all_dict = asdict(self)
all_dict["model_cfg"] = model_cfg
cache_cfg = CacheConfig(all_dict)
load_cfg = LoadConfig(all_dict)
parallel_cfg = ParallelConfig(all_dict)
scheduler_cfg = self.create_scheduler_config()
speculative_cfg = self.create_speculative_config()
graph_opt_cfg = self.create_graph_optimization_config()
@@ -880,7 +874,7 @@ class EngineArgs:
tokenizer=self.tokenizer,
cache_config=cache_cfg,
load_config=load_cfg,
parallel_config=self.create_parallel_config(),
parallel_config=parallel_cfg,
max_model_len=self.max_model_len,
tensor_parallel_size=self.tensor_parallel_size,
max_num_seqs=self.max_num_seqs,

View File

@@ -15,116 +15,21 @@
import json
import os
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, List, Optional
from fastdeploy.config import CacheConfig, LoadConfig, ModelConfig
from fastdeploy.config import (
CacheConfig,
CommitConfig,
LoadConfig,
ModelConfig,
ParallelConfig,
)
from fastdeploy.platforms import current_platform
from fastdeploy.scheduler import SchedulerConfig
from fastdeploy.utils import ceil_div, get_host_ip, is_port_available, llm_logger
class ParallelConfig:
"""
Configuration for parallelism.
Attributes:
tensor_parallel_size (int): Size of tensor parallelism.
data_parallel_size (int): Size of data parallelism.
local_data_parallel_id (int): ID of local data parallel.
enable_expert_parallel (bool): Whether to enable expert parallel.
"""
def __init__(
self,
tensor_parallel_size: int = 1,
data_parallel_size: int = 1,
enable_expert_parallel: bool = False,
enable_custom_all_reduce: bool = False,
):
"""
Initialize the ParallelConfig class.
Args:
tensor_parallel_size (int): Size of tensor parallelism.
data_parallel_size (int): Size of data parallelism.
local_data_parallel_id (int): ID of local data parallel.
enable_expert_parallel (bool): Whether to enable expert parallel.
"""
self.tensor_parallel_size = tensor_parallel_size
self.data_parallel_size = data_parallel_size
self.enable_expert_parallel = enable_expert_parallel
self.expert_parallel_size = data_parallel_size
self.local_data_parallel_id = 0
self.enable_custom_all_reduce = enable_custom_all_reduce
def print(self):
"""
print all config
"""
llm_logger.info("Parallel Configuration Information :")
for k, v in self.__dict__.items():
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
llm_logger.info("=============================================================")
@dataclass
class CommitConfig:
"""
Configuration for tracking version information from version.txt
Attributes:
fastdeploy_commit: Full FastDeploy git commit hash
paddle_version: PaddlePaddle version string
paddle_commit: PaddlePaddle git commit hash
cuda_version: CUDA version string
compiler_version: CXX compiler version string
"""
fastdeploy_commit: str = ""
paddle_version: str = ""
paddle_commit: str = ""
cuda_version: str = ""
compiler_version: str = ""
def __post_init__(self):
"""Automatically load version info when initialized"""
self._load_from_version_file()
def _load_from_version_file(self, file_path: str = "fastdeploy/version.txt"):
"""Internal method to load version info from file"""
try:
with open(file_path, "r") as f:
for line in f:
line = line.strip()
if line.startswith("fastdeploy GIT COMMIT ID:"):
self.fastdeploy_commit = line.split(":")[1].strip()
elif line.startswith("Paddle version:"):
self.paddle_version = line.split(":")[1].strip()
elif line.startswith("Paddle GIT COMMIT ID:"):
self.paddle_commit = line.split(":")[1].strip()
elif line.startswith("CUDA version:"):
self.cuda_version = line.split(":")[1].strip()
elif line.startswith("CXX compiler version:"):
self.compiler_version = line.split(":")[1].strip()
except FileNotFoundError:
llm_logger.info(f"Warning: Version file not found at {file_path}")
except Exception as e:
llm_logger.info(f"Warning: Could not read version file - {e!s}")
def print(self):
"""
print all config
"""
llm_logger.info("Fasedeploy Commit Information :")
for k, v in self.__dict__.items():
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
llm_logger.info("=============================================================")
class Config:
"""
Initial configuration class.