""" # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License" # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ import json from dataclasses import asdict, dataclass from dataclasses import fields as dataclass_fields from typing import Any, Dict, List, Optional from fastdeploy.engine.config import (CacheConfig, Config, GraphOptimizationConfig, ModelConfig, ParallelConfig, SpeculativeConfig, TaskOption) from fastdeploy.scheduler.config import SchedulerConfig from fastdeploy.utils import FlexibleArgumentParser def nullable_str(x: str) -> Optional[str]: """ Convert an empty string to None, preserving other string values. """ return x if x else None @dataclass class EngineArgs: # Model configuration parameters model: str = "baidu/ernie-45-turbo" """ The name or path of the model to be used. """ model_config_name: Optional[str] = "config.json" """ The name of the model configuration file. """ tokenizer: str = None """ The name or path of the tokenizer (defaults to model path if not provided). """ max_model_len: int = 2048 """ Maximum context length supported by the model. """ tensor_parallel_size: int = 1 """ Degree of tensor parallelism. """ block_size: int = 64 """ Number of tokens in one processing block. """ task: TaskOption = "generate" """ The task to be executed by the model. """ max_num_seqs: int = 8 """ Maximum number of sequences per iteration. """ mm_processor_kwargs: Optional[Dict[str, Any]] = None """ Additional keyword arguments for the multi-modal processor. """ limit_mm_per_prompt: Optional[Dict[str, Any]] = None """ Limitation of numbers of multi-modal data. """ reasoning_parser: str = None """ specifies the reasoning parser to use for extracting reasoning content from the model output """ enable_mm: bool = False """ Flags to enable multi-modal model """ speculative_config: Optional[Dict[str, Any]] = None """ Configuration for speculative execution. """ dynamic_load_weight: bool = False """ dynamic load weight """ load_strategy: str = "ipc_snapshot" """ dynamic load weight strategy """ quantization: str = None guided_decoding_backend: str = "off" """ Guided decoding backend. """ guided_decoding_disable_any_whitespace: bool = False """ Disable any whitespace in guided decoding. """ # Inference configuration parameters gpu_memory_utilization: float = 0.9 """ The fraction of GPU memory to be utilized. """ num_gpu_blocks_override: Optional[int] = None """ Override for the number of GPU blocks. """ max_num_batched_tokens: Optional[int] = None """ Maximum number of tokens to batch together. """ kv_cache_ratio: float = 0.75 """ Ratio of tokens to process in a block. """ dist_init_ip: Optional[str] = None """ The master node ip of multinode deployment """ nnodes: int = 1 """ The number of nodes in multinode deployment """ node_rank: int = 0 """ The rank of the current node in multinode deployment """ swap_space: float = None """ The amount of CPU memory to offload to. """ cache_queue_port: int = 8003 """ Port for cache queue. """ # System configuration parameters use_warmup: int = 0 """ Flag to indicate whether to use warm-up before inference. """ enable_prefix_caching: bool = False """ Flag to enable prefix caching. """ enable_custom_all_reduce: bool = False """ Flag to enable the custom all-reduce kernel. """ engine_worker_queue_port: int = 8002 """ Port for worker queue communication. """ splitwise_role: str = "mixed" """ Splitwise role: prefill, decode or mixed """ data_parallel_size: int = 1 """ Number of data parallelism. """ enable_expert_parallel: bool = False """ Enable expert parallelism. """ cache_transfer_protocol: str = "ipc" """ Protocol to use for cache transfer. """ pd_comm_port: Optional[List[int]] = None """ Port for splitwise communication. """ innode_prefill_ports: Optional[List[int]] = None """ Ports for innode dispatch request. """ rdma_comm_ports: Optional[List[int]] = None """ Ports for rdma communication. """ enable_chunked_prefill: bool = False """ Flag to enable chunked prefilling. """ max_num_partial_prefills: int = 1 """ For chunked prefill, the max number of concurrent partial prefills. """ max_long_partial_prefills: int = 1 """ For chunked prefill, the maximum number of prompts longer than –long-prefill-token-threshold that will be prefilled concurrently. """ long_prefill_token_threshold: int = 0 """ For chunked prefill, a request is considered long if the prompt is longer than this number of tokens. """ static_decode_blocks: int = 2 """ additional decode block num """ scheduler_name: str = "local" """ Scheduler name to be used """ scheduler_max_size: int = -1 """ Size of scheduler """ scheduler_ttl: int = 900 """ TTL of request """ scheduler_host: str = "127.0.0.1" """ Host of redis """ scheduler_port: int = 6379 """ Port of redis """ scheduler_db: int = 0 """ DB of redis """ scheduler_password: Optional[str] = None """ Password of redis """ scheduler_topic: str = "default" """ Topic of scheduler """ scheduler_min_load_score: float = 3 """ Minimum load score for task assignment """ scheduler_load_shards_num: int = 1 """ Number of shards for load balancing table """ scheduler_sync_period: int = 5 """ SplitWise Use, node load sync period """ scheduler_expire_period: int = 3000 """ SplitWise Use, node will not be scheduled after expire_period ms not sync load """ scheduler_release_load_expire_period: int = 600 """ SplitWise Use, scheduler will release req load after expire period(s) """ scheduler_reader_parallel: int = 4 """ SplitWise Use, Results Reader Sync Parallel """ scheduler_writer_parallel: int = 4 """ SplitWise Use, Results Writer Sync Parallel """ scheduler_reader_batch_size: int = 200 """ SplitWise Use, Results Reader Batch Size """ scheduler_writer_batch_size: int = 200 """ SplitWise Use, Results Writer Batch Size """ use_cudagraph: bool = False """ Flags to enable Cuda Graph """ graph_optimization_config: Optional[Dict[str, Any]] = None """ Configuration for graph optimization backend execution. """ enable_logprob: bool = False """ Flag to enable logprob output. Default is False (disabled). Must be explicitly enabled via the `--enable-logprob` startup parameter to output logprob values. """ def __post_init__(self): """ Post-initialization processing to set default tokenizer if not provided. """ if not self.tokenizer: self.tokenizer = self.model @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: """ Add command line interface arguments to the parser. """ # Model parameters group model_group = parser.add_argument_group("Model Configuration") model_group.add_argument("--model", type=str, default=EngineArgs.model, help="Model name or path to be used.") model_group.add_argument("--model-config-name", type=nullable_str, default=EngineArgs.model_config_name, help="The model configuration file name.") model_group.add_argument( "--tokenizer", type=nullable_str, default=EngineArgs.tokenizer, help= "Tokenizer name or path (defaults to model path if not specified)." ) model_group.add_argument( "--max-model-len", type=int, default=EngineArgs.max_model_len, help="Maximum context length supported by the model.") model_group.add_argument( "--block-size", type=int, default=EngineArgs.block_size, help="Number of tokens processed in one block.") model_group.add_argument("--task", type=str, default=EngineArgs.task, help="Task to be executed by the model.") model_group.add_argument( "--use-warmup", type=int, default=EngineArgs.use_warmup, help="Flag to indicate whether to use warm-up before inference.") model_group.add_argument( "--limit-mm-per-prompt", default=EngineArgs.limit_mm_per_prompt, type=json.loads, help="Limitation of numbers of multi-modal data.") model_group.add_argument( "--mm-processor-kwargs", default=EngineArgs.mm_processor_kwargs, type=json.loads, help="Additional keyword arguments for the multi-modal processor.") model_group.add_argument("--enable-mm", action='store_true', default=EngineArgs.enable_mm, help="Flag to enable multi-modal model.") model_group.add_argument("--reasoning-parser", type=str, default=EngineArgs.reasoning_parser, help="Flag specifies the reasoning parser to use for extracting "\ "reasoning content from the model output") model_group.add_argument( "--speculative-config", type=json.loads, default=EngineArgs.speculative_config, help="Configuration for speculative execution.") model_group.add_argument( "--dynamic-load-weight", action='store_true', default=EngineArgs.dynamic_load_weight, help="Flag to indicate whether to load weight dynamically.") model_group.add_argument( "--load-strategy", type=str, default=EngineArgs.load_strategy, help="Flag to dynamic load strategy.") model_group.add_argument("--engine-worker-queue-port", type=int, default=EngineArgs.engine_worker_queue_port, help="port for engine worker queue") model_group.add_argument("--quantization", type=str, default=EngineArgs.quantization, help="Quantization name for the model, currentlly support " \ "'wint8', 'wint4'," \ "default is None. The priority of this configuration "\ "is lower than that of the config file. " \ "More complex quantization methods need to be configured via the config file.") model_group.add_argument("--use-cudagraph", action='store_true', default=EngineArgs.use_cudagraph, help="Flags to enable cuda graph.") model_group.add_argument("--graph-optimization-config", type=json.loads, default=EngineArgs.graph_optimization_config, help="") model_group.add_argument("--guided-decoding-backend", type=str, default=EngineArgs.guided_decoding_backend, help="Guided Decoding Backend") model_group.add_argument( "--guided-decoding-disable-any-whitespace", type=str, default=EngineArgs.guided_decoding_disable_any_whitespace, help= "Disabled any whitespaces when using guided decoding backend XGrammar." ) model_group.add_argument("--enable-logprob", action="store_true", default=EngineArgs.enable_logprob, help="Enable output of token-level log probabilities." ) # Parallel processing parameters group parallel_group = parser.add_argument_group("Parallel Configuration") parallel_group.add_argument("--tensor-parallel-size", "-tp", type=int, default=EngineArgs.tensor_parallel_size, help="Degree of tensor parallelism.") parallel_group.add_argument("--enable-custom-all-reduce", action='store_true', default=EngineArgs.enable_custom_all_reduce, help="Flag to enable custom all-reduce.") parallel_group.add_argument( "--max-num-seqs", type=int, default=EngineArgs.max_num_seqs, help="Maximum number of sequences per iteration.") parallel_group.add_argument( "--num-gpu-blocks-override", type=int, default=EngineArgs.num_gpu_blocks_override, help="Override for the number of GPU blocks.") parallel_group.add_argument( "--max-num-batched-tokens", type=int, default=EngineArgs.max_num_batched_tokens, help="Maximum number of tokens to batch together.") parallel_group.add_argument( "--gpu-memory-utilization", type=float, default=EngineArgs.gpu_memory_utilization, help="Fraction of GPU memory to be utilized.") parallel_group.add_argument("--data-parallel-size", type=int, default=EngineArgs.data_parallel_size, help="Degree of data parallelism.") parallel_group.add_argument("--enable-expert-parallel", action='store_true', default=EngineArgs.enable_expert_parallel, help="Enable expert parallelism.") # CacheConfig parameters group cache_group = parser.add_argument_group("Cache Configuration") cache_group.add_argument("--kv-cache-ratio", type=float, default=EngineArgs.kv_cache_ratio, help="Ratio of tokens to process in a block.") cache_group.add_argument( "--swap-space", type=float, default=EngineArgs.swap_space, help="The amount of CPU memory to offload to.") cache_group.add_argument("--cache-queue-port", type=int, default=EngineArgs.cache_queue_port, help="port for cache queue") cache_group.add_argument("--static-decode-blocks", type=int, default=EngineArgs.static_decode_blocks, help="Static decoding blocks num.") # Cluster system parameters group system_group = parser.add_argument_group("System Configuration") system_group.add_argument( "--dist-init-ip", default=EngineArgs.dist_init_ip, help= "IP addresses of master node.") system_group.add_argument( "--nnodes", type=int, default=EngineArgs.nnodes, help= "The number of all nodes.") system_group.add_argument( "--node-rank", type=int, default=EngineArgs.node_rank, help= "node rank id (range [0, nnodes)).") # Performance tuning parameters group perf_group = parser.add_argument_group("Performance Tuning") perf_group.add_argument("--enable-prefix-caching", action='store_true', default=EngineArgs.enable_prefix_caching, help="Flag to enable prefix caching.") perf_group.add_argument("--splitwise-role", type=str, default=EngineArgs.splitwise_role, help="Role of splitwise. Default is \ 'mixed'. (prefill, decode, mixed)") perf_group.add_argument("--innode-prefill-ports", type=lambda s: s.split(",") if s else None, default=EngineArgs.innode_prefill_ports, help="port for innode prefill") perf_group.add_argument("--enable-chunked-prefill", action='store_true', default=EngineArgs.enable_chunked_prefill, help="Flag to enable chunked prefill.") perf_group.add_argument("--max-num-partial-prefills", type=int, default=EngineArgs.max_num_partial_prefills, help="For chunked prefill, Maximum number \ of concurrent partial prefill requests.") perf_group.add_argument( "--max-long-partial-prefills", type=int, default=EngineArgs.max_long_partial_prefills, help= ("For chunked prefill, the maximum number of prompts longer than long-prefill-token-threshold" "that will be prefilled concurrently.")) perf_group.add_argument( "--long-prefill-token-threshold", type=int, default=EngineArgs.long_prefill_token_threshold, help=("For chunked prefill, the threshold number of" " tokens for a prompt to be considered long.")) perf_group.add_argument( "--cache-transfer-protocol", type=str, default=EngineArgs.cache_transfer_protocol, help="support protocol list, comma separated, default is ipc") perf_group.add_argument("--pd-comm-port", type=lambda s: s.split(",") if s else None, default=EngineArgs.pd_comm_port, help="port for splitwise communication.") perf_group.add_argument("--rdma-comm-ports", type=lambda s: s.split(",") if s else None, default=EngineArgs.rdma_comm_ports, help="ports for rdma communication.") # Scheduler parameters group scheduler_group = parser.add_argument_group("Scheduler") scheduler_group.add_argument( "--scheduler-name", default=EngineArgs.scheduler_name, help= f"Scheduler name to be used. Default is {EngineArgs.scheduler_name}. (local,global)" ) scheduler_group.add_argument( "--scheduler-max-size", type=int, default=EngineArgs.scheduler_max_size, help= f"Size of scheduler. Default is {EngineArgs.scheduler_max_size}. (Local)" ) scheduler_group.add_argument( "--scheduler-ttl", type=int, default=EngineArgs.scheduler_ttl, help= f"TTL of request. Default is {EngineArgs.scheduler_ttl} seconds. (local,global)" ) scheduler_group.add_argument( "--scheduler-host", default=EngineArgs.scheduler_host, help= f"Host address of redis. Default is {EngineArgs.scheduler_host}. (global)" ) scheduler_group.add_argument( "--scheduler-port", type=int, default=EngineArgs.scheduler_port, help= f"Port of redis. Default is {EngineArgs.scheduler_port}. (global)") scheduler_group.add_argument( "--scheduler-db", type=int, default=EngineArgs.scheduler_db, help=f"DB of redis. Default is {EngineArgs.scheduler_db}. (global)" ) scheduler_group.add_argument( "--scheduler-password", default=EngineArgs.scheduler_password, help= f"Password of redis. Default is {EngineArgs.scheduler_password}. (global)" ) scheduler_group.add_argument( "--scheduler-topic", default=EngineArgs.scheduler_topic, help= f"Topic of scheduler. Defaule is {EngineArgs.scheduler_topic}. (global)" ) scheduler_group.add_argument( "--scheduler-min-load-score", type=float, default=EngineArgs.scheduler_min_load_score, help= f"Minimum load score for task assignment. Default is {EngineArgs.scheduler_min_load_score} (global)" ) scheduler_group.add_argument( "--scheduler-load-shards-num", type=int, default=EngineArgs.scheduler_load_shards_num, help=("Number of shards for load balancing table. Default is " f"{EngineArgs.scheduler_load_shards_num} (global)")) scheduler_group.add_argument( "--scheduler-sync-period", type=int, default=EngineArgs.scheduler_sync_period, help=f"SplitWise Use, node load sync period, " f"Default is {EngineArgs.scheduler_sync_period}ms. (global)") scheduler_group.add_argument( "--scheduler-expire-period", type=int, default=EngineArgs.scheduler_expire_period, help=f"SplitWise Use, node will not be scheduled after " f"expire-period ms not sync load, Default is " f"{EngineArgs.scheduler_expire_period}ms. (global)") scheduler_group.add_argument( "--scheduler-release-load-expire-period", type=int, default=EngineArgs.scheduler_release_load_expire_period, help=f"SplitWise Use, scheduler will release req load after " f"expire period(s). Default is " f"{EngineArgs.scheduler_release_load_expire_period}. (global)") scheduler_group.add_argument( "--scheduler-reader-parallel", type=int, default=EngineArgs.scheduler_reader_parallel, help=f"SplitWise Use, Results Reader Sync Parallel, " f"Default is {EngineArgs.scheduler_reader_parallel}. (global)") scheduler_group.add_argument( "--scheduler-writer-parallel", type=int, default=EngineArgs.scheduler_writer_parallel, help=f"SplitWise Use, Results Writer Sync Parallel, " f"Default is {EngineArgs.scheduler_writer_parallel}. (global)") scheduler_group.add_argument( "--scheduler-reader-batch-size", type=int, default=EngineArgs.scheduler_reader_batch_size, help=f"SplitWise Use, Results Reader Batch Size, " f"Default is {EngineArgs.scheduler_reader_batch_size}. (global)") scheduler_group.add_argument( "--scheduler-writer-batch-size", type=int, default=EngineArgs.scheduler_writer_batch_size, help=f"SplitWise Use, Results Writer Batch Size, " f"Default is {EngineArgs.scheduler_writer_batch_size}. (global)") return parser @classmethod def from_cli_args(cls, args: FlexibleArgumentParser) -> "EngineArgs": """ Create an instance of EngineArgs from command line arguments. """ return cls( **{ field.name: getattr(args, field.name) for field in dataclass_fields(cls) }) def create_model_config(self) -> ModelConfig: """ Create and return a ModelConfig object based on the current settings. """ return ModelConfig(model_name_or_path=self.model, config_json_file=self.model_config_name, quantization=self.quantization, dynamic_load_weight=self.dynamic_load_weight, load_strategy=self.load_strategy) def create_cache_config(self, model_cfg) -> CacheConfig: """ Create and return a CacheConfig object based on the current settings. """ return CacheConfig( block_size=self.block_size, tensor_parallel_size=self.tensor_parallel_size, gpu_memory_utilization=self.gpu_memory_utilization, num_gpu_blocks_override=self.num_gpu_blocks_override, kv_cache_ratio=self.kv_cache_ratio, enable_prefix_caching=self.enable_prefix_caching, swap_space=self.swap_space, cache_queue_port=self.cache_queue_port, model_cfg=model_cfg, enable_chunked_prefill=self.enable_chunked_prefill, enc_dec_block_num=self.static_decode_blocks, rdma_comm_ports=self.rdma_comm_ports, cache_transfer_protocol=self.cache_transfer_protocol, pd_comm_port=self.pd_comm_port, ) def create_speculative_config(self) -> SpeculativeConfig: """ """ if self.speculative_config is not None: return SpeculativeConfig(**self.speculative_config) else: return SpeculativeConfig() def create_scheduler_config(self) -> SchedulerConfig: """ Create and retuan a SchedulerConfig object based on the current settings. """ prefix = "scheduler_" prefix_len = len(prefix) extra_params = [ "max_model_len", "enable_chunked_prefill", "max_num_partial_prefills", "max_long_partial_prefills", "long_prefill_token_threshold" ] all = asdict(self) params = dict() for k, v in all.items(): if k[:prefix_len] == prefix: params[k[prefix_len:]] = v elif k in extra_params: params[k] = v return SchedulerConfig(**params) def create_parallel_config(self) -> ParallelConfig: """ Create and return a ParallelConfig object based on the current settings. """ return ParallelConfig( tensor_parallel_size=self.tensor_parallel_size, enable_expert_parallel=self.enable_expert_parallel, data_parallel_size=self.data_parallel_size, enable_custom_all_reduce=self.enable_custom_all_reduce ) def create_graph_optimization_config(self) -> GraphOptimizationConfig: """ Create and retuan a GraphOptimizationConfig object based on the current settings. """ if self.graph_optimization_config is not None: return GraphOptimizationConfig(**self.graph_optimization_config) else: return GraphOptimizationConfig() def create_engine_config(self) -> Config: """ Create and return a Config object based on the current settings. """ model_cfg = self.create_model_config() if not model_cfg.is_unified_ckpt and hasattr(model_cfg, 'tensor_parallel_size'): self.tensor_parallel_size = model_cfg.tensor_parallel_size if self.max_num_batched_tokens is None: if self.enable_chunked_prefill: self.max_num_batched_tokens = 2048 else: self.max_num_batched_tokens = self.max_model_len scheduler_cfg = self.create_scheduler_config() speculative_cfg = self.create_speculative_config() graph_opt_cfg = self.create_graph_optimization_config() graph_opt_cfg.update_use_cudagraph(self.use_cudagraph) assert not (self.use_cudagraph and self.enable_prefix_caching), \ "Prefix caching cannot be used with CUDA graph" assert not (self.tensor_parallel_size<=1 and self.enable_custom_all_reduce), \ "enable_custom_all_reduce must be used with tensor_parallel_size>1" return Config( model_name_or_path=self.model, model_config=model_cfg, scheduler_config=scheduler_cfg, tokenizer=self.tokenizer, cache_config=self.create_cache_config(model_cfg), parallel_config=self.create_parallel_config(), max_model_len=self.max_model_len, tensor_parallel_size=self.tensor_parallel_size, max_num_seqs=self.max_num_seqs, speculative_config=speculative_cfg, max_num_batched_tokens=self.max_num_batched_tokens, dist_init_ip=self.dist_init_ip, nnodes=self.nnodes, node_rank=self.node_rank, use_warmup=self.use_warmup, engine_worker_queue_port=self.engine_worker_queue_port, limit_mm_per_prompt=self.limit_mm_per_prompt, mm_processor_kwargs=self.mm_processor_kwargs, enable_mm=self.enable_mm, reasoning_parser=self.reasoning_parser, splitwise_role=self.splitwise_role, innode_prefill_ports=self.innode_prefill_ports, max_num_partial_prefills=self.max_num_partial_prefills, max_long_partial_prefills=self.max_long_partial_prefills, long_prefill_token_threshold=self.long_prefill_token_threshold, graph_optimization_config=graph_opt_cfg, guided_decoding_backend=self.guided_decoding_backend, disable_any_whitespace=self.guided_decoding_disable_any_whitespace, enable_custom_all_reduce=self.enable_custom_all_reduce, enable_logprob = self.enable_logprob, )