diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 58e7c4f31..f1eb23852 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1484,6 +1484,31 @@ class StructuredOutputsConfig: return json.dumps({key: value for key, value in self.__dict__.items()}) +class RoutingReplayConfig: + """Configuration for Routing Replay used in RL training""" + + def __init__(self, args) -> None: + self.enable_routing_replay: bool = False + self.routing_store_type: str = "local" + + # Local routing store + self.local_store_dir: str = "./routing_replay_output" + + # RDMA routing store + # TODO: Add RDMA routing store configuration attributes here when the feature is implemented. + + if args is not None: + for key, value in args.items(): + if hasattr(self, key) and value != "None": + setattr(self, key, value) + + def to_json_string(self): + """ + Convert routing replay config to json string. + """ + return json.dumps({key: value for key, value in self.__dict__.items()}) + + class FDConfig: """ The configuration class which contains all fastdeploy-related configuration. This @@ -1517,6 +1542,7 @@ class FDConfig: early_stop_config: Optional[Dict[str, Any]] = None, tool_parser: str = None, test_mode=False, + routing_replay_config: Optional[RoutingReplayConfig] = None, ): self.model_config: ModelConfig = model_config # type: ignore self.cache_config: CacheConfig = cache_config # type: ignore @@ -1533,6 +1559,7 @@ class FDConfig: self.plas_attention_config: Optional[PlasAttentionConfig] = plas_attention_config self.structured_outputs_config: StructuredOutputsConfig = structured_outputs_config self.router_config: RouterConfig = router_config + self.routing_replay_config = routing_replay_config # Initialize cuda graph capture list max_capture_shape = self.scheduler_config.max_num_seqs diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 1eaf53549..d2d7c6f90 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -35,6 +35,7 @@ from fastdeploy.config import ( PlasAttentionConfig, PoolerConfig, RouterConfig, + RoutingReplayConfig, RunnerOption, SpeculativeConfig, StructuredOutputsConfig, @@ -491,6 +492,11 @@ class EngineArgs: Configuration for eplb. """ + routing_replay_config: Optional[Dict[str, Any]] = None + """ + Flag to rollout routing replay(r3) + """ + def __post_init__(self): """ Post-initialization processing to set default tokenizer if not provided. @@ -882,6 +888,12 @@ class EngineArgs: default=EngineArgs.eplb_config, help="Config of eplb.", ) + parallel_group.add_argument( + "--routing-replay-config", + type=json.loads, + default=EngineArgs.routing_replay_config, + help="Flag of rollout routing replay(r3).", + ) parallel_group.add_argument( "--enable-chunked-moe", action="store_true", @@ -1235,6 +1247,14 @@ class EngineArgs: eplb_args["enable_eplb"] = self.enable_eplb return EPLBConfig(eplb_args) + def create_routing_repaly_config(self) -> RoutingReplayConfig: + """ """ + routing_replay_args = asdict(self) + if self.routing_replay_config is not None: + for k, v in self.routing_replay_config.items(): + routing_replay_args[k] = v + return RoutingReplayConfig(routing_replay_args) + def create_engine_config(self, port_availability_check=True) -> FDConfig: """ Create and return a Config object based on the current settings. @@ -1278,6 +1298,7 @@ class EngineArgs: graph_opt_cfg = self.create_graph_optimization_config() plas_attention_config = self.create_plas_attention_config() eplb_cfg = self.create_eplb_config() + routing_replay_config = self.create_routing_repaly_config() router_config = RouterConfig(all_dict) early_stop_cfg = self.create_early_stop_config() @@ -1310,4 +1331,5 @@ class EngineArgs: graph_opt_config=graph_opt_cfg, plas_attention_config=plas_attention_config, early_stop_config=early_stop_cfg, + routing_replay_config=routing_replay_config, ) diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 4a493843d..fadf95467 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -568,6 +568,7 @@ class LLMEngine: f" --logprobs_mode {self.cfg.model_config.logprobs_mode}" f" --max_logprobs {self.cfg.model_config.max_logprobs}" f" --eplb_config '{self.cfg.eplb_config.to_json_string()}'" + f" --routing_replay_config '{self.cfg.routing_replay_config.to_json_string()}'" ) if self.cfg.structured_outputs_config.logits_processors is not None: arguments += f" --logits-processors {' '.join(self.cfg.structured_outputs_config.logits_processors)}" diff --git a/fastdeploy/model_executor/forward_meta.py b/fastdeploy/model_executor/forward_meta.py index 4e9df0d3c..787ec77c0 100644 --- a/fastdeploy/model_executor/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -142,6 +142,8 @@ class ForwardMeta: caches: Optional[list[paddle.Tensor]] = None # Flag of profile run is_dummy_or_profile_run: bool = False + # Routing Replay table buffer + routing_replay_table: Optional[paddle.Tensor] = None # chunked MoE related moe_num_chunk: int = 1 diff --git a/fastdeploy/model_executor/layers/backends/dcu/fused_moe_triton_backends.py b/fastdeploy/model_executor/layers/backends/dcu/fused_moe_triton_backends.py index 918450c74..192c0b883 100644 --- a/fastdeploy/model_executor/layers/backends/dcu/fused_moe_triton_backends.py +++ b/fastdeploy/model_executor/layers/backends/dcu/fused_moe_triton_backends.py @@ -14,6 +14,8 @@ # limitations under the License. """ +from typing import Callable + import paddle from paddle import nn @@ -101,6 +103,7 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Triton compute Fused MoE. @@ -117,6 +120,8 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase): scores += layer.gate_correction_bias topk_weights, topk_ids = paddle.topk(scores, k=top_k, axis=-1, sorted=False) topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdim=True) + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_ids) intermediate_cache1 = paddle.empty( [token_num * top_k, moe_intermediate_size * 2], diff --git a/fastdeploy/model_executor/layers/backends/gcu/moe/fused_moe_method_gcu_backend.py b/fastdeploy/model_executor/layers/backends/gcu/moe/fused_moe_method_gcu_backend.py index e67dd6dbd..2260d7caf 100644 --- a/fastdeploy/model_executor/layers/backends/gcu/moe/fused_moe_method_gcu_backend.py +++ b/fastdeploy/model_executor/layers/backends/gcu/moe/fused_moe_method_gcu_backend.py @@ -16,6 +16,7 @@ import multiprocessing import os +from typing import Callable import numpy as np import paddle @@ -182,6 +183,7 @@ class GCUFusedMoeMethod(UnquantizedFusedMoEMethod): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Paddle gcu compute Fused MoE. @@ -194,6 +196,7 @@ class GCUFusedMoeMethod(UnquantizedFusedMoEMethod): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Apply the EP prefill method. @@ -205,6 +208,7 @@ class GCUFusedMoeMethod(UnquantizedFusedMoEMethod): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Apply the EP decoder method. @@ -216,6 +220,7 @@ class GCUFusedMoeMethod(UnquantizedFusedMoEMethod): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Paddle Cutlass compute Fused MoE. @@ -381,6 +386,7 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Paddle gcu compute Fused MoE. diff --git a/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_triton_metax_backend.py b/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_triton_metax_backend.py index 7b61d58b6..fbbfac277 100644 --- a/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_triton_metax_backend.py +++ b/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_triton_metax_backend.py @@ -14,6 +14,8 @@ # limitations under the License. """ +from typing import Callable + import paddle from paddle import nn @@ -245,6 +247,7 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Triton compute Fused MoE. @@ -274,6 +277,9 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase): True, # apply_norm_weight False, ) + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_ids) + up_gate_proj_out = paddle.empty( [token_num * top_k, moe_intermediate_size * 2], dtype=x.dtype, diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py index b34291a96..a8bd70465 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py @@ -15,6 +15,7 @@ """ from abc import abstractmethod +from typing import Callable import paddle from paddle import nn @@ -163,6 +164,7 @@ class MoEMethodBase(QuantMethodBase): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Apply the EP prefill method. @@ -175,6 +177,7 @@ class MoEMethodBase(QuantMethodBase): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Apply the EP decoder method. @@ -187,6 +190,7 @@ class MoEMethodBase(QuantMethodBase): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Paddle Cutlass compute Fused MoE. @@ -198,6 +202,7 @@ class MoEMethodBase(QuantMethodBase): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Paddle Cutlass compute Fused MoE. @@ -207,13 +212,13 @@ class MoEMethodBase(QuantMethodBase): if layer.fd_config.model_config.moe_phase.phase == "prefill": if layer.fd_config.scheduler_config.splitwise_role == "mixed" and is_moe_start_layer: self.ep_prefill_runner.clean_low_latency_buffer() - return self.apply_ep_prefill(layer, x, gate) + return self.apply_ep_prefill(layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc) else: if layer.fd_config.scheduler_config.splitwise_role == "mixed" and is_moe_start_layer: self.ep_decoder_runner.clean_low_latency_buffer() - return self.apply_ep_decode(layer, x, gate) + return self.apply_ep_decode(layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc) else: - return self.apply_tp(layer, x, gate) + return self.apply_tp(layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc) class UnquantizedFusedMoEMethod(MoEMethodBase): diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index e45ad63b1..c3dbfc9ba 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -14,6 +14,8 @@ # limitations under the License. """ +from typing import Callable + import paddle from paddle import nn from paddle.nn.quant import weight_quantize @@ -132,6 +134,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Apply the EP prefill method. @@ -148,8 +151,13 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod): handle, event, ) = self.ep_prefill_runner.dispatch(x, topk_idx, topk_weights) + + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_idx) + if self.ep_prefill_runner.ep_engine.async_finish: event.current_stream_wait() + token_all_num = sum(recv_num_tokens_per_expert_list) # 3. Compute ffn @@ -217,6 +225,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Apply the EP decoder method. @@ -225,6 +234,10 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod): estimate_total_token_nums = gate_out.shape[0] * layer.top_k # 1. Select topk experts and weights topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out) + + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_idx) + expertwise_scale = None if hasattr(layer, "up_gate_proj_in_scale_all_experts"): # only use in w4a8 expertwise_scale = getattr(layer, "up_gate_proj_in_scale_all_experts", None) @@ -269,6 +282,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Paddle Cutlass compute Fused MoE. @@ -369,6 +383,9 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod): if hasattr(layer, "up_gate_proj_in_scale"): dequant_scale = None + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_idx) + if not layer.with_bias and self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8": # only w4a8 need expert_idx_per_token # Other need not this tensor, so we make it None. diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py index 1245cddce..881f9a22c 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py @@ -14,6 +14,8 @@ # limitations under the License. """ +from typing import Callable + import paddle from paddle import nn from paddle.distributed.communication import deep_ep @@ -139,6 +141,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Apply the EP prefill method. @@ -147,6 +150,10 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): # 1. Select topk experts and weights topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out) + + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_idx) + # 2. Dynamic compute blockwise quantization scales x, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant( x, self.quant_config.weight_block_size[0] @@ -264,6 +271,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Apply the EP decoder method. @@ -271,6 +279,10 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): gate_out = gate(x.cast("float32")) # 1. Select topk experts and weights topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out) + + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_idx) + # 2. EP Dispatch permute_input, token_nums_per_expert, handle = self.ep_decoder_runner.dispatch( x, topk_idx, topk_weights, use_fp8=True @@ -335,6 +347,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Paddle Use DeepGemm compute Fused MoE. @@ -363,6 +376,9 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): False, ) + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_ids) + tmp = count_tokens_per_expert_func(topk_ids, layer.num_experts) recv_x, recv_x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(x, 128) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py index 094d3df8f..cd836dbaf 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py @@ -14,6 +14,8 @@ # limitations under the License. """ +from typing import Callable + import paddle from paddle import nn @@ -239,6 +241,7 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Marlin compute Fused MoE. @@ -273,6 +276,9 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase): False, ) + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_ids) + block_size_m = 64 for m in [8, 16, 32, 48, 64]: diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index 3c1485937..2861d96e8 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -14,6 +14,8 @@ # limitations under the License. """ +from typing import Callable + import paddle from paddle import nn @@ -282,6 +284,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Triton compute Fused MoE. @@ -314,6 +317,10 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): True, # apply_norm_weight, False, ) + + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_ids) + up_gate_proj_out = paddle.empty( [token_num * top_k, moe_intermediate_size * 2], dtype=x.dtype, @@ -664,6 +671,7 @@ class Wfp8Afp8MoEMethod(QuantMethodBase): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Triton compute Fused MoE. @@ -724,6 +732,9 @@ class Wfp8Afp8MoEMethod(QuantMethodBase): * ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]), ) + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_ids) + up_gate_proj_out = paddle.empty( [token_num * top_k, moe_intermediate_size * 2], dtype=x.dtype, @@ -953,6 +964,7 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Triton compute Fused MoE. @@ -974,6 +986,9 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): False, ) + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_ids) + up_gate_proj_out = paddle.empty( [token_num * top_k, moe_intermediate_size * 2], dtype=x.dtype, @@ -1466,6 +1481,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Triton compute Fused MoE. @@ -1488,6 +1504,8 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): True, # apply_norm_weight False, ) + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_ids) config = { "BLOCK_SIZE_M": 64, diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py index f75e36bcb..3c548ba57 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py @@ -14,6 +14,8 @@ # limitations under the License. """ +from typing import Callable + import paddle from paddle import nn @@ -261,6 +263,7 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Use Wint2 Triton Fusedmoe compute Fused MoE. @@ -288,6 +291,9 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod): topk_only_mode=False, ) + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_idx) + ffn_out = fastdeploy.model_executor.ops.gpu.moe_expert_ffn_wint2( permute_input, token_nums_per_expert, @@ -328,6 +334,7 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Use Wint2 Triton Fusedmoe compute Fused MoE. @@ -343,6 +350,9 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod): False, ) + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_ids) + num_tokens, K = x.shape E, _, N = layer.up_gate_proj_weight.shape M = num_tokens diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 743e05031..5b1be52d1 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -14,7 +14,8 @@ # limitations under the License. """ -from typing import Optional +from functools import partial +from typing import Callable, Optional import paddle from paddle import nn @@ -26,6 +27,9 @@ from fastdeploy.distributed.communication import ( tensor_model_parallel_all_reduce_custom, ) from fastdeploy.model_executor.forward_meta import ForwardMeta +from fastdeploy.model_executor.layers.moe.routing_indices_cache import ( + save_routing_to_buffer, +) from fastdeploy.model_executor.layers.utils import get_tensor from fastdeploy.model_executor.utils import h2d_copy, slice_fn from fastdeploy.platforms import current_platform @@ -226,7 +230,7 @@ class FusedMoE(nn.Layer): self.is_rearrange = False if self.ep_size > 1: self.quant_method.init_ep(self) - + self.enable_routing_replay = fd_config.routing_replay_config.enable_routing_replay # Merge normal and RL build model if gate_correction_bias is not None: self.gate_correction_bias = gate_correction_bias @@ -600,7 +604,7 @@ class FusedMoE(nn.Layer): else: self.quant_method.process_loaded_weights(self, state_dict) - def forward_split_allgather(self, x: paddle.Tensor, gate: nn.Layer): + def forward_split_allgather(self, x: paddle.Tensor, gate: nn.Layer, topk_ids_hookfunc: Callable = None): """ Forward split allgather function. """ @@ -615,14 +619,14 @@ class FusedMoE(nn.Layer): if end_offset > token_num: end_offset = token_num part_x[: (end_offset - start_offset), :] = x[start_offset:end_offset, :] - out = self.quant_method.apply(self, part_x, gate) + out = self.quant_method.apply(self, part_x, gate, topk_ids_hookfunc=topk_ids_hookfunc) multi_outs = paddle.zeros([token_num_per_rank * self.attn_tp_size, x.shape[1]], dtype=x.dtype) paddle.distributed.all_gather(multi_outs, out, self.tp_group) out = multi_outs[:token_num, :] return out - def forward(self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta): + def forward(self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta = None): """ Defines the forward computation of the moe layer. @@ -633,6 +637,21 @@ class FusedMoE(nn.Layer): Tensor: Output tensor.s """ + topk_ids_hookfunc = None + if self.enable_routing_replay: + if forward_meta is not None: # forward_meta is None when execute empty_input_forward + topk_ids_hookfunc = partial( + save_routing_to_buffer, + routing_replay_table=forward_meta.routing_replay_table, + batch_id_per_token=forward_meta.batch_id_per_token, + seq_lens_decoder=forward_meta.seq_lens_decoder, + cu_seqlens_q=forward_meta.cu_seqlens_q, + layer_idx=self.layer_idx, + tp_size=self.fd_config.parallel_config.tensor_parallel_size, + ep_size=self.fd_config.parallel_config.expert_parallel_size, + tp_group=self.fd_config.parallel_config.tp_group, + ) + token_num = x.shape[0] if ( self.ep_size > 1 @@ -640,11 +659,16 @@ class FusedMoE(nn.Layer): and (not self.fd_config.parallel_config.use_sequence_parallel_moe) and token_num >= self.attn_tp_size ): - out = self.forward_split_allgather(x, gate) + out = self.forward_split_allgather(x, gate, topk_ids_hookfunc=topk_ids_hookfunc) elif self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.enable_chunked_moe: - out = self.forward_chunked_moe(x, gate, forward_meta) + out = self.forward_chunked_moe( + x, + gate, + forward_meta, + topk_ids_hookfunc=topk_ids_hookfunc, + ) else: - out = self.forward_normal(x, gate) + out = self.forward_normal(x, gate, forward_meta, topk_ids_hookfunc=topk_ids_hookfunc) if self.reduce_results and self.tp_size > 1: if current_platform.is_intel_hpu(): @@ -653,7 +677,9 @@ class FusedMoE(nn.Layer): out = tensor_model_parallel_all_reduce(out, self.tp_group) return out - def forward_chunked_moe(self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta): + def forward_chunked_moe( + self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta, topk_ids_hookfunc: Callable = None + ): """ Split input to multi chunk to reduce the memory usage of moe. @@ -677,21 +703,25 @@ class FusedMoE(nn.Layer): for i in range(forward_meta.max_moe_num_chunk): if i < forward_meta.moe_num_chunk: - out_split_list[i] = self.quant_method.apply(self, x_split_list[i], gate) + out_split_list[i] = self.quant_method.apply( + self, x_split_list[i], gate, topk_ids_hookfunc=topk_ids_hookfunc + ) else: # just need to use real data to infer max_moe_num_chunk times. - self.quant_method.apply(self, fake_x, gate) + self.quant_method.apply(self, fake_x, gate, topk_ids_hookfunc=topk_ids_hookfunc) out = paddle.concat(out_split_list, axis=0) else: # when only one chunk, just need to use real data to infer once. - out = self.quant_method.apply(self, x, gate) + out = self.quant_method.apply(self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc) for i in range(forward_meta.max_moe_num_chunk - 1): - self.quant_method.apply(self, fake_x, gate) + self.quant_method.apply(self, fake_x, gate, topk_ids_hookfunc=topk_ids_hookfunc) return out - def forward_normal(self, x: paddle.Tensor, gate: nn.Layer): + def forward_normal( + self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta, topk_ids_hookfunc: Callable = None + ): """ Normal mode of forward. @@ -702,5 +732,5 @@ class FusedMoE(nn.Layer): Tensor: Output tensor.s """ - out = self.quant_method.apply(self, x, gate) + out = self.quant_method.apply(self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc) return out diff --git a/fastdeploy/model_executor/layers/moe/routing_indices_cache.py b/fastdeploy/model_executor/layers/moe/routing_indices_cache.py new file mode 100644 index 000000000..e95a3d856 --- /dev/null +++ b/fastdeploy/model_executor/layers/moe/routing_indices_cache.py @@ -0,0 +1,346 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import copy +import os +import shutil +from abc import ABC, abstractmethod +from typing import Dict, List, Optional + +import paddle +import paddle.distributed as dist +import triton +import triton.language as tl + +from fastdeploy.config import FDConfig + + +@triton.jit +def _save_routing_kernel( + ROUTING_REPLAY_TABLE_PTR, + TOPK_IDS_PTR, + BATCH_ID_PER_TOKEN_PTR, + CU_SEQLENS_Q_PTR, + SEQ_LENS_DECODER_PTR, + LAYER_IDX, + TOKEN_NUM, + TOP_K, + NUM_HIDDEN_LAYERS, + MAX_MODEL_LEN, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + + token_offsets = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + token_mask = token_offsets < TOKEN_NUM + + k_offsets = tl.arange(0, BLOCK_SIZE_K) + + k_mask = k_offsets < TOP_K + + topk_ids_ptrs = TOPK_IDS_PTR + token_offsets[:, None] * TOP_K + k_offsets[None, :] + # [BLOCK_SIZE_M, BLOCK_SIZE_K] + + load_mask = token_mask[:, None] & k_mask[None, :] + topk_vals = tl.load(topk_ids_ptrs, mask=load_mask) + + batch_ids = tl.load(BATCH_ID_PER_TOKEN_PTR + token_offsets, mask=token_mask) + pad_mask = token_mask & (batch_ids != -1) + # [0, 3, 4, 10, 12][0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 3, 3] + # -> [0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 10, 10] + # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] - [0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 10, 10] + # -> [0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 0, 1] + start_offsets = tl.load(CU_SEQLENS_Q_PTR + batch_ids, mask=pad_mask) + token_relative_index = token_offsets - start_offsets + + # [BLOCK_SIZE_M] + len_decoder = tl.load(SEQ_LENS_DECODER_PTR + batch_ids, mask=pad_mask) + token_seq_pos = len_decoder + token_relative_index + + STRIDE_BUF_SEQ = NUM_HIDDEN_LAYERS * MAX_MODEL_LEN * TOP_K + STRIDE_BUF_LAYER = MAX_MODEL_LEN * TOP_K + STRIDE_BUF_TOKEN = TOP_K + + # [BLOCK_SIZE_M, BLOCK_SIZE_K] + output_ptrs = ( + ROUTING_REPLAY_TABLE_PTR + + batch_ids[:, None] * STRIDE_BUF_SEQ + + LAYER_IDX * STRIDE_BUF_LAYER + + token_seq_pos[:, None] * STRIDE_BUF_TOKEN + + k_offsets[None, :] + ) + + pos_mask = token_seq_pos < MAX_MODEL_LEN + pos_mask = pos_mask & pad_mask + + # [BLOCK_SIZE_M, BLOCK_SIZE_K] + pos_mask = pos_mask[:, None] & k_mask[None, :] + + final_mask = load_mask & pos_mask + + tl.store(output_ptrs, topk_vals, mask=final_mask) + + +def save_routing_to_buffer( + routing_replay_table: paddle.Tensor, # [max_num_seqs, num_layers, max_len, top_k] + topk_ids: paddle.Tensor, # [token_num, top_k] + batch_id_per_token: paddle.Tensor, # [token_num, 1] + seq_lens_decoder: paddle.Tensor, # [max_num_seqs, 1] + cu_seqlens_q: paddle.Tensor, # [max_num_seqs + 1, 1] + layer_idx: int, + tp_size: int, + ep_size: int, + tp_group: dist.communication.group.Group, +): + if tp_size > 1 and ep_size > 1: + token_num_per_rank = topk_ids.shape[0] + topk_ids_all = paddle.zeros([token_num_per_rank * tp_size, topk_ids.shape[1]], dtype=topk_ids.dtype) + paddle.distributed.all_gather(topk_ids_all, topk_ids, tp_group) + topk_ids = topk_ids_all[: batch_id_per_token.shape[0], :] + + token_num, top_k = topk_ids.shape + max_num_seqs, num_hidden_layers, max_model_len, _ = routing_replay_table.shape + assert token_num > 0 + + assert topk_ids.shape[1] == routing_replay_table.shape[3], (topk_ids.shape[1], routing_replay_table.shape[3]) + assert batch_id_per_token.shape[0] == token_num, (batch_id_per_token.shape[0], token_num) + assert seq_lens_decoder.shape[0] == max_num_seqs, (seq_lens_decoder.shape[0], max_num_seqs) + + BLOCK_SIZE_M = 128 + BLOCK_SIZE_K = triton.next_power_of_2(top_k) # top_k + + grid = (triton.cdiv(token_num, BLOCK_SIZE_M),) + _save_routing_kernel[grid]( + routing_replay_table, + topk_ids, + batch_id_per_token, + cu_seqlens_q, + seq_lens_decoder, + LAYER_IDX=layer_idx, + TOKEN_NUM=token_num, + TOP_K=top_k, + NUM_HIDDEN_LAYERS=num_hidden_layers, + MAX_MODEL_LEN=max_model_len, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_K=BLOCK_SIZE_K, + ) + + +class RoutingReplayManager: + """Request level routing replay table manager""" + + def __init__( + self, + fd_config: FDConfig, + ): + self.max_num_seqs = fd_config.scheduler_config.max_num_seqs + self.max_model_len = fd_config.model_config.max_model_len + self.num_moe_layers = fd_config.model_config.num_hidden_layers - fd_config.model_config.moe_layer_start_index + + if fd_config.model_config.architectures[0] == "Glm4MoeForCausalLM": + self.moe_top_k = fd_config.model_config.num_experts_per_tok + else: + self.moe_top_k = fd_config.model_config.moe_k + self.tp_rank = fd_config.parallel_config.tensor_parallel_rank + + self.routing_store = get_routing_store(fd_config=fd_config) + self.routing_batch_to_request: Dict[int, str] = {} + self.routing_replay_table = paddle.full( + shape=[self.max_num_seqs, self.num_moe_layers, self.max_model_len, self.moe_top_k], + fill_value=-1, + dtype="int32", + ) + + def register_request(self, batch_id: int, request_id: str): + """ + Register a new request to routing replay table + Args: + batch_id: The batch ID of this request + request_id: The global ID of the request is usually executed by the training process in RL + """ + # Save requests that have been finished for the current slot + if batch_id in self.routing_batch_to_request: + pre_request_id = self._deregister_request(batch_id) + self._put_request_to_store(batch_id, pre_request_id) + # Register the new request + self.routing_batch_to_request[batch_id] = request_id + + def _deregister_request(self, batch_id: int) -> str: + """ + Deregister a request from routing replay table + """ + assert batch_id in self.routing_batch_to_request + return self.routing_batch_to_request.pop(batch_id) + + def _put_request_to_store( + self, + batch_id: int, + request_id: str, + ): + if self.tp_rank == 0: + batch_buffer = self.routing_replay_table[batch_id] + for layer_id in range(self.num_moe_layers): + layer_buffer = batch_buffer[layer_id] + rollout_id = self.split_request_id(request_id) + self.routing_store.put(routing_indices=layer_buffer, rollout_id=rollout_id, layer_idx=layer_id) + + self._clear_table_slot(batch_id) + + def put_table_to_store(self): + """Put the routing table""" + batch_ids = copy.deepcopy(list(self.routing_batch_to_request.keys())) + for batch_id in batch_ids: + request_id = self._deregister_request(batch_id) + self._put_request_to_store(batch_id, request_id) + + def _clear_table_slot(self, batch_id: int): + assert 0 <= batch_id < self.max_num_seqs + self.routing_replay_table[batch_id].fill_(-1) + + def clear_routing_table(self): + """Clear all slots of the routing replay table""" + self.routing_replay_table.fill_(-1) + + def _clear_store(self): + """Clear routing store""" + self.routing_store.clear_store() + + def _clear_request_of_store(self, request_id): + """Clear one request of routing store""" + rollout_id = self.split_request_id(request_id) + for layer_idx in range(self.num_moe_layers): + self.routing_store.clear(rollout_id=rollout_id, layer_idx=layer_idx) + + def get_request_from_store(self, request_id: str) -> List[paddle.Tensor]: + """Get the routing indices of the request from store""" + routing_list = [] + rollout_id = self.split_request_id(request_id) + for layer_idx in range(self.num_moe_layers): + one_layer_routing = self.routing_store.get(rollout_id, layer_idx) + routing_list.append(one_layer_routing) + + return routing_list + + def get_routing_table(self) -> paddle.Tensor: + return self.routing_replay_table + + def split_request_id(self, request_id: str): + """Split the request id to get rollout id""" + chat_type, tmp_str = request_id.split("-", 1) + # NOTE(gongshaotian): only support chatcmpl now + # assert chat_type == "chatcmpl" + reversed_tmp_str = tmp_str[::-1].split("-", 5) + rollout_id = reversed_tmp_str[-1][::-1] + return rollout_id + + +class RoutingStoreBase(ABC): + """Base class for routing store""" + + def __init__(self, fd_config: FDConfig) -> None: + self.fd_config = fd_config + + @abstractmethod + def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: Optional[int] = None) -> None: + """Put the routing indices into store""" + raise NotImplementedError + + @abstractmethod + def get(self, rollout_id: str, layer_idx: Optional[int] = None) -> paddle.Tensor: + """Get the routing indices from store""" + raise NotImplementedError + + @abstractmethod + def clear(self, rollout_id: str, layer_idx: Optional[int] = None) -> None: + """Clear the routing indices of the request""" + raise NotImplementedError + + @abstractmethod + def clear_store( + self, + ): + """Clear the routing indices store""" + raise NotImplementedError + + +class RoutingStoreLocal(RoutingStoreBase): + """Routing Store using local memory""" + + def __init__(self, fd_config) -> None: + super().__init__(fd_config=fd_config) + self.local_store_dir = fd_config.routing_replay_config.local_store_dir + + def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: int) -> None: + """Put the routing indices into store""" + dir_path = os.path.join(self.local_store_dir, f"{rollout_id}") + os.makedirs(dir_path, exist_ok=True) + file_path = os.path.join(dir_path, f"layer_{layer_idx}.pdtensor") + paddle.save(routing_indices, file_path) + + def get( + self, + rollout_id: str, + layer_idx: int = None, + ) -> paddle.Tensor: + """Get the routing indices from store""" + dir_path = os.path.join(self.local_store_dir, f"{rollout_id}") + file_path = os.path.join(dir_path, f"layer_{layer_idx}.pdtensor") + assert os.path.exists(file_path), f"File not found: {file_path}" + layer_routing_indices = paddle.load(file_path) + + return layer_routing_indices + + def clear( + self, + rollout_id: str, + layer_idx: int = None, + ) -> None: + """Clear the routing indices of the request""" + dir_path = os.path.join(self.local_store_dir, f"{rollout_id}") + file_path = os.path.join(dir_path, f"layer_{layer_idx}.pdtensor") + assert os.path.exists(file_path), f"File not found: {file_path}" + os.remove(file_path) + + # Delete empty directory + if len(os.listdir(dir_path)) == 0: + os.rmdir(dir_path) + + def clear_store(self): + """Clear the routing indices store""" + if os.path.isdir(self.local_store_dir): + for file_name in os.listdir(self.local_store_dir): + file_path = os.path.join(self.local_store_dir, file_name) + shutil.rmtree(file_path) + + +class RoutingStoreRDMA(RoutingStoreBase): + """Routing Store using RDMA""" + + def __init__(self) -> None: + super().__init__() + + +def get_routing_store(fd_config: FDConfig) -> RoutingStoreBase: + if fd_config.routing_replay_config.routing_store_type == "local": + return RoutingStoreLocal(fd_config=fd_config) + elif fd_config.routing_replay_config.routing_store_type == "rdma": + return RoutingStoreRDMA(fd_config=fd_config) + else: + raise ValueError( + f"Invalid routing store type: '{fd_config.routing_replay_config.routing_store_type}'. " + "Valid types are: 'local', 'rdma'" + ) diff --git a/fastdeploy/model_executor/models/glm4_moe.py b/fastdeploy/model_executor/models/glm4_moe.py index d5ad6e391..0cc7c4dae 100644 --- a/fastdeploy/model_executor/models/glm4_moe.py +++ b/fastdeploy/model_executor/models/glm4_moe.py @@ -161,7 +161,7 @@ class Glm4Moe(nn.Layer): reduce_results=False, ) - def forward(self, x, forward_meta): + def forward(self, x, forward_meta: ForwardMeta = None): shared_experts_out = self.shared_experts(x) out = self.experts(x, self.gate, forward_meta) out = out + shared_experts_out @@ -306,10 +306,7 @@ class Glm4MoeDecoderLayer(nn.Layer): # Fully Connected hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) - hidden_states = self.mlp( - hidden_states, - forward_meta, - ) + hidden_states = self.mlp(hidden_states, forward_meta) return hidden_states, residual diff --git a/fastdeploy/rl/rollout_config.py b/fastdeploy/rl/rollout_config.py index 6bd3c3bcb..f7ff748fe 100644 --- a/fastdeploy/rl/rollout_config.py +++ b/fastdeploy/rl/rollout_config.py @@ -65,6 +65,7 @@ class RolloutModelConfig: data_parallel_size: int = 1, num_nextn_predict_layers: int = 0, eplb_config: str = {}, + routing_replay_config: str = None, ): # Required parameters self.model = model_name_or_path @@ -113,6 +114,7 @@ class RolloutModelConfig: self.plas_attention_config = plas_attention_config self.num_nextn_predict_layers = num_nextn_predict_layers self.eplb_config = eplb_config + self.routing_replay_config = routing_replay_config def __str__(self): return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items()) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index a91611524..94c7a0b3f 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -45,6 +45,9 @@ from fastdeploy.model_executor.layers.attention.append_attn_backend import ( from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionBackend, ) +from fastdeploy.model_executor.layers.moe.routing_indices_cache import ( + RoutingReplayManager, +) from fastdeploy.model_executor.layers.rotary_embedding import get_rope, get_rope_3d from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata from fastdeploy.model_executor.layers.sample.sampler import Sampler, SpeculativeSampler @@ -202,6 +205,11 @@ class GPUModelRunner(ModelRunnerBase): os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.parallel_config.engine_worker_queue_port) logger.info(f"queue id is {str(self.parallel_config.engine_worker_queue_port)}") + # Rollout routing replay config + self.routing_replay_manager = None + if self.fd_config.routing_replay_config.enable_routing_replay: + self.routing_replay_manager = RoutingReplayManager(fd_config=self.fd_config) + self.zmq_client = None self.async_output_queue = None if envs.FD_USE_GET_SAVE_OUTPUT_V1: @@ -648,6 +656,7 @@ class GPUModelRunner(ModelRunnerBase): self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = 0 self.share_inputs["prompt_lens"][idx : idx + 1] = len(input_ids) self.share_inputs["is_block_step"][idx : idx + 1] = False + self.share_inputs["is_chunk_step"][idx : idx + 1] = prefill_end_index < len(input_ids) self.share_inputs["step_idx"][idx : idx + 1] = ( len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0 ) @@ -656,6 +665,12 @@ class GPUModelRunner(ModelRunnerBase): if request.sampling_params is not None and request.sampling_params.prompt_logprobs is not None: self.prompt_logprobs_reqs[request.request_id] = request has_prefill_task = True + + # Routing Replay + if self.fd_config.routing_replay_config.enable_routing_replay: + if prefill_start_index == 0: + self.routing_replay_manager.register_request(batch_id=idx, request_id=request.request_id) + if ( self.fd_config.scheduler_config.splitwise_role == "decode" ): # In PD, we continue to decode after P generate first token @@ -1148,6 +1163,7 @@ class GPUModelRunner(ModelRunnerBase): self.share_inputs["bad_tokens_len"] = paddle.full([max_num_seqs], 1, dtype="int64") self.share_inputs["next_tokens"] = paddle.full([max_num_seqs, 1], -1, dtype="int64") self.share_inputs["is_block_step"] = paddle.full([max_num_seqs], False, dtype="bool") + self.share_inputs["is_chunk_step"] = paddle.full([max_num_seqs], False, dtype="bool").cpu() self.share_inputs["encoder_block_lens"] = paddle.full([max_num_seqs], 0, dtype="int32") self.share_inputs["step_block_list"] = paddle.full([max_num_seqs], -1, dtype="int32") self.share_inputs["step_lens"] = paddle.full([1], 0, dtype="int32") @@ -1418,6 +1434,9 @@ class GPUModelRunner(ModelRunnerBase): Initialize forward meta, attention meta data and update some config. """ # Initialize forward meta + routing_replay_table = None + if self.routing_replay_manager is not None: + routing_replay_table = self.routing_replay_manager.get_routing_table() self.forward_meta = ForwardMeta( ids_remove_padding=self.share_inputs["ids_remove_padding"], rotary_embs=self.share_inputs["rope_emb"], @@ -1444,6 +1463,7 @@ class GPUModelRunner(ModelRunnerBase): kv_batch_ids=self.share_inputs["kv_batch_ids"], kv_tile_ids_per_batch=self.share_inputs["kv_tile_ids_per_batch"], kv_num_blocks_x_cpu=self.share_inputs["kv_num_blocks_x_cpu"], + routing_replay_table=routing_replay_table, ) dist_status = self.collect_distributed_status() @@ -1932,6 +1952,9 @@ class GPUModelRunner(ModelRunnerBase): if int((self.share_inputs["seq_lens_this_time"] > 0).sum()) == 0: break + if self.fd_config.routing_replay_config.enable_routing_replay: + self.routing_replay_manager.clear_routing_table() + def _update_chunked_prefill(self, tasks): """ Update chunked prefill related parameters @@ -2429,6 +2452,15 @@ class GPUModelRunner(ModelRunnerBase): self.speculative_config.num_speculative_tokens, ) + # Routing replay + if self.fd_config.routing_replay_config.enable_routing_replay: + if ( + not self.exist_prefill() + and not self.exist_decode() + and self.share_inputs["is_block_step"].sum() == 0 + and self.share_inputs["is_chunk_step"].sum() == 0 + ): + self.routing_replay_manager.put_table_to_store() return None def _pool(self, hidden_states: paddle.Tensor, num_running_requests: int) -> Optional[ModelRunnerOutput]: diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 02d66f4bc..0c29ce4d7 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -38,6 +38,7 @@ from fastdeploy.config import ( ModelConfig, ParallelConfig, PlasAttentionConfig, + RoutingReplayConfig, SpeculativeConfig, StructuredOutputsConfig, ) @@ -885,6 +886,13 @@ def parse_args(): help="EPLB Configuration.", ) + parser.add_argument( + "--routing_replay_config", + type=json.loads, + default=None, + help="Configation of Rollout Routing Replay.", + ) + args = parser.parse_args() return args @@ -944,6 +952,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: eplb_config = EPLBConfig(args.eplb_config) structured_outputs_config: StructuredOutputsConfig = StructuredOutputsConfig(args=vars(args)) + routing_replay_config = RoutingReplayConfig(args.routing_replay_config) # Note(tangbinhan): used for load_checkpoint model_config.pretrained_config.tensor_parallel_rank = parallel_config.tensor_parallel_rank @@ -1003,6 +1012,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: plas_attention_config=plas_attention_config, structured_outputs_config=structured_outputs_config, eplb_config=eplb_config, + routing_replay_config=routing_replay_config, ) update_fd_config_for_mm(fd_config) if fd_config.load_config.load_choices == "default_v1" and not v1_loader_support(fd_config): diff --git a/tests/distributed/chunked_moe.py b/tests/distributed/chunked_moe.py index 0be645d38..ef41a610d 100644 --- a/tests/distributed/chunked_moe.py +++ b/tests/distributed/chunked_moe.py @@ -90,7 +90,7 @@ class MockAttentionBackend: class MockQuantMethod: - def apply(self, layer, x, gate): + def apply(self, layer, x, gate, topk_ids_hookfunc=None): return x @@ -129,6 +129,7 @@ class TestChunkedMoE(unittest.TestCase): model_runner.speculative_decoding = False model_runner._init_share_inputs(mock_fd_config.scheduler_config.max_num_seqs) model_runner.share_inputs["caches"] = None + model_runner.routing_replay_manager = None if dist.get_rank() == 0: model_runner.share_inputs["ids_remove_padding"] = paddle.ones([10]) @@ -148,6 +149,7 @@ class TestChunkedMoE(unittest.TestCase): fused_moe.fd_config = mock_fd_config fused_moe.quant_method = MockQuantMethod() + fused_moe.enable_routing_replay = None return fused_moe def run_model_runner(self): diff --git a/tests/e2e/test_EB_Lite_serving.py b/tests/e2e/test_EB_Lite_serving.py index bc27daab9..c71b76672 100644 --- a/tests/e2e/test_EB_Lite_serving.py +++ b/tests/e2e/test_EB_Lite_serving.py @@ -78,6 +78,8 @@ def setup_and_run_server(): "wint4", "--graph-optimization-config", '{"cudagraph_capture_sizes": [1], "use_cudagraph":true}', + "--routing-replay-config", + '{"enable_routing_replay":true, "routing_store_type":"local", "local_store_dir":"./routing_replay_output"}', ] # Start subprocess in new process group diff --git a/tests/layers/test_fusedmoe.py b/tests/layers/test_fusedmoe.py index ed4fe5b28..346afc98f 100644 --- a/tests/layers/test_fusedmoe.py +++ b/tests/layers/test_fusedmoe.py @@ -31,6 +31,7 @@ from fastdeploy.config import ( LoadConfig, ModelConfig, ParallelConfig, + RoutingReplayConfig, ) from fastdeploy.model_executor.layers.moe.moe import FusedMoE from fastdeploy.model_executor.layers.quantization.block_wise_fp8 import ( @@ -476,6 +477,7 @@ class FuseMoEWrapper(paddle.nn.Layer): graph_opt_config=GraphOptimizationConfig({}), load_config=LoadConfig({}), ips=",".join(["0"] * nnodes), + routing_replay_config=RoutingReplayConfig({}), ) self.fd_config.parallel_config.tp_group = None self.fd_config.parallel_config.tensor_parallel_rank = tp_rank diff --git a/tests/layers/test_w4a8_moe.py b/tests/layers/test_w4a8_moe.py index dc6dab154..f20c27b06 100644 --- a/tests/layers/test_w4a8_moe.py +++ b/tests/layers/test_w4a8_moe.py @@ -13,6 +13,7 @@ from fastdeploy.config import ( LoadConfig, ModelConfig, ParallelConfig, + RoutingReplayConfig, ) from fastdeploy.model_executor.layers.moe.moe import FusedMoE from fastdeploy.model_executor.layers.quantization.w4a8 import W4A8Config @@ -59,6 +60,7 @@ class FuseMoEWrapper(paddle.nn.Layer): graph_opt_config=GraphOptimizationConfig({}), load_config=LoadConfig({}), ips=",".join(["0"] * nnodes), + routing_replay_config=RoutingReplayConfig({}), ) self.fd_config.parallel_config.tp_group = None self.fd_config.parallel_config.tensor_parallel_rank = tp_rank diff --git a/tests/layers/test_w4afp8_moe.py b/tests/layers/test_w4afp8_moe.py index 65b773317..8f1ae79cd 100644 --- a/tests/layers/test_w4afp8_moe.py +++ b/tests/layers/test_w4afp8_moe.py @@ -13,6 +13,7 @@ from fastdeploy.config import ( LoadConfig, ModelConfig, ParallelConfig, + RoutingReplayConfig, ) from fastdeploy.model_executor.layers.moe.moe import FusedMoE from fastdeploy.model_executor.layers.quantization.w4afp8 import W4AFP8Config @@ -65,6 +66,7 @@ class FuseMoEWrapper(paddle.nn.Layer): graph_opt_config=GraphOptimizationConfig({}), load_config=LoadConfig({}), ips=",".join(["0"] * nnodes), + routing_replay_config=RoutingReplayConfig({}), ) self.fd_config.parallel_config.tp_group = None self.fd_config.parallel_config.tensor_parallel_rank = tp_rank