mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
unify parallel config (#3070)
This commit is contained in:
@@ -15,116 +15,21 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastdeploy.config import CacheConfig, LoadConfig, ModelConfig
|
||||
from fastdeploy.config import (
|
||||
CacheConfig,
|
||||
CommitConfig,
|
||||
LoadConfig,
|
||||
ModelConfig,
|
||||
ParallelConfig,
|
||||
)
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.scheduler import SchedulerConfig
|
||||
from fastdeploy.utils import ceil_div, get_host_ip, is_port_available, llm_logger
|
||||
|
||||
|
||||
class ParallelConfig:
|
||||
"""
|
||||
Configuration for parallelism.
|
||||
|
||||
Attributes:
|
||||
tensor_parallel_size (int): Size of tensor parallelism.
|
||||
data_parallel_size (int): Size of data parallelism.
|
||||
local_data_parallel_id (int): ID of local data parallel.
|
||||
enable_expert_parallel (bool): Whether to enable expert parallel.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tensor_parallel_size: int = 1,
|
||||
data_parallel_size: int = 1,
|
||||
enable_expert_parallel: bool = False,
|
||||
enable_custom_all_reduce: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the ParallelConfig class.
|
||||
|
||||
Args:
|
||||
tensor_parallel_size (int): Size of tensor parallelism.
|
||||
data_parallel_size (int): Size of data parallelism.
|
||||
local_data_parallel_id (int): ID of local data parallel.
|
||||
enable_expert_parallel (bool): Whether to enable expert parallel.
|
||||
"""
|
||||
self.tensor_parallel_size = tensor_parallel_size
|
||||
self.data_parallel_size = data_parallel_size
|
||||
self.enable_expert_parallel = enable_expert_parallel
|
||||
self.expert_parallel_size = data_parallel_size
|
||||
self.local_data_parallel_id = 0
|
||||
self.enable_custom_all_reduce = enable_custom_all_reduce
|
||||
|
||||
def print(self):
|
||||
"""
|
||||
print all config
|
||||
|
||||
"""
|
||||
llm_logger.info("Parallel Configuration Information :")
|
||||
for k, v in self.__dict__.items():
|
||||
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
||||
llm_logger.info("=============================================================")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommitConfig:
|
||||
"""
|
||||
Configuration for tracking version information from version.txt
|
||||
|
||||
Attributes:
|
||||
fastdeploy_commit: Full FastDeploy git commit hash
|
||||
paddle_version: PaddlePaddle version string
|
||||
paddle_commit: PaddlePaddle git commit hash
|
||||
cuda_version: CUDA version string
|
||||
compiler_version: CXX compiler version string
|
||||
"""
|
||||
|
||||
fastdeploy_commit: str = ""
|
||||
paddle_version: str = ""
|
||||
paddle_commit: str = ""
|
||||
cuda_version: str = ""
|
||||
compiler_version: str = ""
|
||||
|
||||
def __post_init__(self):
|
||||
"""Automatically load version info when initialized"""
|
||||
self._load_from_version_file()
|
||||
|
||||
def _load_from_version_file(self, file_path: str = "fastdeploy/version.txt"):
|
||||
"""Internal method to load version info from file"""
|
||||
try:
|
||||
with open(file_path, "r") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line.startswith("fastdeploy GIT COMMIT ID:"):
|
||||
self.fastdeploy_commit = line.split(":")[1].strip()
|
||||
elif line.startswith("Paddle version:"):
|
||||
self.paddle_version = line.split(":")[1].strip()
|
||||
elif line.startswith("Paddle GIT COMMIT ID:"):
|
||||
self.paddle_commit = line.split(":")[1].strip()
|
||||
elif line.startswith("CUDA version:"):
|
||||
self.cuda_version = line.split(":")[1].strip()
|
||||
elif line.startswith("CXX compiler version:"):
|
||||
self.compiler_version = line.split(":")[1].strip()
|
||||
except FileNotFoundError:
|
||||
llm_logger.info(f"Warning: Version file not found at {file_path}")
|
||||
except Exception as e:
|
||||
llm_logger.info(f"Warning: Could not read version file - {e!s}")
|
||||
|
||||
def print(self):
|
||||
"""
|
||||
print all config
|
||||
|
||||
"""
|
||||
llm_logger.info("Fasedeploy Commit Information :")
|
||||
for k, v in self.__dict__.items():
|
||||
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
||||
llm_logger.info("=============================================================")
|
||||
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Initial configuration class.
|
||||
|
Reference in New Issue
Block a user