diff --git a/build.sh b/build.sh index 8608435c1..5597aec2d 100644 --- a/build.sh +++ b/build.sh @@ -162,12 +162,11 @@ function copy_ops(){ is_maca=`$python -c "import paddle; print(paddle.device.is_compiled_with_custom_device('metax_gpu'))"` if [ "$is_maca" = "True" ]; then DEVICE_TYPE="metax_gpu" - mkdir -p ../fastdeploy/model_executor/ops/base - cp -r ${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base cp -r ${TMP_PACKAGE_DIR}/* ../fastdeploy/model_executor/ops/gpu echo -e "MACA ops have been copy to fastdeploy" return fi + is_intel_hpu=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device('intel_hpu'))"` if [ "$is_intel_hpu" = "True" ]; then DEVICE_TYPE="intel-hpu" diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index 4608bd81e..375bb9792 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -43,6 +43,7 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionMetadata, ) from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id +from fastdeploy.platforms import current_platform @dataclass @@ -87,7 +88,10 @@ def allocate_launch_related_buffer( res = {} res["decoder_batch_ids"] = paddle.full([decode_max_tile_size], 0, dtype="int32") res["decoder_tile_ids_per_batch"] = paddle.full([decode_max_tile_size], 0, dtype="int32") - res["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").pin_memory() + if current_platform.is_maca(): + res["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").cpu() + else: + res["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").pin_memory() # NOTE: (changwenbin) MLA kernel only needs decoder_num_blocks_device in place of GPU tensor, # adapted to cudagraph. res["decoder_num_blocks_device"] = paddle.full([1], 0, dtype="int32") diff --git a/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py b/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py index 87d3c2543..cd7649ae7 100644 --- a/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py +++ b/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py @@ -206,6 +206,9 @@ class MetaxMLAAttentionBackend(AttentionBackend): self.seq_lens = seq_lens_decoder + seq_lens_this_time self.block_tables = forward_meta.block_tables[non_zero_index] + self.tile_scheduler_metadata = None + self.num_splits = None + def get_attntion_meta(self) -> AttentionMetadata: """get_attntion_meta""" return self.attention_metadata @@ -250,13 +253,13 @@ class MetaxMLAAttentionBackend(AttentionBackend): ] ) - query = query.reshape([-1, seq_len_q, num_heads_q, head_dim_qk]) + query = query.reshape_([-1, seq_len_q, num_heads_q, head_dim_qk]) - tile_scheduler_metadata, num_splits = get_mla_metadata( - self.seq_lens, seq_len_q * num_heads_q // num_heads_kv, num_heads_kv - ) - - assert tile_scheduler_metadata.shape[0] != 0 + if self.tile_scheduler_metadata is None or self.num_splits is None: + self.tile_scheduler_metadata, self.num_splits = get_mla_metadata( + self.seq_lens, seq_len_q * num_heads_q // num_heads_kv, num_heads_kv + ) + assert self.tile_scheduler_metadata.shape[0] != 0 out = flash_mla_with_kvcache( query, @@ -264,8 +267,8 @@ class MetaxMLAAttentionBackend(AttentionBackend): self.block_tables, self.seq_lens, head_dim_v, - tile_scheduler_metadata, - num_splits, + self.tile_scheduler_metadata, + self.num_splits, softmax_scale=self.attn_softmax_scale, causal=self.causal, )[0] @@ -273,7 +276,7 @@ class MetaxMLAAttentionBackend(AttentionBackend): if seq_len_q != self.seq_lens_this_time_min: out = paddle.concat([paddle.split(x, [n, seq_len_q - n])[0] for x, n in zip(out, self.seq_lens_this_time)]) else: - out = out.reshape([-1, num_heads_q, head_dim_v]) + out = out.reshape_([-1, num_heads_q, head_dim_v]) return out diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index 680d565e1..1b7371392 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -411,14 +411,14 @@ class DeepseekV3MLAAttention(nn.Layer): forward_meta=forward_meta, ) - fmha_out_decode = fmha_out_decode.reshape([-1, self.num_attention_heads_tp, self.kv_lora_rank]).transpose( + fmha_out_decode = fmha_out_decode.reshape_([-1, self.num_attention_heads_tp, self.kv_lora_rank]).transpose( [1, 0, 2] ) fmha_out_decode = ( self.kv_b_proj_bmm(fmha_out_decode, proj_type="v") .transpose([1, 0, 2]) - .reshape([-1, self.num_attention_heads_tp * self.v_head_dim]) + .reshape_([-1, self.num_attention_heads_tp * self.v_head_dim]) ) if fmha_out is None: diff --git a/fastdeploy/worker/metax_model_runner.py b/fastdeploy/worker/metax_model_runner.py index 3038a34fc..bcb558fc6 100644 --- a/fastdeploy/worker/metax_model_runner.py +++ b/fastdeploy/worker/metax_model_runner.py @@ -17,6 +17,7 @@ import os import queue import time +from concurrent.futures import Future from threading import Thread from typing import List, Optional, cast @@ -38,15 +39,25 @@ from fastdeploy.model_executor.graph_optimization.utils import ( profile_run_guard, sot_warmup_guard, ) -from fastdeploy.model_executor.guided_decoding import get_guided_backend +from fastdeploy.model_executor.guided_decoding import ( + LogitsProcessorBase, + get_guided_backend, +) from fastdeploy.model_executor.layers.attention import get_attention_backend +from fastdeploy.model_executor.layers.attention.append_attn_backend import ( + allocate_launch_related_buffer, +) from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionBackend, ) +from fastdeploy.model_executor.layers.moe.routing_indices_cache import ( + RoutingReplayManager, +) from fastdeploy.model_executor.layers.pool.metadata import PoolingMetadata from fastdeploy.model_executor.layers.rotary_embedding import get_rope, get_rope_3d from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata from fastdeploy.model_executor.layers.sample.sampler import Sampler, SpeculativeSampler +from fastdeploy.model_executor.logits_processor import build_logits_processors from fastdeploy.model_executor.model_loader import get_model_loader from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp from fastdeploy.model_executor.models.interfaces_base import FdModelForPooling @@ -65,8 +76,12 @@ from fastdeploy.model_executor.pre_and_post_process import ( ) from fastdeploy.output.pooler import PoolerOutput from fastdeploy.spec_decode import MTPProposer, NgramProposer -from fastdeploy.worker.model_runner_base import ModelRunnerBase -from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput +from fastdeploy.worker.model_runner_base import ( + DistributedOut, + DistributedStatus, + ModelRunnerBase, +) +from fastdeploy.worker.output import LogprobsTensors, ModelOutputData, ModelRunnerOutput class MetaxModelRunner(ModelRunnerBase): @@ -88,6 +103,12 @@ class MetaxModelRunner(ModelRunnerBase): self.enable_logprob = fd_config.model_config.enable_logprob self.enable_early_stop = self.fd_config.early_stop_config.enable_early_stop self.is_pooling_model = self.fd_config.model_config.runner_type == "pooling" + self.ori_vocab_size = self.fd_config.model_config.ori_vocab_size + self.max_logprobs = ( + self.ori_vocab_size if fd_config.model_config.max_logprobs == -1 else fd_config.model_config.max_logprobs + ) + self.prompt_logprobs_reqs: dict[str, Request] = {} + self.in_progress_prompt_logprobs: dict[str, LogprobsTensors] = {} # VL model config: if self.enable_mm: @@ -111,6 +132,12 @@ class MetaxModelRunner(ModelRunnerBase): "matmul_v2", "fused_gemm_epilogue", ] + + if self.cache_config.max_encoder_cache > 0: + self.encoder_cache: dict[str, paddle.Tensor] = {} + else: + self.encoder_cache = None + # Sampler if not self.speculative_decoding: self.sampler = Sampler(fd_config) @@ -126,7 +153,7 @@ class MetaxModelRunner(ModelRunnerBase): # self.kv_caches: list[paddle.Tensor] = [] # CUDA Graph - self.use_cudagraph = False + self.use_cudagraph = self.graph_opt_config.use_cudagraph self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes)) self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes self.cudagraph_only_prefill = self.graph_opt_config.cudagraph_only_prefill @@ -154,6 +181,11 @@ class MetaxModelRunner(ModelRunnerBase): os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.parallel_config.engine_worker_queue_port) logger.info(f"queue id is {str(self.parallel_config.engine_worker_queue_port)}") + # Rollout routing replay config + self.routing_replay_manager = None + if self.fd_config.routing_replay_config.enable_routing_replay: + self.routing_replay_manager = RoutingReplayManager(fd_config=self.fd_config) + self.zmq_client = None self.async_output_queue = None if envs.FD_USE_GET_SAVE_OUTPUT_V1: @@ -182,13 +214,13 @@ class MetaxModelRunner(ModelRunnerBase): """ check whether prefill stage exist """ - return int(paddle.max(self.share_inputs["seq_lens_encoder"])) > 0 + return np.any(self.share_inputs["seq_lens_encoder"].numpy() > 0) def exist_decode(self): """ check whether decode stage exist """ - return int(paddle.max(self.share_inputs["seq_lens_decoder"])) > 0 + return np.any(self.share_inputs["seq_lens_decoder"].numpy() > 0) def only_prefill(self): """ @@ -206,6 +238,56 @@ class MetaxModelRunner(ModelRunnerBase): return if_only_prefill + def collect_distributed_status(self): + """ + Collect distributed status + """ + dist_status_list = [] + dist_status_obj = DistributedStatus() + dist_out = DistributedOut() + + prefill_exists = None + if_only_decode = True + # mix ep in single node + if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed": + prefill_exists = self.exist_prefill() + dist_status_obj.only_decode = not prefill_exists + + # whether chunked moe + if self.fd_config.parallel_config.enable_chunked_moe: + chunk_size = self.fd_config.parallel_config.chunked_moe_size + token_num = self.share_inputs["ids_remove_padding"].shape[0] + + if token_num > chunk_size: + self.forward_meta.moe_num_chunk = (token_num + chunk_size - 1) // chunk_size + else: + self.forward_meta.moe_num_chunk = 1 + + dist_status_obj.moe_num_chunk = self.forward_meta.moe_num_chunk + + # only ep need to collect and sync distributed status + if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed": + # call once to gather all status + paddle.distributed.all_gather_object(dist_status_list, dist_status_obj) + + # Update Batch type for cuda graph for if_only_decode + if_only_decode = all(dist_status.only_decode for dist_status in dist_status_list) + + if_only_decode = if_only_decode and not ( + prefill_exists if prefill_exists is not None else self.exist_prefill() + ) + + max_moe_num_chunk = None + if self.fd_config.parallel_config.enable_chunked_moe: + max_moe_num_chunk = max(dist_status.moe_num_chunk for dist_status in dist_status_list) + + dist_out = DistributedOut( + if_only_decode=if_only_decode, + max_moe_num_chunk=max_moe_num_chunk, + ) + + return dist_out + def only_decode(self): """ check whether decode only @@ -244,7 +326,7 @@ class MetaxModelRunner(ModelRunnerBase): else: self.proposer = None - def _init_logits_processor(self, request): + def _init_logits_processor(self, request) -> tuple[Future[LogitsProcessorBase],]: """ init logits processor for guided decoding """ @@ -264,24 +346,210 @@ class MetaxModelRunner(ModelRunnerBase): return ( self.guided_backend.get_logits_processor( schemata_key=schemata_key, - enable_thinking=True, + enable_thinking=False, # TODO cfg ), schemata_key, ) + def get_chunked_inputs(self, req: Request): + """ + Get inputs in current chunk + """ + prefill_start_index = req.prefill_start_index + prefill_end_index = req.prefill_end_index + inputs = req.multimodal_inputs + input_ids = inputs["input_ids"][prefill_start_index:prefill_end_index] + token_type_ids = inputs["token_type_ids"][prefill_start_index:prefill_end_index] + image_type_ids = inputs["image_type_ids"][req.image_type_ids_start : req.image_type_ids_end] + images = inputs["images"][req.image_start : req.image_end] + grid_thw = inputs["grid_thw"][req.num_image_start : req.num_image_end] + mm_hashes = inputs["mm_hashes"][req.num_image_start : req.num_image_end] + + return ( + input_ids, + token_type_ids, + image_type_ids, + images, + grid_thw, + mm_hashes, + ) + + def batch_uncached_inputs(self, req: Request): + """ + Batch uncached multimodal inputs + """ + (input_ids, token_type_ids, image_type_ids, images, grid_thw, mm_hashes) = self.get_chunked_inputs(req) + + image_type_ids_size = grid_thw[:, 0] + image_type_ids_split = np.cumsum(image_type_ids_size)[:-1] + image_type_ids_lst = np.array_split(image_type_ids, image_type_ids_split, axis=0) + + images_size = np.prod(grid_thw, axis=1) + images_split = np.cumsum(images_size)[:-1] + images_lst = np.array_split(images, images_split, axis=0) + + assert len(image_type_ids_lst) == len( + mm_hashes + ), f"image_type_ids_lst length {len(image_type_ids_lst)} != mm_hashes length {len(mm_hashes)}" + assert len(images_lst) == len( + mm_hashes + ), f"images_lst length {len(images_lst)} != mm_hashes length {len(mm_hashes)}" + + uncached_image_type_ids = [] + uncached_images = [] + uncached_grid_thw = [] + uncached_mm_hashes = [] + for i, mm_hash in enumerate(mm_hashes): + if mm_hash in self.encoder_cache: + continue + uncached_image_type_ids.append(image_type_ids_lst[i]) + uncached_images.append(images_lst[i]) + uncached_grid_thw.append(grid_thw[i]) + uncached_mm_hashes.append(mm_hash) + + uncached_input_ids = paddle.to_tensor(input_ids, dtype=paddle.int64) + uncached_token_type_ids = paddle.to_tensor(token_type_ids, dtype=paddle.int64) + if len(uncached_mm_hashes) > 0: + uncached_image_type_ids = paddle.to_tensor(np.hstack(uncached_image_type_ids), dtype=paddle.int64) + uncached_images = paddle.to_tensor( + np.vstack(uncached_images), dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16" + ) + uncached_grid_thw = paddle.to_tensor(uncached_grid_thw, dtype=paddle.int64) + + return ( + uncached_input_ids, + uncached_token_type_ids, + uncached_image_type_ids, + uncached_images, + uncached_grid_thw, + uncached_mm_hashes, + ) + + def scatter_and_cache_features(self, image_features, inputs): + """ + Split batched image features and cache them + """ + merge_size = 2 + grid_thw = inputs["grid_thw"] + mm_hashes = inputs["mm_hashes"] + image_features_size = (paddle.prod(grid_thw[:, 1:], axis=1) // (merge_size**2)).tolist() + image_features_lst = paddle.split(image_features, image_features_size, axis=0) + + assert len(image_features_lst) == len( + mm_hashes + ), f"image_features_lst length {len(image_features_lst)} != mm_hashes length {len(mm_hashes)}" + for i, mm_hash in enumerate(mm_hashes): + self.encoder_cache[mm_hash] = image_features_lst[i].cpu() + + def _apply_mm_inputs(self, request: Request, multi_vision_inputs: dict, rope_3d_position_ids: dict): + """ + Apply multimodal inputs to share_inputs + - add image_features, extract and cache vision features from model + - add rope_emb, rotate position embeddings + """ + if self.encoder_cache: + evict_mm_hashes = request.get("evict_mm_hashes", None) + if evict_mm_hashes: + for mm_hash in evict_mm_hashes: + self.encoder_cache.pop(mm_hash, None) + + inputs = request.multimodal_inputs + if request.with_image: + if envs.FD_ENABLE_MAX_PREFILL: + multi_vision_inputs["images_lst"].append( + inputs["images"][request.image_start : request.image_end].cuda() + ) + multi_vision_inputs["grid_thw_lst"].extend( + inputs["grid_thw"][request.num_image_start : request.num_image_end] + ) + if "vit_seqlen" in inputs: + multi_vision_inputs["cu_seqlens"].extend( + inputs["vit_seqlen"][request.num_image_start : request.num_image_end] + ) + if "vit_position_ids" in inputs: + multi_vision_inputs["vit_position_ids_lst"].extend( + inputs["vit_position_ids"][request.num_image_start : request.num_image_end] + ) + else: + vision_inputs = inputs + if self.encoder_cache: + ( + vision_inputs["input_ids"], + vision_inputs["token_type_ids"], + vision_inputs["image_type_ids"], + vision_inputs["images"], + vision_inputs["grid_thw"], + vision_inputs["mm_hashes"], + ) = self.batch_uncached_inputs(request) + if len(vision_inputs["mm_hashes"]) > 0: + # uncached multimodal inputs exist + image_features = self.extract_vision_features(vision_inputs) + self.scatter_and_cache_features(image_features, vision_inputs) + + full_image_features_lst = [] + for mm_hash in inputs["mm_hashes"][request.num_image_start : request.num_image_end]: + feature = self.encoder_cache[mm_hash].cuda() + full_image_features_lst.append(feature) + image_features = paddle.concat(full_image_features_lst, axis=0) + else: + ( + input_ids, + token_type_ids, + image_type_ids, + images, + grid_thw, + mm_hashes, + ) = self.get_chunked_inputs(request) + vision_inputs["input_ids"] = paddle.to_tensor(input_ids, dtype=paddle.int64) + vision_inputs["token_type_ids"] = paddle.to_tensor(token_type_ids, dtype=paddle.int64) + vision_inputs["image_type_ids"] = paddle.to_tensor(image_type_ids, dtype=paddle.int64) + vision_inputs["images"] = paddle.to_tensor( + images, dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16" + ) + vision_inputs["grid_thw"] = paddle.to_tensor(grid_thw, dtype=paddle.int64) + vision_inputs["mm_hashes"] = mm_hashes + + image_features = self.extract_vision_features(vision_inputs) + + # part of the first image may be already cached + if "ernie" in self.model_config.model_type: + actual_image_token_num = paddle.sum(vision_inputs["input_ids"] == self.model_config.im_patch_id) + elif "qwen" in self.model_config.model_type: + actual_image_token_num = paddle.sum( + vision_inputs["input_ids"] == vision_inputs["image_patch_id"] + ) + paddle.sum(vision_inputs["input_ids"] == vision_inputs["video_patch_id"]) + else: + raise ValueError(f"multiple modalities model {self.model_config.model_type} is not supported") + self.share_inputs["image_features"] = image_features[-actual_image_token_num:] + + position_ids = request.multimodal_inputs["position_ids"] + rope_3d_position_ids["position_ids_idx"].append(request.idx) + rope_3d_position_ids["position_ids_lst"].append(position_ids) + rope_3d_position_ids["position_ids_offset"].append( + position_ids.shape[0] + rope_3d_position_ids["position_ids_offset"][-1] + ) + + if self.is_pooling_model: + rope_3d_position_ids["max_tokens_lst"].append(0) + else: + rope_3d_position_ids["max_tokens_lst"].append(request.get("max_tokens", 2048)) + def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = None): """ Process scheduler output tasks, used when ENABLE_V1_KVCACHE_SCHEDULER=1 req_dict: A list of Request dict num_running_requests: batch_size """ - # Lazy initialize kv cache + # NOTE(luotingdan): Lazy initialize kv cache if "caches" not in self.share_inputs: self.initialize_kv_cache() req_len = len(req_dicts) has_prefill_task = False has_decode_task = False + + batch_pooling_params = [] + self.share_inputs["image_features"] = None multi_vision_inputs = {"images_lst": [], "grid_thw_lst": [], "vit_position_ids_lst": [], "cu_seqlens": [0]} rope_3d_position_ids = { "position_ids_idx": [], @@ -293,75 +561,60 @@ class MetaxModelRunner(ModelRunnerBase): for i in range(req_len): request = req_dicts[i] idx = request.idx + + if hasattr(request, "pooling_params") and request.pooling_params is not None: + batch_pooling_params.append(request.pooling_params) + + logits_info = None + prefill_tokens = [] if request.task_type.value == RequestType.PREFILL.value: # prefill task + # guided decoding + if ( + request.guided_json is not None + or request.guided_regex is not None + or request.structural_tag is not None + or request.guided_grammar is not None + ): + logits_info, schemata_key = self._init_logits_processor(request) + request.schemata_key = schemata_key + + if self.scheduler_config.splitwise_role == "decode": + if ( + hasattr(request, "prefill_end_index") + and hasattr(request, "prompt_token_ids") + and request.prefill_end_index > len(request.prompt_token_ids) + ): + if hasattr(request, "output_token_ids"): + prefill_tokens.extend(request.output_token_ids) + prefill_start_index = request.prefill_start_index prefill_end_index = request.prefill_end_index length = prefill_end_index - prefill_start_index if self.enable_mm: - inputs = request.multimodal_inputs - if request.with_image: - if envs.FD_ENABLE_MAX_PREFILL: - multi_vision_inputs["images_lst"].append( - inputs["images"][request.image_start : request.image_end].cuda() - ) - multi_vision_inputs["grid_thw_lst"].extend( - inputs["grid_thw"][request.num_image_start : request.num_image_end] - ) - multi_vision_inputs["cu_seqlens"].extend( - inputs["vit_seqlen"][request.num_image_start : request.num_image_end] - ) - multi_vision_inputs["vit_position_ids_lst"].extend( - inputs["vit_position_ids"][request.num_image_start : request.num_image_end] - ) - else: - vision_inputs = {} - vision_inputs["input_ids"] = paddle.to_tensor( - inputs["input_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64 - ) - vision_inputs["token_type_ids"] = paddle.to_tensor( - inputs["token_type_ids"][prefill_start_index:prefill_end_index], dtype=paddle.int64 - ) - vision_inputs["image_type_ids"] = paddle.to_tensor( - inputs["image_type_ids"][request.image_type_ids_start : request.image_type_ids_end], - dtype=paddle.int64, - ) - vision_inputs["images"] = paddle.to_tensor( - inputs["images"][request.image_start : request.image_end], - dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16", - ) - vision_inputs["grid_thw"] = paddle.to_tensor( - inputs["grid_thw"][request.num_image_start : request.num_image_end], dtype="int64" - ) - self.share_inputs["image_features"] = self.extract_vision_features(vision_inputs) + self._apply_mm_inputs(request, multi_vision_inputs, rope_3d_position_ids) + + if not self.is_pooling_model: + if request.get("enable_thinking", False) and request.get("reasoning_max_tokens", None) is not None: + # Enable thinking + self.share_inputs["max_think_lens"][idx : idx + 1, :] = request.get("reasoning_max_tokens") + self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0 else: - self.share_inputs["image_features"] = None - - position_ids = request.multimodal_inputs["position_ids"] - rope_3d_position_ids["position_ids_idx"].append(idx) - rope_3d_position_ids["position_ids_lst"].append(position_ids) - rope_3d_position_ids["position_ids_offset"].append( - position_ids.shape[0] + rope_3d_position_ids["position_ids_offset"][-1] - ) - rope_3d_position_ids["max_tokens_lst"].append(request.get("max_tokens", 2048)) - - if request.get("enable_thinking", False) and request.get("reasoning_max_tokens", None) is not None: - # Enable thinking - self.share_inputs["max_think_lens"][idx : idx + 1, :] = request.get("reasoning_max_tokens") - self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0 - else: - # Disable thinking - self.share_inputs["max_think_lens"][idx : idx + 1, :] = -1 - self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0 + # Disable thinking + self.share_inputs["max_think_lens"][idx : idx + 1, :] = -1 + self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0 if isinstance(request.prompt_token_ids, np.ndarray): prompt_token_ids = request.prompt_token_ids.tolist() else: prompt_token_ids = request.prompt_token_ids input_ids = prompt_token_ids + request.output_token_ids + prompt_len = len(prompt_token_ids) + self.share_inputs["prompt_ids"][idx : idx + 1, :prompt_len] = np.array(prompt_token_ids, dtype="int64") logger.debug( f"Handle prefill request {request} at idx {idx}, " f"{prefill_start_index=}, {prefill_end_index=}, " f"need_prefilled_token_num={len(input_ids)}" + f"prompt_len={prompt_len}" ) self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array( input_ids[prefill_start_index:prefill_end_index] @@ -379,11 +632,25 @@ class MetaxModelRunner(ModelRunnerBase): self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = 0 self.share_inputs["prompt_lens"][idx : idx + 1] = len(input_ids) self.share_inputs["is_block_step"][idx : idx + 1] = False + self.share_inputs["is_chunk_step"][idx : idx + 1] = prefill_end_index < len(input_ids) self.share_inputs["step_idx"][idx : idx + 1] = ( len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0 ) self.share_inputs["pre_ids"][idx : idx + 1] = -1 + # pooling model request.sampling_params is None + if request.sampling_params is not None and request.sampling_params.prompt_logprobs is not None: + self.prompt_logprobs_reqs[request.request_id] = request has_prefill_task = True + + # Routing Replay + if self.fd_config.routing_replay_config.enable_routing_replay: + if prefill_start_index == 0: + self.routing_replay_manager.register_request(batch_id=idx, request_id=request.request_id) + + if ( + self.fd_config.scheduler_config.splitwise_role == "decode" + ): # In PD, we continue to decode after P generate first token + self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0 elif request.task_type.value == RequestType.DECODE.value: # decode task logger.debug(f"Handle decode request {request} at idx {idx}") encoder_block_num = len(request.block_tables) @@ -403,6 +670,8 @@ class MetaxModelRunner(ModelRunnerBase): self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0 self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0 self.share_inputs["is_block_step"][idx : idx + 1] = False + self.prompt_logprobs_reqs.pop(request.request_id, None) + self.in_progress_prompt_logprobs.pop(request.request_id, None) continue assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens @@ -428,7 +697,6 @@ class MetaxModelRunner(ModelRunnerBase): ) self.share_inputs["first_token_ids"][idx : idx + 1] = self.share_inputs["input_ids"][idx : idx + 1, :1] - self.share_inputs["ori_seq_lens_encoder"][idx : idx + 1] = length if request.get("seed") is not None: self.share_inputs["infer_seed"][idx : idx + 1] = request.get("seed") @@ -456,6 +724,12 @@ class MetaxModelRunner(ModelRunnerBase): else: self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0 + self.pooling_params = batch_pooling_params + # For logits processors + self.share_inputs["logits_processors_args"][idx] = request.get("logits_processors_args") or {} + + self.sampler.apply_logits_processor(idx, logits_info, prefill_tokens) + if len(multi_vision_inputs["images_lst"]) > 0: self.share_inputs["image_features"] = self.extract_vision_features(multi_vision_inputs) @@ -496,6 +770,7 @@ class MetaxModelRunner(ModelRunnerBase): length = len(request.prompt_token_ids) assert length > 0, "The prompt requested must not be empty." + logits_info = None prefill_tokens = [] if ( request.guided_json is not None @@ -504,7 +779,6 @@ class MetaxModelRunner(ModelRunnerBase): or request.guided_grammar is not None ): logits_info, schemata_key = self._init_logits_processor(request) - request.logits_processor = logits_info request.schemata_key = schemata_key # Is Decode Node @@ -591,14 +865,15 @@ class MetaxModelRunner(ModelRunnerBase): )[0] self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0 - if request.get("enable_thinking", False) and request.get("reasoning_max_tokens", None) is not None: - # Enable thinking - self.share_inputs["max_think_lens"][idx : idx + 1, :] = request.get("reasoning_max_tokens") - self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0 - else: - # Disable thinking - self.share_inputs["max_think_lens"][idx : idx + 1, :] = -1 - self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0 + if not self.is_pooling_model: + if request.get("enable_thinking", False) and request.get("reasoning_max_tokens", None) is not None: + # Enable thinking + self.share_inputs["max_think_lens"][idx : idx + 1, :] = request.get("reasoning_max_tokens") + self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0 + else: + # Disable thinking + self.share_inputs["max_think_lens"][idx : idx + 1, :] = -1 + self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0 def get_attr_from_request(request, attr, default_value=None): res = request.get(attr, default_value) @@ -639,7 +914,6 @@ class MetaxModelRunner(ModelRunnerBase): self.share_inputs["stop_flags"][idx : idx + 1] = False self.share_inputs["first_token_ids"][idx : idx + 1] = self.share_inputs["input_ids"][idx : idx + 1, :1] - self.share_inputs["ori_seq_lens_encoder"][idx : idx + 1] = length if request.get("seed") is not None: self.share_inputs["infer_seed"][idx : idx + 1] = request.get("seed") @@ -673,7 +947,7 @@ class MetaxModelRunner(ModelRunnerBase): else: self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0 - self.sampler.apply_logits_processor(idx, request.get("logits_processor"), prefill_tokens) + self.sampler.apply_logits_processor(idx, logits_info, prefill_tokens) self.share_inputs["not_need_stop"][0] = True @@ -722,10 +996,14 @@ class MetaxModelRunner(ModelRunnerBase): """ # NOTE(gongshaotian): The maximum decoding length is equal to the expected decoded tokens plus the eos token max_dec_len = expected_decode_len + 1 - input_length = min( - num_tokens // (1 if capture_prefill else batch_size), - self.model_config.max_model_len - max_dec_len, - ) + if batch_size == 0: + # Note(ZKK): divided by 0 is invalid, here we give a input_length = 1 + input_length = 1 + else: + input_length = min( + num_tokens // (1 if capture_prefill else batch_size), + self.model_config.max_model_len - max_dec_len, + ) # NOTE(wanglongzhi): When the full length is too large, DeepEP's buffer size will not be enough to cause the result to appear nan. # TODO(wanglongzhi): Figure out the accurate buffer size of DeepEP. @@ -760,7 +1038,7 @@ class MetaxModelRunner(ModelRunnerBase): if self.cache_config.enable_chunked_prefill and "encode" in supported_tasks: supported_tasks.remove("encode") - logger.warning( + logger.debug( "Chunked prefill is not supported with " "encode task which using ALL pooling. " "Please turn off chunked prefill by export=FD_DISABLE_CHUNKED_PREFILL=1 before using it." @@ -791,9 +1069,7 @@ class MetaxModelRunner(ModelRunnerBase): self.share_inputs["min_dec_len"][idx : idx + 1] = max_dec_len self.share_inputs["stop_flags"][idx : idx + 1] = False self.share_inputs["temperature"][idx : idx + 1] = 1 - self.share_inputs["first_token_ids"][idx : idx + 1] = self.share_inputs["input_ids"][idx : idx + 1, :1] - self.share_inputs["ori_seq_lens_encoder"][idx : idx + 1] = input_length self.share_inputs["encoder_block_lens"][idx : idx + 1] = block_num self.share_inputs["block_tables"][idx : idx + 1, :block_num] = np.arange( @@ -867,6 +1143,7 @@ class MetaxModelRunner(ModelRunnerBase): self.share_inputs["bad_tokens_len"] = paddle.full([max_num_seqs], 1, dtype="int64") self.share_inputs["next_tokens"] = paddle.full([max_num_seqs, 1], -1, dtype="int64") self.share_inputs["is_block_step"] = paddle.full([max_num_seqs], False, dtype="bool") + self.share_inputs["is_chunk_step"] = paddle.full([max_num_seqs], False, dtype="bool").cpu() self.share_inputs["encoder_block_lens"] = paddle.full([max_num_seqs], 0, dtype="int32") self.share_inputs["step_block_list"] = paddle.full([max_num_seqs], -1, dtype="int32") self.share_inputs["step_lens"] = paddle.full([1], 0, dtype="int32") @@ -877,7 +1154,6 @@ class MetaxModelRunner(ModelRunnerBase): self.share_inputs["used_list_len"] = paddle.full([max_num_seqs], 0, dtype="int32") self.share_inputs["infer_seed"] = paddle.full([max_num_seqs, 1], 0, dtype="int64").cpu() self.share_inputs["first_token_ids"] = paddle.full([max_num_seqs, 1], -1, dtype="int64") - self.share_inputs["ori_seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["system_lens"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["system_ids"] = paddle.full([max_num_seqs, 1], -1, dtype="int32") @@ -1000,7 +1276,12 @@ class MetaxModelRunner(ModelRunnerBase): if self.enable_mm: head_dim = self.model_config.head_dim - rope_head_dim = head_dim // 2 + if ( + "qwen" in self.model_config.model_type or "paddleocr" in self.model_config.model_type + ): # neox style = True + rope_head_dim = head_dim + else: # neox style = False + rope_head_dim = head_dim // 2 self.share_inputs["rope_emb"] = paddle.full( shape=[ @@ -1016,7 +1297,14 @@ class MetaxModelRunner(ModelRunnerBase): ) self.share_inputs["image_features"] = None - def _prepare_inputs(self) -> None: + # For logits processors + self.share_inputs["logits_processors"] = build_logits_processors(self.fd_config) + self.share_inputs["logits_processors_args"] = [{} for _ in range(max_num_seqs)] + logger.info(f"Enabled logits processors: {self.share_inputs['logits_processors']}") + + self.share_inputs["mask_rollback"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") + + def _prepare_inputs(self, is_dummy_or_profile_run=False) -> None: """Prepare the model inputs""" if envs.ENABLE_V1_KVCACHE_SCHEDULER: recover_decode_task( @@ -1054,6 +1342,7 @@ class MetaxModelRunner(ModelRunnerBase): self.share_inputs["ids_remove_padding"].copy_(ids_remove_padding, False) # NOTE: (changwenbin) Initialized to max_num_seq '-1' before copying, marking illegal positions self.share_inputs["batch_id_per_token"][:] = -1 + self.share_inputs["batch_id_per_token"].copy_(batch_id_per_token, False) self.share_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False) self.share_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False) @@ -1063,11 +1352,10 @@ class MetaxModelRunner(ModelRunnerBase): self.share_inputs["output_padding_offset"].copy_(output_padding_offset, False) # Update bad tokens len - max_bad_tokens_len = paddle.max(self.share_inputs["bad_tokens_len"]) + max_bad_tokens_len = np.max(self.share_inputs["bad_tokens_len"].numpy()) # Initialize forward meta data - self.initialize_forward_meta() - self.forward_meta.batch_id_per_token.copy_(batch_id_per_token, False) + self.initialize_forward_meta(is_dummy_or_profile_run=is_dummy_or_profile_run) # Get sampling metadata self.sampling_metadata = SamplingMetadata( @@ -1088,11 +1376,12 @@ class MetaxModelRunner(ModelRunnerBase): min_dec_lens=self.share_inputs["min_dec_len"], bad_words_token_ids=self.share_inputs["bad_tokens"][:, :max_bad_tokens_len], eos_token_ids=self.share_inputs["eos_token_id"], - max_num_logprobs=20 if self.enable_logprob else None, + max_num_logprobs=self.max_logprobs if self.enable_logprob else None, enable_early_stop=self.enable_early_stop, stop_flags=self.share_inputs["stop_flags"], temp_scaled_logprobs=self.share_inputs["temp_scaled_logprobs"], top_p_normalized_logprobs=self.share_inputs["top_p_normalized_logprobs"], + logits_processors=self.share_inputs["logits_processors"], share_inputs=self.share_inputs, ) @@ -1102,6 +1391,7 @@ class MetaxModelRunner(ModelRunnerBase): # 1. Load original model model_loader = get_model_loader(load_config=self.fd_config.load_config) self.model = model_loader.load_model(fd_config=self.fd_config) + # 1.1 Load RL dynamic model if self.fd_config.load_config.dynamic_load_weight: from fastdeploy.rl.dynamic_weight_manager import DynamicWeightManager @@ -1119,11 +1409,14 @@ class MetaxModelRunner(ModelRunnerBase): """Get current model""" return self.model - def initialize_forward_meta(self): + def initialize_forward_meta(self, is_dummy_or_profile_run=False): """ - Initialize forward meta and attention meta data + Initialize forward meta, attention meta data and update some config. """ # Initialize forward meta + routing_replay_table = None + if self.routing_replay_manager is not None: + routing_replay_table = self.routing_replay_manager.get_routing_table() self.forward_meta = ForwardMeta( ids_remove_padding=self.share_inputs["ids_remove_padding"], rotary_embs=self.share_inputs["rope_emb"], @@ -1150,16 +1443,23 @@ class MetaxModelRunner(ModelRunnerBase): kv_batch_ids=self.share_inputs["kv_batch_ids"], kv_tile_ids_per_batch=self.share_inputs["kv_tile_ids_per_batch"], kv_num_blocks_x_cpu=self.share_inputs["kv_num_blocks_x_cpu"], + routing_replay_table=routing_replay_table, ) - # Update Batch type for cuda graph for only_decode_batch - if_only_decode = self.only_decode() + dist_status = self.collect_distributed_status() + + if_only_decode = dist_status.if_only_decode + if self.fd_config.parallel_config.enable_chunked_moe: + self.forward_meta.max_moe_num_chunk = dist_status.max_moe_num_chunk + only_decode_use_cudagraph = self.use_cudagraph and if_only_decode # Update config about moe for better performance # TODO(wanglongzhi):Modifying the config at runtime is not appropriate; it needs to be moved to forward_meta. It will be used in MoEMethodBase.apply() if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed": self.fd_config.model_config.moe_phase.phase = "decode" if if_only_decode else "prefill" + if self.speculative_decoding: + self.proposer.fd_config.parallel_config.moe_phase.phase = "decode" if if_only_decode else "prefill" # Update Batch type for cuda graph for only_prefill_batch only_prefill_use_cudagraph = self.use_cudagraph and self.cudagraph_only_prefill and self.only_prefill() @@ -1169,6 +1469,9 @@ class MetaxModelRunner(ModelRunnerBase): only_prefill_use_cudagraph if self.cudagraph_only_prefill else only_decode_use_cudagraph ) + # Set forward_meta.is_dummy_or_profile_run to True to skip init_kv_signal_per_query for attention backends + self.forward_meta.is_dummy_or_profile_run = is_dummy_or_profile_run + # Initialzie attention meta data for attn_backend in self.attn_backends: attn_backend.init_attention_metadata(self.forward_meta) @@ -1182,7 +1485,6 @@ class MetaxModelRunner(ModelRunnerBase): # Get kv cache dtype cache_type = self.model_config.dtype - kv_cache_quant_type = None if ( self.quant_config @@ -1223,12 +1525,20 @@ class MetaxModelRunner(ModelRunnerBase): logger.info(f"Initializing kv cache for all layers. {cache_ready_signal.value}") cache_kvs_list = [] + # NOTE:(changwenbin) Determine whether it is Multi-Head Latent Attention, + # To rationalize the allocation of kvcache. + from fastdeploy import envs + + self.mla_cache = envs.FD_ATTENTION_BACKEND == "MLA_ATTN" for i in range(self.model_config.num_hidden_layers): + # init key cache key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}" + key_cache_scales_name = f"key_cache_scales_{i}_rank{local_rank}.device{self.device}" if value_cache_shape: val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}" + value_cache_scales_name = f"value_cache_scales_{i}_rank{local_rank}.device{self.device}" if create_cache_tensor: - logger.info(f"..creating kv cache for layer {i}: {key_cache_shape} {value_cache_shape}") + logger.info(f"..creating kv cache for layer {i}: key:{key_cache_shape}, value:{value_cache_shape}") key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_type) set_data_ipc(key_cache, key_cache_name) if value_cache_shape: @@ -1241,7 +1551,7 @@ class MetaxModelRunner(ModelRunnerBase): key_cache_scales = paddle.full( shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype() ) - if not self.mla_cache: + if value_cache_shape: val_cache_scales = paddle.full( shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype() ) @@ -1249,15 +1559,28 @@ class MetaxModelRunner(ModelRunnerBase): else: cache_kvs_list.extend([key_cache_scales]) else: - logger.info(f"..attaching kv cache for layer {i}: {key_cache_shape} {value_cache_shape}") + logger.info(f"..attaching kv cache for layer {i}: key:{key_cache_shape}, value:{value_cache_shape}") key_cache = paddle.empty(shape=[], dtype=cache_type) key_cache = share_external_data(key_cache, key_cache_name, key_cache_shape) + if kv_cache_quant_type == "block_wise_fp8": + key_cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype()) + key_cache_scales = share_external_data( + key_cache_scales, key_cache_scales_name, kv_cache_scale_shape + ) if value_cache_shape: val_cache = paddle.empty(shape=[], dtype=cache_type) val_cache = share_external_data(val_cache, val_cache_name, value_cache_shape) cache_kvs_list.extend([key_cache, val_cache]) + if kv_cache_quant_type == "block_wise_fp8": + val_cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype()) + val_cache_scales = share_external_data( + val_cache_scales, value_cache_scales_name, kv_cache_scale_shape + ) + cache_kvs_list.extend([key_cache_scales, val_cache_scales]) else: cache_kvs_list.extend([key_cache]) + if kv_cache_quant_type == "block_wise_fp8": + cache_kvs_list.extend([key_cache_scales]) self.share_inputs["caches"] = cache_kvs_list @@ -1280,41 +1603,20 @@ class MetaxModelRunner(ModelRunnerBase): ) head_dim = self.model_config.head_dim - # Initialize AttentionBackend buffers encoder_block_shape_q = 64 decoder_block_shape_q = 16 - decoder_step_token_num = self.speculative_config.num_speculative_tokens + 1 - group_size = np.ceil(num_heads / self.model_config.kv_num_heads) - # NOTE: (changwenbin) When using auto_chunk, - # decode_max_tile_size must take into account the maximum case, where *1024 can cover 128K. - decode_max_tile_size = ( - 1024 - * self.scheduler_config.max_num_seqs - * np.ceil((decoder_step_token_num * group_size) / decoder_block_shape_q) + res_buffer = allocate_launch_related_buffer( + max_batch_size=self.scheduler_config.max_num_seqs, + max_model_len=self.model_config.max_model_len, + encoder_block_shape_q=encoder_block_shape_q, + decoder_block_shape_q=decoder_block_shape_q, + decoder_step_token_num=self.speculative_config.num_speculative_tokens + 1, + num_heads=num_heads, + kv_num_heads=self.model_config.kv_num_heads, + block_size=self.fd_config.cache_config.block_size, ) - encode_max_tile_size = self.scheduler_config.max_num_seqs * np.ceil( - (self.model_config.max_model_len * group_size) / encoder_block_shape_q - ) - kv_max_tile_size = self.scheduler_config.max_num_seqs * np.ceil( - self.model_config.max_model_len / self.fd_config.cache_config.block_size - ) - self.share_inputs["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") - self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") - self.share_inputs["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").cpu() - # NOTE: (changwenbin) MLA kernel only needs decoder_num_blocks_device in place of GPU tensor, - # adapted to cudagraph. - self.share_inputs["decoder_num_blocks_device"] = paddle.full([1], 0, dtype="int32") - self.share_inputs["decoder_chunk_size_device"] = paddle.full([1], 64, dtype="int32") - self.share_inputs["max_len_tensor_cpu"] = paddle.full([9], 0, dtype="int32").cpu() - - self.share_inputs["encoder_batch_ids"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32") - self.share_inputs["encoder_tile_ids_per_batch"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32") - self.share_inputs["encoder_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu() - - self.share_inputs["kv_batch_ids"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32") - self.share_inputs["kv_tile_ids_per_batch"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32") - self.share_inputs["kv_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu() + self.share_inputs.update(res_buffer) # Get the attention backend attn_cls = get_attention_backend() @@ -1344,12 +1646,8 @@ class MetaxModelRunner(ModelRunnerBase): assert len(num_scheduled_tokens_list) == num_reqs req_num_tokens = num_tokens // num_reqs - - dummy_prompt_lens = paddle.to_tensor(num_scheduled_tokens_list, dtype="int64") - dummy_token_ids = paddle.zeros( - [num_reqs, req_num_tokens], - dtype="int64", - ) + dummy_prompt_lens = paddle.to_tensor(num_scheduled_tokens_list, dtype="int64", place=paddle.CPUPlace()) + dummy_token_ids = paddle.zeros([num_reqs, req_num_tokens], dtype="int64", device=hidden_states.place) model = cast(FdModelForPooling, self.get_model()) dummy_pooling_params = PoolingParams(task=task) to_update = model.pooler.get_pooling_updates(task) @@ -1378,22 +1676,66 @@ class MetaxModelRunner(ModelRunnerBase): def _dummy_pooler_run( self, hidden_states: paddle.Tensor, + model_output: paddle.Tensor, ) -> PoolerOutput: output_size = dict[PoolingTask, float]() for task in self.get_supported_pooling_tasks(): + output = self._dummy_pooler_run_task(hidden_states, task) - output_size[task] = output.get_data_nbytes() + output_size[task] = sum(o.numel() * o.element_size() if hasattr(o, "numel") else 0 for o in output) del output max_task = max(output_size.items(), key=lambda x: x[1])[0] - final_output = self._dummy_pooler_run_task(hidden_states, max_task) + pooler_output = self._dummy_pooler_run_task(hidden_states, max_task) - return final_output + model_output_data = ModelOutputData( + next_tokens=self.share_inputs["next_tokens"], + stop_flags=self.share_inputs["stop_flags"], + step_idx=self.share_inputs["step_idx"], + max_dec_len=self.share_inputs["max_dec_len"], + pre_ids=self.share_inputs["pre_ids"], + seq_lens_this_time=self.share_inputs["seq_lens_this_time"], + eos_token_id=self.share_inputs["eos_token_id"], + not_need_stop=self.share_inputs["not_need_stop"], + input_ids=self.share_inputs["input_ids"], + stop_nums=self.share_inputs["stop_nums"], + seq_lens_encoder=self.share_inputs["seq_lens_encoder"], + seq_lens_decoder=self.share_inputs["seq_lens_decoder"], + is_block_step=self.share_inputs["is_block_step"], + full_hidden_states=model_output, + msg_queue_id=self.parallel_config.msg_queue_id, + mp_rank=self.parallel_config.tensor_parallel_rank, + use_ep=self.parallel_config.use_ep, + draft_tokens=(self.share_inputs["draft_tokens"] if self.speculative_decoding else None), + actual_draft_token_num=( + self.share_inputs["actual_draft_token_num"] if self.speculative_decoding else None + ), + accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None), + accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None), + stop_token_ids=self.share_inputs["stop_seqs"], + stop_seqs_len=self.share_inputs["stop_seqs_len"], + prompt_lens=self.share_inputs["prompt_lens"], + ) + + post_process( + sampler_or_pooler_output=pooler_output, + model_output=model_output_data, + share_inputs=self.share_inputs, + block_size=self.cache_config.block_size, + speculative_decoding=self.speculative_decoding, + skip_save_output=True, + async_output_queue=self.async_output_queue, + think_end_id=self.model_config.think_end_id, + line_break_id=self.model_config.line_break_id, + ) + return pooler_output def _dummy_sampler_run( self, hidden_states: paddle.Tensor, model_output: paddle.Tensor, + accept_all_drafts=False, + reject_all_drafts=False, ) -> paddle.Tensor: logits = self.model.compute_logits(hidden_states) @@ -1420,6 +1762,8 @@ class MetaxModelRunner(ModelRunnerBase): self.sampling_metadata, self.model_config.max_model_len, self.share_inputs, + accept_all_drafts, + reject_all_drafts, ) sampler_output = None if self.parallel_config.tensor_parallel_size > 1: @@ -1470,6 +1814,8 @@ class MetaxModelRunner(ModelRunnerBase): accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None), stop_token_ids=self.share_inputs["stop_seqs"], stop_seqs_len=self.share_inputs["stop_seqs_len"], + prompt_lens=self.share_inputs["prompt_lens"], + mask_rollback=self.share_inputs["mask_rollback"], ) post_process( @@ -1486,7 +1832,9 @@ class MetaxModelRunner(ModelRunnerBase): if self.speculative_decoding: if self.speculative_method == "mtp": self.proposer.run( - full_hidden_states=model_output, step_use_cudagraph=self.forward_meta.step_use_cudagraph + full_hidden_states=model_output, + step_use_cudagraph=self.forward_meta.step_use_cudagraph, + is_dummy_run=True, ) else: self.proposer.run(share_inputs=self.share_inputs) @@ -1495,12 +1843,13 @@ class MetaxModelRunner(ModelRunnerBase): def _dummy_run( self, - num_tokens: paddle.Tensor, - batch_size: paddle.Tensor, + num_tokens: int, + batch_size: int, expected_decode_len: int = 1, in_capturing: bool = False, capture_prefill: bool = False, accept_all_drafts: bool = False, + reject_all_drafts: bool = False, ) -> paddle.Tensor: """ Use dummy inputs to run before formal execution. @@ -1510,8 +1859,8 @@ class MetaxModelRunner(ModelRunnerBase): in_capturing: Is cuda graph in capturing state capture_prefill: Capture pure prefill for cuda graph accept_all_drafts: Target model will accept all draft tokens + reject_all_drafts: Target model will reject all draft tokens """ - input_length_list, max_dec_len_list, block_num = self.get_input_length_list( num_tokens=num_tokens, batch_size=batch_size, @@ -1532,7 +1881,7 @@ class MetaxModelRunner(ModelRunnerBase): while True: # 1. Initialize forward meta and attention meta data - self._prepare_inputs() + self._prepare_inputs(is_dummy_or_profile_run=True) # 2. Padding inputs for cuda graph self.forward_meta.step_use_cudagraph = in_capturing and self.forward_meta.step_use_cudagraph @@ -1541,35 +1890,34 @@ class MetaxModelRunner(ModelRunnerBase): # 3. Run model if self.enable_mm: model_output = self.model( - self.share_inputs["ids_remove_padding"], + self.forward_meta.ids_remove_padding, self.share_inputs["image_features"], self.forward_meta, ) else: model_output = self.model( - ids_remove_padding=self.share_inputs["ids_remove_padding"], - forward_meta=self.forward_meta, + self.forward_meta.ids_remove_padding, + self.forward_meta, ) if self.use_cudagraph: model_output = model_output[: self.real_token_num] - hidden_states = rebuild_padding( - model_output, - self.share_inputs["cu_seqlens_q"], - self.share_inputs["seq_lens_this_time"], - self.share_inputs["seq_lens_decoder"], - self.share_inputs["seq_lens_encoder"], - ( - self.share_inputs["output_padding_offset"] if self.speculative_decoding else None - ), # speculative decoding requires - self.model_config.max_model_len, - ) - if self.is_pooling_model: - self._dummy_pooler_run(hidden_states) + self._dummy_pooler_run(model_output, model_output) break else: - self._dummy_sampler_run(hidden_states, model_output) + hidden_states = rebuild_padding( + model_output, + self.share_inputs["cu_seqlens_q"], + self.share_inputs["seq_lens_this_time"], + self.share_inputs["seq_lens_decoder"], + self.share_inputs["seq_lens_encoder"], + ( + self.share_inputs["output_padding_offset"] if self.speculative_decoding else None + ), # speculative decoding requires + self.model_config.max_model_len, + ) + self._dummy_sampler_run(hidden_states, model_output, accept_all_drafts, reject_all_drafts) # 7. Updata 'infer_seed' and step_cuda() self.share_inputs["infer_seed"].add_(self.infer_seed_increment) @@ -1584,6 +1932,9 @@ class MetaxModelRunner(ModelRunnerBase): if int((self.share_inputs["seq_lens_this_time"] > 0).sum()) == 0: break + if self.fd_config.routing_replay_config.enable_routing_replay: + self.routing_replay_manager.clear_routing_table() + def _update_chunked_prefill(self, tasks): """ Update chunked prefill related parameters @@ -1680,7 +2031,12 @@ class MetaxModelRunner(ModelRunnerBase): else: assert batch_size % 2 == 0 self._dummy_run( - num_tokens=self.scheduler_config.max_num_batched_tokens, + num_tokens=( + self.scheduler_config.max_num_seqs + * (self.speculative_config.num_speculative_tokens + 1) + if self.scheduler_config.splitwise_role == "decode" + else self.scheduler_config.max_num_batched_tokens + ), batch_size=int(batch_size / 2), in_capturing=True, expected_decode_len=1, @@ -1688,38 +2044,53 @@ class MetaxModelRunner(ModelRunnerBase): logger.info( f"Warm up the Target model with the num_tokens:{batch_size}, expected_decode_len:{1}" ) - # Capture Draft Model without bsz 1 - # NOTE(liujundong): expected_decode_len = 1, will affect mtp capture in cudagraph - for batch_size in sorted(capture_sizes, reverse=True): - if batch_size == 1: - logger.info("Skip token_num = 1, when capture Draft model for mtp") - else: - assert batch_size % 2 == 0 + if self.graph_opt_config.draft_model_use_cudagraph: + # Capture Draft Model without bsz 1 + # NOTE(liujundong): expected_decode_len = 1, will affect mtp capture in cudagraph + for batch_size in sorted(capture_sizes, reverse=True): + if batch_size == 1: + logger.info("Skip token_num = 1, when capture Draft model for mtp") + else: + assert batch_size % 2 == 0 + self._dummy_run( + num_tokens=( + self.scheduler_config.max_num_seqs + if self.scheduler_config.splitwise_role == "decode" + else self.scheduler_config.max_num_batched_tokens + ), + batch_size=int(batch_size / 2), + in_capturing=True, + expected_decode_len=3, + accept_all_drafts=True, + ) + logger.info( + f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}" + ) + # Capture Draft Model with bsz 1 + if 1 in capture_sizes: self._dummy_run( - num_tokens=self.scheduler_config.max_num_batched_tokens, - batch_size=int(batch_size / 2), + num_tokens=( + self.scheduler_config.max_num_seqs + if self.scheduler_config.splitwise_role == "decode" + else self.scheduler_config.max_num_batched_tokens + ), + batch_size=int(1), in_capturing=True, expected_decode_len=3, - accept_all_drafts=True, + accept_all_drafts=False, + reject_all_drafts=True, ) logger.info( f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}" ) - # Capture Draft Model with bsz 1 - if 1 in capture_sizes: - self._dummy_run( - num_tokens=self.scheduler_config.max_num_batched_tokens, - batch_size=int(1), - in_capturing=True, - expected_decode_len=3, - accept_all_drafts=False, - ) - logger.info(f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}") - else: for batch_size in sorted(capture_sizes, reverse=True): self._dummy_run( - num_tokens=self.scheduler_config.max_num_batched_tokens, + num_tokens=( + self.scheduler_config.max_num_seqs + if self.scheduler_config.splitwise_role == "decode" + else self.scheduler_config.max_num_batched_tokens + ), batch_size=batch_size, in_capturing=True, expected_decode_len=expected_decode_len, @@ -1752,7 +2123,11 @@ class MetaxModelRunner(ModelRunnerBase): start_time = time.perf_counter() for batch_size in self.sot_warmup_sizes: self._dummy_run( - num_tokens=self.scheduler_config.max_num_batched_tokens, + num_tokens=( + self.scheduler_config.max_num_seqs + if self.scheduler_config.splitwise_role == "decode" + else self.scheduler_config.max_num_batched_tokens + ), batch_size=batch_size, ) logger.info(f"SOT warmup the model with the batch size:{batch_size}") @@ -1771,6 +2146,21 @@ class MetaxModelRunner(ModelRunnerBase): if self.share_inputs["step_idx"][idx] == 0: prefill_done_idxs.append(idx) + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + if model_forward_batch is None: + return prefill_done_idxs + + for task in model_forward_batch: + if task.task_type.value != RequestType.PREFILL.value: + continue + # in chunk prefill + if self.cache_config.enable_chunked_prefill: + if hasattr(task, "prefill_end_index") and hasattr(task, "prompt_token_ids"): + if len(task.prompt_token_ids) > task.prefill_end_index and task.idx in prefill_done_idxs: + prefill_done_idxs.remove(task.idx) + + return prefill_done_idxs + if self.cache_config.enable_chunked_prefill: if model_forward_batch is not None: for task in model_forward_batch: @@ -1793,7 +2183,7 @@ class MetaxModelRunner(ModelRunnerBase): self, model_forward_batch: Optional[List[Request]] = None, num_running_requests: int = None, - ) -> Optional[ModelRunnerOutput]: + ) -> None: """ The Entrance of model execute. Args: @@ -1808,12 +2198,9 @@ class MetaxModelRunner(ModelRunnerBase): self._prepare_inputs() self.sampler.pre_process(p_done_idxs) - # NOTE(wufeisheng): If `not_need_stop`` is False, it means the current worker is in an idle state. - # This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode, - # when there is data on other runner, the current runner is required to execute part of the model. - if not self.not_need_stop(): - self._execute_empty_input(self.forward_meta) - return None + # 1.1 Update state of logits processor + for proc in self.sampling_metadata.logits_processors: + proc.update_state(self.share_inputs) # 2. Padding inputs for cuda graph self.padding_cudagraph_inputs() @@ -1821,182 +2208,284 @@ class MetaxModelRunner(ModelRunnerBase): # 3. Execute model if self.enable_mm: model_output = self.model( - self.share_inputs["ids_remove_padding"], + self.forward_meta.ids_remove_padding, self.share_inputs["image_features"], self.forward_meta, ) else: model_output = self.model( - ids_remove_padding=self.share_inputs["ids_remove_padding"], - forward_meta=self.forward_meta, + self.forward_meta.ids_remove_padding, + self.forward_meta, ) + + # NOTE(wufeisheng): If `not_need_stop`` is False, it means the current worker is in an idle state. + # This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode, + # Then there is data on other runner, the current runner is required to execute part of the model. + # But not need to run the below code. + if not self.not_need_stop(): + return None + if self.use_cudagraph: model_output = model_output[: self.real_token_num] - hidden_states = rebuild_padding( - model_output, - self.share_inputs["cu_seqlens_q"], - self.share_inputs["seq_lens_this_time"], - self.share_inputs["seq_lens_decoder"], - self.share_inputs["seq_lens_encoder"], - (self.share_inputs["output_padding_offset"] if self.speculative_decoding else None), - self.model_config.max_model_len, - ) - logits = None - # 4. Compute logits, Sample - if hasattr(self.model, "is_pooling_model") and self.model.is_pooling_model: - # TODO(lizexu123) The execution of the pooling function have not been implemented yet. - pass + prompt_logprobs_list = self._get_prompt_logprobs_list(model_output) + + if self.is_pooling_model: + pooler_output = self._pool(model_output, num_running_requests) + + model_output_data = ModelOutputData( + next_tokens=self.share_inputs["next_tokens"], + stop_flags=self.share_inputs["stop_flags"], + step_idx=self.share_inputs["step_idx"], + max_dec_len=self.share_inputs["max_dec_len"], + pre_ids=self.share_inputs["pre_ids"], + seq_lens_this_time=self.share_inputs["seq_lens_this_time"], + eos_token_id=self.share_inputs["eos_token_id"], + not_need_stop=self.share_inputs["not_need_stop"], + input_ids=self.share_inputs["input_ids"], + stop_nums=self.share_inputs["stop_nums"], + seq_lens_encoder=self.share_inputs["seq_lens_encoder"], + seq_lens_decoder=self.share_inputs["seq_lens_decoder"], + is_block_step=self.share_inputs["is_block_step"], + full_hidden_states=model_output, + msg_queue_id=self.parallel_config.msg_queue_id, + mp_rank=self.parallel_config.tensor_parallel_rank, + use_ep=self.parallel_config.use_ep, + draft_tokens=(self.share_inputs["draft_tokens"] if self.speculative_decoding else None), + actual_draft_token_num=( + self.share_inputs["actual_draft_token_num"] if self.speculative_decoding else None + ), + accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None), + accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None), + stop_token_ids=self.share_inputs["stop_seqs"], + stop_seqs_len=self.share_inputs["stop_seqs_len"], + prompt_lens=self.share_inputs["prompt_lens"], + ) + + post_process( + sampler_or_pooler_output=pooler_output, + model_output=model_output_data, + share_inputs=self.share_inputs, + block_size=self.cache_config.block_size, + save_each_rank=self.parallel_config.use_ep, + speculative_decoding=self.speculative_decoding, + skip_save_output=False, + async_output_queue=self.async_output_queue, + ) + + return None else: + hidden_states = rebuild_padding( + model_output, + self.share_inputs["cu_seqlens_q"], + self.share_inputs["seq_lens_this_time"], + self.share_inputs["seq_lens_decoder"], + self.share_inputs["seq_lens_encoder"], + (self.share_inputs["output_padding_offset"] if self.speculative_decoding else None), + self.model_config.max_model_len, + ) + + # 4. Compute logits, Sample logits = self.model.compute_logits(hidden_states) - if not self.speculative_decoding: - set_value_by_flags_and_idx( - self.share_inputs["pre_ids"], - self.share_inputs["input_ids"], - self.share_inputs["seq_lens_this_time"], - self.share_inputs["seq_lens_encoder"], - self.share_inputs["seq_lens_decoder"], - self.share_inputs["step_idx"], - self.share_inputs["stop_flags"], - ) - sampler_output = self.sampler( - logits, - self.sampling_metadata, - p_done_idxs, - ) - if self.parallel_config.tensor_parallel_size > 1: - paddle.distributed.broadcast( - sampler_output.sampled_token_ids, - self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, - group=self.parallel_config.tp_group, - ) - - else: - sampler_output = self.sampler( - logits, - self.sampling_metadata, - self.model_config.max_model_len, - self.share_inputs, - ) - if self.parallel_config.tensor_parallel_size > 1: - paddle.distributed.broadcast( - self.share_inputs["accept_tokens"], - self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, - group=self.parallel_config.tp_group, - ) - paddle.distributed.broadcast( - self.share_inputs["accept_num"], - self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, - group=self.parallel_config.tp_group, - ) - paddle.distributed.broadcast( + if not self.speculative_decoding: + set_value_by_flags_and_idx( + self.share_inputs["pre_ids"], + self.share_inputs["input_ids"], + self.share_inputs["seq_lens_this_time"], + self.share_inputs["seq_lens_encoder"], + self.share_inputs["seq_lens_decoder"], self.share_inputs["step_idx"], - self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, - group=self.parallel_config.tp_group, - ) - paddle.distributed.broadcast( self.share_inputs["stop_flags"], - self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, - group=self.parallel_config.tp_group, ) - - # 5. Post Process - model_output_data = ModelOutputData( - next_tokens=self.share_inputs["next_tokens"], - stop_flags=self.share_inputs["stop_flags"], - step_idx=self.share_inputs["step_idx"], - max_dec_len=self.share_inputs["max_dec_len"], - pre_ids=self.share_inputs["pre_ids"], - seq_lens_this_time=self.share_inputs["seq_lens_this_time"], - eos_token_id=self.share_inputs["eos_token_id"], - not_need_stop=self.share_inputs["not_need_stop"], - input_ids=self.share_inputs["input_ids"], - stop_nums=self.share_inputs["stop_nums"], - seq_lens_encoder=self.share_inputs["seq_lens_encoder"], - seq_lens_decoder=self.share_inputs["seq_lens_decoder"], - is_block_step=self.share_inputs["is_block_step"], - full_hidden_states=model_output, - msg_queue_id=self.parallel_config.msg_queue_id, - mp_rank=self.parallel_config.tensor_parallel_rank, - use_ep=self.parallel_config.use_ep, - draft_tokens=(self.share_inputs["draft_tokens"] if self.speculative_decoding else None), - actual_draft_token_num=( - self.share_inputs["actual_draft_token_num"] if self.speculative_decoding else None - ), - accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None), - accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None), - stop_token_ids=self.share_inputs["stop_seqs"], - stop_seqs_len=self.share_inputs["stop_seqs_len"], - prompt_lens=self.share_inputs["prompt_lens"], - ) - - if self.speculative_config.method in ["mtp"] and self.scheduler_config.splitwise_role == "prefill": - skip_save_output = True - else: - skip_save_output = False - post_process( - sampler_or_pooler_output=sampler_output, - model_output=model_output_data, - share_inputs=self.share_inputs, - block_size=self.cache_config.block_size, - save_each_rank=self.parallel_config.use_ep, - speculative_decoding=self.speculative_decoding, - skip_save_output=skip_save_output, - async_output_queue=self.async_output_queue, - think_end_id=self.model_config.think_end_id, - line_break_id=self.model_config.line_break_id, - ) - if self.guided_backend is not None and sampler_output is not None: - self.sampler.post_process(sampler_output.sampled_token_ids) - - # 6. Speculative decode - if self.speculative_decoding: - if self.speculative_method == "mtp": - self.proposer.run( - full_hidden_states=model_output, step_use_cudagraph=self.forward_meta.step_use_cudagraph + sampler_output = self.sampler( + logits, + self.sampling_metadata, + p_done_idxs, ) + if self.parallel_config.tensor_parallel_size > 1: + paddle.distributed.broadcast( + sampler_output.sampled_token_ids, + self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, + group=self.parallel_config.tp_group, + ) else: - self.proposer.run(share_inputs=self.share_inputs) - - # 7. Update 'infer_seed' and step_cuda() - self.share_inputs["infer_seed"].add_(self.infer_seed_increment) - self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED - if not envs.ENABLE_V1_KVCACHE_SCHEDULER: - step_cuda( - self.share_inputs, - self.cache_config.block_size, - self.cache_config.enc_dec_block_num, - self.speculative_config, - self.cache_config.enable_prefix_caching, + sampler_output = self.sampler( + logits, + self.sampling_metadata, + self.model_config.max_model_len, + self.share_inputs, + ) + if self.parallel_config.tensor_parallel_size > 1: + paddle.distributed.broadcast( + self.share_inputs["accept_tokens"], + self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, + group=self.parallel_config.tp_group, + ) + paddle.distributed.broadcast( + self.share_inputs["accept_num"], + self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, + group=self.parallel_config.tp_group, + ) + paddle.distributed.broadcast( + self.share_inputs["step_idx"], + self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, + group=self.parallel_config.tp_group, + ) + paddle.distributed.broadcast( + self.share_inputs["stop_flags"], + self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, + group=self.parallel_config.tp_group, + ) + # 5. Post Process + model_output_data = ModelOutputData( + next_tokens=self.share_inputs["next_tokens"], + stop_flags=self.share_inputs["stop_flags"], + step_idx=self.share_inputs["step_idx"], + max_dec_len=self.share_inputs["max_dec_len"], + pre_ids=self.share_inputs["pre_ids"], + seq_lens_this_time=self.share_inputs["seq_lens_this_time"], + eos_token_id=self.share_inputs["eos_token_id"], + not_need_stop=self.share_inputs["not_need_stop"], + input_ids=self.share_inputs["input_ids"], + stop_nums=self.share_inputs["stop_nums"], + seq_lens_encoder=self.share_inputs["seq_lens_encoder"], + seq_lens_decoder=self.share_inputs["seq_lens_decoder"], + is_block_step=self.share_inputs["is_block_step"], + full_hidden_states=model_output, + msg_queue_id=self.parallel_config.msg_queue_id, + mp_rank=self.parallel_config.tensor_parallel_rank, + use_ep=self.parallel_config.use_ep, + draft_tokens=(self.share_inputs["draft_tokens"] if self.speculative_decoding else None), + actual_draft_token_num=( + self.share_inputs["actual_draft_token_num"] if self.speculative_decoding else None + ), + accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None), + accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None), + stop_token_ids=self.share_inputs["stop_seqs"], + stop_seqs_len=self.share_inputs["stop_seqs_len"], + prompt_lens=self.share_inputs["prompt_lens"], + mask_rollback=self.share_inputs["mask_rollback"], + prompt_logprobs_list=prompt_logprobs_list, ) - self._update_chunked_prefill(model_forward_batch) - elif self.speculative_decoding: - speculate_schedule_cache( - self.share_inputs["draft_tokens"], - self.share_inputs["block_tables"], - self.share_inputs["stop_flags"], - self.share_inputs["prompt_lens"], - self.share_inputs["seq_lens_this_time"], - self.share_inputs["seq_lens_encoder"], - self.share_inputs["seq_lens_decoder"], - self.share_inputs["step_seq_lens_decoder"], - self.share_inputs["step_draft_tokens"], - self.share_inputs["step_seq_lens_this_time"], - self.share_inputs["accept_num"], - self.share_inputs["accept_tokens"], - self.share_inputs["is_block_step"], - self.share_inputs["not_need_stop"], - self.share_inputs["stop_nums"], - self.cache_config.block_size, - self.speculative_config.num_speculative_tokens, - ) + if self.speculative_config.method in ["mtp"] and self.scheduler_config.splitwise_role == "prefill": + skip_save_output = True + else: + skip_save_output = False - self.seq_lens_this_time_buffer[:num_running_requests].copy_( - self.share_inputs["seq_lens_this_time"][:num_running_requests], False + post_process( + sampler_or_pooler_output=sampler_output, + model_output=model_output_data, + share_inputs=self.share_inputs, + block_size=self.cache_config.block_size, + save_each_rank=self.parallel_config.use_ep, + speculative_decoding=self.speculative_decoding, + skip_save_output=skip_save_output, + async_output_queue=self.async_output_queue, + think_end_id=self.model_config.think_end_id, + line_break_id=self.model_config.line_break_id, + ) + if self.guided_backend is not None and sampler_output is not None: + self.sampler.post_process(sampler_output.sampled_token_ids) + + # 6. Speculative decode + if self.speculative_decoding: + if self.speculative_method == "mtp": + self.proposer.run( + full_hidden_states=model_output, step_use_cudagraph=self.forward_meta.step_use_cudagraph + ) + else: + self.proposer.run(share_inputs=self.share_inputs) + + # 7. Update 'infer_seed' and step_cuda() + self.share_inputs["infer_seed"].add_(self.infer_seed_increment) + self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED + if not envs.ENABLE_V1_KVCACHE_SCHEDULER: + step_cuda( + self.share_inputs, + self.cache_config.block_size, + self.cache_config.enc_dec_block_num, + self.speculative_config, + self.cache_config.enable_prefix_caching, + ) + + self._update_chunked_prefill(model_forward_batch) + elif self.speculative_decoding: + speculate_schedule_cache( + self.share_inputs["draft_tokens"], + self.share_inputs["block_tables"], + self.share_inputs["stop_flags"], + self.share_inputs["prompt_lens"], + self.share_inputs["seq_lens_this_time"], + self.share_inputs["seq_lens_encoder"], + self.share_inputs["seq_lens_decoder"], + self.share_inputs["step_seq_lens_decoder"], + self.share_inputs["step_draft_tokens"], + self.share_inputs["step_seq_lens_this_time"], + self.share_inputs["accept_num"], + self.share_inputs["accept_tokens"], + self.share_inputs["is_block_step"], + self.share_inputs["not_need_stop"], + self.share_inputs["stop_nums"], + self.cache_config.block_size, + self.speculative_config.num_speculative_tokens, + ) + + # Routing replay + if self.fd_config.routing_replay_config.enable_routing_replay: + if ( + not self.exist_prefill() + and not self.exist_decode() + and self.share_inputs["is_block_step"].sum() == 0 + and self.share_inputs["is_chunk_step"].sum() == 0 + ): + self.routing_replay_manager.put_table_to_store() + return None + + def _pool(self, hidden_states: paddle.Tensor, num_running_requests: int) -> Optional[ModelRunnerOutput]: + + num_scheduled_tokens = int(self.share_inputs["seq_lens_this_time"][:num_running_requests].sum()) + hidden_states = hidden_states[:num_scheduled_tokens] + + prompt_lens = self.share_inputs["prompt_lens"][:num_running_requests] + prompt_token_ids = self.share_inputs["prompt_ids"] + + pooling_metadata = PoolingMetadata( + prompt_lens=prompt_lens, + prompt_token_ids=prompt_token_ids, + pooling_params=self.pooling_params, ) - return None + num_scheduled_tokens_list = [ + int(self.share_inputs["seq_lens_this_time"][i]) for i in range(num_running_requests) + ] + device_str = "gpu" if hidden_states.place.is_gpu_place() else "cpu" + pooling_metadata.build_pooling_cursor(num_scheduled_tokens_list, device=device_str) + + raw_pooler_output = self.model.pooler(hidden_states=hidden_states, pooling_metadata=pooling_metadata) + + seq_lens_cpu = self.share_inputs["seq_lens_this_time"][:num_running_requests] + pooler_output: list[Optional[paddle.Tensor]] = [] + + seq_lens_decoder_batch = self.share_inputs["seq_lens_decoder"][:num_running_requests] + + for i, (seq_len, prompt_len) in enumerate(zip(seq_lens_cpu, pooling_metadata.prompt_lens)): + if not self.cache_config.enable_prefix_caching: + output = raw_pooler_output[0].data if int(seq_len) == int(prompt_len) else None + pooler_output.append(output) + else: + current_seq_len_decoder = seq_lens_decoder_batch[i] + if int(current_seq_len_decoder) + int(seq_len) == int(prompt_len): + output = raw_pooler_output[0].data + else: + output = None + pooler_output.append(output) + + pooler_output = PoolerOutput( + outputs=pooler_output, + ) + + return pooler_output def _execute_empty_input(self, forward_meta) -> None: """ @@ -2024,7 +2513,11 @@ class MetaxModelRunner(ModelRunnerBase): # 2. Dummy run self._dummy_run( - num_tokens=self.scheduler_config.max_num_batched_tokens, + num_tokens=( + self.scheduler_config.max_num_seqs + if self.scheduler_config.splitwise_role == "decode" + else self.scheduler_config.max_num_batched_tokens + ), batch_size=self.scheduler_config.max_num_seqs, ) @@ -2066,6 +2559,7 @@ class MetaxModelRunner(ModelRunnerBase): def cal_theortical_kvcache(self): """ Calculate the total block memory required at the model level + TODO(gongshaotian): Move to Attention Backend """ """ Byte of dtype: @@ -2087,7 +2581,7 @@ class MetaxModelRunner(ModelRunnerBase): byte_of_dtype = 2 hidden_dim = self.model_config.head_dim * self.model_config.kv_num_heads - + # NOTE(liuzichang): Implement multi-layer MTP architecture in the future num_layers = ( self.model_config.num_hidden_layers + self.speculative_config.num_gpu_block_expand_ratio if self.speculative_method in ["mtp"] @@ -2106,6 +2600,9 @@ class MetaxModelRunner(ModelRunnerBase): ) # compress_kv + k_pe else: required_memory = byte_of_dtype * 2 * (self.cache_config.block_size * hidden_dim) * num_layers # k + v + + logger.info(f"Block Memory: {required_memory / 1024} KB") + return required_memory def not_need_stop(self) -> bool: @@ -2134,6 +2631,9 @@ class MetaxModelRunner(ModelRunnerBase): def clear_requests(self): """Dynamic model loader use to clear requests use for RL""" self.share_inputs["stop_flags"][:] = True + # prompt_logprobs + self.prompt_logprobs_reqs.clear() + self.in_progress_prompt_logprobs.clear() def update_parameters(self, pid): """Dynamic model loader use to update parameters use for RL""" @@ -2220,6 +2720,7 @@ class MetaxModelRunner(ModelRunnerBase): def extract_vision_features_ernie(self, inputs: list[paddle.Tensor]) -> paddle.Tensor: assert inputs["images"] is not None grid_thw = inputs["grid_thw"] + # ernie-vl has images norm images = inputs["images"].cast("float32") images = self.image_preprocess.rescale_factor * images - self.image_preprocess.image_mean_tensor images = images / self.image_preprocess.image_std_tensor @@ -2244,6 +2745,7 @@ class MetaxModelRunner(ModelRunnerBase): image_features = image_features.reshape([-1, C * self.model_config.spatial_conv_size**2]) image_features = ScatterOp.apply(image_features, axis=-1) # mp 切 Fea image_features = image_features.reshape([S, -1]) + # ernie-vl has resampler_model image_features = self.model.resampler_model( image_features, image_mask, @@ -2253,11 +2755,77 @@ class MetaxModelRunner(ModelRunnerBase): ) return image_features + def extract_vision_features_qwen(self, inputs: list[paddle.Tensor]) -> paddle.Tensor: + if envs.FD_ENABLE_MAX_PREFILL: + images = paddle.concat(inputs["images_lst"]).cast("bfloat16") + grid_thw = paddle.to_tensor(inputs["grid_thw_lst"], dtype="int64") + else: + assert inputs["images"] is not None + grid_thw = inputs["grid_thw"] + images = inputs["images"] + with paddle.amp.auto_cast( + True, + custom_black_list=self.amp_black, + custom_white_list=self.amp_white, + level="O2", + dtype=self.model_config.dtype, + ): + image_features = self.model.visual.extract_feature(images, grid_thw) + + return image_features + + def extract_vision_features_paddleocr(self, inputs: list[paddle.Tensor]) -> paddle.Tensor: + if envs.FD_ENABLE_MAX_PREFILL: + inputs["vit_position_ids_lst"] = np.concatenate(inputs["vit_position_ids_lst"]) + images = paddle.concat(inputs["images_lst"]).cast("bfloat16") + grid_thw = paddle.to_tensor(inputs["grid_thw_lst"], dtype="int64") + position_ids = paddle.to_tensor(inputs["vit_position_ids_lst"], dtype="int64") + cu_seqlens = paddle.cumsum(paddle.to_tensor(inputs["cu_seqlens"])).cast("int32") + else: + assert inputs["images"] is not None + grid_thw = inputs["grid_thw"] + images = inputs["images"] + + position_ids = [] + cu_seqlens = [0] + for idx, thw in enumerate(grid_thw): + numel = np.prod(np.array(thw)) + position_ids.append(paddle.arange(numel) % np.prod(thw[1:])) + cu_seqlens.append(cu_seqlens[-1] + numel) + + position_ids = paddle.concat(position_ids, axis=0).to(images.place) + cu_seqlens = paddle.to_tensor(cu_seqlens, dtype=paddle.int32).to(images.place) + + with paddle.amp.auto_cast( + True, + custom_black_list=self.amp_black, + custom_white_list=self.amp_white, + level="O2", + dtype=self.model_config.dtype, + ): + image_features = self.model.visual( + pixel_values=images, + image_grid_thw=grid_thw, + position_ids=position_ids, + interpolate_pos_encoding=True, + cu_seqlens=cu_seqlens, + use_rope=True, + window_size=-1, + ) + image_features = self.model.projector(image_features, grid_thw) + image_features = paddle.concat(image_features, axis=0) + + return image_features + @paddle.no_grad() def extract_vision_features(self, inputs: list[paddle.Tensor]) -> paddle.Tensor: """extract_vision_features""" if "ernie" in self.model_config.model_type: return self.extract_vision_features_ernie(inputs) + elif "qwen" in self.model_config.model_type: + return self.extract_vision_features_qwen(inputs) + elif "paddleocr" in self.model_config.model_type: + return self.extract_vision_features_paddleocr(inputs) else: raise ValueError(f"multiple modalities model {self.model_config.model_type} is not supported") @@ -2279,3 +2847,69 @@ class MetaxModelRunner(ModelRunnerBase): cumsum_seqlens=cumsum_seqlens, ) return rope_emb_lst + + def _get_prompt_logprobs_list( + self, + hidden_states: paddle.Tensor, + ) -> list[Optional[LogprobsTensors]]: + if len(self.prompt_logprobs_reqs) > 0: + assert ( + not self.fd_config.cache_config.enable_prefix_caching + ), "prompt_logprobs must disable prefix caching, --no-enable-prefix-caching." + logprobs_mode = self.fd_config.model_config.logprobs_mode + prompt_logprobs_list: list[Optional[LogprobsTensors]] = self.scheduler_config.max_num_seqs * [None] + completed_prefill_reqs: list[Request] = [] + for req_id, request in self.prompt_logprobs_reqs.items(): + num_prompt_logprobs = request.sampling_params.prompt_logprobs + if request.prompt_token_ids is None or num_prompt_logprobs is None: + continue + if num_prompt_logprobs == -1: + num_prompt_logprobs = self.ori_vocab_size + + num_tokens = request.prefill_end_index - request.prefill_start_index + num_prompt_tokens = len(request.prompt_token_ids) + + logprobs_tensors = self.in_progress_prompt_logprobs.get(req_id) + if not logprobs_tensors: + logprobs_tensors = LogprobsTensors.empty_cpu(num_prompt_tokens - 1, num_prompt_logprobs + 1) + self.in_progress_prompt_logprobs[req_id] = logprobs_tensors + start_idx = request.prefill_start_index + start_tok = start_idx + 1 + num_remaining_tokens = num_prompt_tokens - start_tok + if num_tokens <= num_remaining_tokens: + # This is a chunk, more tokens remain. + # In the == case, there are no more prompt logprobs to produce + # but we want to defer returning them to the next step where we + # have new generated tokens to return. + num_logits = num_tokens + else: + # This is the last chunk of prompt tokens to return. + num_logits = num_remaining_tokens + completed_prefill_reqs.append(request) + prompt_logprobs_list[request.idx] = logprobs_tensors + if num_logits <= 0: + # This can happen for the final chunk if we prefilled exactly + # (num_prompt_tokens - 1) tokens for this request in the prior + # step. There are no more prompt logprobs to produce. + continue + offset = self.share_inputs["cu_seqlens_q"][request.idx] + prompt_hidden_states = hidden_states[offset : offset + num_logits] + logits = self.model.compute_logits(prompt_hidden_states) + prompt_token_ids = request.prompt_token_ids[start_tok : start_tok + num_logits] + prompt_token_ids_tensor = paddle.to_tensor(prompt_token_ids, dtype="int64") + if logprobs_mode == "raw_logprobs": + raw_logprobs = self.sampler.compute_logprobs(logits) + elif logprobs_mode == "raw_logits": + raw_logprobs = logits + token_ids, logprobs, ranks = self.sampler.gather_logprobs( + raw_logprobs, num_prompt_logprobs, prompt_token_ids_tensor + ) + chunk_slice = slice(start_idx, start_idx + num_logits) + logprobs_tensors.logprob_token_ids[chunk_slice].copy_(token_ids, False) + logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, False) + logprobs_tensors.selected_token_ranks[chunk_slice].copy_(ranks, False) + + for req in completed_prefill_reqs: + del self.prompt_logprobs_reqs[req.request_id] + del self.in_progress_prompt_logprobs[req.request_id] + return prompt_logprobs_list diff --git a/fastdeploy/worker/metax_worker.py b/fastdeploy/worker/metax_worker.py index b57b4e3dd..a3d74b0bd 100644 --- a/fastdeploy/worker/metax_worker.py +++ b/fastdeploy/worker/metax_worker.py @@ -120,23 +120,28 @@ class MetaxWorker(WorkerBase): before_run_meminfo_used = info.vramUse * 1024 before_run_meminfo_free = before_run_meminfo_total - before_run_meminfo_used - logger.info("Before running the profile, the memory usage info of Metax GPU is as follows:") - logger.info(f"Device Index: {device_id}") - logger.info(f"Device Total memory: {before_run_meminfo_total / Gb}") - logger.info(f"Device used memory: {before_run_meminfo_used / Gb}") - logger.info(f"Device free memory: {before_run_meminfo_free / Gb}") - logger.info(f"Paddle reserved memory: {paddle_reserved_mem_before_run / Gb}") - logger.info(f"Paddle allocated memory: {paddle_allocated_mem_before_run / Gb}") + logger.info( + ( + "Before running the profile, the memory usage info is as follows:", + f"\nDevice Index: {device_id}", + f"\nDevice Total memory: {before_run_meminfo_total / Gb}", + f"\nDevice used memory: {before_run_meminfo_used / Gb}", + f"\nDevice free memory: {before_run_meminfo_free / Gb}", + f"\nPaddle reserved memory: {paddle_reserved_mem_before_run / Gb}", + f"\nPaddle allocated memory: {paddle_allocated_mem_before_run / Gb}", + ) + ) # 2. Profile run self.model_runner.profile_run() + set_random_seed(self.fd_config.model_config.seed) # 3. Statistical memory information paddle_reserved_mem_after_run = paddle.device.max_memory_reserved(local_rank) paddle_allocated_mem_after_run = paddle.device.max_memory_allocated(local_rank) model_block_memory_used = self.cal_theortical_kvcache() - paddle_peak_increase = paddle_reserved_mem_after_run - paddle_allocated_mem_before_run + paddle_peak_increase = paddle_allocated_mem_after_run - paddle_allocated_mem_before_run paddle.device.empty_cache() @@ -154,15 +159,19 @@ class MetaxWorker(WorkerBase): end_time = time.perf_counter() - logger.info("After running the profile, the memory usage info of Metax GPU is as follows:") - logger.info(f"Device Index: {device_id}") - logger.info(f"Device Total memory: {after_run_meminfo_total / Gb}") - logger.info(f"Device used memory: {after_run_meminfo_used / Gb}") - logger.info(f"Device free memory: {after_run_meminfo_free / Gb}") - logger.info(f"Paddle reserved memory: {paddle_reserved_mem_after_run / Gb}") - logger.info(f"Paddle allocated memory: {paddle_allocated_mem_after_run / Gb}") - logger.info(f"Paddle available_kv_cache_memory: {available_kv_cache_memory / Gb}") - logger.info(f"Profile time: {end_time - start_time}") + logger.info( + ( + "After running the profile, the memory usage info is as follows:", + f"\nDevice Index: {device_id}", + f"\nDevice Total memory: {after_run_meminfo_total / Gb}", + f"\nDevice used memory: {after_run_meminfo_used / Gb}", + f"\nDevice free memory: {after_run_meminfo_free / Gb}", + f"\nPaddle reserved memory: {paddle_reserved_mem_after_run / Gb}", + f"\nPaddle allocated memory: {paddle_allocated_mem_after_run / Gb}", + f"\nAvailable KV Cache meomory: {available_kv_cache_memory / Gb}", + f"\nProfile time: {end_time - start_time}", + ) + ) return available_kv_cache_memory