[RL] R3 Support RDMA Store (#5467)

* [RL] R3 support rdma store

* refine notes

* refine code

* disable prefix cache

* support preempted task and put cpu tensor
This commit is contained in:
RAM
2025-12-16 16:50:13 +08:00
committed by GitHub
parent a30b4da260
commit 6fc5eccf83
3 changed files with 69 additions and 4 deletions

View File

@@ -1498,14 +1498,17 @@ class RoutingReplayConfig:
"""Configuration for Routing Replay used in RL training"""
def __init__(self, args) -> None:
self.enable_routing_replay: bool = False
# Routing store type: local/rdma
self.routing_store_type: str = "local"
# Local routing store
self.local_store_dir: str = "./routing_replay_output"
# RDMA routing store
# TODO: Add RDMA routing store configuration attributes here when the feature is implemented.
self.rdma_store_server: str = ""
if args is not None:
for key, value in args.items():
@@ -1698,7 +1701,9 @@ class FDConfig:
self.cache_config.postprocess(self.scheduler_config.max_num_batched_tokens, self.scheduler_config.max_num_seqs)
if self.model_config is not None and self.model_config.enable_mm and not envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.cache_config.enable_prefix_caching = False
if self.routing_replay_config is not None and self.routing_replay_config.enable_routing_replay:
# TODO(gongshaotian): R3 support prefix caching
self.cache_config.enable_prefix_caching = False
if (
self.structured_outputs_config is not None
and self.structured_outputs_config.guided_decoding_backend != "off"

View File

@@ -14,9 +14,11 @@
# limitations under the License.
"""
import asyncio
import copy
import os
import shutil
import time
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
@@ -247,6 +249,11 @@ class RoutingReplayManager:
rollout_id = reversed_tmp_str[-1][::-1]
return rollout_id
def clear_request(self, batch_id: int):
"""Clear the routing indices of the request"""
self._clear_table_slot(batch_id)
self.routing_batch_to_request.pop(batch_id, None)
class RoutingStoreBase(ABC):
"""Base class for routing store"""
@@ -283,6 +290,7 @@ class RoutingStoreLocal(RoutingStoreBase):
def __init__(self, fd_config) -> None:
super().__init__(fd_config=fd_config)
self.local_store_dir = fd_config.routing_replay_config.local_store_dir
self.clear_store()
def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: int) -> None:
"""Put the routing indices into store"""
@@ -330,8 +338,55 @@ class RoutingStoreLocal(RoutingStoreBase):
class RoutingStoreRDMA(RoutingStoreBase):
"""Routing Store using RDMA"""
def __init__(self) -> None:
super().__init__()
def __init__(self, fd_config) -> None:
super().__init__(fd_config=fd_config)
try:
# Only used in RLHF
from p2pstore import P2PClient, P2PConfig
except ModuleNotFoundError:
raise ModuleNotFoundError(" RoutingStoreRDMA and p2pstore only support in RLHF. ")
rdma_store_server = fd_config.routing_replay_config.rdma_store_server
p2pConfig = P2PConfig(metadata_server=rdma_store_server)
self.p2p_client = P2PClient(p2pConfig)
self.clear_store()
def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: int) -> None:
"""Put the routing indices into store"""
rdma_rollout_key = f"{rollout_id}_{layer_idx}"
# async put
time_before_put = time.perf_counter()
routing_indices_pin = routing_indices.pin_memory()
routing_indices_np = routing_indices_pin.numpy()
asyncio.run(self.p2p_client.put(rdma_rollout_key, routing_indices_np))
print(f"Success put with key {rdma_rollout_key}, time cost is {time.perf_counter()-time_before_put} s")
def get(
self,
rollout_id: str,
layer_idx: int = None,
) -> paddle.Tensor:
"""Get the routing indices from store"""
rdma_rollout_key = f"{rollout_id}_{layer_idx}"
# sync get
tmp_routing = asyncio.run(self.p2p_client.get(rdma_rollout_key))
return tmp_routing
def clear(
self,
rollout_id: str,
layer_idx: int = None,
) -> None:
"""Clear the routing indices of the request"""
rdma_rollout_key = f"{rollout_id}_{layer_idx}"
# sync delete
asyncio.run(self.p2p_client.delete(rdma_rollout_key))
def clear_store(self):
"""Clear the routing indices store"""
# sync clear routing store
asyncio.run(self.p2p_client.clear())
def get_routing_store(fd_config: FDConfig) -> RoutingStoreBase:

View File

@@ -746,6 +746,11 @@ class GPUModelRunner(ModelRunnerBase):
self.prompt_logprobs_reqs.pop(request.request_id, None)
self.in_progress_prompt_logprobs.pop(request.request_id, None)
self.forward_batch_reqs_list[idx] = None
# Routing Replay
if self.fd_config.routing_replay_config.enable_routing_replay:
self.routing_replay_manager.clear_request(batch_id=idx)
continue
assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens