mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-27 10:30:34 +08:00
Sync v2.0 version of code to github repo
This commit is contained in:
@@ -13,12 +13,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
from dataclasses import asdict, dataclass
|
||||
from dataclasses import fields as dataclass_fields
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastdeploy.engine.config import (CacheConfig, Config, ModelConfig,
|
||||
ParallelConfig, SpeculativeConfig,
|
||||
TaskOption)
|
||||
from fastdeploy.scheduler.config import SchedulerConfig
|
||||
from fastdeploy.utils import FlexibleArgumentParser
|
||||
@@ -34,7 +35,7 @@ def nullable_str(x: str) -> Optional[str]:
|
||||
@dataclass
|
||||
class EngineArgs:
|
||||
# Model configuration parameters
|
||||
model: str = ""
|
||||
model: str = "baidu/ernie-45-turbo"
|
||||
"""
|
||||
The name or path of the model to be used.
|
||||
"""
|
||||
@@ -70,6 +71,14 @@ class EngineArgs:
|
||||
"""
|
||||
Additional keyword arguments for the multi-modal processor.
|
||||
"""
|
||||
limit_mm_per_prompt: Optional[Dict[str, Any]] = None
|
||||
"""
|
||||
Limitation of numbers of multi-modal data.
|
||||
"""
|
||||
reasoning_parser: str = None
|
||||
"""
|
||||
specifies the reasoning parser to use for extracting reasoning content from the model output
|
||||
"""
|
||||
enable_mm: bool = False
|
||||
"""
|
||||
Flags to enable multi-modal model
|
||||
@@ -82,6 +91,15 @@ class EngineArgs:
|
||||
"""
|
||||
dynamic load weight
|
||||
"""
|
||||
quantization: str = None
|
||||
guided_decoding_backend: str = "off"
|
||||
"""
|
||||
Guided decoding backend.
|
||||
"""
|
||||
guided_decoding_disable_any_whitespace: bool = False
|
||||
"""
|
||||
Disable any whitespace in guided decoding.
|
||||
"""
|
||||
|
||||
# Inference configuration parameters
|
||||
gpu_memory_utilization: float = 0.9
|
||||
@@ -109,6 +127,16 @@ class EngineArgs:
|
||||
List of IP addresses for nodes in the cluster.
|
||||
"""
|
||||
|
||||
swap_space: float = None
|
||||
"""
|
||||
The amount of CPU memory to offload to.
|
||||
"""
|
||||
|
||||
cache_queue_port: int = 8003
|
||||
"""
|
||||
Port for cache queue.
|
||||
"""
|
||||
|
||||
# System configuration parameters
|
||||
use_warmup: int = 0
|
||||
"""
|
||||
@@ -119,51 +147,150 @@ class EngineArgs:
|
||||
Flag to enable prefix caching.
|
||||
"""
|
||||
engine_worker_queue_port: int = 8002
|
||||
"""
|
||||
Port for worker queue communication.
|
||||
"""
|
||||
|
||||
splitwise_role: str = "mixed"
|
||||
"""
|
||||
Splitwise role: prefill, decode or mixed
|
||||
"""
|
||||
|
||||
data_parallel_size: int = 1
|
||||
"""
|
||||
Number of data parallelism.
|
||||
"""
|
||||
|
||||
enable_expert_parallel: bool = False
|
||||
"""
|
||||
Enable expert parallelism.
|
||||
"""
|
||||
|
||||
cache_transfer_protocol: str = "ipc"
|
||||
"""
|
||||
Protocol to use for cache transfer.
|
||||
"""
|
||||
|
||||
pd_comm_port: Optional[List[int]] = None
|
||||
"""
|
||||
Port for splitwise communication.
|
||||
"""
|
||||
|
||||
innode_prefill_ports: Optional[List[int]] = None
|
||||
"""
|
||||
Ports for innode dispatch request.
|
||||
"""
|
||||
|
||||
rdma_comm_ports: Optional[List[int]] = None
|
||||
"""
|
||||
Ports for rdma communication.
|
||||
"""
|
||||
|
||||
enable_chunked_prefill: bool = False
|
||||
"""
|
||||
Flag to enable chunked prefilling.
|
||||
"""
|
||||
max_num_partial_prefills: int = 1
|
||||
"""
|
||||
For chunked prefill, the max number of concurrent partial prefills.
|
||||
"""
|
||||
max_long_partial_prefills: int = 1
|
||||
"""
|
||||
For chunked prefill, the maximum number of prompts longer than –long-prefill-token-threshold
|
||||
that will be prefilled concurrently.
|
||||
"""
|
||||
long_prefill_token_threshold: int = 0
|
||||
"""
|
||||
For chunked prefill, a request is considered long if the prompt is longer than this number of tokens.
|
||||
"""
|
||||
static_decode_blocks: int = 2
|
||||
"""
|
||||
additional decode block num
|
||||
"""
|
||||
|
||||
scheduler_name: str = "local"
|
||||
"""
|
||||
Scheduler name to be used
|
||||
"""
|
||||
scheduler_name: str = "local"
|
||||
scheduler_max_size: int = -1
|
||||
"""
|
||||
Size of scheduler
|
||||
"""
|
||||
scheduler_max_size: int = -1
|
||||
scheduler_ttl: int = 900
|
||||
"""
|
||||
TTL of request
|
||||
"""
|
||||
scheduler_ttl: int = 900
|
||||
"""
|
||||
Timeout for waiting for response
|
||||
"""
|
||||
scheduler_wait_response_timeout: float = 0.001
|
||||
scheduler_host: str = "127.0.0.1"
|
||||
"""
|
||||
Host of redis
|
||||
"""
|
||||
scheduler_host: str = "127.0.0.1"
|
||||
scheduler_port: int = 6379
|
||||
"""
|
||||
Port of redis
|
||||
"""
|
||||
scheduler_port: int = 6379
|
||||
scheduler_db: int = 0
|
||||
"""
|
||||
DB of redis
|
||||
"""
|
||||
scheduler_db: int = 0
|
||||
scheduler_password: Optional[str] = None
|
||||
"""
|
||||
Password of redis
|
||||
"""
|
||||
scheduler_password: Optional[str] = None
|
||||
scheduler_topic: str = "default"
|
||||
"""
|
||||
Topic of scheduler
|
||||
"""
|
||||
scheduler_topic: str = "default"
|
||||
scheduler_min_load_score: float = 3
|
||||
"""
|
||||
Max write time of redis
|
||||
Minimum load score for task assignment
|
||||
"""
|
||||
scheduler_load_shards_num: int = 1
|
||||
"""
|
||||
Number of shards for load balancing table
|
||||
"""
|
||||
scheduler_sync_period: int = 5
|
||||
"""
|
||||
SplitWise Use, node load sync period
|
||||
"""
|
||||
scheduler_expire_period: int = 3000
|
||||
"""
|
||||
SplitWise Use, node will not be scheduled after expire_period ms not sync load
|
||||
"""
|
||||
scheduler_release_load_expire_period: int = 600
|
||||
"""
|
||||
SplitWise Use, scheduler will release req load after expire period(s)
|
||||
"""
|
||||
scheduler_reader_parallel: int = 4
|
||||
"""
|
||||
SplitWise Use, Results Reader Sync Parallel
|
||||
"""
|
||||
scheduler_writer_parallel: int = 4
|
||||
"""
|
||||
SplitWise Use, Results Writer Sync Parallel
|
||||
"""
|
||||
scheduler_reader_batch_size: int = 200
|
||||
"""
|
||||
SplitWise Use, Results Reader Batch Size
|
||||
"""
|
||||
scheduler_writer_batch_size: int = 200
|
||||
"""
|
||||
SplitWise Use, Results Writer Batch Size
|
||||
"""
|
||||
enable_static_graph_inference: bool = False
|
||||
"""
|
||||
Whether to use static mode
|
||||
"""
|
||||
use_cudagraph: bool = False
|
||||
"""
|
||||
Flags to enable Cuda Graph
|
||||
"""
|
||||
max_capture_batch_size: int = 64
|
||||
"""
|
||||
Maximum Batch Size for Cuda Graph Capture
|
||||
NOTE: Now only support to capture continuous batch size,
|
||||
Example:
|
||||
max_capture_batch_size=64, FastDeploy will capture graphs for batches [1,64].
|
||||
"""
|
||||
scheduler_remote_write_time: int = 3
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
@@ -214,20 +341,32 @@ class EngineArgs:
|
||||
default=EngineArgs.use_warmup,
|
||||
help="Flag to indicate whether to use warm-up before inference.")
|
||||
model_group.add_argument(
|
||||
"--mm_processor_kwargs",
|
||||
default=None,
|
||||
"--limit-mm-per-prompt",
|
||||
default=EngineArgs.limit_mm_per_prompt,
|
||||
type=json.loads,
|
||||
help="Limitation of numbers of multi-modal data.")
|
||||
model_group.add_argument(
|
||||
"--mm-processor-kwargs",
|
||||
default=EngineArgs.mm_processor_kwargs,
|
||||
type=json.loads,
|
||||
help="Additional keyword arguments for the multi-modal processor.")
|
||||
model_group.add_argument("--enable-mm",
|
||||
action='store_true',
|
||||
default=EngineArgs.enable_mm,
|
||||
help="Flag to enable multi-modal model.")
|
||||
model_group.add_argument("--reasoning-parser",
|
||||
type=str,
|
||||
default=EngineArgs.reasoning_parser,
|
||||
help="Flag specifies the reasoning parser to use for extracting "\
|
||||
"reasoning content from the model output")
|
||||
model_group.add_argument(
|
||||
"--speculative_config",
|
||||
default=None,
|
||||
"--speculative-config",
|
||||
type=json.loads,
|
||||
default=EngineArgs.speculative_config,
|
||||
help="Configuration for speculative execution.")
|
||||
|
||||
model_group.add_argument(
|
||||
"--dynamic_load_weight",
|
||||
"--dynamic-load-weight",
|
||||
type=int,
|
||||
default=EngineArgs.dynamic_load_weight,
|
||||
help="Flag to indicate whether to load weight dynamically.")
|
||||
@@ -236,6 +375,39 @@ class EngineArgs:
|
||||
type=int,
|
||||
default=EngineArgs.engine_worker_queue_port,
|
||||
help="port for engine worker queue")
|
||||
model_group.add_argument("--quantization",
|
||||
type=str,
|
||||
default=EngineArgs.quantization,
|
||||
help="Quantization name for the model, currentlly support " \
|
||||
"'wint8', 'wint4'," \
|
||||
"default is None. The priority of this configuration "\
|
||||
"is lower than that of the config file. " \
|
||||
"More complex quantization methods need to be configured via the config file.")
|
||||
model_group.add_argument(
|
||||
"--enable-static-graph-inference",
|
||||
action='store_true',
|
||||
default=EngineArgs.enable_static_graph_inference,
|
||||
help="Whether to use static mode; if enabled, " \
|
||||
"'paddle.to_static' will be used to convert dynamic to static.")
|
||||
model_group.add_argument("--use-cudagraph",
|
||||
action='store_true',
|
||||
default=EngineArgs.use_cudagraph,
|
||||
help="Flags to enable cuda graph.")
|
||||
model_group.add_argument("--max-capture-batch-size",
|
||||
type=int,
|
||||
default=EngineArgs.max_capture_batch_size,
|
||||
help="Maximum of Batch Size for Warm Up.")
|
||||
model_group.add_argument("--guided-decoding-backend",
|
||||
type=str,
|
||||
default=EngineArgs.guided_decoding_backend,
|
||||
help="Guided Decoding Backend")
|
||||
model_group.add_argument(
|
||||
"--guided-decoding-disable-any-whitespace",
|
||||
type=str,
|
||||
default=EngineArgs.guided_decoding_disable_any_whitespace,
|
||||
help=
|
||||
"Disabled any whitespaces when using guided decoding backend XGrammar."
|
||||
)
|
||||
|
||||
# Parallel processing parameters group
|
||||
parallel_group = parser.add_argument_group("Parallel Configuration")
|
||||
@@ -264,11 +436,38 @@ class EngineArgs:
|
||||
type=float,
|
||||
default=EngineArgs.gpu_memory_utilization,
|
||||
help="Fraction of GPU memory to be utilized.")
|
||||
parallel_group.add_argument(
|
||||
"--kv-cache-ratio",
|
||||
|
||||
parallel_group.add_argument("--data-parallel-size",
|
||||
type=int,
|
||||
default=EngineArgs.data_parallel_size,
|
||||
help="Degree of data parallelism.")
|
||||
parallel_group.add_argument("--enable-expert-parallel",
|
||||
action='store_true',
|
||||
default=EngineArgs.enable_expert_parallel,
|
||||
help="Enable expert parallelism.")
|
||||
|
||||
# CacheConfig parameters group
|
||||
cache_group = parser.add_argument_group("Cache Configuration")
|
||||
|
||||
cache_group.add_argument("--kv-cache-ratio",
|
||||
type=float,
|
||||
default=EngineArgs.kv_cache_ratio,
|
||||
help="Ratio of tokens to process in a block.")
|
||||
|
||||
cache_group.add_argument(
|
||||
"--swap-space",
|
||||
type=float,
|
||||
default=EngineArgs.kv_cache_ratio,
|
||||
help="Ratio of tokens to process in a block.")
|
||||
default=EngineArgs.swap_space,
|
||||
help="The amount of CPU memory to offload to.")
|
||||
|
||||
cache_group.add_argument("--cache-queue-port",
|
||||
type=int,
|
||||
default=EngineArgs.cache_queue_port,
|
||||
help="port for cache queue")
|
||||
cache_group.add_argument("--static-decode-blocks",
|
||||
type=int,
|
||||
default=EngineArgs.static_decode_blocks,
|
||||
help="Static decoding blocks num.")
|
||||
|
||||
# Cluster system parameters group
|
||||
system_group = parser.add_argument_group("System Configuration")
|
||||
@@ -285,18 +484,60 @@ class EngineArgs:
|
||||
|
||||
# Performance tuning parameters group
|
||||
perf_group = parser.add_argument_group("Performance Tuning")
|
||||
perf_group.add_argument("--enable-prefix-caching",
|
||||
action='store_true',
|
||||
default=EngineArgs.enable_prefix_caching,
|
||||
help="Flag to enable prefix caching.")
|
||||
|
||||
perf_group.add_argument("--splitwise-role",
|
||||
type=str,
|
||||
default=EngineArgs.splitwise_role,
|
||||
help="Role of splitwise. Default is \
|
||||
'mixed'. (prefill, decode, mixed)")
|
||||
|
||||
perf_group.add_argument("--innode-prefill-ports",
|
||||
type=lambda s: s.split(",") if s else None,
|
||||
default=EngineArgs.innode_prefill_ports,
|
||||
help="port for innode prefill")
|
||||
|
||||
perf_group.add_argument("--enable-chunked-prefill",
|
||||
action='store_true',
|
||||
default=EngineArgs.enable_chunked_prefill,
|
||||
help="Flag to enable chunked prefill.")
|
||||
perf_group.add_argument("--max-num-partial-prefills",
|
||||
type=int,
|
||||
default=EngineArgs.max_num_partial_prefills,
|
||||
help="For chunked prefill, Maximum number \
|
||||
of concurrent partial prefill requests.")
|
||||
perf_group.add_argument(
|
||||
"--enable-prefix-caching",
|
||||
action='store_true',
|
||||
default=EngineArgs.enable_prefix_caching,
|
||||
help="Flag to enable prefix caching."
|
||||
)
|
||||
"--max-long-partial-prefills",
|
||||
type=int,
|
||||
default=EngineArgs.max_long_partial_prefills,
|
||||
help=
|
||||
("For chunked prefill, the maximum number of prompts longer than long-prefill-token-threshold"
|
||||
"that will be prefilled concurrently."))
|
||||
perf_group.add_argument(
|
||||
"--enable-chunked-prefill",
|
||||
action='store_true',
|
||||
default=EngineArgs.enable_chunked_prefill,
|
||||
help="Flag to enable chunked prefill."
|
||||
)
|
||||
"--long-prefill-token-threshold",
|
||||
type=int,
|
||||
default=EngineArgs.long_prefill_token_threshold,
|
||||
help=("For chunked prefill, the threshold number of"
|
||||
" tokens for a prompt to be considered long."))
|
||||
|
||||
perf_group.add_argument(
|
||||
"--cache-transfer-protocol",
|
||||
type=str,
|
||||
default=EngineArgs.cache_transfer_protocol,
|
||||
help="support protocol list, comma separated, default is ipc")
|
||||
|
||||
perf_group.add_argument("--pd-comm-port",
|
||||
type=lambda s: s.split(",") if s else None,
|
||||
default=EngineArgs.pd_comm_port,
|
||||
help="port for splitwise communication.")
|
||||
|
||||
perf_group.add_argument("--rdma-comm-ports",
|
||||
type=lambda s: s.split(",") if s else None,
|
||||
default=EngineArgs.rdma_comm_ports,
|
||||
help="ports for rdma communication.")
|
||||
|
||||
# Scheduler parameters group
|
||||
scheduler_group = parser.add_argument_group("Scheduler")
|
||||
@@ -320,14 +561,6 @@ class EngineArgs:
|
||||
help=
|
||||
f"TTL of request. Default is {EngineArgs.scheduler_ttl} seconds. (local,global)"
|
||||
)
|
||||
scheduler_group.add_argument(
|
||||
"--scheduler-wait-response-timeout",
|
||||
type=float,
|
||||
default=EngineArgs.scheduler_wait_response_timeout,
|
||||
help=
|
||||
("Timeout for waiting for response. Default is "
|
||||
f"{EngineArgs.scheduler_wait_response_timeout} seconds. (local,global)"
|
||||
))
|
||||
scheduler_group.add_argument(
|
||||
"--scheduler-host",
|
||||
default=EngineArgs.scheduler_host,
|
||||
@@ -359,12 +592,62 @@ class EngineArgs:
|
||||
f"Topic of scheduler. Defaule is {EngineArgs.scheduler_topic}. (global)"
|
||||
)
|
||||
scheduler_group.add_argument(
|
||||
"--scheduler-remote-write-time",
|
||||
type=int,
|
||||
default=EngineArgs.scheduler_remote_write_time,
|
||||
"--scheduler-min-load-score",
|
||||
type=float,
|
||||
default=EngineArgs.scheduler_min_load_score,
|
||||
help=
|
||||
f"Max write time of redis. Default is {EngineArgs.scheduler_remote_write_time} seconds (global)"
|
||||
f"Minimum load score for task assignment. Default is {EngineArgs.scheduler_min_load_score} (global)"
|
||||
)
|
||||
scheduler_group.add_argument(
|
||||
"--scheduler-load-shards-num",
|
||||
type=int,
|
||||
default=EngineArgs.scheduler_load_shards_num,
|
||||
help=("Number of shards for load balancing table. Default is "
|
||||
f"{EngineArgs.scheduler_load_shards_num} (global)"))
|
||||
scheduler_group.add_argument(
|
||||
"--scheduler-sync-period",
|
||||
type=int,
|
||||
default=EngineArgs.scheduler_sync_period,
|
||||
help=f"SplitWise Use, node load sync period, "
|
||||
f"Default is {EngineArgs.scheduler_sync_period}ms. (global)")
|
||||
scheduler_group.add_argument(
|
||||
"--scheduler-expire-period",
|
||||
type=int,
|
||||
default=EngineArgs.scheduler_expire_period,
|
||||
help=f"SplitWise Use, node will not be scheduled after "
|
||||
f"expire-period ms not sync load, Default is "
|
||||
f"{EngineArgs.scheduler_expire_period}ms. (global)")
|
||||
scheduler_group.add_argument(
|
||||
"--scheduler-release-load-expire-period",
|
||||
type=int,
|
||||
default=EngineArgs.scheduler_release_load_expire_period,
|
||||
help=f"SplitWise Use, scheduler will release req load after "
|
||||
f"expire period(s). Default is "
|
||||
f"{EngineArgs.scheduler_release_load_expire_period}. (global)")
|
||||
scheduler_group.add_argument(
|
||||
"--scheduler-reader-parallel",
|
||||
type=int,
|
||||
default=EngineArgs.scheduler_reader_parallel,
|
||||
help=f"SplitWise Use, Results Reader Sync Parallel, "
|
||||
f"Default is {EngineArgs.scheduler_reader_parallel}. (global)")
|
||||
scheduler_group.add_argument(
|
||||
"--scheduler-writer-parallel",
|
||||
type=int,
|
||||
default=EngineArgs.scheduler_writer_parallel,
|
||||
help=f"SplitWise Use, Results Writer Sync Parallel, "
|
||||
f"Default is {EngineArgs.scheduler_writer_parallel}. (global)")
|
||||
scheduler_group.add_argument(
|
||||
"--scheduler-reader-batch-size",
|
||||
type=int,
|
||||
default=EngineArgs.scheduler_reader_batch_size,
|
||||
help=f"SplitWise Use, Results Reader Batch Size, "
|
||||
f"Default is {EngineArgs.scheduler_reader_batch_size}. (global)")
|
||||
scheduler_group.add_argument(
|
||||
"--scheduler-writer-batch-size",
|
||||
type=int,
|
||||
default=EngineArgs.scheduler_writer_batch_size,
|
||||
help=f"SplitWise Use, Results Writer Batch Size, "
|
||||
f"Default is {EngineArgs.scheduler_writer_batch_size}. (global)")
|
||||
|
||||
return parser
|
||||
|
||||
@@ -385,18 +668,37 @@ class EngineArgs:
|
||||
"""
|
||||
return ModelConfig(model_name_or_path=self.model,
|
||||
config_json_file=self.model_config_name,
|
||||
dynamic_load_weight=self.dynamic_load_weight)
|
||||
dynamic_load_weight=self.dynamic_load_weight,
|
||||
quantization=self.quantization)
|
||||
|
||||
def create_cache_config(self) -> CacheConfig:
|
||||
def create_cache_config(self, model_cfg) -> CacheConfig:
|
||||
"""
|
||||
Create and return a CacheConfig object based on the current settings.
|
||||
"""
|
||||
return CacheConfig(
|
||||
block_size=self.block_size,
|
||||
tensor_parallel_size=self.tensor_parallel_size,
|
||||
gpu_memory_utilization=self.gpu_memory_utilization,
|
||||
num_gpu_blocks_override=self.num_gpu_blocks_override,
|
||||
kv_cache_ratio=self.kv_cache_ratio,
|
||||
enable_prefix_caching=self.enable_prefix_caching)
|
||||
enable_prefix_caching=self.enable_prefix_caching,
|
||||
swap_space=self.swap_space,
|
||||
cache_queue_port=self.cache_queue_port,
|
||||
model_cfg=model_cfg,
|
||||
enable_chunked_prefill=self.enable_chunked_prefill,
|
||||
enc_dec_block_num=self.static_decode_blocks,
|
||||
rdma_comm_ports=self.rdma_comm_ports,
|
||||
cache_transfer_protocol=self.cache_transfer_protocol,
|
||||
pd_comm_port=self.pd_comm_port,
|
||||
)
|
||||
|
||||
def create_speculative_config(self) -> SpeculativeConfig:
|
||||
"""
|
||||
"""
|
||||
if self.speculative_config is not None:
|
||||
return SpeculativeConfig(**self.speculative_config)
|
||||
else:
|
||||
return SpeculativeConfig()
|
||||
|
||||
def create_scheduler_config(self) -> SchedulerConfig:
|
||||
"""
|
||||
@@ -404,14 +706,32 @@ class EngineArgs:
|
||||
"""
|
||||
prefix = "scheduler_"
|
||||
prefix_len = len(prefix)
|
||||
extra_params = [
|
||||
"max_model_len", "enable_chunked_prefill",
|
||||
"max_num_partial_prefills", "max_long_partial_prefills",
|
||||
"long_prefill_token_threshold"
|
||||
]
|
||||
|
||||
all = asdict(self)
|
||||
params = dict()
|
||||
for k, v in all.items():
|
||||
if k[:prefix_len] == prefix:
|
||||
params[k[prefix_len:]] = v
|
||||
elif k in extra_params:
|
||||
params[k] = v
|
||||
|
||||
return SchedulerConfig(**params)
|
||||
|
||||
def create_parallel_config(self) -> ParallelConfig:
|
||||
"""
|
||||
Create and return a ParallelConfig object based on the current settings.
|
||||
"""
|
||||
return ParallelConfig(
|
||||
tensor_parallel_size=self.tensor_parallel_size,
|
||||
enable_expert_parallel=self.enable_expert_parallel,
|
||||
data_parallel_size=self.data_parallel_size,
|
||||
)
|
||||
|
||||
def create_engine_config(self) -> Config:
|
||||
"""
|
||||
Create and return a Config object based on the current settings.
|
||||
@@ -426,22 +746,37 @@ class EngineArgs:
|
||||
else:
|
||||
self.max_num_batched_tokens = self.max_model_len
|
||||
scheduler_cfg = self.create_scheduler_config()
|
||||
|
||||
speculative_cfg = self.create_speculative_config()
|
||||
|
||||
return Config(
|
||||
model_name_or_path=self.model,
|
||||
model_config=model_cfg,
|
||||
scheduler_config=scheduler_cfg,
|
||||
tokenizer=self.tokenizer,
|
||||
cache_config=self.create_cache_config(),
|
||||
cache_config=self.create_cache_config(model_cfg),
|
||||
parallel_config=self.create_parallel_config(),
|
||||
max_model_len=self.max_model_len,
|
||||
tensor_parallel_size=self.tensor_parallel_size,
|
||||
max_num_seqs=self.max_num_seqs,
|
||||
mm_processor_kwargs=self.mm_processor_kwargs,
|
||||
speculative_config=self.speculative_config,
|
||||
speculative_config=speculative_cfg,
|
||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||
nnode=self.nnode,
|
||||
pod_ips=self.pod_ips,
|
||||
use_warmup=self.use_warmup,
|
||||
engine_worker_queue_port=self.engine_worker_queue_port,
|
||||
limit_mm_per_prompt=self.limit_mm_per_prompt,
|
||||
mm_processor_kwargs=self.mm_processor_kwargs,
|
||||
enable_mm=self.enable_mm,
|
||||
enable_chunked_prefill=self.enable_chunked_prefill,
|
||||
reasoning_parser=self.reasoning_parser,
|
||||
splitwise_role=self.splitwise_role,
|
||||
innode_prefill_ports=self.innode_prefill_ports,
|
||||
max_num_partial_prefills=self.max_num_partial_prefills,
|
||||
max_long_partial_prefills=self.max_long_partial_prefills,
|
||||
long_prefill_token_threshold=self.long_prefill_token_threshold,
|
||||
enable_static_graph_inference=self.enable_static_graph_inference,
|
||||
use_cudagraph=self.use_cudagraph,
|
||||
max_capture_batch_size=self.max_capture_batch_size,
|
||||
guided_decoding_backend=self.guided_decoding_backend,
|
||||
disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
|
||||
)
|
||||
|
||||
@@ -1,21 +1,28 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.scheduler import SchedulerConfig
|
||||
from fastdeploy.utils import (check_unified_ckpt, get_host_ip,
|
||||
from fastdeploy.utils import (ceil_div, check_unified_ckpt, get_host_ip,
|
||||
is_port_available, llm_logger)
|
||||
|
||||
TaskOption = Literal["generate"]
|
||||
@@ -23,37 +30,37 @@ TaskOption = Literal["generate"]
|
||||
|
||||
class ModelConfig:
|
||||
"""
|
||||
Configuration class for model settings and parameters.
|
||||
Configuration class for the model.
|
||||
|
||||
Attributes:
|
||||
model_dir (str): Path to the model directory
|
||||
is_unified_ckpt (bool): Whether the checkpoint uses unified format
|
||||
model_name_or_path (str): Model identifier or path
|
||||
dynamic_load_weight (int): Dynamic weight loading flag
|
||||
"""
|
||||
Attributes:
|
||||
model_dir (str): Directory path to the model.
|
||||
is_unified_ckpt (bool): Flag indicating if the checkpoint is unified.
|
||||
model_name_or_path (str): Name or path of the model.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_name_or_path: str,
|
||||
config_json_file: str = "config.json",
|
||||
dynamic_load_weight: int = 0,
|
||||
quantization: str = None,
|
||||
download_dir: Optional[str] = None):
|
||||
"""
|
||||
Initialize model configuration.
|
||||
Initialize the ModelConfig class.
|
||||
|
||||
Args:
|
||||
model_name_or_path (str): Model identifier or path
|
||||
config_json_file (str): Model config file name (default: 'config.json')
|
||||
dynamic_load_weight (int): Dynamic weight loading mode (default: 0)
|
||||
download_dir (Optional[str]): Directory for downloaded models (default: None)
|
||||
model_name_or_path (str): Name or path of the model.
|
||||
config_json_file (str): Path to the configuration JSON file. Default is 'config.json'.
|
||||
download_dir (Optional[str]): Directory to download model files. Default is None.
|
||||
"""
|
||||
self.model_dir = model_name_or_path
|
||||
self.is_unified_ckpt = check_unified_ckpt(self.model_dir)
|
||||
self.dynamic_load_weight = dynamic_load_weight
|
||||
self.quantization = quantization
|
||||
|
||||
config_file = os.path.join(model_name_or_path, config_json_file)
|
||||
if os.path.isfile(model_name_or_path):
|
||||
try:
|
||||
from paddlenlp.transformers import AutoConfig
|
||||
from paddleformers.transformers import AutoConfig
|
||||
config = AutoConfig.from_pretrained(model_name_or_path)
|
||||
config_dict = {
|
||||
k: v
|
||||
@@ -63,10 +70,10 @@ Attributes:
|
||||
setattr(self, key, value)
|
||||
except Exception:
|
||||
llm_logger.error(
|
||||
"Don't support the current model, you can use `paddlenlp` to register your model."
|
||||
"Don't support the current model, you can use `paddleformers` to register your model."
|
||||
)
|
||||
raise ValueError(
|
||||
"Don't support the current model, you can use `paddlenlp` to register your model."
|
||||
"Don't support the current model, you can use `paddleformers` to register your model."
|
||||
)
|
||||
else:
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
@@ -85,12 +92,9 @@ Attributes:
|
||||
|
||||
def override_name_from_config(self):
|
||||
"""
|
||||
Update attribute names from model configuration.
|
||||
Handles special cases like:
|
||||
- Renaming infer_model_mp_num to tensor_parallel_size
|
||||
- Adjusting num_hidden_layers based on remove_tail_layer
|
||||
- Setting default mla_use_absorb value
|
||||
Override attribute names from the exported model's configuration.
|
||||
"""
|
||||
|
||||
if not self.is_unified_ckpt and hasattr(self, "infer_model_mp_num"):
|
||||
self.tensor_parallel_size = self.infer_model_mp_num
|
||||
del self.infer_model_mp_num
|
||||
@@ -107,27 +111,19 @@ Attributes:
|
||||
|
||||
if not hasattr(self, "mla_use_absorb"):
|
||||
self.mla_use_absorb = False
|
||||
if not hasattr(self, "head_dim"):
|
||||
assert hasattr(self, "hidden_size") and hasattr(
|
||||
self, "num_attention_heads")
|
||||
self.head_dim = self.hidden_size // self.num_attention_heads
|
||||
|
||||
def read_from_env(self):
|
||||
"""
|
||||
Load configuration from environment variables.
|
||||
Sets default values if env vars not found.
|
||||
Reads:
|
||||
- MAX_STOP_SEQS_NUM (default: 5)
|
||||
- STOP_SEQS_MAX_LEN (default: 8)
|
||||
- ELLM_DYNAMIC_QUANT_TYPE (default: 'default')
|
||||
- ELLM_DYNAMIC_USE_STOP_SEQS (default: 0)
|
||||
- COMPRESSION_RATIO (default: 1.0)
|
||||
- ROPE_THETA (default: 10000)
|
||||
"""
|
||||
self.max_stop_seqs_num = int(os.getenv("MAX_STOP_SEQS_NUM", "5"))
|
||||
self.stop_seqs_max_len = int(os.getenv("STOP_SEQS_MAX_LEN", "8"))
|
||||
Read configuration information from environment variables and update the object's attributes.
|
||||
|
||||
self.ellm_dynamic_quant_type = os.getenv("ELLM_DYNAMIC_QUANT_TYPE",
|
||||
"default")
|
||||
# Whether to use stop sequences in dynamic graph inference
|
||||
self.ellm_dynamic_use_stop_seqs = int(
|
||||
os.getenv("ELLM_DYNAMIC_USE_STOP_SEQS", "0")) == 1
|
||||
If an attribute is not present or is an empty string in the environment variables, use the default value.
|
||||
"""
|
||||
self.max_stop_seqs_num = int(envs.FD_MAX_STOP_SEQS_NUM)
|
||||
self.stop_seqs_max_len = int(envs.FD_STOP_SEQS_MAX_LEN)
|
||||
|
||||
def reset_config_value(key, value):
|
||||
if not hasattr(self, key.lower()):
|
||||
@@ -144,13 +140,12 @@ Attributes:
|
||||
reset_config_value("ROPE_THETA", 10000)
|
||||
|
||||
def _get_download_model(self, model_name, model_type="default"):
|
||||
# TODO: Implement dynamic graph for self-downloading models
|
||||
# TODO: Provide dynamic graph for self-downloading and save to the specified download directory.
|
||||
pass
|
||||
|
||||
def print(self):
|
||||
"""
|
||||
Print current model configuration.
|
||||
Logs all attributes and their values.
|
||||
Print all configuration information.
|
||||
"""
|
||||
llm_logger.info("Model Configuration Information :")
|
||||
for k, v in self.__dict__.items():
|
||||
@@ -161,19 +156,18 @@ Attributes:
|
||||
|
||||
class CacheConfig:
|
||||
"""
|
||||
Configuration for key-value cache management.
|
||||
Configuration for the KV cache.
|
||||
|
||||
Attributes:
|
||||
block_size (int): Tokens per cache block
|
||||
gpu_memory_utilization (float): GPU memory usage fraction (0-1)
|
||||
cache_dtype (str): Data type for cache (default: 'bfloat16')
|
||||
num_gpu_blocks_override (Optional[int]): Manual GPU blocks override
|
||||
kv_cache_ratio (float): Max blocks ratio (default: 0.75)
|
||||
enc_dec_block_num (int): Encoder-decoder blocks count
|
||||
enable_prefix_caching (bool): Prefix caching enable flag
|
||||
total_block_num (int): Total available blocks
|
||||
prefill_kvcache_block_num (int): Blocks allocated for prefill
|
||||
"""
|
||||
Attributes:
|
||||
block_size (int): Size of a cache block in number of tokens.
|
||||
gpu_memory_utilization (float): Fraction of GPU memory to use for model execution.
|
||||
cache_dtype (str): Data type for kv cache storage. Default is 'bfloat16'.
|
||||
num_gpu_blocks_override (Optional[int]): Number of GPU blocks to use.
|
||||
Overrides profiled num_gpu_blocks if provided.
|
||||
kv_cache_ratio (float): Ratio for calculating the maximum block number.
|
||||
enc_dec_block_num (int): Number of encoder-decoder blocks.
|
||||
enable_prefix_caching (bool): Flag to enable prefix caching.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -181,21 +175,31 @@ Attributes:
|
||||
gpu_memory_utilization: float,
|
||||
cache_dtype: str = "bfloat16",
|
||||
num_gpu_blocks_override: Optional[int] = None,
|
||||
swap_space: Optional[int] = None,
|
||||
kv_cache_ratio: float = 0.75,
|
||||
enc_dec_block_num: int = 2,
|
||||
enable_prefix_caching: bool = False,
|
||||
tensor_parallel_size: int = 1,
|
||||
enable_prefix_caching=False,
|
||||
enable_ssd_cache=False,
|
||||
model_cfg=None,
|
||||
cache_queue_port=None,
|
||||
enable_chunked_prefill=False,
|
||||
rdma_comm_ports=None,
|
||||
cache_transfer_protocol=None,
|
||||
pd_comm_port=None,
|
||||
):
|
||||
"""
|
||||
Initialize cache configuration.
|
||||
Initialize the CacheConfig class.
|
||||
|
||||
Args:
|
||||
block_size (int): Tokens per cache block
|
||||
gpu_memory_utilization (float): GPU memory usage target (0-1)
|
||||
cache_dtype (str): Cache data type (default: 'bfloat16')
|
||||
num_gpu_blocks_override (Optional[int]): Manual GPU blocks setting
|
||||
kv_cache_ratio (float): Max blocks ratio (default: 0.75)
|
||||
enc_dec_block_num (int): Encoder-decoder blocks count
|
||||
enable_prefix_caching (bool): Enable prefix sharing
|
||||
block_size (int): Size of a cache block in number of tokens.
|
||||
gpu_memory_utilization (float): Fraction of GPU memory to use.
|
||||
cache_dtype (str): Data type for cache storage. Default is 'bfloat16'.
|
||||
num_gpu_blocks_override (Optional[int]): Override for number of GPU blocks.
|
||||
num_cpu_blocks (Optional[int]): Number of CPU blocks.
|
||||
kv_cache_ratio (float): Ratio for max block calculation.
|
||||
enc_dec_block_num (int): Number of encoder-decoder blocks.
|
||||
enable_prefix_caching (bool): Enable prefix caching.
|
||||
"""
|
||||
self.block_size = block_size
|
||||
self.gpu_memory_utilization = gpu_memory_utilization
|
||||
@@ -203,20 +207,74 @@ Attributes:
|
||||
self.kv_cache_ratio = kv_cache_ratio
|
||||
self.enc_dec_block_num = enc_dec_block_num
|
||||
self.cache_dtype = cache_dtype
|
||||
if hasattr(model_cfg, "quantization_config"):
|
||||
self.cache_dtype = model_cfg.quantization_config.get(
|
||||
"kv_cache_quant_type", cache_dtype)
|
||||
|
||||
self.enable_chunked_prefill = enable_chunked_prefill
|
||||
self.rdma_comm_ports = rdma_comm_ports
|
||||
self.cache_transfer_protocol = cache_transfer_protocol
|
||||
self.pd_comm_port = pd_comm_port
|
||||
|
||||
if rdma_comm_ports is not None and isinstance(rdma_comm_ports, str):
|
||||
self.rdma_comm_ports = rdma_comm_ports.split(',')
|
||||
|
||||
if pd_comm_port is not None and isinstance(pd_comm_port, str):
|
||||
self.pd_comm_port = [int(port) for port in pd_comm_port.split(",")]
|
||||
|
||||
self.enable_prefix_caching = enable_prefix_caching
|
||||
if swap_space is None:
|
||||
self.enable_hierarchical_cache = False
|
||||
else:
|
||||
self.enable_hierarchical_cache = True
|
||||
|
||||
self.enable_ssd_cache = enable_ssd_cache
|
||||
self.model_cfg = model_cfg
|
||||
self.cache_queue_port = cache_queue_port
|
||||
self.swap_space = swap_space
|
||||
|
||||
if (hasattr(self.model_cfg, "num_key_value_heads")
|
||||
and hasattr(self.model_cfg, "num_key_value_heads")
|
||||
and self.model_cfg.num_key_value_heads is not None
|
||||
and int(self.model_cfg.num_key_value_heads) > 0):
|
||||
kv_num_head = int(self.model_cfg.num_key_value_heads)
|
||||
else:
|
||||
kv_num_head = self.model_cfg.num_attention_heads
|
||||
self.model_cfg.kv_num_head = kv_num_head
|
||||
|
||||
# TODO check name
|
||||
if "int4" in self.cache_dtype.lower(
|
||||
) or "float4" in self.cache_dtype.lower():
|
||||
byte_size = 0.5
|
||||
self.cache_dtype = "uint8"
|
||||
elif "int8" in self.cache_dtype.lower(
|
||||
) or "float8" in self.cache_dtype.lower():
|
||||
self.cache_dtype = "uint8"
|
||||
byte_size = 1
|
||||
else:
|
||||
byte_size = 2
|
||||
|
||||
self.each_token_cache_space = int(
|
||||
self.model_cfg.num_layers * kv_num_head * self.model_cfg.head_dim *
|
||||
byte_size)
|
||||
self.bytes_per_block = int(self.each_token_cache_space *
|
||||
self.block_size)
|
||||
self.bytes_per_layer_per_block = int(
|
||||
self.block_size * self.model_cfg.kv_num_head *
|
||||
self.model_cfg.head_dim // tensor_parallel_size * byte_size)
|
||||
|
||||
if self.swap_space is None:
|
||||
self.num_cpu_blocks = 0
|
||||
else:
|
||||
self.num_cpu_blocks = int(self.swap_space * 1024**3 /
|
||||
self.bytes_per_block)
|
||||
self._verify_args()
|
||||
|
||||
def metrics_info(self):
|
||||
"""
|
||||
Convert config to metrics dictionary.
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: Key-value pairs of all config attributes
|
||||
"""
|
||||
"""Convert cache_config to dict(key: str, value: str) for prometheus metrics info."""
|
||||
return {key: str(value) for key, value in self.__dict__.items()}
|
||||
|
||||
def _verify_args(self):
|
||||
"""Validate configuration arguments."""
|
||||
if self.gpu_memory_utilization > 1.0:
|
||||
raise ValueError(
|
||||
"GPU memory utilization must be less than 1.0. Got "
|
||||
@@ -227,46 +285,38 @@ Attributes:
|
||||
|
||||
def postprocess(self, num_total_tokens, number_of_tasks):
|
||||
"""
|
||||
Calculate block allocation based on tokens and tasks.
|
||||
|
||||
Args:
|
||||
num_total_tokens (int): Total tokens to process
|
||||
number_of_tasks (int): Number of parallel tasks
|
||||
|
||||
Sets:
|
||||
dec_token_num (int): Decoder tokens per block
|
||||
total_block_num (int): Total blocks needed
|
||||
prefill_kvcache_block_num (int): Blocks for prefill phase
|
||||
calculate block num
|
||||
"""
|
||||
self.dec_token_num = self.enc_dec_block_num * self.block_size
|
||||
if self.num_gpu_blocks_override is not None:
|
||||
self.total_block_num = self.num_gpu_blocks_override
|
||||
self.prefill_kvcache_block_num= int(self.total_block_num * self.kv_cache_ratio)
|
||||
self.prefill_kvcache_block_num = int(self.total_block_num *
|
||||
self.kv_cache_ratio)
|
||||
else:
|
||||
length = num_total_tokens // number_of_tasks
|
||||
block_num = (length + self.block_size - 1 + self.enc_dec_block_num) // self.block_size
|
||||
self.total_block_num = block_num * number_of_tasks
|
||||
self.prefill_kvcache_block_num= self.total_block_num
|
||||
llm_logger.info(f"Doing profile, the total_block_num:{self.total_block_num}")
|
||||
block_num = (length + self.block_size - 1 +
|
||||
self.dec_token_num) // self.block_size
|
||||
self.total_block_num = block_num * number_of_tasks
|
||||
self.prefill_kvcache_block_num = self.total_block_num
|
||||
llm_logger.info(
|
||||
f"Doing profile, the total_block_num:{self.total_block_num}")
|
||||
|
||||
def reset(self, num_gpu_blocks):
|
||||
"""
|
||||
Reset GPU block allocation.
|
||||
|
||||
Args:
|
||||
num_gpu_blocks (int): New total blocks count
|
||||
|
||||
Updates:
|
||||
total_block_num (int)
|
||||
prefill_kvcache_block_num (int)
|
||||
reset gpu block number
|
||||
"""
|
||||
self.total_block_num = num_gpu_blocks
|
||||
self.prefill_kvcache_block_num= int(self.total_block_num * self.kv_cache_ratio)
|
||||
llm_logger.info((f"Reset block num, the total_block_num:{self.total_block_num},"
|
||||
f" prefill_kvcache_block_num:{self.prefill_kvcache_block_num}"))
|
||||
self.total_block_num = num_gpu_blocks
|
||||
self.prefill_kvcache_block_num = int(self.total_block_num *
|
||||
self.kv_cache_ratio)
|
||||
llm_logger.info(
|
||||
(f"Reset block num, the total_block_num:{self.total_block_num},"
|
||||
f" prefill_kvcache_block_num:{self.prefill_kvcache_block_num}"))
|
||||
|
||||
def print(self):
|
||||
"""Print current cache configuration."""
|
||||
"""
|
||||
print all config
|
||||
|
||||
"""
|
||||
llm_logger.info("Cache Configuration Information :")
|
||||
for k, v in self.__dict__.items():
|
||||
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
||||
@@ -274,38 +324,182 @@ Attributes:
|
||||
"=============================================================")
|
||||
|
||||
|
||||
class SpeculativeConfig:
|
||||
"""
|
||||
Speculative Decoding Configuration class.
|
||||
|
||||
Attributes:
|
||||
method (Optional[str]): Method used for speculative decoding.
|
||||
num_speculative_tokens (int): Maximum draft tokens, default is 1.
|
||||
model_name_or_path (Optional[str]): Path of the model.
|
||||
quantization (str): Quantization method for draft model, default is WINT8.
|
||||
max_model_len: Optional[int]: Maximum model length for draft model.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
method: Optional[str] = None,
|
||||
num_speculative_tokens: Optional[int] = 1,
|
||||
model: Optional[str] = None,
|
||||
quantization: Optional[str] = "WINT8",
|
||||
max_model_len: Optional[int] = None,
|
||||
**kwargs):
|
||||
self.model_name_or_path = model
|
||||
self.method = method
|
||||
self.num_speculative_tokens = num_speculative_tokens
|
||||
self.quantization = quantization
|
||||
self.max_model_len = max_model_len
|
||||
# Fixed now
|
||||
self.num_gpu_block_expand_ratio = 1
|
||||
self.num_extra_cache_layer = 0
|
||||
|
||||
for key, value in kwargs.items():
|
||||
try:
|
||||
setattr(self, key, value)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
self.read_model_config()
|
||||
self.reset()
|
||||
|
||||
def read_model_config(self):
|
||||
"""
|
||||
Read configuration from file.
|
||||
"""
|
||||
self.model_config = {}
|
||||
if not self.enabled_speculative_decoding():
|
||||
return
|
||||
|
||||
self.is_unified_ckpt = check_unified_ckpt(self.model_name_or_path)
|
||||
if self.model_name_or_path is None:
|
||||
return
|
||||
|
||||
self.config_path = os.path.join(self.model_name_or_path, "config.json")
|
||||
if os.path.exists(self.config_path):
|
||||
self.model_config = json.load(
|
||||
open(self.config_path, 'r', encoding='utf-8'))
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset configuration.
|
||||
"""
|
||||
|
||||
def reset_value(cls, value_name, key=None, default=None):
|
||||
if key is not None and key in cls.model_config:
|
||||
setattr(cls, value_name, cls.model_config[key])
|
||||
elif getattr(cls, value_name, None) is None:
|
||||
setattr(cls, value_name, default)
|
||||
|
||||
if not self.enabled_speculative_decoding():
|
||||
return
|
||||
|
||||
# NOTE(liuzichang): We will support multi-layer in future
|
||||
if self.method in ["mtp"]:
|
||||
self.num_extra_cache_layer = 1
|
||||
|
||||
def enabled_speculative_decoding(self):
|
||||
"""
|
||||
Check if speculative decoding is enabled.
|
||||
"""
|
||||
if self.method is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
def to_json_string(self):
|
||||
"""
|
||||
Convert speculative_config to json string.
|
||||
"""
|
||||
return json.dumps({
|
||||
key: value
|
||||
for key, value in self.__dict__.items() if value is not None
|
||||
})
|
||||
|
||||
def print(self):
|
||||
"""
|
||||
print all config
|
||||
|
||||
"""
|
||||
llm_logger.info("Speculative Decoding Configuration Information :")
|
||||
for k, v in self.__dict__.items():
|
||||
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
||||
llm_logger.info(
|
||||
"=============================================================")
|
||||
|
||||
|
||||
class ParallelConfig:
|
||||
"""
|
||||
Configuration for parallelism.
|
||||
|
||||
Attributes:
|
||||
tensor_parallel_size (int): Size of tensor parallelism.
|
||||
data_parallel_size (int): Size of data parallelism.
|
||||
local_data_parallel_id (int): ID of local data parallel.
|
||||
enable_expert_parallel (bool): Whether to enable expert parallel.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tensor_parallel_size: int = 1,
|
||||
data_parallel_size: int = 1,
|
||||
enable_expert_parallel: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the ParallelConfig class.
|
||||
|
||||
Args:
|
||||
tensor_parallel_size (int): Size of tensor parallelism.
|
||||
data_parallel_size (int): Size of data parallelism.
|
||||
local_data_parallel_id (int): ID of local data parallel.
|
||||
enable_expert_parallel (bool): Whether to enable expert parallel.
|
||||
"""
|
||||
self.tensor_parallel_size = tensor_parallel_size
|
||||
self.data_parallel_size = data_parallel_size
|
||||
self.enable_expert_parallel = enable_expert_parallel
|
||||
self.expert_parallel_size = data_parallel_size
|
||||
self.local_data_parallel_id = 0
|
||||
|
||||
def print(self):
|
||||
"""
|
||||
print all config
|
||||
|
||||
"""
|
||||
llm_logger.info("Parallel Configuration Information :")
|
||||
for k, v in self.__dict__.items():
|
||||
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
||||
llm_logger.info("==================")
|
||||
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Main engine configuration class combining all components.
|
||||
Initial configuration class.
|
||||
|
||||
Attributes:
|
||||
model_config (ModelConfig): Model settings
|
||||
cache_config (CacheConfig): Cache management settings
|
||||
scheduler_config (SchedulerConfig): Task scheduling settings
|
||||
model_name_or_path (str): Model identifier/path
|
||||
tokenizer (str): Tokenizer identifier
|
||||
tensor_parallel_size (int): Parallelism degree (default: 8)
|
||||
nnode (int): Node count (default: 1)
|
||||
max_model_len (int): Max sequence length (default: 8192)
|
||||
max_num_seqs (int): Max concurrent sequences (default: 8)
|
||||
max_num_batched_tokens (Optional[int]): Max batched tokens
|
||||
pod_ips (Optional[List[str]]): Cluster node IPs
|
||||
mm_processor_kwargs (Optional[Dict]): Multi-modal processor args
|
||||
speculative_config (Optional[Dict]): Speculative execution settings
|
||||
use_warmup (bool): Warmup enable flag
|
||||
enable_mm (bool): Multi-modal enable flag
|
||||
enable_chunked_prefill (bool): Chunked prefill enable flag
|
||||
device_ids (str): GPU device IDs
|
||||
tp_num_per_node (int): Tensor parallelism per node
|
||||
host_ip (str): Current host IP
|
||||
paddle_commit_id (str): PaddlePaddle version
|
||||
"""
|
||||
Attributes:
|
||||
model_config (ModelConfig): Model configuration object.
|
||||
cache_config (CacheConfig): Cache configuration object.
|
||||
model_name_or_path (str): Directory path to the model or the model name.
|
||||
tokenizer (Optional[str]): Default is the model.
|
||||
max_num_batched_tokens (Optional[int]): Maximum number of batched tokens.
|
||||
tensor_parallel_size (int): Tensor parallel size.
|
||||
nnode (int): Number of nodes.
|
||||
max_model_len (int): Maximum model length. Default is 8192.
|
||||
max_num_seqs (int): Maximum number of sequences. Default is 8.
|
||||
mm_processor_kwargs (Optional[Dict[str, Any]]): Additional arguments for multi-modal processor.
|
||||
speculative_config (Optional[Dict[str, Any]]): Speculative execution configuration.
|
||||
use_warmup (bool): Flag to use warmup.
|
||||
engine_worker_queue_port (int): Port for engine worker queue.
|
||||
enable_mm (bool): Flag to enable multi-modal processing.
|
||||
reasoning_parser(str): Flag specifies the reasoning parser to use for
|
||||
extracting reasoning content from the model output
|
||||
splitwise_role (str): Splitwise role.
|
||||
innode_prefill_ports (Optional[List[int]]): Innode prefill ports.
|
||||
Temporary configuration, will be removed in the future.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
cache_config: CacheConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
model_name_or_path: str = None,
|
||||
tokenizer: str = None,
|
||||
tensor_parallel_size: int = 8,
|
||||
@@ -314,38 +508,57 @@ Attributes:
|
||||
max_num_seqs: int = 8,
|
||||
max_num_batched_tokens: Optional[int] = None,
|
||||
pod_ips: Optional[List[str]] = None,
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
speculative_config: Optional[Dict[str, Any]] = None,
|
||||
use_warmup: bool = False,
|
||||
engine_worker_queue_port: int = 8002,
|
||||
limit_mm_per_prompt: Optional[Dict[str, Any]] = None,
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
enable_mm: bool = False,
|
||||
enable_chunked_prefill: bool = False,
|
||||
splitwise_role: str = "mixed",
|
||||
innode_prefill_ports: Optional[List[int]] = None,
|
||||
max_num_partial_prefills: int = 1,
|
||||
max_long_partial_prefills: int = 1,
|
||||
long_prefill_token_threshold: int = 0,
|
||||
reasoning_parser: str = None,
|
||||
enable_static_graph_inference: bool = False,
|
||||
use_cudagraph: bool = False,
|
||||
max_capture_batch_size: int = 64,
|
||||
guided_decoding_backend: Optional[str] = None,
|
||||
disable_any_whitespace: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize engine configuration.
|
||||
Initialize the Config class.
|
||||
|
||||
Args:
|
||||
model_config (ModelConfig): Model settings
|
||||
cache_config (CacheConfig): Cache settings
|
||||
scheduler_config (SchedulerConfig): Scheduler settings
|
||||
model_name_or_path (str): Model identifier (default: None)
|
||||
tokenizer (str): Tokenizer identifier (default: None)
|
||||
tensor_parallel_size (int): Parallelism degree (default: 8)
|
||||
nnode (int): Node count (default: 1)
|
||||
max_model_len (int): Max sequence length (default: 8192)
|
||||
max_num_seqs (int): Max concurrent sequences (default: 8)
|
||||
max_num_batched_tokens (Optional[int]): Max batched tokens (default: None)
|
||||
pod_ips (Optional[List[str]]): Cluster node IPs (default: None)
|
||||
mm_processor_kwargs (Optional[Dict]): Multi-modal args (default: None)
|
||||
speculative_config (Optional[Dict]): Speculative settings (default: None)
|
||||
use_warmup (bool): Warmup flag (default: False)
|
||||
engine_worker_queue_port (int): Worker queue port (default: 8002)
|
||||
enable_mm (bool): Multi-modal flag (default: False)
|
||||
enable_chunked_prefill (bool): Chunked prefill flag (default: False)
|
||||
model_config (ModelConfig): Model configuration object.
|
||||
cache_config (CacheConfig): Cache configuration object.
|
||||
parallel_config (ParallelConfig): Parallel configuration object.
|
||||
scheduler_config (SchedulerConfig): Scheduler configuration object.
|
||||
model_name_or_path (str): Model directory path or model name.
|
||||
tokenizer (str): Default is the model.
|
||||
tensor_parallel_size (int): Tensor parallel size. Default is 8.
|
||||
nnode (int): Number of nodes. Default is 1.
|
||||
max_model_len (int): Maximum model length. Default is 8192.
|
||||
max_num_seqs (int): Maximum number of sequences. Default is 8.
|
||||
max_num_batched_tokens (Optional[int]): Maximum number of batched tokens. Default is None.
|
||||
pod_ips (Optional[List[str]]): List of POD IPs. Default is None.
|
||||
mm_processor_kwargs (Optional[Dict[str, Any]]): Additional arguments for multi-modal processor. Default is None.
|
||||
speculative_config (Optional[Dict[str, Any]]): Speculative execution configuration. Default is None.
|
||||
use_warmup (bool): Flag to use warmup. Default is False.
|
||||
engine_worker_queue_port (int): Engine worker queue port. Default is 8002.
|
||||
enable_mm (bool): Flag to enable multi-modal processing. Default is False.
|
||||
splitwise_role (str): Splitwise role. Default is "mixed".
|
||||
innode_prefill_ports (Optional[List[int]]): Innode prefill ports. Default is None.
|
||||
reasoning_parser (str): Flag specifies the reasoning parser to use for
|
||||
extracting reasoning content from the model output. Default is None.
|
||||
guided_decoding_backend(str): Guided decoding backend. Default is None.
|
||||
disable_any_whitespace(bool): Disable any whitespace when using guided decoding.
|
||||
Default is False.
|
||||
"""
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.parallel_config = parallel_config
|
||||
self.model_name_or_path = model_name_or_path
|
||||
self.tokenizer = tokenizer
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
@@ -354,20 +567,56 @@ Attributes:
|
||||
self.pod_ips = pod_ips
|
||||
self.max_model_len = max_model_len
|
||||
self.max_num_seqs = max_num_seqs
|
||||
self.limit_mm_per_prompt = limit_mm_per_prompt
|
||||
self.mm_processor_kwargs = mm_processor_kwargs
|
||||
self.enable_mm = enable_mm
|
||||
self.speculative_config = speculative_config
|
||||
self.use_warmup = use_warmup
|
||||
self.enable_chunked_prefill = enable_chunked_prefill
|
||||
self.splitwise_role = splitwise_role
|
||||
self.innode_prefill_ports = innode_prefill_ports
|
||||
self.max_num_partial_prefills = max_num_partial_prefills
|
||||
self.max_long_partial_prefills = max_long_partial_prefills
|
||||
self.long_prefill_token_threshold = long_prefill_token_threshold
|
||||
self.reasoning_parser = reasoning_parser
|
||||
self.enable_static_graph_inference = enable_static_graph_inference
|
||||
self.use_cudagraph = use_cudagraph
|
||||
self.max_capture_batch_size = max_capture_batch_size
|
||||
self.guided_decoding_backend = guided_decoding_backend
|
||||
self.disable_any_whitespace = disable_any_whitespace
|
||||
|
||||
|
||||
if self.innode_prefill_ports is not None:
|
||||
if not isinstance(self.innode_prefill_ports, list):
|
||||
ports = str(self.innode_prefill_ports).split(',')
|
||||
self.innode_prefill_ports = [int(port) for port in ports]
|
||||
|
||||
|
||||
assert self.splitwise_role in ["mixed", "prefill", "decode"]
|
||||
|
||||
# TODO
|
||||
self.max_prefill_batch = 3
|
||||
if current_platform.is_xpu():
|
||||
self.max_prefill_batch = 1
|
||||
if enable_mm:
|
||||
self.max_prefill_batch = 1 # TODO: Currently multi-modal prefill only supports parallelism=1 (needs optimization)
|
||||
self.max_prefill_batch = 1 # TODO:当前多模prefill阶段只支持并行度为1,待优化
|
||||
|
||||
# TODO(@wufeisheng): TP and EP need to be supported simultaneously.
|
||||
assert (self.tensor_parallel_size == 1
|
||||
and self.parallel_config.expert_parallel_size
|
||||
>= 1) or (self.tensor_parallel_size >= 1
|
||||
and self.parallel_config.expert_parallel_size
|
||||
== 1), "TP and EP cannot be enabled at the same time"
|
||||
|
||||
num_ranks = self.tensor_parallel_size * self.parallel_config.expert_parallel_size
|
||||
if num_ranks > 8:
|
||||
local_num_ranks = 8
|
||||
self.nnode = ceil_div(num_ranks, local_num_ranks)
|
||||
else:
|
||||
local_num_ranks = num_ranks
|
||||
|
||||
self.engine_worker_queue_port = engine_worker_queue_port
|
||||
self.device_ids = ",".join(
|
||||
[str(i) for i in range(self.tensor_parallel_size)])
|
||||
self.device_ids = ",".join([str(i) for i in range(min((self.tensor_parallel_size * \
|
||||
self.parallel_config.expert_parallel_size), 8))])
|
||||
self.device_ids = os.getenv("CUDA_VISIBLE_DEVICES", self.device_ids)
|
||||
|
||||
self.read_from_config()
|
||||
@@ -377,45 +626,44 @@ Attributes:
|
||||
|
||||
def postprocess(self):
|
||||
"""
|
||||
Calculate derived parameters:
|
||||
- Validates GPU device count matches tensor_parallel_size
|
||||
- Computes tensor parallelism per node
|
||||
- Gets host IP and Paddle version
|
||||
- Sets default max_num_batched_tokens if not provided
|
||||
- Initializes cache configuration
|
||||
calculate some parameters
|
||||
"""
|
||||
if len(self.device_ids.split(',')) > self.tensor_parallel_size:
|
||||
self.device_ids = ",".join(
|
||||
self.device_ids.split(',')[:self.tensor_parallel_size:])
|
||||
assert len(
|
||||
self.device_ids.split(',')
|
||||
) == self.tensor_parallel_size, f"The number of available GPUs is {len(self.device_ids.split(','))}, which is less than the tensor parallel required {self.tensor_parallel_size}."
|
||||
|
||||
assert self.tensor_parallel_size % self.nnode == 0, f"tensor_parallel_size: {self.tensor_parallel_size} should be divisible by nnode: {self.nnode}"
|
||||
self.tp_num_per_node = self.tensor_parallel_size // self.nnode
|
||||
total_rank = self.tensor_parallel_size * self.parallel_config.expert_parallel_size
|
||||
assert self.device_ids.split(',').__len__() == min(total_rank, 8), \
|
||||
f"invalid CUDA_VISIBLE_DEVICES, should be equal to {min(total_rank, 8)}"
|
||||
self.local_device_ids = self.device_ids.split(
|
||||
',')[:self.tensor_parallel_size]
|
||||
assert self.tensor_parallel_size % self.nnode == 0, \
|
||||
f"tensor_parallel_size: {self.tensor_parallel_size} should be divisible by nnode: {self.nnode}"
|
||||
self.worker_num_per_node = total_rank // self.nnode
|
||||
self.host_ip = get_host_ip()
|
||||
|
||||
import paddle
|
||||
self.paddle_commit_id = paddle.version.commit
|
||||
|
||||
if self.max_num_batched_tokens is None:
|
||||
if self.enable_chunked_prefill:
|
||||
if self.cache_config.enable_chunked_prefill:
|
||||
self.max_num_batched_tokens = 2048
|
||||
else:
|
||||
self.max_num_batched_tokens = self.max_model_len
|
||||
self.cache_config.postprocess(self.max_num_batched_tokens, self.max_num_seqs)
|
||||
|
||||
if self.long_prefill_token_threshold == 0:
|
||||
self.long_prefill_token_threshold = int(self.max_model_len * 0.04)
|
||||
|
||||
self.cache_config.postprocess(self.max_num_batched_tokens,
|
||||
self.max_num_seqs)
|
||||
self.cache_config.max_block_num_per_seq = int(
|
||||
self.max_model_len // self.cache_config.block_size)
|
||||
|
||||
if self.guided_decoding_backend == "auto":
|
||||
if self.enable_mm:
|
||||
self.guided_decoding_backend = "off"
|
||||
else:
|
||||
self.guided_decoding_backend = "xgrammar"
|
||||
|
||||
def check(self):
|
||||
"""
|
||||
Validate configuration values:
|
||||
- max_num_seqs <= 256
|
||||
- engine_worker_queue_port available
|
||||
- 1 <= tensor_parallel_size <= 8
|
||||
- nnode >= 1
|
||||
- max_model_len >= 16
|
||||
- max_num_seqs >= 1
|
||||
- Validates scheduler configuration
|
||||
check the legality of config
|
||||
"""
|
||||
assert (
|
||||
self.max_num_seqs <= 256
|
||||
@@ -434,15 +682,66 @@ Attributes:
|
||||
assert (
|
||||
self.max_num_seqs
|
||||
>= 1), f"max_num_seqs: {self.max_num_seqs} should be larger than 1"
|
||||
assert (
|
||||
self.max_num_batched_tokens >= self.max_num_seqs
|
||||
), f"max_num_batched_tokens: {self.max_num_batched_tokens} " \
|
||||
f"should be larger than or equal to max_num_seqs: {self.max_num_seqs}"
|
||||
assert (self.max_num_batched_tokens <= self.max_model_len * self.max_num_seqs), \
|
||||
f"max_num_batched_tokens: {self.max_num_batched_tokens} should be larger" \
|
||||
f"than or equal to max_num_seqs: {self.max_num_seqs} * max_model_len: {self.max_model_len}"
|
||||
assert (
|
||||
self.max_num_partial_prefills >= 1
|
||||
), f"max_num_partial_prefills: {self.max_num_partial_prefills} should be larger than or equal to 1"
|
||||
|
||||
assert (
|
||||
self.max_long_partial_prefills >= 1
|
||||
), f"max_long_partial_prefills: {self.max_long_partial_prefills} should be larger than or equal to 1"
|
||||
assert (self.max_long_partial_prefills <= self.max_num_partial_prefills), \
|
||||
f"max_long_partial_prefills: {self.max_long_partial_prefills} should " \
|
||||
f"be less than or equal to max_num_partial_prefills: {self.max_num_partial_prefills}"
|
||||
|
||||
if not self.cache_config.enable_chunked_prefill:
|
||||
assert (
|
||||
self.max_num_batched_tokens >= self.max_model_len
|
||||
), f"max_num_batched_tokens: {self.max_num_batched_tokens} " \
|
||||
f"should be larger than or equal to max_model_len: {self.max_model_len}"
|
||||
|
||||
if self.max_num_partial_prefills > 1:
|
||||
assert (self.cache_config.enable_chunked_prefill is True), \
|
||||
"Chunked prefill must be enabled to set max_num_partial_prefills > 1"
|
||||
assert (self.long_prefill_token_threshold < self.max_model_len), \
|
||||
f"long_prefill_token_threshold: {self.long_prefill_token_threshold} should be less than"\
|
||||
f" max_model_len: {self.max_model_len}"
|
||||
|
||||
if self.guided_decoding_backend is not None:
|
||||
assert self.guided_decoding_backend in ["xgrammar", "XGrammar", "auto", "off"], \
|
||||
f"Only support xgrammar、auto guided decoding backend, but got {self.guided_decoding_backend}."
|
||||
|
||||
if self.guided_decoding_backend != "off":
|
||||
# TODO: mm support guided_decoding
|
||||
assert self.enable_mm is False, "Multimodal model currently do not support guided_decoding"
|
||||
|
||||
# TODO: speculative decoding support guided_decoding
|
||||
|
||||
# TODO: xpu support guided_decoding
|
||||
assert not current_platform.is_xpu(
|
||||
), "XPU currently do not support guided_decoding"
|
||||
|
||||
try:
|
||||
pass
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"import XGrammar failed, please install XGrammar use `pip install xgrammar==0.1.19`. \n\t {e}"
|
||||
)
|
||||
|
||||
self.scheduler_config.check()
|
||||
|
||||
def print(self, file=None):
|
||||
"""
|
||||
Print or save current configuration.
|
||||
|
||||
print all config
|
||||
|
||||
Args:
|
||||
file (Optional[str]): File path to save config (default: None)
|
||||
file (str): the path of file to save config
|
||||
"""
|
||||
llm_logger.info(
|
||||
"=================== Configuration Information ===============")
|
||||
@@ -450,7 +749,7 @@ Attributes:
|
||||
if k == "generation_config" and v is not None:
|
||||
for gck, gcv in v.to_dict().items():
|
||||
llm_logger.info("{:<20}:{:<6}{}".format(gck, "", gcv))
|
||||
elif k == "cache_config" or k == "model_config" or k == "scheduler_config":
|
||||
elif k == "cache_config" or k == "model_config" or k == "scheduler_config" or k == "parallel_config":
|
||||
v.print()
|
||||
else:
|
||||
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
|
||||
@@ -464,13 +763,36 @@ Attributes:
|
||||
f.write("{:<20}:{:<6}{}\n".format(k, "", v))
|
||||
f.close()
|
||||
|
||||
def init_cache_info(self):
|
||||
"""
|
||||
initialize cache info
|
||||
"""
|
||||
disaggregate_info = {}
|
||||
if self.splitwise_role != "mixed":
|
||||
disaggregate_info["role"] = self.splitwise_role
|
||||
disaggregate_info["cache_info"] = dict()
|
||||
current_protocol = self.cache_config.cache_transfer_protocol.split(
|
||||
",")
|
||||
disaggregate_info["transfer_protocol"] = current_protocol
|
||||
for protocol in current_protocol:
|
||||
if protocol == "ipc":
|
||||
disaggregate_info["cache_info"][protocol] = {
|
||||
"ip": self.host_ip,
|
||||
"port": self.engine_worker_queue_port,
|
||||
"device_ids": self.local_device_ids
|
||||
}
|
||||
elif protocol == "rdma":
|
||||
disaggregate_info["cache_info"][protocol] = {
|
||||
"ip": self.host_ip,
|
||||
"port": self.cache_config.pd_comm_port[0],
|
||||
"rdma_port": self.cache_config.rdma_comm_ports,
|
||||
}
|
||||
self.disaggregate_info = disaggregate_info
|
||||
llm_logger.info(f"disaggregate_info: {self.disaggregate_info}")
|
||||
|
||||
def read_from_config(self):
|
||||
"""
|
||||
Update configuration from model JSON file.
|
||||
Handles special cases:
|
||||
- infer_model_block_size -> block_size
|
||||
- return_full_hidden_states
|
||||
- infer_model_dtype -> cache_dtype
|
||||
reset model config from json file
|
||||
"""
|
||||
|
||||
def reset_value(cls, value_name, key):
|
||||
@@ -482,9 +804,9 @@ Attributes:
|
||||
)
|
||||
|
||||
reset_value(self.cache_config, "block_size", "infer_model_block_size")
|
||||
reset_value(self.model_config, "return_full_hidden_states", "return_full_hidden_states")
|
||||
reset_value(self.model_config, "return_full_hidden_states",
|
||||
"return_full_hidden_states")
|
||||
reset_value(self.cache_config, "cache_dtype", "infer_model_dtype")
|
||||
|
||||
def __str__(self) -> str:
|
||||
return json.dumps(self.__dict__, indent=4)
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
370
fastdeploy/engine/expert_service.py
Normal file
370
fastdeploy/engine/expert_service.py
Normal file
@@ -0,0 +1,370 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import signal
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import weakref
|
||||
|
||||
import numpy as np
|
||||
|
||||
from fastdeploy.engine.resource_manager import ResourceManager
|
||||
from fastdeploy.inter_communicator import EngineWorkerQueue
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.output.token_processor import TokenProcessor
|
||||
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
|
||||
from fastdeploy.utils import EngineError, console_logger, llm_logger
|
||||
|
||||
|
||||
class ExpertService(object):
|
||||
"""
|
||||
Engine class responsible for managing the Large Language Model (LLM) operations.
|
||||
|
||||
Attributes:
|
||||
cfg (Config): Configuration object containing all the parameters.
|
||||
local_data_parallel_id (int): Local data parallel ID.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, local_data_parallel_id):
|
||||
"""
|
||||
Initializes the LLMEngine with the provided configuration.
|
||||
|
||||
Args:
|
||||
cfg (Config): Config object containing all the configuration parameters.
|
||||
"""
|
||||
self.cfg = cfg
|
||||
start_pos = local_data_parallel_id * self.cfg.tensor_parallel_size
|
||||
end_pos = (local_data_parallel_id + 1) * self.cfg.tensor_parallel_size
|
||||
self.cfg.cache_config.rdma_comm_ports = self.cfg.cache_config.rdma_comm_ports[
|
||||
start_pos:end_pos]
|
||||
self.cfg.local_device_ids = self.cfg.device_ids.split(
|
||||
",")[start_pos:end_pos]
|
||||
self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id
|
||||
self.cfg.disaggregate_info = None
|
||||
|
||||
self.scheduler = cfg.scheduler_config.scheduler()
|
||||
|
||||
self.scheduler.reset_nodeid(
|
||||
f"{self.scheduler.infer.nodeid}_{str(local_data_parallel_id)}")
|
||||
|
||||
self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id
|
||||
|
||||
address = ('0.0.0.0', cfg.engine_worker_queue_port)
|
||||
self.engine_worker_queue = EngineWorkerQueue(
|
||||
address=address,
|
||||
is_server=False,
|
||||
client_id=0,
|
||||
num_client=cfg.tensor_parallel_size,
|
||||
local_data_parallel_id=local_data_parallel_id,
|
||||
)
|
||||
self.resource_manager = ResourceManager(cfg.max_num_seqs, cfg, \
|
||||
cfg.tensor_parallel_size, cfg.splitwise_role, local_data_parallel_id)
|
||||
|
||||
if len(self.cfg.cache_config.pd_comm_port) == 1:
|
||||
self.cfg.cache_config.pd_comm_port[0] = int(
|
||||
self.cfg.cache_config.pd_comm_port[0]) + local_data_parallel_id
|
||||
else:
|
||||
self.cfg.cache_config.pd_comm_port = [
|
||||
self.cfg.cache_config.pd_comm_port[local_data_parallel_id]
|
||||
]
|
||||
|
||||
self.split_connector = SplitwiseConnector(self.cfg, self.scheduler,
|
||||
self.engine_worker_queue,
|
||||
self.resource_manager)
|
||||
|
||||
self.token_processor = TokenProcessor(
|
||||
cfg=cfg,
|
||||
cached_generated_tokens=self.scheduler,
|
||||
engine_worker_queue=self.engine_worker_queue,
|
||||
split_connector=self.split_connector)
|
||||
self.token_processor.set_resource_manager(self.resource_manager)
|
||||
|
||||
self.partial_chunked_tokens = [0] * (
|
||||
self.cfg.max_num_partial_prefills + 1)
|
||||
for idx in range(1, self.cfg.max_num_partial_prefills + 1):
|
||||
self.partial_chunked_tokens[idx] = (self.cfg.max_num_batched_tokens // idx) \
|
||||
// self.cfg.cache_config.block_size * self.cfg.cache_config.block_size
|
||||
|
||||
self._finalizer = weakref.finalize(self, self._exit_sub_services)
|
||||
|
||||
def start(self, ipc_signal_suffix, local_data_parallel_id):
|
||||
"""
|
||||
Initializes the engine and starts its sub-services.
|
||||
If `api_server_pid` is defined, will launch a thread
|
||||
to keep getting request from zmq_server.
|
||||
"""
|
||||
# assert not self.is_started, "The engine is already started."
|
||||
start_time = time.time()
|
||||
|
||||
llm_logger.info(f"start expert service {local_data_parallel_id}")
|
||||
|
||||
self.cache_manager_processes = self.resource_manager.cache_manager.launch_cache_manager(
|
||||
self.cfg.cache_config, self.cfg.tensor_parallel_size,
|
||||
self.cfg.local_device_ids, self.cfg.engine_worker_queue_port,
|
||||
f"{local_data_parallel_id}_{ipc_signal_suffix}")
|
||||
|
||||
self.insert_task_to_worker_thread = threading.Thread(
|
||||
target=self._insert_task_to_worker, args=())
|
||||
self.insert_task_to_worker_thread.daemon = True
|
||||
self.insert_task_to_worker_thread.start()
|
||||
|
||||
# Start TokenProcessor thread
|
||||
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(
|
||||
local_data_parallel_id + int(self.cfg.engine_worker_queue_port))
|
||||
|
||||
self.token_processor.run()
|
||||
|
||||
self.split_mode_get_tasks()
|
||||
|
||||
self.cfg.init_cache_info()
|
||||
|
||||
role = self.cfg.splitwise_role
|
||||
host_ip = self.cfg.host_ip
|
||||
disaggregate = self.cfg.disaggregate_info
|
||||
self.scheduler.start(role, host_ip, disaggregate)
|
||||
self.cfg.print()
|
||||
|
||||
console_logger.info(
|
||||
"Worker processes are launched with {} seconds.".format(
|
||||
time.time() - start_time))
|
||||
return True
|
||||
|
||||
def _insert_task_to_worker(self):
|
||||
"""
|
||||
Insert task to engine thread, monitor scheduler request queue.
|
||||
if the engine has resource, insert task to engine
|
||||
"""
|
||||
current_id = -1
|
||||
while True:
|
||||
try:
|
||||
if self.resource_manager.available_batch() == 0:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
if self.engine_worker_queue.num_tasks() > 0:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
if len(self.split_connector.current_request_ids) > 0:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
|
||||
num_prefill_batch = min(
|
||||
int(self.resource_manager.available_batch()),
|
||||
self.cfg.max_prefill_batch)
|
||||
|
||||
self.resource_manager.check_and_free_block_tables()
|
||||
tasks = self.scheduler.get_requests(
|
||||
available_blocks=self.resource_manager.available_block_num(
|
||||
),
|
||||
block_size=self.cfg.cache_config.block_size,
|
||||
reserved_output_blocks=self.cfg.cache_config.
|
||||
enc_dec_block_num,
|
||||
max_num_batched_tokens=self.cfg.max_num_batched_tokens,
|
||||
batch=num_prefill_batch)
|
||||
|
||||
if len(tasks) == 0:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
|
||||
if self.cfg.splitwise_role != "mixed":
|
||||
llm_logger.info("Inserting splitwise tasks")
|
||||
self.split_connector.send_splitwise_tasks(
|
||||
tasks, current_id)
|
||||
|
||||
current_id = (current_id + 1) % 100003
|
||||
|
||||
self.insert_tasks(tasks, current_id)
|
||||
|
||||
main_process_metrics.num_requests_waiting.dec(len(tasks))
|
||||
main_process_metrics.num_requests_running.inc(len(tasks))
|
||||
except Exception as e:
|
||||
err_msg = "Error happend while insert task to engine: {}, {}.".format(
|
||||
e, str(traceback.format_exc()))
|
||||
llm_logger.error(err_msg)
|
||||
|
||||
def split_mode_get_tasks(self):
|
||||
"""
|
||||
Split mode get tasks
|
||||
"""
|
||||
waiting_requests = []
|
||||
|
||||
def receiver_loop():
|
||||
while True:
|
||||
try:
|
||||
if len(waiting_requests) > 0:
|
||||
for task in waiting_requests:
|
||||
if self.resource_manager.is_resource_sufficient(
|
||||
task.prompt_token_ids_len):
|
||||
self.insert_tasks([task])
|
||||
waiting_requests.remove(task)
|
||||
else:
|
||||
break
|
||||
if not self.engine_worker_queue.disaggregate_queue_empty():
|
||||
items = self.engine_worker_queue.get_disaggregated_tasks(
|
||||
)
|
||||
for item in items:
|
||||
role = item[0]
|
||||
tasks = item[1]
|
||||
if role == "prefill":
|
||||
llm_logger.info("get prefill tasks")
|
||||
for task in tasks:
|
||||
task.max_tokens = task.min_tokens = 2
|
||||
self.insert_tasks(tasks)
|
||||
elif role == "decode":
|
||||
llm_logger.info(f"get decode tasks {tasks}")
|
||||
if hasattr(tasks[0], 'finished'):
|
||||
if not isinstance(tasks, list):
|
||||
tasks = [tasks]
|
||||
for task in tasks:
|
||||
task.finished = False
|
||||
# self.scheduler.put_results(tasks)
|
||||
|
||||
self.insert_tasks(tasks, allocated=True)
|
||||
else:
|
||||
if len(waiting_requests):
|
||||
for task in tasks:
|
||||
waiting_requests.append(task)
|
||||
else:
|
||||
for task in tasks:
|
||||
if not self.resource_manager.is_resource_sufficient(
|
||||
task.prompt_token_ids_len):
|
||||
waiting_requests.append(task)
|
||||
else:
|
||||
self.insert_tasks([task])
|
||||
|
||||
else:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
except Exception as e:
|
||||
llm_logger.error(f"get decode tasks error: {e}")
|
||||
|
||||
threading.Thread(target=receiver_loop, daemon=True).start()
|
||||
|
||||
def insert_tasks(self, tasks, current_id=-1, allocated=False):
|
||||
"""
|
||||
Insert tasks to engine.
|
||||
"""
|
||||
if allocated:
|
||||
current_tasks = []
|
||||
for task in tasks:
|
||||
cur_task_idx = self.resource_manager.req_dict[task.request_id]
|
||||
del self.resource_manager.req_dict[task.request_id]
|
||||
cur_task = self.resource_manager.tasks_list[cur_task_idx]
|
||||
if task.error_code != 200:
|
||||
self.resource_manager.stop_flags[cur_task_idx] = True
|
||||
self.resource_manager.tasks_list[cur_task_idx] = None
|
||||
self.resource_manager._recycle_block_tables(cur_task)
|
||||
if task.request_id in self.token_processor.tokens_counter:
|
||||
del self.token_processor.tokens_counter[
|
||||
task.request_id]
|
||||
self.scheduler.put_results([task])
|
||||
llm_logger.warning(
|
||||
f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource."
|
||||
)
|
||||
continue
|
||||
llm_logger.info(f"{cur_task_idx} {task.request_id}")
|
||||
cur_task.prompt_token_ids[0] = task.outputs.token_ids[0]
|
||||
self.token_processor.tokens_counter[task.request_id] = 1
|
||||
current_tasks.append(cur_task)
|
||||
self.engine_worker_queue.put_tasks(
|
||||
(current_tasks, self.resource_manager.real_bsz))
|
||||
return True
|
||||
|
||||
self.resource_manager.check_and_free_block_tables()
|
||||
|
||||
if not isinstance(tasks, list):
|
||||
tasks = [tasks]
|
||||
|
||||
for item in tasks:
|
||||
item.schedule_start_time = time.time()
|
||||
|
||||
available_batch = np.sum(self.resource_manager.stop_flags)
|
||||
if len(tasks) > available_batch:
|
||||
llm_logger.error(
|
||||
"Inserting batch:{} exceeds the available batch:{}.".format(
|
||||
len(tasks), available_batch))
|
||||
llm_logger.error("The exceeded part will be ignored!")
|
||||
tasks = tasks[:available_batch]
|
||||
|
||||
req_ids = [t.request_id for t in tasks]
|
||||
|
||||
tasks = self.resource_manager.allocate_resources_for_new_tasks(tasks)
|
||||
|
||||
if not tasks:
|
||||
error_msg = f"The request required resources is exceed the limit, request id={req_ids}."
|
||||
llm_logger.error(error_msg)
|
||||
raise EngineError(error_msg, error_code=500)
|
||||
return False
|
||||
|
||||
self.token_processor.number_of_tasks += len(tasks)
|
||||
|
||||
is_decode = False
|
||||
is_prefill = False
|
||||
for i in range(len(tasks)):
|
||||
if tasks[i].disaggregate_info is not None:
|
||||
if tasks[i].disaggregate_info["role"] == "decode":
|
||||
is_decode = True
|
||||
else:
|
||||
is_prefill = True
|
||||
self.token_processor.number_of_input_tokens += tasks[
|
||||
i].prompt_token_ids_len
|
||||
|
||||
self.split_connector.send_cache_infos(tasks, current_id)
|
||||
for task in tasks:
|
||||
task.infer_start_time = time.time()
|
||||
if not is_decode:
|
||||
llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
|
||||
if not is_prefill:
|
||||
if not self.cfg.enable_mm:
|
||||
self.update_requests_chunk_size(tasks)
|
||||
else:
|
||||
self.update_mm_requests_chunk_size(tasks)
|
||||
self.engine_worker_queue.put_tasks(
|
||||
(tasks, self.resource_manager.real_bsz))
|
||||
return True
|
||||
|
||||
def _exit_sub_services(self):
|
||||
"""
|
||||
exit sub services
|
||||
"""
|
||||
|
||||
if hasattr(self, "cache_manager_processes"):
|
||||
self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear(
|
||||
)
|
||||
self.resource_manager.cache_manager.cache_ready_signal.clear()
|
||||
for p in self.cache_manager_processes:
|
||||
llm_logger.info(f"Killing cache manager process {p.pid}")
|
||||
try:
|
||||
os.killpg(p.pid, signal.SIGTERM)
|
||||
except:
|
||||
pass
|
||||
|
||||
if hasattr(self, "zmq_server") and self.zmq_server is not None:
|
||||
self.zmq_server.close()
|
||||
|
||||
|
||||
def start_expert_service(cfg, local_data_parallel_id, ipc_signal_suffix):
|
||||
"""
|
||||
Start expert service
|
||||
"""
|
||||
expert_service = ExpertService(cfg, local_data_parallel_id)
|
||||
try:
|
||||
expert_service.start(ipc_signal_suffix, local_data_parallel_id)
|
||||
expert_service.split_connector.start_receiver()
|
||||
except Exception as e:
|
||||
llm_logger.exception(f"Expert service failed to start: {e}")
|
||||
@@ -15,12 +15,12 @@
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import asdict, dataclass, fields
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import numpy
|
||||
from dataclasses import dataclass, asdict, fields
|
||||
from typing import TYPE_CHECKING, Optional, Union, Any
|
||||
from fastdeploy.engine.sampling_params import SamplingParams
|
||||
from fastdeploy.utils import data_processor_logger
|
||||
|
||||
from fastdeploy.engine.sampling_params import SamplingParams
|
||||
from fastdeploy.utils import data_processor_logger
|
||||
@@ -28,41 +28,33 @@ from fastdeploy.utils import data_processor_logger
|
||||
|
||||
@dataclass
|
||||
class Request:
|
||||
"""A class representing an inference request to the LLM engine.
|
||||
|
||||
Attributes:
|
||||
request_id: Unique identifier for the request
|
||||
prompt: Input prompt text or list of prompts
|
||||
prompt_token_ids: Token IDs of the input prompt
|
||||
prompt_token_ids_len: Length of prompt token IDs
|
||||
messages: List of message dictionaries (for chat models)
|
||||
history: Conversation history (for chat models)
|
||||
system: System message (for chat models)
|
||||
sampling_params: Parameters controlling text generation
|
||||
eos_token_ids: List of end-of-sequence token IDs
|
||||
arrival_time: Timestamp when request was received
|
||||
preprocess_start_time: Timestamp when preprocessing started
|
||||
preprocess_end_time: Timestamp when preprocessing completed
|
||||
multimodal_inputs: Dictionary of multimodal inputs (images, audio etc.)
|
||||
raw_request: Flag indicating if this is a raw request
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: Optional[Union[str, list[str]]],
|
||||
prompt_token_ids: Optional[list[int]],
|
||||
prompt_token_ids_len: Optional[int],
|
||||
messages: Optional[list[list[dict[str, Any]]]],
|
||||
history: Optional[list[list[str]]],
|
||||
system: Optional[Union[str, list[str]]],
|
||||
sampling_params: SamplingParams,
|
||||
eos_token_ids: Optional[list[int]],
|
||||
arrival_time: float,
|
||||
preprocess_start_time: Optional[float] = None,
|
||||
preprocess_end_time: Optional[float] = None,
|
||||
multimodal_inputs: Optional[dict] = None,
|
||||
raw_request: bool = True
|
||||
) -> None:
|
||||
|
||||
def __init__(self,
|
||||
request_id: str,
|
||||
prompt: Optional[Union[str, list[str]]],
|
||||
prompt_token_ids: Optional[list[int]],
|
||||
prompt_token_ids_len: Optional[int],
|
||||
messages: Optional[list[list[dict[str, Any]]]],
|
||||
history: Optional[list[list[str]]],
|
||||
tools: Optional[list[Dict]],
|
||||
system: Optional[Union[str, list[str]]],
|
||||
sampling_params: SamplingParams,
|
||||
eos_token_ids: Optional[list[int]],
|
||||
arrival_time: float,
|
||||
preprocess_start_time: Optional[float] = None,
|
||||
preprocess_end_time: Optional[float] = None,
|
||||
multimodal_inputs: Optional[dict] = None,
|
||||
multimodal_data: Optional[dict] = None,
|
||||
raw_request: bool = True,
|
||||
disaggregate_info: Optional[dict] = None,
|
||||
draft_token_ids: Optional[list[int]] = None,
|
||||
guided_json: Optional[Any] = None,
|
||||
guided_regex: Optional[Any] = None,
|
||||
guided_choice: Optional[Any] = None,
|
||||
guided_grammar: Optional[Any] = None,
|
||||
structural_tag: Optional[Any] = None,
|
||||
guided_json_object: Optional[bool] = None,
|
||||
enable_thinking: Optional[bool] = True) -> None:
|
||||
self.request_id = request_id
|
||||
self.prompt = prompt
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
@@ -71,52 +63,66 @@ class Request:
|
||||
self.system = system
|
||||
self.sampling_params = sampling_params
|
||||
self.history = history
|
||||
self.tools = tools
|
||||
# model specific token ids: end of sentence token ids
|
||||
self.eos_token_ids = eos_token_ids
|
||||
self.num_cached_tokens = 0
|
||||
|
||||
self.arrival_time = arrival_time
|
||||
self.preprocess_start_time = preprocess_start_time
|
||||
self.preprocess_end_time = preprocess_end_time
|
||||
self.raw_request = raw_request
|
||||
self.disaggregate_info = disaggregate_info
|
||||
|
||||
# speculative method in disaggregate-mode
|
||||
self.draft_token_ids = draft_token_ids
|
||||
|
||||
# guided decoding related
|
||||
self.guided_json = guided_json
|
||||
self.guided_regex = guided_regex
|
||||
self.guided_choice = guided_choice
|
||||
self.guided_grammar = guided_grammar
|
||||
self.structural_tag = structural_tag
|
||||
self.guided_json_object = guided_json_object
|
||||
|
||||
# Multi-modal related
|
||||
self.multimodal_inputs = multimodal_inputs
|
||||
self.multimodal_data = multimodal_data
|
||||
|
||||
self.enable_thinking = enable_thinking
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict):
|
||||
"""Create a Request instance from a dictionary.
|
||||
|
||||
Args:
|
||||
d: Dictionary containing request parameters
|
||||
|
||||
Returns:
|
||||
Request: A new Request instance initialized with values from the dictionary
|
||||
"""
|
||||
data_processor_logger.debug(f"{d}")
|
||||
sampling_params = SamplingParams.from_dict(d)
|
||||
return cls(
|
||||
request_id=d["request_id"],
|
||||
prompt=d.get("prompt"),
|
||||
prompt_token_ids=d.get("prompt_token_ids"),
|
||||
prompt_token_ids_len=d.get("prompt_token_ids_len"),
|
||||
messages=d.get("messages"),
|
||||
system=d.get("system"),
|
||||
history=d.get("history"),
|
||||
sampling_params=sampling_params,
|
||||
eos_token_ids=d.get("eos_token_ids"),
|
||||
arrival_time=d.get("arrival_time", time.time()),
|
||||
preprocess_start_time=d.get("preprocess_start_time"),
|
||||
preprocess_end_time=d.get("preprocess_end_time"),
|
||||
multimodal_inputs=d.get("multimodal_inputs"),
|
||||
raw_request=d.get("raw_request", True)
|
||||
)
|
||||
return cls(request_id=d["request_id"],
|
||||
prompt=d.get("prompt"),
|
||||
prompt_token_ids=d.get("prompt_token_ids"),
|
||||
prompt_token_ids_len=d.get("prompt_token_ids_len"),
|
||||
messages=d.get("messages"),
|
||||
system=d.get("system"),
|
||||
history=d.get("history"),
|
||||
tools=d.get("tools"),
|
||||
sampling_params=sampling_params,
|
||||
eos_token_ids=d.get("eos_token_ids"),
|
||||
arrival_time=d.get("arrival_time", time.time()),
|
||||
preprocess_start_time=d.get("preprocess_start_time"),
|
||||
preprocess_end_time=d.get("preprocess_end_time"),
|
||||
multimodal_inputs=d.get("multimodal_inputs"),
|
||||
multimodal_data=d.get("multimodal_data"),
|
||||
disaggregate_info=d.get("disaggregate_info"),
|
||||
draft_token_ids=d.get("draft_token_ids"),
|
||||
raw_request=d.get("raw_request", True),
|
||||
guided_json=d.get("guided_json", None),
|
||||
guided_regex=d.get("guided_regex", None),
|
||||
guided_choice=d.get("guided_choice", None),
|
||||
guided_grammar=d.get("guided_grammar", None),
|
||||
structural_tag=d.get("structural_tag", None),
|
||||
guided_json_object=d.get("guided_json_object", None),
|
||||
enable_thinking=d.get("enable_thinking", True))
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert the Request object into a serializable dictionary.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing all request attributes and sampling parameters
|
||||
"""
|
||||
"""convert Request into a serializable dict """
|
||||
data = {
|
||||
"request_id": self.request_id,
|
||||
"prompt": self.prompt,
|
||||
@@ -125,26 +131,30 @@ class Request:
|
||||
"messages": self.messages,
|
||||
"system": self.system,
|
||||
"history": self.history,
|
||||
"tools": self.tools,
|
||||
"eos_token_ids": self.eos_token_ids,
|
||||
"arrival_time": self.arrival_time,
|
||||
"preprocess_start_time": self.preprocess_start_time,
|
||||
"preprocess_end_time": self.preprocess_end_time,
|
||||
"multimodal_inputs": self.multimodal_inputs,
|
||||
"raw_request": self.raw_request
|
||||
"multimodal_data": self.multimodal_data,
|
||||
"raw_request": self.raw_request,
|
||||
"disaggregate_info": self.disaggregate_info,
|
||||
"draft_token_ids": self.draft_token_ids,
|
||||
"enable_thinking": self.enable_thinking
|
||||
}
|
||||
add_params = [
|
||||
"guided_json", "guided_regex", "guided_choice", "guided_grammar",
|
||||
"structural_tag", "guided_json_object"
|
||||
]
|
||||
for param in add_params:
|
||||
if getattr(self, param, None) is not None:
|
||||
data[param] = getattr(self, param)
|
||||
|
||||
data.update(asdict(self.sampling_params))
|
||||
return data
|
||||
|
||||
def get(self, key: str, default_value=None):
|
||||
"""Get an attribute value from either the Request or its sampling parameters.
|
||||
|
||||
Args:
|
||||
key: Attribute name to retrieve
|
||||
default_value: Default value to return if attribute not found
|
||||
|
||||
Returns:
|
||||
The attribute value if found, otherwise default_value
|
||||
"""
|
||||
if hasattr(self, key):
|
||||
return getattr(self, key)
|
||||
elif hasattr(self.sampling_params, key):
|
||||
@@ -153,12 +163,6 @@ class Request:
|
||||
return default_value
|
||||
|
||||
def set(self, key, value):
|
||||
"""Set an attribute value on either the Request or its sampling parameters.
|
||||
|
||||
Args:
|
||||
key: Attribute name to set
|
||||
value: Value to assign to the attribute
|
||||
"""
|
||||
if hasattr(self.sampling_params, key):
|
||||
setattr(self.sampling_params, key, value)
|
||||
else:
|
||||
@@ -168,6 +172,7 @@ class Request:
|
||||
return (f"Request(request_id={self.request_id}, "
|
||||
f"prompt={self.prompt!r}, "
|
||||
f"prompt_token_ids={self.prompt_token_ids}, "
|
||||
f"draft_token_ids={self.draft_token_ids}, "
|
||||
f"sampling_params={self.sampling_params})")
|
||||
|
||||
|
||||
@@ -182,22 +187,42 @@ class CompletionOutput:
|
||||
"""
|
||||
|
||||
index: int
|
||||
send_idx: int
|
||||
token_ids: list[int]
|
||||
draft_token_ids: list[int] = None
|
||||
text: Optional[str] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
convert CompletionOutput to a serialized dict
|
||||
"""
|
||||
return {
|
||||
"index": self.index,
|
||||
"send_idx": self.send_idx,
|
||||
"token_ids": self.token_ids,
|
||||
"draft_token_ids": self.draft_token_ids,
|
||||
"text": self.text,
|
||||
"reasoning_content": self.reasoning_content
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, req_dict: dict[str, Any]) -> 'CompletionOutput':
|
||||
"""Create instance from dict arguments"""
|
||||
return cls(**{
|
||||
field.name: req_dict[field.name] if field.name in req_dict else field.default
|
||||
for field in fields(cls)
|
||||
})
|
||||
return cls(
|
||||
**{
|
||||
field.name:
|
||||
req_dict[field.name] if field.name in
|
||||
req_dict else field.default
|
||||
for field in fields(cls)
|
||||
})
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"CompletionOutput(index={self.index}, "
|
||||
f"send_idx={self.send_idx}, "
|
||||
f"text={self.text!r}, "
|
||||
f"token_ids={self.token_ids}, "
|
||||
f"draft_token_ids={self.draft_token_ids}, "
|
||||
f"reasoning_content={self.reasoning_content!r}")
|
||||
|
||||
|
||||
@@ -227,13 +252,31 @@ class RequestMetrics:
|
||||
model_execute_time: Optional[float] = None
|
||||
request_start_time: Optional[float] = None
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Convert the RequestMetrics object to a dictionary.
|
||||
"""
|
||||
return {
|
||||
"arrival_time": self.arrival_time,
|
||||
"inference_start_time": self.inference_start_time,
|
||||
"first_token_time": self.first_token_time,
|
||||
"time_in_queue": self.time_in_queue,
|
||||
"preprocess_cost_time": self.preprocess_cost_time,
|
||||
"model_forward_time": self.model_forward_time,
|
||||
"model_execute_time": self.model_execute_time,
|
||||
"request_start_time": self.request_start_time
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, req_dict: dict[str, Any]) -> 'RequestMetrics':
|
||||
"""Create instance from dict arguments"""
|
||||
return cls(**{
|
||||
field.name: req_dict[field.name] if field.name in req_dict else field.default
|
||||
for field in fields(cls)
|
||||
})
|
||||
return cls(
|
||||
**{
|
||||
field.name:
|
||||
req_dict[field.name] if field.name in
|
||||
req_dict else field.default
|
||||
for field in fields(cls)
|
||||
})
|
||||
|
||||
|
||||
class RequestOutput:
|
||||
@@ -282,16 +325,8 @@ class RequestOutput:
|
||||
self.error_msg = error_msg
|
||||
|
||||
def add(self, next_output: "RequestOutput") -> None:
|
||||
"""Merge another RequestOutput into this one.
|
||||
|
||||
Args:
|
||||
next_output: The RequestOutput to merge into this one
|
||||
|
||||
Updates:
|
||||
- Combines output sequences
|
||||
- Updates finish status
|
||||
- Calculates timing metrics
|
||||
"""
|
||||
"""Merge RequestOutput into this one"""
|
||||
|
||||
self.prompt = next_output.prompt
|
||||
self.prompt_token_ids = next_output.prompt_token_ids
|
||||
self.finished |= next_output.finished
|
||||
@@ -314,25 +349,13 @@ class RequestOutput:
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict):
|
||||
"""Create a RequestOutput instance from a dictionary.
|
||||
|
||||
Args:
|
||||
d: Dictionary containing request output parameters
|
||||
|
||||
Returns:
|
||||
RequestOutput: A new RequestOutput instance initialized with values from the dictionary
|
||||
"""
|
||||
"""Create instance from dict arguments"""
|
||||
completion_output = CompletionOutput.from_dict(d.pop("outputs"))
|
||||
metrics = RequestMetrics.from_dict(d.pop("metrics"))
|
||||
return RequestOutput(**d, outputs=completion_output, metrics=metrics)
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert the RequestOutput object into a serializable dictionary.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing all request output attributes,
|
||||
with token IDs converted to lists if necessary
|
||||
"""
|
||||
"""convert RequestOutput into a serializable dict """
|
||||
if self.prompt_token_ids is None:
|
||||
self.prompt_token_ids = []
|
||||
|
||||
@@ -343,11 +366,12 @@ class RequestOutput:
|
||||
"request_id": self.request_id,
|
||||
"prompt": self.prompt,
|
||||
"prompt_token_ids": self.prompt_token_ids,
|
||||
"outputs": None if self.outputs is None else asdict(self.outputs),
|
||||
"outputs":
|
||||
None if self.outputs is None else self.outputs.to_dict(),
|
||||
"metrics":
|
||||
None if self.metrics is None else self.metrics.to_dict(),
|
||||
"finished": self.finished,
|
||||
"metrics": None if self.metrics is None else asdict(self.metrics),
|
||||
"num_cached_tokens": self.num_cached_tokens,
|
||||
"error_code": self.error_code,
|
||||
"error_msg": self.error_msg,
|
||||
}
|
||||
|
||||
|
||||
@@ -14,110 +14,116 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import os
|
||||
import math
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from fastdeploy.cache_manager.prefix_cache_manager import PrefixCacheManager
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.utils import llm_logger
|
||||
|
||||
|
||||
class ResourceManager(object):
|
||||
"""Manages and allocates computational resources for the inference engine.
|
||||
|
||||
This class handles the allocation and recycling of memory blocks for KV cache,
|
||||
manages task scheduling, and tracks resource utilization.
|
||||
"""
|
||||
def __init__(self, max_num_seqs, cache_config):
|
||||
"""Initializes the resource manager with configuration parameters.
|
||||
|
||||
Args:
|
||||
max_num_seqs (int): Maximum number of concurrent sequences the engine can handle
|
||||
cache_config (Config): Configuration object containing:
|
||||
- prefill_kvcache_block_num: Number of pre-allocated KV cache blocks
|
||||
- block_size: Size of each memory block in tokens
|
||||
- dec_token_num: Number of decoder tokens
|
||||
record and allocate resources for the engine
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
max_num_seqs,
|
||||
config,
|
||||
tensor_parallel_size,
|
||||
splitwise_role,
|
||||
local_data_parallel_id=0):
|
||||
"""
|
||||
self.cfg = cache_config
|
||||
Args:
|
||||
cfg (Config): config object containing parameters for the engine
|
||||
initialization
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Initializes the engine with the given configuration and sets up necessary
|
||||
data structures to manage tasks and blocks.
|
||||
"""
|
||||
self.cfg = config.cache_config
|
||||
self.max_num_seqs = max_num_seqs
|
||||
self.stop_flags = [True] * max_num_seqs
|
||||
|
||||
|
||||
self.free_list = list(range(self.cfg.prefill_kvcache_block_num - 1, -1, -1))
|
||||
self.enable_prefix_cache = config.cache_config.enable_prefix_caching
|
||||
self.cache_manager = PrefixCacheManager(config, tensor_parallel_size,
|
||||
splitwise_role,
|
||||
local_data_parallel_id)
|
||||
self.tasks_list = [None] * max_num_seqs
|
||||
self.req_dict = dict()
|
||||
# current batch status of the engine
|
||||
self.real_bsz = 0
|
||||
llm_logger.info(f"{self.info()}")
|
||||
|
||||
def reset_cache_config(self, cfg):
|
||||
"""Updates the cache configuration with new parameters.
|
||||
|
||||
Args:
|
||||
cfg (Config): New cache configuration object
|
||||
"""
|
||||
reset cache config
|
||||
"""
|
||||
self.cfg = cfg
|
||||
self.free_list = list(range(self.cfg.prefill_kvcache_block_num - 1, -1, -1))
|
||||
|
||||
self.cache_manager.update_cache_config(cfg)
|
||||
|
||||
def get_required_block_number(self, input_token_num):
|
||||
"""Calculates the total number of blocks needed for a sequence.
|
||||
|
||||
Includes both encoder and decoder requirements.
|
||||
|
||||
Args:
|
||||
input_token_num (int): Number of tokens in the input sequence
|
||||
|
||||
Returns:
|
||||
int: Total number of blocks required (rounded up)
|
||||
"""
|
||||
block_num = (input_token_num + self.cfg.block_size - 1 + self.cfg.dec_token_num) // self.cfg.block_size
|
||||
Calculate Block resources are needed
|
||||
|
||||
Args:
|
||||
input_token_num (int): input token number
|
||||
|
||||
Returns:
|
||||
int: block number
|
||||
"""
|
||||
block_num = (input_token_num + self.cfg.block_size - 1 +
|
||||
self.cfg.dec_token_num) // self.cfg.block_size
|
||||
return block_num
|
||||
|
||||
def get_encoder_block_number(self, input_token_num):
|
||||
"""Calculates the number of blocks needed for encoder inputs only.
|
||||
|
||||
Args:
|
||||
input_token_num (int): Number of tokens in the encoder input
|
||||
|
||||
Returns:
|
||||
int: Number of blocks required for encoder (rounded up)
|
||||
"""
|
||||
enc_block_num = (input_token_num + self.cfg.block_size - 1) // self.cfg.block_size
|
||||
get the number of blocks for the encoder
|
||||
|
||||
Args:
|
||||
input_token_num (int): input token number
|
||||
|
||||
Returns:
|
||||
int: encoder block number
|
||||
"""
|
||||
enc_block_num = (input_token_num + self.cfg.block_size -
|
||||
1) // self.cfg.block_size
|
||||
return enc_block_num
|
||||
|
||||
def get_decoder_block_number(self):
|
||||
"""Calculates the number of blocks needed for decoder outputs.
|
||||
|
||||
Returns:
|
||||
int: Number of blocks required for decoder (rounded up)
|
||||
"""
|
||||
return (self.cfg.dec_token_num + self.cfg.block_size - 1) // self.cfg.block_size
|
||||
get the number of blocks for the decoder
|
||||
|
||||
Returns:
|
||||
int: decoder block number
|
||||
"""
|
||||
return (self.cfg.dec_token_num + self.cfg.block_size -
|
||||
1) // self.cfg.block_size
|
||||
|
||||
def total_block_number(self):
|
||||
"""Gets the total number of pre-allocated KV cache blocks.
|
||||
|
||||
Returns:
|
||||
int: Total number of blocks available in the pool
|
||||
"""
|
||||
return self.cfg.prefill_kvcache_block_num
|
||||
the number of pre allocated blocks at service startup
|
||||
|
||||
Returns:
|
||||
int: total block number
|
||||
"""
|
||||
return self.cache_manager.num_gpu_blocks
|
||||
|
||||
def _get_block_tables(self, input_token_num, required_type="all"):
|
||||
"""Allocates memory blocks from the free pool.
|
||||
|
||||
"""
|
||||
allocate memory resources
|
||||
|
||||
Args:
|
||||
input_token_num (int): Number of input tokens
|
||||
required_type (str): Type of blocks needed:
|
||||
- "all": Both encoder and decoder blocks
|
||||
- "encoder": Encoder blocks only
|
||||
- "decoder": Decoder blocks only
|
||||
|
||||
input_token_num (int): input token number
|
||||
required_type (str): required type
|
||||
|
||||
Returns:
|
||||
list: List of allocated block IDs
|
||||
|
||||
Raises:
|
||||
ValueError: If unknown required_type is specified
|
||||
list: block list
|
||||
"""
|
||||
if required_type == "all":
|
||||
block_num = self.get_required_block_number(input_token_num)
|
||||
@@ -129,50 +135,75 @@ class ResourceManager(object):
|
||||
raise ValueError('unknown required type')
|
||||
|
||||
block_list = list()
|
||||
if block_num > len(self.free_list):
|
||||
llm_logger.error("block_num:{0} > free_list len:{1}".format(block_num, len(self.free_list)))
|
||||
current_block_num = self.available_block_num()
|
||||
if block_num > current_block_num:
|
||||
llm_logger.error("block_num:{0} > free_list len:{1}".format(
|
||||
block_num, current_block_num))
|
||||
return block_list
|
||||
for _ in range(block_num):
|
||||
used_block_id = self.free_list.pop()
|
||||
block_list.append(used_block_id)
|
||||
block_list = self.cache_manager.allocate_gpu_blocks(block_num)
|
||||
llm_logger.debug(f"dispatch {len(block_list)} blocks.")
|
||||
return block_list
|
||||
|
||||
def _recycle_block_tables(self, block_tables):
|
||||
"""Returns memory blocks to the free pool for reuse.
|
||||
|
||||
Args:
|
||||
block_tables (list): List of block IDs to recycle
|
||||
def check_and_free_block_tables(self):
|
||||
"""
|
||||
ori_number = len(self.free_list)
|
||||
self.free_list.extend(block_tables)
|
||||
cur_number = len(self.free_list)
|
||||
llm_logger.info(f"recycle {cur_number - ori_number} blocks.")
|
||||
Check and free block tables only in prefix caching mode.
|
||||
If the number of free blocks is less than a certain threshold, free up to the threshold.
|
||||
"""
|
||||
if self.enable_prefix_cache:
|
||||
if self.available_block_num() < self.cfg.max_block_num_per_seq:
|
||||
self.free_block_tables(self.cfg.max_block_num_per_seq)
|
||||
|
||||
def _recycle_block_tables(self, task):
|
||||
"""
|
||||
Recycling memory resource blocks
|
||||
|
||||
Args:
|
||||
block_tables (list): block list
|
||||
"""
|
||||
|
||||
if self.enable_prefix_cache:
|
||||
self.cache_manager.release_block_ids_async(task)
|
||||
else:
|
||||
req_id = task.request_id
|
||||
if isinstance(task, list):
|
||||
block_tables = task
|
||||
else:
|
||||
block_tables = task.block_tables
|
||||
ori_number = self.available_block_num()
|
||||
self.cache_manager.recycle_gpu_blocks(block_tables)
|
||||
cur_number = self.available_block_num()
|
||||
main_process_metrics.gpu_cache_usage_perc.set(
|
||||
self.get_gpu_cache_usage_perc())
|
||||
llm_logger.info(
|
||||
f"recycle {req_id} {cur_number - ori_number} blocks.")
|
||||
|
||||
def available_batch(self):
|
||||
"""Gets the number of available sequence slots.
|
||||
|
||||
"""
|
||||
available batch size for engine
|
||||
|
||||
Returns:
|
||||
int: Number of available sequence slots in the batch
|
||||
int: available batch size
|
||||
"""
|
||||
return np.sum(self.stop_flags)
|
||||
|
||||
def available_block_num(self):
|
||||
"""Gets the number of available memory blocks.
|
||||
|
||||
Returns:
|
||||
int: Number of free blocks in the pool
|
||||
"""
|
||||
return len(self.free_list)
|
||||
available block size for engine
|
||||
|
||||
Returns:
|
||||
int: available block size
|
||||
"""
|
||||
return len(self.cache_manager.gpu_free_block_list)
|
||||
|
||||
def is_resource_sufficient(self, input_token_num):
|
||||
"""Checks if sufficient resources are available for a new sequence.
|
||||
|
||||
"""
|
||||
check current available resources meet the new requirements
|
||||
|
||||
Args:
|
||||
input_token_num (int): Number of tokens in the new sequence
|
||||
|
||||
input_token_num (int): input token number
|
||||
|
||||
Returns:
|
||||
bool: True if both batch slots and memory blocks are available
|
||||
bool: whether current available resources meet the new requirements
|
||||
"""
|
||||
if self.available_batch() < 1:
|
||||
return False
|
||||
@@ -181,19 +212,21 @@ class ResourceManager(object):
|
||||
return False
|
||||
return True
|
||||
|
||||
def free_block_tables(self, need_reserved_block_num):
|
||||
"""
|
||||
回收block到可用资源池
|
||||
"""
|
||||
return self.cache_manager.free_block_ids_async(need_reserved_block_num)
|
||||
|
||||
def allocate_resources_for_new_tasks(self, tasks):
|
||||
"""Assigns resources to new inference tasks.
|
||||
|
||||
"""
|
||||
allocate resources for new tasks
|
||||
|
||||
Args:
|
||||
tasks (list): List of Request objects needing resources
|
||||
|
||||
tasks (list): task list
|
||||
|
||||
Returns:
|
||||
list: List of successfully allocated Request objects
|
||||
|
||||
Note:
|
||||
- Assigns sequence slots and memory blocks
|
||||
- Sets initial timestamps and metadata
|
||||
- Updates real-time batch size statistics
|
||||
list: processed task list
|
||||
"""
|
||||
|
||||
allocated_position = 0
|
||||
@@ -205,7 +238,8 @@ class ResourceManager(object):
|
||||
|
||||
can_insert = False
|
||||
while allocated_position + 1 <= self.max_num_seqs:
|
||||
if sum(self.stop_flags[allocated_position : allocated_position + 1]) == 1:
|
||||
if sum(self.stop_flags[allocated_position:allocated_position +
|
||||
1]) == 1:
|
||||
can_insert = True
|
||||
break
|
||||
allocated_position += 1
|
||||
@@ -215,14 +249,61 @@ class ResourceManager(object):
|
||||
task = tasks[processing_task_index]
|
||||
|
||||
if task.get("seed") is None:
|
||||
task.set("seed", random.randint(0, 9223372036854775807))
|
||||
task.set("seed",
|
||||
random.randint(0, 9223372036854775807))
|
||||
task.idx = allocated_position
|
||||
block_tables = self._get_block_tables(task.prompt_token_ids_len)
|
||||
if not block_tables:
|
||||
llm_logger.error("req_id: {0} block_tables is empty".format(task.request_id))
|
||||
continue
|
||||
|
||||
if self.enable_prefix_cache:
|
||||
cache_prepare_time = time.time()
|
||||
common_block_ids, unique_block_ids, hit_info = self.cache_manager.request_block_ids(
|
||||
task, self.cfg.block_size, self.cfg.dec_token_num)
|
||||
if unique_block_ids is None:
|
||||
llm_logger.warning(
|
||||
"req_id: {0} not enough blocks available".
|
||||
format(task["req_id"]))
|
||||
return
|
||||
|
||||
cached_len = self._record_request_cache_info(
|
||||
task, common_block_ids, unique_block_ids, hit_info)
|
||||
task.cache_prepare_time = time.time(
|
||||
) - cache_prepare_time
|
||||
|
||||
if task.disaggregate_info is not None:
|
||||
if task.disaggregate_info['role'] == "prefill":
|
||||
self.req_dict[
|
||||
task.request_id] = allocated_position
|
||||
task.disaggregate_info[
|
||||
'block_tables'] = task.block_tables
|
||||
self._delete_cached_data(task, cached_len)
|
||||
elif task.disaggregate_info['role'] == "decode":
|
||||
self.req_dict[
|
||||
task.request_id] = allocated_position
|
||||
task.disaggregate_info[
|
||||
'block_tables'] = task.need_block_tables
|
||||
else:
|
||||
self._delete_cached_data(task, cached_len)
|
||||
|
||||
else:
|
||||
task.block_tables = block_tables
|
||||
block_tables = self._get_block_tables(
|
||||
task.prompt_token_ids_len)
|
||||
if not block_tables:
|
||||
llm_logger.error(
|
||||
"req_id: {0} block_tables is empty".format(
|
||||
task.request_id))
|
||||
continue
|
||||
else:
|
||||
task.block_tables = block_tables
|
||||
task.need_block_tables = task.block_tables
|
||||
|
||||
if task.disaggregate_info is not None:
|
||||
task.disaggregate_info[
|
||||
'block_tables'] = block_tables
|
||||
if task.disaggregate_info['role'] == "prefill":
|
||||
self.req_dict[
|
||||
task.request_id] = allocated_position
|
||||
elif task.disaggregate_info['role'] == "decode":
|
||||
self.req_dict[
|
||||
task.request_id] = allocated_position
|
||||
|
||||
processed_tasks.append(task)
|
||||
self.stop_flags[allocated_position] = False
|
||||
@@ -230,9 +311,10 @@ class ResourceManager(object):
|
||||
task.inference_time_cost = -1.0
|
||||
task.tokens_all_num = int(0)
|
||||
self.tasks_list[allocated_position] = task
|
||||
llm_logger.info(f"Allocate request: {task.request_id}, "
|
||||
f"allocated_position:{allocated_position}, "
|
||||
f"length of prompt token: {task.prompt_token_ids_len}")
|
||||
llm_logger.info(
|
||||
f"Allocate request: {task.request_id}, "
|
||||
f"allocated_position:{allocated_position}, "
|
||||
f"length of prompt token: {task.prompt_token_ids_len}")
|
||||
allocated_position += 1
|
||||
processing_task_index += 1
|
||||
|
||||
@@ -242,20 +324,70 @@ class ResourceManager(object):
|
||||
self.real_bsz = i + 1
|
||||
break
|
||||
|
||||
llm_logger.info(f"Number of allocated requests: {len(tasks)}, number of "
|
||||
f"running requests in worker: {self.real_bsz}")
|
||||
llm_logger.info(
|
||||
f"Number of allocated requests: {len(tasks)}, number of "
|
||||
f"running requests in worker: {self.real_bsz}")
|
||||
llm_logger.info(f"{self.info()}")
|
||||
main_process_metrics.gpu_cache_usage_perc.set(
|
||||
self.get_gpu_cache_usage_perc())
|
||||
|
||||
return processed_tasks
|
||||
|
||||
def _delete_cached_data(self, task, cached_len):
|
||||
"""
|
||||
Delete cached data from the task's prompt token ids based on the cached length.
|
||||
"""
|
||||
if cached_len == len(task.prompt_token_ids):
|
||||
task.prompt_token_ids = task.prompt_token_ids[cached_len - 1:]
|
||||
task.seq_lens_decoder = cached_len - 1
|
||||
else:
|
||||
task.prompt_token_ids = task.prompt_token_ids[cached_len:]
|
||||
task.seq_lens_decoder = cached_len
|
||||
task.prompt_token_ids_len = len(task.prompt_token_ids)
|
||||
|
||||
def _record_request_cache_info(self, task, common_block_ids,
|
||||
unique_block_ids, hit_info):
|
||||
"""
|
||||
Record the cache information for a given task and its corresponding block IDs.
|
||||
"""
|
||||
cache_block_num = len(common_block_ids)
|
||||
no_cache_block_num = math.ceil(len(task.prompt_token_ids) / self.cfg.block_size \
|
||||
- cache_block_num)
|
||||
task.num_cached_tokens = cache_block_num * self.cfg.block_size
|
||||
task.gpu_cache_token_num = hit_info[
|
||||
"gpu_cache_blocks"] * self.cfg.block_size
|
||||
task.cpu_cache_token_num = hit_info[
|
||||
"cpu_cache_blocks"] * self.cfg.block_size
|
||||
task.cache_info = (cache_block_num, no_cache_block_num)
|
||||
|
||||
cached_len = len(common_block_ids) * self.cfg.block_size
|
||||
task.block_tables = common_block_ids + unique_block_ids
|
||||
task.need_block_tables = unique_block_ids
|
||||
llm_logger.debug(f"common: {common_block_ids} ")
|
||||
llm_logger.debug(f"unique: {unique_block_ids} ")
|
||||
return cached_len
|
||||
|
||||
def info(self):
|
||||
"""Generates a summary of current resource status.
|
||||
|
||||
"""
|
||||
get resource manager info
|
||||
|
||||
Returns:
|
||||
str: Formatted string showing:
|
||||
- Total blocks/batch slots
|
||||
- Available blocks/batch slots
|
||||
str: resource manager info
|
||||
"""
|
||||
info = f"ResourceManager info, " \
|
||||
f"total_block_number: {self.total_block_number()}, total_batch_number: {len(self.stop_flags)}, " \
|
||||
f"available_block_num: {self.available_block_num()}, available_batch: {self.available_batch()}"
|
||||
return info
|
||||
|
||||
def get_gpu_cache_usage_perc(self):
|
||||
"""
|
||||
Calculate GPU KV-cache usage
|
||||
|
||||
Returns:
|
||||
float: GPU KV-cache usage (0.0 - 1.0)
|
||||
"""
|
||||
num_total_gpu = self.total_block_number()
|
||||
num_free_gpu = len(self.cache_manager.gpu_free_block_list)
|
||||
if num_total_gpu > 0:
|
||||
return 1.0 - (num_free_gpu / num_total_gpu)
|
||||
return 0.0
|
||||
|
||||
@@ -15,9 +15,10 @@
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Any, Optional, Union, List
|
||||
|
||||
import random
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -62,6 +63,7 @@ class SamplingParams:
|
||||
token sequence is not allowed when the next generated token
|
||||
can complete the sequence.
|
||||
max_tokens: Maximum number of tokens to generate per output sequence.
|
||||
reasoning_max_tokens: Maximum number of tokens to generate for reasoning per output sequence.
|
||||
min_tokens: Minimum number of tokens to generate per output sequence
|
||||
before EOS or stop_token_ids can be generated
|
||||
logprobs: Number of log probabilities to return per output token.
|
||||
@@ -75,131 +77,107 @@ class SamplingParams:
|
||||
|
||||
n: int = 1
|
||||
best_of: Optional[int] = None
|
||||
presence_penalty: float = 0.0
|
||||
frequency_penalty: float = 0.0
|
||||
repetition_penalty: float = 1.0
|
||||
temperature: float = 1.0
|
||||
top_p: float = 0.7
|
||||
presence_penalty: float = None
|
||||
frequency_penalty: float = None
|
||||
repetition_penalty: float = None
|
||||
temperature: float = None
|
||||
top_p: float = None
|
||||
seed: Optional[int] = None
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
stop_token_ids: Optional[Union[List[List[int]], List[int]]] = None
|
||||
max_tokens: Optional[int] = 16
|
||||
max_tokens: Optional[int] = None
|
||||
reasoning_max_tokens: Optional[int] = None
|
||||
min_tokens: int = 1
|
||||
logprobs: Optional[int] = None
|
||||
bad_words: Optional[List[str]] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, req_dict: dict[str, Any]) -> "SamplingParams":
|
||||
"""Create a SamplingParams instance from a dictionary.
|
||||
|
||||
Args:
|
||||
req_dict: Dictionary containing sampling parameters where keys match
|
||||
the field names of SamplingParams
|
||||
|
||||
Returns:
|
||||
SamplingParams: A new instance initialized with values from the dictionary
|
||||
"""
|
||||
return cls(**{
|
||||
field.name: req_dict[field.name] if field.name in req_dict else field.default
|
||||
for field in fields(cls)
|
||||
})
|
||||
|
||||
"""Create instance from command line arguments"""
|
||||
return cls(
|
||||
**{
|
||||
field.name:
|
||||
req_dict[field.name] if field.name in
|
||||
req_dict else field.default
|
||||
for field in fields(cls)
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def from_optional(cls,
|
||||
n,
|
||||
best_of,
|
||||
presence_penalty,
|
||||
frequency_penalty,
|
||||
repetition_penalty,
|
||||
temperature,
|
||||
top_p,
|
||||
seed=None,
|
||||
stop=None,
|
||||
stop_token_ids=None,
|
||||
max_tokens=None,
|
||||
min_tokens=1,
|
||||
logprobs=None,
|
||||
bad_words=None
|
||||
) -> "SamplingParams":
|
||||
"""Create a SamplingParams instance from optional arguments with default fallbacks.
|
||||
|
||||
Args:
|
||||
n: Number of output sequences (default: 1)
|
||||
best_of: Number of sequences to generate before selecting best (default: None)
|
||||
presence_penalty: Penalty for new tokens (default: 0.0)
|
||||
frequency_penalty: Penalty based on token frequency (default: 0.0)
|
||||
repetition_penalty: Penalty for repeated tokens (default: 1.0)
|
||||
temperature: Sampling temperature (default: 1.0)
|
||||
top_p: Nucleus sampling probability (default: 0.7)
|
||||
seed: Random seed (default: random)
|
||||
stop: Stop sequences (default: None)
|
||||
stop_token_ids: Stop token IDs (default: None)
|
||||
max_tokens: Maximum tokens to generate (default: 8192)
|
||||
min_tokens: Minimum tokens before stopping (default: 1)
|
||||
logprobs: Number of logprobs to return (default: None)
|
||||
bad_words: List of banned words (default: None)
|
||||
|
||||
Returns:
|
||||
SamplingParams: A new instance with provided or default values
|
||||
"""
|
||||
return cls(
|
||||
n=1 if n is None else n,
|
||||
best_of=best_of,
|
||||
presence_penalty=presence_penalty if presence_penalty is not None else 0.0,
|
||||
frequency_penalty=frequency_penalty if frequency_penalty is not None else 0.0,
|
||||
repetition_penalty=repetition_penalty if repetition_penalty is not None else 1.0,
|
||||
temperature=temperature if temperature is not None else 1.0,
|
||||
top_p=top_p if top_p is not None else 0.7,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
max_tokens=max_tokens if max_tokens is not None else 8192,
|
||||
min_tokens=min_tokens,
|
||||
logprobs=logprobs,
|
||||
bad_words=bad_words
|
||||
)
|
||||
|
||||
n,
|
||||
best_of,
|
||||
presence_penalty,
|
||||
frequency_penalty,
|
||||
repetition_penalty,
|
||||
temperature,
|
||||
top_p,
|
||||
seed=None,
|
||||
stop=None,
|
||||
stop_token_ids=None,
|
||||
max_tokens=None,
|
||||
reasoning_max_tokens=None,
|
||||
min_tokens=1,
|
||||
logprobs=None,
|
||||
bad_words=None) -> "SamplingParams":
|
||||
"""Create instance from command line arguments"""
|
||||
return cls(n=1 if n is None else n,
|
||||
best_of=best_of,
|
||||
presence_penalty=presence_penalty
|
||||
if presence_penalty is not None else 0.0,
|
||||
frequency_penalty=frequency_penalty
|
||||
if frequency_penalty is not None else 0.0,
|
||||
repetition_penalty=repetition_penalty
|
||||
if repetition_penalty is not None else 1.0,
|
||||
temperature=temperature if temperature is not None else 1.0,
|
||||
top_p=top_p if top_p is not None else 0.7,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
max_tokens=max_tokens if max_tokens is not None else 8192,
|
||||
reasoning_max_tokens=reasoning_max_tokens,
|
||||
min_tokens=min_tokens,
|
||||
logprobs=logprobs,
|
||||
bad_words=bad_words)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize sampling parameters after instance creation.
|
||||
|
||||
Sets a random seed if none provided and validates all parameters.
|
||||
"""
|
||||
if self.seed is None:
|
||||
self.seed = random.randint(0, 922337203685477580)
|
||||
if self.max_tokens is not None and self.reasoning_max_tokens is None:
|
||||
self.reasoning_max_tokens = max(int(self.max_tokens * 0.8), 1)
|
||||
self._verify_args()
|
||||
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
"""Validate all sampling parameters.
|
||||
|
||||
Raises:
|
||||
ValueError: If any parameter is outside its valid range or of incorrect type
|
||||
"""
|
||||
if not isinstance(self.n, int):
|
||||
raise ValueError(f"n must be an int, but is of type {type(self.n)}")
|
||||
raise ValueError(
|
||||
f"n must be an int, but is of type {type(self.n)}")
|
||||
if self.n < 1:
|
||||
raise ValueError(f"n must be at least 1, got {self.n}.")
|
||||
if not -2.0 <= self.presence_penalty <= 2.0:
|
||||
if self.presence_penalty is not None and (
|
||||
not -2.0 <= self.presence_penalty <= 2.0):
|
||||
raise ValueError("presence_penalty must be in [-2, 2], got "
|
||||
f"{self.presence_penalty}.")
|
||||
if not -2.0 <= self.frequency_penalty <= 2.0:
|
||||
if self.frequency_penalty is not None and (
|
||||
not -2.0 <= self.frequency_penalty <= 2.0):
|
||||
raise ValueError("frequency_penalty must be in [-2, 2], got "
|
||||
f"{self.frequency_penalty}.")
|
||||
if self.repetition_penalty <= 0.0:
|
||||
if self.repetition_penalty is not None and self.repetition_penalty <= 0.0:
|
||||
raise ValueError(
|
||||
"repetition_penalty must be greater than zero, got "
|
||||
f"{self.repetition_penalty}.")
|
||||
if self.temperature < 0.0:
|
||||
if self.temperature is not None and self.temperature < 0.0:
|
||||
raise ValueError(
|
||||
f"temperature must be non-negative, got {self.temperature}.")
|
||||
if not 0.0 <= self.top_p <= 1.0:
|
||||
if self.top_p is not None and not 0.0 <= self.top_p <= 1.0:
|
||||
raise ValueError(f"top_p must be in [0, 1], got {self.top_p}.")
|
||||
|
||||
if self.max_tokens is not None and self.max_tokens < 1:
|
||||
raise ValueError(
|
||||
f"max_tokens must be at least 1, got {self.max_tokens}.")
|
||||
|
||||
if self.reasoning_max_tokens is not None and self.reasoning_max_tokens > self.max_tokens:
|
||||
raise ValueError(
|
||||
f"reasoning_max_tokens must be less than max_tokens, got {self.reasoning_max_tokens}.")
|
||||
|
||||
if self.min_tokens < 0:
|
||||
raise ValueError(f"min_tokens must be greater than or equal to 0, "
|
||||
f"got {self.min_tokens}.")
|
||||
@@ -215,33 +193,17 @@ class SamplingParams:
|
||||
raise ValueError("seed must be in [0, 922337203685477580], got "
|
||||
f"{self.seed}.")
|
||||
|
||||
|
||||
def update_from_tokenizer(self, tokenizer):
|
||||
"""Update sampling parameters based on tokenizer configuration.
|
||||
|
||||
Note: Currently a placeholder for future implementation of:
|
||||
- Stop tokens handling
|
||||
- Bad words filtering
|
||||
|
||||
Args:
|
||||
tokenizer: The tokenizer instance to use for configuration
|
||||
"""
|
||||
# TODO: Implement stop tokens and bad words support
|
||||
# Currently stop tokens and bad words are not supported yet
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeamSearchParams:
|
||||
"""Parameters for beam search text generation.
|
||||
|
||||
Args:
|
||||
beam_width: Number of beams to maintain during search
|
||||
max_tokens: Maximum number of tokens to generate
|
||||
ignore_eos: Whether to ignore EOS tokens (default: False)
|
||||
temperature: Sampling temperature (0 means greedy, default: 0.0)
|
||||
length_penalty: Penalty applied to length (1.0 means no penalty, default: 1.0)
|
||||
include_stop_str_in_output: Whether to include stop strings in output (default: False)
|
||||
"""
|
||||
"""Beam search parameters for text generation."""
|
||||
beam_width: int
|
||||
max_tokens: int
|
||||
ignore_eos: bool = False
|
||||
|
||||
Reference in New Issue
Block a user