mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Cherry-Pick][RL] Support Rollout Routing Replay (#5166)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* support r3 * update * support tp>1&&ep>1 * support cudagraph padding * support all backends * replace env with options * modularize * update * Add RoutingStore and refine code * add routing replay cofig * add routing repaly config * success run routing store * convert request id as rollout id * fix rollout config bug * unify code * use rollout_id to replace request_id in routing store * delete code --------- Co-authored-by: yuanlehome <yuanlehome@163.com>
This commit is contained in:
@@ -1163,6 +1163,31 @@ class CommitConfig:
|
||||
logger.info("=============================================================")
|
||||
|
||||
|
||||
class RoutingReplayConfig:
|
||||
"""Configuration for Routing Replay used in RL training"""
|
||||
|
||||
def __init__(self, args) -> None:
|
||||
self.enable_routing_replay: bool = False
|
||||
self.routing_store_type: str = "local"
|
||||
|
||||
# Local routing store
|
||||
self.local_store_dir: str = "./routing_replay_output"
|
||||
|
||||
# RDMA routing store
|
||||
pass
|
||||
|
||||
if args is not None:
|
||||
for key, value in args.items():
|
||||
if hasattr(self, key) and value != "None":
|
||||
setattr(self, key, value)
|
||||
|
||||
def to_json_string(self):
|
||||
"""
|
||||
Convert routing replay config to json string.
|
||||
"""
|
||||
return json.dumps({key: value for key, value in self.__dict__.items()})
|
||||
|
||||
|
||||
class FDConfig:
|
||||
"""
|
||||
The configuration class which contains all fastdeploy-related configuration. This
|
||||
@@ -1206,6 +1231,7 @@ class FDConfig:
|
||||
test_mode=False,
|
||||
enable_attention_dp_balance: bool = False,
|
||||
attention_dp_time_out_iters: int = 0,
|
||||
routing_replay_config: Optional[RoutingReplayConfig] = None,
|
||||
):
|
||||
self.model_config: ModelConfig = model_config # type: ignore
|
||||
self.cache_config: CacheConfig = cache_config # type: ignore
|
||||
@@ -1221,8 +1247,10 @@ class FDConfig:
|
||||
self.cache_config: CacheConfig = cache_config # type: ignore
|
||||
self.eplb_config: Optional[EPLBConfig] = eplb_config
|
||||
self.moba_attention_config: Optional[MobaAttentionConfig] = moba_attention_config
|
||||
self.routing_replay_config = routing_replay_config
|
||||
self.enable_attention_dp_balance = enable_attention_dp_balance
|
||||
self.attention_dp_time_out_iters = attention_dp_time_out_iters
|
||||
|
||||
# Initialize cuda graph capture list
|
||||
max_capture_shape = self.parallel_config.max_num_seqs
|
||||
if self.speculative_config is not None and self.speculative_config.method == "mtp":
|
||||
|
||||
@@ -33,6 +33,7 @@ from fastdeploy.config import (
|
||||
MobaAttentionConfig,
|
||||
ModelConfig,
|
||||
ParallelConfig,
|
||||
RoutingReplayConfig,
|
||||
SpeculativeConfig,
|
||||
TaskOption,
|
||||
)
|
||||
@@ -421,6 +422,11 @@ class EngineArgs:
|
||||
Configuration for eplb.
|
||||
"""
|
||||
|
||||
routing_replay_config: Optional[Dict[str, Any]] = None
|
||||
"""
|
||||
Flag to rollout routing replay(r3)
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
Post-initialization processing to set default tokenizer if not provided.
|
||||
@@ -731,6 +737,12 @@ class EngineArgs:
|
||||
default=EngineArgs.eplb_config,
|
||||
help="Config of eplb.",
|
||||
)
|
||||
parallel_group.add_argument(
|
||||
"--routing-replay-config",
|
||||
type=json.loads,
|
||||
default=EngineArgs.routing_replay_config,
|
||||
help="Flag of rollout routing replay(r3).",
|
||||
)
|
||||
|
||||
# Load group
|
||||
load_group = parser.add_argument_group("Load Configuration")
|
||||
@@ -1076,6 +1088,14 @@ class EngineArgs:
|
||||
eplb_args[k] = v
|
||||
return EPLBConfig(eplb_args)
|
||||
|
||||
def create_routing_repaly_config(self) -> RoutingReplayConfig:
|
||||
""" """
|
||||
routing_replay_args = asdict(self)
|
||||
if self.routing_replay_config is not None:
|
||||
for k, v in self.routing_replay_config.items():
|
||||
routing_replay_args[k] = v
|
||||
return RoutingReplayConfig(routing_replay_args)
|
||||
|
||||
def create_engine_config(self, port_availability_check: bool = True) -> FDConfig:
|
||||
"""
|
||||
Create and return a Config object based on the current settings.
|
||||
@@ -1118,6 +1138,7 @@ class EngineArgs:
|
||||
graph_opt_cfg.update_use_cudagraph(self.use_cudagraph)
|
||||
moba_attention_config = self.create_moba_attention_config()
|
||||
eplb_cfg = self.create_eplb_config()
|
||||
routing_replay_config = self.create_routing_repaly_config()
|
||||
|
||||
early_stop_cfg = self.create_early_stop_config()
|
||||
early_stop_cfg.update_enable_early_stop(self.enable_early_stop)
|
||||
@@ -1163,4 +1184,5 @@ class EngineArgs:
|
||||
early_stop_config=early_stop_cfg,
|
||||
enable_attention_dp_balance=self.enable_attention_dp_balance,
|
||||
attention_dp_time_out_iters=self.attention_dp_time_out_iters,
|
||||
routing_replay_config=routing_replay_config,
|
||||
)
|
||||
|
||||
@@ -462,6 +462,7 @@ class LLMEngine:
|
||||
f" --moba_attention_config '{self.cfg.moba_attention_config.to_json_string()}'"
|
||||
f" --attention_dp_time_out_iters {self.cfg.attention_dp_time_out_iters}"
|
||||
f" --eplb_config '{self.cfg.eplb_config.to_json_string()}'"
|
||||
f" --routing_replay_config '{self.cfg.routing_replay_config.to_json_string()}'"
|
||||
f" --ips {ips}"
|
||||
)
|
||||
|
||||
|
||||
@@ -110,6 +110,8 @@ class ForwardMeta:
|
||||
block_tables: Optional[paddle.Tensor] = None
|
||||
# KV caches
|
||||
caches: Optional[list[paddle.Tensor]] = None
|
||||
# Routing Replay table buffer
|
||||
routing_replay_table: Optional[paddle.Tensor] = None
|
||||
|
||||
def clear_caches(self):
|
||||
"""Safely clean up the caches"""
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
@@ -102,6 +104,7 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Triton compute Fused MoE.
|
||||
@@ -119,6 +122,9 @@ class DCUTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
topk_weights, topk_ids = paddle.topk(scores, k=top_k, axis=-1, sorted=False)
|
||||
topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdim=True)
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_ids)
|
||||
|
||||
intermediate_cache1 = paddle.empty(
|
||||
[token_num * top_k, moe_intermediate_size * 2],
|
||||
dtype=x.dtype,
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
@@ -189,6 +190,7 @@ class GCUFusedMoeMethod(UnquantizedFusedMoEMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle gcu compute Fused MoE.
|
||||
@@ -201,6 +203,7 @@ class GCUFusedMoeMethod(UnquantizedFusedMoEMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP prefill method.
|
||||
@@ -212,6 +215,7 @@ class GCUFusedMoeMethod(UnquantizedFusedMoEMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP decoder method.
|
||||
@@ -223,6 +227,7 @@ class GCUFusedMoeMethod(UnquantizedFusedMoEMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
@@ -388,6 +393,7 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle gcu compute Fused MoE.
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
@@ -132,6 +134,7 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Triton compute Fused MoE.
|
||||
@@ -151,6 +154,10 @@ class MetaxTritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
True, # apply_norm_weight,
|
||||
False,
|
||||
)
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_ids)
|
||||
|
||||
up_gate_proj_out = paddle.empty(
|
||||
[token_num * top_k, moe_intermediate_size * 2],
|
||||
dtype=x.dtype,
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
"""
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import Callable
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
@@ -120,6 +121,7 @@ class MoEMethodBase(QuantMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP prefill method.
|
||||
@@ -144,6 +146,7 @@ class MoEMethodBase(QuantMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
@@ -155,6 +158,7 @@ class MoEMethodBase(QuantMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
@@ -163,13 +167,13 @@ class MoEMethodBase(QuantMethodBase):
|
||||
if layer.fd_config.parallel_config.moe_phase.phase == "prefill":
|
||||
if layer.fd_config.parallel_config.splitwise_role == "mixed" and layer.layer_idx == 0:
|
||||
self.ep_prefill_runner.clean_low_latency_buffer()
|
||||
return self.apply_ep_prefill(layer, x, gate)
|
||||
return self.apply_ep_prefill(layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
else:
|
||||
if layer.fd_config.parallel_config.splitwise_role == "mixed" and layer.layer_idx == 0:
|
||||
self.ep_decoder_runner.clean_low_latency_buffer()
|
||||
return self.apply_ep_decode(layer, x, gate)
|
||||
return self.apply_ep_decode(layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
else:
|
||||
return self.apply_tp(layer, x, gate)
|
||||
return self.apply_tp(layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
|
||||
|
||||
class UnquantizedFusedMoEMethod(MoEMethodBase):
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.nn.quant import weight_quantize
|
||||
@@ -105,6 +107,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP prefill method.
|
||||
@@ -121,6 +124,10 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
handle,
|
||||
_,
|
||||
) = self.ep_prefill_runner.dispatch(x, topk_idx, topk_weights)
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_idx)
|
||||
|
||||
token_all_num = sum(recv_num_tokens_per_expert_list)
|
||||
|
||||
# 3. Compute ffn
|
||||
@@ -178,6 +185,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP decoder method.
|
||||
@@ -186,6 +194,10 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
estimate_total_token_nums = gate_out.shape[0] * layer.top_k
|
||||
# 1. Select topk experts and weights
|
||||
topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out)
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_idx)
|
||||
|
||||
expertwise_scale = None
|
||||
if hasattr(layer, "up_gate_proj_in_scale_all_experts"): # only use in w4a8
|
||||
expertwise_scale = getattr(layer, "up_gate_proj_in_scale_all_experts", None)
|
||||
@@ -220,6 +232,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
@@ -277,6 +290,9 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
topk_only_mode=False,
|
||||
)
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_idx)
|
||||
|
||||
if self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8":
|
||||
# only w4a8 need expert_idx_per_token
|
||||
# Other need not this tensor, so we make it None.
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddleformers.utils.log import logger
|
||||
@@ -298,6 +300,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP prefill method.
|
||||
@@ -305,6 +308,10 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
gate_out = gate(x.cast("float32"))
|
||||
# 1. Select topk experts and weights
|
||||
topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out)
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_idx)
|
||||
|
||||
# 2. Dynamic compute blockwise quantization scales
|
||||
x, x_scale_tensor = fastdeploy.model_executor.ops.gpu.per_token_quant(
|
||||
x, self.quant_config.weight_block_size[0]
|
||||
@@ -406,6 +413,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP decoder method.
|
||||
@@ -413,6 +421,10 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
gate_out = gate(x.cast("float32"))
|
||||
# 1. Select topk experts and weights
|
||||
topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out)
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_idx)
|
||||
|
||||
# 2. EP Dispatch
|
||||
permute_input, token_nums_per_expert, handle = self.ep_decoder_runner.dispatch(
|
||||
x, topk_idx, topk_weights, use_fp8=True
|
||||
@@ -477,6 +489,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle Use DeepGemm compute Fused MoE.
|
||||
@@ -504,6 +517,9 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
False,
|
||||
)
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_ids)
|
||||
|
||||
tmp = count_tokens_per_expert_func(topk_ids, layer.num_experts)
|
||||
|
||||
recv_x, recv_x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(x, 128)
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
@@ -240,6 +242,7 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Marlin compute Fused MoE.
|
||||
@@ -275,6 +278,9 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
|
||||
False,
|
||||
)
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_ids)
|
||||
|
||||
block_size_m = 64
|
||||
|
||||
for m in [8, 16, 32, 48, 64]:
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
@@ -156,6 +158,7 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Triton compute Fused MoE.
|
||||
@@ -186,6 +189,9 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
False,
|
||||
)
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_ids)
|
||||
|
||||
up_gate_proj_out = paddle.empty(
|
||||
[token_num * top_k, moe_intermediate_size * 2],
|
||||
dtype=x.dtype,
|
||||
@@ -420,6 +426,7 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Triton compute Fused MoE.
|
||||
@@ -451,6 +458,9 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
False,
|
||||
)
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_ids)
|
||||
|
||||
up_gate_proj_out = paddle.empty(
|
||||
[token_num * top_k, moe_intermediate_size * 2],
|
||||
dtype=x.dtype,
|
||||
@@ -840,6 +850,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Triton compute Fused MoE.
|
||||
@@ -871,6 +882,9 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
False,
|
||||
)
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_ids)
|
||||
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": self.quant_config.weight_block_size[1],
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
@@ -261,6 +263,7 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Use Wint2 Triton Fusedmoe compute Fused MoE.
|
||||
@@ -288,6 +291,9 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
|
||||
topk_only_mode=False,
|
||||
)
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_idx)
|
||||
|
||||
ffn_out = fastdeploy.model_executor.ops.gpu.moe_expert_ffn_wint2(
|
||||
permute_input,
|
||||
token_nums_per_expert,
|
||||
@@ -333,6 +339,7 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Use Wint2 Triton Fusedmoe compute Fused MoE.
|
||||
@@ -348,6 +355,9 @@ class TritonWint2FusedMoeMethod(CutlassWint2FusedMoeMethod):
|
||||
False,
|
||||
)
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_ids)
|
||||
|
||||
num_tokens, K = x.shape
|
||||
E, _, N = layer.up_gate_proj_weight.shape
|
||||
M = num_tokens
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
@@ -47,6 +49,7 @@ class XPUMoEMethod(UnquantizedFusedMoEMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
@@ -82,6 +85,7 @@ class XPUMoEMethod(UnquantizedFusedMoEMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP prefill method.
|
||||
@@ -93,6 +97,7 @@ class XPUMoEMethod(UnquantizedFusedMoEMethod):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP decoder method.
|
||||
@@ -227,6 +232,7 @@ class XPUWeightOnlyMoEMethod(QuantMethodBase):
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate: nn.Layer,
|
||||
topk_ids_hookfunc: Callable = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
XPU compute Fused MoE.
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
@@ -22,6 +23,10 @@ from paddle import nn
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
from fastdeploy.model_executor.layers.moe.routing_indices_cache import (
|
||||
save_routing_to_buffer,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
from fastdeploy.model_executor.utils import slice_fn
|
||||
from fastdeploy.platforms import current_platform
|
||||
@@ -195,6 +200,7 @@ class FusedMoE(nn.Layer):
|
||||
if self.ep_size > 1:
|
||||
self.quant_method.init_ep(self)
|
||||
|
||||
self.enable_routing_replay = fd_config.routing_replay_config.enable_routing_replay
|
||||
# Merge normal and RL build model
|
||||
if gate_correction_bias is not None:
|
||||
self.gate_correction_bias = gate_correction_bias
|
||||
@@ -532,7 +538,7 @@ class FusedMoE(nn.Layer):
|
||||
else:
|
||||
self.quant_method.process_loaded_weights(self, state_dict)
|
||||
|
||||
def forward(self, x: paddle.Tensor, gate: nn.Layer):
|
||||
def forward(self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta = None):
|
||||
"""
|
||||
Defines the forward computation of the moe layer.
|
||||
|
||||
@@ -543,5 +549,20 @@ class FusedMoE(nn.Layer):
|
||||
Tensor: Output tensor.s
|
||||
|
||||
"""
|
||||
out = self.quant_method.apply(self, x, gate)
|
||||
topk_ids_hookfunc = None
|
||||
if self.enable_routing_replay:
|
||||
if forward_meta is not None: # forward_meta is None when execute empty_input_forward
|
||||
topk_ids_hookfunc = partial(
|
||||
save_routing_to_buffer,
|
||||
routing_replay_table=forward_meta.routing_replay_table,
|
||||
batch_id_per_token=forward_meta.batch_id_per_token,
|
||||
seq_lens_decoder=forward_meta.seq_lens_decoder,
|
||||
cu_seqlens_q=forward_meta.cu_seqlens_q,
|
||||
layer_idx=self.layer_idx,
|
||||
tp_size=self.fd_config.parallel_config.tensor_parallel_size,
|
||||
ep_size=self.fd_config.parallel_config.expert_parallel_size,
|
||||
tp_group=self.fd_config.parallel_config.tp_group,
|
||||
)
|
||||
|
||||
out = self.quant_method.apply(self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
|
||||
return out
|
||||
|
||||
328
fastdeploy/model_executor/layers/moe/routing_indices_cache.py
Normal file
328
fastdeploy/model_executor/layers/moe/routing_indices_cache.py
Normal file
@@ -0,0 +1,328 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import os
|
||||
import shutil
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import paddle
|
||||
import paddle.distributed as dist
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _save_routing_kernel(
|
||||
ROUTING_REPLAY_TABLE_PTR,
|
||||
TOPK_IDS_PTR,
|
||||
BATCH_ID_PER_TOKEN_PTR,
|
||||
CU_SEQLENS_Q_PTR,
|
||||
SEQ_LENS_DECODER_PTR,
|
||||
LAYER_IDX,
|
||||
TOKEN_NUM,
|
||||
TOP_K,
|
||||
NUM_HIDDEN_LAYERS,
|
||||
MAX_MODEL_LEN,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
):
|
||||
pid_m = tl.program_id(axis=0)
|
||||
|
||||
token_offsets = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
token_mask = token_offsets < TOKEN_NUM
|
||||
|
||||
k_offsets = tl.arange(0, BLOCK_SIZE_K)
|
||||
topk_ids_ptrs = TOPK_IDS_PTR + token_offsets[:, None] * TOP_K + k_offsets[None, :]
|
||||
# [BLOCK_SIZE_M, BLOCK_SIZE_K]
|
||||
topk_vals = tl.load(topk_ids_ptrs, mask=token_mask[:, None])
|
||||
|
||||
batch_ids = tl.load(BATCH_ID_PER_TOKEN_PTR + token_offsets, mask=token_mask)
|
||||
pad_mask = token_mask & (batch_ids != -1)
|
||||
# [0, 3, 4, 10, 12][0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 3, 3]
|
||||
# -> [0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 10, 10]
|
||||
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] - [0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 10, 10]
|
||||
# -> [0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 0, 1]
|
||||
start_offsets = tl.load(CU_SEQLENS_Q_PTR + batch_ids, mask=pad_mask)
|
||||
token_relative_index = token_offsets - start_offsets
|
||||
|
||||
# [BLOCK_SIZE_M]
|
||||
len_decoder = tl.load(SEQ_LENS_DECODER_PTR + batch_ids, mask=pad_mask)
|
||||
token_seq_pos = len_decoder + token_relative_index
|
||||
|
||||
STRIDE_BUF_SEQ = NUM_HIDDEN_LAYERS * MAX_MODEL_LEN * TOP_K
|
||||
STRIDE_BUF_LAYER = MAX_MODEL_LEN * TOP_K
|
||||
STRIDE_BUF_TOKEN = TOP_K
|
||||
|
||||
# [BLOCK_SIZE_M, BLOCK_SIZE_K]
|
||||
output_ptrs = (
|
||||
ROUTING_REPLAY_TABLE_PTR
|
||||
+ batch_ids[:, None] * STRIDE_BUF_SEQ
|
||||
+ LAYER_IDX * STRIDE_BUF_LAYER
|
||||
+ token_seq_pos[:, None] * STRIDE_BUF_TOKEN
|
||||
+ k_offsets[None, :]
|
||||
)
|
||||
|
||||
pos_mask = token_seq_pos < MAX_MODEL_LEN
|
||||
pos_mask = pos_mask & pad_mask
|
||||
final_mask = token_mask[:, None] & pos_mask[:, None]
|
||||
|
||||
tl.store(output_ptrs, topk_vals, mask=final_mask)
|
||||
|
||||
|
||||
def save_routing_to_buffer(
|
||||
routing_replay_table: paddle.Tensor, # [max_num_seqs, num_layers, max_len, top_k]
|
||||
topk_ids: paddle.Tensor, # [token_num, top_k]
|
||||
batch_id_per_token: paddle.Tensor, # [token_num, 1]
|
||||
seq_lens_decoder: paddle.Tensor, # [max_num_seqs, 1]
|
||||
cu_seqlens_q: paddle.Tensor, # [max_num_seqs + 1, 1]
|
||||
layer_idx: int,
|
||||
tp_size: int,
|
||||
ep_size: int,
|
||||
tp_group: dist.communication.group.Group,
|
||||
):
|
||||
if tp_size > 1 and ep_size > 1:
|
||||
token_num_per_rank = topk_ids.shape[0]
|
||||
topk_ids_all = paddle.zeros([token_num_per_rank * tp_size, topk_ids.shape[1]], dtype=topk_ids.dtype)
|
||||
paddle.distributed.all_gather(topk_ids_all, topk_ids, tp_group)
|
||||
topk_ids = topk_ids_all[: batch_id_per_token.shape[0], :]
|
||||
|
||||
token_num, top_k = topk_ids.shape
|
||||
max_num_seqs, num_hidden_layers, max_model_len, _ = routing_replay_table.shape
|
||||
assert token_num > 0
|
||||
assert topk_ids.shape[1] == routing_replay_table.shape[3], (topk_ids.shape[1], routing_replay_table.shape[3])
|
||||
assert batch_id_per_token.shape[0] == token_num, (batch_id_per_token.shape[0], token_num)
|
||||
assert seq_lens_decoder.shape[0] == max_num_seqs, (seq_lens_decoder.shape[0], max_num_seqs)
|
||||
|
||||
BLOCK_SIZE_M = 128
|
||||
BLOCK_SIZE_K = top_k
|
||||
|
||||
grid = (triton.cdiv(token_num, BLOCK_SIZE_M),)
|
||||
_save_routing_kernel[grid](
|
||||
routing_replay_table,
|
||||
topk_ids,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens_decoder,
|
||||
LAYER_IDX=layer_idx,
|
||||
TOKEN_NUM=token_num,
|
||||
TOP_K=top_k,
|
||||
NUM_HIDDEN_LAYERS=num_hidden_layers,
|
||||
MAX_MODEL_LEN=max_model_len,
|
||||
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
)
|
||||
|
||||
|
||||
class RoutingReplayManager:
|
||||
"""Request level routing replay table manager"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config: FDConfig,
|
||||
):
|
||||
self.max_num_seqs = fd_config.parallel_config.max_num_seqs
|
||||
self.max_model_len = fd_config.model_config.max_model_len
|
||||
self.num_moe_layers = fd_config.model_config.num_hidden_layers - fd_config.model_config.moe_layer_start_index
|
||||
self.moe_top_k = fd_config.model_config.moe_k
|
||||
self.tp_rank = fd_config.parallel_config.tensor_parallel_rank
|
||||
|
||||
self.routing_store = get_routing_store(fd_config=fd_config)
|
||||
self.routing_batch_to_request: Dict[int, str] = {}
|
||||
self.routing_replay_table = paddle.full(
|
||||
shape=[self.max_num_seqs, self.num_moe_layers, self.max_model_len, self.moe_top_k],
|
||||
fill_value=-1,
|
||||
dtype="int32",
|
||||
)
|
||||
|
||||
def register_request(self, batch_id: int, request_id: str):
|
||||
"""
|
||||
Register a new request to routing replay table
|
||||
Args:
|
||||
batch_id: The batch ID of this request
|
||||
request_id: The global ID of the request is usually executed by the training process in RL
|
||||
"""
|
||||
# Save requests that have been finished for the current slot
|
||||
if batch_id in self.routing_batch_to_request:
|
||||
pre_request_id = self._deregister_request(batch_id)
|
||||
self._put_request_to_store(batch_id, pre_request_id)
|
||||
# Register the new request
|
||||
self.routing_batch_to_request[batch_id] = request_id
|
||||
|
||||
def _deregister_request(self, batch_id: int) -> str:
|
||||
"""
|
||||
Deregister a request from routing replay table
|
||||
"""
|
||||
assert batch_id in self.routing_batch_to_request
|
||||
return self.routing_batch_to_request.pop(batch_id)
|
||||
|
||||
def _put_request_to_store(
|
||||
self,
|
||||
batch_id: int,
|
||||
request_id: str,
|
||||
):
|
||||
if self.tp_rank == 0:
|
||||
batch_buffer = self.routing_replay_table[batch_id]
|
||||
for layer_id in range(self.num_moe_layers):
|
||||
layer_buffer = batch_buffer[layer_id]
|
||||
rollout_id = self.split_request_id(request_id)
|
||||
self.routing_store.put(routing_indices=layer_buffer, rollout_id=rollout_id, layer_idx=layer_id)
|
||||
|
||||
self._clear_table_slot(batch_id)
|
||||
|
||||
def put_table_to_store(self):
|
||||
"""Put the routin table"""
|
||||
batch_ids = copy.deepcopy(list(self.routing_batch_to_request.keys()))
|
||||
for batch_id in batch_ids:
|
||||
request_id = self._deregister_request(batch_id)
|
||||
self._put_request_to_store(batch_id, request_id)
|
||||
|
||||
def _clear_table_slot(self, batch_id: int):
|
||||
assert 0 <= batch_id < self.max_num_seqs
|
||||
self.routing_replay_table[batch_id].fill_(-1)
|
||||
|
||||
def clear_routing_table(self):
|
||||
"""Clear all slots of the routing replay table"""
|
||||
self.routing_replay_table.fill_(-1)
|
||||
|
||||
def _clear_store(self):
|
||||
"""Clear routing store"""
|
||||
self.routing_store.clear_store()
|
||||
|
||||
def _clear_request_of_store(self, request_id):
|
||||
"""Clear one request of routing store"""
|
||||
rollout_id = self.split_request_id(request_id)
|
||||
for layer_idx in range(self.num_moe_layers):
|
||||
self.routing_store.clear(rollout_id=rollout_id, layer_idx=layer_idx)
|
||||
|
||||
def get_request_from_store(self, request_id: str) -> List[paddle.Tensor]:
|
||||
"""Get the routing indices of the reuest from store"""
|
||||
routing_list = []
|
||||
rollout_id = self.split_request_id(request_id)
|
||||
for layer_idx in range(self.num_moe_layers):
|
||||
one_layer_routing = self.routing_store.get(rollout_id, layer_idx)
|
||||
routing_list.append(one_layer_routing)
|
||||
|
||||
return routing_list
|
||||
|
||||
def get_routing_table(self) -> paddle.Tensor:
|
||||
return self.routing_replay_table
|
||||
|
||||
def split_request_id(self, request_id: str):
|
||||
"""Split the request id to get rollout id"""
|
||||
chat_type, tmp_str = request_id.split("-", 1)
|
||||
assert chat_type == "chatcmpl"
|
||||
reversed_tmp_str = tmp_str[::-1].split("-", 5)
|
||||
rollout_id = reversed_tmp_str[-1][::-1]
|
||||
return rollout_id
|
||||
|
||||
|
||||
class RoutingStoreBase(ABC):
|
||||
"""Base class for routing store"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig) -> None:
|
||||
self.fd_config = fd_config
|
||||
|
||||
@abstractmethod
|
||||
def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: Optional[int] = None) -> None:
|
||||
"""Put the routing indices into store"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get(self, rollout_id: str, layer_idx: Optional[int] = None) -> paddle.Tensor:
|
||||
"""Get the routing indices from store"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def clear(self, rollout_id: str, layer_idx: Optional[int] = None) -> None:
|
||||
"""Clear the routing indices of the request"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def clear_store(
|
||||
self,
|
||||
):
|
||||
"""Clear the routing indices store"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RoutingStoreLocal(RoutingStoreBase):
|
||||
"""Routing Store using local memory"""
|
||||
|
||||
def __init__(self, fd_config) -> None:
|
||||
super().__init__(fd_config=fd_config)
|
||||
self.local_store_dir = fd_config.routing_replay_config.local_store_dir
|
||||
|
||||
def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: int) -> None:
|
||||
"""Put the routing indices into store"""
|
||||
dir_path = os.path.join(self.local_store_dir, f"{rollout_id}")
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
file_path = os.path.join(dir_path, f"layer_{layer_idx}.pdtensor")
|
||||
paddle.save(routing_indices, file_path)
|
||||
|
||||
def get(
|
||||
self,
|
||||
rollout_id: str,
|
||||
layer_idx: int = None,
|
||||
) -> paddle.Tensor:
|
||||
"""Get the routing indices from store"""
|
||||
dir_path = os.path.join(self.local_store_dir, f"{rollout_id}")
|
||||
file_path = os.path.join(dir_path, f"layer_{layer_idx}.pdtensor")
|
||||
assert os.path.exists(file_path), f"File not found: {file_path}"
|
||||
layer_routing_indices = paddle.load(file_path)
|
||||
|
||||
return layer_routing_indices
|
||||
|
||||
def clear(
|
||||
self,
|
||||
rollout_id: str,
|
||||
layer_idx: int = None,
|
||||
) -> None:
|
||||
"""Clear the routing indices of the request"""
|
||||
dir_path = os.path.join(self.local_store_dir, f"{rollout_id}")
|
||||
file_path = os.path.join(dir_path, f"layer_{layer_idx}.pdtensor")
|
||||
assert os.path.exists(file_path), f"File not found: {file_path}"
|
||||
os.remove(file_path)
|
||||
|
||||
# Delete empty directory
|
||||
if len(os.listdir(dir_path)) == 0:
|
||||
os.rmdir(dir_path)
|
||||
|
||||
def clear_store(self):
|
||||
"""Clear the routing indices store"""
|
||||
if os.path.isdir(self.local_store_dir):
|
||||
for file_name in os.listdir(self.local_store_dir):
|
||||
file_path = os.path.join(self.local_store_dir, file_name)
|
||||
shutil.rmtree(file_path)
|
||||
|
||||
|
||||
class RoutingStoreRDMA(RoutingStoreBase):
|
||||
"""Routing Store using RDMA"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
def get_routing_store(fd_config: FDConfig) -> RoutingStoreBase:
|
||||
if fd_config.routing_replay_config.routing_store_type == "local":
|
||||
return RoutingStoreLocal(fd_config=fd_config)
|
||||
elif fd_config.routing_replay_config.routing_store_type == "rdma":
|
||||
return RoutingStoreRDMA(fd_config=fd_config)
|
||||
else:
|
||||
raise ValueError("Invalid store type")
|
||||
@@ -67,6 +67,7 @@ class RolloutModelConfig:
|
||||
enable_attention_dp_balance: bool = False,
|
||||
attention_dp_time_out_iters: int = 0,
|
||||
eplb_config: str = {},
|
||||
routing_replay_config: str = None,
|
||||
):
|
||||
# Required parameters
|
||||
self.model = model_name_or_path
|
||||
@@ -117,6 +118,7 @@ class RolloutModelConfig:
|
||||
self.enable_attention_dp_balance = enable_attention_dp_balance
|
||||
self.attention_dp_time_out_iters = attention_dp_time_out_iters
|
||||
self.eplb_config = eplb_config
|
||||
self.routing_replay_config = routing_replay_config
|
||||
|
||||
def __str__(self):
|
||||
return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items())
|
||||
|
||||
@@ -38,6 +38,7 @@ from fastdeploy.config import (
|
||||
MobaAttentionConfig,
|
||||
ModelConfig,
|
||||
ParallelConfig,
|
||||
RoutingReplayConfig,
|
||||
SpeculativeConfig,
|
||||
)
|
||||
from fastdeploy.engine.request import RequestType
|
||||
@@ -840,6 +841,13 @@ def parse_args():
|
||||
help="EPLB Configuration.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--routing_replay_config",
|
||||
type=json.loads,
|
||||
default=None,
|
||||
help="Configation of Rollout Routing Replay.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
@@ -900,6 +908,8 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
|
||||
eplb_config = EPLBConfig(args.eplb_config)
|
||||
|
||||
routing_replay_config = RoutingReplayConfig(args.routing_replay_config)
|
||||
|
||||
# Note(tangbinhan): used for load_checkpoint
|
||||
model_config.pretrained_config.tensor_parallel_rank = parallel_config.tensor_parallel_rank
|
||||
model_config.pretrained_config.tensor_parallel_degree = parallel_config.tensor_parallel_size
|
||||
@@ -998,6 +1008,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
enable_attention_dp_balance=args.enable_attention_dp_balance,
|
||||
attention_dp_time_out_iters=args.attention_dp_time_out_iters,
|
||||
eplb_config=eplb_config,
|
||||
routing_replay_config=routing_replay_config,
|
||||
)
|
||||
update_fd_config_for_mm(fd_config)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user