diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index c20455e06..d74a77121 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -284,6 +284,32 @@ void UpdateInputes(const paddle::Tensor &stop_flags, const paddle::Tensor &next_tokens, const paddle::Tensor &is_block_step); +void UpdateInputesV1(const paddle::Tensor &stop_flags, + const paddle::Tensor ¬_need_stop, // only on cpu + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &step_seq_lens_decoder, + const paddle::Tensor &prompt_lens, + const paddle::Tensor &topk_ids, + const paddle::Tensor &input_ids, + const paddle::Tensor &block_tables, + const paddle::Tensor &stop_nums, + const paddle::Tensor &next_tokens, + const paddle::Tensor &is_block_step, + const int block_size); + +void RecoverDecodeTask(const paddle::Tensor &stop_flags, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &step_seq_lens_decoder, + const paddle::Tensor &block_tables, + const paddle::Tensor &is_block_step, + const int block_size); + + + paddle::Tensor GroupSwigluWithMasked(const paddle::Tensor &fc1_out_tensor, const paddle::Tensor &token_nums_per_expert); @@ -941,6 +967,18 @@ PYBIND11_MODULE(fastdeploy_ops, m) { */ m.def("update_inputs", &UpdateInputes, "update_inputs function"); + /** + * update_inputs_v1.cu + * update_inputs_v1 + */ + m.def("update_inputs_v1", &UpdateInputesV1, "update inputs for scheduler v1 function"); + + /** + * recover_decode_task.cu + * recover_decode_task + */ + m.def("recover_decode_task", &RecoverDecodeTask, "recover decode task for scheduler v1 function"); + /** * extract_text_token_output.cu * extract_text_token_output diff --git a/custom_ops/gpu_ops/recover_decode_task.cu b/custom_ops/gpu_ops/recover_decode_task.cu new file mode 100644 index 000000000..88c7dd51c --- /dev/null +++ b/custom_ops/gpu_ops/recover_decode_task.cu @@ -0,0 +1,91 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" + +__global__ void recover_decode_task(bool *stop_flags, + int *seq_lens_this_time, + int *seq_lens_encoder, + int *seq_lens_decoder, + int *step_seq_lens_decoder, + int *block_tables, + bool *is_block_step, + const int bsz, + const int block_num_per_seq, + const int block_size) { + int thread_idx = threadIdx.x; + if (thread_idx < bsz) { + if(is_block_step[thread_idx] == true) { + int *block_table_now = block_tables + thread_idx * block_num_per_seq; + if (block_table_now[step_seq_lens_decoder[thread_idx] / block_size] != -1) { + // can be recovered for decoding + is_block_step[thread_idx] = false; + seq_lens_this_time[thread_idx]= 1; + stop_flags[thread_idx] = false; + seq_lens_encoder[thread_idx] = 0; + seq_lens_decoder[thread_idx] = step_seq_lens_decoder[thread_idx]; + } + } + } +} + +void RecoverDecodeTask(const paddle::Tensor &stop_flags, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &step_seq_lens_decoder, + const paddle::Tensor &block_tables, + const paddle::Tensor &is_block_step, + const int block_size) { +#ifdef PADDLE_WITH_CUSTOM_DEVICE + auto dev_ctx = static_cast(paddle::experimental::DeviceContextPool::Instance().Get(seq_lens_this_time.place())); + auto cu_stream = dev_ctx->stream(); +#else + auto cu_stream = seq_lens_this_time.stream(); +#endif + const int bsz = seq_lens_this_time.shape()[0]; + const int block_num_per_seq = block_tables.shape()[1]; + recover_decode_task<<<1, 1024, 0, cu_stream>>>( + const_cast(stop_flags.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(step_seq_lens_decoder.data()), + const_cast(block_tables.data()), + const_cast(is_block_step.data()), + bsz, + block_num_per_seq, + block_size); +} + +PD_BUILD_STATIC_OP(recover_decode_task) + .Inputs({"stop_flags", + "seq_lens_this_time", + "seq_lens_encoder", + "seq_lens_decoder", + "step_seq_lens_decoder", + "block_tables", + "is_block_step"}) + .Attrs({"block_size: int"}) + .Outputs({"seq_lens_this_time_out", + "seq_lens_encoder_out", + "seq_lens_decoder_out", + "stop_flags_out", + "is_block_step_out"}) + .SetInplaceMap({{"seq_lens_this_time", "seq_lens_this_time_out"}, + {"seq_lens_encoder", "seq_lens_encoder_out"}, + {"seq_lens_decoder", "seq_lens_decoder_out"}, + {"stop_flags", "stop_flags_out"}, + {"is_block_step", "is_block_step_out"}}) + .SetKernelFn(PD_KERNEL(RecoverDecodeTask)); diff --git a/custom_ops/gpu_ops/update_inputs_v1.cu b/custom_ops/gpu_ops/update_inputs_v1.cu new file mode 100644 index 000000000..9229fdcf0 --- /dev/null +++ b/custom_ops/gpu_ops/update_inputs_v1.cu @@ -0,0 +1,176 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" + +template +__global__ void update_inputs_kernel_v1(bool *not_need_stop, + int *seq_lens_this_time, + int *seq_lens_encoder, + int *seq_lens_decoder, + int *step_seq_lens_decoder, + int64_t *prompt_lens, + int64_t *topk_ids, + int64_t *input_ids, + int *block_tables, + const int64_t *stop_nums, + bool *stop_flags, + bool *is_block_step, + const int64_t *next_tokens, + const int bsz, + const int max_bsz, + const int input_ids_stride, + const int block_num_per_seq, + const int block_size) { + int thread_idx = threadIdx.x; + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + bool stop_flag_now = false; + int64_t stop_flag_now_int = 0; + if (thread_idx < max_bsz) { + if (thread_idx < bsz) { + stop_flag_now = stop_flags[thread_idx]; + stop_flag_now_int = static_cast(stop_flag_now); + } else { + stop_flag_now_int = 1; + } + } + if (thread_idx < bsz) { + if(stop_flag_now) { + seq_lens_this_time[thread_idx] = 0; // stop at next step + seq_lens_decoder[thread_idx] = 0; + seq_lens_encoder[thread_idx] = 0; + } else { + if (seq_lens_this_time[thread_idx] + seq_lens_decoder[thread_idx] >= prompt_lens[thread_idx]) { + // decoding + seq_lens_decoder[thread_idx] += seq_lens_this_time[thread_idx]; + seq_lens_this_time[thread_idx] = 1; + seq_lens_encoder[thread_idx] = 0; + int64_t *input_ids_now = input_ids + thread_idx * input_ids_stride; + input_ids_now[0] = next_tokens[thread_idx]; + + // to judge whether block is not enough + int *block_table_now = block_tables + thread_idx * block_num_per_seq; + if (seq_lens_this_time[thread_idx] != 0 && block_table_now[seq_lens_decoder[thread_idx] / block_size] == -1) { + // should be scheduled by server + is_block_step[thread_idx] = true; + seq_lens_this_time[thread_idx]= 0; + stop_flags[thread_idx] = true; + step_seq_lens_decoder[thread_idx] = seq_lens_decoder[thread_idx]; + seq_lens_decoder[thread_idx] = 0; + stop_flag_now_int = 1; + } + } else + { + stop_flags[thread_idx] = true; + seq_lens_this_time[thread_idx] = 0; + seq_lens_decoder[thread_idx] = 0; + seq_lens_encoder[thread_idx] = 0; + topk_ids[thread_idx] = -1; + stop_flag_now_int = 1; + } + } + } + __syncthreads(); + int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int); + if (thread_idx == 0) { + not_need_stop[0] = stop_sum < stop_nums[0]; + } +} + +void UpdateInputesV1(const paddle::Tensor &stop_flags, + const paddle::Tensor ¬_need_stop, // only on cpu + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &step_seq_lens_decoder, + const paddle::Tensor &prompt_lens, + const paddle::Tensor &topk_ids, + const paddle::Tensor &input_ids, + const paddle::Tensor &block_tables, + const paddle::Tensor &stop_nums, + const paddle::Tensor &next_tokens, + const paddle::Tensor &is_block_step, + const int block_size) { +#ifdef PADDLE_WITH_CUSTOM_DEVICE + auto dev_ctx = static_cast(paddle::experimental::DeviceContextPool::Instance().Get(input_ids.place())); + auto cu_stream = dev_ctx->stream(); +#else + auto cu_stream = input_ids.stream(); +#endif + const int max_bsz = stop_flags.shape()[0]; + const int now_bsz = seq_lens_this_time.shape()[0]; + const int input_ids_stride = input_ids.shape()[1]; + const int block_num_per_seq = block_tables.shape()[1]; + auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false); + update_inputs_kernel_v1<1024><<<1, 1024, 0, cu_stream>>>( + const_cast(not_need_stop_gpu.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(step_seq_lens_decoder.data()), + const_cast(prompt_lens.data()), + const_cast(topk_ids.data()), + const_cast(input_ids.data()), + const_cast(block_tables.data()), + stop_nums.data(), + const_cast(stop_flags.data()), + const_cast(is_block_step.data()), + next_tokens.data(), + now_bsz, + max_bsz, + input_ids_stride, + block_num_per_seq, + block_size); + auto not_need_stop_cpu = + not_need_stop_gpu.copy_to(not_need_stop.place(), false); + bool *not_need_stop_data = const_cast(not_need_stop.data()); + not_need_stop_data[0] = not_need_stop_cpu.data()[0]; +} + +PD_BUILD_STATIC_OP(update_inputs_v1) + .Inputs({"stop_flags", + "not_need_stop", + "seq_lens_this_time", + "seq_lens_encoder", + "seq_lens_decoder", + "step_seq_lens_decoder", + "prompt_lens", + "topk_ids", + "input_ids", + "block_tables", + "stop_nums", + "next_tokens", + "is_block_step"}) + .Attrs({"block_size: int"}) + .Outputs({"not_need_stop_out", + "seq_lens_this_time_out", + "seq_lens_encoder_out", + "seq_lens_decoder_out", + "step_seq_lens_decoder_out", + "topk_ids_out", + "input_ids_out", + "stop_flags_out", + "is_block_step_out"}) + .SetInplaceMap({{"not_need_stop", "not_need_stop_out"}, + {"seq_lens_this_time", "seq_lens_this_time_out"}, + {"seq_lens_encoder", "seq_lens_encoder_out"}, + {"seq_lens_decoder", "seq_lens_decoder_out"}, + {"topk_ids", "topk_ids_out"}, + {"input_ids", "input_ids_out"}, + {"stop_flags", "stop_flags_out"}, + {"step_seq_lens_decoder", "step_seq_lens_decoder_out"}, + {"is_block_step", "is_block_step_out"}}) + .SetKernelFn(PD_KERNEL(UpdateInputesV1)); diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index 6e25b6b13..8307330e9 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -262,6 +262,8 @@ elif paddle.is_compiled_with_cuda(): "gpu_ops/stop_generation_multi_ends.cu", "gpu_ops/stop_generation_multi_stop_seqs.cu", "gpu_ops/set_flags.cu", + "gpu_ops/update_inputs_v1.cu", + "gpu_ops/recover_decode_task.cu", "gpu_ops/step.cu", "gpu_ops/step_reschedule.cu", "gpu_ops/fused_get_rope.cu", diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index e64dbb5ae..0114914bb 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -210,8 +210,16 @@ class PrefixCacheManager: update cache config """ self.cache_config = cache_config - self.num_gpu_blocks = cache_config.prefill_kvcache_block_num - self.gpu_free_block_list = list(range(self.num_gpu_blocks - 1, -1, -1)) # 服务端管理的GPU上剩余的block id + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + self.num_gpu_blocks = cache_config.total_block_num + self.gpu_free_block_list = list( + range(self.num_gpu_blocks - 1, -1, -1) + ) # All gpu blocks are managed by cache manager + else: + self.num_gpu_blocks = cache_config.prefill_kvcache_block_num + self.gpu_free_block_list = list( + range(self.num_gpu_blocks - 1, -1, -1) + ) # Only block table divided for prefill managed by server heapq.heapify(self.gpu_free_block_list) self.node_id_pool = list(range(self.num_gpu_blocks + self.num_cpu_blocks)) @@ -231,6 +239,15 @@ class PrefixCacheManager: self.transfer_recv_thread = threading.Thread(target=self.recv_data_transfer_result) self.transfer_recv_thread.start() + def can_allocate_gpu_blocks(self, num_blocks: int): + """ + Check if num_blocks gpu blocks can be allocated. + """ + if len(self.gpu_free_block_list) < num_blocks: + return False + else: + return True + def allocate_gpu_blocks(self, num_blocks): """ allocate gpu blocks. diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 633b6837c..b1e464a6a 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -130,6 +130,11 @@ class EngineArgs: Ratio of tokens to process in a block. """ + prealloc_dec_block_slot_num_threshold: int = 5 + """ + Token slot threshold for preallocating decoder blocks. + """ + dist_init_ip: Optional[str] = None """ The master node ip of multinode deployment @@ -525,10 +530,14 @@ class EngineArgs: ) cache_group.add_argument( - "--swap-space", - type=float, - default=EngineArgs.swap_space, - help="The amount of CPU memory to offload to.", + "--swap-space", type=float, default=EngineArgs.swap_space, help="The amount of CPU memory to offload to." + ) + + cache_group.add_argument( + "--prealloc-dec-block-slot-num-threshold", + type=int, + default=5, + help="Number of token slot threadshold to allocate next blocks for decoding.", ) cache_group.add_argument( @@ -784,6 +793,7 @@ class EngineArgs: gpu_memory_utilization=self.gpu_memory_utilization, num_gpu_blocks_override=self.num_gpu_blocks_override, kv_cache_ratio=self.kv_cache_ratio, + prealloc_dec_block_slot_num_threshold=self.prealloc_dec_block_slot_num_threshold, enable_prefix_caching=self.enable_prefix_caching, swap_space=self.swap_space, cache_queue_port=self.cache_queue_port, diff --git a/fastdeploy/engine/config.py b/fastdeploy/engine/config.py index 9c559ce32..d8ebb38f0 100644 --- a/fastdeploy/engine/config.py +++ b/fastdeploy/engine/config.py @@ -171,6 +171,7 @@ class CacheConfig: Overrides profiled num_gpu_blocks if provided. kv_cache_ratio (float): Ratio for calculating the maximum block number. enc_dec_block_num (int): Number of encoder-decoder blocks. + prealloc_dec_block_slot_num_threshold (int): Number of token slot threadshold to allocate next blocks for decoding. enable_prefix_caching (bool): Flag to enable prefix caching. """ @@ -183,6 +184,7 @@ class CacheConfig: swap_space: Optional[int] = None, kv_cache_ratio: float = 0.75, enc_dec_block_num: int = 2, + prealloc_dec_block_slot_num_threshold: int = 5, tensor_parallel_size: int = 1, enable_prefix_caching=False, enable_ssd_cache=False, @@ -204,6 +206,7 @@ class CacheConfig: num_cpu_blocks (Optional[int]): Number of CPU blocks. kv_cache_ratio (float): Ratio for max block calculation. enc_dec_block_num (int): Number of encoder-decoder blocks. + prealloc_dec_block_slot_num_threshold (int): Number of token slot threadshold to allocate next blocks for decoding, used when ENABLE_V1_KVCACHE_SCHEDULER=1. enable_prefix_caching (bool): Enable prefix caching. """ self.block_size = block_size @@ -211,6 +214,7 @@ class CacheConfig: self.num_gpu_blocks_override = num_gpu_blocks_override self.kv_cache_ratio = kv_cache_ratio self.enc_dec_block_num = enc_dec_block_num + self.prealloc_dec_block_slot_num_threshold = prealloc_dec_block_slot_num_threshold self.cache_dtype = cache_dtype if hasattr(model_cfg, "quantization_config"): self.cache_dtype = model_cfg.quantization_config.get("kv_cache_quant_type", cache_dtype) diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 8bdc64a7f..89070ae3e 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -28,6 +28,7 @@ import time import traceback import uuid import weakref +from concurrent.futures import ThreadPoolExecutor from typing import Dict, List, Optional, Tuple import numpy as np @@ -40,6 +41,7 @@ from fastdeploy.engine.args_utils import EngineArgs from fastdeploy.engine.expert_service import start_expert_service from fastdeploy.engine.request import Request, RequestOutput from fastdeploy.engine.resource_manager import ResourceManager +from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1 from fastdeploy.input.preprocess import InputPreprocessor from fastdeploy.inter_communicator import ( EngineCacheQueue, @@ -52,7 +54,7 @@ from fastdeploy.metrics.trace_util import start_span, start_span_request from fastdeploy.model_executor.guided_decoding import schema_checker from fastdeploy.output.token_processor import TokenProcessor, WarmUpTokenProcessor from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector -from fastdeploy.utils import EngineError, console_logger, llm_logger +from fastdeploy.utils import EngineError, console_logger, envs, llm_logger class LLMEngine: @@ -108,7 +110,18 @@ class LLMEngine: self.start_queue_service() - self.resource_manager = ResourceManager(cfg.max_num_seqs, cfg, cfg.tensor_parallel_size, cfg.splitwise_role) + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + self.resource_manager = ResourceManagerV1( + cfg.max_num_seqs, cfg, cfg.tensor_parallel_size, cfg.splitwise_role + ) + if cfg.splitwise_role != "mixed": + raise NotImplementedError( + "Currently ENABLE_V1_KVCACHE_SCHEDULER=1 only supported in mixed sampling now." + ) + else: + self.resource_manager = ResourceManager( + cfg.max_num_seqs, cfg, cfg.tensor_parallel_size, cfg.splitwise_role + ) os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.cfg.engine_worker_queue_port) @@ -203,7 +216,10 @@ class LLMEngine: self.token_processor.tasks_queue = self.engine_worker_queue - self.insert_task_to_worker_thread = threading.Thread(target=self._insert_task_to_worker, daemon=True) + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + self.insert_task_to_worker_thread = threading.Thread(target=self._scheduler_task_to_worker_v1, daemon=True) + else: + self.insert_task_to_worker_thread = threading.Thread(target=self._insert_task_to_worker, daemon=True) self.insert_task_to_worker_thread.start() if self.api_server_pid is not None: @@ -343,6 +359,56 @@ class LLMEngine: err_msg = f"Error happend while insert task to engine: {e}, {traceback.format_exc()!s}." llm_logger.error(err_msg) + def _scheduler_task_to_worker_v1(self): + """ + Insert tasks to worker with scheduler v1 (ENABLE_V1_KVCACHE_SCHEDULER=1). + """ + get_request_pool = ThreadPoolExecutor(max_workers=1) + is_fetching = False + + def _fetch_request(): + nonlocal is_fetching + is_fetching = True + num_prefill_batch = min( + int(self.resource_manager.available_batch()), + self.cfg.max_prefill_batch, + ) + tasks = self.scheduler.get_requests( + available_blocks=self.resource_manager.available_block_num(), + block_size=self.cfg.cache_config.block_size, + reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num, + max_num_batched_tokens=self.cfg.max_model_len, + batch=num_prefill_batch, + ) + # Fetch requests and add them to the scheduling queue + for task in tasks: + self.resource_manager.add_request(task) + is_fetching = False + + while self.running: + try: + if self.engine_worker_queue.num_tasks() > 0: + time.sleep(0.001) + continue + if ( + len(self.resource_manager.waiting) == 0 + and (not is_fetching) + and self.exist_prefill_task_signal.value[0] == 0 + ): + get_request_pool.submit(_fetch_request) + # 2. Schedule requests + tasks = self.resource_manager.schedule() + # 3. Send to engine + if tasks: + self.resource_manager.get_real_bsz() + self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz)) + else: + time.sleep(0.005) + + except Exception as e: + err_msg = "Error happend while insert task to engine: {}, {}.".format(e, str(traceback.format_exc())) + llm_logger.error(err_msg) + def _insert_zmq_task_to_scheduler(self): if self.api_server_pid is None: return diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index f31a00ce0..e1c255b49 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -18,6 +18,7 @@ from __future__ import annotations import time from dataclasses import asdict, dataclass, fields +from enum import Enum from typing import Any, Dict, Optional, Union import numpy as np @@ -27,6 +28,19 @@ from fastdeploy.utils import data_processor_logger from fastdeploy.worker.output import LogprobsLists +class RequestStatus(Enum): + WAITING = 0 + RUNNING = 1 + PREEMPTED = 2 + FINISHED = 3 + + +class RequestType(Enum): + PREFILL = 0 + DECODE = 1 + PREEMPTED = 2 + + @dataclass class Request: def __init__( @@ -93,6 +107,15 @@ class Request: self.enable_thinking = enable_thinking self.trace_carrier = trace_carrier + # token num + self.block_tables = [] + self.output_token_ids = [] + self.num_computed_tokens = 0 + # status + self.status = RequestStatus.WAITING + self.task_type = RequestType.PREFILL + self.idx = None + @classmethod def from_dict(cls, d: dict): data_processor_logger.debug(f"{d}") @@ -125,6 +148,21 @@ class Request: trace_carrier=d.get("trace_carrier", {}), ) + @property + def num_total_tokens(self): + """ + Total tokens of the request, include prompt tokens and generated tokens. + """ + return self.prompt_token_ids_len + len(self.output_token_ids) + + def __eq__(self, other): + """ + EQ operator. + """ + if not isinstance(other, Request): + return False + return self.request_id == other.request_id + def to_dict(self) -> dict: """convert Request into a serializable dict""" data = { diff --git a/fastdeploy/engine/sched/__init__.py b/fastdeploy/engine/sched/__init__.py new file mode 100644 index 000000000..f4ede9062 --- /dev/null +++ b/fastdeploy/engine/sched/__init__.py @@ -0,0 +1,15 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py new file mode 100644 index 000000000..4b99f35a9 --- /dev/null +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -0,0 +1,261 @@ +import threading +import time +from collections import deque +from collections.abc import Iterable +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from typing import Union + +from fastdeploy.engine.request import Request, RequestStatus, RequestType +from fastdeploy.engine.resource_manager import ResourceManager +from fastdeploy.utils import llm_logger + + +@dataclass +class ScheduledDecodeTask: + """ + Task for allocating new blocks to decode. + """ + + idx: int + request_id: str + block_tables: list[int] + task_type: RequestType = RequestType.DECODE + + +@dataclass +class ScheduledPreemptTask: + """ + Task for terminating inference to recycle resource. + """ + + idx: int + request_id: str + task_type: RequestType = RequestType.PREEMPTED + + +class ResourceManagerV1(ResourceManager): + """ + Resource manager for scheduler v1. + In scheduler v1, all gpu blocks are managed by PrefixCacheManager. + Tasks sent to worker are divided into 3 types, PREFILL、DECODE and PREEMPTED. + For prefill task, the worker infer with one step and then stopped for this query if not all prompt tokens are computed. + For decode task, the work continues to decode until allocated blocks are exhausted. + For preempted task, the work reset all inputs to terminate the inference. + """ + + def __init__(self, max_num_seqs, config, tensor_parallel_size, splitwise_role, local_data_parallel_id=0): + super(ResourceManagerV1, self).__init__( + max_num_seqs, config, tensor_parallel_size, splitwise_role, local_data_parallel_id + ) + # req_id -> Request + self.config = config + self.requests: dict[str, Request] = {} + # Priority queues for requests. + self.waiting: deque[Request] = deque() + self.running: list[Request] = [] + self.finish_execution_pool = ThreadPoolExecutor(max_workers=1) + self.lock = threading.Lock() + + def allocated_slots(self, request: Request): + return len(request.block_tables) * self.config.cache_config.block_size + + def get_new_block_nums(self, request: Request, num_new_tokens: int): + return ( + request.num_computed_tokens + num_new_tokens + self.config.cache_config.block_size - 1 + ) // self.config.cache_config.block_size - len(request.block_tables) + + def _prepare_prefill_task(self, request, new_token_num): + request.prefill_start_index = request.num_computed_tokens + request.prefill_end_index = request.num_computed_tokens + new_token_num + request.task_type = RequestType.PREFILL + return request + + def _prepare_decode_task(self, request): + return ScheduledDecodeTask(idx=request.idx, request_id=request.request_id, block_tables=request.block_tables) + + def _prepare_preempt_task(self, request): + return ScheduledPreemptTask(idx=request.idx, request_id=request.request_id) + + def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_reqs): + can_schedule = True + while True: + if not self.cache_manager.can_allocate_gpu_blocks(num_new_blocks): + preempted_req = self.running.pop() + preempted_req.status = RequestStatus.PREEMPTED + preempted_req.num_computed_tokens = 0 + self._free_blocks(preempted_req) + self.waiting.appendleft(preempted_req) + preempted_reqs.append(preempted_req) + scheduled_reqs.append(self._prepare_preempt_task(preempted_req)) + if preempted_req == request: + # No more request to preempt. + can_schedule = False + break + else: + # The request can be scheduled. + can_schedule = True + break + return can_schedule + + def schedule(self): + with self.lock: + scheduled_reqs: list[Request] = [] + preempted_reqs: list[Request] = [] + token_budget = self.config.max_num_batched_tokens + + # First, schedule the RUNNING requests. + req_index = 0 + num_decoding_req_nums = 0 + while req_index < len(self.running) and token_budget > 0: + request = self.running[req_index] + if request.num_computed_tokens >= request.prompt_token_ids_len: # to be decoding + if request.num_total_tokens > request.prompt_token_ids_len: # has generated tokens + request.num_computed_tokens = request.num_total_tokens - 1 + if ( + self.allocated_slots(request) - request.num_total_tokens + <= self.config.cache_config.prealloc_dec_block_slot_num_threshold + ): + # Allocation for next decoding blocks + if self.cache_manager.can_allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num): + llm_logger.debug( + f"schedule decoding task: {request} request.num_total_tokens {request.num_total_tokens} request.num_computed_tokens {request.num_computed_tokens}" + ) + request.block_tables.extend( + self.cache_manager.allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num) + ) + # Prepare decoding task + scheduled_reqs.append(self._prepare_decode_task(request)) + else: + # Not enough blocks to allocate, trigger preemption + can_schedule = self._trigger_preempt( + request, self.config.cache_config.enc_dec_block_num, preempted_reqs, scheduled_reqs + ) + if not can_schedule: + break + # Allocation for next decoding blocks + request.block_tables.extend( + self.cache_manager.allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num) + ) + # Prepare decoding task + scheduled_reqs.append(self._prepare_decode_task(request)) + num_decoding_req_nums += 1 + token_budget -= 1 + else: # need to prefill + llm_logger.debug( + f"scheduler prefill task: {request} request.prompt_token_ids_len {request.prompt_token_ids_len} request.num_computed_tokens {request.num_computed_tokens}" + ) + num_new_tokens = request.prompt_token_ids_len - request.num_computed_tokens + num_new_tokens = min(num_new_tokens, token_budget) + num_new_block = self.get_new_block_nums(request, num_new_tokens) + # Allocate blocks to prefill + if self.cache_manager.can_allocate_gpu_blocks(num_new_block): + request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(num_new_block)) + # Prepare prefill task + scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens)) + else: + can_schedule = self._trigger_preempt(request, num_new_block, preempted_reqs, scheduled_reqs) + if not can_schedule: + break + request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(num_new_block)) + # Prepare prefill task + scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens)) + token_budget -= num_new_tokens + request.num_computed_tokens += num_new_tokens + req_index += 1 + # schedule the WAITING requests. + if not preempted_reqs: + while self.waiting and token_budget > 0: + if len(self.running) == self.max_num_seqs: + break + request = self.waiting[0] + if request.status == RequestStatus.WAITING: + num_new_tokens = request.num_total_tokens - request.num_computed_tokens + num_new_tokens = min(num_new_tokens, token_budget) + num_new_block = self.get_new_block_nums(request, num_new_tokens) + # Allocate blocks to prefill + if self.cache_manager.can_allocate_gpu_blocks(num_new_block): + request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(num_new_block)) + self.waiting.popleft() + self.running.append(request) + scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens)) + request.inference_start_time = time.time() + request.schedule_start_time = time.time() + token_budget -= num_new_tokens + request.num_computed_tokens += num_new_tokens + request.status = RequestStatus.RUNNING + allocated_position = self.get_available_position() + request.idx = allocated_position + self.tasks_list[allocated_position] = request + self.stop_flags[allocated_position] = False + self.req_dict[request.request_id] = allocated_position + else: + break + elif request.status == RequestStatus.PREEMPTED: + num_new_tokens = request.num_total_tokens - request.num_computed_tokens + num_new_tokens = min(num_new_tokens, token_budget) + num_new_block = self.get_new_block_nums(request, num_new_tokens) + # Allocate blocks to prefill + if self.cache_manager.can_allocate_gpu_blocks(num_new_block): + request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(num_new_block)) + self.waiting.popleft() + self.running.append(request) + scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens)) + token_budget -= num_new_tokens + request.num_computed_tokens += num_new_tokens + request.status = RequestStatus.RUNNING + else: + break + else: + llm_logger.error("Unknown request status type") + if scheduled_reqs: + llm_logger.debug(f"schedued_reqs: {scheduled_reqs}") + return scheduled_reqs + + def get_available_position(self) -> int: + position = 0 + while position < self.max_num_seqs: + if self.stop_flags[position] is True: + return position + position += 1 + raise RuntimeError("No available position is available for new request") + + def get_real_bsz(self) -> int: + for i in range(self.max_num_seqs - 1, -1, -1): + if not self.stop_flags[i]: + self.real_bsz = i + 1 + break + return self.real_bsz + + def add_request(self, request: Request) -> None: + self.waiting.append(request) + self.requests[request.request_id] = request + + def _free_blocks(self, request: Request): + self.cache_manager.recycle_gpu_blocks(request.block_tables) + request.block_tables = [] + + def finish_requests_async(self, request_ids: Union[str, Iterable[str]]): + return self.finish_execution_pool.submit(self.finish_requests, request_ids) + + def finish_requests(self, request_ids: Union[str, Iterable[str]]): + llm_logger.info(f"recycle resources for requests: {request_ids}") + try: + with self.lock: + if isinstance(request_ids, str): + request_ids = (request_ids,) + else: + request_ids = set(request_ids) + for req_id in request_ids: + request = self.requests.get(req_id) + if request is None: + # Invalid request ID. + continue + request.status = RequestStatus.FINISHED + self.running.remove(request) + self._free_blocks(request) + self.tasks_list[request.idx] = None + self.stop_flags[request.idx] = True + del self.requests[req_id] + except Exception as e: + llm_logger.error(e) diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 58f51aa78..40203b485 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -76,6 +76,8 @@ environment_variables: dict[str, Callable[[], Any]] = { "EXPORTER_OTLP_ENDPOINT": lambda: os.getenv("EXPORTER_OTLP_ENDPOINT"), # set traec exporter_otlp_headers. "EXPORTER_OTLP_HEADERS": lambda: os.getenv("EXPORTER_OTLP_HEADERS"), + # enable kv cache block scheduler v1 (no need for kv_cache_ratio) + "ENABLE_V1_KVCACHE_SCHEDULER": lambda: int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")), } diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index eba068e89..7d0a2aef7 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -61,9 +61,10 @@ else: speculate_step_system_cache, speculate_update_v3, step_paddle, - step_reschedule, step_system_cache, update_inputs, + step_reschedule, + update_inputs_v1, ) from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput, SamplerOutput @@ -152,6 +153,8 @@ def pre_process( def post_process_normal( sampler_output: SamplerOutput, model_output: ModelOutputData, + share_inputs: Dict[str, paddle.Tensor], + block_size: int = 64, save_each_rank: bool = False, skip_save_output: bool = False, ) -> ModelRunnerOutput: @@ -219,17 +222,35 @@ def post_process_normal( # 2. Update the input buffer of the model with paddle.framework._no_check_dy2st_diff(): - update_inputs( - model_output.stop_flags, - model_output.not_need_stop, - model_output.seq_lens_this_time, - model_output.seq_lens_encoder, - model_output.seq_lens_decoder, - model_output.input_ids, - model_output.stop_nums, - sampler_output.sampled_token_ids, - model_output.is_block_step, - ) + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + update_inputs_v1( + model_output.stop_flags, + model_output.not_need_stop, + model_output.seq_lens_this_time, + model_output.seq_lens_encoder, + model_output.seq_lens_decoder, + share_inputs["step_seq_lens_decoder"], + share_inputs["prompt_lens"], + sampler_output.sampled_token_ids, + model_output.input_ids, + share_inputs["block_tables"], + model_output.stop_nums, + model_output.next_tokens, + model_output.is_block_step, + block_size, + ) + else: + update_inputs( + model_output.stop_flags, + model_output.not_need_stop, + model_output.seq_lens_this_time, + model_output.seq_lens_encoder, + model_output.seq_lens_decoder, + model_output.input_ids, + model_output.stop_nums, + sampler_output.sampled_token_ids, + model_output.is_block_step, + ) # 3. Transmit the model's output and stop generation signal via message queue. # In the future, we will abandon this approach. if not skip_save_output: @@ -295,6 +316,8 @@ def post_process_specualate(model_output, save_each_rank: bool = False, skip_sav def post_process( sampler_output: SamplerOutput, model_output: ModelOutputData, + share_inputs: Dict[str, paddle.Tensor], + block_size: int = 64, save_each_rank: bool = False, speculative_decoding: bool = False, skip_save_output: bool = False, @@ -303,7 +326,7 @@ def post_process( if speculative_decoding: post_process_specualate(model_output, save_each_rank, skip_save_output) else: - post_process_normal(sampler_output, model_output, save_each_rank, skip_save_output) + post_process_normal(sampler_output, model_output, share_inputs, block_size, save_each_rank, skip_save_output) def step_cuda( diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 953b93571..1f46a0952 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -25,6 +25,7 @@ from concurrent.futures import ThreadPoolExecutor import numpy as np +from fastdeploy import envs from fastdeploy.engine.request import CompletionOutput, RequestMetrics, RequestOutput from fastdeploy.inter_communicator import IPCSignal from fastdeploy.metrics.metrics import main_process_metrics @@ -269,9 +270,12 @@ class TokenProcessor: else: time.sleep(0.002) else: - self.resource_manager.stop_flags[index] = True - self.resource_manager.tasks_list[index] = None - self.resource_manager._recycle_block_tables(task) + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + self.resource_manager.finish_requests_async(task_id) + else: + self.resource_manager.stop_flags[index] = True + self.resource_manager.tasks_list[index] = None + self.resource_manager._recycle_block_tables(task) if task_id in self.tokens_counter: del self.tokens_counter[task_id] @@ -508,6 +512,7 @@ class TokenProcessor: self.tokens_counter[task_id] += 1 if token_id != RECOVERY_STOP_SIGNAL: result.outputs.token_ids.append(token_id) + task.output_token_ids.append(token_id) if token_id in task.eos_token_ids or is_prefill or recovery_stop: result.finished = True if recovery_stop: diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index c639c29ef..432a12ddf 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -24,7 +24,7 @@ from paddle import nn from paddleformers.utils.log import logger from fastdeploy.config import FDConfig -from fastdeploy.engine.request import Request +from fastdeploy.engine.request import Request, RequestType from fastdeploy.model_executor.graph_optimization.utils import ( profile_run_guard, sot_warmup_guard, @@ -42,6 +42,7 @@ from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata from fastdeploy.model_executor.layers.sample.sampler import Sampler, SpeculativeSampler from fastdeploy.model_executor.model_loader import get_model_from_loader from fastdeploy.model_executor.ops.gpu import ( + recover_decode_task, set_value_by_flags_and_idx, share_external_data, ) @@ -56,6 +57,7 @@ from fastdeploy.platforms import current_platform if not current_platform.is_dcu(): from fastdeploy.spec_decode import MTPProposer, NgramProposer +from fastdeploy import envs from fastdeploy.input.mm_processor import DataProcessor from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp @@ -189,10 +191,97 @@ class GPUModelRunner(ModelRunnerBase): elif request.structural_tag is not None: schemata_key = ("structural_tag", request.structural_tag) - return ( - self.guided_backend.get_logits_processor(schemata_key=schemata_key), - schemata_key, - ) + return self.guided_backend.get_logits_processor(schemata_key=schemata_key), schemata_key + + def insert_tasks_v1(self, req_dicts: List[Request]): + """ + Process scheduler output tasks, used when ENABLE_V1_KVCACHE_SCHEDULER=1 + """ + # NOTE(luotingdan): Lazy initialize kv cache + if "caches" not in self.share_inputs: + self.initialize_kv_cache() + + req_len = len(req_dicts) + has_prefill_task = False + for i in range(req_len): + request = req_dicts[i] + idx = request.idx + if request.task_type.value == RequestType.PREFILL.value: # prefill task + logger.debug(f"Handle prefill request {request} at idx {idx}") + prefill_start_index = request.prefill_start_index + prefill_end_index = request.prefill_end_index + length = prefill_end_index - prefill_start_index + input_ids = request.prompt_token_ids + request.output_token_ids + self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array( + input_ids[prefill_start_index:prefill_end_index] + ) + encoder_block_num = len(request.block_tables) + self.share_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num + self.share_inputs["block_tables"][idx : idx + 1, :] = -1 + self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( + request.block_tables, dtype="int32" + ) + self.share_inputs["stop_flags"][idx : idx + 1] = False + self.share_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index + self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length + self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length + self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = 0 + self.share_inputs["prompt_lens"][idx : idx + 1] = len(input_ids) + self.share_inputs["is_block_step"][idx : idx + 1] = False + self.share_inputs["step_idx"][idx : idx + 1] = ( + len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0 + ) + has_prefill_task = True + elif request.task_type.value == RequestType.DECODE.value: # decode task + logger.debug(f"Handle decode request {request} at idx {idx}") + encoder_block_num = len(request.block_tables) + self.share_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num + self.share_inputs["block_tables"][idx : idx + 1, :] = -1 + self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( + request.block_tables, dtype="int32" + ) + continue + else: # preempted task + logger.debug(f"Handle preempted request {request} at idx {idx}") + self.share_inputs["block_tables"][idx : idx + 1, :] = -1 + self.share_inputs["stop_flags"][idx : idx + 1] = True + self.share_inputs["seq_lens_this_time"][idx : idx + 1] = 0 + self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0 + self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0 + self.share_inputs["is_block_step"][idx : idx + 1] = False + continue + + if len(request.eos_token_ids) < self.parallel_config.eos_tokens_lens: + request.eos_token_ids.append(request.eos_token_ids[0]) + self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1) + + self.share_inputs["top_p"][idx : idx + 1] = request.get("top_p", 0.7) + self.share_inputs["temperature"][idx : idx + 1] = request.get("temperature", 0.95) + self.share_inputs["penalty_score"][idx : idx + 1] = request.get("repetition_penalty", 1.0) + self.share_inputs["frequency_score"][idx : idx + 1] = request.get("frequency_penalty", 0.0) + self.share_inputs["presence_score"][idx : idx + 1] = request.get("presence_penalty", 0.0) + + self.share_inputs["min_dec_len"][idx : idx + 1] = request.get("min_tokens", 1) + self.share_inputs["max_dec_len"][idx : idx + 1] = request.get( + "max_tokens", self.model_config.max_model_len + ) + + self.share_inputs["first_token_ids"][idx : idx + 1] = self.share_inputs["input_ids"][idx : idx + 1, :1] + self.share_inputs["ori_seq_lens_encoder"][idx : idx + 1] = length + + if request.get("seed") is not None: + self.share_inputs["infer_seed"][idx : idx + 1] = request.get("seed") + + if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None: + stop_seqs_num = len(request.get("stop_seqs_len")) + for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num): + request.stop_seqs_len.append(0) + self.share_inputs["stop_seqs_len"][:] = np.array(request.stop_seqs_len, dtype="int32") + self.share_inputs["stop_seqs"][:stop_seqs_num, : len(request.get("stop_token_ids")[0])] = np.array( + request.get("stop_token_ids"), dtype="int64" + ) + if has_prefill_task: + self.share_inputs["not_need_stop"][0] = True def insert_prefill_inputs(self, req_dicts: List[Request]): """ @@ -591,6 +680,18 @@ class GPUModelRunner(ModelRunnerBase): def _prepare_inputs(self) -> None: """Prepare the model inputs""" + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + recover_decode_task( + self.share_inputs["stop_flags"], + self.share_inputs["seq_lens_this_time"], + self.share_inputs["seq_lens_encoder"], + self.share_inputs["seq_lens_decoder"], + self.share_inputs["step_seq_lens_decoder"], + self.share_inputs["block_tables"], + self.share_inputs["is_block_step"], + self.parallel_config.block_size, + ) + # Remove padding ( ids_remove_padding, @@ -901,6 +1002,8 @@ class GPUModelRunner(ModelRunnerBase): post_process( sampler_output=sampler_output, model_output=model_output_data, + share_inputs=self.share_inputs, + block_size=self.parallel_config.block_size, speculative_decoding=self.speculative_decoding, skip_save_output=True, ) @@ -1165,6 +1268,8 @@ class GPUModelRunner(ModelRunnerBase): post_process( sampler_output=sampler_output, model_output=model_output_data, + share_inputs=self.share_inputs, + block_size=self.parallel_config.block_size, save_each_rank=self.parallel_config.use_ep, speculative_decoding=self.speculative_decoding, skip_save_output=skip_save_output, @@ -1180,16 +1285,17 @@ class GPUModelRunner(ModelRunnerBase): # 7. Updata 'infer_seed' and step_cuda() self.share_inputs["infer_seed"].add_(self.infer_seed_increment) self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED - step_cuda( - self.share_inputs, - self.parallel_config.block_size, - self.parallel_config.enc_dec_block_num, - self.speculative_config, - self.parallel_config.enable_prefix_caching, - ) + if not envs.ENABLE_V1_KVCACHE_SCHEDULER: + step_cuda( + self.share_inputs, + self.parallel_config.block_size, + self.parallel_config.enc_dec_block_num, + self.speculative_config, + self.parallel_config.enable_prefix_caching, + ) - self._update_chunked_prefill(model_forward_batch) - self._add_cache(model_forward_batch) + self._update_chunked_prefill(model_forward_batch) + self._add_cache(model_forward_batch) return None def _add_cache(self, model_forward_batch) -> None: diff --git a/fastdeploy/worker/gpu_worker.py b/fastdeploy/worker/gpu_worker.py index a3abbf6d0..50ac4d231 100644 --- a/fastdeploy/worker/gpu_worker.py +++ b/fastdeploy/worker/gpu_worker.py @@ -22,6 +22,7 @@ import paddle import pynvml from paddle import nn +from fastdeploy import envs from fastdeploy.config import FDConfig from fastdeploy.engine.request import Request from fastdeploy.platforms import current_platform @@ -183,7 +184,10 @@ class GpuWorker(WorkerBase): TODO(gongshaotian):The scheduler should schedule the handling of prefill, and workers and modelrunners should not perceive it. """ - self.model_runner.insert_prefill_inputs(req_dicts=req_dicts) + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + self.model_runner.insert_tasks_v1(req_dicts=req_dicts) + else: + self.model_runner.insert_prefill_inputs(req_dicts=req_dicts) def graph_optimize_and_warm_up_model(self) -> None: """