mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 11:56:44 +08:00 
			
		
		
		
	 68b4755587
			
		
	
	68b4755587
	
	
		
			
	
		
	
	
		
			Some checks failed
		
		
	
	Deploy GitHub Pages / deploy (push) Has been cancelled
				
			* [LLM] support multi node deploy * Update engine.py * fix bugs * fix * [LLM] support multi node deploy * [LLM] support multi node deploy --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
		
			
				
	
	
		
			788 lines
		
	
	
		
			29 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			788 lines
		
	
	
		
			29 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | ||
| # Copyright (c) 2025  PaddlePaddle Authors. All Rights Reserved.
 | ||
| #
 | ||
| # Licensed under the Apache License, Version 2.0 (the "License"
 | ||
| # you may not use this file except in compliance with the License.
 | ||
| # You may obtain a copy of the License at
 | ||
| #
 | ||
| #     http://www.apache.org/licenses/LICENSE-2.0
 | ||
| #
 | ||
| # Unless required by applicable law or agreed to in writing, software
 | ||
| # distributed under the License is distributed on an "AS IS" BASIS,
 | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | ||
| # See the License for the specific language governing permissions and
 | ||
| # limitations under the License.
 | ||
| """
 | ||
| import json
 | ||
| from dataclasses import asdict, dataclass
 | ||
| from dataclasses import fields as dataclass_fields
 | ||
| from typing import Any, Dict, List, Optional
 | ||
| 
 | ||
| from fastdeploy.engine.config import (CacheConfig, Config, ModelConfig,
 | ||
|                                       ParallelConfig, SpeculativeConfig,
 | ||
|                                       TaskOption)
 | ||
| from fastdeploy.scheduler.config import SchedulerConfig
 | ||
| from fastdeploy.utils import FlexibleArgumentParser
 | ||
| 
 | ||
| 
 | ||
| def nullable_str(x: str) -> Optional[str]:
 | ||
|     """
 | ||
|     Convert an empty string to None, preserving other string values.
 | ||
|     """
 | ||
|     return x if x else None
 | ||
| 
 | ||
| 
 | ||
| @dataclass
 | ||
| class EngineArgs:
 | ||
|     # Model configuration parameters
 | ||
|     model: str = "baidu/ernie-45-turbo"
 | ||
|     """
 | ||
|     The name or path of the model to be used.
 | ||
|     """
 | ||
|     model_config_name: Optional[str] = "config.json"
 | ||
|     """
 | ||
|     The name of the model configuration file.
 | ||
|     """
 | ||
|     tokenizer: str = None
 | ||
|     """
 | ||
|     The name or path of the tokenizer (defaults to model path if not provided).
 | ||
|     """
 | ||
|     max_model_len: int = 2048
 | ||
|     """
 | ||
|     Maximum context length supported by the model.
 | ||
|     """
 | ||
|     tensor_parallel_size: int = 1
 | ||
|     """
 | ||
|     Degree of tensor parallelism.
 | ||
|     """
 | ||
|     block_size: int = 64
 | ||
|     """
 | ||
|     Number of tokens in one processing block.
 | ||
|     """
 | ||
|     task: TaskOption = "generate"
 | ||
|     """
 | ||
|     The task to be executed by the model.
 | ||
|     """
 | ||
|     max_num_seqs: int = 8
 | ||
|     """
 | ||
|     Maximum number of sequences per iteration.
 | ||
|     """
 | ||
|     mm_processor_kwargs: Optional[Dict[str, Any]] = None
 | ||
|     """
 | ||
|     Additional keyword arguments for the multi-modal processor.
 | ||
|     """
 | ||
|     limit_mm_per_prompt: Optional[Dict[str, Any]] = None
 | ||
|     """
 | ||
|     Limitation of numbers of multi-modal data.
 | ||
|     """
 | ||
|     reasoning_parser: str = None
 | ||
|     """
 | ||
|     specifies the reasoning parser to use for extracting reasoning content from the model output
 | ||
|     """
 | ||
|     enable_mm: bool = False
 | ||
|     """
 | ||
|     Flags to enable multi-modal model
 | ||
|     """
 | ||
|     speculative_config: Optional[Dict[str, Any]] = None
 | ||
|     """
 | ||
|     Configuration for speculative execution.
 | ||
|     """
 | ||
|     dynamic_load_weight: bool = False
 | ||
|     """
 | ||
|     dynamic load weight
 | ||
|     """
 | ||
|     load_strategy: str = "meta"
 | ||
|     """
 | ||
|     dynamic load weight strategy
 | ||
|     """
 | ||
|     quantization: str = None
 | ||
|     guided_decoding_backend: str = "off"
 | ||
|     """
 | ||
|     Guided decoding backend.
 | ||
|     """
 | ||
|     guided_decoding_disable_any_whitespace: bool = False
 | ||
|     """
 | ||
|     Disable any whitespace in guided decoding.
 | ||
|     """
 | ||
| 
 | ||
|     # Inference configuration parameters
 | ||
|     gpu_memory_utilization: float = 0.9
 | ||
|     """
 | ||
|     The fraction of GPU memory to be utilized.
 | ||
|     """
 | ||
|     num_gpu_blocks_override: Optional[int] = None
 | ||
|     """
 | ||
|     Override for the number of GPU blocks.
 | ||
|     """
 | ||
|     max_num_batched_tokens: Optional[int] = None
 | ||
|     """
 | ||
|     Maximum number of tokens to batch together.
 | ||
|     """
 | ||
|     kv_cache_ratio: float = 0.75
 | ||
|     """
 | ||
|     Ratio of tokens to process in a block.
 | ||
|     """
 | ||
| 
 | ||
|     pod_ips: Optional[List[str]] = None
 | ||
|     """
 | ||
|     List of IP addresses for nodes in the cluster.
 | ||
|     """
 | ||
| 
 | ||
|     swap_space: float = None
 | ||
|     """
 | ||
|     The amount of CPU memory to offload to.
 | ||
|     """
 | ||
| 
 | ||
|     cache_queue_port: int = 8003
 | ||
|     """
 | ||
|     Port for cache queue.
 | ||
|     """
 | ||
| 
 | ||
|     # System configuration parameters
 | ||
|     use_warmup: int = 0
 | ||
|     """
 | ||
|     Flag to indicate whether to use warm-up before inference.
 | ||
|     """
 | ||
|     enable_prefix_caching: bool = False
 | ||
|     """
 | ||
|     Flag to enable prefix caching.
 | ||
|     """
 | ||
|     engine_worker_queue_port: int = 8002
 | ||
|     """
 | ||
|     Port for worker queue communication.
 | ||
|     """
 | ||
| 
 | ||
|     splitwise_role: str = "mixed"
 | ||
|     """
 | ||
|     Splitwise role: prefill, decode or mixed
 | ||
|     """
 | ||
| 
 | ||
|     data_parallel_size: int = 1
 | ||
|     """
 | ||
|     Number of data parallelism.
 | ||
|     """
 | ||
| 
 | ||
|     enable_expert_parallel: bool = False
 | ||
|     """
 | ||
|     Enable expert parallelism.
 | ||
|     """
 | ||
| 
 | ||
|     cache_transfer_protocol: str = "ipc"
 | ||
|     """
 | ||
|     Protocol to use for cache transfer.
 | ||
|     """
 | ||
| 
 | ||
|     pd_comm_port: Optional[List[int]] = None
 | ||
|     """
 | ||
|     Port for splitwise communication.
 | ||
|     """
 | ||
| 
 | ||
|     innode_prefill_ports: Optional[List[int]] = None
 | ||
|     """
 | ||
|     Ports for innode dispatch request.
 | ||
|     """
 | ||
| 
 | ||
|     rdma_comm_ports: Optional[List[int]] = None
 | ||
|     """
 | ||
|     Ports for rdma communication.
 | ||
|     """
 | ||
| 
 | ||
|     enable_chunked_prefill: bool = False
 | ||
|     """
 | ||
|     Flag to enable chunked prefilling.
 | ||
|     """
 | ||
|     max_num_partial_prefills: int = 1
 | ||
|     """
 | ||
|     For chunked prefill, the max number of concurrent partial prefills.
 | ||
|     """
 | ||
|     max_long_partial_prefills: int = 1
 | ||
|     """
 | ||
|     For chunked prefill, the maximum number of prompts longer than –long-prefill-token-threshold
 | ||
|     that will be prefilled concurrently.
 | ||
|     """
 | ||
|     long_prefill_token_threshold: int = 0
 | ||
|     """
 | ||
|     For chunked prefill, a request is considered long if the prompt is longer than this number of tokens.
 | ||
|     """
 | ||
|     static_decode_blocks: int = 2
 | ||
|     """
 | ||
|     additional decode block num
 | ||
|     """
 | ||
| 
 | ||
|     scheduler_name: str = "local"
 | ||
|     """
 | ||
|     Scheduler name to be used
 | ||
|     """
 | ||
|     scheduler_max_size: int = -1
 | ||
|     """
 | ||
|     Size of scheduler
 | ||
|     """
 | ||
|     scheduler_ttl: int = 900
 | ||
|     """
 | ||
|     TTL of request
 | ||
|     """
 | ||
|     scheduler_host: str = "127.0.0.1"
 | ||
|     """
 | ||
|     Host of redis
 | ||
|     """
 | ||
|     scheduler_port: int = 6379
 | ||
|     """
 | ||
|     Port of redis
 | ||
|     """
 | ||
|     scheduler_db: int = 0
 | ||
|     """
 | ||
|     DB of redis
 | ||
|     """
 | ||
|     scheduler_password: Optional[str] = None
 | ||
|     """
 | ||
|     Password of redis
 | ||
|     """
 | ||
|     scheduler_topic: str = "default"
 | ||
|     """
 | ||
|     Topic of scheduler
 | ||
|     """
 | ||
|     scheduler_min_load_score: float = 3
 | ||
|     """
 | ||
|     Minimum load score for task assignment
 | ||
|     """
 | ||
|     scheduler_load_shards_num: int = 1
 | ||
|     """
 | ||
|     Number of shards for load balancing table
 | ||
|     """
 | ||
|     scheduler_sync_period: int = 5
 | ||
|     """
 | ||
|     SplitWise Use, node load sync period
 | ||
|     """
 | ||
|     scheduler_expire_period: int = 3000
 | ||
|     """
 | ||
|     SplitWise Use, node will not be scheduled after expire_period ms not sync load
 | ||
|     """
 | ||
|     scheduler_release_load_expire_period: int = 600
 | ||
|     """
 | ||
|     SplitWise Use, scheduler will release req load after expire period(s)
 | ||
|     """
 | ||
|     scheduler_reader_parallel: int = 4
 | ||
|     """
 | ||
|     SplitWise Use, Results Reader Sync Parallel
 | ||
|     """
 | ||
|     scheduler_writer_parallel: int = 4
 | ||
|     """
 | ||
|     SplitWise Use, Results Writer Sync Parallel
 | ||
|     """
 | ||
|     scheduler_reader_batch_size: int = 200
 | ||
|     """
 | ||
|     SplitWise Use, Results Reader Batch Size
 | ||
|     """
 | ||
|     scheduler_writer_batch_size: int = 200
 | ||
|     """
 | ||
|     SplitWise Use, Results Writer Batch Size
 | ||
|     """
 | ||
|     enable_static_graph_inference: bool = False
 | ||
|     """
 | ||
|     Whether to use static mode
 | ||
|     """
 | ||
|     use_cudagraph: bool = False
 | ||
|     """
 | ||
|     Flags to enable Cuda Graph
 | ||
|     """
 | ||
|     max_capture_batch_size: int = 64
 | ||
|     """
 | ||
|     Maximum Batch Size for Cuda Graph Capture
 | ||
|     NOTE: Now only support to capture continuous batch size,
 | ||
|     Example:
 | ||
|         max_capture_batch_size=64, FastDeploy will capture graphs for batches [1,64].
 | ||
|     """
 | ||
| 
 | ||
|     def __post_init__(self):
 | ||
|         """
 | ||
|         Post-initialization processing to set default tokenizer if not provided.
 | ||
|         """
 | ||
|         if not self.tokenizer:
 | ||
|             self.tokenizer = self.model
 | ||
| 
 | ||
|     @staticmethod
 | ||
|     def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
 | ||
|         """
 | ||
|         Add command line interface arguments to the parser.
 | ||
|         """
 | ||
|         # Model parameters group
 | ||
|         model_group = parser.add_argument_group("Model Configuration")
 | ||
|         model_group.add_argument("--model",
 | ||
|                                  type=str,
 | ||
|                                  default=EngineArgs.model,
 | ||
|                                  help="Model name or path to be used.")
 | ||
|         model_group.add_argument("--model-config-name",
 | ||
|                                  type=nullable_str,
 | ||
|                                  default=EngineArgs.model_config_name,
 | ||
|                                  help="The model configuration file name.")
 | ||
|         model_group.add_argument(
 | ||
|             "--tokenizer",
 | ||
|             type=nullable_str,
 | ||
|             default=EngineArgs.tokenizer,
 | ||
|             help=
 | ||
|             "Tokenizer name or path (defaults to model path if not specified)."
 | ||
|         )
 | ||
|         model_group.add_argument(
 | ||
|             "--max-model-len",
 | ||
|             type=int,
 | ||
|             default=EngineArgs.max_model_len,
 | ||
|             help="Maximum context length supported by the model.")
 | ||
|         model_group.add_argument(
 | ||
|             "--block-size",
 | ||
|             type=int,
 | ||
|             default=EngineArgs.block_size,
 | ||
|             help="Number of tokens processed in one block.")
 | ||
|         model_group.add_argument("--task",
 | ||
|                                  type=str,
 | ||
|                                  default=EngineArgs.task,
 | ||
|                                  help="Task to be executed by the model.")
 | ||
|         model_group.add_argument(
 | ||
|             "--use-warmup",
 | ||
|             type=int,
 | ||
|             default=EngineArgs.use_warmup,
 | ||
|             help="Flag to indicate whether to use warm-up before inference.")
 | ||
|         model_group.add_argument(
 | ||
|             "--limit-mm-per-prompt",
 | ||
|             default=EngineArgs.limit_mm_per_prompt,
 | ||
|             type=json.loads,
 | ||
|             help="Limitation of numbers of multi-modal data.")
 | ||
|         model_group.add_argument(
 | ||
|             "--mm-processor-kwargs",
 | ||
|             default=EngineArgs.mm_processor_kwargs,
 | ||
|             type=json.loads,
 | ||
|             help="Additional keyword arguments for the multi-modal processor.")
 | ||
|         model_group.add_argument("--enable-mm",
 | ||
|                                  action='store_true',
 | ||
|                                  default=EngineArgs.enable_mm,
 | ||
|                                  help="Flag to enable multi-modal model.")
 | ||
|         model_group.add_argument("--reasoning-parser",
 | ||
|                                  type=str,
 | ||
|                                  default=EngineArgs.reasoning_parser,
 | ||
|                                  help="Flag specifies the reasoning parser to use for extracting "\
 | ||
|                                       "reasoning content from the model output")
 | ||
|         model_group.add_argument(
 | ||
|             "--speculative-config",
 | ||
|             type=json.loads,
 | ||
|             default=EngineArgs.speculative_config,
 | ||
|             help="Configuration for speculative execution.")
 | ||
|         model_group.add_argument(
 | ||
|             "--dynamic-load-weight",
 | ||
|             action='store_true',
 | ||
|             default=EngineArgs.dynamic_load_weight,
 | ||
|             help="Flag to indicate whether to load weight dynamically.")
 | ||
|         model_group.add_argument(
 | ||
|             "--load-strategy",
 | ||
|             type=str,
 | ||
|             default=EngineArgs.load_strategy,
 | ||
|             help="Flag to dynamic load strategy.")
 | ||
|         model_group.add_argument("--engine-worker-queue-port",
 | ||
|                                  type=int,
 | ||
|                                  default=EngineArgs.engine_worker_queue_port,
 | ||
|                                  help="port for engine worker queue")
 | ||
|         model_group.add_argument("--quantization",
 | ||
|                                  type=str,
 | ||
|                                  default=EngineArgs.quantization,
 | ||
|                                  help="Quantization name for the model, currentlly support " \
 | ||
|                                  "'wint8', 'wint4'," \
 | ||
|                                  "default is None. The priority of this configuration "\
 | ||
|                                  "is lower than that of the config file. " \
 | ||
|                                  "More complex quantization methods need to be configured via the config file.")
 | ||
| 
 | ||
|         model_group.add_argument(
 | ||
|             "--enable-static-graph-inference",
 | ||
|             action='store_true',
 | ||
|             default=EngineArgs.enable_static_graph_inference,
 | ||
|             help="Whether to use static mode; if enabled, " \
 | ||
|                  "'paddle.to_static' will be used to convert dynamic to static.")
 | ||
|         model_group.add_argument("--use-cudagraph",
 | ||
|                                  action='store_true',
 | ||
|                                  default=EngineArgs.use_cudagraph,
 | ||
|                                  help="Flags to enable cuda graph.")
 | ||
|         model_group.add_argument("--max-capture-batch-size",
 | ||
|                                  type=int,
 | ||
|                                  default=EngineArgs.max_capture_batch_size,
 | ||
|                                  help="Maximum of Batch Size for Warm Up.")
 | ||
|         model_group.add_argument("--guided-decoding-backend",
 | ||
|                                  type=str,
 | ||
|                                  default=EngineArgs.guided_decoding_backend,
 | ||
|                                  help="Guided Decoding Backend")
 | ||
|         model_group.add_argument(
 | ||
|             "--guided-decoding-disable-any-whitespace",
 | ||
|             type=str,
 | ||
|             default=EngineArgs.guided_decoding_disable_any_whitespace,
 | ||
|             help=
 | ||
|             "Disabled any whitespaces when using guided decoding backend XGrammar."
 | ||
|         )
 | ||
| 
 | ||
|         # Parallel processing parameters group
 | ||
|         parallel_group = parser.add_argument_group("Parallel Configuration")
 | ||
|         parallel_group.add_argument("--tensor-parallel-size",
 | ||
|                                     "-tp",
 | ||
|                                     type=int,
 | ||
|                                     default=EngineArgs.tensor_parallel_size,
 | ||
|                                     help="Degree of tensor parallelism.")
 | ||
|         parallel_group.add_argument(
 | ||
|             "--max-num-seqs",
 | ||
|             type=int,
 | ||
|             default=EngineArgs.max_num_seqs,
 | ||
|             help="Maximum number of sequences per iteration.")
 | ||
|         parallel_group.add_argument(
 | ||
|             "--num-gpu-blocks-override",
 | ||
|             type=int,
 | ||
|             default=EngineArgs.num_gpu_blocks_override,
 | ||
|             help="Override for the number of GPU blocks.")
 | ||
|         parallel_group.add_argument(
 | ||
|             "--max-num-batched-tokens",
 | ||
|             type=int,
 | ||
|             default=EngineArgs.max_num_batched_tokens,
 | ||
|             help="Maximum number of tokens to batch together.")
 | ||
|         parallel_group.add_argument(
 | ||
|             "--gpu-memory-utilization",
 | ||
|             type=float,
 | ||
|             default=EngineArgs.gpu_memory_utilization,
 | ||
|             help="Fraction of GPU memory to be utilized.")
 | ||
| 
 | ||
|         parallel_group.add_argument("--data-parallel-size",
 | ||
|                                     type=int,
 | ||
|                                     default=EngineArgs.data_parallel_size,
 | ||
|                                     help="Degree of data parallelism.")
 | ||
|         parallel_group.add_argument("--enable-expert-parallel",
 | ||
|                                     action='store_true',
 | ||
|                                     default=EngineArgs.enable_expert_parallel,
 | ||
|                                     help="Enable expert parallelism.")
 | ||
| 
 | ||
|         # CacheConfig parameters group
 | ||
|         cache_group = parser.add_argument_group("Cache Configuration")
 | ||
| 
 | ||
|         cache_group.add_argument("--kv-cache-ratio",
 | ||
|                                  type=float,
 | ||
|                                  default=EngineArgs.kv_cache_ratio,
 | ||
|                                  help="Ratio of tokens to process in a block.")
 | ||
| 
 | ||
|         cache_group.add_argument(
 | ||
|             "--swap-space",
 | ||
|             type=float,
 | ||
|             default=EngineArgs.swap_space,
 | ||
|             help="The amount of CPU memory to offload to.")
 | ||
| 
 | ||
|         cache_group.add_argument("--cache-queue-port",
 | ||
|                                  type=int,
 | ||
|                                  default=EngineArgs.cache_queue_port,
 | ||
|                                  help="port for cache queue")
 | ||
|         cache_group.add_argument("--static-decode-blocks",
 | ||
|                                  type=int,
 | ||
|                                  default=EngineArgs.static_decode_blocks,
 | ||
|                                  help="Static decoding blocks num.")
 | ||
| 
 | ||
|         # Cluster system parameters group
 | ||
|         system_group = parser.add_argument_group("System Configuration")
 | ||
|         system_group.add_argument(
 | ||
|             "--pod-ips",
 | ||
|             type=lambda s: s.split(",") if s else None,
 | ||
|             default=EngineArgs.pod_ips,
 | ||
|             help=
 | ||
|             "List of IP addresses for nodes in the cluster (comma-separated).")
 | ||
| 
 | ||
| 
 | ||
|         # Performance tuning parameters group
 | ||
|         perf_group = parser.add_argument_group("Performance Tuning")
 | ||
|         perf_group.add_argument("--enable-prefix-caching",
 | ||
|                                 action='store_true',
 | ||
|                                 default=EngineArgs.enable_prefix_caching,
 | ||
|                                 help="Flag to enable prefix caching.")
 | ||
| 
 | ||
|         perf_group.add_argument("--splitwise-role",
 | ||
|                                 type=str,
 | ||
|                                 default=EngineArgs.splitwise_role,
 | ||
|                                 help="Role of splitwise. Default is \
 | ||
|             'mixed'. (prefill, decode, mixed)")
 | ||
| 
 | ||
|         perf_group.add_argument("--innode-prefill-ports",
 | ||
|                                 type=lambda s: s.split(",") if s else None,
 | ||
|                                 default=EngineArgs.innode_prefill_ports,
 | ||
|                                 help="port for innode prefill")
 | ||
| 
 | ||
|         perf_group.add_argument("--enable-chunked-prefill",
 | ||
|                                 action='store_true',
 | ||
|                                 default=EngineArgs.enable_chunked_prefill,
 | ||
|                                 help="Flag to enable chunked prefill.")
 | ||
|         perf_group.add_argument("--max-num-partial-prefills",
 | ||
|                                 type=int,
 | ||
|                                 default=EngineArgs.max_num_partial_prefills,
 | ||
|                                 help="For chunked prefill, Maximum number \
 | ||
|             of concurrent partial prefill requests.")
 | ||
|         perf_group.add_argument(
 | ||
|             "--max-long-partial-prefills",
 | ||
|             type=int,
 | ||
|             default=EngineArgs.max_long_partial_prefills,
 | ||
|             help=
 | ||
|             ("For chunked prefill, the maximum number of prompts longer than long-prefill-token-threshold"
 | ||
|              "that will be prefilled concurrently."))
 | ||
|         perf_group.add_argument(
 | ||
|             "--long-prefill-token-threshold",
 | ||
|             type=int,
 | ||
|             default=EngineArgs.long_prefill_token_threshold,
 | ||
|             help=("For chunked prefill, the threshold number of"
 | ||
|                   " tokens for a prompt to be considered long."))
 | ||
| 
 | ||
|         perf_group.add_argument(
 | ||
|             "--cache-transfer-protocol",
 | ||
|             type=str,
 | ||
|             default=EngineArgs.cache_transfer_protocol,
 | ||
|             help="support protocol list, comma separated, default is ipc")
 | ||
| 
 | ||
|         perf_group.add_argument("--pd-comm-port",
 | ||
|                                 type=lambda s: s.split(",") if s else None,
 | ||
|                                 default=EngineArgs.pd_comm_port,
 | ||
|                                 help="port for splitwise communication.")
 | ||
| 
 | ||
|         perf_group.add_argument("--rdma-comm-ports",
 | ||
|                                 type=lambda s: s.split(",") if s else None,
 | ||
|                                 default=EngineArgs.rdma_comm_ports,
 | ||
|                                 help="ports for rdma communication.")
 | ||
| 
 | ||
|         # Scheduler parameters group
 | ||
|         scheduler_group = parser.add_argument_group("Scheduler")
 | ||
|         scheduler_group.add_argument(
 | ||
|             "--scheduler-name",
 | ||
|             default=EngineArgs.scheduler_name,
 | ||
|             help=
 | ||
|             f"Scheduler name to be used. Default is {EngineArgs.scheduler_name}. (local,global)"
 | ||
|         )
 | ||
|         scheduler_group.add_argument(
 | ||
|             "--scheduler-max-size",
 | ||
|             type=int,
 | ||
|             default=EngineArgs.scheduler_max_size,
 | ||
|             help=
 | ||
|             f"Size of scheduler. Default is {EngineArgs.scheduler_max_size}. (Local)"
 | ||
|         )
 | ||
|         scheduler_group.add_argument(
 | ||
|             "--scheduler-ttl",
 | ||
|             type=int,
 | ||
|             default=EngineArgs.scheduler_ttl,
 | ||
|             help=
 | ||
|             f"TTL of request. Default is {EngineArgs.scheduler_ttl} seconds. (local,global)"
 | ||
|         )
 | ||
|         scheduler_group.add_argument(
 | ||
|             "--scheduler-host",
 | ||
|             default=EngineArgs.scheduler_host,
 | ||
|             help=
 | ||
|             f"Host address of redis. Default is {EngineArgs.scheduler_host}. (global)"
 | ||
|         )
 | ||
|         scheduler_group.add_argument(
 | ||
|             "--scheduler-port",
 | ||
|             type=int,
 | ||
|             default=EngineArgs.scheduler_port,
 | ||
|             help=
 | ||
|             f"Port of redis. Default is {EngineArgs.scheduler_port}. (global)")
 | ||
|         scheduler_group.add_argument(
 | ||
|             "--scheduler-db",
 | ||
|             type=int,
 | ||
|             default=EngineArgs.scheduler_db,
 | ||
|             help=f"DB of redis. Default is {EngineArgs.scheduler_db}. (global)"
 | ||
|         )
 | ||
|         scheduler_group.add_argument(
 | ||
|             "--scheduler-password",
 | ||
|             default=EngineArgs.scheduler_password,
 | ||
|             help=
 | ||
|             f"Password of redis. Default is {EngineArgs.scheduler_password}. (global)"
 | ||
|         )
 | ||
|         scheduler_group.add_argument(
 | ||
|             "--scheduler-topic",
 | ||
|             default=EngineArgs.scheduler_topic,
 | ||
|             help=
 | ||
|             f"Topic of scheduler. Defaule is {EngineArgs.scheduler_topic}. (global)"
 | ||
|         )
 | ||
|         scheduler_group.add_argument(
 | ||
|             "--scheduler-min-load-score",
 | ||
|             type=float,
 | ||
|             default=EngineArgs.scheduler_min_load_score,
 | ||
|             help=
 | ||
|             f"Minimum load score for task assignment. Default is {EngineArgs.scheduler_min_load_score} (global)"
 | ||
|         )
 | ||
|         scheduler_group.add_argument(
 | ||
|             "--scheduler-load-shards-num",
 | ||
|             type=int,
 | ||
|             default=EngineArgs.scheduler_load_shards_num,
 | ||
|             help=("Number of shards for load balancing table. Default is "
 | ||
|                   f"{EngineArgs.scheduler_load_shards_num} (global)"))
 | ||
|         scheduler_group.add_argument(
 | ||
|             "--scheduler-sync-period",
 | ||
|             type=int,
 | ||
|             default=EngineArgs.scheduler_sync_period,
 | ||
|             help=f"SplitWise Use, node load sync period, "
 | ||
|             f"Default is {EngineArgs.scheduler_sync_period}ms. (global)")
 | ||
|         scheduler_group.add_argument(
 | ||
|             "--scheduler-expire-period",
 | ||
|             type=int,
 | ||
|             default=EngineArgs.scheduler_expire_period,
 | ||
|             help=f"SplitWise Use, node will not be scheduled after "
 | ||
|             f"expire-period ms not sync load, Default is "
 | ||
|             f"{EngineArgs.scheduler_expire_period}ms. (global)")
 | ||
|         scheduler_group.add_argument(
 | ||
|             "--scheduler-release-load-expire-period",
 | ||
|             type=int,
 | ||
|             default=EngineArgs.scheduler_release_load_expire_period,
 | ||
|             help=f"SplitWise Use, scheduler will release req load after "
 | ||
|             f"expire period(s). Default is "
 | ||
|             f"{EngineArgs.scheduler_release_load_expire_period}. (global)")
 | ||
|         scheduler_group.add_argument(
 | ||
|             "--scheduler-reader-parallel",
 | ||
|             type=int,
 | ||
|             default=EngineArgs.scheduler_reader_parallel,
 | ||
|             help=f"SplitWise Use, Results Reader Sync Parallel, "
 | ||
|             f"Default is {EngineArgs.scheduler_reader_parallel}. (global)")
 | ||
|         scheduler_group.add_argument(
 | ||
|             "--scheduler-writer-parallel",
 | ||
|             type=int,
 | ||
|             default=EngineArgs.scheduler_writer_parallel,
 | ||
|             help=f"SplitWise Use, Results Writer Sync Parallel, "
 | ||
|             f"Default is {EngineArgs.scheduler_writer_parallel}. (global)")
 | ||
|         scheduler_group.add_argument(
 | ||
|             "--scheduler-reader-batch-size",
 | ||
|             type=int,
 | ||
|             default=EngineArgs.scheduler_reader_batch_size,
 | ||
|             help=f"SplitWise Use, Results Reader Batch Size, "
 | ||
|             f"Default is {EngineArgs.scheduler_reader_batch_size}. (global)")
 | ||
|         scheduler_group.add_argument(
 | ||
|             "--scheduler-writer-batch-size",
 | ||
|             type=int,
 | ||
|             default=EngineArgs.scheduler_writer_batch_size,
 | ||
|             help=f"SplitWise Use, Results Writer Batch Size, "
 | ||
|             f"Default is {EngineArgs.scheduler_writer_batch_size}. (global)")
 | ||
| 
 | ||
|         return parser
 | ||
| 
 | ||
|     @classmethod
 | ||
|     def from_cli_args(cls, args: FlexibleArgumentParser) -> "EngineArgs":
 | ||
|         """
 | ||
|         Create an instance of EngineArgs from command line arguments.
 | ||
|         """
 | ||
|         return cls(
 | ||
|             **{
 | ||
|                 field.name: getattr(args, field.name)
 | ||
|                 for field in dataclass_fields(cls)
 | ||
|             })
 | ||
| 
 | ||
|     def create_model_config(self) -> ModelConfig:
 | ||
|         """
 | ||
|         Create and return a ModelConfig object based on the current settings.
 | ||
|         """
 | ||
|         return ModelConfig(model_name_or_path=self.model,
 | ||
|                            config_json_file=self.model_config_name,
 | ||
|                            quantization=self.quantization,
 | ||
|                            dynamic_load_weight=self.dynamic_load_weight,
 | ||
|                            load_strategy=self.load_strategy)
 | ||
| 
 | ||
|     def create_cache_config(self, model_cfg) -> CacheConfig:
 | ||
|         """
 | ||
|         Create and return a CacheConfig object based on the current settings.
 | ||
|         """
 | ||
|         return CacheConfig(
 | ||
|             block_size=self.block_size,
 | ||
|             tensor_parallel_size=self.tensor_parallel_size,
 | ||
|             gpu_memory_utilization=self.gpu_memory_utilization,
 | ||
|             num_gpu_blocks_override=self.num_gpu_blocks_override,
 | ||
|             kv_cache_ratio=self.kv_cache_ratio,
 | ||
|             enable_prefix_caching=self.enable_prefix_caching,
 | ||
|             swap_space=self.swap_space,
 | ||
|             cache_queue_port=self.cache_queue_port,
 | ||
|             model_cfg=model_cfg,
 | ||
|             enable_chunked_prefill=self.enable_chunked_prefill,
 | ||
|             enc_dec_block_num=self.static_decode_blocks,
 | ||
|             rdma_comm_ports=self.rdma_comm_ports,
 | ||
|             cache_transfer_protocol=self.cache_transfer_protocol,
 | ||
|             pd_comm_port=self.pd_comm_port,
 | ||
|         )
 | ||
| 
 | ||
|     def create_speculative_config(self) -> SpeculativeConfig:
 | ||
|         """
 | ||
|         """
 | ||
|         if self.speculative_config is not None:
 | ||
|             return SpeculativeConfig(**self.speculative_config)
 | ||
|         else:
 | ||
|             return SpeculativeConfig()
 | ||
| 
 | ||
|     def create_scheduler_config(self) -> SchedulerConfig:
 | ||
|         """
 | ||
|         Create and retuan a SchedulerConfig object based on the current settings.
 | ||
|         """
 | ||
|         prefix = "scheduler_"
 | ||
|         prefix_len = len(prefix)
 | ||
|         extra_params = [
 | ||
|             "max_model_len", "enable_chunked_prefill",
 | ||
|             "max_num_partial_prefills", "max_long_partial_prefills",
 | ||
|             "long_prefill_token_threshold"
 | ||
|         ]
 | ||
| 
 | ||
|         all = asdict(self)
 | ||
|         params = dict()
 | ||
|         for k, v in all.items():
 | ||
|             if k[:prefix_len] == prefix:
 | ||
|                 params[k[prefix_len:]] = v
 | ||
|             elif k in extra_params:
 | ||
|                 params[k] = v
 | ||
| 
 | ||
|         return SchedulerConfig(**params)
 | ||
| 
 | ||
|     def create_parallel_config(self) -> ParallelConfig:
 | ||
|         """
 | ||
|         Create and return a ParallelConfig object based on the current settings.
 | ||
|         """
 | ||
|         return ParallelConfig(
 | ||
|             tensor_parallel_size=self.tensor_parallel_size,
 | ||
|             enable_expert_parallel=self.enable_expert_parallel,
 | ||
|             data_parallel_size=self.data_parallel_size,
 | ||
|         )
 | ||
| 
 | ||
|     def create_engine_config(self) -> Config:
 | ||
|         """
 | ||
|         Create and return a Config object based on the current settings.
 | ||
|         """
 | ||
|         model_cfg = self.create_model_config()
 | ||
|         if not model_cfg.is_unified_ckpt and hasattr(model_cfg,
 | ||
|                                                      'tensor_parallel_size'):
 | ||
|             self.tensor_parallel_size = model_cfg.tensor_parallel_size
 | ||
|         if self.max_num_batched_tokens is None:
 | ||
|             if self.enable_chunked_prefill:
 | ||
|                 self.max_num_batched_tokens = 2048
 | ||
|             else:
 | ||
|                 self.max_num_batched_tokens = self.max_model_len
 | ||
|         scheduler_cfg = self.create_scheduler_config()
 | ||
| 
 | ||
|         speculative_cfg = self.create_speculative_config()
 | ||
| 
 | ||
|         assert not (self.use_cudagraph and self.enable_prefix_caching), \
 | ||
|             "Prefix caching cannot be used with CUDA graph"
 | ||
| 
 | ||
|         return Config(
 | ||
|             model_name_or_path=self.model,
 | ||
|             model_config=model_cfg,
 | ||
|             scheduler_config=scheduler_cfg,
 | ||
|             tokenizer=self.tokenizer,
 | ||
|             cache_config=self.create_cache_config(model_cfg),
 | ||
|             parallel_config=self.create_parallel_config(),
 | ||
|             max_model_len=self.max_model_len,
 | ||
|             tensor_parallel_size=self.tensor_parallel_size,
 | ||
|             max_num_seqs=self.max_num_seqs,
 | ||
|             speculative_config=speculative_cfg,
 | ||
|             max_num_batched_tokens=self.max_num_batched_tokens,
 | ||
|             pod_ips=self.pod_ips,
 | ||
|             use_warmup=self.use_warmup,
 | ||
|             engine_worker_queue_port=self.engine_worker_queue_port,
 | ||
|             limit_mm_per_prompt=self.limit_mm_per_prompt,
 | ||
|             mm_processor_kwargs=self.mm_processor_kwargs,
 | ||
|             enable_mm=self.enable_mm,
 | ||
|             reasoning_parser=self.reasoning_parser,
 | ||
|             splitwise_role=self.splitwise_role,
 | ||
|             innode_prefill_ports=self.innode_prefill_ports,
 | ||
|             max_num_partial_prefills=self.max_num_partial_prefills,
 | ||
|             max_long_partial_prefills=self.max_long_partial_prefills,
 | ||
|             long_prefill_token_threshold=self.long_prefill_token_threshold,
 | ||
|             enable_static_graph_inference=self.enable_static_graph_inference,
 | ||
|             use_cudagraph=self.use_cudagraph,
 | ||
|             max_capture_batch_size=self.max_capture_batch_size,
 | ||
|             guided_decoding_backend=self.guided_decoding_backend,
 | ||
|             disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
 | ||
|         )
 |