mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Intel HPU] enable level 1 prefix caching and fix some bugs (#4971)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
* [Intel HPU] enable prefix caching and dense tp moe ep and fix some bugs * update code by copilot * remove dense tp and moe ep code
This commit is contained in:
@@ -478,7 +478,7 @@ class EngineArgs:
|
||||
self.enable_prefix_caching = False
|
||||
if self.speculative_config is not None:
|
||||
self.enable_prefix_caching = False
|
||||
if not current_platform.is_cuda() and not current_platform.is_xpu():
|
||||
if not current_platform.is_cuda() and not current_platform.is_xpu() and not current_platform.is_intel_hpu():
|
||||
self.enable_prefix_caching = False
|
||||
# if self.dynamic_load_weight:
|
||||
# self.enable_prefix_caching = False
|
||||
|
||||
@@ -40,6 +40,7 @@ from fastdeploy.engine.expert_service import start_data_parallel_service
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.utils import EngineError, console_logger, envs, llm_logger
|
||||
|
||||
|
||||
@@ -136,8 +137,9 @@ class LLMEngine:
|
||||
|
||||
# If block numer is specified and model is deployed in mixed mode, start cache manager first
|
||||
if not self.do_profile and self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
device_ids = self.cfg.parallel_config.device_ids.split(",")
|
||||
self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix)
|
||||
if not current_platform.is_intel_hpu():
|
||||
device_ids = self.cfg.parallel_config.device_ids.split(",")
|
||||
self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix)
|
||||
|
||||
# Start workers
|
||||
self.worker_proc = self._start_worker_service()
|
||||
@@ -170,8 +172,9 @@ class LLMEngine:
|
||||
if self.do_profile:
|
||||
self._stop_profile()
|
||||
elif self.cfg.scheduler_config.splitwise_role == "mixed" and self.cfg.cache_config.enable_prefix_caching:
|
||||
device_ids = self.cfg.parallel_config.device_ids.split(",")
|
||||
self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix)
|
||||
if not current_platform.is_intel_hpu():
|
||||
device_ids = self.cfg.parallel_config.device_ids.split(",")
|
||||
self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix)
|
||||
|
||||
# Launch components: scheduler, cache_manager, expert_service et.al.
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
@@ -673,8 +676,9 @@ class LLMEngine:
|
||||
self.cfg.cache_config.reset(num_gpu_blocks)
|
||||
self.engine.resource_manager.reset_cache_config(self.cfg.cache_config)
|
||||
if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
device_ids = self.cfg.parallel_config.device_ids.split(",")
|
||||
self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix)
|
||||
if not current_platform.is_intel_hpu():
|
||||
device_ids = self.cfg.parallel_config.device_ids.split(",")
|
||||
self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix)
|
||||
|
||||
def check_health(self, time_interval_threashold=30):
|
||||
"""
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
@@ -39,6 +40,31 @@ if TYPE_CHECKING:
|
||||
from fastdeploy.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
|
||||
|
||||
|
||||
def get_attention_mask(seq_lens_encoder, seq_lens_decoder, batch_size, query_len):
|
||||
max_context_len = int(paddle.max(seq_lens_decoder).item())
|
||||
past_mask = paddle.arange(0, max_context_len, dtype=paddle.int32)
|
||||
past_mask = paddle.greater_equal(
|
||||
past_mask.reshape([1, -1]).expand([batch_size, -1]), seq_lens_decoder.reshape([-1, 1]).astype(paddle.int32)
|
||||
)
|
||||
past_mask = (
|
||||
past_mask.reshape([batch_size, 1, -1])
|
||||
.expand([batch_size, query_len, -1])
|
||||
.reshape([batch_size, 1, query_len, -1])
|
||||
)
|
||||
len_mask = paddle.greater_equal(
|
||||
paddle.arange(0, query_len, dtype=paddle.int32).reshape([1, query_len]),
|
||||
seq_lens_encoder.unsqueeze(-1).astype(paddle.int32),
|
||||
)
|
||||
len_mask = len_mask.reshape([batch_size, 1, 1, query_len])
|
||||
attn_mask = paddle.triu(paddle.ones((batch_size, 1, query_len, query_len), dtype=paddle.bool), diagonal=1)
|
||||
mask = attn_mask.logical_or(len_mask)
|
||||
mask = paddle.concat((past_mask, mask), axis=-1)
|
||||
off_value = -math.inf
|
||||
attn_mask = paddle.zeros_like(mask, dtype=paddle.bfloat16).masked_fill_(mask, off_value)
|
||||
attn_mask = paddle.unsqueeze(attn_mask, axis=1)
|
||||
return attn_mask
|
||||
|
||||
|
||||
class AttentionBackend_HPU(AttentionBackend):
|
||||
"""The base class of attention backends"""
|
||||
|
||||
@@ -254,16 +280,40 @@ class HPUAttentionBackend(AttentionBackend_HPU):
|
||||
index_copy_(k_cache, forward_meta.block_indices, key_states, 0)
|
||||
index_copy_(v_cache, forward_meta.block_indices, value_states, 0)
|
||||
|
||||
out_linear_out = fused_sdpa_proj_t(
|
||||
query_states,
|
||||
key_value_states,
|
||||
forward_meta.attn_mask,
|
||||
None,
|
||||
o_proj.weight,
|
||||
scaling_factor=self.head_dim**-0.5,
|
||||
causal=True,
|
||||
softmax_mode=0,
|
||||
)
|
||||
if forward_meta.block_list.shape == forward_meta.block_indices.shape:
|
||||
out_linear_out = fused_sdpa_proj_t(
|
||||
query_states,
|
||||
key_value_states,
|
||||
forward_meta.attn_mask,
|
||||
None,
|
||||
o_proj.weight,
|
||||
scaling_factor=self.head_dim**-0.5,
|
||||
causal=True,
|
||||
softmax_mode=0,
|
||||
)
|
||||
else:
|
||||
key_states_with_context = k_cache.index_select(forward_meta.block_list)
|
||||
val_states_with_context = v_cache.index_select(forward_meta.block_list)
|
||||
key_value_states_with_context = paddle.stack(
|
||||
[key_states_with_context, val_states_with_context], axis=0
|
||||
).reshape([kv, B, -1, M, H])
|
||||
if forward_meta.attn_mask is None:
|
||||
forward_meta.attn_mask = get_attention_mask(
|
||||
forward_meta.seq_lens_encoder[forward_meta.batch_ids],
|
||||
forward_meta.seq_lens_decoder[forward_meta.batch_ids],
|
||||
query_states.shape[0],
|
||||
query_states.shape[1],
|
||||
)
|
||||
out_linear_out = fused_sdpa_proj_t(
|
||||
query_states,
|
||||
key_value_states_with_context,
|
||||
forward_meta.attn_mask,
|
||||
None,
|
||||
o_proj.weight,
|
||||
scaling_factor=self.head_dim**-0.5,
|
||||
causal=False,
|
||||
softmax_mode=0,
|
||||
)
|
||||
|
||||
if self.nranks > 1:
|
||||
from fastdeploy.distributed.communication import (
|
||||
@@ -297,11 +347,14 @@ class HPUAttentionBackend(AttentionBackend_HPU):
|
||||
qkv_proj.weight,
|
||||
qkv_proj.bias,
|
||||
o_proj.weight,
|
||||
None, # past_key: not used in decode mode
|
||||
None, # past_value: not used in decode mode
|
||||
self.head_dim,
|
||||
self.num_heads,
|
||||
scaling_factor=self.head_dim**-0.5,
|
||||
transpose=False,
|
||||
use_neox_style=layer.use_neox_rotary_style,
|
||||
epsilon=1e-6,
|
||||
)
|
||||
|
||||
# all_reduce
|
||||
|
||||
@@ -100,7 +100,9 @@ def recover_block_hpu(
|
||||
stop_flags, # hpu
|
||||
seq_lens_this_time, # hpu
|
||||
ori_seq_lens_encoder, # cpu
|
||||
ori_seq_lens_decoder, # cpu
|
||||
seq_lens_encoder, # hpu
|
||||
seq_lens_decoder, # hpu
|
||||
block_tables, # cpu
|
||||
free_list, # cpu
|
||||
free_list_len, # cpu
|
||||
@@ -116,6 +118,7 @@ def recover_block_hpu(
|
||||
for bid in range(recover_len.item()):
|
||||
recover_id = recover_block_list[bid].item()
|
||||
ori_seq_len_encoder = ori_seq_lens_encoder[recover_id].item()
|
||||
ori_seq_len_decoder = ori_seq_lens_decoder[recover_id].item()
|
||||
step_idx_now = step_idx[recover_id].item()
|
||||
seq_len = ori_seq_len_encoder + step_idx_now
|
||||
encoder_block_len = encoder_block_lens[recover_id].item()
|
||||
@@ -123,13 +126,13 @@ def recover_block_hpu(
|
||||
|
||||
seq_lens_this_time[recover_id] = seq_len
|
||||
seq_lens_encoder[recover_id] = seq_len
|
||||
seq_lens_decoder[recover_id] = ori_seq_len_decoder
|
||||
stop_flags[recover_id] = False
|
||||
|
||||
ori_free_list_len = free_list_len[0]
|
||||
free_list_len[0] -= decoder_used_len
|
||||
|
||||
for i in range(decoder_used_len):
|
||||
block_tables[recover_id, encoder_block_len + i] = free_list[ori_free_list_len - i - 1]
|
||||
free_list_len[0] -= decoder_used_len
|
||||
|
||||
recover_block(input_ids, first_token_ids, pre_ids, next_tokens, recover_id, ori_seq_len_encoder, step_idx_now)
|
||||
|
||||
@@ -160,13 +163,16 @@ def step_intel_hpu(share_inputs: Dict[str, paddle.Tensor], block_size: int, max_
|
||||
max_model_len,
|
||||
)
|
||||
if share_inputs["recover_lens"].item() > 0:
|
||||
logger.info("recover block hpu happening ...")
|
||||
recover_block_hpu(
|
||||
share_inputs["recover_block_list"],
|
||||
share_inputs["recover_lens"],
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["ori_seq_lens_encoder"],
|
||||
share_inputs["ori_seq_lens_decoder"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs["block_tables"],
|
||||
share_inputs["free_list"],
|
||||
share_inputs["free_list_len"],
|
||||
@@ -179,6 +185,7 @@ def step_intel_hpu(share_inputs: Dict[str, paddle.Tensor], block_size: int, max_
|
||||
share_inputs["first_token_ids"],
|
||||
)
|
||||
share_inputs["recover_lens"] = paddle.full([1], 0, dtype="int32").cpu()
|
||||
share_inputs["not_need_stop"][0] = True
|
||||
|
||||
|
||||
# TODO: replace rebuild_padding_v3 in CustomDevice if we adopt this version pp optimization
|
||||
@@ -477,9 +484,11 @@ class HPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = token_chunk_size
|
||||
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size
|
||||
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
|
||||
self.share_inputs["ori_seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
|
||||
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
|
||||
else:
|
||||
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
|
||||
self.share_inputs["ori_seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
|
||||
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
|
||||
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length
|
||||
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = length
|
||||
@@ -615,6 +624,7 @@ class HPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["infer_seed"] = paddle.full([max_num_seqs, 1], 0, dtype="int64").cpu()
|
||||
self.share_inputs["first_token_ids"] = paddle.full([max_num_seqs, 1], -1, dtype="int64")
|
||||
self.share_inputs["ori_seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32").cpu()
|
||||
self.share_inputs["ori_seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32").cpu()
|
||||
self.share_inputs["system_lens"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
self.share_inputs["system_ids"] = paddle.full([max_num_seqs, 1], -1, dtype="int32")
|
||||
|
||||
@@ -707,6 +717,7 @@ class HPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["seq_lens_decoder"],
|
||||
self.cache_config.block_size,
|
||||
self.model_config.dtype,
|
||||
self.scheduler_config.max_num_batched_tokens,
|
||||
)
|
||||
is_prompt = is_prompt.item() == 1 if is_prompt.item() > 0 else None
|
||||
if is_prompt is True:
|
||||
@@ -1023,7 +1034,18 @@ class HPUModelRunner(ModelRunnerBase):
|
||||
""" """
|
||||
pass
|
||||
|
||||
def update_warmup_inputs(self, requests, is_decode=False):
|
||||
def update_warmup_inputs(self, requests, is_decode=False, context_len=0) -> None:
|
||||
"""
|
||||
Update the shared input tensors for warmup requests.
|
||||
Args:
|
||||
requests (list): List of request dicts containing input data.
|
||||
is_decode (bool, optional): If True, sets up inputs for decode phase. Defaults to False.
|
||||
context_len (int, optional): The length of the context (prefix) to use for prefix caching during warmup.
|
||||
If >0, this value is used to set the decoder sequence length for prefill (prefix caching).
|
||||
Typically, set to the number of tokens in the prefix to be cached. Defaults to 0 (no prefix caching).
|
||||
This parameter affects the warmup behavior for prefix caching by controlling how much of the input
|
||||
is considered as context for the decoder during the prefill phase.
|
||||
"""
|
||||
for i in range(len(requests)):
|
||||
request = requests[i]
|
||||
idx = request["idx"]
|
||||
@@ -1038,7 +1060,7 @@ class HPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["step_idx"][idx : idx + 1] = 1
|
||||
else:
|
||||
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length
|
||||
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
|
||||
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = context_len
|
||||
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length
|
||||
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = length
|
||||
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = 0
|
||||
@@ -1073,35 +1095,48 @@ class HPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["not_need_stop"][0] = True
|
||||
|
||||
def warm_up_bucket(self) -> None:
|
||||
max_prefill_batch = 3 # Hard-Code in FastDeploy/fastdeploy/engine/config.py
|
||||
max_prefill_batch = int(os.getenv("MAX_PREFILL_NUM", "3"))
|
||||
warmup_max_model_len = min(int(os.environ.get("HPU_WARMUP_MODEL_LEN", 4096)), self.model_config.max_model_len)
|
||||
prefill_batchs = []
|
||||
prefill_batch_step = int(os.environ.get("BATCH_STEP_PREFILL", 1))
|
||||
prefill_seq_step = int(os.environ.get("SEQUENCE_STEP_PREFILL", 128))
|
||||
current_prefill_batch = prefill_batch_step
|
||||
while current_prefill_batch <= max_prefill_batch:
|
||||
prefill_batchs.append(int(current_prefill_batch))
|
||||
current_prefill_batch += prefill_batch_step
|
||||
|
||||
max_prefill_length = self.cache_config.block_size + warmup_max_model_len
|
||||
prefill_context_block_step = int(os.environ.get("CONTEXT_BLOCK_STEP_PREFILL", 1))
|
||||
for prefill_batch in prefill_batchs:
|
||||
for prefill_length in range(
|
||||
self.cache_config.block_size, max_prefill_length, self.cache_config.block_size
|
||||
for prefill_length_with_context in range(
|
||||
self.cache_config.block_size, max_prefill_length, prefill_seq_step
|
||||
):
|
||||
if prefill_length * prefill_batch > self.scheduler_config.max_num_batched_tokens:
|
||||
if prefill_length_with_context * prefill_batch > self.scheduler_config.max_num_batched_tokens:
|
||||
continue
|
||||
logger.info(f"Warmup prefill_batch: {prefill_batch}, prefill_length: {prefill_length} start")
|
||||
requests = [
|
||||
{
|
||||
"idx": i,
|
||||
"input_ids": [5] * (prefill_length - 1),
|
||||
"block_tables": list(range(prefill_length // self.cache_config.block_size)),
|
||||
"eos_token_ids": [2],
|
||||
}
|
||||
for i in range(prefill_batch)
|
||||
]
|
||||
self.update_warmup_inputs(requests, is_decode=False)
|
||||
self.execute_model()
|
||||
logger.info(f"warmup prefill_batch: {prefill_batch}, prefill_length: {prefill_length} done")
|
||||
for context_len in range(
|
||||
0, prefill_length_with_context, self.cache_config.block_size * prefill_context_block_step
|
||||
):
|
||||
prefill_length = prefill_length_with_context - context_len
|
||||
logger.info(
|
||||
f"Warmup prefill_batch: {prefill_batch}, prefill_length: {prefill_length}, context_len: {context_len} start"
|
||||
)
|
||||
requests = [
|
||||
{
|
||||
"idx": i,
|
||||
"input_ids": [5] * (prefill_length_with_context - context_len - 1),
|
||||
"block_tables": list(range(prefill_length_with_context // self.cache_config.block_size)),
|
||||
"eos_token_ids": [2],
|
||||
}
|
||||
for i in range(prefill_batch)
|
||||
]
|
||||
self.update_warmup_inputs(requests, is_decode=False, context_len=context_len)
|
||||
self.execute_model()
|
||||
logger.info(
|
||||
f"warmup prefill_batch: {prefill_batch}, prefill_length: {prefill_length}, context_len: {context_len} done"
|
||||
)
|
||||
# when disable prefix caching, only run context_len = 0 for each prefill_batch
|
||||
if not self.cache_config.enable_prefix_caching:
|
||||
break
|
||||
|
||||
decode_batchs = []
|
||||
decode_batch_step = int(os.environ.get("BATCH_STEP_DECODE", 4))
|
||||
|
||||
Reference in New Issue
Block a user