mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 00:06:38 +08:00
unify parallel config (#3070)
This commit is contained in:
@@ -24,10 +24,12 @@ from fastdeploy.config import (
|
||||
EarlyStopConfig,
|
||||
GraphOptimizationConfig,
|
||||
LoadConfig,
|
||||
ModelConfig,
|
||||
ParallelConfig,
|
||||
SpeculativeConfig,
|
||||
TaskOption,
|
||||
)
|
||||
from fastdeploy.engine.config import Config, ModelConfig, ParallelConfig
|
||||
from fastdeploy.engine.config import Config
|
||||
from fastdeploy.scheduler.config import SchedulerConfig
|
||||
from fastdeploy.utils import FlexibleArgumentParser
|
||||
|
||||
@@ -813,17 +815,6 @@ class EngineArgs:
|
||||
|
||||
return SchedulerConfig(**params)
|
||||
|
||||
def create_parallel_config(self) -> ParallelConfig:
|
||||
"""
|
||||
Create and return a ParallelConfig object based on the current settings.
|
||||
"""
|
||||
return ParallelConfig(
|
||||
tensor_parallel_size=self.tensor_parallel_size,
|
||||
enable_expert_parallel=self.enable_expert_parallel,
|
||||
data_parallel_size=self.data_parallel_size,
|
||||
enable_custom_all_reduce=self.enable_custom_all_reduce,
|
||||
)
|
||||
|
||||
def create_graph_optimization_config(self) -> GraphOptimizationConfig:
|
||||
"""
|
||||
Create and retuan a GraphOptimizationConfig object based on the current settings.
|
||||
@@ -850,9 +841,6 @@ class EngineArgs:
|
||||
"""
|
||||
all_dict = asdict(self)
|
||||
model_cfg = ModelConfig(all_dict)
|
||||
all_dict["model_cfg"] = model_cfg
|
||||
cache_cfg = CacheConfig(all_dict)
|
||||
load_cfg = LoadConfig(all_dict)
|
||||
|
||||
if not model_cfg.is_unified_ckpt and hasattr(model_cfg, "tensor_parallel_size"):
|
||||
self.tensor_parallel_size = model_cfg.tensor_parallel_size
|
||||
@@ -861,6 +849,12 @@ class EngineArgs:
|
||||
self.max_num_batched_tokens = 2048
|
||||
else:
|
||||
self.max_num_batched_tokens = self.max_model_len
|
||||
|
||||
all_dict = asdict(self)
|
||||
all_dict["model_cfg"] = model_cfg
|
||||
cache_cfg = CacheConfig(all_dict)
|
||||
load_cfg = LoadConfig(all_dict)
|
||||
parallel_cfg = ParallelConfig(all_dict)
|
||||
scheduler_cfg = self.create_scheduler_config()
|
||||
speculative_cfg = self.create_speculative_config()
|
||||
graph_opt_cfg = self.create_graph_optimization_config()
|
||||
@@ -880,7 +874,7 @@ class EngineArgs:
|
||||
tokenizer=self.tokenizer,
|
||||
cache_config=cache_cfg,
|
||||
load_config=load_cfg,
|
||||
parallel_config=self.create_parallel_config(),
|
||||
parallel_config=parallel_cfg,
|
||||
max_model_len=self.max_model_len,
|
||||
tensor_parallel_size=self.tensor_parallel_size,
|
||||
max_num_seqs=self.max_num_seqs,
|
||||
|
Reference in New Issue
Block a user