[RL] Support Rollout Routing Replay (#5321)

* [RL] Support Rollout Routing Replay

* add routing indices cache

* fix config bug and moe forward bug

* R3 Support GLM

* support eb4.5

* fix merge bug

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* add routing replay ci

* support glm topk

* support orther top_k

* fix ci bug

* pre-commit

* only support chatcmpl

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Yuanle Liu <yuanlehome@163.com>
This commit is contained in:
RAM
2025-12-05 20:01:33 +08:00
committed by GitHub
parent 8545b705ed
commit 96d2d4877b
24 changed files with 592 additions and 24 deletions

View File

@@ -1484,6 +1484,31 @@ class StructuredOutputsConfig:
return json.dumps({key: value for key, value in self.__dict__.items()})
class RoutingReplayConfig:
"""Configuration for Routing Replay used in RL training"""
def __init__(self, args) -> None:
self.enable_routing_replay: bool = False
self.routing_store_type: str = "local"
# Local routing store
self.local_store_dir: str = "./routing_replay_output"
# RDMA routing store
# TODO: Add RDMA routing store configuration attributes here when the feature is implemented.
if args is not None:
for key, value in args.items():
if hasattr(self, key) and value != "None":
setattr(self, key, value)
def to_json_string(self):
"""
Convert routing replay config to json string.
"""
return json.dumps({key: value for key, value in self.__dict__.items()})
class FDConfig:
"""
The configuration class which contains all fastdeploy-related configuration. This
@@ -1517,6 +1542,7 @@ class FDConfig:
early_stop_config: Optional[Dict[str, Any]] = None,
tool_parser: str = None,
test_mode=False,
routing_replay_config: Optional[RoutingReplayConfig] = None,
):
self.model_config: ModelConfig = model_config # type: ignore
self.cache_config: CacheConfig = cache_config # type: ignore
@@ -1533,6 +1559,7 @@ class FDConfig:
self.plas_attention_config: Optional[PlasAttentionConfig] = plas_attention_config
self.structured_outputs_config: StructuredOutputsConfig = structured_outputs_config
self.router_config: RouterConfig = router_config
self.routing_replay_config = routing_replay_config
# Initialize cuda graph capture list
max_capture_shape = self.scheduler_config.max_num_seqs