mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-29 22:02:30 +08:00

* online chat support logprobs * check xpu * check vl_gpu_model_runner and xpu_model_runner * get_worker() check platform
815 lines
30 KiB
Python
815 lines
30 KiB
Python
"""
|
||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
"""
|
||
import json
|
||
from dataclasses import asdict, dataclass
|
||
from dataclasses import fields as dataclass_fields
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
from fastdeploy.engine.config import (CacheConfig, Config, ModelConfig,
|
||
ParallelConfig, SpeculativeConfig,
|
||
TaskOption)
|
||
from fastdeploy.scheduler.config import SchedulerConfig
|
||
from fastdeploy.utils import FlexibleArgumentParser
|
||
|
||
|
||
def nullable_str(x: str) -> Optional[str]:
|
||
"""
|
||
Convert an empty string to None, preserving other string values.
|
||
"""
|
||
return x if x else None
|
||
|
||
|
||
@dataclass
|
||
class EngineArgs:
|
||
# Model configuration parameters
|
||
model: str = "baidu/ernie-45-turbo"
|
||
"""
|
||
The name or path of the model to be used.
|
||
"""
|
||
model_config_name: Optional[str] = "config.json"
|
||
"""
|
||
The name of the model configuration file.
|
||
"""
|
||
tokenizer: str = None
|
||
"""
|
||
The name or path of the tokenizer (defaults to model path if not provided).
|
||
"""
|
||
max_model_len: int = 2048
|
||
"""
|
||
Maximum context length supported by the model.
|
||
"""
|
||
tensor_parallel_size: int = 1
|
||
"""
|
||
Degree of tensor parallelism.
|
||
"""
|
||
block_size: int = 64
|
||
"""
|
||
Number of tokens in one processing block.
|
||
"""
|
||
task: TaskOption = "generate"
|
||
"""
|
||
The task to be executed by the model.
|
||
"""
|
||
max_num_seqs: int = 8
|
||
"""
|
||
Maximum number of sequences per iteration.
|
||
"""
|
||
mm_processor_kwargs: Optional[Dict[str, Any]] = None
|
||
"""
|
||
Additional keyword arguments for the multi-modal processor.
|
||
"""
|
||
limit_mm_per_prompt: Optional[Dict[str, Any]] = None
|
||
"""
|
||
Limitation of numbers of multi-modal data.
|
||
"""
|
||
reasoning_parser: str = None
|
||
"""
|
||
specifies the reasoning parser to use for extracting reasoning content from the model output
|
||
"""
|
||
enable_mm: bool = False
|
||
"""
|
||
Flags to enable multi-modal model
|
||
"""
|
||
speculative_config: Optional[Dict[str, Any]] = None
|
||
"""
|
||
Configuration for speculative execution.
|
||
"""
|
||
dynamic_load_weight: bool = False
|
||
"""
|
||
dynamic load weight
|
||
"""
|
||
load_strategy: str = "meta"
|
||
"""
|
||
dynamic load weight strategy
|
||
"""
|
||
quantization: str = None
|
||
guided_decoding_backend: str = "off"
|
||
"""
|
||
Guided decoding backend.
|
||
"""
|
||
guided_decoding_disable_any_whitespace: bool = False
|
||
"""
|
||
Disable any whitespace in guided decoding.
|
||
"""
|
||
|
||
# Inference configuration parameters
|
||
gpu_memory_utilization: float = 0.9
|
||
"""
|
||
The fraction of GPU memory to be utilized.
|
||
"""
|
||
num_gpu_blocks_override: Optional[int] = None
|
||
"""
|
||
Override for the number of GPU blocks.
|
||
"""
|
||
max_num_batched_tokens: Optional[int] = None
|
||
"""
|
||
Maximum number of tokens to batch together.
|
||
"""
|
||
kv_cache_ratio: float = 0.75
|
||
"""
|
||
Ratio of tokens to process in a block.
|
||
"""
|
||
|
||
pod_ips: Optional[List[str]] = None
|
||
"""
|
||
List of IP addresses for nodes in the cluster.
|
||
"""
|
||
|
||
swap_space: float = None
|
||
"""
|
||
The amount of CPU memory to offload to.
|
||
"""
|
||
|
||
cache_queue_port: int = 8003
|
||
"""
|
||
Port for cache queue.
|
||
"""
|
||
|
||
# System configuration parameters
|
||
use_warmup: int = 0
|
||
"""
|
||
Flag to indicate whether to use warm-up before inference.
|
||
"""
|
||
enable_prefix_caching: bool = False
|
||
"""
|
||
Flag to enable prefix caching.
|
||
"""
|
||
|
||
enable_custom_all_reduce: bool = False
|
||
"""
|
||
Flag to enable the custom all-reduce kernel.
|
||
"""
|
||
|
||
engine_worker_queue_port: int = 8002
|
||
"""
|
||
Port for worker queue communication.
|
||
"""
|
||
|
||
splitwise_role: str = "mixed"
|
||
"""
|
||
Splitwise role: prefill, decode or mixed
|
||
"""
|
||
|
||
data_parallel_size: int = 1
|
||
"""
|
||
Number of data parallelism.
|
||
"""
|
||
|
||
enable_expert_parallel: bool = False
|
||
"""
|
||
Enable expert parallelism.
|
||
"""
|
||
|
||
cache_transfer_protocol: str = "ipc"
|
||
"""
|
||
Protocol to use for cache transfer.
|
||
"""
|
||
|
||
pd_comm_port: Optional[List[int]] = None
|
||
"""
|
||
Port for splitwise communication.
|
||
"""
|
||
|
||
innode_prefill_ports: Optional[List[int]] = None
|
||
"""
|
||
Ports for innode dispatch request.
|
||
"""
|
||
|
||
rdma_comm_ports: Optional[List[int]] = None
|
||
"""
|
||
Ports for rdma communication.
|
||
"""
|
||
|
||
enable_chunked_prefill: bool = False
|
||
"""
|
||
Flag to enable chunked prefilling.
|
||
"""
|
||
max_num_partial_prefills: int = 1
|
||
"""
|
||
For chunked prefill, the max number of concurrent partial prefills.
|
||
"""
|
||
max_long_partial_prefills: int = 1
|
||
"""
|
||
For chunked prefill, the maximum number of prompts longer than –long-prefill-token-threshold
|
||
that will be prefilled concurrently.
|
||
"""
|
||
long_prefill_token_threshold: int = 0
|
||
"""
|
||
For chunked prefill, a request is considered long if the prompt is longer than this number of tokens.
|
||
"""
|
||
static_decode_blocks: int = 2
|
||
"""
|
||
additional decode block num
|
||
"""
|
||
|
||
scheduler_name: str = "local"
|
||
"""
|
||
Scheduler name to be used
|
||
"""
|
||
scheduler_max_size: int = -1
|
||
"""
|
||
Size of scheduler
|
||
"""
|
||
scheduler_ttl: int = 900
|
||
"""
|
||
TTL of request
|
||
"""
|
||
scheduler_host: str = "127.0.0.1"
|
||
"""
|
||
Host of redis
|
||
"""
|
||
scheduler_port: int = 6379
|
||
"""
|
||
Port of redis
|
||
"""
|
||
scheduler_db: int = 0
|
||
"""
|
||
DB of redis
|
||
"""
|
||
scheduler_password: Optional[str] = None
|
||
"""
|
||
Password of redis
|
||
"""
|
||
scheduler_topic: str = "default"
|
||
"""
|
||
Topic of scheduler
|
||
"""
|
||
scheduler_min_load_score: float = 3
|
||
"""
|
||
Minimum load score for task assignment
|
||
"""
|
||
scheduler_load_shards_num: int = 1
|
||
"""
|
||
Number of shards for load balancing table
|
||
"""
|
||
scheduler_sync_period: int = 5
|
||
"""
|
||
SplitWise Use, node load sync period
|
||
"""
|
||
scheduler_expire_period: int = 3000
|
||
"""
|
||
SplitWise Use, node will not be scheduled after expire_period ms not sync load
|
||
"""
|
||
scheduler_release_load_expire_period: int = 600
|
||
"""
|
||
SplitWise Use, scheduler will release req load after expire period(s)
|
||
"""
|
||
scheduler_reader_parallel: int = 4
|
||
"""
|
||
SplitWise Use, Results Reader Sync Parallel
|
||
"""
|
||
scheduler_writer_parallel: int = 4
|
||
"""
|
||
SplitWise Use, Results Writer Sync Parallel
|
||
"""
|
||
scheduler_reader_batch_size: int = 200
|
||
"""
|
||
SplitWise Use, Results Reader Batch Size
|
||
"""
|
||
scheduler_writer_batch_size: int = 200
|
||
"""
|
||
SplitWise Use, Results Writer Batch Size
|
||
"""
|
||
enable_static_graph_inference: bool = False
|
||
"""
|
||
Whether to use static mode
|
||
"""
|
||
use_cudagraph: bool = False
|
||
"""
|
||
Flags to enable Cuda Graph
|
||
"""
|
||
max_capture_batch_size: int = 64
|
||
"""
|
||
Maximum Batch Size for Cuda Graph Capture
|
||
NOTE: Now only support to capture continuous batch size,
|
||
Example:
|
||
max_capture_batch_size=64, FastDeploy will capture graphs for batches [1,64].
|
||
"""
|
||
|
||
enable_logprob: bool = False
|
||
"""
|
||
Flag to enable logprob output. Default is False (disabled).
|
||
Must be explicitly enabled via the `--enable-logprob` startup parameter to output logprob values.
|
||
"""
|
||
|
||
def __post_init__(self):
|
||
"""
|
||
Post-initialization processing to set default tokenizer if not provided.
|
||
"""
|
||
if not self.tokenizer:
|
||
self.tokenizer = self.model
|
||
|
||
@staticmethod
|
||
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||
"""
|
||
Add command line interface arguments to the parser.
|
||
"""
|
||
# Model parameters group
|
||
model_group = parser.add_argument_group("Model Configuration")
|
||
model_group.add_argument("--model",
|
||
type=str,
|
||
default=EngineArgs.model,
|
||
help="Model name or path to be used.")
|
||
model_group.add_argument("--model-config-name",
|
||
type=nullable_str,
|
||
default=EngineArgs.model_config_name,
|
||
help="The model configuration file name.")
|
||
model_group.add_argument(
|
||
"--tokenizer",
|
||
type=nullable_str,
|
||
default=EngineArgs.tokenizer,
|
||
help=
|
||
"Tokenizer name or path (defaults to model path if not specified)."
|
||
)
|
||
model_group.add_argument(
|
||
"--max-model-len",
|
||
type=int,
|
||
default=EngineArgs.max_model_len,
|
||
help="Maximum context length supported by the model.")
|
||
model_group.add_argument(
|
||
"--block-size",
|
||
type=int,
|
||
default=EngineArgs.block_size,
|
||
help="Number of tokens processed in one block.")
|
||
model_group.add_argument("--task",
|
||
type=str,
|
||
default=EngineArgs.task,
|
||
help="Task to be executed by the model.")
|
||
model_group.add_argument(
|
||
"--use-warmup",
|
||
type=int,
|
||
default=EngineArgs.use_warmup,
|
||
help="Flag to indicate whether to use warm-up before inference.")
|
||
model_group.add_argument(
|
||
"--limit-mm-per-prompt",
|
||
default=EngineArgs.limit_mm_per_prompt,
|
||
type=json.loads,
|
||
help="Limitation of numbers of multi-modal data.")
|
||
model_group.add_argument(
|
||
"--mm-processor-kwargs",
|
||
default=EngineArgs.mm_processor_kwargs,
|
||
type=json.loads,
|
||
help="Additional keyword arguments for the multi-modal processor.")
|
||
model_group.add_argument("--enable-mm",
|
||
action='store_true',
|
||
default=EngineArgs.enable_mm,
|
||
help="Flag to enable multi-modal model.")
|
||
model_group.add_argument("--reasoning-parser",
|
||
type=str,
|
||
default=EngineArgs.reasoning_parser,
|
||
help="Flag specifies the reasoning parser to use for extracting "\
|
||
"reasoning content from the model output")
|
||
model_group.add_argument(
|
||
"--speculative-config",
|
||
type=json.loads,
|
||
default=EngineArgs.speculative_config,
|
||
help="Configuration for speculative execution.")
|
||
model_group.add_argument(
|
||
"--dynamic-load-weight",
|
||
action='store_true',
|
||
default=EngineArgs.dynamic_load_weight,
|
||
help="Flag to indicate whether to load weight dynamically.")
|
||
model_group.add_argument(
|
||
"--load-strategy",
|
||
type=str,
|
||
default=EngineArgs.load_strategy,
|
||
help="Flag to dynamic load strategy.")
|
||
model_group.add_argument("--engine-worker-queue-port",
|
||
type=int,
|
||
default=EngineArgs.engine_worker_queue_port,
|
||
help="port for engine worker queue")
|
||
model_group.add_argument("--quantization",
|
||
type=str,
|
||
default=EngineArgs.quantization,
|
||
help="Quantization name for the model, currentlly support " \
|
||
"'wint8', 'wint4'," \
|
||
"default is None. The priority of this configuration "\
|
||
"is lower than that of the config file. " \
|
||
"More complex quantization methods need to be configured via the config file.")
|
||
|
||
model_group.add_argument(
|
||
"--enable-static-graph-inference",
|
||
action='store_true',
|
||
default=EngineArgs.enable_static_graph_inference,
|
||
help="Whether to use static mode; if enabled, " \
|
||
"'paddle.to_static' will be used to convert dynamic to static.")
|
||
model_group.add_argument("--use-cudagraph",
|
||
action='store_true',
|
||
default=EngineArgs.use_cudagraph,
|
||
help="Flags to enable cuda graph.")
|
||
model_group.add_argument("--max-capture-batch-size",
|
||
type=int,
|
||
default=EngineArgs.max_capture_batch_size,
|
||
help="Maximum of Batch Size for Warm Up.")
|
||
model_group.add_argument("--guided-decoding-backend",
|
||
type=str,
|
||
default=EngineArgs.guided_decoding_backend,
|
||
help="Guided Decoding Backend")
|
||
model_group.add_argument(
|
||
"--guided-decoding-disable-any-whitespace",
|
||
type=str,
|
||
default=EngineArgs.guided_decoding_disable_any_whitespace,
|
||
help=
|
||
"Disabled any whitespaces when using guided decoding backend XGrammar."
|
||
)
|
||
model_group.add_argument("--enable-logprob",
|
||
action="store_true",
|
||
default=EngineArgs.enable_logprob,
|
||
help="Enable output of token-level log probabilities."
|
||
)
|
||
|
||
# Parallel processing parameters group
|
||
parallel_group = parser.add_argument_group("Parallel Configuration")
|
||
parallel_group.add_argument("--tensor-parallel-size",
|
||
"-tp",
|
||
type=int,
|
||
default=EngineArgs.tensor_parallel_size,
|
||
help="Degree of tensor parallelism.")
|
||
parallel_group.add_argument("--enable-custom-all-reduce",
|
||
action='store_true',
|
||
default=EngineArgs.enable_custom_all_reduce,
|
||
help="Flag to enable custom all-reduce.")
|
||
parallel_group.add_argument(
|
||
"--max-num-seqs",
|
||
type=int,
|
||
default=EngineArgs.max_num_seqs,
|
||
help="Maximum number of sequences per iteration.")
|
||
parallel_group.add_argument(
|
||
"--num-gpu-blocks-override",
|
||
type=int,
|
||
default=EngineArgs.num_gpu_blocks_override,
|
||
help="Override for the number of GPU blocks.")
|
||
parallel_group.add_argument(
|
||
"--max-num-batched-tokens",
|
||
type=int,
|
||
default=EngineArgs.max_num_batched_tokens,
|
||
help="Maximum number of tokens to batch together.")
|
||
parallel_group.add_argument(
|
||
"--gpu-memory-utilization",
|
||
type=float,
|
||
default=EngineArgs.gpu_memory_utilization,
|
||
help="Fraction of GPU memory to be utilized.")
|
||
|
||
parallel_group.add_argument("--data-parallel-size",
|
||
type=int,
|
||
default=EngineArgs.data_parallel_size,
|
||
help="Degree of data parallelism.")
|
||
parallel_group.add_argument("--enable-expert-parallel",
|
||
action='store_true',
|
||
default=EngineArgs.enable_expert_parallel,
|
||
help="Enable expert parallelism.")
|
||
|
||
# CacheConfig parameters group
|
||
cache_group = parser.add_argument_group("Cache Configuration")
|
||
|
||
cache_group.add_argument("--kv-cache-ratio",
|
||
type=float,
|
||
default=EngineArgs.kv_cache_ratio,
|
||
help="Ratio of tokens to process in a block.")
|
||
|
||
cache_group.add_argument(
|
||
"--swap-space",
|
||
type=float,
|
||
default=EngineArgs.swap_space,
|
||
help="The amount of CPU memory to offload to.")
|
||
|
||
cache_group.add_argument("--cache-queue-port",
|
||
type=int,
|
||
default=EngineArgs.cache_queue_port,
|
||
help="port for cache queue")
|
||
cache_group.add_argument("--static-decode-blocks",
|
||
type=int,
|
||
default=EngineArgs.static_decode_blocks,
|
||
help="Static decoding blocks num.")
|
||
|
||
# Cluster system parameters group
|
||
system_group = parser.add_argument_group("System Configuration")
|
||
system_group.add_argument(
|
||
"--pod-ips",
|
||
type=lambda s: s.split(",") if s else None,
|
||
default=EngineArgs.pod_ips,
|
||
help=
|
||
"List of IP addresses for nodes in the cluster (comma-separated).")
|
||
|
||
|
||
# Performance tuning parameters group
|
||
perf_group = parser.add_argument_group("Performance Tuning")
|
||
perf_group.add_argument("--enable-prefix-caching",
|
||
action='store_true',
|
||
default=EngineArgs.enable_prefix_caching,
|
||
help="Flag to enable prefix caching.")
|
||
|
||
perf_group.add_argument("--splitwise-role",
|
||
type=str,
|
||
default=EngineArgs.splitwise_role,
|
||
help="Role of splitwise. Default is \
|
||
'mixed'. (prefill, decode, mixed)")
|
||
|
||
perf_group.add_argument("--innode-prefill-ports",
|
||
type=lambda s: s.split(",") if s else None,
|
||
default=EngineArgs.innode_prefill_ports,
|
||
help="port for innode prefill")
|
||
|
||
perf_group.add_argument("--enable-chunked-prefill",
|
||
action='store_true',
|
||
default=EngineArgs.enable_chunked_prefill,
|
||
help="Flag to enable chunked prefill.")
|
||
perf_group.add_argument("--max-num-partial-prefills",
|
||
type=int,
|
||
default=EngineArgs.max_num_partial_prefills,
|
||
help="For chunked prefill, Maximum number \
|
||
of concurrent partial prefill requests.")
|
||
perf_group.add_argument(
|
||
"--max-long-partial-prefills",
|
||
type=int,
|
||
default=EngineArgs.max_long_partial_prefills,
|
||
help=
|
||
("For chunked prefill, the maximum number of prompts longer than long-prefill-token-threshold"
|
||
"that will be prefilled concurrently."))
|
||
perf_group.add_argument(
|
||
"--long-prefill-token-threshold",
|
||
type=int,
|
||
default=EngineArgs.long_prefill_token_threshold,
|
||
help=("For chunked prefill, the threshold number of"
|
||
" tokens for a prompt to be considered long."))
|
||
|
||
perf_group.add_argument(
|
||
"--cache-transfer-protocol",
|
||
type=str,
|
||
default=EngineArgs.cache_transfer_protocol,
|
||
help="support protocol list, comma separated, default is ipc")
|
||
|
||
perf_group.add_argument("--pd-comm-port",
|
||
type=lambda s: s.split(",") if s else None,
|
||
default=EngineArgs.pd_comm_port,
|
||
help="port for splitwise communication.")
|
||
|
||
perf_group.add_argument("--rdma-comm-ports",
|
||
type=lambda s: s.split(",") if s else None,
|
||
default=EngineArgs.rdma_comm_ports,
|
||
help="ports for rdma communication.")
|
||
|
||
# Scheduler parameters group
|
||
scheduler_group = parser.add_argument_group("Scheduler")
|
||
scheduler_group.add_argument(
|
||
"--scheduler-name",
|
||
default=EngineArgs.scheduler_name,
|
||
help=
|
||
f"Scheduler name to be used. Default is {EngineArgs.scheduler_name}. (local,global)"
|
||
)
|
||
scheduler_group.add_argument(
|
||
"--scheduler-max-size",
|
||
type=int,
|
||
default=EngineArgs.scheduler_max_size,
|
||
help=
|
||
f"Size of scheduler. Default is {EngineArgs.scheduler_max_size}. (Local)"
|
||
)
|
||
scheduler_group.add_argument(
|
||
"--scheduler-ttl",
|
||
type=int,
|
||
default=EngineArgs.scheduler_ttl,
|
||
help=
|
||
f"TTL of request. Default is {EngineArgs.scheduler_ttl} seconds. (local,global)"
|
||
)
|
||
scheduler_group.add_argument(
|
||
"--scheduler-host",
|
||
default=EngineArgs.scheduler_host,
|
||
help=
|
||
f"Host address of redis. Default is {EngineArgs.scheduler_host}. (global)"
|
||
)
|
||
scheduler_group.add_argument(
|
||
"--scheduler-port",
|
||
type=int,
|
||
default=EngineArgs.scheduler_port,
|
||
help=
|
||
f"Port of redis. Default is {EngineArgs.scheduler_port}. (global)")
|
||
scheduler_group.add_argument(
|
||
"--scheduler-db",
|
||
type=int,
|
||
default=EngineArgs.scheduler_db,
|
||
help=f"DB of redis. Default is {EngineArgs.scheduler_db}. (global)"
|
||
)
|
||
scheduler_group.add_argument(
|
||
"--scheduler-password",
|
||
default=EngineArgs.scheduler_password,
|
||
help=
|
||
f"Password of redis. Default is {EngineArgs.scheduler_password}. (global)"
|
||
)
|
||
scheduler_group.add_argument(
|
||
"--scheduler-topic",
|
||
default=EngineArgs.scheduler_topic,
|
||
help=
|
||
f"Topic of scheduler. Defaule is {EngineArgs.scheduler_topic}. (global)"
|
||
)
|
||
scheduler_group.add_argument(
|
||
"--scheduler-min-load-score",
|
||
type=float,
|
||
default=EngineArgs.scheduler_min_load_score,
|
||
help=
|
||
f"Minimum load score for task assignment. Default is {EngineArgs.scheduler_min_load_score} (global)"
|
||
)
|
||
scheduler_group.add_argument(
|
||
"--scheduler-load-shards-num",
|
||
type=int,
|
||
default=EngineArgs.scheduler_load_shards_num,
|
||
help=("Number of shards for load balancing table. Default is "
|
||
f"{EngineArgs.scheduler_load_shards_num} (global)"))
|
||
scheduler_group.add_argument(
|
||
"--scheduler-sync-period",
|
||
type=int,
|
||
default=EngineArgs.scheduler_sync_period,
|
||
help=f"SplitWise Use, node load sync period, "
|
||
f"Default is {EngineArgs.scheduler_sync_period}ms. (global)")
|
||
scheduler_group.add_argument(
|
||
"--scheduler-expire-period",
|
||
type=int,
|
||
default=EngineArgs.scheduler_expire_period,
|
||
help=f"SplitWise Use, node will not be scheduled after "
|
||
f"expire-period ms not sync load, Default is "
|
||
f"{EngineArgs.scheduler_expire_period}ms. (global)")
|
||
scheduler_group.add_argument(
|
||
"--scheduler-release-load-expire-period",
|
||
type=int,
|
||
default=EngineArgs.scheduler_release_load_expire_period,
|
||
help=f"SplitWise Use, scheduler will release req load after "
|
||
f"expire period(s). Default is "
|
||
f"{EngineArgs.scheduler_release_load_expire_period}. (global)")
|
||
scheduler_group.add_argument(
|
||
"--scheduler-reader-parallel",
|
||
type=int,
|
||
default=EngineArgs.scheduler_reader_parallel,
|
||
help=f"SplitWise Use, Results Reader Sync Parallel, "
|
||
f"Default is {EngineArgs.scheduler_reader_parallel}. (global)")
|
||
scheduler_group.add_argument(
|
||
"--scheduler-writer-parallel",
|
||
type=int,
|
||
default=EngineArgs.scheduler_writer_parallel,
|
||
help=f"SplitWise Use, Results Writer Sync Parallel, "
|
||
f"Default is {EngineArgs.scheduler_writer_parallel}. (global)")
|
||
scheduler_group.add_argument(
|
||
"--scheduler-reader-batch-size",
|
||
type=int,
|
||
default=EngineArgs.scheduler_reader_batch_size,
|
||
help=f"SplitWise Use, Results Reader Batch Size, "
|
||
f"Default is {EngineArgs.scheduler_reader_batch_size}. (global)")
|
||
scheduler_group.add_argument(
|
||
"--scheduler-writer-batch-size",
|
||
type=int,
|
||
default=EngineArgs.scheduler_writer_batch_size,
|
||
help=f"SplitWise Use, Results Writer Batch Size, "
|
||
f"Default is {EngineArgs.scheduler_writer_batch_size}. (global)")
|
||
|
||
return parser
|
||
|
||
@classmethod
|
||
def from_cli_args(cls, args: FlexibleArgumentParser) -> "EngineArgs":
|
||
"""
|
||
Create an instance of EngineArgs from command line arguments.
|
||
"""
|
||
return cls(
|
||
**{
|
||
field.name: getattr(args, field.name)
|
||
for field in dataclass_fields(cls)
|
||
})
|
||
|
||
def create_model_config(self) -> ModelConfig:
|
||
"""
|
||
Create and return a ModelConfig object based on the current settings.
|
||
"""
|
||
return ModelConfig(model_name_or_path=self.model,
|
||
config_json_file=self.model_config_name,
|
||
quantization=self.quantization,
|
||
dynamic_load_weight=self.dynamic_load_weight,
|
||
load_strategy=self.load_strategy)
|
||
|
||
def create_cache_config(self, model_cfg) -> CacheConfig:
|
||
"""
|
||
Create and return a CacheConfig object based on the current settings.
|
||
"""
|
||
return CacheConfig(
|
||
block_size=self.block_size,
|
||
tensor_parallel_size=self.tensor_parallel_size,
|
||
gpu_memory_utilization=self.gpu_memory_utilization,
|
||
num_gpu_blocks_override=self.num_gpu_blocks_override,
|
||
kv_cache_ratio=self.kv_cache_ratio,
|
||
enable_prefix_caching=self.enable_prefix_caching,
|
||
swap_space=self.swap_space,
|
||
cache_queue_port=self.cache_queue_port,
|
||
model_cfg=model_cfg,
|
||
enable_chunked_prefill=self.enable_chunked_prefill,
|
||
enc_dec_block_num=self.static_decode_blocks,
|
||
rdma_comm_ports=self.rdma_comm_ports,
|
||
cache_transfer_protocol=self.cache_transfer_protocol,
|
||
pd_comm_port=self.pd_comm_port,
|
||
)
|
||
|
||
def create_speculative_config(self) -> SpeculativeConfig:
|
||
"""
|
||
"""
|
||
if self.speculative_config is not None:
|
||
return SpeculativeConfig(**self.speculative_config)
|
||
else:
|
||
return SpeculativeConfig()
|
||
|
||
def create_scheduler_config(self) -> SchedulerConfig:
|
||
"""
|
||
Create and retuan a SchedulerConfig object based on the current settings.
|
||
"""
|
||
prefix = "scheduler_"
|
||
prefix_len = len(prefix)
|
||
extra_params = [
|
||
"max_model_len", "enable_chunked_prefill",
|
||
"max_num_partial_prefills", "max_long_partial_prefills",
|
||
"long_prefill_token_threshold"
|
||
]
|
||
|
||
all = asdict(self)
|
||
params = dict()
|
||
for k, v in all.items():
|
||
if k[:prefix_len] == prefix:
|
||
params[k[prefix_len:]] = v
|
||
elif k in extra_params:
|
||
params[k] = v
|
||
|
||
return SchedulerConfig(**params)
|
||
|
||
def create_parallel_config(self) -> ParallelConfig:
|
||
"""
|
||
Create and return a ParallelConfig object based on the current settings.
|
||
"""
|
||
return ParallelConfig(
|
||
tensor_parallel_size=self.tensor_parallel_size,
|
||
enable_expert_parallel=self.enable_expert_parallel,
|
||
data_parallel_size=self.data_parallel_size,
|
||
enable_custom_all_reduce=self.enable_custom_all_reduce
|
||
)
|
||
|
||
def create_engine_config(self) -> Config:
|
||
"""
|
||
Create and return a Config object based on the current settings.
|
||
"""
|
||
model_cfg = self.create_model_config()
|
||
if not model_cfg.is_unified_ckpt and hasattr(model_cfg,
|
||
'tensor_parallel_size'):
|
||
self.tensor_parallel_size = model_cfg.tensor_parallel_size
|
||
if self.max_num_batched_tokens is None:
|
||
if self.enable_chunked_prefill:
|
||
self.max_num_batched_tokens = 2048
|
||
else:
|
||
self.max_num_batched_tokens = self.max_model_len
|
||
scheduler_cfg = self.create_scheduler_config()
|
||
|
||
speculative_cfg = self.create_speculative_config()
|
||
|
||
assert not (self.use_cudagraph and self.enable_prefix_caching), \
|
||
"Prefix caching cannot be used with CUDA graph"
|
||
|
||
assert not (self.tensor_parallel_size<=1 and self.enable_custom_all_reduce), \
|
||
"enable_custom_all_reduce must be used with tensor_parallel_size>1"
|
||
|
||
return Config(
|
||
model_name_or_path=self.model,
|
||
model_config=model_cfg,
|
||
scheduler_config=scheduler_cfg,
|
||
tokenizer=self.tokenizer,
|
||
cache_config=self.create_cache_config(model_cfg),
|
||
parallel_config=self.create_parallel_config(),
|
||
max_model_len=self.max_model_len,
|
||
tensor_parallel_size=self.tensor_parallel_size,
|
||
max_num_seqs=self.max_num_seqs,
|
||
speculative_config=speculative_cfg,
|
||
max_num_batched_tokens=self.max_num_batched_tokens,
|
||
pod_ips=self.pod_ips,
|
||
use_warmup=self.use_warmup,
|
||
engine_worker_queue_port=self.engine_worker_queue_port,
|
||
limit_mm_per_prompt=self.limit_mm_per_prompt,
|
||
mm_processor_kwargs=self.mm_processor_kwargs,
|
||
enable_mm=self.enable_mm,
|
||
reasoning_parser=self.reasoning_parser,
|
||
splitwise_role=self.splitwise_role,
|
||
innode_prefill_ports=self.innode_prefill_ports,
|
||
max_num_partial_prefills=self.max_num_partial_prefills,
|
||
max_long_partial_prefills=self.max_long_partial_prefills,
|
||
long_prefill_token_threshold=self.long_prefill_token_threshold,
|
||
enable_static_graph_inference=self.enable_static_graph_inference,
|
||
use_cudagraph=self.use_cudagraph,
|
||
max_capture_batch_size=self.max_capture_batch_size,
|
||
guided_decoding_backend=self.guided_decoding_backend,
|
||
disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
|
||
enable_custom_all_reduce=self.enable_custom_all_reduce,
|
||
enable_logprob = self.enable_logprob,
|
||
)
|