[RL] Support Rollout Routing Replay (#5321)

* [RL] Support Rollout Routing Replay

* add routing indices cache

* fix config bug and moe forward bug

* R3 Support GLM

* support eb4.5

* fix merge bug

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* add routing replay ci

* support glm topk

* support orther top_k

* fix ci bug

* pre-commit

* only support chatcmpl

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Yuanle Liu <yuanlehome@163.com>
This commit is contained in:
RAM
2025-12-05 20:01:33 +08:00
committed by GitHub
parent 8545b705ed
commit 96d2d4877b
24 changed files with 592 additions and 24 deletions

View File

@@ -1484,6 +1484,31 @@ class StructuredOutputsConfig:
return json.dumps({key: value for key, value in self.__dict__.items()})
class RoutingReplayConfig:
"""Configuration for Routing Replay used in RL training"""
def __init__(self, args) -> None:
self.enable_routing_replay: bool = False
self.routing_store_type: str = "local"
# Local routing store
self.local_store_dir: str = "./routing_replay_output"
# RDMA routing store
# TODO: Add RDMA routing store configuration attributes here when the feature is implemented.
if args is not None:
for key, value in args.items():
if hasattr(self, key) and value != "None":
setattr(self, key, value)
def to_json_string(self):
"""
Convert routing replay config to json string.
"""
return json.dumps({key: value for key, value in self.__dict__.items()})
class FDConfig:
"""
The configuration class which contains all fastdeploy-related configuration. This
@@ -1517,6 +1542,7 @@ class FDConfig:
early_stop_config: Optional[Dict[str, Any]] = None,
tool_parser: str = None,
test_mode=False,
routing_replay_config: Optional[RoutingReplayConfig] = None,
):
self.model_config: ModelConfig = model_config # type: ignore
self.cache_config: CacheConfig = cache_config # type: ignore
@@ -1533,6 +1559,7 @@ class FDConfig:
self.plas_attention_config: Optional[PlasAttentionConfig] = plas_attention_config
self.structured_outputs_config: StructuredOutputsConfig = structured_outputs_config
self.router_config: RouterConfig = router_config
self.routing_replay_config = routing_replay_config
# Initialize cuda graph capture list
max_capture_shape = self.scheduler_config.max_num_seqs

View File

@@ -35,6 +35,7 @@ from fastdeploy.config import (
PlasAttentionConfig,
PoolerConfig,
RouterConfig,
RoutingReplayConfig,
RunnerOption,
SpeculativeConfig,
StructuredOutputsConfig,
@@ -491,6 +492,11 @@ class EngineArgs:
Configuration for eplb.
"""
routing_replay_config: Optional[Dict[str, Any]] = None
"""
Flag to rollout routing replay(r3)
"""
def __post_init__(self):
"""
Post-initialization processing to set default tokenizer if not provided.
@@ -882,6 +888,12 @@ class EngineArgs:
default=EngineArgs.eplb_config,
help="Config of eplb.",
)
parallel_group.add_argument(
"--routing-replay-config",
type=json.loads,
default=EngineArgs.routing_replay_config,
help="Flag of rollout routing replay(r3).",
)
parallel_group.add_argument(
"--enable-chunked-moe",
action="store_true",
@@ -1235,6 +1247,14 @@ class EngineArgs:
eplb_args["enable_eplb"] = self.enable_eplb
return EPLBConfig(eplb_args)
def create_routing_repaly_config(self) -> RoutingReplayConfig:
""" """
routing_replay_args = asdict(self)
if self.routing_replay_config is not None:
for k, v in self.routing_replay_config.items():
routing_replay_args[k] = v
return RoutingReplayConfig(routing_replay_args)
def create_engine_config(self, port_availability_check=True) -> FDConfig:
"""
Create and return a Config object based on the current settings.
@@ -1278,6 +1298,7 @@ class EngineArgs:
graph_opt_cfg = self.create_graph_optimization_config()
plas_attention_config = self.create_plas_attention_config()
eplb_cfg = self.create_eplb_config()
routing_replay_config = self.create_routing_repaly_config()
router_config = RouterConfig(all_dict)
early_stop_cfg = self.create_early_stop_config()
@@ -1310,4 +1331,5 @@ class EngineArgs:
graph_opt_config=graph_opt_cfg,
plas_attention_config=plas_attention_config,
early_stop_config=early_stop_cfg,
routing_replay_config=routing_replay_config,
)

View File

@@ -568,6 +568,7 @@ class LLMEngine:
f" --logprobs_mode {self.cfg.model_config.logprobs_mode}"
f" --max_logprobs {self.cfg.model_config.max_logprobs}"
f" --eplb_config '{self.cfg.eplb_config.to_json_string()}'"
f" --routing_replay_config '{self.cfg.routing_replay_config.to_json_string()}'"
)
if self.cfg.structured_outputs_config.logits_processors is not None:
arguments += f" --logits-processors {' '.join(self.cfg.structured_outputs_config.logits_processors)}"

View File

@@ -142,6 +142,8 @@ class ForwardMeta:
caches: Optional[list[paddle.Tensor]] = None
# Flag of profile run
is_dummy_or_profile_run: bool = False
# Routing Replay table buffer
routing_replay_table: Optional[paddle.Tensor] = None
# chunked MoE related
moe_num_chunk: int = 1

View File

@@ -14,6 +14,8 @@
# limitations under the License.
"""
from typing import Callable
import paddle
from paddle import nn
@@ -101,6 +103,7 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Triton compute Fused MoE.
@@ -117,6 +120,8 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
scores += layer.gate_correction_bias
topk_weights, topk_ids = paddle.topk(scores, k=top_k, axis=-1, sorted=False)
topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdim=True)
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_ids)
intermediate_cache1 = paddle.empty(
[token_num * top_k, moe_intermediate_size * 2],

View File

@@ -16,6 +16,7 @@
import multiprocessing
import os
from typing import Callable
import numpy as np
import paddle
@@ -182,6 +183,7 @@ class GCUFusedMoeMethod(UnquantizedFusedMoEMethod):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Paddle gcu compute Fused MoE.
@@ -194,6 +196,7 @@ class GCUFusedMoeMethod(UnquantizedFusedMoEMethod):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Apply the EP prefill method.
@@ -205,6 +208,7 @@ class GCUFusedMoeMethod(UnquantizedFusedMoEMethod):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Apply the EP decoder method.
@@ -216,6 +220,7 @@ class GCUFusedMoeMethod(UnquantizedFusedMoEMethod):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Paddle Cutlass compute Fused MoE.
@@ -381,6 +386,7 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Paddle gcu compute Fused MoE.

View File

@@ -14,6 +14,8 @@
# limitations under the License.
"""
from typing import Callable
import paddle
from paddle import nn
@@ -245,6 +247,7 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Triton compute Fused MoE.
@@ -274,6 +277,9 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
True, # apply_norm_weight
False,
)
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_ids)
up_gate_proj_out = paddle.empty(
[token_num * top_k, moe_intermediate_size * 2],
dtype=x.dtype,

View File

@@ -15,6 +15,7 @@
"""
from abc import abstractmethod
from typing import Callable
import paddle
from paddle import nn
@@ -163,6 +164,7 @@ class MoEMethodBase(QuantMethodBase):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Apply the EP prefill method.
@@ -175,6 +177,7 @@ class MoEMethodBase(QuantMethodBase):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Apply the EP decoder method.
@@ -187,6 +190,7 @@ class MoEMethodBase(QuantMethodBase):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Paddle Cutlass compute Fused MoE.
@@ -198,6 +202,7 @@ class MoEMethodBase(QuantMethodBase):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Paddle Cutlass compute Fused MoE.
@@ -207,13 +212,13 @@ class MoEMethodBase(QuantMethodBase):
if layer.fd_config.model_config.moe_phase.phase == "prefill":
if layer.fd_config.scheduler_config.splitwise_role == "mixed" and is_moe_start_layer:
self.ep_prefill_runner.clean_low_latency_buffer()
return self.apply_ep_prefill(layer, x, gate)
return self.apply_ep_prefill(layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
else:
if layer.fd_config.scheduler_config.splitwise_role == "mixed" and is_moe_start_layer:
self.ep_decoder_runner.clean_low_latency_buffer()
return self.apply_ep_decode(layer, x, gate)
return self.apply_ep_decode(layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
else:
return self.apply_tp(layer, x, gate)
return self.apply_tp(layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
class UnquantizedFusedMoEMethod(MoEMethodBase):

View File

@@ -14,6 +14,8 @@
# limitations under the License.
"""
from typing import Callable
import paddle
from paddle import nn
from paddle.nn.quant import weight_quantize
@@ -132,6 +134,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Apply the EP prefill method.
@@ -148,8 +151,13 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
handle,
event,
) = self.ep_prefill_runner.dispatch(x, topk_idx, topk_weights)
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_idx)
if self.ep_prefill_runner.ep_engine.async_finish:
event.current_stream_wait()
token_all_num = sum(recv_num_tokens_per_expert_list)
# 3. Compute ffn
@@ -217,6 +225,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Apply the EP decoder method.
@@ -225,6 +234,10 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
estimate_total_token_nums = gate_out.shape[0] * layer.top_k
# 1. Select topk experts and weights
topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out)
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_idx)
expertwise_scale = None
if hasattr(layer, "up_gate_proj_in_scale_all_experts"): # only use in w4a8
expertwise_scale = getattr(layer, "up_gate_proj_in_scale_all_experts", None)
@@ -269,6 +282,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Paddle Cutlass compute Fused MoE.
@@ -369,6 +383,9 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
if hasattr(layer, "up_gate_proj_in_scale"):
dequant_scale = None
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_idx)
if not layer.with_bias and self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8":
# only w4a8 need expert_idx_per_token
# Other need not this tensor, so we make it None.

View File

@@ -14,6 +14,8 @@
# limitations under the License.
"""
from typing import Callable
import paddle
from paddle import nn
from paddle.distributed.communication import deep_ep
@@ -139,6 +141,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Apply the EP prefill method.
@@ -147,6 +150,10 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
# 1. Select topk experts and weights
topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out)
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_idx)
# 2. Dynamic compute blockwise quantization scales
x, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
x, self.quant_config.weight_block_size[0]
@@ -264,6 +271,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Apply the EP decoder method.
@@ -271,6 +279,10 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
gate_out = gate(x.cast("float32"))
# 1. Select topk experts and weights
topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out)
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_idx)
# 2. EP Dispatch
permute_input, token_nums_per_expert, handle = self.ep_decoder_runner.dispatch(
x, topk_idx, topk_weights, use_fp8=True
@@ -335,6 +347,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Paddle Use DeepGemm compute Fused MoE.
@@ -363,6 +376,9 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
False,
)
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_ids)
tmp = count_tokens_per_expert_func(topk_ids, layer.num_experts)
recv_x, recv_x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(x, 128)

View File

@@ -14,6 +14,8 @@
# limitations under the License.
"""
from typing import Callable
import paddle
from paddle import nn
@@ -239,6 +241,7 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Marlin compute Fused MoE.
@@ -273,6 +276,9 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
False,
)
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_ids)
block_size_m = 64
for m in [8, 16, 32, 48, 64]:

View File

@@ -14,6 +14,8 @@
# limitations under the License.
"""
from typing import Callable
import paddle
from paddle import nn
@@ -282,6 +284,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Triton compute Fused MoE.
@@ -314,6 +317,10 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
True, # apply_norm_weight,
False,
)
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_ids)
up_gate_proj_out = paddle.empty(
[token_num * top_k, moe_intermediate_size * 2],
dtype=x.dtype,
@@ -664,6 +671,7 @@ class Wfp8Afp8MoEMethod(QuantMethodBase):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Triton compute Fused MoE.
@@ -724,6 +732,9 @@ class Wfp8Afp8MoEMethod(QuantMethodBase):
* ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]),
)
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_ids)
up_gate_proj_out = paddle.empty(
[token_num * top_k, moe_intermediate_size * 2],
dtype=x.dtype,
@@ -953,6 +964,7 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Triton compute Fused MoE.
@@ -974,6 +986,9 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
False,
)
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_ids)
up_gate_proj_out = paddle.empty(
[token_num * top_k, moe_intermediate_size * 2],
dtype=x.dtype,
@@ -1466,6 +1481,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Triton compute Fused MoE.
@@ -1488,6 +1504,8 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
True, # apply_norm_weight
False,
)
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_ids)
config = {
"BLOCK_SIZE_M": 64,

View File

@@ -14,6 +14,8 @@
# limitations under the License.
"""
from typing import Callable
import paddle
from paddle import nn
@@ -261,6 +263,7 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Use Wint2 Triton Fusedmoe compute Fused MoE.
@@ -288,6 +291,9 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
topk_only_mode=False,
)
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_idx)
ffn_out = fastdeploy.model_executor.ops.gpu.moe_expert_ffn_wint2(
permute_input,
token_nums_per_expert,
@@ -328,6 +334,7 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Use Wint2 Triton Fusedmoe compute Fused MoE.
@@ -343,6 +350,9 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
False,
)
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_ids)
num_tokens, K = x.shape
E, _, N = layer.up_gate_proj_weight.shape
M = num_tokens

View File

@@ -14,7 +14,8 @@
# limitations under the License.
"""
from typing import Optional
from functools import partial
from typing import Callable, Optional
import paddle
from paddle import nn
@@ -26,6 +27,9 @@ from fastdeploy.distributed.communication import (
tensor_model_parallel_all_reduce_custom,
)
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.layers.moe.routing_indices_cache import (
save_routing_to_buffer,
)
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.utils import h2d_copy, slice_fn
from fastdeploy.platforms import current_platform
@@ -226,7 +230,7 @@ class FusedMoE(nn.Layer):
self.is_rearrange = False
if self.ep_size > 1:
self.quant_method.init_ep(self)
self.enable_routing_replay = fd_config.routing_replay_config.enable_routing_replay
# Merge normal and RL build model
if gate_correction_bias is not None:
self.gate_correction_bias = gate_correction_bias
@@ -600,7 +604,7 @@ class FusedMoE(nn.Layer):
else:
self.quant_method.process_loaded_weights(self, state_dict)
def forward_split_allgather(self, x: paddle.Tensor, gate: nn.Layer):
def forward_split_allgather(self, x: paddle.Tensor, gate: nn.Layer, topk_ids_hookfunc: Callable = None):
"""
Forward split allgather function.
"""
@@ -615,14 +619,14 @@ class FusedMoE(nn.Layer):
if end_offset > token_num:
end_offset = token_num
part_x[: (end_offset - start_offset), :] = x[start_offset:end_offset, :]
out = self.quant_method.apply(self, part_x, gate)
out = self.quant_method.apply(self, part_x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
multi_outs = paddle.zeros([token_num_per_rank * self.attn_tp_size, x.shape[1]], dtype=x.dtype)
paddle.distributed.all_gather(multi_outs, out, self.tp_group)
out = multi_outs[:token_num, :]
return out
def forward(self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta):
def forward(self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta = None):
"""
Defines the forward computation of the moe layer.
@@ -633,6 +637,21 @@ class FusedMoE(nn.Layer):
Tensor: Output tensor.s
"""
topk_ids_hookfunc = None
if self.enable_routing_replay:
if forward_meta is not None: # forward_meta is None when execute empty_input_forward
topk_ids_hookfunc = partial(
save_routing_to_buffer,
routing_replay_table=forward_meta.routing_replay_table,
batch_id_per_token=forward_meta.batch_id_per_token,
seq_lens_decoder=forward_meta.seq_lens_decoder,
cu_seqlens_q=forward_meta.cu_seqlens_q,
layer_idx=self.layer_idx,
tp_size=self.fd_config.parallel_config.tensor_parallel_size,
ep_size=self.fd_config.parallel_config.expert_parallel_size,
tp_group=self.fd_config.parallel_config.tp_group,
)
token_num = x.shape[0]
if (
self.ep_size > 1
@@ -640,11 +659,16 @@ class FusedMoE(nn.Layer):
and (not self.fd_config.parallel_config.use_sequence_parallel_moe)
and token_num >= self.attn_tp_size
):
out = self.forward_split_allgather(x, gate)
out = self.forward_split_allgather(x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
elif self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.enable_chunked_moe:
out = self.forward_chunked_moe(x, gate, forward_meta)
out = self.forward_chunked_moe(
x,
gate,
forward_meta,
topk_ids_hookfunc=topk_ids_hookfunc,
)
else:
out = self.forward_normal(x, gate)
out = self.forward_normal(x, gate, forward_meta, topk_ids_hookfunc=topk_ids_hookfunc)
if self.reduce_results and self.tp_size > 1:
if current_platform.is_intel_hpu():
@@ -653,7 +677,9 @@ class FusedMoE(nn.Layer):
out = tensor_model_parallel_all_reduce(out, self.tp_group)
return out
def forward_chunked_moe(self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta):
def forward_chunked_moe(
self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta, topk_ids_hookfunc: Callable = None
):
"""
Split input to multi chunk to reduce the memory usage of moe.
@@ -677,21 +703,25 @@ class FusedMoE(nn.Layer):
for i in range(forward_meta.max_moe_num_chunk):
if i < forward_meta.moe_num_chunk:
out_split_list[i] = self.quant_method.apply(self, x_split_list[i], gate)
out_split_list[i] = self.quant_method.apply(
self, x_split_list[i], gate, topk_ids_hookfunc=topk_ids_hookfunc
)
else:
# just need to use real data to infer max_moe_num_chunk times.
self.quant_method.apply(self, fake_x, gate)
self.quant_method.apply(self, fake_x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
out = paddle.concat(out_split_list, axis=0)
else:
# when only one chunk, just need to use real data to infer once.
out = self.quant_method.apply(self, x, gate)
out = self.quant_method.apply(self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
for i in range(forward_meta.max_moe_num_chunk - 1):
self.quant_method.apply(self, fake_x, gate)
self.quant_method.apply(self, fake_x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
return out
def forward_normal(self, x: paddle.Tensor, gate: nn.Layer):
def forward_normal(
self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta, topk_ids_hookfunc: Callable = None
):
"""
Normal mode of forward.
@@ -702,5 +732,5 @@ class FusedMoE(nn.Layer):
Tensor: Output tensor.s
"""
out = self.quant_method.apply(self, x, gate)
out = self.quant_method.apply(self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
return out

View File

@@ -0,0 +1,346 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import copy
import os
import shutil
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
import paddle
import paddle.distributed as dist
import triton
import triton.language as tl
from fastdeploy.config import FDConfig
@triton.jit
def _save_routing_kernel(
ROUTING_REPLAY_TABLE_PTR,
TOPK_IDS_PTR,
BATCH_ID_PER_TOKEN_PTR,
CU_SEQLENS_Q_PTR,
SEQ_LENS_DECODER_PTR,
LAYER_IDX,
TOKEN_NUM,
TOP_K,
NUM_HIDDEN_LAYERS,
MAX_MODEL_LEN,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
token_offsets = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
token_mask = token_offsets < TOKEN_NUM
k_offsets = tl.arange(0, BLOCK_SIZE_K)
k_mask = k_offsets < TOP_K
topk_ids_ptrs = TOPK_IDS_PTR + token_offsets[:, None] * TOP_K + k_offsets[None, :]
# [BLOCK_SIZE_M, BLOCK_SIZE_K]
load_mask = token_mask[:, None] & k_mask[None, :]
topk_vals = tl.load(topk_ids_ptrs, mask=load_mask)
batch_ids = tl.load(BATCH_ID_PER_TOKEN_PTR + token_offsets, mask=token_mask)
pad_mask = token_mask & (batch_ids != -1)
# [0, 3, 4, 10, 12][0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 3, 3]
# -> [0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 10, 10]
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] - [0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 10, 10]
# -> [0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 0, 1]
start_offsets = tl.load(CU_SEQLENS_Q_PTR + batch_ids, mask=pad_mask)
token_relative_index = token_offsets - start_offsets
# [BLOCK_SIZE_M]
len_decoder = tl.load(SEQ_LENS_DECODER_PTR + batch_ids, mask=pad_mask)
token_seq_pos = len_decoder + token_relative_index
STRIDE_BUF_SEQ = NUM_HIDDEN_LAYERS * MAX_MODEL_LEN * TOP_K
STRIDE_BUF_LAYER = MAX_MODEL_LEN * TOP_K
STRIDE_BUF_TOKEN = TOP_K
# [BLOCK_SIZE_M, BLOCK_SIZE_K]
output_ptrs = (
ROUTING_REPLAY_TABLE_PTR
+ batch_ids[:, None] * STRIDE_BUF_SEQ
+ LAYER_IDX * STRIDE_BUF_LAYER
+ token_seq_pos[:, None] * STRIDE_BUF_TOKEN
+ k_offsets[None, :]
)
pos_mask = token_seq_pos < MAX_MODEL_LEN
pos_mask = pos_mask & pad_mask
# [BLOCK_SIZE_M, BLOCK_SIZE_K]
pos_mask = pos_mask[:, None] & k_mask[None, :]
final_mask = load_mask & pos_mask
tl.store(output_ptrs, topk_vals, mask=final_mask)
def save_routing_to_buffer(
routing_replay_table: paddle.Tensor, # [max_num_seqs, num_layers, max_len, top_k]
topk_ids: paddle.Tensor, # [token_num, top_k]
batch_id_per_token: paddle.Tensor, # [token_num, 1]
seq_lens_decoder: paddle.Tensor, # [max_num_seqs, 1]
cu_seqlens_q: paddle.Tensor, # [max_num_seqs + 1, 1]
layer_idx: int,
tp_size: int,
ep_size: int,
tp_group: dist.communication.group.Group,
):
if tp_size > 1 and ep_size > 1:
token_num_per_rank = topk_ids.shape[0]
topk_ids_all = paddle.zeros([token_num_per_rank * tp_size, topk_ids.shape[1]], dtype=topk_ids.dtype)
paddle.distributed.all_gather(topk_ids_all, topk_ids, tp_group)
topk_ids = topk_ids_all[: batch_id_per_token.shape[0], :]
token_num, top_k = topk_ids.shape
max_num_seqs, num_hidden_layers, max_model_len, _ = routing_replay_table.shape
assert token_num > 0
assert topk_ids.shape[1] == routing_replay_table.shape[3], (topk_ids.shape[1], routing_replay_table.shape[3])
assert batch_id_per_token.shape[0] == token_num, (batch_id_per_token.shape[0], token_num)
assert seq_lens_decoder.shape[0] == max_num_seqs, (seq_lens_decoder.shape[0], max_num_seqs)
BLOCK_SIZE_M = 128
BLOCK_SIZE_K = triton.next_power_of_2(top_k) # top_k
grid = (triton.cdiv(token_num, BLOCK_SIZE_M),)
_save_routing_kernel[grid](
routing_replay_table,
topk_ids,
batch_id_per_token,
cu_seqlens_q,
seq_lens_decoder,
LAYER_IDX=layer_idx,
TOKEN_NUM=token_num,
TOP_K=top_k,
NUM_HIDDEN_LAYERS=num_hidden_layers,
MAX_MODEL_LEN=max_model_len,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_K=BLOCK_SIZE_K,
)
class RoutingReplayManager:
"""Request level routing replay table manager"""
def __init__(
self,
fd_config: FDConfig,
):
self.max_num_seqs = fd_config.scheduler_config.max_num_seqs
self.max_model_len = fd_config.model_config.max_model_len
self.num_moe_layers = fd_config.model_config.num_hidden_layers - fd_config.model_config.moe_layer_start_index
if fd_config.model_config.architectures[0] == "Glm4MoeForCausalLM":
self.moe_top_k = fd_config.model_config.num_experts_per_tok
else:
self.moe_top_k = fd_config.model_config.moe_k
self.tp_rank = fd_config.parallel_config.tensor_parallel_rank
self.routing_store = get_routing_store(fd_config=fd_config)
self.routing_batch_to_request: Dict[int, str] = {}
self.routing_replay_table = paddle.full(
shape=[self.max_num_seqs, self.num_moe_layers, self.max_model_len, self.moe_top_k],
fill_value=-1,
dtype="int32",
)
def register_request(self, batch_id: int, request_id: str):
"""
Register a new request to routing replay table
Args:
batch_id: The batch ID of this request
request_id: The global ID of the request is usually executed by the training process in RL
"""
# Save requests that have been finished for the current slot
if batch_id in self.routing_batch_to_request:
pre_request_id = self._deregister_request(batch_id)
self._put_request_to_store(batch_id, pre_request_id)
# Register the new request
self.routing_batch_to_request[batch_id] = request_id
def _deregister_request(self, batch_id: int) -> str:
"""
Deregister a request from routing replay table
"""
assert batch_id in self.routing_batch_to_request
return self.routing_batch_to_request.pop(batch_id)
def _put_request_to_store(
self,
batch_id: int,
request_id: str,
):
if self.tp_rank == 0:
batch_buffer = self.routing_replay_table[batch_id]
for layer_id in range(self.num_moe_layers):
layer_buffer = batch_buffer[layer_id]
rollout_id = self.split_request_id(request_id)
self.routing_store.put(routing_indices=layer_buffer, rollout_id=rollout_id, layer_idx=layer_id)
self._clear_table_slot(batch_id)
def put_table_to_store(self):
"""Put the routing table"""
batch_ids = copy.deepcopy(list(self.routing_batch_to_request.keys()))
for batch_id in batch_ids:
request_id = self._deregister_request(batch_id)
self._put_request_to_store(batch_id, request_id)
def _clear_table_slot(self, batch_id: int):
assert 0 <= batch_id < self.max_num_seqs
self.routing_replay_table[batch_id].fill_(-1)
def clear_routing_table(self):
"""Clear all slots of the routing replay table"""
self.routing_replay_table.fill_(-1)
def _clear_store(self):
"""Clear routing store"""
self.routing_store.clear_store()
def _clear_request_of_store(self, request_id):
"""Clear one request of routing store"""
rollout_id = self.split_request_id(request_id)
for layer_idx in range(self.num_moe_layers):
self.routing_store.clear(rollout_id=rollout_id, layer_idx=layer_idx)
def get_request_from_store(self, request_id: str) -> List[paddle.Tensor]:
"""Get the routing indices of the request from store"""
routing_list = []
rollout_id = self.split_request_id(request_id)
for layer_idx in range(self.num_moe_layers):
one_layer_routing = self.routing_store.get(rollout_id, layer_idx)
routing_list.append(one_layer_routing)
return routing_list
def get_routing_table(self) -> paddle.Tensor:
return self.routing_replay_table
def split_request_id(self, request_id: str):
"""Split the request id to get rollout id"""
chat_type, tmp_str = request_id.split("-", 1)
# NOTE(gongshaotian): only support chatcmpl now
# assert chat_type == "chatcmpl"
reversed_tmp_str = tmp_str[::-1].split("-", 5)
rollout_id = reversed_tmp_str[-1][::-1]
return rollout_id
class RoutingStoreBase(ABC):
"""Base class for routing store"""
def __init__(self, fd_config: FDConfig) -> None:
self.fd_config = fd_config
@abstractmethod
def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: Optional[int] = None) -> None:
"""Put the routing indices into store"""
raise NotImplementedError
@abstractmethod
def get(self, rollout_id: str, layer_idx: Optional[int] = None) -> paddle.Tensor:
"""Get the routing indices from store"""
raise NotImplementedError
@abstractmethod
def clear(self, rollout_id: str, layer_idx: Optional[int] = None) -> None:
"""Clear the routing indices of the request"""
raise NotImplementedError
@abstractmethod
def clear_store(
self,
):
"""Clear the routing indices store"""
raise NotImplementedError
class RoutingStoreLocal(RoutingStoreBase):
"""Routing Store using local memory"""
def __init__(self, fd_config) -> None:
super().__init__(fd_config=fd_config)
self.local_store_dir = fd_config.routing_replay_config.local_store_dir
def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: int) -> None:
"""Put the routing indices into store"""
dir_path = os.path.join(self.local_store_dir, f"{rollout_id}")
os.makedirs(dir_path, exist_ok=True)
file_path = os.path.join(dir_path, f"layer_{layer_idx}.pdtensor")
paddle.save(routing_indices, file_path)
def get(
self,
rollout_id: str,
layer_idx: int = None,
) -> paddle.Tensor:
"""Get the routing indices from store"""
dir_path = os.path.join(self.local_store_dir, f"{rollout_id}")
file_path = os.path.join(dir_path, f"layer_{layer_idx}.pdtensor")
assert os.path.exists(file_path), f"File not found: {file_path}"
layer_routing_indices = paddle.load(file_path)
return layer_routing_indices
def clear(
self,
rollout_id: str,
layer_idx: int = None,
) -> None:
"""Clear the routing indices of the request"""
dir_path = os.path.join(self.local_store_dir, f"{rollout_id}")
file_path = os.path.join(dir_path, f"layer_{layer_idx}.pdtensor")
assert os.path.exists(file_path), f"File not found: {file_path}"
os.remove(file_path)
# Delete empty directory
if len(os.listdir(dir_path)) == 0:
os.rmdir(dir_path)
def clear_store(self):
"""Clear the routing indices store"""
if os.path.isdir(self.local_store_dir):
for file_name in os.listdir(self.local_store_dir):
file_path = os.path.join(self.local_store_dir, file_name)
shutil.rmtree(file_path)
class RoutingStoreRDMA(RoutingStoreBase):
"""Routing Store using RDMA"""
def __init__(self) -> None:
super().__init__()
def get_routing_store(fd_config: FDConfig) -> RoutingStoreBase:
if fd_config.routing_replay_config.routing_store_type == "local":
return RoutingStoreLocal(fd_config=fd_config)
elif fd_config.routing_replay_config.routing_store_type == "rdma":
return RoutingStoreRDMA(fd_config=fd_config)
else:
raise ValueError(
f"Invalid routing store type: '{fd_config.routing_replay_config.routing_store_type}'. "
"Valid types are: 'local', 'rdma'"
)

View File

@@ -161,7 +161,7 @@ class Glm4Moe(nn.Layer):
reduce_results=False,
)
def forward(self, x, forward_meta):
def forward(self, x, forward_meta: ForwardMeta = None):
shared_experts_out = self.shared_experts(x)
out = self.experts(x, self.gate, forward_meta)
out = out + shared_experts_out
@@ -306,10 +306,7 @@ class Glm4MoeDecoderLayer(nn.Layer):
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(
hidden_states,
forward_meta,
)
hidden_states = self.mlp(hidden_states, forward_meta)
return hidden_states, residual

View File

@@ -65,6 +65,7 @@ class RolloutModelConfig:
data_parallel_size: int = 1,
num_nextn_predict_layers: int = 0,
eplb_config: str = {},
routing_replay_config: str = None,
):
# Required parameters
self.model = model_name_or_path
@@ -113,6 +114,7 @@ class RolloutModelConfig:
self.plas_attention_config = plas_attention_config
self.num_nextn_predict_layers = num_nextn_predict_layers
self.eplb_config = eplb_config
self.routing_replay_config = routing_replay_config
def __str__(self):
return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items())

View File

@@ -45,6 +45,9 @@ from fastdeploy.model_executor.layers.attention.append_attn_backend import (
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend,
)
from fastdeploy.model_executor.layers.moe.routing_indices_cache import (
RoutingReplayManager,
)
from fastdeploy.model_executor.layers.rotary_embedding import get_rope, get_rope_3d
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.model_executor.layers.sample.sampler import Sampler, SpeculativeSampler
@@ -202,6 +205,11 @@ class GPUModelRunner(ModelRunnerBase):
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.parallel_config.engine_worker_queue_port)
logger.info(f"queue id is {str(self.parallel_config.engine_worker_queue_port)}")
# Rollout routing replay config
self.routing_replay_manager = None
if self.fd_config.routing_replay_config.enable_routing_replay:
self.routing_replay_manager = RoutingReplayManager(fd_config=self.fd_config)
self.zmq_client = None
self.async_output_queue = None
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
@@ -648,6 +656,7 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = 0
self.share_inputs["prompt_lens"][idx : idx + 1] = len(input_ids)
self.share_inputs["is_block_step"][idx : idx + 1] = False
self.share_inputs["is_chunk_step"][idx : idx + 1] = prefill_end_index < len(input_ids)
self.share_inputs["step_idx"][idx : idx + 1] = (
len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0
)
@@ -656,6 +665,12 @@ class GPUModelRunner(ModelRunnerBase):
if request.sampling_params is not None and request.sampling_params.prompt_logprobs is not None:
self.prompt_logprobs_reqs[request.request_id] = request
has_prefill_task = True
# Routing Replay
if self.fd_config.routing_replay_config.enable_routing_replay:
if prefill_start_index == 0:
self.routing_replay_manager.register_request(batch_id=idx, request_id=request.request_id)
if (
self.fd_config.scheduler_config.splitwise_role == "decode"
): # In PD, we continue to decode after P generate first token
@@ -1148,6 +1163,7 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["bad_tokens_len"] = paddle.full([max_num_seqs], 1, dtype="int64")
self.share_inputs["next_tokens"] = paddle.full([max_num_seqs, 1], -1, dtype="int64")
self.share_inputs["is_block_step"] = paddle.full([max_num_seqs], False, dtype="bool")
self.share_inputs["is_chunk_step"] = paddle.full([max_num_seqs], False, dtype="bool").cpu()
self.share_inputs["encoder_block_lens"] = paddle.full([max_num_seqs], 0, dtype="int32")
self.share_inputs["step_block_list"] = paddle.full([max_num_seqs], -1, dtype="int32")
self.share_inputs["step_lens"] = paddle.full([1], 0, dtype="int32")
@@ -1418,6 +1434,9 @@ class GPUModelRunner(ModelRunnerBase):
Initialize forward meta, attention meta data and update some config.
"""
# Initialize forward meta
routing_replay_table = None
if self.routing_replay_manager is not None:
routing_replay_table = self.routing_replay_manager.get_routing_table()
self.forward_meta = ForwardMeta(
ids_remove_padding=self.share_inputs["ids_remove_padding"],
rotary_embs=self.share_inputs["rope_emb"],
@@ -1444,6 +1463,7 @@ class GPUModelRunner(ModelRunnerBase):
kv_batch_ids=self.share_inputs["kv_batch_ids"],
kv_tile_ids_per_batch=self.share_inputs["kv_tile_ids_per_batch"],
kv_num_blocks_x_cpu=self.share_inputs["kv_num_blocks_x_cpu"],
routing_replay_table=routing_replay_table,
)
dist_status = self.collect_distributed_status()
@@ -1932,6 +1952,9 @@ class GPUModelRunner(ModelRunnerBase):
if int((self.share_inputs["seq_lens_this_time"] > 0).sum()) == 0:
break
if self.fd_config.routing_replay_config.enable_routing_replay:
self.routing_replay_manager.clear_routing_table()
def _update_chunked_prefill(self, tasks):
"""
Update chunked prefill related parameters
@@ -2429,6 +2452,15 @@ class GPUModelRunner(ModelRunnerBase):
self.speculative_config.num_speculative_tokens,
)
# Routing replay
if self.fd_config.routing_replay_config.enable_routing_replay:
if (
not self.exist_prefill()
and not self.exist_decode()
and self.share_inputs["is_block_step"].sum() == 0
and self.share_inputs["is_chunk_step"].sum() == 0
):
self.routing_replay_manager.put_table_to_store()
return None
def _pool(self, hidden_states: paddle.Tensor, num_running_requests: int) -> Optional[ModelRunnerOutput]:

View File

@@ -38,6 +38,7 @@ from fastdeploy.config import (
ModelConfig,
ParallelConfig,
PlasAttentionConfig,
RoutingReplayConfig,
SpeculativeConfig,
StructuredOutputsConfig,
)
@@ -885,6 +886,13 @@ def parse_args():
help="EPLB Configuration.",
)
parser.add_argument(
"--routing_replay_config",
type=json.loads,
default=None,
help="Configation of Rollout Routing Replay.",
)
args = parser.parse_args()
return args
@@ -944,6 +952,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
eplb_config = EPLBConfig(args.eplb_config)
structured_outputs_config: StructuredOutputsConfig = StructuredOutputsConfig(args=vars(args))
routing_replay_config = RoutingReplayConfig(args.routing_replay_config)
# Note(tangbinhan): used for load_checkpoint
model_config.pretrained_config.tensor_parallel_rank = parallel_config.tensor_parallel_rank
@@ -1003,6 +1012,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
plas_attention_config=plas_attention_config,
structured_outputs_config=structured_outputs_config,
eplb_config=eplb_config,
routing_replay_config=routing_replay_config,
)
update_fd_config_for_mm(fd_config)
if fd_config.load_config.load_choices == "default_v1" and not v1_loader_support(fd_config):

View File

@@ -90,7 +90,7 @@ class MockAttentionBackend:
class MockQuantMethod:
def apply(self, layer, x, gate):
def apply(self, layer, x, gate, topk_ids_hookfunc=None):
return x
@@ -129,6 +129,7 @@ class TestChunkedMoE(unittest.TestCase):
model_runner.speculative_decoding = False
model_runner._init_share_inputs(mock_fd_config.scheduler_config.max_num_seqs)
model_runner.share_inputs["caches"] = None
model_runner.routing_replay_manager = None
if dist.get_rank() == 0:
model_runner.share_inputs["ids_remove_padding"] = paddle.ones([10])
@@ -148,6 +149,7 @@ class TestChunkedMoE(unittest.TestCase):
fused_moe.fd_config = mock_fd_config
fused_moe.quant_method = MockQuantMethod()
fused_moe.enable_routing_replay = None
return fused_moe
def run_model_runner(self):

View File

@@ -78,6 +78,8 @@ def setup_and_run_server():
"wint4",
"--graph-optimization-config",
'{"cudagraph_capture_sizes": [1], "use_cudagraph":true}',
"--routing-replay-config",
'{"enable_routing_replay":true, "routing_store_type":"local", "local_store_dir":"./routing_replay_output"}',
]
# Start subprocess in new process group

View File

@@ -31,6 +31,7 @@ from fastdeploy.config import (
LoadConfig,
ModelConfig,
ParallelConfig,
RoutingReplayConfig,
)
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
from fastdeploy.model_executor.layers.quantization.block_wise_fp8 import (
@@ -476,6 +477,7 @@ class FuseMoEWrapper(paddle.nn.Layer):
graph_opt_config=GraphOptimizationConfig({}),
load_config=LoadConfig({}),
ips=",".join(["0"] * nnodes),
routing_replay_config=RoutingReplayConfig({}),
)
self.fd_config.parallel_config.tp_group = None
self.fd_config.parallel_config.tensor_parallel_rank = tp_rank

View File

@@ -13,6 +13,7 @@ from fastdeploy.config import (
LoadConfig,
ModelConfig,
ParallelConfig,
RoutingReplayConfig,
)
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
from fastdeploy.model_executor.layers.quantization.w4a8 import W4A8Config
@@ -59,6 +60,7 @@ class FuseMoEWrapper(paddle.nn.Layer):
graph_opt_config=GraphOptimizationConfig({}),
load_config=LoadConfig({}),
ips=",".join(["0"] * nnodes),
routing_replay_config=RoutingReplayConfig({}),
)
self.fd_config.parallel_config.tp_group = None
self.fd_config.parallel_config.tensor_parallel_rank = tp_rank

View File

@@ -13,6 +13,7 @@ from fastdeploy.config import (
LoadConfig,
ModelConfig,
ParallelConfig,
RoutingReplayConfig,
)
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
from fastdeploy.model_executor.layers.quantization.w4afp8 import W4AFP8Config
@@ -65,6 +66,7 @@ class FuseMoEWrapper(paddle.nn.Layer):
graph_opt_config=GraphOptimizationConfig({}),
load_config=LoadConfig({}),
ips=",".join(["0"] * nnodes),
routing_replay_config=RoutingReplayConfig({}),
)
self.fd_config.parallel_config.tp_group = None
self.fd_config.parallel_config.tensor_parallel_rank = tp_rank