mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[New][RL] Support Rollout Routing Replay (#5405)
* [RL] Support Rollout Routing Replay
* add routing indices cache
* fix config bug and moe forward bug
* R3 Support GLM
* support eb4.5
* fix merge bug
* Apply suggestion from @Copilot
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* Apply suggestion from @Copilot
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* Apply suggestion from @Copilot
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* Apply suggestion from @Copilot
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* add routing replay ci
* support glm topk
* support orther top_k
* fix ci bug
* pre-commit
* only support chatcmpl
* Revert "Revert "[RL] Support Rollout Routing Replay (#5321)" (#5402)"
This reverts commit c45e064f3d.
* Fix XPU and NPU bug
---------
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Yuanle Liu <yuanlehome@163.com>
This commit is contained in:
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
@@ -101,6 +103,7 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Triton compute Fused MoE.
|
||||
@@ -117,6 +120,8 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
scores += layer.gate_correction_bias
|
||||
topk_weights, topk_ids = paddle.topk(scores, k=top_k, axis=-1, sorted=False)
|
||||
topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdim=True)
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_ids)
|
||||
|
||||
intermediate_cache1 = paddle.empty(
|
||||
[token_num * top_k, moe_intermediate_size * 2],
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
@@ -182,6 +183,7 @@ class GCUFusedMoeMethod(UnquantizedFusedMoEMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle gcu compute Fused MoE.
|
||||
@@ -194,6 +196,7 @@ class GCUFusedMoeMethod(UnquantizedFusedMoEMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP prefill method.
|
||||
@@ -205,6 +208,7 @@ class GCUFusedMoeMethod(UnquantizedFusedMoEMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP decoder method.
|
||||
@@ -216,6 +220,7 @@ class GCUFusedMoeMethod(UnquantizedFusedMoEMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
@@ -381,6 +386,7 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle gcu compute Fused MoE.
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
@@ -48,6 +50,7 @@ class HpuMoEMethod(UnquantizedFusedMoEMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP prefill method.
|
||||
@@ -59,6 +62,7 @@ class HpuMoEMethod(UnquantizedFusedMoEMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP decoder method.
|
||||
@@ -70,6 +74,7 @@ class HpuMoEMethod(UnquantizedFusedMoEMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle hpu Fused MoE.
|
||||
@@ -142,6 +147,7 @@ class HpuTensorWiseFP8MoEMethod(HpuMoEMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP prefill method.
|
||||
@@ -153,6 +159,7 @@ class HpuTensorWiseFP8MoEMethod(HpuMoEMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP decoder method.
|
||||
@@ -164,6 +171,7 @@ class HpuTensorWiseFP8MoEMethod(HpuMoEMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle hpu Fused MoE.
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
@@ -245,6 +247,7 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Triton compute Fused MoE.
|
||||
@@ -274,6 +277,9 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
True, # apply_norm_weight
|
||||
False,
|
||||
)
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_ids)
|
||||
|
||||
up_gate_proj_out = paddle.empty(
|
||||
[token_num * top_k, moe_intermediate_size * 2],
|
||||
dtype=x.dtype,
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
@@ -235,6 +237,7 @@ class XPUMoEMethod(MoEMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply TP Fused Op.
|
||||
@@ -262,6 +265,7 @@ class XPUMoEMethod(MoEMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply TP Scatter Op.
|
||||
@@ -318,6 +322,7 @@ class XPUMoEMethod(MoEMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
apply tp
|
||||
@@ -368,6 +373,7 @@ class XPUMoEMethod(MoEMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP prefill method.
|
||||
@@ -442,6 +448,7 @@ class XPUMoEMethod(MoEMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP decoder method.
|
||||
@@ -488,6 +495,7 @@ class XPUMoEMethod(MoEMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
compute Fused MoE.
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
"""
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import Callable
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
@@ -163,6 +164,7 @@ class MoEMethodBase(QuantMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP prefill method.
|
||||
@@ -175,6 +177,7 @@ class MoEMethodBase(QuantMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP decoder method.
|
||||
@@ -187,6 +190,7 @@ class MoEMethodBase(QuantMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
@@ -198,6 +202,7 @@ class MoEMethodBase(QuantMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
@@ -207,13 +212,13 @@ class MoEMethodBase(QuantMethodBase):
|
||||
if layer.fd_config.model_config.moe_phase.phase == "prefill":
|
||||
if layer.fd_config.scheduler_config.splitwise_role == "mixed" and is_moe_start_layer:
|
||||
self.ep_prefill_runner.clean_low_latency_buffer()
|
||||
return self.apply_ep_prefill(layer, x, gate)
|
||||
return self.apply_ep_prefill(layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
else:
|
||||
if layer.fd_config.scheduler_config.splitwise_role == "mixed" and is_moe_start_layer:
|
||||
self.ep_decoder_runner.clean_low_latency_buffer()
|
||||
return self.apply_ep_decode(layer, x, gate)
|
||||
return self.apply_ep_decode(layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
else:
|
||||
return self.apply_tp(layer, x, gate)
|
||||
return self.apply_tp(layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
|
||||
|
||||
class UnquantizedFusedMoEMethod(MoEMethodBase):
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.nn.quant import weight_quantize
|
||||
@@ -132,6 +134,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP prefill method.
|
||||
@@ -148,8 +151,13 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
handle,
|
||||
event,
|
||||
) = self.ep_prefill_runner.dispatch(x, topk_idx, topk_weights)
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_idx)
|
||||
|
||||
if self.ep_prefill_runner.ep_engine.async_finish:
|
||||
event.current_stream_wait()
|
||||
|
||||
token_all_num = sum(recv_num_tokens_per_expert_list)
|
||||
|
||||
# 3. Compute ffn
|
||||
@@ -217,6 +225,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP decoder method.
|
||||
@@ -225,6 +234,10 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
estimate_total_token_nums = gate_out.shape[0] * layer.top_k
|
||||
# 1. Select topk experts and weights
|
||||
topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out)
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_idx)
|
||||
|
||||
expertwise_scale = None
|
||||
if hasattr(layer, "up_gate_proj_in_scale_all_experts"): # only use in w4a8
|
||||
expertwise_scale = getattr(layer, "up_gate_proj_in_scale_all_experts", None)
|
||||
@@ -269,6 +282,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
@@ -369,6 +383,9 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
if hasattr(layer, "up_gate_proj_in_scale"):
|
||||
dequant_scale = None
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_idx)
|
||||
|
||||
if not layer.with_bias and self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8":
|
||||
# only w4a8 need expert_idx_per_token
|
||||
# Other need not this tensor, so we make it None.
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.distributed.communication import deep_ep
|
||||
@@ -139,6 +141,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP prefill method.
|
||||
@@ -147,6 +150,10 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
|
||||
# 1. Select topk experts and weights
|
||||
topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out)
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_idx)
|
||||
|
||||
# 2. Dynamic compute blockwise quantization scales
|
||||
x, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
|
||||
x, self.quant_config.weight_block_size[0]
|
||||
@@ -264,6 +271,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP decoder method.
|
||||
@@ -271,6 +279,10 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
gate_out = gate(x.cast("float32"))
|
||||
# 1. Select topk experts and weights
|
||||
topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out)
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_idx)
|
||||
|
||||
# 2. EP Dispatch
|
||||
permute_input, token_nums_per_expert, handle = self.ep_decoder_runner.dispatch(
|
||||
x, topk_idx, topk_weights, use_fp8=True
|
||||
@@ -335,6 +347,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle Use DeepGemm compute Fused MoE.
|
||||
@@ -363,6 +376,9 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
False,
|
||||
)
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_ids)
|
||||
|
||||
tmp = count_tokens_per_expert_func(topk_ids, layer.num_experts)
|
||||
|
||||
recv_x, recv_x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(x, 128)
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
@@ -239,6 +241,7 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Marlin compute Fused MoE.
|
||||
@@ -273,6 +276,9 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
|
||||
False,
|
||||
)
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_ids)
|
||||
|
||||
block_size_m = 64
|
||||
|
||||
for m in [8, 16, 32, 48, 64]:
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
@@ -282,6 +284,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Triton compute Fused MoE.
|
||||
@@ -314,6 +317,10 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
True, # apply_norm_weight,
|
||||
False,
|
||||
)
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_ids)
|
||||
|
||||
up_gate_proj_out = paddle.empty(
|
||||
[token_num * top_k, moe_intermediate_size * 2],
|
||||
dtype=x.dtype,
|
||||
@@ -664,6 +671,7 @@ class Wfp8Afp8MoEMethod(QuantMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Triton compute Fused MoE.
|
||||
@@ -724,6 +732,9 @@ class Wfp8Afp8MoEMethod(QuantMethodBase):
|
||||
* ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_ids)
|
||||
|
||||
up_gate_proj_out = paddle.empty(
|
||||
[token_num * top_k, moe_intermediate_size * 2],
|
||||
dtype=x.dtype,
|
||||
@@ -953,6 +964,7 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Triton compute Fused MoE.
|
||||
@@ -974,6 +986,9 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
False,
|
||||
)
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_ids)
|
||||
|
||||
up_gate_proj_out = paddle.empty(
|
||||
[token_num * top_k, moe_intermediate_size * 2],
|
||||
dtype=x.dtype,
|
||||
@@ -1466,6 +1481,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Triton compute Fused MoE.
|
||||
@@ -1488,6 +1504,8 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
True, # apply_norm_weight
|
||||
False,
|
||||
)
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_ids)
|
||||
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
@@ -261,6 +263,7 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Use Wint2 Triton Fusedmoe compute Fused MoE.
|
||||
@@ -288,6 +291,9 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
|
||||
topk_only_mode=False,
|
||||
)
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_idx)
|
||||
|
||||
ffn_out = fastdeploy.model_executor.ops.gpu.moe_expert_ffn_wint2(
|
||||
permute_input,
|
||||
token_nums_per_expert,
|
||||
@@ -328,6 +334,7 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Use Wint2 Triton Fusedmoe compute Fused MoE.
|
||||
@@ -343,6 +350,9 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
|
||||
False,
|
||||
)
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_ids)
|
||||
|
||||
num_tokens, K = x.shape
|
||||
E, _, N = layer.up_gate_proj_weight.shape
|
||||
M = num_tokens
|
||||
|
||||
@@ -14,7 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from functools import partial
|
||||
from typing import Callable, Optional
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
@@ -26,6 +27,9 @@ from fastdeploy.distributed.communication import (
|
||||
tensor_model_parallel_all_reduce_custom,
|
||||
)
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
from fastdeploy.model_executor.layers.moe.routing_indices_cache import (
|
||||
save_routing_to_buffer,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
from fastdeploy.model_executor.utils import h2d_copy, slice_fn
|
||||
from fastdeploy.platforms import current_platform
|
||||
@@ -226,7 +230,7 @@ class FusedMoE(nn.Layer):
|
||||
self.is_rearrange = False
|
||||
if self.ep_size > 1:
|
||||
self.quant_method.init_ep(self)
|
||||
|
||||
self.enable_routing_replay = fd_config.routing_replay_config.enable_routing_replay
|
||||
# Merge normal and RL build model
|
||||
if gate_correction_bias is not None:
|
||||
self.gate_correction_bias = gate_correction_bias
|
||||
@@ -600,7 +604,7 @@ class FusedMoE(nn.Layer):
|
||||
else:
|
||||
self.quant_method.process_loaded_weights(self, state_dict)
|
||||
|
||||
def forward_split_allgather(self, x: paddle.Tensor, gate: nn.Layer):
|
||||
def forward_split_allgather(self, x: paddle.Tensor, gate: nn.Layer, topk_ids_hookfunc: Callable = None):
|
||||
"""
|
||||
Forward split allgather function.
|
||||
"""
|
||||
@@ -615,14 +619,14 @@ class FusedMoE(nn.Layer):
|
||||
if end_offset > token_num:
|
||||
end_offset = token_num
|
||||
part_x[: (end_offset - start_offset), :] = x[start_offset:end_offset, :]
|
||||
out = self.quant_method.apply(self, part_x, gate)
|
||||
out = self.quant_method.apply(self, part_x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
multi_outs = paddle.zeros([token_num_per_rank * self.attn_tp_size, x.shape[1]], dtype=x.dtype)
|
||||
paddle.distributed.all_gather(multi_outs, out, self.tp_group)
|
||||
out = multi_outs[:token_num, :]
|
||||
|
||||
return out
|
||||
|
||||
def forward(self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta):
|
||||
def forward(self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta = None):
|
||||
"""
|
||||
Defines the forward computation of the moe layer.
|
||||
|
||||
@@ -633,6 +637,21 @@ class FusedMoE(nn.Layer):
|
||||
Tensor: Output tensor.s
|
||||
|
||||
"""
|
||||
topk_ids_hookfunc = None
|
||||
if self.enable_routing_replay:
|
||||
if forward_meta is not None: # forward_meta is None when execute empty_input_forward
|
||||
topk_ids_hookfunc = partial(
|
||||
save_routing_to_buffer,
|
||||
routing_replay_table=forward_meta.routing_replay_table,
|
||||
batch_id_per_token=forward_meta.batch_id_per_token,
|
||||
seq_lens_decoder=forward_meta.seq_lens_decoder,
|
||||
cu_seqlens_q=forward_meta.cu_seqlens_q,
|
||||
layer_idx=self.layer_idx,
|
||||
tp_size=self.fd_config.parallel_config.tensor_parallel_size,
|
||||
ep_size=self.fd_config.parallel_config.expert_parallel_size,
|
||||
tp_group=self.fd_config.parallel_config.tp_group,
|
||||
)
|
||||
|
||||
token_num = x.shape[0]
|
||||
if (
|
||||
self.ep_size > 1
|
||||
@@ -640,11 +659,16 @@ class FusedMoE(nn.Layer):
|
||||
and (not self.fd_config.parallel_config.use_sequence_parallel_moe)
|
||||
and token_num >= self.attn_tp_size
|
||||
):
|
||||
out = self.forward_split_allgather(x, gate)
|
||||
out = self.forward_split_allgather(x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
elif self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.enable_chunked_moe:
|
||||
out = self.forward_chunked_moe(x, gate, forward_meta)
|
||||
out = self.forward_chunked_moe(
|
||||
x,
|
||||
gate,
|
||||
forward_meta,
|
||||
topk_ids_hookfunc=topk_ids_hookfunc,
|
||||
)
|
||||
else:
|
||||
out = self.forward_normal(x, gate)
|
||||
out = self.forward_normal(x, gate, forward_meta, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
if current_platform.is_intel_hpu():
|
||||
@@ -653,7 +677,9 @@ class FusedMoE(nn.Layer):
|
||||
out = tensor_model_parallel_all_reduce(out, self.tp_group)
|
||||
return out
|
||||
|
||||
def forward_chunked_moe(self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta):
|
||||
def forward_chunked_moe(
|
||||
self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta, topk_ids_hookfunc: Callable = None
|
||||
):
|
||||
"""
|
||||
Split input to multi chunk to reduce the memory usage of moe.
|
||||
|
||||
@@ -677,21 +703,25 @@ class FusedMoE(nn.Layer):
|
||||
|
||||
for i in range(forward_meta.max_moe_num_chunk):
|
||||
if i < forward_meta.moe_num_chunk:
|
||||
out_split_list[i] = self.quant_method.apply(self, x_split_list[i], gate)
|
||||
out_split_list[i] = self.quant_method.apply(
|
||||
self, x_split_list[i], gate, topk_ids_hookfunc=topk_ids_hookfunc
|
||||
)
|
||||
else:
|
||||
# just need to use real data to infer max_moe_num_chunk times.
|
||||
self.quant_method.apply(self, fake_x, gate)
|
||||
self.quant_method.apply(self, fake_x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
|
||||
out = paddle.concat(out_split_list, axis=0)
|
||||
else:
|
||||
# when only one chunk, just need to use real data to infer once.
|
||||
out = self.quant_method.apply(self, x, gate)
|
||||
out = self.quant_method.apply(self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
for i in range(forward_meta.max_moe_num_chunk - 1):
|
||||
self.quant_method.apply(self, fake_x, gate)
|
||||
self.quant_method.apply(self, fake_x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
|
||||
return out
|
||||
|
||||
def forward_normal(self, x: paddle.Tensor, gate: nn.Layer):
|
||||
def forward_normal(
|
||||
self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta, topk_ids_hookfunc: Callable = None
|
||||
):
|
||||
"""
|
||||
Normal mode of forward.
|
||||
|
||||
@@ -702,5 +732,5 @@ class FusedMoE(nn.Layer):
|
||||
Tensor: Output tensor.s
|
||||
|
||||
"""
|
||||
out = self.quant_method.apply(self, x, gate)
|
||||
out = self.quant_method.apply(self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
return out
|
||||
|
||||
346
fastdeploy/model_executor/layers/moe/routing_indices_cache.py
Normal file
346
fastdeploy/model_executor/layers/moe/routing_indices_cache.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import os
|
||||
import shutil
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import paddle
|
||||
import paddle.distributed as dist
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _save_routing_kernel(
|
||||
ROUTING_REPLAY_TABLE_PTR,
|
||||
TOPK_IDS_PTR,
|
||||
BATCH_ID_PER_TOKEN_PTR,
|
||||
CU_SEQLENS_Q_PTR,
|
||||
SEQ_LENS_DECODER_PTR,
|
||||
LAYER_IDX,
|
||||
TOKEN_NUM,
|
||||
TOP_K,
|
||||
NUM_HIDDEN_LAYERS,
|
||||
MAX_MODEL_LEN,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
):
|
||||
pid_m = tl.program_id(axis=0)
|
||||
|
||||
token_offsets = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
token_mask = token_offsets < TOKEN_NUM
|
||||
|
||||
k_offsets = tl.arange(0, BLOCK_SIZE_K)
|
||||
|
||||
k_mask = k_offsets < TOP_K
|
||||
|
||||
topk_ids_ptrs = TOPK_IDS_PTR + token_offsets[:, None] * TOP_K + k_offsets[None, :]
|
||||
# [BLOCK_SIZE_M, BLOCK_SIZE_K]
|
||||
|
||||
load_mask = token_mask[:, None] & k_mask[None, :]
|
||||
topk_vals = tl.load(topk_ids_ptrs, mask=load_mask)
|
||||
|
||||
batch_ids = tl.load(BATCH_ID_PER_TOKEN_PTR + token_offsets, mask=token_mask)
|
||||
pad_mask = token_mask & (batch_ids != -1)
|
||||
# [0, 3, 4, 10, 12][0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 3, 3]
|
||||
# -> [0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 10, 10]
|
||||
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] - [0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 10, 10]
|
||||
# -> [0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 0, 1]
|
||||
start_offsets = tl.load(CU_SEQLENS_Q_PTR + batch_ids, mask=pad_mask)
|
||||
token_relative_index = token_offsets - start_offsets
|
||||
|
||||
# [BLOCK_SIZE_M]
|
||||
len_decoder = tl.load(SEQ_LENS_DECODER_PTR + batch_ids, mask=pad_mask)
|
||||
token_seq_pos = len_decoder + token_relative_index
|
||||
|
||||
STRIDE_BUF_SEQ = NUM_HIDDEN_LAYERS * MAX_MODEL_LEN * TOP_K
|
||||
STRIDE_BUF_LAYER = MAX_MODEL_LEN * TOP_K
|
||||
STRIDE_BUF_TOKEN = TOP_K
|
||||
|
||||
# [BLOCK_SIZE_M, BLOCK_SIZE_K]
|
||||
output_ptrs = (
|
||||
ROUTING_REPLAY_TABLE_PTR
|
||||
+ batch_ids[:, None] * STRIDE_BUF_SEQ
|
||||
+ LAYER_IDX * STRIDE_BUF_LAYER
|
||||
+ token_seq_pos[:, None] * STRIDE_BUF_TOKEN
|
||||
+ k_offsets[None, :]
|
||||
)
|
||||
|
||||
pos_mask = token_seq_pos < MAX_MODEL_LEN
|
||||
pos_mask = pos_mask & pad_mask
|
||||
|
||||
# [BLOCK_SIZE_M, BLOCK_SIZE_K]
|
||||
pos_mask = pos_mask[:, None] & k_mask[None, :]
|
||||
|
||||
final_mask = load_mask & pos_mask
|
||||
|
||||
tl.store(output_ptrs, topk_vals, mask=final_mask)
|
||||
|
||||
|
||||
def save_routing_to_buffer(
|
||||
routing_replay_table: paddle.Tensor, # [max_num_seqs, num_layers, max_len, top_k]
|
||||
topk_ids: paddle.Tensor, # [token_num, top_k]
|
||||
batch_id_per_token: paddle.Tensor, # [token_num, 1]
|
||||
seq_lens_decoder: paddle.Tensor, # [max_num_seqs, 1]
|
||||
cu_seqlens_q: paddle.Tensor, # [max_num_seqs + 1, 1]
|
||||
layer_idx: int,
|
||||
tp_size: int,
|
||||
ep_size: int,
|
||||
tp_group: dist.communication.group.Group,
|
||||
):
|
||||
if tp_size > 1 and ep_size > 1:
|
||||
token_num_per_rank = topk_ids.shape[0]
|
||||
topk_ids_all = paddle.zeros([token_num_per_rank * tp_size, topk_ids.shape[1]], dtype=topk_ids.dtype)
|
||||
paddle.distributed.all_gather(topk_ids_all, topk_ids, tp_group)
|
||||
topk_ids = topk_ids_all[: batch_id_per_token.shape[0], :]
|
||||
|
||||
token_num, top_k = topk_ids.shape
|
||||
max_num_seqs, num_hidden_layers, max_model_len, _ = routing_replay_table.shape
|
||||
assert token_num > 0
|
||||
|
||||
assert topk_ids.shape[1] == routing_replay_table.shape[3], (topk_ids.shape[1], routing_replay_table.shape[3])
|
||||
assert batch_id_per_token.shape[0] == token_num, (batch_id_per_token.shape[0], token_num)
|
||||
assert seq_lens_decoder.shape[0] == max_num_seqs, (seq_lens_decoder.shape[0], max_num_seqs)
|
||||
|
||||
BLOCK_SIZE_M = 128
|
||||
BLOCK_SIZE_K = triton.next_power_of_2(top_k) # top_k
|
||||
|
||||
grid = (triton.cdiv(token_num, BLOCK_SIZE_M),)
|
||||
_save_routing_kernel[grid](
|
||||
routing_replay_table,
|
||||
topk_ids,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens_decoder,
|
||||
LAYER_IDX=layer_idx,
|
||||
TOKEN_NUM=token_num,
|
||||
TOP_K=top_k,
|
||||
NUM_HIDDEN_LAYERS=num_hidden_layers,
|
||||
MAX_MODEL_LEN=max_model_len,
|
||||
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
)
|
||||
|
||||
|
||||
class RoutingReplayManager:
|
||||
"""Request level routing replay table manager"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config: FDConfig,
|
||||
):
|
||||
self.max_num_seqs = fd_config.scheduler_config.max_num_seqs
|
||||
self.max_model_len = fd_config.model_config.max_model_len
|
||||
self.num_moe_layers = fd_config.model_config.num_hidden_layers - fd_config.model_config.moe_layer_start_index
|
||||
|
||||
if fd_config.model_config.architectures[0] == "Glm4MoeForCausalLM":
|
||||
self.moe_top_k = fd_config.model_config.num_experts_per_tok
|
||||
else:
|
||||
self.moe_top_k = fd_config.model_config.moe_k
|
||||
self.tp_rank = fd_config.parallel_config.tensor_parallel_rank
|
||||
|
||||
self.routing_store = get_routing_store(fd_config=fd_config)
|
||||
self.routing_batch_to_request: Dict[int, str] = {}
|
||||
self.routing_replay_table = paddle.full(
|
||||
shape=[self.max_num_seqs, self.num_moe_layers, self.max_model_len, self.moe_top_k],
|
||||
fill_value=-1,
|
||||
dtype="int32",
|
||||
)
|
||||
|
||||
def register_request(self, batch_id: int, request_id: str):
|
||||
"""
|
||||
Register a new request to routing replay table
|
||||
Args:
|
||||
batch_id: The batch ID of this request
|
||||
request_id: The global ID of the request is usually executed by the training process in RL
|
||||
"""
|
||||
# Save requests that have been finished for the current slot
|
||||
if batch_id in self.routing_batch_to_request:
|
||||
pre_request_id = self._deregister_request(batch_id)
|
||||
self._put_request_to_store(batch_id, pre_request_id)
|
||||
# Register the new request
|
||||
self.routing_batch_to_request[batch_id] = request_id
|
||||
|
||||
def _deregister_request(self, batch_id: int) -> str:
|
||||
"""
|
||||
Deregister a request from routing replay table
|
||||
"""
|
||||
assert batch_id in self.routing_batch_to_request
|
||||
return self.routing_batch_to_request.pop(batch_id)
|
||||
|
||||
def _put_request_to_store(
|
||||
self,
|
||||
batch_id: int,
|
||||
request_id: str,
|
||||
):
|
||||
if self.tp_rank == 0:
|
||||
batch_buffer = self.routing_replay_table[batch_id]
|
||||
for layer_id in range(self.num_moe_layers):
|
||||
layer_buffer = batch_buffer[layer_id]
|
||||
rollout_id = self.split_request_id(request_id)
|
||||
self.routing_store.put(routing_indices=layer_buffer, rollout_id=rollout_id, layer_idx=layer_id)
|
||||
|
||||
self._clear_table_slot(batch_id)
|
||||
|
||||
def put_table_to_store(self):
|
||||
"""Put the routing table"""
|
||||
batch_ids = copy.deepcopy(list(self.routing_batch_to_request.keys()))
|
||||
for batch_id in batch_ids:
|
||||
request_id = self._deregister_request(batch_id)
|
||||
self._put_request_to_store(batch_id, request_id)
|
||||
|
||||
def _clear_table_slot(self, batch_id: int):
|
||||
assert 0 <= batch_id < self.max_num_seqs
|
||||
self.routing_replay_table[batch_id].fill_(-1)
|
||||
|
||||
def clear_routing_table(self):
|
||||
"""Clear all slots of the routing replay table"""
|
||||
self.routing_replay_table.fill_(-1)
|
||||
|
||||
def _clear_store(self):
|
||||
"""Clear routing store"""
|
||||
self.routing_store.clear_store()
|
||||
|
||||
def _clear_request_of_store(self, request_id):
|
||||
"""Clear one request of routing store"""
|
||||
rollout_id = self.split_request_id(request_id)
|
||||
for layer_idx in range(self.num_moe_layers):
|
||||
self.routing_store.clear(rollout_id=rollout_id, layer_idx=layer_idx)
|
||||
|
||||
def get_request_from_store(self, request_id: str) -> List[paddle.Tensor]:
|
||||
"""Get the routing indices of the request from store"""
|
||||
routing_list = []
|
||||
rollout_id = self.split_request_id(request_id)
|
||||
for layer_idx in range(self.num_moe_layers):
|
||||
one_layer_routing = self.routing_store.get(rollout_id, layer_idx)
|
||||
routing_list.append(one_layer_routing)
|
||||
|
||||
return routing_list
|
||||
|
||||
def get_routing_table(self) -> paddle.Tensor:
|
||||
return self.routing_replay_table
|
||||
|
||||
def split_request_id(self, request_id: str):
|
||||
"""Split the request id to get rollout id"""
|
||||
chat_type, tmp_str = request_id.split("-", 1)
|
||||
# NOTE(gongshaotian): only support chatcmpl now
|
||||
# assert chat_type == "chatcmpl"
|
||||
reversed_tmp_str = tmp_str[::-1].split("-", 5)
|
||||
rollout_id = reversed_tmp_str[-1][::-1]
|
||||
return rollout_id
|
||||
|
||||
|
||||
class RoutingStoreBase(ABC):
|
||||
"""Base class for routing store"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig) -> None:
|
||||
self.fd_config = fd_config
|
||||
|
||||
@abstractmethod
|
||||
def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: Optional[int] = None) -> None:
|
||||
"""Put the routing indices into store"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get(self, rollout_id: str, layer_idx: Optional[int] = None) -> paddle.Tensor:
|
||||
"""Get the routing indices from store"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def clear(self, rollout_id: str, layer_idx: Optional[int] = None) -> None:
|
||||
"""Clear the routing indices of the request"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def clear_store(
|
||||
self,
|
||||
):
|
||||
"""Clear the routing indices store"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RoutingStoreLocal(RoutingStoreBase):
|
||||
"""Routing Store using local memory"""
|
||||
|
||||
def __init__(self, fd_config) -> None:
|
||||
super().__init__(fd_config=fd_config)
|
||||
self.local_store_dir = fd_config.routing_replay_config.local_store_dir
|
||||
|
||||
def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: int) -> None:
|
||||
"""Put the routing indices into store"""
|
||||
dir_path = os.path.join(self.local_store_dir, f"{rollout_id}")
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
file_path = os.path.join(dir_path, f"layer_{layer_idx}.pdtensor")
|
||||
paddle.save(routing_indices, file_path)
|
||||
|
||||
def get(
|
||||
self,
|
||||
rollout_id: str,
|
||||
layer_idx: int = None,
|
||||
) -> paddle.Tensor:
|
||||
"""Get the routing indices from store"""
|
||||
dir_path = os.path.join(self.local_store_dir, f"{rollout_id}")
|
||||
file_path = os.path.join(dir_path, f"layer_{layer_idx}.pdtensor")
|
||||
assert os.path.exists(file_path), f"File not found: {file_path}"
|
||||
layer_routing_indices = paddle.load(file_path)
|
||||
|
||||
return layer_routing_indices
|
||||
|
||||
def clear(
|
||||
self,
|
||||
rollout_id: str,
|
||||
layer_idx: int = None,
|
||||
) -> None:
|
||||
"""Clear the routing indices of the request"""
|
||||
dir_path = os.path.join(self.local_store_dir, f"{rollout_id}")
|
||||
file_path = os.path.join(dir_path, f"layer_{layer_idx}.pdtensor")
|
||||
assert os.path.exists(file_path), f"File not found: {file_path}"
|
||||
os.remove(file_path)
|
||||
|
||||
# Delete empty directory
|
||||
if len(os.listdir(dir_path)) == 0:
|
||||
os.rmdir(dir_path)
|
||||
|
||||
def clear_store(self):
|
||||
"""Clear the routing indices store"""
|
||||
if os.path.isdir(self.local_store_dir):
|
||||
for file_name in os.listdir(self.local_store_dir):
|
||||
file_path = os.path.join(self.local_store_dir, file_name)
|
||||
shutil.rmtree(file_path)
|
||||
|
||||
|
||||
class RoutingStoreRDMA(RoutingStoreBase):
|
||||
"""Routing Store using RDMA"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
def get_routing_store(fd_config: FDConfig) -> RoutingStoreBase:
|
||||
if fd_config.routing_replay_config.routing_store_type == "local":
|
||||
return RoutingStoreLocal(fd_config=fd_config)
|
||||
elif fd_config.routing_replay_config.routing_store_type == "rdma":
|
||||
return RoutingStoreRDMA(fd_config=fd_config)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid routing store type: '{fd_config.routing_replay_config.routing_store_type}'. "
|
||||
"Valid types are: 'local', 'rdma'"
|
||||
)
|
||||
Reference in New Issue
Block a user