Sync v2.0 version of code to github repo

This commit is contained in:
Jiang-Jia-Jun
2025-06-29 23:29:37 +00:00
parent d151496038
commit 92c2cfa2e7
597 changed files with 78776 additions and 22905 deletions

View File

@@ -13,12 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import json
from dataclasses import asdict, dataclass
from dataclasses import fields as dataclass_fields
from typing import Any, Dict, List, Optional
from fastdeploy.engine.config import (CacheConfig, Config, ModelConfig,
ParallelConfig, SpeculativeConfig,
TaskOption)
from fastdeploy.scheduler.config import SchedulerConfig
from fastdeploy.utils import FlexibleArgumentParser
@@ -34,7 +35,7 @@ def nullable_str(x: str) -> Optional[str]:
@dataclass
class EngineArgs:
# Model configuration parameters
model: str = ""
model: str = "baidu/ernie-45-turbo"
"""
The name or path of the model to be used.
"""
@@ -70,6 +71,14 @@ class EngineArgs:
"""
Additional keyword arguments for the multi-modal processor.
"""
limit_mm_per_prompt: Optional[Dict[str, Any]] = None
"""
Limitation of numbers of multi-modal data.
"""
reasoning_parser: str = None
"""
specifies the reasoning parser to use for extracting reasoning content from the model output
"""
enable_mm: bool = False
"""
Flags to enable multi-modal model
@@ -82,6 +91,15 @@ class EngineArgs:
"""
dynamic load weight
"""
quantization: str = None
guided_decoding_backend: str = "off"
"""
Guided decoding backend.
"""
guided_decoding_disable_any_whitespace: bool = False
"""
Disable any whitespace in guided decoding.
"""
# Inference configuration parameters
gpu_memory_utilization: float = 0.9
@@ -109,6 +127,16 @@ class EngineArgs:
List of IP addresses for nodes in the cluster.
"""
swap_space: float = None
"""
The amount of CPU memory to offload to.
"""
cache_queue_port: int = 8003
"""
Port for cache queue.
"""
# System configuration parameters
use_warmup: int = 0
"""
@@ -119,51 +147,150 @@ class EngineArgs:
Flag to enable prefix caching.
"""
engine_worker_queue_port: int = 8002
"""
Port for worker queue communication.
"""
splitwise_role: str = "mixed"
"""
Splitwise role: prefill, decode or mixed
"""
data_parallel_size: int = 1
"""
Number of data parallelism.
"""
enable_expert_parallel: bool = False
"""
Enable expert parallelism.
"""
cache_transfer_protocol: str = "ipc"
"""
Protocol to use for cache transfer.
"""
pd_comm_port: Optional[List[int]] = None
"""
Port for splitwise communication.
"""
innode_prefill_ports: Optional[List[int]] = None
"""
Ports for innode dispatch request.
"""
rdma_comm_ports: Optional[List[int]] = None
"""
Ports for rdma communication.
"""
enable_chunked_prefill: bool = False
"""
Flag to enable chunked prefilling.
"""
max_num_partial_prefills: int = 1
"""
For chunked prefill, the max number of concurrent partial prefills.
"""
max_long_partial_prefills: int = 1
"""
For chunked prefill, the maximum number of prompts longer than long-prefill-token-threshold
that will be prefilled concurrently.
"""
long_prefill_token_threshold: int = 0
"""
For chunked prefill, a request is considered long if the prompt is longer than this number of tokens.
"""
static_decode_blocks: int = 2
"""
additional decode block num
"""
scheduler_name: str = "local"
"""
Scheduler name to be used
"""
scheduler_name: str = "local"
scheduler_max_size: int = -1
"""
Size of scheduler
"""
scheduler_max_size: int = -1
scheduler_ttl: int = 900
"""
TTL of request
"""
scheduler_ttl: int = 900
"""
Timeout for waiting for response
"""
scheduler_wait_response_timeout: float = 0.001
scheduler_host: str = "127.0.0.1"
"""
Host of redis
"""
scheduler_host: str = "127.0.0.1"
scheduler_port: int = 6379
"""
Port of redis
"""
scheduler_port: int = 6379
scheduler_db: int = 0
"""
DB of redis
"""
scheduler_db: int = 0
scheduler_password: Optional[str] = None
"""
Password of redis
"""
scheduler_password: Optional[str] = None
scheduler_topic: str = "default"
"""
Topic of scheduler
"""
scheduler_topic: str = "default"
scheduler_min_load_score: float = 3
"""
Max write time of redis
Minimum load score for task assignment
"""
scheduler_load_shards_num: int = 1
"""
Number of shards for load balancing table
"""
scheduler_sync_period: int = 5
"""
SplitWise Use, node load sync period
"""
scheduler_expire_period: int = 3000
"""
SplitWise Use, node will not be scheduled after expire_period ms not sync load
"""
scheduler_release_load_expire_period: int = 600
"""
SplitWise Use, scheduler will release req load after expire period(s)
"""
scheduler_reader_parallel: int = 4
"""
SplitWise Use, Results Reader Sync Parallel
"""
scheduler_writer_parallel: int = 4
"""
SplitWise Use, Results Writer Sync Parallel
"""
scheduler_reader_batch_size: int = 200
"""
SplitWise Use, Results Reader Batch Size
"""
scheduler_writer_batch_size: int = 200
"""
SplitWise Use, Results Writer Batch Size
"""
enable_static_graph_inference: bool = False
"""
Whether to use static mode
"""
use_cudagraph: bool = False
"""
Flags to enable Cuda Graph
"""
max_capture_batch_size: int = 64
"""
Maximum Batch Size for Cuda Graph Capture
NOTE: Now only support to capture continuous batch size,
Example:
max_capture_batch_size=64, FastDeploy will capture graphs for batches [1,64].
"""
scheduler_remote_write_time: int = 3
def __post_init__(self):
"""
@@ -214,20 +341,32 @@ class EngineArgs:
default=EngineArgs.use_warmup,
help="Flag to indicate whether to use warm-up before inference.")
model_group.add_argument(
"--mm_processor_kwargs",
default=None,
"--limit-mm-per-prompt",
default=EngineArgs.limit_mm_per_prompt,
type=json.loads,
help="Limitation of numbers of multi-modal data.")
model_group.add_argument(
"--mm-processor-kwargs",
default=EngineArgs.mm_processor_kwargs,
type=json.loads,
help="Additional keyword arguments for the multi-modal processor.")
model_group.add_argument("--enable-mm",
action='store_true',
default=EngineArgs.enable_mm,
help="Flag to enable multi-modal model.")
model_group.add_argument("--reasoning-parser",
type=str,
default=EngineArgs.reasoning_parser,
help="Flag specifies the reasoning parser to use for extracting "\
"reasoning content from the model output")
model_group.add_argument(
"--speculative_config",
default=None,
"--speculative-config",
type=json.loads,
default=EngineArgs.speculative_config,
help="Configuration for speculative execution.")
model_group.add_argument(
"--dynamic_load_weight",
"--dynamic-load-weight",
type=int,
default=EngineArgs.dynamic_load_weight,
help="Flag to indicate whether to load weight dynamically.")
@@ -236,6 +375,39 @@ class EngineArgs:
type=int,
default=EngineArgs.engine_worker_queue_port,
help="port for engine worker queue")
model_group.add_argument("--quantization",
type=str,
default=EngineArgs.quantization,
help="Quantization name for the model, currentlly support " \
"'wint8', 'wint4'," \
"default is None. The priority of this configuration "\
"is lower than that of the config file. " \
"More complex quantization methods need to be configured via the config file.")
model_group.add_argument(
"--enable-static-graph-inference",
action='store_true',
default=EngineArgs.enable_static_graph_inference,
help="Whether to use static mode; if enabled, " \
"'paddle.to_static' will be used to convert dynamic to static.")
model_group.add_argument("--use-cudagraph",
action='store_true',
default=EngineArgs.use_cudagraph,
help="Flags to enable cuda graph.")
model_group.add_argument("--max-capture-batch-size",
type=int,
default=EngineArgs.max_capture_batch_size,
help="Maximum of Batch Size for Warm Up.")
model_group.add_argument("--guided-decoding-backend",
type=str,
default=EngineArgs.guided_decoding_backend,
help="Guided Decoding Backend")
model_group.add_argument(
"--guided-decoding-disable-any-whitespace",
type=str,
default=EngineArgs.guided_decoding_disable_any_whitespace,
help=
"Disabled any whitespaces when using guided decoding backend XGrammar."
)
# Parallel processing parameters group
parallel_group = parser.add_argument_group("Parallel Configuration")
@@ -264,11 +436,38 @@ class EngineArgs:
type=float,
default=EngineArgs.gpu_memory_utilization,
help="Fraction of GPU memory to be utilized.")
parallel_group.add_argument(
"--kv-cache-ratio",
parallel_group.add_argument("--data-parallel-size",
type=int,
default=EngineArgs.data_parallel_size,
help="Degree of data parallelism.")
parallel_group.add_argument("--enable-expert-parallel",
action='store_true',
default=EngineArgs.enable_expert_parallel,
help="Enable expert parallelism.")
# CacheConfig parameters group
cache_group = parser.add_argument_group("Cache Configuration")
cache_group.add_argument("--kv-cache-ratio",
type=float,
default=EngineArgs.kv_cache_ratio,
help="Ratio of tokens to process in a block.")
cache_group.add_argument(
"--swap-space",
type=float,
default=EngineArgs.kv_cache_ratio,
help="Ratio of tokens to process in a block.")
default=EngineArgs.swap_space,
help="The amount of CPU memory to offload to.")
cache_group.add_argument("--cache-queue-port",
type=int,
default=EngineArgs.cache_queue_port,
help="port for cache queue")
cache_group.add_argument("--static-decode-blocks",
type=int,
default=EngineArgs.static_decode_blocks,
help="Static decoding blocks num.")
# Cluster system parameters group
system_group = parser.add_argument_group("System Configuration")
@@ -285,18 +484,60 @@ class EngineArgs:
# Performance tuning parameters group
perf_group = parser.add_argument_group("Performance Tuning")
perf_group.add_argument("--enable-prefix-caching",
action='store_true',
default=EngineArgs.enable_prefix_caching,
help="Flag to enable prefix caching.")
perf_group.add_argument("--splitwise-role",
type=str,
default=EngineArgs.splitwise_role,
help="Role of splitwise. Default is \
'mixed'. (prefill, decode, mixed)")
perf_group.add_argument("--innode-prefill-ports",
type=lambda s: s.split(",") if s else None,
default=EngineArgs.innode_prefill_ports,
help="port for innode prefill")
perf_group.add_argument("--enable-chunked-prefill",
action='store_true',
default=EngineArgs.enable_chunked_prefill,
help="Flag to enable chunked prefill.")
perf_group.add_argument("--max-num-partial-prefills",
type=int,
default=EngineArgs.max_num_partial_prefills,
help="For chunked prefill, Maximum number \
of concurrent partial prefill requests.")
perf_group.add_argument(
"--enable-prefix-caching",
action='store_true',
default=EngineArgs.enable_prefix_caching,
help="Flag to enable prefix caching."
)
"--max-long-partial-prefills",
type=int,
default=EngineArgs.max_long_partial_prefills,
help=
("For chunked prefill, the maximum number of prompts longer than long-prefill-token-threshold"
"that will be prefilled concurrently."))
perf_group.add_argument(
"--enable-chunked-prefill",
action='store_true',
default=EngineArgs.enable_chunked_prefill,
help="Flag to enable chunked prefill."
)
"--long-prefill-token-threshold",
type=int,
default=EngineArgs.long_prefill_token_threshold,
help=("For chunked prefill, the threshold number of"
" tokens for a prompt to be considered long."))
perf_group.add_argument(
"--cache-transfer-protocol",
type=str,
default=EngineArgs.cache_transfer_protocol,
help="support protocol list, comma separated, default is ipc")
perf_group.add_argument("--pd-comm-port",
type=lambda s: s.split(",") if s else None,
default=EngineArgs.pd_comm_port,
help="port for splitwise communication.")
perf_group.add_argument("--rdma-comm-ports",
type=lambda s: s.split(",") if s else None,
default=EngineArgs.rdma_comm_ports,
help="ports for rdma communication.")
# Scheduler parameters group
scheduler_group = parser.add_argument_group("Scheduler")
@@ -320,14 +561,6 @@ class EngineArgs:
help=
f"TTL of request. Default is {EngineArgs.scheduler_ttl} seconds. (local,global)"
)
scheduler_group.add_argument(
"--scheduler-wait-response-timeout",
type=float,
default=EngineArgs.scheduler_wait_response_timeout,
help=
("Timeout for waiting for response. Default is "
f"{EngineArgs.scheduler_wait_response_timeout} seconds. (local,global)"
))
scheduler_group.add_argument(
"--scheduler-host",
default=EngineArgs.scheduler_host,
@@ -359,12 +592,62 @@ class EngineArgs:
f"Topic of scheduler. Defaule is {EngineArgs.scheduler_topic}. (global)"
)
scheduler_group.add_argument(
"--scheduler-remote-write-time",
type=int,
default=EngineArgs.scheduler_remote_write_time,
"--scheduler-min-load-score",
type=float,
default=EngineArgs.scheduler_min_load_score,
help=
f"Max write time of redis. Default is {EngineArgs.scheduler_remote_write_time} seconds (global)"
f"Minimum load score for task assignment. Default is {EngineArgs.scheduler_min_load_score} (global)"
)
scheduler_group.add_argument(
"--scheduler-load-shards-num",
type=int,
default=EngineArgs.scheduler_load_shards_num,
help=("Number of shards for load balancing table. Default is "
f"{EngineArgs.scheduler_load_shards_num} (global)"))
scheduler_group.add_argument(
"--scheduler-sync-period",
type=int,
default=EngineArgs.scheduler_sync_period,
help=f"SplitWise Use, node load sync period, "
f"Default is {EngineArgs.scheduler_sync_period}ms. (global)")
scheduler_group.add_argument(
"--scheduler-expire-period",
type=int,
default=EngineArgs.scheduler_expire_period,
help=f"SplitWise Use, node will not be scheduled after "
f"expire-period ms not sync load, Default is "
f"{EngineArgs.scheduler_expire_period}ms. (global)")
scheduler_group.add_argument(
"--scheduler-release-load-expire-period",
type=int,
default=EngineArgs.scheduler_release_load_expire_period,
help=f"SplitWise Use, scheduler will release req load after "
f"expire period(s). Default is "
f"{EngineArgs.scheduler_release_load_expire_period}. (global)")
scheduler_group.add_argument(
"--scheduler-reader-parallel",
type=int,
default=EngineArgs.scheduler_reader_parallel,
help=f"SplitWise Use, Results Reader Sync Parallel, "
f"Default is {EngineArgs.scheduler_reader_parallel}. (global)")
scheduler_group.add_argument(
"--scheduler-writer-parallel",
type=int,
default=EngineArgs.scheduler_writer_parallel,
help=f"SplitWise Use, Results Writer Sync Parallel, "
f"Default is {EngineArgs.scheduler_writer_parallel}. (global)")
scheduler_group.add_argument(
"--scheduler-reader-batch-size",
type=int,
default=EngineArgs.scheduler_reader_batch_size,
help=f"SplitWise Use, Results Reader Batch Size, "
f"Default is {EngineArgs.scheduler_reader_batch_size}. (global)")
scheduler_group.add_argument(
"--scheduler-writer-batch-size",
type=int,
default=EngineArgs.scheduler_writer_batch_size,
help=f"SplitWise Use, Results Writer Batch Size, "
f"Default is {EngineArgs.scheduler_writer_batch_size}. (global)")
return parser
@@ -385,18 +668,37 @@ class EngineArgs:
"""
return ModelConfig(model_name_or_path=self.model,
config_json_file=self.model_config_name,
dynamic_load_weight=self.dynamic_load_weight)
dynamic_load_weight=self.dynamic_load_weight,
quantization=self.quantization)
def create_cache_config(self) -> CacheConfig:
def create_cache_config(self, model_cfg) -> CacheConfig:
"""
Create and return a CacheConfig object based on the current settings.
"""
return CacheConfig(
block_size=self.block_size,
tensor_parallel_size=self.tensor_parallel_size,
gpu_memory_utilization=self.gpu_memory_utilization,
num_gpu_blocks_override=self.num_gpu_blocks_override,
kv_cache_ratio=self.kv_cache_ratio,
enable_prefix_caching=self.enable_prefix_caching)
enable_prefix_caching=self.enable_prefix_caching,
swap_space=self.swap_space,
cache_queue_port=self.cache_queue_port,
model_cfg=model_cfg,
enable_chunked_prefill=self.enable_chunked_prefill,
enc_dec_block_num=self.static_decode_blocks,
rdma_comm_ports=self.rdma_comm_ports,
cache_transfer_protocol=self.cache_transfer_protocol,
pd_comm_port=self.pd_comm_port,
)
def create_speculative_config(self) -> SpeculativeConfig:
"""
"""
if self.speculative_config is not None:
return SpeculativeConfig(**self.speculative_config)
else:
return SpeculativeConfig()
def create_scheduler_config(self) -> SchedulerConfig:
"""
@@ -404,14 +706,32 @@ class EngineArgs:
"""
prefix = "scheduler_"
prefix_len = len(prefix)
extra_params = [
"max_model_len", "enable_chunked_prefill",
"max_num_partial_prefills", "max_long_partial_prefills",
"long_prefill_token_threshold"
]
all = asdict(self)
params = dict()
for k, v in all.items():
if k[:prefix_len] == prefix:
params[k[prefix_len:]] = v
elif k in extra_params:
params[k] = v
return SchedulerConfig(**params)
def create_parallel_config(self) -> ParallelConfig:
"""
Create and return a ParallelConfig object based on the current settings.
"""
return ParallelConfig(
tensor_parallel_size=self.tensor_parallel_size,
enable_expert_parallel=self.enable_expert_parallel,
data_parallel_size=self.data_parallel_size,
)
def create_engine_config(self) -> Config:
"""
Create and return a Config object based on the current settings.
@@ -426,22 +746,37 @@ class EngineArgs:
else:
self.max_num_batched_tokens = self.max_model_len
scheduler_cfg = self.create_scheduler_config()
speculative_cfg = self.create_speculative_config()
return Config(
model_name_or_path=self.model,
model_config=model_cfg,
scheduler_config=scheduler_cfg,
tokenizer=self.tokenizer,
cache_config=self.create_cache_config(),
cache_config=self.create_cache_config(model_cfg),
parallel_config=self.create_parallel_config(),
max_model_len=self.max_model_len,
tensor_parallel_size=self.tensor_parallel_size,
max_num_seqs=self.max_num_seqs,
mm_processor_kwargs=self.mm_processor_kwargs,
speculative_config=self.speculative_config,
speculative_config=speculative_cfg,
max_num_batched_tokens=self.max_num_batched_tokens,
nnode=self.nnode,
pod_ips=self.pod_ips,
use_warmup=self.use_warmup,
engine_worker_queue_port=self.engine_worker_queue_port,
limit_mm_per_prompt=self.limit_mm_per_prompt,
mm_processor_kwargs=self.mm_processor_kwargs,
enable_mm=self.enable_mm,
enable_chunked_prefill=self.enable_chunked_prefill,
reasoning_parser=self.reasoning_parser,
splitwise_role=self.splitwise_role,
innode_prefill_ports=self.innode_prefill_ports,
max_num_partial_prefills=self.max_num_partial_prefills,
max_long_partial_prefills=self.max_long_partial_prefills,
long_prefill_token_threshold=self.long_prefill_token_threshold,
enable_static_graph_inference=self.enable_static_graph_inference,
use_cudagraph=self.use_cudagraph,
max_capture_batch_size=self.max_capture_batch_size,
guided_decoding_backend=self.guided_decoding_backend,
disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
)

View File

@@ -1,21 +1,28 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import json
import os
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional
from fastdeploy import envs
from fastdeploy.platforms import current_platform
from fastdeploy.scheduler import SchedulerConfig
from fastdeploy.utils import (check_unified_ckpt, get_host_ip,
from fastdeploy.utils import (ceil_div, check_unified_ckpt, get_host_ip,
is_port_available, llm_logger)
TaskOption = Literal["generate"]
@@ -23,37 +30,37 @@ TaskOption = Literal["generate"]
class ModelConfig:
"""
Configuration class for model settings and parameters.
Configuration class for the model.
Attributes:
model_dir (str): Path to the model directory
is_unified_ckpt (bool): Whether the checkpoint uses unified format
model_name_or_path (str): Model identifier or path
dynamic_load_weight (int): Dynamic weight loading flag
"""
Attributes:
model_dir (str): Directory path to the model.
is_unified_ckpt (bool): Flag indicating if the checkpoint is unified.
model_name_or_path (str): Name or path of the model.
"""
def __init__(self,
model_name_or_path: str,
config_json_file: str = "config.json",
dynamic_load_weight: int = 0,
quantization: str = None,
download_dir: Optional[str] = None):
"""
Initialize model configuration.
Initialize the ModelConfig class.
Args:
model_name_or_path (str): Model identifier or path
config_json_file (str): Model config file name (default: 'config.json')
dynamic_load_weight (int): Dynamic weight loading mode (default: 0)
download_dir (Optional[str]): Directory for downloaded models (default: None)
model_name_or_path (str): Name or path of the model.
config_json_file (str): Path to the configuration JSON file. Default is 'config.json'.
download_dir (Optional[str]): Directory to download model files. Default is None.
"""
self.model_dir = model_name_or_path
self.is_unified_ckpt = check_unified_ckpt(self.model_dir)
self.dynamic_load_weight = dynamic_load_weight
self.quantization = quantization
config_file = os.path.join(model_name_or_path, config_json_file)
if os.path.isfile(model_name_or_path):
try:
from paddlenlp.transformers import AutoConfig
from paddleformers.transformers import AutoConfig
config = AutoConfig.from_pretrained(model_name_or_path)
config_dict = {
k: v
@@ -63,10 +70,10 @@ Attributes:
setattr(self, key, value)
except Exception:
llm_logger.error(
"Don't support the current model, you can use `paddlenlp` to register your model."
"Don't support the current model, you can use `paddleformers` to register your model."
)
raise ValueError(
"Don't support the current model, you can use `paddlenlp` to register your model."
"Don't support the current model, you can use `paddleformers` to register your model."
)
else:
with open(config_file, "r", encoding="utf-8") as f:
@@ -85,12 +92,9 @@ Attributes:
def override_name_from_config(self):
"""
Update attribute names from model configuration.
Handles special cases like:
- Renaming infer_model_mp_num to tensor_parallel_size
- Adjusting num_hidden_layers based on remove_tail_layer
- Setting default mla_use_absorb value
Override attribute names from the exported model's configuration.
"""
if not self.is_unified_ckpt and hasattr(self, "infer_model_mp_num"):
self.tensor_parallel_size = self.infer_model_mp_num
del self.infer_model_mp_num
@@ -107,27 +111,19 @@ Attributes:
if not hasattr(self, "mla_use_absorb"):
self.mla_use_absorb = False
if not hasattr(self, "head_dim"):
assert hasattr(self, "hidden_size") and hasattr(
self, "num_attention_heads")
self.head_dim = self.hidden_size // self.num_attention_heads
def read_from_env(self):
"""
Load configuration from environment variables.
Sets default values if env vars not found.
Reads:
- MAX_STOP_SEQS_NUM (default: 5)
- STOP_SEQS_MAX_LEN (default: 8)
- ELLM_DYNAMIC_QUANT_TYPE (default: 'default')
- ELLM_DYNAMIC_USE_STOP_SEQS (default: 0)
- COMPRESSION_RATIO (default: 1.0)
- ROPE_THETA (default: 10000)
"""
self.max_stop_seqs_num = int(os.getenv("MAX_STOP_SEQS_NUM", "5"))
self.stop_seqs_max_len = int(os.getenv("STOP_SEQS_MAX_LEN", "8"))
Read configuration information from environment variables and update the object's attributes.
self.ellm_dynamic_quant_type = os.getenv("ELLM_DYNAMIC_QUANT_TYPE",
"default")
# Whether to use stop sequences in dynamic graph inference
self.ellm_dynamic_use_stop_seqs = int(
os.getenv("ELLM_DYNAMIC_USE_STOP_SEQS", "0")) == 1
If an attribute is not present or is an empty string in the environment variables, use the default value.
"""
self.max_stop_seqs_num = int(envs.FD_MAX_STOP_SEQS_NUM)
self.stop_seqs_max_len = int(envs.FD_STOP_SEQS_MAX_LEN)
def reset_config_value(key, value):
if not hasattr(self, key.lower()):
@@ -144,13 +140,12 @@ Attributes:
reset_config_value("ROPE_THETA", 10000)
def _get_download_model(self, model_name, model_type="default"):
# TODO: Implement dynamic graph for self-downloading models
# TODO: Provide dynamic graph for self-downloading and save to the specified download directory.
pass
def print(self):
"""
Print current model configuration.
Logs all attributes and their values.
Print all configuration information.
"""
llm_logger.info("Model Configuration Information :")
for k, v in self.__dict__.items():
@@ -161,19 +156,18 @@ Attributes:
class CacheConfig:
"""
Configuration for key-value cache management.
Configuration for the KV cache.
Attributes:
block_size (int): Tokens per cache block
gpu_memory_utilization (float): GPU memory usage fraction (0-1)
cache_dtype (str): Data type for cache (default: 'bfloat16')
num_gpu_blocks_override (Optional[int]): Manual GPU blocks override
kv_cache_ratio (float): Max blocks ratio (default: 0.75)
enc_dec_block_num (int): Encoder-decoder blocks count
enable_prefix_caching (bool): Prefix caching enable flag
total_block_num (int): Total available blocks
prefill_kvcache_block_num (int): Blocks allocated for prefill
"""
Attributes:
block_size (int): Size of a cache block in number of tokens.
gpu_memory_utilization (float): Fraction of GPU memory to use for model execution.
cache_dtype (str): Data type for kv cache storage. Default is 'bfloat16'.
num_gpu_blocks_override (Optional[int]): Number of GPU blocks to use.
Overrides profiled num_gpu_blocks if provided.
kv_cache_ratio (float): Ratio for calculating the maximum block number.
enc_dec_block_num (int): Number of encoder-decoder blocks.
enable_prefix_caching (bool): Flag to enable prefix caching.
"""
def __init__(
self,
@@ -181,21 +175,31 @@ Attributes:
gpu_memory_utilization: float,
cache_dtype: str = "bfloat16",
num_gpu_blocks_override: Optional[int] = None,
swap_space: Optional[int] = None,
kv_cache_ratio: float = 0.75,
enc_dec_block_num: int = 2,
enable_prefix_caching: bool = False,
tensor_parallel_size: int = 1,
enable_prefix_caching=False,
enable_ssd_cache=False,
model_cfg=None,
cache_queue_port=None,
enable_chunked_prefill=False,
rdma_comm_ports=None,
cache_transfer_protocol=None,
pd_comm_port=None,
):
"""
Initialize cache configuration.
Initialize the CacheConfig class.
Args:
block_size (int): Tokens per cache block
gpu_memory_utilization (float): GPU memory usage target (0-1)
cache_dtype (str): Cache data type (default: 'bfloat16')
num_gpu_blocks_override (Optional[int]): Manual GPU blocks setting
kv_cache_ratio (float): Max blocks ratio (default: 0.75)
enc_dec_block_num (int): Encoder-decoder blocks count
enable_prefix_caching (bool): Enable prefix sharing
block_size (int): Size of a cache block in number of tokens.
gpu_memory_utilization (float): Fraction of GPU memory to use.
cache_dtype (str): Data type for cache storage. Default is 'bfloat16'.
num_gpu_blocks_override (Optional[int]): Override for number of GPU blocks.
num_cpu_blocks (Optional[int]): Number of CPU blocks.
kv_cache_ratio (float): Ratio for max block calculation.
enc_dec_block_num (int): Number of encoder-decoder blocks.
enable_prefix_caching (bool): Enable prefix caching.
"""
self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization
@@ -203,20 +207,74 @@ Attributes:
self.kv_cache_ratio = kv_cache_ratio
self.enc_dec_block_num = enc_dec_block_num
self.cache_dtype = cache_dtype
if hasattr(model_cfg, "quantization_config"):
self.cache_dtype = model_cfg.quantization_config.get(
"kv_cache_quant_type", cache_dtype)
self.enable_chunked_prefill = enable_chunked_prefill
self.rdma_comm_ports = rdma_comm_ports
self.cache_transfer_protocol = cache_transfer_protocol
self.pd_comm_port = pd_comm_port
if rdma_comm_ports is not None and isinstance(rdma_comm_ports, str):
self.rdma_comm_ports = rdma_comm_ports.split(',')
if pd_comm_port is not None and isinstance(pd_comm_port, str):
self.pd_comm_port = [int(port) for port in pd_comm_port.split(",")]
self.enable_prefix_caching = enable_prefix_caching
if swap_space is None:
self.enable_hierarchical_cache = False
else:
self.enable_hierarchical_cache = True
self.enable_ssd_cache = enable_ssd_cache
self.model_cfg = model_cfg
self.cache_queue_port = cache_queue_port
self.swap_space = swap_space
if (hasattr(self.model_cfg, "num_key_value_heads")
and hasattr(self.model_cfg, "num_key_value_heads")
and self.model_cfg.num_key_value_heads is not None
and int(self.model_cfg.num_key_value_heads) > 0):
kv_num_head = int(self.model_cfg.num_key_value_heads)
else:
kv_num_head = self.model_cfg.num_attention_heads
self.model_cfg.kv_num_head = kv_num_head
# TODO check name
if "int4" in self.cache_dtype.lower(
) or "float4" in self.cache_dtype.lower():
byte_size = 0.5
self.cache_dtype = "uint8"
elif "int8" in self.cache_dtype.lower(
) or "float8" in self.cache_dtype.lower():
self.cache_dtype = "uint8"
byte_size = 1
else:
byte_size = 2
self.each_token_cache_space = int(
self.model_cfg.num_layers * kv_num_head * self.model_cfg.head_dim *
byte_size)
self.bytes_per_block = int(self.each_token_cache_space *
self.block_size)
self.bytes_per_layer_per_block = int(
self.block_size * self.model_cfg.kv_num_head *
self.model_cfg.head_dim // tensor_parallel_size * byte_size)
if self.swap_space is None:
self.num_cpu_blocks = 0
else:
self.num_cpu_blocks = int(self.swap_space * 1024**3 /
self.bytes_per_block)
self._verify_args()
def metrics_info(self):
"""
Convert config to metrics dictionary.
Returns:
Dict[str, str]: Key-value pairs of all config attributes
"""
"""Convert cache_config to dict(key: str, value: str) for prometheus metrics info."""
return {key: str(value) for key, value in self.__dict__.items()}
def _verify_args(self):
"""Validate configuration arguments."""
if self.gpu_memory_utilization > 1.0:
raise ValueError(
"GPU memory utilization must be less than 1.0. Got "
@@ -227,46 +285,38 @@ Attributes:
def postprocess(self, num_total_tokens, number_of_tasks):
"""
Calculate block allocation based on tokens and tasks.
Args:
num_total_tokens (int): Total tokens to process
number_of_tasks (int): Number of parallel tasks
Sets:
dec_token_num (int): Decoder tokens per block
total_block_num (int): Total blocks needed
prefill_kvcache_block_num (int): Blocks for prefill phase
calculate block num
"""
self.dec_token_num = self.enc_dec_block_num * self.block_size
if self.num_gpu_blocks_override is not None:
self.total_block_num = self.num_gpu_blocks_override
self.prefill_kvcache_block_num= int(self.total_block_num * self.kv_cache_ratio)
self.prefill_kvcache_block_num = int(self.total_block_num *
self.kv_cache_ratio)
else:
length = num_total_tokens // number_of_tasks
block_num = (length + self.block_size - 1 + self.enc_dec_block_num) // self.block_size
self.total_block_num = block_num * number_of_tasks
self.prefill_kvcache_block_num= self.total_block_num
llm_logger.info(f"Doing profile, the total_block_num:{self.total_block_num}")
block_num = (length + self.block_size - 1 +
self.dec_token_num) // self.block_size
self.total_block_num = block_num * number_of_tasks
self.prefill_kvcache_block_num = self.total_block_num
llm_logger.info(
f"Doing profile, the total_block_num:{self.total_block_num}")
def reset(self, num_gpu_blocks):
"""
Reset GPU block allocation.
Args:
num_gpu_blocks (int): New total blocks count
Updates:
total_block_num (int)
prefill_kvcache_block_num (int)
reset gpu block number
"""
self.total_block_num = num_gpu_blocks
self.prefill_kvcache_block_num= int(self.total_block_num * self.kv_cache_ratio)
llm_logger.info((f"Reset block num, the total_block_num:{self.total_block_num},"
f" prefill_kvcache_block_num:{self.prefill_kvcache_block_num}"))
self.total_block_num = num_gpu_blocks
self.prefill_kvcache_block_num = int(self.total_block_num *
self.kv_cache_ratio)
llm_logger.info(
(f"Reset block num, the total_block_num:{self.total_block_num},"
f" prefill_kvcache_block_num:{self.prefill_kvcache_block_num}"))
def print(self):
"""Print current cache configuration."""
"""
print all config
"""
llm_logger.info("Cache Configuration Information :")
for k, v in self.__dict__.items():
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
@@ -274,38 +324,182 @@ Attributes:
"=============================================================")
class SpeculativeConfig:
"""
Speculative Decoding Configuration class.
Attributes:
method (Optional[str]): Method used for speculative decoding.
num_speculative_tokens (int): Maximum draft tokens, default is 1.
model_name_or_path (Optional[str]): Path of the model.
quantization (str): Quantization method for draft model, default is WINT8.
max_model_len: Optional[int]: Maximum model length for draft model.
"""
def __init__(self,
method: Optional[str] = None,
num_speculative_tokens: Optional[int] = 1,
model: Optional[str] = None,
quantization: Optional[str] = "WINT8",
max_model_len: Optional[int] = None,
**kwargs):
self.model_name_or_path = model
self.method = method
self.num_speculative_tokens = num_speculative_tokens
self.quantization = quantization
self.max_model_len = max_model_len
# Fixed now
self.num_gpu_block_expand_ratio = 1
self.num_extra_cache_layer = 0
for key, value in kwargs.items():
try:
setattr(self, key, value)
except Exception:
continue
self.read_model_config()
self.reset()
def read_model_config(self):
"""
Read configuration from file.
"""
self.model_config = {}
if not self.enabled_speculative_decoding():
return
self.is_unified_ckpt = check_unified_ckpt(self.model_name_or_path)
if self.model_name_or_path is None:
return
self.config_path = os.path.join(self.model_name_or_path, "config.json")
if os.path.exists(self.config_path):
self.model_config = json.load(
open(self.config_path, 'r', encoding='utf-8'))
def reset(self):
"""
Reset configuration.
"""
def reset_value(cls, value_name, key=None, default=None):
if key is not None and key in cls.model_config:
setattr(cls, value_name, cls.model_config[key])
elif getattr(cls, value_name, None) is None:
setattr(cls, value_name, default)
if not self.enabled_speculative_decoding():
return
# NOTE(liuzichang): We will support multi-layer in future
if self.method in ["mtp"]:
self.num_extra_cache_layer = 1
def enabled_speculative_decoding(self):
"""
Check if speculative decoding is enabled.
"""
if self.method is None:
return False
return True
def to_json_string(self):
"""
Convert speculative_config to json string.
"""
return json.dumps({
key: value
for key, value in self.__dict__.items() if value is not None
})
def print(self):
"""
print all config
"""
llm_logger.info("Speculative Decoding Configuration Information :")
for k, v in self.__dict__.items():
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
llm_logger.info(
"=============================================================")
class ParallelConfig:
"""
Configuration for parallelism.
Attributes:
tensor_parallel_size (int): Size of tensor parallelism.
data_parallel_size (int): Size of data parallelism.
local_data_parallel_id (int): ID of local data parallel.
enable_expert_parallel (bool): Whether to enable expert parallel.
"""
def __init__(
self,
tensor_parallel_size: int = 1,
data_parallel_size: int = 1,
enable_expert_parallel: bool = False,
):
"""
Initialize the ParallelConfig class.
Args:
tensor_parallel_size (int): Size of tensor parallelism.
data_parallel_size (int): Size of data parallelism.
local_data_parallel_id (int): ID of local data parallel.
enable_expert_parallel (bool): Whether to enable expert parallel.
"""
self.tensor_parallel_size = tensor_parallel_size
self.data_parallel_size = data_parallel_size
self.enable_expert_parallel = enable_expert_parallel
self.expert_parallel_size = data_parallel_size
self.local_data_parallel_id = 0
def print(self):
"""
print all config
"""
llm_logger.info("Parallel Configuration Information :")
for k, v in self.__dict__.items():
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
llm_logger.info("==================")
class Config:
"""
Main engine configuration class combining all components.
Initial configuration class.
Attributes:
model_config (ModelConfig): Model settings
cache_config (CacheConfig): Cache management settings
scheduler_config (SchedulerConfig): Task scheduling settings
model_name_or_path (str): Model identifier/path
tokenizer (str): Tokenizer identifier
tensor_parallel_size (int): Parallelism degree (default: 8)
nnode (int): Node count (default: 1)
max_model_len (int): Max sequence length (default: 8192)
max_num_seqs (int): Max concurrent sequences (default: 8)
max_num_batched_tokens (Optional[int]): Max batched tokens
pod_ips (Optional[List[str]]): Cluster node IPs
mm_processor_kwargs (Optional[Dict]): Multi-modal processor args
speculative_config (Optional[Dict]): Speculative execution settings
use_warmup (bool): Warmup enable flag
enable_mm (bool): Multi-modal enable flag
enable_chunked_prefill (bool): Chunked prefill enable flag
device_ids (str): GPU device IDs
tp_num_per_node (int): Tensor parallelism per node
host_ip (str): Current host IP
paddle_commit_id (str): PaddlePaddle version
"""
Attributes:
model_config (ModelConfig): Model configuration object.
cache_config (CacheConfig): Cache configuration object.
model_name_or_path (str): Directory path to the model or the model name.
tokenizer (Optional[str]): Default is the model.
max_num_batched_tokens (Optional[int]): Maximum number of batched tokens.
tensor_parallel_size (int): Tensor parallel size.
nnode (int): Number of nodes.
max_model_len (int): Maximum model length. Default is 8192.
max_num_seqs (int): Maximum number of sequences. Default is 8.
mm_processor_kwargs (Optional[Dict[str, Any]]): Additional arguments for multi-modal processor.
speculative_config (Optional[Dict[str, Any]]): Speculative execution configuration.
use_warmup (bool): Flag to use warmup.
engine_worker_queue_port (int): Port for engine worker queue.
enable_mm (bool): Flag to enable multi-modal processing.
reasoning_parser(str): Flag specifies the reasoning parser to use for
extracting reasoning content from the model output
splitwise_role (str): Splitwise role.
innode_prefill_ports (Optional[List[int]]): Innode prefill ports.
Temporary configuration, will be removed in the future.
"""
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
scheduler_config: SchedulerConfig,
parallel_config: ParallelConfig,
model_name_or_path: str = None,
tokenizer: str = None,
tensor_parallel_size: int = 8,
@@ -314,38 +508,57 @@ Attributes:
max_num_seqs: int = 8,
max_num_batched_tokens: Optional[int] = None,
pod_ips: Optional[List[str]] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
speculative_config: Optional[Dict[str, Any]] = None,
use_warmup: bool = False,
engine_worker_queue_port: int = 8002,
limit_mm_per_prompt: Optional[Dict[str, Any]] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
enable_mm: bool = False,
enable_chunked_prefill: bool = False,
splitwise_role: str = "mixed",
innode_prefill_ports: Optional[List[int]] = None,
max_num_partial_prefills: int = 1,
max_long_partial_prefills: int = 1,
long_prefill_token_threshold: int = 0,
reasoning_parser: str = None,
enable_static_graph_inference: bool = False,
use_cudagraph: bool = False,
max_capture_batch_size: int = 64,
guided_decoding_backend: Optional[str] = None,
disable_any_whitespace: bool = False,
):
"""
Initialize engine configuration.
Initialize the Config class.
Args:
model_config (ModelConfig): Model settings
cache_config (CacheConfig): Cache settings
scheduler_config (SchedulerConfig): Scheduler settings
model_name_or_path (str): Model identifier (default: None)
tokenizer (str): Tokenizer identifier (default: None)
tensor_parallel_size (int): Parallelism degree (default: 8)
nnode (int): Node count (default: 1)
max_model_len (int): Max sequence length (default: 8192)
max_num_seqs (int): Max concurrent sequences (default: 8)
max_num_batched_tokens (Optional[int]): Max batched tokens (default: None)
pod_ips (Optional[List[str]]): Cluster node IPs (default: None)
mm_processor_kwargs (Optional[Dict]): Multi-modal args (default: None)
speculative_config (Optional[Dict]): Speculative settings (default: None)
use_warmup (bool): Warmup flag (default: False)
engine_worker_queue_port (int): Worker queue port (default: 8002)
enable_mm (bool): Multi-modal flag (default: False)
enable_chunked_prefill (bool): Chunked prefill flag (default: False)
model_config (ModelConfig): Model configuration object.
cache_config (CacheConfig): Cache configuration object.
parallel_config (ParallelConfig): Parallel configuration object.
scheduler_config (SchedulerConfig): Scheduler configuration object.
model_name_or_path (str): Model directory path or model name.
tokenizer (str): Default is the model.
tensor_parallel_size (int): Tensor parallel size. Default is 8.
nnode (int): Number of nodes. Default is 1.
max_model_len (int): Maximum model length. Default is 8192.
max_num_seqs (int): Maximum number of sequences. Default is 8.
max_num_batched_tokens (Optional[int]): Maximum number of batched tokens. Default is None.
pod_ips (Optional[List[str]]): List of POD IPs. Default is None.
mm_processor_kwargs (Optional[Dict[str, Any]]): Additional arguments for multi-modal processor. Default is None.
speculative_config (Optional[Dict[str, Any]]): Speculative execution configuration. Default is None.
use_warmup (bool): Flag to use warmup. Default is False.
engine_worker_queue_port (int): Engine worker queue port. Default is 8002.
enable_mm (bool): Flag to enable multi-modal processing. Default is False.
splitwise_role (str): Splitwise role. Default is "mixed".
innode_prefill_ports (Optional[List[int]]): Innode prefill ports. Default is None.
reasoning_parser (str): Flag specifies the reasoning parser to use for
extracting reasoning content from the model output. Default is None.
guided_decoding_backend(str): Guided decoding backend. Default is None.
disable_any_whitespace(bool): Disable any whitespace when using guided decoding.
Default is False.
"""
self.model_config = model_config
self.cache_config = cache_config
self.scheduler_config = scheduler_config
self.parallel_config = parallel_config
self.model_name_or_path = model_name_or_path
self.tokenizer = tokenizer
self.max_num_batched_tokens = max_num_batched_tokens
@@ -354,20 +567,56 @@ Attributes:
self.pod_ips = pod_ips
self.max_model_len = max_model_len
self.max_num_seqs = max_num_seqs
self.limit_mm_per_prompt = limit_mm_per_prompt
self.mm_processor_kwargs = mm_processor_kwargs
self.enable_mm = enable_mm
self.speculative_config = speculative_config
self.use_warmup = use_warmup
self.enable_chunked_prefill = enable_chunked_prefill
self.splitwise_role = splitwise_role
self.innode_prefill_ports = innode_prefill_ports
self.max_num_partial_prefills = max_num_partial_prefills
self.max_long_partial_prefills = max_long_partial_prefills
self.long_prefill_token_threshold = long_prefill_token_threshold
self.reasoning_parser = reasoning_parser
self.enable_static_graph_inference = enable_static_graph_inference
self.use_cudagraph = use_cudagraph
self.max_capture_batch_size = max_capture_batch_size
self.guided_decoding_backend = guided_decoding_backend
self.disable_any_whitespace = disable_any_whitespace
if self.innode_prefill_ports is not None:
if not isinstance(self.innode_prefill_ports, list):
ports = str(self.innode_prefill_ports).split(',')
self.innode_prefill_ports = [int(port) for port in ports]
assert self.splitwise_role in ["mixed", "prefill", "decode"]
# TODO
self.max_prefill_batch = 3
if current_platform.is_xpu():
self.max_prefill_batch = 1
if enable_mm:
self.max_prefill_batch = 1 # TODO: Currently multi-modal prefill only supports parallelism=1 (needs optimization)
self.max_prefill_batch = 1 # TODO:当前多模prefill阶段只支持并行度为1,待优化
# TODO(@wufeisheng): TP and EP need to be supported simultaneously.
assert (self.tensor_parallel_size == 1
and self.parallel_config.expert_parallel_size
>= 1) or (self.tensor_parallel_size >= 1
and self.parallel_config.expert_parallel_size
== 1), "TP and EP cannot be enabled at the same time"
num_ranks = self.tensor_parallel_size * self.parallel_config.expert_parallel_size
if num_ranks > 8:
local_num_ranks = 8
self.nnode = ceil_div(num_ranks, local_num_ranks)
else:
local_num_ranks = num_ranks
self.engine_worker_queue_port = engine_worker_queue_port
self.device_ids = ",".join(
[str(i) for i in range(self.tensor_parallel_size)])
self.device_ids = ",".join([str(i) for i in range(min((self.tensor_parallel_size * \
self.parallel_config.expert_parallel_size), 8))])
self.device_ids = os.getenv("CUDA_VISIBLE_DEVICES", self.device_ids)
self.read_from_config()
@@ -377,45 +626,44 @@ Attributes:
def postprocess(self):
"""
Calculate derived parameters:
- Validates GPU device count matches tensor_parallel_size
- Computes tensor parallelism per node
- Gets host IP and Paddle version
- Sets default max_num_batched_tokens if not provided
- Initializes cache configuration
calculate some parameters
"""
if len(self.device_ids.split(',')) > self.tensor_parallel_size:
self.device_ids = ",".join(
self.device_ids.split(',')[:self.tensor_parallel_size:])
assert len(
self.device_ids.split(',')
) == self.tensor_parallel_size, f"The number of available GPUs is {len(self.device_ids.split(','))}, which is less than the tensor parallel required {self.tensor_parallel_size}."
assert self.tensor_parallel_size % self.nnode == 0, f"tensor_parallel_size: {self.tensor_parallel_size} should be divisible by nnode: {self.nnode}"
self.tp_num_per_node = self.tensor_parallel_size // self.nnode
total_rank = self.tensor_parallel_size * self.parallel_config.expert_parallel_size
assert self.device_ids.split(',').__len__() == min(total_rank, 8), \
f"invalid CUDA_VISIBLE_DEVICES, should be equal to {min(total_rank, 8)}"
self.local_device_ids = self.device_ids.split(
',')[:self.tensor_parallel_size]
assert self.tensor_parallel_size % self.nnode == 0, \
f"tensor_parallel_size: {self.tensor_parallel_size} should be divisible by nnode: {self.nnode}"
self.worker_num_per_node = total_rank // self.nnode
self.host_ip = get_host_ip()
import paddle
self.paddle_commit_id = paddle.version.commit
if self.max_num_batched_tokens is None:
if self.enable_chunked_prefill:
if self.cache_config.enable_chunked_prefill:
self.max_num_batched_tokens = 2048
else:
self.max_num_batched_tokens = self.max_model_len
self.cache_config.postprocess(self.max_num_batched_tokens, self.max_num_seqs)
if self.long_prefill_token_threshold == 0:
self.long_prefill_token_threshold = int(self.max_model_len * 0.04)
self.cache_config.postprocess(self.max_num_batched_tokens,
self.max_num_seqs)
self.cache_config.max_block_num_per_seq = int(
self.max_model_len // self.cache_config.block_size)
if self.guided_decoding_backend == "auto":
if self.enable_mm:
self.guided_decoding_backend = "off"
else:
self.guided_decoding_backend = "xgrammar"
def check(self):
"""
Validate configuration values:
- max_num_seqs <= 256
- engine_worker_queue_port available
- 1 <= tensor_parallel_size <= 8
- nnode >= 1
- max_model_len >= 16
- max_num_seqs >= 1
- Validates scheduler configuration
check the legality of config
"""
assert (
self.max_num_seqs <= 256
@@ -434,15 +682,66 @@ Attributes:
assert (
self.max_num_seqs
>= 1), f"max_num_seqs: {self.max_num_seqs} should be larger than 1"
assert (
self.max_num_batched_tokens >= self.max_num_seqs
), f"max_num_batched_tokens: {self.max_num_batched_tokens} " \
f"should be larger than or equal to max_num_seqs: {self.max_num_seqs}"
assert (self.max_num_batched_tokens <= self.max_model_len * self.max_num_seqs), \
f"max_num_batched_tokens: {self.max_num_batched_tokens} should be larger" \
f"than or equal to max_num_seqs: {self.max_num_seqs} * max_model_len: {self.max_model_len}"
assert (
self.max_num_partial_prefills >= 1
), f"max_num_partial_prefills: {self.max_num_partial_prefills} should be larger than or equal to 1"
assert (
self.max_long_partial_prefills >= 1
), f"max_long_partial_prefills: {self.max_long_partial_prefills} should be larger than or equal to 1"
assert (self.max_long_partial_prefills <= self.max_num_partial_prefills), \
f"max_long_partial_prefills: {self.max_long_partial_prefills} should " \
f"be less than or equal to max_num_partial_prefills: {self.max_num_partial_prefills}"
if not self.cache_config.enable_chunked_prefill:
assert (
self.max_num_batched_tokens >= self.max_model_len
), f"max_num_batched_tokens: {self.max_num_batched_tokens} " \
f"should be larger than or equal to max_model_len: {self.max_model_len}"
if self.max_num_partial_prefills > 1:
assert (self.cache_config.enable_chunked_prefill is True), \
"Chunked prefill must be enabled to set max_num_partial_prefills > 1"
assert (self.long_prefill_token_threshold < self.max_model_len), \
f"long_prefill_token_threshold: {self.long_prefill_token_threshold} should be less than"\
f" max_model_len: {self.max_model_len}"
if self.guided_decoding_backend is not None:
assert self.guided_decoding_backend in ["xgrammar", "XGrammar", "auto", "off"], \
f"Only support xgrammar、auto guided decoding backend, but got {self.guided_decoding_backend}."
if self.guided_decoding_backend != "off":
# TODO: mm support guided_decoding
assert self.enable_mm is False, "Multimodal model currently do not support guided_decoding"
# TODO: speculative decoding support guided_decoding
# TODO: xpu support guided_decoding
assert not current_platform.is_xpu(
), "XPU currently do not support guided_decoding"
try:
pass
except Exception as e:
raise Exception(
f"import XGrammar failed, please install XGrammar use `pip install xgrammar==0.1.19`. \n\t {e}"
)
self.scheduler_config.check()
def print(self, file=None):
"""
Print or save current configuration.
print all config
Args:
file (Optional[str]): File path to save config (default: None)
file (str): the path of file to save config
"""
llm_logger.info(
"=================== Configuration Information ===============")
@@ -450,7 +749,7 @@ Attributes:
if k == "generation_config" and v is not None:
for gck, gcv in v.to_dict().items():
llm_logger.info("{:<20}:{:<6}{}".format(gck, "", gcv))
elif k == "cache_config" or k == "model_config" or k == "scheduler_config":
elif k == "cache_config" or k == "model_config" or k == "scheduler_config" or k == "parallel_config":
v.print()
else:
llm_logger.info("{:<20}:{:<6}{}".format(k, "", v))
@@ -464,13 +763,36 @@ Attributes:
f.write("{:<20}:{:<6}{}\n".format(k, "", v))
f.close()
def init_cache_info(self):
"""
initialize cache info
"""
disaggregate_info = {}
if self.splitwise_role != "mixed":
disaggregate_info["role"] = self.splitwise_role
disaggregate_info["cache_info"] = dict()
current_protocol = self.cache_config.cache_transfer_protocol.split(
",")
disaggregate_info["transfer_protocol"] = current_protocol
for protocol in current_protocol:
if protocol == "ipc":
disaggregate_info["cache_info"][protocol] = {
"ip": self.host_ip,
"port": self.engine_worker_queue_port,
"device_ids": self.local_device_ids
}
elif protocol == "rdma":
disaggregate_info["cache_info"][protocol] = {
"ip": self.host_ip,
"port": self.cache_config.pd_comm_port[0],
"rdma_port": self.cache_config.rdma_comm_ports,
}
self.disaggregate_info = disaggregate_info
llm_logger.info(f"disaggregate_info: {self.disaggregate_info}")
def read_from_config(self):
"""
Update configuration from model JSON file.
Handles special cases:
- infer_model_block_size -> block_size
- return_full_hidden_states
- infer_model_dtype -> cache_dtype
reset model config from json file
"""
def reset_value(cls, value_name, key):
@@ -482,9 +804,9 @@ Attributes:
)
reset_value(self.cache_config, "block_size", "infer_model_block_size")
reset_value(self.model_config, "return_full_hidden_states", "return_full_hidden_states")
reset_value(self.model_config, "return_full_hidden_states",
"return_full_hidden_states")
reset_value(self.cache_config, "cache_dtype", "infer_model_dtype")
def __str__(self) -> str:
return json.dumps(self.__dict__, indent=4)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,370 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
import os
import signal
import threading
import time
import traceback
import weakref
import numpy as np
from fastdeploy.engine.resource_manager import ResourceManager
from fastdeploy.inter_communicator import EngineWorkerQueue
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.output.token_processor import TokenProcessor
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
from fastdeploy.utils import EngineError, console_logger, llm_logger
class ExpertService(object):
"""
Engine class responsible for managing the Large Language Model (LLM) operations.
Attributes:
cfg (Config): Configuration object containing all the parameters.
local_data_parallel_id (int): Local data parallel ID.
"""
def __init__(self, cfg, local_data_parallel_id):
"""
Initializes the LLMEngine with the provided configuration.
Args:
cfg (Config): Config object containing all the configuration parameters.
"""
self.cfg = cfg
start_pos = local_data_parallel_id * self.cfg.tensor_parallel_size
end_pos = (local_data_parallel_id + 1) * self.cfg.tensor_parallel_size
self.cfg.cache_config.rdma_comm_ports = self.cfg.cache_config.rdma_comm_ports[
start_pos:end_pos]
self.cfg.local_device_ids = self.cfg.device_ids.split(
",")[start_pos:end_pos]
self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id
self.cfg.disaggregate_info = None
self.scheduler = cfg.scheduler_config.scheduler()
self.scheduler.reset_nodeid(
f"{self.scheduler.infer.nodeid}_{str(local_data_parallel_id)}")
self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id
address = ('0.0.0.0', cfg.engine_worker_queue_port)
self.engine_worker_queue = EngineWorkerQueue(
address=address,
is_server=False,
client_id=0,
num_client=cfg.tensor_parallel_size,
local_data_parallel_id=local_data_parallel_id,
)
self.resource_manager = ResourceManager(cfg.max_num_seqs, cfg, \
cfg.tensor_parallel_size, cfg.splitwise_role, local_data_parallel_id)
if len(self.cfg.cache_config.pd_comm_port) == 1:
self.cfg.cache_config.pd_comm_port[0] = int(
self.cfg.cache_config.pd_comm_port[0]) + local_data_parallel_id
else:
self.cfg.cache_config.pd_comm_port = [
self.cfg.cache_config.pd_comm_port[local_data_parallel_id]
]
self.split_connector = SplitwiseConnector(self.cfg, self.scheduler,
self.engine_worker_queue,
self.resource_manager)
self.token_processor = TokenProcessor(
cfg=cfg,
cached_generated_tokens=self.scheduler,
engine_worker_queue=self.engine_worker_queue,
split_connector=self.split_connector)
self.token_processor.set_resource_manager(self.resource_manager)
self.partial_chunked_tokens = [0] * (
self.cfg.max_num_partial_prefills + 1)
for idx in range(1, self.cfg.max_num_partial_prefills + 1):
self.partial_chunked_tokens[idx] = (self.cfg.max_num_batched_tokens // idx) \
// self.cfg.cache_config.block_size * self.cfg.cache_config.block_size
self._finalizer = weakref.finalize(self, self._exit_sub_services)
def start(self, ipc_signal_suffix, local_data_parallel_id):
"""
Initializes the engine and starts its sub-services.
If `api_server_pid` is defined, will launch a thread
to keep getting request from zmq_server.
"""
# assert not self.is_started, "The engine is already started."
start_time = time.time()
llm_logger.info(f"start expert service {local_data_parallel_id}")
self.cache_manager_processes = self.resource_manager.cache_manager.launch_cache_manager(
self.cfg.cache_config, self.cfg.tensor_parallel_size,
self.cfg.local_device_ids, self.cfg.engine_worker_queue_port,
f"{local_data_parallel_id}_{ipc_signal_suffix}")
self.insert_task_to_worker_thread = threading.Thread(
target=self._insert_task_to_worker, args=())
self.insert_task_to_worker_thread.daemon = True
self.insert_task_to_worker_thread.start()
# Start TokenProcessor thread
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(
local_data_parallel_id + int(self.cfg.engine_worker_queue_port))
self.token_processor.run()
self.split_mode_get_tasks()
self.cfg.init_cache_info()
role = self.cfg.splitwise_role
host_ip = self.cfg.host_ip
disaggregate = self.cfg.disaggregate_info
self.scheduler.start(role, host_ip, disaggregate)
self.cfg.print()
console_logger.info(
"Worker processes are launched with {} seconds.".format(
time.time() - start_time))
return True
def _insert_task_to_worker(self):
"""
Insert task to engine thread, monitor scheduler request queue.
if the engine has resource, insert task to engine
"""
current_id = -1
while True:
try:
if self.resource_manager.available_batch() == 0:
time.sleep(0.001)
continue
if self.engine_worker_queue.num_tasks() > 0:
time.sleep(0.001)
continue
if len(self.split_connector.current_request_ids) > 0:
time.sleep(0.001)
continue
num_prefill_batch = min(
int(self.resource_manager.available_batch()),
self.cfg.max_prefill_batch)
self.resource_manager.check_and_free_block_tables()
tasks = self.scheduler.get_requests(
available_blocks=self.resource_manager.available_block_num(
),
block_size=self.cfg.cache_config.block_size,
reserved_output_blocks=self.cfg.cache_config.
enc_dec_block_num,
max_num_batched_tokens=self.cfg.max_num_batched_tokens,
batch=num_prefill_batch)
if len(tasks) == 0:
time.sleep(0.001)
continue
if self.cfg.splitwise_role != "mixed":
llm_logger.info("Inserting splitwise tasks")
self.split_connector.send_splitwise_tasks(
tasks, current_id)
current_id = (current_id + 1) % 100003
self.insert_tasks(tasks, current_id)
main_process_metrics.num_requests_waiting.dec(len(tasks))
main_process_metrics.num_requests_running.inc(len(tasks))
except Exception as e:
err_msg = "Error happend while insert task to engine: {}, {}.".format(
e, str(traceback.format_exc()))
llm_logger.error(err_msg)
def split_mode_get_tasks(self):
"""
Split mode get tasks
"""
waiting_requests = []
def receiver_loop():
while True:
try:
if len(waiting_requests) > 0:
for task in waiting_requests:
if self.resource_manager.is_resource_sufficient(
task.prompt_token_ids_len):
self.insert_tasks([task])
waiting_requests.remove(task)
else:
break
if not self.engine_worker_queue.disaggregate_queue_empty():
items = self.engine_worker_queue.get_disaggregated_tasks(
)
for item in items:
role = item[0]
tasks = item[1]
if role == "prefill":
llm_logger.info("get prefill tasks")
for task in tasks:
task.max_tokens = task.min_tokens = 2
self.insert_tasks(tasks)
elif role == "decode":
llm_logger.info(f"get decode tasks {tasks}")
if hasattr(tasks[0], 'finished'):
if not isinstance(tasks, list):
tasks = [tasks]
for task in tasks:
task.finished = False
# self.scheduler.put_results(tasks)
self.insert_tasks(tasks, allocated=True)
else:
if len(waiting_requests):
for task in tasks:
waiting_requests.append(task)
else:
for task in tasks:
if not self.resource_manager.is_resource_sufficient(
task.prompt_token_ids_len):
waiting_requests.append(task)
else:
self.insert_tasks([task])
else:
time.sleep(0.001)
continue
except Exception as e:
llm_logger.error(f"get decode tasks error: {e}")
threading.Thread(target=receiver_loop, daemon=True).start()
def insert_tasks(self, tasks, current_id=-1, allocated=False):
"""
Insert tasks to engine.
"""
if allocated:
current_tasks = []
for task in tasks:
cur_task_idx = self.resource_manager.req_dict[task.request_id]
del self.resource_manager.req_dict[task.request_id]
cur_task = self.resource_manager.tasks_list[cur_task_idx]
if task.error_code != 200:
self.resource_manager.stop_flags[cur_task_idx] = True
self.resource_manager.tasks_list[cur_task_idx] = None
self.resource_manager._recycle_block_tables(cur_task)
if task.request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[
task.request_id]
self.scheduler.put_results([task])
llm_logger.warning(
f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource."
)
continue
llm_logger.info(f"{cur_task_idx} {task.request_id}")
cur_task.prompt_token_ids[0] = task.outputs.token_ids[0]
self.token_processor.tokens_counter[task.request_id] = 1
current_tasks.append(cur_task)
self.engine_worker_queue.put_tasks(
(current_tasks, self.resource_manager.real_bsz))
return True
self.resource_manager.check_and_free_block_tables()
if not isinstance(tasks, list):
tasks = [tasks]
for item in tasks:
item.schedule_start_time = time.time()
available_batch = np.sum(self.resource_manager.stop_flags)
if len(tasks) > available_batch:
llm_logger.error(
"Inserting batch:{} exceeds the available batch:{}.".format(
len(tasks), available_batch))
llm_logger.error("The exceeded part will be ignored!")
tasks = tasks[:available_batch]
req_ids = [t.request_id for t in tasks]
tasks = self.resource_manager.allocate_resources_for_new_tasks(tasks)
if not tasks:
error_msg = f"The request required resources is exceed the limit, request id={req_ids}."
llm_logger.error(error_msg)
raise EngineError(error_msg, error_code=500)
return False
self.token_processor.number_of_tasks += len(tasks)
is_decode = False
is_prefill = False
for i in range(len(tasks)):
if tasks[i].disaggregate_info is not None:
if tasks[i].disaggregate_info["role"] == "decode":
is_decode = True
else:
is_prefill = True
self.token_processor.number_of_input_tokens += tasks[
i].prompt_token_ids_len
self.split_connector.send_cache_infos(tasks, current_id)
for task in tasks:
task.infer_start_time = time.time()
if not is_decode:
llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
if not is_prefill:
if not self.cfg.enable_mm:
self.update_requests_chunk_size(tasks)
else:
self.update_mm_requests_chunk_size(tasks)
self.engine_worker_queue.put_tasks(
(tasks, self.resource_manager.real_bsz))
return True
def _exit_sub_services(self):
"""
exit sub services
"""
if hasattr(self, "cache_manager_processes"):
self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear(
)
self.resource_manager.cache_manager.cache_ready_signal.clear()
for p in self.cache_manager_processes:
llm_logger.info(f"Killing cache manager process {p.pid}")
try:
os.killpg(p.pid, signal.SIGTERM)
except:
pass
if hasattr(self, "zmq_server") and self.zmq_server is not None:
self.zmq_server.close()
def start_expert_service(cfg, local_data_parallel_id, ipc_signal_suffix):
"""
Start expert service
"""
expert_service = ExpertService(cfg, local_data_parallel_id)
try:
expert_service.start(ipc_signal_suffix, local_data_parallel_id)
expert_service.split_connector.start_receiver()
except Exception as e:
llm_logger.exception(f"Expert service failed to start: {e}")

View File

@@ -15,12 +15,12 @@
"""
from __future__ import annotations
import time
from dataclasses import asdict, dataclass, fields
from typing import Any, Dict, Optional, Union
import numpy
from dataclasses import dataclass, asdict, fields
from typing import TYPE_CHECKING, Optional, Union, Any
from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.utils import data_processor_logger
from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.utils import data_processor_logger
@@ -28,41 +28,33 @@ from fastdeploy.utils import data_processor_logger
@dataclass
class Request:
"""A class representing an inference request to the LLM engine.
Attributes:
request_id: Unique identifier for the request
prompt: Input prompt text or list of prompts
prompt_token_ids: Token IDs of the input prompt
prompt_token_ids_len: Length of prompt token IDs
messages: List of message dictionaries (for chat models)
history: Conversation history (for chat models)
system: System message (for chat models)
sampling_params: Parameters controlling text generation
eos_token_ids: List of end-of-sequence token IDs
arrival_time: Timestamp when request was received
preprocess_start_time: Timestamp when preprocessing started
preprocess_end_time: Timestamp when preprocessing completed
multimodal_inputs: Dictionary of multimodal inputs (images, audio etc.)
raw_request: Flag indicating if this is a raw request
"""
def __init__(
self,
request_id: str,
prompt: Optional[Union[str, list[str]]],
prompt_token_ids: Optional[list[int]],
prompt_token_ids_len: Optional[int],
messages: Optional[list[list[dict[str, Any]]]],
history: Optional[list[list[str]]],
system: Optional[Union[str, list[str]]],
sampling_params: SamplingParams,
eos_token_ids: Optional[list[int]],
arrival_time: float,
preprocess_start_time: Optional[float] = None,
preprocess_end_time: Optional[float] = None,
multimodal_inputs: Optional[dict] = None,
raw_request: bool = True
) -> None:
def __init__(self,
request_id: str,
prompt: Optional[Union[str, list[str]]],
prompt_token_ids: Optional[list[int]],
prompt_token_ids_len: Optional[int],
messages: Optional[list[list[dict[str, Any]]]],
history: Optional[list[list[str]]],
tools: Optional[list[Dict]],
system: Optional[Union[str, list[str]]],
sampling_params: SamplingParams,
eos_token_ids: Optional[list[int]],
arrival_time: float,
preprocess_start_time: Optional[float] = None,
preprocess_end_time: Optional[float] = None,
multimodal_inputs: Optional[dict] = None,
multimodal_data: Optional[dict] = None,
raw_request: bool = True,
disaggregate_info: Optional[dict] = None,
draft_token_ids: Optional[list[int]] = None,
guided_json: Optional[Any] = None,
guided_regex: Optional[Any] = None,
guided_choice: Optional[Any] = None,
guided_grammar: Optional[Any] = None,
structural_tag: Optional[Any] = None,
guided_json_object: Optional[bool] = None,
enable_thinking: Optional[bool] = True) -> None:
self.request_id = request_id
self.prompt = prompt
self.prompt_token_ids = prompt_token_ids
@@ -71,52 +63,66 @@ class Request:
self.system = system
self.sampling_params = sampling_params
self.history = history
self.tools = tools
# model specific token ids: end of sentence token ids
self.eos_token_ids = eos_token_ids
self.num_cached_tokens = 0
self.arrival_time = arrival_time
self.preprocess_start_time = preprocess_start_time
self.preprocess_end_time = preprocess_end_time
self.raw_request = raw_request
self.disaggregate_info = disaggregate_info
# speculative method in disaggregate-mode
self.draft_token_ids = draft_token_ids
# guided decoding related
self.guided_json = guided_json
self.guided_regex = guided_regex
self.guided_choice = guided_choice
self.guided_grammar = guided_grammar
self.structural_tag = structural_tag
self.guided_json_object = guided_json_object
# Multi-modal related
self.multimodal_inputs = multimodal_inputs
self.multimodal_data = multimodal_data
self.enable_thinking = enable_thinking
@classmethod
def from_dict(cls, d: dict):
"""Create a Request instance from a dictionary.
Args:
d: Dictionary containing request parameters
Returns:
Request: A new Request instance initialized with values from the dictionary
"""
data_processor_logger.debug(f"{d}")
sampling_params = SamplingParams.from_dict(d)
return cls(
request_id=d["request_id"],
prompt=d.get("prompt"),
prompt_token_ids=d.get("prompt_token_ids"),
prompt_token_ids_len=d.get("prompt_token_ids_len"),
messages=d.get("messages"),
system=d.get("system"),
history=d.get("history"),
sampling_params=sampling_params,
eos_token_ids=d.get("eos_token_ids"),
arrival_time=d.get("arrival_time", time.time()),
preprocess_start_time=d.get("preprocess_start_time"),
preprocess_end_time=d.get("preprocess_end_time"),
multimodal_inputs=d.get("multimodal_inputs"),
raw_request=d.get("raw_request", True)
)
return cls(request_id=d["request_id"],
prompt=d.get("prompt"),
prompt_token_ids=d.get("prompt_token_ids"),
prompt_token_ids_len=d.get("prompt_token_ids_len"),
messages=d.get("messages"),
system=d.get("system"),
history=d.get("history"),
tools=d.get("tools"),
sampling_params=sampling_params,
eos_token_ids=d.get("eos_token_ids"),
arrival_time=d.get("arrival_time", time.time()),
preprocess_start_time=d.get("preprocess_start_time"),
preprocess_end_time=d.get("preprocess_end_time"),
multimodal_inputs=d.get("multimodal_inputs"),
multimodal_data=d.get("multimodal_data"),
disaggregate_info=d.get("disaggregate_info"),
draft_token_ids=d.get("draft_token_ids"),
raw_request=d.get("raw_request", True),
guided_json=d.get("guided_json", None),
guided_regex=d.get("guided_regex", None),
guided_choice=d.get("guided_choice", None),
guided_grammar=d.get("guided_grammar", None),
structural_tag=d.get("structural_tag", None),
guided_json_object=d.get("guided_json_object", None),
enable_thinking=d.get("enable_thinking", True))
def to_dict(self) -> dict:
"""Convert the Request object into a serializable dictionary.
Returns:
dict: A dictionary containing all request attributes and sampling parameters
"""
"""convert Request into a serializable dict """
data = {
"request_id": self.request_id,
"prompt": self.prompt,
@@ -125,26 +131,30 @@ class Request:
"messages": self.messages,
"system": self.system,
"history": self.history,
"tools": self.tools,
"eos_token_ids": self.eos_token_ids,
"arrival_time": self.arrival_time,
"preprocess_start_time": self.preprocess_start_time,
"preprocess_end_time": self.preprocess_end_time,
"multimodal_inputs": self.multimodal_inputs,
"raw_request": self.raw_request
"multimodal_data": self.multimodal_data,
"raw_request": self.raw_request,
"disaggregate_info": self.disaggregate_info,
"draft_token_ids": self.draft_token_ids,
"enable_thinking": self.enable_thinking
}
add_params = [
"guided_json", "guided_regex", "guided_choice", "guided_grammar",
"structural_tag", "guided_json_object"
]
for param in add_params:
if getattr(self, param, None) is not None:
data[param] = getattr(self, param)
data.update(asdict(self.sampling_params))
return data
def get(self, key: str, default_value=None):
"""Get an attribute value from either the Request or its sampling parameters.
Args:
key: Attribute name to retrieve
default_value: Default value to return if attribute not found
Returns:
The attribute value if found, otherwise default_value
"""
if hasattr(self, key):
return getattr(self, key)
elif hasattr(self.sampling_params, key):
@@ -153,12 +163,6 @@ class Request:
return default_value
def set(self, key, value):
"""Set an attribute value on either the Request or its sampling parameters.
Args:
key: Attribute name to set
value: Value to assign to the attribute
"""
if hasattr(self.sampling_params, key):
setattr(self.sampling_params, key, value)
else:
@@ -168,6 +172,7 @@ class Request:
return (f"Request(request_id={self.request_id}, "
f"prompt={self.prompt!r}, "
f"prompt_token_ids={self.prompt_token_ids}, "
f"draft_token_ids={self.draft_token_ids}, "
f"sampling_params={self.sampling_params})")
@@ -182,22 +187,42 @@ class CompletionOutput:
"""
index: int
send_idx: int
token_ids: list[int]
draft_token_ids: list[int] = None
text: Optional[str] = None
reasoning_content: Optional[str] = None
def to_dict(self):
"""
convert CompletionOutput to a serialized dict
"""
return {
"index": self.index,
"send_idx": self.send_idx,
"token_ids": self.token_ids,
"draft_token_ids": self.draft_token_ids,
"text": self.text,
"reasoning_content": self.reasoning_content
}
@classmethod
def from_dict(cls, req_dict: dict[str, Any]) -> 'CompletionOutput':
"""Create instance from dict arguments"""
return cls(**{
field.name: req_dict[field.name] if field.name in req_dict else field.default
for field in fields(cls)
})
return cls(
**{
field.name:
req_dict[field.name] if field.name in
req_dict else field.default
for field in fields(cls)
})
def __repr__(self) -> str:
return (f"CompletionOutput(index={self.index}, "
f"send_idx={self.send_idx}, "
f"text={self.text!r}, "
f"token_ids={self.token_ids}, "
f"draft_token_ids={self.draft_token_ids}, "
f"reasoning_content={self.reasoning_content!r}")
@@ -227,13 +252,31 @@ class RequestMetrics:
model_execute_time: Optional[float] = None
request_start_time: Optional[float] = None
def to_dict(self):
"""
Convert the RequestMetrics object to a dictionary.
"""
return {
"arrival_time": self.arrival_time,
"inference_start_time": self.inference_start_time,
"first_token_time": self.first_token_time,
"time_in_queue": self.time_in_queue,
"preprocess_cost_time": self.preprocess_cost_time,
"model_forward_time": self.model_forward_time,
"model_execute_time": self.model_execute_time,
"request_start_time": self.request_start_time
}
@classmethod
def from_dict(cls, req_dict: dict[str, Any]) -> 'RequestMetrics':
"""Create instance from dict arguments"""
return cls(**{
field.name: req_dict[field.name] if field.name in req_dict else field.default
for field in fields(cls)
})
return cls(
**{
field.name:
req_dict[field.name] if field.name in
req_dict else field.default
for field in fields(cls)
})
class RequestOutput:
@@ -282,16 +325,8 @@ class RequestOutput:
self.error_msg = error_msg
def add(self, next_output: "RequestOutput") -> None:
"""Merge another RequestOutput into this one.
Args:
next_output: The RequestOutput to merge into this one
Updates:
- Combines output sequences
- Updates finish status
- Calculates timing metrics
"""
"""Merge RequestOutput into this one"""
self.prompt = next_output.prompt
self.prompt_token_ids = next_output.prompt_token_ids
self.finished |= next_output.finished
@@ -314,25 +349,13 @@ class RequestOutput:
@classmethod
def from_dict(cls, d: dict):
"""Create a RequestOutput instance from a dictionary.
Args:
d: Dictionary containing request output parameters
Returns:
RequestOutput: A new RequestOutput instance initialized with values from the dictionary
"""
"""Create instance from dict arguments"""
completion_output = CompletionOutput.from_dict(d.pop("outputs"))
metrics = RequestMetrics.from_dict(d.pop("metrics"))
return RequestOutput(**d, outputs=completion_output, metrics=metrics)
def to_dict(self):
"""Convert the RequestOutput object into a serializable dictionary.
Returns:
dict: A dictionary containing all request output attributes,
with token IDs converted to lists if necessary
"""
"""convert RequestOutput into a serializable dict """
if self.prompt_token_ids is None:
self.prompt_token_ids = []
@@ -343,11 +366,12 @@ class RequestOutput:
"request_id": self.request_id,
"prompt": self.prompt,
"prompt_token_ids": self.prompt_token_ids,
"outputs": None if self.outputs is None else asdict(self.outputs),
"outputs":
None if self.outputs is None else self.outputs.to_dict(),
"metrics":
None if self.metrics is None else self.metrics.to_dict(),
"finished": self.finished,
"metrics": None if self.metrics is None else asdict(self.metrics),
"num_cached_tokens": self.num_cached_tokens,
"error_code": self.error_code,
"error_msg": self.error_msg,
}

View File

@@ -14,110 +14,116 @@
# limitations under the License.
"""
import copy
import os
import math
import random
import threading
import time
import numpy as np
from fastdeploy.cache_manager.prefix_cache_manager import PrefixCacheManager
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.utils import llm_logger
class ResourceManager(object):
"""Manages and allocates computational resources for the inference engine.
This class handles the allocation and recycling of memory blocks for KV cache,
manages task scheduling, and tracks resource utilization.
"""
def __init__(self, max_num_seqs, cache_config):
"""Initializes the resource manager with configuration parameters.
Args:
max_num_seqs (int): Maximum number of concurrent sequences the engine can handle
cache_config (Config): Configuration object containing:
- prefill_kvcache_block_num: Number of pre-allocated KV cache blocks
- block_size: Size of each memory block in tokens
- dec_token_num: Number of decoder tokens
record and allocate resources for the engine
"""
def __init__(self,
max_num_seqs,
config,
tensor_parallel_size,
splitwise_role,
local_data_parallel_id=0):
"""
self.cfg = cache_config
Args:
cfg (Config): config object containing parameters for the engine
initialization
Returns:
None
Initializes the engine with the given configuration and sets up necessary
data structures to manage tasks and blocks.
"""
self.cfg = config.cache_config
self.max_num_seqs = max_num_seqs
self.stop_flags = [True] * max_num_seqs
self.free_list = list(range(self.cfg.prefill_kvcache_block_num - 1, -1, -1))
self.enable_prefix_cache = config.cache_config.enable_prefix_caching
self.cache_manager = PrefixCacheManager(config, tensor_parallel_size,
splitwise_role,
local_data_parallel_id)
self.tasks_list = [None] * max_num_seqs
self.req_dict = dict()
# current batch status of the engine
self.real_bsz = 0
llm_logger.info(f"{self.info()}")
def reset_cache_config(self, cfg):
"""Updates the cache configuration with new parameters.
Args:
cfg (Config): New cache configuration object
"""
reset cache config
"""
self.cfg = cfg
self.free_list = list(range(self.cfg.prefill_kvcache_block_num - 1, -1, -1))
self.cache_manager.update_cache_config(cfg)
def get_required_block_number(self, input_token_num):
"""Calculates the total number of blocks needed for a sequence.
Includes both encoder and decoder requirements.
Args:
input_token_num (int): Number of tokens in the input sequence
Returns:
int: Total number of blocks required (rounded up)
"""
block_num = (input_token_num + self.cfg.block_size - 1 + self.cfg.dec_token_num) // self.cfg.block_size
Calculate Block resources are needed
Args:
input_token_num (int): input token number
Returns:
int: block number
"""
block_num = (input_token_num + self.cfg.block_size - 1 +
self.cfg.dec_token_num) // self.cfg.block_size
return block_num
def get_encoder_block_number(self, input_token_num):
"""Calculates the number of blocks needed for encoder inputs only.
Args:
input_token_num (int): Number of tokens in the encoder input
Returns:
int: Number of blocks required for encoder (rounded up)
"""
enc_block_num = (input_token_num + self.cfg.block_size - 1) // self.cfg.block_size
get the number of blocks for the encoder
Args:
input_token_num (int): input token number
Returns:
int: encoder block number
"""
enc_block_num = (input_token_num + self.cfg.block_size -
1) // self.cfg.block_size
return enc_block_num
def get_decoder_block_number(self):
"""Calculates the number of blocks needed for decoder outputs.
Returns:
int: Number of blocks required for decoder (rounded up)
"""
return (self.cfg.dec_token_num + self.cfg.block_size - 1) // self.cfg.block_size
get the number of blocks for the decoder
Returns:
int: decoder block number
"""
return (self.cfg.dec_token_num + self.cfg.block_size -
1) // self.cfg.block_size
def total_block_number(self):
"""Gets the total number of pre-allocated KV cache blocks.
Returns:
int: Total number of blocks available in the pool
"""
return self.cfg.prefill_kvcache_block_num
the number of pre allocated blocks at service startup
Returns:
int: total block number
"""
return self.cache_manager.num_gpu_blocks
def _get_block_tables(self, input_token_num, required_type="all"):
"""Allocates memory blocks from the free pool.
"""
allocate memory resources
Args:
input_token_num (int): Number of input tokens
required_type (str): Type of blocks needed:
- "all": Both encoder and decoder blocks
- "encoder": Encoder blocks only
- "decoder": Decoder blocks only
input_token_num (int): input token number
required_type (str): required type
Returns:
list: List of allocated block IDs
Raises:
ValueError: If unknown required_type is specified
list: block list
"""
if required_type == "all":
block_num = self.get_required_block_number(input_token_num)
@@ -129,50 +135,75 @@ class ResourceManager(object):
raise ValueError('unknown required type')
block_list = list()
if block_num > len(self.free_list):
llm_logger.error("block_num:{0} > free_list len:{1}".format(block_num, len(self.free_list)))
current_block_num = self.available_block_num()
if block_num > current_block_num:
llm_logger.error("block_num:{0} > free_list len:{1}".format(
block_num, current_block_num))
return block_list
for _ in range(block_num):
used_block_id = self.free_list.pop()
block_list.append(used_block_id)
block_list = self.cache_manager.allocate_gpu_blocks(block_num)
llm_logger.debug(f"dispatch {len(block_list)} blocks.")
return block_list
def _recycle_block_tables(self, block_tables):
"""Returns memory blocks to the free pool for reuse.
Args:
block_tables (list): List of block IDs to recycle
def check_and_free_block_tables(self):
"""
ori_number = len(self.free_list)
self.free_list.extend(block_tables)
cur_number = len(self.free_list)
llm_logger.info(f"recycle {cur_number - ori_number} blocks.")
Check and free block tables only in prefix caching mode.
If the number of free blocks is less than a certain threshold, free up to the threshold.
"""
if self.enable_prefix_cache:
if self.available_block_num() < self.cfg.max_block_num_per_seq:
self.free_block_tables(self.cfg.max_block_num_per_seq)
def _recycle_block_tables(self, task):
"""
Recycling memory resource blocks
Args:
block_tables (list): block list
"""
if self.enable_prefix_cache:
self.cache_manager.release_block_ids_async(task)
else:
req_id = task.request_id
if isinstance(task, list):
block_tables = task
else:
block_tables = task.block_tables
ori_number = self.available_block_num()
self.cache_manager.recycle_gpu_blocks(block_tables)
cur_number = self.available_block_num()
main_process_metrics.gpu_cache_usage_perc.set(
self.get_gpu_cache_usage_perc())
llm_logger.info(
f"recycle {req_id} {cur_number - ori_number} blocks.")
def available_batch(self):
"""Gets the number of available sequence slots.
"""
available batch size for engine
Returns:
int: Number of available sequence slots in the batch
int: available batch size
"""
return np.sum(self.stop_flags)
def available_block_num(self):
"""Gets the number of available memory blocks.
Returns:
int: Number of free blocks in the pool
"""
return len(self.free_list)
available block size for engine
Returns:
int: available block size
"""
return len(self.cache_manager.gpu_free_block_list)
def is_resource_sufficient(self, input_token_num):
"""Checks if sufficient resources are available for a new sequence.
"""
check current available resources meet the new requirements
Args:
input_token_num (int): Number of tokens in the new sequence
input_token_num (int): input token number
Returns:
bool: True if both batch slots and memory blocks are available
bool: whether current available resources meet the new requirements
"""
if self.available_batch() < 1:
return False
@@ -181,19 +212,21 @@ class ResourceManager(object):
return False
return True
def free_block_tables(self, need_reserved_block_num):
"""
回收block到可用资源池
"""
return self.cache_manager.free_block_ids_async(need_reserved_block_num)
def allocate_resources_for_new_tasks(self, tasks):
"""Assigns resources to new inference tasks.
"""
allocate resources for new tasks
Args:
tasks (list): List of Request objects needing resources
tasks (list): task list
Returns:
list: List of successfully allocated Request objects
Note:
- Assigns sequence slots and memory blocks
- Sets initial timestamps and metadata
- Updates real-time batch size statistics
list: processed task list
"""
allocated_position = 0
@@ -205,7 +238,8 @@ class ResourceManager(object):
can_insert = False
while allocated_position + 1 <= self.max_num_seqs:
if sum(self.stop_flags[allocated_position : allocated_position + 1]) == 1:
if sum(self.stop_flags[allocated_position:allocated_position +
1]) == 1:
can_insert = True
break
allocated_position += 1
@@ -215,14 +249,61 @@ class ResourceManager(object):
task = tasks[processing_task_index]
if task.get("seed") is None:
task.set("seed", random.randint(0, 9223372036854775807))
task.set("seed",
random.randint(0, 9223372036854775807))
task.idx = allocated_position
block_tables = self._get_block_tables(task.prompt_token_ids_len)
if not block_tables:
llm_logger.error("req_id: {0} block_tables is empty".format(task.request_id))
continue
if self.enable_prefix_cache:
cache_prepare_time = time.time()
common_block_ids, unique_block_ids, hit_info = self.cache_manager.request_block_ids(
task, self.cfg.block_size, self.cfg.dec_token_num)
if unique_block_ids is None:
llm_logger.warning(
"req_id: {0} not enough blocks available".
format(task["req_id"]))
return
cached_len = self._record_request_cache_info(
task, common_block_ids, unique_block_ids, hit_info)
task.cache_prepare_time = time.time(
) - cache_prepare_time
if task.disaggregate_info is not None:
if task.disaggregate_info['role'] == "prefill":
self.req_dict[
task.request_id] = allocated_position
task.disaggregate_info[
'block_tables'] = task.block_tables
self._delete_cached_data(task, cached_len)
elif task.disaggregate_info['role'] == "decode":
self.req_dict[
task.request_id] = allocated_position
task.disaggregate_info[
'block_tables'] = task.need_block_tables
else:
self._delete_cached_data(task, cached_len)
else:
task.block_tables = block_tables
block_tables = self._get_block_tables(
task.prompt_token_ids_len)
if not block_tables:
llm_logger.error(
"req_id: {0} block_tables is empty".format(
task.request_id))
continue
else:
task.block_tables = block_tables
task.need_block_tables = task.block_tables
if task.disaggregate_info is not None:
task.disaggregate_info[
'block_tables'] = block_tables
if task.disaggregate_info['role'] == "prefill":
self.req_dict[
task.request_id] = allocated_position
elif task.disaggregate_info['role'] == "decode":
self.req_dict[
task.request_id] = allocated_position
processed_tasks.append(task)
self.stop_flags[allocated_position] = False
@@ -230,9 +311,10 @@ class ResourceManager(object):
task.inference_time_cost = -1.0
task.tokens_all_num = int(0)
self.tasks_list[allocated_position] = task
llm_logger.info(f"Allocate request: {task.request_id}, "
f"allocated_position:{allocated_position}, "
f"length of prompt token: {task.prompt_token_ids_len}")
llm_logger.info(
f"Allocate request: {task.request_id}, "
f"allocated_position:{allocated_position}, "
f"length of prompt token: {task.prompt_token_ids_len}")
allocated_position += 1
processing_task_index += 1
@@ -242,20 +324,70 @@ class ResourceManager(object):
self.real_bsz = i + 1
break
llm_logger.info(f"Number of allocated requests: {len(tasks)}, number of "
f"running requests in worker: {self.real_bsz}")
llm_logger.info(
f"Number of allocated requests: {len(tasks)}, number of "
f"running requests in worker: {self.real_bsz}")
llm_logger.info(f"{self.info()}")
main_process_metrics.gpu_cache_usage_perc.set(
self.get_gpu_cache_usage_perc())
return processed_tasks
def _delete_cached_data(self, task, cached_len):
"""
Delete cached data from the task's prompt token ids based on the cached length.
"""
if cached_len == len(task.prompt_token_ids):
task.prompt_token_ids = task.prompt_token_ids[cached_len - 1:]
task.seq_lens_decoder = cached_len - 1
else:
task.prompt_token_ids = task.prompt_token_ids[cached_len:]
task.seq_lens_decoder = cached_len
task.prompt_token_ids_len = len(task.prompt_token_ids)
def _record_request_cache_info(self, task, common_block_ids,
unique_block_ids, hit_info):
"""
Record the cache information for a given task and its corresponding block IDs.
"""
cache_block_num = len(common_block_ids)
no_cache_block_num = math.ceil(len(task.prompt_token_ids) / self.cfg.block_size \
- cache_block_num)
task.num_cached_tokens = cache_block_num * self.cfg.block_size
task.gpu_cache_token_num = hit_info[
"gpu_cache_blocks"] * self.cfg.block_size
task.cpu_cache_token_num = hit_info[
"cpu_cache_blocks"] * self.cfg.block_size
task.cache_info = (cache_block_num, no_cache_block_num)
cached_len = len(common_block_ids) * self.cfg.block_size
task.block_tables = common_block_ids + unique_block_ids
task.need_block_tables = unique_block_ids
llm_logger.debug(f"common: {common_block_ids} ")
llm_logger.debug(f"unique: {unique_block_ids} ")
return cached_len
def info(self):
"""Generates a summary of current resource status.
"""
get resource manager info
Returns:
str: Formatted string showing:
- Total blocks/batch slots
- Available blocks/batch slots
str: resource manager info
"""
info = f"ResourceManager info, " \
f"total_block_number: {self.total_block_number()}, total_batch_number: {len(self.stop_flags)}, " \
f"available_block_num: {self.available_block_num()}, available_batch: {self.available_batch()}"
return info
def get_gpu_cache_usage_perc(self):
"""
Calculate GPU KV-cache usage
Returns:
float: GPU KV-cache usage (0.0 - 1.0)
"""
num_total_gpu = self.total_block_number()
num_free_gpu = len(self.cache_manager.gpu_free_block_list)
if num_total_gpu > 0:
return 1.0 - (num_free_gpu / num_total_gpu)
return 0.0

View File

@@ -15,9 +15,10 @@
"""
from __future__ import annotations
from dataclasses import dataclass, fields
from typing import Any, Optional, Union, List
import random
from dataclasses import dataclass, fields
from typing import Any, List, Optional, Union
@dataclass
@@ -62,6 +63,7 @@ class SamplingParams:
token sequence is not allowed when the next generated token
can complete the sequence.
max_tokens: Maximum number of tokens to generate per output sequence.
reasoning_max_tokens: Maximum number of tokens to generate for reasoning per output sequence.
min_tokens: Minimum number of tokens to generate per output sequence
before EOS or stop_token_ids can be generated
logprobs: Number of log probabilities to return per output token.
@@ -75,131 +77,107 @@ class SamplingParams:
n: int = 1
best_of: Optional[int] = None
presence_penalty: float = 0.0
frequency_penalty: float = 0.0
repetition_penalty: float = 1.0
temperature: float = 1.0
top_p: float = 0.7
presence_penalty: float = None
frequency_penalty: float = None
repetition_penalty: float = None
temperature: float = None
top_p: float = None
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
stop: Optional[Union[str, List[str]]] = None
stop_token_ids: Optional[Union[List[List[int]], List[int]]] = None
max_tokens: Optional[int] = 16
max_tokens: Optional[int] = None
reasoning_max_tokens: Optional[int] = None
min_tokens: int = 1
logprobs: Optional[int] = None
bad_words: Optional[List[str]] = None
@classmethod
def from_dict(cls, req_dict: dict[str, Any]) -> "SamplingParams":
"""Create a SamplingParams instance from a dictionary.
Args:
req_dict: Dictionary containing sampling parameters where keys match
the field names of SamplingParams
Returns:
SamplingParams: A new instance initialized with values from the dictionary
"""
return cls(**{
field.name: req_dict[field.name] if field.name in req_dict else field.default
for field in fields(cls)
})
"""Create instance from command line arguments"""
return cls(
**{
field.name:
req_dict[field.name] if field.name in
req_dict else field.default
for field in fields(cls)
})
@classmethod
def from_optional(cls,
n,
best_of,
presence_penalty,
frequency_penalty,
repetition_penalty,
temperature,
top_p,
seed=None,
stop=None,
stop_token_ids=None,
max_tokens=None,
min_tokens=1,
logprobs=None,
bad_words=None
) -> "SamplingParams":
"""Create a SamplingParams instance from optional arguments with default fallbacks.
Args:
n: Number of output sequences (default: 1)
best_of: Number of sequences to generate before selecting best (default: None)
presence_penalty: Penalty for new tokens (default: 0.0)
frequency_penalty: Penalty based on token frequency (default: 0.0)
repetition_penalty: Penalty for repeated tokens (default: 1.0)
temperature: Sampling temperature (default: 1.0)
top_p: Nucleus sampling probability (default: 0.7)
seed: Random seed (default: random)
stop: Stop sequences (default: None)
stop_token_ids: Stop token IDs (default: None)
max_tokens: Maximum tokens to generate (default: 8192)
min_tokens: Minimum tokens before stopping (default: 1)
logprobs: Number of logprobs to return (default: None)
bad_words: List of banned words (default: None)
Returns:
SamplingParams: A new instance with provided or default values
"""
return cls(
n=1 if n is None else n,
best_of=best_of,
presence_penalty=presence_penalty if presence_penalty is not None else 0.0,
frequency_penalty=frequency_penalty if frequency_penalty is not None else 0.0,
repetition_penalty=repetition_penalty if repetition_penalty is not None else 1.0,
temperature=temperature if temperature is not None else 1.0,
top_p=top_p if top_p is not None else 0.7,
seed=seed,
stop=stop,
stop_token_ids=stop_token_ids,
max_tokens=max_tokens if max_tokens is not None else 8192,
min_tokens=min_tokens,
logprobs=logprobs,
bad_words=bad_words
)
n,
best_of,
presence_penalty,
frequency_penalty,
repetition_penalty,
temperature,
top_p,
seed=None,
stop=None,
stop_token_ids=None,
max_tokens=None,
reasoning_max_tokens=None,
min_tokens=1,
logprobs=None,
bad_words=None) -> "SamplingParams":
"""Create instance from command line arguments"""
return cls(n=1 if n is None else n,
best_of=best_of,
presence_penalty=presence_penalty
if presence_penalty is not None else 0.0,
frequency_penalty=frequency_penalty
if frequency_penalty is not None else 0.0,
repetition_penalty=repetition_penalty
if repetition_penalty is not None else 1.0,
temperature=temperature if temperature is not None else 1.0,
top_p=top_p if top_p is not None else 0.7,
seed=seed,
stop=stop,
stop_token_ids=stop_token_ids,
max_tokens=max_tokens if max_tokens is not None else 8192,
reasoning_max_tokens=reasoning_max_tokens,
min_tokens=min_tokens,
logprobs=logprobs,
bad_words=bad_words)
def __post_init__(self):
"""Initialize sampling parameters after instance creation.
Sets a random seed if none provided and validates all parameters.
"""
if self.seed is None:
self.seed = random.randint(0, 922337203685477580)
if self.max_tokens is not None and self.reasoning_max_tokens is None:
self.reasoning_max_tokens = max(int(self.max_tokens * 0.8), 1)
self._verify_args()
def _verify_args(self) -> None:
"""Validate all sampling parameters.
Raises:
ValueError: If any parameter is outside its valid range or of incorrect type
"""
if not isinstance(self.n, int):
raise ValueError(f"n must be an int, but is of type {type(self.n)}")
raise ValueError(
f"n must be an int, but is of type {type(self.n)}")
if self.n < 1:
raise ValueError(f"n must be at least 1, got {self.n}.")
if not -2.0 <= self.presence_penalty <= 2.0:
if self.presence_penalty is not None and (
not -2.0 <= self.presence_penalty <= 2.0):
raise ValueError("presence_penalty must be in [-2, 2], got "
f"{self.presence_penalty}.")
if not -2.0 <= self.frequency_penalty <= 2.0:
if self.frequency_penalty is not None and (
not -2.0 <= self.frequency_penalty <= 2.0):
raise ValueError("frequency_penalty must be in [-2, 2], got "
f"{self.frequency_penalty}.")
if self.repetition_penalty <= 0.0:
if self.repetition_penalty is not None and self.repetition_penalty <= 0.0:
raise ValueError(
"repetition_penalty must be greater than zero, got "
f"{self.repetition_penalty}.")
if self.temperature < 0.0:
if self.temperature is not None and self.temperature < 0.0:
raise ValueError(
f"temperature must be non-negative, got {self.temperature}.")
if not 0.0 <= self.top_p <= 1.0:
if self.top_p is not None and not 0.0 <= self.top_p <= 1.0:
raise ValueError(f"top_p must be in [0, 1], got {self.top_p}.")
if self.max_tokens is not None and self.max_tokens < 1:
raise ValueError(
f"max_tokens must be at least 1, got {self.max_tokens}.")
if self.reasoning_max_tokens is not None and self.reasoning_max_tokens > self.max_tokens:
raise ValueError(
f"reasoning_max_tokens must be less than max_tokens, got {self.reasoning_max_tokens}.")
if self.min_tokens < 0:
raise ValueError(f"min_tokens must be greater than or equal to 0, "
f"got {self.min_tokens}.")
@@ -215,33 +193,17 @@ class SamplingParams:
raise ValueError("seed must be in [0, 922337203685477580], got "
f"{self.seed}.")
def update_from_tokenizer(self, tokenizer):
"""Update sampling parameters based on tokenizer configuration.
Note: Currently a placeholder for future implementation of:
- Stop tokens handling
- Bad words filtering
Args:
tokenizer: The tokenizer instance to use for configuration
"""
# TODO: Implement stop tokens and bad words support
# Currently stop tokens and bad words are not supported yet
"""
pass
@dataclass
class BeamSearchParams:
"""Parameters for beam search text generation.
Args:
beam_width: Number of beams to maintain during search
max_tokens: Maximum number of tokens to generate
ignore_eos: Whether to ignore EOS tokens (default: False)
temperature: Sampling temperature (0 means greedy, default: 0.0)
length_penalty: Penalty applied to length (1.0 means no penalty, default: 1.0)
include_stop_str_in_output: Whether to include stop strings in output (default: False)
"""
"""Beam search parameters for text generation."""
beam_width: int
max_tokens: int
ignore_eos: bool = False