mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Feature] Support block scheduler v1 for FD (#2928)
* Support FD block scheduler v1 * Support FD block scheduler v1 * Support FD block scheduler v1 * Fix according to copilot review * Fix according to review * Remove is_dummy * Fix bug when real_bsz=1 * Fix infer first token cost time --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -130,6 +130,11 @@ class EngineArgs:
|
||||
Ratio of tokens to process in a block.
|
||||
"""
|
||||
|
||||
prealloc_dec_block_slot_num_threshold: int = 5
|
||||
"""
|
||||
Token slot threshold for preallocating decoder blocks.
|
||||
"""
|
||||
|
||||
dist_init_ip: Optional[str] = None
|
||||
"""
|
||||
The master node ip of multinode deployment
|
||||
@@ -525,10 +530,14 @@ class EngineArgs:
|
||||
)
|
||||
|
||||
cache_group.add_argument(
|
||||
"--swap-space",
|
||||
type=float,
|
||||
default=EngineArgs.swap_space,
|
||||
help="The amount of CPU memory to offload to.",
|
||||
"--swap-space", type=float, default=EngineArgs.swap_space, help="The amount of CPU memory to offload to."
|
||||
)
|
||||
|
||||
cache_group.add_argument(
|
||||
"--prealloc-dec-block-slot-num-threshold",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of token slot threadshold to allocate next blocks for decoding.",
|
||||
)
|
||||
|
||||
cache_group.add_argument(
|
||||
@@ -784,6 +793,7 @@ class EngineArgs:
|
||||
gpu_memory_utilization=self.gpu_memory_utilization,
|
||||
num_gpu_blocks_override=self.num_gpu_blocks_override,
|
||||
kv_cache_ratio=self.kv_cache_ratio,
|
||||
prealloc_dec_block_slot_num_threshold=self.prealloc_dec_block_slot_num_threshold,
|
||||
enable_prefix_caching=self.enable_prefix_caching,
|
||||
swap_space=self.swap_space,
|
||||
cache_queue_port=self.cache_queue_port,
|
||||
|
@@ -171,6 +171,7 @@ class CacheConfig:
|
||||
Overrides profiled num_gpu_blocks if provided.
|
||||
kv_cache_ratio (float): Ratio for calculating the maximum block number.
|
||||
enc_dec_block_num (int): Number of encoder-decoder blocks.
|
||||
prealloc_dec_block_slot_num_threshold (int): Number of token slot threadshold to allocate next blocks for decoding.
|
||||
enable_prefix_caching (bool): Flag to enable prefix caching.
|
||||
"""
|
||||
|
||||
@@ -183,6 +184,7 @@ class CacheConfig:
|
||||
swap_space: Optional[int] = None,
|
||||
kv_cache_ratio: float = 0.75,
|
||||
enc_dec_block_num: int = 2,
|
||||
prealloc_dec_block_slot_num_threshold: int = 5,
|
||||
tensor_parallel_size: int = 1,
|
||||
enable_prefix_caching=False,
|
||||
enable_ssd_cache=False,
|
||||
@@ -204,6 +206,7 @@ class CacheConfig:
|
||||
num_cpu_blocks (Optional[int]): Number of CPU blocks.
|
||||
kv_cache_ratio (float): Ratio for max block calculation.
|
||||
enc_dec_block_num (int): Number of encoder-decoder blocks.
|
||||
prealloc_dec_block_slot_num_threshold (int): Number of token slot threadshold to allocate next blocks for decoding, used when ENABLE_V1_KVCACHE_SCHEDULER=1.
|
||||
enable_prefix_caching (bool): Enable prefix caching.
|
||||
"""
|
||||
self.block_size = block_size
|
||||
@@ -211,6 +214,7 @@ class CacheConfig:
|
||||
self.num_gpu_blocks_override = num_gpu_blocks_override
|
||||
self.kv_cache_ratio = kv_cache_ratio
|
||||
self.enc_dec_block_num = enc_dec_block_num
|
||||
self.prealloc_dec_block_slot_num_threshold = prealloc_dec_block_slot_num_threshold
|
||||
self.cache_dtype = cache_dtype
|
||||
if hasattr(model_cfg, "quantization_config"):
|
||||
self.cache_dtype = model_cfg.quantization_config.get("kv_cache_quant_type", cache_dtype)
|
||||
|
@@ -28,6 +28,7 @@ import time
|
||||
import traceback
|
||||
import uuid
|
||||
import weakref
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
@@ -40,6 +41,7 @@ from fastdeploy.engine.args_utils import EngineArgs
|
||||
from fastdeploy.engine.expert_service import start_expert_service
|
||||
from fastdeploy.engine.request import Request, RequestOutput
|
||||
from fastdeploy.engine.resource_manager import ResourceManager
|
||||
from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1
|
||||
from fastdeploy.input.preprocess import InputPreprocessor
|
||||
from fastdeploy.inter_communicator import (
|
||||
EngineCacheQueue,
|
||||
@@ -52,7 +54,7 @@ from fastdeploy.metrics.trace_util import start_span, start_span_request
|
||||
from fastdeploy.model_executor.guided_decoding import schema_checker
|
||||
from fastdeploy.output.token_processor import TokenProcessor, WarmUpTokenProcessor
|
||||
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
|
||||
from fastdeploy.utils import EngineError, console_logger, llm_logger
|
||||
from fastdeploy.utils import EngineError, console_logger, envs, llm_logger
|
||||
|
||||
|
||||
class LLMEngine:
|
||||
@@ -108,7 +110,18 @@ class LLMEngine:
|
||||
|
||||
self.start_queue_service()
|
||||
|
||||
self.resource_manager = ResourceManager(cfg.max_num_seqs, cfg, cfg.tensor_parallel_size, cfg.splitwise_role)
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
self.resource_manager = ResourceManagerV1(
|
||||
cfg.max_num_seqs, cfg, cfg.tensor_parallel_size, cfg.splitwise_role
|
||||
)
|
||||
if cfg.splitwise_role != "mixed":
|
||||
raise NotImplementedError(
|
||||
"Currently ENABLE_V1_KVCACHE_SCHEDULER=1 only supported in mixed sampling now."
|
||||
)
|
||||
else:
|
||||
self.resource_manager = ResourceManager(
|
||||
cfg.max_num_seqs, cfg, cfg.tensor_parallel_size, cfg.splitwise_role
|
||||
)
|
||||
|
||||
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.cfg.engine_worker_queue_port)
|
||||
|
||||
@@ -203,7 +216,10 @@ class LLMEngine:
|
||||
|
||||
self.token_processor.tasks_queue = self.engine_worker_queue
|
||||
|
||||
self.insert_task_to_worker_thread = threading.Thread(target=self._insert_task_to_worker, daemon=True)
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
self.insert_task_to_worker_thread = threading.Thread(target=self._scheduler_task_to_worker_v1, daemon=True)
|
||||
else:
|
||||
self.insert_task_to_worker_thread = threading.Thread(target=self._insert_task_to_worker, daemon=True)
|
||||
self.insert_task_to_worker_thread.start()
|
||||
|
||||
if self.api_server_pid is not None:
|
||||
@@ -343,6 +359,56 @@ class LLMEngine:
|
||||
err_msg = f"Error happend while insert task to engine: {e}, {traceback.format_exc()!s}."
|
||||
llm_logger.error(err_msg)
|
||||
|
||||
def _scheduler_task_to_worker_v1(self):
|
||||
"""
|
||||
Insert tasks to worker with scheduler v1 (ENABLE_V1_KVCACHE_SCHEDULER=1).
|
||||
"""
|
||||
get_request_pool = ThreadPoolExecutor(max_workers=1)
|
||||
is_fetching = False
|
||||
|
||||
def _fetch_request():
|
||||
nonlocal is_fetching
|
||||
is_fetching = True
|
||||
num_prefill_batch = min(
|
||||
int(self.resource_manager.available_batch()),
|
||||
self.cfg.max_prefill_batch,
|
||||
)
|
||||
tasks = self.scheduler.get_requests(
|
||||
available_blocks=self.resource_manager.available_block_num(),
|
||||
block_size=self.cfg.cache_config.block_size,
|
||||
reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num,
|
||||
max_num_batched_tokens=self.cfg.max_model_len,
|
||||
batch=num_prefill_batch,
|
||||
)
|
||||
# Fetch requests and add them to the scheduling queue
|
||||
for task in tasks:
|
||||
self.resource_manager.add_request(task)
|
||||
is_fetching = False
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
if self.engine_worker_queue.num_tasks() > 0:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
if (
|
||||
len(self.resource_manager.waiting) == 0
|
||||
and (not is_fetching)
|
||||
and self.exist_prefill_task_signal.value[0] == 0
|
||||
):
|
||||
get_request_pool.submit(_fetch_request)
|
||||
# 2. Schedule requests
|
||||
tasks = self.resource_manager.schedule()
|
||||
# 3. Send to engine
|
||||
if tasks:
|
||||
self.resource_manager.get_real_bsz()
|
||||
self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz))
|
||||
else:
|
||||
time.sleep(0.005)
|
||||
|
||||
except Exception as e:
|
||||
err_msg = "Error happend while insert task to engine: {}, {}.".format(e, str(traceback.format_exc()))
|
||||
llm_logger.error(err_msg)
|
||||
|
||||
def _insert_zmq_task_to_scheduler(self):
|
||||
if self.api_server_pid is None:
|
||||
return
|
||||
|
@@ -18,6 +18,7 @@ from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import asdict, dataclass, fields
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -27,6 +28,19 @@ from fastdeploy.utils import data_processor_logger
|
||||
from fastdeploy.worker.output import LogprobsLists
|
||||
|
||||
|
||||
class RequestStatus(Enum):
|
||||
WAITING = 0
|
||||
RUNNING = 1
|
||||
PREEMPTED = 2
|
||||
FINISHED = 3
|
||||
|
||||
|
||||
class RequestType(Enum):
|
||||
PREFILL = 0
|
||||
DECODE = 1
|
||||
PREEMPTED = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class Request:
|
||||
def __init__(
|
||||
@@ -93,6 +107,15 @@ class Request:
|
||||
self.enable_thinking = enable_thinking
|
||||
self.trace_carrier = trace_carrier
|
||||
|
||||
# token num
|
||||
self.block_tables = []
|
||||
self.output_token_ids = []
|
||||
self.num_computed_tokens = 0
|
||||
# status
|
||||
self.status = RequestStatus.WAITING
|
||||
self.task_type = RequestType.PREFILL
|
||||
self.idx = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict):
|
||||
data_processor_logger.debug(f"{d}")
|
||||
@@ -125,6 +148,21 @@ class Request:
|
||||
trace_carrier=d.get("trace_carrier", {}),
|
||||
)
|
||||
|
||||
@property
|
||||
def num_total_tokens(self):
|
||||
"""
|
||||
Total tokens of the request, include prompt tokens and generated tokens.
|
||||
"""
|
||||
return self.prompt_token_ids_len + len(self.output_token_ids)
|
||||
|
||||
def __eq__(self, other):
|
||||
"""
|
||||
EQ operator.
|
||||
"""
|
||||
if not isinstance(other, Request):
|
||||
return False
|
||||
return self.request_id == other.request_id
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""convert Request into a serializable dict"""
|
||||
data = {
|
||||
|
15
fastdeploy/engine/sched/__init__.py
Normal file
15
fastdeploy/engine/sched/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
261
fastdeploy/engine/sched/resource_manager_v1.py
Normal file
261
fastdeploy/engine/sched/resource_manager_v1.py
Normal file
@@ -0,0 +1,261 @@
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from collections.abc import Iterable
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
|
||||
from fastdeploy.engine.request import Request, RequestStatus, RequestType
|
||||
from fastdeploy.engine.resource_manager import ResourceManager
|
||||
from fastdeploy.utils import llm_logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScheduledDecodeTask:
|
||||
"""
|
||||
Task for allocating new blocks to decode.
|
||||
"""
|
||||
|
||||
idx: int
|
||||
request_id: str
|
||||
block_tables: list[int]
|
||||
task_type: RequestType = RequestType.DECODE
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScheduledPreemptTask:
|
||||
"""
|
||||
Task for terminating inference to recycle resource.
|
||||
"""
|
||||
|
||||
idx: int
|
||||
request_id: str
|
||||
task_type: RequestType = RequestType.PREEMPTED
|
||||
|
||||
|
||||
class ResourceManagerV1(ResourceManager):
|
||||
"""
|
||||
Resource manager for scheduler v1.
|
||||
In scheduler v1, all gpu blocks are managed by PrefixCacheManager.
|
||||
Tasks sent to worker are divided into 3 types, PREFILL、DECODE and PREEMPTED.
|
||||
For prefill task, the worker infer with one step and then stopped for this query if not all prompt tokens are computed.
|
||||
For decode task, the work continues to decode until allocated blocks are exhausted.
|
||||
For preempted task, the work reset all inputs to terminate the inference.
|
||||
"""
|
||||
|
||||
def __init__(self, max_num_seqs, config, tensor_parallel_size, splitwise_role, local_data_parallel_id=0):
|
||||
super(ResourceManagerV1, self).__init__(
|
||||
max_num_seqs, config, tensor_parallel_size, splitwise_role, local_data_parallel_id
|
||||
)
|
||||
# req_id -> Request
|
||||
self.config = config
|
||||
self.requests: dict[str, Request] = {}
|
||||
# Priority queues for requests.
|
||||
self.waiting: deque[Request] = deque()
|
||||
self.running: list[Request] = []
|
||||
self.finish_execution_pool = ThreadPoolExecutor(max_workers=1)
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def allocated_slots(self, request: Request):
|
||||
return len(request.block_tables) * self.config.cache_config.block_size
|
||||
|
||||
def get_new_block_nums(self, request: Request, num_new_tokens: int):
|
||||
return (
|
||||
request.num_computed_tokens + num_new_tokens + self.config.cache_config.block_size - 1
|
||||
) // self.config.cache_config.block_size - len(request.block_tables)
|
||||
|
||||
def _prepare_prefill_task(self, request, new_token_num):
|
||||
request.prefill_start_index = request.num_computed_tokens
|
||||
request.prefill_end_index = request.num_computed_tokens + new_token_num
|
||||
request.task_type = RequestType.PREFILL
|
||||
return request
|
||||
|
||||
def _prepare_decode_task(self, request):
|
||||
return ScheduledDecodeTask(idx=request.idx, request_id=request.request_id, block_tables=request.block_tables)
|
||||
|
||||
def _prepare_preempt_task(self, request):
|
||||
return ScheduledPreemptTask(idx=request.idx, request_id=request.request_id)
|
||||
|
||||
def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_reqs):
|
||||
can_schedule = True
|
||||
while True:
|
||||
if not self.cache_manager.can_allocate_gpu_blocks(num_new_blocks):
|
||||
preempted_req = self.running.pop()
|
||||
preempted_req.status = RequestStatus.PREEMPTED
|
||||
preempted_req.num_computed_tokens = 0
|
||||
self._free_blocks(preempted_req)
|
||||
self.waiting.appendleft(preempted_req)
|
||||
preempted_reqs.append(preempted_req)
|
||||
scheduled_reqs.append(self._prepare_preempt_task(preempted_req))
|
||||
if preempted_req == request:
|
||||
# No more request to preempt.
|
||||
can_schedule = False
|
||||
break
|
||||
else:
|
||||
# The request can be scheduled.
|
||||
can_schedule = True
|
||||
break
|
||||
return can_schedule
|
||||
|
||||
def schedule(self):
|
||||
with self.lock:
|
||||
scheduled_reqs: list[Request] = []
|
||||
preempted_reqs: list[Request] = []
|
||||
token_budget = self.config.max_num_batched_tokens
|
||||
|
||||
# First, schedule the RUNNING requests.
|
||||
req_index = 0
|
||||
num_decoding_req_nums = 0
|
||||
while req_index < len(self.running) and token_budget > 0:
|
||||
request = self.running[req_index]
|
||||
if request.num_computed_tokens >= request.prompt_token_ids_len: # to be decoding
|
||||
if request.num_total_tokens > request.prompt_token_ids_len: # has generated tokens
|
||||
request.num_computed_tokens = request.num_total_tokens - 1
|
||||
if (
|
||||
self.allocated_slots(request) - request.num_total_tokens
|
||||
<= self.config.cache_config.prealloc_dec_block_slot_num_threshold
|
||||
):
|
||||
# Allocation for next decoding blocks
|
||||
if self.cache_manager.can_allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num):
|
||||
llm_logger.debug(
|
||||
f"schedule decoding task: {request} request.num_total_tokens {request.num_total_tokens} request.num_computed_tokens {request.num_computed_tokens}"
|
||||
)
|
||||
request.block_tables.extend(
|
||||
self.cache_manager.allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num)
|
||||
)
|
||||
# Prepare decoding task
|
||||
scheduled_reqs.append(self._prepare_decode_task(request))
|
||||
else:
|
||||
# Not enough blocks to allocate, trigger preemption
|
||||
can_schedule = self._trigger_preempt(
|
||||
request, self.config.cache_config.enc_dec_block_num, preempted_reqs, scheduled_reqs
|
||||
)
|
||||
if not can_schedule:
|
||||
break
|
||||
# Allocation for next decoding blocks
|
||||
request.block_tables.extend(
|
||||
self.cache_manager.allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num)
|
||||
)
|
||||
# Prepare decoding task
|
||||
scheduled_reqs.append(self._prepare_decode_task(request))
|
||||
num_decoding_req_nums += 1
|
||||
token_budget -= 1
|
||||
else: # need to prefill
|
||||
llm_logger.debug(
|
||||
f"scheduler prefill task: {request} request.prompt_token_ids_len {request.prompt_token_ids_len} request.num_computed_tokens {request.num_computed_tokens}"
|
||||
)
|
||||
num_new_tokens = request.prompt_token_ids_len - request.num_computed_tokens
|
||||
num_new_tokens = min(num_new_tokens, token_budget)
|
||||
num_new_block = self.get_new_block_nums(request, num_new_tokens)
|
||||
# Allocate blocks to prefill
|
||||
if self.cache_manager.can_allocate_gpu_blocks(num_new_block):
|
||||
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(num_new_block))
|
||||
# Prepare prefill task
|
||||
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
|
||||
else:
|
||||
can_schedule = self._trigger_preempt(request, num_new_block, preempted_reqs, scheduled_reqs)
|
||||
if not can_schedule:
|
||||
break
|
||||
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(num_new_block))
|
||||
# Prepare prefill task
|
||||
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
|
||||
token_budget -= num_new_tokens
|
||||
request.num_computed_tokens += num_new_tokens
|
||||
req_index += 1
|
||||
# schedule the WAITING requests.
|
||||
if not preempted_reqs:
|
||||
while self.waiting and token_budget > 0:
|
||||
if len(self.running) == self.max_num_seqs:
|
||||
break
|
||||
request = self.waiting[0]
|
||||
if request.status == RequestStatus.WAITING:
|
||||
num_new_tokens = request.num_total_tokens - request.num_computed_tokens
|
||||
num_new_tokens = min(num_new_tokens, token_budget)
|
||||
num_new_block = self.get_new_block_nums(request, num_new_tokens)
|
||||
# Allocate blocks to prefill
|
||||
if self.cache_manager.can_allocate_gpu_blocks(num_new_block):
|
||||
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(num_new_block))
|
||||
self.waiting.popleft()
|
||||
self.running.append(request)
|
||||
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
|
||||
request.inference_start_time = time.time()
|
||||
request.schedule_start_time = time.time()
|
||||
token_budget -= num_new_tokens
|
||||
request.num_computed_tokens += num_new_tokens
|
||||
request.status = RequestStatus.RUNNING
|
||||
allocated_position = self.get_available_position()
|
||||
request.idx = allocated_position
|
||||
self.tasks_list[allocated_position] = request
|
||||
self.stop_flags[allocated_position] = False
|
||||
self.req_dict[request.request_id] = allocated_position
|
||||
else:
|
||||
break
|
||||
elif request.status == RequestStatus.PREEMPTED:
|
||||
num_new_tokens = request.num_total_tokens - request.num_computed_tokens
|
||||
num_new_tokens = min(num_new_tokens, token_budget)
|
||||
num_new_block = self.get_new_block_nums(request, num_new_tokens)
|
||||
# Allocate blocks to prefill
|
||||
if self.cache_manager.can_allocate_gpu_blocks(num_new_block):
|
||||
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(num_new_block))
|
||||
self.waiting.popleft()
|
||||
self.running.append(request)
|
||||
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
|
||||
token_budget -= num_new_tokens
|
||||
request.num_computed_tokens += num_new_tokens
|
||||
request.status = RequestStatus.RUNNING
|
||||
else:
|
||||
break
|
||||
else:
|
||||
llm_logger.error("Unknown request status type")
|
||||
if scheduled_reqs:
|
||||
llm_logger.debug(f"schedued_reqs: {scheduled_reqs}")
|
||||
return scheduled_reqs
|
||||
|
||||
def get_available_position(self) -> int:
|
||||
position = 0
|
||||
while position < self.max_num_seqs:
|
||||
if self.stop_flags[position] is True:
|
||||
return position
|
||||
position += 1
|
||||
raise RuntimeError("No available position is available for new request")
|
||||
|
||||
def get_real_bsz(self) -> int:
|
||||
for i in range(self.max_num_seqs - 1, -1, -1):
|
||||
if not self.stop_flags[i]:
|
||||
self.real_bsz = i + 1
|
||||
break
|
||||
return self.real_bsz
|
||||
|
||||
def add_request(self, request: Request) -> None:
|
||||
self.waiting.append(request)
|
||||
self.requests[request.request_id] = request
|
||||
|
||||
def _free_blocks(self, request: Request):
|
||||
self.cache_manager.recycle_gpu_blocks(request.block_tables)
|
||||
request.block_tables = []
|
||||
|
||||
def finish_requests_async(self, request_ids: Union[str, Iterable[str]]):
|
||||
return self.finish_execution_pool.submit(self.finish_requests, request_ids)
|
||||
|
||||
def finish_requests(self, request_ids: Union[str, Iterable[str]]):
|
||||
llm_logger.info(f"recycle resources for requests: {request_ids}")
|
||||
try:
|
||||
with self.lock:
|
||||
if isinstance(request_ids, str):
|
||||
request_ids = (request_ids,)
|
||||
else:
|
||||
request_ids = set(request_ids)
|
||||
for req_id in request_ids:
|
||||
request = self.requests.get(req_id)
|
||||
if request is None:
|
||||
# Invalid request ID.
|
||||
continue
|
||||
request.status = RequestStatus.FINISHED
|
||||
self.running.remove(request)
|
||||
self._free_blocks(request)
|
||||
self.tasks_list[request.idx] = None
|
||||
self.stop_flags[request.idx] = True
|
||||
del self.requests[req_id]
|
||||
except Exception as e:
|
||||
llm_logger.error(e)
|
Reference in New Issue
Block a user