mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-07 17:41:52 +08:00
[XPU] support XPU VL model inference (#4030)
* [XPU] support XPU VL model inference * fix image op import and device check * rebase develop * fix perf
This commit is contained in:
@@ -25,6 +25,7 @@ from paddle import nn
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.engine.request import Request, RequestType
|
||||
from fastdeploy.input.ernie4_5_vl_processor import DataProcessor
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta, XPUForwardMeta
|
||||
from fastdeploy.model_executor.graph_optimization.utils import (
|
||||
profile_run_guard,
|
||||
@@ -34,10 +35,11 @@ from fastdeploy.model_executor.layers.attention import get_attention_backend
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.rotary_embedding import get_rope
|
||||
from fastdeploy.model_executor.layers.rotary_embedding import get_rope, get_rope_3d
|
||||
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
|
||||
from fastdeploy.model_executor.layers.sample.sampler import Sampler
|
||||
from fastdeploy.model_executor.model_loader import get_model_loader
|
||||
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
|
||||
from fastdeploy.model_executor.ops.xpu import (
|
||||
adjust_batch,
|
||||
get_infer_param,
|
||||
@@ -201,6 +203,45 @@ def xpu_post_process(
|
||||
update_inputs,
|
||||
)
|
||||
|
||||
# handle vl:
|
||||
if model_output.enable_thinking:
|
||||
exists_think_end = sampled_token_ids == model_output.think_end_id
|
||||
paddle.assign(
|
||||
paddle.where(
|
||||
exists_think_end,
|
||||
model_output.need_think_end - 1,
|
||||
model_output.need_think_end,
|
||||
),
|
||||
model_output.need_think_end,
|
||||
)
|
||||
|
||||
paddle.assign(
|
||||
paddle.where(
|
||||
model_output.need_think_end.cast("bool"),
|
||||
model_output.reasoning_index - 1,
|
||||
model_output.reasoning_index,
|
||||
),
|
||||
model_output.reasoning_index,
|
||||
)
|
||||
|
||||
stop_wo_think = (
|
||||
(sampled_token_ids == model_output.eos_token_id.T).any(axis=1, keepdim=True)
|
||||
| (model_output.reasoning_index == 0)
|
||||
) & (model_output.need_think_end > 0)
|
||||
sampled_token_ids = paddle.where(
|
||||
stop_wo_think,
|
||||
model_output.think_end_id,
|
||||
sampled_token_ids,
|
||||
)
|
||||
paddle.assign(
|
||||
paddle.where(
|
||||
stop_wo_think,
|
||||
model_output.need_think_end - 1,
|
||||
model_output.need_think_end,
|
||||
),
|
||||
model_output.need_think_end,
|
||||
)
|
||||
|
||||
# 1. Set stop value
|
||||
paddle.assign(
|
||||
paddle.where(
|
||||
@@ -340,11 +381,36 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
|
||||
def __init__(self, fd_config: FDConfig, device: str, rank: int, local_rank: int):
|
||||
super().__init__(fd_config=fd_config, device=device)
|
||||
self.enable_mm = self.model_config.enable_mm
|
||||
self.rank = rank
|
||||
self.local_rank = local_rank
|
||||
self.enable_early_stop = self.fd_config.early_stop_config.enable_early_stop
|
||||
|
||||
# VL model config:
|
||||
if self.enable_mm:
|
||||
self._init_image_preprocess()
|
||||
|
||||
self.amp_black = [
|
||||
"reduce_sum",
|
||||
"c_softmax_with_cross_entropy",
|
||||
"elementwise_div",
|
||||
"sin",
|
||||
"cos",
|
||||
"sort",
|
||||
"multinomial",
|
||||
]
|
||||
self.amp_white = [
|
||||
"lookup_table",
|
||||
"lookup_table_v2",
|
||||
"flash_attn",
|
||||
"matmul",
|
||||
"matmul_v2",
|
||||
"fused_gemm_epilogue",
|
||||
]
|
||||
|
||||
# Sampler
|
||||
self.sampler = Sampler()
|
||||
# TODU(lilujia): sync with GPU
|
||||
self.sampler = Sampler(fd_config)
|
||||
|
||||
# Lazy initialize kv cache after model loading
|
||||
# self.kv_caches: list[paddle.Tensor] = []
|
||||
@@ -364,18 +430,28 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
).cpu()
|
||||
|
||||
# Initialize attention Backend
|
||||
# Note(gonshaotian): Currently, all attention layers share one attention backend instance.
|
||||
# NOTE(gonshaotian): Currently, all attention layers share one attention backend instance.
|
||||
# In the future, we will expand it as a list.
|
||||
self.attn_backends: list[AttentionBackend] = []
|
||||
|
||||
self.initialize_attn_backend()
|
||||
|
||||
# Forward meta store the global meta information of the forward
|
||||
self.forward_meta: ForwardMeta = None
|
||||
|
||||
def exist_prefill(self):
|
||||
"""
|
||||
check whether prefill stage exist
|
||||
"""
|
||||
if int(paddle.max(self.share_inputs["seq_lens_encoder"])) != 0:
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
|
||||
def insert_tasks_v1(self, req_dicts: List[Request]):
|
||||
"""
|
||||
Process scheduler output tasks, used when ENABLE_V1_KVCACHE_SCHEDULER=1
|
||||
req_dict: A list of Request dict
|
||||
num_running_requests: batch_size
|
||||
"""
|
||||
# NOTE(luotingdan): Lazy initialize kv cache
|
||||
if "caches" not in self.share_inputs:
|
||||
@@ -388,11 +464,54 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
request = req_dicts[i]
|
||||
idx = request.idx
|
||||
if request.task_type.value == RequestType.PREFILL.value: # prefill task
|
||||
logger.debug(f"Handle prefill request {request} at idx {idx}")
|
||||
prefill_start_index = request.prefill_start_index
|
||||
prefill_end_index = request.prefill_end_index
|
||||
length = prefill_end_index - prefill_start_index
|
||||
input_ids = request.prompt_token_ids + request.output_token_ids
|
||||
if self.enable_mm:
|
||||
inputs = request.multimodal_inputs
|
||||
if request.with_image:
|
||||
vision_inputs = {}
|
||||
vision_inputs["input_ids"] = paddle.to_tensor(
|
||||
inputs["input_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64
|
||||
)
|
||||
vision_inputs["token_type_ids"] = paddle.to_tensor(
|
||||
inputs["token_type_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64
|
||||
)
|
||||
vision_inputs["image_type_ids"] = paddle.to_tensor(
|
||||
inputs["image_type_ids"][request.image_type_ids_start : request.image_type_ids_end],
|
||||
dtype=paddle.int64,
|
||||
)
|
||||
vision_inputs["images"] = paddle.to_tensor(
|
||||
inputs["images"][request.image_start : request.image_end], dtype="uint8"
|
||||
)
|
||||
vision_inputs["grid_thw"] = paddle.to_tensor(
|
||||
inputs["grid_thw"][request.num_image_start : request.num_image_end], dtype="int64"
|
||||
)
|
||||
self.share_inputs["image_features"] = self.extract_vision_features(vision_inputs)
|
||||
else:
|
||||
self.share_inputs["image_features"] = None
|
||||
|
||||
if inputs["position_ids"] is not None:
|
||||
position_ids = paddle.to_tensor(
|
||||
request.multimodal_inputs["position_ids"],
|
||||
dtype="int64",
|
||||
).unsqueeze([0])
|
||||
else:
|
||||
position_ids = None
|
||||
|
||||
enable_thinking = request.get("enable_thinking", True)
|
||||
enable_thinking = enable_thinking if enable_thinking is not None else True
|
||||
self.share_inputs["enable_thinking"][:] = enable_thinking
|
||||
self.share_inputs["need_think_end"][idx : idx + 1, :] = 1 if enable_thinking else 0
|
||||
self.share_inputs["reasoning_index"][idx : idx + 1, :] = request.get("reasoning_max_tokens", 2048)
|
||||
self.share_inputs["rope_emb"][idx : idx + 1, :] = self.prepare_rope3d(
|
||||
position_ids, request.get("max_tokens", 2048)
|
||||
)
|
||||
|
||||
if len(request.output_token_ids) == 0:
|
||||
input_ids = request.prompt_token_ids
|
||||
else:
|
||||
input_ids = request.prompt_token_ids + request.output_token_ids
|
||||
logger.debug(
|
||||
f"Handle prefill request {request} at idx {idx} prefill_start_index {prefill_start_index} prefill_end_index {prefill_end_index} need_prefilled_token_num {len(input_ids)}"
|
||||
)
|
||||
@@ -475,41 +594,86 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
|
||||
stop_seqs_num = len(request.get("stop_seqs_len"))
|
||||
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):
|
||||
request.stop_seqs_len.append(0)
|
||||
self.share_inputs["stop_seqs_len"][:] = np.array(request.stop_seqs_len, dtype="int32")
|
||||
self.share_inputs["stop_seqs"][:stop_seqs_num, : len(request.get("stop_token_ids")[0])] = np.array(
|
||||
request.get("stop_token_ids"), dtype="int64"
|
||||
request.sampling_params.stop_seqs_len.append(0)
|
||||
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = np.array(
|
||||
request.sampling_params.stop_seqs_len, dtype="int32"
|
||||
)
|
||||
self.share_inputs["stop_seqs"][
|
||||
idx : idx + 1, :stop_seqs_num, : len(request.get("stop_token_ids")[0])
|
||||
] = np.array(request.get("stop_token_ids"), dtype="int64")
|
||||
else:
|
||||
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0
|
||||
|
||||
if has_prefill_task or has_decode_task:
|
||||
self.share_inputs["not_need_stop"][0] = True
|
||||
|
||||
def process_prefill_inputs(self, req_dicts: List[Request]):
|
||||
def insert_prefill_inputs(self, req_dicts: List[Request]):
|
||||
"""Process inputs for prefill tasks and update share_inputs buffer"""
|
||||
req_len = len(req_dicts)
|
||||
for i in range(req_len):
|
||||
request = req_dicts[i]
|
||||
idx = request.idx
|
||||
length = request.prompt_token_ids_len
|
||||
length = len(request.prompt_token_ids)
|
||||
assert length > 0, "The prompt requested must not be empty."
|
||||
self.share_inputs["pre_ids"][idx : idx + 1] = -1
|
||||
self.share_inputs["step_idx"][idx : idx + 1] = 0
|
||||
self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array(request.prompt_token_ids)
|
||||
self.share_inputs["prompt_ids"][idx : idx + 1, :length] = np.array(request.prompt_token_ids)
|
||||
if self.enable_mm:
|
||||
inputs = self._preprocess_mm_task(request.multimodal_inputs)
|
||||
if inputs.get("images") is not None:
|
||||
self.share_inputs["image_features"] = self.extract_vision_features(inputs)
|
||||
else:
|
||||
# Compatible with the situation that lacks images and videos
|
||||
self.share_inputs["image_features"] = None
|
||||
position_ids = inputs["position_ids"]
|
||||
length = inputs["input_ids"].shape[1]
|
||||
self.share_inputs["input_ids"][idx : idx + 1, :length] = inputs["input_ids"]
|
||||
else:
|
||||
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
|
||||
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
|
||||
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length
|
||||
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = length
|
||||
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length
|
||||
self.share_inputs["prompt_lens"][idx : idx + 1] = length
|
||||
|
||||
if self.enable_mm:
|
||||
enable_thinking = request.get("enable_thinking", True)
|
||||
enable_thinking = enable_thinking if enable_thinking is not None else True
|
||||
self.share_inputs["enable_thinking"][:] = enable_thinking
|
||||
self.share_inputs["need_think_end"][idx : idx + 1, :] = 1 if enable_thinking else 0
|
||||
self.share_inputs["reasoning_index"][idx : idx + 1, :] = request.get("reasoning_max_tokens", 2048)
|
||||
self.share_inputs["rope_emb"][idx : idx + 1, :] = self.prepare_rope3d(
|
||||
position_ids, request.get("max_tokens", 2048)
|
||||
)
|
||||
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
|
||||
|
||||
def get_attr_from_request(request, attr, default_value=None):
|
||||
res = request.get(attr, default_value)
|
||||
if res is not None:
|
||||
return res
|
||||
else:
|
||||
return default_value
|
||||
|
||||
assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens
|
||||
self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1)
|
||||
self.share_inputs["pre_ids"][idx : idx + 1] = -1
|
||||
self.share_inputs["top_p"][idx : idx + 1] = request.get("top_p", 0.7)
|
||||
self.share_inputs["top_p"][idx : idx + 1] = get_attr_from_request(request, "top_p", 0.7)
|
||||
self.share_inputs["top_k"][idx : idx + 1] = request.get("top_k", 0)
|
||||
self.share_inputs["top_k_list"][idx] = request.get("top_k", 0)
|
||||
self.share_inputs["min_p"][idx : idx + 1] = request.get("min_p", 0.0)
|
||||
self.share_inputs["min_p_list"][idx] = request.get("min_p", 0.0)
|
||||
self.share_inputs["temperature"][idx : idx + 1] = request.get("temperature", 0.95)
|
||||
self.share_inputs["penalty_score"][idx : idx + 1] = request.get("repetition_penalty", 1.0)
|
||||
self.share_inputs["frequency_score"][idx : idx + 1] = request.get("frequency_penalty", 0.0)
|
||||
self.share_inputs["presence_score"][idx : idx + 1] = request.get("presence_penalty", 0.0)
|
||||
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length
|
||||
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = length
|
||||
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length
|
||||
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
|
||||
self.share_inputs["step_idx"][idx : idx + 1] = 0
|
||||
self.share_inputs["min_dec_len"][idx : idx + 1] = request.get("min_tokens", 1)
|
||||
|
||||
self.share_inputs["temperature"][idx : idx + 1] = get_attr_from_request(request, "temperature", 0.95)
|
||||
self.share_inputs["penalty_score"][idx : idx + 1] = get_attr_from_request(
|
||||
request, "repetition_penalty", 1.0
|
||||
)
|
||||
self.share_inputs["frequency_score"][idx : idx + 1] = get_attr_from_request(
|
||||
request, "frequency_penalty", 0.0
|
||||
)
|
||||
self.share_inputs["presence_score"][idx : idx + 1] = get_attr_from_request(
|
||||
request, "presence_penalty", 0.0
|
||||
)
|
||||
self.share_inputs["min_dec_len"][idx : idx + 1] = request.get("min_tokens", 1)
|
||||
self.share_inputs["max_dec_len"][idx : idx + 1] = request.get(
|
||||
"max_tokens", self.model_config.max_model_len
|
||||
)
|
||||
@@ -540,11 +704,15 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
|
||||
stop_seqs_num = len(request.get("stop_seqs_len"))
|
||||
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):
|
||||
request.stop_seqs_len.append(0)
|
||||
self.share_inputs["stop_seqs_len"][:] = np.array(request.stop_seqs_len, dtype="int32")
|
||||
self.share_inputs["stop_seqs"][:stop_seqs_num, : len(request.get("stop_token_ids")[0])] = np.array(
|
||||
request.get("stop_token_ids"), dtype="int64"
|
||||
request.sampling_params.stop_seqs_len.append(0)
|
||||
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = np.array(
|
||||
request.sampling_params.stop_seqs_len, dtype="int32"
|
||||
)
|
||||
self.share_inputs["stop_seqs"][
|
||||
idx : idx + 1, :stop_seqs_num, : len(request.get("stop_token_ids")[0])
|
||||
] = np.array(request.get("stop_token_ids"), dtype="int64")
|
||||
else:
|
||||
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0
|
||||
|
||||
self.share_inputs["not_need_stop"][0] = True
|
||||
|
||||
@@ -565,6 +733,11 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
self.model_config.pad_token_id,
|
||||
dtype="int64",
|
||||
)
|
||||
self.share_inputs["prompt_ids"] = paddle.full(
|
||||
[max_num_seqs, self.parallel_config.max_model_len],
|
||||
self.model_config.pad_token_id,
|
||||
dtype="int64",
|
||||
)
|
||||
self.share_inputs["eos_token_id"] = paddle.full([self.model_config.eos_tokens_lens, 1], 0, dtype="int64")
|
||||
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], self.model_config.top_p, dtype="float32")
|
||||
self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
|
||||
@@ -627,13 +800,15 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
|
||||
# Initialize rotary position embedding
|
||||
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
|
||||
|
||||
# TODO(gongshaotian): move to models
|
||||
self.share_inputs["rope_emb"] = get_rope(
|
||||
rotary_dim=self.model_config.head_dim,
|
||||
position_ids=tmp_position_ids,
|
||||
base=self.model_config.rope_theta,
|
||||
model_config=self.model_config,
|
||||
)
|
||||
if not self.enable_mm:
|
||||
self.share_inputs["rope_emb"] = get_rope(
|
||||
rotary_dim=self.model_config.head_dim,
|
||||
position_ids=tmp_position_ids,
|
||||
base=self.model_config.rope_theta,
|
||||
model_config=self.model_config,
|
||||
)
|
||||
|
||||
# Set block tables
|
||||
pre_max_block_num = (
|
||||
@@ -654,18 +829,40 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["free_list_len"] = paddle.full([1], self.free_list_len, dtype="int32")
|
||||
|
||||
# Initialize stop seqs
|
||||
self.share_inputs["stop_seqs_len"] = paddle.full([self.model_config.max_stop_seqs_num], 0, dtype="int32")
|
||||
self.share_inputs["stop_seqs_len"] = paddle.full(
|
||||
[max_num_seqs, self.model_config.max_stop_seqs_num], 0, dtype="int32"
|
||||
)
|
||||
self.share_inputs["stop_seqs"] = paddle.full(
|
||||
[
|
||||
max_num_seqs,
|
||||
self.model_config.max_stop_seqs_num,
|
||||
self.model_config.stop_seqs_max_len,
|
||||
],
|
||||
-1,
|
||||
dtype="int32",
|
||||
dtype="int64",
|
||||
)
|
||||
|
||||
if self.enable_mm:
|
||||
head_dim = self.model_config.head_dim
|
||||
self.share_inputs["rope_emb"] = paddle.full(
|
||||
shape=[
|
||||
max_num_seqs,
|
||||
2,
|
||||
1,
|
||||
self.parallel_config.max_model_len,
|
||||
1,
|
||||
head_dim // 2,
|
||||
],
|
||||
fill_value=0,
|
||||
dtype="float32",
|
||||
)
|
||||
self.share_inputs["image_features"] = None
|
||||
self.share_inputs["need_think_end"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
|
||||
self.share_inputs["enable_thinking"] = paddle.full(shape=[1], fill_value=True, dtype="bool")
|
||||
self.share_inputs["reasoning_index"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
|
||||
|
||||
def _prepare_inputs(self, is_dummy_run=False) -> None:
|
||||
"""prepare the model inputs"""
|
||||
"""Prepare the model inputs"""
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER and not is_dummy_run:
|
||||
recover_decode_task(
|
||||
self.share_inputs["stop_flags"],
|
||||
@@ -689,10 +886,13 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
# Update bad tokens len
|
||||
max_bad_tokens_len = paddle.max(self.share_inputs["bad_tokens_len"])
|
||||
|
||||
if self.enable_mm: # pos_emb_type is different in EB and VL
|
||||
self.forward_meta.pos_emb_type = "HALF_HEAD_DIM"
|
||||
self.forward_meta.attn_backend = self.attn_backends[0]
|
||||
self.initialize_attention_backend()
|
||||
|
||||
# Get sampling metadata
|
||||
# TODU(lilujia): sync with GPU
|
||||
self.sampling_metadata = SamplingMetadata(
|
||||
temperature=self.share_inputs["temperature"],
|
||||
top_p=self.share_inputs["top_p"],
|
||||
@@ -703,12 +903,16 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
seed=self.share_inputs["infer_seed"],
|
||||
step_idx=self.share_inputs["step_idx"],
|
||||
pre_token_ids=self.share_inputs["pre_ids"],
|
||||
prompt_ids=self.share_inputs["prompt_ids"],
|
||||
prompt_lens=self.share_inputs["prompt_lens"],
|
||||
frequency_penalties=self.share_inputs["frequency_score"],
|
||||
presence_penalties=self.share_inputs["presence_score"],
|
||||
repetition_penalties=self.share_inputs["penalty_score"],
|
||||
min_dec_lens=self.share_inputs["min_dec_len"],
|
||||
bad_words_token_ids=self.share_inputs["bad_tokens"][:, :max_bad_tokens_len],
|
||||
eos_token_ids=self.share_inputs["eos_token_id"],
|
||||
enable_early_stop=self.enable_early_stop,
|
||||
stop_flags=self.share_inputs["stop_flags"],
|
||||
)
|
||||
|
||||
def load_model(self) -> None:
|
||||
@@ -723,7 +927,7 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
# 3. Load drafter model(for speculative decoding)
|
||||
|
||||
def get_model(self) -> nn.Layer:
|
||||
"""get current model"""
|
||||
"""Get current model"""
|
||||
return self.model
|
||||
|
||||
def initialize_attention_backend(self):
|
||||
@@ -741,6 +945,7 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
cache_kvs = {}
|
||||
max_block_num = self.num_gpu_blocks
|
||||
|
||||
# Get kv cache dtype
|
||||
cache_type = self.parallel_config.dtype
|
||||
|
||||
kv_cache_quant_type = None
|
||||
@@ -800,33 +1005,6 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
)
|
||||
self.attn_backends.append(attn_backend)
|
||||
|
||||
def capture_model(self) -> None:
|
||||
"""
|
||||
Trigger CUDA Graph capture for all shapes in 'CudaGraphConfig.cudagraph_capture_sizes'
|
||||
"""
|
||||
logger.warn("XPU not support cuda graph currently")
|
||||
pass
|
||||
|
||||
@sot_warmup_guard(True)
|
||||
def sot_warmup(self) -> None:
|
||||
start_time = time.perf_counter()
|
||||
for batch_size in self.sot_warmup_sizes:
|
||||
self._dummy_run(
|
||||
num_tokens=self.scheduler_config.max_num_batched_tokens,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
logger.info(f"SOT warmup the model with the batch size:{batch_size}")
|
||||
logger.info(f"SOT warmup took {time.perf_counter() - start_time} seconds")
|
||||
|
||||
def exist_prefill(self):
|
||||
"""
|
||||
check whether prefill stage exist
|
||||
"""
|
||||
if int(paddle.max(self.share_inputs["seq_lens_encoder"])) != 0:
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
|
||||
def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int):
|
||||
"""Set dummy prefill inputs to share_inputs"""
|
||||
full_length = min(num_tokens // batch_size, self.parallel_config.max_model_len - 10)
|
||||
@@ -838,7 +1016,7 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
for i in range(batch_size):
|
||||
idx = i
|
||||
self.share_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
|
||||
|
||||
self.share_inputs["prompt_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
|
||||
self.share_inputs["eos_token_id"][:] = np.array([2], dtype="int64").reshape(-1, 1)
|
||||
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = input_length
|
||||
|
||||
@@ -897,6 +1075,24 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
else:
|
||||
paddle.device.xpu.set_debug_level(debug_level)
|
||||
|
||||
def capture_model(self) -> None:
|
||||
"""
|
||||
Trigger CUDA Graph capture for all shapes in 'CudaGraphConfig.cudagraph_capture_sizes'
|
||||
"""
|
||||
logger.warn("XPU not support cuda graph currently")
|
||||
pass
|
||||
|
||||
@sot_warmup_guard(True)
|
||||
def sot_warmup(self) -> None:
|
||||
start_time = time.perf_counter()
|
||||
for batch_size in self.sot_warmup_sizes:
|
||||
self._dummy_run(
|
||||
num_tokens=self.parallel_config.max_num_batched_tokens,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
logger.info(f"SOT warmup the model with the batch size:{batch_size}")
|
||||
logger.info(f"SOT warmup took {time.perf_counter() - start_time} seconds")
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
model_forward_batch: Optional[List[Request]] = None,
|
||||
@@ -921,13 +1117,20 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
# 2. Padding inputs for cuda grph
|
||||
|
||||
# 3. Execute model
|
||||
model_output = self.model(self.share_inputs["ids_remove_padding"], self.forward_meta)
|
||||
if self.enable_mm:
|
||||
model_output = self.model(
|
||||
self.share_inputs["ids_remove_padding"], self.share_inputs["image_features"], self.forward_meta
|
||||
)
|
||||
else:
|
||||
model_output = self.model(
|
||||
ids_remove_padding=self.share_inputs["ids_remove_padding"],
|
||||
forward_meta=self.forward_meta,
|
||||
)
|
||||
|
||||
hiddden_states = xpu_process_output(model_output, self.share_inputs["cum_offsets"], self.forward_meta)
|
||||
hidden_states = xpu_process_output(model_output, self.share_inputs["cum_offsets"], self.forward_meta)
|
||||
|
||||
# 4. Compute logits, Sample
|
||||
logits = self.model.compute_logits(hiddden_states)
|
||||
|
||||
logits = self.model.compute_logits(hidden_states)
|
||||
sampler_output = self.sampler(logits, self.sampling_metadata)
|
||||
|
||||
# 5. Speculative decode
|
||||
@@ -947,15 +1150,21 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
|
||||
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
|
||||
is_block_step=self.share_inputs["is_block_step"],
|
||||
# 投机解码
|
||||
full_hidden_states=None,
|
||||
msg_queue_id=self.parallel_config.msg_queue_id,
|
||||
mp_rank=self.local_rank,
|
||||
use_ep=self.parallel_config.use_ep,
|
||||
# 投机解码
|
||||
full_hidden_states=None,
|
||||
draft_tokens=None,
|
||||
actual_draft_token_num=None,
|
||||
accept_tokens=None,
|
||||
accept_num=None,
|
||||
enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None),
|
||||
think_end_id=(self.model_config.think_end_id if self.enable_mm else -1),
|
||||
need_think_end=(self.share_inputs["need_think_end"][:num_running_requests] if self.enable_mm else None),
|
||||
reasoning_index=(self.share_inputs["reasoning_index"][:num_running_requests] if self.enable_mm else None),
|
||||
stop_token_ids=self.share_inputs["stop_seqs"],
|
||||
stop_seqs_len=self.share_inputs["stop_seqs_len"],
|
||||
)
|
||||
xpu_post_process(
|
||||
sampled_token_ids=sampler_output.sampled_token_ids,
|
||||
@@ -984,13 +1193,43 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
|
||||
@profile_run_guard(True)
|
||||
def profile_run(self) -> None:
|
||||
"""Execute a forward pass with dummy inputs to profile the memory usage of the model."""
|
||||
"""Execute a forward pass with dummy inputs to profile the memory usage of the model"""
|
||||
|
||||
self.num_gpu_blocks = self.parallel_config.total_block_num
|
||||
self.initialize_kv_cache()
|
||||
|
||||
self._dummy_run(
|
||||
num_tokens=int(self.scheduler_config.max_num_batched_tokens),
|
||||
batch_size=min(self.scheduler_config.max_num_seqs, 1),
|
||||
)
|
||||
|
||||
def update_share_input_block_num(self, num_gpu_blocks: int) -> None:
|
||||
"""
|
||||
Set a globally unified block number and update the model's shared input.
|
||||
Args:
|
||||
num_gpu_blocks:
|
||||
"""
|
||||
self.num_gpu_blocks = num_gpu_blocks
|
||||
|
||||
# Reset block table and kv cache with global block num
|
||||
self.initialize_kv_cache()
|
||||
|
||||
# Reset free list
|
||||
free_list = list(
|
||||
range(
|
||||
self.num_gpu_blocks - 1,
|
||||
int(self.num_gpu_blocks * self.cache_config.kv_cache_ratio) - 1,
|
||||
-1,
|
||||
)
|
||||
)
|
||||
self.free_list_len = len(free_list)
|
||||
self.share_inputs.update(
|
||||
{
|
||||
"free_list": paddle.to_tensor(free_list, dtype="int32"),
|
||||
"free_list_len": paddle.full([1], self.free_list_len, dtype="int32"),
|
||||
}
|
||||
)
|
||||
|
||||
def clear_block_table(self) -> None:
|
||||
"""
|
||||
Clear the block tables and kv cache after profiling.
|
||||
@@ -1025,41 +1264,135 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
byte_of_dtype = 2
|
||||
|
||||
hidden_dim = self.model_config.head_dim * self.model_config.kv_num_heads
|
||||
required_memory = (
|
||||
byte_of_dtype
|
||||
* 2 # k + v
|
||||
* (self.cache_config.block_size * hidden_dim)
|
||||
* self.model_config.num_hidden_layers
|
||||
)
|
||||
num_layers = self.model_config.num_hidden_layers
|
||||
required_memory = byte_of_dtype * 2 * (self.cache_config.block_size * hidden_dim) * num_layers # k + v
|
||||
return required_memory
|
||||
|
||||
def update_share_input_block_num(self, num_gpu_blocks: int) -> None:
|
||||
"""
|
||||
Set a globally unified block number and update the model's shared input.
|
||||
Args:
|
||||
num_gpu_blocks:
|
||||
"""
|
||||
self.num_gpu_blocks = num_gpu_blocks
|
||||
|
||||
# Reset block table and kv cache with global block num
|
||||
self.initialize_kv_cache()
|
||||
|
||||
# Reset free list
|
||||
free_list = list(
|
||||
range(
|
||||
self.num_gpu_blocks - 1,
|
||||
int(self.num_gpu_blocks * self.cache_config.kv_cache_ratio) - 1,
|
||||
-1,
|
||||
)
|
||||
)
|
||||
self.free_list_len = len(free_list)
|
||||
self.share_inputs.update(
|
||||
{
|
||||
"free_list": paddle.to_tensor(free_list, dtype="int32"),
|
||||
"free_list_len": paddle.full([1], self.free_list_len, dtype="int32"),
|
||||
}
|
||||
)
|
||||
|
||||
def not_need_stop(self) -> bool:
|
||||
""" """
|
||||
"""Stop decoding if the tensor meets the termination condition"""
|
||||
return self.share_inputs["not_need_stop"][0]
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear cached data from shared inputs and forward metadata"""
|
||||
self.share_inputs.pop("caches", None)
|
||||
if self.forward_meta is not None:
|
||||
self.forward_meta.clear_caches()
|
||||
|
||||
def _init_image_preprocess(self) -> None:
|
||||
processor = DataProcessor(
|
||||
tokenizer_name=self.model_config.model,
|
||||
image_preprocessor_name=str(self.model_config.model),
|
||||
)
|
||||
processor.eval()
|
||||
image_preprocess = processor.image_preprocessor
|
||||
image_preprocess.image_mean_tensor = paddle.to_tensor(image_preprocess.image_mean, dtype="float32").reshape(
|
||||
[1, 3, 1, 1]
|
||||
)
|
||||
image_preprocess.image_std_tensor = paddle.to_tensor(image_preprocess.image_std, dtype="float32").reshape(
|
||||
[1, 3, 1, 1]
|
||||
)
|
||||
image_preprocess.rescale_factor = paddle.to_tensor(image_preprocess.rescale_factor, dtype="float32")
|
||||
image_preprocess.image_mean_tensor = image_preprocess.image_mean_tensor.squeeze([-2, -1]).repeat_interleave(
|
||||
self.model_config.vision_config.patch_size**2 * 1, -1
|
||||
)
|
||||
image_preprocess.image_std_tensor = image_preprocess.image_std_tensor.squeeze([-2, -1]).repeat_interleave(
|
||||
self.model_config.vision_config.patch_size**2 * 1, -1
|
||||
)
|
||||
self.image_preprocess = image_preprocess
|
||||
|
||||
def _preprocess_mm_task(self, one: dict) -> None:
|
||||
"""process batch"""
|
||||
|
||||
input_ids = one["input_ids"][np.newaxis, :]
|
||||
input_ids = paddle.to_tensor(input_ids, dtype=paddle.int64)
|
||||
token_type_ids = one["token_type_ids"][np.newaxis, :]
|
||||
token_type_ids = paddle.to_tensor(token_type_ids, dtype=paddle.int64)
|
||||
|
||||
if one["images"] is not None:
|
||||
image_type_ids = one["image_type_ids"][np.newaxis, :]
|
||||
images = one["images"]
|
||||
image_type_ids = paddle.to_tensor(image_type_ids, dtype=paddle.int64)
|
||||
images = paddle.to_tensor(images, dtype="uint8")
|
||||
grid_thw = paddle.to_tensor(one["grid_thw"], dtype="int64")
|
||||
else:
|
||||
image_type_ids = None
|
||||
images = None
|
||||
grid_thw = None
|
||||
|
||||
if one["position_ids"] is not None:
|
||||
position_ids = paddle.to_tensor(one["position_ids"], dtype="int64").unsqueeze([0])
|
||||
else:
|
||||
position_ids = None
|
||||
|
||||
result = dict(
|
||||
input_ids=input_ids,
|
||||
image_type_ids=image_type_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
grid_thw=grid_thw,
|
||||
images=images,
|
||||
)
|
||||
return result
|
||||
|
||||
@paddle.no_grad()
|
||||
def extract_vision_features(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
|
||||
"""extract_vision_features"""
|
||||
assert inputs["images"] is not None
|
||||
grid_thw = inputs["grid_thw"]
|
||||
|
||||
images = inputs["images"].cast("float32")
|
||||
images = self.image_preprocess.rescale_factor * images - self.image_preprocess.image_mean_tensor
|
||||
images = images / self.image_preprocess.image_std_tensor
|
||||
images = images.cast("bfloat16")
|
||||
|
||||
token_type_ids = inputs["token_type_ids"]
|
||||
token_type_ids_w_video = token_type_ids
|
||||
input_ids = inputs["input_ids"]
|
||||
# convert to img patch id
|
||||
# TODO(lulinjun): may need to check model_config and model_cfg
|
||||
image_mask = input_ids == self.model_config.im_patch_id
|
||||
image_type_ids = inputs["image_type_ids"]
|
||||
with paddle.amp.auto_cast(
|
||||
True,
|
||||
custom_black_list=self.amp_black,
|
||||
custom_white_list=self.amp_white,
|
||||
level="O2",
|
||||
dtype=self.parallel_config.dtype,
|
||||
):
|
||||
image_features = self.model.vision_model.extract_feature(images, grid_thw)
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
S, C = image_features.shape
|
||||
image_features = image_features.reshape([-1, C * self.model_config.spatial_conv_size**2])
|
||||
image_features = ScatterOp.apply(image_features, axis=-1) # mp 切 Fea
|
||||
image_features = image_features.reshape([S, -1])
|
||||
image_features = self.model.resampler_model(
|
||||
image_features,
|
||||
image_mask,
|
||||
token_type_ids_w_video,
|
||||
image_type_ids,
|
||||
grid_thw,
|
||||
)
|
||||
return image_features
|
||||
|
||||
@paddle.no_grad()
|
||||
def prepare_rope3d(self, position_ids: paddle.Tensor, max_len: int) -> paddle.Tensor:
|
||||
"""prepare_rope3d"""
|
||||
|
||||
prefix_max_position_ids = paddle.max(position_ids) + 1
|
||||
dec_pos_ids = paddle.tile(
|
||||
paddle.arange(max_len, dtype="int64").unsqueeze(0).unsqueeze(-1),
|
||||
[1, 1, 3],
|
||||
)
|
||||
dec_pos_ids = dec_pos_ids + prefix_max_position_ids
|
||||
position_ids_3d_real = paddle.concat([position_ids, dec_pos_ids], axis=1)
|
||||
|
||||
rope_emb = get_rope_3d(
|
||||
position_ids=position_ids_3d_real,
|
||||
rotary_dim=self.model_config.head_dim,
|
||||
partial_rotary_factor=1.0,
|
||||
base=self.model_config.rope_theta,
|
||||
max_position=self.parallel_config.max_model_len,
|
||||
freq_allocation=getattr(self.model_config, "freq_allocation", 20),
|
||||
model_type=self.model_config.model_type,
|
||||
)
|
||||
return rope_emb
|
||||
|
Reference in New Issue
Block a user