Merge vl execution path into normal execution path (#2829)

* merge vl model into gpu_model runner

Change-Id: I9f4691a3d5f135e8d72b1d58abcd15ef3aa3f2a6

* fix chinese

Change-Id: Ic7405109b984c21e076fb3b01ff6feb571d0119a

* fix the parse parameter

Change-Id: I4cd62ee87c06220af580d91e347145d4394917fe

* fix the bug in online_inference

Change-Id: Idb111bb2114e83017c4050b2a68cf039c6d3c559

* polish code

Change-Id: I7d4194102c2f1b0743b74fbd5fc284eb8ef4d17c
This commit is contained in:
Zero Rains
2025-07-15 22:20:03 +08:00
committed by GitHub
parent 5fc659b900
commit e7bcbbab52
9 changed files with 441 additions and 1732 deletions

View File

@@ -18,7 +18,7 @@ from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum
from typing import Literal, Optional, Union
from typing import Literal, Optional
from paddleformers.transformers.configuration_utils import PretrainedConfig
from paddleformers.trl import llm_utils
@@ -89,6 +89,7 @@ class ModelConfig:
self.max_model_len = 0
self.dtype = ""
self.enable_logprob = False
self.enable_mm = False
for key, value in args.items():
if hasattr(self, key):

View File

@@ -990,8 +990,6 @@ class LLMEngine(object):
pd_cmd = pd_cmd + f" --log_dir {log_dir}"
worker_path = "../worker/worker_process.py"
if self.cfg.enable_mm:
worker_path = "../worker/vl_worker_process.py"
py_script = os.path.join(current_dir_path, worker_path)
ori_vocab_size = (
@@ -1030,7 +1028,9 @@ class LLMEngine(object):
f" --speculative_benchmark_mode {self.cfg.speculative_config.benchmark_mode}"
f" --max_capture_batch_size {self.cfg.max_capture_batch_size}"
f" --guided_decoding_backend {self.cfg.guided_decoding_backend}"
f" --load_strategy {self.cfg.model_config.load_strategy}")
f" --load_strategy {self.cfg.model_config.load_strategy}"
f" --enable_mm {self.cfg.enable_mm}")
worker_append_flag = {
"enable_expert_parallel":

View File

@@ -129,6 +129,36 @@ def post_process_normal(sampler_output: SamplerOutput,
save_each_rank: bool = False,
skip_save_output: bool = False) -> ModelRunnerOutput:
""" Post-processing steps after completing a single token generation. """
# handle vl:
if model_output.enable_thinking:
exists_think_end = sampler_output.sampled_token_ids == model_output.think_end_id
paddle.assign(
paddle.where(
exists_think_end,
model_output.need_think_end - 1,
model_output.need_think_end,
), model_output.need_think_end)
paddle.assign(
paddle.where(
model_output.need_think_end.cast("bool"),
model_output.reasoning_index - 1,
model_output.reasoning_index,
), model_output.reasoning_index)
stop_wo_think = (
(sampler_output.sampled_token_ids == model_output.eos_token_id) |
(model_output.reasoning_index == 0)) & (
model_output.need_think_end > 0)
sampler_output.sampled_token_ids = paddle.where(stop_wo_think,
model_output.think_end_id,
sampler_output.sampled_token_ids)
paddle.assign(
paddle.where(
stop_wo_think,
model_output.need_think_end - 1,
model_output.need_think_end,
), model_output.need_think_end)
# 1. Set stop value
paddle.assign(
paddle.where(

View File

@@ -30,7 +30,8 @@ from fastdeploy.model_executor.guided_decoding.base_guided_decoding import \
from fastdeploy.model_executor.layers.attention import get_attention_backend
from fastdeploy.model_executor.layers.attention.base_attention_backend import \
AttentionBackend
from fastdeploy.model_executor.layers.rotary_embedding import get_rope
from fastdeploy.model_executor.layers.rotary_embedding import (get_rope,
get_rope_3d)
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.model_executor.layers.sample.sampler import (
Sampler, SpeculativeSampler)
@@ -46,9 +47,14 @@ from fastdeploy.platforms import current_platform
if not current_platform.is_dcu():
from fastdeploy.spec_decode import MTPProposer, NgramProposer
from fastdeploy.input.ernie_tokenizer import ErnieBotTokenizer
from fastdeploy.input.mm_processor import DataProcessor
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import \
ScatterOp
from fastdeploy.worker.model_runner_base import ModelRunnerBase
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput
from fastdeploy.worker.utils import check_safetensors_model
class GPUModelRunner(ModelRunnerBase):
@@ -61,6 +67,7 @@ class GPUModelRunner(ModelRunnerBase):
rank: int,
local_rank: int):
super().__init__(fd_config=fd_config, device=device)
self.enable_mm = self.model_config.enable_mm
self.rank = rank
self.local_rank = local_rank
self.device_id = device_id
@@ -72,6 +79,37 @@ class GPUModelRunner(ModelRunnerBase):
if self.fd_config.parallel_config.guided_decoding_backend != "off":
self.guided_backend = get_guided_backend(fd_config=self.fd_config)
# VL model config:
if self.enable_mm:
model_path = os.path.dirname(self.parallel_config.model_name_or_path)
self.is_safetensors_model = check_safetensors_model(
self.parallel_config.model_name_or_path)
if not self.is_safetensors_model:
self.tokenizer_path = self.image_preprocessor_path = model_path
else:
self.tokenizer_path = self.parallel_config.model_name_or_path
self.image_preprocessor_path = self.parallel_config.model_name_or_path
self.vision_model_name_or_path = os.path.join(
model_path, "DFNRopeVisionTransformer")
self.amp_black = [
"reduce_sum",
"c_softmax_with_cross_entropy",
"elementwise_div",
"sin",
"cos",
"sort",
"multinomial",
]
self.amp_white = [
"lookup_table",
"lookup_table_v2",
"flash_attn",
"matmul",
"matmul_v2",
"fused_gemm_epilogue",
]
# Sampler
if not self.speculative_decoding:
self.sampler = Sampler()
@@ -216,19 +254,52 @@ class GPUModelRunner(ModelRunnerBase):
logger.info(
f"prefill_chunk_info: {request.prefill_chunk_info}")
token_chunk_size = request.prefill_chunk_info[0]
self.share_inputs["seq_lens_this_time"][
idx:idx + 1] = token_chunk_size
if self.enable_mm:
inputs = self._preprocess_mm_task(token_chunk_size)
if inputs.get("images") is not None:
self.share_inputs["image_features"] = self.extract_vision_features(
inputs)
else:
# Compatible with the situation that lacks images and videos
self.share_inputs["image_features"] = None
if request.multimodal_inputs["position_ids"] is not None:
position_ids = paddle.to_tensor(
request.multimodal_inputs["position_ids"],
dtype="int64").unsqueeze([0])
else:
position_ids = None
token_chunk_size = inputs["input_ids"].shape[1]
request.set("start_idx", token_chunk_size)
self.share_inputs["input_ids"][
idx:idx + 1, :token_chunk_size] = inputs["input_ids"]
else:
self.share_inputs['input_ids'][
idx, :token_chunk_size] = np.array(
request.prompt_token_ids[:token_chunk_size])
self.share_inputs['step_seq_lens_encoder'][
idx:idx + 1] = token_chunk_size
self.share_inputs['seq_lens_encoder'][idx:idx +
1] = token_chunk_size
self.share_inputs['seq_lens_decoder'][
idx:idx + 1] = request.get("seq_lens_decoder", 0)
self.share_inputs['step_seq_lens_decoder'][
idx:idx + 1] = request.get("seq_lens_decoder", 0)
self.share_inputs["seq_lens_this_time"][
idx:idx + 1] = token_chunk_size
self.share_inputs['step_seq_lens_encoder'][
idx:idx + 1] = token_chunk_size
self.share_inputs['seq_lens_encoder'][idx:idx +
1] = token_chunk_size
else:
if self.enable_mm:
inputs = self._preprocess_mm_task(request.multimodal_inputs)
if inputs.get("images") is not None:
self.share_inputs[
"image_features"] = self.extract_vision_features(
inputs)
else:
# Compatible with the situation that lacks images and videos
self.share_inputs["image_features"] = None
position_ids = inputs["position_ids"]
length = inputs["input_ids"].shape[1]
self.share_inputs["input_ids"][
idx:idx + 1, :length] = inputs["input_ids"]
else:
self.share_inputs['seq_lens_decoder'][
idx:idx + 1] = request.get("seq_lens_decoder", 0)
@@ -240,21 +311,41 @@ class GPUModelRunner(ModelRunnerBase):
1] = length
self.share_inputs['seq_lens_encoder'][idx:idx + 1] = length
if self.enable_mm:
enable_thinking = request.get("enable_thinking", True)
enable_thinking = enable_thinking if enable_thinking is not None else True
self.share_inputs["enable_thinking"][:] = enable_thinking
self.share_inputs["need_think_end"][
idx:idx + 1, :] = 1 if enable_thinking else 0
self.share_inputs["reasoning_index"][
idx:idx + 1, :] = request.get("reasoning_max_tokens", 2048)
self.share_inputs["rope_emb"][idx:idx +
1, :] = self.prepare_rope3d(
position_ids, request.get("max_tokens", 2048))
self.share_inputs["seq_lens_decoder"][idx:idx + 1] = 0
def get_attr_from_request(request, attr, default_value=None):
res = request.get(attr, default_value)
if res is not None:
return res
else:
return default_value
if len(request.eos_token_ids
) < self.parallel_config.eos_tokens_lens:
request.eos_token_ids.append(request.eos_token_ids[0])
self.share_inputs["eos_token_id"][:] = np.array(
request.eos_token_ids, dtype="int64").reshape(-1, 1)
self.share_inputs["top_p"][idx:idx + 1] = request.get("top_p", 0.7)
self.share_inputs["top_p"][idx:idx + 1] = get_attr_from_request(request, "top_p", 0.7)
self.share_inputs["top_k"][idx:idx + 1] = request.get("top_k", 0)
self.share_inputs["temperature"][idx:idx + 1] = request.get(
"temperature", 0.95)
self.share_inputs["penalty_score"][idx:idx + 1] = request.get(
"repetition_penalty", 1.0)
self.share_inputs["frequency_score"][idx:idx + 1] = request.get(
"frequency_penalty", 0.0)
self.share_inputs["presence_score"][idx:idx + 1] = request.get(
"presence_penalty", 0.0)
self.share_inputs["temperature"][idx:idx + 1] = get_attr_from_request(request,"temperature", 0.95)
self.share_inputs["penalty_score"][idx:idx + 1] = get_attr_from_request(
request, "repetition_penalty", 1.0)
self.share_inputs["frequency_score"][idx:idx + 1] = get_attr_from_request(
request, "frequency_penalty", 0.0)
self.share_inputs["presence_score"][idx:idx + 1] = get_attr_from_request(
request, "presence_penalty", 0.0)
self.share_inputs["min_dec_len"][idx:idx + 1] = request.get(
"min_tokens", 1)
@@ -301,6 +392,9 @@ class GPUModelRunner(ModelRunnerBase):
expected_decode_len: int):
""" Set dummy prefill inputs to share_inputs """
# NOTE(gongshaotian): The maximum decoding length is equal to the expected decoded tokens plus the eos token
if self.enable_mm:
self.share_inputs["free_list"] = paddle.to_tensor([], dtype="int32")
self.share_inputs["free_list_len"][0] = 0
max_dec_len = expected_decode_len + 1
full_length = min(num_tokens // batch_size,
self.parallel_config.max_model_len - max_dec_len)
@@ -476,6 +570,7 @@ class GPUModelRunner(ModelRunnerBase):
self.parallel_config.max_model_len).reshape((1, -1))
# TODO(gongshaotian): move to models
if not self.enable_mm:
self.share_inputs["rope_emb"] = get_rope(
rotary_dim=self.model_config.head_dim,
position_ids=tmp_position_ids,
@@ -541,6 +636,24 @@ class GPUModelRunner(ModelRunnerBase):
fill_value=0,
dtype="int32")
if self.enable_mm:
head_dim = self.model_config.head_dim
self.share_inputs["rope_emb"] = paddle.full(shape=[
max_num_seqs, 2, 1, self.parallel_config.max_model_len, 1, head_dim // 2
],
fill_value=0,
dtype="float32")
self.share_inputs["image_features"] = None
self.share_inputs["need_think_end"] = paddle.full(shape=[max_num_seqs, 1],
fill_value=0,
dtype="int32")
self.share_inputs["enable_thinking"] = paddle.full(shape=[1],
fill_value=True,
dtype="bool")
self.share_inputs["reasoning_index"] = paddle.full(shape=[max_num_seqs, 1],
fill_value=0,
dtype="int32")
def _prepare_inputs(self) -> None:
""" Prepare the model inputs """
# Remove padding
@@ -598,6 +711,8 @@ class GPUModelRunner(ModelRunnerBase):
f"Starting to load model {self.model_config.architectures[0]}")
time_before_load = time.perf_counter()
# 1. Load original model
if self.enable_mm:
self.load_mm_config_and_image_preprocess()
self.model = get_model_from_loader(fd_config=self.fd_config)
# 1.1 Load RL dynamic model
if self.fd_config.load_config.dynamic_load_weight:
@@ -756,11 +871,16 @@ class GPUModelRunner(ModelRunnerBase):
> 1).sum() > 0)
self.forward_meta.step_use_cudagraph = is_decode_batch and in_capturing
self.forward_meta.is_decode_batch = is_decode_batch
if self.enable_mm:
hidden_states = model_output = self.model(self.share_inputs["ids_remove_padding"],
self.share_inputs["image_features"],
self.forward_meta)
else:
model_output = self.model(
ids_remove_padding=self.share_inputs["ids_remove_padding"],
forward_meta=self.forward_meta)
hiddden_states = rebuild_padding(
hidden_states = rebuild_padding(
model_output,
self.share_inputs["cum_offsets"],
self.share_inputs["seq_lens_this_time"],
@@ -773,7 +893,7 @@ class GPUModelRunner(ModelRunnerBase):
)
# 4. Execute spec decode
logits = self.model.compute_logits(hiddden_states)
logits = self.model.compute_logits(hidden_states)
if not self.speculative_decoding:
set_value_by_flags_and_idx(
@@ -831,7 +951,15 @@ class GPUModelRunner(ModelRunnerBase):
accept_tokens=self.share_inputs["accept_tokens"]
if self.speculative_decoding else None,
accept_num=self.share_inputs["accept_num"]
if self.speculative_decoding else None)
if self.speculative_decoding else None,
enable_thinking= self.share_inputs["enable_thinking"]
if self.enable_mm else None,
think_end_id=self.model_config.think_end_id
if self.enable_mm else -1,
need_think_end=self.share_inputs["need_think_end"]
if self.enable_mm else None,
reasoning_index=self.share_inputs["reasoning_index"]
if self.enable_mm else None)
post_process(sampler_output=sampler_output,
model_output=model_output_data,
@@ -861,7 +989,6 @@ class GPUModelRunner(ModelRunnerBase):
"""
if not self.parallel_config.enable_chunked_prefill:
return
for task in tasks:
if task.get("prefill_chunk_info", None) is None:
continue
@@ -875,28 +1002,46 @@ class GPUModelRunner(ModelRunnerBase):
logger.debug(
f"{task.request_id} chunked prefill {task.chunk_idx}/{len(task.prefill_chunk_info)}"
)
if not self.enable_mm:
start_idx = sum(task.prefill_chunk_info[:task.chunk_idx])
if task.chunk_idx == len(task.prefill_chunk_info):
self.share_inputs["seq_lens_this_time"][idx:idx + 1] = 1
self.share_inputs['seq_lens_encoder'][idx:idx + 1] = 0
self.share_inputs["step_idx"][idx:idx + 1] = 1
if self.enable_mm:
self.share_inputs["seq_lens_decoder"][idx:idx +
1] = task.start_idx
else:
self.share_inputs["seq_lens_decoder"][
idx:idx + 1] = start_idx + task.get("seq_lens_decoder", 0)
del self.restore_chunked_prefill_request[task.request_id]
else:
token_chunk_size = task.prefill_chunk_info[task.chunk_idx]
self.share_inputs["seq_lens_this_time"][idx:idx +
1] = token_chunk_size
self.share_inputs['input_ids'][
idx, :token_chunk_size] = np.array(
if self.enable_mm:
inputs = self._preprocess_mm_task(task.prefill_chunk_info[task.chunk_idx])
if inputs.get("images") is not None:
self.share_inputs[
"image_features"] = self.extract_vision_features(
inputs)
else:
# Compatible with the situation that lacks images and videos
self.share_inputs["image_features"] = None
token_chunk_size = inputs["input_ids"].shape[1]
self.share_inputs["input_ids"][idx:idx + 1, :token_chunk_size] = inputs["input_ids"]
self.share_inputs["seq_lens_decoder"][idx:idx +1] = task.start_idx
task.start_idx += token_chunk_size
else:
self.share_inputs['input_ids'][idx, :token_chunk_size] = np.array(
task.prompt_token_ids[start_idx:start_idx +
token_chunk_size])
self.share_inputs["seq_lens_decoder"][
idx:idx + 1] = start_idx + task.get("seq_lens_decoder", 0)
self.share_inputs["seq_lens_this_time"][idx:idx +
1] = token_chunk_size
self.share_inputs['seq_lens_encoder'][idx:idx +
1] = token_chunk_size
self.share_inputs["step_idx"][idx:idx + 1] = 0
self.share_inputs["seq_lens_decoder"][
idx:idx + 1] = start_idx + task.get("seq_lens_decoder", 0)
if self.speculative_decoding and self.proposer.is_chunk_prefill_enabled(
):
self.proposer.update_task_chunk_prefill(task)
@@ -988,11 +1133,16 @@ class GPUModelRunner(ModelRunnerBase):
> 1).sum() > 0)
self.forward_meta.step_use_cudagraph = self.use_cudagraph and is_decode_batch
self.forward_meta.is_decode_batch = is_decode_batch
if self.enable_mm:
hidden_states = model_output = self.model(self.share_inputs["ids_remove_padding"],
self.share_inputs["image_features"],
self.forward_meta)
else:
model_output = self.model(
ids_remove_padding=self.share_inputs["ids_remove_padding"],
forward_meta=self.forward_meta)
hiddden_states = rebuild_padding(
hidden_states = rebuild_padding(
model_output,
self.share_inputs["cum_offsets"],
self.share_inputs["seq_lens_this_time"],
@@ -1004,7 +1154,7 @@ class GPUModelRunner(ModelRunnerBase):
)
# 4. Compute logits, Sample
logits = self.model.compute_logits(hiddden_states)
logits = self.model.compute_logits(hidden_states)
if not self.speculative_decoding:
set_value_by_flags_and_idx(
@@ -1063,7 +1213,15 @@ class GPUModelRunner(ModelRunnerBase):
accept_tokens=self.share_inputs["accept_tokens"]
if self.speculative_decoding else None,
accept_num=self.share_inputs["accept_num"]
if self.speculative_decoding else None)
if self.speculative_decoding else None,
enable_thinking= self.share_inputs["enable_thinking"]
if self.enable_mm else None,
think_end_id=self.model_config.think_end_id
if self.enable_mm else -1,
need_think_end=self.share_inputs["need_think_end"]
if self.enable_mm else None,
reasoning_index=self.share_inputs["reasoning_index"]
if self.enable_mm else None)
if self.speculative_config.method in ["mtp"] and \
self.parallel_config.splitwise_role == "prefill":
@@ -1240,3 +1398,155 @@ class GPUModelRunner(ModelRunnerBase):
self.initialize_kv_cache()
self.dynamic_weight_manager._log_memory(
"dynamic weight manager update all memory")
def _init_image_preprocess(self) -> None:
processor = DataProcessor(
tokenizer_name=self.tokenizer_path,
image_preprocessor_name=str(self.image_preprocessor_path),
)
processor.eval()
image_preprocess = processor.image_preprocessor
image_preprocess.image_mean_tensor = paddle.to_tensor(
image_preprocess.image_mean, dtype="float32").reshape([1, 3, 1, 1])
image_preprocess.image_std_tensor = paddle.to_tensor(
image_preprocess.image_std, dtype="float32").reshape([1, 3, 1, 1])
image_preprocess.rescale_factor = paddle.to_tensor(
image_preprocess.rescale_factor, dtype="float32")
image_preprocess.image_mean_tensor = image_preprocess.image_mean_tensor.squeeze(
[-2, -1]).repeat_interleave(self.model_config.vision_config.patch_size**2 * 1,
-1)
image_preprocess.image_std_tensor = image_preprocess.image_std_tensor.squeeze(
[-2, -1]).repeat_interleave(self.model_config.vision_config.patch_size**2 * 1,
-1)
self.image_preprocess = image_preprocess
def load_mm_config_and_image_preprocess(self) -> None:
tokenizer = ErnieBotTokenizer.from_pretrained(
self.tokenizer_path,
model_max_length=self.parallel_config.max_model_len,
padding_side="right",
use_fast=False,
)
tokenizer.ignored_index = -100
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.unk_token
self.fd_config.model_config.tensor_parallel_degree = self.parallel_config.tensor_parallel_size
self.fd_config.model_config.tensor_parallel_rank = self.parallel_config.tensor_parallel_rank
self.fd_config.model_config.moe_group="dummy"
self.fd_config.parallel_config.column_cut = False
vision_config = self.fd_config.model_config.vision_config
vision_config.attn_sep = False
vision_config.dtype = "bfloat16"
vision_config.tensor_parallel_degree = self.parallel_config.tensor_parallel_size
vision_config.tensor_parallel_rank = self.parallel_config.tensor_parallel_rank
self.fd_config.model_config.pixel_hidden_size = vision_config.hidden_size
self.fd_config.model_config.im_patch_id = tokenizer.get_vocab()[
"<|IMAGE_PLACEHOLDER|>"
]
self.fd_config.model_config.think_end_id = tokenizer.get_vocab()["</think>"]
self.fd_config.model_config.max_text_id = self.fd_config.model_config.im_patch_id
self.fd_config.model_config.sequence_parallel = False
self.model_config = self.fd_config.model_config
self._init_image_preprocess()
def _preprocess_mm_task(self, one: dict) -> None:
"""process batch"""
input_ids = one["input_ids"][np.newaxis, :]
input_ids = paddle.to_tensor(input_ids, dtype=paddle.int64)
token_type_ids = one["token_type_ids"][np.newaxis, :]
token_type_ids = paddle.to_tensor(token_type_ids, dtype=paddle.int64)
if one["images"] is not None:
image_type_ids = one["image_type_ids"][np.newaxis, :]
images = one["images"]
image_type_ids = paddle.to_tensor(image_type_ids,
dtype=paddle.int64)
images = paddle.to_tensor(images, dtype="uint8")
grid_thw = paddle.to_tensor(one["grid_thw"], dtype="int64")
else:
image_type_ids = None
images = None
grid_thw = None
if one["position_ids"] is not None:
position_ids = paddle.to_tensor(one["position_ids"],
dtype="int64").unsqueeze([0])
else:
position_ids = None
result = dict(
input_ids=input_ids,
image_type_ids=image_type_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
grid_thw=grid_thw,
images=images,
)
return result
@paddle.no_grad()
def extract_vision_features(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
"""extract_vision_features"""
assert inputs["images"] is not None
grid_thw = inputs["grid_thw"]
images = inputs["images"].cast("float32")
images = self.image_preprocess.rescale_factor * images - self.image_preprocess.image_mean_tensor
images = images / self.image_preprocess.image_std_tensor
images = images.cast("bfloat16")
token_type_ids = inputs["token_type_ids"]
token_type_ids_w_video = token_type_ids
input_ids = inputs["input_ids"]
# convert to img patch id
# TODO(lulinjun): may need to check model_config and model_cfg
image_mask = input_ids == self.model_config.im_patch_id
image_type_ids = inputs["image_type_ids"]
with paddle.amp.auto_cast(
True,
custom_black_list=self.amp_black,
custom_white_list=self.amp_white,
level="O2",
dtype=self.parallel_config.dtype,
):
image_features = self.model.vision_model.extract_feature(
images, grid_thw)
if self.parallel_config.tensor_parallel_size > 1:
S, C = image_features.shape
image_features = image_features.reshape(
[-1, C * self.model_config.spatial_conv_size**2])
image_features = ScatterOp.apply(image_features,
axis=-1) # mp 切 Fea
image_features = image_features.reshape([S, -1])
image_features = self.model.resampler_model(
image_features,
image_mask,
token_type_ids_w_video,
image_type_ids,
grid_thw,
)
return image_features
@paddle.no_grad()
def prepare_rope3d(self, position_ids: paddle.Tensor, max_len: int) -> paddle.Tensor:
"""prepare_rope3d"""
prefix_max_position_ids = paddle.max(position_ids) + 1
dec_pos_ids = paddle.tile(
paddle.arange(max_len,
dtype="int64").unsqueeze(0).unsqueeze(-1), [1, 1, 3])
dec_pos_ids = dec_pos_ids + prefix_max_position_ids
position_ids_3d_real = paddle.concat([position_ids, dec_pos_ids],
axis=1)
rope_emb = get_rope_3d(
position_ids=position_ids_3d_real,
rotary_dim=self.model_config.head_dim,
paritial_rotary_factor=1.0,
base=self.model_config.rope_theta,
max_position=self.parallel_config.max_model_len,
freq_allocation=self.model_config.freq_allocation,
)
return rope_emb

View File

@@ -201,6 +201,27 @@ class ModelOutputData:
"""
accept_num: paddle.Tensor
"""
vl model enable to think
"""
enable_thinking: paddle.Tensor = None
"""
vl model think end id
"""
think_end_id: int = -1
"""
vl model need to think
"""
need_think_end: paddle.Tensor = None
"""
vl model reasoning index
"""
reasoning_index: paddle.Tensor = None
@dataclass
class ModelRunnerOutput:

View File

@@ -1,842 +0,0 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import argparse
import os
import random
import numpy as np
import paddle
import paddle.distributed.fleet as fleet
from fastdeploy.config import ModelConfig
from fastdeploy.input.ernie_tokenizer import ErnieBotTokenizer
from fastdeploy.input.mm_processor import DataProcessor
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.layers.attention import get_attention_backend
from fastdeploy.model_executor.layers.rotary_embedding import get_rope_3d
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.model_executor.layers.sample.sampler import Sampler
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import \
ScatterOp
from fastdeploy.platforms import current_platform
from fastdeploy.worker.output import SamplerOutput
from fastdeploy.worker.utils import check_safetensors_model
from fastdeploy.worker.vl_model_runner_base import VLModelRunnerBase
if current_platform.is_cuda() and current_platform.available():
from fastdeploy.model_executor.layers.utils import (
remove_padding, speculate_remove_padding)
from fastdeploy.model_executor.ops.gpu import (save_output, save_output_topk,
set_stop_value_multi_ends,
set_value_by_flags_and_idx,
update_inputs)
class GPUVLModelRunner(VLModelRunnerBase):
"""
The GPUVLModelRunner class for vision-language tasks on GPU.
"""
def __init__(
self,
config: ModelConfig,
args: argparse.Namespace,
nranks: int,
rank: int,
) -> None:
"""
GPUVLModelRunner init
"""
self.nranks = nranks
self.rank = rank
hcg = fleet.get_hybrid_communicate_group()
self.tensor_parallel_degree = max(hcg.get_model_parallel_world_size(),
1)
self.tensor_parallel_rank = hcg.get_model_parallel_rank()
self.mp_src_rank = hcg.get_model_parallel_group_src_rank()
self.mp_group = hcg.get_model_parallel_group()
self.is_safetensors_model = check_safetensors_model(
args.model_name_or_path)
self.enable_logprob = args.enable_logprob
model_path = os.path.dirname(args.model_name_or_path)
args.llm_model_name_or_path = args.model_name_or_path
if not self.is_safetensors_model:
args.tokenizer = args.image_preprocessor = model_path
else:
args.tokenizer = args.image_preprocessor = args.model_name_or_path
args.vision_model_name_or_path = os.path.join(
model_path, "DFNRopeVisionTransformer")
self.amp_black = [
"reduce_sum",
"c_softmax_with_cross_entropy",
"elementwise_div",
"sin",
"cos",
"sort",
"multinomial",
]
self.amp_white = [
"lookup_table",
"lookup_table_v2",
"flash_attn",
"matmul",
"matmul_v2",
"fused_gemm_epilogue",
]
super().__init__(config, args)
self.init_extra_input(config, args)
self._reset_paddle_env()
self.sampler = Sampler()
def _reset_paddle_env(self):
pass
def update_chunked_prefill(self, tasks: list[any]) -> None:
"""
update chunked prefill
"""
if not self.args.enable_chunked_prefill:
return
for task in tasks:
if task.chunk_idx > len(task.prefill_chunk_info):
continue
idx = task.idx
if task.chunk_idx == len(task.prefill_chunk_info):
self.share_inputs["seq_lens_this_time"][idx:idx + 1] = 1
self.share_inputs['seq_lens_encoder'][idx:idx + 1] = 0
self.share_inputs["seq_lens_decoder"][idx:idx +
1] = task.start_idx
self.share_inputs["step_idx"][idx:idx + 1] = 1
else:
inputs = self._preprocess_task(
task.prefill_chunk_info[task.chunk_idx])
if inputs.get("images") is not None:
self.share_inputs[
"image_features"] = self.extract_vision_features(
inputs)
else:
# Compatible with the situation that lacks images and videos
self.share_inputs["image_features"] = None
token_chunk_size = inputs["input_ids"].shape[1]
self.share_inputs["input_ids"][
idx:idx + 1, :token_chunk_size] = inputs["input_ids"]
self.share_inputs["seq_lens_this_time"][idx:idx +
1] = token_chunk_size
self.share_inputs['seq_lens_encoder'][idx:idx +
1] = token_chunk_size
self.share_inputs["seq_lens_decoder"][idx:idx +
1] = task.start_idx
self.share_inputs["step_idx"][idx:idx + 1] = 0
task.start_idx += token_chunk_size
task.chunk_idx += 1
def _init_image_preprocess(self, vision_config) -> None:
processor = DataProcessor(
tokenizer_name=self.args.tokenizer,
image_preprocessor_name=str(self.args.image_preprocessor),
)
processor.eval()
image_preprocess = processor.image_preprocessor
image_preprocess.image_mean_tensor = paddle.to_tensor(
image_preprocess.image_mean, dtype="float32"
).reshape([1, 3, 1, 1])
image_preprocess.image_std_tensor = paddle.to_tensor(
image_preprocess.image_std, dtype="float32"
).reshape([1, 3, 1, 1])
image_preprocess.rescale_factor = paddle.to_tensor(
image_preprocess.rescale_factor, dtype="float32"
)
image_preprocess.image_mean_tensor = (
image_preprocess.image_mean_tensor.squeeze(
[-2, -1]
).repeat_interleave(vision_config.patch_size**2 * 1, -1)
)
image_preprocess.image_std_tensor = (
image_preprocess.image_std_tensor.squeeze(
[-2, -1]
).repeat_interleave(vision_config.patch_size**2 * 1, -1)
)
return image_preprocess
def _load_model(
self,
model_name: str,
dynamic_load_weight: int = 0,
) -> None:
"""
Load the model from the given model name.
"""
vocab_file_names = [
"tokenizer.model", "spm.model", "ernie_token_100k.model"
]
for i in range(len(vocab_file_names)):
if os.path.exists(
os.path.join(self.args.tokenizer, vocab_file_names[i])):
ErnieBotTokenizer.resource_files_names[
"vocab_file"] = vocab_file_names[i]
break
tokenizer = ErnieBotTokenizer.from_pretrained(
self.args.tokenizer,
model_max_length=self.args.max_model_len,
padding_side="right",
use_fast=False,
)
tokenizer.ignored_index = -100
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.unk_token
self.dtype = self.args.dtype
paddle.set_default_dtype(self.dtype)
from fastdeploy.worker.worker_process import initialize_fd_config
fd_config = initialize_fd_config(
self.args, self.tensor_parallel_degree, self.tensor_parallel_rank
)
fd_config.model_config.tensor_parallel_degree=self.tensor_parallel_degree
fd_config.model_config.tensor_parallel_rank=self.tensor_parallel_rank
fd_config.model_config.moe_group="dummy"
fd_config.parallel_config.column_cut = False
vision_config = fd_config.model_config.vision_config
vision_config.attn_sep = False
vision_config.dtype = "bfloat16"
vision_config.tensor_parallel_degree = self.tensor_parallel_degree
vision_config.tensor_parallel_rank = self.tensor_parallel_rank
fd_config.model_config.pixel_hidden_size = vision_config.hidden_size
fd_config.model_config.im_patch_id = tokenizer.get_vocab()[
"<|IMAGE_PLACEHOLDER|>"
]
fd_config.model_config.think_end_id = tokenizer.get_vocab()["</think>"]
fd_config.model_config.max_text_id = fd_config.model_config.im_patch_id
fd_config.model_config.sequence_parallel = False
self.fd_config = fd_config
self.model_cfg = self.fd_config.model_config
self.image_preprocess = self._init_image_preprocess(
self.fd_config.model_config.vision_config
)
from fastdeploy.model_executor.model_loader import \
get_model_from_loader
self.model = get_model_from_loader(self.fd_config)
attn_backend_cls = get_attention_backend()
num_heads = self.fd_config.model_config.num_attention_heads // \
self.fd_config.parallel_config.tensor_parallel_size
self.fd_config.model_config.kv_num_heads = int(
self.fd_config.model_config.num_key_value_heads
) // self.fd_config.parallel_config.tensor_parallel_size
head_dim = self.fd_config.model_config.head_dim
self.attn_backend = attn_backend_cls(
self.fd_config,
kv_num_heads=self.fd_config.model_config.kv_num_heads,
num_heads=num_heads,
head_dim=head_dim)
self._init_kvcache()
def init_extra_input(self, config: ModelConfig, args: argparse.Namespace) -> None:
"""
Initialize extra input tensors.
"""
head_dim = self.model_cfg.head_dim
self.share_inputs.update({
"rope_emb":
paddle.full(shape=[
args.max_num_seqs, 2, 1, self.max_length, 1, head_dim // 2
],
fill_value=0,
dtype="float32")
})
self.share_inputs.update({"image_features": None})
self.share_inputs.update({
"need_think_end":
paddle.full(shape=[args.max_num_seqs, 1],
fill_value=0,
dtype="int32")
})
self.share_inputs.update({
"enable_thinking":
paddle.full(shape=[1], fill_value=True, dtype="bool")
})
self.share_inputs.update({
"reasoning_index":
paddle.full(shape=[args.max_num_seqs, 1],
fill_value=0,
dtype="int32")
})
def init_rotary_position_embedding(self, max_model_len: int) -> None:
"""
Init rotary position embedding
"""
pass
def _init_kvcache(self):
"""
Init kv cache
"""
cache_kvs = {}
total_block_num = self.num_gpu_blocks
num_layers = self.model_cfg.num_hidden_layers
kv_num_head = self.model_cfg.num_key_value_heads if self.model_cfg.num_key_value_heads != -1 else self.model_cfg.num_attention_heads
kv_num_head = kv_num_head // self.tensor_parallel_degree
self.model_cfg.kv_num_head = kv_num_head
for i in range(num_layers):
cache_type = self.args.dtype
cache_kvs["key_caches_{}".format(i)] = paddle.full(
shape=[
total_block_num,
kv_num_head,
self.args.block_size,
self.model_cfg.head_dim,
],
fill_value=0,
dtype=cache_type,
)
cache_kvs["value_caches_{}".format(i)] = paddle.full(
shape=[
total_block_num,
kv_num_head,
self.args.block_size,
self.model_cfg.head_dim,
],
fill_value=0,
dtype=cache_type,
)
self.share_inputs["caches"] = list(cache_kvs.values())
for value in cache_kvs.values():
del value
paddle.device.cuda.empty_cache()
def clear_parameters(self, pid: int) -> None:
""" clear_parameters """
if "caches" in self.share_inputs:
self.model.clear_parameters(pid)
del self.share_inputs["caches"]
paddle.device.cuda.empty_cache()
self.model.log_memory_usage("clear all memory")
def update_parameters(self, pid: int) -> None:
""" update_parameters """
if "caches" not in self.share_inputs:
self.model.update_parameters(pid)
self._init_kvcache()
self.model.log_memory_usage("update all memory")
@paddle.no_grad()
def extract_vision_features(self, inputs: list[paddle.Tensor]) -> paddle.Tensor:
"""extract_vision_features"""
assert inputs["images"] is not None
grid_thw = inputs["grid_thw"]
images = inputs["images"].cast("float32")
images = self.image_preprocess.rescale_factor * images - self.image_preprocess.image_mean_tensor
images = images / self.image_preprocess.image_std_tensor
images = images.cast("bfloat16")
token_type_ids = inputs["token_type_ids"]
token_type_ids_w_video = token_type_ids
input_ids = inputs["input_ids"]
# convert to img patch id
image_mask = input_ids == self.model_cfg.im_patch_id
image_type_ids = inputs["image_type_ids"]
with paddle.amp.auto_cast(
True,
custom_black_list=self.amp_black,
custom_white_list=self.amp_white,
level="O2",
dtype=self.dtype,
):
image_features = self.model.vision_model.extract_feature(
images, grid_thw)
if self.tensor_parallel_degree > 1:
S, C = image_features.shape
image_features = image_features.reshape(
[-1, C * self.model_cfg.spatial_conv_size**2])
image_features = ScatterOp.apply(image_features,
axis=-1) # mp 切 Fea
image_features = image_features.reshape([S, -1])
image_features = self.model.resampler_model(
image_features,
image_mask,
token_type_ids_w_video,
image_type_ids,
grid_thw,
)
return image_features
@paddle.no_grad()
def prepare_rope3d(self, position_ids: paddle.Tensor, **kwargs) -> paddle.Tensor:
"""prepare_rope3d"""
prefix_max_position_ids = paddle.max(position_ids) + 1
dec_pos_ids = paddle.tile(
paddle.arange(kwargs["max_length"],
dtype="int64").unsqueeze(0).unsqueeze(-1), [1, 1, 3])
dec_pos_ids = dec_pos_ids + prefix_max_position_ids
position_ids_3d_real = paddle.concat([position_ids, dec_pos_ids],
axis=1)
rope_emb = get_rope_3d(
position_ids=position_ids_3d_real,
rotary_dim=self.model_cfg.head_dim,
paritial_rotary_factor=1.0,
base=self.model_cfg.rope_theta,
max_position=self.args.max_model_len,
freq_allocation=self.model_cfg.freq_allocation,
)
return rope_emb
def prefill_finished(self):
"""
Verify prefill operation completion
"""
prefill_statue = (self.share_inputs["seq_lens_this_time"] != 0) & (
self.share_inputs["seq_lens_this_time"] != 1)
return not paddle.any(prefill_statue).numpy()
def dy_input_preprocess(self, tasks: list[any]) -> None:
"""
dynamic insertion
"""
def get_numeric_value(task, key, default_value):
if task.get(key, None) is not None:
return task.get(key)
else:
return default_value
for i in range(len(tasks)):
task = tasks[i]
idx = task.idx
kwargs = {
"max_length":
get_numeric_value(task, "max_tokens", 2048),
"top_p":
get_numeric_value(task, "top_p", 0.8),
"temperature":
get_numeric_value(task, "temperature", 0.2),
"top_k":
get_numeric_value(task, "top_k", 0),
"penalty_score":
get_numeric_value(task, "repetition_penalty", 1.0),
"frequency_score":
get_numeric_value(task, "frequency_penalty", 0.0),
"presence_score":
get_numeric_value(task, "presence_penalty", 0.0),
"decode_strategy":
"sampling",
"pad_token_id":
self.args.pad_token_id,
"enable_thinking":
get_numeric_value(task, "enable_thinking", True),
"reasoning_max_tokens":
get_numeric_value(task, "reasoning_max_tokens", 2048),
}
if self.args.enable_chunked_prefill:
task.set("chunk_idx", 1)
inputs = self._preprocess_task(task.prefill_chunk_info[0])
if inputs.get("images") is not None:
self.share_inputs[
"image_features"] = self.extract_vision_features(
inputs)
else:
# Compatible with the situation that lacks images and videos
self.share_inputs["image_features"] = None
if task.multimodal_inputs["position_ids"] is not None:
position_ids = paddle.to_tensor(
task.multimodal_inputs["position_ids"],
dtype="int64").unsqueeze([0])
else:
position_ids = None
token_chunk_size = inputs["input_ids"].shape[1]
task.set("start_idx", token_chunk_size)
self.share_inputs["input_ids"][
idx:idx + 1, :token_chunk_size] = inputs["input_ids"]
self.share_inputs["seq_lens_this_time"][idx:idx +
1] = token_chunk_size
self.share_inputs["seq_lens_encoder"][idx:idx +
1] = token_chunk_size
self.share_inputs["step_seq_lens_encoder"][
idx:idx + 1] = token_chunk_size
else:
inputs = self._preprocess_task(task.multimodal_inputs)
if inputs.get("images") is not None:
self.share_inputs[
"image_features"] = self.extract_vision_features(
inputs)
else:
# Compatible with the situation that lacks images and videos
self.share_inputs["image_features"] = None
position_ids = inputs["position_ids"]
length = inputs["input_ids"].shape[1]
self.share_inputs["input_ids"][
idx:idx + 1, :length] = inputs["input_ids"]
self.share_inputs["seq_lens_this_time"][idx:idx + 1] = length
self.share_inputs["seq_lens_encoder"][idx:idx + 1] = length
self.share_inputs["step_seq_lens_encoder"][idx:idx +
1] = length
# force </think>
self.share_inputs["enable_thinking"][:] = kwargs["enable_thinking"]
self.share_inputs["need_think_end"][
idx:idx + 1, :] = 1 if kwargs["enable_thinking"] else 0
self.share_inputs["reasoning_index"][
idx:idx + 1, :] = kwargs["reasoning_max_tokens"]
self.share_inputs["rope_emb"][idx:idx +
1, :] = self.prepare_rope3d(
position_ids, **kwargs)
self.share_inputs["top_p"][idx:idx + 1] = kwargs["top_p"]
self.share_inputs["temperature"][idx:idx +
1] = kwargs["temperature"]
self.share_inputs["eos_token_id"][:] = np.array(
task.eos_token_ids).astype("int64").reshape(-1, 1)
self.share_inputs["penalty_score"][idx:idx +
1] = kwargs["penalty_score"]
self.share_inputs["frequency_score"][idx:idx +
1] = kwargs["frequency_score"]
self.share_inputs["presence_score"][idx:idx +
1] = kwargs["presence_score"]
self.share_inputs["seq_lens_decoder"][idx:idx + 1] = 0
self.share_inputs["step_idx"][idx:idx + 1] = 0
self.share_inputs["min_dec_len"][idx:idx + 1] = 1
self.share_inputs["max_dec_len"][idx:idx +
1] = kwargs["max_length"]
self.share_inputs["stop_flags"][idx:idx + 1] = False
self.share_inputs["pre_ids"][idx:idx + 1] = -1
encoder_block_num = len(task.get("block_tables"))
self.share_inputs["encoder_block_lens"][idx:idx +
1] = encoder_block_num
self.share_inputs["block_tables"][idx:idx + 1, :] = -1
self.share_inputs["block_tables"][
idx:idx + 1, :encoder_block_num] = np.array(task.block_tables,
dtype="int32")
def pre_process(self) -> None:
"""
pre_process
"""
if current_platform.is_cuda():
if self.args.speculative_method is not None:
(
ids_remove_padding,
padding_offset,
cum_offsets,
cu_seqlens_q,
cu_seqlens_k,
) = speculate_remove_padding(
max_len=self.args.max_model_len,
input_ids=self.share_inputs["input_ids"],
seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
draft_tokens=self.share_inputs["draft_tokens"],
seq_lens_encoder=self.share_inputs["seq_lens_encoder"])
else:
(
ids_remove_padding,
padding_offset,
cum_offsets,
cu_seqlens_q,
cu_seqlens_k,
) = remove_padding(
max_len=self.args.max_model_len,
input_ids=self.share_inputs["input_ids"],
seq_lens_this_time=self.share_inputs["seq_lens_this_time"])
self.share_inputs["ids_remove_padding"] = ids_remove_padding
self.share_inputs["padding_offset"] = padding_offset
self.share_inputs["cum_offsets"] = cum_offsets
self.share_inputs["cu_seqlens_q"] = cu_seqlens_q
self.share_inputs["cu_seqlens_k"] = cu_seqlens_k
self.share_inputs["decoder_batch_ids"] = paddle.full(
[self.fd_config.parallel_config.max_num_seqs, 1], 0, dtype='int32')
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full(
[self.fd_config.parallel_config.max_num_seqs, 1], 0, dtype='int32')
# initialize_forward_meta
self.forward_meta = ForwardMeta(
input_ids=self.share_inputs["input_ids"],
ids_remove_padding=self.share_inputs["ids_remove_padding"],
rotary_embs=self.share_inputs["rope_emb"],
attn_backend=self.attn_backend,
decoder_batch_ids=self.share_inputs["decoder_batch_ids"],
decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"],
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
cum_offsets=self.share_inputs["cum_offsets"],
padding_offset=self.share_inputs["padding_offset"],
cu_seqlens_q=self.share_inputs["cu_seqlens_q"],
cu_seqlens_k=self.share_inputs["cu_seqlens_k"],
block_tables=self.share_inputs["block_tables"],
caches=self.share_inputs["caches"]
)
self.attn_backend.init_attention_metadata(self.forward_meta)
self.sampling_metadata = SamplingMetadata(
temperature=self.share_inputs["temperature"],
top_p=self.share_inputs["top_p"],
step_idx=self.share_inputs["step_idx"],
pre_token_ids=self.share_inputs["pre_ids"],
frequency_penalties=self.share_inputs["frequency_score"],
presence_penalties=self.share_inputs["presence_score"],
repetition_penalties=self.share_inputs["penalty_score"],
min_dec_lens=self.share_inputs["min_dec_len"],
bad_words_token_ids=self.share_inputs["bad_tokens"],
eos_token_ids=self.share_inputs["eos_token_id"],
max_num_logprobs=20 if self.enable_logprob else None,
)
def generate(self) -> None:
"""
generate
"""
self.pre_process()
hiddden_states = self.model(self.share_inputs["ids_remove_padding"],
self.share_inputs["image_features"],
self.forward_meta)
logits = self.model.compute_logits(hiddden_states)
set_value_by_flags_and_idx(
self.share_inputs["pre_ids"],
self.share_inputs["input_ids"],
self.share_inputs["seq_lens_this_time"],
self.share_inputs["seq_lens_encoder"],
self.share_inputs["seq_lens_decoder"],
self.share_inputs["step_idx"],
self.share_inputs["stop_flags"],
)
# sampler & save_output
sampler_output = self.sampler(logits, self.sampling_metadata)
if self.fd_config.parallel_config.tensor_parallel_size > 1:
paddle.distributed.broadcast(sampler_output.sampled_token_ids, 0)
self.post_process(sampler_output)
def post_process(self, sampler_output: SamplerOutput) -> None:
"""
post_process
"""
if self.share_inputs["enable_thinking"]:
exists_think_end = sampler_output.sampled_token_ids == self.model_cfg.think_end_id
paddle.assign(
paddle.where(
exists_think_end,
self.share_inputs["need_think_end"] - 1,
self.share_inputs["need_think_end"],
), self.share_inputs["need_think_end"])
paddle.assign(
paddle.where(
self.share_inputs["need_think_end"].cast("bool"),
self.share_inputs["reasoning_index"] - 1,
self.share_inputs["reasoning_index"],
), self.share_inputs["reasoning_index"])
stop_wo_think = (
(sampler_output.sampled_token_ids == self.share_inputs["eos_token_id"]) |
(self.share_inputs["reasoning_index"] == 0)) & (
self.share_inputs["need_think_end"] > 0)
sampler_output.sampled_token_ids = paddle.where(stop_wo_think,
self.model_cfg.think_end_id,
sampler_output.sampled_token_ids)
paddle.assign(
paddle.where(
stop_wo_think,
self.share_inputs["need_think_end"] - 1,
self.share_inputs["need_think_end"],
), self.share_inputs["need_think_end"])
paddle.assign(
paddle.where(
self.share_inputs["stop_flags"],
self.share_inputs["step_idx"],
self.share_inputs["step_idx"] + 1,
),
self.share_inputs["step_idx"],
)
length_cond = paddle.greater_equal(self.share_inputs["step_idx"],
self.share_inputs["max_dec_len"])
paddle.assign(
paddle.logical_or(self.share_inputs["stop_flags"], length_cond),
self.share_inputs["stop_flags"],
)
set_stop_value_multi_ends(
sampler_output.sampled_token_ids,
self.share_inputs["stop_flags"],
self.share_inputs["seq_lens_this_time"],
self.share_inputs["eos_token_id"],
self.share_inputs["next_tokens"],
False,
) # multi ends
# update inputs
update_inputs(
self.share_inputs["stop_flags"],
self.share_inputs["not_need_stop"],
self.share_inputs["seq_lens_this_time"],
self.share_inputs["seq_lens_encoder"],
self.share_inputs["seq_lens_decoder"],
self.share_inputs["input_ids"],
self.share_inputs["stop_nums"],
sampler_output.sampled_token_ids,
self.share_inputs["is_block_step"],
)
if sampler_output.logprobs_tensors is None:
save_output(
sampler_output.sampled_token_ids,
self.share_inputs["not_need_stop"],
self.rank,
False, # use_ep
)
else:
save_output_topk(
sampler_output.sampled_token_ids,
sampler_output.logprobs_tensors.logprob_token_ids,
sampler_output.logprobs_tensors.logprobs,
sampler_output.logprobs_tensors.selected_token_ranks,
self.share_inputs["not_need_stop"],
self.rank,
)
def _cal_theortical_kvcache(self):
"""
Calculate the size of kvcache for computational theory
"""
num_layers = self.model_cfg.num_hidden_layers
byte_of_cache = 2
# support c8 c4
hidden_dim = self.model_cfg.head_dim * self.model_cfg.kv_num_head
theoretical_kv_cache_memory = (2 * byte_of_cache *
self.args.block_size * num_layers *
hidden_dim)
return theoretical_kv_cache_memory
def _update_share_input_block_num(self):
"""
Update share_inputs['block_tables'] and share_inputs['free_list']
"""
num_gpu_blocks = self.num_gpu_blocks
del self.share_inputs["caches"]
self._init_kvcache()
del self.share_inputs["block_tables"]
self.share_inputs["block_tables"] = paddle.full(
[self.args.max_num_seqs, num_gpu_blocks], -1, dtype="int32")
# Init free list
free_list = list(
range(num_gpu_blocks - 1,
int(num_gpu_blocks * self.args.kv_cache_ratio) - 1, -1))
self.free_list_len = len(free_list)
self.share_inputs.update({
"free_list":
paddle.to_tensor(free_list, dtype="int32"),
"free_list_len":
paddle.full([1], self.free_list_len, dtype="int32"),
})
def dummy_input(self, num_total_tokens: int, number_of_tasks: int) -> None:
"""
fake input to profile
"""
input_length = min(num_total_tokens // number_of_tasks,
self.args.max_model_len - 10)
block_num = (input_length + self.args.block_size - 1 ) // self.args.block_size \
+ self.args.enc_dec_block_num
self.share_inputs["free_list"] = paddle.to_tensor([], dtype="int32")
self.share_inputs["free_list_len"][0] = 0
for i in range(number_of_tasks):
idx = i
self.share_inputs["input_ids"][idx:idx +
1, :input_length] = np.array(
[5] * input_length)
self.share_inputs["eos_token_id"][:] = np.array(
[2], dtype="int64").reshape(-1, 1)
self.share_inputs["seq_lens_this_time"][idx:idx + 1] = input_length
self.share_inputs["step_seq_lens_encoder"][idx:idx +
1] = input_length
self.share_inputs["seq_lens_encoder"][idx:idx + 1] = input_length
self.share_inputs["seq_lens_decoder"][idx:idx + 1] = 0
self.share_inputs["step_idx"][idx:idx + 1] = 0
self.share_inputs["max_dec_len"][idx:idx + 1] = 10
self.share_inputs["stop_flags"][idx:idx + 1] = False
self.share_inputs["first_token_ids"][
idx:idx + 1] = self.share_inputs["input_ids"][idx:idx + 1, :1]
self.share_inputs["ori_seq_lens_encoder"][idx:idx +
1] = input_length
self.share_inputs["infer_seed"][idx:idx + 1] = random.randint(
0, 922337203685477580)
self.share_inputs["encoder_block_lens"][idx:idx + 1] = block_num
self.share_inputs["block_tables"][idx : idx + 1, :block_num] = np.arange(idx * block_num, \
(idx + 1) * block_num, 1)
def _preprocess_task(self, one: dict) -> None:
"""process batch"""
input_ids = one["input_ids"][np.newaxis, :]
input_ids = paddle.to_tensor(input_ids, dtype=paddle.int64)
token_type_ids = one["token_type_ids"][np.newaxis, :]
token_type_ids = paddle.to_tensor(token_type_ids, dtype=paddle.int64)
if one["images"] is not None:
image_type_ids = one["image_type_ids"][np.newaxis, :]
images = one["images"]
image_type_ids = paddle.to_tensor(image_type_ids,
dtype=paddle.int64)
images = paddle.to_tensor(images, dtype="uint8")
grid_thw = paddle.to_tensor(one["grid_thw"], dtype="int64")
else:
image_type_ids = None
images = None
grid_thw = None
if one["position_ids"] is not None:
position_ids = paddle.to_tensor(one["position_ids"],
dtype="int64").unsqueeze([0])
else:
position_ids = None
result = dict(
input_ids=input_ids,
image_type_ids=image_type_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
grid_thw=grid_thw,
images=images,
)
return result

View File

@@ -1,277 +0,0 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import argparse
from abc import ABC, abstractmethod
import paddle
import paddle.distributed as dist
import paddle.distributed.fleet as fleet
from fastdeploy.config import ModelConfig
from fastdeploy.utils import get_logger
logger = get_logger("worker", "worker.log")
class VLModelRunnerBase(ABC):
"""
Engine -> (WIP)Executor -> Worker -> VLModelRunnerBase -> Model
VLModelRunnerBase interface abstracts the model execution logic that
contain input preparation, token generation, and tokenprocessing.
"""
def __init__(
self,
config: ModelConfig,
args: argparse.Namespace,
) -> None:
"""
VLModelRunnerBase init
"""
self.share_inputs = {}
self.model_cfg = config
self.args = args
self.init_dist_env()
self._init_share_inputs(args.max_num_seqs)
self.init_rotary_position_embedding(args.max_model_len)
self.num_gpu_blocks = args.total_block_num
self._load_model(config.model_name_or_path, args.dynamic_load_weight)
def _log_memory_usage(self, context: str = "") -> None:
"""Log current GPU memory usage."""
max_alloc = paddle.device.cuda.max_memory_allocated() / (1024**3)
max_reserved = paddle.device.cuda.max_memory_reserved() / (1024**3)
curr_alloc = paddle.device.cuda.memory_allocated() / (1024**3)
curr_reserved = paddle.device.cuda.memory_reserved() / (1024**3)
logger.info(f"GPU memory usage {context}:")
logger.warning(f"max_allocated: {max_alloc:.2f}GB\n"
f"max_reserved: {max_reserved:.2f}GB\n"
f"current_allocated: {curr_alloc:.2f}GB\n"
f"current_reserved: {curr_reserved:.2f}GB")
def init_dist_env(self, seed=20) -> None:
"""
init distributed env
"""
self.nranks = dist.get_world_size()
strategy = fleet.DistributedStrategy()
strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": self.nranks,
"pp_degree": 1,
"sharding_degree": 1,
}
# Set control in tensor parallel
strategy.tensor_parallel_configs = {"tensor_init_seed": seed}
fleet.init(is_collective=True, strategy=strategy)
self.rank = fleet.worker_index()
def _load_model_init_val(self) -> None:
"""
initialize model config from config file
"""
def _get_attr(key, default=None):
if hasattr(self.model_cfg, key):
return getattr(self.model_cfg, key)
return default
self.top_p = _get_attr("top_p", 0.0)
self.temperature = _get_attr("temperature", 1.0)
self.rope_theta = _get_attr("rope_theta", 10000.0)
self.rope_scaling = _get_attr("rope_scaling", None)
self.penalty_score = _get_attr("penalty_score", 1.0)
self.frequency_score = _get_attr("frequency_score", 0.0)
self.presence_score = _get_attr("presence_score", 0.0)
self.min_length = _get_attr("min_length", 1)
self.max_length = self.args.max_model_len
def _init_share_inputs(self, max_num_seqs: int) -> None:
"""
initialize shared inputs
"""
self._load_model_init_val()
int64_config = {"dtype": "int64"}
int32_config = {"dtype": "int32"}
float32_config = {"dtype": "float32"}
bool_config = {"dtype": "bool"}
self.share_inputs.update({
"pre_ids":
paddle.full([max_num_seqs, self.max_length], -1, **int64_config),
"input_ids":
paddle.full([max_num_seqs, self.args.max_model_len],
self.args.pad_token_id, **int64_config),
"eos_token_id":
paddle.full([self.args.eos_tokens_lens, 1], 0, **int64_config),
"top_p":
paddle.full([max_num_seqs, 1], self.top_p, **float32_config),
"temperature":
paddle.full([max_num_seqs, 1], self.temperature, **float32_config),
"penalty_score":
paddle.full([max_num_seqs, 1], self.penalty_score,
**float32_config),
"frequency_score":
paddle.full([max_num_seqs, 1], self.frequency_score,
**float32_config),
"presence_score":
paddle.full([max_num_seqs, 1], self.presence_score,
**float32_config),
"min_dec_len":
paddle.full([max_num_seqs, 1], self.min_length, **int64_config),
"max_dec_len":
paddle.full([max_num_seqs, 1], self.max_length, **int64_config),
"min_length":
paddle.full([max_num_seqs, 1], self.min_length, **int64_config),
"max_length":
paddle.full([max_num_seqs, 1], self.max_length, **int64_config),
"seq_lens_this_time":
paddle.full(max_num_seqs, 0, **int32_config),
"seq_lens_encoder":
paddle.full([max_num_seqs, 1], 0, **int32_config),
"step_seq_lens_encoder":
paddle.full([max_num_seqs, 1], 0, **int32_config),
"step_seq_lens_decoder":
paddle.full([max_num_seqs, 1], 0, **int32_config),
"seq_lens_decoder":
paddle.full([max_num_seqs, 1], 0, **int32_config),
"step_idx":
paddle.full([max_num_seqs, 1], 0, **int64_config),
"not_need_stop":
paddle.full([1], False, **bool_config).cpu(),
"stop_flags":
paddle.full([max_num_seqs, 1], True, **bool_config),
"stop_nums":
paddle.full([1], max_num_seqs, **int64_config),
"bad_tokens":
paddle.full([1], -1, **int64_config),
"next_tokens":
paddle.full([max_num_seqs, 1], -1, **int64_config),
"is_block_step":
paddle.full([max_num_seqs], False, **bool_config),
"encoder_block_lens":
paddle.full([max_num_seqs], 0, **int32_config),
"step_block_list":
paddle.full([max_num_seqs], -1, **int32_config),
"step_lens":
paddle.full([1], 0, **int32_config),
"recover_block_list":
paddle.full([max_num_seqs], -1, **int32_config),
"recover_lens":
paddle.full([1], 0, **int32_config),
"need_block_list":
paddle.full([max_num_seqs], -1, **int32_config),
"need_block_len":
paddle.full([1], 0, **int32_config),
"used_list_len":
paddle.full([max_num_seqs], 0, **int32_config),
"infer_seed":
paddle.full([max_num_seqs, 1], 0, **int64_config),
"first_token_ids":
paddle.full([max_num_seqs, 1], -1, **int64_config),
"ori_seq_lens_encoder":
paddle.full([max_num_seqs, 1], 0, **int32_config),
"system_lens":
paddle.full([max_num_seqs, 1], 0, **int32_config),
"system_ids":
paddle.full([max_num_seqs, 1], -1, **int32_config),
})
pre_max_block_num = (
self.args.max_model_len + self.args.block_size -
1) // self.args.block_size + self.args.enc_dec_block_num
self.share_inputs["block_tables"] = paddle.full(
[max_num_seqs, pre_max_block_num], -1, **int32_config)
free_list = list(
range(
self.args.total_block_num - 1,
int(self.args.total_block_num * self.args.kv_cache_ratio) - 1,
-1))
self.free_list_len = len(free_list)
self.share_inputs.update({
"free_list":
paddle.to_tensor(free_list, dtype="int32"),
"free_list_len":
paddle.full([1], self.free_list_len, **int32_config),
})
self.share_inputs.update({
"stop_seqs_len":
paddle.full([self.model_cfg.max_stop_seqs_num], 0, **int32_config),
"stop_seqs":
paddle.full([
self.model_cfg.max_stop_seqs_num,
self.model_cfg.stop_seqs_max_len
], -1, **int64_config),
})
def update_chunked_prefill(self, tasks: list[any]) -> None:
"""
update chunked prefill
"""
if not self.args.enable_chunked_prefill:
return
raise NotImplementedError(
"currently chunked_prefill is not supported.")
def prefill_finished(self):
"""
Verify prefill operation completion
"""
return True
@abstractmethod
def init_rotary_position_embedding(self, max_model_len: int) -> None:
"""
Init rotary position embedding
"""
raise NotImplementedError
@abstractmethod
def _load_model(
self,
model_name: str,
dynamic_load_weight: int = 0,
) -> None:
"""
Load the model from the given model name.
"""
raise NotImplementedError
@abstractmethod
def _init_kvcache(self):
"""
Init kv cache
"""
raise NotImplementedError
@abstractmethod
def dy_input_preprocess(self, tasks: list[any]) -> None:
"""
dynamic insertion
"""
raise NotImplementedError

View File

@@ -1,540 +0,0 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import argparse
import time
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
import numpy as np
import paddle
import paddle.distributed as dist
import paddle.distributed.fleet as fleet
from fastdeploy.engine.config import ModelConfig
from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal
from fastdeploy.utils import get_logger, none_or_str
from fastdeploy.worker.worker_process import initialize_fd_config, parse_args
logger = get_logger("worker", "worker.log")
class PrefillTracker:
"""
Record the prefill time of the request
"""
def __init__(
self,
engine_pid: int,
) -> None:
"""
Initialize the PrefillTracker.
"""
super().__init__()
self.start_times = defaultdict(float)
prefill_time_data = np.zeros([100], dtype=np.float32)
self.prefill_time_signal = IPCSignal(name="prefill_time_signal",
array=prefill_time_data,
dtype=np.float32,
suffix=engine_pid,
create=False)
self.current_index = 0
self.executor = ThreadPoolExecutor(max_workers=1)
def start_prefill(self, task_idx: int):
"""
Record the start time of the prefill process for a given task index.
Args:
task_idx (int): The index of the task being prefetched.
"""
self.start_times[task_idx] = time.time()
def end_prefill(self, task_idx: int):
"""
Record the end time of the prefill process for a given task index and
asynchronously submit the duration for metric recording.
Args:
task_idx (int): The index of the task being prefetched.
"""
if task_idx in self.start_times:
duration = time.time() - self.start_times[task_idx]
# Submit metric recording to the executor for asynchronous execution
self.executor.submit(self._record_metrics, duration)
del self.start_times[task_idx]
def _record_metrics(self, duration: float):
"""
Internal method to record the prefill duration into the signal buffer.
Logs the duration and updates a circular buffer of timing metrics.
Args:
duration (float): Time taken for the prefill process in seconds.
"""
self.prefill_time_signal.value[self.current_index] = duration
self.current_index = (self.current_index + 1) % len(
self.prefill_time_signal.value)
def __del__(self):
"""Clean up resources"""
if hasattr(self, 'executor'):
self.executor.shutdown(wait=False)
class Worker:
"""
Engine -> (WIP)Executor -> Worker -> ModelRunner -> Model
Worker interface that allows inference framwork to cleanly separate implementations for different harware.
"""
def __init__(
self,
args,
) -> None:
"""
Initialize the Worker.
"""
super().__init__()
self.args = args
self.MAX_INFER_SEED = 9223372036854775806
paddle.set_default_dtype(args.dtype)
self.device_ids = self.args.device_ids.split(",")
self.model_cfg = ModelConfig(args.model_name_or_path)
from fastdeploy.worker.vl_gpu_model_runner import GPUVLModelRunner
self.init_dist_env()
self.format_print_configuration()
self.helper_tensors = {}
local_rank = self.rank % self.args.tensor_parallel_size
self.local_data_parallel_id = self.rank // self.args.tensor_parallel_size
self.infer_engine = GPUVLModelRunner(config=self.model_cfg,
args=self.args,
nranks=self.nranks,
rank=self.rank)
self.prefill_tracker = PrefillTracker(args.engine_pid)
address = (self.args.pod_ip, self.args.engine_worker_queue_port)
self.engine_worker_queue = EngineWorkerQueue(
address=address,
is_server=False,
num_client=self.nranks,
client_id=local_rank,
local_data_parallel_id=self.local_data_parallel_id)
self.init_health()
def init_dist_env(self, seed=20):
"""
init distributed env
"""
self.nranks = dist.get_world_size()
strategy = fleet.DistributedStrategy()
strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": self.nranks,
"pp_degree": 1,
"sharding_degree": 1,
}
# Set control in tensor parallel
strategy.tensor_parallel_configs = {"tensor_init_seed": seed}
fleet.init(is_collective=True, strategy=strategy)
self.rank = fleet.worker_index()
def init_health(self):
"""
init health signals
"""
# To perceive whether each worker process is ready
worker_ready_signal_data = np.zeros(shape=[self.nranks],
dtype=np.int32)
self.worker_ready_signal = IPCSignal(name="worker_ready_signal",
array=worker_ready_signal_data,
dtype=np.int32,
suffix=self.args.engine_pid,
create=False)
self.worker_ready_signal.value[self.rank] = 1
# To monitor the liveness of worker processes and record each step's timestamp
worker_healthy_live_recorded_time_array = np.zeros(shape=[self.nranks],
dtype=np.int32)
self.worker_healthy_live_signal = IPCSignal(
name="worker_healthy_live_signal",
array=worker_healthy_live_recorded_time_array,
dtype=np.int32,
suffix=self.args.engine_pid,
create=False)
self.worker_healthy_live_signal.value[self.rank] = int(time.time())
# To perceive whether there is a new task to be processed
exist_task_signal_data = np.zeros([1], dtype=np.int32)
self.exist_task_signal = IPCSignal(name="exist_task_signal",
array=exist_task_signal_data,
dtype=np.int32,
suffix=self.args.engine_pid,
create=False)
# To detect whether there are swapped tasks in the worker
exist_swapped_task_signal_data = np.zeros([1], dtype=np.int32)
self.exist_swapped_task_signal = IPCSignal(
name="exist_swapped_task_signal",
array=exist_swapped_task_signal_data,
dtype=np.int32,
suffix=self.args.engine_pid,
create=False)
model_weights_status = np.zeros([1], dtype=np.int32)
self.model_weights_status_signal = IPCSignal(
name="model_weights_status",
array=model_weights_status,
dtype=np.int32,
suffix=self.args.engine_pid,
create=False)
def format_print_configuration(self):
"""
print model config
"""
logger.info("=============== Model Information ==============")
for k, v in self.model_cfg.__dict__.items():
logger.info("{:<20}:{:<6}{}".format(k, "", v))
logger.info("=============== Service Configuration ===============")
for k, v in vars(self.args).items():
logger.info("{:<20}:{:<6}{}".format(k, "", v))
logger.info("=====================================================\n")
def step_cuda(self):
"""
step cuda
"""
from fastdeploy.model_executor.ops.gpu import (step_reschedule,
step_system_cache)
if self.args.enable_prefix_caching:
step_system_cache(
self.infer_engine.share_inputs["stop_flags"],
self.infer_engine.share_inputs["seq_lens_this_time"],
self.infer_engine.share_inputs["step_seq_lens_encoder"],
self.infer_engine.share_inputs["step_seq_lens_decoder"],
self.infer_engine.share_inputs["seq_lens_encoder"],
self.infer_engine.share_inputs["seq_lens_decoder"],
self.infer_engine.share_inputs["block_tables"],
self.infer_engine.share_inputs["encoder_block_lens"],
self.infer_engine.share_inputs["is_block_step"],
self.infer_engine.share_inputs["step_block_list"],
self.infer_engine.share_inputs["step_lens"],
self.infer_engine.share_inputs["recover_block_list"],
self.infer_engine.share_inputs["recover_lens"],
self.infer_engine.share_inputs["need_block_list"],
self.infer_engine.share_inputs["need_block_len"],
self.infer_engine.share_inputs["used_list_len"],
self.infer_engine.share_inputs["free_list"],
self.infer_engine.share_inputs["free_list_len"],
self.infer_engine.share_inputs["input_ids"],
self.infer_engine.share_inputs["pre_ids"],
self.infer_engine.share_inputs["step_idx"],
self.infer_engine.share_inputs["next_tokens"],
self.infer_engine.share_inputs["first_token_ids"],
self.args.block_size, self.args.enc_dec_block_num)
else:
step_reschedule(
self.infer_engine.share_inputs["stop_flags"],
self.infer_engine.share_inputs["seq_lens_this_time"],
self.infer_engine.share_inputs["step_seq_lens_encoder"],
self.infer_engine.share_inputs["seq_lens_encoder"],
self.infer_engine.share_inputs["seq_lens_decoder"],
self.infer_engine.share_inputs["block_tables"],
self.infer_engine.share_inputs["encoder_block_lens"],
self.infer_engine.share_inputs["is_block_step"],
self.infer_engine.share_inputs["step_block_list"],
self.infer_engine.share_inputs["step_lens"],
self.infer_engine.share_inputs["recover_block_list"],
self.infer_engine.share_inputs["recover_lens"],
self.infer_engine.share_inputs["need_block_list"],
self.infer_engine.share_inputs["need_block_len"],
self.infer_engine.share_inputs["used_list_len"],
self.infer_engine.share_inputs["free_list"],
self.infer_engine.share_inputs["free_list_len"],
self.infer_engine.share_inputs["input_ids"],
self.infer_engine.share_inputs["pre_ids"],
self.infer_engine.share_inputs["step_idx"],
self.infer_engine.share_inputs["next_tokens"],
self.infer_engine.share_inputs["first_token_ids"],
self.args.block_size,
self.args.enc_dec_block_num,
)
def check_model_weights_status(self):
"""
check model weights status
"""
is_stop = 0
while self.model_weights_status_signal.value[0] != 0:
if self.model_weights_status_signal.value[0] == 1:
logger.info(
f"infer engine stopped! start to load new checkpoint... {self.rank}"
)
self.infer_engine.update_parameters(self.args.engine_pid)
elif self.model_weights_status_signal.value[0] == -1:
logger.info(
f"infer engine stopped! start to clear checkpoint... {self.rank}"
)
self.infer_engine.clear_parameters(self.args.engine_pid)
while True:
if self.model_weights_status_signal.value[0] == 0:
logger.info(f"finished loading new checkpoint {self.rank}")
break
elif is_stop == 1 or (self.model_weights_status_signal.value[0]
== -2 and is_stop == 0):
if is_stop == 0:
logger.info(
f"finished clearing checkpoint {self.rank}")
is_stop = 1
time.sleep(0.001)
break
else:
time.sleep(0.001)
def run(self):
"""
run function, continuously get tasks and do inference.
"""
infer_seed_increment = paddle.full(shape=[self.args.max_num_seqs, 1],
fill_value=4,
dtype="int64")
self.nnode = int((self.nranks + 7) // 8)
mp_num_per_node = self.nranks // self.nnode
while True:
if self.rank == 0:
if self.model_weights_status_signal.value[0] != 0:
self.exist_task_signal.value[0] = 2
else:
self.exist_task_signal.value[0] = 0
if self.nranks > 1:
paddle.distributed.barrier()
if self.exist_task_signal.value[0] == 2:
self.check_model_weights_status()
self.insert_step = False
self.worker_healthy_live_signal.value[self.rank] = int(time.time())
if self.rank % mp_num_per_node == 0:
if self.engine_worker_queue.num_tasks(
) > 0 and self.infer_engine.prefill_finished():
if self.nnode > 1:
self.engine_worker_queue.read_finish_flag.set(1)
else:
self.exist_task_signal.value[0] = 1
if self.nranks > 1:
paddle.distributed.barrier()
if self.exist_task_signal.value[
0] == 1 or self.engine_worker_queue.read_finish_flag.get(
) == 1:
logger.info(f"Rank: {self.rank} Detected new requests.")
self.insert_step = True
tasks, read_finish = self.engine_worker_queue.get_tasks()
if read_finish:
self.exist_task_signal.value[0] = 0
self.engine_worker_queue.read_finish_flag.set(0)
req_dicts = []
for req_dict, bsz in tasks:
num_running_requests = int(bsz)
req_dicts.extend(req_dict)
req_ids = [req.request_id for req in req_dicts]
logger.info(f"Rank: {self.rank}, num_running_requests: {num_running_requests}, " \
f"num_insert_requests: {len(req_dicts)}. {req_ids}")
self.infer_engine.dy_input_preprocess(req_dicts)
for req_dict in req_dicts:
if self.infer_engine.share_inputs["seq_lens_this_time"][
req_dict.idx] > 1:
self.prefill_tracker.start_prefill(req_dict.idx)
self.infer_engine.share_inputs["not_need_stop"][0] = True
if not self.infer_engine.share_inputs["not_need_stop"]:
time.sleep(0.001)
continue
self.infer_engine.generate()
self.infer_engine.share_inputs["infer_seed"].add_(
infer_seed_increment)
self.infer_engine.share_inputs[
"infer_seed"][:] %= self.MAX_INFER_SEED
for req_dict in req_dicts:
if (self.infer_engine.share_inputs["seq_lens_this_time"][
req_dict.idx] == 1
and req_dict.idx in self.prefill_tracker.start_times):
self.prefill_tracker.end_prefill(req_dict.idx)
self.infer_engine.update_chunked_prefill(req_dicts)
self.step_cuda()
def determine_num_available_blocks(self):
"""Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
start_time = time.time()
GiB = 1024**3
paddle.device.cuda.empty_cache()
paddle.device.cuda.reset_max_memory_allocated()
before_activation_gpu_memory = paddle.device.cuda.max_memory_allocated(
) / GiB
logger.info(
f"before activate gpu memory: {before_activation_gpu_memory} GiB.")
import gc
import pynvml
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(
int(self.device_ids[self.rank]))
meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
total_gpu_memory = meminfo.total / GiB
used_gpu_memory = meminfo.used / GiB
pynvml.nvmlShutdown()
logger.info(f"used gpu memory: {used_gpu_memory} GiB.")
self.run_profile()
current_max_peak_gpu_memory = paddle.device.cuda.max_memory_reserved(
) / GiB
logger.info(
f"current max peak gpu memory: {current_max_peak_gpu_memory} GiB.")
per_block_memory_used = self.infer_engine._cal_theortical_kvcache(
) / GiB
logger.info(f"each kv cache block takes {per_block_memory_used} GiB.")
used_cache_gpu_memory = self.args.total_block_num * per_block_memory_used
logger.info(f"used cache gpu memory: {used_cache_gpu_memory} GiB.")
model_weights_memory = used_gpu_memory - used_cache_gpu_memory
paddle_peak_increase = current_max_peak_gpu_memory - before_activation_gpu_memory
memory_for_current_instance = total_gpu_memory * self.args.gpu_memory_utilization
available_kv_cache_memory = memory_for_current_instance - used_gpu_memory - \
paddle_peak_increase + used_cache_gpu_memory
num_gpu_blocks = max(
int(available_kv_cache_memory // per_block_memory_used),
self.args.total_block_num)
profile_time = time.time() - start_time
msg = (f"Memory profiling takes {profile_time:.2f} seconds\n"
"the current instance can use "
"total_gpu_memory "
f"({(total_gpu_memory):.2f}GiB)"
" x gpu_memory_utilization "
f"({self.args.gpu_memory_utilization})"
f" = {(memory_for_current_instance):.2f}GiB\n"
"model weights take "
f"{(model_weights_memory ):.2f}GiB;"
" Paddle activation peak memory takes "
f"{(paddle_peak_increase):.2f}GiB;"
" the rest of the memory reserved for KV Cache is "
f"{(available_kv_cache_memory):.2f}GiB.")
self.infer_engine.record_profile_msg = {
"per_block_memory_used": per_block_memory_used,
"paddle_peak_increase": paddle_peak_increase,
}
logger.info(msg)
# Final cleanup
get_profile_block_num = np.zeros(shape=[self.nranks], dtype=np.int32)
self.get_profile_block_num_signal = IPCSignal(
name="get_profile_block_num",
array=get_profile_block_num,
dtype=np.int32,
suffix=self.args.engine_pid,
create=False)
self.get_profile_block_num_signal.value[self.rank] = int(
num_gpu_blocks)
while np.any(self.get_profile_block_num_signal.value <= 0):
time.sleep(0.01)
num_gpu_blocks = self.get_profile_block_num_signal.value.min().item()
self.get_profile_block_num_signal.value[self.rank] = int(
num_gpu_blocks)
logger.info(
f"{self.get_profile_block_num_signal.value[self.rank]} GPU KV blocks can be allocated."
)
self.infer_engine.num_gpu_blocks = num_gpu_blocks
self.infer_engine._update_share_input_block_num()
paddle.device.cuda.empty_cache()
gc.collect()
def run_profile(self):
"""
run profile
"""
infer_seed_increment = paddle.full(shape=[self.args.max_num_seqs, 1],
fill_value=4,
dtype="int64")
self.infer_engine.dummy_input(self.args.max_num_batched_tokens,
self.args.max_num_seqs)
while True:
if self.nranks > 1:
paddle.distributed.barrier()
self.infer_engine.generate()
self.infer_engine.share_inputs["infer_seed"].add_(
infer_seed_increment)
self.infer_engine.share_inputs[
"infer_seed"][:] %= self.MAX_INFER_SEED
self.step_cuda()
if int((self.infer_engine.share_inputs['seq_lens_this_time']
> 0).sum()) == 0:
break
def main():
"""
start worker
"""
args = parse_args()
worker = Worker(args)
if args.do_profile:
worker.determine_num_available_blocks()
worker.run()
if __name__ == "__main__":
main()

View File

@@ -549,6 +549,10 @@ def parse_args():
"'ipc_snapshot': load from disk snapshot of IPC weights, "
"'meta': provide RL traing worker, no_weights_load"
"'normal':normal load weight")
parser.add_argument("--enable_mm",
type=str,
default="false",
help="Whether to use vl")
parser.add_argument("--enable_logprob",
action='store_true',
help="Enable output of token-level log probabilities.")
@@ -650,6 +654,8 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
"No quantization config found and use original weight and act dtype."
)
# Set VL tag
model_config.enable_mm = getattr(args, 'enable_mm', 'false').lower() == 'true'
logger.info(f"- Dynamic load weight: {load_config.dynamic_load_weight}")
logger.info(f"- Load strategy: {load_config.load_strategy}")