mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Cherry-Pick][RL] R3 Support RDMA Store (#5454)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* [RL] R3 support rdma store * refine notes * refine code * support preempted task and put cpu tensor
This commit is contained in:
@@ -1167,14 +1167,17 @@ class RoutingReplayConfig:
|
||||
"""Configuration for Routing Replay used in RL training"""
|
||||
|
||||
def __init__(self, args) -> None:
|
||||
|
||||
self.enable_routing_replay: bool = False
|
||||
|
||||
# Routing store type: local/rdma
|
||||
self.routing_store_type: str = "local"
|
||||
|
||||
# Local routing store
|
||||
self.local_store_dir: str = "./routing_replay_output"
|
||||
|
||||
# RDMA routing store
|
||||
pass
|
||||
self.rdma_store_server: str = ""
|
||||
|
||||
if args is not None:
|
||||
for key, value in args.items():
|
||||
|
||||
@@ -14,9 +14,11 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
@@ -232,6 +234,11 @@ class RoutingReplayManager:
|
||||
rollout_id = reversed_tmp_str[-1][::-1]
|
||||
return rollout_id
|
||||
|
||||
def clear_request(self, batch_id: int):
|
||||
"""Clear the routing indices of the request"""
|
||||
self._clear_table_slot(batch_id)
|
||||
self.routing_batch_to_request.pop(batch_id, None)
|
||||
|
||||
|
||||
class RoutingStoreBase(ABC):
|
||||
"""Base class for routing store"""
|
||||
@@ -268,6 +275,7 @@ class RoutingStoreLocal(RoutingStoreBase):
|
||||
def __init__(self, fd_config) -> None:
|
||||
super().__init__(fd_config=fd_config)
|
||||
self.local_store_dir = fd_config.routing_replay_config.local_store_dir
|
||||
self.clear_store()
|
||||
|
||||
def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: int) -> None:
|
||||
"""Put the routing indices into store"""
|
||||
@@ -315,8 +323,55 @@ class RoutingStoreLocal(RoutingStoreBase):
|
||||
class RoutingStoreRDMA(RoutingStoreBase):
|
||||
"""Routing Store using RDMA"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
def __init__(self, fd_config) -> None:
|
||||
super().__init__(fd_config=fd_config)
|
||||
try:
|
||||
# Only used in RLHF
|
||||
from p2pstore import P2PClient, P2PConfig
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError(" RoutingStoreRDMA and p2pstore only support in RLHF. ")
|
||||
|
||||
rdma_store_server = fd_config.routing_replay_config.rdma_store_server
|
||||
p2pConfig = P2PConfig(metadata_server=rdma_store_server)
|
||||
self.p2p_client = P2PClient(p2pConfig)
|
||||
self.clear_store()
|
||||
|
||||
def put(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: int) -> None:
|
||||
"""Put the routing indices into store"""
|
||||
rdma_rollout_key = f"{rollout_id}_{layer_idx}"
|
||||
|
||||
# async put
|
||||
time_before_put = time.perf_counter()
|
||||
routing_indices_pin = routing_indices.pin_memory()
|
||||
routing_indices_np = routing_indices_pin.numpy()
|
||||
asyncio.run(self.p2p_client.put(rdma_rollout_key, routing_indices_np))
|
||||
print(f"Success put with key {rdma_rollout_key}, time cost is {time.perf_counter()-time_before_put} s")
|
||||
|
||||
def get(
|
||||
self,
|
||||
rollout_id: str,
|
||||
layer_idx: int = None,
|
||||
) -> paddle.Tensor:
|
||||
"""Get the routing indices from store"""
|
||||
rdma_rollout_key = f"{rollout_id}_{layer_idx}"
|
||||
# sync get
|
||||
tmp_routing = asyncio.run(self.p2p_client.get(rdma_rollout_key))
|
||||
return tmp_routing
|
||||
|
||||
def clear(
|
||||
self,
|
||||
rollout_id: str,
|
||||
layer_idx: int = None,
|
||||
) -> None:
|
||||
"""Clear the routing indices of the request"""
|
||||
rdma_rollout_key = f"{rollout_id}_{layer_idx}"
|
||||
# sync delete
|
||||
asyncio.run(self.p2p_client.delete(rdma_rollout_key))
|
||||
|
||||
def clear_store(self):
|
||||
"""Clear the routing indices store"""
|
||||
# sync clear routing store
|
||||
asyncio.run(self.p2p_client.clear())
|
||||
|
||||
|
||||
def get_routing_store(fd_config: FDConfig) -> RoutingStoreBase:
|
||||
|
||||
@@ -75,6 +75,9 @@ if not (current_platform.is_dcu() or current_platform.is_iluvatar()):
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.input.ernie4_5_vl_processor import DataProcessor
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
from fastdeploy.model_executor.layers.moe.routing_indices_cache import (
|
||||
RoutingReplayManager,
|
||||
)
|
||||
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
|
||||
from fastdeploy.worker.model_runner_base import ModelRunnerBase
|
||||
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput
|
||||
@@ -163,6 +166,11 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.parallel_config.engine_worker_queue_port)
|
||||
logger.info(f"queue id is {str(self.parallel_config.engine_worker_queue_port)}")
|
||||
|
||||
# Rollout routing replay config
|
||||
self.routing_replay_manager = None
|
||||
if self.fd_config.routing_replay_config.enable_routing_replay:
|
||||
self.routing_replay_manager = RoutingReplayManager(fd_config=self.fd_config)
|
||||
|
||||
def exist_prefill(self):
|
||||
"""
|
||||
check whether prefill stage exist
|
||||
@@ -313,11 +321,18 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = 0
|
||||
self.share_inputs["prompt_lens"][idx : idx + 1] = len(input_ids)
|
||||
self.share_inputs["is_block_step"][idx : idx + 1] = False
|
||||
self.share_inputs["is_chunk_step"][idx : idx + 1] = prefill_end_index < len(input_ids)
|
||||
self.share_inputs["step_idx"][idx : idx + 1] = (
|
||||
len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0
|
||||
)
|
||||
self.share_inputs["pre_ids"][idx : idx + 1] = -1
|
||||
has_prefill_task = True
|
||||
|
||||
# Routing Replay
|
||||
if self.fd_config.routing_replay_config.enable_routing_replay:
|
||||
if prefill_start_index == 0:
|
||||
self.routing_replay_manager.register_request(batch_id=idx, request_id=request.request_id)
|
||||
|
||||
elif request.task_type.value == RequestType.DECODE.value: # decode task
|
||||
logger.debug(f"Handle decode request {request} at idx {idx}")
|
||||
encoder_block_num = len(request.block_tables)
|
||||
@@ -338,6 +353,11 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0
|
||||
self.share_inputs["is_block_step"][idx : idx + 1] = False
|
||||
has_preempted_task = True
|
||||
|
||||
# Routing Replay
|
||||
if self.fd_config.routing_replay_config.enable_routing_replay:
|
||||
self.routing_replay_manager.clear_request(batch_id=idx)
|
||||
|
||||
continue
|
||||
|
||||
assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens
|
||||
@@ -716,6 +736,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["bad_tokens_len"] = paddle.full([max_num_seqs], 1, dtype="int64")
|
||||
self.share_inputs["next_tokens"] = paddle.full([max_num_seqs, 1], -1, dtype="int64")
|
||||
self.share_inputs["is_block_step"] = paddle.full([max_num_seqs], False, dtype="bool")
|
||||
self.share_inputs["is_chunk_step"] = paddle.full([max_num_seqs], False, dtype="bool").cpu()
|
||||
self.share_inputs["encoder_block_lens"] = paddle.full([max_num_seqs], 0, dtype="int32")
|
||||
self.share_inputs["step_block_list"] = paddle.full([max_num_seqs], -1, dtype="int32")
|
||||
self.share_inputs["step_lens"] = paddle.full([1], 0, dtype="int32")
|
||||
@@ -972,6 +993,9 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
Initialize forward meta and attention meta data
|
||||
"""
|
||||
# Initialize forward meta
|
||||
routing_replay_table = None
|
||||
if self.routing_replay_manager is not None:
|
||||
routing_replay_table = self.routing_replay_manager.get_routing_table()
|
||||
self.forward_meta = ForwardMeta(
|
||||
input_ids=self.share_inputs["input_ids"],
|
||||
ids_remove_padding=self.share_inputs["ids_remove_padding"],
|
||||
@@ -989,6 +1013,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
cu_seqlens_k=self.share_inputs["cu_seqlens_k"],
|
||||
block_tables=self.share_inputs["block_tables"],
|
||||
caches=self.share_inputs["caches"],
|
||||
routing_replay_table=routing_replay_table,
|
||||
)
|
||||
|
||||
# Update Batch type for cuda graph
|
||||
@@ -1314,6 +1339,9 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
if int((self.share_inputs["seq_lens_this_time"] > 0).sum()) == 0:
|
||||
break
|
||||
|
||||
if self.fd_config.routing_replay_config.enable_routing_replay:
|
||||
self.routing_replay_manager.clear_routing_table()
|
||||
|
||||
def _update_chunked_prefill(self, tasks):
|
||||
"""
|
||||
Update chunked prefill related parameters
|
||||
@@ -1694,6 +1722,17 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.seq_lens_this_time_buffer[:num_running_requests].copy_(
|
||||
self.share_inputs["seq_lens_this_time"][:num_running_requests], False
|
||||
)
|
||||
|
||||
# Routing replay
|
||||
if self.fd_config.routing_replay_config.enable_routing_replay:
|
||||
if (
|
||||
not self.exist_prefill()
|
||||
and not self.exist_decode()
|
||||
and self.share_inputs["is_block_step"].sum() == 0
|
||||
and self.share_inputs["is_chunk_step"].sum() == 0
|
||||
):
|
||||
self.routing_replay_manager.put_table_to_store()
|
||||
|
||||
return None
|
||||
|
||||
def _add_cache(self, model_forward_batch) -> None:
|
||||
|
||||
Reference in New Issue
Block a user