mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[RL] Support Rollout Routing Replay (#5321)
* [RL] Support Rollout Routing Replay * add routing indices cache * fix config bug and moe forward bug * R3 Support GLM * support eb4.5 * fix merge bug * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * add routing replay ci * support glm topk * support orther top_k * fix ci bug * pre-commit * only support chatcmpl --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Yuanle Liu <yuanlehome@163.com>
This commit is contained in:
@@ -1484,6 +1484,31 @@ class StructuredOutputsConfig:
|
||||
return json.dumps({key: value for key, value in self.__dict__.items()})
|
||||
|
||||
|
||||
class RoutingReplayConfig:
|
||||
"""Configuration for Routing Replay used in RL training"""
|
||||
|
||||
def __init__(self, args) -> None:
|
||||
self.enable_routing_replay: bool = False
|
||||
self.routing_store_type: str = "local"
|
||||
|
||||
# Local routing store
|
||||
self.local_store_dir: str = "./routing_replay_output"
|
||||
|
||||
# RDMA routing store
|
||||
# TODO: Add RDMA routing store configuration attributes here when the feature is implemented.
|
||||
|
||||
if args is not None:
|
||||
for key, value in args.items():
|
||||
if hasattr(self, key) and value != "None":
|
||||
setattr(self, key, value)
|
||||
|
||||
def to_json_string(self):
|
||||
"""
|
||||
Convert routing replay config to json string.
|
||||
"""
|
||||
return json.dumps({key: value for key, value in self.__dict__.items()})
|
||||
|
||||
|
||||
class FDConfig:
|
||||
"""
|
||||
The configuration class which contains all fastdeploy-related configuration. This
|
||||
@@ -1517,6 +1542,7 @@ class FDConfig:
|
||||
early_stop_config: Optional[Dict[str, Any]] = None,
|
||||
tool_parser: str = None,
|
||||
test_mode=False,
|
||||
routing_replay_config: Optional[RoutingReplayConfig] = None,
|
||||
):
|
||||
self.model_config: ModelConfig = model_config # type: ignore
|
||||
self.cache_config: CacheConfig = cache_config # type: ignore
|
||||
@@ -1533,6 +1559,7 @@ class FDConfig:
|
||||
self.plas_attention_config: Optional[PlasAttentionConfig] = plas_attention_config
|
||||
self.structured_outputs_config: StructuredOutputsConfig = structured_outputs_config
|
||||
self.router_config: RouterConfig = router_config
|
||||
self.routing_replay_config = routing_replay_config
|
||||
|
||||
# Initialize cuda graph capture list
|
||||
max_capture_shape = self.scheduler_config.max_num_seqs
|
||||
|
||||
@@ -35,6 +35,7 @@ from fastdeploy.config import (
|
||||
PlasAttentionConfig,
|
||||
PoolerConfig,
|
||||
RouterConfig,
|
||||
RoutingReplayConfig,
|
||||
RunnerOption,
|
||||
SpeculativeConfig,
|
||||
StructuredOutputsConfig,
|
||||
@@ -491,6 +492,11 @@ class EngineArgs:
|
||||
Configuration for eplb.
|
||||
"""
|
||||
|
||||
routing_replay_config: Optional[Dict[str, Any]] = None
|
||||
"""
|
||||
Flag to rollout routing replay(r3)
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
Post-initialization processing to set default tokenizer if not provided.
|
||||
@@ -882,6 +888,12 @@ class EngineArgs:
|
||||
default=EngineArgs.eplb_config,
|
||||
help="Config of eplb.",
|
||||
)
|
||||
parallel_group.add_argument(
|
||||
"--routing-replay-config",
|
||||
type=json.loads,
|
||||
default=EngineArgs.routing_replay_config,
|
||||
help="Flag of rollout routing replay(r3).",
|
||||
)
|
||||
parallel_group.add_argument(
|
||||
"--enable-chunked-moe",
|
||||
action="store_true",
|
||||
@@ -1235,6 +1247,14 @@ class EngineArgs:
|
||||
eplb_args["enable_eplb"] = self.enable_eplb
|
||||
return EPLBConfig(eplb_args)
|
||||
|
||||
def create_routing_repaly_config(self) -> RoutingReplayConfig:
|
||||
""" """
|
||||
routing_replay_args = asdict(self)
|
||||
if self.routing_replay_config is not None:
|
||||
for k, v in self.routing_replay_config.items():
|
||||
routing_replay_args[k] = v
|
||||
return RoutingReplayConfig(routing_replay_args)
|
||||
|
||||
def create_engine_config(self, port_availability_check=True) -> FDConfig:
|
||||
"""
|
||||
Create and return a Config object based on the current settings.
|
||||
@@ -1278,6 +1298,7 @@ class EngineArgs:
|
||||
graph_opt_cfg = self.create_graph_optimization_config()
|
||||
plas_attention_config = self.create_plas_attention_config()
|
||||
eplb_cfg = self.create_eplb_config()
|
||||
routing_replay_config = self.create_routing_repaly_config()
|
||||
router_config = RouterConfig(all_dict)
|
||||
|
||||
early_stop_cfg = self.create_early_stop_config()
|
||||
@@ -1310,4 +1331,5 @@ class EngineArgs:
|
||||
graph_opt_config=graph_opt_cfg,
|
||||
plas_attention_config=plas_attention_config,
|
||||
early_stop_config=early_stop_cfg,
|
||||
routing_replay_config=routing_replay_config,
|
||||
)
|
||||
|
||||
@@ -568,6 +568,7 @@ class LLMEngine:
|
||||
f" --logprobs_mode {self.cfg.model_config.logprobs_mode}"
|
||||
f" --max_logprobs {self.cfg.model_config.max_logprobs}"
|
||||
f" --eplb_config '{self.cfg.eplb_config.to_json_string()}'"
|
||||
f" --routing_replay_config '{self.cfg.routing_replay_config.to_json_string()}'"
|
||||
)
|
||||
if self.cfg.structured_outputs_config.logits_processors is not None:
|
||||
arguments += f" --logits-processors {' '.join(self.cfg.structured_outputs_config.logits_processors)}"
|
||||
|
||||
@@ -142,6 +142,8 @@ class ForwardMeta:
|
||||
caches: Optional[list[paddle.Tensor]] = None
|
||||
# Flag of profile run
|
||||
is_dummy_or_profile_run: bool = False
|
||||
# Routing Replay table buffer
|
||||
routing_replay_table: Optional[paddle.Tensor] = None
|
||||
|
||||
# chunked MoE related
|
||||
moe_num_chunk: int = 1
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
@@ -101,6 +103,7 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Triton compute Fused MoE.
|
||||
@@ -117,6 +120,8 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
scores += layer.gate_correction_bias
|
||||
topk_weights, topk_ids = paddle.topk(scores, k=top_k, axis=-1, sorted=False)
|
||||
topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdim=True)
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_ids)
|
||||
|
||||
intermediate_cache1 = paddle.empty(
|
||||
[token_num * top_k, moe_intermediate_size * 2],
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
@@ -182,6 +183,7 @@ class GCUFusedMoeMethod(UnquantizedFusedMoEMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle gcu compute Fused MoE.
|
||||
@@ -194,6 +196,7 @@ class GCUFusedMoeMethod(UnquantizedFusedMoEMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP prefill method.
|
||||
@@ -205,6 +208,7 @@ class GCUFusedMoeMethod(UnquantizedFusedMoEMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP decoder method.
|
||||
@@ -216,6 +220,7 @@ class GCUFusedMoeMethod(UnquantizedFusedMoEMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
@@ -381,6 +386,7 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle gcu compute Fused MoE.
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
@@ -245,6 +247,7 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Triton compute Fused MoE.
|
||||
@@ -274,6 +277,9 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
True, # apply_norm_weight
|
||||
False,
|
||||
)
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_ids)
|
||||
|
||||
up_gate_proj_out = paddle.empty(
|
||||
[token_num * top_k, moe_intermediate_size * 2],
|
||||
dtype=x.dtype,
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
"""
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import Callable
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
@@ -163,6 +164,7 @@ class MoEMethodBase(QuantMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP prefill method.
|
||||
@@ -175,6 +177,7 @@ class MoEMethodBase(QuantMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP decoder method.
|
||||
@@ -187,6 +190,7 @@ class MoEMethodBase(QuantMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
@@ -198,6 +202,7 @@ class MoEMethodBase(QuantMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
@@ -207,13 +212,13 @@ class MoEMethodBase(QuantMethodBase):
|
||||
if layer.fd_config.model_config.moe_phase.phase == "prefill":
|
||||
if layer.fd_config.scheduler_config.splitwise_role == "mixed" and is_moe_start_layer:
|
||||
self.ep_prefill_runner.clean_low_latency_buffer()
|
||||
return self.apply_ep_prefill(layer, x, gate)
|
||||
return self.apply_ep_prefill(layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
else:
|
||||
if layer.fd_config.scheduler_config.splitwise_role == "mixed" and is_moe_start_layer:
|
||||
self.ep_decoder_runner.clean_low_latency_buffer()
|
||||
return self.apply_ep_decode(layer, x, gate)
|
||||
return self.apply_ep_decode(layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
else:
|
||||
return self.apply_tp(layer, x, gate)
|
||||
return self.apply_tp(layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
|
||||
|
||||
class UnquantizedFusedMoEMethod(MoEMethodBase):
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.nn.quant import weight_quantize
|
||||
@@ -132,6 +134,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP prefill method.
|
||||
@@ -148,8 +151,13 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
handle,
|
||||
event,
|
||||
) = self.ep_prefill_runner.dispatch(x, topk_idx, topk_weights)
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_idx)
|
||||
|
||||
if self.ep_prefill_runner.ep_engine.async_finish:
|
||||
event.current_stream_wait()
|
||||
|
||||
token_all_num = sum(recv_num_tokens_per_expert_list)
|
||||
|
||||
# 3. Compute ffn
|
||||
@@ -217,6 +225,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP decoder method.
|
||||
@@ -225,6 +234,10 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
estimate_total_token_nums = gate_out.shape[0] * layer.top_k
|
||||
# 1. Select topk experts and weights
|
||||
topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out)
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_idx)
|
||||
|
||||
expertwise_scale = None
|
||||
if hasattr(layer, "up_gate_proj_in_scale_all_experts"): # only use in w4a8
|
||||
expertwise_scale = getattr(layer, "up_gate_proj_in_scale_all_experts", None)
|
||||
@@ -269,6 +282,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
@@ -369,6 +383,9 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
if hasattr(layer, "up_gate_proj_in_scale"):
|
||||
dequant_scale = None
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_idx)
|
||||
|
||||
if not layer.with_bias and self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8":
|
||||
# only w4a8 need expert_idx_per_token
|
||||
# Other need not this tensor, so we make it None.
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.distributed.communication import deep_ep
|
||||
@@ -139,6 +141,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP prefill method.
|
||||
@@ -147,6 +150,10 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
|
||||
# 1. Select topk experts and weights
|
||||
topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out)
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_idx)
|
||||
|
||||
# 2. Dynamic compute blockwise quantization scales
|
||||
x, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
|
||||
x, self.quant_config.weight_block_size[0]
|
||||
@@ -264,6 +271,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP decoder method.
|
||||
@@ -271,6 +279,10 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
gate_out = gate(x.cast("float32"))
|
||||
# 1. Select topk experts and weights
|
||||
topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out)
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_idx)
|
||||
|
||||
# 2. EP Dispatch
|
||||
permute_input, token_nums_per_expert, handle = self.ep_decoder_runner.dispatch(
|
||||
x, topk_idx, topk_weights, use_fp8=True
|
||||
@@ -335,6 +347,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle Use DeepGemm compute Fused MoE.
|
||||
@@ -363,6 +376,9 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
False,
|
||||
)
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_ids)
|
||||
|
||||
tmp = count_tokens_per_expert_func(topk_ids, layer.num_experts)
|
||||
|
||||
recv_x, recv_x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(x, 128)
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
@@ -239,6 +241,7 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Marlin compute Fused MoE.
|
||||
@@ -273,6 +276,9 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
|
||||
False,
|
||||
)
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_ids)
|
||||
|
||||
block_size_m = 64
|
||||
|
||||
for m in [8, 16, 32, 48, 64]:
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
@@ -282,6 +284,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Triton compute Fused MoE.
|
||||
@@ -314,6 +317,10 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
True, # apply_norm_weight,
|
||||
False,
|
||||
)
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_ids)
|
||||
|
||||
up_gate_proj_out = paddle.empty(
|
||||
[token_num * top_k, moe_intermediate_size * 2],
|
||||
dtype=x.dtype,
|
||||
@@ -664,6 +671,7 @@ class Wfp8Afp8MoEMethod(QuantMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Triton compute Fused MoE.
|
||||
@@ -724,6 +732,9 @@ class Wfp8Afp8MoEMethod(QuantMethodBase):
|
||||
* ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_ids)
|
||||
|
||||
up_gate_proj_out = paddle.empty(
|
||||
[token_num * top_k, moe_intermediate_size * 2],
|
||||
dtype=x.dtype,
|
||||
@@ -953,6 +964,7 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Triton compute Fused MoE.
|
||||
@@ -974,6 +986,9 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
False,
|
||||
)
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_ids)
|
||||
|
||||
up_gate_proj_out = paddle.empty(
|
||||
[token_num * top_k, moe_intermediate_size * 2],
|
||||
dtype=x.dtype,
|
||||
@@ -1466,6 +1481,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Triton compute Fused MoE.
|
||||
@@ -1488,6 +1504,8 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
True, # apply_norm_weight
|
||||
False,
|
||||
)
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_ids)
|
||||
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
@@ -261,6 +263,7 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Use Wint2 Triton Fusedmoe compute Fused MoE.
|
||||
@@ -288,6 +291,9 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
|
||||
topk_only_mode=False,
|
||||
)
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_idx)
|
||||
|
||||
ffn_out = fastdeploy.model_executor.ops.gpu.moe_expert_ffn_wint2(
|
||||
permute_input,
|
||||
token_nums_per_expert,
|
||||
@@ -328,6 +334,7 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Use Wint2 Triton Fusedmoe compute Fused MoE.
|
||||
@@ -343,6 +350,9 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
|
||||
False,
|
||||
)
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_ids)
|
||||
|
||||
num_tokens, K = x.shape
|
||||
E, _, N = layer.up_gate_proj_weight.shape
|
||||
M = num_tokens
|
||||
|
||||
@@ -14,7 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from functools import partial
|
||||
from typing import Callable, Optional
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
@@ -26,6 +27,9 @@ from fastdeploy.distributed.communication import (
|
||||
tensor_model_parallel_all_reduce_custom,
|
||||
)
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
from fastdeploy.model_executor.layers.moe.routing_indices_cache import (
|
||||
save_routing_to_buffer,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
from fastdeploy.model_executor.utils import h2d_copy, slice_fn
|
||||
from fastdeploy.platforms import current_platform
|
||||
@@ -226,7 +230,7 @@ class FusedMoE(nn.Layer):
|
||||
self.is_rearrange = False
|
||||
if self.ep_size > 1:
|
||||
self.quant_method.init_ep(self)
|
||||
|
||||
self.enable_routing_replay = fd_config.routing_replay_config.enable_routing_replay
|
||||
# Merge normal and RL build model
|
||||
if gate_correction_bias is not None:
|
||||
self.gate_correction_bias = gate_correction_bias
|
||||
@@ -600,7 +604,7 @@ class FusedMoE(nn.Layer):
|
||||
else:
|
||||
self.quant_method.process_loaded_weights(self, state_dict)
|
||||
|
||||
def forward_split_allgather(self, x: paddle.Tensor, gate: nn.Layer):
|
||||
def forward_split_allgather(self, x: paddle.Tensor, gate: nn.Layer, topk_ids_hookfunc: Callable = None):
|
||||
"""
|
||||
Forward split allgather function.
|
||||
"""
|
||||
@@ -615,14 +619,14 @@ class FusedMoE(nn.Layer):
|
||||
if end_offset > token_num:
|
||||
end_offset = token_num
|
||||
part_x[: (end_offset - start_offset), :] = x[start_offset:end_offset, :]
|
||||
out = self.quant_method.apply(self, part_x, gate)
|
||||
out = self.quant_method.apply(self, part_x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
multi_outs = paddle.zeros([token_num_per_rank * self.attn_tp_size, x.shape[1]], dtype=x.dtype)
|
||||
paddle.distributed.all_gather(multi_outs, out, self.tp_group)
|
||||
out = multi_outs[:token_num, :]
|
||||
|
||||
return out
|
||||
|
||||
def forward(self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta):
|
||||
def forward(self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta = None):
|
||||
"""
|
||||
Defines the forward computation of the moe layer.
|
||||
|
||||
@@ -633,6 +637,21 @@ class FusedMoE(nn.Layer):
|
||||
Tensor: Output tensor.s
|
||||
|
||||
"""
|
||||
topk_ids_hookfunc = None
|
||||
if self.enable_routing_replay:
|
||||
if forward_meta is not None: # forward_meta is None when execute empty_input_forward
|
||||
topk_ids_hookfunc = partial(
|
||||
save_routing_to_buffer,
|
||||
routing_replay_table=forward_meta.routing_replay_table,
|
||||
batch_id_per_token=forward_meta.batch_id_per_token,
|
||||
seq_lens_decoder=forward_meta.seq_lens_decoder,
|
||||
cu_seqlens_q=forward_meta.cu_seqlens_q,
|
||||
layer_idx=self.layer_idx,
|
||||
tp_size=self.fd_config.parallel_config.tensor_parallel_size,
|
||||
ep_size=self.fd_config.parallel_config.expert_parallel_size,
|
||||
tp_group=self.fd_config.parallel_config.tp_group,
|
||||
)
|
||||
|
||||
token_num = x.shape[0]
|
||||
if (
|
||||
self.ep_size > 1
|
||||
@@ -640,11 +659,16 @@ class FusedMoE(nn.Layer):
|
||||
and (not self.fd_config.parallel_config.use_sequence_parallel_moe)
|
||||
and token_num >= self.attn_tp_size
|
||||
):
|
||||
out = self.forward_split_allgather(x, gate)
|
||||
out = self.forward_split_allgather(x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
elif self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.enable_chunked_moe:
|
||||
out = self.forward_chunked_moe(x, gate, forward_meta)
|
||||
out = self.forward_chunked_moe(
|
||||
x,
|
||||
gate,
|
||||
forward_meta,
|
||||
topk_ids_hookfunc=topk_ids_hookfunc,
|
||||
)
|
||||
else:
|
||||
out = self.forward_normal(x, gate)
|
||||
out = self.forward_normal(x, gate, forward_meta, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
if current_platform.is_intel_hpu():
|
||||
@@ -653,7 +677,9 @@ class FusedMoE(nn.Layer):
|
||||
out = tensor_model_parallel_all_reduce(out, self.tp_group)
|
||||
return out
|
||||
|
||||
def forward_chunked_moe(self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta):
|
||||
def forward_chunked_moe(
|
||||
self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta, topk_ids_hookfunc: Callable = None
|
||||
):
|
||||
"""
|
||||
Split input to multi chunk to reduce the memory usage of moe.
|
||||
|
||||
@@ -677,21 +703,25 @@ class FusedMoE(nn.Layer):
|
||||
|
||||
for i in range(forward_meta.max_moe_num_chunk):
|
||||
if i < forward_meta.moe_num_chunk:
|
||||
out_split_list[i] = self.quant_method.apply(self, x_split_list[i], gate)
|
||||
out_split_list[i] = self.quant_method.apply(
|
||||
self, x_split_list[i], gate, topk_ids_hookfunc=topk_ids_hookfunc
|
||||
)
|
||||
else:
|
||||
# just need to use real data to infer max_moe_num_chunk times.
|
||||
self.quant_method.apply(self, fake_x, gate)
|
||||
self.quant_method.apply(self, fake_x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
|
||||
out = paddle.concat(out_split_list, axis=0)
|
||||
else:
|
||||
# when only one chunk, just need to use real data to infer once.
|
||||
out = self.quant_method.apply(self, x, gate)
|
||||
out = self.quant_method.apply(self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
for i in range(forward_meta.max_moe_num_chunk - 1):
|
||||
self.quant_method.apply(self, fake_x, gate)
|
||||
self.quant_method.apply(self, fake_x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
|
||||
return out
|
||||
|
||||
def forward_normal(self, x: paddle.Tensor, gate: nn.Layer):
|
||||
def forward_normal(
|
||||
self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta, topk_ids_hookfunc: Callable = None
|
||||
):
|
||||
"""
|
||||
Normal mode of forward.
|
||||
|
||||
@@ -702,5 +732,5 @@ class FusedMoE(nn.Layer):
|
||||
Tensor: Output tensor.s
|
||||
|
||||
"""
|
||||
out = self.quant_method.apply(self, x, gate)
|
||||
out = self.quant_method.apply(self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
return out
|
||||
|
||||
346
fastdeploy/model_executor/layers/moe/routing_indices_cache.py
Normal file
346
fastdeploy/model_executor/layers/moe/routing_indices_cache.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import os
|
||||
import shutil
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import paddle
|
||||
import paddle.distributed as dist
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _save_routing_kernel(
|
||||
ROUTING_REPLAY_TABLE_PTR,
|
||||
TOPK_IDS_PTR,
|
||||
BATCH_ID_PER_TOKEN_PTR,
|
||||
CU_SEQLENS_Q_PTR,
|
||||
SEQ_LENS_DECODER_PTR,
|
||||
LAYER_IDX,
|
||||
TOKEN_NUM,
|
||||
TOP_K,
|
||||
NUM_HIDDEN_LAYERS,
|
||||
MAX_MODEL_LEN,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
):
|
||||
pid_m = tl.program_id(axis=0)
|
||||
|
||||
token_offsets = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
token_mask = token_offsets < TOKEN_NUM
|
||||
|
||||
k_offsets = tl.arange(0, BLOCK_SIZE_K)
|
||||
|
||||
k_mask = k_offsets < TOP_K
|
||||
|
||||
topk_ids_ptrs = TOPK_IDS_PTR + token_offsets[:, None] * TOP_K + k_offsets[None, :]
|
||||
# [BLOCK_SIZE_M, BLOCK_SIZE_K]
|
||||
|
||||
load_mask = token_mask[:, None] & k_mask[None, :]
|
||||
topk_vals = tl.load(topk_ids_ptrs, mask=load_mask)
|
||||
|
||||
batch_ids = tl.load(BATCH_ID_PER_TOKEN_PTR + token_offsets, mask=token_mask)
|
||||
pad_mask = token_mask & (batch_ids != -1)
|
||||
# [0, 3, 4, 10, 12][0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 3, 3]
|
||||
# -> [0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 10, 10]
|
||||
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] - [0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 10, 10]
|
||||
# -> [0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 0, 1]
|
||||
start_offsets = tl.load(CU_SEQLENS_Q_PTR + batch_ids, mask=pad_mask)
|
||||
token_relative_index = token_offsets - start_offsets
|
||||
|
||||
# [BLOCK_SIZE_M]
|
||||
len_decoder = tl.load(SEQ_LENS_DECODER_PTR + batch_ids, mask=pad_mask)
|
||||
token_seq_pos = len_decoder + token_relative_index
|
||||
|
||||
STRIDE_BUF_SEQ = NUM_HIDDEN_LAYERS * MAX_MODEL_LEN * TOP_K
|
||||
STRIDE_BUF_LAYER = MAX_MODEL_LEN * TOP_K
|
||||
STRIDE_BUF_TOKEN = TOP_K
|
||||
|
||||
# [BLOCK_SIZE_M, BLOCK_SIZE_K]
|
||||
output_ptrs = (
|
||||
ROUTING_REPLAY_TABLE_PTR
|
||||
+ batch_ids[:, None] * STRIDE_BUF_SEQ
|
||||
+ LAYER_IDX * STRIDE_BUF_LAYER
|
||||
+ token_seq_pos[:, None] * STRIDE_BUF_TOKEN
|
||||
+ k_offsets[None, :]
|
||||
)
|
||||
|
||||
pos_mask = token_seq_pos < MAX_MODEL_LEN
|
||||
pos_mask = pos_mask & pad_mask
|
||||
|
||||
# [BLOCK_SIZE_M, BLOCK_SIZE_K]
|
||||
pos_mask = pos_mask[:, None] & k_mask[None, :]
|
||||
|
||||
final_mask = load_mask & pos_mask
|
||||
|
||||
tl.store(output_ptrs, topk_vals, mask=final_mask)
|
||||
|
||||
|
||||
def save_routing_to_buffer(
|
||||
routing_replay_table: paddle.Tensor, # [max_num_seqs, num_layers, max_len, top_k]
|
||||
topk_ids: paddle.Tensor, # [token_num, top_k]
|
||||
batch_id_per_token: paddle.Tensor, # [token_num, 1]
|
||||
seq_lens_decoder: paddle.Tensor, # [max_num_seqs, 1]
|
||||
cu_seqlens_q: paddle.Tensor, # [max_num_seqs + 1, 1]
|
||||
layer_idx: int,
|
||||
tp_size: int,
|
||||
ep_size: int,
|
||||
tp_group: dist.communication.group.Group,
|
||||
):
|
||||
if tp_size > 1 and ep_size > 1:
|
||||
token_num_per_rank = topk_ids.shape[0]
|
||||
topk_ids_all = paddle.zeros([token_num_per_rank * tp_size, topk_ids.shape[1]], dtype=topk_ids.dtype)
|
||||
paddle.distributed.all_gather(topk_ids_all, topk_ids, tp_group)
|
||||
topk_ids = topk_ids_all[: batch_id_per_token.shape[0], :]
|
||||
|
||||
token_num, top_k = topk_ids.shape
|
||||
max_num_seqs, num_hidden_layers, max_model_len, _ = routing_replay_table.shape
|
||||
assert token_num > 0
|
||||
|
||||
assert topk_ids.shape[1] == routing_replay_table.shape[3], (topk_ids.shape[1], routing_replay_table.shape[3])
|
||||
assert batch_id_per_token.shape[0] == token_num, (batch_id_per_token.shape[0], token_num)
|
||||
assert seq_lens_decoder.shape[0] == max_num_seqs, (seq_lens_decoder.shape[0], max_num_seqs)
|
||||
|
||||
BLOCK_SIZE_M = 128
|
||||
BLOCK_SIZE_K = triton.next_power_of_2(top_k) # top_k
|
||||
|
||||
grid = (triton.cdiv(token_num, BLOCK_SIZE_M),)
|
||||
_save_routing_kernel[grid](
|
||||
routing_replay_table,
|
||||
topk_ids,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens_decoder,
|
||||
LAYER_IDX=layer_idx,
|
||||
TOKEN_NUM=token_num,
|
||||
TOP_K=top_k,
|
||||
NUM_HIDDEN_LAYERS=num_hidden_layers,
|
||||
MAX_MODEL_LEN=max_model_len,
|
||||
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
)
|
||||
|
||||
|
||||
class RoutingReplayManager:
|
||||
"""Request level routing replay table manager"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config: FDConfig,
|
||||
):
|
||||
self.max_num_seqs = fd_config.scheduler_config.max_num_seqs
|
||||
self.max_model_len = fd_config.model_config.max_model_len
|
||||
self.num_moe_layers = fd_config.model_config.num_hidden_layers - fd_config.model_config.moe_layer_start_index
|
||||
|
||||
if fd_config.model_config.architectures[0] == "Glm4MoeForCausalLM":
|
||||
self.moe_top_k = fd_config.model_config.num_experts_per_tok
|
||||
else:
|
||||
self.moe_top_k = fd_config.model_config.moe_k
|
||||
self.tp_rank = fd_config.parallel_config.tensor_parallel_rank
|
||||
|
||||
self.routing_store = get_routing_store(fd_config=fd_config)
|
||||
self.routing_batch_to_request: Dict[int, str] = {}
|
||||
self.routing_replay_table = paddle.full(
|
||||
shape=[self.max_num_seqs, self.num_moe_layers, self.max_model_len, self.moe_top_k],
|
||||
fill_value=-1,
|
||||
dtype="int32",
|
||||
)
|
||||
|
||||
def register_request(self, batch_id: int, request_id: str):
|
||||
"""
|
||||
Register a new request to routing replay table
|
||||
Args:
|
||||
batch_id: The batch ID of this request
|
||||
request_id: The global ID of the request is usually executed by the training process in RL
|
||||
"""
|
||||
# Save requests that have been finished for the current slot
|
||||
if batch_id in self.routing_batch_to_request:
|
||||
pre_request_id = self._deregister_request(batch_id)
|
||||
self._put_request_to_store(batch_id, pre_request_id)
|
||||
# Register the new request
|
||||
self.routing_batch_to_request[batch_id] = request_id
|
||||
|
||||
def _deregister_request(self, batch_id: int) -> str:
|
||||
"""
|
||||
Deregister a request from routing replay table
|
||||
"""
|
||||
assert batch_id in self.routing_batch_to_request
|
||||
return self.routing_batch_to_request.pop(batch_id)
|
||||
|
||||
def _put_request_to_store(
|
||||
self,
|
||||
batch_id: int,
|
||||
request_id: str,
|
||||
):
|
||||
if self.tp_rank == 0:
|
||||
batch_buffer = self.routing_replay_table[batch_id]
|
||||
for layer_id in range(self.num_moe_layers):
|
||||
layer_buffer = batch_buffer[layer_id]
|
||||
rollout_id = self.split_request_id(request_id)
|
||||
self.routing_store.put(routing_indices=layer_buffer, rollout_id=rollout_id, layer_idx=layer_id)
|
||||
|
||||
self._clear_table_slot(batch_id)
|
||||
|
||||
def put_table_to_store(self):
|
||||
"""Put the routing table"""
|
||||
batch_ids = copy.deepcopy(list(self.routing_batch_to_request.keys()))
|
||||
for batch_id in batch_ids:
|
||||
request_id = self._deregister_request(batch_id)
|
||||
self._put_request_to_store(batch_id, request_id)
|
||||
|
||||
def _clear_table_slot(self, batch_id: int):
|
||||
assert 0 <= batch_id < self.max_num_seqs
|
||||
self.routing_replay_table[batch_id].fill_(-1)
|
||||
|
||||
def clear_routing_table(self):
|
||||
"""Clear all slots of the routing replay table"""
|
||||
self.routing_replay_table.fill_(-1)
|
||||
|
||||
def _clear_store(self):
|
||||
"""Clear routing store"""
|
||||
self.routing_store.clear_store()
|
||||
|
||||
def _clear_request_of_store(self, request_id):
|
||||
"""Clear one request of routing store"""
|
||||
rollout_id = self.split_request_id(request_id)
|
||||
for layer_idx in range(self.num_moe_layers):
|
||||
self.routing_store.clear(rollout_id=rollout_id, layer_idx=layer_idx)
|
||||
|
||||
def get_request_from_store(self, request_id: str) -> List[paddle.Tensor]:
|
||||
"""Get the routing indices of the request from store"""
|
||||
routing_list = []
|
||||
rollout_id = self.split_request_id(request_id)
|
||||
for layer_idx in range(self.num_moe_layers):
|
||||
one_layer_routing = self.routing_store.get(rollout_id, layer_idx)
|
||||
routing_list.append(one_layer_routing)
|
||||
|
||||
return routing_list
|
||||
|
||||
def get_routing_table(self) -> paddle.Tensor:
|
||||
return self.routing_replay_table
|
||||
|
||||
def split_request_id(self, request_id: str):
|
||||
"""Split the request id to get rollout id"""
|
||||
chat_type, tmp_str = request_id.split("-", 1)
|
||||
# NOTE(gongshaotian): only support chatcmpl now
|
||||
# assert chat_type == "chatcmpl"
|
||||
reversed_tmp_str = tmp_str[::-1].split("-", 5)
|
||||
rollout_id = reversed_tmp_str[-1][::-1]
|
||||
return rollout_id
|
||||
|
||||
|
||||
class RoutingStoreBase(ABC):
|
||||
"""Base class for routing store"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig) -> None:
|
||||
self.fd_config = fd_config
|
||||
|
||||
@abstractmethod
|
||||
def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: Optional[int] = None) -> None:
|
||||
"""Put the routing indices into store"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get(self, rollout_id: str, layer_idx: Optional[int] = None) -> paddle.Tensor:
|
||||
"""Get the routing indices from store"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def clear(self, rollout_id: str, layer_idx: Optional[int] = None) -> None:
|
||||
"""Clear the routing indices of the request"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def clear_store(
|
||||
self,
|
||||
):
|
||||
"""Clear the routing indices store"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RoutingStoreLocal(RoutingStoreBase):
|
||||
"""Routing Store using local memory"""
|
||||
|
||||
def __init__(self, fd_config) -> None:
|
||||
super().__init__(fd_config=fd_config)
|
||||
self.local_store_dir = fd_config.routing_replay_config.local_store_dir
|
||||
|
||||
def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: int) -> None:
|
||||
"""Put the routing indices into store"""
|
||||
dir_path = os.path.join(self.local_store_dir, f"{rollout_id}")
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
file_path = os.path.join(dir_path, f"layer_{layer_idx}.pdtensor")
|
||||
paddle.save(routing_indices, file_path)
|
||||
|
||||
def get(
|
||||
self,
|
||||
rollout_id: str,
|
||||
layer_idx: int = None,
|
||||
) -> paddle.Tensor:
|
||||
"""Get the routing indices from store"""
|
||||
dir_path = os.path.join(self.local_store_dir, f"{rollout_id}")
|
||||
file_path = os.path.join(dir_path, f"layer_{layer_idx}.pdtensor")
|
||||
assert os.path.exists(file_path), f"File not found: {file_path}"
|
||||
layer_routing_indices = paddle.load(file_path)
|
||||
|
||||
return layer_routing_indices
|
||||
|
||||
def clear(
|
||||
self,
|
||||
rollout_id: str,
|
||||
layer_idx: int = None,
|
||||
) -> None:
|
||||
"""Clear the routing indices of the request"""
|
||||
dir_path = os.path.join(self.local_store_dir, f"{rollout_id}")
|
||||
file_path = os.path.join(dir_path, f"layer_{layer_idx}.pdtensor")
|
||||
assert os.path.exists(file_path), f"File not found: {file_path}"
|
||||
os.remove(file_path)
|
||||
|
||||
# Delete empty directory
|
||||
if len(os.listdir(dir_path)) == 0:
|
||||
os.rmdir(dir_path)
|
||||
|
||||
def clear_store(self):
|
||||
"""Clear the routing indices store"""
|
||||
if os.path.isdir(self.local_store_dir):
|
||||
for file_name in os.listdir(self.local_store_dir):
|
||||
file_path = os.path.join(self.local_store_dir, file_name)
|
||||
shutil.rmtree(file_path)
|
||||
|
||||
|
||||
class RoutingStoreRDMA(RoutingStoreBase):
|
||||
"""Routing Store using RDMA"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
def get_routing_store(fd_config: FDConfig) -> RoutingStoreBase:
|
||||
if fd_config.routing_replay_config.routing_store_type == "local":
|
||||
return RoutingStoreLocal(fd_config=fd_config)
|
||||
elif fd_config.routing_replay_config.routing_store_type == "rdma":
|
||||
return RoutingStoreRDMA(fd_config=fd_config)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid routing store type: '{fd_config.routing_replay_config.routing_store_type}'. "
|
||||
"Valid types are: 'local', 'rdma'"
|
||||
)
|
||||
@@ -161,7 +161,7 @@ class Glm4Moe(nn.Layer):
|
||||
reduce_results=False,
|
||||
)
|
||||
|
||||
def forward(self, x, forward_meta):
|
||||
def forward(self, x, forward_meta: ForwardMeta = None):
|
||||
shared_experts_out = self.shared_experts(x)
|
||||
out = self.experts(x, self.gate, forward_meta)
|
||||
out = out + shared_experts_out
|
||||
@@ -306,10 +306,7 @@ class Glm4MoeDecoderLayer(nn.Layer):
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
|
||||
hidden_states = self.mlp(
|
||||
hidden_states,
|
||||
forward_meta,
|
||||
)
|
||||
hidden_states = self.mlp(hidden_states, forward_meta)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@@ -65,6 +65,7 @@ class RolloutModelConfig:
|
||||
data_parallel_size: int = 1,
|
||||
num_nextn_predict_layers: int = 0,
|
||||
eplb_config: str = {},
|
||||
routing_replay_config: str = None,
|
||||
):
|
||||
# Required parameters
|
||||
self.model = model_name_or_path
|
||||
@@ -113,6 +114,7 @@ class RolloutModelConfig:
|
||||
self.plas_attention_config = plas_attention_config
|
||||
self.num_nextn_predict_layers = num_nextn_predict_layers
|
||||
self.eplb_config = eplb_config
|
||||
self.routing_replay_config = routing_replay_config
|
||||
|
||||
def __str__(self):
|
||||
return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items())
|
||||
|
||||
@@ -45,6 +45,9 @@ from fastdeploy.model_executor.layers.attention.append_attn_backend import (
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.moe.routing_indices_cache import (
|
||||
RoutingReplayManager,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.rotary_embedding import get_rope, get_rope_3d
|
||||
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
|
||||
from fastdeploy.model_executor.layers.sample.sampler import Sampler, SpeculativeSampler
|
||||
@@ -202,6 +205,11 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.parallel_config.engine_worker_queue_port)
|
||||
logger.info(f"queue id is {str(self.parallel_config.engine_worker_queue_port)}")
|
||||
|
||||
# Rollout routing replay config
|
||||
self.routing_replay_manager = None
|
||||
if self.fd_config.routing_replay_config.enable_routing_replay:
|
||||
self.routing_replay_manager = RoutingReplayManager(fd_config=self.fd_config)
|
||||
|
||||
self.zmq_client = None
|
||||
self.async_output_queue = None
|
||||
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
|
||||
@@ -648,6 +656,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = 0
|
||||
self.share_inputs["prompt_lens"][idx : idx + 1] = len(input_ids)
|
||||
self.share_inputs["is_block_step"][idx : idx + 1] = False
|
||||
self.share_inputs["is_chunk_step"][idx : idx + 1] = prefill_end_index < len(input_ids)
|
||||
self.share_inputs["step_idx"][idx : idx + 1] = (
|
||||
len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0
|
||||
)
|
||||
@@ -656,6 +665,12 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
if request.sampling_params is not None and request.sampling_params.prompt_logprobs is not None:
|
||||
self.prompt_logprobs_reqs[request.request_id] = request
|
||||
has_prefill_task = True
|
||||
|
||||
# Routing Replay
|
||||
if self.fd_config.routing_replay_config.enable_routing_replay:
|
||||
if prefill_start_index == 0:
|
||||
self.routing_replay_manager.register_request(batch_id=idx, request_id=request.request_id)
|
||||
|
||||
if (
|
||||
self.fd_config.scheduler_config.splitwise_role == "decode"
|
||||
): # In PD, we continue to decode after P generate first token
|
||||
@@ -1148,6 +1163,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["bad_tokens_len"] = paddle.full([max_num_seqs], 1, dtype="int64")
|
||||
self.share_inputs["next_tokens"] = paddle.full([max_num_seqs, 1], -1, dtype="int64")
|
||||
self.share_inputs["is_block_step"] = paddle.full([max_num_seqs], False, dtype="bool")
|
||||
self.share_inputs["is_chunk_step"] = paddle.full([max_num_seqs], False, dtype="bool").cpu()
|
||||
self.share_inputs["encoder_block_lens"] = paddle.full([max_num_seqs], 0, dtype="int32")
|
||||
self.share_inputs["step_block_list"] = paddle.full([max_num_seqs], -1, dtype="int32")
|
||||
self.share_inputs["step_lens"] = paddle.full([1], 0, dtype="int32")
|
||||
@@ -1418,6 +1434,9 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
Initialize forward meta, attention meta data and update some config.
|
||||
"""
|
||||
# Initialize forward meta
|
||||
routing_replay_table = None
|
||||
if self.routing_replay_manager is not None:
|
||||
routing_replay_table = self.routing_replay_manager.get_routing_table()
|
||||
self.forward_meta = ForwardMeta(
|
||||
ids_remove_padding=self.share_inputs["ids_remove_padding"],
|
||||
rotary_embs=self.share_inputs["rope_emb"],
|
||||
@@ -1444,6 +1463,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
kv_batch_ids=self.share_inputs["kv_batch_ids"],
|
||||
kv_tile_ids_per_batch=self.share_inputs["kv_tile_ids_per_batch"],
|
||||
kv_num_blocks_x_cpu=self.share_inputs["kv_num_blocks_x_cpu"],
|
||||
routing_replay_table=routing_replay_table,
|
||||
)
|
||||
|
||||
dist_status = self.collect_distributed_status()
|
||||
@@ -1932,6 +1952,9 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
if int((self.share_inputs["seq_lens_this_time"] > 0).sum()) == 0:
|
||||
break
|
||||
|
||||
if self.fd_config.routing_replay_config.enable_routing_replay:
|
||||
self.routing_replay_manager.clear_routing_table()
|
||||
|
||||
def _update_chunked_prefill(self, tasks):
|
||||
"""
|
||||
Update chunked prefill related parameters
|
||||
@@ -2429,6 +2452,15 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.speculative_config.num_speculative_tokens,
|
||||
)
|
||||
|
||||
# Routing replay
|
||||
if self.fd_config.routing_replay_config.enable_routing_replay:
|
||||
if (
|
||||
not self.exist_prefill()
|
||||
and not self.exist_decode()
|
||||
and self.share_inputs["is_block_step"].sum() == 0
|
||||
and self.share_inputs["is_chunk_step"].sum() == 0
|
||||
):
|
||||
self.routing_replay_manager.put_table_to_store()
|
||||
return None
|
||||
|
||||
def _pool(self, hidden_states: paddle.Tensor, num_running_requests: int) -> Optional[ModelRunnerOutput]:
|
||||
|
||||
@@ -38,6 +38,7 @@ from fastdeploy.config import (
|
||||
ModelConfig,
|
||||
ParallelConfig,
|
||||
PlasAttentionConfig,
|
||||
RoutingReplayConfig,
|
||||
SpeculativeConfig,
|
||||
StructuredOutputsConfig,
|
||||
)
|
||||
@@ -885,6 +886,13 @@ def parse_args():
|
||||
help="EPLB Configuration.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--routing_replay_config",
|
||||
type=json.loads,
|
||||
default=None,
|
||||
help="Configation of Rollout Routing Replay.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
@@ -944,6 +952,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
eplb_config = EPLBConfig(args.eplb_config)
|
||||
|
||||
structured_outputs_config: StructuredOutputsConfig = StructuredOutputsConfig(args=vars(args))
|
||||
routing_replay_config = RoutingReplayConfig(args.routing_replay_config)
|
||||
|
||||
# Note(tangbinhan): used for load_checkpoint
|
||||
model_config.pretrained_config.tensor_parallel_rank = parallel_config.tensor_parallel_rank
|
||||
@@ -1003,6 +1012,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
plas_attention_config=plas_attention_config,
|
||||
structured_outputs_config=structured_outputs_config,
|
||||
eplb_config=eplb_config,
|
||||
routing_replay_config=routing_replay_config,
|
||||
)
|
||||
update_fd_config_for_mm(fd_config)
|
||||
if fd_config.load_config.load_choices == "default_v1" and not v1_loader_support(fd_config):
|
||||
|
||||
@@ -90,7 +90,7 @@ class MockAttentionBackend:
|
||||
|
||||
|
||||
class MockQuantMethod:
|
||||
def apply(self, layer, x, gate):
|
||||
def apply(self, layer, x, gate, topk_ids_hookfunc=None):
|
||||
return x
|
||||
|
||||
|
||||
@@ -129,6 +129,7 @@ class TestChunkedMoE(unittest.TestCase):
|
||||
model_runner.speculative_decoding = False
|
||||
model_runner._init_share_inputs(mock_fd_config.scheduler_config.max_num_seqs)
|
||||
model_runner.share_inputs["caches"] = None
|
||||
model_runner.routing_replay_manager = None
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
model_runner.share_inputs["ids_remove_padding"] = paddle.ones([10])
|
||||
@@ -148,6 +149,7 @@ class TestChunkedMoE(unittest.TestCase):
|
||||
|
||||
fused_moe.fd_config = mock_fd_config
|
||||
fused_moe.quant_method = MockQuantMethod()
|
||||
fused_moe.enable_routing_replay = None
|
||||
return fused_moe
|
||||
|
||||
def run_model_runner(self):
|
||||
|
||||
@@ -78,6 +78,8 @@ def setup_and_run_server():
|
||||
"wint4",
|
||||
"--graph-optimization-config",
|
||||
'{"cudagraph_capture_sizes": [1], "use_cudagraph":true}',
|
||||
"--routing-replay-config",
|
||||
'{"enable_routing_replay":true, "routing_store_type":"local", "local_store_dir":"./routing_replay_output"}',
|
||||
]
|
||||
|
||||
# Start subprocess in new process group
|
||||
|
||||
@@ -31,6 +31,7 @@ from fastdeploy.config import (
|
||||
LoadConfig,
|
||||
ModelConfig,
|
||||
ParallelConfig,
|
||||
RoutingReplayConfig,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
|
||||
from fastdeploy.model_executor.layers.quantization.block_wise_fp8 import (
|
||||
@@ -476,6 +477,7 @@ class FuseMoEWrapper(paddle.nn.Layer):
|
||||
graph_opt_config=GraphOptimizationConfig({}),
|
||||
load_config=LoadConfig({}),
|
||||
ips=",".join(["0"] * nnodes),
|
||||
routing_replay_config=RoutingReplayConfig({}),
|
||||
)
|
||||
self.fd_config.parallel_config.tp_group = None
|
||||
self.fd_config.parallel_config.tensor_parallel_rank = tp_rank
|
||||
|
||||
@@ -13,6 +13,7 @@ from fastdeploy.config import (
|
||||
LoadConfig,
|
||||
ModelConfig,
|
||||
ParallelConfig,
|
||||
RoutingReplayConfig,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
|
||||
from fastdeploy.model_executor.layers.quantization.w4a8 import W4A8Config
|
||||
@@ -59,6 +60,7 @@ class FuseMoEWrapper(paddle.nn.Layer):
|
||||
graph_opt_config=GraphOptimizationConfig({}),
|
||||
load_config=LoadConfig({}),
|
||||
ips=",".join(["0"] * nnodes),
|
||||
routing_replay_config=RoutingReplayConfig({}),
|
||||
)
|
||||
self.fd_config.parallel_config.tp_group = None
|
||||
self.fd_config.parallel_config.tensor_parallel_rank = tp_rank
|
||||
|
||||
@@ -13,6 +13,7 @@ from fastdeploy.config import (
|
||||
LoadConfig,
|
||||
ModelConfig,
|
||||
ParallelConfig,
|
||||
RoutingReplayConfig,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
|
||||
from fastdeploy.model_executor.layers.quantization.w4afp8 import W4AFP8Config
|
||||
@@ -65,6 +66,7 @@ class FuseMoEWrapper(paddle.nn.Layer):
|
||||
graph_opt_config=GraphOptimizationConfig({}),
|
||||
load_config=LoadConfig({}),
|
||||
ips=",".join(["0"] * nnodes),
|
||||
routing_replay_config=RoutingReplayConfig({}),
|
||||
)
|
||||
self.fd_config.parallel_config.tp_group = None
|
||||
self.fd_config.parallel_config.tensor_parallel_rank = tp_rank
|
||||
|
||||
Reference in New Issue
Block a user