[MetaxGPU] adapt to the latest fastdeploy on metax gpu (#3492)

This commit is contained in:
Kane2011
2025-08-25 17:44:20 +08:00
committed by GitHub
parent c13c904971
commit 2ae7ab28d2
8 changed files with 338 additions and 115 deletions

View File

@@ -591,6 +591,12 @@ elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
if not os.listdir(json_dir): if not os.listdir(json_dir):
raise ValueError("Git clone nlohmann_json failed!") raise ValueError("Git clone nlohmann_json failed!")
sources = [ sources = [
"gpu_ops/update_inputs_v1.cu",
"gpu_ops/save_with_output_msg.cc",
"gpu_ops/get_output.cc",
"gpu_ops/get_output_msg_with_topk.cc",
"gpu_ops/save_output_msg_with_topk.cc",
"gpu_ops/transfer_output.cc",
"gpu_ops/save_with_output.cc", "gpu_ops/save_with_output.cc",
"gpu_ops/set_mask_value.cu", "gpu_ops/set_mask_value.cu",
"gpu_ops/set_value_by_flags.cu", "gpu_ops/set_value_by_flags.cu",

View File

@@ -0,0 +1,83 @@
# Metax GPU Installation for running ERNIE 4.5 Series Models
The following installation methods are available when your environment meets these requirements:
- Python >= 3.10
- Linux X86_64
Before starting, prepare a machine equipped with Enflame S60 accelerator cards. Requirements:
| Chip Type | Driver Version | KMD Version |
| :---: | :---: | :---: |
| MetaX C550 | 3.0.0.1 | 2.14.6 |
## 1. Pre-built Docker Installation (Recommended)
```shell
docker login --username=cr_temp_user --password=eyJpbnN0YW5jZUlkIjoiY3JpLXpxYTIzejI2YTU5M3R3M2QiLCJ0aW1lIjoiMTc1NTUxODEwODAwMCIsInR5cGUiOiJzdWIiLCJ1c2VySWQiOiIyMDcwOTQwMTA1NjYzNDE3OTIifQ:8226ca50ce5476c42062e24d3c465545de1c1780 cr.metax-tech.com && docker pull cr.metax-tech.com/public-library/maca-native:3.0.0.4-ubuntu20.04-amd64
```
## 2. paddlepaddle and custom device installation
```shell
1pip install paddlepaddle==3.0.0.dev20250729 -i https://www.paddlepaddle.org.cn/packages/nightly/cpu/
2pip install paddle-metax-gpu==3.0.0.dev20250807 -i https://www.paddlepaddle.org.cn/packages/nightly/maca/
```
## 3. Build Wheel from Source
Then clone the source code and build:
```shell
git clone https://github.com/PaddlePaddle/FastDeploy
cd FastDeploy
bash build.sh
```
The built packages will be in the ```FastDeploy/dist``` directory.
## 4. Environment Verification
After installation, verify the environment with this Python code:
```python
import paddle
from paddle.jit.marker import unified
# Verify GPU availability
paddle.utils.run_check()
# Verify FastDeploy custom operators compilation
from fastdeploy.model_executor.ops.gpu import beam_search_softmax
```
If the above code executes successfully, the environment is ready.
## 5. Demo
from fastdeploy import LLM, SamplingParams
prompts = [
"Hello. My name is",
]
sampling_params = SamplingParams(top_p=0.95, max_tokens=32, temperature=0.6)
llm = LLM(model="/root/model/ERNIE-4.5-21B-A3B-Paddle", tensor_parallel_size=1, max_model_len=256, engine_worker_queue_port=9135, quantization='wint8', static_decode_blocks=0, gpu_memory_utilization=0.9)
outputs = llm.generate(prompts, sampling_params)
print(f"Generated {len(outputs)} outputs")
print("=" * 50 + "\n")
for output in outputs:
prompt = output.prompt
generated_text = output.outputs.text
print(prompt)
print(generated_text)
print("-" * 50)
Output
INFO 2025-08-18 10:54:18,455 416822 engine.py[line:202] Waiting worker processes ready...
Loading Weights: 100%|█████████████████████████████████████████████████████████████████████████| 100/100 [03:33<00:00, 2.14s/it]
Loading Layers: 100%|██████████████████████████████████████████████████████████████████████████| 100/100 [00:18<00:00, 5.54it/s]
INFO 2025-08-18 10:58:16,149 416822 engine.py[line:247] Worker processes are launched with 240.08204197883606 seconds.
Processed prompts: 100%|███████████████████████| 1/1 [00:21<00:00, 21.84s/it, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
Generated 1 outputs
==================================================
Hello. My name is
Alice and I'm here to help you. What can I do for you today?
Hello Alice! I'm trying to organize a small party

View File

@@ -0,0 +1,82 @@
# 使用 Metax GPU C550 运行ERNIE 4.5 系列模型
FastDeploy在Metax C550上对ERNIE 4.5系列模型进行了深度适配和优化实现了推理入口和GPU的统一无需修改即可完成推理任务的迁移。
环境准备:
- Python >= 3.10
- Linux X86_64
| Chip Type | Driver Version | KMD Version |
| :---: | :---: | :---: |
| MetaX C550 | 3.0.0.1 | 2.14.6 |
## 1. 容器镜像获取
```shell
docker login --username=cr_temp_user --password=eyJpbnN0YW5jZUlkIjoiY3JpLXpxYTIzejI2YTU5M3R3M2QiLCJ0aW1lIjoiMTc1NTUxODEwODAwMCIsInR5cGUiOiJzdWIiLCJ1c2VySWQiOiIyMDcwOTQwMTA1NjYzNDE3OTIifQ:8226ca50ce5476c42062e24d3c465545de1c1780 cr.metax-tech.com && docker pull cr.metax-tech.com/public-library/maca-native:3.0.0.4-ubuntu20.04-amd64
```
## 2. 预安装
```shell
1pip install paddlepaddle==3.0.0.dev20250729 -i https://www.paddlepaddle.org.cn/packages/nightly/cpu/
2pip install paddle-metax-gpu==3.0.0.dev20250807 -i https://www.paddlepaddle.org.cn/packages/nightly/maca/
```
## 3. FastDeploy代码下载并编译
```shell
git clone https://github.com/PaddlePaddle/FastDeploy
cd FastDeploy
bash build.sh
```
The built packages will be in the ```FastDeploy/dist``` directory.
## 4. 环境验证
After installation, verify the environment with this Python code:
```python
import paddle
from paddle.jit.marker import unified
# Verify GPU availability
paddle.utils.run_check()
# Verify FastDeploy custom operators compilation
from fastdeploy.model_executor.ops.gpu import beam_search_softmax
```
If the above code executes successfully, the environment is ready.
## 5. 示例
from fastdeploy import LLM, SamplingParams
prompts = [
"Hello. My name is",
]
sampling_params = SamplingParams(top_p=0.95, max_tokens=32, temperature=0.6)
llm = LLM(model="/root/model/ERNIE-4.5-21B-A3B-Paddle", tensor_parallel_size=1, max_model_len=256, engine_worker_queue_port=9135, quantization='wint8', static_decode_blocks=0, gpu_memory_utilization=0.9)
outputs = llm.generate(prompts, sampling_params)
print(f"Generated {len(outputs)} outputs")
print("=" * 50 + "\n")
for output in outputs:
prompt = output.prompt
generated_text = output.outputs.text
print(prompt)
print(generated_text)
print("-" * 50)
输出:
INFO 2025-08-18 10:54:18,455 416822 engine.py[line:202] Waiting worker processes ready...
Loading Weights: 100%|█████████████████████████████████████████████████████████████████████████| 100/100 [03:33<00:00, 2.14s/it]
Loading Layers: 100%|██████████████████████████████████████████████████████████████████████████| 100/100 [00:18<00:00, 5.54it/s]
INFO 2025-08-18 10:58:16,149 416822 engine.py[line:247] Worker processes are launched with 240.08204197883606 seconds.
Processed prompts: 100%|███████████████████████| 1/1 [00:21<00:00, 21.84s/it, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
Generated 1 outputs
==================================================
Hello. My name is
Alice and I'm here to help you. What can I do for you today?
Hello Alice! I'm trying to organize a small party

View File

@@ -257,6 +257,7 @@ class FlashAttentionBackend(AttentionBackend):
out = paddle.add(paddle.multiply(qk, cos), paddle.multiply(rotate_half, sin)) out = paddle.add(paddle.multiply(qk, cos), paddle.multiply(rotate_half, sin))
return paddle.cast(out, qk.dtype) return paddle.cast(out, qk.dtype)
@paddle.no_grad()
def forward_native_backend( def forward_native_backend(
self, self,
q: paddle.Tensor, q: paddle.Tensor,
@@ -273,7 +274,7 @@ class FlashAttentionBackend(AttentionBackend):
# 1. 分离 encoder / decoder 的 mask # 1. 分离 encoder / decoder 的 mask
seq_lens_encoder = forward_meta.seq_lens_encoder.squeeze(-1) seq_lens_encoder = forward_meta.seq_lens_encoder.squeeze(-1)
seq_lens_decoder = forward_meta.seq_lens_decoder.squeeze(-1) seq_lens_decoder = forward_meta.seq_lens_decoder.squeeze(-1)
seq_lens_this_time = forward_meta.seq_lens_this_time.squeeze(-1) seq_lens_this_time = forward_meta.seq_lens_this_time
encoder_indices = [] encoder_indices = []
decoder_indices = [] decoder_indices = []

View File

@@ -45,6 +45,7 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
"""process_prequanted_weights""" """process_prequanted_weights"""
pass pass
@paddle.no_grad()
def create_weights(self, layer: nn.Layer, state_dict): def create_weights(self, layer: nn.Layer, state_dict):
""" """
Triton MoE create weight process. Triton MoE create weight process.
@@ -125,11 +126,12 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
) )
getattr(layer, scale_name).set_value(quanted_weight_scale) getattr(layer, scale_name).set_value(quanted_weight_scale)
@paddle.no_grad()
def apply( def apply(
self, self,
layer: nn.Layer, layer: nn.Layer,
x: paddle.Tensor, x: paddle.Tensor,
gate_out: paddle.Tensor, gate: nn.Layer,
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """
Triton compute Fused MoE. Triton compute Fused MoE.
@@ -141,6 +143,7 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
moe_intermediate_size = layer.moe_intermediate_size moe_intermediate_size = layer.moe_intermediate_size
hidden_size = layer.hidden_size hidden_size = layer.hidden_size
gate_out = gate(x.cast("float32"))
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out, gate_out,
layer.gate_correction_bias, layer.gate_correction_bias,

View File

@@ -52,6 +52,7 @@ elif current_platform.is_maca():
set_stop_value_multi_ends, set_stop_value_multi_ends,
step_paddle, step_paddle,
update_inputs, update_inputs,
update_inputs_v1,
) )
else: else:
from fastdeploy.model_executor.ops.gpu import ( from fastdeploy.model_executor.ops.gpu import (

View File

@@ -23,8 +23,11 @@ import paddle
from paddle import nn from paddle import nn
from paddleformers.utils.log import logger from paddleformers.utils.log import logger
from fastdeploy import envs
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request, RequestType from fastdeploy.engine.request import Request, RequestType
from fastdeploy.input.mm_processor import DataProcessor
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.graph_optimization.utils import ( from fastdeploy.model_executor.graph_optimization.utils import (
profile_run_guard, profile_run_guard,
sot_warmup_guard, sot_warmup_guard,
@@ -41,6 +44,7 @@ from fastdeploy.model_executor.layers.rotary_embedding import get_rope, get_rope
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.model_executor.layers.sample.sampler import Sampler, SpeculativeSampler from fastdeploy.model_executor.layers.sample.sampler import Sampler, SpeculativeSampler
from fastdeploy.model_executor.model_loader import get_model_loader from fastdeploy.model_executor.model_loader import get_model_loader
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
from fastdeploy.model_executor.ops.gpu import ( from fastdeploy.model_executor.ops.gpu import (
recover_decode_task, recover_decode_task,
set_value_by_flags_and_idx, set_value_by_flags_and_idx,
@@ -52,15 +56,7 @@ from fastdeploy.model_executor.pre_and_post_process import (
rebuild_padding, rebuild_padding,
step_cuda, step_cuda,
) )
from fastdeploy.platforms import current_platform from fastdeploy.spec_decode import MTPProposer, NgramProposer
if not current_platform.is_dcu():
from fastdeploy.spec_decode import MTPProposer, NgramProposer
from fastdeploy import envs
from fastdeploy.input.mm_processor import DataProcessor
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
from fastdeploy.worker.model_runner_base import ModelRunnerBase from fastdeploy.worker.model_runner_base import ModelRunnerBase
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput
@@ -130,7 +126,7 @@ class MetaxModelRunner(ModelRunnerBase):
shape=[self.parallel_config.max_num_seqs, 1], shape=[self.parallel_config.max_num_seqs, 1],
fill_value=4, fill_value=4,
dtype="int64", dtype="int64",
) ).cpu()
self.restore_chunked_prefill_request = dict() self.restore_chunked_prefill_request = dict()
# Initialize attention Backend # Initialize attention Backend
@@ -164,6 +160,7 @@ class MetaxModelRunner(ModelRunnerBase):
if self.speculative_method == "ngram": if self.speculative_method == "ngram":
self.proposer = NgramProposer(self.fd_config) self.proposer = NgramProposer(self.fd_config)
elif self.speculative_method == "mtp": elif self.speculative_method == "mtp":
self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer
self.proposer = MTPProposer( self.proposer = MTPProposer(
self.fd_config, self.fd_config,
self.get_model(), self.get_model(),
@@ -193,21 +190,23 @@ class MetaxModelRunner(ModelRunnerBase):
return self.guided_backend.get_logits_processor(schemata_key=schemata_key), schemata_key return self.guided_backend.get_logits_processor(schemata_key=schemata_key), schemata_key
def insert_tasks_v1(self, req_dicts: List[Request]): def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = None):
""" """
Process scheduler output tasks, used when ENABLE_V1_KVCACHE_SCHEDULER=1 Process scheduler output tasks, used when ENABLE_V1_KVCACHE_SCHEDULER=1
req_dict: A list of Request dict
num_running_requests: batch_size
""" """
# NOTE(luotingdan): Lazy initialize kv cache # Lazy initialize kv cache
if "caches" not in self.share_inputs: if "caches" not in self.share_inputs:
self.initialize_kv_cache() self.initialize_kv_cache()
req_len = len(req_dicts) req_len = len(req_dicts)
has_prefill_task = False has_prefill_task = False
has_decode_task = False
for i in range(req_len): for i in range(req_len):
request = req_dicts[i] request = req_dicts[i]
idx = request.idx idx = request.idx
if request.task_type.value == RequestType.PREFILL.value: # prefill task if request.task_type.value == RequestType.PREFILL.value: # prefill task
logger.debug(f"Handle prefill request {request} at idx {idx}")
prefill_start_index = request.prefill_start_index prefill_start_index = request.prefill_start_index
prefill_end_index = request.prefill_end_index prefill_end_index = request.prefill_end_index
length = prefill_end_index - prefill_start_index length = prefill_end_index - prefill_start_index
@@ -253,6 +252,11 @@ class MetaxModelRunner(ModelRunnerBase):
) )
input_ids = request.prompt_token_ids + request.output_token_ids input_ids = request.prompt_token_ids + request.output_token_ids
logger.debug(
f"Handle prefill request {request} at idx {idx}, "
f"{prefill_start_index=}, {prefill_end_index=}, "
f"need_prefilled_token_num={len(input_ids)}"
)
self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array( self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array(
input_ids[prefill_start_index:prefill_end_index] input_ids[prefill_start_index:prefill_end_index]
) )
@@ -264,7 +268,7 @@ class MetaxModelRunner(ModelRunnerBase):
) )
self.share_inputs["stop_flags"][idx : idx + 1] = False self.share_inputs["stop_flags"][idx : idx + 1] = False
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index self.share_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length self.seq_lens_this_time_buffer[idx : idx + 1] = length
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = 0 self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = 0
self.share_inputs["prompt_lens"][idx : idx + 1] = len(input_ids) self.share_inputs["prompt_lens"][idx : idx + 1] = len(input_ids)
@@ -281,22 +285,27 @@ class MetaxModelRunner(ModelRunnerBase):
self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
request.block_tables, dtype="int32" request.block_tables, dtype="int32"
) )
if self.share_inputs["is_block_step"][idx]: # has tasks to continue to decode
has_decode_task = True
continue continue
else: # preempted task else: # preempted task
logger.debug(f"Handle preempted request {request} at idx {idx}") logger.debug(f"Handle preempted request {request} at idx {idx}")
self.share_inputs["block_tables"][idx : idx + 1, :] = -1 self.share_inputs["block_tables"][idx : idx + 1, :] = -1
self.share_inputs["stop_flags"][idx : idx + 1] = True self.share_inputs["stop_flags"][idx : idx + 1] = True
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = 0 self.seq_lens_this_time_buffer[idx : idx + 1] = 0
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0 self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0 self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0
self.share_inputs["is_block_step"][idx : idx + 1] = False self.share_inputs["is_block_step"][idx : idx + 1] = False
continue continue
if len(request.eos_token_ids) < self.parallel_config.eos_tokens_lens: assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens
request.eos_token_ids.append(request.eos_token_ids[0])
self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1) self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1)
self.share_inputs["top_p"][idx : idx + 1] = request.get("top_p", 0.7) self.share_inputs["top_p"][idx : idx + 1] = request.get("top_p", 0.7)
self.share_inputs["top_k"][idx : idx + 1] = request.get("top_k", 0)
self.share_inputs["top_k_list"][idx] = request.get("top_k", 0)
self.share_inputs["min_p"][idx : idx + 1] = request.get("min_p", 0.0)
self.share_inputs["min_p_list"][idx] = request.get("min_p", 0.0)
self.share_inputs["temperature"][idx : idx + 1] = request.get("temperature", 0.95) self.share_inputs["temperature"][idx : idx + 1] = request.get("temperature", 0.95)
self.share_inputs["penalty_score"][idx : idx + 1] = request.get("repetition_penalty", 1.0) self.share_inputs["penalty_score"][idx : idx + 1] = request.get("repetition_penalty", 1.0)
self.share_inputs["frequency_score"][idx : idx + 1] = request.get("frequency_penalty", 0.0) self.share_inputs["frequency_score"][idx : idx + 1] = request.get("frequency_penalty", 0.0)
@@ -326,12 +335,15 @@ class MetaxModelRunner(ModelRunnerBase):
else: else:
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0 self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0
if has_prefill_task: if has_prefill_task or has_decode_task:
self.share_inputs["not_need_stop"][0] = True self.share_inputs["not_need_stop"][0] = True
self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests]
def insert_prefill_inputs(self, req_dicts: List[Request]): def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int = None):
""" """
Process inputs for prefill tasks and insert it to share_inputs buffer Process inputs for prefill tasks and insert it to share_inputs buffer
req_dict: A list of Request dict
num_running_requests: batch_size
TODO(gongshaotian): Refactor this func TODO(gongshaotian): Refactor this func
""" """
@@ -365,7 +377,7 @@ class MetaxModelRunner(ModelRunnerBase):
self.share_inputs["prompt_ids"][idx : idx + 1, :length] = np.array(request.prompt_token_ids) self.share_inputs["prompt_ids"][idx : idx + 1, :length] = np.array(request.prompt_token_ids)
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0 self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = length self.share_inputs["seq_lens_decoder"][idx : idx + 1] = length
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = 1 self.seq_lens_this_time_buffer[idx : idx + 1] = 1
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = 0 self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = 0
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = length self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = length
self.share_inputs["prompt_lens"][idx : idx + 1] = length self.share_inputs["prompt_lens"][idx : idx + 1] = length
@@ -377,7 +389,7 @@ class MetaxModelRunner(ModelRunnerBase):
request.draft_token_ids[0:num_prefill_send_token], request.draft_token_ids[0:num_prefill_send_token],
dtype="int64", dtype="int64",
) )
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = num_prefill_send_token self.seq_lens_this_time_buffer[idx : idx + 1] = num_prefill_send_token
else: else:
self.share_inputs["pre_ids"][idx : idx + 1] = -1 self.share_inputs["pre_ids"][idx : idx + 1] = -1
self.share_inputs["step_idx"][idx : idx + 1] = 0 self.share_inputs["step_idx"][idx : idx + 1] = 0
@@ -412,7 +424,7 @@ class MetaxModelRunner(ModelRunnerBase):
) )
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0) self.share_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0) self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = token_chunk_size self.seq_lens_this_time_buffer[idx : idx + 1] = token_chunk_size
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = token_chunk_size self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = token_chunk_size
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size self.share_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size
self.share_inputs["prompt_lens"][idx : idx + 1] = token_chunk_size self.share_inputs["prompt_lens"][idx : idx + 1] = token_chunk_size
@@ -430,7 +442,7 @@ class MetaxModelRunner(ModelRunnerBase):
else: else:
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0) self.share_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0) self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length self.seq_lens_this_time_buffer[idx : idx + 1] = length
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = length self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = length
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length
self.share_inputs["prompt_lens"][idx : idx + 1] = length self.share_inputs["prompt_lens"][idx : idx + 1] = length
@@ -453,12 +465,13 @@ class MetaxModelRunner(ModelRunnerBase):
else: else:
return default_value return default_value
if len(request.eos_token_ids) < self.parallel_config.eos_tokens_lens: assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens
request.eos_token_ids.append(request.eos_token_ids[0])
self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1) self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1)
self.share_inputs["top_p"][idx : idx + 1] = get_attr_from_request(request, "top_p", 0.7) self.share_inputs["top_p"][idx : idx + 1] = get_attr_from_request(request, "top_p", 0.7)
self.share_inputs["top_k"][idx : idx + 1] = request.get("top_k", 0) self.share_inputs["top_k"][idx : idx + 1] = request.get("top_k", 0)
self.share_inputs["top_k_list"][idx] = request.get("top_k", 0)
self.share_inputs["min_p"][idx : idx + 1] = request.get("min_p", 0.0) self.share_inputs["min_p"][idx : idx + 1] = request.get("min_p", 0.0)
self.share_inputs["min_p_list"][idx] = request.get("min_p", 0.0)
self.share_inputs["temperature"][idx : idx + 1] = get_attr_from_request(request, "temperature", 0.95) self.share_inputs["temperature"][idx : idx + 1] = get_attr_from_request(request, "temperature", 0.95)
self.share_inputs["penalty_score"][idx : idx + 1] = get_attr_from_request( self.share_inputs["penalty_score"][idx : idx + 1] = get_attr_from_request(
@@ -489,13 +502,15 @@ class MetaxModelRunner(ModelRunnerBase):
request.block_tables, dtype="int32" request.block_tables, dtype="int32"
) )
if request.get("bad_words_token_ids") is not None: if request.get("bad_words_token_ids") is not None and len(request.get("bad_words_token_ids")) > 0:
bad_words_len = len(request.get("bad_words_token_ids")) bad_words_len = len(request.get("bad_words_token_ids"))
if bad_words_len > 0:
self.share_inputs["bad_tokens_len"][idx : idx + 1] = bad_words_len self.share_inputs["bad_tokens_len"][idx : idx + 1] = bad_words_len
self.share_inputs["bad_tokens"][idx : idx + 1, :bad_words_len] = np.array( self.share_inputs["bad_tokens"][idx : idx + 1, :bad_words_len] = np.array(
request.get("bad_words_token_ids"), dtype="int64" request.get("bad_words_token_ids"), dtype="int64"
) )
else:
self.share_inputs["bad_tokens_len"][idx : idx + 1] = 1
self.share_inputs["bad_tokens"][idx : idx + 1, :] = np.array([-1], dtype="int64")
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None: if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
stop_seqs_num = len(request.get("stop_seqs_len")) stop_seqs_num = len(request.get("stop_seqs_len"))
@@ -514,8 +529,10 @@ class MetaxModelRunner(ModelRunnerBase):
self.share_inputs["not_need_stop"][0] = True self.share_inputs["not_need_stop"][0] = True
self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests]
if self.speculative_method in ["mtp"]: if self.speculative_method in ["mtp"]:
self.proposer.insert_prefill_inputs(req_dicts) self.proposer.insert_prefill_inputs(req_dicts, num_running_requests)
def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int): def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int):
"""Set dummy prefill inputs to share_inputs""" """Set dummy prefill inputs to share_inputs"""
@@ -525,6 +542,12 @@ class MetaxModelRunner(ModelRunnerBase):
num_tokens // batch_size, num_tokens // batch_size,
self.parallel_config.max_model_len - max_dec_len, self.parallel_config.max_model_len - max_dec_len,
) )
# When the full length is too large, DeepEP's buffer size will not be enough to cause the result to appear nan.
# Figure out the accurate buffer size of DeepEP.
if self.fd_config.parallel_config.enable_expert_parallel:
full_length = min(full_length, 32)
input_length = int(full_length * self.cache_config.kv_cache_ratio) input_length = int(full_length * self.cache_config.kv_cache_ratio)
block_num = ( block_num = (
input_length + self.cache_config.block_size - 1 input_length + self.cache_config.block_size - 1
@@ -534,8 +557,10 @@ class MetaxModelRunner(ModelRunnerBase):
idx = i idx = i
self.share_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length) self.share_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
self.share_inputs["prompt_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length) self.share_inputs["prompt_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
self.share_inputs["eos_token_id"][:] = np.array([2], dtype="int64").reshape(-1, 1) self.share_inputs["eos_token_id"][:] = np.array(
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = input_length [2] * self.model_config.eos_tokens_lens, dtype="int64"
).reshape(-1, 1)
self.seq_lens_this_time_buffer[idx : idx + 1] = input_length
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = input_length self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = input_length
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = input_length self.share_inputs["seq_lens_encoder"][idx : idx + 1] = input_length
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0 self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
@@ -553,6 +578,7 @@ class MetaxModelRunner(ModelRunnerBase):
self.share_inputs["block_tables"][idx : idx + 1, :block_num] = np.arange( self.share_inputs["block_tables"][idx : idx + 1, :block_num] = np.arange(
idx * block_num, (idx + 1) * block_num, 1 idx * block_num, (idx + 1) * block_num, 1
) )
self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer
def _init_share_inputs(self, max_num_seqs: int): def _init_share_inputs(self, max_num_seqs: int):
""" """
@@ -568,18 +594,20 @@ class MetaxModelRunner(ModelRunnerBase):
) )
self.share_inputs["input_ids"] = paddle.full( self.share_inputs["input_ids"] = paddle.full(
[max_num_seqs, self.parallel_config.max_model_len], [max_num_seqs, self.parallel_config.max_model_len],
self.parallel_config.pad_token_id, self.model_config.pad_token_id,
dtype="int64", dtype="int64",
) )
self.share_inputs["prompt_ids"] = paddle.full( self.share_inputs["prompt_ids"] = paddle.full(
[max_num_seqs, self.parallel_config.max_model_len], [max_num_seqs, self.parallel_config.max_model_len],
self.parallel_config.pad_token_id, self.model_config.pad_token_id,
dtype="int64", dtype="int64",
) )
self.share_inputs["eos_token_id"] = paddle.full([self.parallel_config.eos_tokens_lens, 1], 0, dtype="int64") self.share_inputs["eos_token_id"] = paddle.full([self.model_config.eos_tokens_lens, 1], 0, dtype="int64")
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], self.model_config.top_p, dtype="float32") self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], self.model_config.top_p, dtype="float32")
self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int64") self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
self.share_inputs["top_k_list"] = [0] * max_num_seqs
self.share_inputs["min_p"] = paddle.full([max_num_seqs, 1], 0.0, dtype="float32") self.share_inputs["min_p"] = paddle.full([max_num_seqs, 1], 0.0, dtype="float32")
self.share_inputs["min_p_list"] = [0.0] * max_num_seqs
self.share_inputs["temperature"] = paddle.full( self.share_inputs["temperature"] = paddle.full(
[max_num_seqs, 1], self.model_config.temperature, dtype="float32" [max_num_seqs, 1], self.model_config.temperature, dtype="float32"
) )
@@ -603,7 +631,9 @@ class MetaxModelRunner(ModelRunnerBase):
self.share_inputs["max_length"] = paddle.full( self.share_inputs["max_length"] = paddle.full(
[max_num_seqs, 1], self.model_config.max_model_len, dtype="int64" [max_num_seqs, 1], self.model_config.max_model_len, dtype="int64"
) )
self.share_inputs["seq_lens_this_time"] = paddle.full(max_num_seqs, 0, dtype="int32") self.seq_lens_this_time_buffer = paddle.full([max_num_seqs, 1], 0, dtype="int32")
if self.fd_config.parallel_config.enable_expert_parallel:
self.share_inputs["seq_lens_this_time"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["step_seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["step_seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
@@ -626,7 +656,7 @@ class MetaxModelRunner(ModelRunnerBase):
self.share_inputs["need_block_list"] = paddle.full([max_num_seqs], -1, dtype="int32") self.share_inputs["need_block_list"] = paddle.full([max_num_seqs], -1, dtype="int32")
self.share_inputs["need_block_len"] = paddle.full([1], 0, dtype="int32") self.share_inputs["need_block_len"] = paddle.full([1], 0, dtype="int32")
self.share_inputs["used_list_len"] = paddle.full([max_num_seqs], 0, dtype="int32") self.share_inputs["used_list_len"] = paddle.full([max_num_seqs], 0, dtype="int32")
self.share_inputs["infer_seed"] = paddle.full([max_num_seqs, 1], 0, dtype="int64") self.share_inputs["infer_seed"] = paddle.full([max_num_seqs, 1], 0, dtype="int64").cpu()
self.share_inputs["first_token_ids"] = paddle.full([max_num_seqs, 1], -1, dtype="int64") self.share_inputs["first_token_ids"] = paddle.full([max_num_seqs, 1], -1, dtype="int64")
self.share_inputs["ori_seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["ori_seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["system_lens"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["system_lens"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
@@ -637,10 +667,11 @@ class MetaxModelRunner(ModelRunnerBase):
0, 0,
dtype="int64", dtype="int64",
) )
self.share_inputs["cum_offsets"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["batch_id_per_token"] = paddle.full(
self.share_inputs["batch_id_per_token"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") [max_num_seqs * self.parallel_config.max_model_len, 1], 0, dtype="int32"
self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") )
self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs + 1, 1], 0, dtype="int32")
self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs + 1, 1], 0, dtype="int32")
# Declare AttentionBackend buffers # Declare AttentionBackend buffers
self.share_inputs["decoder_batch_ids"] = None self.share_inputs["decoder_batch_ids"] = None
@@ -758,7 +789,6 @@ class MetaxModelRunner(ModelRunnerBase):
# Remove padding # Remove padding
( (
ids_remove_padding, ids_remove_padding,
cum_offsets,
batch_id_per_token, batch_id_per_token,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_k, cu_seqlens_k,
@@ -774,7 +804,6 @@ class MetaxModelRunner(ModelRunnerBase):
) )
self.share_inputs["ids_remove_padding"].copy_(ids_remove_padding, False) self.share_inputs["ids_remove_padding"].copy_(ids_remove_padding, False)
self.share_inputs["cum_offsets"].copy_(cum_offsets, False)
self.share_inputs["batch_id_per_token"].copy_(batch_id_per_token, False) self.share_inputs["batch_id_per_token"].copy_(batch_id_per_token, False)
self.share_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False) self.share_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False)
self.share_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False) self.share_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False)
@@ -795,7 +824,10 @@ class MetaxModelRunner(ModelRunnerBase):
temperature=self.share_inputs["temperature"], temperature=self.share_inputs["temperature"],
top_p=self.share_inputs["top_p"], top_p=self.share_inputs["top_p"],
top_k=self.share_inputs["top_k"], top_k=self.share_inputs["top_k"],
top_k_list=self.share_inputs["top_k_list"],
min_p=self.share_inputs["min_p"], min_p=self.share_inputs["min_p"],
min_p_list=self.share_inputs["min_p_list"],
seed=self.share_inputs["infer_seed"],
step_idx=self.share_inputs["step_idx"], step_idx=self.share_inputs["step_idx"],
pre_token_ids=self.share_inputs["pre_ids"], pre_token_ids=self.share_inputs["pre_ids"],
prompt_ids=self.share_inputs["prompt_ids"], prompt_ids=self.share_inputs["prompt_ids"],
@@ -933,7 +965,7 @@ class MetaxModelRunner(ModelRunnerBase):
self.share_inputs["caches"] = list(cache_kvs.values()) self.share_inputs["caches"] = list(cache_kvs.values())
for value in cache_kvs.values(): for value in cache_kvs.values():
del value del value
paddle.device.cuda.empty_cache() # paddle.device.empty_cache()
def initialize_attn_backend(self) -> None: def initialize_attn_backend(self) -> None:
""" """
@@ -1023,7 +1055,7 @@ class MetaxModelRunner(ModelRunnerBase):
hidden_states = rebuild_padding( hidden_states = rebuild_padding(
model_output, model_output,
self.share_inputs["cum_offsets"], self.share_inputs["cu_seqlens_q"],
self.share_inputs["seq_lens_this_time"], self.share_inputs["seq_lens_this_time"],
self.share_inputs["seq_lens_decoder"], self.share_inputs["seq_lens_decoder"],
self.share_inputs["seq_lens_encoder"], self.share_inputs["seq_lens_encoder"],
@@ -1247,6 +1279,7 @@ class MetaxModelRunner(ModelRunnerBase):
def execute_model( def execute_model(
self, self,
model_forward_batch: Optional[List[Request]] = None, model_forward_batch: Optional[List[Request]] = None,
num_running_requests: int = None,
) -> Optional[ModelRunnerOutput]: ) -> Optional[ModelRunnerOutput]:
""" """
The Entrance of model execute. The Entrance of model execute.
@@ -1255,6 +1288,7 @@ class MetaxModelRunner(ModelRunnerBase):
class at the server level, which is too granular for ModelRunner. class at the server level, which is too granular for ModelRunner.
We plan to replace it with 'ModelForwardBatch'. We plan to replace it with 'ModelForwardBatch'.
intermediate_tensors: intermediate_tensors:
num_running_requests: batch_size
""" """
# 1. Prepare inputs of model and sampler. # 1. Prepare inputs of model and sampler.
skip_idx_list = self._get_skip_idx(model_forward_batch) skip_idx_list = self._get_skip_idx(model_forward_batch)
@@ -1286,7 +1320,7 @@ class MetaxModelRunner(ModelRunnerBase):
) )
hidden_states = rebuild_padding( hidden_states = rebuild_padding(
model_output, model_output,
self.share_inputs["cum_offsets"], self.share_inputs["cu_seqlens_q"],
self.share_inputs["seq_lens_this_time"], self.share_inputs["seq_lens_this_time"],
self.share_inputs["seq_lens_decoder"], self.share_inputs["seq_lens_decoder"],
self.share_inputs["seq_lens_encoder"], self.share_inputs["seq_lens_encoder"],
@@ -1356,8 +1390,8 @@ class MetaxModelRunner(ModelRunnerBase):
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None), accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None), enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None),
think_end_id=(self.model_config.think_end_id if self.enable_mm else -1), think_end_id=(self.model_config.think_end_id if self.enable_mm else -1),
need_think_end=(self.share_inputs["need_think_end"] if self.enable_mm else None), need_think_end=(self.share_inputs["need_think_end"][:num_running_requests] if self.enable_mm else None),
reasoning_index=(self.share_inputs["reasoning_index"] if self.enable_mm else None), reasoning_index=(self.share_inputs["reasoning_index"][:num_running_requests] if self.enable_mm else None),
stop_token_ids=self.share_inputs["stop_seqs"], stop_token_ids=self.share_inputs["stop_seqs"],
stop_seqs_len=self.share_inputs["stop_seqs_len"], stop_seqs_len=self.share_inputs["stop_seqs_len"],
) )
@@ -1397,6 +1431,9 @@ class MetaxModelRunner(ModelRunnerBase):
self._update_chunked_prefill(model_forward_batch) self._update_chunked_prefill(model_forward_batch)
self._add_cache(model_forward_batch) self._add_cache(model_forward_batch)
self.seq_lens_this_time_buffer[:num_running_requests].copy_(
self.share_inputs["seq_lens_this_time"][:num_running_requests], False
)
return None return None
def _add_cache(self, model_forward_batch) -> None: def _add_cache(self, model_forward_batch) -> None:
@@ -1528,7 +1565,7 @@ class MetaxModelRunner(ModelRunnerBase):
""" " Dynamic model loader use to clear parameters use for RL""" """ " Dynamic model loader use to clear parameters use for RL"""
self.dynamic_weight_manager.clear_parameters(pid) self.dynamic_weight_manager.clear_parameters(pid)
self.clear_cache() self.clear_cache()
paddle.device.cuda.empty_cache() # paddle.device.empty_cache()
self.dynamic_weight_manager._log_memory("dynamic weight manager clear all memory") self.dynamic_weight_manager._log_memory("dynamic weight manager clear all memory")
def update_parameters(self, pid): def update_parameters(self, pid):

View File

@@ -20,6 +20,7 @@ import time
from typing import List, Optional from typing import List, Optional
import paddle import paddle
import pymxsml
from paddle import nn from paddle import nn
from fastdeploy import envs from fastdeploy import envs
@@ -60,7 +61,7 @@ class MetaxWorker(WorkerBase):
paddle.set_default_dtype(self.parallel_config.dtype) paddle.set_default_dtype(self.parallel_config.dtype)
gc.collect() gc.collect()
paddle.device.cuda.empty_cache()
else: else:
raise RuntimeError(f"Not support device type: {self.device_config.device}") raise RuntimeError(f"Not support device type: {self.device_config.device}")
@@ -92,8 +93,12 @@ class MetaxWorker(WorkerBase):
You may limit the usage of GPU memory You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter. by adjusting the `gpu_memory_utilization` parameter.
""" """
"""Will implement later"""
# temporary fix kvcache size to test
fd_kvache_mem = os.getenv("FD_METAX_KVCACHE_MEM")
if fd_kvache_mem is not None:
return int(float(fd_kvache_mem) * 1024**3)
else:
# 1. Record memory state before profile run # 1. Record memory state before profile run
start_time = time.perf_counter() start_time = time.perf_counter()
Gb = 1024**3 Gb = 1024**3
@@ -110,8 +115,6 @@ class MetaxWorker(WorkerBase):
if os.getenv("MACA_VISIBLE_DEVICES") is not None: if os.getenv("MACA_VISIBLE_DEVICES") is not None:
device_id = int(os.getenv("MACA_VISIBLE_DEVICES").split(",")[device_id]) device_id = int(os.getenv("MACA_VISIBLE_DEVICES").split(",")[device_id])
import pymxsml
pymxsml.mxSmlInit() pymxsml.mxSmlInit()
info = pymxsml.mxSmlGetMemoryInfo(device_id) info = pymxsml.mxSmlGetMemoryInfo(device_id)
before_run_meminfo_total = info.vramTotal * 1024 before_run_meminfo_total = info.vramTotal * 1024
@@ -144,10 +147,8 @@ class MetaxWorker(WorkerBase):
after_run_meminfo_free = after_run_meminfo_total - after_run_meminfo_used after_run_meminfo_free = after_run_meminfo_total - after_run_meminfo_used
available_kv_cache_memory = ( available_kv_cache_memory = (
after_run_meminfo_total * self.cache_config.gpu_memory_utilization after_run_meminfo_free - paddle_peak_increase
- after_run_meminfo_used ) * self.cache_config.gpu_memory_utilization
- paddle_peak_increase
)
available_kv_cache_memory += model_block_memory_used * self.parallel_config.total_block_num available_kv_cache_memory += model_block_memory_used * self.parallel_config.total_block_num
end_time = time.perf_counter() end_time = time.perf_counter()
@@ -180,19 +181,28 @@ class MetaxWorker(WorkerBase):
def execute_model( def execute_model(
self, self,
model_forward_batch: Optional[List[Request]] = None, model_forward_batch: Optional[List[Request]] = None,
num_running_request: int = None,
) -> Optional[ModelRunnerOutput]: ) -> Optional[ModelRunnerOutput]:
""" """ """ """
output = self.model_runner.execute_model(model_forward_batch) output = self.model_runner.execute_model(model_forward_batch, num_running_request)
return output return output
def preprocess_new_task(self, req_dicts: List[Request]) -> None: def preprocess_new_task(self, req_dicts: List[Request], num_running_requests: int) -> None:
"""Process new requests and then start the decode loop """Process new requests and then start the decode loop
and workers and modelrunners should not perceive it. and workers and modelrunners should not perceive it.
""" """
if envs.ENABLE_V1_KVCACHE_SCHEDULER: if envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.model_runner.insert_tasks_v1(req_dicts=req_dicts) self.model_runner.insert_tasks_v1(req_dicts=req_dicts, num_running_requests=num_running_requests)
else: else:
self.model_runner.insert_prefill_inputs(req_dicts=req_dicts) self.model_runner.insert_prefill_inputs(req_dicts=req_dicts, num_running_requests=num_running_requests)
def graph_optimize_and_warm_up_model(self) -> None:
"""
Perform the warm-up and the graph optimization
"""
if self.model_runner.graph_opt_level >= 1:
self.model_runner.sot_warmup()
# Todo Triger cuda grpah capture.
def check_health(self) -> bool: def check_health(self) -> bool:
""" """ """ """