[Cherry-Pick][RL] Support Rollout Routing Replay (#5166)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled

* support r3

* update

* support tp>1&&ep>1

* support cudagraph padding

* support all backends

* replace env with options

* modularize

* update

* Add RoutingStore and refine code

* add routing replay cofig

* add routing repaly config

* success run routing store

* convert request id as rollout id

* fix rollout config bug

* unify code

* use rollout_id to replace request_id in routing store

* delete code

---------

Co-authored-by: yuanlehome <yuanlehome@163.com>
This commit is contained in:
RAM
2025-12-04 16:35:30 +08:00
committed by GitHub
parent 74ba637b6b
commit fbed0ef851
18 changed files with 511 additions and 5 deletions

View File

@@ -1163,6 +1163,31 @@ class CommitConfig:
logger.info("=============================================================")
class RoutingReplayConfig:
"""Configuration for Routing Replay used in RL training"""
def __init__(self, args) -> None:
self.enable_routing_replay: bool = False
self.routing_store_type: str = "local"
# Local routing store
self.local_store_dir: str = "./routing_replay_output"
# RDMA routing store
pass
if args is not None:
for key, value in args.items():
if hasattr(self, key) and value != "None":
setattr(self, key, value)
def to_json_string(self):
"""
Convert routing replay config to json string.
"""
return json.dumps({key: value for key, value in self.__dict__.items()})
class FDConfig:
"""
The configuration class which contains all fastdeploy-related configuration. This
@@ -1206,6 +1231,7 @@ class FDConfig:
test_mode=False,
enable_attention_dp_balance: bool = False,
attention_dp_time_out_iters: int = 0,
routing_replay_config: Optional[RoutingReplayConfig] = None,
):
self.model_config: ModelConfig = model_config # type: ignore
self.cache_config: CacheConfig = cache_config # type: ignore
@@ -1221,8 +1247,10 @@ class FDConfig:
self.cache_config: CacheConfig = cache_config # type: ignore
self.eplb_config: Optional[EPLBConfig] = eplb_config
self.moba_attention_config: Optional[MobaAttentionConfig] = moba_attention_config
self.routing_replay_config = routing_replay_config
self.enable_attention_dp_balance = enable_attention_dp_balance
self.attention_dp_time_out_iters = attention_dp_time_out_iters
# Initialize cuda graph capture list
max_capture_shape = self.parallel_config.max_num_seqs
if self.speculative_config is not None and self.speculative_config.method == "mtp":

View File

@@ -33,6 +33,7 @@ from fastdeploy.config import (
MobaAttentionConfig,
ModelConfig,
ParallelConfig,
RoutingReplayConfig,
SpeculativeConfig,
TaskOption,
)
@@ -421,6 +422,11 @@ class EngineArgs:
Configuration for eplb.
"""
routing_replay_config: Optional[Dict[str, Any]] = None
"""
Flag to rollout routing replay(r3)
"""
def __post_init__(self):
"""
Post-initialization processing to set default tokenizer if not provided.
@@ -731,6 +737,12 @@ class EngineArgs:
default=EngineArgs.eplb_config,
help="Config of eplb.",
)
parallel_group.add_argument(
"--routing-replay-config",
type=json.loads,
default=EngineArgs.routing_replay_config,
help="Flag of rollout routing replay(r3).",
)
# Load group
load_group = parser.add_argument_group("Load Configuration")
@@ -1076,6 +1088,14 @@ class EngineArgs:
eplb_args[k] = v
return EPLBConfig(eplb_args)
def create_routing_repaly_config(self) -> RoutingReplayConfig:
""" """
routing_replay_args = asdict(self)
if self.routing_replay_config is not None:
for k, v in self.routing_replay_config.items():
routing_replay_args[k] = v
return RoutingReplayConfig(routing_replay_args)
def create_engine_config(self, port_availability_check: bool = True) -> FDConfig:
"""
Create and return a Config object based on the current settings.
@@ -1118,6 +1138,7 @@ class EngineArgs:
graph_opt_cfg.update_use_cudagraph(self.use_cudagraph)
moba_attention_config = self.create_moba_attention_config()
eplb_cfg = self.create_eplb_config()
routing_replay_config = self.create_routing_repaly_config()
early_stop_cfg = self.create_early_stop_config()
early_stop_cfg.update_enable_early_stop(self.enable_early_stop)
@@ -1163,4 +1184,5 @@ class EngineArgs:
early_stop_config=early_stop_cfg,
enable_attention_dp_balance=self.enable_attention_dp_balance,
attention_dp_time_out_iters=self.attention_dp_time_out_iters,
routing_replay_config=routing_replay_config,
)

View File

@@ -462,6 +462,7 @@ class LLMEngine:
f" --moba_attention_config '{self.cfg.moba_attention_config.to_json_string()}'"
f" --attention_dp_time_out_iters {self.cfg.attention_dp_time_out_iters}"
f" --eplb_config '{self.cfg.eplb_config.to_json_string()}'"
f" --routing_replay_config '{self.cfg.routing_replay_config.to_json_string()}'"
f" --ips {ips}"
)

View File

@@ -110,6 +110,8 @@ class ForwardMeta:
block_tables: Optional[paddle.Tensor] = None
# KV caches
caches: Optional[list[paddle.Tensor]] = None
# Routing Replay table buffer
routing_replay_table: Optional[paddle.Tensor] = None
def clear_caches(self):
"""Safely clean up the caches"""

View File

@@ -14,6 +14,8 @@
# limitations under the License.
"""
from typing import Callable
import paddle
from paddle import nn
@@ -102,6 +104,7 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Triton compute Fused MoE.
@@ -119,6 +122,9 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
topk_weights, topk_ids = paddle.topk(scores, k=top_k, axis=-1, sorted=False)
topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdim=True)
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_ids)
intermediate_cache1 = paddle.empty(
[token_num * top_k, moe_intermediate_size * 2],
dtype=x.dtype,

View File

@@ -16,6 +16,7 @@
import multiprocessing
import os
from typing import Callable
import numpy as np
import paddle
@@ -189,6 +190,7 @@ class GCUFusedMoeMethod(UnquantizedFusedMoEMethod):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Paddle gcu compute Fused MoE.
@@ -201,6 +203,7 @@ class GCUFusedMoeMethod(UnquantizedFusedMoEMethod):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Apply the EP prefill method.
@@ -212,6 +215,7 @@ class GCUFusedMoeMethod(UnquantizedFusedMoEMethod):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Apply the EP decoder method.
@@ -223,6 +227,7 @@ class GCUFusedMoeMethod(UnquantizedFusedMoEMethod):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Paddle Cutlass compute Fused MoE.
@@ -388,6 +393,7 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Paddle gcu compute Fused MoE.

View File

@@ -14,6 +14,8 @@
# limitations under the License.
"""
from typing import Callable
import paddle
from paddle import nn
@@ -132,6 +134,7 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Triton compute Fused MoE.
@@ -151,6 +154,10 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
True, # apply_norm_weight,
False,
)
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_ids)
up_gate_proj_out = paddle.empty(
[token_num * top_k, moe_intermediate_size * 2],
dtype=x.dtype,

View File

@@ -15,6 +15,7 @@
"""
from abc import abstractmethod
from typing import Callable
import paddle
from paddle import nn
@@ -120,6 +121,7 @@ class MoEMethodBase(QuantMethodBase):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Apply the EP prefill method.
@@ -144,6 +146,7 @@ class MoEMethodBase(QuantMethodBase):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Paddle Cutlass compute Fused MoE.
@@ -155,6 +158,7 @@ class MoEMethodBase(QuantMethodBase):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Paddle Cutlass compute Fused MoE.
@@ -163,13 +167,13 @@ class MoEMethodBase(QuantMethodBase):
if layer.fd_config.parallel_config.moe_phase.phase == "prefill":
if layer.fd_config.parallel_config.splitwise_role == "mixed" and layer.layer_idx == 0:
self.ep_prefill_runner.clean_low_latency_buffer()
return self.apply_ep_prefill(layer, x, gate)
return self.apply_ep_prefill(layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
else:
if layer.fd_config.parallel_config.splitwise_role == "mixed" and layer.layer_idx == 0:
self.ep_decoder_runner.clean_low_latency_buffer()
return self.apply_ep_decode(layer, x, gate)
return self.apply_ep_decode(layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
else:
return self.apply_tp(layer, x, gate)
return self.apply_tp(layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
class UnquantizedFusedMoEMethod(MoEMethodBase):

View File

@@ -14,6 +14,8 @@
# limitations under the License.
"""
from typing import Callable
import paddle
from paddle import nn
from paddle.nn.quant import weight_quantize
@@ -105,6 +107,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Apply the EP prefill method.
@@ -121,6 +124,10 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
handle,
_,
) = self.ep_prefill_runner.dispatch(x, topk_idx, topk_weights)
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_idx)
token_all_num = sum(recv_num_tokens_per_expert_list)
# 3. Compute ffn
@@ -178,6 +185,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Apply the EP decoder method.
@@ -186,6 +194,10 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
estimate_total_token_nums = gate_out.shape[0] * layer.top_k
# 1. Select topk experts and weights
topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out)
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_idx)
expertwise_scale = None
if hasattr(layer, "up_gate_proj_in_scale_all_experts"): # only use in w4a8
expertwise_scale = getattr(layer, "up_gate_proj_in_scale_all_experts", None)
@@ -220,6 +232,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Paddle Cutlass compute Fused MoE.
@@ -277,6 +290,9 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
topk_only_mode=False,
)
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_idx)
if self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8":
# only w4a8 need expert_idx_per_token
# Other need not this tensor, so we make it None.

View File

@@ -14,6 +14,8 @@
# limitations under the License.
"""
from typing import Callable
import paddle
from paddle import nn
from paddleformers.utils.log import logger
@@ -298,6 +300,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Apply the EP prefill method.
@@ -305,6 +308,10 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
gate_out = gate(x.cast("float32"))
# 1. Select topk experts and weights
topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out)
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_idx)
# 2. Dynamic compute blockwise quantization scales
x, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
x, self.quant_config.weight_block_size[0]
@@ -406,6 +413,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Apply the EP decoder method.
@@ -413,6 +421,10 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
gate_out = gate(x.cast("float32"))
# 1. Select topk experts and weights
topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out)
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_idx)
# 2. EP Dispatch
permute_input, token_nums_per_expert, handle = self.ep_decoder_runner.dispatch(
x, topk_idx, topk_weights, use_fp8=True
@@ -477,6 +489,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Paddle Use DeepGemm compute Fused MoE.
@@ -504,6 +517,9 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
False,
)
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_ids)
tmp = count_tokens_per_expert_func(topk_ids, layer.num_experts)
recv_x, recv_x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(x, 128)

View File

@@ -14,6 +14,8 @@
# limitations under the License.
"""
from typing import Callable
import paddle
from paddle import nn
@@ -240,6 +242,7 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Marlin compute Fused MoE.
@@ -275,6 +278,9 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
False,
)
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_ids)
block_size_m = 64
for m in [8, 16, 32, 48, 64]:

View File

@@ -14,6 +14,8 @@
# limitations under the License.
"""
from typing import Callable
import paddle
from paddle import nn
@@ -156,6 +158,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Triton compute Fused MoE.
@@ -186,6 +189,9 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
False,
)
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_ids)
up_gate_proj_out = paddle.empty(
[token_num * top_k, moe_intermediate_size * 2],
dtype=x.dtype,
@@ -420,6 +426,7 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Triton compute Fused MoE.
@@ -451,6 +458,9 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
False,
)
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_ids)
up_gate_proj_out = paddle.empty(
[token_num * top_k, moe_intermediate_size * 2],
dtype=x.dtype,
@@ -840,6 +850,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Triton compute Fused MoE.
@@ -871,6 +882,9 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
False,
)
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_ids)
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": self.quant_config.weight_block_size[1],

View File

@@ -14,6 +14,8 @@
# limitations under the License.
"""
from typing import Callable
import paddle
from paddle import nn
@@ -261,6 +263,7 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Use Wint2 Triton Fusedmoe compute Fused MoE.
@@ -288,6 +291,9 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
topk_only_mode=False,
)
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_idx)
ffn_out = fastdeploy.model_executor.ops.gpu.moe_expert_ffn_wint2(
permute_input,
token_nums_per_expert,
@@ -333,6 +339,7 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Use Wint2 Triton Fusedmoe compute Fused MoE.
@@ -348,6 +355,9 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
False,
)
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_ids)
num_tokens, K = x.shape
E, _, N = layer.up_gate_proj_weight.shape
M = num_tokens

View File

@@ -14,6 +14,8 @@
# limitations under the License.
"""
from typing import Callable
import paddle
from paddle import nn
@@ -47,6 +49,7 @@ class XPUMoEMethod(UnquantizedFusedMoEMethod):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Paddle Cutlass compute Fused MoE.
@@ -82,6 +85,7 @@ class XPUMoEMethod(UnquantizedFusedMoEMethod):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Apply the EP prefill method.
@@ -93,6 +97,7 @@ class XPUMoEMethod(UnquantizedFusedMoEMethod):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
Apply the EP decoder method.
@@ -227,6 +232,7 @@ class XPUWeightOnlyMoEMethod(QuantMethodBase):
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
"""
XPU compute Fused MoE.

View File

@@ -14,6 +14,7 @@
# limitations under the License.
"""
from functools import partial
from typing import Optional
import numpy as np
@@ -22,6 +23,10 @@ from paddle import nn
from paddleformers.utils.log import logger
from fastdeploy import envs
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.layers.moe.routing_indices_cache import (
save_routing_to_buffer,
)
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.utils import slice_fn
from fastdeploy.platforms import current_platform
@@ -195,6 +200,7 @@ class FusedMoE(nn.Layer):
if self.ep_size > 1:
self.quant_method.init_ep(self)
self.enable_routing_replay = fd_config.routing_replay_config.enable_routing_replay
# Merge normal and RL build model
if gate_correction_bias is not None:
self.gate_correction_bias = gate_correction_bias
@@ -532,7 +538,7 @@ class FusedMoE(nn.Layer):
else:
self.quant_method.process_loaded_weights(self, state_dict)
def forward(self, x: paddle.Tensor, gate: nn.Layer):
def forward(self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta = None):
"""
Defines the forward computation of the moe layer.
@@ -543,5 +549,20 @@ class FusedMoE(nn.Layer):
Tensor: Output tensor.s
"""
out = self.quant_method.apply(self, x, gate)
topk_ids_hookfunc = None
if self.enable_routing_replay:
if forward_meta is not None: # forward_meta is None when execute empty_input_forward
topk_ids_hookfunc = partial(
save_routing_to_buffer,
routing_replay_table=forward_meta.routing_replay_table,
batch_id_per_token=forward_meta.batch_id_per_token,
seq_lens_decoder=forward_meta.seq_lens_decoder,
cu_seqlens_q=forward_meta.cu_seqlens_q,
layer_idx=self.layer_idx,
tp_size=self.fd_config.parallel_config.tensor_parallel_size,
ep_size=self.fd_config.parallel_config.expert_parallel_size,
tp_group=self.fd_config.parallel_config.tp_group,
)
out = self.quant_method.apply(self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
return out

View File

@@ -0,0 +1,328 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import copy
import os
import shutil
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
import paddle
import paddle.distributed as dist
import triton
import triton.language as tl
from fastdeploy.config import FDConfig
@triton.jit
def _save_routing_kernel(
ROUTING_REPLAY_TABLE_PTR,
TOPK_IDS_PTR,
BATCH_ID_PER_TOKEN_PTR,
CU_SEQLENS_Q_PTR,
SEQ_LENS_DECODER_PTR,
LAYER_IDX,
TOKEN_NUM,
TOP_K,
NUM_HIDDEN_LAYERS,
MAX_MODEL_LEN,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
token_offsets = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
token_mask = token_offsets < TOKEN_NUM
k_offsets = tl.arange(0, BLOCK_SIZE_K)
topk_ids_ptrs = TOPK_IDS_PTR + token_offsets[:, None] * TOP_K + k_offsets[None, :]
# [BLOCK_SIZE_M, BLOCK_SIZE_K]
topk_vals = tl.load(topk_ids_ptrs, mask=token_mask[:, None])
batch_ids = tl.load(BATCH_ID_PER_TOKEN_PTR + token_offsets, mask=token_mask)
pad_mask = token_mask & (batch_ids != -1)
# [0, 3, 4, 10, 12][0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 3, 3]
# -> [0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 10, 10]
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] - [0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 10, 10]
# -> [0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 0, 1]
start_offsets = tl.load(CU_SEQLENS_Q_PTR + batch_ids, mask=pad_mask)
token_relative_index = token_offsets - start_offsets
# [BLOCK_SIZE_M]
len_decoder = tl.load(SEQ_LENS_DECODER_PTR + batch_ids, mask=pad_mask)
token_seq_pos = len_decoder + token_relative_index
STRIDE_BUF_SEQ = NUM_HIDDEN_LAYERS * MAX_MODEL_LEN * TOP_K
STRIDE_BUF_LAYER = MAX_MODEL_LEN * TOP_K
STRIDE_BUF_TOKEN = TOP_K
# [BLOCK_SIZE_M, BLOCK_SIZE_K]
output_ptrs = (
ROUTING_REPLAY_TABLE_PTR
+ batch_ids[:, None] * STRIDE_BUF_SEQ
+ LAYER_IDX * STRIDE_BUF_LAYER
+ token_seq_pos[:, None] * STRIDE_BUF_TOKEN
+ k_offsets[None, :]
)
pos_mask = token_seq_pos < MAX_MODEL_LEN
pos_mask = pos_mask & pad_mask
final_mask = token_mask[:, None] & pos_mask[:, None]
tl.store(output_ptrs, topk_vals, mask=final_mask)
def save_routing_to_buffer(
routing_replay_table: paddle.Tensor, # [max_num_seqs, num_layers, max_len, top_k]
topk_ids: paddle.Tensor, # [token_num, top_k]
batch_id_per_token: paddle.Tensor, # [token_num, 1]
seq_lens_decoder: paddle.Tensor, # [max_num_seqs, 1]
cu_seqlens_q: paddle.Tensor, # [max_num_seqs + 1, 1]
layer_idx: int,
tp_size: int,
ep_size: int,
tp_group: dist.communication.group.Group,
):
if tp_size > 1 and ep_size > 1:
token_num_per_rank = topk_ids.shape[0]
topk_ids_all = paddle.zeros([token_num_per_rank * tp_size, topk_ids.shape[1]], dtype=topk_ids.dtype)
paddle.distributed.all_gather(topk_ids_all, topk_ids, tp_group)
topk_ids = topk_ids_all[: batch_id_per_token.shape[0], :]
token_num, top_k = topk_ids.shape
max_num_seqs, num_hidden_layers, max_model_len, _ = routing_replay_table.shape
assert token_num > 0
assert topk_ids.shape[1] == routing_replay_table.shape[3], (topk_ids.shape[1], routing_replay_table.shape[3])
assert batch_id_per_token.shape[0] == token_num, (batch_id_per_token.shape[0], token_num)
assert seq_lens_decoder.shape[0] == max_num_seqs, (seq_lens_decoder.shape[0], max_num_seqs)
BLOCK_SIZE_M = 128
BLOCK_SIZE_K = top_k
grid = (triton.cdiv(token_num, BLOCK_SIZE_M),)
_save_routing_kernel[grid](
routing_replay_table,
topk_ids,
batch_id_per_token,
cu_seqlens_q,
seq_lens_decoder,
LAYER_IDX=layer_idx,
TOKEN_NUM=token_num,
TOP_K=top_k,
NUM_HIDDEN_LAYERS=num_hidden_layers,
MAX_MODEL_LEN=max_model_len,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_K=BLOCK_SIZE_K,
)
class RoutingReplayManager:
"""Request level routing replay table manager"""
def __init__(
self,
fd_config: FDConfig,
):
self.max_num_seqs = fd_config.parallel_config.max_num_seqs
self.max_model_len = fd_config.model_config.max_model_len
self.num_moe_layers = fd_config.model_config.num_hidden_layers - fd_config.model_config.moe_layer_start_index
self.moe_top_k = fd_config.model_config.moe_k
self.tp_rank = fd_config.parallel_config.tensor_parallel_rank
self.routing_store = get_routing_store(fd_config=fd_config)
self.routing_batch_to_request: Dict[int, str] = {}
self.routing_replay_table = paddle.full(
shape=[self.max_num_seqs, self.num_moe_layers, self.max_model_len, self.moe_top_k],
fill_value=-1,
dtype="int32",
)
def register_request(self, batch_id: int, request_id: str):
"""
Register a new request to routing replay table
Args:
batch_id: The batch ID of this request
request_id: The global ID of the request is usually executed by the training process in RL
"""
# Save requests that have been finished for the current slot
if batch_id in self.routing_batch_to_request:
pre_request_id = self._deregister_request(batch_id)
self._put_request_to_store(batch_id, pre_request_id)
# Register the new request
self.routing_batch_to_request[batch_id] = request_id
def _deregister_request(self, batch_id: int) -> str:
"""
Deregister a request from routing replay table
"""
assert batch_id in self.routing_batch_to_request
return self.routing_batch_to_request.pop(batch_id)
def _put_request_to_store(
self,
batch_id: int,
request_id: str,
):
if self.tp_rank == 0:
batch_buffer = self.routing_replay_table[batch_id]
for layer_id in range(self.num_moe_layers):
layer_buffer = batch_buffer[layer_id]
rollout_id = self.split_request_id(request_id)
self.routing_store.put(routing_indices=layer_buffer, rollout_id=rollout_id, layer_idx=layer_id)
self._clear_table_slot(batch_id)
def put_table_to_store(self):
"""Put the routin table"""
batch_ids = copy.deepcopy(list(self.routing_batch_to_request.keys()))
for batch_id in batch_ids:
request_id = self._deregister_request(batch_id)
self._put_request_to_store(batch_id, request_id)
def _clear_table_slot(self, batch_id: int):
assert 0 <= batch_id < self.max_num_seqs
self.routing_replay_table[batch_id].fill_(-1)
def clear_routing_table(self):
"""Clear all slots of the routing replay table"""
self.routing_replay_table.fill_(-1)
def _clear_store(self):
"""Clear routing store"""
self.routing_store.clear_store()
def _clear_request_of_store(self, request_id):
"""Clear one request of routing store"""
rollout_id = self.split_request_id(request_id)
for layer_idx in range(self.num_moe_layers):
self.routing_store.clear(rollout_id=rollout_id, layer_idx=layer_idx)
def get_request_from_store(self, request_id: str) -> List[paddle.Tensor]:
"""Get the routing indices of the reuest from store"""
routing_list = []
rollout_id = self.split_request_id(request_id)
for layer_idx in range(self.num_moe_layers):
one_layer_routing = self.routing_store.get(rollout_id, layer_idx)
routing_list.append(one_layer_routing)
return routing_list
def get_routing_table(self) -> paddle.Tensor:
return self.routing_replay_table
def split_request_id(self, request_id: str):
"""Split the request id to get rollout id"""
chat_type, tmp_str = request_id.split("-", 1)
assert chat_type == "chatcmpl"
reversed_tmp_str = tmp_str[::-1].split("-", 5)
rollout_id = reversed_tmp_str[-1][::-1]
return rollout_id
class RoutingStoreBase(ABC):
"""Base class for routing store"""
def __init__(self, fd_config: FDConfig) -> None:
self.fd_config = fd_config
@abstractmethod
def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: Optional[int] = None) -> None:
"""Put the routing indices into store"""
raise NotImplementedError
@abstractmethod
def get(self, rollout_id: str, layer_idx: Optional[int] = None) -> paddle.Tensor:
"""Get the routing indices from store"""
raise NotImplementedError
@abstractmethod
def clear(self, rollout_id: str, layer_idx: Optional[int] = None) -> None:
"""Clear the routing indices of the request"""
raise NotImplementedError
@abstractmethod
def clear_store(
self,
):
"""Clear the routing indices store"""
raise NotImplementedError
class RoutingStoreLocal(RoutingStoreBase):
"""Routing Store using local memory"""
def __init__(self, fd_config) -> None:
super().__init__(fd_config=fd_config)
self.local_store_dir = fd_config.routing_replay_config.local_store_dir
def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: int) -> None:
"""Put the routing indices into store"""
dir_path = os.path.join(self.local_store_dir, f"{rollout_id}")
os.makedirs(dir_path, exist_ok=True)
file_path = os.path.join(dir_path, f"layer_{layer_idx}.pdtensor")
paddle.save(routing_indices, file_path)
def get(
self,
rollout_id: str,
layer_idx: int = None,
) -> paddle.Tensor:
"""Get the routing indices from store"""
dir_path = os.path.join(self.local_store_dir, f"{rollout_id}")
file_path = os.path.join(dir_path, f"layer_{layer_idx}.pdtensor")
assert os.path.exists(file_path), f"File not found: {file_path}"
layer_routing_indices = paddle.load(file_path)
return layer_routing_indices
def clear(
self,
rollout_id: str,
layer_idx: int = None,
) -> None:
"""Clear the routing indices of the request"""
dir_path = os.path.join(self.local_store_dir, f"{rollout_id}")
file_path = os.path.join(dir_path, f"layer_{layer_idx}.pdtensor")
assert os.path.exists(file_path), f"File not found: {file_path}"
os.remove(file_path)
# Delete empty directory
if len(os.listdir(dir_path)) == 0:
os.rmdir(dir_path)
def clear_store(self):
"""Clear the routing indices store"""
if os.path.isdir(self.local_store_dir):
for file_name in os.listdir(self.local_store_dir):
file_path = os.path.join(self.local_store_dir, file_name)
shutil.rmtree(file_path)
class RoutingStoreRDMA(RoutingStoreBase):
"""Routing Store using RDMA"""
def __init__(self) -> None:
super().__init__()
def get_routing_store(fd_config: FDConfig) -> RoutingStoreBase:
if fd_config.routing_replay_config.routing_store_type == "local":
return RoutingStoreLocal(fd_config=fd_config)
elif fd_config.routing_replay_config.routing_store_type == "rdma":
return RoutingStoreRDMA(fd_config=fd_config)
else:
raise ValueError("Invalid store type")

View File

@@ -67,6 +67,7 @@ class RolloutModelConfig:
enable_attention_dp_balance: bool = False,
attention_dp_time_out_iters: int = 0,
eplb_config: str = {},
routing_replay_config: str = None,
):
# Required parameters
self.model = model_name_or_path
@@ -117,6 +118,7 @@ class RolloutModelConfig:
self.enable_attention_dp_balance = enable_attention_dp_balance
self.attention_dp_time_out_iters = attention_dp_time_out_iters
self.eplb_config = eplb_config
self.routing_replay_config = routing_replay_config
def __str__(self):
return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items())

View File

@@ -38,6 +38,7 @@ from fastdeploy.config import (
MobaAttentionConfig,
ModelConfig,
ParallelConfig,
RoutingReplayConfig,
SpeculativeConfig,
)
from fastdeploy.engine.request import RequestType
@@ -840,6 +841,13 @@ def parse_args():
help="EPLB Configuration.",
)
parser.add_argument(
"--routing_replay_config",
type=json.loads,
default=None,
help="Configation of Rollout Routing Replay.",
)
args = parser.parse_args()
return args
@@ -900,6 +908,8 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
eplb_config = EPLBConfig(args.eplb_config)
routing_replay_config = RoutingReplayConfig(args.routing_replay_config)
# Note(tangbinhan): used for load_checkpoint
model_config.pretrained_config.tensor_parallel_rank = parallel_config.tensor_parallel_rank
model_config.pretrained_config.tensor_parallel_degree = parallel_config.tensor_parallel_size
@@ -998,6 +1008,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
enable_attention_dp_balance=args.enable_attention_dp_balance,
attention_dp_time_out_iters=args.attention_dp_time_out_iters,
eplb_config=eplb_config,
routing_replay_config=routing_replay_config,
)
update_fd_config_for_mm(fd_config)