[Cherry-Pick][RL] R3 Support RDMA Store(#5467) (#5468)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled

* [RL] R3 support rdma store

* refine code

* refine notes

* disable prefix cache

* fix ci bug

* support preempted task and put cpu tensor
This commit is contained in:
RAM
2025-12-17 09:50:40 +08:00
committed by GitHub
parent 53158b7f8d
commit c19af496cb
3 changed files with 69 additions and 4 deletions

View File

@@ -14,9 +14,11 @@
# limitations under the License.
"""
import asyncio
import copy
import os
import shutil
import time
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
@@ -247,6 +249,11 @@ class RoutingReplayManager:
rollout_id = reversed_tmp_str[-1][::-1]
return rollout_id
def clear_request(self, batch_id: int):
"""Clear the routing indices of the request"""
self._clear_table_slot(batch_id)
self.routing_batch_to_request.pop(batch_id, None)
class RoutingStoreBase(ABC):
"""Base class for routing store"""
@@ -283,6 +290,7 @@ class RoutingStoreLocal(RoutingStoreBase):
def __init__(self, fd_config) -> None:
super().__init__(fd_config=fd_config)
self.local_store_dir = fd_config.routing_replay_config.local_store_dir
self.clear_store()
def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: int) -> None:
"""Put the routing indices into store"""
@@ -330,8 +338,55 @@ class RoutingStoreLocal(RoutingStoreBase):
class RoutingStoreRDMA(RoutingStoreBase):
"""Routing Store using RDMA"""
def __init__(self) -> None:
super().__init__()
def __init__(self, fd_config) -> None:
super().__init__(fd_config=fd_config)
try:
# Only used in RLHF
from p2pstore import P2PClient, P2PConfig
except ModuleNotFoundError:
raise ModuleNotFoundError(" RoutingStoreRDMA and p2pstore only support in RLHF. ")
rdma_store_server = fd_config.routing_replay_config.rdma_store_server
p2pConfig = P2PConfig(metadata_server=rdma_store_server)
self.p2p_client = P2PClient(p2pConfig)
self.clear_store()
def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: int) -> None:
"""Put the routing indices into store"""
rdma_rollout_key = f"{rollout_id}_{layer_idx}"
# async put
time_before_put = time.perf_counter()
routing_indices_pin = routing_indices.pin_memory()
routing_indices_np = routing_indices_pin.numpy()
asyncio.run(self.p2p_client.put(rdma_rollout_key, routing_indices_np))
print(f"Success put with key {rdma_rollout_key}, time cost is {time.perf_counter()-time_before_put} s")
def get(
self,
rollout_id: str,
layer_idx: int = None,
) -> paddle.Tensor:
"""Get the routing indices from store"""
rdma_rollout_key = f"{rollout_id}_{layer_idx}"
# sync get
tmp_routing = asyncio.run(self.p2p_client.get(rdma_rollout_key))
return tmp_routing
def clear(
self,
rollout_id: str,
layer_idx: int = None,
) -> None:
"""Clear the routing indices of the request"""
rdma_rollout_key = f"{rollout_id}_{layer_idx}"
# sync delete
asyncio.run(self.p2p_client.delete(rdma_rollout_key))
def clear_store(self):
"""Clear the routing indices store"""
# sync clear routing store
asyncio.run(self.p2p_client.clear())
def get_routing_store(fd_config: FDConfig) -> RoutingStoreBase: