mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[New][RL] Support Rollout Routing Replay (#5405)
* [RL] Support Rollout Routing Replay
* add routing indices cache
* fix config bug and moe forward bug
* R3 Support GLM
* support eb4.5
* fix merge bug
* Apply suggestion from @Copilot
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* Apply suggestion from @Copilot
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* Apply suggestion from @Copilot
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* Apply suggestion from @Copilot
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* add routing replay ci
* support glm topk
* support orther top_k
* fix ci bug
* pre-commit
* only support chatcmpl
* Revert "Revert "[RL] Support Rollout Routing Replay (#5321)" (#5402)"
This reverts commit c45e064f3d.
* Fix XPU and NPU bug
---------
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Yuanle Liu <yuanlehome@163.com>
This commit is contained in:
@@ -1484,6 +1484,31 @@ class StructuredOutputsConfig:
|
||||
return json.dumps({key: value for key, value in self.__dict__.items()})
|
||||
|
||||
|
||||
class RoutingReplayConfig:
|
||||
"""Configuration for Routing Replay used in RL training"""
|
||||
|
||||
def __init__(self, args) -> None:
|
||||
self.enable_routing_replay: bool = False
|
||||
self.routing_store_type: str = "local"
|
||||
|
||||
# Local routing store
|
||||
self.local_store_dir: str = "./routing_replay_output"
|
||||
|
||||
# RDMA routing store
|
||||
# TODO: Add RDMA routing store configuration attributes here when the feature is implemented.
|
||||
|
||||
if args is not None:
|
||||
for key, value in args.items():
|
||||
if hasattr(self, key) and value != "None":
|
||||
setattr(self, key, value)
|
||||
|
||||
def to_json_string(self):
|
||||
"""
|
||||
Convert routing replay config to json string.
|
||||
"""
|
||||
return json.dumps({key: value for key, value in self.__dict__.items()})
|
||||
|
||||
|
||||
class FDConfig:
|
||||
"""
|
||||
The configuration class which contains all fastdeploy-related configuration. This
|
||||
@@ -1517,6 +1542,7 @@ class FDConfig:
|
||||
early_stop_config: Optional[Dict[str, Any]] = None,
|
||||
tool_parser: str = None,
|
||||
test_mode=False,
|
||||
routing_replay_config: Optional[RoutingReplayConfig] = None,
|
||||
):
|
||||
self.model_config: ModelConfig = model_config # type: ignore
|
||||
self.cache_config: CacheConfig = cache_config # type: ignore
|
||||
@@ -1533,6 +1559,7 @@ class FDConfig:
|
||||
self.plas_attention_config: Optional[PlasAttentionConfig] = plas_attention_config
|
||||
self.structured_outputs_config: StructuredOutputsConfig = structured_outputs_config
|
||||
self.router_config: RouterConfig = router_config
|
||||
self.routing_replay_config = routing_replay_config
|
||||
|
||||
# Initialize cuda graph capture list
|
||||
max_capture_shape = self.scheduler_config.max_num_seqs
|
||||
|
||||
Reference in New Issue
Block a user