diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 583d1ba70..e6bd44f58 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1163,6 +1163,31 @@ class CommitConfig: logger.info("=============================================================") +class RoutingReplayConfig: + """Configuration for Routing Replay used in RL training""" + + def __init__(self, args) -> None: + self.enable_routing_replay: bool = False + self.routing_store_type: str = "local" + + # Local routing store + self.local_store_dir: str = "./routing_replay_output" + + # RDMA routing store + pass + + if args is not None: + for key, value in args.items(): + if hasattr(self, key) and value != "None": + setattr(self, key, value) + + def to_json_string(self): + """ + Convert routing replay config to json string. + """ + return json.dumps({key: value for key, value in self.__dict__.items()}) + + class FDConfig: """ The configuration class which contains all fastdeploy-related configuration. This @@ -1206,6 +1231,7 @@ class FDConfig: test_mode=False, enable_attention_dp_balance: bool = False, attention_dp_time_out_iters: int = 0, + routing_replay_config: Optional[RoutingReplayConfig] = None, ): self.model_config: ModelConfig = model_config # type: ignore self.cache_config: CacheConfig = cache_config # type: ignore @@ -1221,8 +1247,10 @@ class FDConfig: self.cache_config: CacheConfig = cache_config # type: ignore self.eplb_config: Optional[EPLBConfig] = eplb_config self.moba_attention_config: Optional[MobaAttentionConfig] = moba_attention_config + self.routing_replay_config = routing_replay_config self.enable_attention_dp_balance = enable_attention_dp_balance self.attention_dp_time_out_iters = attention_dp_time_out_iters + # Initialize cuda graph capture list max_capture_shape = self.parallel_config.max_num_seqs if self.speculative_config is not None and self.speculative_config.method == "mtp": diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index a53658963..ac4354df0 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -33,6 +33,7 @@ from fastdeploy.config import ( MobaAttentionConfig, ModelConfig, ParallelConfig, + RoutingReplayConfig, SpeculativeConfig, TaskOption, ) @@ -421,6 +422,11 @@ class EngineArgs: Configuration for eplb. """ + routing_replay_config: Optional[Dict[str, Any]] = None + """ + Flag to rollout routing replay(r3) + """ + def __post_init__(self): """ Post-initialization processing to set default tokenizer if not provided. @@ -731,6 +737,12 @@ class EngineArgs: default=EngineArgs.eplb_config, help="Config of eplb.", ) + parallel_group.add_argument( + "--routing-replay-config", + type=json.loads, + default=EngineArgs.routing_replay_config, + help="Flag of rollout routing replay(r3).", + ) # Load group load_group = parser.add_argument_group("Load Configuration") @@ -1076,6 +1088,14 @@ class EngineArgs: eplb_args[k] = v return EPLBConfig(eplb_args) + def create_routing_repaly_config(self) -> RoutingReplayConfig: + """ """ + routing_replay_args = asdict(self) + if self.routing_replay_config is not None: + for k, v in self.routing_replay_config.items(): + routing_replay_args[k] = v + return RoutingReplayConfig(routing_replay_args) + def create_engine_config(self, port_availability_check: bool = True) -> FDConfig: """ Create and return a Config object based on the current settings. @@ -1118,6 +1138,7 @@ class EngineArgs: graph_opt_cfg.update_use_cudagraph(self.use_cudagraph) moba_attention_config = self.create_moba_attention_config() eplb_cfg = self.create_eplb_config() + routing_replay_config = self.create_routing_repaly_config() early_stop_cfg = self.create_early_stop_config() early_stop_cfg.update_enable_early_stop(self.enable_early_stop) @@ -1163,4 +1184,5 @@ class EngineArgs: early_stop_config=early_stop_cfg, enable_attention_dp_balance=self.enable_attention_dp_balance, attention_dp_time_out_iters=self.attention_dp_time_out_iters, + routing_replay_config=routing_replay_config, ) diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index b26337da3..4b712614d 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -462,6 +462,7 @@ class LLMEngine: f" --moba_attention_config '{self.cfg.moba_attention_config.to_json_string()}'" f" --attention_dp_time_out_iters {self.cfg.attention_dp_time_out_iters}" f" --eplb_config '{self.cfg.eplb_config.to_json_string()}'" + f" --routing_replay_config '{self.cfg.routing_replay_config.to_json_string()}'" f" --ips {ips}" ) diff --git a/fastdeploy/model_executor/forward_meta.py b/fastdeploy/model_executor/forward_meta.py index 06ef4b755..e66289fdb 100644 --- a/fastdeploy/model_executor/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -110,6 +110,8 @@ class ForwardMeta: block_tables: Optional[paddle.Tensor] = None # KV caches caches: Optional[list[paddle.Tensor]] = None + # Routing Replay table buffer + routing_replay_table: Optional[paddle.Tensor] = None def clear_caches(self): """Safely clean up the caches""" diff --git a/fastdeploy/model_executor/layers/backends/dcu/fused_moe_triton_backends.py b/fastdeploy/model_executor/layers/backends/dcu/fused_moe_triton_backends.py index f1ea6572f..5c9879963 100644 --- a/fastdeploy/model_executor/layers/backends/dcu/fused_moe_triton_backends.py +++ b/fastdeploy/model_executor/layers/backends/dcu/fused_moe_triton_backends.py @@ -14,6 +14,8 @@ # limitations under the License. """ +from typing import Callable + import paddle from paddle import nn @@ -102,6 +104,7 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Triton compute Fused MoE. @@ -119,6 +122,9 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase): topk_weights, topk_ids = paddle.topk(scores, k=top_k, axis=-1, sorted=False) topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdim=True) + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_ids) + intermediate_cache1 = paddle.empty( [token_num * top_k, moe_intermediate_size * 2], dtype=x.dtype, diff --git a/fastdeploy/model_executor/layers/backends/gcu/moe/fused_moe_method_gcu_backend.py b/fastdeploy/model_executor/layers/backends/gcu/moe/fused_moe_method_gcu_backend.py index c13a68f31..66464a037 100644 --- a/fastdeploy/model_executor/layers/backends/gcu/moe/fused_moe_method_gcu_backend.py +++ b/fastdeploy/model_executor/layers/backends/gcu/moe/fused_moe_method_gcu_backend.py @@ -16,6 +16,7 @@ import multiprocessing import os +from typing import Callable import numpy as np import paddle @@ -189,6 +190,7 @@ class GCUFusedMoeMethod(UnquantizedFusedMoEMethod): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Paddle gcu compute Fused MoE. @@ -201,6 +203,7 @@ class GCUFusedMoeMethod(UnquantizedFusedMoEMethod): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Apply the EP prefill method. @@ -212,6 +215,7 @@ class GCUFusedMoeMethod(UnquantizedFusedMoEMethod): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Apply the EP decoder method. @@ -223,6 +227,7 @@ class GCUFusedMoeMethod(UnquantizedFusedMoEMethod): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Paddle Cutlass compute Fused MoE. @@ -388,6 +393,7 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Paddle gcu compute Fused MoE. diff --git a/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_triton_metax_backend.py b/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_triton_metax_backend.py index e945a189a..78ac8982d 100644 --- a/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_triton_metax_backend.py +++ b/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_triton_metax_backend.py @@ -14,6 +14,8 @@ # limitations under the License. """ +from typing import Callable + import paddle from paddle import nn @@ -132,6 +134,7 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Triton compute Fused MoE. @@ -151,6 +154,10 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase): True, # apply_norm_weight, False, ) + + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_ids) + up_gate_proj_out = paddle.empty( [token_num * top_k, moe_intermediate_size * 2], dtype=x.dtype, diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py index 039bb2b13..0c5bccc3d 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py @@ -15,6 +15,7 @@ """ from abc import abstractmethod +from typing import Callable import paddle from paddle import nn @@ -120,6 +121,7 @@ class MoEMethodBase(QuantMethodBase): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Apply the EP prefill method. @@ -144,6 +146,7 @@ class MoEMethodBase(QuantMethodBase): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Paddle Cutlass compute Fused MoE. @@ -155,6 +158,7 @@ class MoEMethodBase(QuantMethodBase): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Paddle Cutlass compute Fused MoE. @@ -163,13 +167,13 @@ class MoEMethodBase(QuantMethodBase): if layer.fd_config.parallel_config.moe_phase.phase == "prefill": if layer.fd_config.parallel_config.splitwise_role == "mixed" and layer.layer_idx == 0: self.ep_prefill_runner.clean_low_latency_buffer() - return self.apply_ep_prefill(layer, x, gate) + return self.apply_ep_prefill(layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc) else: if layer.fd_config.parallel_config.splitwise_role == "mixed" and layer.layer_idx == 0: self.ep_decoder_runner.clean_low_latency_buffer() - return self.apply_ep_decode(layer, x, gate) + return self.apply_ep_decode(layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc) else: - return self.apply_tp(layer, x, gate) + return self.apply_tp(layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc) class UnquantizedFusedMoEMethod(MoEMethodBase): diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index b547e1129..35680c5f1 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -14,6 +14,8 @@ # limitations under the License. """ +from typing import Callable + import paddle from paddle import nn from paddle.nn.quant import weight_quantize @@ -105,6 +107,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Apply the EP prefill method. @@ -121,6 +124,10 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod): handle, _, ) = self.ep_prefill_runner.dispatch(x, topk_idx, topk_weights) + + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_idx) + token_all_num = sum(recv_num_tokens_per_expert_list) # 3. Compute ffn @@ -178,6 +185,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Apply the EP decoder method. @@ -186,6 +194,10 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod): estimate_total_token_nums = gate_out.shape[0] * layer.top_k # 1. Select topk experts and weights topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out) + + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_idx) + expertwise_scale = None if hasattr(layer, "up_gate_proj_in_scale_all_experts"): # only use in w4a8 expertwise_scale = getattr(layer, "up_gate_proj_in_scale_all_experts", None) @@ -220,6 +232,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Paddle Cutlass compute Fused MoE. @@ -277,6 +290,9 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod): topk_only_mode=False, ) + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_idx) + if self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8": # only w4a8 need expert_idx_per_token # Other need not this tensor, so we make it None. diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py index e76bdf55c..1f3f46796 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py @@ -14,6 +14,8 @@ # limitations under the License. """ +from typing import Callable + import paddle from paddle import nn from paddleformers.utils.log import logger @@ -298,6 +300,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Apply the EP prefill method. @@ -305,6 +308,10 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): gate_out = gate(x.cast("float32")) # 1. Select topk experts and weights topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out) + + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_idx) + # 2. Dynamic compute blockwise quantization scales x, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant( x, self.quant_config.weight_block_size[0] @@ -406,6 +413,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Apply the EP decoder method. @@ -413,6 +421,10 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): gate_out = gate(x.cast("float32")) # 1. Select topk experts and weights topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out) + + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_idx) + # 2. EP Dispatch permute_input, token_nums_per_expert, handle = self.ep_decoder_runner.dispatch( x, topk_idx, topk_weights, use_fp8=True @@ -477,6 +489,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Paddle Use DeepGemm compute Fused MoE. @@ -504,6 +517,9 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): False, ) + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_ids) + tmp = count_tokens_per_expert_func(topk_ids, layer.num_experts) recv_x, recv_x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(x, 128) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py index 705dfee92..43ca6f892 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py @@ -14,6 +14,8 @@ # limitations under the License. """ +from typing import Callable + import paddle from paddle import nn @@ -240,6 +242,7 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Marlin compute Fused MoE. @@ -275,6 +278,9 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase): False, ) + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_ids) + block_size_m = 64 for m in [8, 16, 32, 48, 64]: diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index e9bf781a2..6a28db5c7 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -14,6 +14,8 @@ # limitations under the License. """ +from typing import Callable + import paddle from paddle import nn @@ -156,6 +158,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Triton compute Fused MoE. @@ -186,6 +189,9 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): False, ) + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_ids) + up_gate_proj_out = paddle.empty( [token_num * top_k, moe_intermediate_size * 2], dtype=x.dtype, @@ -420,6 +426,7 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Triton compute Fused MoE. @@ -451,6 +458,9 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): False, ) + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_ids) + up_gate_proj_out = paddle.empty( [token_num * top_k, moe_intermediate_size * 2], dtype=x.dtype, @@ -840,6 +850,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Triton compute Fused MoE. @@ -871,6 +882,9 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): False, ) + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_ids) + config = { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": self.quant_config.weight_block_size[1], diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py index 7cbb46dc1..f7bbba151 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py @@ -14,6 +14,8 @@ # limitations under the License. """ +from typing import Callable + import paddle from paddle import nn @@ -261,6 +263,7 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Use Wint2 Triton Fusedmoe compute Fused MoE. @@ -288,6 +291,9 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod): topk_only_mode=False, ) + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_idx) + ffn_out = fastdeploy.model_executor.ops.gpu.moe_expert_ffn_wint2( permute_input, token_nums_per_expert, @@ -333,6 +339,7 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Use Wint2 Triton Fusedmoe compute Fused MoE. @@ -348,6 +355,9 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod): False, ) + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_ids) + num_tokens, K = x.shape E, _, N = layer.up_gate_proj_weight.shape M = num_tokens diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_xpu_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_xpu_backend.py index 272899531..2a2fd0cab 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_xpu_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_xpu_backend.py @@ -14,6 +14,8 @@ # limitations under the License. """ +from typing import Callable + import paddle from paddle import nn @@ -47,6 +49,7 @@ class XPUMoEMethod(UnquantizedFusedMoEMethod): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Paddle Cutlass compute Fused MoE. @@ -82,6 +85,7 @@ class XPUMoEMethod(UnquantizedFusedMoEMethod): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Apply the EP prefill method. @@ -93,6 +97,7 @@ class XPUMoEMethod(UnquantizedFusedMoEMethod): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ Apply the EP decoder method. @@ -227,6 +232,7 @@ class XPUWeightOnlyMoEMethod(QuantMethodBase): layer: nn.Layer, x: paddle.Tensor, gate: nn.Layer, + topk_ids_hookfunc: Callable = None, ) -> paddle.Tensor: """ XPU compute Fused MoE. diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 8b83aecca..5afd22688 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -14,6 +14,7 @@ # limitations under the License. """ +from functools import partial from typing import Optional import numpy as np @@ -22,6 +23,10 @@ from paddle import nn from paddleformers.utils.log import logger from fastdeploy import envs +from fastdeploy.model_executor.forward_meta import ForwardMeta +from fastdeploy.model_executor.layers.moe.routing_indices_cache import ( + save_routing_to_buffer, +) from fastdeploy.model_executor.layers.utils import get_tensor from fastdeploy.model_executor.utils import slice_fn from fastdeploy.platforms import current_platform @@ -195,6 +200,7 @@ class FusedMoE(nn.Layer): if self.ep_size > 1: self.quant_method.init_ep(self) + self.enable_routing_replay = fd_config.routing_replay_config.enable_routing_replay # Merge normal and RL build model if gate_correction_bias is not None: self.gate_correction_bias = gate_correction_bias @@ -532,7 +538,7 @@ class FusedMoE(nn.Layer): else: self.quant_method.process_loaded_weights(self, state_dict) - def forward(self, x: paddle.Tensor, gate: nn.Layer): + def forward(self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta = None): """ Defines the forward computation of the moe layer. @@ -543,5 +549,20 @@ class FusedMoE(nn.Layer): Tensor: Output tensor.s """ - out = self.quant_method.apply(self, x, gate) + topk_ids_hookfunc = None + if self.enable_routing_replay: + if forward_meta is not None: # forward_meta is None when execute empty_input_forward + topk_ids_hookfunc = partial( + save_routing_to_buffer, + routing_replay_table=forward_meta.routing_replay_table, + batch_id_per_token=forward_meta.batch_id_per_token, + seq_lens_decoder=forward_meta.seq_lens_decoder, + cu_seqlens_q=forward_meta.cu_seqlens_q, + layer_idx=self.layer_idx, + tp_size=self.fd_config.parallel_config.tensor_parallel_size, + ep_size=self.fd_config.parallel_config.expert_parallel_size, + tp_group=self.fd_config.parallel_config.tp_group, + ) + + out = self.quant_method.apply(self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc) return out diff --git a/fastdeploy/model_executor/layers/moe/routing_indices_cache.py b/fastdeploy/model_executor/layers/moe/routing_indices_cache.py new file mode 100644 index 000000000..d4d971a93 --- /dev/null +++ b/fastdeploy/model_executor/layers/moe/routing_indices_cache.py @@ -0,0 +1,328 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import copy +import os +import shutil +from abc import ABC, abstractmethod +from typing import Dict, List, Optional + +import paddle +import paddle.distributed as dist +import triton +import triton.language as tl + +from fastdeploy.config import FDConfig + + +@triton.jit +def _save_routing_kernel( + ROUTING_REPLAY_TABLE_PTR, + TOPK_IDS_PTR, + BATCH_ID_PER_TOKEN_PTR, + CU_SEQLENS_Q_PTR, + SEQ_LENS_DECODER_PTR, + LAYER_IDX, + TOKEN_NUM, + TOP_K, + NUM_HIDDEN_LAYERS, + MAX_MODEL_LEN, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + + token_offsets = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + token_mask = token_offsets < TOKEN_NUM + + k_offsets = tl.arange(0, BLOCK_SIZE_K) + topk_ids_ptrs = TOPK_IDS_PTR + token_offsets[:, None] * TOP_K + k_offsets[None, :] + # [BLOCK_SIZE_M, BLOCK_SIZE_K] + topk_vals = tl.load(topk_ids_ptrs, mask=token_mask[:, None]) + + batch_ids = tl.load(BATCH_ID_PER_TOKEN_PTR + token_offsets, mask=token_mask) + pad_mask = token_mask & (batch_ids != -1) + # [0, 3, 4, 10, 12][0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 3, 3] + # -> [0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 10, 10] + # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] - [0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 10, 10] + # -> [0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 0, 1] + start_offsets = tl.load(CU_SEQLENS_Q_PTR + batch_ids, mask=pad_mask) + token_relative_index = token_offsets - start_offsets + + # [BLOCK_SIZE_M] + len_decoder = tl.load(SEQ_LENS_DECODER_PTR + batch_ids, mask=pad_mask) + token_seq_pos = len_decoder + token_relative_index + + STRIDE_BUF_SEQ = NUM_HIDDEN_LAYERS * MAX_MODEL_LEN * TOP_K + STRIDE_BUF_LAYER = MAX_MODEL_LEN * TOP_K + STRIDE_BUF_TOKEN = TOP_K + + # [BLOCK_SIZE_M, BLOCK_SIZE_K] + output_ptrs = ( + ROUTING_REPLAY_TABLE_PTR + + batch_ids[:, None] * STRIDE_BUF_SEQ + + LAYER_IDX * STRIDE_BUF_LAYER + + token_seq_pos[:, None] * STRIDE_BUF_TOKEN + + k_offsets[None, :] + ) + + pos_mask = token_seq_pos < MAX_MODEL_LEN + pos_mask = pos_mask & pad_mask + final_mask = token_mask[:, None] & pos_mask[:, None] + + tl.store(output_ptrs, topk_vals, mask=final_mask) + + +def save_routing_to_buffer( + routing_replay_table: paddle.Tensor, # [max_num_seqs, num_layers, max_len, top_k] + topk_ids: paddle.Tensor, # [token_num, top_k] + batch_id_per_token: paddle.Tensor, # [token_num, 1] + seq_lens_decoder: paddle.Tensor, # [max_num_seqs, 1] + cu_seqlens_q: paddle.Tensor, # [max_num_seqs + 1, 1] + layer_idx: int, + tp_size: int, + ep_size: int, + tp_group: dist.communication.group.Group, +): + if tp_size > 1 and ep_size > 1: + token_num_per_rank = topk_ids.shape[0] + topk_ids_all = paddle.zeros([token_num_per_rank * tp_size, topk_ids.shape[1]], dtype=topk_ids.dtype) + paddle.distributed.all_gather(topk_ids_all, topk_ids, tp_group) + topk_ids = topk_ids_all[: batch_id_per_token.shape[0], :] + + token_num, top_k = topk_ids.shape + max_num_seqs, num_hidden_layers, max_model_len, _ = routing_replay_table.shape + assert token_num > 0 + assert topk_ids.shape[1] == routing_replay_table.shape[3], (topk_ids.shape[1], routing_replay_table.shape[3]) + assert batch_id_per_token.shape[0] == token_num, (batch_id_per_token.shape[0], token_num) + assert seq_lens_decoder.shape[0] == max_num_seqs, (seq_lens_decoder.shape[0], max_num_seqs) + + BLOCK_SIZE_M = 128 + BLOCK_SIZE_K = top_k + + grid = (triton.cdiv(token_num, BLOCK_SIZE_M),) + _save_routing_kernel[grid]( + routing_replay_table, + topk_ids, + batch_id_per_token, + cu_seqlens_q, + seq_lens_decoder, + LAYER_IDX=layer_idx, + TOKEN_NUM=token_num, + TOP_K=top_k, + NUM_HIDDEN_LAYERS=num_hidden_layers, + MAX_MODEL_LEN=max_model_len, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_K=BLOCK_SIZE_K, + ) + + +class RoutingReplayManager: + """Request level routing replay table manager""" + + def __init__( + self, + fd_config: FDConfig, + ): + self.max_num_seqs = fd_config.parallel_config.max_num_seqs + self.max_model_len = fd_config.model_config.max_model_len + self.num_moe_layers = fd_config.model_config.num_hidden_layers - fd_config.model_config.moe_layer_start_index + self.moe_top_k = fd_config.model_config.moe_k + self.tp_rank = fd_config.parallel_config.tensor_parallel_rank + + self.routing_store = get_routing_store(fd_config=fd_config) + self.routing_batch_to_request: Dict[int, str] = {} + self.routing_replay_table = paddle.full( + shape=[self.max_num_seqs, self.num_moe_layers, self.max_model_len, self.moe_top_k], + fill_value=-1, + dtype="int32", + ) + + def register_request(self, batch_id: int, request_id: str): + """ + Register a new request to routing replay table + Args: + batch_id: The batch ID of this request + request_id: The global ID of the request is usually executed by the training process in RL + """ + # Save requests that have been finished for the current slot + if batch_id in self.routing_batch_to_request: + pre_request_id = self._deregister_request(batch_id) + self._put_request_to_store(batch_id, pre_request_id) + # Register the new request + self.routing_batch_to_request[batch_id] = request_id + + def _deregister_request(self, batch_id: int) -> str: + """ + Deregister a request from routing replay table + """ + assert batch_id in self.routing_batch_to_request + return self.routing_batch_to_request.pop(batch_id) + + def _put_request_to_store( + self, + batch_id: int, + request_id: str, + ): + if self.tp_rank == 0: + batch_buffer = self.routing_replay_table[batch_id] + for layer_id in range(self.num_moe_layers): + layer_buffer = batch_buffer[layer_id] + rollout_id = self.split_request_id(request_id) + self.routing_store.put(routing_indices=layer_buffer, rollout_id=rollout_id, layer_idx=layer_id) + + self._clear_table_slot(batch_id) + + def put_table_to_store(self): + """Put the routin table""" + batch_ids = copy.deepcopy(list(self.routing_batch_to_request.keys())) + for batch_id in batch_ids: + request_id = self._deregister_request(batch_id) + self._put_request_to_store(batch_id, request_id) + + def _clear_table_slot(self, batch_id: int): + assert 0 <= batch_id < self.max_num_seqs + self.routing_replay_table[batch_id].fill_(-1) + + def clear_routing_table(self): + """Clear all slots of the routing replay table""" + self.routing_replay_table.fill_(-1) + + def _clear_store(self): + """Clear routing store""" + self.routing_store.clear_store() + + def _clear_request_of_store(self, request_id): + """Clear one request of routing store""" + rollout_id = self.split_request_id(request_id) + for layer_idx in range(self.num_moe_layers): + self.routing_store.clear(rollout_id=rollout_id, layer_idx=layer_idx) + + def get_request_from_store(self, request_id: str) -> List[paddle.Tensor]: + """Get the routing indices of the reuest from store""" + routing_list = [] + rollout_id = self.split_request_id(request_id) + for layer_idx in range(self.num_moe_layers): + one_layer_routing = self.routing_store.get(rollout_id, layer_idx) + routing_list.append(one_layer_routing) + + return routing_list + + def get_routing_table(self) -> paddle.Tensor: + return self.routing_replay_table + + def split_request_id(self, request_id: str): + """Split the request id to get rollout id""" + chat_type, tmp_str = request_id.split("-", 1) + assert chat_type == "chatcmpl" + reversed_tmp_str = tmp_str[::-1].split("-", 5) + rollout_id = reversed_tmp_str[-1][::-1] + return rollout_id + + +class RoutingStoreBase(ABC): + """Base class for routing store""" + + def __init__(self, fd_config: FDConfig) -> None: + self.fd_config = fd_config + + @abstractmethod + def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: Optional[int] = None) -> None: + """Put the routing indices into store""" + raise NotImplementedError + + @abstractmethod + def get(self, rollout_id: str, layer_idx: Optional[int] = None) -> paddle.Tensor: + """Get the routing indices from store""" + raise NotImplementedError + + @abstractmethod + def clear(self, rollout_id: str, layer_idx: Optional[int] = None) -> None: + """Clear the routing indices of the request""" + raise NotImplementedError + + @abstractmethod + def clear_store( + self, + ): + """Clear the routing indices store""" + raise NotImplementedError + + +class RoutingStoreLocal(RoutingStoreBase): + """Routing Store using local memory""" + + def __init__(self, fd_config) -> None: + super().__init__(fd_config=fd_config) + self.local_store_dir = fd_config.routing_replay_config.local_store_dir + + def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: int) -> None: + """Put the routing indices into store""" + dir_path = os.path.join(self.local_store_dir, f"{rollout_id}") + os.makedirs(dir_path, exist_ok=True) + file_path = os.path.join(dir_path, f"layer_{layer_idx}.pdtensor") + paddle.save(routing_indices, file_path) + + def get( + self, + rollout_id: str, + layer_idx: int = None, + ) -> paddle.Tensor: + """Get the routing indices from store""" + dir_path = os.path.join(self.local_store_dir, f"{rollout_id}") + file_path = os.path.join(dir_path, f"layer_{layer_idx}.pdtensor") + assert os.path.exists(file_path), f"File not found: {file_path}" + layer_routing_indices = paddle.load(file_path) + + return layer_routing_indices + + def clear( + self, + rollout_id: str, + layer_idx: int = None, + ) -> None: + """Clear the routing indices of the request""" + dir_path = os.path.join(self.local_store_dir, f"{rollout_id}") + file_path = os.path.join(dir_path, f"layer_{layer_idx}.pdtensor") + assert os.path.exists(file_path), f"File not found: {file_path}" + os.remove(file_path) + + # Delete empty directory + if len(os.listdir(dir_path)) == 0: + os.rmdir(dir_path) + + def clear_store(self): + """Clear the routing indices store""" + if os.path.isdir(self.local_store_dir): + for file_name in os.listdir(self.local_store_dir): + file_path = os.path.join(self.local_store_dir, file_name) + shutil.rmtree(file_path) + + +class RoutingStoreRDMA(RoutingStoreBase): + """Routing Store using RDMA""" + + def __init__(self) -> None: + super().__init__() + + +def get_routing_store(fd_config: FDConfig) -> RoutingStoreBase: + if fd_config.routing_replay_config.routing_store_type == "local": + return RoutingStoreLocal(fd_config=fd_config) + elif fd_config.routing_replay_config.routing_store_type == "rdma": + return RoutingStoreRDMA(fd_config=fd_config) + else: + raise ValueError("Invalid store type") diff --git a/fastdeploy/rl/rollout_config.py b/fastdeploy/rl/rollout_config.py index 940d37821..58535236d 100644 --- a/fastdeploy/rl/rollout_config.py +++ b/fastdeploy/rl/rollout_config.py @@ -67,6 +67,7 @@ class RolloutModelConfig: enable_attention_dp_balance: bool = False, attention_dp_time_out_iters: int = 0, eplb_config: str = {}, + routing_replay_config: str = None, ): # Required parameters self.model = model_name_or_path @@ -117,6 +118,7 @@ class RolloutModelConfig: self.enable_attention_dp_balance = enable_attention_dp_balance self.attention_dp_time_out_iters = attention_dp_time_out_iters self.eplb_config = eplb_config + self.routing_replay_config = routing_replay_config def __str__(self): return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items()) diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 84ed49948..fe379954b 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -38,6 +38,7 @@ from fastdeploy.config import ( MobaAttentionConfig, ModelConfig, ParallelConfig, + RoutingReplayConfig, SpeculativeConfig, ) from fastdeploy.engine.request import RequestType @@ -840,6 +841,13 @@ def parse_args(): help="EPLB Configuration.", ) + parser.add_argument( + "--routing_replay_config", + type=json.loads, + default=None, + help="Configation of Rollout Routing Replay.", + ) + args = parser.parse_args() return args @@ -900,6 +908,8 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: eplb_config = EPLBConfig(args.eplb_config) + routing_replay_config = RoutingReplayConfig(args.routing_replay_config) + # Note(tangbinhan): used for load_checkpoint model_config.pretrained_config.tensor_parallel_rank = parallel_config.tensor_parallel_rank model_config.pretrained_config.tensor_parallel_degree = parallel_config.tensor_parallel_size @@ -998,6 +1008,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: enable_attention_dp_balance=args.enable_attention_dp_balance, attention_dp_time_out_iters=args.attention_dp_time_out_iters, eplb_config=eplb_config, + routing_replay_config=routing_replay_config, ) update_fd_config_for_mm(fd_config)