[Feature] Support pd ep deployment with yiyan adapter (#4029)

* [Feature] Support mixed deployment with yiyan adapter in release2.2

* fix metrics

* add unit test

* add unit test

* add unit test

* Support pd ep deployment with yiyan adapter

* Support pd ep deployment with yiyan adapter

* refactor cache messager

* support scheduler v1 in PD

* suppport pd v1 + chunk prefill

* suppport pd v1 + chunk prefill

* add eplb

* support eplb

* support eplb

* support eplb

* support v1

* fix

* fix

* fix bug

* remove eplb support

* support prefix cache in P

* fix bug

* fix bug

* support one stop in V1

* fix bug

* fix ci

* fix ci

* fix

* fix

* fix

* fix

* fix

---------

Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
chenjian
2025-09-22 16:41:38 +08:00
committed by GitHub
parent 9845f0d010
commit 918ccdb123
22 changed files with 1838 additions and 343 deletions

View File

@@ -14,6 +14,7 @@
# limitations under the License.
"""
import copy
import threading
import time
import traceback
@@ -26,7 +27,7 @@ from typing import Union
import numpy as np
import paddle
from fastdeploy.engine.request import Request, RequestStatus, RequestType
from fastdeploy.engine.request import Request, RequestOutput, RequestStatus, RequestType
from fastdeploy.engine.resource_manager import ResourceManager
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.utils import llm_logger
@@ -297,6 +298,11 @@ class ResourceManagerV1(ResourceManager):
while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index]
if request.num_computed_tokens >= request.need_prefill_tokens: # to be decoding
if (
self.config.scheduler_config.splitwise_role == "prefill"
): # do not need to schedule for decoding
req_index += 1
continue
if request.num_total_tokens > request.need_prefill_tokens: # has generated tokens
request.num_computed_tokens = request.num_total_tokens - 1
if (
@@ -400,11 +406,12 @@ class ResourceManagerV1(ResourceManager):
request.status = RequestStatus.RUNNING
main_process_metrics.num_requests_waiting.dec(1)
main_process_metrics.num_requests_running.inc(1)
allocated_position = self.get_available_position()
request.idx = allocated_position
self.tasks_list[allocated_position] = request
self.stop_flags[allocated_position] = False
self.req_dict[request.request_id] = allocated_position
if self.config.scheduler_config.splitwise_role == "mixed":
allocated_position = self.get_available_position()
request.idx = allocated_position
self.tasks_list[allocated_position] = request
self.stop_flags[allocated_position] = False
self.req_dict[request.request_id] = allocated_position
else:
if self.config.cache_config.enable_prefix_caching:
self._free_blocks(request)
@@ -569,6 +576,127 @@ class ResourceManagerV1(ResourceManager):
self.waiting.append(request)
self.requests[request.request_id] = request
def prerelease_resource(self, request: Request):
"""
Release resource in P or D before finished due to unexpected error.
"""
with self.lock:
self.tasks_list[request.idx] = None
self.stop_flags[request.idx] = True
del self.requests[request.request_id]
del self.req_dict[request.request_id]
self._free_blocks(request)
def add_request_in_p(self, requests: list[Request]):
with self.lock:
for request in requests:
request.inference_start_time = time.time()
request.schedule_start_time = time.time()
self.running.append(request)
def preallocate_resource_in_p(self, request: Request):
"""
In P/D aggregated deployment, preallocate resource for P.
If can allocate, allocate resources and return True
If can not, return False
"""
assert self.config.scheduler_config.splitwise_role == "prefill", "Only P instance can call this method"
with self.lock:
if self.available_batch() == 0:
return False
request.need_prefill_tokens = len(request.prompt_token_ids)
need_prealloc_prefill_blocks = (
request.need_prefill_tokens + self.config.cache_config.block_size - 1
) // self.config.cache_config.block_size + self.config.cache_config.enc_dec_block_num # consider for mtp, plus enc_dec_block_num
if self.config.cache_config.enable_prefix_caching:
# Enable prefix caching
if self.config.cache_config.enable_hierarchical_cache and self.cache_manager.num_cpu_blocks > 0:
if not self.cache_manager.can_allocate_gpu_blocks(
need_prealloc_prefill_blocks
): # to prevent block allocation for matching in hierarchical cache and cause dead lock
return False
success = self.get_prefix_cached_blocks(request)
if not success:
self._free_blocks(request)
return False
# consider for mtp, plus enc_dec_block_num
need_extra_prefill_blocks = need_prealloc_prefill_blocks - request.cache_info[0]
if self.cache_manager.can_allocate_gpu_blocks(need_extra_prefill_blocks):
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(need_extra_prefill_blocks))
allocated_position = self.get_available_position()
request.idx = allocated_position
self.tasks_list[request.idx] = request
self.stop_flags[request.idx] = False
self.requests[request.request_id] = request
self.req_dict[request.request_id] = allocated_position
return True
else:
self._free_blocks(request)
return False
else:
if self.cache_manager.can_allocate_gpu_blocks(need_prealloc_prefill_blocks):
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(need_prealloc_prefill_blocks))
request.num_computed_tokens = 0
allocated_position = self.get_available_position()
request.idx = allocated_position
self.tasks_list[request.idx] = request
self.stop_flags[request.idx] = False
self.requests[request.request_id] = request
self.req_dict[request.request_id] = allocated_position
return True
return False
def preallocate_resource_in_d(self, request: Request):
"""
In P/D aggregated deployment, D should preallocate resource for P.
If can allocate, allocate resources and return True
If can not, return False
"""
assert self.config.scheduler_config.splitwise_role == "decode", "Only D instance can call this method"
with self.lock:
if len(self.waiting) > 0:
return False
if self.available_batch() == 0:
return False
request.need_prefill_tokens = len(request.prompt_token_ids)
need_prealloc_prefill_blocks = (
request.need_prefill_tokens + self.config.cache_config.block_size - 1
) // self.config.cache_config.block_size + self.config.cache_config.enc_dec_block_num # consider for mtp, plus enc_dec_block_num
if self.cache_manager.can_allocate_gpu_blocks(need_prealloc_prefill_blocks):
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(need_prealloc_prefill_blocks))
request.num_computed_tokens = request.need_prefill_tokens
request.disaggregate_info["block_tables"] = request.block_tables
allocated_position = self.get_available_position()
request.idx = allocated_position
self.tasks_list[request.idx] = request
self.stop_flags[request.idx] = False
self.requests[request.request_id] = request
self.req_dict[request.request_id] = allocated_position
return True
return False
def insert_task_for_decoding(self, request_output_in_p: RequestOutput):
"""
In P/D aggregated deployment, D should continue to decode after recieving first token and cache from P.
"""
assert self.config.scheduler_config.splitwise_role == "decode", "Only D instance can call this method"
with self.lock:
request = self.requests[request_output_in_p.request_id]
request.output_token_ids.append(request_output_in_p.outputs.token_ids[0])
request.num_cached_tokens = request_output_in_p.num_cached_tokens
if (
self.config.speculative_config.method in ["mtp"]
and self.config.scheduler_config.splitwise_role == "decode"
):
request.draft_token_ids = copy.deepcopy(request_output_in_p.outputs.draft_token_ids)
# update request.need_prefill_tokens
request.need_prefill_tokens = len(request.prompt_token_ids) + 1
request.inference_start_time = time.time()
request.schedule_start_time = time.time()
self.running.append(request)
def _free_blocks(self, request: Request):
if self.config.cache_config.enable_prefix_caching:
self.cache_manager.release_block_ids(request)
@@ -620,5 +748,7 @@ class ResourceManagerV1(ResourceManager):
self.tasks_list[request.idx] = None
self.stop_flags[request.idx] = True
del self.requests[req_id]
if req_id in self.req_dict:
del self.req_dict[req_id]
except Exception as e:
llm_logger.error(f"finish_request err: {e}, {str(traceback.format_exc())}")